diff --git a/filter-parser/src/error.rs b/filter-parser/src/error.rs index e28685c7a..8a628156a 100644 --- a/filter-parser/src/error.rs +++ b/filter-parser/src/error.rs @@ -65,6 +65,7 @@ pub enum ErrorKind<'a> { MalformedValue, InOpeningBracket, InClosingBracket, + NonFiniteFloat, InExpectedValue(ExpectedValueKind), ReservedKeyword(String), MissingClosingDelimiter(char), @@ -167,6 +168,9 @@ impl<'a> Display for Error<'a> { ErrorKind::InClosingBracket => { writeln!(f, "Expected matching `]` after the list of field names given to `IN[`")? } + ErrorKind::NonFiniteFloat => { + writeln!(f, "Non finite floats are not supported")? + } ErrorKind::InExpectedValue(ExpectedValueKind::ReservedKeyword) => { writeln!(f, "Expected only comma-separated field names inside `IN[..]` but instead found `{escaped_input}`, which is a keyword. To use `{escaped_input}` as a field name or a value, surround it by quotes.")? } diff --git a/filter-parser/src/lib.rs b/filter-parser/src/lib.rs index 8c1431d93..a9bd9b3d7 100644 --- a/filter-parser/src/lib.rs +++ b/filter-parser/src/lib.rs @@ -44,7 +44,6 @@ mod error; mod value; use std::fmt::Debug; -use std::str::FromStr; pub use condition::{parse_condition, parse_to, Condition}; use condition::{parse_exists, parse_not_exists}; @@ -100,12 +99,13 @@ impl<'a> Token<'a> { Error::new_from_external(self.span, error) } - pub fn parse(&self) -> Result - where - T: FromStr, - T::Err: std::error::Error, - { - self.span.parse().map_err(|e| self.as_external_error(e)) + pub fn parse_finite_float(&self) -> Result { + let value: f64 = self.span.parse().map_err(|e| self.as_external_error(e))?; + if value.is_finite() { + Ok(value) + } else { + Err(Error::new_from_kind(self.span, ErrorKind::NonFiniteFloat)) + } } } diff --git a/milli/src/search/facet/filter.rs b/milli/src/search/facet/filter.rs index 5da1ba7fd..9b87353b0 100644 --- a/milli/src/search/facet/filter.rs +++ b/milli/src/search/facet/filter.rs @@ -169,11 +169,19 @@ impl<'a> Filter<'a> { // field id and the level. let (left, right) = match operator { - Condition::GreaterThan(val) => (Excluded(val.parse()?), Included(f64::MAX)), - Condition::GreaterThanOrEqual(val) => (Included(val.parse()?), Included(f64::MAX)), - Condition::LowerThan(val) => (Included(f64::MIN), Excluded(val.parse()?)), - Condition::LowerThanOrEqual(val) => (Included(f64::MIN), Included(val.parse()?)), - Condition::Between { from, to } => (Included(from.parse()?), Included(to.parse()?)), + Condition::GreaterThan(val) => { + (Excluded(val.parse_finite_float()?), Included(f64::MAX)) + } + Condition::GreaterThanOrEqual(val) => { + (Included(val.parse_finite_float()?), Included(f64::MAX)) + } + Condition::LowerThan(val) => (Included(f64::MIN), Excluded(val.parse_finite_float()?)), + Condition::LowerThanOrEqual(val) => { + (Included(f64::MIN), Included(val.parse_finite_float()?)) + } + Condition::Between { from, to } => { + (Included(from.parse_finite_float()?), Included(to.parse_finite_float()?)) + } Condition::Exists => { let exist = index.exists_faceted_documents_ids(rtxn, field_id)?; return Ok(exist); @@ -190,7 +198,7 @@ impl<'a> Filter<'a> { )? .map(|v| v.bitmap) .unwrap_or_default(); - let number = val.parse::().ok(); + let number = val.parse_finite_float().ok(); let number_docids = match number { Some(n) => { let n = Included(n); @@ -389,7 +397,8 @@ impl<'a> Filter<'a> { } FilterCondition::GeoLowerThan { point, radius } => { if filterable_fields.contains("_geo") { - let base_point: [f64; 2] = [point[0].parse()?, point[1].parse()?]; + let base_point: [f64; 2] = + [point[0].parse_finite_float()?, point[1].parse_finite_float()?]; if !(-90.0..=90.0).contains(&base_point[0]) { return Err( point[0].as_external_error(FilterError::BadGeoLat(base_point[0])) @@ -400,7 +409,7 @@ impl<'a> Filter<'a> { point[1].as_external_error(FilterError::BadGeoLng(base_point[1])) )?; } - let radius = radius.parse()?; + let radius = radius.parse_finite_float()?; let rtree = match index.geo_rtree(rtxn)? { Some(rtree) => rtree, None => return Ok(RoaringBitmap::new()), @@ -689,4 +698,60 @@ mod tests { let option = Filter::from_str(" ").unwrap(); assert_eq!(option, None); } + + #[test] + fn non_finite_float() { + let index = TempIndex::new(); + + index + .update_settings(|settings| { + settings.set_searchable_fields(vec![S("price")]); // to keep the fields order + settings.set_filterable_fields(hashset! { S("price") }); + }) + .unwrap(); + index + .add_documents(documents!([ + { + "id": "test_1", + "price": "inf" + }, + { + "id": "test_2", + "price": "2000" + }, + { + "id": "test_3", + "price": "infinity" + }, + ])) + .unwrap(); + + let rtxn = index.read_txn().unwrap(); + let filter = Filter::from_str("price = inf").unwrap().unwrap(); + let result = filter.evaluate(&rtxn, &index).unwrap(); + assert!(result.contains(0)); + let filter = Filter::from_str("price < inf").unwrap().unwrap(); + assert!(matches!( + filter.evaluate(&rtxn, &index), + Err(crate::Error::UserError(crate::error::UserError::InvalidFilter(_))) + )); + + let filter = Filter::from_str("price = NaN").unwrap().unwrap(); + let result = filter.evaluate(&rtxn, &index).unwrap(); + assert!(result.is_empty()); + let filter = Filter::from_str("price < NaN").unwrap().unwrap(); + assert!(matches!( + filter.evaluate(&rtxn, &index), + Err(crate::Error::UserError(crate::error::UserError::InvalidFilter(_))) + )); + + let filter = Filter::from_str("price = infinity").unwrap().unwrap(); + let result = filter.evaluate(&rtxn, &index).unwrap(); + assert!(result.contains(2)); + let filter = Filter::from_str("price < infinity").unwrap().unwrap(); + assert!(matches!( + filter.evaluate(&rtxn, &index), + Err(crate::Error::UserError(crate::error::UserError::InvalidFilter(_))) + )); + } }