Implement geo sort on documents

This commit is contained in:
Mubelotix 2025-06-30 13:57:30 +02:00
parent e35d58b531
commit f86f4f619f
No known key found for this signature in database
GPG key ID: 0406DF6C3A69B942
5 changed files with 152 additions and 48 deletions

View file

@ -66,7 +66,6 @@ impl GeoSortStrategy {
}
}
// TODO: Make it take a mut reference to cache
#[allow(clippy::too_many_arguments)]
pub fn fill_cache(
index: &Index,

View file

@ -1,4 +1,7 @@
use std::collections::VecDeque;
use crate::{
documents::{geo_sort::next_bucket, GeoSortParameter},
heed_codec::{
facet::{FacetGroupKeyCodec, FacetGroupValueCodec},
BytesRefCodec,
@ -12,38 +15,64 @@ use crate::{
use heed::Database;
use roaring::RoaringBitmap;
#[derive(Debug, Clone, Copy)]
enum AscDescId {
Facet { field_id: u16, ascending: bool },
Geo { field_ids: [u16; 2], target_point: [f64; 2], ascending: bool },
}
/// Builder for a [`SortedDocumentsIterator`].
/// Most builders won't ever be built, because pagination will skip them.
pub struct SortedDocumentsIteratorBuilder<'ctx> {
index: &'ctx crate::Index,
rtxn: &'ctx heed::RoTxn<'ctx>,
number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
fields: &'ctx [(u16, bool)],
fields: &'ctx [AscDescId],
candidates: RoaringBitmap,
geo_candidates: &'ctx RoaringBitmap,
}
impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> {
/// Performs the sort and builds a [`SortedDocumentsIterator`].
fn build(self) -> heed::Result<SortedDocumentsIterator<'ctx>> {
let SortedDocumentsIteratorBuilder { rtxn, number_db, string_db, fields, candidates } =
self;
let size = candidates.len() as usize;
fn build(self) -> crate::Result<SortedDocumentsIterator<'ctx>> {
let size = self.candidates.len() as usize;
// There is no point sorting a 1-element array
if size <= 1 {
return Ok(SortedDocumentsIterator::Leaf {
size,
values: Box::new(candidates.into_iter()),
values: Box::new(self.candidates.into_iter()),
});
}
// There is no variable to sort on
let Some((field_id, ascending)) = fields.first().copied() else {
return Ok(SortedDocumentsIterator::Leaf {
match self.fields.first().copied() {
Some(AscDescId::Facet { field_id, ascending }) => self.build_facet(field_id, ascending),
Some(AscDescId::Geo { field_ids, target_point, ascending }) => {
self.build_geo(field_ids, target_point, ascending)
}
None => Ok(SortedDocumentsIterator::Leaf {
size,
values: Box::new(candidates.into_iter()),
});
};
values: Box::new(self.candidates.into_iter()),
}),
}
}
fn build_facet(
self,
field_id: u16,
ascending: bool,
) -> crate::Result<SortedDocumentsIterator<'ctx>> {
let SortedDocumentsIteratorBuilder {
index,
rtxn,
number_db,
string_db,
fields,
candidates,
geo_candidates,
} = self;
let size = candidates.len() as usize;
// Perform the sort on the first field
let (number_iter, string_iter) = if ascending {
@ -62,25 +91,29 @@ impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> {
let number_db2 = number_db;
let string_db2 = string_db;
let number_iter =
number_iter.map(move |r| -> heed::Result<SortedDocumentsIteratorBuilder> {
number_iter.map(move |r| -> crate::Result<SortedDocumentsIteratorBuilder> {
let (docids, _bytes) = r?;
Ok(SortedDocumentsIteratorBuilder {
index,
rtxn,
number_db,
string_db,
fields: &fields[1..],
candidates: docids,
geo_candidates,
})
});
let string_iter =
string_iter.map(move |r| -> heed::Result<SortedDocumentsIteratorBuilder> {
string_iter.map(move |r| -> crate::Result<SortedDocumentsIteratorBuilder> {
let (docids, _bytes) = r?;
Ok(SortedDocumentsIteratorBuilder {
index,
rtxn,
number_db: number_db2,
string_db: string_db2,
fields: &fields[1..],
candidates: docids,
geo_candidates,
})
});
@ -90,6 +123,60 @@ impl<'ctx> SortedDocumentsIteratorBuilder<'ctx> {
next_children: Box::new(number_iter.chain(string_iter)),
})
}
fn build_geo(
self,
field_ids: [u16; 2],
target_point: [f64; 2],
ascending: bool,
) -> crate::Result<SortedDocumentsIterator<'ctx>> {
let SortedDocumentsIteratorBuilder {
index,
rtxn,
number_db,
string_db,
fields,
candidates,
geo_candidates,
} = self;
let mut cache = VecDeque::new();
let mut rtree = None;
let size = candidates.len() as usize;
let next_children = std::iter::from_fn(move || {
match next_bucket(
index,
rtxn,
&candidates,
ascending,
target_point,
&Some(field_ids),
&mut rtree,
&mut cache,
geo_candidates,
GeoSortParameter::default(),
) {
Ok(Some((docids, _point))) => Some(Ok(SortedDocumentsIteratorBuilder {
index,
rtxn,
number_db,
string_db,
fields: &fields[1..],
candidates: docids,
geo_candidates,
})),
Ok(None) => None,
Err(e) => Some(Err(e)),
}
});
Ok(SortedDocumentsIterator::Branch {
current_child: None,
next_children_size: size, // TODO: confirm all candidates will be included
next_children: Box::new(next_children),
})
}
}
/// A [`SortedDocumentsIterator`] allows efficient access to a continuous range of sorted documents.
@ -108,7 +195,7 @@ pub enum SortedDocumentsIterator<'ctx> {
next_children_size: usize,
/// Iterators to become the current child once it is exhausted
next_children:
Box<dyn Iterator<Item = heed::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx>,
Box<dyn Iterator<Item = crate::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx>,
},
}
@ -118,9 +205,9 @@ impl SortedDocumentsIterator<'_> {
current_child: &mut Option<Box<SortedDocumentsIterator<'ctx>>>,
next_children_size: &mut usize,
next_children: &mut Box<
dyn Iterator<Item = heed::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx,
dyn Iterator<Item = crate::Result<SortedDocumentsIteratorBuilder<'ctx>>> + 'ctx,
>,
) -> heed::Result<()> {
) -> crate::Result<()> {
if current_child.is_none() {
*current_child = match next_children.next() {
Some(Ok(builder)) => {
@ -137,7 +224,7 @@ impl SortedDocumentsIterator<'_> {
}
impl Iterator for SortedDocumentsIterator<'_> {
type Item = heed::Result<DocumentId>;
type Item = crate::Result<DocumentId>;
fn nth(&mut self, n: usize) -> Option<Self::Item> {
// If it's at the leaf level, just forward the call to the values iterator
@ -241,21 +328,25 @@ impl Iterator for SortedDocumentsIterator<'_> {
/// A structure owning the data needed during the lifetime of a [`SortedDocumentsIterator`].
pub struct SortedDocuments<'ctx> {
index: &'ctx crate::Index,
rtxn: &'ctx heed::RoTxn<'ctx>,
fields: Vec<(u16, bool)>,
fields: Vec<AscDescId>,
number_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
string_db: Database<FacetGroupKeyCodec<BytesRefCodec>, FacetGroupValueCodec>,
candidates: &'ctx RoaringBitmap,
geo_candidates: RoaringBitmap,
}
impl<'ctx> SortedDocuments<'ctx> {
pub fn iter(&'ctx self) -> heed::Result<SortedDocumentsIterator<'ctx>> {
pub fn iter(&'ctx self) -> crate::Result<SortedDocumentsIterator<'ctx>> {
let builder = SortedDocumentsIteratorBuilder {
index: self.index,
rtxn: self.rtxn,
number_db: self.number_db,
string_db: self.string_db,
fields: &self.fields,
candidates: self.candidates.clone(),
geo_candidates: &self.geo_candidates,
};
builder.build()
}
@ -264,28 +355,55 @@ impl<'ctx> SortedDocuments<'ctx> {
pub fn recursive_facet_sort<'ctx>(
index: &'ctx crate::Index,
rtxn: &'ctx heed::RoTxn<'ctx>,
sort: &[AscDesc],
sort: Vec<AscDesc>,
candidates: &'ctx RoaringBitmap,
) -> crate::Result<SortedDocuments<'ctx>> {
check_sort_criteria(index, rtxn, Some(sort))?;
check_sort_criteria(index, rtxn, Some(&sort))?;
let mut fields = Vec::new();
let fields_ids_map = index.fields_ids_map(rtxn)?;
let geo_candidates = index.geo_faceted_documents_ids(rtxn)?; // TODO: skip when no geo sort
for sort in sort {
let (field_id, ascending) = match sort {
AscDesc::Asc(Member::Field(field)) => (fields_ids_map.id(field), true),
AscDesc::Desc(Member::Field(field)) => (fields_ids_map.id(field), false),
AscDesc::Asc(Member::Geo(_)) => todo!(),
AscDesc::Desc(Member::Geo(_)) => todo!(),
match sort {
AscDesc::Asc(Member::Field(field)) => {
if let Some(field_id) = fields_ids_map.id(&field) {
fields.push(AscDescId::Facet { field_id, ascending: true });
}
}
AscDesc::Desc(Member::Field(field)) => {
if let Some(field_id) = fields_ids_map.id(&field) {
fields.push(AscDescId::Facet { field_id, ascending: false });
}
}
AscDesc::Asc(Member::Geo(target_point)) => {
if let (Some(lat), Some(lng)) =
(fields_ids_map.id("_geo.lat"), fields_ids_map.id("_geo.lng"))
{
fields.push(AscDescId::Geo {
field_ids: [lat, lng],
target_point,
ascending: true,
});
}
}
AscDesc::Desc(Member::Geo(target_point)) => {
if let (Some(lat), Some(lng)) =
(fields_ids_map.id("_geo.lat"), fields_ids_map.id("_geo.lng"))
{
fields.push(AscDescId::Geo {
field_ids: [lat, lng],
target_point,
ascending: false,
});
}
}
};
if let Some(field_id) = field_id {
fields.push((field_id, ascending)); // FIXME: Should this return an error if the field is not found?
}
// FIXME: Should this return an error if the field is not found?
}
let number_db = index.facet_id_f64_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
let string_db =
index.facet_id_string_docids.remap_key_type::<FacetGroupKeyCodec<BytesRefCodec>>();
Ok(SortedDocuments { rtxn, fields, number_db, string_db, candidates })
Ok(SortedDocuments { index, rtxn, fields, number_db, string_db, candidates, geo_candidates })
}

View file

@ -10,7 +10,6 @@ pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FAC
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
use self::new::{execute_vector_search, PartialSearchResult, VectorStoreStats};
use crate::documents::GeoSortParameter;
use crate::documents::GeoSortStrategy;
use crate::filterable_attributes_rules::{filtered_matching_patterns, matching_features};
use crate::index::MatchingStrategy;
use crate::score_details::{ScoreDetails, ScoringStrategy};
@ -147,7 +146,7 @@ impl<'a> Search<'a> {
}
#[cfg(test)]
pub fn geo_sort_strategy(&mut self, strategy: GeoSortStrategy) -> &mut Search<'a> {
pub fn geo_sort_strategy(&mut self, strategy: crate::GeoSortStrategy) -> &mut Search<'a> {
self.geo_param.strategy = strategy;
self
}

View file

@ -1,25 +1,13 @@
use std::collections::VecDeque;
use heed::types::{Bytes, Unit};
use heed::{RoPrefix, RoTxn};
use roaring::RoaringBitmap;
use rstar::RTree;
use super::facet_string_values;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::documents::geo_sort::{fill_cache, next_bucket};
use crate::documents::{GeoSortParameter, GeoSortStrategy};
use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec};
use crate::score_details::{self, ScoreDetails};
use crate::{GeoPoint, Index, Result, SearchContext, SearchLogger};
const FID_SIZE: usize = 2;
const DOCID_SIZE: usize = 4;
#[allow(clippy::drop_non_drop)]
fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE] {
concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes())
}
use crate::{GeoPoint, Result, SearchContext, SearchLogger};
pub struct GeoSort<Q: RankingRuleQueryTrait> {
query: Option<Q>,