4666: Add a score threshold search parameter r=ManyTheFish a=dureuill

# Pull Request

## Related issue
Fixes https://github.com/meilisearch/meilisearch/issues/4609

## What does this PR do?
- See [usage](https://meilisearch.notion.site/Filter-by-score-usage-224a183ce7b24ca99b6a9a8da755668a?pvs=25#95b76ded400342ba9ab3d67c734836f0) and [the known limitation](https://meilisearch.notion.site/Filter-by-score-usage-224a183ce7b24ca99b6a9a8da755668a?pvs=25#e4e32195bf0e4195b5daecdbb7a97a17)


Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2024-06-03 08:42:44 +00:00 committed by GitHub
commit fc584f1db3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 804 additions and 32 deletions

View File

@ -189,4 +189,6 @@ merge_with_error_impl_take_error_message!(ParseTaskKindError);
merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(ParseTaskStatusError);
merge_with_error_impl_take_error_message!(IndexUidFormatError); merge_with_error_impl_take_error_message!(IndexUidFormatError);
merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio);
merge_with_error_impl_take_error_message!(InvalidSearchRankingScoreThreshold);
merge_with_error_impl_take_error_message!(InvalidSimilarRankingScoreThreshold);
merge_with_error_impl_take_error_message!(InvalidSimilarId); merge_with_error_impl_take_error_message!(InvalidSimilarId);

View File

