From 0f07cfed143fd67e797fd4fe2fdaf6cabdc6b6b0 Mon Sep 17 00:00:00 2001 From: hdt3213 Date: Fri, 4 Apr 2025 16:35:34 +0800 Subject: [PATCH] GeoSort support max_bucket_size and distance_error_margin configuration --- crates/milli/src/search/new/geo_sort.rs | 42 ++- crates/milli/src/search/new/tests/geo_sort.rs | 37 +- ...o_sort_with_following_ranking_rules-2.snap | 356 ++++++++++++++++++ ...o_sort_with_following_ranking_rules-4.snap | 356 ++++++++++++++++++ 4 files changed, 784 insertions(+), 7 deletions(-) create mode 100644 crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-2.snap create mode 100644 crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-4.snap diff --git a/crates/milli/src/search/new/geo_sort.rs b/crates/milli/src/search/new/geo_sort.rs index 1a2bd2edc..ca5a4ab8b 100644 --- a/crates/milli/src/search/new/geo_sort.rs +++ b/crates/milli/src/search/new/geo_sort.rs @@ -82,6 +82,11 @@ pub struct GeoSort { cached_sorted_docids: VecDeque<(u32, [f64; 2])>, geo_candidates: RoaringBitmap, + + // Limit the number of docs in a single bucket to avoid unexpectedly large overhead + max_bucket_size: u64, + // Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal + distance_error_margin: f64, } impl GeoSort { @@ -100,6 +105,8 @@ impl GeoSort { field_ids: None, rtree: None, cached_sorted_docids: VecDeque::new(), + max_bucket_size: 1000, + distance_error_margin: 1.0, }) } @@ -275,15 +282,14 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { let mut current_bucket = RoaringBitmap::new(); // current_distance stores the first point and distance in current bucket - // The farthest distance between two points on earth is about 2e7 meters, u32 is big enough to hold any distance. - let mut current_distance: Option<([f64; 2], u32)> = None; + let mut current_distance: Option<([f64; 2], f64)> = None; loop { // The loop will only exit when we have found all points with equal distance or have exhausted the candidates. if let Some((id, point)) = next(&mut self.cached_sorted_docids) { if geo_candidates.contains(id) { - let distance = distance_between_two_points(&self.point, &point).round() as u32; + let distance = distance_between_two_points(&self.point, &point); if let Some((point0, bucket_distance)) = current_distance.as_ref() { - if bucket_distance != &distance { + if (bucket_distance - &distance).abs() > self.distance_error_margin { // different distance, point belongs to next bucket put_back(&mut self.cached_sorted_docids, (id, point)); return Ok(Some(RankingRuleOutput { @@ -297,15 +303,39 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { })); } else { // same distance, point belongs to current bucket - current_bucket.push(id); + current_bucket.insert(id); // remove from cadidates to prevent it from being added to the cache again geo_candidates.remove(id); + // current bucket size reaches limit, force return + if current_bucket.len() == self.max_bucket_size { + return Ok(Some(RankingRuleOutput { + query, + candidates: current_bucket, + score: ScoreDetails::GeoSort(score_details::GeoSort { + target_point: self.point, + ascending: self.ascending, + value: Some(point0.to_owned()), + }), + })); + } } } else { // first doc in current bucket current_distance = Some((point, distance)); - current_bucket.push(id); + current_bucket.insert(id); geo_candidates.remove(id); + // current bucket size reaches limit, force return + if current_bucket.len() == self.max_bucket_size { + return Ok(Some(RankingRuleOutput { + query, + candidates: current_bucket, + score: ScoreDetails::GeoSort(score_details::GeoSort { + target_point: self.point, + ascending: self.ascending, + value: Some(point.to_owned()), + }), + })); + } } } } else { diff --git a/crates/milli/src/search/new/tests/geo_sort.rs b/crates/milli/src/search/new/tests/geo_sort.rs index 2eda39ba1..3d89f5d2f 100644 --- a/crates/milli/src/search/new/tests/geo_sort.rs +++ b/crates/milli/src/search/new/tests/geo_sort.rs @@ -18,7 +18,7 @@ fn create_index() -> TempIndex { index .update_settings(|s| { s.set_primary_key("id".to_owned()); - s.set_sortable_fields(hashset! { S(RESERVED_GEO_FIELD_NAME) }); + s.set_sortable_fields(hashset! { S(RESERVED_GEO_FIELD_NAME), S("score") }); s.set_criteria(vec![Criterion::Words, Criterion::Sort]); }) .unwrap(); @@ -95,6 +95,41 @@ fn test_geo_sort() { insta::assert_snapshot!(format!("{scores:#?}")); } +#[test] +fn test_geo_sort_with_following_ranking_rules() { + let index = create_index(); + + index + .add_documents(documents!([ + { "id": 1 }, { "id": 4 }, { "id": 3 }, { "id": 2 }, { "id": 5 }, + { "id": 6, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 10 }, + { "id": 7, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 9 }, + { "id": 8, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 8 }, + { "id": 9, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 7 }, + { "id": 10, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score":6 }, + { "id": 11, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 5 }, + { "id": 12, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 10 }, + { "id": 13, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 9 }, + { "id": 14, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 8 }, + { "id": 15, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 7 }, + ])) + .unwrap(); + + let rtxn = index.read_txn().unwrap(); + + let mut s = Search::new(&rtxn, &index); + s.scoring_strategy(crate::score_details::ScoringStrategy::Detailed); + s.sort_criteria(vec![AscDesc::Asc(Member::Geo([0., 0.])), AscDesc::Desc(Member::Field("score".to_string()))]); + let (ids, scores) = execute_iterative_and_rtree_returns_the_same(&rtxn, &index, &mut s); + insta::assert_snapshot!(format!("{ids:?}"), @"[6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 1, 4, 3, 2, 5]"); + insta::assert_snapshot!(format!("{scores:#?}")); + + s.sort_criteria(vec![AscDesc::Desc(Member::Geo([0., 0.])), AscDesc::Desc(Member::Field("score".to_string()))]); + let (ids, scores) = execute_iterative_and_rtree_returns_the_same(&rtxn, &index, &mut s); + insta::assert_snapshot!(format!("{ids:?}"), @"[12, 13, 14, 15, 6, 7, 8, 9, 10, 11, 1, 4, 3, 2, 5]"); + insta::assert_snapshot!(format!("{scores:#?}")); +} + #[test] fn test_geo_sort_around_the_edge_of_the_flat_earth() { let index = create_index(); diff --git a/crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-2.snap b/crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-2.snap new file mode 100644 index 000000000..b8b6e33e3 --- /dev/null +++ b/crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-2.snap @@ -0,0 +1,356 @@ +--- +source: crates/milli/src/search/new/tests/geo_sort.rs +expression: "format!(\"{scores:#?}\")" +--- +[ + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(10.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(9.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(8.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(7.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(6.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(5.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(10.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(9.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(8.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(7.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: true, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], +] diff --git a/crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-4.snap b/crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-4.snap new file mode 100644 index 000000000..82124cd90 --- /dev/null +++ b/crates/milli/src/search/new/tests/snapshots/milli__search__new__tests__geo_sort__geo_sort_with_following_ranking_rules-4.snap @@ -0,0 +1,356 @@ +--- +source: crates/milli/src/search/new/tests/geo_sort.rs +expression: "format!(\"{scores:#?}\")" +--- +[ + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(10.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(9.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(8.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 5.0, + 5.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(7.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(10.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(9.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(8.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(7.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(6.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: Some( + [ + 2.0, + 2.0, + ], + ), + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Number(5.0), + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], + [ + GeoSort( + GeoSort { + target_point: [ + 0.0, + 0.0, + ], + ascending: false, + value: None, + }, + ), + Sort( + Sort { + field_name: "score", + ascending: false, + redacted: false, + value: Null, + }, + ), + ], +]