Add ranking_score_threshold to milli

This commit is contained in:
Louis Dureuil 2024-04-11 19:04:06 +02:00
parent 75d5c0ae1f
commit aac1d769a7
No known key found for this signature in database
5 changed files with 46 additions and 0 deletions

View File

@ -169,6 +169,7 @@ impl<'a> Search<'a> {
index: self.index, index: self.index,
semantic: self.semantic.clone(), semantic: self.semantic.clone(),
time_budget: self.time_budget.clone(), time_budget: self.time_budget.clone(),
ranking_score_threshold: self.ranking_score_threshold,
}; };
let semantic = search.semantic.take(); let semantic = search.semantic.take();

View File

@ -50,6 +50,7 @@ pub struct Search<'a> {
index: &'a Index, index: &'a Index,
semantic: Option<SemanticSearch>, semantic: Option<SemanticSearch>,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
} }
impl<'a> Search<'a> { impl<'a> Search<'a> {
@ -70,6 +71,7 @@ impl<'a> Search<'a> {
index, index,
semantic: None, semantic: None,
time_budget: TimeBudget::max(), time_budget: TimeBudget::max(),
ranking_score_threshold: None,
} }
} }
@ -146,6 +148,14 @@ impl<'a> Search<'a> {
self self
} }
pub fn ranking_score_threshold(
&mut self,
ranking_score_threshold: Option<f64>,
) -> &mut Search<'a> {
self.ranking_score_threshold = ranking_score_threshold;
self
}
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> { pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
if has_vector_search { if has_vector_search {
let ctx = SearchContext::new(self.index, self.rtxn)?; let ctx = SearchContext::new(self.index, self.rtxn)?;
@ -184,6 +194,7 @@ impl<'a> Search<'a> {
embedder_name, embedder_name,
embedder, embedder,
self.time_budget.clone(), self.time_budget.clone(),
self.ranking_score_threshold,
)? )?
} }
_ => execute_search( _ => execute_search(
@ -201,6 +212,7 @@ impl<'a> Search<'a> {
&mut DefaultSearchLogger, &mut DefaultSearchLogger,
&mut DefaultSearchLogger, &mut DefaultSearchLogger,
self.time_budget.clone(), self.time_budget.clone(),
self.ranking_score_threshold,
)?, )?,
}; };
@ -239,6 +251,7 @@ impl fmt::Debug for Search<'_> {
index: _, index: _,
semantic, semantic,
time_budget, time_budget,
ranking_score_threshold,
} = self; } = self;
f.debug_struct("Search") f.debug_struct("Search")
.field("query", query) .field("query", query)
@ -257,6 +270,7 @@ impl fmt::Debug for Search<'_> {
&semantic.as_ref().map(|semantic| &semantic.embedder_name), &semantic.as_ref().map(|semantic| &semantic.embedder_name),
) )
.field("time_budget", time_budget) .field("time_budget", time_budget)
.field("ranking_score_threshold", ranking_score_threshold)
.finish() .finish()
} }
} }

View File

