diff --git a/src/search.rs b/src/search.rs index a6eb3b3d4..36b67b502 100644 --- a/src/search.rs +++ b/src/search.rs @@ -80,8 +80,7 @@ impl<'a> Search<'a> { /// Fetch the words from the given FST related to the given DFAs along with /// the associated documents ids. fn fetch_words_docids( - rtxn: &heed::RoTxn, - index: &Index, + &self, fst: &fst::Set<&[u8]>, dfas: Vec<(String, bool, DFA)>, ) -> anyhow::Result, RoaringBitmap)>> @@ -98,7 +97,7 @@ impl<'a> Search<'a> { while let Some((word, state)) = stream.next() { let word = std::str::from_utf8(word)?; - let docids = index.word_docids.get(rtxn, word)?.unwrap(); + let docids = self.index.word_docids.get(self.rtxn, word)?.unwrap(); let distance = dfa.distance(state); unions_docids.union_with(&docids); acc_derived_words.insert(word.to_string(), (distance.to_u8(), docids)); @@ -134,8 +133,7 @@ impl<'a> Search<'a> { } fn fecth_keywords( - rtxn: &heed::RoTxn, - index: &Index, + &self, derived_words: &[(HashMap, RoaringBitmap)], candidate: DocumentId, ) -> anyhow::Result> @@ -148,7 +146,7 @@ impl<'a> Search<'a> { for (word, (_distance, docids)) in words { if !docids.contains(candidate) { continue; } - if let Some(positions) = index.docid_word_positions.get(rtxn, &(candidate, word))? { + if let Some(positions) = self.index.docid_word_positions.get(self.rtxn, &(candidate, word))? { union_positions.union_with(&positions); } } @@ -158,12 +156,73 @@ impl<'a> Search<'a> { Ok(keywords) } + fn words_pair_combinations<'h>( + w1: &'h HashMap, + w2: &'h HashMap, + ) -> Vec<(&'h str, &'h str)> + { + let mut pairs = Vec::new(); + for (w1, (_typos, docids1)) in w1 { + for (w2, (_typos, docids2)) in w2 { + if !docids1.is_disjoint(&docids2) { + pairs.push((w1.as_str(), w2.as_str())); + } + } + } + pairs + } + + fn depth_first_search( + &self, + words: &[(HashMap, RoaringBitmap)], + candidates: &RoaringBitmap, + parent_docids: Option<&RoaringBitmap>, + union_cache: &mut HashMap<(usize, u8), RoaringBitmap>, + ) -> anyhow::Result> + { + let (words1, words2) = (&words[0].0, &words[1].0); + let pairs = Self::words_pair_combinations(words1, words2); + + for proximity in 1..=8 { + let mut docids = match union_cache.entry((words.len(), proximity)) { + Occupied(entry) => entry.get().clone(), + Vacant(entry) => { + let mut docids = RoaringBitmap::new(); + if proximity == 8 { + docids = candidates.clone(); + } else { + for (w1, w2) in pairs.iter().cloned() { + let key = (w1, w2, proximity); + if let Some(di) = self.index.word_pair_proximity_docids.get(self.rtxn, &key)? { + docids.union_with(&di); + } + } + } + entry.insert(docids).clone() + } + }; + + if let Some(parent_docids) = &parent_docids { + docids.intersect_with(parent_docids); + } + + if !docids.is_empty() { + let words = &words[1..]; + // We are the last word. + if words.len() < 2 { return Ok(Some(docids)) } + if let Some(di) = self.depth_first_search(words, candidates, Some(&docids), union_cache)? { + return Ok(Some(di)) + } + } + } + + Ok(None) + } + pub fn execute(&self) -> anyhow::Result { - let rtxn = self.rtxn; - let index = self.index; let limit = self.limit; - let fst = match index.fst(rtxn)? { + let fst = match self.index.fst(self.rtxn)? { Some(fst) => fst, None => return Ok(Default::default()), }; @@ -179,8 +238,8 @@ impl<'a> Search<'a> { return Ok(Default::default()); } - let derived_words = Self::fetch_words_docids(rtxn, index, &fst, dfas)?; - let candidates = Self::compute_candidates(&derived_words); + let derived_words = self.fetch_words_docids(&fst, dfas)?; + let mut candidates = Self::compute_candidates(&derived_words); debug!("candidates: {:?}", candidates); @@ -191,73 +250,7 @@ impl<'a> Search<'a> { return Ok(SearchResult { found_words, documents_ids }); } - fn words_pair_combinations<'a>( - w1: &'a HashMap, - w2: &'a HashMap, - ) -> Vec<(&'a str, &'a str)> - { - let mut pairs = Vec::new(); - for (w1, (_typos, docids1)) in w1 { - for (w2, (_typos, docids2)) in w2 { - if !docids1.is_disjoint(&docids2) { - pairs.push((w1.as_str(), w2.as_str())); - } - } - } - pairs - } - - fn depth_first_search( - index: &Index, - rtxn: &heed::RoTxn, - words: &[(HashMap, RoaringBitmap)], - candidates: &RoaringBitmap, - parent_docids: Option<&RoaringBitmap>, - union_cache: &mut HashMap<(usize, u8), RoaringBitmap>, - ) -> anyhow::Result> - { - let (words1, words2) = (&words[0].0, &words[1].0); - let pairs = words_pair_combinations(words1, words2); - - for proximity in 1..=8 { - let mut docids = match union_cache.entry((words.len(), proximity)) { - Occupied(entry) => entry.get().clone(), - Vacant(entry) => { - let mut docids = RoaringBitmap::new(); - if proximity == 8 { - docids = candidates.clone(); - } else { - for (w1, w2) in pairs.iter().cloned() { - let key = (w1, w2, proximity); - if let Some(di) = index.word_pair_proximity_docids.get(rtxn, &key)? { - docids.union_with(&di); - } - } - } - entry.insert(docids).clone() - } - }; - - if let Some(parent_docids) = &parent_docids { - docids.intersect_with(parent_docids); - } - - if !docids.is_empty() { - let words = &words[1..]; - // We are the last word. - if words.len() < 2 { return Ok(Some(docids)) } - if let Some(di) = depth_first_search(index, rtxn, words, candidates, Some(&docids), union_cache)? { - return Ok(Some(di)) - } - } - } - - Ok(None) - } - let mut union_cache = HashMap::new(); - let answer = depth_first_search(index, rtxn, &derived_words, &candidates, None, &mut union_cache)?; - let mut documents = Vec::new(); if let Some(answer) = answer { documents.push(answer);