Move sorting code out of search

This commit is contained in:
Mubelotix 2025-06-30 11:59:59 +02:00
parent 340d9e6edc
commit 63827bbee0
No known key found for this signature in database
GPG key ID: 0406DF6C3A69B942
6 changed files with 227 additions and 178 deletions

View file

@ -0,0 +1,182 @@
use std::collections::VecDeque;
use heed::RoTxn;
use roaring::RoaringBitmap;
use rstar::RTree;
use crate::{
distance_between_two_points, lat_lng_to_xyz,
search::new::geo_sort::{geo_value, opposite_of},
GeoPoint, GeoSortStrategy, Index,
};
// TODO: Make it take a mut reference to cache
#[allow(clippy::too_many_arguments)]
pub fn fill_cache(
index: &Index,
txn: &RoTxn<heed::AnyTls>,
strategy: GeoSortStrategy,
ascending: bool,
target_point: [f64; 2],
field_ids: &Option<[u16; 2]>,
rtree: &mut Option<RTree<GeoPoint>>,
geo_candidates: &RoaringBitmap,
cached_sorted_docids: &mut VecDeque<(u32, [f64; 2])>,
) -> crate::Result<()> {
debug_assert!(cached_sorted_docids.is_empty());
// lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree`
let rtree = if strategy.use_rtree(geo_candidates.len() as usize) {
if let Some(rtree) = rtree.as_ref() {
// get rtree from cache
Some(rtree)
} else {
let rtree2 = index.geo_rtree(txn)?.expect("geo candidates but no rtree");
// insert rtree in cache and returns it.
// Can't use `get_or_insert_with` because getting the rtree from the DB is a fallible operation.
Some(&*rtree.insert(rtree2))
}
} else {
None
};
let cache_size = strategy.cache_size();
if let Some(rtree) = rtree {
if ascending {
let point = lat_lng_to_xyz(&target_point);
for point in rtree.nearest_neighbor_iter(&point) {
if geo_candidates.contains(point.data.0) {
cached_sorted_docids.push_back(point.data);
if cached_sorted_docids.len() >= cache_size {
break;
}
}
}
} else {
// in the case of the desc geo sort we look for the closest point to the opposite of the queried point
// and we insert the points in reverse order they get reversed when emptying the cache later on
let point = lat_lng_to_xyz(&opposite_of(target_point));
for point in rtree.nearest_neighbor_iter(&point) {
if geo_candidates.contains(point.data.0) {
cached_sorted_docids.push_front(point.data);
if cached_sorted_docids.len() >= cache_size {
break;
}
}
}
}
} else {
// the iterative version
let [lat, lng] = field_ids.expect("fill_buffer can't be called without the lat&lng");
let mut documents = geo_candidates
.iter()
.map(|id| -> crate::Result<_> { Ok((id, geo_value(id, lat, lng, index, txn)?)) })
.collect::<crate::Result<Vec<(u32, [f64; 2])>>>()?;
// computing the distance between two points is expensive thus we cache the result
documents.sort_by_cached_key(|(_, p)| distance_between_two_points(&target_point, p) as usize);
cached_sorted_docids.extend(documents);
};
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn next_bucket(
index: &Index,
txn: &RoTxn<heed::AnyTls>,
universe: &RoaringBitmap,
strategy: GeoSortStrategy,
ascending: bool,
target_point: [f64; 2],
field_ids: &Option<[u16; 2]>,
rtree: &mut Option<RTree<GeoPoint>>,
cached_sorted_docids: &mut VecDeque<(u32, [f64; 2])>,
geo_candidates: &RoaringBitmap,
// Limit the number of docs in a single bucket to avoid unexpectedly large overhead
max_bucket_size: u64,
// Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal
distance_error_margin: f64,
) -> crate::Result<Option<(RoaringBitmap, Option<[f64; 2]>)>> {
let mut geo_candidates = geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(Some((universe.clone(), None)));
}
let next = |cache: &mut VecDeque<_>| {
if ascending {
cache.pop_front()
} else {
cache.pop_back()
}
};
let put_back = |cache: &mut VecDeque<_>, x: _| {
if ascending {
cache.push_front(x)
} else {
cache.push_back(x)
}
};
let mut current_bucket = RoaringBitmap::new();
// current_distance stores the first point and distance in current bucket
let mut current_distance: Option<([f64; 2], f64)> = None;
loop {
// The loop will only exit when we have found all points with equal distance or have exhausted the candidates.
if let Some((id, point)) = next(cached_sorted_docids) {
if geo_candidates.contains(id) {
let distance = distance_between_two_points(&target_point, &point);
if let Some((point0, bucket_distance)) = current_distance.as_ref() {
if (bucket_distance - distance).abs() > distance_error_margin {
// different distance, point belongs to next bucket
put_back(cached_sorted_docids, (id, point));
return Ok(Some((current_bucket, Some(point0.to_owned()))));
} else {
// same distance, point belongs to current bucket
current_bucket.insert(id);
// remove from candidates to prevent it from being added to the cache again
geo_candidates.remove(id);
// current bucket size reaches limit, force return
if current_bucket.len() == max_bucket_size {
return Ok(Some((current_bucket, Some(point0.to_owned()))));
}
}
} else {
// first doc in current bucket
current_distance = Some((point, distance));
current_bucket.insert(id);
geo_candidates.remove(id);
// current bucket size reaches limit, force return
if current_bucket.len() == max_bucket_size {
return Ok(Some((current_bucket, Some(point.to_owned()))));
}
}
}
} else {
// cache exhausted, we need to refill it
fill_cache(
index,
txn,
strategy,
ascending,
target_point,
field_ids,
rtree,
&geo_candidates,
cached_sorted_docids,
)?;
if cached_sorted_docids.is_empty() {
// candidates exhausted, exit
if let Some((point0, _)) = current_distance.as_ref() {
return Ok(Some((current_bucket, Some(point0.to_owned()))));
} else {
return Ok(Some((universe.clone(), None)));
}
}
}
}
}