@ -28,6 +28,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
scoring_strategy: ScoringStrategy, scoring_strategy: ScoringStrategy,
logger: &mut dyn SearchLogger<Q>, logger: &mut dyn SearchLogger<Q>,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<BucketSortOutput> { ) -> Result<BucketSortOutput> {
logger.initial_query(query); logger.initial_query(query);
logger.ranking_rules(&ranking_rules); logger.ranking_rules(&ranking_rules);
@ -144,6 +145,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
ctx, ctx,
from, from,
length, length,
ranking_score_threshold,
logger, logger,
&mut valid_docids, &mut valid_docids,
&mut valid_scores, &mut valid_scores,
@ -164,7 +166,9 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
loop { loop {
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]); let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
ranking_rule_scores.push(ScoreDetails::Skipped); ranking_rule_scores.push(ScoreDetails::Skipped);
maybe_add_to_results!(bucket); maybe_add_to_results!(bucket);
ranking_rule_scores.pop(); ranking_rule_scores.pop();
if cur_ranking_rule_index == 0 { if cur_ranking_rule_index == 0 {
@ -220,6 +224,17 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
debug_assert!( debug_assert!(
ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates) ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates)
); );
if let Some(ranking_score_threshold) = ranking_score_threshold {
let current_score = ScoreDetails::global_score(ranking_rule_scores.iter());
if current_score < ranking_score_threshold {
all_candidates -=
next_bucket.candidates | &ranking_rule_universes[cur_ranking_rule_index];
back!();
continue;
}
}
ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates; ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates;
if cur_ranking_rule_index == ranking_rules_len - 1 if cur_ranking_rule_index == ranking_rules_len - 1
@ -262,6 +277,7 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
ctx: &mut SearchContext<'ctx>, ctx: &mut SearchContext<'ctx>,
from: usize, from: usize,
length: usize, length: usize,
ranking_score_threshold: Option<f64>,
logger: &mut dyn SearchLogger<Q>, logger: &mut dyn SearchLogger<Q>,
valid_docids: &mut Vec<u32>, valid_docids: &mut Vec<u32>,
@ -279,6 +295,15 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
ranking_rule_scores: &[ScoreDetails], ranking_rule_scores: &[ScoreDetails],
candidates: RoaringBitmap, candidates: RoaringBitmap,
) -> Result<()> { ) -> Result<()> {
// remove candidates from the universe without adding them to result if their score is below the threshold
if let Some(ranking_score_threshold) = ranking_score_threshold {
let score = ScoreDetails::global_score(ranking_rule_scores.iter());
if score < ranking_score_threshold {
*all_candidates -= candidates | &ranking_rule_universes[cur_ranking_rule_index];
return Ok(());
}
}
// First apply the distinct rule on the candidates, reducing the universes if necessary // First apply the distinct rule on the candidates, reducing the universes if necessary
let candidates = if let Some(distinct_fid) = distinct_fid { let candidates = if let Some(distinct_fid) = distinct_fid {
let DistinctOutput { remaining, excluded } = let DistinctOutput { remaining, excluded } =

View File

@ -523,6 +523,7 @@ mod tests {
&mut crate::DefaultSearchLogger, &mut crate::DefaultSearchLogger,
&mut crate::DefaultSearchLogger, &mut crate::DefaultSearchLogger,
TimeBudget::max(), TimeBudget::max(),
None,
) )
.unwrap(); .unwrap();

View File

@ -568,6 +568,7 @@ pub fn execute_vector_search(
embedder_name: &str, embedder_name: &str,
embedder: &Embedder, embedder: &Embedder,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -597,6 +598,7 @@ pub fn execute_vector_search(
scoring_strategy, scoring_strategy,
placeholder_search_logger, placeholder_search_logger,
time_budget, time_budget,
ranking_score_threshold,
)?; )?;
Ok(PartialSearchResult { Ok(PartialSearchResult {
@ -626,6 +628,7 @@ pub fn execute_search(
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>, placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
query_graph_logger: &mut dyn SearchLogger<QueryGraph>, query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
time_budget: TimeBudget, time_budget: TimeBudget,
ranking_score_threshold: Option<f64>,
) -> Result<PartialSearchResult> { ) -> Result<PartialSearchResult> {
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
@ -714,6 +717,7 @@ pub fn execute_search(
scoring_strategy, scoring_strategy,
query_graph_logger, query_graph_logger,
time_budget, time_budget,
ranking_score_threshold,
)? )?
} else { } else {
let ranking_rules = let ranking_rules =
@ -728,6 +732,7 @@ pub fn execute_search(
scoring_strategy, scoring_strategy,
placeholder_search_logger, placeholder_search_logger,
time_budget, time_budget,
ranking_score_threshold,
)? )?
}; };