mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-30 00:34:26 +01:00
Refactor the paths_of_cost algorithm
Support conditions that require certain nodes to be skipped
This commit is contained in:
parent
01e24dd630
commit
aa9592455c
@ -9,141 +9,202 @@ use crate::search::new::query_graph::QueryNode;
|
||||
use crate::search::new::small_bitmap::SmallBitmap;
|
||||
use crate::Result;
|
||||
|
||||
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
||||
pub fn visit_paths_of_cost(
|
||||
&mut self,
|
||||
from: Interned<QueryNode>,
|
||||
cost: u16,
|
||||
all_distances: &MappedInterner<QueryNode, Vec<u16>>,
|
||||
dead_ends_cache: &mut DeadEndsCache<G::Condition>,
|
||||
mut visit: impl FnMut(
|
||||
&[Interned<G::Condition>],
|
||||
&mut Self,
|
||||
&mut DeadEndsCache<G::Condition>,
|
||||
) -> Result<ControlFlow<()>>,
|
||||
) -> Result<()> {
|
||||
let _ = self.visit_paths_of_cost_rec(
|
||||
from,
|
||||
cost,
|
||||
all_distances,
|
||||
dead_ends_cache,
|
||||
&mut visit,
|
||||
&mut vec![],
|
||||
&mut SmallBitmap::for_interned_values_in(&self.conditions_interner),
|
||||
dead_ends_cache.forbidden.clone(),
|
||||
)?;
|
||||
type VisitFn<'f, G> = &'f mut dyn FnMut(
|
||||
&[Interned<<G as RankingRuleGraphTrait>::Condition>],
|
||||
&mut RankingRuleGraph<G>,
|
||||
&mut DeadEndsCache<<G as RankingRuleGraphTrait>::Condition>,
|
||||
) -> Result<ControlFlow<()>>;
|
||||
|
||||
struct VisitorContext<'a, G: RankingRuleGraphTrait> {
|
||||
graph: &'a mut RankingRuleGraph<G>,
|
||||
all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
|
||||
dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
|
||||
}
|
||||
|
||||
struct VisitorState<G: RankingRuleGraphTrait> {
|
||||
remaining_cost: u64,
|
||||
|
||||
path: Vec<Interned<G::Condition>>,
|
||||
|
||||
visited_conditions: SmallBitmap<G::Condition>,
|
||||
visited_nodes: SmallBitmap<QueryNode>,
|
||||
|
||||
forbidden_conditions: SmallBitmap<G::Condition>,
|
||||
forbidden_conditions_to_nodes: SmallBitmap<QueryNode>,
|
||||
}
|
||||
|
||||
pub struct PathVisitor<'a, G: RankingRuleGraphTrait> {
|
||||
state: VisitorState<G>,
|
||||
ctx: VisitorContext<'a, G>,
|
||||
}
|
||||
impl<'a, G: RankingRuleGraphTrait> PathVisitor<'a, G> {
|
||||
pub fn new(
|
||||
cost: u64,
|
||||
graph: &'a mut RankingRuleGraph<G>,
|
||||
all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
|
||||
dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state: VisitorState {
|
||||
remaining_cost: cost,
|
||||
path: vec![],
|
||||
visited_conditions: SmallBitmap::for_interned_values_in(&graph.conditions_interner),
|
||||
visited_nodes: SmallBitmap::for_interned_values_in(&graph.query_graph.nodes),
|
||||
forbidden_conditions: SmallBitmap::for_interned_values_in(
|
||||
&graph.conditions_interner,
|
||||
),
|
||||
forbidden_conditions_to_nodes: SmallBitmap::for_interned_values_in(
|
||||
&graph.query_graph.nodes,
|
||||
),
|
||||
},
|
||||
ctx: VisitorContext { graph, all_costs_from_node, dead_ends_cache },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn visit_paths(mut self, visit: VisitFn<G>) -> Result<()> {
|
||||
let _ =
|
||||
self.state.visit_node(self.ctx.graph.query_graph.root_node, visit, &mut self.ctx)?;
|
||||
Ok(())
|
||||
}
|
||||
pub fn visit_paths_of_cost_rec(
|
||||
}
|
||||
|
||||
impl<G: RankingRuleGraphTrait> VisitorState<G> {
|
||||
fn visit_node(
|
||||
&mut self,
|
||||
from: Interned<QueryNode>,
|
||||
cost: u16,
|
||||
all_distances: &MappedInterner<QueryNode, Vec<u16>>,
|
||||
dead_ends_cache: &mut DeadEndsCache<G::Condition>,
|
||||
visit: &mut impl FnMut(
|
||||
&[Interned<G::Condition>],
|
||||
&mut Self,
|
||||
&mut DeadEndsCache<G::Condition>,
|
||||
) -> Result<ControlFlow<()>>,
|
||||
prev_conditions: &mut Vec<Interned<G::Condition>>,
|
||||
cur_path: &mut SmallBitmap<G::Condition>,
|
||||
mut forbidden_conditions: SmallBitmap<G::Condition>,
|
||||
) -> Result<bool> {
|
||||
from_node: Interned<QueryNode>,
|
||||
visit: VisitFn<G>,
|
||||
ctx: &mut VisitorContext<G>,
|
||||
) -> Result<ControlFlow<(), bool>> {
|
||||
let mut any_valid = false;
|
||||
|
||||
let edges = self.edges_of_node.get(from).clone();
|
||||
'edges_loop: for edge_idx in edges.iter() {
|
||||
let Some(edge) = self.edges_store.get(edge_idx).as_ref() else { continue };
|
||||
if cost < edge.cost as u16 {
|
||||
let edges = ctx.graph.edges_of_node.get(from_node).clone();
|
||||
for edge_idx in edges.iter() {
|
||||
let Some(edge) = ctx.graph.edges_store.get(edge_idx).clone() else { continue };
|
||||
|
||||
if self.remaining_cost < edge.cost as u64 {
|
||||
continue;
|
||||
}
|
||||
let next_any_valid = match edge.condition {
|
||||
None => {
|
||||
if edge.dest_node == self.query_graph.end_node {
|
||||
any_valid = true;
|
||||
let control_flow = visit(prev_conditions, self, dead_ends_cache)?;
|
||||
match control_flow {
|
||||
ControlFlow::Continue(_) => {}
|
||||
ControlFlow::Break(_) => return Ok(true),
|
||||
}
|
||||
true
|
||||
} else {
|
||||
self.visit_paths_of_cost_rec(
|
||||
edge.dest_node,
|
||||
cost - edge.cost as u16,
|
||||
all_distances,
|
||||
dead_ends_cache,
|
||||
visit,
|
||||
prev_conditions,
|
||||
cur_path,
|
||||
forbidden_conditions.clone(),
|
||||
)?
|
||||
}
|
||||
}
|
||||
Some(condition) => {
|
||||
if forbidden_conditions.contains(condition)
|
||||
|| all_distances
|
||||
.get(edge.dest_node)
|
||||
.iter()
|
||||
.all(|next_cost| *next_cost != cost - edge.cost as u16)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
cur_path.insert(condition);
|
||||
prev_conditions.push(condition);
|
||||
let mut new_forbidden_conditions = forbidden_conditions.clone();
|
||||
if let Some(next_forbidden) =
|
||||
dead_ends_cache.forbidden_conditions_after_prefix(prev_conditions)
|
||||
{
|
||||
new_forbidden_conditions.union(&next_forbidden);
|
||||
}
|
||||
|
||||
let next_any_valid = if edge.dest_node == self.query_graph.end_node {
|
||||
any_valid = true;
|
||||
let control_flow = visit(prev_conditions, self, dead_ends_cache)?;
|
||||
match control_flow {
|
||||
ControlFlow::Continue(_) => {}
|
||||
ControlFlow::Break(_) => return Ok(true),
|
||||
}
|
||||
true
|
||||
} else {
|
||||
self.visit_paths_of_cost_rec(
|
||||
edge.dest_node,
|
||||
cost - edge.cost as u16,
|
||||
all_distances,
|
||||
dead_ends_cache,
|
||||
visit,
|
||||
prev_conditions,
|
||||
cur_path,
|
||||
new_forbidden_conditions,
|
||||
)?
|
||||
};
|
||||
cur_path.remove(condition);
|
||||
prev_conditions.pop();
|
||||
next_any_valid
|
||||
}
|
||||
self.remaining_cost -= edge.cost as u64;
|
||||
let cf = match edge.condition {
|
||||
Some(condition) => self.visit_condition(
|
||||
condition,
|
||||
edge.dest_node,
|
||||
&edge.nodes_to_skip,
|
||||
visit,
|
||||
ctx,
|
||||
)?,
|
||||
None => self.visit_no_condition(edge.dest_node, &edge.nodes_to_skip, visit, ctx)?,
|
||||
};
|
||||
any_valid |= next_any_valid;
|
||||
self.remaining_cost += edge.cost as u64;
|
||||
|
||||
let ControlFlow::Continue(next_any_valid) = cf else {
|
||||
return Ok(ControlFlow::Break(()));
|
||||
};
|
||||
if next_any_valid {
|
||||
forbidden_conditions =
|
||||
dead_ends_cache.forbidden_conditions_for_all_prefixes_up_to(prev_conditions);
|
||||
if cur_path.intersects(&forbidden_conditions) {
|
||||
break 'edges_loop;
|
||||
self.forbidden_conditions = ctx
|
||||
.dead_ends_cache
|
||||
.forbidden_conditions_for_all_prefixes_up_to(self.path.iter().copied());
|
||||
if self.visited_conditions.intersects(&self.forbidden_conditions) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
any_valid |= next_any_valid;
|
||||
}
|
||||
|
||||
Ok(any_valid)
|
||||
Ok(ControlFlow::Continue(any_valid))
|
||||
}
|
||||
|
||||
pub fn initialize_distances_with_necessary_edges(&self) -> MappedInterner<QueryNode, Vec<u16>> {
|
||||
let mut distances_to_end = self.query_graph.nodes.map(|_| vec![]);
|
||||
fn visit_no_condition(
|
||||
&mut self,
|
||||
dest_node: Interned<QueryNode>,
|
||||
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
|
||||
visit: VisitFn<G>,
|
||||
ctx: &mut VisitorContext<G>,
|
||||
) -> Result<ControlFlow<(), bool>> {
|
||||
if ctx
|
||||
.all_costs_from_node
|
||||
.get(dest_node)
|
||||
.iter()
|
||||
.all(|next_cost| *next_cost != self.remaining_cost)
|
||||
{
|
||||
return Ok(ControlFlow::Continue(false));
|
||||
}
|
||||
if dest_node == ctx.graph.query_graph.end_node {
|
||||
let control_flow = visit(&self.path, ctx.graph, ctx.dead_ends_cache)?;
|
||||
match control_flow {
|
||||
ControlFlow::Continue(_) => Ok(ControlFlow::Continue(true)),
|
||||
ControlFlow::Break(_) => Ok(ControlFlow::Break(())),
|
||||
}
|
||||
} else {
|
||||
let old_fbct = self.forbidden_conditions_to_nodes.clone();
|
||||
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
|
||||
let cf = self.visit_node(dest_node, visit, ctx)?;
|
||||
self.forbidden_conditions_to_nodes = old_fbct;
|
||||
Ok(cf)
|
||||
}
|
||||
}
|
||||
fn visit_condition(
|
||||
&mut self,
|
||||
condition: Interned<G::Condition>,
|
||||
dest_node: Interned<QueryNode>,
|
||||
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
|
||||
visit: VisitFn<G>,
|
||||
ctx: &mut VisitorContext<G>,
|
||||
) -> Result<ControlFlow<(), bool>> {
|
||||
assert!(dest_node != ctx.graph.query_graph.end_node);
|
||||
|
||||
if self.forbidden_conditions_to_nodes.contains(dest_node)
|
||||
|| edge_forbidden_nodes.intersects(&self.visited_nodes)
|
||||
{
|
||||
return Ok(ControlFlow::Continue(false));
|
||||
}
|
||||
if self.forbidden_conditions.contains(condition) {
|
||||
return Ok(ControlFlow::Continue(false));
|
||||
}
|
||||
|
||||
if ctx
|
||||
.all_costs_from_node
|
||||
.get(dest_node)
|
||||
.iter()
|
||||
.all(|next_cost| *next_cost != self.remaining_cost)
|
||||
{
|
||||
return Ok(ControlFlow::Continue(false));
|
||||
}
|
||||
|
||||
self.path.push(condition);
|
||||
self.visited_nodes.insert(dest_node);
|
||||
self.visited_conditions.insert(condition);
|
||||
|
||||
let old_fc = self.forbidden_conditions.clone();
|
||||
if let Some(next_forbidden) =
|
||||
ctx.dead_ends_cache.forbidden_conditions_after_prefix(self.path.iter().copied())
|
||||
{
|
||||
self.forbidden_conditions.union(&next_forbidden);
|
||||
}
|
||||
let old_fctn = self.forbidden_conditions_to_nodes.clone();
|
||||
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
|
||||
|
||||
let cf = self.visit_node(dest_node, visit, ctx)?;
|
||||
|
||||
self.forbidden_conditions_to_nodes = old_fctn;
|
||||
self.forbidden_conditions = old_fc;
|
||||
|
||||
self.visited_conditions.remove(condition);
|
||||
self.visited_nodes.remove(dest_node);
|
||||
self.path.pop();
|
||||
|
||||
Ok(cf)
|
||||
}
|
||||
}
|
||||
|
||||
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
||||
pub fn find_all_costs_to_end(&self) -> MappedInterner<QueryNode, Vec<u64>> {
|
||||
let mut costs_to_end = self.query_graph.nodes.map(|_| vec![]);
|
||||
let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len());
|
||||
|
||||
let mut node_stack = VecDeque::new();
|
||||
|
||||
*distances_to_end.get_mut(self.query_graph.end_node) = vec![0];
|
||||
*costs_to_end.get_mut(self.query_graph.end_node) = vec![0];
|
||||
|
||||
for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() {
|
||||
node_stack.push_back(prev_node);
|
||||
@ -151,22 +212,22 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
||||
}
|
||||
|
||||
while let Some(cur_node) = node_stack.pop_front() {
|
||||
let mut self_distances = BTreeSet::<u16>::new();
|
||||
let mut self_costs = BTreeSet::<u64>::new();
|
||||
|
||||
let cur_node_edges = &self.edges_of_node.get(cur_node);
|
||||
for edge_idx in cur_node_edges.iter() {
|
||||
let edge = self.edges_store.get(edge_idx).as_ref().unwrap();
|
||||
let succ_node = edge.dest_node;
|
||||
let succ_distances = distances_to_end.get(succ_node);
|
||||
for succ_distance in succ_distances {
|
||||
self_distances.insert(edge.cost as u16 + succ_distance);
|
||||
let succ_costs = costs_to_end.get(succ_node);
|
||||
for succ_distance in succ_costs {
|
||||
self_costs.insert(edge.cost as u64 + succ_distance);
|
||||
}
|
||||
}
|
||||
let distances_to_end_cur_node = distances_to_end.get_mut(cur_node);
|
||||
for cost in self_distances.iter() {
|
||||
distances_to_end_cur_node.push(*cost);
|
||||
let costs_to_end_cur_node = costs_to_end.get_mut(cur_node);
|
||||
for cost in self_costs.iter() {
|
||||
costs_to_end_cur_node.push(*cost);
|
||||
}
|
||||
*distances_to_end.get_mut(cur_node) = self_distances.into_iter().collect();
|
||||
*costs_to_end.get_mut(cur_node) = self_costs.into_iter().collect();
|
||||
for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() {
|
||||
if !enqueued.contains(prev_node) {
|
||||
node_stack.push_back(prev_node);
|
||||
@ -174,6 +235,6 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
||||
}
|
||||
}
|
||||
}
|
||||
distances_to_end
|
||||
costs_to_end
|
||||
}
|
||||
}
|
||||
|
@ -36,12 +36,12 @@ impl<T> DeadEndsCache<T> {
|
||||
}
|
||||
pub fn forbidden_conditions_for_all_prefixes_up_to(
|
||||
&mut self,
|
||||
prefix: &[Interned<T>],
|
||||
prefix: impl Iterator<Item = Interned<T>>,
|
||||
) -> SmallBitmap<T> {
|
||||
let mut forbidden = self.forbidden.clone();
|
||||
let mut cursor = self;
|
||||
for c in prefix.iter() {
|
||||
if let Some(next) = cursor.advance(*c) {
|
||||
for c in prefix {
|
||||
if let Some(next) = cursor.advance(c) {
|
||||
cursor = next;
|
||||
forbidden.union(&cursor.forbidden);
|
||||
} else {
|
||||
@ -52,11 +52,11 @@ impl<T> DeadEndsCache<T> {
|
||||
}
|
||||
pub fn forbidden_conditions_after_prefix(
|
||||
&mut self,
|
||||
prefix: &[Interned<T>],
|
||||
prefix: impl Iterator<Item = Interned<T>>,
|
||||
) -> Option<SmallBitmap<T>> {
|
||||
let mut cursor = self;
|
||||
for c in prefix.iter() {
|
||||
if let Some(next) = cursor.advance(*c) {
|
||||
for c in prefix {
|
||||
if let Some(next) = cursor.advance(c) {
|
||||
cursor = next;
|
||||
} else {
|
||||
return None;
|
||||
|
Loading…
Reference in New Issue
Block a user