fix ranking rules after _geo do not work

This commit is contained in:
hdt3213 2025-04-02 20:09:51 +08:00
parent a500fa053c
commit e4733dcd42

View File

@ -1,6 +1,4 @@
use std::collections::VecDeque; use std::collections::VecDeque;
use std::iter::FromIterator;
use heed::types::{Bytes, Unit}; use heed::types::{Bytes, Unit};
use heed::{RoPrefix, RoTxn}; use heed::{RoPrefix, RoTxn};
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
@ -245,7 +243,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
) -> Result<Option<RankingRuleOutput<Q>>> { ) -> Result<Option<RankingRuleOutput<Q>>> {
let query = self.query.as_ref().unwrap().clone(); let query = self.query.as_ref().unwrap().clone();
let geo_candidates = &self.geo_candidates & universe; let mut geo_candidates = &self.geo_candidates & universe;
if geo_candidates.is_empty() { if geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput { return Ok(Some(RankingRuleOutput {
@ -267,24 +265,79 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
cache.pop_back() cache.pop_back()
} }
}; };
while let Some((id, point)) = next(&mut self.cached_sorted_docids) { let put_back = |cache: &mut VecDeque<_>, x: _| {
if geo_candidates.contains(id) { if ascending {
return Ok(Some(RankingRuleOutput { cache.push_front(x)
query, } else {
candidates: RoaringBitmap::from_iter([id]), cache.push_back(x)
score: ScoreDetails::GeoSort(score_details::GeoSort { }
target_point: self.point, };
ascending: self.ascending,
value: Some(point), let mut current_bucket = RoaringBitmap::new();
}), // current_distance stores the first point and distance in current bucket
})); // The farthest distance between two points on earth is about 2e7 meters, u32 is big enough to hold any distance.
let mut current_distance: Option<([f64; 2], u32)> = None;
loop {
// The loop will only exit when we have found all points with equal distance or have exhausted the candidates.
if let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if geo_candidates.contains(id) {
let distance = distance_between_two_points(&self.point, &point).round() as u32;
if let Some((point0, bucket_distance)) = current_distance.as_ref() {
if bucket_distance != &distance {
// different distance, point belongs to next bucket
put_back(&mut self.cached_sorted_docids, (id, point));
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
} else {
// same distance, point belongs to current bucket
current_bucket.push(id);
// remove from cadidates to prevent it from being added to the cache again
geo_candidates.remove(id);
}
} else {
// first doc in current bucket
current_distance = Some((point, distance));
current_bucket.push(id);
geo_candidates.remove(id);
}
}
} else {
// cache exhausted, we need to refill it
self.fill_buffer(ctx, &geo_candidates)?;
if self.cached_sorted_docids.is_empty() {
// candidates exhausted, exit
if let Some((point0, _)) = current_distance.as_ref() {
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
} else {
return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: None,
}),
}));
}
}
} }
} }
// if we got out of this loop it means we've exhausted our cache.
// we need to refill it and run the function again.
self.fill_buffer(ctx, &geo_candidates)?;
self.next_bucket(ctx, logger, universe)
} }
#[tracing::instrument(level = "trace", skip_all, target = "search::geo_sort")] #[tracing::instrument(level = "trace", skip_all, target = "search::geo_sort")]