diff --git a/src/lib.rs b/src/lib.rs index 5285cf426..e160d13b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,48 +117,62 @@ impl Index { positions.push(union_positions.iter().collect()); } - // let positions = BestProximity::new(positions).next().unwrap_or_default(); - let _positions: Vec> = positions; - let positions = vec![0u32]; - eprintln!("best proximity {:?}", positions); + let mut documents = Vec::new(); - let mut intersect_docids: Option = None; - for ((word, is_prefix, dfa), pos) in words_positions.into_iter().zip(positions) { - let mut count = 0; - let mut union_docids = RoaringBitmap::default(); + for (_proximity, positions) in BestProximity::new(positions) { + let mut same_proximity_union = RoaringBitmap::default(); - if false && word.len() <= 4 && is_prefix { - let mut key = word.as_bytes()[..word.len().min(5)].to_vec(); - key.extend_from_slice(&pos.to_be_bytes()); - if let Some(ids) = self.prefix_postings_ids.get(rtxn, &key)? { - let right = RoaringBitmap::deserialize_from(ids)?; - union_docids.union_with(&right); - count = 1; - } - } else { - let mut stream = fst.search(dfa).into_stream(); - while let Some(word) = stream.next() { - let word = std::str::from_utf8(word)?; - let mut key = word.as_bytes().to_vec(); - key.extend_from_slice(&pos.to_be_bytes()); - if let Some(attrs) = self.postings_ids.get(rtxn, &key)? { - let right = RoaringBitmap::deserialize_from(attrs)?; - union_docids.union_with(&right); - count += 1; + for positions in positions { + let mut intersect_docids: Option = None; + for ((word, is_prefix, dfa), pos) in words_positions.iter().zip(positions) { + let mut count = 0; + let mut union_docids = RoaringBitmap::default(); + + // TODO re-enable the prefixes system + if false && word.len() <= 4 && *is_prefix { + let mut key = word.as_bytes()[..word.len().min(5)].to_vec(); + key.extend_from_slice(&pos.to_be_bytes()); + if let Some(ids) = self.prefix_postings_ids.get(rtxn, &key)? { + let right = RoaringBitmap::deserialize_from(ids)?; + union_docids.union_with(&right); + count = 1; + } + } else { + let mut stream = fst.search(dfa).into_stream(); + while let Some(word) = stream.next() { + let word = std::str::from_utf8(word)?; + let mut key = word.as_bytes().to_vec(); + key.extend_from_slice(&pos.to_be_bytes()); + if let Some(attrs) = self.postings_ids.get(rtxn, &key)? { + let right = RoaringBitmap::deserialize_from(attrs)?; + union_docids.union_with(&right); + count += 1; + } + } } + + let _ = count; + + match &mut intersect_docids { + Some(left) => left.intersect_with(&union_docids), + None => intersect_docids = Some(union_docids), + } + } + + if let Some(intersect_docids) = intersect_docids { + same_proximity_union.union_with(&intersect_docids); } } - let _ = count; + documents.push(same_proximity_union); - match &mut intersect_docids { - Some(left) => left.intersect_with(&union_docids), - None => intersect_docids = Some(union_docids), + // We found enough documents we can stop here + if documents.iter().map(RoaringBitmap::len).sum::() >= 20 { + break } } - eprintln!("{} candidates", intersect_docids.as_ref().map_or(0, |r| r.len())); - - Ok(intersect_docids.unwrap_or_default().iter().take(20).collect()) + eprintln!("{} candidates", documents.iter().map(RoaringBitmap::len).sum::()); + Ok(documents.iter().flatten().take(20).collect()) } }