diff --git a/milli/src/search/facet/facet_distribution.rs b/milli/src/search/facet/facet_distribution.rs index 45ce13a28..f6ef51ccd 100644 --- a/milli/src/search/facet/facet_distribution.rs +++ b/milli/src/search/facet/facet_distribution.rs @@ -240,42 +240,49 @@ impl<'a> FacetDistribution<'a> { } fn facet_values(&self, field_id: FieldId) -> heed::Result> { - use FacetType::{Number, String}; + // use FacetType::{Number, String}; - match self.candidates { - Some(ref candidates) => { - // Classic search, candidates were specified, we must return facet values only related - // to those candidates. We also enter here for facet strings for performance reasons. - let mut distribution = BTreeMap::new(); - if candidates.len() <= CANDIDATES_THRESHOLD { - self.facet_distribution_from_documents( - field_id, - Number, - candidates, - &mut distribution, - )?; - self.facet_distribution_from_documents( - field_id, - String, - candidates, - &mut distribution, - )?; - } else { - self.facet_numbers_distribution_from_facet_levels( - field_id, - candidates, - &mut distribution, - )?; - self.facet_strings_distribution_from_facet_levels( - field_id, - candidates, - &mut distribution, - )?; - } - Ok(distribution) - } - None => self.facet_values_from_raw_facet_database(field_id), + let candidates = match self.candidates.as_ref() { + Some(candidates) => candidates.clone(), + None => todo!("fetch candidates"), + }; + + let mut distribution = BTreeMap::new(); + + let number_distribution = facet_distribution_iter::count_iterate_over_facet_distribution( + self.rtxn, + self.index + .facet_id_f64_docids + .remap_key_type::>(), + field_id, + &candidates, + )?; + + for (count, facet_key, _) in number_distribution { + let facet_key = OrderedF64Codec::bytes_decode(facet_key).unwrap(); + distribution.insert(facet_key.to_string(), count); } + + let string_distribution = facet_distribution_iter::count_iterate_over_facet_distribution( + self.rtxn, + self.index + .facet_id_string_docids + .remap_key_type::>(), + field_id, + &candidates, + )?; + + for (count, facet_key, any_docid) in string_distribution { + let facet_key = StrRefCodec::bytes_decode(facet_key).unwrap(); + + let key: (FieldId, _, &str) = (field_id, any_docid, facet_key); + let original_string = + self.index.field_id_docid_facet_strings.get(self.rtxn, &key)?.unwrap().to_owned(); + + distribution.insert(original_string, count); + } + + Ok(distribution) } pub fn compute_stats(&self) -> Result> { diff --git a/milli/src/search/facet/facet_distribution_iter.rs b/milli/src/search/facet/facet_distribution_iter.rs index 0e1efaa0e..acd936eff 100644 --- a/milli/src/search/facet/facet_distribution_iter.rs +++ b/milli/src/search/facet/facet_distribution_iter.rs @@ -1,5 +1,5 @@ use std::cmp::Reverse; -use std::collections::{BTreeMap, BinaryHeap}; +use std::collections::BinaryHeap; use std::ops::ControlFlow; use heed::Result; @@ -46,15 +46,12 @@ where } } -pub fn count_iterate_over_facet_distribution<'t, CB>( +pub fn count_iterate_over_facet_distribution<'t>( rtxn: &'t heed::RoTxn<'t>, db: heed::Database, FacetGroupValueCodec>, field_id: u16, candidates: &RoaringBitmap, -) -> Result> -where - CB: FnMut(&'t [u8], u64, DocumentId) -> Result>, -{ +) -> Result> { #[derive(Debug, PartialOrd, Ord, PartialEq, Eq)] struct LevelEntry<'t> { /// The number of candidates in this entry. @@ -65,6 +62,8 @@ where left_bound: &'t [u8], /// The number of keys we must look for after `left_bound`. group_size: u8, + /// Any docid in the set of matching documents. Used to find the original facet string. + any_docid: u32, } // Represents the list of keys that we must explore. @@ -88,20 +87,23 @@ where if key.field_id != field_id { break; } - let count = value.bitmap.intersection_len(&candidates); + let intersection = value.bitmap & candidates; + let count = intersection.len(); if count != 0 { heap.push(LevelEntry { count, level: Reverse(key.level), left_bound: key.left_bound, group_size: value.size, + any_docid: intersection.min().unwrap(), }); } } - while let Some(LevelEntry { count, level, left_bound, group_size }) = heap.pop() { + while let Some(LevelEntry { count, level, left_bound, group_size, any_docid }) = heap.pop() + { if let Reverse(0) = level { - results.push((count, left_bound)); + results.push((count, left_bound, any_docid)); // TODO better just call the user callback and ask for a ControlFlow if results.len() == 20 { break; @@ -116,13 +118,15 @@ where if key.field_id != field_id { break; } - let count = value.bitmap.intersection_len(&candidates); + let intersection = value.bitmap & candidates; + let count = intersection.len(); if count != 0 { heap.push(LevelEntry { count, level: Reverse(key.level), left_bound: key.left_bound, group_size: value.size, + any_docid: intersection.min().unwrap(), }); } }