test geo sort reached max_bucket_size

This commit is contained in:
hdt3213 2025-04-14 22:50:32 +08:00
parent c4a8b84dc0
commit 5da92a3d53
2 changed files with 87 additions and 1 deletions

View File

@ -71,6 +71,26 @@ impl Strategy {
} }
} }
#[cfg(not(test))]
fn default_max_bucket_size() -> u64 {
1000
}
#[cfg(test)]
static DEFAULT_MAX_BUCKET_SIZE: std::sync::Mutex<u64> = std::sync::Mutex::new(1000);
#[cfg(test)]
pub fn set_default_max_bucket_size(n: u64) {
let mut size = DEFAULT_MAX_BUCKET_SIZE.lock().unwrap();
*size = n;
}
#[cfg(test)]
fn default_max_bucket_size() -> u64 {
let max_size = *(DEFAULT_MAX_BUCKET_SIZE.lock().unwrap());
max_size
}
pub struct GeoSort<Q: RankingRuleQueryTrait> { pub struct GeoSort<Q: RankingRuleQueryTrait> {
query: Option<Q>, query: Option<Q>,
@ -105,7 +125,7 @@ impl<Q: RankingRuleQueryTrait> GeoSort<Q> {
field_ids: None, field_ids: None,
rtree: None, rtree: None,
cached_sorted_docids: VecDeque::new(), cached_sorted_docids: VecDeque::new(),
max_bucket_size: 1000, max_bucket_size: default_max_bucket_size(),
distance_error_margin: 1.0, distance_error_margin: 1.0,
}) })
} }

View File

@ -4,6 +4,7 @@ This module tests the `geo_sort` ranking rule
use big_s::S; use big_s::S;
use heed::RoTxn; use heed::RoTxn;
use itertools::Itertools;
use maplit::hashset; use maplit::hashset;
use crate::constants::RESERVED_GEO_FIELD_NAME; use crate::constants::RESERVED_GEO_FIELD_NAME;
@ -136,6 +137,71 @@ fn test_geo_sort_with_following_ranking_rules() {
insta::assert_snapshot!(format!("{scores:#?}")); insta::assert_snapshot!(format!("{scores:#?}"));
} }
#[test]
fn test_geo_sort_reached_max_bucket_size() {
let index = create_index();
index
.add_documents(documents!([
{ "id": 1 }, { "id": 4 }, { "id": 3 }, { "id": 2 }, { "id": 5 },
{ "id": 6, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 10 },
{ "id": 7, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 9 },
{ "id": 8, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 8 },
{ "id": 9, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 7 },
{ "id": 10, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score":6 },
{ "id": 11, RESERVED_GEO_FIELD_NAME: { "lat": 2, "lng": 2 }, "score": 5 },
{ "id": 12, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 10 },
{ "id": 13, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 9 },
{ "id": 14, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 8 },
{ "id": 15, RESERVED_GEO_FIELD_NAME: { "lat": 5, "lng": 5 }, "score": 7 },
]))
.unwrap();
crate::search::new::geo_sort::set_default_max_bucket_size(2);
let rtxn = index.read_txn().unwrap();
let mut s = Search::new(&rtxn, &index);
s.scoring_strategy(crate::score_details::ScoringStrategy::Detailed);
s.sort_criteria(vec![
AscDesc::Asc(Member::Geo([0., 0.])),
AscDesc::Desc(Member::Field("score".to_string())),
]);
/* We should not expect the results to obey the following ranking rules when the bucket size limit is reached,
* nor should we expect Iteration and rtree to give exactly the same order for the same bucket in this case.*/
s.geo_sort_strategy(GeoSortStrategy::AlwaysIterative(1000));
let SearchResult { documents_ids, .. } = s.execute().unwrap();
let iterative_ids = collect_field_values(&index, &rtxn, "id", &documents_ids);
assert_eq!(iterative_ids.len(), 15);
for id_str in &iterative_ids[0..6] {
let id = id_str.parse::<u32>().unwrap();
assert!(id >= 6 && id <= 11)
}
for id_str in &iterative_ids[6..10] {
let id = id_str.parse::<u32>().unwrap();
assert!(id >= 12 && id <= 15)
}
let no_geo_ids = iterative_ids[10..].iter().collect_vec();
insta::assert_snapshot!(format!("{no_geo_ids:?}"), @r#"["1", "4", "3", "2", "5"]"#);
s.geo_sort_strategy(GeoSortStrategy::AlwaysRtree(1000));
let SearchResult { documents_ids, .. } = s.execute().unwrap();
let rtree_ids = collect_field_values(&index, &rtxn, "id", &documents_ids);
assert_eq!(rtree_ids.len(), 15);
for id_str in &rtree_ids[0..6] {
let id = id_str.parse::<u32>().unwrap();
assert!(id >= 6 && id <= 11)
}
for id_str in &rtree_ids[6..10] {
let id = id_str.parse::<u32>().unwrap();
assert!(id >= 12 && id <= 15)
}
let no_geo_ids = rtree_ids[10..].iter().collect_vec();
insta::assert_snapshot!(format!("{no_geo_ids:?}"), @r#"["1", "4", "3", "2", "5"]"#);
}
#[test] #[test]
fn test_geo_sort_around_the_edge_of_the_flat_earth() { fn test_geo_sort_around_the_edge_of_the_flat_earth() {
let index = create_index(); let index = create_index();