diff --git a/src/search.rs b/src/search.rs index b55b38904..eadb656ea 100644 --- a/src/search.rs +++ b/src/search.rs @@ -201,6 +201,74 @@ impl<'a> Search<'a> { Ok(union_docids) } + // Returns `true` if there is documents in common between the two words and positions given. + fn contains_documents( + rtxn: &heed::RoTxn, + index: &Index, + (lword, lpos): (usize, u32), + (rword, rpos): (usize, u32), + candidates: &RoaringBitmap, + derived_words: &[Vec<(String, u8, RoaringBitmap)>], + union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, + non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, + attribute_union_cache: &mut HashMap<(usize, u32), RoaringBitmap>, + attribute_non_disjoint_cache: &mut HashMap<((usize, u32), (usize, u32)), bool>, + ) -> bool + { + if lpos == rpos { return false } + + // TODO move this function to a better place. + let (lattr, _) = node::extract_position(lpos); + let (rattr, _) = node::extract_position(rpos); + + if lattr == rattr { + // We retrieve or compute the intersection between the two given words and positions. + *non_disjoint_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { + // We retrieve or compute the unions for the two words and positions. + union_cache.entry((lword, lpos)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[lword]; + Self::union_word_position(rtxn, index, words, lpos).unwrap() + }); + union_cache.entry((rword, rpos)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[rword]; + Self::union_word_position(rtxn, index, words, rpos).unwrap() + }); + + // TODO is there a way to avoid this double gets? + let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); + let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); + + // We first check that the docids of these unions are part of the candidates. + if lunion_docids.is_disjoint(candidates) { return false } + if runion_docids.is_disjoint(candidates) { return false } + + !lunion_docids.is_disjoint(&runion_docids) + }) + } else { + *attribute_non_disjoint_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| { + // We retrieve or compute the unions for the two words and positions. + attribute_union_cache.entry((lword, lattr)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[lword]; + Self::union_word_attribute(rtxn, index, words, lattr).unwrap() + }); + attribute_union_cache.entry((rword, rattr)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[rword]; + Self::union_word_attribute(rtxn, index, words, rattr).unwrap() + }); + + // TODO is there a way to avoid this double gets? + let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap(); + let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap(); + + // We first check that the docids of these unions are part of the candidates. + if lunion_docids.is_disjoint(candidates) { return false } + if runion_docids.is_disjoint(candidates) { return false } + + !lunion_docids.is_disjoint(&runion_docids) + }) + } + } + pub fn execute(&self) -> anyhow::Result { let rtxn = self.rtxn; let index = self.index; @@ -225,74 +293,27 @@ impl<'a> Search<'a> { let candidates = Self::compute_candidates(rtxn, index, &derived_words)?; let union_cache = HashMap::new(); - let mut intersect_cache = HashMap::new(); + let mut non_disjoint_cache = HashMap::new(); let mut attribute_union_cache = HashMap::new(); - let mut attribute_intersect_cache = HashMap::new(); + let mut attribute_non_disjoint_cache = HashMap::new(); let candidates = Rc::new(RefCell::new(candidates)); let union_cache = Rc::new(RefCell::new(union_cache)); - // Returns `true` if there is documents in common between the two words and positions given. - // TODO move this closure to a better place. let candidates_cloned = candidates.clone(); let union_cache_cloned = union_cache.clone(); - let mut contains_documents = |(lword, lpos), (rword, rpos)| { - if lpos == rpos { return false } - - // TODO move this function to a better place. - let (lattr, _) = node::extract_position(lpos); - let (rattr, _) = node::extract_position(rpos); - - let candidates = &candidates_cloned.borrow(); - let mut union_cache = union_cache_cloned.borrow_mut(); - - if lattr == rattr { - // We retrieve or compute the intersection between the two given words and positions. - *intersect_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and positions. - union_cache.entry((lword, lpos)).or_insert_with(|| { - let words: &Vec<_> = &derived_words[lword]; - Self::union_word_position(rtxn, index, words, lpos).unwrap() - }); - union_cache.entry((rword, rpos)).or_insert_with(|| { - let words: &Vec<_> = &derived_words[rword]; - Self::union_word_position(rtxn, index, words, rpos).unwrap() - }); - - // TODO is there a way to avoid this double gets? - let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); - let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); - - // We first check that the docids of these unions are part of the candidates. - if lunion_docids.is_disjoint(candidates) { return false } - if runion_docids.is_disjoint(candidates) { return false } - - !lunion_docids.is_disjoint(&runion_docids) - }) - } else { - *attribute_intersect_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and positions. - attribute_union_cache.entry((lword, lattr)).or_insert_with(|| { - let words: &Vec<_> = &derived_words[lword]; - Self::union_word_attribute(rtxn, index, words, lattr).unwrap() - }); - attribute_union_cache.entry((rword, rattr)).or_insert_with(|| { - let words: &Vec<_> = &derived_words[rword]; - Self::union_word_attribute(rtxn, index, words, rattr).unwrap() - }); - - // TODO is there a way to avoid this double gets? - let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap(); - let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap(); - - // We first check that the docids of these unions are part of the candidates. - if lunion_docids.is_disjoint(candidates) { return false } - if runion_docids.is_disjoint(candidates) { return false } - - !lunion_docids.is_disjoint(&runion_docids) - }) - } + let mut contains_documents = |left, right| { + Self::contains_documents( + rtxn, index, + left, right, + &candidates_cloned.borrow(), + &derived_words, + &mut union_cache_cloned.borrow_mut(), + &mut non_disjoint_cache, + &mut attribute_union_cache, + &mut attribute_non_disjoint_cache, + ) }; // We instantiate an astar bag Iterator that returns the best paths incrementally, @@ -320,7 +341,8 @@ impl<'a> Search<'a> { // Precompute the potentially missing unions positions.iter().enumerate().for_each(|(word, pos)| { union_cache.entry((word, *pos)).or_insert_with(|| { - let words = &derived_words[word]; + let words = &&derived_words[word]; + Self::union_word_position(rtxn, index, words, *pos).unwrap() }); });