Score for geosort

This commit is contained in:
Louis Dureuil 2023-06-15 17:32:51 +02:00
parent 2ea8194c18
commit 59c5b992c2
No known key found for this signature in database

View File

@ -8,6 +8,7 @@ use rstar::RTree;
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec}; use crate::heed_codec::facet::{FieldDocIdFacetCodec, OrderedF64Codec};
use crate::score_details::{self, ScoreDetails};
use crate::{ use crate::{
distance_between_two_points, lat_lng_to_xyz, GeoPoint, Index, Result, SearchContext, distance_between_two_points, lat_lng_to_xyz, GeoPoint, Index, Result, SearchContext,
SearchLogger, SearchLogger,
@ -80,7 +81,7 @@ pub struct GeoSort<Q: RankingRuleQueryTrait> {
field_ids: Option<[u16; 2]>, field_ids: Option<[u16; 2]>,
rtree: Option<RTree<GeoPoint>>, rtree: Option<RTree<GeoPoint>>,
cached_sorted_docids: VecDeque<u32>, cached_sorted_docids: VecDeque<(u32, [f64; 2])>,
geo_candidates: RoaringBitmap, geo_candidates: RoaringBitmap,
} }
@ -130,7 +131,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
let point = lat_lng_to_xyz(&self.point); let point = lat_lng_to_xyz(&self.point);
for point in rtree.nearest_neighbor_iter(&point) { for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) { if self.geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_back(point.data.0); self.cached_sorted_docids.push_back(point.data);
if self.cached_sorted_docids.len() >= cache_size { if self.cached_sorted_docids.len() >= cache_size {
break; break;
} }
@ -142,7 +143,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
let point = lat_lng_to_xyz(&opposite_of(self.point)); let point = lat_lng_to_xyz(&opposite_of(self.point));
for point in rtree.nearest_neighbor_iter(&point) { for point in rtree.nearest_neighbor_iter(&point) {
if self.geo_candidates.contains(point.data.0) { if self.geo_candidates.contains(point.data.0) {
self.cached_sorted_docids.push_front(point.data.0); self.cached_sorted_docids.push_front(point.data);
if self.cached_sorted_docids.len() >= cache_size { if self.cached_sorted_docids.len() >= cache_size {
break; break;
} }
@ -177,7 +178,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
// computing the distance between two points is expensive thus we cache the result // computing the distance between two points is expensive thus we cache the result
documents documents
.sort_by_cached_key(|(_, p)| distance_between_two_points(&self.point, p) as usize); .sort_by_cached_key(|(_, p)| distance_between_two_points(&self.point, p) as usize);
self.cached_sorted_docids.extend(documents.into_iter().map(|(doc_id, _)| doc_id)); self.cached_sorted_docids.extend(documents.into_iter());
}; };
Ok(()) Ok(())
@ -220,12 +221,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
logger: &mut dyn SearchLogger<Q>, logger: &mut dyn SearchLogger<Q>,
universe: &RoaringBitmap, universe: &RoaringBitmap,
) -> Result<Option<RankingRuleOutput<Q>>> { ) -> Result<Option<RankingRuleOutput<Q>>> {
assert!(universe.len() > 1);
let query = self.query.as_ref().unwrap().clone(); let query = self.query.as_ref().unwrap().clone();
self.geo_candidates &= universe; self.geo_candidates &= universe;
if self.geo_candidates.is_empty() { if self.geo_candidates.is_empty() {
return Ok(Some(RankingRuleOutput { query, candidates: universe.clone() })); return Ok(Some(RankingRuleOutput {
query,
candidates: universe.clone(),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: None,
}),
}));
} }
let ascending = self.ascending; let ascending = self.ascending;
@ -236,11 +244,16 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
cache.pop_back() cache.pop_back()
} }
}; };
while let Some(id) = next(&mut self.cached_sorted_docids) { while let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if self.geo_candidates.contains(id) { if self.geo_candidates.contains(id) {
return Ok(Some(RankingRuleOutput { return Ok(Some(RankingRuleOutput {
query, query,
candidates: RoaringBitmap::from_iter([id]), candidates: RoaringBitmap::from_iter([id]),
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point),
}),
})); }));
} }
} }