From 007e647462febad05ab689dd34706aaf68ca3587 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Thu, 1 Oct 2020 16:28:49 +0200 Subject: [PATCH] Introduce the Mdfs Iterator that explore the proximity graph using a mana DFS --- src/lib.rs | 1 + src/mdfs.rs | 158 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/search.rs | 136 +++++-------------------------------------- 3 files changed, 172 insertions(+), 123 deletions(-) create mode 100644 src/mdfs.rs diff --git a/src/lib.rs b/src/lib.rs index 547189aa3..6f90dc287 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ mod criterion; +mod mdfs; mod query_tokens; mod search; pub mod heed_codec; diff --git a/src/mdfs.rs b/src/mdfs.rs new file mode 100644 index 000000000..b5e4b19e0 --- /dev/null +++ b/src/mdfs.rs @@ -0,0 +1,158 @@ +use std::collections::hash_map::Entry::{Occupied, Vacant}; +use std::collections::HashMap; +use std::mem; + +use roaring::RoaringBitmap; +use crate::Index; + +/// A mana depth first search implementation. +pub struct Mdfs<'a> { + index: &'a Index, + rtxn: &'a heed::RoTxn, + words: &'a [(HashMap, RoaringBitmap)], + union_cache: HashMap<(usize, u8), RoaringBitmap>, + candidates: RoaringBitmap, + mana: u32, + max_mana: u32, +} + +impl<'a> Mdfs<'a> { + pub fn new( + index: &'a Index, + rtxn: &'a heed::RoTxn, + words: &'a [(HashMap, RoaringBitmap)], + candidates: RoaringBitmap, + ) -> Mdfs<'a> + { + // Compute the number of pairs (windows) we have for this list of words. + let mana = words.len().checked_sub(1).unwrap_or(0) as u32; + let max_mana = mana * 8; + Mdfs { index, rtxn, words, union_cache: HashMap::new(), candidates, mana, max_mana } + } +} + +impl<'a> Iterator for Mdfs<'a> { + type Item = anyhow::Result; + + fn next(&mut self) -> Option { + // If there is less or only one word therefore the only + // possible documents that we can return are the candidates. + if self.words.len() <= 1 { + if self.candidates.is_empty() { return None } + return Some(Ok(mem::take(&mut self.candidates))); + } + + let mut answer = RoaringBitmap::new(); + while self.mana <= self.max_mana { + let result = mdfs_step( + &self.index, + &self.rtxn, + self.mana, + self.words, + &self.candidates, + &self.candidates, + &mut self.union_cache, + ); + + match result { + Ok(Some(a)) => { + // We remove the answered documents from the list of + // candidates to be sure we don't search for them again. + self.candidates.difference_with(&a); + answer.union_with(&a); + }, + Ok(None) => { + // We found the last iteration for this amount of mana that gives nothing, + // we can now store that the next mana to use for the loop is incremented. + self.mana = self.mana + 1; + // If the answer is empty it means that we found nothing for this amount + // of mana therefore we continue with a bigger mana. + if !answer.is_empty() { + // Otherwise we return the answer. + return Some(Ok(answer)); + } + }, + Err(e) => return Some(Err(e)), + } + } + + None + } +} + +fn mdfs_step( + index: &Index, + rtxn: &heed::RoTxn, + mana: u32, + words: &[(HashMap, RoaringBitmap)], + candidates: &RoaringBitmap, + parent_docids: &RoaringBitmap, + union_cache: &mut HashMap<(usize, u8), RoaringBitmap>, +) -> anyhow::Result> +{ + use std::cmp::{min, max}; + + let (words1, words2) = (&words[0].0, &words[1].0); + let pairs = words_pair_combinations(words1, words2); + let tail = &words[1..]; + let nb_children = tail.len() as u32 - 1; + + // The minimum amount of mana that you must consume is at least 1 and the + // amount of mana that your children can consume. Because the last child must + // consume the remaining mana, it is mandatory that there not too much at the end. + let min_proximity = max(1, mana.saturating_sub(nb_children * 8)) as u8; + + // The maximum amount of mana that you can use is 8 or the remaining amount of + // mana minus your children, as you can't just consume all the mana, + // your children must have at least 1 mana. + let max_proximity = min(8, mana - nb_children) as u8; + + for proximity in min_proximity..=max_proximity { + let mut docids = match union_cache.entry((words.len(), proximity)) { + Occupied(entry) => entry.get().clone(), + Vacant(entry) => { + let mut docids = RoaringBitmap::new(); + if proximity == 8 { + docids = candidates.clone(); + } else { + for (w1, w2) in pairs.iter().cloned() { + let key = (w1, w2, proximity); + if let Some(di) = index.word_pair_proximity_docids.get(rtxn, &key)? { + docids.union_with(&di); + } + } + } + entry.insert(docids).clone() + } + }; + + docids.intersect_with(parent_docids); + + if !docids.is_empty() { + let mana = mana.checked_sub(proximity as u32).unwrap(); + // We are the last pair, we return without recursing as we don't have any child. + if tail.len() < 2 { return Ok(Some(docids)) } + if let Some(di) = mdfs_step(index, rtxn, mana, tail, candidates, &docids, union_cache)? { + return Ok(Some(di)) + } + } + } + + Ok(None) +} + +fn words_pair_combinations<'h>( + w1: &'h HashMap, + w2: &'h HashMap, +) -> Vec<(&'h str, &'h str)> +{ + let mut pairs = Vec::new(); + for (w1, (_typos, docids1)) in w1 { + for (w2, (_typos, docids2)) in w2 { + if !docids1.is_disjoint(&docids2) { + pairs.push((w1.as_str(), w2.as_str())); + } + } + } + pairs +} diff --git a/src/search.rs b/src/search.rs index 5998e3aec..60b89d678 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,5 +1,4 @@ use std::collections::{HashMap, HashSet}; -use std::collections::hash_map::Entry::{Occupied, Vacant}; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::DFA; @@ -9,6 +8,7 @@ use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; use crate::query_tokens::{QueryTokens, QueryToken}; +use crate::mdfs::Mdfs; use crate::{Index, DocumentId}; // Building these factories is not free. @@ -132,111 +132,6 @@ impl<'a> Search<'a> { candidates } - // TODO Move this elsewhere! - fn mana_depth_first_search( - &self, - words: &[(HashMap, RoaringBitmap)], - candidates: &RoaringBitmap, - union_cache: &mut HashMap<(usize, u8), RoaringBitmap>, - ) -> anyhow::Result> - { - fn words_pair_combinations<'h>( - w1: &'h HashMap, - w2: &'h HashMap, - ) -> Vec<(&'h str, &'h str)> - { - let mut pairs = Vec::new(); - for (w1, (_typos, docids1)) in w1 { - for (w2, (_typos, docids2)) in w2 { - if !docids1.is_disjoint(&docids2) { - pairs.push((w1.as_str(), w2.as_str())); - } - } - } - pairs - } - - fn mdfs( - index: &Index, - rtxn: &heed::RoTxn, - mana: u32, - words: &[(HashMap, RoaringBitmap)], - candidates: &RoaringBitmap, - parent_docids: &RoaringBitmap, - union_cache: &mut HashMap<(usize, u8), RoaringBitmap>, - ) -> anyhow::Result> - { - use std::cmp::{min, max}; - - let (words1, words2) = (&words[0].0, &words[1].0); - let pairs = words_pair_combinations(words1, words2); - let tail = &words[1..]; - let nb_children = tail.len() as u32 - 1; - - // The minimum amount of mana that you must consume is at least 1 and the - // amount of mana that your children can consume. Because the last child must - // consume the remaining mana, it is mandatory that there not too much at the end. - let min_proximity = max(1, mana.saturating_sub(nb_children * 8)) as u8; - - // The maximum amount of mana that you can use is 8 or the remaining amount of - // mana minus your children, as you can't just consume all the mana, - // your children must have at least 1 mana. - let max_proximity = min(8, mana - nb_children) as u8; - - for proximity in min_proximity..=max_proximity { - let mut docids = match union_cache.entry((words.len(), proximity)) { - Occupied(entry) => entry.get().clone(), - Vacant(entry) => { - let mut docids = RoaringBitmap::new(); - if proximity == 8 { - docids = candidates.clone(); - } else { - for (w1, w2) in pairs.iter().cloned() { - let key = (w1, w2, proximity); - if let Some(di) = index.word_pair_proximity_docids.get(rtxn, &key)? { - docids.union_with(&di); - } - } - } - entry.insert(docids).clone() - } - }; - - docids.intersect_with(parent_docids); - - if !docids.is_empty() { - let mana = mana.checked_sub(proximity as u32).unwrap(); - // We are the last pair, we return without recursing as we don't have any child. - if tail.len() < 2 { return Ok(Some(docids)) } - if let Some(di) = mdfs(index, rtxn, mana, tail, candidates, &docids, union_cache)? { - return Ok(Some(di)) - } - } - } - - Ok(None) - } - - // Compute the number of pairs (windows) we have for this list of words. - // If there only is one word therefore the only possible documents are the candidates. - let initial_mana = match words.len().checked_sub(1) { - Some(nb_windows) if nb_windows != 0 => nb_windows as u32, - _ => return Ok(Some(candidates.clone())), - }; - - // TODO We must keep track of where we are in terms of mana and that should either be - // handled by an Iterator or by the caller. Keeping track of the amount of mana - // is an optimization, it makes this mdfs to only be called with the next valid - // mana and not called with all of the previous mana values. - for mana in initial_mana..=initial_mana * 8 { - if let Some(answer) = mdfs(&self.index, &self.rtxn, mana, words, candidates, candidates, union_cache)? { - return Ok(Some(answer)); - } - } - - Ok(None) - } - pub fn execute(&self) -> anyhow::Result { let limit = self.limit; @@ -257,29 +152,24 @@ impl<'a> Search<'a> { } let derived_words = self.fetch_words_docids(&fst, dfas)?; - let mut candidates = Self::compute_candidates(&derived_words); + let candidates = Self::compute_candidates(&derived_words); debug!("candidates: {:?}", candidates); + // The mana depth first search is a revised DFS that explore + // solutions in the order of their proximities. + let mut mdfs = Mdfs::new(self.index, self.rtxn, &derived_words, candidates); let mut documents = Vec::new(); - let mut union_cache = HashMap::new(); - // We execute the DFS until we find enough documents, we run it with the - // candidates list and remove the found documents from this list at each iteration. + // We execute the Mdfs iterator until we find enough documents. while documents.iter().map(RoaringBitmap::len).sum::() < limit as u64 { - let answer = self.mana_depth_first_search(&derived_words, &candidates, &mut union_cache)?; - - let answer = match answer { - Some(answer) if !answer.is_empty() => answer, - _ => break, - }; - - debug!("answer: {:?}", answer); - - // We remove the answered documents from the list of - // candidates to be sure we don't search for them again. - candidates.difference_with(&answer); - documents.push(answer); + match mdfs.next().transpose()? { + Some(answer) => { + debug!("answer: {:?}", answer); + documents.push(answer); + }, + None => break, + } } let found_words = derived_words.into_iter().flat_map(|(w, _)| w).map(|(w, _)| w).collect();