View file

@ -3,6 +3,7 @@ mod enriched;
mod primary_key;
mod reader;
mod serde_impl;
pub mod geo_sort;
use std::fmt::Debug;
use std::io;

View file

@ -82,7 +82,7 @@ fn facet_value_docids(
}
/// Return an iterator over each number value in the given field of the given document.
fn facet_number_values<'a>(
pub(crate) fn facet_number_values<'a>(
docid: u32,
field_id: u16,
index: &Index,

View file

@ -7,12 +7,10 @@ use rstar::RTree;
use super::facet_string_values;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::documents::geo_sort::{fill_cache, next_bucket};
use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec};
use crate::score_details::{self, ScoreDetails};
use crate::{
distance_between_two_points, lat_lng_to_xyz, GeoPoint, Index, Result, SearchContext,
SearchLogger,
};
use crate::{GeoPoint, Index, Result, SearchContext, SearchLogger};
const FID_SIZE: usize = 2;
const DOCID_SIZE: usize = 4;
@ -134,62 +132,17 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
ctx: &mut SearchContext<'_>,
geo_candidates: &RoaringBitmap,
) -> Result<()> {
debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng");
debug_assert!(self.cached_sorted_docids.is_empty());
// lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree`
let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) {
if let Some(rtree) = self.rtree.as_ref() {
// get rtree from cache
Some(rtree)
} else {
let rtree = ctx.index.geo_rtree(ctx.txn)?.expect("geo candidates but no rtree");
// insert rtree in cache and returns it.
// Can't use `get_or_insert_with` because getting the rtree from the DB is a fallible operation.
Some(&*self.rtree.insert(rtree))
}
} else {
None
};
let cache_size = self.strategy.cache_size();
if let Some(rtree) = rtree {
if self.ascending {
let point = lat_lng_to_xyz(&self.point);
for point in rtree.nearest_neighbor_iter(&point) {
if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_back(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
}
}
}
} else {
// in the case of the desc geo sort we look for the closest point to the opposite of the queried point
// and we insert the points in reverse order they get reversed when emptying the cache later on
let point = lat_lng_to_xyz(&opposite_of(self.point));
for point in rtree.nearest_neighbor_iter(&point) {
if geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_front(point.data);
if self.cached_sorted_docids.len() >= cache_size {
break;
}
}
}
}
} else {
// the iterative version
let [lat, lng] = self.field_ids.unwrap();
let mut documents = geo_candidates
.iter()
.map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) })
.collect::<Result<Vec<(u32, [f64; 2])>>>()?;
// computing the distance between two points is expensive thus we cache the result
documents
.sort_by_cached_key(|(_, p)| distance_between_two_points(&self.point, p) as usize);
self.cached_sorted_docids.extend(documents);
};
fill_cache(
ctx.index,
ctx.txn,
self.strategy,
self.ascending,
self.point,
&self.field_ids,
&mut self.rtree,
geo_candidates,
&mut self.cached_sorted_docids,
)?;
Ok(())
}
@ -199,7 +152,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
///
/// If it is not able to find it in the facet number index it will extract it
/// from the facet string index and parse it as f64 (as the geo extraction behaves).
fn geo_value(
pub(crate) fn geo_value(
docid: u32,
field_lat: u16,
field_lng: u16,
@ -267,124 +220,31 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
let mut geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
next_bucket(
ctx.index,
ctx.txn,
universe,
self.strategy,
self.ascending,
self.point,
&self.field_ids,
&mut self.rtree,
&mut self.cached_sorted_docids,
&self.geo_candidates,
self.max_bucket_size,
self.distance_error_margin,
)
.map(|o| {
o.map(|(candidates, point)| RankingRuleOutput {
query,
candidates: universe.clone(),
candidates,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: None,
value: point,
}),
}));
}
let ascending = self.ascending;
let next = |cache: &mut VecDeque<_>| {
if ascending {
cache.pop_front()
} else {
cache.pop_back()
}
};
let put_back = |cache: &mut VecDeque<_>, x: _| {
if ascending {
cache.push_front(x)
} else {
cache.push_back(x)
}
};
let mut current_bucket = RoaringBitmap::new();
// current_distance stores the first point and distance in current bucket
let mut current_distance: Option<([f64; 2], f64)> = None;
loop {
// The loop will only exit when we have found all points with equal distance or have exhausted the candidates.
if let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if geo_candidates.contains(id) {
let distance = distance_between_two_points(&self.point, &point);
if let Some((point0, bucket_distance)) = current_distance.as_ref() {
if (bucket_distance - distance).abs() > self.distance_error_margin {
// different distance, point belongs to next bucket
put_back(&mut self.cached_sorted_docids, (id, point));
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
} else {
// same distance, point belongs to current bucket
current_bucket.insert(id);
// remove from cadidates to prevent it from being added to the cache again
geo_candidates.remove(id);
// current bucket size reaches limit, force return
if current_bucket.len() == self.max_bucket_size {
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
}
}
} else {
// first doc in current bucket
current_distance = Some((point, distance));
current_bucket.insert(id);
geo_candidates.remove(id);
// current bucket size reaches limit, force return
if current_bucket.len() == self.max_bucket_size {
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point.to_owned()),
}),
}));
}
}
}
} else {
// cache exhausted, we need to refill it
self.fill_buffer(ctx, &geo_candidates)?;
if self.cached_sorted_docids.is_empty() {
// candidates exhausted, exit
if let Some((point0, _)) = current_distance.as_ref() {
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
} else {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: None,
}),
}));
}
}
}
}
})
})
}
#[tracing::instrument(level = "trace", skip_all, target = "search::geo_sort")]
@ -396,7 +256,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
}
/// Compute the antipodal coordinate of `coord`
fn opposite_of(mut coord: [f64; 2]) -> [f64; 2] {
pub(crate) fn opposite_of(mut coord: [f64; 2]) -> [f64; 2] {
coord[0] *= -1.;
// in the case of x,0 we want to return x,180
if coord[1] > 0. {

View file

@ -1,7 +1,7 @@
mod bucket_sort;
mod db_cache;
mod distinct;
mod geo_sort;
pub(crate) mod geo_sort;
mod graph_based_ranking_rule;
mod interner;
mod limits;