diff --git a/Cargo.lock b/Cargo.lock index f32d2f133..330588564 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -617,6 +617,7 @@ dependencies = [ "maplit", "memmap", "near-proximity", + "num-traits", "obkv", "once_cell", "ordered-float", @@ -675,9 +676,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.12" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac267bcc07f48ee5f8935ab0d24f316fb722d7a1292e2913f0cc196b29ffd611" +checksum = "9a64b1ec5cda2586e284722486d802acf1f7dbdc623e2bfc57e65ca1cd099290" dependencies = [ "autocfg", ] diff --git a/Cargo.toml b/Cargo.toml index bee4ebddc..2510cb245 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ levenshtein_automata = { version = "0.2.0", features = ["fst_automaton"] } linked-hash-map = "0.5.3" memmap = "0.7.0" near-proximity = { git = "https://github.com/Kerollmops/plane-sweep-proximity", rev = "6608205" } +num-traits = "0.2.14" obkv = "0.1.0" once_cell = "1.4.0" ordered-float = "2.0.0" diff --git a/src/search.rs b/src/search.rs index e7100e148..e6fcefc62 100644 --- a/src/search.rs +++ b/src/search.rs @@ -1,7 +1,9 @@ use std::borrow::Cow; use std::collections::{HashMap, HashSet}; -use std::fmt; +use std::error::Error as StdError; +use std::fmt::{self, Debug}; use std::ops::Bound::{self, Unbounded, Included, Excluded}; +use std::str::FromStr; use anyhow::{bail, ensure, Context}; use fst::{IntoStreamer, Streamer}; @@ -9,11 +11,12 @@ use heed::types::DecodeIgnore; use levenshtein_automata::DFA; use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; use log::debug; +use num_traits::Bounded; use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; use crate::facet::FacetType; -use crate::heed_codec::facet::FacetLevelValueI64Codec; +use crate::heed_codec::facet::{FacetLevelValueI64Codec, FacetLevelValueF64Codec}; use crate::mdfs::Mdfs; use crate::query_tokens::{QueryTokens, QueryToken}; use crate::{Index, DocumentId}; @@ -24,20 +27,21 @@ static LEVDIST1: Lazy = Lazy::new(|| LevBuilder::new(1, true)); static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); // TODO support also floats -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum FacetOperator { - GreaterThan(i64), - GreaterThanOrEqual(i64), - LowerThan(i64), - LowerThanOrEqual(i64), - Equal(i64), - Between(i64, i64), +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum FacetOperator { + GreaterThan(T), + GreaterThanOrEqual(T), + LowerThan(T), + LowerThanOrEqual(T), + Equal(T), + Between(T, T), } // TODO also support ANDs, ORs, NOTs. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Copy, Clone, PartialEq)] pub enum FacetCondition { - Operator(u8, FacetOperator), + OperatorI64(u8, FacetOperator), + OperatorF64(u8, FacetOperator), } impl FacetCondition { @@ -48,7 +52,6 @@ impl FacetCondition { ) -> anyhow::Result> { use FacetCondition::*; - use FacetOperator::*; let fields_ids_map = index.fields_ids_map(rtxn)?; let faceted_fields = index.faceted_fields(rtxn)?; @@ -64,33 +67,44 @@ impl FacetCondition { let field_id = fields_ids_map.id(&field_name).with_context(|| format!("field {} not found", field_name))?; let field_type = faceted_fields.get(&field_id).with_context(|| format!("field {} is not faceted", field_name))?; - ensure!(*field_type == FacetType::Integer, "Only conditions on integer facets"); + match field_type { + FacetType::Integer => Self::parse_condition(iter).map(|op| Some(OperatorI64(field_id, op))), + FacetType::Float => Self::parse_condition(iter).map(|op| Some(OperatorF64(field_id, op))), + FacetType::String => bail!("invalid facet type"), + } + } + fn parse_condition<'a, T: FromStr>( + mut iter: impl Iterator, + ) -> anyhow::Result> + where T::Err: Send + Sync + StdError + 'static, + { + use FacetOperator::*; match iter.next() { Some(">") => { let param = iter.next().context("missing parameter")?; let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(Some(Operator(field_id, GreaterThan(value)))) + Ok(GreaterThan(value)) }, Some(">=") => { let param = iter.next().context("missing parameter")?; let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(Some(Operator(field_id, GreaterThanOrEqual(value)))) + Ok(GreaterThanOrEqual(value)) }, Some("<") => { let param = iter.next().context("missing parameter")?; let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(Some(Operator(field_id, LowerThan(value)))) + Ok(LowerThan(value)) }, Some("<=") => { let param = iter.next().context("missing parameter")?; let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(Some(Operator(field_id, LowerThanOrEqual(value)))) + Ok(LowerThanOrEqual(value)) }, Some("=") => { let param = iter.next().context("missing parameter")?; let value = param.parse().with_context(|| format!("invalid parameter ({:?})", param))?; - Ok(Some(Operator(field_id, Equal(value)))) + Ok(Equal(value)) }, Some(otherwise) => { // BETWEEN or X TO Y (both inclusive) @@ -98,7 +112,7 @@ impl FacetCondition { ensure!(iter.next().map_or(false, |s| s.eq_ignore_ascii_case("to")), "TO keyword missing or invalid"); let next = iter.next().context("missing second TO parameter")?; let right_param = next.parse().with_context(|| format!("invalid second TO parameter ({:?})", next))?; - Ok(Some(Operator(field_id, Between(left_param, right_param)))) + Ok(Between(left_param, right_param)) }, None => bail!("missing facet filter first parameter"), } @@ -229,19 +243,23 @@ impl<'a> Search<'a> { /// Aggregates the documents ids that are part of the specified range automatically /// going deeper through the levels. - fn explore_facet_levels( + fn explore_facet_levels( &self, field_id: u8, level: u8, - left: Bound, - right: Bound, + left: Bound, + right: Bound, output: &mut RoaringBitmap, ) -> anyhow::Result<()> + where + T: Copy + PartialEq + PartialOrd + Bounded + Debug, + KC: heed::BytesDecode<'a, DItem = (u8, u8, T, T)>, + KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, { match (left, right) { // If the request is an exact value we must go directly to the deepest level. (Included(l), Included(r)) if l == r && level > 0 => { - return self.explore_facet_levels(field_id, 0, left, right, output); + return self.explore_facet_levels::(field_id, 0, left, right, output); }, // lower TO upper when lower > upper must return no result (Included(l), Included(r)) if l > r => return Ok(()), @@ -257,12 +275,12 @@ impl<'a> Search<'a> { // We must create a custom iterator to be able to iterate over the // requested range as the range iterator cannot express some conditions. let left_bound = match left { - Included(left) => Included((field_id, level, left, i64::MIN)), - Excluded(left) => Excluded((field_id, level, left, i64::MIN)), + Included(left) => Included((field_id, level, left, T::min_value())), + Excluded(left) => Excluded((field_id, level, left, T::min_value())), Unbounded => Unbounded, }; - let right_bound = Included((field_id, level, i64::MAX, i64::MAX)); - let db = self.index.facet_field_id_value_docids.remap_key_type::(); + let right_bound = Included((field_id, level, T::max_value(), T::max_value())); + let db = self.index.facet_field_id_value_docids.remap_key_type::(); let iter = db .range(self.rtxn, &(left_bound, right_bound))? .take_while(|r| r.as_ref().map_or(true, |((.., r), _)| { @@ -277,7 +295,7 @@ impl<'a> Search<'a> { for (i, result) in iter.enumerate() { let ((_fid, _level, l, r), docids) = result?; - debug!("{} to {} (level {}) found {} documents", l, r, _level, docids.len()); + debug!("{:?} to {:?} (level {}) found {} documents", l, r, _level, docids.len()); output.union_with(&docids); // We save the leftest and rightest bounds we actually found at this level. if i == 0 { left_found = Some(l); } @@ -298,18 +316,18 @@ impl<'a> Search<'a> { if !matches!(left, Included(l) if l == left_found) { let sub_right = Excluded(left_found); debug!("calling left with {:?} to {:?} (level {})", left, sub_right, deeper_level); - self.explore_facet_levels(field_id, deeper_level, left, sub_right, output)?; + self.explore_facet_levels::(field_id, deeper_level, left, sub_right, output)?; } if !matches!(right, Included(r) if r == right_found) { let sub_left = Excluded(right_found); debug!("calling right with {:?} to {:?} (level {})", sub_left, right, deeper_level); - self.explore_facet_levels(field_id, deeper_level, sub_left, right, output)?; + self.explore_facet_levels::(field_id, deeper_level, sub_left, right, output)?; } }, None => { // If we found nothing at this level it means that we must find // the same bounds but at a deeper, more precise level. - self.explore_facet_levels(field_id, deeper_level, left, right, output)?; + self.explore_facet_levels::(field_id, deeper_level, left, right, output)?; }, } @@ -327,10 +345,10 @@ impl<'a> Search<'a> { }; // We create the original candidates with the facet conditions results. + use FacetOperator::*; let facet_candidates = match self.facet_condition { - Some(FacetCondition::Operator(fid, operator)) => { - use FacetOperator::*; - + // TODO make that generic over floats and integers. + Some(FacetCondition::OperatorI64(fid, operator)) => { // Make sure we always bound the ranges with the field id and the level, // as the facets values are all in the same database and prefixed by the // field id and the level. @@ -357,7 +375,40 @@ impl<'a> Search<'a> { match biggest_level { Some(level) => { let mut output = RoaringBitmap::new(); - self.explore_facet_levels(fid, level, left, right, &mut output)?; + self.explore_facet_levels::(fid, level, left, right, &mut output)?; + Some(output) + }, + None => None, + } + }, + Some(FacetCondition::OperatorF64(fid, operator)) => { + // Make sure we always bound the ranges with the field id and the level, + // as the facets values are all in the same database and prefixed by the + // field id and the level. + let (left, right) = match operator { + GreaterThan(val) => (Excluded(val), Included(f64::MAX)), + GreaterThanOrEqual(val) => (Included(val), Included(f64::MAX)), + LowerThan(val) => (Included(f64::MIN), Excluded(val)), + LowerThanOrEqual(val) => (Included(f64::MIN), Included(val)), + Equal(val) => (Included(val), Included(val)), + Between(left, right) => (Included(left), Included(right)), + }; + + let db = self.index + .facet_field_id_value_docids + .remap_key_type::(); + + // Ask for the biggest value that can exist for this specific field, if it exists + // that's fine if it don't, the value just before will be returned instead. + let biggest_level = db + .remap_data_type::() + .get_lower_than_or_equal_to(self.rtxn, &(fid, u8::MAX, f64::MAX, f64::MAX))? + .and_then(|((id, level, _, _), _)| if id == fid { Some(level) } else { None }); + + match biggest_level { + Some(level) => { + let mut output = RoaringBitmap::new(); + self.explore_facet_levels::(fid, level, left, right, &mut output)?; Some(output) }, None => None, diff --git a/src/update/facet_levels.rs b/src/update/facet_levels.rs index bc8f7121f..4a7769b7a 100644 --- a/src/update/facet_levels.rs +++ b/src/update/facet_levels.rs @@ -6,10 +6,12 @@ use heed::types::{ByteSlice, DecodeIgnore}; use heed::{BytesEncode, Error}; use itertools::Itertools; use log::debug; +use num_traits::{Bounded, Zero}; use roaring::RoaringBitmap; use crate::facet::FacetType; -use crate::heed_codec::{facet::FacetLevelValueI64Codec, CboRoaringBitmapCodec}; +use crate::heed_codec::CboRoaringBitmapCodec; +use crate::heed_codec::facet::{FacetLevelValueI64Codec, FacetLevelValueF64Codec}; use crate::Index; use crate::update::index_documents::WriteMethod; use crate::update::index_documents::{create_writer, writer_into_reader, write_into_lmdb_database}; @@ -68,26 +70,47 @@ impl<'t, 'u, 'i> FacetLevels<'t, 'u, 'i> { debug!("Computing and writing the facet values levels docids into LMDB on disk..."); for (field_id, facet_type) in faceted_fields { - if facet_type == FacetType::String { continue } + let content = match facet_type { + FacetType::Integer => { + clear_field_levels::( + self.wtxn, + self.index.facet_field_id_value_docids, + field_id, + )?; - clear_field_levels( - self.wtxn, - self.index.facet_field_id_value_docids, - field_id, - )?; + compute_facet_levels::( + self.wtxn, + self.index.facet_field_id_value_docids, + self.chunk_compression_type, + self.chunk_compression_level, + self.chunk_fusing_shrink_size, + self.last_level_size, + self.number_of_levels, + self.easing_function, + field_id, + )? + }, + FacetType::Float => { + clear_field_levels::( + self.wtxn, + self.index.facet_field_id_value_docids, + field_id, + )?; - let content = compute_facet_levels( - self.wtxn, - self.index.facet_field_id_value_docids, - self.chunk_compression_type, - self.chunk_compression_level, - self.chunk_fusing_shrink_size, - self.last_level_size, - self.number_of_levels, - self.easing_function, - field_id, - facet_type, - )?; + compute_facet_levels::( + self.wtxn, + self.index.facet_field_id_value_docids, + self.chunk_compression_type, + self.chunk_compression_level, + self.chunk_fusing_shrink_size, + self.last_level_size, + self.number_of_levels, + self.easing_function, + field_id, + )? + }, + FacetType::String => continue, + }; write_into_lmdb_database( self.wtxn, @@ -102,20 +125,26 @@ impl<'t, 'u, 'i> FacetLevels<'t, 'u, 'i> { } } -fn clear_field_levels( - wtxn: &mut heed::RwTxn, +fn clear_field_levels<'t, T: 't, KC>( + wtxn: &'t mut heed::RwTxn, db: heed::Database, field_id: u8, ) -> heed::Result<()> +where + T: Copy + Bounded, + KC: heed::BytesDecode<'t, DItem = (u8, u8, T, T)>, + KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, { - let range = (field_id, 1, i64::MIN, i64::MIN)..=(field_id, u8::MAX, i64::MAX, i64::MAX); - db.remap_key_type::() + let left = (field_id, 1, T::min_value(), T::min_value()); + let right = (field_id, u8::MAX, T::max_value(), T::max_value()); + let range = left..=right; + db.remap_key_type::() .delete_range(wtxn, &range) .map(drop) } -fn compute_facet_levels( - rtxn: &heed::RoTxn, +fn compute_facet_levels<'t, T: 't, KC>( + rtxn: &'t heed::RoTxn, db: heed::Database, compression_type: CompressionType, compression_level: Option, @@ -124,8 +153,11 @@ fn compute_facet_levels( number_of_levels: NonZeroUsize, easing_function: EasingName, field_id: u8, - facet_type: FacetType, ) -> anyhow::Result> +where + T: Copy + PartialEq + PartialOrd + Bounded + Zero, + KC: heed::BytesDecode<'t, DItem = (u8, u8, T, T)>, + KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, { let first_level_size = db.prefix_iter(rtxn, &[field_id])? .remap_types::() @@ -137,7 +169,12 @@ fn compute_facet_levels( create_writer(compression_type, compression_level, file) })?; - let level_0_range = (field_id, 0, i64::MIN, i64::MIN)..=(field_id, 0, i64::MAX, i64::MAX); + let level_0_range = { + let left = (field_id, 0, T::min_value(), T::min_value()); + let right = (field_id, 0, T::max_value(), T::max_value()); + left..=right + }; + let level_sizes_iter = levels_iterator(first_level_size, last_level_size.get(), number_of_levels.get(), easing_function) .map(|size| (first_level_size as f64 / size as f64).ceil() as usize) @@ -147,13 +184,11 @@ fn compute_facet_levels( // TODO we must not create levels with identical group sizes. for (level, level_entry_sizes) in level_sizes_iter { - let mut left = 0; - let mut right = 0; + let mut left = T::zero(); + let mut right = T::zero(); let mut group_docids = RoaringBitmap::new(); - dbg!(level, level_entry_sizes, first_level_size); - - let db = db.remap_key_type::(); + let db = db.remap_key_type::(); for (i, result) in db.range(rtxn, &level_0_range)?.enumerate() { let ((_field_id, _level, value, _right), docids) = result?; @@ -162,7 +197,7 @@ fn compute_facet_levels( } else if i % level_entry_sizes == 0 { // we found the first bound of the next group, we must store the left // and right bounds associated with the docids. - write_entry(&mut writer, field_id, level as u8, left, right, &group_docids)?; + write_entry::(&mut writer, field_id, level as u8, left, right, &group_docids)?; // We save the left bound for the new group and also reset the docids. group_docids = RoaringBitmap::new(); @@ -175,24 +210,26 @@ fn compute_facet_levels( } if !group_docids.is_empty() { - write_entry(&mut writer, field_id, level as u8, left, right, &group_docids)?; + write_entry::(&mut writer, field_id, level as u8, left, right, &group_docids)?; } } writer_into_reader(writer, shrink_size) } -fn write_entry( +fn write_entry( writer: &mut Writer, field_id: u8, level: u8, - left: i64, - right: i64, + left: T, + right: T, ids: &RoaringBitmap, ) -> anyhow::Result<()> +where + KC: for<'x> heed::BytesEncode<'x, EItem = (u8, u8, T, T)>, { let key = (field_id, level, left, right); - let key = FacetLevelValueI64Codec::bytes_encode(&key).ok_or(Error::Encoding)?; + let key = KC::bytes_encode(&key).ok_or(Error::Encoding)?; let data = CboRoaringBitmapCodec::bytes_encode(&ids).ok_or(Error::Encoding)?; writer.insert(&key, &data)?; Ok(())