mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-30 00:34:26 +01:00
Refactor the paths_of_cost algorithm
Support conditions that require certain nodes to be skipped
This commit is contained in:
parent
01e24dd630
commit
aa9592455c
@ -9,141 +9,202 @@ use crate::search::new::query_graph::QueryNode;
|
|||||||
use crate::search::new::small_bitmap::SmallBitmap;
|
use crate::search::new::small_bitmap::SmallBitmap;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
|
||||||
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
type VisitFn<'f, G> = &'f mut dyn FnMut(
|
||||||
pub fn visit_paths_of_cost(
|
&[Interned<<G as RankingRuleGraphTrait>::Condition>],
|
||||||
&mut self,
|
&mut RankingRuleGraph<G>,
|
||||||
from: Interned<QueryNode>,
|
&mut DeadEndsCache<<G as RankingRuleGraphTrait>::Condition>,
|
||||||
cost: u16,
|
) -> Result<ControlFlow<()>>;
|
||||||
all_distances: &MappedInterner<QueryNode, Vec<u16>>,
|
|
||||||
dead_ends_cache: &mut DeadEndsCache<G::Condition>,
|
struct VisitorContext<'a, G: RankingRuleGraphTrait> {
|
||||||
mut visit: impl FnMut(
|
graph: &'a mut RankingRuleGraph<G>,
|
||||||
&[Interned<G::Condition>],
|
all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
|
||||||
&mut Self,
|
dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
|
||||||
&mut DeadEndsCache<G::Condition>,
|
}
|
||||||
) -> Result<ControlFlow<()>>,
|
|
||||||
) -> Result<()> {
|
struct VisitorState<G: RankingRuleGraphTrait> {
|
||||||
let _ = self.visit_paths_of_cost_rec(
|
remaining_cost: u64,
|
||||||
from,
|
|
||||||
cost,
|
path: Vec<Interned<G::Condition>>,
|
||||||
all_distances,
|
|
||||||
dead_ends_cache,
|
visited_conditions: SmallBitmap<G::Condition>,
|
||||||
&mut visit,
|
visited_nodes: SmallBitmap<QueryNode>,
|
||||||
&mut vec![],
|
|
||||||
&mut SmallBitmap::for_interned_values_in(&self.conditions_interner),
|
forbidden_conditions: SmallBitmap<G::Condition>,
|
||||||
dead_ends_cache.forbidden.clone(),
|
forbidden_conditions_to_nodes: SmallBitmap<QueryNode>,
|
||||||
)?;
|
}
|
||||||
|
|
||||||
|
pub struct PathVisitor<'a, G: RankingRuleGraphTrait> {
|
||||||
|
state: VisitorState<G>,
|
||||||
|
ctx: VisitorContext<'a, G>,
|
||||||
|
}
|
||||||
|
impl<'a, G: RankingRuleGraphTrait> PathVisitor<'a, G> {
|
||||||
|
pub fn new(
|
||||||
|
cost: u64,
|
||||||
|
graph: &'a mut RankingRuleGraph<G>,
|
||||||
|
all_costs_from_node: &'a MappedInterner<QueryNode, Vec<u64>>,
|
||||||
|
dead_ends_cache: &'a mut DeadEndsCache<G::Condition>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
state: VisitorState {
|
||||||
|
remaining_cost: cost,
|
||||||
|
path: vec![],
|
||||||
|
visited_conditions: SmallBitmap::for_interned_values_in(&graph.conditions_interner),
|
||||||
|
visited_nodes: SmallBitmap::for_interned_values_in(&graph.query_graph.nodes),
|
||||||
|
forbidden_conditions: SmallBitmap::for_interned_values_in(
|
||||||
|
&graph.conditions_interner,
|
||||||
|
),
|
||||||
|
forbidden_conditions_to_nodes: SmallBitmap::for_interned_values_in(
|
||||||
|
&graph.query_graph.nodes,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
ctx: VisitorContext { graph, all_costs_from_node, dead_ends_cache },
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn visit_paths(mut self, visit: VisitFn<G>) -> Result<()> {
|
||||||
|
let _ =
|
||||||
|
self.state.visit_node(self.ctx.graph.query_graph.root_node, visit, &mut self.ctx)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
pub fn visit_paths_of_cost_rec(
|
}
|
||||||
|
|
||||||
|
impl<G: RankingRuleGraphTrait> VisitorState<G> {
|
||||||
|
fn visit_node(
|
||||||
&mut self,
|
&mut self,
|
||||||
from: Interned<QueryNode>,
|
from_node: Interned<QueryNode>,
|
||||||
cost: u16,
|
visit: VisitFn<G>,
|
||||||
all_distances: &MappedInterner<QueryNode, Vec<u16>>,
|
ctx: &mut VisitorContext<G>,
|
||||||
dead_ends_cache: &mut DeadEndsCache<G::Condition>,
|
) -> Result<ControlFlow<(), bool>> {
|
||||||
visit: &mut impl FnMut(
|
|
||||||
&[Interned<G::Condition>],
|
|
||||||
&mut Self,
|
|
||||||
&mut DeadEndsCache<G::Condition>,
|
|
||||||
) -> Result<ControlFlow<()>>,
|
|
||||||
prev_conditions: &mut Vec<Interned<G::Condition>>,
|
|
||||||
cur_path: &mut SmallBitmap<G::Condition>,
|
|
||||||
mut forbidden_conditions: SmallBitmap<G::Condition>,
|
|
||||||
) -> Result<bool> {
|
|
||||||
let mut any_valid = false;
|
let mut any_valid = false;
|
||||||
|
|
||||||
let edges = self.edges_of_node.get(from).clone();
|
let edges = ctx.graph.edges_of_node.get(from_node).clone();
|
||||||
'edges_loop: for edge_idx in edges.iter() {
|
for edge_idx in edges.iter() {
|
||||||
let Some(edge) = self.edges_store.get(edge_idx).as_ref() else { continue };
|
let Some(edge) = ctx.graph.edges_store.get(edge_idx).clone() else { continue };
|
||||||
if cost < edge.cost as u16 {
|
|
||||||
|
if self.remaining_cost < edge.cost as u64 {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let next_any_valid = match edge.condition {
|
self.remaining_cost -= edge.cost as u64;
|
||||||
None => {
|
let cf = match edge.condition {
|
||||||
if edge.dest_node == self.query_graph.end_node {
|
Some(condition) => self.visit_condition(
|
||||||
any_valid = true;
|
condition,
|
||||||
let control_flow = visit(prev_conditions, self, dead_ends_cache)?;
|
|
||||||
match control_flow {
|
|
||||||
ControlFlow::Continue(_) => {}
|
|
||||||
ControlFlow::Break(_) => return Ok(true),
|
|
||||||
}
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
self.visit_paths_of_cost_rec(
|
|
||||||
edge.dest_node,
|
edge.dest_node,
|
||||||
cost - edge.cost as u16,
|
&edge.nodes_to_skip,
|
||||||
all_distances,
|
|
||||||
dead_ends_cache,
|
|
||||||
visit,
|
visit,
|
||||||
prev_conditions,
|
ctx,
|
||||||
cur_path,
|
)?,
|
||||||
forbidden_conditions.clone(),
|
None => self.visit_no_condition(edge.dest_node, &edge.nodes_to_skip, visit, ctx)?,
|
||||||
)?
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Some(condition) => {
|
|
||||||
if forbidden_conditions.contains(condition)
|
|
||||||
|| all_distances
|
|
||||||
.get(edge.dest_node)
|
|
||||||
.iter()
|
|
||||||
.all(|next_cost| *next_cost != cost - edge.cost as u16)
|
|
||||||
{
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
cur_path.insert(condition);
|
|
||||||
prev_conditions.push(condition);
|
|
||||||
let mut new_forbidden_conditions = forbidden_conditions.clone();
|
|
||||||
if let Some(next_forbidden) =
|
|
||||||
dead_ends_cache.forbidden_conditions_after_prefix(prev_conditions)
|
|
||||||
{
|
|
||||||
new_forbidden_conditions.union(&next_forbidden);
|
|
||||||
}
|
|
||||||
|
|
||||||
let next_any_valid = if edge.dest_node == self.query_graph.end_node {
|
|
||||||
any_valid = true;
|
|
||||||
let control_flow = visit(prev_conditions, self, dead_ends_cache)?;
|
|
||||||
match control_flow {
|
|
||||||
ControlFlow::Continue(_) => {}
|
|
||||||
ControlFlow::Break(_) => return Ok(true),
|
|
||||||
}
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
self.visit_paths_of_cost_rec(
|
|
||||||
edge.dest_node,
|
|
||||||
cost - edge.cost as u16,
|
|
||||||
all_distances,
|
|
||||||
dead_ends_cache,
|
|
||||||
visit,
|
|
||||||
prev_conditions,
|
|
||||||
cur_path,
|
|
||||||
new_forbidden_conditions,
|
|
||||||
)?
|
|
||||||
};
|
};
|
||||||
cur_path.remove(condition);
|
self.remaining_cost += edge.cost as u64;
|
||||||
prev_conditions.pop();
|
|
||||||
next_any_valid
|
|
||||||
}
|
|
||||||
};
|
|
||||||
any_valid |= next_any_valid;
|
|
||||||
|
|
||||||
|
let ControlFlow::Continue(next_any_valid) = cf else {
|
||||||
|
return Ok(ControlFlow::Break(()));
|
||||||
|
};
|
||||||
if next_any_valid {
|
if next_any_valid {
|
||||||
forbidden_conditions =
|
self.forbidden_conditions = ctx
|
||||||
dead_ends_cache.forbidden_conditions_for_all_prefixes_up_to(prev_conditions);
|
.dead_ends_cache
|
||||||
if cur_path.intersects(&forbidden_conditions) {
|
.forbidden_conditions_for_all_prefixes_up_to(self.path.iter().copied());
|
||||||
break 'edges_loop;
|
if self.visited_conditions.intersects(&self.forbidden_conditions) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
any_valid |= next_any_valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ControlFlow::Continue(any_valid))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_no_condition(
|
||||||
|
&mut self,
|
||||||
|
dest_node: Interned<QueryNode>,
|
||||||
|
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
|
||||||
|
visit: VisitFn<G>,
|
||||||
|
ctx: &mut VisitorContext<G>,
|
||||||
|
) -> Result<ControlFlow<(), bool>> {
|
||||||
|
if ctx
|
||||||
|
.all_costs_from_node
|
||||||
|
.get(dest_node)
|
||||||
|
.iter()
|
||||||
|
.all(|next_cost| *next_cost != self.remaining_cost)
|
||||||
|
{
|
||||||
|
return Ok(ControlFlow::Continue(false));
|
||||||
|
}
|
||||||
|
if dest_node == ctx.graph.query_graph.end_node {
|
||||||
|
let control_flow = visit(&self.path, ctx.graph, ctx.dead_ends_cache)?;
|
||||||
|
match control_flow {
|
||||||
|
ControlFlow::Continue(_) => Ok(ControlFlow::Continue(true)),
|
||||||
|
ControlFlow::Break(_) => Ok(ControlFlow::Break(())),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let old_fbct = self.forbidden_conditions_to_nodes.clone();
|
||||||
|
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
|
||||||
|
let cf = self.visit_node(dest_node, visit, ctx)?;
|
||||||
|
self.forbidden_conditions_to_nodes = old_fbct;
|
||||||
|
Ok(cf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn visit_condition(
|
||||||
|
&mut self,
|
||||||
|
condition: Interned<G::Condition>,
|
||||||
|
dest_node: Interned<QueryNode>,
|
||||||
|
edge_forbidden_nodes: &SmallBitmap<QueryNode>,
|
||||||
|
visit: VisitFn<G>,
|
||||||
|
ctx: &mut VisitorContext<G>,
|
||||||
|
) -> Result<ControlFlow<(), bool>> {
|
||||||
|
assert!(dest_node != ctx.graph.query_graph.end_node);
|
||||||
|
|
||||||
|
if self.forbidden_conditions_to_nodes.contains(dest_node)
|
||||||
|
|| edge_forbidden_nodes.intersects(&self.visited_nodes)
|
||||||
|
{
|
||||||
|
return Ok(ControlFlow::Continue(false));
|
||||||
|
}
|
||||||
|
if self.forbidden_conditions.contains(condition) {
|
||||||
|
return Ok(ControlFlow::Continue(false));
|
||||||
|
}
|
||||||
|
|
||||||
|
if ctx
|
||||||
|
.all_costs_from_node
|
||||||
|
.get(dest_node)
|
||||||
|
.iter()
|
||||||
|
.all(|next_cost| *next_cost != self.remaining_cost)
|
||||||
|
{
|
||||||
|
return Ok(ControlFlow::Continue(false));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.path.push(condition);
|
||||||
|
self.visited_nodes.insert(dest_node);
|
||||||
|
self.visited_conditions.insert(condition);
|
||||||
|
|
||||||
|
let old_fc = self.forbidden_conditions.clone();
|
||||||
|
if let Some(next_forbidden) =
|
||||||
|
ctx.dead_ends_cache.forbidden_conditions_after_prefix(self.path.iter().copied())
|
||||||
|
{
|
||||||
|
self.forbidden_conditions.union(&next_forbidden);
|
||||||
|
}
|
||||||
|
let old_fctn = self.forbidden_conditions_to_nodes.clone();
|
||||||
|
self.forbidden_conditions_to_nodes.union(edge_forbidden_nodes);
|
||||||
|
|
||||||
|
let cf = self.visit_node(dest_node, visit, ctx)?;
|
||||||
|
|
||||||
|
self.forbidden_conditions_to_nodes = old_fctn;
|
||||||
|
self.forbidden_conditions = old_fc;
|
||||||
|
|
||||||
|
self.visited_conditions.remove(condition);
|
||||||
|
self.visited_nodes.remove(dest_node);
|
||||||
|
self.path.pop();
|
||||||
|
|
||||||
|
Ok(cf)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(any_valid)
|
impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
||||||
}
|
pub fn find_all_costs_to_end(&self) -> MappedInterner<QueryNode, Vec<u64>> {
|
||||||
|
let mut costs_to_end = self.query_graph.nodes.map(|_| vec![]);
|
||||||
pub fn initialize_distances_with_necessary_edges(&self) -> MappedInterner<QueryNode, Vec<u16>> {
|
|
||||||
let mut distances_to_end = self.query_graph.nodes.map(|_| vec![]);
|
|
||||||
let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len());
|
let mut enqueued = SmallBitmap::new(self.query_graph.nodes.len());
|
||||||
|
|
||||||
let mut node_stack = VecDeque::new();
|
let mut node_stack = VecDeque::new();
|
||||||
|
|
||||||
*distances_to_end.get_mut(self.query_graph.end_node) = vec![0];
|
*costs_to_end.get_mut(self.query_graph.end_node) = vec![0];
|
||||||
|
|
||||||
for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() {
|
for prev_node in self.query_graph.nodes.get(self.query_graph.end_node).predecessors.iter() {
|
||||||
node_stack.push_back(prev_node);
|
node_stack.push_back(prev_node);
|
||||||
@ -151,22 +212,22 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
while let Some(cur_node) = node_stack.pop_front() {
|
while let Some(cur_node) = node_stack.pop_front() {
|
||||||
let mut self_distances = BTreeSet::<u16>::new();
|
let mut self_costs = BTreeSet::<u64>::new();
|
||||||
|
|
||||||
let cur_node_edges = &self.edges_of_node.get(cur_node);
|
let cur_node_edges = &self.edges_of_node.get(cur_node);
|
||||||
for edge_idx in cur_node_edges.iter() {
|
for edge_idx in cur_node_edges.iter() {
|
||||||
let edge = self.edges_store.get(edge_idx).as_ref().unwrap();
|
let edge = self.edges_store.get(edge_idx).as_ref().unwrap();
|
||||||
let succ_node = edge.dest_node;
|
let succ_node = edge.dest_node;
|
||||||
let succ_distances = distances_to_end.get(succ_node);
|
let succ_costs = costs_to_end.get(succ_node);
|
||||||
for succ_distance in succ_distances {
|
for succ_distance in succ_costs {
|
||||||
self_distances.insert(edge.cost as u16 + succ_distance);
|
self_costs.insert(edge.cost as u64 + succ_distance);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let distances_to_end_cur_node = distances_to_end.get_mut(cur_node);
|
let costs_to_end_cur_node = costs_to_end.get_mut(cur_node);
|
||||||
for cost in self_distances.iter() {
|
for cost in self_costs.iter() {
|
||||||
distances_to_end_cur_node.push(*cost);
|
costs_to_end_cur_node.push(*cost);
|
||||||
}
|
}
|
||||||
*distances_to_end.get_mut(cur_node) = self_distances.into_iter().collect();
|
*costs_to_end.get_mut(cur_node) = self_costs.into_iter().collect();
|
||||||
for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() {
|
for prev_node in self.query_graph.nodes.get(cur_node).predecessors.iter() {
|
||||||
if !enqueued.contains(prev_node) {
|
if !enqueued.contains(prev_node) {
|
||||||
node_stack.push_back(prev_node);
|
node_stack.push_back(prev_node);
|
||||||
@ -174,6 +235,6 @@ impl<G: RankingRuleGraphTrait> RankingRuleGraph<G> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
distances_to_end
|
costs_to_end
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -36,12 +36,12 @@ impl<T> DeadEndsCache<T> {
|
|||||||
}
|
}
|
||||||
pub fn forbidden_conditions_for_all_prefixes_up_to(
|
pub fn forbidden_conditions_for_all_prefixes_up_to(
|
||||||
&mut self,
|
&mut self,
|
||||||
prefix: &[Interned<T>],
|
prefix: impl Iterator<Item = Interned<T>>,
|
||||||
) -> SmallBitmap<T> {
|
) -> SmallBitmap<T> {
|
||||||
let mut forbidden = self.forbidden.clone();
|
let mut forbidden = self.forbidden.clone();
|
||||||
let mut cursor = self;
|
let mut cursor = self;
|
||||||
for c in prefix.iter() {
|
for c in prefix {
|
||||||
if let Some(next) = cursor.advance(*c) {
|
if let Some(next) = cursor.advance(c) {
|
||||||
cursor = next;
|
cursor = next;
|
||||||
forbidden.union(&cursor.forbidden);
|
forbidden.union(&cursor.forbidden);
|
||||||
} else {
|
} else {
|
||||||
@ -52,11 +52,11 @@ impl<T> DeadEndsCache<T> {
|
|||||||
}
|
}
|
||||||
pub fn forbidden_conditions_after_prefix(
|
pub fn forbidden_conditions_after_prefix(
|
||||||
&mut self,
|
&mut self,
|
||||||
prefix: &[Interned<T>],
|
prefix: impl Iterator<Item = Interned<T>>,
|
||||||
) -> Option<SmallBitmap<T>> {
|
) -> Option<SmallBitmap<T>> {
|
||||||
let mut cursor = self;
|
let mut cursor = self;
|
||||||
for c in prefix.iter() {
|
for c in prefix {
|
||||||
if let Some(next) = cursor.advance(*c) {
|
if let Some(next) = cursor.advance(c) {
|
||||||
cursor = next;
|
cursor = next;
|
||||||
} else {
|
} else {
|
||||||
return None;
|
return None;
|
||||||
|
Loading…
Reference in New Issue
Block a user