diff --git a/meilisearch-types/src/deserr/mod.rs b/meilisearch-types/src/deserr/mod.rs index c593c50fb..1c1b0e987 100644 --- a/meilisearch-types/src/deserr/mod.rs +++ b/meilisearch-types/src/deserr/mod.rs @@ -189,4 +189,6 @@ merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(IndexUidFormatError); merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); +merge_with_error_impl_take_error_message!(InvalidSearchRankingScoreThreshold); +merge_with_error_impl_take_error_message!(InvalidSimilarRankingScoreThreshold); merge_with_error_impl_take_error_message!(InvalidSimilarId); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index d2218807f..150c56b9d 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -241,6 +241,8 @@ InvalidSearchAttributesToCrop , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ; InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; +InvalidSearchRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; +InvalidSimilarRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; @@ -505,6 +507,21 @@ impl fmt::Display for deserr_codes::InvalidSimilarId { } } +impl fmt::Display for deserr_codes::InvalidSearchRankingScoreThreshold { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`." + ) + } +} + +impl fmt::Display for deserr_codes::InvalidSimilarRankingScoreThreshold { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + deserr_codes::InvalidSearchRankingScoreThreshold.fmt(f) + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index add430893..aed29e612 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -648,6 +648,7 @@ pub struct SearchAggregator { // scoring show_ranking_score: bool, show_ranking_score_details: bool, + ranking_score_threshold: bool, } impl SearchAggregator { @@ -676,6 +677,7 @@ impl SearchAggregator { matching_strategy, attributes_to_search_on, hybrid, + ranking_score_threshold, } = query; let mut ret = Self::default(); @@ -748,6 +750,7 @@ impl SearchAggregator { ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score_details = *show_ranking_score_details; + ret.ranking_score_threshold = ranking_score_threshold.is_some(); if let Some(hybrid) = hybrid { ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO(); @@ -821,6 +824,7 @@ impl SearchAggregator { hybrid, total_degraded, total_used_negative_operator, + ranking_score_threshold, } = other; if self.timestamp.is_none() { @@ -904,6 +908,7 @@ impl SearchAggregator { // scoring self.show_ranking_score |= show_ranking_score; self.show_ranking_score_details |= show_ranking_score_details; + self.ranking_score_threshold |= ranking_score_threshold; } pub fn into_event(self, user: &User, event_name: &str) -> Option { @@ -945,6 +950,7 @@ impl SearchAggregator { hybrid, total_degraded, total_used_negative_operator, + ranking_score_threshold, } = self; if total_received == 0 { @@ -1015,6 +1021,7 @@ impl SearchAggregator { "scoring": { "show_ranking_score": show_ranking_score, "show_ranking_score_details": show_ranking_score_details, + "ranking_score_threshold": ranking_score_threshold, }, }); @@ -1087,6 +1094,7 @@ impl MultiSearchAggregator { matching_strategy: _, attributes_to_search_on: _, hybrid: _, + ranking_score_threshold: _, } = query; index_uid.as_str() @@ -1234,6 +1242,7 @@ impl FacetSearchAggregator { matching_strategy, attributes_to_search_on, hybrid, + ranking_score_threshold, } = query; let mut ret = Self::default(); @@ -1248,7 +1257,8 @@ impl FacetSearchAggregator { || filter.is_some() || *matching_strategy != MatchingStrategy::default() || attributes_to_search_on.is_some() - || hybrid.is_some(); + || hybrid.is_some() + || ranking_score_threshold.is_some(); ret } @@ -1624,6 +1634,7 @@ pub struct SimilarAggregator { // scoring show_ranking_score: bool, show_ranking_score_details: bool, + ranking_score_threshold: bool, } impl SimilarAggregator { @@ -1638,6 +1649,7 @@ impl SimilarAggregator { show_ranking_score, show_ranking_score_details, filter, + ranking_score_threshold, } = query; let mut ret = Self::default(); @@ -1675,6 +1687,7 @@ impl SimilarAggregator { ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score_details = *show_ranking_score_details; + ret.ranking_score_threshold = ranking_score_threshold.is_some(); ret.embedder = embedder.is_some(); @@ -1708,6 +1721,7 @@ impl SimilarAggregator { show_ranking_score, show_ranking_score_details, embedder, + ranking_score_threshold, } = other; if self.timestamp.is_none() { @@ -1749,6 +1763,7 @@ impl SimilarAggregator { // scoring self.show_ranking_score |= show_ranking_score; self.show_ranking_score_details |= show_ranking_score_details; + self.ranking_score_threshold |= ranking_score_threshold; } pub fn into_event(self, user: &User, event_name: &str) -> Option { @@ -1769,6 +1784,7 @@ impl SimilarAggregator { show_ranking_score, show_ranking_score_details, embedder, + ranking_score_threshold, } = self; if total_received == 0 { @@ -1808,6 +1824,7 @@ impl SimilarAggregator { "scoring": { "show_ranking_score": show_ranking_score, "show_ranking_score_details": show_ranking_score_details, + "ranking_score_threshold": ranking_score_threshold, }, }); diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 3f05fa846..10b371f2d 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -14,8 +14,8 @@ use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::routes::indexes::search::search_kind; use crate::search::{ - add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, - DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, RankingScoreThreshold, + SearchQuery, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, }; use crate::search_queue::SearchQueue; @@ -46,6 +46,8 @@ pub struct FacetSearchQuery { pub matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrJsonError, default)] pub attributes_to_search_on: Option>, + #[deserr(default, error = DeserrJsonError, default)] + pub ranking_score_threshold: Option, } pub async fn search( @@ -103,6 +105,7 @@ impl From for SearchQuery { matching_strategy, attributes_to_search_on, hybrid, + ranking_score_threshold, } = value; SearchQuery { @@ -128,6 +131,7 @@ impl From for SearchQuery { vector, attributes_to_search_on, hybrid, + ranking_score_threshold, } } } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 8628da6d9..348d8295c 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -19,9 +19,10 @@ use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::search::{ - add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchKind, SearchQuery, - SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, - DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, + add_search_rules, perform_search, HybridQuery, MatchingStrategy, RankingScoreThreshold, + SearchKind, SearchQuery, SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, + DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, + DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; use crate::search_queue::SearchQueue; @@ -82,6 +83,21 @@ pub struct SearchQueryGet { pub hybrid_embedder: Option, #[deserr(default, error = DeserrQueryParamError)] pub hybrid_semantic_ratio: Option, + #[deserr(default, error = DeserrQueryParamError)] + pub ranking_score_threshold: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] +#[deserr(try_from(String) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)] +pub struct RankingScoreThresholdGet(RankingScoreThreshold); + +impl std::convert::TryFrom for RankingScoreThresholdGet { + type Error = InvalidSearchRankingScoreThreshold; + + fn try_from(s: String) -> Result { + let f: f64 = s.parse().map_err(|_| InvalidSearchRankingScoreThreshold)?; + Ok(RankingScoreThresholdGet(RankingScoreThreshold::try_from(f)?)) + } } #[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)] @@ -152,6 +168,7 @@ impl From for SearchQuery { matching_strategy: other.matching_strategy, attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), hybrid, + ranking_score_threshold: other.ranking_score_threshold.map(|o| o.0), } } } diff --git a/meilisearch/src/routes/indexes/similar.rs b/meilisearch/src/routes/indexes/similar.rs index da73dd63b..518fedab7 100644 --- a/meilisearch/src/routes/indexes/similar.rs +++ b/meilisearch/src/routes/indexes/similar.rs @@ -6,8 +6,8 @@ use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::{ InvalidEmbedder, InvalidSimilarAttributesToRetrieve, InvalidSimilarFilter, InvalidSimilarId, - InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarShowRankingScore, - InvalidSimilarShowRankingScoreDetails, + InvalidSimilarLimit, InvalidSimilarOffset, InvalidSimilarRankingScoreThreshold, + InvalidSimilarShowRankingScore, InvalidSimilarShowRankingScoreDetails, }; use meilisearch_types::error::{ErrorCode as _, ResponseError}; use meilisearch_types::index_uid::IndexUid; @@ -21,8 +21,8 @@ use crate::analytics::{Analytics, SimilarAggregator}; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_similar, SearchKind, SimilarQuery, SimilarResult, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_similar, RankingScoreThresholdSimilar, SearchKind, SimilarQuery, + SimilarResult, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -42,9 +42,7 @@ pub async fn similar_get( ) -> Result { let index_uid = IndexUid::try_from(index_uid.into_inner())?; - let query = params.0.try_into().map_err(|code: InvalidSimilarId| { - ResponseError::from_msg(code.to_string(), code.error_code()) - })?; + let query = params.0.try_into()?; let mut aggregate = SimilarAggregator::from_query(&query, &req); @@ -130,12 +128,27 @@ pub struct SimilarQueryGet { show_ranking_score: Param, #[deserr(default, error = DeserrQueryParamError)] show_ranking_score_details: Param, + #[deserr(default, error = DeserrQueryParamError, default)] + pub ranking_score_threshold: Option, #[deserr(default, error = DeserrQueryParamError)] pub embedder: Option, } +#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] +#[deserr(try_from(String) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)] +pub struct RankingScoreThresholdGet(RankingScoreThresholdSimilar); + +impl std::convert::TryFrom for RankingScoreThresholdGet { + type Error = InvalidSimilarRankingScoreThreshold; + + fn try_from(s: String) -> Result { + let f: f64 = s.parse().map_err(|_| InvalidSimilarRankingScoreThreshold)?; + Ok(RankingScoreThresholdGet(RankingScoreThresholdSimilar::try_from(f)?)) + } +} + impl TryFrom for SimilarQuery { - type Error = InvalidSimilarId; + type Error = ResponseError; fn try_from( SimilarQueryGet { @@ -147,6 +160,7 @@ impl TryFrom for SimilarQuery { show_ranking_score, show_ranking_score_details, embedder, + ranking_score_threshold, }: SimilarQueryGet, ) -> Result { let filter = match filter { @@ -158,7 +172,9 @@ impl TryFrom for SimilarQuery { }; Ok(SimilarQuery { - id: id.0.try_into()?, + id: id.0.try_into().map_err(|code: InvalidSimilarId| { + ResponseError::from_msg(code.to_string(), code.error_code()) + })?, offset: offset.0, limit: limit.0, filter, @@ -166,6 +182,7 @@ impl TryFrom for SimilarQuery { attributes_to_retrieve: attributes_to_retrieve.map(|o| o.into_iter().collect()), show_ranking_score: show_ranking_score.0, show_ranking_score_details: show_ranking_score_details.0, + ranking_score_threshold: ranking_score_threshold.map(|x| x.0), }) } } diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 0c2c49452..05b3c1aff 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -87,6 +87,44 @@ pub struct SearchQuery { pub matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrJsonError, default)] pub attributes_to_search_on: Option>, + #[deserr(default, error = DeserrJsonError, default)] + pub ranking_score_threshold: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Deserr)] +#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)] +pub struct RankingScoreThreshold(f64); + +impl std::convert::TryFrom for RankingScoreThreshold { + type Error = InvalidSearchRankingScoreThreshold; + + fn try_from(f: f64) -> Result { + // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable + #[allow(clippy::manual_range_contains)] + if f > 1.0 || f < 0.0 { + Err(InvalidSearchRankingScoreThreshold) + } else { + Ok(RankingScoreThreshold(f)) + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Deserr)] +#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)] +pub struct RankingScoreThresholdSimilar(f64); + +impl std::convert::TryFrom for RankingScoreThresholdSimilar { + type Error = InvalidSimilarRankingScoreThreshold; + + fn try_from(f: f64) -> Result { + // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable + #[allow(clippy::manual_range_contains)] + if f > 1.0 || f < 0.0 { + Err(InvalidSimilarRankingScoreThreshold) + } else { + Ok(Self(f)) + } + } } // Since this structure is logged A LOT we're going to reduce the number of things it logs to the bare minimum. @@ -117,6 +155,7 @@ impl fmt::Debug for SearchQuery { crop_marker, matching_strategy, attributes_to_search_on, + ranking_score_threshold, } = self; let mut debug = f.debug_struct("SearchQuery"); @@ -188,6 +227,9 @@ impl fmt::Debug for SearchQuery { debug.field("highlight_pre_tag", &highlight_pre_tag); debug.field("highlight_post_tag", &highlight_post_tag); debug.field("crop_marker", &crop_marker); + if let Some(ranking_score_threshold) = ranking_score_threshold { + debug.field("ranking_score_threshold", &ranking_score_threshold); + } debug.finish() } @@ -356,6 +398,8 @@ pub struct SearchQueryWithIndex { pub matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrJsonError, default)] pub attributes_to_search_on: Option>, + #[deserr(default, error = DeserrJsonError, default)] + pub ranking_score_threshold: Option, } impl SearchQueryWithIndex { @@ -384,6 +428,7 @@ impl SearchQueryWithIndex { matching_strategy, attributes_to_search_on, hybrid, + ranking_score_threshold, } = self; ( index_uid, @@ -410,6 +455,7 @@ impl SearchQueryWithIndex { matching_strategy, attributes_to_search_on, hybrid, + ranking_score_threshold, // do not use ..Default::default() here, // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` }, @@ -436,6 +482,8 @@ pub struct SimilarQuery { pub show_ranking_score: bool, #[deserr(default, error = DeserrJsonError, default)] pub show_ranking_score_details: bool, + #[deserr(default, error = DeserrJsonError, default)] + pub ranking_score_threshold: Option, } #[derive(Debug, Clone, PartialEq, Deserr)] @@ -664,6 +712,9 @@ fn prepare_search<'t>( ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { let mut search = index.search(rtxn); search.time_budget(time_budget); + if let Some(ranking_score_threshold) = query.ranking_score_threshold { + search.ranking_score_threshold(ranking_score_threshold.0); + } match search_kind { SearchKind::KeywordOnly => { @@ -705,11 +756,16 @@ fn prepare_search<'t>( .unwrap_or(DEFAULT_PAGINATION_MAX_TOTAL_HITS); search.exhaustive_number_hits(is_finite_pagination); - search.scoring_strategy(if query.show_ranking_score || query.show_ranking_score_details { - ScoringStrategy::Detailed - } else { - ScoringStrategy::Skip - }); + search.scoring_strategy( + if query.show_ranking_score + || query.show_ranking_score_details + || query.ranking_score_threshold.is_some() + { + ScoringStrategy::Detailed + } else { + ScoringStrategy::Skip + }, + ); // compute the offset on the limit depending on the pagination mode. let (offset, limit) = if is_finite_pagination { @@ -787,10 +843,6 @@ pub fn perform_search( let SearchQuery { q, - vector: _, - hybrid: _, - // already computed from prepare_search - offset: _, limit, page, hits_per_page, @@ -801,14 +853,19 @@ pub fn perform_search( show_matches_position, show_ranking_score, show_ranking_score_details, - filter: _, sort, facets, highlight_pre_tag, highlight_post_tag, crop_marker, + // already used in prepare_search + vector: _, + hybrid: _, + offset: _, + ranking_score_threshold: _, matching_strategy: _, attributes_to_search_on: _, + filter: _, } = query; let format = AttributesFormat { @@ -1070,6 +1127,7 @@ pub fn perform_similar( attributes_to_retrieve, show_ranking_score, show_ranking_score_details, + ranking_score_threshold, } = query; // using let-else rather than `?` so that the borrow checker identifies we're always returning here, @@ -1093,6 +1151,10 @@ pub fn perform_similar( } } + if let Some(ranking_score_threshold) = ranking_score_threshold { + similar.ranking_score_threshold(ranking_score_threshold.0); + } + let milli::SearchResult { documents_ids, matching_words: _, diff --git a/meilisearch/tests/search/errors.rs b/meilisearch/tests/search/errors.rs index cce1a86e7..53d516c44 100644 --- a/meilisearch/tests/search/errors.rs +++ b/meilisearch/tests/search/errors.rs @@ -321,6 +321,40 @@ async fn search_bad_facets() { // Can't make the `attributes_to_highlight` fail with a get search since it'll accept anything as an array of strings. } +#[actix_rt::test] +async fn search_bad_threshold() { + let server = Server::new().await; + let index = server.index("test"); + + let (response, code) = index.search_post(json!({"rankingScoreThreshold": "doggo"})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value type at `.rankingScoreThreshold`: expected a number, but found a string: `\"doggo\"`", + "code": "invalid_search_ranking_score_threshold", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_ranking_score_threshold" + } + "###); +} + +#[actix_rt::test] +async fn search_invalid_threshold() { + let server = Server::new().await; + let index = server.index("test"); + + let (response, code) = index.search_post(json!({"rankingScoreThreshold": 42})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value at `.rankingScoreThreshold`: the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`.", + "code": "invalid_search_ranking_score_threshold", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_ranking_score_threshold" + } + "###); +} + #[actix_rt::test] async fn search_non_filterable_facets() { let server = Server::new().await; diff --git a/meilisearch/tests/search/mod.rs b/meilisearch/tests/search/mod.rs index 284b68a15..b65c0dc42 100644 --- a/meilisearch/tests/search/mod.rs +++ b/meilisearch/tests/search/mod.rs @@ -48,6 +48,31 @@ static DOCUMENTS: Lazy = Lazy::new(|| { ]) }); +static SCORE_DOCUMENTS: Lazy = Lazy::new(|| { + json!([ + { + "title": "Batman the dark knight returns: Part 1", + "id": "A", + }, + { + "title": "Batman the dark knight returns: Part 2", + "id": "B", + }, + { + "title": "Batman Returns", + "id": "C", + }, + { + "title": "Batman", + "id": "D", + }, + { + "title": "Badman", + "id": "E", + } + ]) +}); + static NESTED_DOCUMENTS: Lazy = Lazy::new(|| { json!([ { @@ -960,6 +985,213 @@ async fn test_score_details() { .await; } +#[actix_rt::test] +async fn test_score() { + let server = Server::new().await; + let index = server.index("test"); + + let documents = SCORE_DOCUMENTS.clone(); + + let res = index.add_documents(json!(documents), None).await; + index.wait_task(res.0.uid()).await; + + index + .search( + json!({ + "q": "Badman the dark knight returns 1", + "showRankingScore": true, + }), + |response, code| { + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Batman the dark knight returns: Part 1", + "id": "A", + "_rankingScore": 0.9746605609456898 + }, + { + "title": "Batman the dark knight returns: Part 2", + "id": "B", + "_rankingScore": 0.8055252965383685 + }, + { + "title": "Badman", + "id": "E", + "_rankingScore": 0.16666666666666666 + }, + { + "title": "Batman Returns", + "id": "C", + "_rankingScore": 0.07702020202020202 + }, + { + "title": "Batman", + "id": "D", + "_rankingScore": 0.07702020202020202 + } + ] + "###); + }, + ) + .await; +} + +#[actix_rt::test] +async fn test_score_threshold() { + let query = "Badman dark returns 1"; + let server = Server::new().await; + let index = server.index("test"); + + let documents = SCORE_DOCUMENTS.clone(); + + let res = index.add_documents(json!(documents), None).await; + index.wait_task(res.0.uid()).await; + + index + .search( + json!({ + "q": query, + "showRankingScore": true, + "rankingScoreThreshold": 0.0 + }), + |response, code| { + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"5"); + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Batman the dark knight returns: Part 1", + "id": "A", + "_rankingScore": 0.93430081300813 + }, + { + "title": "Batman the dark knight returns: Part 2", + "id": "B", + "_rankingScore": 0.6685627880184332 + }, + { + "title": "Badman", + "id": "E", + "_rankingScore": 0.25 + }, + { + "title": "Batman Returns", + "id": "C", + "_rankingScore": 0.11553030303030302 + }, + { + "title": "Batman", + "id": "D", + "_rankingScore": 0.11553030303030302 + } + ] + "###); + }, + ) + .await; + + index + .search( + json!({ + "q": query, + "showRankingScore": true, + "rankingScoreThreshold": 0.2 + }), + |response, code| { + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"3"###); + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Batman the dark knight returns: Part 1", + "id": "A", + "_rankingScore": 0.93430081300813 + }, + { + "title": "Batman the dark knight returns: Part 2", + "id": "B", + "_rankingScore": 0.6685627880184332 + }, + { + "title": "Badman", + "id": "E", + "_rankingScore": 0.25 + } + ] + "###); + }, + ) + .await; + + index + .search( + json!({ + "q": query, + "showRankingScore": true, + "rankingScoreThreshold": 0.5 + }), + |response, code| { + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"2"###); + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Batman the dark knight returns: Part 1", + "id": "A", + "_rankingScore": 0.93430081300813 + }, + { + "title": "Batman the dark knight returns: Part 2", + "id": "B", + "_rankingScore": 0.6685627880184332 + } + ] + "###); + }, + ) + .await; + + index + .search( + json!({ + "q": query, + "showRankingScore": true, + "rankingScoreThreshold": 0.8 + }), + |response, code| { + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"1"###); + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Batman the dark knight returns: Part 1", + "id": "A", + "_rankingScore": 0.93430081300813 + } + ] + "###); + }, + ) + .await; + + index + .search( + json!({ + "q": query, + "showRankingScore": true, + "rankingScoreThreshold": 1.0 + }), + |response, code| { + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @r###"0"###); + // nobody is perfect + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]"); + }, + ) + .await; +} + #[actix_rt::test] async fn test_degraded_score_details() { let server = Server::new().await; diff --git a/meilisearch/tests/similar/errors.rs b/meilisearch/tests/similar/errors.rs index 64386a7bf..7765b9a85 100644 --- a/meilisearch/tests/similar/errors.rs +++ b/meilisearch/tests/similar/errors.rs @@ -87,6 +87,68 @@ async fn similar_bad_id() { "###); } +#[actix_rt::test] +async fn similar_bad_ranking_score_threshold() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let (response, code) = index.similar_post(json!({"rankingScoreThreshold": ["doggo"]})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value type at `.rankingScoreThreshold`: expected a number, but found an array: `[\"doggo\"]`", + "code": "invalid_similar_ranking_score_threshold", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_ranking_score_threshold" + } + "###); +} + +#[actix_rt::test] +async fn similar_invalid_ranking_score_threshold() { + let server = Server::new().await; + let index = server.index("test"); + server.set_features(json!({"vectorStore": true})).await; + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let (response, code) = index.similar_post(json!({"rankingScoreThreshold": 42})).await; + snapshot!(code, @"400 Bad Request"); + snapshot!(json_string!(response), @r###" + { + "message": "Invalid value at `.rankingScoreThreshold`: the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`.", + "code": "invalid_similar_ranking_score_threshold", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_similar_ranking_score_threshold" + } + "###); +} + #[actix_rt::test] async fn similar_invalid_id() { let server = Server::new().await; diff --git a/meilisearch/tests/similar/mod.rs b/meilisearch/tests/similar/mod.rs index ee78917cb..bde23b67f 100644 --- a/meilisearch/tests/similar/mod.rs +++ b/meilisearch/tests/similar/mod.rs @@ -194,6 +194,235 @@ async fn basic() { .await; } +#[actix_rt::test] +async fn ranking_score_threshold() { + let server = Server::new().await; + let index = server.index("test"); + let (value, code) = server.set_features(json!({"vectorStore": true})).await; + snapshot!(code, @"200 OK"); + snapshot!(value, @r###" + { + "vectorStore": true, + "metrics": false, + "logsRoute": false + } + "###); + + let (response, code) = index + .update_settings(json!({ + "embedders": { + "manual": { + "source": "userProvided", + "dimensions": 3, + } + }, + "filterableAttributes": ["title"]})) + .await; + snapshot!(code, @"202 Accepted"); + server.wait_task(response.uid()).await; + + let documents = DOCUMENTS.clone(); + let (value, code) = index.add_documents(documents, None).await; + snapshot!(code, @"202 Accepted"); + index.wait_task(value.uid()).await; + + index + .similar( + json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0}), + |response, code| { + snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"4"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { + "manual": [ + 0.1, + 0.6, + 0.8 + ] + }, + "_rankingScore": 0.890957772731781 + }, + { + "title": "Captain Marvel", + "release_year": 2019, + "id": "299537", + "_vectors": { + "manual": [ + 0.6, + 0.8, + -0.2 + ] + }, + "_rankingScore": 0.39060014486312866 + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "release_year": 2019, + "id": "166428", + "_vectors": { + "manual": [ + 0.7, + 0.7, + -0.4 + ] + }, + "_rankingScore": 0.2819308042526245 + }, + { + "title": "Shazam!", + "release_year": 2019, + "id": "287947", + "_vectors": { + "manual": [ + 0.8, + 0.4, + -0.5 + ] + }, + "_rankingScore": 0.1662663221359253 + } + ] + "###); + }, + ) + .await; + + index + .similar( + json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.2}), + |response, code| { + snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"3"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { + "manual": [ + 0.1, + 0.6, + 0.8 + ] + }, + "_rankingScore": 0.890957772731781 + }, + { + "title": "Captain Marvel", + "release_year": 2019, + "id": "299537", + "_vectors": { + "manual": [ + 0.6, + 0.8, + -0.2 + ] + }, + "_rankingScore": 0.39060014486312866 + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "release_year": 2019, + "id": "166428", + "_vectors": { + "manual": [ + 0.7, + 0.7, + -0.4 + ] + }, + "_rankingScore": 0.2819308042526245 + } + ] + "###); + }, + ) + .await; + + index + .similar( + json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.3}), + |response, code| { + snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"2"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { + "manual": [ + 0.1, + 0.6, + 0.8 + ] + }, + "_rankingScore": 0.890957772731781 + }, + { + "title": "Captain Marvel", + "release_year": 2019, + "id": "299537", + "_vectors": { + "manual": [ + 0.6, + 0.8, + -0.2 + ] + }, + "_rankingScore": 0.39060014486312866 + } + ] + "###); + }, + ) + .await; + + index + .similar( + json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.6}), + |response, code| { + snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response["estimatedTotalHits"]), @"1"); + snapshot!(json_string!(response["hits"]), @r###" + [ + { + "title": "Escape Room", + "release_year": 2019, + "id": "522681", + "_vectors": { + "manual": [ + 0.1, + 0.6, + 0.8 + ] + }, + "_rankingScore": 0.890957772731781 + } + ] + "###); + }, + ) + .await; + + index + .similar( + json!({"id": 143, "showRankingScore": true, "rankingScoreThreshold": 0.9}), + |response, code| { + snapshot!(code, @"200 OK"); + snapshot!(json_string!(response["hits"]), @"[]"); + }, + ) + .await; +} + #[actix_rt::test] async fn filter() { let server = Server::new().await; diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 2779f5b15..0195c396f 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -66,6 +66,7 @@ fn main() -> Result<(), Box> { &mut DefaultSearchLogger, logger, TimeBudget::max(), + None, )?; if let Some((logger, dir)) = detailed_logger { logger.finish(&mut ctx, Path::new(dir))?; diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index fc13a5e1e..87f922c4c 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -169,6 +169,7 @@ impl<'a> Search<'a> { index: self.index, semantic: self.semantic.clone(), time_budget: self.time_budget.clone(), + ranking_score_threshold: self.ranking_score_threshold, }; let semantic = search.semantic.take(); diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 19d5ff358..49d73ff31 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -50,6 +50,7 @@ pub struct Search<'a> { index: &'a Index, semantic: Option, time_budget: TimeBudget, + ranking_score_threshold: Option, } impl<'a> Search<'a> { @@ -70,6 +71,7 @@ impl<'a> Search<'a> { index, semantic: None, time_budget: TimeBudget::max(), + ranking_score_threshold: None, } } @@ -146,6 +148,11 @@ impl<'a> Search<'a> { self } + pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Search<'a> { + self.ranking_score_threshold = Some(ranking_score_threshold); + self + } + pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { if has_vector_search { let ctx = SearchContext::new(self.index, self.rtxn)?; @@ -184,6 +191,7 @@ impl<'a> Search<'a> { embedder_name, embedder, self.time_budget.clone(), + self.ranking_score_threshold, )? } _ => execute_search( @@ -201,6 +209,7 @@ impl<'a> Search<'a> { &mut DefaultSearchLogger, &mut DefaultSearchLogger, self.time_budget.clone(), + self.ranking_score_threshold, )?, }; @@ -239,6 +248,7 @@ impl fmt::Debug for Search<'_> { index: _, semantic, time_budget, + ranking_score_threshold, } = self; f.debug_struct("Search") .field("query", query) @@ -257,6 +267,7 @@ impl fmt::Debug for Search<'_> { &semantic.as_ref().map(|semantic| &semantic.embedder_name), ) .field("time_budget", time_budget) + .field("ranking_score_threshold", ranking_score_threshold) .finish() } } diff --git a/milli/src/search/new/bucket_sort.rs b/milli/src/search/new/bucket_sort.rs index e9bc5449d..d937c78bf 100644 --- a/milli/src/search/new/bucket_sort.rs +++ b/milli/src/search/new/bucket_sort.rs @@ -28,6 +28,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( scoring_strategy: ScoringStrategy, logger: &mut dyn SearchLogger, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { logger.initial_query(query); logger.ranking_rules(&ranking_rules); @@ -164,7 +165,19 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( loop { let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]); ranking_rule_scores.push(ScoreDetails::Skipped); + + // remove candidates from the universe without adding them to result if their score is below the threshold + if let Some(ranking_score_threshold) = ranking_score_threshold { + let current_score = ScoreDetails::global_score(ranking_rule_scores.iter()); + if current_score < ranking_score_threshold { + all_candidates -= bucket | &ranking_rule_universes[cur_ranking_rule_index]; + back!(); + continue; + } + } + maybe_add_to_results!(bucket); + ranking_rule_scores.pop(); if cur_ranking_rule_index == 0 { @@ -220,6 +233,18 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( debug_assert!( ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates) ); + + // remove candidates from the universe without adding them to result if their score is below the threshold + if let Some(ranking_score_threshold) = ranking_score_threshold { + let current_score = ScoreDetails::global_score(ranking_rule_scores.iter()); + if current_score < ranking_score_threshold { + all_candidates -= + next_bucket.candidates | &ranking_rule_universes[cur_ranking_rule_index]; + back!(); + continue; + } + } + ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates; if cur_ranking_rule_index == ranking_rules_len - 1 diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index f121971b8..87ddb2915 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -523,6 +523,7 @@ mod tests { &mut crate::DefaultSearchLogger, &mut crate::DefaultSearchLogger, TimeBudget::max(), + None, ) .unwrap(); diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index f1d1db6a9..623c72567 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -573,6 +573,7 @@ pub fn execute_vector_search( embedder_name: &str, embedder: &Embedder, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -602,6 +603,7 @@ pub fn execute_vector_search( scoring_strategy, placeholder_search_logger, time_budget, + ranking_score_threshold, )?; Ok(PartialSearchResult { @@ -631,6 +633,7 @@ pub fn execute_search( placeholder_search_logger: &mut dyn SearchLogger, query_graph_logger: &mut dyn SearchLogger, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -719,6 +722,7 @@ pub fn execute_search( scoring_strategy, query_graph_logger, time_budget, + ranking_score_threshold, )? } else { let ranking_rules = @@ -733,6 +737,7 @@ pub fn execute_search( scoring_strategy, placeholder_search_logger, time_budget, + ranking_score_threshold, )? }; diff --git a/milli/src/search/similar.rs b/milli/src/search/similar.rs index 49b7c876f..bf5cc323f 100644 --- a/milli/src/search/similar.rs +++ b/milli/src/search/similar.rs @@ -17,6 +17,7 @@ pub struct Similar<'a> { index: &'a Index, embedder_name: String, embedder: Arc, + ranking_score_threshold: Option, } impl<'a> Similar<'a> { @@ -29,7 +30,17 @@ impl<'a> Similar<'a> { embedder_name: String, embedder: Arc, ) -> Self { - Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } + Self { + id, + filter: None, + offset, + limit, + rtxn, + index, + embedder_name, + embedder, + ranking_score_threshold: None, + } } pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { @@ -37,8 +48,18 @@ impl<'a> Similar<'a> { self } + pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Self { + self.ranking_score_threshold = Some(ranking_score_threshold); + self + } + pub fn execute(&self) -> Result { - let universe = filtered_universe(self.index, self.rtxn, &self.filter)?; + let mut universe = filtered_universe(self.index, self.rtxn, &self.filter)?; + + // we never want to receive the docid + universe.remove(self.id); + + let universe = universe; let embedder_index = self.index @@ -77,6 +98,8 @@ impl<'a> Similar<'a> { let mut documents_seen = RoaringBitmap::new(); documents_seen.insert(self.id); + let mut candidates = universe; + for (docid, distance) in results .into_iter() // skip documents we've already seen & mark that we saw the current document @@ -85,8 +108,6 @@ impl<'a> Similar<'a> { // take **after** filter and skip so that we get exactly limit elements if available .take(self.limit) { - documents_ids.push(docid); - let score = 1.0 - distance; let score = self .embedder @@ -94,14 +115,28 @@ impl<'a> Similar<'a> { .map(|distribution| distribution.shift(score)) .unwrap_or(score); - let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }); + let score_details = + vec![ScoreDetails::Vector(score_details::Vector { similarity: Some(score) })]; - document_scores.push(vec![score]); + let score = ScoreDetails::global_score(score_details.iter()); + + if let Some(ranking_score_threshold) = &self.ranking_score_threshold { + if score < *ranking_score_threshold { + // this document is no longer a candidate + candidates.remove(docid); + // any document after this one is no longer a candidate either, so restrict the set to documents already seen. + candidates &= documents_seen; + break; + } + } + + documents_ids.push(docid); + document_scores.push(score_details); } Ok(SearchResult { matching_words: Default::default(), - candidates: universe, + candidates, documents_ids, document_scores, degraded: false,