From 13c78e5aa2dc9413f02c2df3d28d6c700f88fb93 Mon Sep 17 00:00:00 2001 From: Tamo Date: Mon, 30 Aug 2021 18:22:52 +0200 Subject: [PATCH] Implement the _geoPoint in the sortable --- milli/src/criterion.rs | 78 +++++++++++++++++++++++---- milli/src/search/criteria/asc_desc.rs | 75 +++++++++++++++++++------- milli/src/search/criteria/mod.rs | 10 ++-- milli/src/search/mod.rs | 14 ++--- 4 files changed, 138 insertions(+), 39 deletions(-) diff --git a/milli/src/criterion.rs b/milli/src/criterion.rs index d91d4a7e1..2bca6948b 100644 --- a/milli/src/criterion.rs +++ b/milli/src/criterion.rs @@ -58,24 +58,84 @@ impl FromStr for Criterion { Err(error) => { Err(UserError::InvalidCriterionName { name: error.to_string() }.into()) } + Ok(AscDesc::Asc(Member::Geo(_))) | Ok(AscDesc::Desc(Member::Geo(_))) => { + Err(UserError::AttributeLimitReached)? // TODO: TAMO: use a real error + } + Err(error) => Err(error.into()), }, } } } -#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] +pub enum Member { + Field(String), + Geo([f64; 2]), +} + +impl FromStr for Member { + type Err = UserError; + + fn from_str(text: &str) -> Result { + if text.starts_with("_geoPoint(") { + let point = + text.strip_prefix("_geoPoint(") + .and_then(|point| point.strip_suffix(")")) + .ok_or_else(|| UserError::InvalidCriterionName { name: text.to_string() })?; + let point = point + .split(',') + .map(|el| el.parse()) + .collect::, _>>() + .map_err(|_| UserError::InvalidCriterionName { name: text.to_string() })?; + Ok(Member::Geo([point[0], point[1]])) + } else { + Ok(Member::Field(text.to_string())) + } + } +} + +impl fmt::Display for Member { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Member::Field(name) => write!(f, "{}", name), + Member::Geo([lat, lng]) => write!(f, "_geoPoint({}, {})", lat, lng), + } + } +} + +impl Member { + pub fn field(&self) -> Option<&str> { + match self { + Member::Field(field) => Some(field), + Member::Geo(_) => None, + } + } + + pub fn geo_point(&self) -> Option<&[f64; 2]> { + match self { + Member::Geo(point) => Some(point), + Member::Field(_) => None, + } + } +} + +#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)] pub enum AscDesc { - Asc(String), - Desc(String), + Asc(Member), + Desc(Member), } impl AscDesc { - pub fn field(&self) -> &str { + pub fn member(&self) -> &Member { match self { - AscDesc::Asc(field) => field, - AscDesc::Desc(field) => field, + AscDesc::Asc(member) => member, + AscDesc::Desc(member) => member, } } + + pub fn field(&self) -> Option<&str> { + self.member().field() + } } impl FromStr for AscDesc { @@ -85,9 +145,9 @@ impl FromStr for AscDesc { /// string and let the caller create his own error fn from_str(text: &str) -> Result { match text.rsplit_once(':') { - Some((field_name, "asc")) => Ok(AscDesc::Asc(field_name.to_string())), - Some((field_name, "desc")) => Ok(AscDesc::Desc(field_name.to_string())), - _ => Err(UserError::InvalidAscDescSyntax { name: text.to_string() }), + Some((left, "asc")) => Ok(AscDesc::Asc(left.parse()?)), + Some((left, "desc")) => Ok(AscDesc::Desc(left.parse()?)), + _ => Err(UserError::InvalidCriterionName { name: text.to_string() }), } } } diff --git a/milli/src/search/criteria/asc_desc.rs b/milli/src/search/criteria/asc_desc.rs index 6d50c1bb5..b0951f655 100644 --- a/milli/src/search/criteria/asc_desc.rs +++ b/milli/src/search/criteria/asc_desc.rs @@ -4,12 +4,14 @@ use itertools::Itertools; use log::debug; use ordered_float::OrderedFloat; use roaring::RoaringBitmap; +use rstar::RTree; use super::{Criterion, CriterionParameters, CriterionResult}; +use crate::criterion::Member; use crate::search::criteria::{resolve_query_tree, CriteriaBuilder}; use crate::search::facet::{FacetNumberIter, FacetStringIter}; use crate::search::query_tree::Operation; -use crate::{FieldId, Index, Result}; +use crate::{FieldId, GeoPoint, Index, Result}; /// Threshold on the number of candidates that will make /// the system to choose between one algorithm or another. @@ -18,10 +20,11 @@ const CANDIDATES_THRESHOLD: u64 = 1000; pub struct AscDesc<'t> { index: &'t Index, rtxn: &'t heed::RoTxn<'t>, - field_name: String, + member: Member, field_id: Option, is_ascending: bool, query_tree: Option, + rtree: Option>, candidates: Box> + 't>, allowed_candidates: RoaringBitmap, bucket_candidates: RoaringBitmap, @@ -34,29 +37,29 @@ impl<'t> AscDesc<'t> { index: &'t Index, rtxn: &'t heed::RoTxn, parent: Box, - field_name: String, + member: Member, ) -> Result { - Self::new(index, rtxn, parent, field_name, true) + Self::new(index, rtxn, parent, member, true) } pub fn desc( index: &'t Index, rtxn: &'t heed::RoTxn, parent: Box, - field_name: String, + member: Member, ) -> Result { - Self::new(index, rtxn, parent, field_name, false) + Self::new(index, rtxn, parent, member, false) } fn new( index: &'t Index, rtxn: &'t heed::RoTxn, parent: Box, - field_name: String, + member: Member, is_ascending: bool, ) -> Result { let fields_ids_map = index.fields_ids_map(rtxn)?; - let field_id = fields_ids_map.id(&field_name); + let field_id = member.field().and_then(|field| fields_ids_map.id(&field)); let faceted_candidates = match field_id { Some(field_id) => { let number_faceted = index.number_faceted_documents_ids(rtxn, field_id)?; @@ -65,14 +68,16 @@ impl<'t> AscDesc<'t> { } None => RoaringBitmap::default(), }; + let rtree = index.geo_rtree(rtxn)?; Ok(AscDesc { index, rtxn, - field_name, + member, field_id, is_ascending, query_tree: None, + rtree, candidates: Box::new(std::iter::empty()), allowed_candidates: RoaringBitmap::new(), faceted_candidates, @@ -92,7 +97,7 @@ impl<'t> Criterion for AscDesc<'t> { debug!( "Facet {}({}) iteration", if self.is_ascending { "Asc" } else { "Desc" }, - self.field_name + self.member ); match self.candidates.next().transpose()? { @@ -135,15 +140,31 @@ impl<'t> Criterion for AscDesc<'t> { } self.allowed_candidates = &candidates - params.excluded_candidates; - self.candidates = match self.field_id { - Some(field_id) => facet_ordered( - self.index, - self.rtxn, - field_id, - self.is_ascending, - candidates & &self.faceted_candidates, - )?, - None => Box::new(std::iter::empty()), + + match &self.member { + Member::Field(field_name) => { + self.candidates = match self.field_id { + Some(field_id) => facet_ordered( + self.index, + self.rtxn, + field_id, + self.is_ascending, + candidates & &self.faceted_candidates, + )?, + None => Box::new(std::iter::empty()), + } + } + Member::Geo(point) => { + self.candidates = match &self.rtree { + Some(rtree) => { + // TODO: TAMO how to remove that? + let rtree = Box::new(rtree.clone()); + let rtree = Box::leak(rtree); + geo_point(rtree, candidates, point.clone())? + } + None => Box::new(std::iter::empty()), + } + } }; } None => return Ok(None), @@ -163,6 +184,22 @@ impl<'t> Criterion for AscDesc<'t> { } } +fn geo_point<'t>( + rtree: &'t RTree, + candidates: RoaringBitmap, + point: [f64; 2], +) -> Result> + 't>> { + Ok(Box::new( + rtree + .nearest_neighbor_iter_with_distance_2(&point) + .filter_map(move |(point, _distance)| { + candidates.contains(point.data).then(|| point.data) + }) + .map(|id| std::iter::once(id).collect::()) + .map(Ok), + )) +} + /// Returns an iterator over groups of the given candidates in ascending or descending order. /// /// It will either use an iterative or a recursive method on the whole facet database depending diff --git a/milli/src/search/criteria/mod.rs b/milli/src/search/criteria/mod.rs index 2a883de67..92c0d284a 100644 --- a/milli/src/search/criteria/mod.rs +++ b/milli/src/search/criteria/mod.rs @@ -12,7 +12,7 @@ use self::r#final::Final; use self::typo::Typo; use self::words::Words; use super::query_tree::{Operation, PrimitiveQueryPart, Query, QueryKind}; -use crate::criterion::AscDesc as AscDescName; +use crate::criterion::{AscDesc as AscDescName, Member}; use crate::search::{word_derivations, WordDerivationsCache}; use crate::{DocumentId, FieldId, Index, Result, TreeLevel}; @@ -294,13 +294,13 @@ impl<'t> CriteriaBuilder<'t> { &self.index, &self.rtxn, criterion, - field.to_string(), + field.clone(), )?), AscDescName::Desc(field) => Box::new(AscDesc::desc( &self.index, &self.rtxn, criterion, - field.to_string(), + field.clone(), )?), }; } @@ -312,10 +312,10 @@ impl<'t> CriteriaBuilder<'t> { Name::Attribute => Box::new(Attribute::new(self, criterion)), Name::Exactness => Box::new(Exactness::new(self, criterion, &primitive_query)?), Name::Asc(field) => { - Box::new(AscDesc::asc(&self.index, &self.rtxn, criterion, field)?) + Box::new(AscDesc::asc(&self.index, &self.rtxn, criterion, Member::Field(field))?) } Name::Desc(field) => { - Box::new(AscDesc::desc(&self.index, &self.rtxn, criterion, field)?) + Box::new(AscDesc::desc(&self.index, &self.rtxn, criterion, Member::Field(field))?) } }; } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 207f46f8a..f752f5822 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -148,13 +148,15 @@ impl<'a> Search<'a> { if let Some(sort_criteria) = &self.sort_criteria { let sortable_fields = self.index.sortable_fields(self.rtxn)?; for asc_desc in sort_criteria { - let field = asc_desc.field(); - if !sortable_fields.contains(field) { - return Err(UserError::InvalidSortableAttribute { - field: field.to_string(), - valid_fields: sortable_fields, + // we are not supposed to find any geoPoint in the criterion + if let Some(field) = asc_desc.field() { + if !sortable_fields.contains(field) { + return Err(UserError::InvalidSortableAttribute { + field: field.to_string(), + valid_fields: sortable_fields, + } + .into()); } - .into()); } } }