mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-22 21:04:27 +01:00
Introduce the AstarBagIter that iterates through best paths
This commit is contained in:
parent
7dc594ba4d
commit
1e358e3ae8
24
Cargo.lock
generated
24
Cargo.lock
generated
@ -85,6 +85,15 @@ dependencies = [
|
|||||||
"warp",
|
"warp",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "astar-iter"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "git+https://github.com/Kerollmops/astar-iter#87cb97a11c701f1a6025b72b673a8bfd0ca249a5"
|
||||||
|
dependencies = [
|
||||||
|
"indexmap",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "atty"
|
name = "atty"
|
||||||
version = "0.2.11"
|
version = "0.2.11"
|
||||||
@ -636,6 +645,15 @@ dependencies = [
|
|||||||
"tokio-util",
|
"tokio-util",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.8.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg 1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "headers"
|
name = "headers"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
@ -785,11 +803,12 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "1.4.0"
|
version = "1.5.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c398b2b113b55809ceb9ee3e753fcbac793f1956663f3c36549c1346015c2afe"
|
checksum = "86b45e59b16c76b11bf9738fd5d38879d3bd28ad292d7b313608becb17ae2df9"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg 1.0.0",
|
"autocfg 1.0.0",
|
||||||
|
"hashbrown",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -981,6 +1000,7 @@ dependencies = [
|
|||||||
"arc-cache",
|
"arc-cache",
|
||||||
"askama",
|
"askama",
|
||||||
"askama_warp",
|
"askama_warp",
|
||||||
|
"astar-iter",
|
||||||
"bitpacking",
|
"bitpacking",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"cow-utils",
|
"cow-utils",
|
||||||
|
@ -8,6 +8,7 @@ default-run = "indexer"
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.28"
|
anyhow = "1.0.28"
|
||||||
arc-cache = { git = "https://github.com/Kerollmops/rust-arc-cache.git", rev = "56530f2" }
|
arc-cache = { git = "https://github.com/Kerollmops/rust-arc-cache.git", rev = "56530f2" }
|
||||||
|
astar-iter = { git = "https://github.com/Kerollmops/astar-iter" }
|
||||||
bitpacking = "0.8.2"
|
bitpacking = "0.8.2"
|
||||||
byteorder = "1.3.4"
|
byteorder = "1.3.4"
|
||||||
cow-utils = "0.1.2"
|
cow-utils = "0.1.2"
|
||||||
@ -34,7 +35,7 @@ log = "0.4.8"
|
|||||||
stderrlog = "0.4.3"
|
stderrlog = "0.4.3"
|
||||||
|
|
||||||
# best proximity
|
# best proximity
|
||||||
indexmap = "1.4.0"
|
indexmap = "1.5.1"
|
||||||
|
|
||||||
# to implement internally
|
# to implement internally
|
||||||
itertools = "0.9.0"
|
itertools = "0.9.0"
|
||||||
|
@ -1,232 +0,0 @@
|
|||||||
use std::cmp;
|
|
||||||
use std::time::Instant;
|
|
||||||
|
|
||||||
use crate::iter_shortest_paths::astar_bag;
|
|
||||||
use log::debug;
|
|
||||||
use roaring::RoaringBitmap;
|
|
||||||
|
|
||||||
const ONE_ATTRIBUTE: u32 = 1000;
|
|
||||||
const MAX_DISTANCE: u32 = 8;
|
|
||||||
|
|
||||||
fn index_proximity(lhs: u32, rhs: u32) -> u32 {
|
|
||||||
if lhs <= rhs {
|
|
||||||
cmp::min(rhs - lhs, MAX_DISTANCE)
|
|
||||||
} else {
|
|
||||||
cmp::min((lhs - rhs) + 1, MAX_DISTANCE)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 {
|
|
||||||
let (lhs_attr, lhs_index) = extract_position(lhs);
|
|
||||||
let (rhs_attr, rhs_index) = extract_position(rhs);
|
|
||||||
if lhs_attr != rhs_attr { MAX_DISTANCE }
|
|
||||||
else { index_proximity(lhs_index, rhs_index) }
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the attribute and index parts.
|
|
||||||
pub fn extract_position(position: u32) -> (u32, u32) {
|
|
||||||
(position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
|
||||||
enum Node {
|
|
||||||
// Is this node is the first node.
|
|
||||||
Uninit,
|
|
||||||
Init {
|
|
||||||
// The layer where this node located.
|
|
||||||
layer: usize,
|
|
||||||
// The position where this node is located.
|
|
||||||
position: u32,
|
|
||||||
// The total accumulated proximity until this node, used for skipping nodes.
|
|
||||||
acc_proximity: u32,
|
|
||||||
// The parent position from the above layer.
|
|
||||||
parent_position: u32,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Node {
|
|
||||||
// TODO we must skip the successors that have already been seen
|
|
||||||
// TODO we must skip the successors that doesn't return any documents
|
|
||||||
// this way we are able to skip entire paths
|
|
||||||
fn successors(&self, positions: &[RoaringBitmap], best_proximity: u32) -> Vec<(Node, u32)> {
|
|
||||||
match self {
|
|
||||||
Node::Uninit => {
|
|
||||||
positions[0].iter().map(|position| {
|
|
||||||
(Node::Init { layer: 0, position, acc_proximity: 0, parent_position: 0 }, 0)
|
|
||||||
}).collect()
|
|
||||||
},
|
|
||||||
// We reached the highest layer
|
|
||||||
n @ Node::Init { .. } if n.is_complete(positions) => vec![],
|
|
||||||
Node::Init { layer, position, acc_proximity, .. } => {
|
|
||||||
positions[layer + 1].iter().filter_map(|p| {
|
|
||||||
let proximity = positions_proximity(*position, p);
|
|
||||||
let node = Node::Init {
|
|
||||||
layer: layer + 1,
|
|
||||||
position: p,
|
|
||||||
acc_proximity: acc_proximity + proximity,
|
|
||||||
parent_position: *position,
|
|
||||||
};
|
|
||||||
// We do not produce the nodes we have already seen in previous iterations loops.
|
|
||||||
if node.is_complete(positions) && acc_proximity + proximity < best_proximity {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
Some((node, proximity))
|
|
||||||
}
|
|
||||||
}).collect()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_complete(&self, positions: &[RoaringBitmap]) -> bool {
|
|
||||||
match self {
|
|
||||||
Node::Uninit => false,
|
|
||||||
Node::Init { layer, .. } => *layer == positions.len() - 1,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn position(&self) -> Option<u32> {
|
|
||||||
match self {
|
|
||||||
Node::Uninit => None,
|
|
||||||
Node::Init { position, .. } => Some(*position),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn proximity(&self) -> u32 {
|
|
||||||
match self {
|
|
||||||
Node::Uninit => 0,
|
|
||||||
Node::Init { layer, position, acc_proximity, parent_position } => {
|
|
||||||
if layer.checked_sub(1).is_some() {
|
|
||||||
acc_proximity + positions_proximity(*position, *parent_position)
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_reachable<F>(&self, contains_documents: &mut F) -> bool
|
|
||||||
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
|
||||||
{
|
|
||||||
match self {
|
|
||||||
Node::Uninit => true,
|
|
||||||
Node::Init { layer, position, parent_position, .. } => {
|
|
||||||
match layer.checked_sub(1) {
|
|
||||||
Some(parent_layer) => {
|
|
||||||
(contains_documents)((parent_layer, *parent_position), (*layer, *position))
|
|
||||||
},
|
|
||||||
None => true,
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub struct BestProximity {
|
|
||||||
positions: Vec<RoaringBitmap>,
|
|
||||||
best_proximity: u32,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BestProximity {
|
|
||||||
pub fn new(positions: Vec<RoaringBitmap>) -> BestProximity {
|
|
||||||
let best_proximity = (positions.len() as u32).saturating_sub(1);
|
|
||||||
BestProximity { positions, best_proximity }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl BestProximity {
|
|
||||||
pub fn next<F>(&mut self, mut contains_documents: F) -> Option<(u32, Vec<RoaringBitmap>)>
|
|
||||||
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
|
||||||
{
|
|
||||||
let before = Instant::now();
|
|
||||||
|
|
||||||
if self.best_proximity == self.positions.len() as u32 * MAX_DISTANCE {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
|
|
||||||
let BestProximity { positions, best_proximity } = self;
|
|
||||||
|
|
||||||
let result = astar_bag(
|
|
||||||
&Node::Uninit, // start
|
|
||||||
|n| n.successors(&positions, *best_proximity),
|
|
||||||
|_| 0, // heuristic
|
|
||||||
|n| { // success
|
|
||||||
let c = n.is_complete(&positions) && n.proximity() >= *best_proximity;
|
|
||||||
if n.is_reachable(&mut contains_documents) { Some(c) } else { None }
|
|
||||||
},
|
|
||||||
);
|
|
||||||
|
|
||||||
debug!("BestProximity::next() took {:.02?}", before.elapsed());
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Some((paths, proximity)) => {
|
|
||||||
self.best_proximity = proximity + 1;
|
|
||||||
// We retrieve the last path that we convert into a Vec
|
|
||||||
let paths: Vec<_> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect();
|
|
||||||
debug!("result: {} {:?}", proximity, paths);
|
|
||||||
Some((proximity, paths))
|
|
||||||
},
|
|
||||||
None => {
|
|
||||||
debug!("result: {:?}", None as Option<()>);
|
|
||||||
self.best_proximity += 1;
|
|
||||||
None
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
use std::iter::FromIterator;
|
|
||||||
|
|
||||||
fn sort<T: Ord>(mut val: (u32, Vec<T>)) -> (u32, Vec<T>) {
|
|
||||||
val.1.sort_unstable();
|
|
||||||
val
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn same_attribute() {
|
|
||||||
let positions = vec![
|
|
||||||
RoaringBitmap::from_iter(vec![0, 2, 3, 4 ]),
|
|
||||||
RoaringBitmap::from_iter(vec![ 1, ]),
|
|
||||||
RoaringBitmap::from_iter(vec![ 3, 6]),
|
|
||||||
];
|
|
||||||
let mut iter = BestProximity::new(positions);
|
|
||||||
let f = |_, _| true;
|
|
||||||
|
|
||||||
assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3
|
|
||||||
assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4
|
|
||||||
assert_eq!(iter.next(f), Some((3+2, vec![RoaringBitmap::from_iter(vec![3, 1, 3])]))); // 5
|
|
||||||
assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6]), RoaringBitmap::from_iter(vec![4, 1, 3])]))); // 6
|
|
||||||
assert_eq!(iter.next(f), Some((2+5, vec![RoaringBitmap::from_iter(vec![2, 1, 6])]))); // 7
|
|
||||||
assert_eq!(iter.next(f), Some((3+5, vec![RoaringBitmap::from_iter(vec![3, 1, 6])]))); // 8
|
|
||||||
assert_eq!(iter.next(f), Some((4+5, vec![RoaringBitmap::from_iter(vec![4, 1, 6])]))); // 9
|
|
||||||
assert_eq!(iter.next(f), None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn different_attributes() {
|
|
||||||
let positions = vec![
|
|
||||||
RoaringBitmap::from_iter(vec![0, 2, 1000, 1001, 2000 ]),
|
|
||||||
RoaringBitmap::from_iter(vec![ 1, 1000, 2001 ]),
|
|
||||||
RoaringBitmap::from_iter(vec![ 3, 6, 2002, 3000]),
|
|
||||||
];
|
|
||||||
let mut iter = BestProximity::new(positions);
|
|
||||||
let f = |_, _| true;
|
|
||||||
|
|
||||||
assert_eq!(iter.next(f), Some((1+1, vec![RoaringBitmap::from_iter(vec![2000, 2001, 2002])]))); // 2
|
|
||||||
assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3
|
|
||||||
assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4
|
|
||||||
assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6])]))); // 6
|
|
||||||
// We ignore others here...
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn easy_proximities() {
|
|
||||||
fn slice_proximity(positions: &[u32]) -> u32 {
|
|
||||||
positions.windows(2).map(|ps| positions_proximity(ps[0], ps[1])).sum::<u32>()
|
|
||||||
}
|
|
||||||
|
|
||||||
assert_eq!(slice_proximity(&[1000, 1000, 2002]), 8);
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,204 +0,0 @@
|
|||||||
use std::cmp::Ordering;
|
|
||||||
use std::collections::{BinaryHeap, HashSet};
|
|
||||||
use std::hash::Hash;
|
|
||||||
use std::usize;
|
|
||||||
|
|
||||||
use indexmap::map::Entry::{Occupied, Vacant};
|
|
||||||
use indexmap::IndexMap;
|
|
||||||
|
|
||||||
pub fn astar_bag<N, FN, IN, FH, FS>(
|
|
||||||
start: &N,
|
|
||||||
mut successors: FN,
|
|
||||||
mut heuristic: FH,
|
|
||||||
mut success: FS,
|
|
||||||
) -> Option<(AstarSolution<N>, u32)>
|
|
||||||
where
|
|
||||||
N: Eq + Hash + Clone,
|
|
||||||
FN: FnMut(&N) -> IN,
|
|
||||||
IN: IntoIterator<Item = (N, u32)>,
|
|
||||||
FH: FnMut(&N) -> u32,
|
|
||||||
FS: FnMut(&N) -> Option<bool>,
|
|
||||||
{
|
|
||||||
let mut to_see = BinaryHeap::new();
|
|
||||||
let mut min_cost = None;
|
|
||||||
let mut sinks = HashSet::new();
|
|
||||||
to_see.push(SmallestCostHolder {
|
|
||||||
estimated_cost: heuristic(start),
|
|
||||||
cost: 0,
|
|
||||||
index: 0,
|
|
||||||
});
|
|
||||||
let mut parents: IndexMap<N, (HashSet<usize>, u32)> = IndexMap::new();
|
|
||||||
parents.insert(start.clone(), (HashSet::new(), 0));
|
|
||||||
while let Some(SmallestCostHolder { cost, index, estimated_cost, .. }) = to_see.pop() {
|
|
||||||
if let Some(min_cost) = min_cost {
|
|
||||||
if estimated_cost > min_cost {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let successors = {
|
|
||||||
let (node, &(_, c)) = parents.get_index(index).unwrap();
|
|
||||||
// We check that the node is even reachable and if so if it is an answer.
|
|
||||||
// If this node is unreachable we skip it.
|
|
||||||
match success(node) {
|
|
||||||
Some(success) => if success {
|
|
||||||
min_cost = Some(cost);
|
|
||||||
sinks.insert(index);
|
|
||||||
},
|
|
||||||
None => continue,
|
|
||||||
}
|
|
||||||
|
|
||||||
// We may have inserted a node several time into the binary heap if we found
|
|
||||||
// a better way to access it. Ensure that we are currently dealing with the
|
|
||||||
// best path and discard the others.
|
|
||||||
if cost > c {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
successors(node)
|
|
||||||
};
|
|
||||||
for (successor, move_cost) in successors {
|
|
||||||
let new_cost = cost + move_cost;
|
|
||||||
let h; // heuristic(&successor)
|
|
||||||
let n; // index for successor
|
|
||||||
match parents.entry(successor) {
|
|
||||||
Vacant(e) => {
|
|
||||||
h = heuristic(e.key());
|
|
||||||
n = e.index();
|
|
||||||
let mut p = HashSet::new();
|
|
||||||
p.insert(index);
|
|
||||||
e.insert((p, new_cost));
|
|
||||||
}
|
|
||||||
Occupied(mut e) => {
|
|
||||||
if e.get().1 > new_cost {
|
|
||||||
h = heuristic(e.key());
|
|
||||||
n = e.index();
|
|
||||||
let s = e.get_mut();
|
|
||||||
s.0.clear();
|
|
||||||
s.0.insert(index);
|
|
||||||
s.1 = new_cost;
|
|
||||||
} else {
|
|
||||||
if e.get().1 == new_cost {
|
|
||||||
// New parent with an identical cost, this is not
|
|
||||||
// considered as an insertion.
|
|
||||||
e.get_mut().0.insert(index);
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
to_see.push(SmallestCostHolder {
|
|
||||||
estimated_cost: new_cost + h,
|
|
||||||
cost: new_cost,
|
|
||||||
index: n,
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
min_cost.map(|cost| {
|
|
||||||
let parents = parents
|
|
||||||
.into_iter()
|
|
||||||
.map(|(k, (ps, _))| (k, ps.into_iter().collect()))
|
|
||||||
.collect();
|
|
||||||
(
|
|
||||||
AstarSolution {
|
|
||||||
sinks: sinks.into_iter().collect(),
|
|
||||||
parents,
|
|
||||||
current: vec![],
|
|
||||||
terminated: false,
|
|
||||||
},
|
|
||||||
cost,
|
|
||||||
)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
struct SmallestCostHolder<K> {
|
|
||||||
estimated_cost: K,
|
|
||||||
cost: K,
|
|
||||||
index: usize,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<K: PartialEq> PartialEq for SmallestCostHolder<K> {
|
|
||||||
fn eq(&self, other: &Self) -> bool {
|
|
||||||
self.estimated_cost.eq(&other.estimated_cost) && self.cost.eq(&other.cost)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<K: PartialEq> Eq for SmallestCostHolder<K> {}
|
|
||||||
|
|
||||||
impl<K: Ord> PartialOrd for SmallestCostHolder<K> {
|
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
|
||||||
Some(self.cmp(other))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<K: Ord> Ord for SmallestCostHolder<K> {
|
|
||||||
fn cmp(&self, other: &Self) -> Ordering {
|
|
||||||
match other.estimated_cost.cmp(&self.estimated_cost) {
|
|
||||||
Ordering::Equal => self.cost.cmp(&other.cost),
|
|
||||||
s => s,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Iterator structure created by the `astar_bag` function.
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct AstarSolution<N> {
|
|
||||||
sinks: Vec<usize>,
|
|
||||||
parents: Vec<(N, Vec<usize>)>,
|
|
||||||
current: Vec<Vec<usize>>,
|
|
||||||
terminated: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<N: Clone + Eq + Hash> AstarSolution<N> {
|
|
||||||
fn complete(&mut self) {
|
|
||||||
loop {
|
|
||||||
let ps = match self.current.last() {
|
|
||||||
None => self.sinks.clone(),
|
|
||||||
Some(last) => {
|
|
||||||
let &top = last.last().unwrap();
|
|
||||||
self.parents(top).clone()
|
|
||||||
}
|
|
||||||
};
|
|
||||||
if ps.is_empty() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
self.current.push(ps);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn next_vec(&mut self) {
|
|
||||||
while self.current.last().map(Vec::len) == Some(1) {
|
|
||||||
self.current.pop();
|
|
||||||
}
|
|
||||||
self.current.last_mut().map(Vec::pop);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn node(&self, i: usize) -> &N {
|
|
||||||
&self.parents[i].0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn parents(&self, i: usize) -> &Vec<usize> {
|
|
||||||
&self.parents[i].1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<N: Clone + Eq + Hash> Iterator for AstarSolution<N> {
|
|
||||||
type Item = Vec<N>;
|
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
|
||||||
if self.terminated {
|
|
||||||
return None;
|
|
||||||
}
|
|
||||||
self.complete();
|
|
||||||
let path = self
|
|
||||||
.current
|
|
||||||
.iter()
|
|
||||||
.rev()
|
|
||||||
.map(|v| v.last().cloned().unwrap())
|
|
||||||
.map(|i| self.node(i).clone())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
self.next_vec();
|
|
||||||
self.terminated = self.current.is_empty();
|
|
||||||
Some(path)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,7 +1,6 @@
|
|||||||
mod best_proximity;
|
|
||||||
mod criterion;
|
mod criterion;
|
||||||
mod heed_codec;
|
mod heed_codec;
|
||||||
mod iter_shortest_paths;
|
mod node;
|
||||||
mod query_tokens;
|
mod query_tokens;
|
||||||
mod search;
|
mod search;
|
||||||
mod transitive_arc;
|
mod transitive_arc;
|
||||||
|
104
src/node.rs
Normal file
104
src/node.rs
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
use std::cmp;
|
||||||
|
use roaring::RoaringBitmap;
|
||||||
|
|
||||||
|
const ONE_ATTRIBUTE: u32 = 1000;
|
||||||
|
const MAX_DISTANCE: u32 = 8;
|
||||||
|
|
||||||
|
fn index_proximity(lhs: u32, rhs: u32) -> u32 {
|
||||||
|
if lhs <= rhs {
|
||||||
|
cmp::min(rhs - lhs, MAX_DISTANCE)
|
||||||
|
} else {
|
||||||
|
cmp::min((lhs - rhs) + 1, MAX_DISTANCE)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 {
|
||||||
|
let (lhs_attr, lhs_index) = extract_position(lhs);
|
||||||
|
let (rhs_attr, rhs_index) = extract_position(rhs);
|
||||||
|
if lhs_attr != rhs_attr { MAX_DISTANCE }
|
||||||
|
else { index_proximity(lhs_index, rhs_index) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns the attribute and index parts.
|
||||||
|
pub fn extract_position(position: u32) -> (u32, u32) {
|
||||||
|
(position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||||
|
pub enum Node {
|
||||||
|
// Is this node is the first node.
|
||||||
|
Uninit,
|
||||||
|
Init {
|
||||||
|
// The layer where this node located.
|
||||||
|
layer: usize,
|
||||||
|
// The position where this node is located.
|
||||||
|
position: u32,
|
||||||
|
// The parent position from the above layer.
|
||||||
|
parent_position: u32,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Node {
|
||||||
|
// TODO we must skip the successors that have already been seen
|
||||||
|
// TODO we must skip the successors that doesn't return any documents
|
||||||
|
// this way we are able to skip entire paths
|
||||||
|
pub fn successors<F>(&self, positions: &[RoaringBitmap], contains_documents: &mut F) -> Vec<(Node, u32)>
|
||||||
|
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Node::Uninit => {
|
||||||
|
positions[0].iter().map(|position| {
|
||||||
|
(Node::Init { layer: 0, position, parent_position: 0 }, 0)
|
||||||
|
}).collect()
|
||||||
|
},
|
||||||
|
// We reached the highest layer
|
||||||
|
n @ Node::Init { .. } if n.is_complete(positions) => vec![],
|
||||||
|
Node::Init { layer, position, .. } => {
|
||||||
|
positions[layer + 1].iter().filter_map(|p| {
|
||||||
|
let proximity = positions_proximity(*position, p);
|
||||||
|
let node = Node::Init {
|
||||||
|
layer: layer + 1,
|
||||||
|
position: p,
|
||||||
|
parent_position: *position,
|
||||||
|
};
|
||||||
|
// We do not produce the nodes we have already seen in previous iterations loops.
|
||||||
|
if node.is_reachable(contains_documents) {
|
||||||
|
Some((node, proximity))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_complete(&self, positions: &[RoaringBitmap]) -> bool {
|
||||||
|
match self {
|
||||||
|
Node::Uninit => false,
|
||||||
|
Node::Init { layer, .. } => *layer == positions.len() - 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn position(&self) -> Option<u32> {
|
||||||
|
match self {
|
||||||
|
Node::Uninit => None,
|
||||||
|
Node::Init { position, .. } => Some(*position),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_reachable<F>(&self, contains_documents: &mut F) -> bool
|
||||||
|
where F: FnMut((usize, u32), (usize, u32)) -> bool,
|
||||||
|
{
|
||||||
|
match self {
|
||||||
|
Node::Uninit => true,
|
||||||
|
Node::Init { layer, position, parent_position, .. } => {
|
||||||
|
match layer.checked_sub(1) {
|
||||||
|
Some(parent_layer) => {
|
||||||
|
(contains_documents)((parent_layer, *parent_position), (*layer, *position))
|
||||||
|
},
|
||||||
|
None => true,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,14 +1,18 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::rc::Rc;
|
||||||
|
|
||||||
|
use astar_iter::AstarBagIter;
|
||||||
use fst::{IntoStreamer, Streamer};
|
use fst::{IntoStreamer, Streamer};
|
||||||
use levenshtein_automata::DFA;
|
use levenshtein_automata::DFA;
|
||||||
use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder;
|
use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder;
|
||||||
|
use log::debug;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
|
|
||||||
|
use crate::node::{self, Node};
|
||||||
use crate::query_tokens::{QueryTokens, QueryToken};
|
use crate::query_tokens::{QueryTokens, QueryToken};
|
||||||
use crate::{Index, DocumentId, Position, Attribute};
|
use crate::{Index, DocumentId, Position, Attribute};
|
||||||
use crate::best_proximity::{self, BestProximity};
|
|
||||||
|
|
||||||
// Building these factories is not free.
|
// Building these factories is not free.
|
||||||
static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
|
static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
|
||||||
@ -214,20 +218,30 @@ impl<'a> Search<'a> {
|
|||||||
};
|
};
|
||||||
|
|
||||||
let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?;
|
let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?;
|
||||||
let mut candidates = Self::compute_candidates(rtxn, index, &derived_words)?;
|
let candidates = Self::compute_candidates(rtxn, index, &derived_words)?;
|
||||||
|
|
||||||
let mut union_cache = HashMap::new();
|
let union_cache = HashMap::new();
|
||||||
let mut intersect_cache = HashMap::new();
|
let mut intersect_cache = HashMap::new();
|
||||||
|
|
||||||
let mut attribute_union_cache = HashMap::new();
|
let mut attribute_union_cache = HashMap::new();
|
||||||
let mut attribute_intersect_cache = HashMap::new();
|
let mut attribute_intersect_cache = HashMap::new();
|
||||||
|
|
||||||
|
let candidates = Rc::new(RefCell::new(candidates));
|
||||||
|
let union_cache = Rc::new(RefCell::new(union_cache));
|
||||||
|
|
||||||
// Returns `true` if there is documents in common between the two words and positions given.
|
// Returns `true` if there is documents in common between the two words and positions given.
|
||||||
let mut contains_documents = |(lword, lpos), (rword, rpos), union_cache: &mut HashMap<_, _>, candidates: &RoaringBitmap| {
|
// TODO move this closure to a better place.
|
||||||
|
let candidates_cloned = candidates.clone();
|
||||||
|
let union_cache_cloned = union_cache.clone();
|
||||||
|
let mut contains_documents = |(lword, lpos), (rword, rpos)| {
|
||||||
if lpos == rpos { return false }
|
if lpos == rpos { return false }
|
||||||
|
|
||||||
let (lattr, _) = best_proximity::extract_position(lpos);
|
// TODO move this function to a better place.
|
||||||
let (rattr, _) = best_proximity::extract_position(rpos);
|
let (lattr, _) = node::extract_position(lpos);
|
||||||
|
let (rattr, _) = node::extract_position(rpos);
|
||||||
|
|
||||||
|
let candidates = &candidates_cloned.borrow();
|
||||||
|
let mut union_cache = union_cache_cloned.borrow_mut();
|
||||||
|
|
||||||
if lattr == rattr {
|
if lattr == rattr {
|
||||||
// We retrieve or compute the intersection between the two given words and positions.
|
// We retrieve or compute the intersection between the two given words and positions.
|
||||||
@ -277,19 +291,33 @@ impl<'a> Search<'a> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// We instantiate an astar bag Iterator that returns the best paths incrementally,
|
||||||
|
// it means that it will first return the best paths then the next best paths...
|
||||||
|
let astar_iter = AstarBagIter::new(
|
||||||
|
Node::Uninit, // start
|
||||||
|
|n| n.successors(&union_positions, &mut contains_documents), // successors
|
||||||
|
|_| 0, // heuristic
|
||||||
|
|n| n.is_complete(&union_positions), // success
|
||||||
|
);
|
||||||
|
|
||||||
let mut documents = Vec::new();
|
let mut documents = Vec::new();
|
||||||
let mut iter = BestProximity::new(union_positions);
|
for (paths, proximity) in astar_iter {
|
||||||
while let Some((_proximity, mut positions)) = iter.next(|l, r| contains_documents(l, r, &mut union_cache, &candidates)) {
|
|
||||||
positions.sort_unstable_by(|a, b| a.iter().cmp(b.iter()));
|
let mut union_cache = union_cache.borrow_mut();
|
||||||
|
let mut candidates = candidates.borrow_mut();
|
||||||
|
|
||||||
|
let mut positions: Vec<Vec<_>> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect();
|
||||||
|
positions.sort_unstable();
|
||||||
|
|
||||||
|
debug!("Found {} positions with a proximity of {}", positions.len(), proximity);
|
||||||
|
|
||||||
let mut same_proximity_union = RoaringBitmap::default();
|
let mut same_proximity_union = RoaringBitmap::default();
|
||||||
for positions in positions {
|
for positions in positions {
|
||||||
|
|
||||||
// Precompute the potentially missing unions
|
// Precompute the potentially missing unions
|
||||||
positions.iter().enumerate().for_each(|(word, pos)| {
|
positions.iter().enumerate().for_each(|(word, pos)| {
|
||||||
union_cache.entry((word, pos)).or_insert_with(|| {
|
union_cache.entry((word, *pos)).or_insert_with(|| {
|
||||||
let words = &derived_words[word];
|
let words = &derived_words[word];
|
||||||
Self::union_word_position(rtxn, index, words, pos).unwrap()
|
Self::union_word_position(rtxn, index, words, *pos).unwrap()
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -297,7 +325,7 @@ impl<'a> Search<'a> {
|
|||||||
let mut to_intersect: Vec<_> = positions.iter()
|
let mut to_intersect: Vec<_> = positions.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.map(|(word, pos)| {
|
.map(|(word, pos)| {
|
||||||
let docids = union_cache.get(&(word, pos)).unwrap();
|
let docids = union_cache.get(&(word, *pos)).unwrap();
|
||||||
(docids.len(), docids)
|
(docids.len(), docids)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
Loading…
Reference in New Issue
Block a user