diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index fc13a5e1e..87f922c4c 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -169,6 +169,7 @@ impl<'a> Search<'a> { index: self.index, semantic: self.semantic.clone(), time_budget: self.time_budget.clone(), + ranking_score_threshold: self.ranking_score_threshold, }; let semantic = search.semantic.take(); diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index bab67e6bd..73c811049 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -49,6 +49,7 @@ pub struct Search<'a> { index: &'a Index, semantic: Option, time_budget: TimeBudget, + ranking_score_threshold: Option, } impl<'a> Search<'a> { @@ -69,6 +70,7 @@ impl<'a> Search<'a> { index, semantic: None, time_budget: TimeBudget::max(), + ranking_score_threshold: None, } } @@ -145,6 +147,14 @@ impl<'a> Search<'a> { self } + pub fn ranking_score_threshold( + &mut self, + ranking_score_threshold: Option, + ) -> &mut Search<'a> { + self.ranking_score_threshold = ranking_score_threshold; + self + } + pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { if has_vector_search { let ctx = SearchContext::new(self.index, self.rtxn); @@ -183,6 +193,7 @@ impl<'a> Search<'a> { embedder_name, embedder, self.time_budget.clone(), + self.ranking_score_threshold, )? } _ => execute_search( @@ -200,6 +211,7 @@ impl<'a> Search<'a> { &mut DefaultSearchLogger, &mut DefaultSearchLogger, self.time_budget.clone(), + self.ranking_score_threshold, )?, }; @@ -238,6 +250,7 @@ impl fmt::Debug for Search<'_> { index: _, semantic, time_budget, + ranking_score_threshold, } = self; f.debug_struct("Search") .field("query", query) @@ -256,6 +269,7 @@ impl fmt::Debug for Search<'_> { &semantic.as_ref().map(|semantic| &semantic.embedder_name), ) .field("time_budget", time_budget) + .field("ranking_score_threshold", ranking_score_threshold) .finish() } } diff --git a/milli/src/search/new/bucket_sort.rs b/milli/src/search/new/bucket_sort.rs index 521fcb983..7fcfd10f6 100644 --- a/milli/src/search/new/bucket_sort.rs +++ b/milli/src/search/new/bucket_sort.rs @@ -28,6 +28,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( scoring_strategy: ScoringStrategy, logger: &mut dyn SearchLogger, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { logger.initial_query(query); logger.ranking_rules(&ranking_rules); @@ -144,6 +145,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( ctx, from, length, + ranking_score_threshold, logger, &mut valid_docids, &mut valid_scores, @@ -164,7 +166,9 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( loop { let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]); ranking_rule_scores.push(ScoreDetails::Skipped); + maybe_add_to_results!(bucket); + ranking_rule_scores.pop(); if cur_ranking_rule_index == 0 { @@ -220,6 +224,17 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( debug_assert!( ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates) ); + + if let Some(ranking_score_threshold) = ranking_score_threshold { + let current_score = ScoreDetails::global_score(ranking_rule_scores.iter()); + if current_score < ranking_score_threshold { + all_candidates -= + next_bucket.candidates | &ranking_rule_universes[cur_ranking_rule_index]; + back!(); + continue; + } + } + ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates; if cur_ranking_rule_index == ranking_rules_len - 1 @@ -262,6 +277,7 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>( ctx: &mut SearchContext<'ctx>, from: usize, length: usize, + ranking_score_threshold: Option, logger: &mut dyn SearchLogger, valid_docids: &mut Vec, @@ -279,6 +295,15 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>( ranking_rule_scores: &[ScoreDetails], candidates: RoaringBitmap, ) -> Result<()> { + // remove candidates from the universe without adding them to result if their score is below the threshold + if let Some(ranking_score_threshold) = ranking_score_threshold { + let score = ScoreDetails::global_score(ranking_rule_scores.iter()); + if score < ranking_score_threshold { + *all_candidates -= candidates | &ranking_rule_universes[cur_ranking_rule_index]; + return Ok(()); + } + } + // First apply the distinct rule on the candidates, reducing the universes if necessary let candidates = if let Some(distinct_fid) = distinct_fid { let DistinctOutput { remaining, excluded } = diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 2913f206d..5e65a10aa 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -523,6 +523,7 @@ mod tests { &mut crate::DefaultSearchLogger, &mut crate::DefaultSearchLogger, TimeBudget::max(), + None, ) .unwrap(); diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 617068ef8..09ba66f8c 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -551,6 +551,7 @@ pub fn execute_vector_search( embedder_name: &str, embedder: &Embedder, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -580,6 +581,7 @@ pub fn execute_vector_search( scoring_strategy, placeholder_search_logger, time_budget, + ranking_score_threshold, )?; Ok(PartialSearchResult { @@ -609,6 +611,7 @@ pub fn execute_search( placeholder_search_logger: &mut dyn SearchLogger, query_graph_logger: &mut dyn SearchLogger, time_budget: TimeBudget, + ranking_score_threshold: Option, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -697,6 +700,7 @@ pub fn execute_search( scoring_strategy, query_graph_logger, time_budget, + ranking_score_threshold, )? } else { let ranking_rules = @@ -711,6 +715,7 @@ pub fn execute_search( scoring_strategy, placeholder_search_logger, time_budget, + ranking_score_threshold, )? };