diff --git a/meilisearch-core/src/error.rs b/meilisearch-core/src/error.rs index eca70843f..370147d59 100644 --- a/meilisearch-core/src/error.rs +++ b/meilisearch-core/src/error.rs @@ -33,6 +33,28 @@ impl From for Error { } } +impl From> for Error { + fn from(error: PestError) -> Error { + Error::FilterParseError(error.renamed_rules(|r| { + let s = match r { + Rule::or => "OR", + Rule::and => "AND", + Rule::not => "NOT", + Rule::string => "string", + Rule::word => "word", + Rule::greater => "field>value", + Rule::less => "field "field:value", + Rule::leq => "field<=value", + Rule::geq => "field>=value", + Rule::key => "key", + _ => "other", + }; + s.to_string() + })) + } +} + impl From for Error { fn from(error: meilisearch_schema::Error) -> Error { Error::Schema(error) diff --git a/meilisearch-core/src/filters/condition.rs b/meilisearch-core/src/filters/condition.rs new file mode 100644 index 000000000..fe364ebcb --- /dev/null +++ b/meilisearch-core/src/filters/condition.rs @@ -0,0 +1,200 @@ +use std::str::FromStr; +use std::cmp::Ordering; + +use crate::error::Error; +use crate::{store::Index, DocumentId, MainT}; +use heed::RoTxn; +use meilisearch_schema::{FieldId, Schema}; +use pest::error::{Error as PestError, ErrorVariant}; +use pest::iterators::Pair; +use serde_json::{Value, Number}; +use super::parser::Rule; + +#[derive(Debug)] +enum ConditionType { + Greater, + Less, + Equal, + LessEqual, + GreaterEqual, + NotEqual, +} + +/// We need to infer type when the filter is constructed +/// and match every possible types it can be parsed into. +#[derive(Debug)] +struct ConditionValue<'a> { + string: &'a str, + boolean: Option, + number: Option +} + +impl<'a> ConditionValue<'a> { + pub fn new(value: &Pair<'a, Rule>) -> Self { + let value = match value.as_rule() { + Rule::string | Rule::word => { + let string = value.as_str(); + let boolean = match value.as_str() { + "true" => Some(true), + "false" => Some(false), + _ => None, + }; + let number = Number::from_str(value.as_str()).ok(); + ConditionValue { string, boolean, number } + }, + _ => unreachable!(), + }; + value + } + + pub fn as_str(&self) -> &str { + self.string.as_ref() + } + + pub fn as_number(&self) -> Option<&Number> { + self.number.as_ref() + } + + pub fn as_bool(&self) -> Option { + self.boolean + } +} + +#[derive(Debug)] +pub struct Condition<'a> { + field: FieldId, + condition: ConditionType, + value: ConditionValue<'a> +} + +fn get_field_value<'a>(schema: &Schema, pair: Pair<'a, Rule>) -> Result<(FieldId, ConditionValue<'a>), Error> { + let mut items = pair.into_inner(); + // lexing ensures that we at least have a key + let key = items.next().unwrap(); + let field = schema + .id(key.as_str()) + .ok_or::>(PestError::new_from_span( + ErrorVariant::CustomError { + message: format!( + "attribute `{}` not found, available attributes are: {}", + key.as_str(), + schema.names().collect::>().join(", ") + ), + }, + key.as_span()))?; + let value = ConditionValue::new(&items.next().unwrap()); + Ok((field, value)) +} + +// undefined behavior with big numbers +fn compare_numbers(lhs: &Number, rhs: &Number) -> Option { + match (lhs.as_i64(), lhs.as_u64(), lhs.as_f64(), + rhs.as_i64(), rhs.as_u64(), rhs.as_f64()) { + // i64 u64 f64 i64 u64 f64 + (Some(lhs), _, _, Some(rhs), _, _) => lhs.partial_cmp(&rhs), + (_, Some(lhs), _, _, Some(rhs), _) => lhs.partial_cmp(&rhs), + (_, _, Some(lhs), _, _, Some(rhs)) => lhs.partial_cmp(&rhs), + (_, _, _, _, _, _) => None, + } +} + +impl<'a> Condition<'a> { + pub fn less( + item: Pair<'a, Rule>, + schema: &'a Schema, + ) -> Result { + let (field, value) = get_field_value(schema, item)?; + let condition = ConditionType::Less; + Ok(Self { field, condition, value }) + } + + pub fn greater( + item: Pair<'a, Rule>, + schema: &'a Schema, + ) -> Result { + let (field, value) = get_field_value(schema, item)?; + let condition = ConditionType::Greater; + Ok(Self { field, condition, value }) + } + + pub fn neq( + item: Pair<'a, Rule>, + schema: &'a Schema, + ) -> Result { + let (field, value) = get_field_value(schema, item)?; + let condition = ConditionType::NotEqual; + Ok(Self { field, condition, value }) + } + + pub fn geq( + item: Pair<'a, Rule>, + schema: &'a Schema, + ) -> Result { + let (field, value) = get_field_value(schema, item)?; + let condition = ConditionType::GreaterEqual; + Ok(Self { field, condition, value }) + } + + pub fn leq( + item: Pair<'a, Rule>, + schema: &'a Schema, + ) -> Result { + let (field, value) = get_field_value(schema, item)?; + let condition = ConditionType::LessEqual; + Ok(Self { field, condition, value }) + } + + pub fn eq( + item: Pair<'a, Rule>, + schema: &'a Schema, + ) -> Result { + let (field, value) = get_field_value(schema, item)?; + let condition = ConditionType::Equal; + Ok(Self { field, condition, value }) + } + + pub fn test( + &self, + reader: &RoTxn, + index: &Index, + document_id: DocumentId, + ) -> Result { + match index.document_attribute::(reader, document_id, self.field)? { + Some(Value::String(s)) => { + let value = self.value.as_str(); + match self.condition { + ConditionType::Equal => Ok(unicase::eq(value, &s)), + ConditionType::NotEqual => Ok(!unicase::eq(value, &s)), + _ => Ok(false) + } + }, + Some(Value::Number(n)) => { + if let Some(value) = self.value.as_number() { + if let Some(ord) = compare_numbers(&n, value) { + let res = match self.condition { + ConditionType::Equal => ord == Ordering::Equal, + ConditionType::NotEqual => ord != Ordering::Equal, + ConditionType::GreaterEqual => ord != Ordering::Less, + ConditionType::LessEqual => ord != Ordering::Greater, + ConditionType::Greater => ord == Ordering::Greater, + ConditionType::Less => ord == Ordering::Less, + }; + return Ok(res) + } + } + Ok(false) + }, + Some(Value::Bool(b)) => { + if let Some(value) = self.value.as_bool() { + return match self.condition { + ConditionType::Equal => Ok(b == value), + ConditionType::NotEqual => Ok(b != value), + _ => Ok(false) + } + } + Ok(false) + }, + _ => Ok(false), + } + } +} diff --git a/meilisearch-core/src/filters/mod.rs b/meilisearch-core/src/filters/mod.rs new file mode 100644 index 000000000..3ee3b6497 --- /dev/null +++ b/meilisearch-core/src/filters/mod.rs @@ -0,0 +1,125 @@ +mod parser; +mod condition; + +pub(crate) use parser::Rule; + +use std::ops::Not; + +use condition::Condition; +use crate::error::Error; +use crate::{DocumentId, MainT, store::Index}; +use heed::RoTxn; +use meilisearch_schema::Schema; +use parser::{PREC_CLIMBER, FilterParser}; +use pest::iterators::{Pair, Pairs}; +use pest::Parser; + +type FilterResult<'a> = Result, Error>; + +#[derive(Debug)] +pub enum Filter<'a> { + Condition(Condition<'a>), + Or(Box, Box), + And(Box, Box), + Not(Box), +} + +impl<'a> Filter<'a> { + pub fn parse(expr: &'a str, schema: &'a Schema) -> FilterResult<'a> { + let mut lexed = FilterParser::parse(Rule::prgm, expr.as_ref())?; + Self::build(lexed.next().unwrap().into_inner(), schema) + } + + pub fn test( + &self, + reader: &RoTxn, + index: &Index, + document_id: DocumentId, + ) -> Result { + use Filter::*; + match self { + Condition(c) => c.test(reader, index, document_id), + Or(lhs, rhs) => Ok( + lhs.test(reader, index, document_id)? || rhs.test(reader, index, document_id)? + ), + And(lhs, rhs) => Ok( + lhs.test(reader, index, document_id)? && rhs.test(reader, index, document_id)? + ), + Not(op) => op.test(reader, index, document_id).map(bool::not), + } + } + + fn build(expression: Pairs<'a, Rule>, schema: &'a Schema) -> FilterResult<'a> { + PREC_CLIMBER.climb( + expression, + |pair: Pair| match pair.as_rule() { + Rule::eq => Ok(Filter::Condition(Condition::eq(pair, schema)?)), + Rule::greater => Ok(Filter::Condition(Condition::greater(pair, schema)?)), + Rule::less => Ok(Filter::Condition(Condition::less(pair, schema)?)), + Rule::neq => Ok(Filter::Condition(Condition::neq(pair, schema)?)), + Rule::geq => Ok(Filter::Condition(Condition::geq(pair, schema)?)), + Rule::leq => Ok(Filter::Condition(Condition::leq(pair, schema)?)), + Rule::prgm => Self::build(pair.into_inner(), schema), + Rule::not => Ok(Filter::Not(Box::new(Self::build( + pair.into_inner(), + schema, + )?))), + _ => unreachable!(), + }, + |lhs: FilterResult, op: Pair, rhs: FilterResult| match op.as_rule() { + Rule::or => Ok(Filter::Or(Box::new(lhs?), Box::new(rhs?))), + Rule::and => Ok(Filter::And(Box::new(lhs?), Box::new(rhs?))), + _ => unreachable!(), + }, + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn invalid_syntax() { + assert!(FilterParser::parse(Rule::prgm, "field : id").is_err()); + assert!(FilterParser::parse(Rule::prgm, "field=hello hello").is_err()); + assert!(FilterParser::parse(Rule::prgm, "field=hello OR OR").is_err()); + assert!(FilterParser::parse(Rule::prgm, "OR field:hello").is_err()); + assert!(FilterParser::parse(Rule::prgm, r#"field="hello world"#).is_err()); + assert!(FilterParser::parse(Rule::prgm, r#"field='hello world"#).is_err()); + assert!(FilterParser::parse(Rule::prgm, "NOT field=").is_err()); + assert!(FilterParser::parse(Rule::prgm, "N").is_err()); + assert!(FilterParser::parse(Rule::prgm, "(field=1").is_err()); + assert!(FilterParser::parse(Rule::prgm, "(field=1))").is_err()); + assert!(FilterParser::parse(Rule::prgm, "field=1ORfield=2").is_err()); + assert!(FilterParser::parse(Rule::prgm, "field=1 ( OR field=2)").is_err()); + assert!(FilterParser::parse(Rule::prgm, "hello world=1").is_err()); + assert!(FilterParser::parse(Rule::prgm, "").is_err()); + assert!(FilterParser::parse(Rule::prgm, r#"((((((hello=world)))))"#).is_err()); + } + + #[test] + fn valid_syntax() { + assert!(FilterParser::parse(Rule::prgm, "field = id").is_ok()); + assert!(FilterParser::parse(Rule::prgm, "field=id").is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field >= 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field <= 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field="hello world""#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field='hello world'"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field > 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field < 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field < 10 AND NOT field=5"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field=true OR NOT field=5"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"NOT field=true OR NOT field=5"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field='hello world' OR ( NOT field=true OR NOT field=5 )"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field='hello \'worl\'d' OR ( NOT field=true OR NOT field=5 )"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"field="hello \"worl\"d" OR ( NOT field=true OR NOT field=5 )"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"((((((hello=world))))))"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#""foo bar" > 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#""foo bar" = 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"'foo bar' = 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"'foo bar' <= 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"'foo bar' != 10"#).is_ok()); + assert!(FilterParser::parse(Rule::prgm, r#"bar != 10"#).is_ok()); + } +}