MeiliSearch/milli/src/search/facet/facet_distribution_iter.rs

213 lines
7.2 KiB
Rust
Raw Normal View History

2022-09-01 11:40:29 +02:00
use std::ops::ControlFlow;
use heed::Result;
use roaring::RoaringBitmap;
use super::{get_first_facet_value, get_highest_level};
2022-09-01 11:40:29 +02:00
use crate::heed_codec::facet::new::{FacetGroupValueCodec, FacetKey, FacetKeyCodec, MyByteSlice};
pub fn iterate_over_facet_distribution<'t, CB>(
rtxn: &'t heed::RoTxn<'t>,
2022-08-31 09:36:19 +02:00
db: heed::Database<FacetKeyCodec<MyByteSlice>, FacetGroupValueCodec>,
field_id: u16,
candidates: &RoaringBitmap,
callback: CB,
) -> Result<()>
where
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
{
let mut fd = FacetDistribution { rtxn, db, field_id, callback };
let highest_level =
2022-08-31 09:36:19 +02:00
get_highest_level(rtxn, db.remap_key_type::<FacetKeyCodec<MyByteSlice>>(), field_id)?;
if let Some(first_bound) = get_first_facet_value::<MyByteSlice>(rtxn, db, field_id)? {
fd.iterate(candidates, highest_level, first_bound, usize::MAX)?;
return Ok(());
} else {
return Ok(());
}
}
struct FacetDistribution<'t, CB>
where
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
{
rtxn: &'t heed::RoTxn<'t>,
2022-08-31 09:36:19 +02:00
db: heed::Database<FacetKeyCodec<MyByteSlice>, FacetGroupValueCodec>,
field_id: u16,
callback: CB,
}
impl<'t, CB> FacetDistribution<'t, CB>
where
CB: FnMut(&'t [u8], u64) -> ControlFlow<()>,
{
fn iterate_level_0(
&mut self,
candidates: &RoaringBitmap,
starting_bound: &'t [u8],
group_size: usize,
) -> Result<ControlFlow<()>> {
let starting_key =
FacetKey { field_id: self.field_id, level: 0, left_bound: starting_bound };
let iter = self.db.range(self.rtxn, &(starting_key..))?.take(group_size);
for el in iter {
let (key, value) = el?;
// The range is unbounded on the right and the group size for the highest level is MAX,
// so we need to check that we are not iterating over the next field id
if key.field_id != self.field_id {
return Ok(ControlFlow::Break(()));
}
let docids_in_common = value.bitmap.intersection_len(candidates);
if docids_in_common > 0 {
match (self.callback)(key.left_bound, docids_in_common) {
ControlFlow::Continue(_) => {}
ControlFlow::Break(_) => return Ok(ControlFlow::Break(())),
}
}
}
return Ok(ControlFlow::Continue(()));
}
fn iterate(
&mut self,
candidates: &RoaringBitmap,
level: u8,
starting_bound: &'t [u8],
group_size: usize,
) -> Result<ControlFlow<()>> {
if level == 0 {
return self.iterate_level_0(candidates, starting_bound, group_size);
}
let starting_key = FacetKey { field_id: self.field_id, level, left_bound: starting_bound };
let iter = self.db.range(&self.rtxn, &(&starting_key..)).unwrap().take(group_size);
for el in iter {
let (key, value) = el.unwrap();
// The range is unbounded on the right and the group size for the highest level is MAX,
// so we need to check that we are not iterating over the next field id
if key.field_id != self.field_id {
return Ok(ControlFlow::Break(()));
}
let docids_in_common = value.bitmap & candidates;
if docids_in_common.len() > 0 {
let cf = self.iterate(
&docids_in_common,
level - 1,
key.left_bound,
value.size as usize,
)?;
match cf {
ControlFlow::Continue(_) => {}
ControlFlow::Break(_) => return Ok(ControlFlow::Break(())),
}
}
}
return Ok(ControlFlow::Continue(()));
}
}
#[cfg(test)]
mod tests {
2022-09-01 11:40:29 +02:00
use std::ops::ControlFlow;
use heed::BytesDecode;
use rand::{Rng, SeedableRng};
use roaring::RoaringBitmap;
use super::iterate_over_facet_distribution;
2022-09-01 11:40:29 +02:00
use crate::heed_codec::facet::new::ordered_f64_codec::OrderedF64Codec;
use crate::milli_snap;
use crate::search::facet::test::FacetIndex;
fn get_simple_index() -> FacetIndex<OrderedF64Codec> {
let index = FacetIndex::<OrderedF64Codec>::new(4, 8);
let mut txn = index.env.write_txn().unwrap();
for i in 0..256u16 {
let mut bitmap = RoaringBitmap::new();
bitmap.insert(i as u32);
index.insert(&mut txn, 0, &(i as f64), &bitmap);
}
txn.commit().unwrap();
index
}
fn get_random_looking_index() -> FacetIndex<OrderedF64Codec> {
let index = FacetIndex::<OrderedF64Codec>::new(4, 8);
let mut txn = index.env.write_txn().unwrap();
let mut rng = rand::rngs::SmallRng::from_seed([0; 32]);
let keys =
std::iter::from_fn(|| Some(rng.gen_range(0..256))).take(128).collect::<Vec<u32>>();
for (_i, key) in keys.into_iter().enumerate() {
let mut bitmap = RoaringBitmap::new();
bitmap.insert(key);
bitmap.insert(key + 100);
index.insert(&mut txn, 0, &(key as f64), &bitmap);
}
txn.commit().unwrap();
index
}
#[test]
fn random_looking_index_snap() {
let index = get_random_looking_index();
2022-09-01 11:09:01 +02:00
milli_snap!(format!("{index}"));
}
#[test]
fn filter_distribution_all() {
let indexes = [get_simple_index(), get_random_looking_index()];
for (i, index) in indexes.iter().enumerate() {
let txn = index.env.read_txn().unwrap();
let candidates = (0..=255).into_iter().collect::<RoaringBitmap>();
let mut results = String::new();
iterate_over_facet_distribution(
&txn,
index.db.content,
0,
&candidates,
|facet, count| {
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
results.push_str(&format!("{facet}: {count}\n"));
ControlFlow::Continue(())
},
)
.unwrap();
2022-09-01 11:09:01 +02:00
milli_snap!(results, i);
txn.commit().unwrap();
}
}
#[test]
fn filter_distribution_all_stop_early() {
let indexes = [get_simple_index(), get_random_looking_index()];
for (i, index) in indexes.iter().enumerate() {
let txn = index.env.read_txn().unwrap();
let candidates = (0..=255).into_iter().collect::<RoaringBitmap>();
let mut results = String::new();
let mut nbr_facets = 0;
iterate_over_facet_distribution(
&txn,
index.db.content,
0,
&candidates,
|facet, count| {
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
if nbr_facets == 100 {
return ControlFlow::Break(());
} else {
nbr_facets += 1;
results.push_str(&format!("{facet}: {count}\n"));
ControlFlow::Continue(())
}
},
)
.unwrap();
2022-09-01 11:09:01 +02:00
milli_snap!(results, i);
txn.commit().unwrap();
}
}
}