mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-22 21:04:27 +01:00
Introduce the AstarBagIter that iterates through best paths
This commit is contained in:
parent
7dc594ba4d
commit
1e358e3ae8
24
Cargo.lock
generated
24
Cargo.lock
generated
@ -85,6 +85,15 @@ dependencies = [
|
||||
"warp",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "astar-iter"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/Kerollmops/astar-iter#87cb97a11c701f1a6025b72b673a8bfd0ca249a5"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.11"
|
||||
@ -636,6 +645,15 @@ dependencies = [
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25"
|
||||
dependencies = [
|
||||
"autocfg 1.0.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "headers"
|
||||
version = "0.3.2"
|
||||
@ -785,11 +803,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "1.4.0"
|
||||
version = "1.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c398b2b113b55809ceb9ee3e753fcbac793f1956663f3c36549c1346015c2afe"
|
||||
checksum = "86b45e59b16c76b11bf9738fd5d38879d3bd28ad292d7b313608becb17ae2df9"
|
||||
dependencies = [
|
||||
"autocfg 1.0.0",
|
||||
"hashbrown",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -981,6 +1000,7 @@ dependencies = [
|
||||
"arc-cache",
|
||||
"askama",
|
||||
"askama_warp",
|
||||
"astar-iter",
|
||||
"bitpacking",
|
||||
"byteorder",
|
||||
"cow-utils",
|
||||
|
@ -8,6 +8,7 @@ default-run = "indexer"
|
||||
[dependencies]
|
||||
anyhow = "1.0.28"
|
||||
arc-cache = { git = "https://github.com/Kerollmops/rust-arc-cache.git", rev = "56530f2" }
|
||||
astar-iter = { git = "https://github.com/Kerollmops/astar-iter" }
|
||||
bitpacking = "0.8.2"
|
||||
byteorder = "1.3.4"
|
||||
cow-utils = "0.1.2"
|
||||
@ -34,7 +35,7 @@ log = "0.4.8"
|
||||
stderrlog = "0.4.3"
|
||||
|
||||
# best proximity
|
||||
indexmap = "1.4.0"
|
||||
indexmap = "1.5.1"
|
||||
|
||||
# to implement internally
|
||||
itertools = "0.9.0"
|
||||
|
@ -1,232 +0,0 @@
|
||||
use std::cmp;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::iter_shortest_paths::astar_bag;
|
||||
use log::debug;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
const ONE_ATTRIBUTE: u32 = 1000;
|
||||
const MAX_DISTANCE: u32 = 8;
|
||||
|
||||
fn index_proximity(lhs: u32, rhs: u32) -> u32 {
|
||||
if lhs <= rhs {
|
||||
cmp::min(rhs - lhs, MAX_DISTANCE)
|
||||
} else {
|
||||
cmp::min((lhs - rhs) + 1, MAX_DISTANCE)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 {
|
||||
let (lhs_attr, lhs_index) = extract_position(lhs);
|
||||
let (rhs_attr, rhs_index) = extract_position(rhs);
|
||||
if lhs_attr != rhs_attr { MAX_DISTANCE }
|
||||
else { index_proximity(lhs_index, rhs_index) }
|
||||
}
|
||||
|
||||
// Returns the attribute and index parts.
|
||||
pub fn extract_position(position: u32) -> (u32, u32) {
|
||||
(position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
enum Node {
|
||||
// Is this node is the first node.
|
||||
Uninit,
|
||||
Init {
|
||||
// The layer where this node located.
|
||||
layer: usize,
|
||||
// The position where this node is located.
|
||||
position: u32,
|
||||
// The total accumulated proximity until this node, used for skipping nodes.
|
||||
acc_proximity: u32,
|
||||
// The parent position from the above layer.
|
||||
parent_position: u32,
|
||||
},
|
||||
}
|
||||
|
||||
impl Node {
|
||||
// TODO we must skip the successors that have already been seen
|
||||
// TODO we must skip the successors that doesn't return any documents
|
||||
// this way we are able to skip entire paths
|
||||
fn successors(&self, positions: &[RoaringBitmap], best_proximity: u32) -> Vec<(Node, u32)> {
|
||||
match self {
|
||||
Node::Uninit => {
|
||||
positions[0].iter().map(|position| {
|
||||
(Node::Init { layer: 0, position, acc_proximity: 0, parent_position: 0 }, 0)
|
||||
}).collect()
|
||||
},
|
||||
// We reached the highest layer
|
||||
n @ Node::Init { .. } if n.is_complete(positions) => vec![],
|
||||
Node::Init { layer, position, acc_proximity, .. } => {
|
||||
positions[layer + 1].iter().filter_map(|p| {
|
||||
let proximity = positions_proximity(*position, p);
|
||||
let node = Node::Init {
|
||||
layer: layer + 1,
|
||||
position: p,
|
||||
acc_proximity: acc_proximity + proximity,
|
||||
parent_position: *position,
|
||||
};
|
||||
// We do not produce the nodes we have already seen in previous iterations loops.
|
||||
if node.is_complete(positions) && acc_proximity + proximity < best_proximity {
|
||||
None
|
||||
} else {
|
||||
Some((node, proximity))
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_complete(&self, positions: &[RoaringBitmap]) -> bool {
|
||||
match self {
|
||||
Node::Uninit => false,
|
||||
Node::Init { layer, .. } => *layer == positions.len() - 1,
|
||||
}
|
||||
}
|
||||
|
||||
fn position(&self) -> Option<u32> {
|
||||
match self {
|
||||
Node::Uninit => None,
|
||||
Node::Init { position, .. } => Some(*position),
|
||||
}
|
||||
}
|
||||
|
||||
fn proximity(&self) -> u32 {
|
||||
match self {
|
||||
Node::Uninit => 0,
|
||||
Node::Init { layer, position, acc_proximity, parent_position } => {
|
||||
if layer.checked_sub(1).is_some() {
|
||||
acc_proximity + positions_proximity(*position, *parent_position)
|
||||
} else {
|
||||
0
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn is_reachable<F>(&self, contains_documents: &mut F) -> bool
|
||||
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
||||
{
|
||||
match self {
|
||||
Node::Uninit => true,
|
||||
Node::Init { layer, position, parent_position, .. } => {
|
||||
match layer.checked_sub(1) {
|
||||
Some(parent_layer) => {
|
||||
(contains_documents)((parent_layer, *parent_position), (*layer, *position))
|
||||
},
|
||||
None => true,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct BestProximity {
|
||||
positions: Vec<RoaringBitmap>,
|
||||
best_proximity: u32,
|
||||
}
|
||||
|
||||
impl BestProximity {
|
||||
pub fn new(positions: Vec<RoaringBitmap>) -> BestProximity {
|
||||
let best_proximity = (positions.len() as u32).saturating_sub(1);
|
||||
BestProximity { positions, best_proximity }
|
||||
}
|
||||
}
|
||||
|
||||
impl BestProximity {
|
||||
pub fn next<F>(&mut self, mut contains_documents: F) -> Option<(u32, Vec<RoaringBitmap>)>
|
||||
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
||||
{
|
||||
let before = Instant::now();
|
||||
|
||||
if self.best_proximity == self.positions.len() as u32 * MAX_DISTANCE {
|
||||
return None;
|
||||
}
|
||||
|
||||
let BestProximity { positions, best_proximity } = self;
|
||||
|
||||
let result = astar_bag(
|
||||
&Node::Uninit, // start
|
||||
|n| n.successors(&positions, *best_proximity),
|
||||
|_| 0, // heuristic
|
||||
|n| { // success
|
||||
let c = n.is_complete(&positions) && n.proximity() >= *best_proximity;
|
||||
if n.is_reachable(&mut contains_documents) { Some(c) } else { None }
|
||||
},
|
||||
);
|
||||
|
||||
debug!("BestProximity::next() took {:.02?}", before.elapsed());
|
||||
|
||||
match result {
|
||||
Some((paths, proximity)) => {
|
||||
self.best_proximity = proximity + 1;
|
||||
// We retrieve the last path that we convert into a Vec
|
||||
let paths: Vec<_> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect();
|
||||
debug!("result: {} {:?}", proximity, paths);
|
||||
Some((proximity, paths))
|
||||
},
|
||||
None => {
|
||||
debug!("result: {:?}", None as Option<()>);
|
||||
self.best_proximity += 1;
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::iter::FromIterator;
|
||||
|
||||
fn sort<T: Ord>(mut val: (u32, Vec<T>)) -> (u32, Vec<T>) {
|
||||
val.1.sort_unstable();
|
||||
val
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn same_attribute() {
|
||||
let positions = vec![
|
||||
RoaringBitmap::from_iter(vec![0, 2, 3, 4 ]),
|
||||
RoaringBitmap::from_iter(vec![ 1, ]),
|
||||
RoaringBitmap::from_iter(vec![ 3, 6]),
|
||||
];
|
||||
let mut iter = BestProximity::new(positions);
|
||||
let f = |_, _| true;
|
||||
|
||||
assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3
|
||||
assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4
|
||||
assert_eq!(iter.next(f), Some((3+2, vec![RoaringBitmap::from_iter(vec![3, 1, 3])]))); // 5
|
||||
assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6]), RoaringBitmap::from_iter(vec![4, 1, 3])]))); // 6
|
||||
assert_eq!(iter.next(f), Some((2+5, vec![RoaringBitmap::from_iter(vec![2, 1, 6])]))); // 7
|
||||
assert_eq!(iter.next(f), Some((3+5, vec![RoaringBitmap::from_iter(vec![3, 1, 6])]))); // 8
|
||||
assert_eq!(iter.next(f), Some((4+5, vec![RoaringBitmap::from_iter(vec![4, 1, 6])]))); // 9
|
||||
assert_eq!(iter.next(f), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn different_attributes() {
|
||||
let positions = vec![
|
||||
RoaringBitmap::from_iter(vec![0, 2, 1000, 1001, 2000 ]),
|
||||
RoaringBitmap::from_iter(vec![ 1, 1000, 2001 ]),
|
||||
RoaringBitmap::from_iter(vec![ 3, 6, 2002, 3000]),
|
||||
];
|
||||
let mut iter = BestProximity::new(positions);
|
||||
let f = |_, _| true;
|
||||
|
||||
assert_eq!(iter.next(f), Some((1+1, vec![RoaringBitmap::from_iter(vec![2000, 2001, 2002])]))); // 2
|
||||
assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3
|
||||
assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4
|
||||
assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6])]))); // 6
|
||||
// We ignore others here...
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn easy_proximities() {
|
||||
fn slice_proximity(positions: &[u32]) -> u32 {
|
||||
positions.windows(2).map(|ps| positions_proximity(ps[0], ps[1])).sum::<u32>()
|
||||
}
|
||||
|
||||
assert_eq!(slice_proximity(&[1000, 1000, 2002]), 8);
|
||||
}
|
||||
}
|
@ -1,204 +0,0 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::{BinaryHeap, HashSet};
|
||||
use std::hash::Hash;
|
||||
use std::usize;
|
||||
|
||||
use indexmap::map::Entry::{Occupied, Vacant};
|
||||
use indexmap::IndexMap;
|
||||
|
||||
pub fn astar_bag<N, FN, IN, FH, FS>(
|
||||
start: &N,
|
||||
mut successors: FN,
|
||||
mut heuristic: FH,
|
||||
mut success: FS,
|
||||
) -> Option<(AstarSolution<N>, u32)>
|
||||
where
|
||||
N: Eq + Hash + Clone,
|
||||
FN: FnMut(&N) -> IN,
|
||||
IN: IntoIterator<Item = (N, u32)>,
|
||||
FH: FnMut(&N) -> u32,
|
||||
FS: FnMut(&N) -> Option<bool>,
|
||||
{
|
||||
let mut to_see = BinaryHeap::new();
|
||||
let mut min_cost = None;
|
||||
let mut sinks = HashSet::new();
|
||||
to_see.push(SmallestCostHolder {
|
||||
estimated_cost: heuristic(start),
|
||||
cost: 0,
|
||||
index: 0,
|
||||
});
|
||||
let mut parents: IndexMap<N, (HashSet<usize>, u32)> = IndexMap::new();
|
||||
parents.insert(start.clone(), (HashSet::new(), 0));
|
||||
while let Some(SmallestCostHolder { cost, index, estimated_cost, .. }) = to_see.pop() {
|
||||
if let Some(min_cost) = min_cost {
|
||||
if estimated_cost > min_cost {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let successors = {
|
||||
let (node, &(_, c)) = parents.get_index(index).unwrap();
|
||||
// We check that the node is even reachable and if so if it is an answer.
|
||||
// If this node is unreachable we skip it.
|
||||
match success(node) {
|
||||
Some(success) => if success {
|
||||
min_cost = Some(cost);
|
||||
sinks.insert(index);
|
||||
},
|
||||
None => continue,
|
||||
}
|
||||
|
||||
// We may have inserted a node several time into the binary heap if we found
|
||||
// a better way to access it. Ensure that we are currently dealing with the
|
||||
// best path and discard the others.
|
||||
if cost > c {
|
||||
continue;
|
||||
}
|
||||
successors(node)
|
||||
};
|
||||
for (successor, move_cost) in successors {
|
||||
let new_cost = cost + move_cost;
|
||||
let h; // heuristic(&successor)
|
||||
let n; // index for successor
|
||||
match parents.entry(successor) {
|
||||
Vacant(e) => {
|
||||
h = heuristic(e.key());
|
||||
n = e.index();
|
||||
let mut p = HashSet::new();
|
||||
p.insert(index);
|
||||
e.insert((p, new_cost));
|
||||
}
|
||||
Occupied(mut e) => {
|
||||
if e.get().1 > new_cost {
|
||||
h = heuristic(e.key());
|
||||
n = e.index();
|
||||
let s = e.get_mut();
|
||||
s.0.clear();
|
||||
s.0.insert(index);
|
||||
s.1 = new_cost;
|
||||
} else {
|
||||
if e.get().1 == new_cost {
|
||||
// New parent with an identical cost, this is not
|
||||
// considered as an insertion.
|
||||
e.get_mut().0.insert(index);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
to_see.push(SmallestCostHolder {
|
||||
estimated_cost: new_cost + h,
|
||||
cost: new_cost,
|
||||
index: n,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
min_cost.map(|cost| {
|
||||
let parents = parents
|
||||
.into_iter()
|
||||
.map(|(k, (ps, _))| (k, ps.into_iter().collect()))
|
||||
.collect();
|
||||
(
|
||||
AstarSolution {
|
||||
sinks: sinks.into_iter().collect(),
|
||||
parents,
|
||||
current: vec![],
|
||||
terminated: false,
|
||||
},
|
||||
cost,
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
struct SmallestCostHolder<K> {
|
||||
estimated_cost: K,
|
||||
cost: K,
|
||||
index: usize,
|
||||
}
|
||||
|
||||
impl<K: PartialEq> PartialEq for SmallestCostHolder<K> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.estimated_cost.eq(&other.estimated_cost) && self.cost.eq(&other.cost)
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: PartialEq> Eq for SmallestCostHolder<K> {}
|
||||
|
||||
impl<K: Ord> PartialOrd for SmallestCostHolder<K> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl<K: Ord> Ord for SmallestCostHolder<K> {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
match other.estimated_cost.cmp(&self.estimated_cost) {
|
||||
Ordering::Equal => self.cost.cmp(&other.cost),
|
||||
s => s,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Iterator structure created by the `astar_bag` function.
|
||||
#[derive(Clone)]
|
||||
pub struct AstarSolution<N> {
|
||||
sinks: Vec<usize>,
|
||||
parents: Vec<(N, Vec<usize>)>,
|
||||
current: Vec<Vec<usize>>,
|
||||
terminated: bool,
|
||||
}
|
||||
|
||||
impl<N: Clone + Eq + Hash> AstarSolution<N> {
|
||||
fn complete(&mut self) {
|
||||
loop {
|
||||
let ps = match self.current.last() {
|
||||
None => self.sinks.clone(),
|
||||
Some(last) => {
|
||||
let &top = last.last().unwrap();
|
||||
self.parents(top).clone()
|
||||
}
|
||||
};
|
||||
if ps.is_empty() {
|
||||
break;
|
||||
}
|
||||
self.current.push(ps);
|
||||
}
|
||||
}
|
||||
|
||||
fn next_vec(&mut self) {
|
||||
while self.current.last().map(Vec::len) == Some(1) {
|
||||
self.current.pop();
|
||||
}
|
||||
self.current.last_mut().map(Vec::pop);
|
||||
}
|
||||
|
||||
fn node(&self, i: usize) -> &N {
|
||||
&self.parents[i].0
|
||||
}
|
||||
|
||||
fn parents(&self, i: usize) -> &Vec<usize> {
|
||||
&self.parents[i].1
|
||||
}
|
||||
}
|
||||
|
||||
impl<N: Clone + Eq + Hash> Iterator for AstarSolution<N> {
|
||||
type Item = Vec<N>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.terminated {
|
||||
return None;
|
||||
}
|
||||
self.complete();
|
||||
let path = self
|
||||
.current
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|v| v.last().cloned().unwrap())
|
||||
.map(|i| self.node(i).clone())
|
||||
.collect::<Vec<_>>();
|
||||
self.next_vec();
|
||||
self.terminated = self.current.is_empty();
|
||||
Some(path)
|
||||
}
|
||||
}
|
@ -1,7 +1,6 @@
|
||||
mod best_proximity;
|
||||
mod criterion;
|
||||
mod heed_codec;
|
||||
mod iter_shortest_paths;
|
||||
mod node;
|
||||
mod query_tokens;
|
||||
mod search;
|
||||
mod transitive_arc;
|
||||
|
104
src/node.rs
Normal file
104
src/node.rs
Normal file
@ -0,0 +1,104 @@
|
||||
use std::cmp;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
const ONE_ATTRIBUTE: u32 = 1000;
|
||||
const MAX_DISTANCE: u32 = 8;
|
||||
|
||||
fn index_proximity(lhs: u32, rhs: u32) -> u32 {
|
||||
if lhs <= rhs {
|
||||
cmp::min(rhs - lhs, MAX_DISTANCE)
|
||||
} else {
|
||||
cmp::min((lhs - rhs) + 1, MAX_DISTANCE)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 {
|
||||
let (lhs_attr, lhs_index) = extract_position(lhs);
|
||||
let (rhs_attr, rhs_index) = extract_position(rhs);
|
||||
if lhs_attr != rhs_attr { MAX_DISTANCE }
|
||||
else { index_proximity(lhs_index, rhs_index) }
|
||||
}
|
||||
|
||||
// Returns the attribute and index parts.
|
||||
pub fn extract_position(position: u32) -> (u32, u32) {
|
||||
(position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum Node {
|
||||
// Is this node is the first node.
|
||||
Uninit,
|
||||
Init {
|
||||
// The layer where this node located.
|
||||
layer: usize,
|
||||
// The position where this node is located.
|
||||
position: u32,
|
||||
// The parent position from the above layer.
|
||||
parent_position: u32,
|
||||
},
|
||||
}
|
||||
|
||||
impl Node {
|
||||
// TODO we must skip the successors that have already been seen
|
||||
// TODO we must skip the successors that doesn't return any documents
|
||||
// this way we are able to skip entire paths
|
||||
pub fn successors<F>(&self, positions: &[RoaringBitmap], contains_documents: &mut F) -> Vec<(Node, u32)>
|
||||
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
||||
{
|
||||
match self {
|
||||
Node::Uninit => {
|
||||
positions[0].iter().map(|position| {
|
||||
(Node::Init { layer: 0, position, parent_position: 0 }, 0)
|
||||
}).collect()
|
||||
},
|
||||
// We reached the highest layer
|
||||
n @ Node::Init { .. } if n.is_complete(positions) => vec![],
|
||||
Node::Init { layer, position, .. } => {
|
||||
positions[layer + 1].iter().filter_map(|p| {
|
||||
let proximity = positions_proximity(*position, p);
|
||||
let node = Node::Init {
|
||||
layer: layer + 1,
|
||||
position: p,
|
||||
parent_position: *position,
|
||||
};
|
||||
// We do not produce the nodes we have already seen in previous iterations loops.
|
||||
if node.is_reachable(contains_documents) {
|
||||
Some((node, proximity))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_complete(&self, positions: &[RoaringBitmap]) -> bool {
|
||||
match self {
|
||||
Node::Uninit => false,
|
||||
Node::Init { layer, .. } => *layer == positions.len() - 1,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn position(&self) -> Option<u32> {
|
||||
match self {
|
||||
Node::Uninit => None,
|
||||
Node::Init { position, .. } => Some(*position),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_reachable<F>(&self, contains_documents: &mut F) -> bool
|
||||
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
||||
{
|
||||
match self {
|
||||
Node::Uninit => true,
|
||||
Node::Init { layer, position, parent_position, .. } => {
|
||||
match layer.checked_sub(1) {
|
||||
Some(parent_layer) => {
|
||||
(contains_documents)((parent_layer, *parent_position), (*layer, *position))
|
||||
},
|
||||
None => true,
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
@ -1,14 +1,18 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::rc::Rc;
|
||||
|
||||
use astar_iter::AstarBagIter;
|
||||
use fst::{IntoStreamer, Streamer};
|
||||
use levenshtein_automata::DFA;
|
||||
use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder;
|
||||
use log::debug;
|
||||
use once_cell::sync::Lazy;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use crate::node::{self, Node};
|
||||
use crate::query_tokens::{QueryTokens, QueryToken};
|
||||
use crate::{Index, DocumentId, Position, Attribute};
|
||||
use crate::best_proximity::{self, BestProximity};
|
||||
|
||||
// Building these factories is not free.
|
||||
static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
|
||||
@ -214,20 +218,30 @@ impl<'a> Search<'a> {
|
||||
};
|
||||
|
||||
let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?;
|
||||
let mut candidates = Self::compute_candidates(rtxn, index, &derived_words)?;
|
||||
let candidates = Self::compute_candidates(rtxn, index, &derived_words)?;
|
||||
|
||||
let mut union_cache = HashMap::new();
|
||||
let union_cache = HashMap::new();
|
||||
let mut intersect_cache = HashMap::new();
|
||||
|
||||
let mut attribute_union_cache = HashMap::new();
|
||||
let mut attribute_intersect_cache = HashMap::new();
|
||||
|
||||
let candidates = Rc::new(RefCell::new(candidates));
|
||||
let union_cache = Rc::new(RefCell::new(union_cache));
|
||||
|
||||
// Returns `true` if there is documents in common between the two words and positions given.
|
||||
let mut contains_documents = |(lword, lpos), (rword, rpos), union_cache: &mut HashMap<_, _>, candidates: &RoaringBitmap| {
|
||||
// TODO move this closure to a better place.
|
||||
let candidates_cloned = candidates.clone();
|
||||
let union_cache_cloned = union_cache.clone();
|
||||
let mut contains_documents = |(lword, lpos), (rword, rpos)| {
|
||||
if lpos == rpos { return false }
|
||||
|
||||
let (lattr, _) = best_proximity::extract_position(lpos);
|
||||
let (rattr, _) = best_proximity::extract_position(rpos);
|
||||
// TODO move this function to a better place.
|
||||
let (lattr, _) = node::extract_position(lpos);
|
||||
let (rattr, _) = node::extract_position(rpos);
|
||||
|
||||
let candidates = &candidates_cloned.borrow();
|
||||
let mut union_cache = union_cache_cloned.borrow_mut();
|
||||
|
||||
if lattr == rattr {
|
||||
// We retrieve or compute the intersection between the two given words and positions.
|
||||
@ -277,19 +291,33 @@ impl<'a> Search<'a> {
|
||||
}
|
||||
};
|
||||
|
||||
// We instantiate an astar bag Iterator that returns the best paths incrementally,
|
||||
// it means that it will first return the best paths then the next best paths...
|
||||
let astar_iter = AstarBagIter::new(
|
||||
Node::Uninit, // start
|
||||
|n| n.successors(&union_positions, &mut contains_documents), // successors
|
||||
|_| 0, // heuristic
|
||||
|n| n.is_complete(&union_positions), // success
|
||||
);
|
||||
|
||||
let mut documents = Vec::new();
|
||||
let mut iter = BestProximity::new(union_positions);
|
||||
while let Some((_proximity, mut positions)) = iter.next(|l, r| contains_documents(l, r, &mut union_cache, &candidates)) {
|
||||
positions.sort_unstable_by(|a, b| a.iter().cmp(b.iter()));
|
||||
for (paths, proximity) in astar_iter {
|
||||
|
||||
let mut union_cache = union_cache.borrow_mut();
|
||||
let mut candidates = candidates.borrow_mut();
|
||||
|
||||
let mut positions: Vec<Vec<_>> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect();
|
||||
positions.sort_unstable();
|
||||
|
||||
debug!("Found {} positions with a proximity of {}", positions.len(), proximity);
|
||||
|
||||
let mut same_proximity_union = RoaringBitmap::default();
|
||||
for positions in positions {
|
||||
|
||||
// Precompute the potentially missing unions
|
||||
positions.iter().enumerate().for_each(|(word, pos)| {
|
||||
union_cache.entry((word, pos)).or_insert_with(|| {
|
||||
union_cache.entry((word, *pos)).or_insert_with(|| {
|
||||
let words = &derived_words[word];
|
||||
Self::union_word_position(rtxn, index, words, pos).unwrap()
|
||||
Self::union_word_position(rtxn, index, words, *pos).unwrap()
|
||||
});
|
||||
});
|
||||
|
||||
@ -297,7 +325,7 @@ impl<'a> Search<'a> {
|
||||
let mut to_intersect: Vec<_> = positions.iter()
|
||||
.enumerate()
|
||||
.map(|(word, pos)| {
|
||||
let docids = union_cache.get(&(word, pos)).unwrap();
|
||||
let docids = union_cache.get(&(word, *pos)).unwrap();
|
||||
(docids.len(), docids)
|
||||
})
|
||||
.collect();
|
||||
|
Loading…
Reference in New Issue
Block a user