fix ranking rules after _geo do not work

This commit is contained in:
hdt3213 2025-04-02 20:09:51 +08:00
parent a500fa053c
commit e4733dcd42

View File

@ -1,6 +1,4 @@
use std::collections::VecDeque;
use std::iter::FromIterator;
use heed::types::{Bytes, Unit};
use heed::{RoPrefix, RoTxn};
use roaring::RoaringBitmap;
@ -245,7 +243,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone();
let geo_candidates = &self.geo_candidates & universe;
let mut geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput {
@ -267,24 +265,79 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
cache.pop_back()
}
};
while let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if geo_candidates.contains(id) {
return Ok(Some(RankingRuleOutput {
query,
candidates: RoaringBitmap::from_iter([id]),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point),
}),
}));
let put_back = |cache: &mut VecDeque<_>, x: _| {
if ascending {
cache.push_front(x)
} else {
cache.push_back(x)
}
};
let mut current_bucket = RoaringBitmap::new();
// current_distance stores the first point and distance in current bucket
// The farthest distance between two points on earth is about 2e7 meters, u32 is big enough to hold any distance.
let mut current_distance: Option<([f64; 2], u32)> = None;
loop {
// The loop will only exit when we have found all points with equal distance or have exhausted the candidates.
if let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if geo_candidates.contains(id) {
let distance = distance_between_two_points(&self.point, &point).round() as u32;
if let Some((point0, bucket_distance)) = current_distance.as_ref() {
if bucket_distance != &distance {
// different distance, point belongs to next bucket
put_back(&mut self.cached_sorted_docids, (id, point));
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
} else {
// same distance, point belongs to current bucket
current_bucket.push(id);
// remove from cadidates to prevent it from being added to the cache again
geo_candidates.remove(id);
}
} else {
// first doc in current bucket
current_distance = Some((point, distance));
current_bucket.push(id);
geo_candidates.remove(id);
}
}
} else {
// cache exhausted, we need to refill it
self.fill_buffer(ctx, &geo_candidates)?;
if self.cached_sorted_docids.is_empty() {
// candidates exhausted, exit
if let Some((point0, _)) = current_distance.as_ref() {
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
} else {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: None,
}),
}));
}
}
}
}
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx, &geo_candidates)?;
self.next_bucket(ctx, logger, universe)
}
#[tracing::instrument(level = "trace", skip_all, target = "search::geo_sort")]