diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 594cc60e0..3792204e9 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -39,7 +39,8 @@ tempfile = "3.2.0" uuid = { version = "0.8.2", features = ["v4"] } # facet filter parser -nom = "7.0.0" +pest = { git = "https://github.com/pest-parser/pest.git", rev = "51fd1d49f1041f7839975664ef71fe15c7dcaf67" } +pest_derive = "2.1.0" # documents words self-join itertools = "0.10.0" diff --git a/milli/src/error.rs b/milli/src/error.rs index c0ce101c8..1f1cc5264 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -7,6 +7,7 @@ use heed::{Error as HeedError, MdbError}; use rayon::ThreadPoolBuildError; use serde_json::{Map, Value}; +use crate::search::ParserRule; use crate::{CriterionError, DocumentId, FieldId, SortError}; pub type Object = Map; @@ -58,9 +59,9 @@ pub enum UserError { DocumentLimitReached, InvalidDocumentId { document_id: Value }, InvalidFacetsDistribution { invalid_facets_name: HashSet }, + InvalidFilter(pest::error::Error), + InvalidFilterAttribute(pest::error::Error), InvalidGeoField { document_id: Value, object: Value }, - InvalidFilter { input: String }, - InvalidSortName { name: String }, InvalidSortableAttribute { field: String, valid_fields: HashSet }, SortRankingRuleMissing, InvalidStoreFile, @@ -207,7 +208,6 @@ impl StdError for InternalError {} impl fmt::Display for UserError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - Self::InvalidFilter { input } => write!(f, "parser error {}", input), Self::AttributeLimitReached => f.write_str("maximum number of attributes reached"), Self::CriterionError(error) => write!(f, "{}", error), Self::DocumentLimitReached => f.write_str("maximum number of documents reached"), @@ -220,6 +220,7 @@ impl fmt::Display for UserError { name_list ) } + Self::InvalidFilter(error) => error.fmt(f), Self::InvalidGeoField { document_id, object } => write!( f, "the document with the id: {} contains an invalid _geo field: {}", @@ -235,9 +236,7 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco json ) } - Self::InvalidSortName { name } => { - write!(f, "Invalid syntax for the sort parameter: {}", name) - } + Self::InvalidFilterAttribute(error) => error.fmt(f), Self::InvalidSortableAttribute { field, valid_fields } => { let valid_names = valid_fields.iter().map(AsRef::as_ref).collect::>().join(", "); diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 6fe5947f5..781cedb2c 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -1,3 +1,6 @@ +#[macro_use] +extern crate pest_derive; + #[macro_use] pub mod documents; diff --git a/milli/src/search/facet/filter_condition.rs b/milli/src/search/facet/filter_condition.rs index 4fedeee69..f1055b2f8 100644 --- a/milli/src/search/facet/filter_condition.rs +++ b/milli/src/search/facet/filter_condition.rs @@ -1,20 +1,60 @@ +use std::collections::HashSet; use std::fmt::Debug; use std::ops::Bound::{self, Excluded, Included}; +use std::result::Result as StdResult; +use std::str::FromStr; use either::Either; use heed::types::DecodeIgnore; +use itertools::Itertools; use log::debug; -use nom::error::{convert_error, VerboseError}; +use pest::error::{Error as PestError, ErrorVariant}; +use pest::iterators::{Pair, Pairs}; +use pest::Parser; use roaring::RoaringBitmap; use self::FilterCondition::*; -use super::filter_parser::{Operator, ParseContext}; +use self::Operator::*; +use super::parser::{FilterParser, Rule, PREC_CLIMBER}; use super::FacetNumberRange; -use crate::error::{Error, UserError}; +use crate::error::UserError; use crate::heed_codec::facet::{ FacetLevelValueF64Codec, FacetStringLevelZeroCodec, FacetStringLevelZeroValueCodec, }; -use crate::{distance_between_two_points, CboRoaringBitmapCodec, FieldId, Index, Result}; +use crate::{ + distance_between_two_points, CboRoaringBitmapCodec, FieldId, FieldsIdsMap, Index, Result, +}; + +#[derive(Debug, Clone, PartialEq)] +pub enum Operator { + GreaterThan(f64), + GreaterThanOrEqual(f64), + Equal(Option, String), + NotEqual(Option, String), + LowerThan(f64), + LowerThanOrEqual(f64), + Between(f64, f64), + GeoLowerThan([f64; 2], f64), + GeoGreaterThan([f64; 2], f64), +} + +impl Operator { + /// This method can return two operations in case it must express + /// an OR operation for the between case (i.e. `TO`). + fn negate(self) -> (Self, Option) { + match self { + GreaterThan(n) => (LowerThanOrEqual(n), None), + GreaterThanOrEqual(n) => (LowerThan(n), None), + Equal(n, s) => (NotEqual(n, s), None), + NotEqual(n, s) => (Equal(n, s), None), + LowerThan(n) => (GreaterThanOrEqual(n), None), + LowerThanOrEqual(n) => (GreaterThan(n), None), + Between(n, m) => (LowerThan(n), Some(GreaterThan(m))), + GeoLowerThan(point, distance) => (GeoGreaterThan(point, distance), None), + GeoGreaterThan(point, distance) => (GeoLowerThan(point, distance), None), + } + } +} #[derive(Debug, Clone, PartialEq)] pub enum FilterCondition { @@ -36,7 +76,7 @@ impl FilterCondition { A: AsRef, B: AsRef, { - let mut ands: Option = None; + let mut ands = None; for either in array { match either { @@ -77,23 +117,41 @@ impl FilterCondition { ) -> Result { let fields_ids_map = index.fields_ids_map(rtxn)?; let filterable_fields = index.filterable_fields(rtxn)?; - let ctx = - ParseContext { fields_ids_map: &fields_ids_map, filterable_fields: &filterable_fields }; - match ctx.parse_expression::>(expression) { - Ok((_, fc)) => Ok(fc), - Err(e) => { - let ve = match e { - nom::Err::Error(x) => x, - nom::Err::Failure(x) => x, - _ => unreachable!(), - }; - Err(Error::UserError(UserError::InvalidFilter { - input: convert_error(expression, ve).to_string(), - })) - } - } + let lexed = + FilterParser::parse(Rule::prgm, expression).map_err(UserError::InvalidFilter)?; + FilterCondition::from_pairs(&fields_ids_map, &filterable_fields, lexed) } - pub fn negate(self) -> FilterCondition { + + fn from_pairs( + fim: &FieldsIdsMap, + ff: &HashSet, + expression: Pairs, + ) -> Result { + PREC_CLIMBER.climb( + expression, + |pair: Pair| match pair.as_rule() { + Rule::greater => Ok(Self::greater_than(fim, ff, pair)?), + Rule::geq => Ok(Self::greater_than_or_equal(fim, ff, pair)?), + Rule::eq => Ok(Self::equal(fim, ff, pair)?), + Rule::neq => Ok(Self::equal(fim, ff, pair)?.negate()), + Rule::leq => Ok(Self::lower_than_or_equal(fim, ff, pair)?), + Rule::less => Ok(Self::lower_than(fim, ff, pair)?), + Rule::between => Ok(Self::between(fim, ff, pair)?), + Rule::geo_radius => Ok(Self::geo_radius(fim, ff, pair)?), + Rule::not => Ok(Self::from_pairs(fim, ff, pair.into_inner())?.negate()), + Rule::prgm => Self::from_pairs(fim, ff, pair.into_inner()), + Rule::term => Self::from_pairs(fim, ff, pair.into_inner()), + _ => unreachable!(), + }, + |lhs: Result, op: Pair, rhs: Result| match op.as_rule() { + Rule::or => Ok(Or(Box::new(lhs?), Box::new(rhs?))), + Rule::and => Ok(And(Box::new(lhs?), Box::new(rhs?))), + _ => unreachable!(), + }, + ) + } + + fn negate(self) -> FilterCondition { match self { Operator(fid, op) => match op.negate() { (op, None) => Operator(fid, op), @@ -104,6 +162,189 @@ impl FilterCondition { Empty => Empty, } } + + fn geo_radius( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + item: Pair, + ) -> Result { + if !filterable_fields.contains("_geo") { + return Err(UserError::InvalidFilterAttribute(PestError::new_from_span( + ErrorVariant::CustomError { + message: format!( + "attribute `_geo` is not filterable, available filterable attributes are: {}", + filterable_fields.iter().join(", "), + ), + }, + item.as_span(), + )))?; + } + let mut items = item.into_inner(); + let fid = match fields_ids_map.id("_geo") { + Some(fid) => fid, + None => return Ok(Empty), + }; + let parameters_item = items.next().unwrap(); + // We don't need more than 3 parameters, but to handle errors correctly we are still going + // to extract the first 4 parameters + let param_span = parameters_item.as_span(); + let parameters = parameters_item + .into_inner() + .take(4) + .map(|param| (param.clone(), param.as_span())) + .map(|(param, span)| pest_parse(param).0.map(|arg| (arg, span))) + .collect::, _>>() + .map_err(UserError::InvalidFilter)?; + if parameters.len() != 3 { + return Err(UserError::InvalidFilter(PestError::new_from_span( + ErrorVariant::CustomError { + message: format!("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`"), + }, + // we want to point to the last parameters and if there was no parameters we + // point to the parenthesis + parameters.last().map(|param| param.1.clone()).unwrap_or(param_span), + )))?; + } + let (lat, lng, distance) = (¶meters[0], ¶meters[1], parameters[2].0); + if !(-90.0..=90.0).contains(&lat.0) { + return Err(UserError::InvalidFilter(PestError::new_from_span( + ErrorVariant::CustomError { + message: format!("Latitude must be contained between -90 and 90 degrees."), + }, + lat.1.clone(), + )))?; + } else if !(-180.0..=180.0).contains(&lng.0) { + return Err(UserError::InvalidFilter(PestError::new_from_span( + ErrorVariant::CustomError { + message: format!("Longitude must be contained between -180 and 180 degrees."), + }, + lng.1.clone(), + )))?; + } + Ok(Operator(fid, GeoLowerThan([lat.0, lng.0], distance))) + } + + fn between( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + item: Pair, + ) -> Result { + let mut items = item.into_inner(); + let fid = match field_id(fields_ids_map, filterable_fields, &mut items) + .map_err(UserError::InvalidFilterAttribute)? + { + Some(fid) => fid, + None => return Ok(Empty), + }; + + let (lresult, _) = pest_parse(items.next().unwrap()); + let (rresult, _) = pest_parse(items.next().unwrap()); + + let lvalue = lresult.map_err(UserError::InvalidFilter)?; + let rvalue = rresult.map_err(UserError::InvalidFilter)?; + + Ok(Operator(fid, Between(lvalue, rvalue))) + } + + fn equal( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + item: Pair, + ) -> Result { + let mut items = item.into_inner(); + let fid = match field_id(fields_ids_map, filterable_fields, &mut items) + .map_err(UserError::InvalidFilterAttribute)? + { + Some(fid) => fid, + None => return Ok(Empty), + }; + + let value = items.next().unwrap(); + let (result, svalue) = pest_parse(value); + + let svalue = svalue.to_lowercase(); + Ok(Operator(fid, Equal(result.ok(), svalue))) + } + + fn greater_than( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + item: Pair, + ) -> Result { + let mut items = item.into_inner(); + let fid = match field_id(fields_ids_map, filterable_fields, &mut items) + .map_err(UserError::InvalidFilterAttribute)? + { + Some(fid) => fid, + None => return Ok(Empty), + }; + + let value = items.next().unwrap(); + let (result, _svalue) = pest_parse(value); + let value = result.map_err(UserError::InvalidFilter)?; + + Ok(Operator(fid, GreaterThan(value))) + } + + fn greater_than_or_equal( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + item: Pair, + ) -> Result { + let mut items = item.into_inner(); + let fid = match field_id(fields_ids_map, filterable_fields, &mut items) + .map_err(UserError::InvalidFilterAttribute)? + { + Some(fid) => fid, + None => return Ok(Empty), + }; + + let value = items.next().unwrap(); + let (result, _svalue) = pest_parse(value); + let value = result.map_err(UserError::InvalidFilter)?; + + Ok(Operator(fid, GreaterThanOrEqual(value))) + } + + fn lower_than( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + item: Pair, + ) -> Result { + let mut items = item.into_inner(); + let fid = match field_id(fields_ids_map, filterable_fields, &mut items) + .map_err(UserError::InvalidFilterAttribute)? + { + Some(fid) => fid, + None => return Ok(Empty), + }; + + let value = items.next().unwrap(); + let (result, _svalue) = pest_parse(value); + let value = result.map_err(UserError::InvalidFilter)?; + + Ok(Operator(fid, LowerThan(value))) + } + + fn lower_than_or_equal( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + item: Pair, + ) -> Result { + let mut items = item.into_inner(); + let fid = match field_id(fields_ids_map, filterable_fields, &mut items) + .map_err(UserError::InvalidFilterAttribute)? + { + Some(fid) => fid, + None => return Ok(Empty), + }; + + let value = items.next().unwrap(); + let (result, _svalue) = pest_parse(value); + let value = result.map_err(UserError::InvalidFilter)?; + + Ok(Operator(fid, LowerThanOrEqual(value))) + } } impl FilterCondition { @@ -227,9 +468,9 @@ impl FilterCondition { // as the facets values are all in the same database and prefixed by the // field id and the level. let (left, right) = match operator { - Operator::GreaterThan(val) => (Excluded(*val), Included(f64::MAX)), - Operator::GreaterThanOrEqual(val) => (Included(*val), Included(f64::MAX)), - Operator::Equal(number, string) => { + GreaterThan(val) => (Excluded(*val), Included(f64::MAX)), + GreaterThanOrEqual(val) => (Included(*val), Included(f64::MAX)), + Equal(number, string) => { let (_original_value, string_docids) = strings_db.get(rtxn, &(field_id, &string))?.unwrap_or_default(); let number_docids = match number { @@ -251,23 +492,23 @@ impl FilterCondition { }; return Ok(string_docids | number_docids); } - Operator::NotEqual(number, string) => { + NotEqual(number, string) => { let all_numbers_ids = if number.is_some() { index.number_faceted_documents_ids(rtxn, field_id)? } else { RoaringBitmap::new() }; let all_strings_ids = index.string_faceted_documents_ids(rtxn, field_id)?; - let operator = Operator::Equal(*number, string.clone()); + let operator = Equal(*number, string.clone()); let docids = Self::evaluate_operator( rtxn, index, numbers_db, strings_db, field_id, &operator, )?; return Ok((all_numbers_ids | all_strings_ids) - docids); } - Operator::LowerThan(val) => (Included(f64::MIN), Excluded(*val)), - Operator::LowerThanOrEqual(val) => (Included(f64::MIN), Included(*val)), - Operator::Between(left, right) => (Included(*left), Included(*right)), - Operator::GeoLowerThan(base_point, distance) => { + LowerThan(val) => (Included(f64::MIN), Excluded(*val)), + LowerThanOrEqual(val) => (Included(f64::MIN), Included(*val)), + Between(left, right) => (Included(*left), Included(*right)), + GeoLowerThan(base_point, distance) => { let rtree = match index.geo_rtree(rtxn)? { Some(rtree) => rtree, None => return Ok(RoaringBitmap::new()), @@ -283,14 +524,14 @@ impl FilterCondition { return Ok(result); } - Operator::GeoGreaterThan(point, distance) => { + GeoGreaterThan(point, distance) => { let result = Self::evaluate_operator( rtxn, index, numbers_db, strings_db, field_id, - &Operator::GeoLowerThan(point.clone(), *distance), + &GeoLowerThan(point.clone(), *distance), )?; let geo_faceted_doc_ids = index.geo_faceted_documents_ids(rtxn)?; return Ok(geo_faceted_doc_ids - result); @@ -344,3 +585,406 @@ impl FilterCondition { } } } + +/// Retrieve the field id base on the pest value. +/// +/// Returns an error if the given value is not filterable. +/// +/// Returns Ok(None) if the given value is filterable, but is not yet ascociated to a field_id. +/// +/// The pest pair is simply a string associated with a span, a location to highlight in +/// the error message. +fn field_id( + fields_ids_map: &FieldsIdsMap, + filterable_fields: &HashSet, + items: &mut Pairs, +) -> StdResult, PestError> { + // lexing ensures that we at least have a key + let key = items.next().unwrap(); + if key.as_rule() == Rule::reserved { + let message = match key.as_str() { + key if key.starts_with("_geoPoint") => { + format!( + "`_geoPoint` is a reserved keyword and thus can't be used as a filter expression. \ + Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates.", + ) + } + key @ "_geo" => { + format!( + "`{}` is a reserved keyword and thus can't be used as a filter expression. \ + Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates.", + key + ) + } + key => format!( + "`{}` is a reserved keyword and thus can't be used as a filter expression.", + key + ), + }; + return Err(PestError::new_from_span(ErrorVariant::CustomError { message }, key.as_span())); + } + + if !filterable_fields.contains(key.as_str()) { + return Err(PestError::new_from_span( + ErrorVariant::CustomError { + message: format!( + "attribute `{}` is not filterable, available filterable attributes are: {}.", + key.as_str(), + filterable_fields.iter().join(", "), + ), + }, + key.as_span(), + )); + } + + Ok(fields_ids_map.id(key.as_str())) +} + +/// Tries to parse the pest pair into the type `T` specified, always returns +/// the original string that we tried to parse. +/// +/// Returns the parsing error associated with the span if the conversion fails. +fn pest_parse(pair: Pair) -> (StdResult>, String) +where + T: FromStr, + T::Err: ToString, +{ + let result = match pair.as_str().parse::() { + Ok(value) => Ok(value), + Err(e) => Err(PestError::::new_from_span( + ErrorVariant::CustomError { message: e.to_string() }, + pair.as_span(), + )), + }; + + (result, pair.as_str().to_string()) +} + +#[cfg(test)] +mod tests { + use big_s::S; + use heed::EnvOpenOptions; + use maplit::hashset; + + use super::*; + use crate::update::Settings; + + #[test] + fn string() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut map = index.fields_ids_map(&wtxn).unwrap(); + map.insert("channel"); + index.put_fields_ids_map(&mut wtxn, &map).unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_filterable_fields(hashset! { S("channel") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_str(&rtxn, &index, "channel = Ponce").unwrap(); + let expected = Operator(0, Operator::Equal(None, S("ponce"))); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str(&rtxn, &index, "channel != ponce").unwrap(); + let expected = Operator(0, Operator::NotEqual(None, S("ponce"))); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str(&rtxn, &index, "NOT channel = ponce").unwrap(); + let expected = Operator(0, Operator::NotEqual(None, S("ponce"))); + assert_eq!(condition, expected); + } + + #[test] + fn number() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut map = index.fields_ids_map(&wtxn).unwrap(); + map.insert("timestamp"); + index.put_fields_ids_map(&mut wtxn, &map).unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_filterable_fields(hashset! { "timestamp".into() }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_str(&rtxn, &index, "timestamp 22 TO 44").unwrap(); + let expected = Operator(0, Between(22.0, 44.0)); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str(&rtxn, &index, "NOT timestamp 22 TO 44").unwrap(); + let expected = + Or(Box::new(Operator(0, LowerThan(22.0))), Box::new(Operator(0, GreaterThan(44.0)))); + assert_eq!(condition, expected); + } + + #[test] + fn parentheses() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order + builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_str( + &rtxn, + &index, + "channel = gotaga OR (timestamp 22 TO 44 AND channel != ponce)", + ) + .unwrap(); + let expected = Or( + Box::new(Operator(0, Operator::Equal(None, S("gotaga")))), + Box::new(And( + Box::new(Operator(1, Between(22.0, 44.0))), + Box::new(Operator(0, Operator::NotEqual(None, S("ponce")))), + )), + ); + assert_eq!(condition, expected); + + let condition = FilterCondition::from_str( + &rtxn, + &index, + "channel = gotaga OR NOT (timestamp 22 TO 44 AND channel != ponce)", + ) + .unwrap(); + let expected = Or( + Box::new(Operator(0, Operator::Equal(None, S("gotaga")))), + Box::new(Or( + Box::new(Or( + Box::new(Operator(1, LowerThan(22.0))), + Box::new(Operator(1, GreaterThan(44.0))), + )), + Box::new(Operator(0, Operator::Equal(None, S("ponce")))), + )), + ); + assert_eq!(condition, expected); + } + + #[test] + fn reserved_field_names() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + let rtxn = index.read_txn().unwrap(); + + let error = FilterCondition::from_str(&rtxn, &index, "_geo = 12").unwrap_err(); + assert!(error + .to_string() + .contains("`_geo` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates."), + "{}", + error.to_string() + ); + + let error = + FilterCondition::from_str(&rtxn, &index, r#"_geoDistance <= 1000"#).unwrap_err(); + assert!(error + .to_string() + .contains("`_geoDistance` is a reserved keyword and thus can't be used as a filter expression."), + "{}", + error.to_string() + ); + + let error = FilterCondition::from_str(&rtxn, &index, r#"_geoPoint > 5"#).unwrap_err(); + assert!(error + .to_string() + .contains("`_geoPoint` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates."), + "{}", + error.to_string() + ); + + let error = + FilterCondition::from_str(&rtxn, &index, r#"_geoPoint(12, 16) > 5"#).unwrap_err(); + assert!(error + .to_string() + .contains("`_geoPoint` is a reserved keyword and thus can't be used as a filter expression. Use the `_geoRadius(latitude, longitude, distance)` built-in rule to filter on `_geo` field coordinates."), + "{}", + error.to_string() + ); + } + + #[test] + fn geo_radius() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_searchable_fields(vec![S("_geo"), S("price")]); // to keep the fields order + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + let rtxn = index.read_txn().unwrap(); + // _geo is not filterable + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(12, 12, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("attribute `_geo` is not filterable, available filterable attributes are:"),); + + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_filterable_fields(hashset! { S("_geo"), S("price") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + let rtxn = index.read_txn().unwrap(); + // basic test + let condition = + FilterCondition::from_str(&rtxn, &index, "_geoRadius(12, 13.0005, 2000)").unwrap(); + let expected = Operator(0, GeoLowerThan([12., 13.0005], 2000.)); + assert_eq!(condition, expected); + + // basic test with latitude and longitude at the max angle + let condition = + FilterCondition::from_str(&rtxn, &index, "_geoRadius(90, 180, 2000)").unwrap(); + let expected = Operator(0, GeoLowerThan([90., 180.], 2000.)); + assert_eq!(condition, expected); + + // basic test with latitude and longitude at the min angle + let condition = + FilterCondition::from_str(&rtxn, &index, "_geoRadius(-90, -180, 2000)").unwrap(); + let expected = Operator(0, GeoLowerThan([-90., -180.], 2000.)); + assert_eq!(condition, expected); + + // test the negation of the GeoLowerThan + let condition = + FilterCondition::from_str(&rtxn, &index, "NOT _geoRadius(50, 18, 2000.500)").unwrap(); + let expected = Operator(0, GeoGreaterThan([50., 18.], 2000.500)); + assert_eq!(condition, expected); + + // composition of multiple operations + let condition = FilterCondition::from_str( + &rtxn, + &index, + "(NOT _geoRadius(1, 2, 300) AND _geoRadius(1.001, 2.002, 1000.300)) OR price <= 10", + ) + .unwrap(); + let expected = Or( + Box::new(And( + Box::new(Operator(0, GeoGreaterThan([1., 2.], 300.))), + Box::new(Operator(0, GeoLowerThan([1.001, 2.002], 1000.300))), + )), + Box::new(Operator(1, LowerThanOrEqual(10.))), + ); + assert_eq!(condition, expected); + + // georadius don't have any parameters + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius don't have any parameters + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius()"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius don't have enough parameters + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius have too many parameters + let result = + FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); + + // georadius have a bad latitude + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-100, 150, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("Latitude must be contained between -90 and 90 degrees.")); + + // georadius have a bad latitude + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-90.0000001, 150, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("Latitude must be contained between -90 and 90 degrees.")); + + // georadius have a bad longitude + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-10, 250, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("Longitude must be contained between -180 and 180 degrees.")); + + // georadius have a bad longitude + let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-10, 180.000001, 10)"); + assert!(result.is_err()); + let error = result.unwrap_err(); + assert!(error + .to_string() + .contains("Longitude must be contained between -180 and 180 degrees.")); + } + + #[test] + fn from_array() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index, 0); + builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order + builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); + builder.execute(|_, _| ()).unwrap(); + wtxn.commit().unwrap(); + + // Test that the facet condition is correctly generated. + let rtxn = index.read_txn().unwrap(); + let condition = FilterCondition::from_array( + &rtxn, + &index, + vec![ + Either::Right("channel = gotaga"), + Either::Left(vec!["timestamp = 44", "channel != ponce"]), + ], + ) + .unwrap() + .unwrap(); + let expected = FilterCondition::from_str( + &rtxn, + &index, + "channel = gotaga AND (timestamp = 44 OR channel != ponce)", + ) + .unwrap(); + assert_eq!(condition, expected); + } +} diff --git a/milli/src/search/facet/filter_parser.rs b/milli/src/search/facet/filter_parser.rs deleted file mode 100644 index 4d8a54987..000000000 --- a/milli/src/search/facet/filter_parser.rs +++ /dev/null @@ -1,622 +0,0 @@ -use std::collections::HashSet; -use std::fmt::Debug; -use std::result::Result as StdResult; - -use nom::branch::alt; -use nom::bytes::complete::{tag, take_while1}; -use nom::character::complete::{char, multispace0}; -use nom::combinator::map; -use nom::error::{ContextError, ErrorKind, VerboseError}; -use nom::multi::{many0, separated_list1}; -use nom::sequence::{delimited, preceded, tuple}; -use nom::IResult; - -use self::Operator::*; -use super::FilterCondition; -use crate::{FieldId, FieldsIdsMap}; -#[derive(Debug, Clone, PartialEq)] -pub enum Operator { - GreaterThan(f64), - GreaterThanOrEqual(f64), - Equal(Option, String), - NotEqual(Option, String), - LowerThan(f64), - LowerThanOrEqual(f64), - Between(f64, f64), - GeoLowerThan([f64; 2], f64), - GeoGreaterThan([f64; 2], f64), -} - -impl Operator { - /// This method can return two operations in case it must express - /// an OR operation for the between case (i.e. `TO`). - pub fn negate(self) -> (Self, Option) { - match self { - GreaterThan(n) => (LowerThanOrEqual(n), None), - GreaterThanOrEqual(n) => (LowerThan(n), None), - Equal(n, s) => (NotEqual(n, s), None), - NotEqual(n, s) => (Equal(n, s), None), - LowerThan(n) => (GreaterThanOrEqual(n), None), - LowerThanOrEqual(n) => (GreaterThan(n), None), - Between(n, m) => (LowerThan(n), Some(GreaterThan(m))), - GeoLowerThan(point, distance) => (GeoGreaterThan(point, distance), None), - GeoGreaterThan(point, distance) => (GeoLowerThan(point, distance), None), - } - } -} - -pub trait FilterParserError<'a>: - nom::error::ParseError<&'a str> + ContextError<&'a str> + std::fmt::Debug -{ -} - -impl<'a> FilterParserError<'a> for VerboseError<&'a str> {} - -pub struct ParseContext<'a> { - pub fields_ids_map: &'a FieldsIdsMap, - pub filterable_fields: &'a HashSet, -} - -impl<'a> ParseContext<'a> { - fn parse_or(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - let (input, lhs) = self.parse_and(input)?; - let (input, ors) = many0(preceded(self.ws(tag("OR")), |c| Self::parse_or(self, c)))(input)?; - - let expr = ors - .into_iter() - .fold(lhs, |acc, branch| FilterCondition::Or(Box::new(acc), Box::new(branch))); - Ok((input, expr)) - } - - fn parse_and(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - let (input, lhs) = self.parse_not(input)?; - let (input, ors) = - many0(preceded(self.ws(tag("AND")), |c| Self::parse_and(self, c)))(input)?; - let expr = ors - .into_iter() - .fold(lhs, |acc, branch| FilterCondition::And(Box::new(acc), Box::new(branch))); - Ok((input, expr)) - } - - fn parse_not(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - alt(( - map( - preceded(alt((self.ws(tag("!")), self.ws(tag("NOT")))), |c| { - Self::parse_condition_expression(self, c) - }), - |e| e.negate(), - ), - |c| Self::parse_condition_expression(self, c), - ))(input) - } - - fn ws(&'a self, inner: F) -> impl FnMut(&'a str) -> IResult<&'a str, O, E> - where - F: Fn(&'a str) -> IResult<&'a str, O, E>, - E: FilterParserError<'a>, - { - delimited(multispace0, inner, multispace0) - } - - fn parse_simple_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - let operator = alt((tag("<="), tag(">="), tag(">"), tag("="), tag("<"), tag("!="))); - let k = tuple((self.ws(|c| self.parse_key(c)), operator, self.ws(|c| self.parse_key(c))))( - input, - ); - let (input, (key, op, value)) = match k { - Ok(o) => o, - Err(e) => { - return Err(e); - } - }; - - let fid = self.parse_fid(input, key)?; - let r: StdResult>> = self.parse_numeric(value); - let k = match op { - "=" => FilterCondition::Operator(fid, Equal(r.ok(), value.to_string().to_lowercase())), - "!=" => { - FilterCondition::Operator(fid, NotEqual(r.ok(), value.to_string().to_lowercase())) - } - ">" | "<" | "<=" | ">=" => return self.parse_numeric_unary_condition(op, fid, value), - _ => unreachable!(), - }; - Ok((input, k)) - } - - fn parse_numeric(&'a self, input: &'a str) -> StdResult> - where - E: FilterParserError<'a>, - T: std::str::FromStr, - { - match input.parse::() { - Ok(n) => Ok(n), - Err(_) => { - return match input.chars().nth(0) { - Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), - None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), - }; - } - } - } - - fn parse_numeric_unary_condition( - &'a self, - input: &'a str, - fid: u16, - value: &'a str, - ) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - let numeric: f64 = self.parse_numeric(value)?; - let k = match input { - ">" => FilterCondition::Operator(fid, GreaterThan(numeric)), - "<" => FilterCondition::Operator(fid, LowerThan(numeric)), - "<=" => FilterCondition::Operator(fid, LowerThanOrEqual(numeric)), - ">=" => FilterCondition::Operator(fid, GreaterThanOrEqual(numeric)), - _ => unreachable!(), - }; - Ok((input, k)) - } - - fn parse_fid(&'a self, input: &'a str, key: &'a str) -> StdResult> - where - E: FilterParserError<'a>, - { - let error = match input.chars().nth(0) { - Some(ch) => Err(nom::Err::Failure(E::from_char(input, ch))), - None => Err(nom::Err::Failure(E::from_error_kind(input, ErrorKind::Eof))), - }; - if !self.filterable_fields.contains(key) { - return error; - } - match self.fields_ids_map.id(key) { - Some(fid) => Ok(fid), - None => error, - } - } - - fn parse_range_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - let (input, (key, from, _, to)) = tuple(( - self.ws(|c| self.parse_key(c)), - self.ws(|c| self.parse_key(c)), - tag("TO"), - self.ws(|c| self.parse_key(c)), - ))(input)?; - - let fid = self.parse_fid(input, key)?; - let numeric_from: f64 = self.parse_numeric(from)?; - let numeric_to: f64 = self.parse_numeric(to)?; - let res = FilterCondition::Operator(fid, Between(numeric_from, numeric_to)); - - Ok((input, res)) - } - - fn parse_geo_radius(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - let err_msg_args_incomplete= "_geoRadius. The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`"; - let err_msg_latitude_invalid = - "_geoRadius. Latitude must be contained between -90 and 90 degrees."; - - let err_msg_longitude_invalid = - "_geoRadius. Longitude must be contained between -180 and 180 degrees."; - - let (input, args): (&str, Vec<&str>) = match preceded( - tag("_geoRadius"), - delimited( - char('('), - separated_list1(tag(","), self.ws(|c| self.parse_value::(c))), - char(')'), - ), - )(input) - { - Ok(e) => e, - Err(_e) => { - return Err(nom::Err::Failure(E::add_context( - input, - err_msg_args_incomplete, - E::from_char(input, '('), - ))); - } - }; - - if args.len() != 3 { - let e = E::from_char(input, '('); - return Err(nom::Err::Failure(E::add_context(input, err_msg_args_incomplete, e))); - } - let lat = self.parse_numeric(args[0])?; - let lng = self.parse_numeric(args[1])?; - let dis = self.parse_numeric(args[2])?; - - let fid = match self.fields_ids_map.id("_geo") { - Some(fid) => fid, - None => return Ok((input, FilterCondition::Empty)), - }; - - if !(-90.0..=90.0).contains(&lat) { - return Err(nom::Err::Failure(E::add_context( - input, - err_msg_latitude_invalid, - E::from_char(input, '('), - ))); - } else if !(-180.0..=180.0).contains(&lng) { - return Err(nom::Err::Failure(E::add_context( - input, - err_msg_longitude_invalid, - E::from_char(input, '('), - ))); - } - - let res = FilterCondition::Operator(fid, GeoLowerThan([lat, lng], dis)); - Ok((input, res)) - } - - fn parse_condition(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - let l1 = |c| self.parse_simple_condition(c); - let l2 = |c| self.parse_range_condition(c); - let l3 = |c| self.parse_geo_radius(c); - alt((l1, l2, l3))(input) - } - - fn parse_condition_expression(&'a self, input: &'a str) -> IResult<&str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - alt(( - delimited(self.ws(char('(')), |c| Self::parse_expression(self, c), self.ws(char(')'))), - |c| Self::parse_condition(self, c), - ))(input) - } - - fn parse_key(&'a self, input: &'a str) -> IResult<&'a str, &'a str, E> - where - E: FilterParserError<'a>, - { - let key = |input| take_while1(Self::is_key_component)(input); - alt((key, delimited(char('"'), key, char('"'))))(input) - } - - fn parse_value(&'a self, input: &'a str) -> IResult<&'a str, &'a str, E> - where - E: FilterParserError<'a>, - { - let key = |input| take_while1(Self::is_key_component)(input); - alt((key, delimited(char('"'), key, char('"'))))(input) - } - - fn is_key_component(c: char) -> bool { - c.is_alphanumeric() || ['_', '-', '.'].contains(&c) - } - - pub fn parse_expression(&'a self, input: &'a str) -> IResult<&'a str, FilterCondition, E> - where - E: FilterParserError<'a>, - { - self.parse_or(input) - } -} - -#[cfg(test)] -mod tests { - use big_s::S; - use either::Either; - use heed::EnvOpenOptions; - use maplit::hashset; - - use super::*; - use crate::update::Settings; - use crate::Index; - - #[test] - fn string() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut map = index.fields_ids_map(&wtxn).unwrap(); - map.insert("channel"); - index.put_fields_ids_map(&mut wtxn, &map).unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_filterable_fields(hashset! { S("channel") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str(&rtxn, &index, "channel = Ponce").unwrap(); - let expected = FilterCondition::Operator(0, Operator::Equal(None, S("ponce"))); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str(&rtxn, &index, "channel != ponce").unwrap(); - let expected = FilterCondition::Operator(0, Operator::NotEqual(None, S("ponce"))); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str(&rtxn, &index, "NOT channel = ponce").unwrap(); - let expected = FilterCondition::Operator(0, Operator::NotEqual(None, S("ponce"))); - assert_eq!(condition, expected); - } - - #[test] - fn number() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut map = index.fields_ids_map(&wtxn).unwrap(); - map.insert("timestamp"); - index.put_fields_ids_map(&mut wtxn, &map).unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_filterable_fields(hashset! { "timestamp".into() }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str(&rtxn, &index, "timestamp 22 TO 44").unwrap(); - let expected = FilterCondition::Operator(0, Between(22.0, 44.0)); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str(&rtxn, &index, "NOT timestamp 22 TO 44").unwrap(); - let expected = FilterCondition::Or( - Box::new(FilterCondition::Operator(0, LowerThan(22.0))), - Box::new(FilterCondition::Operator(0, GreaterThan(44.0))), - ); - assert_eq!(condition, expected); - } - - #[test] - fn compare() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("channel"), S("timestamp"), S("id")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") ,S("id")}); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str(&rtxn, &index, "channel < 20").unwrap(); - let expected = FilterCondition::Operator(0, LowerThan(20.0)); - assert_eq!(condition, expected); - - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str(&rtxn, &index, "id < 200").unwrap(); - let expected = FilterCondition::Operator(2, LowerThan(200.0)); - assert_eq!(condition, expected); - } - - #[test] - fn parentheses() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_str( - &rtxn, - &index, - "channel = gotaga OR (timestamp 22 TO 44 AND channel != ponce)", - ) - .unwrap(); - let expected = FilterCondition::Or( - Box::new(FilterCondition::Operator(0, Operator::Equal(None, S("gotaga")))), - Box::new(FilterCondition::And( - Box::new(FilterCondition::Operator(1, Between(22.0, 44.0))), - Box::new(FilterCondition::Operator(0, Operator::NotEqual(None, S("ponce")))), - )), - ); - assert_eq!(condition, expected); - - let condition = FilterCondition::from_str( - &rtxn, - &index, - "channel = gotaga OR NOT (timestamp 22 TO 44 AND channel != ponce)", - ) - .unwrap(); - let expected = FilterCondition::Or( - Box::new(FilterCondition::Operator(0, Operator::Equal(None, S("gotaga")))), - Box::new(FilterCondition::Or( - Box::new(FilterCondition::Or( - Box::new(FilterCondition::Operator(1, LowerThan(22.0))), - Box::new(FilterCondition::Operator(1, GreaterThan(44.0))), - )), - Box::new(FilterCondition::Operator(0, Operator::Equal(None, S("ponce")))), - )), - ); - assert_eq!(condition, expected); - } - - #[test] - fn from_array() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("channel"), S("timestamp")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("channel"), S("timestamp") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - // Test that the facet condition is correctly generated. - let rtxn = index.read_txn().unwrap(); - let condition = FilterCondition::from_array( - &rtxn, - &index, - vec![ - Either::Right("channel = gotaga"), - Either::Left(vec!["timestamp = 44", "channel != ponce"]), - ], - ) - .unwrap() - .unwrap(); - let expected = FilterCondition::from_str( - &rtxn, - &index, - "channel = gotaga AND (timestamp = 44 OR channel != ponce)", - ) - .unwrap(); - assert_eq!(condition, expected); - } - #[test] - fn geo_radius() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("_geo"), S("price")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("_geo"), S("price") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - let rtxn = index.read_txn().unwrap(); - // basic test - let condition = - FilterCondition::from_str(&rtxn, &index, "_geoRadius(12, 13.0005, 2000)").unwrap(); - let expected = FilterCondition::Operator(0, GeoLowerThan([12., 13.0005], 2000.)); - assert_eq!(condition, expected); - - // test the negation of the GeoLowerThan - let condition = - FilterCondition::from_str(&rtxn, &index, "NOT _geoRadius(50, 18, 2000.500)").unwrap(); - let expected = FilterCondition::Operator(0, GeoGreaterThan([50., 18.], 2000.500)); - assert_eq!(condition, expected); - - // composition of multiple operations - let condition = FilterCondition::from_str( - &rtxn, - &index, - "(NOT _geoRadius(1, 2, 300) AND _geoRadius(1.001, 2.002, 1000.300)) OR price <= 10", - ) - .unwrap(); - let expected = FilterCondition::Or( - Box::new(FilterCondition::And( - Box::new(FilterCondition::Operator(0, GeoGreaterThan([1., 2.], 300.))), - Box::new(FilterCondition::Operator(0, GeoLowerThan([1.001, 2.002], 1000.300))), - )), - Box::new(FilterCondition::Operator(1, LowerThanOrEqual(10.))), - ); - assert_eq!(condition, expected); - } - - #[test] - fn geo_radius_error() { - let path = tempfile::tempdir().unwrap(); - let mut options = EnvOpenOptions::new(); - options.map_size(10 * 1024 * 1024); // 10 MB - let index = Index::new(options, &path).unwrap(); - - // Set the filterable fields to be the channel. - let mut wtxn = index.write_txn().unwrap(); - let mut builder = Settings::new(&mut wtxn, &index, 0); - builder.set_searchable_fields(vec![S("_geo"), S("price")]); // to keep the fields order - builder.set_filterable_fields(hashset! { S("_geo"), S("price") }); - builder.execute(|_, _| ()).unwrap(); - wtxn.commit().unwrap(); - - let rtxn = index.read_txn().unwrap(); - - // georadius don't have any parameters - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - // georadius don't have any parameters - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius()"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - // georadius don't have enough parameters - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - // georadius have too many parameters - let result = - FilterCondition::from_str(&rtxn, &index, "_geoRadius(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error.to_string().contains("The `_geoRadius` filter expect three arguments: `_geoRadius(latitude, longitude, radius)`")); - - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-100, 150, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Latitude must be contained between -90 and 90 degrees.")); - - // georadius have a bad latitude - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-90.0000001, 150, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Latitude must be contained between -90 and 90 degrees.")); - - // georadius have a bad longitude - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-10, 250, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Longitude must be contained between -180 and 180 degrees.")); - - // georadius have a bad longitude - let result = FilterCondition::from_str(&rtxn, &index, "_geoRadius(-10, 180.000001, 10)"); - assert!(result.is_err()); - let error = result.unwrap_err(); - assert!(error - .to_string() - .contains("Longitude must be contained between -180 and 180 degrees.")); - } -} diff --git a/milli/src/search/facet/grammar.pest b/milli/src/search/facet/grammar.pest new file mode 100644 index 000000000..8bfdeb667 --- /dev/null +++ b/milli/src/search/facet/grammar.pest @@ -0,0 +1,33 @@ +key = _{reserved | quoted | word } +value = _{quoted | word } +quoted = _{ (PUSH("'") | PUSH("\"")) ~ string ~ POP } +string = {char*} +word = ${(LETTER | NUMBER | "_" | "-" | ".")+} + +char = _{ !(PEEK | "\\") ~ ANY + | "\\" ~ (PEEK | "\\" | "/" | "b" | "f" | "n" | "r" | "t") + | "\\" ~ ("u" ~ ASCII_HEX_DIGIT{4})} + +reserved = { "_geoDistance" | ("_geoPoint" ~ parameters) | "_geo" } +// we deliberately choose to allow empty parameters to generate more specific error message later +parameters = {("(" ~ (value ~ ",")* ~ value? ~ ")") | ""} +condition = _{between | eq | greater | less | geq | leq | neq} +between = {key ~ value ~ "TO" ~ value} +geq = {key ~ ">=" ~ value} +leq = {key ~ "<=" ~ value} +neq = {key ~ "!=" ~ value} +eq = {key ~ "=" ~ value} +greater = {key ~ ">" ~ value} +less = {key ~ "<" ~ value} +geo_radius = {"_geoRadius" ~ parameters } + +prgm = {SOI ~ expr ~ EOI} +expr = _{ ( term ~ (operation ~ term)* ) } +term = { ("(" ~ expr ~ ")") | condition | not | geo_radius } +operation = _{ and | or } +and = {"AND"} +or = {"OR"} + +not = {"NOT" ~ term} + +WHITESPACE = _{ " " } diff --git a/milli/src/search/facet/mod.rs b/milli/src/search/facet/mod.rs index 3efa0262f..ddf710e32 100644 --- a/milli/src/search/facet/mod.rs +++ b/milli/src/search/facet/mod.rs @@ -1,10 +1,11 @@ pub use self::facet_distribution::FacetDistribution; pub use self::facet_number::{FacetNumberIter, FacetNumberRange, FacetNumberRevRange}; pub use self::facet_string::FacetStringIter; -pub use self::filter_condition::FilterCondition; +pub use self::filter_condition::{FilterCondition, Operator}; +pub(crate) use self::parser::Rule as ParserRule; mod facet_distribution; mod facet_number; mod facet_string; mod filter_condition; -mod filter_parser; +mod parser; diff --git a/milli/src/search/facet/parser.rs b/milli/src/search/facet/parser.rs new file mode 100644 index 000000000..1bff27cfb --- /dev/null +++ b/milli/src/search/facet/parser.rs @@ -0,0 +1,12 @@ +use once_cell::sync::Lazy; +use pest::prec_climber::{Assoc, Operator, PrecClimber}; + +pub static PREC_CLIMBER: Lazy> = Lazy::new(|| { + use Assoc::*; + use Rule::*; + pest::prec_climber::PrecClimber::new(vec![Operator::new(or, Left), Operator::new(and, Left)]) +}); + +#[derive(Parser)] +#[grammar = "search/facet/grammar.pest"] +pub struct FilterParser; diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 9b76ca851..bec059d46 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -14,7 +14,8 @@ use meilisearch_tokenizer::{Analyzer, AnalyzerConfig}; use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; -pub use self::facet::{FacetDistribution, FacetNumberIter, FilterCondition}; +pub(crate) use self::facet::ParserRule; +pub use self::facet::{FacetDistribution, FacetNumberIter, FilterCondition, Operator}; pub use self::matching_words::MatchingWords; use self::query_tree::QueryTreeBuilder; use crate::error::UserError;