GeoSort support max_bucket_size and distance_error_margin configuration

This commit is contained in:
hdt3213 2025-04-04 16:35:34 +08:00
parent 326a728434
commit 0f07cfed14
4 changed files with 784 additions and 7 deletions

View File

@ -82,6 +82,11 @@ pub struct GeoSort<Q: RankingRuleQueryTrait> {
cached_sorted_docids: VecDeque<(u32, [f64; 2])>, cached_sorted_docids: VecDeque<(u32, [f64; 2])>,
geo_candidates: RoaringBitmap, geo_candidates: RoaringBitmap,
// Limit the number of docs in a single bucket to avoid unexpectedly large overhead
max_bucket_size: u64,
// Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal
distance_error_margin: f64,
} }
impl<Q: RankingRuleQueryTrait> GeoSort<Q> { impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
@ -100,6 +105,8 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
field_ids: None, field_ids: None,
rtree: None, rtree: None,
cached_sorted_docids: VecDeque::new(), cached_sorted_docids: VecDeque::new(),
max_bucket_size: 1000,
distance_error_margin: 1.0,
}) })
} }
@ -275,15 +282,14 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
let mut current_bucket = RoaringBitmap::new(); let mut current_bucket = RoaringBitmap::new();
// current_distance stores the first point and distance in current bucket // current_distance stores the first point and distance in current bucket
// The farthest distance between two points on earth is about 2e7 meters, u32 is big enough to hold any distance. let mut current_distance: Option<([f64; 2], f64)> = None;
let mut current_distance: Option<([f64; 2], u32)> = None;
loop { loop {
// The loop will only exit when we have found all points with equal distance or have exhausted the candidates. // The loop will only exit when we have found all points with equal distance or have exhausted the candidates.
if let Some((id, point)) = next(&mut self.cached_sorted_docids) { if let Some((id, point)) = next(&mut self.cached_sorted_docids) {
if geo_candidates.contains(id) { if geo_candidates.contains(id) {
let distance = distance_between_two_points(&self.point, &point).round() as u32; let distance = distance_between_two_points(&self.point, &point);
if let Some((point0, bucket_distance)) = current_distance.as_ref() { if let Some((point0, bucket_distance)) = current_distance.as_ref() {
if bucket_distance != &distance { if (bucket_distance - &distance).abs() > self.distance_error_margin {
// different distance, point belongs to next bucket // different distance, point belongs to next bucket
put_back(&mut self.cached_sorted_docids, (id, point)); put_back(&mut self.cached_sorted_docids, (id, point));
return Ok(Some(RankingRuleOutput { return Ok(Some(RankingRuleOutput {
@ -297,15 +303,39 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort<Q> {
})); }));
} else { } else {
// same distance, point belongs to current bucket // same distance, point belongs to current bucket
current_bucket.push(id); current_bucket.insert(id);
// remove from cadidates to prevent it from being added to the cache again // remove from cadidates to prevent it from being added to the cache again
geo_candidates.remove(id); geo_candidates.remove(id);
// current bucket size reaches limit, force return
if current_bucket.len() == self.max_bucket_size {
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point0.to_owned()),
}),
}));
}
} }
} else { } else {
// first doc in current bucket // first doc in current bucket
current_distance = Some((point, distance)); current_distance = Some((point, distance));
current_bucket.push(id); current_bucket.insert(id);
geo_candidates.remove(id); geo_candidates.remove(id);
// current bucket size reaches limit, force return
if current_bucket.len() == self.max_bucket_size {
return Ok(Some(RankingRuleOutput {
query,
candidates: current_bucket,
score: ScoreDetails::GeoSort(score_details::GeoSort {
target_point: self.point,
ascending: self.ascending,
value: Some(point.to_owned()),
}),
}));
}
} }
} }
} else { } else {

View File

@ -18,7 +18,7 @@ fn create_index() -> TempIndex {
index index
.update_settings(|s| { .update_settings(|s| {
s.set_primary_key("id".to_owned()); s.set_primary_key("id".to_owned());
s.set_sortable_fields(hashset! { S(RESERVED_GEO_FIELD_NAME) }); s.set_sortable_fields(hashset! { S(RESERVED_GEO_FIELD_NAME), S("score") });
s.set_criteria(vec![Criterion::Words, Criterion::Sort]); s.set_criteria(vec![Criterion::Words, Criterion::Sort]);
}) })
.unwrap(); .unwrap();
@ -95,6 +95,41 @@ fn test_geo_sort() {
insta::assert_snapshot!(format!("{scores:#?}")); insta::assert_snapshot!(format!("{scores:#?}"));
} }
#[test]
fn test_geo_sort_with_following_ranking_rules() {
let index = create_index();
index
.add_documents(documents!([
{ "id": 1 }, { "id": 4 }, { "id": 3 }, { "id": 2 }, { "id": 5 },
{ "id": 6, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 10 },
{ "id": 7, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 9 },
{ "id": 8, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 8 },
{ "id": 9, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 7 },
{ "id": 10, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score":6 },
{ "id": 11, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 5 },
{ "id": 12, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 10 },
{ "id": 13, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 9 },
{ "id": 14, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 8 },
{ "id": 15, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 7 },
]))
.unwrap();
let rtxn = index.read_txn().unwrap();
let mut s = Search::new(&rtxn, &index);
s.scoring_strategy(crate::score_details::ScoringStrategy::Detailed);
s.sort_criteria(vec![AscDesc::Asc(Member::Geo([0., 0.])), AscDesc::Desc(Member::Field("score".to_string()))]);
let (ids, scores) = execute_iterative_and_rtree_returns_the_same(&rtxn, &index, &mut s);
insta::assert_snapshot!(format!("{ids:?}"), @"[6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 4, 3, 2, 5]");
insta::assert_snapshot!(format!("{scores:#?}"));
s.sort_criteria(vec![AscDesc::Desc(Member::Geo([0., 0.])), AscDesc::Desc(Member::Field("score".to_string()))]);
let (ids, scores) = execute_iterative_and_rtree_returns_the_same(&rtxn, &index, &mut s);
insta::assert_snapshot!(format!("{ids:?}"), @"[12, 13, 14, 15, 6, 7, 8, 9, 10, 11, 1, 4, 3, 2, 5]");
insta::assert_snapshot!(format!("{scores:#?}"));
}
#[test] #[test]
fn test_geo_sort_around_the_edge_of_the_flat_earth() { fn test_geo_sort_around_the_edge_of_the_flat_earth() {
let index = create_index(); let index = create_index();

View File

@ -0,0 +1,356 @@
---
source: crates/milli/src/search/new/tests/geo_sort.rs
expression: "format!(\"{scores:#?}\")"
---
[
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(10.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(9.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(8.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(7.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(6.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(5.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(10.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(9.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(8.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(7.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: true,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
]

View File

@ -0,0 +1,356 @@
---
source: crates/milli/src/search/new/tests/geo_sort.rs
expression: "format!(\"{scores:#?}\")"
---
[
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(10.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(9.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(8.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
5.0,
5.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(7.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(10.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(9.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(8.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(7.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(6.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: Some(
[
2.0,
2.0,
],
),
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Number(5.0),
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
[
GeoSort(
GeoSort {
target_point: [
0.0,
0.0,
],
ascending: false,
value: None,
},
),
Sort(
Sort {
field_name: "score",
ascending: false,
redacted: false,
value: Null,
},
),
],
]