@ -241,6 +241,8 @@ InvalidSearchAttributesToCrop , InvalidRequest , BAD_REQUEST ;
InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ;
InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
InvalidSearchRankingScoreThreshold , InvalidRequest , BAD_REQUEST ;
InvalidSimilarRankingScoreThreshold , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
@ -505,6 +507,21 @@ impl fmt::Display for deserr_codes::InvalidSimilarId {
} }
} }
impl fmt::Display for deserr_codes::InvalidSearchRankingScoreThreshold {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`."
)
}
}
impl fmt::Display for deserr_codes::InvalidSimilarRankingScoreThreshold {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
deserr_codes::InvalidSearchRankingScoreThreshold.fmt(f)
}
}
#[macro_export] #[macro_export]
macro_rules! internal_error { macro_rules! internal_error {
($target:ty : $($other:path), *) => { ($target:ty : $($other:path), *) => {

View File

@ -648,6 +648,7 @@ pub struct SearchAggregator {
// scoring // scoring
show_ranking_score: bool, show_ranking_score: bool,
show_ranking_score_details: bool, show_ranking_score_details: bool,
ranking_score_threshold: bool,
} }
impl SearchAggregator { impl SearchAggregator {
@ -676,6 +677,7 @@ impl SearchAggregator {
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid, hybrid,
ranking_score_threshold,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@ -748,6 +750,7 @@ impl SearchAggregator {
ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score = *show_ranking_score;
ret.show_ranking_score_details = *show_ranking_score_details; ret.show_ranking_score_details = *show_ranking_score_details;
ret.ranking_score_threshold = ranking_score_threshold.is_some();
if let Some(hybrid) = hybrid { if let Some(hybrid) = hybrid {
ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO(); ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO();
@ -821,6 +824,7 @@ impl SearchAggregator {
hybrid, hybrid,
total_degraded, total_degraded,
total_used_negative_operator, total_used_negative_operator,
ranking_score_threshold,
} = other; } = other;
if self.timestamp.is_none() { if self.timestamp.is_none() {
@ -904,6 +908,7 @@ impl SearchAggregator {
// scoring // scoring
self.show_ranking_score |= show_ranking_score; self.show_ranking_score |= show_ranking_score;
self.show_ranking_score_details |= show_ranking_score_details; self.show_ranking_score_details |= show_ranking_score_details;
self.ranking_score_threshold |= ranking_score_threshold;
} }
pub fn into_event(self, user: &User, event_name: &str) -> Option<Track> { pub fn into_event(self, user: &User, event_name: &str) -> Option<Track> {
@ -945,6 +950,7 @@ impl SearchAggregator {
hybrid, hybrid,
total_degraded, total_degraded,
total_used_negative_operator, total_used_negative_operator,
ranking_score_threshold,
} = self; } = self;
if total_received == 0 { if total_received == 0 {
@ -1015,6 +1021,7 @@ impl SearchAggregator {
"scoring": { "scoring": {
"show_ranking_score": show_ranking_score, "show_ranking_score": show_ranking_score,
"show_ranking_score_details": show_ranking_score_details, "show_ranking_score_details": show_ranking_score_details,
"ranking_score_threshold": ranking_score_threshold,
}, },
}); });
@ -1087,6 +1094,7 @@ impl MultiSearchAggregator {
matching_strategy: _, matching_strategy: _,
attributes_to_search_on: _, attributes_to_search_on: _,
hybrid: _, hybrid: _,
ranking_score_threshold: _,
} = query; } = query;
index_uid.as_str() index_uid.as_str()
@ -1234,6 +1242,7 @@ impl FacetSearchAggregator {
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid, hybrid,
ranking_score_threshold,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@ -1248,7 +1257,8 @@ impl FacetSearchAggregator {
|| filter.is_some() || filter.is_some()
|| *matching_strategy != MatchingStrategy::default() || *matching_strategy != MatchingStrategy::default()
|| attributes_to_search_on.is_some() || attributes_to_search_on.is_some()
|| hybrid.is_some(); || hybrid.is_some()
|| ranking_score_threshold.is_some();
ret ret
} }
@ -1624,6 +1634,7 @@ pub struct SimilarAggregator {
// scoring // scoring
show_ranking_score: bool, show_ranking_score: bool,
show_ranking_score_details: bool, show_ranking_score_details: bool,
ranking_score_threshold: bool,
} }
impl SimilarAggregator { impl SimilarAggregator {
@ -1638,6 +1649,7 @@ impl SimilarAggregator {
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
filter, filter,
ranking_score_threshold,
} = query; } = query;
let mut ret = Self::default(); let mut ret = Self::default();
@ -1675,6 +1687,7 @@ impl SimilarAggregator {
ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score = *show_ranking_score;
ret.show_ranking_score_details = *show_ranking_score_details; ret.show_ranking_score_details = *show_ranking_score_details;
ret.ranking_score_threshold = ranking_score_threshold.is_some();
ret.embedder = embedder.is_some(); ret.embedder = embedder.is_some();
@ -1708,6 +1721,7 @@ impl SimilarAggregator {
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
embedder, embedder,
ranking_score_threshold,
} = other; } = other;
if self.timestamp.is_none() { if self.timestamp.is_none() {
@ -1749,6 +1763,7 @@ impl SimilarAggregator {
// scoring // scoring
self.show_ranking_score |= show_ranking_score; self.show_ranking_score |= show_ranking_score;
self.show_ranking_score_details |= show_ranking_score_details; self.show_ranking_score_details |= show_ranking_score_details;
self.ranking_score_threshold |= ranking_score_threshold;
} }
pub fn into_event(self, user: &User, event_name: &str) -> Option<Track> { pub fn into_event(self, user: &User, event_name: &str) -> Option<Track> {
@ -1769,6 +1784,7 @@ impl SimilarAggregator {
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
embedder, embedder,
ranking_score_threshold,
} = self; } = self;
if total_received == 0 { if total_received == 0 {
@ -1808,6 +1824,7 @@ impl SimilarAggregator {
"scoring": { "scoring": {
"show_ranking_score": show_ranking_score, "show_ranking_score": show_ranking_score,
"show_ranking_score_details": show_ranking_score_details, "show_ranking_score_details": show_ranking_score_details,
"ranking_score_threshold": ranking_score_threshold,
}, },
}); });

View File

@ -14,8 +14,8 @@ use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::routes::indexes::search::search_kind; use crate::routes::indexes::search::search_kind;
use crate::search::{ use crate::search::{
add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, RankingScoreThreshold,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, SearchQuery, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
}; };
use crate::search_queue::SearchQueue; use crate::search_queue::SearchQueue;
@ -46,6 +46,8 @@ pub struct FacetSearchQuery {
pub matching_strategy: MatchingStrategy, pub matching_strategy: MatchingStrategy,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToSearchOn>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToSearchOn>, default)]
pub attributes_to_search_on: Option<Vec<String>>, pub attributes_to_search_on: Option<Vec<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchRankingScoreThreshold>, default)]
pub ranking_score_threshold: Option<RankingScoreThreshold>,
} }
pub async fn search( pub async fn search(
@ -103,6 +105,7 @@ impl From<FacetSearchQuery> for SearchQuery {
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid, hybrid,
ranking_score_threshold,
} = value; } = value;
SearchQuery { SearchQuery {
@ -128,6 +131,7 @@ impl From<FacetSearchQuery> for SearchQuery {
vector, vector,
attributes_to_search_on, attributes_to_search_on,
hybrid, hybrid,
ranking_score_threshold,
} }
} }
} }

View File

@ -19,9 +19,10 @@ use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::search::{ use crate::search::{
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchKind, SearchQuery, add_search_rules, perform_search, HybridQuery, MatchingStrategy, RankingScoreThreshold,
SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, SearchKind, SearchQuery, SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT,
DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
}; };
use crate::search_queue::SearchQueue; use crate::search_queue::SearchQueue;
@ -82,6 +83,21 @@ pub struct SearchQueryGet {
pub hybrid_embedder: Option<String>, pub hybrid_embedder: Option<String>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchSemanticRatio>)] #[deserr(default, error = DeserrQueryParamError<InvalidSearchSemanticRatio>)]
pub hybrid_semantic_ratio: Option<SemanticRatioGet>, pub hybrid_semantic_ratio: Option<SemanticRatioGet>,
#[deserr(default, error = DeserrQueryParamError<InvalidSearchRankingScoreThreshold>)]
pub ranking_score_threshold: Option<RankingScoreThresholdGet>,
}
#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)]
#[deserr(try_from(String) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThresholdGet(RankingScoreThreshold);
impl std::convert::TryFrom<String> for RankingScoreThresholdGet {
type Error = InvalidSearchRankingScoreThreshold;
fn try_from(s: String) -> Result<Self, Self::Error> {
let f: f64 = s.parse().map_err(|_| InvalidSearchRankingScoreThreshold)?;
Ok(RankingScoreThresholdGet(RankingScoreThreshold::try_from(f)?))
}
} }
#[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)] #[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)]
@ -152,6 +168,7 @@ impl From<SearchQueryGet> for SearchQuery {
matching_strategy: other.matching_strategy, matching_strategy: other.matching_strategy,
attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()),
hybrid, hybrid,
ranking_score_threshold: other.ranking_score_threshold.map(|o| o.0),
} }
} }
} }

View File

@ -6,8 +6,8 @@ use meilisearch_types::deserr::query_params::Param;
use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
use meilisearch_types::error::deserr_codes::{ use meilisearch_types::error::deserr_codes::{
InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId, InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId,
InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarShowRankingScore, InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarRankingScoreThreshold,
InvalidSimilarShowRankingScoreDetails, InvalidSimilarShowRankingScore, InvalidSimilarShowRankingScoreDetails,
}; };
use meilisearch_types::error::{ErrorCode as _, ResponseError}; use meilisearch_types::error::{ErrorCode as _, ResponseError};
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
@ -21,8 +21,8 @@ use crate::analytics::{Analytics, SimilarAggregator};
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler; use crate::extractors::sequential_extractor::SeqHandler;
use crate::search::{ use crate::search::{
add_search_rules, perform_similar, SearchKind, SimilarQuery, SimilarResult, add_search_rules, perform_similar, RankingScoreThresholdSimilar, SearchKind, SimilarQuery,
DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, SimilarResult, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET,
}; };
pub fn configure(cfg: &mut web::ServiceConfig) { pub fn configure(cfg: &mut web::ServiceConfig) {
@ -42,9 +42,7 @@ pub async fn similar_get(
) -> Result<HttpResponse, ResponseError> { ) -> Result<HttpResponse, ResponseError> {
let index_uid = IndexUid::try_from(index_uid.into_inner())?; let index_uid = IndexUid::try_from(index_uid.into_inner())?;
let query = params.0.try_into().map_err(|code: InvalidSimilarId| { let query = params.0.try_into()?;
ResponseError::from_msg(code.to_string(), code.error_code())
})?;
let mut aggregate = SimilarAggregator::from_query(&query, &req); let mut aggregate = SimilarAggregator::from_query(&query, &req);
@ -130,12 +128,27 @@ pub struct SimilarQueryGet {
show_ranking_score: Param<bool>, show_ranking_score: Param<bool>,
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarShowRankingScoreDetails>)] #[deserr(default, error = DeserrQueryParamError<InvalidSimilarShowRankingScoreDetails>)]
show_ranking_score_details: Param<bool>, show_ranking_score_details: Param<bool>,
#[deserr(default, error = DeserrQueryParamError<InvalidSimilarRankingScoreThreshold>, default)]
pub ranking_score_threshold: Option<RankingScoreThresholdGet>,
#[deserr(default, error = DeserrQueryParamError<InvalidEmbedder>)] #[deserr(default, error = DeserrQueryParamError<InvalidEmbedder>)]
pub embedder: Option<String>, pub embedder: Option<String>,
} }
#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)]
#[deserr(try_from(String) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)]
pub struct RankingScoreThresholdGet(RankingScoreThresholdSimilar);
impl std::convert::TryFrom<String> for RankingScoreThresholdGet {
type Error = InvalidSimilarRankingScoreThreshold;
fn try_from(s: String) -> Result<Self, Self::Error> {
let f: f64 = s.parse().map_err(|_| InvalidSimilarRankingScoreThreshold)?;
Ok(RankingScoreThresholdGet(RankingScoreThresholdSimilar::try_from(f)?))
}
}
impl TryFrom<SimilarQueryGet> for SimilarQuery { impl TryFrom<SimilarQueryGet> for SimilarQuery {
type Error = InvalidSimilarId; type Error = ResponseError;
fn try_from( fn try_from(
SimilarQueryGet { SimilarQueryGet {
@ -147,6 +160,7 @@ impl TryFrom<SimilarQueryGet> for SimilarQuery {
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
embedder, embedder,
ranking_score_threshold,
}: SimilarQueryGet, }: SimilarQueryGet,
) -> Result<Self, Self::Error> { ) -> Result<Self, Self::Error> {
let filter = match filter { let filter = match filter {
@ -158,7 +172,9 @@ impl TryFrom<SimilarQueryGet> for SimilarQuery {
}; };
Ok(SimilarQuery { Ok(SimilarQuery {
id: id.0.try_into()?, id: id.0.try_into().map_err(|code: InvalidSimilarId| {
ResponseError::from_msg(code.to_string(), code.error_code())
})?,
offset: offset.0, offset: offset.0,
limit: limit.0, limit: limit.0,
filter, filter,
@ -166,6 +182,7 @@ impl TryFrom<SimilarQueryGet> for SimilarQuery {
attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()), attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()),
show_ranking_score: show_ranking_score.0, show_ranking_score: show_ranking_score.0,
show_ranking_score_details: show_ranking_score_details.0, show_ranking_score_details: show_ranking_score_details.0,
ranking_score_threshold: ranking_score_threshold.map(|x| x.0),
}) })
} }
} }

View File

@ -87,6 +87,44 @@ pub struct SearchQuery {
pub matching_strategy: MatchingStrategy, pub matching_strategy: MatchingStrategy,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToSearchOn>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToSearchOn>, default)]
pub attributes_to_search_on: Option<Vec<String>>, pub attributes_to_search_on: Option<Vec<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchRankingScoreThreshold>, default)]
pub ranking_score_threshold: Option<RankingScoreThreshold>,
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(f64);
impl std::convert::TryFrom<f64> for RankingScoreThreshold {
type Error = InvalidSearchRankingScoreThreshold;
fn try_from(f: f64) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchRankingScoreThreshold)
} else {
Ok(RankingScoreThreshold(f))
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)]
pub struct RankingScoreThresholdSimilar(f64);
impl std::convert::TryFrom<f64> for RankingScoreThresholdSimilar {
type Error = InvalidSimilarRankingScoreThreshold;
fn try_from(f: f64) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSimilarRankingScoreThreshold)
} else {
Ok(Self(f))
}
}
} }
// Since this structure is logged A LOT we're going to reduce the number of things it logs to the bare minimum. // Since this structure is logged A LOT we're going to reduce the number of things it logs to the bare minimum.
@ -117,6 +155,7 @@ impl fmt::Debug for SearchQuery {
crop_marker, crop_marker,
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
ranking_score_threshold,
} = self; } = self;
let mut debug = f.debug_struct("SearchQuery"); let mut debug = f.debug_struct("SearchQuery");
@ -188,6 +227,9 @@ impl fmt::Debug for SearchQuery {
debug.field("highlight_pre_tag", &highlight_pre_tag); debug.field("highlight_pre_tag", &highlight_pre_tag);
debug.field("highlight_post_tag", &highlight_post_tag); debug.field("highlight_post_tag", &highlight_post_tag);
debug.field("crop_marker", &crop_marker); debug.field("crop_marker", &crop_marker);
if let Some(ranking_score_threshold) = ranking_score_threshold {
debug.field("ranking_score_threshold", &ranking_score_threshold);
}
debug.finish() debug.finish()
} }
@ -356,6 +398,8 @@ pub struct SearchQueryWithIndex {
pub matching_strategy: MatchingStrategy, pub matching_strategy: MatchingStrategy,
#[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToSearchOn>, default)] #[deserr(default, error = DeserrJsonError<InvalidSearchAttributesToSearchOn>, default)]
pub attributes_to_search_on: Option<Vec<String>>, pub attributes_to_search_on: Option<Vec<String>>,
#[deserr(default, error = DeserrJsonError<InvalidSearchRankingScoreThreshold>, default)]
pub ranking_score_threshold: Option<RankingScoreThreshold>,
} }
impl SearchQueryWithIndex { impl SearchQueryWithIndex {
@ -384,6 +428,7 @@ impl SearchQueryWithIndex {
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid, hybrid,
ranking_score_threshold,
} = self; } = self;
( (
index_uid, index_uid,
@ -410,6 +455,7 @@ impl SearchQueryWithIndex {
matching_strategy, matching_strategy,
attributes_to_search_on, attributes_to_search_on,
hybrid, hybrid,
ranking_score_threshold,
// do not use ..Default::default() here, // do not use ..Default::default() here,
// rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex`
}, },
@ -436,6 +482,8 @@ pub struct SimilarQuery {
pub show_ranking_score: bool, pub show_ranking_score: bool,
#[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScoreDetails>, default)] #[deserr(default, error = DeserrJsonError<InvalidSimilarShowRankingScoreDetails>, default)]
pub show_ranking_score_details: bool, pub show_ranking_score_details: bool,
#[deserr(default, error = DeserrJsonError<InvalidSimilarRankingScoreThreshold>, default)]
pub ranking_score_threshold: Option<RankingScoreThresholdSimilar>,
} }
#[derive(Debug, Clone, PartialEq, Deserr)] #[derive(Debug, Clone, PartialEq, Deserr)]
@ -664,6 +712,9 @@ fn prepare_search<'t>(
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn); let mut search = index.search(rtxn);
search.time_budget(time_budget); search.time_budget(time_budget);
if let Some(ranking_score_threshold) = query.ranking_score_threshold {
search.ranking_score_threshold(ranking_score_threshold.0);
}
match search_kind { match search_kind {
SearchKind::KeywordOnly => { SearchKind::KeywordOnly => {
@ -705,11 +756,16 @@ fn prepare_search<'t>(
.unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS);
search.exhaustive_number_hits(is_finite_pagination); search.exhaustive_number_hits(is_finite_pagination);
search.scoring_strategy(if query.show_ranking_score || query.show_ranking_score_details { search.scoring_strategy(
if query.show_ranking_score
|| query.show_ranking_score_details
|| query.ranking_score_threshold.is_some()
{
ScoringStrategy::Detailed ScoringStrategy::Detailed
} else { } else {
ScoringStrategy::Skip ScoringStrategy::Skip
}); },
);
// compute the offset on the limit depending on the pagination mode. // compute the offset on the limit depending on the pagination mode.
let (offset, limit) = if is_finite_pagination { let (offset, limit) = if is_finite_pagination {
@ -787,10 +843,6 @@ pub fn perform_search(
let SearchQuery { let SearchQuery {
q, q,
vector: _,
hybrid: _,
// already computed from prepare_search
offset: _,
limit, limit,
page, page,
hits_per_page, hits_per_page,
@ -801,14 +853,19 @@ pub fn perform_search(
show_matches_position, show_matches_position,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
filter: _,
sort, sort,
facets, facets,
highlight_pre_tag, highlight_pre_tag,
highlight_post_tag, highlight_post_tag,
crop_marker, crop_marker,
// already used in prepare_search
vector: _,
hybrid: _,
offset: _,
ranking_score_threshold: _,
matching_strategy: _, matching_strategy: _,
attributes_to_search_on: _, attributes_to_search_on: _,
filter: _,
} = query; } = query;
let format = AttributesFormat { let format = AttributesFormat {
@ -1070,6 +1127,7 @@ pub fn perform_similar(
attributes_to_retrieve, attributes_to_retrieve,
show_ranking_score, show_ranking_score,
show_ranking_score_details, show_ranking_score_details,
ranking_score_threshold,
} = query; } = query;
// using let-else rather than `?` so that the borrow checker identifies we're always returning here, // using let-else rather than `?` so that the borrow checker identifies we're always returning here,
@ -1093,6 +1151,10 @@ pub fn perform_similar(
} }
} }
if let Some(ranking_score_threshold) = ranking_score_threshold {
similar.ranking_score_threshold(ranking_score_threshold.0);
}
let milli::SearchResult { let milli::SearchResult {
documents_ids, documents_ids,
matching_words: _, matching_words: _,

View File

@ -321,6 +321,40 @@ async fn search_bad_facets() {
// Can't make the `attributes_to_highlight` fail with a get search since it'll accept anything as an array of strings. // Can't make the `attributes_to_highlight` fail with a get search since it'll accept anything as an array of strings.
} }
#[actix_rt::test]
async fn search_bad_threshold() {
let server = Server::new().await;
let index = server.index("test");
let (response, code) = index.search_post(json!({"rankingScoreThreshold": "doggo"})).await;
snapshot!(code, @"400 Bad Request");
snapshot!(json_string!(response), @r###"
{
"message": "Invalid value type at `.rankingScoreThreshold`: expected a number, but found a string: `\"doggo\"`",
"code": "invalid_search_ranking_score_threshold",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_ranking_score_threshold"
}
"###);
}
#[actix_rt::test]
async fn search_invalid_threshold() {
let server = Server::new().await;
let index = server.index("test");
let (response, code) = index.search_post(json!({"rankingScoreThreshold": 42})).await;
snapshot!(code, @"400 Bad Request");
snapshot!(json_string!(response), @r###"
{
"message": "Invalid value at `.rankingScoreThreshold`: the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_search_ranking_score_threshold",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_search_ranking_score_threshold"
}
"###);
}
#[actix_rt::test] #[actix_rt::test]
async fn search_non_filterable_facets() { async fn search_non_filterable_facets() {
let server = Server::new().await; let server = Server::new().await;

View File

@ -48,6 +48,31 @@ static DOCUMENTS: Lazy<Value> = Lazy::new(|| {
]) ])
}); });
static SCORE_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
json!([
{
"title": "Batman the dark knight returns: Part 1",
"id": "A",
},
{
"title": "Batman the dark knight returns: Part 2",
"id": "B",
},
{
"title": "Batman Returns",
"id": "C",
},
{
"title": "Batman",
"id": "D",
},
{
"title": "Badman",
"id": "E",
}
])
});
static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| { static NESTED_DOCUMENTS: Lazy<Value> = Lazy::new(|| {
json!([ json!([
{ {
@ -960,6 +985,213 @@ async fn test_score_details() {
.await; .await;
} }
#[actix_rt::test]
async fn test_score() {
let server = Server::new().await;
let index = server.index("test");
let documents = SCORE_DOCUMENTS.clone();
let res = index.add_documents(json!(documents), None).await;
index.wait_task(res.0.uid()).await;
index
.search(
json!({
"q": "Badman the dark knight returns 1",
"showRankingScore": true,
}),
|response, code| {
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Batman the dark knight returns: Part 1",
"id": "A",
"_rankingScore": 0.9746605609456898
},
{
"title": "Batman the dark knight returns: Part 2",
"id": "B",
"_rankingScore": 0.8055252965383685
},
{
"title": "Badman",
"id": "E",
"_rankingScore": 0.16666666666666666
},
{
"title": "Batman Returns",
"id": "C",
"_rankingScore": 0.07702020202020202
},
{
"title": "Batman",
"id": "D",
"_rankingScore": 0.07702020202020202
}
]
"###);
},
)
.await;
}
#[actix_rt::test]
async fn test_score_threshold() {
let query = "Badman dark returns 1";
let server = Server::new().await;
let index = server.index("test");
let documents = SCORE_DOCUMENTS.clone();
let res = index.add_documents(json!(documents), None).await;
index.wait_task(res.0.uid()).await;
index
.search(
json!({
"q": query,
"showRankingScore": true,
"rankingScoreThreshold": 0.0
}),
|response, code| {
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"5");
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Batman the dark knight returns: Part 1",
"id": "A",
"_rankingScore": 0.93430081300813
},
{
"title": "Batman the dark knight returns: Part 2",
"id": "B",
"_rankingScore": 0.6685627880184332
},
{
"title": "Badman",
"id": "E",
"_rankingScore": 0.25
},
{
"title": "Batman Returns",
"id": "C",
"_rankingScore": 0.11553030303030302
},
{
"title": "Batman",
"id": "D",
"_rankingScore": 0.11553030303030302
}
]
"###);
},
)
.await;
index
.search(
json!({
"q": query,
"showRankingScore": true,
"rankingScoreThreshold": 0.2
}),
|response, code| {
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"3"###);
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Batman the dark knight returns: Part 1",
"id": "A",
"_rankingScore": 0.93430081300813
},
{
"title": "Batman the dark knight returns: Part 2",
"id": "B",
"_rankingScore": 0.6685627880184332
},
{
"title": "Badman",
"id": "E",
"_rankingScore": 0.25
}
]
"###);
},
)
.await;
index
.search(
json!({
"q": query,
"showRankingScore": true,
"rankingScoreThreshold": 0.5
}),
|response, code| {
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"2"###);
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Batman the dark knight returns: Part 1",
"id": "A",
"_rankingScore": 0.93430081300813
},
{
"title": "Batman the dark knight returns: Part 2",
"id": "B",
"_rankingScore": 0.6685627880184332
}
]
"###);
},
)
.await;
index
.search(
json!({
"q": query,
"showRankingScore": true,
"rankingScoreThreshold": 0.8
}),
|response, code| {
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"1"###);
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###"
[
{
"title": "Batman the dark knight returns: Part 1",
"id": "A",
"_rankingScore": 0.93430081300813
}
]
"###);
},
)
.await;
index
.search(
json!({
"q": query,
"showRankingScore": true,
"rankingScoreThreshold": 1.0
}),
|response, code| {
meili_snap::snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"0"###);
// nobody is perfect
meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]");
},
)
.await;
}
#[actix_rt::test] #[actix_rt::test]
async fn test_degraded_score_details() { async fn test_degraded_score_details() {
let server = Server::new().await; let server = Server::new().await;

View File

@ -87,6 +87,68 @@ async fn similar_bad_id() {
"###); "###);
} }
#[actix_rt::test]
async fn similar_bad_ranking_score_threshold() {
let server = Server::new().await;
let index = server.index("test");
server.set_features(json!({"vectorStore": true})).await;
let (response, code) = index
.update_settings(json!({
"embedders": {
"manual": {
"source": "userProvided",
"dimensions": 3,
}
},
"filterableAttributes": ["title"]}))
.await;
snapshot!(code, @"202 Accepted");
server.wait_task(response.uid()).await;
let (response, code) = index.similar_post(json!({"rankingScoreThreshold": ["doggo"]})).await;
snapshot!(code, @"400 Bad Request");
snapshot!(json_string!(response), @r###"
{
"message": "Invalid value type at `.rankingScoreThreshold`: expected a number, but found an array: `[\"doggo\"]`",
"code": "invalid_similar_ranking_score_threshold",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_similar_ranking_score_threshold"
}
"###);
}
#[actix_rt::test]
async fn similar_invalid_ranking_score_threshold() {
let server = Server::new().await;
let index = server.index("test");
server.set_features(json!({"vectorStore": true})).await;
let (response, code) = index
.update_settings(json!({
"embedders": {
"manual": {
"source": "userProvided",
"dimensions": 3,
}
},
"filterableAttributes": ["title"]}))
.await;
snapshot!(code, @"202 Accepted");
server.wait_task(response.uid()).await;
let (response, code) = index.similar_post(json!({"rankingScoreThreshold": 42})).await;
snapshot!(code, @"400 Bad Request");
snapshot!(json_string!(response), @r###"
{
"message": "Invalid value at `.rankingScoreThreshold`: the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`.",
"code": "invalid_similar_ranking_score_threshold",
"type": "invalid_request",
"link": "https://docs.meilisearch.com/errors#invalid_similar_ranking_score_threshold"
}
"###);
}
#[actix_rt::test] #[actix_rt::test]
async fn similar_invalid_id() { async fn similar_invalid_id() {
let server = Server::new().await; let server = Server::new().await;

View File

@ -194,6 +194,235 @@ async fn basic() {
.await; .await;
} }
#[actix_rt::test]
async fn ranking_score_threshold() {
let server = Server::new().await;
let index = server.index("test");
let (value, code) = server.set_features(json!({"vectorStore": true})).await;
snapshot!(code, @"200 OK");
snapshot!(value, @r###"
{
"vectorStore": true,
"metrics": false,
"logsRoute": false
}
"###);
let (response, code) = index
.update_settings(json!({
"embedders": {
"manual": {
"source": "userProvided",
"dimensions": 3,
}
},
"filterableAttributes": ["title"]}))
.await;
snapshot!(code, @"202 Accepted");
server.wait_task(response.uid()).await;
let documents = DOCUMENTS.clone();
let (value, code) = index.add_documents(documents, None).await;
snapshot!(code, @"202 Accepted");
index.wait_task(value.uid()).await;
index
.similar(
json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0}),
|response, code| {
snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"4");
snapshot!(json_string!(response["hits"]), @r###"
[
{
"title": "Escape Room",
"release_year": 2019,
"id": "522681",
"_vectors": {
"manual": [
0.1,
0.6,
0.8
]
},
"_rankingScore": 0.890957772731781
},
{
"title": "Captain Marvel",
"release_year": 2019,
"id": "299537",
"_vectors": {
"manual": [
0.6,
0.8,
-0.2
]
},
"_rankingScore": 0.39060014486312866
},
{
"title": "How to Train Your Dragon: The Hidden World",
"release_year": 2019,
"id": "166428",
"_vectors": {
"manual": [
0.7,
0.7,
-0.4
]
},
"_rankingScore": 0.2819308042526245
},
{
"title": "Shazam!",
"release_year": 2019,
"id": "287947",
"_vectors": {
"manual": [
0.8,
0.4,
-0.5
]
},
"_rankingScore": 0.1662663221359253
}
]
"###);
},
)
.await;
index
.similar(
json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.2}),
|response, code| {
snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"3");
snapshot!(json_string!(response["hits"]), @r###"
[
{
"title": "Escape Room",
"release_year": 2019,
"id": "522681",
"_vectors": {
"manual": [
0.1,
0.6,
0.8
]
},
"_rankingScore": 0.890957772731781
},
{
"title": "Captain Marvel",
"release_year": 2019,
"id": "299537",
"_vectors": {
"manual": [
0.6,
0.8,
-0.2
]
},
"_rankingScore": 0.39060014486312866
},
{
"title": "How to Train Your Dragon: The Hidden World",
"release_year": 2019,
"id": "166428",
"_vectors": {
"manual": [
0.7,
0.7,
-0.4
]
},
"_rankingScore": 0.2819308042526245
}
]
"###);
},
)
.await;
index
.similar(
json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.3}),
|response, code| {
snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"2");
snapshot!(json_string!(response["hits"]), @r###"
[
{
"title": "Escape Room",
"release_year": 2019,
"id": "522681",
"_vectors": {
"manual": [
0.1,
0.6,
0.8
]
},
"_rankingScore": 0.890957772731781
},
{
"title": "Captain Marvel",
"release_year": 2019,
"id": "299537",
"_vectors": {
"manual": [
0.6,
0.8,
-0.2
]
},
"_rankingScore": 0.39060014486312866
}
]
"###);
},
)
.await;
index
.similar(
json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.6}),
|response, code| {
snapshot!(code, @"200 OK");
meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"1");
snapshot!(json_string!(response["hits"]), @r###"
[
{
"title": "Escape Room",
"release_year": 2019,
"id": "522681",
"_vectors": {
"manual": [
0.1,
0.6,
0.8
]
},
"_rankingScore": 0.890957772731781
}
]
"###);
},
)
.await;
index
.similar(
json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.9}),
|response, code| {
snapshot!(code, @"200 OK");
snapshot!(json_string!(response["hits"]), @"[]");
},
)
.await;
}
#[actix_rt::test] #[actix_rt::test]
async fn filter() { async fn filter() {
let server = Server::new().await; let server = Server::new().await;

View File

@ -66,6 +66,7 @@ fn main() -> Result<(), Box<dyn Error>> {
&mut DefaultSearchLogger, &mut DefaultSearchLogger,
logger, logger,
TimeBudget::max(), TimeBudget::max(),
None,
)?; )?;
if let Some((logger, dir)) = detailed_logger { if let Some((logger, dir)) = detailed_logger {
logger.finish(&mut ctx, Path::new(dir))?; logger.finish(&mut ctx, Path::new(dir))?;

View File

@ -169,6 +169,7 @@ impl<'a> Search<'a> {
index: self.index, index: self.index,
semantic: self.semantic.clone(), semantic: self.semantic.clone(),
time_budget: self.time_budget.clone(), time_budget: self.time_budget.clone(),
ranking_score_threshold: self.ranking_score_threshold,
}; };
let semantic = search.semantic.take(); let semantic = search.semantic.take();

View File

@ -50,6 +50,7 @@ pub struct Search<'a> {
index: &'a Index, index: &'a Index,
semantic: Option<SemanticSearch>, semantic: Option<SemanticSearch>,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
} }
impl<'a> Search<'a> { impl<'a> Search<'a> {
@ -70,6 +71,7 @@ impl<'a> Search<'a> {
index, index,
semantic: None, semantic: None,
time_budget: TimeBudget::max(), time_budget: TimeBudget::max(),
ranking_score_threshold: None,
} }
} }
@ -146,6 +148,11 @@ impl<'a> Search<'a> {
self self
} }
pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Search<'a> {
self.ranking_score_threshold = Some(ranking_score_threshold);
self
}
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> { pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
if has_vector_search { if has_vector_search {
let ctx = SearchContext::new(self.index, self.rtxn)?; let ctx = SearchContext::new(self.index, self.rtxn)?;
@ -184,6 +191,7 @@ impl<'a> Search<'a> {
embedder_name, embedder_name,
embedder, embedder,
self.time_budget.clone(), self.time_budget.clone(),
self.ranking_score_threshold,
)? )?
} }
_ => execute_search( _ => execute_search(
@ -201,6 +209,7 @@ impl<'a> Search<'a> {
&mut DefaultSearchLogger, &mut DefaultSearchLogger,
&mut DefaultSearchLogger, &mut DefaultSearchLogger,
self.time_budget.clone(), self.time_budget.clone(),
self.ranking_score_threshold,
)?, )?,
}; };
@ -239,6 +248,7 @@ impl fmt::Debug for Search<'_> {
index: _, index: _,
semantic, semantic,
time_budget, time_budget,
ranking_score_threshold,
} = self; } = self;
f.debug_struct("Search") f.debug_struct("Search")
.field("query", query) .field("query", query)
@ -257,6 +267,7 @@ impl fmt::Debug for Search<'_> {
&semantic.as_ref().map(|semantic| &semantic.embedder_name), &semantic.as_ref().map(|semantic| &semantic.embedder_name),
) )
.field("time_budget", time_budget) .field("time_budget", time_budget)
.field("ranking_score_threshold", ranking_score_threshold)
.finish() .finish()
} }
} }

