diff --git a/milli/src/search/new/geo_sort.rs b/milli/src/search/new/geo_sort.rs index e94ed33d1..dddb7f426 100644 --- a/milli/src/search/new/geo_sort.rs +++ b/milli/src/search/new/geo_sort.rs @@ -8,6 +8,7 @@ use rstar::RTree; use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec}; +use crate::score_details::{self, ScoreDetails}; use crate::{ distance_between_two_points, lat_lng_to_xyz, GeoPoint, Index, Result, SearchContext, SearchLogger, @@ -80,7 +81,7 @@ pub struct GeoSort { field_ids: Option<[u16; 2]>, rtree: Option>, - cached_sorted_docids: VecDeque, + cached_sorted_docids: VecDeque<(u32, [f64; 2])>, geo_candidates: RoaringBitmap, } @@ -130,7 +131,7 @@ impl GeoSort { let point = lat_lng_to_xyz(&self.point); for point in rtree.nearest_neighbor_iter(&point) { if self.geo_candidates.contains(point.data.0) { - self.cached_sorted_docids.push_back(point.data.0); + self.cached_sorted_docids.push_back(point.data); if self.cached_sorted_docids.len() >= cache_size { break; } @@ -142,7 +143,7 @@ impl GeoSort { let point = lat_lng_to_xyz(&opposite_of(self.point)); for point in rtree.nearest_neighbor_iter(&point) { if self.geo_candidates.contains(point.data.0) { - self.cached_sorted_docids.push_front(point.data.0); + self.cached_sorted_docids.push_front(point.data); if self.cached_sorted_docids.len() >= cache_size { break; } @@ -177,7 +178,7 @@ impl GeoSort { // computing the distance between two points is expensive thus we cache the result documents .sort_by_cached_key(|(_, p)| distance_between_two_points(&self.point, p) as usize); - self.cached_sorted_docids.extend(documents.into_iter().map(|(doc_id, _)| doc_id)); + self.cached_sorted_docids.extend(documents.into_iter()); }; Ok(()) @@ -220,12 +221,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { logger: &mut dyn SearchLogger, universe: &RoaringBitmap, ) -> Result>> { - assert!(universe.len() > 1); let query = self.query.as_ref().unwrap().clone(); self.geo_candidates &= universe; if self.geo_candidates.is_empty() { - return Ok(Some(RankingRuleOutput { query, candidates: universe.clone() })); + return Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::GeoSort(score_details::GeoSort { + target_point: self.point, + ascending: self.ascending, + value: None, + }), + })); } let ascending = self.ascending; @@ -236,11 +244,16 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { cache.pop_back() } }; - while let Some(id) = next(&mut self.cached_sorted_docids) { + while let Some((id, point)) = next(&mut self.cached_sorted_docids) { if self.geo_candidates.contains(id) { return Ok(Some(RankingRuleOutput { query, candidates: RoaringBitmap::from_iter([id]), + score: ScoreDetails::GeoSort(score_details::GeoSort { + target_point: self.point, + ascending: self.ascending, + value: Some(point), + }), })); } }