diff --git a/milli/src/asc_desc.rs b/milli/src/asc_desc.rs new file mode 100644 index 000000000..9a3bda934 --- /dev/null +++ b/milli/src/asc_desc.rs @@ -0,0 +1,228 @@ +//! This module provides the `AscDesc` type and defines all the errors related to this type. + +use std::fmt; +use std::str::FromStr; + +use serde::{Deserialize, Serialize}; + +use crate::error::is_reserved_keyword; +use crate::CriterionError; + +/// This error type is never supposed to be shown to the end user. +/// You must always cast it to a sort error or a criterion error. +#[derive(Debug)] +pub enum AscDescError { + InvalidSyntax { name: String }, + ReservedKeyword { name: String }, +} + +impl fmt::Display for AscDescError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::InvalidSyntax { name } => { + write!(f, "invalid asc/desc syntax for {}", name) + } + Self::ReservedKeyword { name } => { + write!( + f, + "{} is a reserved keyword and thus can't be used as a asc/desc rule", + name + ) + } + } + } +} + +impl From for CriterionError { + fn from(error: AscDescError) -> Self { + match error { + AscDescError::InvalidSyntax { name } => CriterionError::InvalidName { name }, + AscDescError::ReservedKeyword { name } if name.starts_with("_geoPoint") => { + CriterionError::ReservedNameForSort { name: "_geoPoint".to_string() } + } + AscDescError::ReservedKeyword { name } if name.starts_with("_geoRadius") => { + CriterionError::ReservedNameForFilter { name: "_geoRadius".to_string() } + } + AscDescError::ReservedKeyword { name } => CriterionError::ReservedName { name }, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub enum Member { + Field(String), + Geo([f64; 2]), +} + +impl FromStr for Member { + type Err = AscDescError; + + fn from_str(text: &str) -> Result { + match text.strip_prefix("_geoPoint(").and_then(|text| text.strip_suffix(")")) { + Some(point) => { + let (lat, long) = point + .split_once(',') + .ok_or_else(|| AscDescError::ReservedKeyword { name: text.to_string() }) + .and_then(|(lat, long)| { + lat.trim() + .parse() + .and_then(|lat| long.trim().parse().map(|long| (lat, long))) + .map_err(|_| AscDescError::ReservedKeyword { name: text.to_string() }) + })?; + Ok(Member::Geo([lat, long])) + } + None => { + if is_reserved_keyword(text) || text.starts_with("_geoRadius(") { + return Err(AscDescError::ReservedKeyword { name: text.to_string() })?; + } + Ok(Member::Field(text.to_string())) + } + } + } +} + +impl fmt::Display for Member { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Member::Field(name) => f.write_str(name), + Member::Geo([lat, lng]) => write!(f, "_geoPoint({}, {})", lat, lng), + } + } +} + +impl Member { + pub fn field(&self) -> Option<&str> { + match self { + Member::Field(field) => Some(field), + Member::Geo(_) => None, + } + } + + pub fn geo_point(&self) -> Option<&[f64; 2]> { + match self { + Member::Geo(point) => Some(point), + Member::Field(_) => None, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub enum AscDesc { + Asc(Member), + Desc(Member), +} + +impl AscDesc { + pub fn member(&self) -> &Member { + match self { + AscDesc::Asc(member) => member, + AscDesc::Desc(member) => member, + } + } + + pub fn field(&self) -> Option<&str> { + self.member().field() + } +} + +impl FromStr for AscDesc { + type Err = AscDescError; + + /// Since we don't know if this was deserialized for a criterion or a sort we just return a + /// string and let the caller create his own error. + fn from_str(text: &str) -> Result { + match text.rsplit_once(':') { + Some((left, "asc")) => Ok(AscDesc::Asc(left.parse()?)), + Some((left, "desc")) => Ok(AscDesc::Desc(left.parse()?)), + _ => Err(AscDescError::InvalidSyntax { name: text.to_string() }), + } + } +} + +#[cfg(test)] +mod tests { + use big_s::S; + use AscDesc::*; + use AscDescError::*; + use Member::*; + + use super::*; + + #[test] + fn parse_asc_desc() { + let valid_req = [ + ("truc:asc", Asc(Field(S("truc")))), + ("bidule:desc", Desc(Field(S("bidule")))), + ("a-b:desc", Desc(Field(S("a-b")))), + ("a:b:desc", Desc(Field(S("a:b")))), + ("a12:asc", Asc(Field(S("a12")))), + ("42:asc", Asc(Field(S("42")))), + ("_geoPoint(42, 59):asc", Asc(Geo([42., 59.]))), + ("_geoPoint(42.459, 59):desc", Desc(Geo([42.459, 59.]))), + ("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))), + ("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))), + ("_geoPoint(42.0002, 59.895):desc", Desc(Geo([42.0002, 59.895]))), + ("_geoPoint(42., 59.):desc", Desc(Geo([42., 59.]))), + ("truc(12, 13):desc", Desc(Field(S("truc(12, 13)")))), + ]; + + for (req, expected) in valid_req { + let res = req.parse::(); + assert!( + res.is_ok(), + "Failed to parse `{}`, was expecting `{:?}` but instead got `{:?}`", + req, + expected, + res + ); + assert_eq!(res.unwrap(), expected); + } + + let invalid_req = [ + ("truc:machin", InvalidSyntax { name: S("truc:machin") }), + ("truc:deesc", InvalidSyntax { name: S("truc:deesc") }), + ("truc:asc:deesc", InvalidSyntax { name: S("truc:asc:deesc") }), + ("42desc", InvalidSyntax { name: S("42desc") }), + ("_geoPoint:asc", ReservedKeyword { name: S("_geoPoint") }), + ("_geoDistance:asc", ReservedKeyword { name: S("_geoDistance") }), + ("_geoPoint(42.12 , 59.598)", InvalidSyntax { name: S("_geoPoint(42.12 , 59.598)") }), + ( + "_geoPoint(42.12 , 59.598):deesc", + InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):deesc") }, + ), + ( + "_geoPoint(42.12 , 59.598):machin", + InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):machin") }, + ), + ( + "_geoPoint(42.12 , 59.598):asc:aasc", + InvalidSyntax { name: S("_geoPoint(42.12 , 59.598):asc:aasc") }, + ), + ( + "_geoPoint(42,12 , 59,598):desc", + ReservedKeyword { name: S("_geoPoint(42,12 , 59,598)") }, + ), + ("_geoPoint(35, 85, 75):asc", ReservedKeyword { name: S("_geoPoint(35, 85, 75)") }), + ("_geoPoint(18):asc", ReservedKeyword { name: S("_geoPoint(18)") }), + ]; + + for (req, expected_error) in invalid_req { + let res = req.parse::(); + assert!( + res.is_err(), + "Should no be able to parse `{}`, was expecting an error but instead got: `{:?}`", + req, + res, + ); + let res = res.unwrap_err(); + assert_eq!( + res.to_string(), + expected_error.to_string(), + "Bad error for input {}: got `{:?}` instead of `{:?}`", + req, + res, + expected_error + ); + } + } +} diff --git a/milli/src/criterion.rs b/milli/src/criterion.rs index c526a7e32..aff7fcf68 100644 --- a/milli/src/criterion.rs +++ b/milli/src/criterion.rs @@ -3,7 +3,49 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; -use crate::error::{is_reserved_keyword, Error, UserError}; +use crate::error::Error; +use crate::{AscDesc, Member, UserError}; + +#[derive(Debug)] +pub enum CriterionError { + InvalidName { name: String }, + ReservedName { name: String }, + ReservedNameForSort { name: String }, + ReservedNameForFilter { name: String }, +} + +impl fmt::Display for CriterionError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::InvalidName { name } => write!(f, "invalid ranking rule {}", name), + Self::ReservedName { name } => { + write!(f, "{} is a reserved keyword and thus can't be used as a ranking rule", name) + } + Self::ReservedNameForSort { name } => { + write!( + f, + "{} is a reserved keyword and thus can't be used as a ranking rule. \ +{} can only be used for sorting at search time", + name, name + ) + } + Self::ReservedNameForFilter { name } => { + write!( + f, + "{} is a reserved keyword and thus can't be used as a ranking rule. \ +{} can only be used for filtering at search time", + name, name + ) + } + } + } +} + +impl From for Error { + fn from(error: CriterionError) -> Self { + Self::UserError(UserError::CriterionError(error)) + } +} #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum Criterion { @@ -39,7 +81,7 @@ impl Criterion { } impl FromStr for Criterion { - type Err = Error; + type Err = CriterionError; fn from_str(text: &str) -> Result { match text { @@ -49,118 +91,17 @@ impl FromStr for Criterion { "attribute" => Ok(Criterion::Attribute), "sort" => Ok(Criterion::Sort), "exactness" => Ok(Criterion::Exactness), - text => match AscDesc::from_str(text) { - Ok(AscDesc::Asc(Member::Field(field))) => Ok(Criterion::Asc(field)), - Ok(AscDesc::Desc(Member::Field(field))) => Ok(Criterion::Desc(field)), - Ok(AscDesc::Asc(Member::Geo(_))) | Ok(AscDesc::Desc(Member::Geo(_))) => { - Err(UserError::InvalidRankingRuleName { name: text.to_string() })? - } - Err(UserError::InvalidAscDescSyntax { name }) => { - Err(UserError::InvalidRankingRuleName { name }.into()) - } - Err(error) => { - Err(UserError::InvalidRankingRuleName { name: error.to_string() }.into()) + text => match AscDesc::from_str(text)? { + AscDesc::Asc(Member::Field(field)) => Ok(Criterion::Asc(field)), + AscDesc::Desc(Member::Field(field)) => Ok(Criterion::Desc(field)), + AscDesc::Asc(Member::Geo(_)) | AscDesc::Desc(Member::Geo(_)) => { + Err(CriterionError::ReservedNameForSort { name: "_geoPoint".to_string() })? } }, } } } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub enum Member { - Field(String), - Geo([f64; 2]), -} - -impl FromStr for Member { - type Err = UserError; - - fn from_str(text: &str) -> Result { - match text.strip_prefix("_geoPoint(").and_then(|text| text.strip_suffix(")")) { - Some(point) => { - let (lat, long) = point - .split_once(',') - .ok_or_else(|| UserError::InvalidRankingRuleName { name: text.to_string() }) - .and_then(|(lat, long)| { - lat.trim() - .parse() - .and_then(|lat| long.trim().parse().map(|long| (lat, long))) - .map_err(|_| UserError::InvalidRankingRuleName { - name: text.to_string(), - }) - })?; - Ok(Member::Geo([lat, long])) - } - None => { - if is_reserved_keyword(text) { - return Err(UserError::InvalidReservedRankingRuleName { - name: text.to_string(), - })?; - } - Ok(Member::Field(text.to_string())) - } - } - } -} - -impl fmt::Display for Member { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Member::Field(name) => f.write_str(name), - Member::Geo([lat, lng]) => write!(f, "_geoPoint({}, {})", lat, lng), - } - } -} - -impl Member { - pub fn field(&self) -> Option<&str> { - match self { - Member::Field(field) => Some(field), - Member::Geo(_) => None, - } - } - - pub fn geo_point(&self) -> Option<&[f64; 2]> { - match self { - Member::Geo(point) => Some(point), - Member::Field(_) => None, - } - } -} - -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] -pub enum AscDesc { - Asc(Member), - Desc(Member), -} - -impl AscDesc { - pub fn member(&self) -> &Member { - match self { - AscDesc::Asc(member) => member, - AscDesc::Desc(member) => member, - } - } - - pub fn field(&self) -> Option<&str> { - self.member().field() - } -} - -impl FromStr for AscDesc { - type Err = UserError; - - /// Since we don't know if this was deserialized for a criterion or a sort we just return a - /// string and let the caller create his own error - fn from_str(text: &str) -> Result { - match text.rsplit_once(':') { - Some((left, "asc")) => Ok(AscDesc::Asc(left.parse()?)), - Some((left, "desc")) => Ok(AscDesc::Desc(left.parse()?)), - _ => Err(UserError::InvalidAscDescSyntax { name: text.to_string() }), - } - } -} - pub fn default_criteria() -> Vec { vec![ Criterion::Words, @@ -191,59 +132,74 @@ impl fmt::Display for Criterion { #[cfg(test)] mod tests { + use big_s::S; + use CriterionError::*; + use super::*; #[test] - fn parse_asc_desc() { - use big_s::S; - use AscDesc::*; - use Member::*; - - let valid_req = [ - ("truc:asc", Asc(Field(S("truc")))), - ("bidule:desc", Desc(Field(S("bidule")))), - ("a-b:desc", Desc(Field(S("a-b")))), - ("a:b:desc", Desc(Field(S("a:b")))), - ("a12:asc", Asc(Field(S("a12")))), - ("42:asc", Asc(Field(S("42")))), - ("_geoPoint(42, 59):asc", Asc(Geo([42., 59.]))), - ("_geoPoint(42.459, 59):desc", Desc(Geo([42.459, 59.]))), - ("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))), - ("_geoPoint(42, 59.895):desc", Desc(Geo([42., 59.895]))), - ("_geoPoint(42.0002, 59.895):desc", Desc(Geo([42.0002, 59.895]))), - ("_geoPoint(42., 59.):desc", Desc(Geo([42., 59.]))), - ("truc(12, 13):desc", Desc(Field(S("truc(12, 13)")))), + fn parse_criterion() { + let valid_criteria = [ + ("words", Criterion::Words), + ("typo", Criterion::Typo), + ("proximity", Criterion::Proximity), + ("attribute", Criterion::Attribute), + ("sort", Criterion::Sort), + ("exactness", Criterion::Exactness), + ("price:asc", Criterion::Asc(S("price"))), + ("price:desc", Criterion::Desc(S("price"))), + ("price:asc:desc", Criterion::Desc(S("price:asc"))), + ("truc:machin:desc", Criterion::Desc(S("truc:machin"))), + ("hello-world!:desc", Criterion::Desc(S("hello-world!"))), + ("it's spacy over there:asc", Criterion::Asc(S("it's spacy over there"))), ]; - for (req, expected) in valid_req { - let res = req.parse(); - assert!(res.is_ok(), "Failed to parse `{}`, was expecting `{:?}`", req, expected); - assert_eq!(expected, res.unwrap()); + for (input, expected) in valid_criteria { + let res = input.parse::(); + assert!( + res.is_ok(), + "Failed to parse `{}`, was expecting `{:?}` but instead got `{:?}`", + input, + expected, + res + ); + assert_eq!(res.unwrap(), expected); } - let invalid_req = [ - "truc:machin", - "truc:deesc", - "truc:asc:deesc", - "42desc", - "_geoPoint:asc", - "_geoDistance:asc", - "_geoPoint(42.12 , 59.598)", - "_geoPoint(42.12 , 59.598):deesc", - "_geoPoint(42.12 , 59.598):machin", - "_geoPoint(42.12 , 59.598):asc:aasc", - "_geoPoint(42,12 , 59,598):desc", - "_geoPoint(35, 85, 75):asc", - "_geoPoint(18):asc", + let invalid_criteria = [ + ("words suffix", InvalidName { name: S("words suffix") }), + ("prefix typo", InvalidName { name: S("prefix typo") }), + ("proximity attribute", InvalidName { name: S("proximity attribute") }), + ("price", InvalidName { name: S("price") }), + ("asc:price", InvalidName { name: S("asc:price") }), + ("price:deesc", InvalidName { name: S("price:deesc") }), + ("price:aasc", InvalidName { name: S("price:aasc") }), + ("price:asc and desc", InvalidName { name: S("price:asc and desc") }), + ("price:asc:truc", InvalidName { name: S("price:asc:truc") }), + ("_geo:asc", ReservedName { name: S("_geo") }), + ("_geoDistance:asc", ReservedName { name: S("_geoDistance") }), + ("_geoPoint:asc", ReservedNameForSort { name: S("_geoPoint") }), + ("_geoPoint(42, 75):asc", ReservedNameForSort { name: S("_geoPoint") }), + ("_geoRadius:asc", ReservedNameForFilter { name: S("_geoRadius") }), + ("_geoRadius(42, 75, 59):asc", ReservedNameForFilter { name: S("_geoRadius") }), ]; - for req in invalid_req { - let res = req.parse::(); + for (input, expected) in invalid_criteria { + let res = input.parse::(); assert!( res.is_err(), "Should no be able to parse `{}`, was expecting an error but instead got: `{:?}`", - req, + input, + res + ); + let res = res.unwrap_err(); + assert_eq!( + res.to_string(), + expected.to_string(), + "Bad error for input {}: got `{:?}` instead of `{:?}`", + input, res, + expected ); } } diff --git a/milli/src/error.rs b/milli/src/error.rs index fe0ac2cf7..bd4f02b99 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -8,7 +8,7 @@ use rayon::ThreadPoolBuildError; use serde_json::{Map, Value}; use crate::search::ParserRule; -use crate::{DocumentId, FieldId}; +use crate::{CriterionError, DocumentId, FieldId}; pub type Object = Map; @@ -55,16 +55,15 @@ pub enum FieldIdMapMissingEntry { #[derive(Debug)] pub enum UserError { AttributeLimitReached, + CriterionError(CriterionError), DocumentLimitReached, - InvalidAscDescSyntax { name: String }, InvalidDocumentId { document_id: Value }, InvalidFacetsDistribution { invalid_facets_name: HashSet }, InvalidFilter(pest::error::Error), InvalidFilterAttribute(pest::error::Error), InvalidSortName { name: String }, + InvalidReservedSortName { name: String }, InvalidGeoField { document_id: Value, object: Value }, - InvalidRankingRuleName { name: String }, - InvalidReservedRankingRuleName { name: String }, InvalidSortableAttribute { field: String, valid_fields: HashSet }, SortRankingRuleMissing, InvalidStoreFile, @@ -211,6 +210,7 @@ impl fmt::Display for UserError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { Self::AttributeLimitReached => f.write_str("maximum number of attributes reached"), + Self::CriterionError(error) => write!(f, "{}", error), Self::DocumentLimitReached => f.write_str("maximum number of documents reached"), Self::InvalidFacetsDistribution { invalid_facets_name } => { let name_list = @@ -222,17 +222,17 @@ impl fmt::Display for UserError { ) } Self::InvalidFilter(error) => error.fmt(f), - Self::InvalidAscDescSyntax { name } => { - write!(f, "invalid asc/desc syntax for {}", name) - } Self::InvalidGeoField { document_id, object } => write!( f, "the document with the id: {} contains an invalid _geo field: {}", document_id, object ), - Self::InvalidRankingRuleName { name } => write!(f, "invalid criterion {}", name), - Self::InvalidReservedRankingRuleName { name } => { - write!(f, "{} is a reserved keyword and thus can't be used as a ranking rule", name) + Self::InvalidReservedSortName { name } => { + write!( + f, + "{} is a reserved keyword and thus can't be used as a sort expression", + name + ) } Self::InvalidDocumentId { document_id } => { let json = serde_json::to_string(document_id).unwrap(); diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 550e7f13d..8a54bbbdf 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -4,6 +4,7 @@ extern crate pest_derive; #[macro_use] pub mod documents; +mod asc_desc; mod criterion; mod error; mod external_documents_ids; @@ -24,7 +25,8 @@ use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; use serde_json::{Map, Value}; -pub use self::criterion::{default_criteria, AscDesc, Criterion, Member}; +pub use self::asc_desc::{AscDesc, AscDescError, Member}; +pub use self::criterion::{default_criteria, Criterion, CriterionError}; pub use self::error::{ Error, FieldIdMapMissingEntry, InternalError, SerializationError, UserError, }; diff --git a/milli/src/search/criteria/mod.rs b/milli/src/search/criteria/mod.rs index c2de55de5..a23e5acf9 100644 --- a/milli/src/search/criteria/mod.rs +++ b/milli/src/search/criteria/mod.rs @@ -12,10 +12,9 @@ use self::r#final::Final; use self::typo::Typo; use self::words::Words; use super::query_tree::{Operation, PrimitiveQueryPart, Query, QueryKind}; -use crate::criterion::{AscDesc as AscDescName, Member}; use crate::search::criteria::geo::Geo; use crate::search::{word_derivations, WordDerivationsCache}; -use crate::{DocumentId, FieldId, Index, Result, TreeLevel}; +use crate::{AscDesc as AscDescName, DocumentId, FieldId, Index, Member, Result, TreeLevel}; mod asc_desc; mod attribute; diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index f752f5822..bec059d46 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -18,10 +18,9 @@ pub(crate) use self::facet::ParserRule; pub use self::facet::{FacetDistribution, FacetNumberIter, FilterCondition, Operator}; pub use self::matching_words::MatchingWords; use self::query_tree::QueryTreeBuilder; -use crate::criterion::{AscDesc, Criterion}; use crate::error::UserError; use crate::search::criteria::r#final::{Final, FinalResult}; -use crate::{DocumentId, Index, Result}; +use crate::{AscDesc, Criterion, DocumentId, Index, Member, Result}; // Building these factories is not free. static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); @@ -148,15 +147,20 @@ impl<'a> Search<'a> { if let Some(sort_criteria) = &self.sort_criteria { let sortable_fields = self.index.sortable_fields(self.rtxn)?; for asc_desc in sort_criteria { - // we are not supposed to find any geoPoint in the criterion - if let Some(field) = asc_desc.field() { - if !sortable_fields.contains(field) { + match asc_desc.member() { + Member::Field(ref field) if !sortable_fields.contains(field) => { return Err(UserError::InvalidSortableAttribute { field: field.to_string(), valid_fields: sortable_fields, - } - .into()); + })? } + Member::Geo(_) if !sortable_fields.contains("_geo") => { + return Err(UserError::InvalidSortableAttribute { + field: "_geo".to_string(), + valid_fields: sortable_fields, + })? + } + _ => (), } } }