View File

@ -28,6 +28,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
logger: &mut dyn SearchLogger<Q>, logger: &mut dyn SearchLogger<Q>,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<BucketSortOutput> { ) -> Result<BucketSortOutput> {
logger.initial_query(query); logger.initial_query(query);
logger.ranking_rules(&ranking_rules); logger.ranking_rules(&ranking_rules);
@ -164,7 +165,19 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
loop { loop {
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]); let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
ranking_rule_scores.push(ScoreDetails::Skipped); ranking_rule_scores.push(ScoreDetails::Skipped);
// remove candidates from the universe without adding them to result if their score is below the threshold
if let Some(ranking_score_threshold) = ranking_score_threshold {
let current_score = ScoreDetails::global_score(ranking_rule_scores.iter());
if current_score < ranking_score_threshold {
all_candidates -= bucket | &ranking_rule_universes[cur_ranking_rule_index];
back!();
continue;
}
}
maybe_add_to_results!(bucket); maybe_add_to_results!(bucket);
ranking_rule_scores.pop(); ranking_rule_scores.pop();
if cur_ranking_rule_index == 0 { if cur_ranking_rule_index == 0 {
@ -220,6 +233,18 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
debug_assert!( debug_assert!(
ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates) ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates)
); );
// remove candidates from the universe without adding them to result if their score is below the threshold
if let Some(ranking_score_threshold) = ranking_score_threshold {
let current_score = ScoreDetails::global_score(ranking_rule_scores.iter());
if current_score < ranking_score_threshold {
all_candidates -=
next_bucket.candidates | &ranking_rule_universes[cur_ranking_rule_index];
back!();
continue;
}
}
ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates; ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates;
if cur_ranking_rule_index == ranking_rules_len - 1 if cur_ranking_rule_index == ranking_rules_len - 1

View File

@ -523,6 +523,7 @@ mod tests {
&mut crate::DefaultSearchLogger, &mut crate::DefaultSearchLogger,
&mut crate::DefaultSearchLogger, &mut crate::DefaultSearchLogger,
TimeBudget::max(), TimeBudget::max(),
None,
) )
.unwrap(); .unwrap();

