From 89d075871362222de1205af4a1ebfa6a80a29d08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Mon, 23 Aug 2021 11:37:18 +0200 Subject: [PATCH] Revert "Revert "Sort at query time"" --- benchmarks/benches/search_songs.rs | 8 +- http-ui/src/main.rs | 4 +- milli/Cargo.toml | 1 - milli/src/criterion.rs | 69 ++++-- milli/src/error.rs | 10 + milli/src/index.rs | 36 ++- milli/src/lib.rs | 2 +- milli/src/search/criteria/asc_desc.rs | 102 ++++++-- milli/src/search/criteria/mod.rs | 26 +- milli/src/search/facet/facet_string.rs | 328 +++++++++++++++++++++---- milli/src/search/mod.rs | 34 ++- milli/src/update/settings.rs | 30 ++- milli/tests/assets/test_set.ndjson | 34 +-- milli/tests/search/distinct.rs | 2 +- milli/tests/search/filters.rs | 2 +- milli/tests/search/mod.rs | 23 +- milli/tests/search/query_criteria.rs | 138 +++++++++-- 17 files changed, 701 insertions(+), 148 deletions(-) diff --git a/benchmarks/benches/search_songs.rs b/benchmarks/benches/search_songs.rs index 726040692..6b11799ec 100644 --- a/benchmarks/benches/search_songs.rs +++ b/benchmarks/benches/search_songs.rs @@ -52,9 +52,9 @@ fn bench_songs(c: &mut criterion::Criterion) { milli::default_criteria().iter().map(|criteria| criteria.to_string()).collect(); let default_criterion = default_criterion.iter().map(|s| s.as_str()); let asc_default: Vec<&str> = - std::iter::once("asc(released-timestamp)").chain(default_criterion.clone()).collect(); + std::iter::once("released-timestamp:asc").chain(default_criterion.clone()).collect(); let desc_default: Vec<&str> = - std::iter::once("desc(released-timestamp)").chain(default_criterion.clone()).collect(); + std::iter::once("released-timestamp:desc").chain(default_criterion.clone()).collect(); let basic_with_quote: Vec = BASE_CONF .queries @@ -118,12 +118,12 @@ fn bench_songs(c: &mut criterion::Criterion) { }, utils::Conf { group_name: "asc", - criterion: Some(&["asc(released-timestamp)"]), + criterion: Some(&["released-timestamp:desc"]), ..BASE_CONF }, utils::Conf { group_name: "desc", - criterion: Some(&["desc(released-timestamp)"]), + criterion: Some(&["released-timestamp:desc"]), ..BASE_CONF }, diff --git a/http-ui/src/main.rs b/http-ui/src/main.rs index ee32882c0..b34418465 100644 --- a/http-ui/src/main.rs +++ b/http-ui/src/main.rs @@ -1030,7 +1030,7 @@ mod tests { displayed_attributes: Setting::Set(vec!["name".to_string()]), searchable_attributes: Setting::Set(vec!["age".to_string()]), filterable_attributes: Setting::Set(hashset! { "age".to_string() }), - criteria: Setting::Set(vec!["asc(age)".to_string()]), + criteria: Setting::Set(vec!["age:asc".to_string()]), stop_words: Setting::Set(btreeset! { "and".to_string() }), synonyms: Setting::Set(hashmap! { "alex".to_string() => vec!["alexey".to_string()] }), }; @@ -1058,7 +1058,7 @@ mod tests { Token::Str("criteria"), Token::Some, Token::Seq { len: Some(1) }, - Token::Str("asc(age)"), + Token::Str("age:asc"), Token::SeqEnd, Token::Str("stopWords"), Token::Some, diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 4bec3d69d..0c6fc6763 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -25,7 +25,6 @@ obkv = "0.2.0" once_cell = "1.5.2" ordered-float = "2.1.1" rayon = "1.5.0" -regex = "1.4.3" roaring = "0.6.6" serde = { version = "1.0.123", features = ["derive"] } serde_json = { version = "1.0.62", features = ["preserve_order"] } diff --git a/milli/src/criterion.rs b/milli/src/criterion.rs index cc1fca01f..47eb7c7dc 100644 --- a/milli/src/criterion.rs +++ b/milli/src/criterion.rs @@ -1,15 +1,10 @@ use std::fmt; use std::str::FromStr; -use once_cell::sync::Lazy; -use regex::Regex; use serde::{Deserialize, Serialize}; use crate::error::{Error, UserError}; -static ASC_DESC_REGEX: Lazy = - Lazy::new(|| Regex::new(r#"(asc|desc)\(([\w_-]+)\)"#).unwrap()); - #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum Criterion { /// Sorted by decreasing number of matched query terms. @@ -17,10 +12,13 @@ pub enum Criterion { Words, /// Sorted by increasing number of typos. Typo, + /// Dynamically sort at query time the documents. None, one or multiple Asc/Desc sortable + /// attributes can be used in place of this criterion at query time. + Sort, /// Sorted by increasing distance between matched query terms. Proximity, /// Documents with quey words contained in more important - /// attributes are considred better. + /// attributes are considered better. Attribute, /// Sorted by the similarity of the matched words with the query words. Exactness, @@ -43,29 +41,46 @@ impl Criterion { impl FromStr for Criterion { type Err = Error; - fn from_str(txt: &str) -> Result { - match txt { + fn from_str(text: &str) -> Result { + match text { "words" => Ok(Criterion::Words), "typo" => Ok(Criterion::Typo), + "sort" => Ok(Criterion::Sort), "proximity" => Ok(Criterion::Proximity), "attribute" => Ok(Criterion::Attribute), "exactness" => Ok(Criterion::Exactness), - text => { - let caps = ASC_DESC_REGEX - .captures(text) - .ok_or_else(|| UserError::InvalidCriterionName { name: text.to_string() })?; - let order = caps.get(1).unwrap().as_str(); - let field_name = caps.get(2).unwrap().as_str(); - match order { - "asc" => Ok(Criterion::Asc(field_name.to_string())), - "desc" => Ok(Criterion::Desc(field_name.to_string())), - text => { - return Err( - UserError::InvalidCriterionName { name: text.to_string() }.into() - ) - } - } - } + text => match AscDesc::from_str(text) { + Ok(AscDesc::Asc(field)) => Ok(Criterion::Asc(field)), + Ok(AscDesc::Desc(field)) => Ok(Criterion::Desc(field)), + Err(error) => Err(error.into()), + }, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +pub enum AscDesc { + Asc(String), + Desc(String), +} + +impl AscDesc { + pub fn field(&self) -> &str { + match self { + AscDesc::Asc(field) => field, + AscDesc::Desc(field) => field, + } + } +} + +impl FromStr for AscDesc { + type Err = UserError; + + fn from_str(text: &str) -> Result { + match text.rsplit_once(':') { + Some((field_name, "asc")) => Ok(AscDesc::Asc(field_name.to_string())), + Some((field_name, "desc")) => Ok(AscDesc::Desc(field_name.to_string())), + _ => Err(UserError::InvalidCriterionName { name: text.to_string() }), } } } @@ -74,6 +89,7 @@ pub fn default_criteria() -> Vec { vec![ Criterion::Words, Criterion::Typo, + Criterion::Sort, Criterion::Proximity, Criterion::Attribute, Criterion::Exactness, @@ -87,11 +103,12 @@ impl fmt::Display for Criterion { match self { Words => f.write_str("words"), Typo => f.write_str("typo"), + Sort => f.write_str("sort"), Proximity => f.write_str("proximity"), Attribute => f.write_str("attribute"), Exactness => f.write_str("exactness"), - Asc(attr) => write!(f, "asc({})", attr), - Desc(attr) => write!(f, "desc({})", attr), + Asc(attr) => write!(f, "{}:asc", attr), + Desc(attr) => write!(f, "{}:desc", attr), } } } diff --git a/milli/src/error.rs b/milli/src/error.rs index 713935869..9bda74631 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -58,6 +58,7 @@ pub enum UserError { InvalidFacetsDistribution { invalid_facets_name: HashSet }, InvalidFilter(pest::error::Error), InvalidFilterAttribute(pest::error::Error), + InvalidSortableAttribute { field: String, valid_fields: HashSet }, InvalidStoreFile, MaxDatabaseSizeReached, MissingDocumentId { document: Object }, @@ -226,6 +227,15 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco ) } Self::InvalidFilterAttribute(error) => error.fmt(f), + Self::InvalidSortableAttribute { field, valid_fields } => { + let valid_names = + valid_fields.iter().map(AsRef::as_ref).collect::>().join(", "); + write!( + f, + "Attribute {} is not sortable, available sortable attributes are: {}", + field, valid_names + ) + } Self::MissingDocumentId { document } => { let json = serde_json::to_string(document).unwrap(); write!(f, "document doesn't have an identifier {}", json) diff --git a/milli/src/index.rs b/milli/src/index.rs index 120bcbadf..e2ab51a1c 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -28,6 +28,7 @@ pub mod main_key { pub const DISTINCT_FIELD_KEY: &str = "distinct-field-key"; pub const DOCUMENTS_IDS_KEY: &str = "documents-ids"; pub const FILTERABLE_FIELDS_KEY: &str = "filterable-fields"; + pub const SORTABLE_FIELDS_KEY: &str = "sortable-fields"; pub const FIELD_DISTRIBUTION_KEY: &str = "fields-distribution"; pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; @@ -446,13 +447,45 @@ impl Index { Ok(fields_ids) } + /* sortable fields */ + + /// Writes the sortable fields names in the database. + pub(crate) fn put_sortable_fields( + &self, + wtxn: &mut RwTxn, + fields: &HashSet, + ) -> heed::Result<()> { + self.main.put::<_, Str, SerdeJson<_>>(wtxn, main_key::SORTABLE_FIELDS_KEY, fields) + } + + /// Deletes the sortable fields ids in the database. + pub(crate) fn delete_sortable_fields(&self, wtxn: &mut RwTxn) -> heed::Result { + self.main.delete::<_, Str>(wtxn, main_key::SORTABLE_FIELDS_KEY) + } + + /// Returns the sortable fields names. + pub fn sortable_fields(&self, rtxn: &RoTxn) -> heed::Result> { + Ok(self + .main + .get::<_, Str, SerdeJson<_>>(rtxn, main_key::SORTABLE_FIELDS_KEY)? + .unwrap_or_default()) + } + + /// Identical to `sortable_fields`, but returns ids instead. + pub fn sortable_fields_ids(&self, rtxn: &RoTxn) -> Result> { + let fields = self.sortable_fields(rtxn)?; + let fields_ids_map = self.fields_ids_map(rtxn)?; + Ok(fields.into_iter().filter_map(|name| fields_ids_map.id(&name)).collect()) + } + /* faceted documents ids */ /// Returns the faceted fields names. /// - /// Faceted fields are the union of all the filterable, distinct, and Asc/Desc fields. + /// Faceted fields are the union of all the filterable, sortable, distinct, and Asc/Desc fields. pub fn faceted_fields(&self, rtxn: &RoTxn) -> Result> { let filterable_fields = self.filterable_fields(rtxn)?; + let sortable_fields = self.sortable_fields(rtxn)?; let distinct_field = self.distinct_field(rtxn)?; let asc_desc_fields = self.criteria(rtxn)?.into_iter().filter_map(|criterion| match criterion { @@ -461,6 +494,7 @@ impl Index { }); let mut faceted_fields = filterable_fields; + faceted_fields.extend(sortable_fields); faceted_fields.extend(asc_desc_fields); if let Some(field) = distinct_field { faceted_fields.insert(field.to_owned()); diff --git a/milli/src/lib.rs b/milli/src/lib.rs index f3bababf6..2b0bd2ed4 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -22,7 +22,7 @@ use std::result::Result as StdResult; use fxhash::{FxHasher32, FxHasher64}; use serde_json::{Map, Value}; -pub use self::criterion::{default_criteria, Criterion}; +pub use self::criterion::{default_criteria, AscDesc, Criterion}; pub use self::error::{ Error, FieldIdMapMissingEntry, InternalError, SerializationError, UserError, }; diff --git a/milli/src/search/criteria/asc_desc.rs b/milli/src/search/criteria/asc_desc.rs index 4a664d042..6d50c1bb5 100644 --- a/milli/src/search/criteria/asc_desc.rs +++ b/milli/src/search/criteria/asc_desc.rs @@ -7,7 +7,7 @@ use roaring::RoaringBitmap; use super::{Criterion, CriterionParameters, CriterionResult}; use crate::search::criteria::{resolve_query_tree, CriteriaBuilder}; -use crate::search::facet::FacetNumberIter; +use crate::search::facet::{FacetNumberIter, FacetStringIter}; use crate::search::query_tree::Operation; use crate::{FieldId, Index, Result}; @@ -20,7 +20,7 @@ pub struct AscDesc<'t> { rtxn: &'t heed::RoTxn<'t>, field_name: String, field_id: Option, - ascending: bool, + is_ascending: bool, query_tree: Option, candidates: Box> + 't>, allowed_candidates: RoaringBitmap, @@ -53,12 +53,16 @@ impl<'t> AscDesc<'t> { rtxn: &'t heed::RoTxn, parent: Box, field_name: String, - ascending: bool, + is_ascending: bool, ) -> Result { let fields_ids_map = index.fields_ids_map(rtxn)?; let field_id = fields_ids_map.id(&field_name); let faceted_candidates = match field_id { - Some(field_id) => index.number_faceted_documents_ids(rtxn, field_id)?, + Some(field_id) => { + let number_faceted = index.number_faceted_documents_ids(rtxn, field_id)?; + let string_faceted = index.string_faceted_documents_ids(rtxn, field_id)?; + number_faceted | string_faceted + } None => RoaringBitmap::default(), }; @@ -67,7 +71,7 @@ impl<'t> AscDesc<'t> { rtxn, field_name, field_id, - ascending, + is_ascending, query_tree: None, candidates: Box::new(std::iter::empty()), allowed_candidates: RoaringBitmap::new(), @@ -87,7 +91,7 @@ impl<'t> Criterion for AscDesc<'t> { loop { debug!( "Facet {}({}) iteration", - if self.ascending { "Asc" } else { "Desc" }, + if self.is_ascending { "Asc" } else { "Desc" }, self.field_name ); @@ -136,7 +140,7 @@ impl<'t> Criterion for AscDesc<'t> { self.index, self.rtxn, field_id, - self.ascending, + self.is_ascending, candidates & &self.faceted_candidates, )?, None => Box::new(std::iter::empty()), @@ -167,31 +171,49 @@ fn facet_ordered<'t>( index: &'t Index, rtxn: &'t heed::RoTxn, field_id: FieldId, - ascending: bool, + is_ascending: bool, candidates: RoaringBitmap, ) -> Result> + 't>> { if candidates.len() <= CANDIDATES_THRESHOLD { - let iter = iterative_facet_ordered_iter(index, rtxn, field_id, ascending, candidates)?; - Ok(Box::new(iter.map(Ok)) as Box>) + let number_iter = iterative_facet_number_ordered_iter( + index, + rtxn, + field_id, + is_ascending, + candidates.clone(), + )?; + let string_iter = + iterative_facet_string_ordered_iter(index, rtxn, field_id, is_ascending, candidates)?; + Ok(Box::new(number_iter.chain(string_iter).map(Ok)) as Box>) } else { - let facet_fn = if ascending { + let facet_number_fn = if is_ascending { FacetNumberIter::new_reducing } else { FacetNumberIter::new_reverse_reducing }; - let iter = facet_fn(rtxn, index, field_id, candidates)?; - Ok(Box::new(iter.map(|res| res.map(|(_, docids)| docids)))) + let number_iter = facet_number_fn(rtxn, index, field_id, candidates.clone())? + .map(|res| res.map(|(_, docids)| docids)); + + let facet_string_fn = if is_ascending { + FacetStringIter::new_reducing + } else { + FacetStringIter::new_reverse_reducing + }; + let string_iter = facet_string_fn(rtxn, index, field_id, candidates)? + .map(|res| res.map(|(_, _, docids)| docids)); + + Ok(Box::new(number_iter.chain(string_iter))) } } -/// Fetch the whole list of candidates facet values one by one and order them by it. +/// Fetch the whole list of candidates facet number values one by one and order them by it. /// /// This function is fast when the amount of candidates to rank is small. -fn iterative_facet_ordered_iter<'t>( +fn iterative_facet_number_ordered_iter<'t>( index: &'t Index, rtxn: &'t heed::RoTxn, field_id: FieldId, - ascending: bool, + is_ascending: bool, candidates: RoaringBitmap, ) -> Result + 't> { let mut docids_values = Vec::with_capacity(candidates.len() as usize); @@ -199,14 +221,14 @@ fn iterative_facet_ordered_iter<'t>( let left = (field_id, docid, f64::MIN); let right = (field_id, docid, f64::MAX); let mut iter = index.field_id_docid_facet_f64s.range(rtxn, &(left..=right))?; - let entry = if ascending { iter.next() } else { iter.last() }; + let entry = if is_ascending { iter.next() } else { iter.last() }; if let Some(((_, _, value), ())) = entry.transpose()? { docids_values.push((docid, OrderedFloat(value))); } } docids_values.sort_unstable_by_key(|(_, v)| *v); let iter = docids_values.into_iter(); - let iter = if ascending { + let iter = if is_ascending { Box::new(iter) as Box> } else { Box::new(iter.rev()) @@ -216,7 +238,49 @@ fn iterative_facet_ordered_iter<'t>( // required to collect the result into an owned collection (a Vec). // https://github.com/rust-itertools/itertools/issues/499 let vec: Vec<_> = iter - .group_by(|(_, v)| v.clone()) + .group_by(|(_, v)| *v) + .into_iter() + .map(|(_, ids)| ids.map(|(id, _)| id).collect()) + .collect(); + + Ok(vec.into_iter()) +} + +/// Fetch the whole list of candidates facet string values one by one and order them by it. +/// +/// This function is fast when the amount of candidates to rank is small. +fn iterative_facet_string_ordered_iter<'t>( + index: &'t Index, + rtxn: &'t heed::RoTxn, + field_id: FieldId, + is_ascending: bool, + candidates: RoaringBitmap, +) -> Result + 't> { + let mut docids_values = Vec::with_capacity(candidates.len() as usize); + for docid in candidates.iter() { + let left = (field_id, docid, ""); + let right = (field_id, docid.saturating_add(1), ""); + // FIXME Doing this means that it will never be possible to retrieve + // the document with id 2^32, not sure this is a real problem. + let mut iter = index.field_id_docid_facet_strings.range(rtxn, &(left..right))?; + let entry = if is_ascending { iter.next() } else { iter.last() }; + if let Some(((_, _, value), _)) = entry.transpose()? { + docids_values.push((docid, value)); + } + } + docids_values.sort_unstable_by_key(|(_, v)| *v); + let iter = docids_values.into_iter(); + let iter = if is_ascending { + Box::new(iter) as Box> + } else { + Box::new(iter.rev()) + }; + + // The itertools GroupBy iterator doesn't provide an owned version, we are therefore + // required to collect the result into an owned collection (a Vec). + // https://github.com/rust-itertools/itertools/issues/499 + let vec: Vec<_> = iter + .group_by(|(_, v)| *v) .into_iter() .map(|(_, ids)| ids.map(|(id, _)| id).collect()) .collect(); diff --git a/milli/src/search/criteria/mod.rs b/milli/src/search/criteria/mod.rs index 2ba3b388f..61b0fe049 100644 --- a/milli/src/search/criteria/mod.rs +++ b/milli/src/search/criteria/mod.rs @@ -12,6 +12,7 @@ use self::r#final::Final; use self::typo::Typo; use self::words::Words; use super::query_tree::{Operation, PrimitiveQueryPart, Query, QueryKind}; +use crate::criterion::AscDesc as AscDescName; use crate::search::{word_derivations, WordDerivationsCache}; use crate::{DocumentId, FieldId, Index, Result, TreeLevel}; @@ -273,6 +274,7 @@ impl<'t> CriteriaBuilder<'t> { query_tree: Option, primitive_query: Option>, filtered_candidates: Option, + sort_criteria: Option>, ) -> Result> { use crate::criterion::Criterion as Name; @@ -282,8 +284,30 @@ impl<'t> CriteriaBuilder<'t> { Box::new(Initial::new(query_tree, filtered_candidates)) as Box; for name in self.index.criteria(&self.rtxn)? { criterion = match name { - Name::Typo => Box::new(Typo::new(self, criterion)), Name::Words => Box::new(Words::new(self, criterion)), + Name::Typo => Box::new(Typo::new(self, criterion)), + Name::Sort => match sort_criteria { + Some(ref sort_criteria) => { + for asc_desc in sort_criteria { + criterion = match asc_desc { + AscDescName::Asc(field) => Box::new(AscDesc::asc( + &self.index, + &self.rtxn, + criterion, + field.to_string(), + )?), + AscDescName::Desc(field) => Box::new(AscDesc::desc( + &self.index, + &self.rtxn, + criterion, + field.to_string(), + )?), + }; + } + criterion + } + None => criterion, + }, Name::Proximity => Box::new(Proximity::new(self, criterion)), Name::Attribute => Box::new(Attribute::new(self, criterion)), Name::Exactness => Box::new(Exactness::new(self, criterion, &primitive_query)?), diff --git a/milli/src/search/facet/facet_string.rs b/milli/src/search/facet/facet_string.rs index ed5322607..927602c98 100644 --- a/milli/src/search/facet/facet_string.rs +++ b/milli/src/search/facet/facet_string.rs @@ -131,7 +131,7 @@ use std::ops::Bound::{Excluded, Included, Unbounded}; use either::{Either, Left, Right}; use heed::types::{ByteSlice, DecodeIgnore}; -use heed::{Database, LazyDecode, RoRange}; +use heed::{Database, LazyDecode, RoRange, RoRevRange}; use roaring::RoaringBitmap; use crate::heed_codec::facet::{ @@ -206,6 +206,65 @@ impl<'t> Iterator for FacetStringGroupRange<'t> { } } +pub struct FacetStringGroupRevRange<'t> { + iter: RoRevRange< + 't, + FacetLevelValueU32Codec, + LazyDecode>, + >, + end: Bound, +} + +impl<'t> FacetStringGroupRevRange<'t> { + pub fn new( + rtxn: &'t heed::RoTxn, + db: Database, + field_id: FieldId, + level: NonZeroU8, + left: Bound, + right: Bound, + ) -> heed::Result> { + let db = db.remap_types::< + FacetLevelValueU32Codec, + FacetStringZeroBoundsValueCodec, + >(); + let left_bound = match left { + Included(left) => Included((field_id, level, left, u32::MIN)), + Excluded(left) => Excluded((field_id, level, left, u32::MIN)), + Unbounded => Included((field_id, level, u32::MIN, u32::MIN)), + }; + let right_bound = Included((field_id, level, u32::MAX, u32::MAX)); + let iter = db.lazily_decode_data().rev_range(rtxn, &(left_bound, right_bound))?; + Ok(FacetStringGroupRevRange { iter, end: right }) + } +} + +impl<'t> Iterator for FacetStringGroupRevRange<'t> { + type Item = heed::Result<((NonZeroU8, u32, u32), (Option<(&'t str, &'t str)>, RoaringBitmap))>; + + fn next(&mut self) -> Option { + match self.iter.next() { + Some(Ok(((_fid, level, left, right), docids))) => { + let must_be_returned = match self.end { + Included(end) => right <= end, + Excluded(end) => right < end, + Unbounded => true, + }; + if must_be_returned { + match docids.decode() { + Ok((bounds, docids)) => Some(Ok(((level, left, right), (bounds, docids)))), + Err(e) => Some(Err(e)), + } + } else { + None + } + } + Some(Err(e)) => Some(Err(e)), + None => None, + } + } +} + /// An iterator that is used to explore the level 0 of the facets string database. /// /// It yields the facet string and the roaring bitmap associated with it. @@ -280,6 +339,81 @@ impl<'t> Iterator for FacetStringLevelZeroRange<'t> { } } +pub struct FacetStringLevelZeroRevRange<'t> { + iter: RoRevRange< + 't, + FacetStringLevelZeroCodec, + FacetStringLevelZeroValueCodec, + >, +} + +impl<'t> FacetStringLevelZeroRevRange<'t> { + pub fn new( + rtxn: &'t heed::RoTxn, + db: Database, + field_id: FieldId, + left: Bound<&str>, + right: Bound<&str>, + ) -> heed::Result> { + fn encode_value<'a>(buffer: &'a mut Vec, field_id: FieldId, value: &str) -> &'a [u8] { + buffer.extend_from_slice(&field_id.to_be_bytes()); + buffer.push(0); + buffer.extend_from_slice(value.as_bytes()); + &buffer[..] + } + + let mut left_buffer = Vec::new(); + let left_bound = match left { + Included(value) => Included(encode_value(&mut left_buffer, field_id, value)), + Excluded(value) => Excluded(encode_value(&mut left_buffer, field_id, value)), + Unbounded => { + left_buffer.extend_from_slice(&field_id.to_be_bytes()); + left_buffer.push(0); + Included(&left_buffer[..]) + } + }; + + let mut right_buffer = Vec::new(); + let right_bound = match right { + Included(value) => Included(encode_value(&mut right_buffer, field_id, value)), + Excluded(value) => Excluded(encode_value(&mut right_buffer, field_id, value)), + Unbounded => { + right_buffer.extend_from_slice(&field_id.to_be_bytes()); + right_buffer.push(1); // we must only get the level 0 + Excluded(&right_buffer[..]) + } + }; + + let iter = db + .remap_key_type::() + .rev_range(rtxn, &(left_bound, right_bound))? + .remap_types::< + FacetStringLevelZeroCodec, + FacetStringLevelZeroValueCodec + >(); + + Ok(FacetStringLevelZeroRevRange { iter }) + } +} + +impl<'t> Iterator for FacetStringLevelZeroRevRange<'t> { + type Item = heed::Result<(&'t str, &'t str, RoaringBitmap)>; + + fn next(&mut self) -> Option { + match self.iter.next() { + Some(Ok(((_fid, normalized), (original, docids)))) => { + Some(Ok((normalized, original, docids))) + } + Some(Err(e)) => Some(Err(e)), + None => None, + } + } +} + +type EitherStringRange<'t> = Either, FacetStringLevelZeroRange<'t>>; +type EitherStringRevRange<'t> = + Either, FacetStringLevelZeroRevRange<'t>>; + /// An iterator that is used to explore the facet strings level by level, /// it will only return facets strings that are associated with the /// candidates documents ids given. @@ -287,12 +421,45 @@ pub struct FacetStringIter<'t> { rtxn: &'t heed::RoTxn<'t>, db: Database, field_id: FieldId, - level_iters: - Vec<(RoaringBitmap, Either, FacetStringLevelZeroRange<'t>>)>, + level_iters: Vec<(RoaringBitmap, Either, EitherStringRevRange<'t>>)>, must_reduce: bool, } impl<'t> FacetStringIter<'t> { + pub fn new_reducing( + rtxn: &'t heed::RoTxn, + index: &'t Index, + field_id: FieldId, + documents_ids: RoaringBitmap, + ) -> heed::Result> { + let db = index.facet_id_string_docids.remap_types::(); + let highest_iter = Self::highest_iter(rtxn, index, db, field_id)?; + Ok(FacetStringIter { + rtxn, + db, + field_id, + level_iters: vec![(documents_ids, Left(highest_iter))], + must_reduce: true, + }) + } + + pub fn new_reverse_reducing( + rtxn: &'t heed::RoTxn, + index: &'t Index, + field_id: FieldId, + documents_ids: RoaringBitmap, + ) -> heed::Result> { + let db = index.facet_id_string_docids.remap_types::(); + let highest_reverse_iter = Self::highest_reverse_iter(rtxn, index, db, field_id)?; + Ok(FacetStringIter { + rtxn, + db, + field_id, + level_iters: vec![(documents_ids, Right(highest_reverse_iter))], + must_reduce: true, + }) + } + pub fn new_non_reducing( rtxn: &'t heed::RoTxn, index: &'t Index, @@ -300,30 +467,12 @@ impl<'t> FacetStringIter<'t> { documents_ids: RoaringBitmap, ) -> heed::Result> { let db = index.facet_id_string_docids.remap_types::(); - let highest_level = Self::highest_level(rtxn, db, field_id)?.unwrap_or(0); - let highest_iter = match NonZeroU8::new(highest_level) { - Some(highest_level) => Left(FacetStringGroupRange::new( - rtxn, - index.facet_id_string_docids, - field_id, - highest_level, - Unbounded, - Unbounded, - )?), - None => Right(FacetStringLevelZeroRange::new( - rtxn, - index.facet_id_string_docids, - field_id, - Unbounded, - Unbounded, - )?), - }; - + let highest_iter = Self::highest_iter(rtxn, index, db, field_id)?; Ok(FacetStringIter { rtxn, db, field_id, - level_iters: vec![(documents_ids, highest_iter)], + level_iters: vec![(documents_ids, Left(highest_iter))], must_reduce: false, }) } @@ -340,6 +489,62 @@ impl<'t> FacetStringIter<'t> { .transpose()? .map(|(key_bytes, _)| key_bytes[2])) // the level is the third bit } + + fn highest_iter( + rtxn: &'t heed::RoTxn, + index: &'t Index, + db: Database, + field_id: FieldId, + ) -> heed::Result, FacetStringLevelZeroRange<'t>>> { + let highest_level = Self::highest_level(rtxn, db, field_id)?.unwrap_or(0); + match NonZeroU8::new(highest_level) { + Some(highest_level) => FacetStringGroupRange::new( + rtxn, + index.facet_id_string_docids, + field_id, + highest_level, + Unbounded, + Unbounded, + ) + .map(Left), + None => FacetStringLevelZeroRange::new( + rtxn, + index.facet_id_string_docids, + field_id, + Unbounded, + Unbounded, + ) + .map(Right), + } + } + + fn highest_reverse_iter( + rtxn: &'t heed::RoTxn, + index: &'t Index, + db: Database, + field_id: FieldId, + ) -> heed::Result, FacetStringLevelZeroRevRange<'t>>> { + let highest_level = Self::highest_level(rtxn, db, field_id)?.unwrap_or(0); + match NonZeroU8::new(highest_level) { + Some(highest_level) => FacetStringGroupRevRange::new( + rtxn, + index.facet_id_string_docids, + field_id, + highest_level, + Unbounded, + Unbounded, + ) + .map(Left), + None => FacetStringLevelZeroRevRange::new( + rtxn, + index.facet_id_string_docids, + field_id, + Unbounded, + Unbounded, + ) + .map(Right), + } + } } impl<'t> Iterator for FacetStringIter<'t> { @@ -348,6 +553,21 @@ impl<'t> Iterator for FacetStringIter<'t> { fn next(&mut self) -> Option { 'outer: loop { let (documents_ids, last) = self.level_iters.last_mut()?; + let is_ascending = last.is_left(); + + // We remap the different iterator types to make + // the algorithm less complex to understand. + let last = match last { + Left(ascending) => match ascending { + Left(last) => Left(Left(last)), + Right(last) => Right(Left(last)), + }, + Right(descending) => match descending { + Left(last) => Left(Right(last)), + Right(last) => Right(Right(last)), + }, + }; + match last { Left(last) => { for result in last { @@ -359,24 +579,50 @@ impl<'t> Iterator for FacetStringIter<'t> { *documents_ids -= &docids; } - let result = match string_bounds { - Some((left, right)) => FacetStringLevelZeroRange::new( - self.rtxn, - self.db, - self.field_id, - Included(left), - Included(right), - ) - .map(Right), - None => FacetStringGroupRange::new( - self.rtxn, - self.db, - self.field_id, - NonZeroU8::new(level.get() - 1).unwrap(), - Included(left), - Included(right), - ) - .map(Left), + let result = if is_ascending { + match string_bounds { + Some((left, right)) => { + FacetStringLevelZeroRevRange::new( + self.rtxn, + self.db, + self.field_id, + Included(left), + Included(right), + ) + .map(Right) + } + None => FacetStringGroupRevRange::new( + self.rtxn, + self.db, + self.field_id, + NonZeroU8::new(level.get() - 1).unwrap(), + Included(left), + Included(right), + ) + .map(Left), + } + .map(Right) + } else { + match string_bounds { + Some((left, right)) => FacetStringLevelZeroRange::new( + self.rtxn, + self.db, + self.field_id, + Included(left), + Included(right), + ) + .map(Right), + None => FacetStringGroupRange::new( + self.rtxn, + self.db, + self.field_id, + NonZeroU8::new(level.get() - 1).unwrap(), + Included(left), + Included(right), + ) + .map(Left), + } + .map(Left) }; match result { diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 871f464ef..23e5c1834 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -18,6 +18,8 @@ pub(crate) use self::facet::ParserRule; pub use self::facet::{FacetDistribution, FacetNumberIter, FilterCondition, Operator}; pub use self::matching_words::MatchingWords; use self::query_tree::QueryTreeBuilder; +use crate::criterion::AscDesc; +use crate::error::UserError; use crate::search::criteria::r#final::{Final, FinalResult}; use crate::{DocumentId, Index, Result}; @@ -37,6 +39,7 @@ pub struct Search<'a> { filter: Option, offset: usize, limit: usize, + sort_criteria: Option>, optional_words: bool, authorize_typos: bool, words_limit: usize, @@ -51,6 +54,7 @@ impl<'a> Search<'a> { filter: None, offset: 0, limit: 20, + sort_criteria: None, optional_words: true, authorize_typos: true, words_limit: 10, @@ -74,6 +78,11 @@ impl<'a> Search<'a> { self } + pub fn sort_criteria(&mut self, criteria: Vec) -> &mut Search<'a> { + self.sort_criteria = Some(criteria); + self + } + pub fn optional_words(&mut self, value: bool) -> &mut Search<'a> { self.optional_words = value; self @@ -134,8 +143,29 @@ impl<'a> Search<'a> { None => MatchingWords::default(), }; + // We check that we are allowed to use the sort criteria, we check + // that they are declared in the sortable fields. + let sortable_fields = self.index.sortable_fields(self.rtxn)?; + if let Some(sort_criteria) = &self.sort_criteria { + for asc_desc in sort_criteria { + let field = asc_desc.field(); + if !sortable_fields.contains(field) { + return Err(UserError::InvalidSortableAttribute { + field: field.to_string(), + valid_fields: sortable_fields, + } + .into()); + } + } + } + let criteria_builder = criteria::CriteriaBuilder::new(self.rtxn, self.index)?; - let criteria = criteria_builder.build(query_tree, primitive_query, filtered_candidates)?; + let criteria = criteria_builder.build( + query_tree, + primitive_query, + filtered_candidates, + self.sort_criteria.clone(), + )?; match self.index.distinct_field(self.rtxn)? { None => self.perform_sort(NoopDistinct, matching_words, criteria), @@ -199,6 +229,7 @@ impl fmt::Debug for Search<'_> { filter, offset, limit, + sort_criteria, optional_words, authorize_typos, words_limit, @@ -210,6 +241,7 @@ impl fmt::Debug for Search<'_> { .field("filter", filter) .field("offset", offset) .field("limit", limit) + .field("sort_criteria", sort_criteria) .field("optional_words", optional_words) .field("authorize_typos", authorize_typos) .field("words_limit", words_limit) diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 07bdfd6fa..c0b5e4549 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -75,6 +75,7 @@ pub struct Settings<'a, 't, 'u, 'i> { searchable_fields: Setting>, displayed_fields: Setting>, filterable_fields: Setting>, + sortable_fields: Setting>, criteria: Setting>, stop_words: Setting>, distinct_field: Setting, @@ -102,6 +103,7 @@ impl<'a, 't, 'u, 'i> Settings<'a, 't, 'u, 'i> { searchable_fields: Setting::NotSet, displayed_fields: Setting::NotSet, filterable_fields: Setting::NotSet, + sortable_fields: Setting::NotSet, criteria: Setting::NotSet, stop_words: Setting::NotSet, distinct_field: Setting::NotSet, @@ -135,6 +137,10 @@ impl<'a, 't, 'u, 'i> Settings<'a, 't, 'u, 'i> { self.filterable_fields = Setting::Set(names); } + pub fn set_sortable_fields(&mut self, names: HashSet) { + self.sortable_fields = Setting::Set(names); + } + pub fn reset_criteria(&mut self) { self.criteria = Setting::Reset; } @@ -392,6 +398,23 @@ impl<'a, 't, 'u, 'i> Settings<'a, 't, 'u, 'i> { Ok(()) } + fn update_sortable(&mut self) -> Result<()> { + match self.sortable_fields { + Setting::Set(ref fields) => { + let mut new_fields = HashSet::new(); + for name in fields { + new_fields.insert(name.clone()); + } + self.index.put_sortable_fields(self.wtxn, &new_fields)?; + } + Setting::Reset => { + self.index.delete_sortable_fields(self.wtxn)?; + } + Setting::NotSet => (), + } + Ok(()) + } + fn update_criteria(&mut self) -> Result<()> { match self.criteria { Setting::Set(ref fields) => { @@ -446,6 +469,7 @@ impl<'a, 't, 'u, 'i> Settings<'a, 't, 'u, 'i> { self.update_displayed()?; self.update_filterable()?; + self.update_sortable()?; self.update_distinct_field()?; self.update_criteria()?; self.update_primary_key()?; @@ -719,7 +743,7 @@ mod tests { let mut builder = Settings::new(&mut wtxn, &index, 0); // Don't display the generated `id` field. builder.set_displayed_fields(vec![S("name")]); - builder.set_criteria(vec![S("asc(age)")]); + builder.set_criteria(vec![S("age:asc")]); builder.execute(|_, _| ()).unwrap(); // Then index some documents. @@ -953,7 +977,7 @@ mod tests { let mut builder = Settings::new(&mut wtxn, &index, 0); builder.set_displayed_fields(vec!["hello".to_string()]); builder.set_filterable_fields(hashset! { S("age"), S("toto") }); - builder.set_criteria(vec!["asc(toto)".to_string()]); + builder.set_criteria(vec!["toto:asc".to_string()]); builder.execute(|_, _| ()).unwrap(); wtxn.commit().unwrap(); @@ -990,7 +1014,7 @@ mod tests { let mut builder = Settings::new(&mut wtxn, &index, 0); builder.set_displayed_fields(vec!["hello".to_string()]); // It is only Asc(toto), there is a facet database but it is denied to filter with toto. - builder.set_criteria(vec!["asc(toto)".to_string()]); + builder.set_criteria(vec!["toto:asc".to_string()]); builder.execute(|_, _| ()).unwrap(); wtxn.commit().unwrap(); diff --git a/milli/tests/assets/test_set.ndjson b/milli/tests/assets/test_set.ndjson index 599d479ed..89d9f1109 100644 --- a/milli/tests/assets/test_set.ndjson +++ b/milli/tests/assets/test_set.ndjson @@ -1,17 +1,17 @@ -{"id":"A","word_rank":0,"typo_rank":1,"proximity_rank":15,"attribute_rank":505,"exact_rank":5,"asc_desc_rank":0,"title":"hell o","description":"hell o is the fourteenth episode of the american television series glee performing songs with this word","tag":"blue","":""} -{"id":"B","word_rank":2,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":4,"asc_desc_rank":1,"title":"hello","description":"hello is a song recorded by english singer songwriter adele","tag":"red","":""} -{"id":"C","word_rank":0,"typo_rank":1,"proximity_rank":8,"attribute_rank":336,"exact_rank":4,"asc_desc_rank":2,"title":"hell on earth","description":"hell on earth is the third studio album by american hip hop duo mobb deep","tag":"blue","":""} -{"id":"D","word_rank":0,"typo_rank":1,"proximity_rank":10,"attribute_rank":757,"exact_rank":4,"asc_desc_rank":3,"title":"hell on wheels tv series","description":"the construction of the first transcontinental railroad across the united states in the world","tag":"red","":""} -{"id":"E","word_rank":2,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":4,"asc_desc_rank":4,"title":"hello kitty","description":"also known by her full name kitty white is a fictional character produced by the japanese company sanrio","tag":"green","":""} -{"id":"F","word_rank":2,"typo_rank":1,"proximity_rank":0,"attribute_rank":1017,"exact_rank":5,"asc_desc_rank":5,"title":"laptop orchestra","description":"a laptop orchestra lork or lo is a chamber music ensemble consisting primarily of laptops like helo huddersfield experimental laptop orchestra","tag":"blue","":""} -{"id":"G","word_rank":1,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":5,"title":"hello world film","description":"hello world is a 2019 japanese animated sci fi romantic drama film directed by tomohiko ito and produced by graphinica","tag":"red","":""} -{"id":"H","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":4,"title":"world hello day","description":"holiday observed on november 21 to express that conflicts should be resolved through communication rather than the use of force","tag":"green","":""} -{"id":"I","word_rank":0,"typo_rank":0,"proximity_rank":8,"attribute_rank":338,"exact_rank":3,"asc_desc_rank":3,"title":"hello world song","description":"hello world is a song written by tom douglas tony lane and david lee and recorded by american country music group lady antebellum","tag":"blue","":""} -{"id":"J","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":1,"exact_rank":3,"asc_desc_rank":2,"title":"hello cruel world","description":"hello cruel world is an album by new zealand band tall dwarfs","tag":"green","":""} -{"id":"K","word_rank":0,"typo_rank":2,"proximity_rank":9,"attribute_rank":670,"exact_rank":5,"asc_desc_rank":1,"title":"ello creation system","description":"in few word ello was a construction toy created by the american company mattel to engage girls in construction play","tag":"red","":""} -{"id":"L","word_rank":0,"typo_rank":0,"proximity_rank":2,"attribute_rank":250,"exact_rank":4,"asc_desc_rank":0,"title":"good morning world","description":"good morning world is an american sitcom broadcast on cbs tv during the 1967 1968 season","tag":"blue","":""} -{"id":"M","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":0,"asc_desc_rank":0,"title":"hello world america","description":"a perfect match for a perfect engine using the query hello world america","tag":"red","":""} -{"id":"N","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":1,"asc_desc_rank":4,"title":"hello world america unleashed","description":"a very good match for a very good engine using the query hello world america","tag":"green","":""} -{"id":"O","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":10,"exact_rank":0,"asc_desc_rank":6,"title":"a perfect match for a perfect engine using the query hello world america","description":"hello world america","tag":"blue","":""} -{"id":"P","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":12,"exact_rank":1,"asc_desc_rank":3,"title":"a very good match for a very good engine using the query hello world america","description":"hello world america unleashed","tag":"red","":""} -{"id":"Q","word_rank":1,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":2,"title":"hello world","description":"a hello world program generally is a computer program that outputs or displays the message hello world","tag":"green","":""} +{"id":"A","word_rank":0,"typo_rank":1,"proximity_rank":15,"attribute_rank":505,"exact_rank":5,"asc_desc_rank":0,"sort_by_rank":0,"title":"hell o","description":"hell o is the fourteenth episode of the american television series glee performing songs with this word","tag":"blue","":""} +{"id":"B","word_rank":2,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":4,"asc_desc_rank":1,"sort_by_rank":2,"title":"hello","description":"hello is a song recorded by english singer songwriter adele","tag":"red","":""} +{"id":"C","word_rank":0,"typo_rank":1,"proximity_rank":8,"attribute_rank":336,"exact_rank":4,"asc_desc_rank":2,"sort_by_rank":0,"title":"hell on earth","description":"hell on earth is the third studio album by american hip hop duo mobb deep","tag":"blue","":""} +{"id":"D","word_rank":0,"typo_rank":1,"proximity_rank":10,"attribute_rank":757,"exact_rank":4,"asc_desc_rank":3,"sort_by_rank":2,"title":"hell on wheels tv series","description":"the construction of the first transcontinental railroad across the united states in the world","tag":"red","":""} +{"id":"E","word_rank":2,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":4,"asc_desc_rank":4,"sort_by_rank":1,"title":"hello kitty","description":"also known by her full name kitty white is a fictional character produced by the japanese company sanrio","tag":"green","":""} +{"id":"F","word_rank":2,"typo_rank":1,"proximity_rank":0,"attribute_rank":1017,"exact_rank":5,"asc_desc_rank":5,"sort_by_rank":0,"title":"laptop orchestra","description":"a laptop orchestra lork or lo is a chamber music ensemble consisting primarily of laptops like helo huddersfield experimental laptop orchestra","tag":"blue","":""} +{"id":"G","word_rank":1,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":5,"sort_by_rank":2,"title":"hello world film","description":"hello world is a 2019 japanese animated sci fi romantic drama film directed by tomohiko ito and produced by graphinica","tag":"red","":""} +{"id":"H","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":4,"sort_by_rank":1,"title":"world hello day","description":"holiday observed on november 21 to express that conflicts should be resolved through communication rather than the use of force","tag":"green","":""} +{"id":"I","word_rank":0,"typo_rank":0,"proximity_rank":8,"attribute_rank":338,"exact_rank":3,"asc_desc_rank":3,"sort_by_rank":0,"title":"hello world song","description":"hello world is a song written by tom douglas tony lane and david lee and recorded by american country music group lady antebellum","tag":"blue","":""} +{"id":"J","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":1,"exact_rank":3,"asc_desc_rank":2,"sort_by_rank":1,"title":"hello cruel world","description":"hello cruel world is an album by new zealand band tall dwarfs","tag":"green","":""} +{"id":"K","word_rank":0,"typo_rank":2,"proximity_rank":9,"attribute_rank":670,"exact_rank":5,"asc_desc_rank":1,"sort_by_rank":2,"title":"ello creation system","description":"in few word ello was a construction toy created by the american company mattel to engage girls in construction play","tag":"red","":""} +{"id":"L","word_rank":0,"typo_rank":0,"proximity_rank":2,"attribute_rank":250,"exact_rank":4,"asc_desc_rank":0,"sort_by_rank":0,"title":"good morning world","description":"good morning world is an american sitcom broadcast on cbs tv during the 1967 1968 season","tag":"blue","":""} +{"id":"M","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":0,"asc_desc_rank":0,"sort_by_rank":2,"title":"hello world america","description":"a perfect match for a perfect engine using the query hello world america","tag":"red","":""} +{"id":"N","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":1,"asc_desc_rank":4,"sort_by_rank":1,"title":"hello world america unleashed","description":"a very good match for a very good engine using the query hello world america","tag":"green","":""} +{"id":"O","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":10,"exact_rank":0,"asc_desc_rank":6,"sort_by_rank":0,"title":"a perfect match for a perfect engine using the query hello world america","description":"hello world america","tag":"blue","":""} +{"id":"P","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":12,"exact_rank":1,"asc_desc_rank":3,"sort_by_rank":2,"title":"a very good match for a very good engine using the query hello world america","description":"hello world america unleashed","tag":"red","":""} +{"id":"Q","word_rank":1,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":2,"sort_by_rank":1,"title":"hello world","description":"a hello world program generally is a computer program that outputs or displays the message hello world","tag":"green","":""} diff --git a/milli/tests/search/distinct.rs b/milli/tests/search/distinct.rs index ef5af3272..f044756eb 100644 --- a/milli/tests/search/distinct.rs +++ b/milli/tests/search/distinct.rs @@ -32,7 +32,7 @@ macro_rules! test_distinct { let SearchResult { documents_ids, .. } = search.execute().unwrap(); let mut distinct_values = HashSet::new(); - let expected_external_ids: Vec<_> = search::expected_order(&criteria, true, true) + let expected_external_ids: Vec<_> = search::expected_order(&criteria, true, true, &[]) .into_iter() .filter_map(|d| { if distinct_values.contains(&d.$distinct) { diff --git a/milli/tests/search/filters.rs b/milli/tests/search/filters.rs index 318197ea3..c810b47af 100644 --- a/milli/tests/search/filters.rs +++ b/milli/tests/search/filters.rs @@ -29,7 +29,7 @@ macro_rules! test_filter { let SearchResult { documents_ids, .. } = search.execute().unwrap(); let filtered_ids = search::expected_filtered_ids($filter); - let expected_external_ids: Vec<_> = search::expected_order(&criteria, true, true) + let expected_external_ids: Vec<_> = search::expected_order(&criteria, true, true, &[]) .into_iter() .filter_map(|d| if filtered_ids.contains(&d.id) { Some(d.id) } else { None }) .collect(); diff --git a/milli/tests/search/mod.rs b/milli/tests/search/mod.rs index c5724a921..7d4043ff1 100644 --- a/milli/tests/search/mod.rs +++ b/milli/tests/search/mod.rs @@ -1,3 +1,4 @@ +use std::cmp::Reverse; use std::collections::HashSet; use big_s::S; @@ -5,7 +6,7 @@ use either::{Either, Left, Right}; use heed::EnvOpenOptions; use maplit::{hashmap, hashset}; use milli::update::{IndexDocuments, Settings, UpdateFormat}; -use milli::{Criterion, DocumentId, Index}; +use milli::{AscDesc, Criterion, DocumentId, Index}; use serde::Deserialize; use slice_group_by::GroupBy; @@ -36,6 +37,10 @@ pub fn setup_search_index_with_criteria(criteria: &[Criterion]) -> Index { S("tag"), S("asc_desc_rank"), }); + builder.set_sortable_fields(hashset! { + S("tag"), + S("asc_desc_rank"), + }); builder.set_synonyms(hashmap! { S("hello") => vec![S("good morning")], S("world") => vec![S("earth")], @@ -67,6 +72,7 @@ pub fn expected_order( criteria: &[Criterion], authorize_typo: bool, optional_words: bool, + sort_by: &[AscDesc], ) -> Vec { let dataset = serde_json::Deserializer::from_str(CONTENT).into_iter().map(|r| r.unwrap()).collect(); @@ -90,6 +96,14 @@ pub fn expected_order( new_groups .extend(group.linear_group_by_key(|d| d.proximity_rank).map(Vec::from)); } + Criterion::Sort if sort_by == [AscDesc::Asc(S("tag"))] => { + group.sort_by_key(|d| d.sort_by_rank); + new_groups.extend(group.linear_group_by_key(|d| d.sort_by_rank).map(Vec::from)); + } + Criterion::Sort if sort_by == [AscDesc::Desc(S("tag"))] => { + group.sort_by_key(|d| Reverse(d.sort_by_rank)); + new_groups.extend(group.linear_group_by_key(|d| d.sort_by_rank).map(Vec::from)); + } Criterion::Typo => { group.sort_by_key(|d| d.typo_rank); new_groups.extend(group.linear_group_by_key(|d| d.typo_rank).map(Vec::from)); @@ -104,11 +118,13 @@ pub fn expected_order( .extend(group.linear_group_by_key(|d| d.asc_desc_rank).map(Vec::from)); } Criterion::Desc(field_name) if field_name == "asc_desc_rank" => { - group.sort_by_key(|d| std::cmp::Reverse(d.asc_desc_rank)); + group.sort_by_key(|d| Reverse(d.asc_desc_rank)); new_groups .extend(group.linear_group_by_key(|d| d.asc_desc_rank).map(Vec::from)); } - Criterion::Asc(_) | Criterion::Desc(_) => new_groups.push(group.clone()), + Criterion::Asc(_) | Criterion::Desc(_) | Criterion::Sort => { + new_groups.push(group.clone()) + } } } groups = std::mem::take(&mut new_groups); @@ -185,6 +201,7 @@ pub struct TestDocument { pub attribute_rank: u32, pub exact_rank: u32, pub asc_desc_rank: u32, + pub sort_by_rank: u32, pub title: String, pub description: String, pub tag: String, diff --git a/milli/tests/search/query_criteria.rs b/milli/tests/search/query_criteria.rs index f814508f5..1723c1d6f 100644 --- a/milli/tests/search/query_criteria.rs +++ b/milli/tests/search/query_criteria.rs @@ -1,6 +1,6 @@ use big_s::S; use milli::update::Settings; -use milli::{Criterion, Search, SearchResult}; +use milli::{AscDesc, Criterion, Search, SearchResult}; use Criterion::*; use crate::search::{self, EXTERNAL_DOCUMENTS_IDS}; @@ -11,7 +11,7 @@ const ALLOW_OPTIONAL_WORDS: bool = true; const DISALLOW_OPTIONAL_WORDS: bool = false; macro_rules! test_criterion { - ($func:ident, $optional_word:ident, $authorize_typos:ident, $criteria:expr) => { + ($func:ident, $optional_word:ident, $authorize_typos:ident, $criteria:expr, $sort_criteria:expr) => { #[test] fn $func() { let criteria = $criteria; @@ -23,82 +23,168 @@ macro_rules! test_criterion { search.limit(EXTERNAL_DOCUMENTS_IDS.len()); search.authorize_typos($authorize_typos); search.optional_words($optional_word); + search.sort_criteria($sort_criteria); let SearchResult { documents_ids, .. } = search.execute().unwrap(); - let expected_external_ids: Vec<_> = - search::expected_order(&criteria, $authorize_typos, $optional_word) - .into_iter() - .map(|d| d.id) - .collect(); + let expected_external_ids: Vec<_> = search::expected_order( + &criteria, + $authorize_typos, + $optional_word, + &$sort_criteria[..], + ) + .into_iter() + .map(|d| d.id) + .collect(); let documents_ids = search::internal_to_external_ids(&index, &documents_ids); assert_eq!(documents_ids, expected_external_ids); } }; } -test_criterion!(none_allow_typo, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![]); -test_criterion!(none_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, vec![]); -test_criterion!(words_allow_typo, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Words]); -test_criterion!(attribute_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Attribute]); -test_criterion!(attribute_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, vec![Attribute]); -test_criterion!(exactness_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Exactness]); -test_criterion!(exactness_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, vec![Exactness]); -test_criterion!(proximity_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Proximity]); -test_criterion!(proximity_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, vec![Proximity]); +test_criterion!(none_allow_typo, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![], vec![]); +test_criterion!(none_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, vec![], vec![]); +test_criterion!(words_allow_typo, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Words], vec![]); +test_criterion!( + attribute_allow_typo, + DISALLOW_OPTIONAL_WORDS, + ALLOW_TYPOS, + vec![Attribute], + vec![] +); +test_criterion!( + attribute_disallow_typo, + DISALLOW_OPTIONAL_WORDS, + DISALLOW_TYPOS, + vec![Attribute], + vec![] +); +test_criterion!( + exactness_allow_typo, + DISALLOW_OPTIONAL_WORDS, + ALLOW_TYPOS, + vec![Exactness], + vec![] +); +test_criterion!( + exactness_disallow_typo, + DISALLOW_OPTIONAL_WORDS, + DISALLOW_TYPOS, + vec![Exactness], + vec![] +); +test_criterion!( + proximity_allow_typo, + DISALLOW_OPTIONAL_WORDS, + ALLOW_TYPOS, + vec![Proximity], + vec![] +); +test_criterion!( + proximity_disallow_typo, + DISALLOW_OPTIONAL_WORDS, + DISALLOW_TYPOS, + vec![Proximity], + vec![] +); test_criterion!( asc_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, - vec![Asc(S("asc_desc_rank"))] + vec![Asc(S("asc_desc_rank"))], + vec![] ); test_criterion!( asc_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, - vec![Asc(S("asc_desc_rank"))] + vec![Asc(S("asc_desc_rank"))], + vec![] ); test_criterion!( desc_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, - vec![Desc(S("asc_desc_rank"))] + vec![Desc(S("asc_desc_rank"))], + vec![] ); test_criterion!( desc_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, - vec![Desc(S("asc_desc_rank"))] + vec![Desc(S("asc_desc_rank"))], + vec![] ); test_criterion!( asc_unexisting_field_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, - vec![Asc(S("unexisting_field"))] + vec![Asc(S("unexisting_field"))], + vec![] ); test_criterion!( asc_unexisting_field_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, - vec![Asc(S("unexisting_field"))] + vec![Asc(S("unexisting_field"))], + vec![] ); test_criterion!( desc_unexisting_field_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, - vec![Desc(S("unexisting_field"))] + vec![Desc(S("unexisting_field"))], + vec![] ); test_criterion!( desc_unexisting_field_disallow_typo, DISALLOW_OPTIONAL_WORDS, DISALLOW_TYPOS, - vec![Desc(S("unexisting_field"))] + vec![Desc(S("unexisting_field"))], + vec![] +); +test_criterion!(empty_sort_by_allow_typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Sort], vec![]); +test_criterion!( + empty_sort_by_disallow_typo, + DISALLOW_OPTIONAL_WORDS, + DISALLOW_TYPOS, + vec![Sort], + vec![] +); +test_criterion!( + sort_by_asc_allow_typo, + DISALLOW_OPTIONAL_WORDS, + ALLOW_TYPOS, + vec![Sort], + vec![AscDesc::Asc(S("tag"))] +); +test_criterion!( + sort_by_asc_disallow_typo, + DISALLOW_OPTIONAL_WORDS, + DISALLOW_TYPOS, + vec![Sort], + vec![AscDesc::Asc(S("tag"))] +); +test_criterion!( + sort_by_desc_allow_typo, + DISALLOW_OPTIONAL_WORDS, + ALLOW_TYPOS, + vec![Sort], + vec![AscDesc::Desc(S("tag"))] +); +test_criterion!( + sort_by_desc_disallow_typo, + DISALLOW_OPTIONAL_WORDS, + DISALLOW_TYPOS, + vec![Sort], + vec![AscDesc::Desc(S("tag"))] ); test_criterion!( default_criteria_order, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, - vec![Words, Typo, Proximity, Attribute, Exactness] + vec![Words, Typo, Proximity, Attribute, Exactness], + vec![] ); #[test] @@ -262,7 +348,7 @@ fn criteria_mixup() { let SearchResult { documents_ids, .. } = search.execute().unwrap(); let expected_external_ids: Vec<_> = - search::expected_order(&criteria, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS) + search::expected_order(&criteria, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, &[]) .into_iter() .map(|d| d.id) .collect();