Implement the _geoPoint in the sortable

This commit is contained in:
Tamo 2021-08-30 18:22:52 +02:00
parent 5bb175fc90
commit 13c78e5aa2
No known key found for this signature in database
GPG Key ID: 20CD8020AFA88D69
4 changed files with 138 additions and 39 deletions

View File

@ -58,24 +58,84 @@ impl FromStr for Criterion {
Err(error) => { Err(error) => {
Err(UserError::InvalidCriterionName { name: error.to_string() }.into()) Err(UserError::InvalidCriterionName { name: error.to_string() }.into())
} }
Ok(AscDesc::Asc(Member::Geo(_))) | Ok(AscDesc::Desc(Member::Geo(_))) => {
Err(UserError::AttributeLimitReached)? // TODO: TAMO: use a real error
}
Err(error) => Err(error.into()),
}, },
} }
} }
} }
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Eq)] #[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub enum Member {
Field(String),
Geo([f64; 2]),
}
impl FromStr for Member {
type Err = UserError;
fn from_str(text: &str) -> Result<Member, Self::Err> {
if text.starts_with("_geoPoint(") {
let point =
text.strip_prefix("_geoPoint(")
.and_then(|point| point.strip_suffix(")"))
.ok_or_else(|| UserError::InvalidCriterionName { name: text.to_string() })?;
let point = point
.split(',')
.map(|el| el.parse())
.collect::<Result<Vec<f64>, _>>()
.map_err(|_| UserError::InvalidCriterionName { name: text.to_string() })?;
Ok(Member::Geo([point[0], point[1]]))
} else {
Ok(Member::Field(text.to_string()))
}
}
}
impl fmt::Display for Member {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Member::Field(name) => write!(f, "{}", name),
Member::Geo([lat, lng]) => write!(f, "_geoPoint({}, {})", lat, lng),
}
}
}
impl Member {
pub fn field(&self) -> Option<&str> {
match self {
Member::Field(field) => Some(field),
Member::Geo(_) => None,
}
}
pub fn geo_point(&self) -> Option<&[f64; 2]> {
match self {
Member::Geo(point) => Some(point),
Member::Field(_) => None,
}
}
}
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
pub enum AscDesc { pub enum AscDesc {
Asc(String), Asc(Member),
Desc(String), Desc(Member),
} }
impl AscDesc { impl AscDesc {
pub fn field(&self) -> &str { pub fn member(&self) -> &Member {
match self { match self {
AscDesc::Asc(field) => field, AscDesc::Asc(member) => member,
AscDesc::Desc(field) => field, AscDesc::Desc(member) => member,
} }
} }
pub fn field(&self) -> Option<&str> {
self.member().field()
}
} }
impl FromStr for AscDesc { impl FromStr for AscDesc {
@ -85,9 +145,9 @@ impl FromStr for AscDesc {
/// string and let the caller create his own error /// string and let the caller create his own error
fn from_str(text: &str) -> Result<AscDesc, Self::Err> { fn from_str(text: &str) -> Result<AscDesc, Self::Err> {
match text.rsplit_once(':') { match text.rsplit_once(':') {
Some((field_name, "asc")) => Ok(AscDesc::Asc(field_name.to_string())), Some((left, "asc")) => Ok(AscDesc::Asc(left.parse()?)),
Some((field_name, "desc")) => Ok(AscDesc::Desc(field_name.to_string())), Some((left, "desc")) => Ok(AscDesc::Desc(left.parse()?)),
_ => Err(UserError::InvalidAscDescSyntax { name: text.to_string() }), _ => Err(UserError::InvalidCriterionName { name: text.to_string() }),
} }
} }
} }

View File

@ -4,12 +4,14 @@ use itertools::Itertools;
use log::debug; use log::debug;
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use rstar::RTree;
use super::{Criterion, CriterionParameters, CriterionResult}; use super::{Criterion, CriterionParameters, CriterionResult};
use crate::criterion::Member;
use crate::search::criteria::{resolve_query_tree, CriteriaBuilder}; use crate::search::criteria::{resolve_query_tree, CriteriaBuilder};
use crate::search::facet::{FacetNumberIter, FacetStringIter}; use crate::search::facet::{FacetNumberIter, FacetStringIter};
use crate::search::query_tree::Operation; use crate::search::query_tree::Operation;
use crate::{FieldId, Index, Result}; use crate::{FieldId, GeoPoint, Index, Result};
/// Threshold on the number of candidates that will make /// Threshold on the number of candidates that will make
/// the system to choose between one algorithm or another. /// the system to choose between one algorithm or another.
@ -18,10 +20,11 @@ const CANDIDATES_THRESHOLD: u64 = 1000;
pub struct AscDesc<'t> { pub struct AscDesc<'t> {
index: &'t Index, index: &'t Index,
rtxn: &'t heed::RoTxn<'t>, rtxn: &'t heed::RoTxn<'t>,
field_name: String, member: Member,
field_id: Option<FieldId>, field_id: Option<FieldId>,
is_ascending: bool, is_ascending: bool,
query_tree: Option<Operation>, query_tree: Option<Operation>,
rtree: Option<RTree<GeoPoint>>,
candidates: Box<dyn Iterator<Item = heed::Result<RoaringBitmap>> + 't>, candidates: Box<dyn Iterator<Item = heed::Result<RoaringBitmap>> + 't>,
allowed_candidates: RoaringBitmap, allowed_candidates: RoaringBitmap,
bucket_candidates: RoaringBitmap, bucket_candidates: RoaringBitmap,
@ -34,29 +37,29 @@ impl<'t> AscDesc<'t> {
index: &'t Index, index: &'t Index,
rtxn: &'t heed::RoTxn, rtxn: &'t heed::RoTxn,
parent: Box<dyn Criterion + 't>, parent: Box<dyn Criterion + 't>,
field_name: String, member: Member,
) -> Result<Self> { ) -> Result<Self> {
Self::new(index, rtxn, parent, field_name, true) Self::new(index, rtxn, parent, member, true)
} }
pub fn desc( pub fn desc(
index: &'t Index, index: &'t Index,
rtxn: &'t heed::RoTxn, rtxn: &'t heed::RoTxn,
parent: Box<dyn Criterion + 't>, parent: Box<dyn Criterion + 't>,
field_name: String, member: Member,
) -> Result<Self> { ) -> Result<Self> {
Self::new(index, rtxn, parent, field_name, false) Self::new(index, rtxn, parent, member, false)
} }
fn new( fn new(
index: &'t Index, index: &'t Index,
rtxn: &'t heed::RoTxn, rtxn: &'t heed::RoTxn,
parent: Box<dyn Criterion + 't>, parent: Box<dyn Criterion + 't>,
field_name: String, member: Member,
is_ascending: bool, is_ascending: bool,
) -> Result<Self> { ) -> Result<Self> {
let fields_ids_map = index.fields_ids_map(rtxn)?; let fields_ids_map = index.fields_ids_map(rtxn)?;
let field_id = fields_ids_map.id(&field_name); let field_id = member.field().and_then(|field| fields_ids_map.id(&field));
let faceted_candidates = match field_id { let faceted_candidates = match field_id {
Some(field_id) => { Some(field_id) => {
let number_faceted = index.number_faceted_documents_ids(rtxn, field_id)?; let number_faceted = index.number_faceted_documents_ids(rtxn, field_id)?;
@ -65,14 +68,16 @@ impl<'t> AscDesc<'t> {
} }
None => RoaringBitmap::default(), None => RoaringBitmap::default(),
}; };
let rtree = index.geo_rtree(rtxn)?;
Ok(AscDesc { Ok(AscDesc {
index, index,
rtxn, rtxn,
field_name, member,
field_id, field_id,
is_ascending, is_ascending,
query_tree: None, query_tree: None,
rtree,
candidates: Box::new(std::iter::empty()), candidates: Box::new(std::iter::empty()),
allowed_candidates: RoaringBitmap::new(), allowed_candidates: RoaringBitmap::new(),
faceted_candidates, faceted_candidates,
@ -92,7 +97,7 @@ impl<'t> Criterion for AscDesc<'t> {
debug!( debug!(
"Facet {}({}) iteration", "Facet {}({}) iteration",
if self.is_ascending { "Asc" } else { "Desc" }, if self.is_ascending { "Asc" } else { "Desc" },
self.field_name self.member
); );
match self.candidates.next().transpose()? { match self.candidates.next().transpose()? {
@ -135,15 +140,31 @@ impl<'t> Criterion for AscDesc<'t> {
} }
self.allowed_candidates = &candidates - params.excluded_candidates; self.allowed_candidates = &candidates - params.excluded_candidates;
self.candidates = match self.field_id {
Some(field_id) => facet_ordered( match &self.member {
self.index, Member::Field(field_name) => {
self.rtxn, self.candidates = match self.field_id {
field_id, Some(field_id) => facet_ordered(
self.is_ascending, self.index,
candidates & &self.faceted_candidates, self.rtxn,
)?, field_id,
None => Box::new(std::iter::empty()), self.is_ascending,
candidates & &self.faceted_candidates,
)?,
None => Box::new(std::iter::empty()),
}
}
Member::Geo(point) => {
self.candidates = match &self.rtree {
Some(rtree) => {
// TODO: TAMO how to remove that?
let rtree = Box::new(rtree.clone());
let rtree = Box::leak(rtree);
geo_point(rtree, candidates, point.clone())?
}
None => Box::new(std::iter::empty()),
}
}
}; };
} }
None => return Ok(None), None => return Ok(None),
@ -163,6 +184,22 @@ impl<'t> Criterion for AscDesc<'t> {
} }
} }
fn geo_point<'t>(
rtree: &'t RTree<GeoPoint>,
candidates: RoaringBitmap,
point: [f64; 2],
) -> Result<Box<dyn Iterator<Item = heed::Result<RoaringBitmap>> + 't>> {
Ok(Box::new(
rtree
.nearest_neighbor_iter_with_distance_2(&point)
.filter_map(move |(point, _distance)| {
candidates.contains(point.data).then(|| point.data)
})
.map(|id| std::iter::once(id).collect::<RoaringBitmap>())
.map(Ok),
))
}
/// Returns an iterator over groups of the given candidates in ascending or descending order. /// Returns an iterator over groups of the given candidates in ascending or descending order.
/// ///
/// It will either use an iterative or a recursive method on the whole facet database depending /// It will either use an iterative or a recursive method on the whole facet database depending

View File

@ -12,7 +12,7 @@ use self::r#final::Final;
use self::typo::Typo; use self::typo::Typo;
use self::words::Words; use self::words::Words;
use super::query_tree::{Operation, PrimitiveQueryPart, Query, QueryKind}; use super::query_tree::{Operation, PrimitiveQueryPart, Query, QueryKind};
use crate::criterion::AscDesc as AscDescName; use crate::criterion::{AscDesc as AscDescName, Member};
use crate::search::{word_derivations, WordDerivationsCache}; use crate::search::{word_derivations, WordDerivationsCache};
use crate::{DocumentId, FieldId, Index, Result, TreeLevel}; use crate::{DocumentId, FieldId, Index, Result, TreeLevel};
@ -294,13 +294,13 @@ impl<'t> CriteriaBuilder<'t> {
&self.index, &self.index,
&self.rtxn, &self.rtxn,
criterion, criterion,
field.to_string(), field.clone(),
)?), )?),
AscDescName::Desc(field) => Box::new(AscDesc::desc( AscDescName::Desc(field) => Box::new(AscDesc::desc(
&self.index, &self.index,
&self.rtxn, &self.rtxn,
criterion, criterion,
field.to_string(), field.clone(),
)?), )?),
}; };
} }
@ -312,10 +312,10 @@ impl<'t> CriteriaBuilder<'t> {
Name::Attribute => Box::new(Attribute::new(self, criterion)), Name::Attribute => Box::new(Attribute::new(self, criterion)),
Name::Exactness => Box::new(Exactness::new(self, criterion, &primitive_query)?), Name::Exactness => Box::new(Exactness::new(self, criterion, &primitive_query)?),
Name::Asc(field) => { Name::Asc(field) => {
Box::new(AscDesc::asc(&self.index, &self.rtxn, criterion, field)?) Box::new(AscDesc::asc(&self.index, &self.rtxn, criterion, Member::Field(field))?)
} }
Name::Desc(field) => { Name::Desc(field) => {
Box::new(AscDesc::desc(&self.index, &self.rtxn, criterion, field)?) Box::new(AscDesc::desc(&self.index, &self.rtxn, criterion, Member::Field(field))?)
} }
}; };
} }

View File

@ -148,13 +148,15 @@ impl<'a> Search<'a> {
if let Some(sort_criteria) = &self.sort_criteria { if let Some(sort_criteria) = &self.sort_criteria {
let sortable_fields = self.index.sortable_fields(self.rtxn)?; let sortable_fields = self.index.sortable_fields(self.rtxn)?;
for asc_desc in sort_criteria { for asc_desc in sort_criteria {
let field = asc_desc.field(); // we are not supposed to find any geoPoint in the criterion
if !sortable_fields.contains(field) { if let Some(field) = asc_desc.field() {
return Err(UserError::InvalidSortableAttribute { if !sortable_fields.contains(field) {
field: field.to_string(), return Err(UserError::InvalidSortableAttribute {
valid_fields: sortable_fields, field: field.to_string(),
valid_fields: sortable_fields,
}
.into());
} }
.into());
} }
} }
} }