diff --git a/filter-parser/src/lib.rs b/filter-parser/src/lib.rs index ed36b1bf4..0e49e00e9 100644 --- a/filter-parser/src/lib.rs +++ b/filter-parser/src/lib.rs @@ -113,6 +113,24 @@ pub enum FilterCondition<'a> { } impl<'a> FilterCondition<'a> { + /// Returns the first token found at the specified depth, `None` if no token at this depth. + pub fn token_at_depth(&self, depth: usize) -> Option<&Token> { + match self { + FilterCondition::Condition { fid, .. } if depth == 0 => Some(fid), + FilterCondition::Or(left, right) => { + let depth = depth.saturating_sub(1); + right.token_at_depth(depth).or_else(|| left.token_at_depth(depth)) + } + FilterCondition::And(left, right) => { + let depth = depth.saturating_sub(1); + right.token_at_depth(depth).or_else(|| left.token_at_depth(depth)) + } + FilterCondition::GeoLowerThan { point: [point, _], .. } if depth == 0 => Some(point), + FilterCondition::GeoGreaterThan { point: [point, _], .. } if depth == 0 => Some(point), + _ => None, + } + } + pub fn negate(self) -> FilterCondition<'a> { use FilterCondition::*; @@ -584,4 +602,10 @@ pub mod tests { assert!(filter.starts_with(expected), "Filter `{:?}` was supposed to return the following error:\n{}\n, but instead returned\n{}\n.", input, expected, filter); } } + + #[test] + fn depth() { + let filter = FilterCondition::parse("account_ids=1 OR account_ids=2 OR account_ids=3 OR account_ids=4 OR account_ids=5 OR account_ids=6").unwrap(); + assert!(filter.token_at_depth(5).is_some()); + } } diff --git a/milli/src/search/facet/filter.rs b/milli/src/search/facet/filter.rs index e994f36d9..9d9d16de5 100644 --- a/milli/src/search/facet/filter.rs +++ b/milli/src/search/facet/filter.rs @@ -15,6 +15,9 @@ use crate::heed_codec::facet::{ }; use crate::{distance_between_two_points, CboRoaringBitmapCodec, FieldId, Index, Result}; +/// The maximum number of filters the filter AST can process. +const MAX_FILTER_DEPTH: usize = 2000; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct Filter<'a> { condition: FilterCondition<'a>, @@ -27,6 +30,7 @@ enum FilterError<'a> { BadGeoLat(f64), BadGeoLng(f64), Reserved(&'a str), + TooDeep, InternalError, } impl<'a> std::error::Error for FilterError<'a> {} @@ -40,6 +44,10 @@ impl<'a> Display for FilterError<'a> { attribute, filterable, ), + Self::TooDeep => write!(f, + "Too many filter conditions, can't process more than {} filters.", + MAX_FILTER_DEPTH + ), Self::Reserved(keyword) => write!( f, "`{}` is a reserved keyword and thus can't be used as a filter expression.", @@ -108,6 +116,10 @@ impl<'a> Filter<'a> { } } + if let Some(token) = ands.as_ref().and_then(|fc| fc.token_at_depth(MAX_FILTER_DEPTH)) { + return Err(token.as_external_error(FilterError::TooDeep).into()); + } + Ok(ands.map(|ands| Self { condition: ands })) } @@ -116,6 +128,11 @@ impl<'a> Filter<'a> { Ok(fc) => Ok(fc), Err(e) => Err(Error::UserError(UserError::InvalidFilter(e.to_string()))), }?; + + if let Some(token) = condition.token_at_depth(MAX_FILTER_DEPTH) { + return Err(token.as_external_error(FilterError::TooDeep).into()); + } + Ok(Self { condition }) } } @@ -419,6 +436,8 @@ impl<'a> From> for Filter<'a> { #[cfg(test)] mod tests { + use std::fmt::Write; + use big_s::S; use either::Either; use heed::EnvOpenOptions; @@ -586,4 +605,37 @@ mod tests { "Bad longitude `180.000001`. Longitude must be contained between -180 and 180 degrees." )); } + + #[test] + fn filter_depth() { + let path = tempfile::tempdir().unwrap(); + let mut options = EnvOpenOptions::new(); + options.map_size(10 * 1024 * 1024); // 10 MB + let index = Index::new(options, &path).unwrap(); + + // Set the filterable fields to be the channel. + let mut wtxn = index.write_txn().unwrap(); + let mut builder = Settings::new(&mut wtxn, &index); + builder.set_searchable_fields(vec![S("account_ids")]); + builder.set_filterable_fields(hashset! { S("account_ids") }); + builder.execute(|_| ()).unwrap(); + wtxn.commit().unwrap(); + + // generates a big (2 MiB) filter with too much of ORs. + let tipic_filter = "account_ids=14361 OR "; + let mut filter_string = String::with_capacity(tipic_filter.len() * 14360); + for i in 1..=14361 { + let _ = write!(&mut filter_string, "account_ids={}", i); + if i != 14361 { + let _ = write!(&mut filter_string, " OR "); + } + } + + let error = Filter::from_str(&filter_string).unwrap_err(); + assert!( + error.to_string().starts_with("Too many filter conditions"), + "{}", + error.to_string() + ); + } }