From 30fb1153cc6b6a66487267687edbce41fc64e856 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Lecrenier?= Date: Mon, 1 May 2023 15:33:28 +0200 Subject: [PATCH] Speed up graph based ranking rule when a lot of different costs exist --- .../search/new/graph_based_ranking_rule.rs | 15 ++--- milli/src/search/new/query_graph.rs | 3 - .../new/ranking_rule_graph/cheapest_paths.rs | 63 ++++++++----------- 3 files changed, 35 insertions(+), 46 deletions(-) diff --git a/milli/src/search/new/graph_based_ranking_rule.rs b/milli/src/search/new/graph_based_ranking_rule.rs index f5918517b..379a0b2ab 100644 --- a/milli/src/search/new/graph_based_ranking_rule.rs +++ b/milli/src/search/new/graph_based_ranking_rule.rs @@ -309,11 +309,6 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase Ok(ControlFlow::Continue(())) } })?; - // if at_least_one { - // unsafe { - // println!("\n===== {id} COST: {cost} ==== PATHS: {COUNT_PATHS} ==== NODES: {COUNT_VISITED_NODES} ===== UNIVERSE: {universe}", id=self.id, universe=universe.len()); - // } - // } logger.log_internal_state(graph); logger.log_internal_state(&good_paths); @@ -337,8 +332,14 @@ impl<'ctx, G: RankingRuleGraphTrait> RankingRule<'ctx, QueryGraph> for GraphBase let next_query_graph = QueryGraph::build_from_paths(paths); - if !nodes_with_removed_outgoing_conditions.is_empty() { - graph.update_all_costs_before_nodes(&nodes_with_removed_outgoing_conditions, all_costs); + #[allow(clippy::comparison_chain)] + if nodes_with_removed_outgoing_conditions.len() == 1 { + graph.update_all_costs_before_node( + *nodes_with_removed_outgoing_conditions.first().unwrap(), + all_costs, + ); + } else if nodes_with_removed_outgoing_conditions.len() > 1 { + *all_costs = graph.find_all_costs_to_end(); } self.state = Some(state); diff --git a/milli/src/search/new/query_graph.rs b/milli/src/search/new/query_graph.rs index faa52d0b9..0c3191390 100644 --- a/milli/src/search/new/query_graph.rs +++ b/milli/src/search/new/query_graph.rs @@ -8,7 +8,6 @@ use crate::search::new::interner::Interner; use crate::Result; use fxhash::{FxHashMap, FxHasher}; use std::cmp::Ordering; -use std::collections::hash_map::Entry; use std::collections::BTreeMap; use std::hash::{Hash, Hasher}; @@ -364,8 +363,6 @@ impl QueryGraph { └──│ b2 │──│ c2 │───│ d │───│ e2 │ └────┘ └────┘ └────┘ └────┘ ``` - But we accept the first representation as it reduces the size - of the graph and shouldn't cause much problems. */ pub fn build_from_paths( paths: Vec, LocatedQueryTermSubset)>>, diff --git a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs index 4a104df69..c065cc706 100644 --- a/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs +++ b/milli/src/search/new/ranking_rule_graph/cheapest_paths.rs @@ -1,8 +1,11 @@ #![allow(clippy::too_many_arguments)] use std::collections::{BTreeSet, VecDeque}; +use std::iter::FromIterator; use std::ops::ControlFlow; +use fxhash::FxHashSet; + use super::{DeadEndsCache, RankingRuleGraph, RankingRuleGraphTrait}; use crate::search::new::interner::{Interned, MappedInterner}; use crate::search::new::query_graph::QueryNode; @@ -112,9 +115,6 @@ impl VisitorState { } } } - // if there wasn't any valid path from this node to the end node, then - // this node is a dead end **for this specific cost**. - // we could encode this in the dead-ends cache Ok(ControlFlow::Continue(any_valid)) } @@ -126,11 +126,11 @@ impl VisitorState { visit: VisitFn, ctx: &mut VisitorContext, ) -> Result> { - if ctx + if !ctx .all_costs_from_node .get(dest_node) .iter() - .all(|next_cost| *next_cost != self.remaining_cost) + .any(|next_cost| *next_cost == self.remaining_cost) { return Ok(ControlFlow::Continue(false)); } @@ -158,14 +158,12 @@ impl VisitorState { ) -> Result> { assert!(dest_node != ctx.graph.query_graph.end_node); - if self.forbidden_conditions_to_nodes.contains(dest_node) + if self.forbidden_conditions.contains(condition) + || self.forbidden_conditions_to_nodes.contains(dest_node) || edge_new_nodes_to_skip.intersects(&self.visited_nodes) { return Ok(ControlFlow::Continue(false)); } - if self.forbidden_conditions.contains(condition) { - return Ok(ControlFlow::Continue(false)); - } // Checking that from the destination node, there is at least // one cost that we can visit that corresponds to our remaining budget. @@ -244,48 +242,41 @@ impl RankingRuleGraph { costs_to_end } - pub fn update_all_costs_before_nodes( + pub fn update_all_costs_before_node( &self, - removed_nodes: &BTreeSet>, + node_with_removed_outgoing_conditions: Interned, costs: &mut MappedInterner>, ) { - // unsafe { - // FIND_ALL_COSTS_INC_COUNT += 1; - // println!( - // "update_all_costs_after_removing_edge incrementally count: {}", - // FIND_ALL_COSTS_INC_COUNT - // ); - // } - let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len()); let mut node_stack = VecDeque::new(); - for node in removed_nodes.iter() { - enqueued.insert(*node); - node_stack.push_back(*node); - } + enqueued.insert(node_with_removed_outgoing_conditions); + node_stack.push_back(node_with_removed_outgoing_conditions); - while let Some(cur_node) = node_stack.pop_front() { - let mut self_costs = BTreeSet::::new(); + 'main_loop: while let Some(cur_node) = node_stack.pop_front() { + let mut costs_to_remove = FxHashSet::default(); + for c in costs.get(cur_node) { + costs_to_remove.insert(*c); + } let cur_node_edges = &self.edges_of_node.get(cur_node); for edge_idx in cur_node_edges.iter() { let edge = self.edges_store.get(edge_idx).as_ref().unwrap(); - let succ_node = edge.dest_node; - let succ_costs = costs.get(succ_node); - for succ_distance in succ_costs { - self_costs.insert(edge.cost as u64 + succ_distance); + for cost in costs.get(edge.dest_node).iter() { + costs_to_remove.remove(&(*cost + edge.cost as u64)); + if costs_to_remove.is_empty() { + continue 'main_loop; + } } } - let costs_to_end_cur_node = costs.get_mut(cur_node); - for cost in self_costs.iter() { - costs_to_end_cur_node.push(*cost); + if costs_to_remove.is_empty() { + continue 'main_loop; } - let self_costs = self_costs.into_iter().collect::>(); - if &self_costs == costs.get(cur_node) { - continue; + let mut new_costs = BTreeSet::from_iter(costs.get(cur_node).iter().copied()); + for c in costs_to_remove { + new_costs.remove(&c); } - *costs.get_mut(cur_node) = self_costs; + *costs.get_mut(cur_node) = new_costs.into_iter().collect(); for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() { if !enqueued.contains(prev_node) {