diff --git a/milli/src/criterion.rs b/milli/src/criterion.rs index ea3214c8e..29c477473 100644 --- a/milli/src/criterion.rs +++ b/milli/src/criterion.rs @@ -3,7 +3,7 @@ use std::str::FromStr; use serde::{Deserialize, Serialize}; -use crate::error::{Error, UserError}; +use crate::error::{is_reserved_keyword, Error, UserError}; #[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] pub enum Criterion { @@ -50,18 +50,20 @@ impl FromStr for Criterion { "sort" => Ok(Criterion::Sort), "exactness" => Ok(Criterion::Exactness), text => match AscDesc::from_str(text) { - Ok(AscDesc::Asc(field)) => Ok(Criterion::Asc(field)), - Ok(AscDesc::Desc(field)) => Ok(Criterion::Desc(field)), + Ok(AscDesc::Asc(Member::Field(field))) if is_reserved_keyword(&field) => { + Err(UserError::InvalidReservedRankingRuleName { name: text.to_string() })? + } + Ok(AscDesc::Asc(Member::Field(field))) => Ok(Criterion::Asc(field)), + Ok(AscDesc::Desc(Member::Field(field))) => Ok(Criterion::Desc(field)), + Ok(AscDesc::Asc(Member::Geo(_))) | Ok(AscDesc::Desc(Member::Geo(_))) => { + Err(UserError::InvalidRankingRuleName { name: text.to_string() })? + } Err(UserError::InvalidAscDescSyntax { name }) => { Err(UserError::InvalidCriterionName { name }.into()) } Err(error) => { Err(UserError::InvalidCriterionName { name: error.to_string() }.into()) } - Ok(AscDesc::Asc(Member::Geo(_))) | Ok(AscDesc::Desc(Member::Geo(_))) => { - Err(UserError::AttributeLimitReached)? // TODO: TAMO: use a real error - } - Err(error) => Err(error.into()), }, } } @@ -81,12 +83,12 @@ impl FromStr for Member { let point = text.strip_prefix("_geoPoint(") .and_then(|point| point.strip_suffix(")")) - .ok_or_else(|| UserError::InvalidCriterionName { name: text.to_string() })?; + .ok_or_else(|| UserError::InvalidRankingRuleName { name: text.to_string() })?; let point = point .split(',') .map(|el| el.trim().parse()) .collect::, _>>() - .map_err(|_| UserError::InvalidCriterionName { name: text.to_string() })?; + .map_err(|_| UserError::InvalidRankingRuleName { name: text.to_string() })?; Ok(Member::Geo([point[0], point[1]])) } else { Ok(Member::Field(text.to_string())) @@ -147,7 +149,7 @@ impl FromStr for AscDesc { match text.rsplit_once(':') { Some((left, "asc")) => Ok(AscDesc::Asc(left.parse()?)), Some((left, "desc")) => Ok(AscDesc::Desc(left.parse()?)), - _ => Err(UserError::InvalidCriterionName { name: text.to_string() }), + _ => Err(UserError::InvalidRankingRuleName { name: text.to_string() }), } } } diff --git a/milli/src/error.rs b/milli/src/error.rs index 3f473a673..f4601ea9a 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -12,6 +12,12 @@ use crate::{DocumentId, FieldId}; pub type Object = Map; +const RESERVED_KEYWORD: &[&'static str] = &["_geo", "_geoDistance"]; + +pub fn is_reserved_keyword(keyword: &str) -> bool { + RESERVED_KEYWORD.contains(&keyword) +} + #[derive(Debug)] pub enum Error { InternalError(InternalError), @@ -60,6 +66,9 @@ pub enum UserError { InvalidFilter(pest::error::Error), InvalidFilterAttribute(pest::error::Error), InvalidSortName { name: String }, + InvalidGeoField { document_id: Value, object: Value }, + InvalidRankingRuleName { name: String }, + InvalidReservedRankingRuleName { name: String }, InvalidSortableAttribute { field: String, valid_fields: HashSet }, SortRankingRuleMissing, InvalidStoreFile, @@ -222,6 +231,15 @@ impl fmt::Display for UserError { write!(f, "invalid asc/desc syntax for {}", name) } Self::InvalidCriterionName { name } => write!(f, "invalid criterion {}", name), + Self::InvalidGeoField { document_id, object } => write!( + f, + "the document with the id: {} contains an invalid _geo field: {}", + document_id, object + ), + Self::InvalidRankingRuleName { name } => write!(f, "invalid criterion {}", name), + Self::InvalidReservedRankingRuleName { name } => { + write!(f, "{} is a reserved keyword and thus can't be used as a ranking rule", name) + } Self::InvalidDocumentId { document_id } => { let json = serde_json::to_string(document_id).unwrap(); write!( diff --git a/milli/src/update/index_documents/extract/extract_geo_points.rs b/milli/src/update/index_documents/extract/extract_geo_points.rs index 88ae7c177..c4bdce211 100644 --- a/milli/src/update/index_documents/extract/extract_geo_points.rs +++ b/milli/src/update/index_documents/extract/extract_geo_points.rs @@ -2,11 +2,10 @@ use std::fs::File; use std::io; use concat_arrays::concat_arrays; -use log::warn; use serde_json::Value; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; -use crate::{FieldId, InternalError, Result}; +use crate::{FieldId, InternalError, Result, UserError}; /// Extracts the geographical coordinates contained in each document under the `_geo` field. /// @@ -14,6 +13,7 @@ use crate::{FieldId, InternalError, Result}; pub fn extract_geo_points( mut obkv_documents: grenad::Reader, indexer: GrenadParameters, + primary_key_id: FieldId, geo_field_id: FieldId, ) -> Result> { let mut writer = tempfile::tempfile().and_then(|file| { @@ -33,9 +33,10 @@ pub fn extract_geo_points( let bytes: [u8; 16] = concat_arrays![lat.to_ne_bytes(), lng.to_ne_bytes()]; writer.insert(docid_bytes, bytes)?; } else { - // TAMO: improve the warn - warn!("Malformed `_geo` field"); - continue; + let primary_key = obkv.get(primary_key_id).unwrap(); // TODO: TAMO: is this valid? + let primary_key = + serde_json::from_slice(primary_key).map_err(InternalError::SerdeJson)?; + Err(UserError::InvalidGeoField { document_id: primary_key, object: point })? } } diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 4cb21c8e4..36e3c870f 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -39,6 +39,7 @@ pub(crate) fn data_from_obkv_documents( lmdb_writer_sx: Sender>, searchable_fields: Option>, faceted_fields: HashSet, + primary_key_id: FieldId, geo_field_id: Option, stop_words: Option>, ) -> Result<()> { @@ -51,6 +52,7 @@ pub(crate) fn data_from_obkv_documents( lmdb_writer_sx.clone(), &searchable_fields, &faceted_fields, + primary_key_id, geo_field_id, &stop_words, ) @@ -172,6 +174,7 @@ fn extract_documents_data( lmdb_writer_sx: Sender>, searchable_fields: &Option>, faceted_fields: &HashSet, + primary_key_id: FieldId, geo_field_id: Option, stop_words: &Option>, ) -> Result<( @@ -186,7 +189,12 @@ fn extract_documents_data( let documents_chunk_cloned = documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); rayon::spawn(move || { - let _ = match extract_geo_points(documents_chunk_cloned, indexer, geo_field_id) { + let _ = match extract_geo_points( + documents_chunk_cloned, + indexer, + primary_key_id, + geo_field_id, + ) { Ok(geo_points) => lmdb_writer_sx_cloned.send(Ok(TypedChunk::GeoPoints(geo_points))), Err(error) => lmdb_writer_sx_cloned.send(Err(error)), }; diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index d4fd3570e..38eea954b 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -228,6 +228,9 @@ impl<'t, 'u, 'i, 'a> IndexDocuments<'t, 'u, 'i, 'a> { Receiver>, ) = crossbeam_channel::unbounded(); + // get the primary key field id + let primary_key_id = fields_ids_map.id(&primary_key).unwrap(); // TODO: TAMO: is this unwrap 100% valid? + // get searchable fields for word databases let searchable_fields = self.index.searchable_fields_ids(self.wtxn)?.map(HashSet::from_iter); @@ -269,6 +272,7 @@ impl<'t, 'u, 'i, 'a> IndexDocuments<'t, 'u, 'i, 'a> { lmdb_writer_sx.clone(), searchable_fields, faceted_fields, + primary_key_id, geo_field_id, stop_words, ) diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index b09bee213..5c27c195f 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -180,7 +180,7 @@ pub(crate) fn write_typed_chunk_into_index( is_merged_database = true; } TypedChunk::GeoPoints(mut geo_points) => { - // TODO: TAMO: we should create the rtree with the `RTree::bulk_load` function + // TODO: we should create the rtree with the `RTree::bulk_load` function let mut rtree = index.geo_rtree(wtxn)?.unwrap_or_default(); let mut doc_ids = index.geo_faceted_documents_ids(wtxn)?;