MeiliSearch/meilisearch-core/src/filters/condition.rs

277 lines
9.3 KiB
Rust

use std::str::FromStr;
use std::cmp::Ordering;
use crate::error::Error;
use crate::{store::Index, DocumentId, MainT};
use heed::RoTxn;
use meilisearch_schema::{FieldId, Schema};
use pest::error::{Error as PestError, ErrorVariant};
use pest::iterators::Pair;
use serde_json::{Value, Number};
use super::parser::Rule;
#[derive(Debug, PartialEq)]
enum ConditionType {
Greater,
Less,
Equal,
LessEqual,
GreaterEqual,
NotEqual,
}
/// We need to infer type when the filter is constructed
/// and match every possible types it can be parsed into.
#[derive(Debug)]
struct ConditionValue<'a> {
string: &'a str,
boolean: Option<bool>,
number: Option<Number>
}
impl<'a> ConditionValue<'a> {
pub fn new(value: &Pair<'a, Rule>) -> Self {
match value.as_rule() {
Rule::string | Rule::word => {
let string = value.as_str();
let boolean = match value.as_str() {
"true" => Some(true),
"false" => Some(false),
_ => None,
};
let number = Number::from_str(value.as_str()).ok();
ConditionValue { string, boolean, number }
},
_ => unreachable!(),
}
}
pub fn as_str(&self) -> &str {
self.string
}
pub fn as_number(&self) -> Option<&Number> {
self.number.as_ref()
}
pub fn as_bool(&self) -> Option<bool> {
self.boolean
}
}
#[derive(Debug)]
pub struct Condition<'a> {
field: FieldId,
condition: ConditionType,
value: ConditionValue<'a>
}
fn get_field_value<'a>(schema: &Schema, pair: Pair<'a, Rule>) -> Result<(FieldId, ConditionValue<'a>), Error> {
let mut items = pair.into_inner();
// lexing ensures that we at least have a key
let key = items.next().unwrap();
let field = schema
.id(key.as_str())
.ok_or_else(|| PestError::new_from_span(
ErrorVariant::CustomError {
message: format!(
"attribute `{}` not found, available attributes are: {}",
key.as_str(),
schema.names().collect::<Vec<_>>().join(", ")
),
},
key.as_span()))?;
let value = ConditionValue::new(&items.next().unwrap());
Ok((field, value))
}
// undefined behavior with big numbers
fn compare_numbers(lhs: &Number, rhs: &Number) -> Option<Ordering> {
match (lhs.as_i64(), lhs.as_u64(), lhs.as_f64(),
rhs.as_i64(), rhs.as_u64(), rhs.as_f64()) {
// i64 u64 f64 i64 u64 f64
(Some(lhs), _, _, Some(rhs), _, _) => lhs.partial_cmp(&rhs),
(_, Some(lhs), _, _, Some(rhs), _) => lhs.partial_cmp(&rhs),
(_, _, Some(lhs), _, _, Some(rhs)) => lhs.partial_cmp(&rhs),
(_, _, _, _, _, _) => None,
}
}
impl<'a> Condition<'a> {
pub fn less(
item: Pair<'a, Rule>,
schema: &'a Schema,
) -> Result<Self, Error> {
let (field, value) = get_field_value(schema, item)?;
let condition = ConditionType::Less;
Ok(Self { field, condition, value })
}
pub fn greater(
item: Pair<'a, Rule>,
schema: &'a Schema,
) -> Result<Self, Error> {
let (field, value) = get_field_value(schema, item)?;
let condition = ConditionType::Greater;
Ok(Self { field, condition, value })
}
pub fn neq(
item: Pair<'a, Rule>,
schema: &'a Schema,
) -> Result<Self, Error> {
let (field, value) = get_field_value(schema, item)?;
let condition = ConditionType::NotEqual;
Ok(Self { field, condition, value })
}
pub fn geq(
item: Pair<'a, Rule>,
schema: &'a Schema,
) -> Result<Self, Error> {
let (field, value) = get_field_value(schema, item)?;
let condition = ConditionType::GreaterEqual;
Ok(Self { field, condition, value })
}
pub fn leq(
item: Pair<'a, Rule>,
schema: &'a Schema,
) -> Result<Self, Error> {
let (field, value) = get_field_value(schema, item)?;
let condition = ConditionType::LessEqual;
Ok(Self { field, condition, value })
}
pub fn eq(
item: Pair<'a, Rule>,
schema: &'a Schema,
) -> Result<Self, Error> {
let (field, value) = get_field_value(schema, item)?;
let condition = ConditionType::Equal;
Ok(Self { field, condition, value })
}
pub fn test(
&self,
reader: &RoTxn<MainT>,
index: &Index,
document_id: DocumentId,
) -> Result<bool, Error> {
match index.document_attribute::<Value>(reader, document_id, self.field)? {
Some(Value::Array(values)) => Ok(values.iter().any(|v| self.match_value(Some(v)))),
other => Ok(self.match_value(other.as_ref())),
}
}
fn match_value(&self, value: Option<&Value>) -> bool {
match value {
Some(Value::String(s)) => {
let value = self.value.as_str();
match self.condition {
ConditionType::Equal => unicase::eq(value, &s),
ConditionType::NotEqual => !unicase::eq(value, &s),
_ => false
}
},
Some(Value::Number(n)) => {
if let Some(value) = self.value.as_number() {
if let Some(ord) = compare_numbers(&n, value) {
let res = match self.condition {
ConditionType::Equal => ord == Ordering::Equal,
ConditionType::NotEqual => ord != Ordering::Equal,
ConditionType::GreaterEqual => ord != Ordering::Less,
ConditionType::LessEqual => ord != Ordering::Greater,
ConditionType::Greater => ord == Ordering::Greater,
ConditionType::Less => ord == Ordering::Less,
};
return res
}
}
false
},
Some(Value::Bool(b)) => {
if let Some(value) = self.value.as_bool() {
let res = match self.condition {
ConditionType::Equal => *b == value,
ConditionType::NotEqual => *b != value,
_ => false
};
return res
}
false
},
// if field is not supported (or not found), all values are different from it,
// so != should always return true in this case.
_ => self.condition == ConditionType::NotEqual,
}
}
}
#[cfg(test)]
mod test {
use super::*;
use serde_json::Number;
use std::cmp::Ordering;
#[test]
fn test_number_comp() {
// test both u64
let n1 = Number::from(1u64);
let n2 = Number::from(2u64);
assert_eq!(Some(Ordering::Less), compare_numbers(&n1, &n2));
assert_eq!(Some(Ordering::Greater), compare_numbers(&n2, &n1));
let n1 = Number::from(1u64);
let n2 = Number::from(1u64);
assert_eq!(Some(Ordering::Equal), compare_numbers(&n1, &n2));
// test both i64
let n1 = Number::from(1i64);
let n2 = Number::from(2i64);
assert_eq!(Some(Ordering::Less), compare_numbers(&n1, &n2));
assert_eq!(Some(Ordering::Greater), compare_numbers(&n2, &n1));
let n1 = Number::from(1i64);
let n2 = Number::from(1i64);
assert_eq!(Some(Ordering::Equal), compare_numbers(&n1, &n2));
// test both f64
let n1 = Number::from_f64(1f64).unwrap();
let n2 = Number::from_f64(2f64).unwrap();
assert_eq!(Some(Ordering::Less), compare_numbers(&n1, &n2));
assert_eq!(Some(Ordering::Greater), compare_numbers(&n2, &n1));
let n1 = Number::from_f64(1f64).unwrap();
let n2 = Number::from_f64(1f64).unwrap();
assert_eq!(Some(Ordering::Equal), compare_numbers(&n1, &n2));
// test one u64 and one f64
let n1 = Number::from_f64(1f64).unwrap();
let n2 = Number::from(2u64);
assert_eq!(Some(Ordering::Less), compare_numbers(&n1, &n2));
assert_eq!(Some(Ordering::Greater), compare_numbers(&n2, &n1));
// equality
let n1 = Number::from_f64(1f64).unwrap();
let n2 = Number::from(1u64);
assert_eq!(Some(Ordering::Equal), compare_numbers(&n1, &n2));
assert_eq!(Some(Ordering::Equal), compare_numbers(&n2, &n1));
// float is neg
let n1 = Number::from_f64(-1f64).unwrap();
let n2 = Number::from(1u64);
assert_eq!(Some(Ordering::Less), compare_numbers(&n1, &n2));
assert_eq!(Some(Ordering::Greater), compare_numbers(&n2, &n1));
// float is too big
let n1 = Number::from_f64(std::f64::MAX).unwrap();
let n2 = Number::from(1u64);
assert_eq!(Some(Ordering::Greater), compare_numbers(&n1, &n2));
assert_eq!(Some(Ordering::Less), compare_numbers(&n2, &n1));
// misc
let n1 = Number::from_f64(std::f64::MAX).unwrap();
let n2 = Number::from(std::u64::MAX);
assert_eq!(Some(Ordering::Greater), compare_numbers(&n1, &n2));
assert_eq!(Some( Ordering::Less ), compare_numbers(&n2, &n1));
}
}