View File

@ -573,6 +573,7 @@ pub fn execute_vector_search(
embedder_name: &str, embedder_name: &str,
embedder: &Embedder, embedder: &Embedder,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -602,6 +603,7 @@ pub fn execute_vector_search(
scoring_strategy, scoring_strategy,
placeholder_search_logger, placeholder_search_logger,
time_budget, time_budget,
ranking_score_threshold,
)?; )?;
Ok(PartialSearchResult { Ok(PartialSearchResult {
@ -631,6 +633,7 @@ pub fn execute_search(
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>, placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>, query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -719,6 +722,7 @@ pub fn execute_search(
scoring_strategy, scoring_strategy,
query_graph_logger, query_graph_logger,
time_budget, time_budget,
ranking_score_threshold,
)? )?
} else { } else {
let ranking_rules = let ranking_rules =
@ -733,6 +737,7 @@ pub fn execute_search(
scoring_strategy, scoring_strategy,
placeholder_search_logger, placeholder_search_logger,
time_budget, time_budget,
ranking_score_threshold,
)? )?
}; };

View File

@ -17,6 +17,7 @@ pub struct Similar<'a> {
index: &'a Index, index: &'a Index,
embedder_name: String, embedder_name: String,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
ranking_score_threshold: Option<f64>,
} }
impl<'a> Similar<'a> { impl<'a> Similar<'a> {
@ -29,7 +30,17 @@ impl<'a> Similar<'a> {
embedder_name: String, embedder_name: String,
embedder: Arc<Embedder>, embedder: Arc<Embedder>,
) -> Self { ) -> Self {
Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } Self {
id,
filter: None,
offset,
limit,
rtxn,
index,
embedder_name,
embedder,
ranking_score_threshold: None,
}
} }
pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self {
@ -37,8 +48,18 @@ impl<'a> Similar<'a> {
self self
} }
pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Self {
self.ranking_score_threshold = Some(ranking_score_threshold);
self
}
pub fn execute(&self) -> Result<SearchResult> { pub fn execute(&self) -> Result<SearchResult> {
let universe = filtered_universe(self.index, self.rtxn, &self.filter)?; let mut universe = filtered_universe(self.index, self.rtxn, &self.filter)?;
// we never want to receive the docid
universe.remove(self.id);
let universe = universe;
let embedder_index = let embedder_index =
self.index self.index
@ -77,6 +98,8 @@ impl<'a> Similar<'a> {
let mut documents_seen = RoaringBitmap::new(); let mut documents_seen = RoaringBitmap::new();
documents_seen.insert(self.id); documents_seen.insert(self.id);
let mut candidates = universe;
for (docid, distance) in results for (docid, distance) in results
.into_iter() .into_iter()
// skip documents we've already seen & mark that we saw the current document // skip documents we've already seen & mark that we saw the current document
@ -85,8 +108,6 @@ impl<'a> Similar<'a> {
// take **after** filter and skip so that we get exactly limit elements if available // take **after** filter and skip so that we get exactly limit elements if available
.take(self.limit) .take(self.limit)
{ {
documents_ids.push(docid);
let score = 1.0 - distance; let score = 1.0 - distance;
let score = self let score = self
.embedder .embedder
@ -94,14 +115,28 @@ impl<'a> Similar<'a> {
.map(|distribution| distribution.shift(score)) .map(|distribution| distribution.shift(score))
.unwrap_or(score); .unwrap_or(score);
let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }); let score_details =
vec![ScoreDetails::Vector(score_details::Vector { similarity: Some(score) })];
document_scores.push(vec![score]); let score = ScoreDetails::global_score(score_details.iter());
if let Some(ranking_score_threshold) = &self.ranking_score_threshold {
if score < *ranking_score_threshold {
// this document is no longer a candidate
candidates.remove(docid);
// any document after this one is no longer a candidate either, so restrict the set to documents already seen.
candidates &= documents_seen;
break;
}
}
documents_ids.push(docid);
document_scores.push(score_details);
} }
Ok(SearchResult { Ok(SearchResult {
matching_words: Default::default(), matching_words: Default::default(),
candidates: universe, candidates,
documents_ids, documents_ids,
document_scores, document_scores,
degraded: false, degraded: false,