diff --git a/src/search.rs b/src/search.rs index e41300cca..a4cf5ead3 100644 --- a/src/search.rs +++ b/src/search.rs @@ -195,22 +195,39 @@ impl<'a> Search<'a> { return Ok(SearchResult { found_words, documents_ids }); } + fn words_pair_combinations<'a>( + w1: &'a HashMap, + w2: &'a HashMap, + ) -> Vec<(&'a str, &'a str)> + { + let mut pairs = Vec::new(); + for (w1, (_typos, docids1)) in w1 { + for (w2, (_typos, docids2)) in w2 { + if !docids1.is_disjoint(&docids2) { + pairs.push((w1.as_str(), w2.as_str())); + } + } + } + pairs + } + let mut answer = RoaringBitmap::new(); for (i, words) in derived_words.windows(2).enumerate() { - let w1: Vec<_> = words[0].0.keys().collect(); - let w2: Vec<_> = words[1].0.keys().collect(); + let pairs = words_pair_combinations(&words[0].0, &words[1].0); + eprintln!("found pairs {:?}", pairs); - let key = (w1[0].as_str(), w2[0].as_str(), 1); - match index.word_pair_proximity_docids.get(rtxn, &key)? { - Some(docids) => if i == 0 { - answer = docids; - } else { - answer.intersect_with(&docids); - }, - None => { - answer = RoaringBitmap::new(); - break; - }, + let mut pairs_union = RoaringBitmap::new(); + for (w1, w2) in pairs { + let key = (w1, w2, 1); + if let Some(docids) = index.word_pair_proximity_docids.get(rtxn, &key)? { + pairs_union.union_with(&docids); + } + } + + if i == 0 { + answer = pairs_union; + } else { + answer.intersect_with(&pairs_union); } }