mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-24 13:54:26 +01:00
294 lines
11 KiB
Rust
294 lines
11 KiB
Rust
use std::cmp::Reverse;
|
|
use std::collections::BinaryHeap;
|
|
use std::ops::ControlFlow;
|
|
|
|
use heed::Result;
|
|
use roaring::RoaringBitmap;
|
|
|
|
use super::{get_first_facet_value, get_highest_level};
|
|
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupKeyCodec, FacetGroupValueCodec};
|
|
use crate::heed_codec::BytesRefCodec;
|
|
use crate::DocumentId;
|
|
|
|
/// Call the given closure on the facet distribution of the candidate documents.
|
|
///
|
|
/// The arguments to the closure are:
|
|
/// - the facet value, as a byte slice
|
|
/// - the number of documents among the candidates that contain this facet value
|
|
/// - the id of a document which contains the facet value. Note that this document
|
|
/// is not necessarily from the list of candidates, it is simply *any* document which
|
|
/// contains this facet value.
|
|
///
|
|
/// The return value of the closure is a `ControlFlow<()>` which indicates whether we should
|
|
/// keep iterating over the different facet values or stop.
|
|
pub fn lexicographically_iterate_over_facet_distribution<'t, CB>(
|
|
rtxn: &'t heed::RoTxn<'t>,
|
|
db: heed::Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
|
|
field_id: u16,
|
|
candidates: &RoaringBitmap,
|
|
callback: CB,
|
|
) -> Result<()>
|
|
where
|
|
CB: FnMut(&'t [u8], u64, DocumentId) -> Result<ControlFlow<()>>,
|
|
{
|
|
let mut fd = LexicographicFacetDistribution { rtxn, db, field_id, callback };
|
|
let highest_level = get_highest_level(
|
|
rtxn,
|
|
db.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>(),
|
|
field_id,
|
|
)?;
|
|
|
|
if let Some(first_bound) = get_first_facet_value::<BytesRefCodec>(rtxn, db, field_id)? {
|
|
fd.iterate(candidates, highest_level, first_bound, usize::MAX)?;
|
|
Ok(())
|
|
} else {
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
pub fn count_iterate_over_facet_distribution<'t, CB>(
|
|
rtxn: &'t heed::RoTxn<'t>,
|
|
db: heed::Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
|
|
field_id: u16,
|
|
candidates: &RoaringBitmap,
|
|
mut callback: CB,
|
|
) -> Result<()>
|
|
where
|
|
CB: FnMut(&'t [u8], u64, DocumentId) -> Result<ControlFlow<()>>,
|
|
{
|
|
/// # Important
|
|
/// The order of the fields determines the order in which the facet values will be returned.
|
|
/// This struct is inserted in a BinaryHeap and popped later on.
|
|
#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
|
|
struct LevelEntry<'t> {
|
|
/// The number of candidates in this entry.
|
|
count: u64,
|
|
/// The key level of the entry.
|
|
level: Reverse<u8>,
|
|
/// The left bound key.
|
|
left_bound: &'t [u8],
|
|
/// The number of keys we must look for after `left_bound`.
|
|
group_size: u8,
|
|
/// Any docid in the set of matching documents. Used to find the original facet string.
|
|
any_docid: u32,
|
|
}
|
|
|
|
// Represents the list of keys that we must explore.
|
|
let mut heap = BinaryHeap::new();
|
|
let highest_level = get_highest_level(
|
|
rtxn,
|
|
db.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>(),
|
|
field_id,
|
|
)?;
|
|
|
|
if let Some(first_bound) = get_first_facet_value::<BytesRefCodec>(rtxn, db, field_id)? {
|
|
// We first fill the heap with values from the highest level
|
|
let starting_key =
|
|
FacetGroupKey { field_id, level: highest_level, left_bound: first_bound };
|
|
for el in db.range(rtxn, &(&starting_key..))?.take(usize::MAX) {
|
|
let (key, value) = el?;
|
|
// The range is unbounded on the right and the group size for the highest level is MAX,
|
|
// so we need to check that we are not iterating over the next field id
|
|
if key.field_id != field_id {
|
|
break;
|
|
}
|
|
let intersection = value.bitmap & candidates;
|
|
let count = intersection.len();
|
|
if count != 0 {
|
|
heap.push(LevelEntry {
|
|
count,
|
|
level: Reverse(key.level),
|
|
left_bound: key.left_bound,
|
|
group_size: value.size,
|
|
any_docid: intersection.min().unwrap(),
|
|
});
|
|
}
|
|
}
|
|
|
|
while let Some(LevelEntry { count, level, left_bound, group_size, any_docid }) = heap.pop()
|
|
{
|
|
if let Reverse(0) = level {
|
|
match (callback)(left_bound, count, any_docid)? {
|
|
ControlFlow::Continue(_) => (),
|
|
ControlFlow::Break(_) => return Ok(()),
|
|
}
|
|
} else {
|
|
let starting_key = FacetGroupKey { field_id, level: level.0 - 1, left_bound };
|
|
for el in db.range(rtxn, &(&starting_key..))?.take(group_size as usize) {
|
|
let (key, value) = el?;
|
|
// The range is unbounded on the right and the group size for the highest level is MAX,
|
|
// so we need to check that we are not iterating over the next field id
|
|
if key.field_id != field_id {
|
|
break;
|
|
}
|
|
let intersection = value.bitmap & candidates;
|
|
let count = intersection.len();
|
|
if count != 0 {
|
|
heap.push(LevelEntry {
|
|
count,
|
|
level: Reverse(key.level),
|
|
left_bound: key.left_bound,
|
|
group_size: value.size,
|
|
any_docid: intersection.min().unwrap(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Iterate over the facets values by lexicographic order.
|
|
struct LexicographicFacetDistribution<'t, CB>
|
|
where
|
|
CB: FnMut(&'t [u8], u64, DocumentId) -> Result<ControlFlow<()>>,
|
|
{
|
|
rtxn: &'t heed::RoTxn<'t>,
|
|
db: heed::Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
|
|
field_id: u16,
|
|
callback: CB,
|
|
}
|
|
|
|
impl<'t, CB> LexicographicFacetDistribution<'t, CB>
|
|
where
|
|
CB: FnMut(&'t [u8], u64, DocumentId) -> Result<ControlFlow<()>>,
|
|
{
|
|
fn iterate_level_0(
|
|
&mut self,
|
|
candidates: &RoaringBitmap,
|
|
starting_bound: &'t [u8],
|
|
group_size: usize,
|
|
) -> Result<ControlFlow<()>> {
|
|
let starting_key =
|
|
FacetGroupKey { field_id: self.field_id, level: 0, left_bound: starting_bound };
|
|
let iter = self.db.range(self.rtxn, &(starting_key..))?.take(group_size);
|
|
for el in iter {
|
|
let (key, value) = el?;
|
|
// The range is unbounded on the right and the group size for the highest level is MAX,
|
|
// so we need to check that we are not iterating over the next field id
|
|
if key.field_id != self.field_id {
|
|
return Ok(ControlFlow::Break(()));
|
|
}
|
|
let docids_in_common = value.bitmap & candidates;
|
|
if !docids_in_common.is_empty() {
|
|
let any_docid_in_common = docids_in_common.min().unwrap();
|
|
match (self.callback)(key.left_bound, docids_in_common.len(), any_docid_in_common)?
|
|
{
|
|
ControlFlow::Continue(_) => (),
|
|
ControlFlow::Break(_) => return Ok(ControlFlow::Break(())),
|
|
}
|
|
}
|
|
}
|
|
Ok(ControlFlow::Continue(()))
|
|
}
|
|
|
|
fn iterate(
|
|
&mut self,
|
|
candidates: &RoaringBitmap,
|
|
level: u8,
|
|
starting_bound: &'t [u8],
|
|
group_size: usize,
|
|
) -> Result<ControlFlow<()>> {
|
|
if level == 0 {
|
|
return self.iterate_level_0(candidates, starting_bound, group_size);
|
|
}
|
|
let starting_key =
|
|
FacetGroupKey { field_id: self.field_id, level, left_bound: starting_bound };
|
|
let iter = self.db.range(self.rtxn, &(&starting_key..))?.take(group_size);
|
|
|
|
for el in iter {
|
|
let (key, value) = el?;
|
|
// The range is unbounded on the right and the group size for the highest level is MAX,
|
|
// so we need to check that we are not iterating over the next field id
|
|
if key.field_id != self.field_id {
|
|
return Ok(ControlFlow::Break(()));
|
|
}
|
|
let docids_in_common = value.bitmap & candidates;
|
|
if !docids_in_common.is_empty() {
|
|
let cf = self.iterate(
|
|
&docids_in_common,
|
|
level - 1,
|
|
key.left_bound,
|
|
value.size as usize,
|
|
)?;
|
|
match cf {
|
|
ControlFlow::Continue(_) => (),
|
|
ControlFlow::Break(_) => return Ok(ControlFlow::Break(())),
|
|
}
|
|
}
|
|
}
|
|
Ok(ControlFlow::Continue(()))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use std::ops::ControlFlow;
|
|
|
|
use heed::BytesDecode;
|
|
use roaring::RoaringBitmap;
|
|
|
|
use super::lexicographically_iterate_over_facet_distribution;
|
|
use crate::heed_codec::facet::OrderedF64Codec;
|
|
use crate::milli_snap;
|
|
use crate::search::facet::tests::{get_random_looking_index, get_simple_index};
|
|
|
|
#[test]
|
|
fn filter_distribution_all() {
|
|
let indexes = [get_simple_index(), get_random_looking_index()];
|
|
for (i, index) in indexes.iter().enumerate() {
|
|
let txn = index.env.read_txn().unwrap();
|
|
let candidates = (0..=255).collect::<RoaringBitmap>();
|
|
let mut results = String::new();
|
|
lexicographically_iterate_over_facet_distribution(
|
|
&txn,
|
|
index.content,
|
|
0,
|
|
&candidates,
|
|
|facet, count, _| {
|
|
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
|
|
results.push_str(&format!("{facet}: {count}\n"));
|
|
Ok(ControlFlow::Continue(()))
|
|
},
|
|
)
|
|
.unwrap();
|
|
milli_snap!(results, i);
|
|
|
|
txn.commit().unwrap();
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn filter_distribution_all_stop_early() {
|
|
let indexes = [get_simple_index(), get_random_looking_index()];
|
|
for (i, index) in indexes.iter().enumerate() {
|
|
let txn = index.env.read_txn().unwrap();
|
|
let candidates = (0..=255).collect::<RoaringBitmap>();
|
|
let mut results = String::new();
|
|
let mut nbr_facets = 0;
|
|
lexicographically_iterate_over_facet_distribution(
|
|
&txn,
|
|
index.content,
|
|
0,
|
|
&candidates,
|
|
|facet, count, _| {
|
|
let facet = OrderedF64Codec::bytes_decode(facet).unwrap();
|
|
if nbr_facets == 100 {
|
|
Ok(ControlFlow::Break(()))
|
|
} else {
|
|
nbr_facets += 1;
|
|
results.push_str(&format!("{facet}: {count}\n"));
|
|
Ok(ControlFlow::Continue(()))
|
|
}
|
|
},
|
|
)
|
|
.unwrap();
|
|
milli_snap!(results, i);
|
|
|
|
txn.commit().unwrap();
|
|
}
|
|
}
|
|
}
|