mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-30 00:34:26 +01:00
Add ranking_score_threshold to milli
This commit is contained in:
parent
75d5c0ae1f
commit
aac1d769a7
@ -169,6 +169,7 @@ impl<'a> Search<'a> {
|
|||||||
index: self.index,
|
index: self.index,
|
||||||
semantic: self.semantic.clone(),
|
semantic: self.semantic.clone(),
|
||||||
time_budget: self.time_budget.clone(),
|
time_budget: self.time_budget.clone(),
|
||||||
|
ranking_score_threshold: self.ranking_score_threshold,
|
||||||
};
|
};
|
||||||
|
|
||||||
let semantic = search.semantic.take();
|
let semantic = search.semantic.take();
|
||||||
|
@ -50,6 +50,7 @@ pub struct Search<'a> {
|
|||||||
index: &'a Index,
|
index: &'a Index,
|
||||||
semantic: Option<SemanticSearch>,
|
semantic: Option<SemanticSearch>,
|
||||||
time_budget: TimeBudget,
|
time_budget: TimeBudget,
|
||||||
|
ranking_score_threshold: Option<f64>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> Search<'a> {
|
impl<'a> Search<'a> {
|
||||||
@ -70,6 +71,7 @@ impl<'a> Search<'a> {
|
|||||||
index,
|
index,
|
||||||
semantic: None,
|
semantic: None,
|
||||||
time_budget: TimeBudget::max(),
|
time_budget: TimeBudget::max(),
|
||||||
|
ranking_score_threshold: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -146,6 +148,14 @@ impl<'a> Search<'a> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn ranking_score_threshold(
|
||||||
|
&mut self,
|
||||||
|
ranking_score_threshold: Option<f64>,
|
||||||
|
) -> &mut Search<'a> {
|
||||||
|
self.ranking_score_threshold = ranking_score_threshold;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
|
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
|
||||||
if has_vector_search {
|
if has_vector_search {
|
||||||
let ctx = SearchContext::new(self.index, self.rtxn)?;
|
let ctx = SearchContext::new(self.index, self.rtxn)?;
|
||||||
@ -184,6 +194,7 @@ impl<'a> Search<'a> {
|
|||||||
embedder_name,
|
embedder_name,
|
||||||
embedder,
|
embedder,
|
||||||
self.time_budget.clone(),
|
self.time_budget.clone(),
|
||||||
|
self.ranking_score_threshold,
|
||||||
)?
|
)?
|
||||||
}
|
}
|
||||||
_ => execute_search(
|
_ => execute_search(
|
||||||
@ -201,6 +212,7 @@ impl<'a> Search<'a> {
|
|||||||
&mut DefaultSearchLogger,
|
&mut DefaultSearchLogger,
|
||||||
&mut DefaultSearchLogger,
|
&mut DefaultSearchLogger,
|
||||||
self.time_budget.clone(),
|
self.time_budget.clone(),
|
||||||
|
self.ranking_score_threshold,
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -239,6 +251,7 @@ impl fmt::Debug for Search<'_> {
|
|||||||
index: _,
|
index: _,
|
||||||
semantic,
|
semantic,
|
||||||
time_budget,
|
time_budget,
|
||||||
|
ranking_score_threshold,
|
||||||
} = self;
|
} = self;
|
||||||
f.debug_struct("Search")
|
f.debug_struct("Search")
|
||||||
.field("query", query)
|
.field("query", query)
|
||||||
@ -257,6 +270,7 @@ impl fmt::Debug for Search<'_> {
|
|||||||
&semantic.as_ref().map(|semantic| &semantic.embedder_name),
|
&semantic.as_ref().map(|semantic| &semantic.embedder_name),
|
||||||
)
|
)
|
||||||
.field("time_budget", time_budget)
|
.field("time_budget", time_budget)
|
||||||
|
.field("ranking_score_threshold", ranking_score_threshold)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -28,6 +28,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
|||||||
scoring_strategy: ScoringStrategy,
|
scoring_strategy: ScoringStrategy,
|
||||||
logger: &mut dyn SearchLogger<Q>,
|
logger: &mut dyn SearchLogger<Q>,
|
||||||
time_budget: TimeBudget,
|
time_budget: TimeBudget,
|
||||||
|
ranking_score_threshold: Option<f64>,
|
||||||
) -> Result<BucketSortOutput> {
|
) -> Result<BucketSortOutput> {
|
||||||
logger.initial_query(query);
|
logger.initial_query(query);
|
||||||
logger.ranking_rules(&ranking_rules);
|
logger.ranking_rules(&ranking_rules);
|
||||||
@ -144,6 +145,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
|||||||
ctx,
|
ctx,
|
||||||
from,
|
from,
|
||||||
length,
|
length,
|
||||||
|
ranking_score_threshold,
|
||||||
logger,
|
logger,
|
||||||
&mut valid_docids,
|
&mut valid_docids,
|
||||||
&mut valid_scores,
|
&mut valid_scores,
|
||||||
@ -164,7 +166,9 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
|||||||
loop {
|
loop {
|
||||||
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
|
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
|
||||||
ranking_rule_scores.push(ScoreDetails::Skipped);
|
ranking_rule_scores.push(ScoreDetails::Skipped);
|
||||||
|
|
||||||
maybe_add_to_results!(bucket);
|
maybe_add_to_results!(bucket);
|
||||||
|
|
||||||
ranking_rule_scores.pop();
|
ranking_rule_scores.pop();
|
||||||
|
|
||||||
if cur_ranking_rule_index == 0 {
|
if cur_ranking_rule_index == 0 {
|
||||||
@ -220,6 +224,17 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
|||||||
debug_assert!(
|
debug_assert!(
|
||||||
ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates)
|
ranking_rule_universes[cur_ranking_rule_index].is_superset(&next_bucket.candidates)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
if let Some(ranking_score_threshold) = ranking_score_threshold {
|
||||||
|
let current_score = ScoreDetails::global_score(ranking_rule_scores.iter());
|
||||||
|
if current_score < ranking_score_threshold {
|
||||||
|
all_candidates -=
|
||||||
|
next_bucket.candidates | &ranking_rule_universes[cur_ranking_rule_index];
|
||||||
|
back!();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates;
|
ranking_rule_universes[cur_ranking_rule_index] -= &next_bucket.candidates;
|
||||||
|
|
||||||
if cur_ranking_rule_index == ranking_rules_len - 1
|
if cur_ranking_rule_index == ranking_rules_len - 1
|
||||||
@ -262,6 +277,7 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
|
|||||||
ctx: &mut SearchContext<'ctx>,
|
ctx: &mut SearchContext<'ctx>,
|
||||||
from: usize,
|
from: usize,
|
||||||
length: usize,
|
length: usize,
|
||||||
|
ranking_score_threshold: Option<f64>,
|
||||||
logger: &mut dyn SearchLogger<Q>,
|
logger: &mut dyn SearchLogger<Q>,
|
||||||
|
|
||||||
valid_docids: &mut Vec<u32>,
|
valid_docids: &mut Vec<u32>,
|
||||||
@ -279,6 +295,15 @@ fn maybe_add_to_results<'ctx, Q: RankingRuleQueryTrait>(
|
|||||||
ranking_rule_scores: &[ScoreDetails],
|
ranking_rule_scores: &[ScoreDetails],
|
||||||
candidates: RoaringBitmap,
|
candidates: RoaringBitmap,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
|
// remove candidates from the universe without adding them to result if their score is below the threshold
|
||||||
|
if let Some(ranking_score_threshold) = ranking_score_threshold {
|
||||||
|
let score = ScoreDetails::global_score(ranking_rule_scores.iter());
|
||||||
|
if score < ranking_score_threshold {
|
||||||
|
*all_candidates -= candidates | &ranking_rule_universes[cur_ranking_rule_index];
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// First apply the distinct rule on the candidates, reducing the universes if necessary
|
// First apply the distinct rule on the candidates, reducing the universes if necessary
|
||||||
let candidates = if let Some(distinct_fid) = distinct_fid {
|
let candidates = if let Some(distinct_fid) = distinct_fid {
|
||||||
let DistinctOutput { remaining, excluded } =
|
let DistinctOutput { remaining, excluded } =
|
||||||
|
@ -523,6 +523,7 @@ mod tests {
|
|||||||
&mut crate::DefaultSearchLogger,
|
&mut crate::DefaultSearchLogger,
|
||||||
&mut crate::DefaultSearchLogger,
|
&mut crate::DefaultSearchLogger,
|
||||||
TimeBudget::max(),
|
TimeBudget::max(),
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -568,6 +568,7 @@ pub fn execute_vector_search(
|
|||||||
embedder_name: &str,
|
embedder_name: &str,
|
||||||
embedder: &Embedder,
|
embedder: &Embedder,
|
||||||
time_budget: TimeBudget,
|
time_budget: TimeBudget,
|
||||||
|
ranking_score_threshold: Option<f64>,
|
||||||
) -> Result<PartialSearchResult> {
|
) -> Result<PartialSearchResult> {
|
||||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||||
|
|
||||||
@ -597,6 +598,7 @@ pub fn execute_vector_search(
|
|||||||
scoring_strategy,
|
scoring_strategy,
|
||||||
placeholder_search_logger,
|
placeholder_search_logger,
|
||||||
time_budget,
|
time_budget,
|
||||||
|
ranking_score_threshold,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(PartialSearchResult {
|
Ok(PartialSearchResult {
|
||||||
@ -626,6 +628,7 @@ pub fn execute_search(
|
|||||||
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
|
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
|
||||||
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
|
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
|
||||||
time_budget: TimeBudget,
|
time_budget: TimeBudget,
|
||||||
|
ranking_score_threshold: Option<f64>,
|
||||||
) -> Result<PartialSearchResult> {
|
) -> Result<PartialSearchResult> {
|
||||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||||
|
|
||||||
@ -714,6 +717,7 @@ pub fn execute_search(
|
|||||||
scoring_strategy,
|
scoring_strategy,
|
||||||
query_graph_logger,
|
query_graph_logger,
|
||||||
time_budget,
|
time_budget,
|
||||||
|
ranking_score_threshold,
|
||||||
)?
|
)?
|
||||||
} else {
|
} else {
|
||||||
let ranking_rules =
|
let ranking_rules =
|
||||||
@ -728,6 +732,7 @@ pub fn execute_search(
|
|||||||
scoring_strategy,
|
scoring_strategy,
|
||||||
placeholder_search_logger,
|
placeholder_search_logger,
|
||||||
time_budget,
|
time_budget,
|
||||||
|
ranking_score_threshold,
|
||||||
)?
|
)?
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user