diff --git a/milli/src/search/new/graph_based_ranking_rule.rs b/milli/src/search/new/graph_based_ranking_rule.rs index a25d9c155..fa3c0b3d0 100644 --- a/milli/src/search/new/graph_based_ranking_rule.rs +++ b/milli/src/search/new/graph_based_ranking_rule.rs @@ -50,6 +50,7 @@ use super::ranking_rule_graph::{ }; use super::small_bitmap::SmallBitmap; use super::{QueryGraph, RankingRule, RankingRuleOutput, SearchContext}; +use crate::score_details::Rank; use crate::search::new::query_term::LocatedQueryTermSubset; use crate::search::new::ranking_rule_graph::PathVisitor; use crate::{Result, TermsMatchingStrategy}; @@ -118,6 +119,8 @@ pub struct GraphBasedRankingRuleState { all_costs: MappedInterner>, /// An index in the first element of `all_distances`, giving the cost of the next bucket cur_cost: u64, + /// One above the highest possible cost for this rule + next_max_cost: u64, } impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBasedRankingRule { @@ -131,7 +134,20 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase _universe: &RoaringBitmap, query_graph: &QueryGraph, ) -> Result<()> { + // the `next_max_cost` is the successor integer to the maximum cost of the paths in the graph. + // + // When there is a matching strategy, it also factors the additional costs of: + // 1. The words that are matched in phrases + // 2. Skipping words (by adding them to the paths with a cost) + let mut next_max_cost = 1; let removal_cost = if let Some(terms_matching_strategy) = self.terms_matching_strategy { + // add the cost of the phrase to the next_max_cost + next_max_cost += query_graph + .words_in_phrases_count(ctx) + // remove 1 from the words in phrases count, because when there is a phrase we can now have a document + // where only the phrase is matching, and none of the non-phrase words. + // With the `1` that `next_max_cost` is initialized with, this gets counted twice. + .saturating_sub(1) as u64; match terms_matching_strategy { TermsMatchingStrategy::Last => { let removal_order = @@ -161,12 +177,16 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase // Then pre-compute the cost of all paths from each node to the end node let all_costs = graph.find_all_costs_to_end(); + next_max_cost += + all_costs.get(graph.query_graph.root_node).iter().copied().max().unwrap_or(0); + let state = GraphBasedRankingRuleState { graph, conditions_cache: condition_docids_cache, dead_ends_cache, all_costs, cur_cost: 0, + next_max_cost, }; self.state = Some(state); @@ -180,17 +200,13 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase logger: &mut dyn SearchLogger, universe: &RoaringBitmap, ) -> Result>> { - // If universe.len() <= 1, the bucket sort algorithm - // should not have called this function. - assert!(universe.len() > 1); // Will crash if `next_bucket` is called before `start_iteration` or after `end_iteration`, // should never happen let mut state = self.state.take().unwrap(); + let all_costs = state.all_costs.get(state.graph.query_graph.root_node); // Retrieve the cost of the paths to compute - let Some(&cost) = state - .all_costs - .get(state.graph.query_graph.root_node) + let Some(&cost) = all_costs .iter() .find(|c| **c >= state.cur_cost) else { self.state = None; @@ -206,8 +222,12 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase dead_ends_cache, all_costs, cur_cost: _, + next_max_cost, } = &mut state; + let rank = *next_max_cost - cost; + let score = G::rank_to_score(Rank { rank: rank as u32, max_rank: *next_max_cost as u32 }); + let mut universe = universe.clone(); let mut used_conditions = SmallBitmap::for_interned_values_in(&graph.conditions_interner); @@ -322,7 +342,7 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase self.state = Some(state); - Ok(Some(RankingRuleOutput { query: next_query_graph, candidates: bucket })) + Ok(Some(RankingRuleOutput { query: next_query_graph, candidates: bucket, score })) } fn end_iteration( diff --git a/milli/src/search/new/query_graph.rs b/milli/src/search/new/query_graph.rs index 114eb8c4e..f1f02b69c 100644 --- a/milli/src/search/new/query_graph.rs +++ b/milli/src/search/new/query_graph.rs @@ -342,6 +342,25 @@ impl QueryGraph { } res } + + /// Number of words in the phrases in this query graph + pub(crate) fn words_in_phrases_count(&self, ctx: &SearchContext) -> usize { + let mut word_count = 0; + for (_, node) in self.nodes.iter() { + match &node.data { + QueryNodeData::Term(term) => { + let Some(phrase) = term.term_subset.original_phrase(ctx) + else { + continue + }; + let phrase = ctx.phrase_interner.get(phrase); + word_count += phrase.words.iter().copied().filter(|a| a.is_some()).count() + } + _ => continue, + } + } + word_count + } } fn add_node(nodes_data: &mut Vec, node_data: QueryNodeData) -> u16 {