diff --git a/Cargo.lock b/Cargo.lock index 465a55817..c3fb8f29f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -85,6 +85,15 @@ dependencies = [ "warp", ] +[[package]] +name = "astar-iter" +version = "0.1.0" +source = "git+https://github.com/Kerollmops/astar-iter#87cb97a11c701f1a6025b72b673a8bfd0ca249a5" +dependencies = [ + "indexmap", + "num-traits", +] + [[package]] name = "atty" version = "0.2.11" @@ -636,6 +645,15 @@ dependencies = [ "tokio-util", ] +[[package]] +name = "hashbrown" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b62f79061a0bc2e046024cb7ba44b08419ed238ecbd9adbd787434b9e8c25" +dependencies = [ + "autocfg 1.0.0", +] + [[package]] name = "headers" version = "0.3.2" @@ -785,11 +803,12 @@ dependencies = [ [[package]] name = "indexmap" -version = "1.4.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c398b2b113b55809ceb9ee3e753fcbac793f1956663f3c36549c1346015c2afe" +checksum = "86b45e59b16c76b11bf9738fd5d38879d3bd28ad292d7b313608becb17ae2df9" dependencies = [ "autocfg 1.0.0", + "hashbrown", ] [[package]] @@ -981,6 +1000,7 @@ dependencies = [ "arc-cache", "askama", "askama_warp", + "astar-iter", "bitpacking", "byteorder", "cow-utils", diff --git a/Cargo.toml b/Cargo.toml index 87c195b27..c34dd6511 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ default-run = "indexer" [dependencies] anyhow = "1.0.28" arc-cache = { git = "https://github.com/Kerollmops/rust-arc-cache.git", rev = "56530f2" } +astar-iter = { git = "https://github.com/Kerollmops/astar-iter" } bitpacking = "0.8.2" byteorder = "1.3.4" cow-utils = "0.1.2" @@ -34,7 +35,7 @@ log = "0.4.8" stderrlog = "0.4.3" # best proximity -indexmap = "1.4.0" +indexmap = "1.5.1" # to implement internally itertools = "0.9.0" diff --git a/src/best_proximity.rs b/src/best_proximity.rs deleted file mode 100644 index 1eb60f167..000000000 --- a/src/best_proximity.rs +++ /dev/null @@ -1,232 +0,0 @@ -use std::cmp; -use std::time::Instant; - -use crate::iter_shortest_paths::astar_bag; -use log::debug; -use roaring::RoaringBitmap; - -const ONE_ATTRIBUTE: u32 = 1000; -const MAX_DISTANCE: u32 = 8; - -fn index_proximity(lhs: u32, rhs: u32) -> u32 { - if lhs <= rhs { - cmp::min(rhs - lhs, MAX_DISTANCE) - } else { - cmp::min((lhs - rhs) + 1, MAX_DISTANCE) - } -} - -pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 { - let (lhs_attr, lhs_index) = extract_position(lhs); - let (rhs_attr, rhs_index) = extract_position(rhs); - if lhs_attr != rhs_attr { MAX_DISTANCE } - else { index_proximity(lhs_index, rhs_index) } -} - -// Returns the attribute and index parts. -pub fn extract_position(position: u32) -> (u32, u32) { - (position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE) -} - -#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -enum Node { - // Is this node is the first node. - Uninit, - Init { - // The layer where this node located. - layer: usize, - // The position where this node is located. - position: u32, - // The total accumulated proximity until this node, used for skipping nodes. - acc_proximity: u32, - // The parent position from the above layer. - parent_position: u32, - }, -} - -impl Node { - // TODO we must skip the successors that have already been seen - // TODO we must skip the successors that doesn't return any documents - // this way we are able to skip entire paths - fn successors(&self, positions: &[RoaringBitmap], best_proximity: u32) -> Vec<(Node, u32)> { - match self { - Node::Uninit => { - positions[0].iter().map(|position| { - (Node::Init { layer: 0, position, acc_proximity: 0, parent_position: 0 }, 0) - }).collect() - }, - // We reached the highest layer - n @ Node::Init { .. } if n.is_complete(positions) => vec![], - Node::Init { layer, position, acc_proximity, .. } => { - positions[layer + 1].iter().filter_map(|p| { - let proximity = positions_proximity(*position, p); - let node = Node::Init { - layer: layer + 1, - position: p, - acc_proximity: acc_proximity + proximity, - parent_position: *position, - }; - // We do not produce the nodes we have already seen in previous iterations loops. - if node.is_complete(positions) && acc_proximity + proximity < best_proximity { - None - } else { - Some((node, proximity)) - } - }).collect() - } - } - } - - fn is_complete(&self, positions: &[RoaringBitmap]) -> bool { - match self { - Node::Uninit => false, - Node::Init { layer, .. } => *layer == positions.len() - 1, - } - } - - fn position(&self) -> Option { - match self { - Node::Uninit => None, - Node::Init { position, .. } => Some(*position), - } - } - - fn proximity(&self) -> u32 { - match self { - Node::Uninit => 0, - Node::Init { layer, position, acc_proximity, parent_position } => { - if layer.checked_sub(1).is_some() { - acc_proximity + positions_proximity(*position, *parent_position) - } else { - 0 - } - }, - } - } - - fn is_reachable(&self, contains_documents: &mut F) -> bool - where F: FnMut((usize, u32), (usize, u32)) -> bool, - { - match self { - Node::Uninit => true, - Node::Init { layer, position, parent_position, .. } => { - match layer.checked_sub(1) { - Some(parent_layer) => { - (contains_documents)((parent_layer, *parent_position), (*layer, *position)) - }, - None => true, - } - }, - } - } -} - -pub struct BestProximity { - positions: Vec, - best_proximity: u32, -} - -impl BestProximity { - pub fn new(positions: Vec) -> BestProximity { - let best_proximity = (positions.len() as u32).saturating_sub(1); - BestProximity { positions, best_proximity } - } -} - -impl BestProximity { - pub fn next(&mut self, mut contains_documents: F) -> Option<(u32, Vec)> - where F: FnMut((usize, u32), (usize, u32)) -> bool, - { - let before = Instant::now(); - - if self.best_proximity == self.positions.len() as u32 * MAX_DISTANCE { - return None; - } - - let BestProximity { positions, best_proximity } = self; - - let result = astar_bag( - &Node::Uninit, // start - |n| n.successors(&positions, *best_proximity), - |_| 0, // heuristic - |n| { // success - let c = n.is_complete(&positions) && n.proximity() >= *best_proximity; - if n.is_reachable(&mut contains_documents) { Some(c) } else { None } - }, - ); - - debug!("BestProximity::next() took {:.02?}", before.elapsed()); - - match result { - Some((paths, proximity)) => { - self.best_proximity = proximity + 1; - // We retrieve the last path that we convert into a Vec - let paths: Vec<_> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect(); - debug!("result: {} {:?}", proximity, paths); - Some((proximity, paths)) - }, - None => { - debug!("result: {:?}", None as Option<()>); - self.best_proximity += 1; - None - }, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::iter::FromIterator; - - fn sort(mut val: (u32, Vec)) -> (u32, Vec) { - val.1.sort_unstable(); - val - } - - #[test] - fn same_attribute() { - let positions = vec![ - RoaringBitmap::from_iter(vec![0, 2, 3, 4 ]), - RoaringBitmap::from_iter(vec![ 1, ]), - RoaringBitmap::from_iter(vec![ 3, 6]), - ]; - let mut iter = BestProximity::new(positions); - let f = |_, _| true; - - assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3 - assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4 - assert_eq!(iter.next(f), Some((3+2, vec![RoaringBitmap::from_iter(vec![3, 1, 3])]))); // 5 - assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6]), RoaringBitmap::from_iter(vec![4, 1, 3])]))); // 6 - assert_eq!(iter.next(f), Some((2+5, vec![RoaringBitmap::from_iter(vec![2, 1, 6])]))); // 7 - assert_eq!(iter.next(f), Some((3+5, vec![RoaringBitmap::from_iter(vec![3, 1, 6])]))); // 8 - assert_eq!(iter.next(f), Some((4+5, vec![RoaringBitmap::from_iter(vec![4, 1, 6])]))); // 9 - assert_eq!(iter.next(f), None); - } - - #[test] - fn different_attributes() { - let positions = vec![ - RoaringBitmap::from_iter(vec![0, 2, 1000, 1001, 2000 ]), - RoaringBitmap::from_iter(vec![ 1, 1000, 2001 ]), - RoaringBitmap::from_iter(vec![ 3, 6, 2002, 3000]), - ]; - let mut iter = BestProximity::new(positions); - let f = |_, _| true; - - assert_eq!(iter.next(f), Some((1+1, vec![RoaringBitmap::from_iter(vec![2000, 2001, 2002])]))); // 2 - assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3 - assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4 - assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6])]))); // 6 - // We ignore others here... - } - - #[test] - fn easy_proximities() { - fn slice_proximity(positions: &[u32]) -> u32 { - positions.windows(2).map(|ps| positions_proximity(ps[0], ps[1])).sum::() - } - - assert_eq!(slice_proximity(&[1000, 1000, 2002]), 8); - } -} diff --git a/src/iter_shortest_paths.rs b/src/iter_shortest_paths.rs deleted file mode 100644 index f993ea674..000000000 --- a/src/iter_shortest_paths.rs +++ /dev/null @@ -1,204 +0,0 @@ -use std::cmp::Ordering; -use std::collections::{BinaryHeap, HashSet}; -use std::hash::Hash; -use std::usize; - -use indexmap::map::Entry::{Occupied, Vacant}; -use indexmap::IndexMap; - -pub fn astar_bag( - start: &N, - mut successors: FN, - mut heuristic: FH, - mut success: FS, -) -> Option<(AstarSolution, u32)> -where - N: Eq + Hash + Clone, - FN: FnMut(&N) -> IN, - IN: IntoIterator, - FH: FnMut(&N) -> u32, - FS: FnMut(&N) -> Option, -{ - let mut to_see = BinaryHeap::new(); - let mut min_cost = None; - let mut sinks = HashSet::new(); - to_see.push(SmallestCostHolder { - estimated_cost: heuristic(start), - cost: 0, - index: 0, - }); - let mut parents: IndexMap, u32)> = IndexMap::new(); - parents.insert(start.clone(), (HashSet::new(), 0)); - while let Some(SmallestCostHolder { cost, index, estimated_cost, .. }) = to_see.pop() { - if let Some(min_cost) = min_cost { - if estimated_cost > min_cost { - break; - } - } - let successors = { - let (node, &(_, c)) = parents.get_index(index).unwrap(); - // We check that the node is even reachable and if so if it is an answer. - // If this node is unreachable we skip it. - match success(node) { - Some(success) => if success { - min_cost = Some(cost); - sinks.insert(index); - }, - None => continue, - } - - // We may have inserted a node several time into the binary heap if we found - // a better way to access it. Ensure that we are currently dealing with the - // best path and discard the others. - if cost > c { - continue; - } - successors(node) - }; - for (successor, move_cost) in successors { - let new_cost = cost + move_cost; - let h; // heuristic(&successor) - let n; // index for successor - match parents.entry(successor) { - Vacant(e) => { - h = heuristic(e.key()); - n = e.index(); - let mut p = HashSet::new(); - p.insert(index); - e.insert((p, new_cost)); - } - Occupied(mut e) => { - if e.get().1 > new_cost { - h = heuristic(e.key()); - n = e.index(); - let s = e.get_mut(); - s.0.clear(); - s.0.insert(index); - s.1 = new_cost; - } else { - if e.get().1 == new_cost { - // New parent with an identical cost, this is not - // considered as an insertion. - e.get_mut().0.insert(index); - } - continue; - } - } - } - - to_see.push(SmallestCostHolder { - estimated_cost: new_cost + h, - cost: new_cost, - index: n, - }); - } - } - - min_cost.map(|cost| { - let parents = parents - .into_iter() - .map(|(k, (ps, _))| (k, ps.into_iter().collect())) - .collect(); - ( - AstarSolution { - sinks: sinks.into_iter().collect(), - parents, - current: vec![], - terminated: false, - }, - cost, - ) - }) -} - -struct SmallestCostHolder { - estimated_cost: K, - cost: K, - index: usize, -} - -impl PartialEq for SmallestCostHolder { - fn eq(&self, other: &Self) -> bool { - self.estimated_cost.eq(&other.estimated_cost) && self.cost.eq(&other.cost) - } -} - -impl Eq for SmallestCostHolder {} - -impl PartialOrd for SmallestCostHolder { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for SmallestCostHolder { - fn cmp(&self, other: &Self) -> Ordering { - match other.estimated_cost.cmp(&self.estimated_cost) { - Ordering::Equal => self.cost.cmp(&other.cost), - s => s, - } - } -} - -/// Iterator structure created by the `astar_bag` function. -#[derive(Clone)] -pub struct AstarSolution { - sinks: Vec, - parents: Vec<(N, Vec)>, - current: Vec>, - terminated: bool, -} - -impl AstarSolution { - fn complete(&mut self) { - loop { - let ps = match self.current.last() { - None => self.sinks.clone(), - Some(last) => { - let &top = last.last().unwrap(); - self.parents(top).clone() - } - }; - if ps.is_empty() { - break; - } - self.current.push(ps); - } - } - - fn next_vec(&mut self) { - while self.current.last().map(Vec::len) == Some(1) { - self.current.pop(); - } - self.current.last_mut().map(Vec::pop); - } - - fn node(&self, i: usize) -> &N { - &self.parents[i].0 - } - - fn parents(&self, i: usize) -> &Vec { - &self.parents[i].1 - } -} - -impl Iterator for AstarSolution { - type Item = Vec; - - fn next(&mut self) -> Option { - if self.terminated { - return None; - } - self.complete(); - let path = self - .current - .iter() - .rev() - .map(|v| v.last().cloned().unwrap()) - .map(|i| self.node(i).clone()) - .collect::>(); - self.next_vec(); - self.terminated = self.current.is_empty(); - Some(path) - } -} diff --git a/src/lib.rs b/src/lib.rs index d24f5f64e..533d342ec 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,6 @@ -mod best_proximity; mod criterion; mod heed_codec; -mod iter_shortest_paths; +mod node; mod query_tokens; mod search; mod transitive_arc; diff --git a/src/node.rs b/src/node.rs new file mode 100644 index 000000000..1779c821c --- /dev/null +++ b/src/node.rs @@ -0,0 +1,104 @@ +use std::cmp; +use roaring::RoaringBitmap; + +const ONE_ATTRIBUTE: u32 = 1000; +const MAX_DISTANCE: u32 = 8; + +fn index_proximity(lhs: u32, rhs: u32) -> u32 { + if lhs <= rhs { + cmp::min(rhs - lhs, MAX_DISTANCE) + } else { + cmp::min((lhs - rhs) + 1, MAX_DISTANCE) + } +} + +pub fn positions_proximity(lhs: u32, rhs: u32) -> u32 { + let (lhs_attr, lhs_index) = extract_position(lhs); + let (rhs_attr, rhs_index) = extract_position(rhs); + if lhs_attr != rhs_attr { MAX_DISTANCE } + else { index_proximity(lhs_index, rhs_index) } +} + +// Returns the attribute and index parts. +pub fn extract_position(position: u32) -> (u32, u32) { + (position / ONE_ATTRIBUTE, position % ONE_ATTRIBUTE) +} + +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Node { + // Is this node is the first node. + Uninit, + Init { + // The layer where this node located. + layer: usize, + // The position where this node is located. + position: u32, + // The parent position from the above layer. + parent_position: u32, + }, +} + +impl Node { + // TODO we must skip the successors that have already been seen + // TODO we must skip the successors that doesn't return any documents + // this way we are able to skip entire paths + pub fn successors(&self, positions: &[RoaringBitmap], contains_documents: &mut F) -> Vec<(Node, u32)> + where F: FnMut((usize, u32), (usize, u32)) -> bool, + { + match self { + Node::Uninit => { + positions[0].iter().map(|position| { + (Node::Init { layer: 0, position, parent_position: 0 }, 0) + }).collect() + }, + // We reached the highest layer + n @ Node::Init { .. } if n.is_complete(positions) => vec![], + Node::Init { layer, position, .. } => { + positions[layer + 1].iter().filter_map(|p| { + let proximity = positions_proximity(*position, p); + let node = Node::Init { + layer: layer + 1, + position: p, + parent_position: *position, + }; + // We do not produce the nodes we have already seen in previous iterations loops. + if node.is_reachable(contains_documents) { + Some((node, proximity)) + } else { + None + } + }).collect() + } + } + } + + pub fn is_complete(&self, positions: &[RoaringBitmap]) -> bool { + match self { + Node::Uninit => false, + Node::Init { layer, .. } => *layer == positions.len() - 1, + } + } + + pub fn position(&self) -> Option { + match self { + Node::Uninit => None, + Node::Init { position, .. } => Some(*position), + } + } + + pub fn is_reachable(&self, contains_documents: &mut F) -> bool + where F: FnMut((usize, u32), (usize, u32)) -> bool, + { + match self { + Node::Uninit => true, + Node::Init { layer, position, parent_position, .. } => { + match layer.checked_sub(1) { + Some(parent_layer) => { + (contains_documents)((parent_layer, *parent_position), (*layer, *position)) + }, + None => true, + } + }, + } + } +} diff --git a/src/search.rs b/src/search.rs index 91be3776d..34772dbcb 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,14 +1,18 @@ +use std::cell::RefCell; use std::collections::{HashMap, HashSet}; +use std::rc::Rc; +use astar_iter::AstarBagIter; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::DFA; use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; +use log::debug; use once_cell::sync::Lazy; use roaring::RoaringBitmap; +use crate::node::{self, Node}; use crate::query_tokens::{QueryTokens, QueryToken}; use crate::{Index, DocumentId, Position, Attribute}; -use crate::best_proximity::{self, BestProximity}; // Building these factories is not free. static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); @@ -214,20 +218,30 @@ impl<'a> Search<'a> { }; let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?; - let mut candidates = Self::compute_candidates(rtxn, index, &derived_words)?; + let candidates = Self::compute_candidates(rtxn, index, &derived_words)?; - let mut union_cache = HashMap::new(); + let union_cache = HashMap::new(); let mut intersect_cache = HashMap::new(); let mut attribute_union_cache = HashMap::new(); let mut attribute_intersect_cache = HashMap::new(); + let candidates = Rc::new(RefCell::new(candidates)); + let union_cache = Rc::new(RefCell::new(union_cache)); + // Returns `true` if there is documents in common between the two words and positions given. - let mut contains_documents = |(lword, lpos), (rword, rpos), union_cache: &mut HashMap<_, _>, candidates: &RoaringBitmap| { + // TODO move this closure to a better place. + let candidates_cloned = candidates.clone(); + let union_cache_cloned = union_cache.clone(); + let mut contains_documents = |(lword, lpos), (rword, rpos)| { if lpos == rpos { return false } - let (lattr, _) = best_proximity::extract_position(lpos); - let (rattr, _) = best_proximity::extract_position(rpos); + // TODO move this function to a better place. + let (lattr, _) = node::extract_position(lpos); + let (rattr, _) = node::extract_position(rpos); + + let candidates = &candidates_cloned.borrow(); + let mut union_cache = union_cache_cloned.borrow_mut(); if lattr == rattr { // We retrieve or compute the intersection between the two given words and positions. @@ -277,19 +291,33 @@ impl<'a> Search<'a> { } }; + // We instantiate an astar bag Iterator that returns the best paths incrementally, + // it means that it will first return the best paths then the next best paths... + let astar_iter = AstarBagIter::new( + Node::Uninit, // start + |n| n.successors(&union_positions, &mut contains_documents), // successors + |_| 0, // heuristic + |n| n.is_complete(&union_positions), // success + ); + let mut documents = Vec::new(); - let mut iter = BestProximity::new(union_positions); - while let Some((_proximity, mut positions)) = iter.next(|l, r| contains_documents(l, r, &mut union_cache, &candidates)) { - positions.sort_unstable_by(|a, b| a.iter().cmp(b.iter())); + for (paths, proximity) in astar_iter { + + let mut union_cache = union_cache.borrow_mut(); + let mut candidates = candidates.borrow_mut(); + + let mut positions: Vec> = paths.map(|p| p.iter().filter_map(Node::position).collect()).collect(); + positions.sort_unstable(); + + debug!("Found {} positions with a proximity of {}", positions.len(), proximity); let mut same_proximity_union = RoaringBitmap::default(); for positions in positions { - // Precompute the potentially missing unions positions.iter().enumerate().for_each(|(word, pos)| { - union_cache.entry((word, pos)).or_insert_with(|| { + union_cache.entry((word, *pos)).or_insert_with(|| { let words = &derived_words[word]; - Self::union_word_position(rtxn, index, words, pos).unwrap() + Self::union_word_position(rtxn, index, words, *pos).unwrap() }); }); @@ -297,7 +325,7 @@ impl<'a> Search<'a> { let mut to_intersect: Vec<_> = positions.iter() .enumerate() .map(|(word, pos)| { - let docids = union_cache.get(&(word, pos)).unwrap(); + let docids = union_cache.get(&(word, *pos)).unwrap(); (docids.len(), docids) }) .collect();