diff --git a/crates/milli/src/search/hybrid.rs b/crates/milli/src/search/hybrid.rs index 81f74fdad..e07f886c9 100644 --- a/crates/milli/src/search/hybrid.rs +++ b/crates/milli/src/search/hybrid.rs @@ -164,7 +164,7 @@ impl Search<'_> { sort_criteria: self.sort_criteria.clone(), distinct: self.distinct.clone(), searchable_attributes: self.searchable_attributes, - geo_strategy: self.geo_strategy, + geo_param: self.geo_param, terms_matching_strategy: self.terms_matching_strategy, scoring_strategy: ScoringStrategy::Detailed, words_limit: self.words_limit, diff --git a/crates/milli/src/search/mod.rs b/crates/milli/src/search/mod.rs index 0dd639c59..def00ec92 100644 --- a/crates/milli/src/search/mod.rs +++ b/crates/milli/src/search/mod.rs @@ -45,7 +45,7 @@ pub struct Search<'a> { sort_criteria: Option>, distinct: Option, searchable_attributes: Option<&'a [String]>, - geo_strategy: new::GeoSortStrategy, + geo_param: new::GeoSortParameter, terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, words_limit: usize, @@ -68,7 +68,7 @@ impl<'a> Search<'a> { sort_criteria: None, distinct: None, searchable_attributes: None, - geo_strategy: new::GeoSortStrategy::default(), + geo_param: new::GeoSortParameter::default(), terms_matching_strategy: TermsMatchingStrategy::default(), scoring_strategy: Default::default(), exhaustive_number_hits: false, @@ -145,7 +145,13 @@ impl<'a> Search<'a> { #[cfg(test)] pub fn geo_sort_strategy(&mut self, strategy: new::GeoSortStrategy) -> &mut Search<'a> { - self.geo_strategy = strategy; + self.geo_param.strategy = strategy; + self + } + + #[cfg(test)] + pub fn geo_max_bucket_size(&mut self, max_size: u64) -> &mut Search<'a> { + self.geo_param.max_bucket_size = max_size; self } @@ -232,7 +238,7 @@ impl<'a> Search<'a> { universe, &self.sort_criteria, &self.distinct, - self.geo_strategy, + self.geo_param, self.offset, self.limit, embedder_name, @@ -251,7 +257,7 @@ impl<'a> Search<'a> { universe, &self.sort_criteria, &self.distinct, - self.geo_strategy, + self.geo_param, self.offset, self.limit, Some(self.words_limit), @@ -290,7 +296,7 @@ impl fmt::Debug for Search<'_> { sort_criteria, distinct, searchable_attributes, - geo_strategy: _, + geo_param: _, terms_matching_strategy, scoring_strategy, words_limit, diff --git a/crates/milli/src/search/new/geo_sort.rs b/crates/milli/src/search/new/geo_sort.rs index 1c52b0a5b..7f1c0feff 100644 --- a/crates/milli/src/search/new/geo_sort.rs +++ b/crates/milli/src/search/new/geo_sort.rs @@ -39,6 +39,22 @@ fn facet_number_values<'a>( Ok(iter) } +#[derive(Debug, Clone, Copy)] + +pub struct Parameter { + // Define the strategy used by the geo sort + pub strategy: Strategy, + // Limit the number of docs in a single bucket to avoid unexpectedly large overhead + pub max_bucket_size: u64, + // Considering the errors of GPS and geographical calculations, distances less than distance_error_margin will be treated as equal + pub distance_error_margin: f64, +} + +impl Default for Parameter { + fn default() -> Self { + Self { strategy: Strategy::default(), max_bucket_size: 1000, distance_error_margin: 1.0 } + } +} /// Define the strategy used by the geo sort. /// The parameter represents the cache size, and, in the case of the Dynamic strategy, /// the point where we move from using the iterative strategy to the rtree. @@ -71,26 +87,6 @@ impl Strategy { } } -#[cfg(not(test))] -fn default_max_bucket_size() -> u64 { - 1000 -} - -#[cfg(test)] -static DEFAULT_MAX_BUCKET_SIZE: std::sync::Mutex = std::sync::Mutex::new(1000); - -#[cfg(test)] -pub fn set_default_max_bucket_size(n: u64) { - let mut size = DEFAULT_MAX_BUCKET_SIZE.lock().unwrap(); - *size = n; -} - -#[cfg(test)] -fn default_max_bucket_size() -> u64 { - let max_size = *(DEFAULT_MAX_BUCKET_SIZE.lock().unwrap()); - max_size -} - pub struct GeoSort { query: Option, @@ -111,22 +107,23 @@ pub struct GeoSort { impl GeoSort { pub fn new( - strategy: Strategy, + parameter: &Parameter, geo_faceted_docids: RoaringBitmap, point: [f64; 2], ascending: bool, ) -> Result { + let Parameter { strategy, max_bucket_size, distance_error_margin } = parameter; Ok(Self { query: None, - strategy, + strategy: *strategy, ascending, point, geo_candidates: geo_faceted_docids, field_ids: None, rtree: None, cached_sorted_docids: VecDeque::new(), - max_bucket_size: default_max_bucket_size(), - distance_error_margin: 1.0, + max_bucket_size: *max_bucket_size, + distance_error_margin: *distance_error_margin, }) } diff --git a/crates/milli/src/search/new/matches/mod.rs b/crates/milli/src/search/new/matches/mod.rs index e30f11e94..2d6f2cf17 100644 --- a/crates/milli/src/search/new/matches/mod.rs +++ b/crates/milli/src/search/new/matches/mod.rs @@ -513,7 +513,7 @@ mod tests { universe, &None, &None, - crate::search::new::GeoSortStrategy::default(), + crate::search::new::GeoSortParameter::default(), 0, 100, Some(10), diff --git a/crates/milli/src/search/new/mod.rs b/crates/milli/src/search/new/mod.rs index b9161b417..5042fb3b7 100644 --- a/crates/milli/src/search/new/mod.rs +++ b/crates/milli/src/search/new/mod.rs @@ -45,6 +45,7 @@ use sort::Sort; use self::distinct::facet_string_values; use self::geo_sort::GeoSort; +pub use self::geo_sort::Parameter as GeoSortParameter; pub use self::geo_sort::Strategy as GeoSortStrategy; use self::graph_based_ranking_rule::Words; use self::interner::Interned; @@ -274,7 +275,7 @@ fn resolve_negative_phrases( fn get_ranking_rules_for_placeholder_search<'ctx>( ctx: &SearchContext<'ctx>, sort_criteria: &Option>, - geo_strategy: geo_sort::Strategy, + geo_param: geo_sort::Parameter, ) -> Result>> { let mut sort = false; let mut sorted_fields = HashSet::new(); @@ -299,7 +300,7 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( &mut ranking_rules, &mut sorted_fields, &mut geo_sorted, - geo_strategy, + &geo_param, )?; sort = true; } @@ -326,7 +327,7 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( fn get_ranking_rules_for_vector<'ctx>( ctx: &SearchContext<'ctx>, sort_criteria: &Option>, - geo_strategy: geo_sort::Strategy, + geo_param: geo_sort::Parameter, limit_plus_offset: usize, target: &[f32], embedder_name: &str, @@ -375,7 +376,7 @@ fn get_ranking_rules_for_vector<'ctx>( &mut ranking_rules, &mut sorted_fields, &mut geo_sorted, - geo_strategy, + &geo_param, )?; sort = true; } @@ -403,7 +404,7 @@ fn get_ranking_rules_for_vector<'ctx>( fn get_ranking_rules_for_query_graph_search<'ctx>( ctx: &SearchContext<'ctx>, sort_criteria: &Option>, - geo_strategy: geo_sort::Strategy, + geo_param: geo_sort::Parameter, terms_matching_strategy: TermsMatchingStrategy, ) -> Result>> { // query graph search @@ -477,7 +478,7 @@ fn get_ranking_rules_for_query_graph_search<'ctx>( &mut ranking_rules, &mut sorted_fields, &mut geo_sorted, - geo_strategy, + &geo_param, )?; sort = true; } @@ -514,7 +515,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( ranking_rules: &mut Vec>, sorted_fields: &mut HashSet, geo_sorted: &mut bool, - geo_strategy: geo_sort::Strategy, + geo_param: &geo_sort::Parameter, ) -> Result<()> { let sort_criteria = sort_criteria.clone().unwrap_or_default(); ranking_rules.reserve(sort_criteria.len()); @@ -540,7 +541,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( } let geo_faceted_docids = ctx.index.geo_faceted_documents_ids(ctx.txn)?; ranking_rules.push(Box::new(GeoSort::new( - geo_strategy, + geo_param, geo_faceted_docids, point, true, @@ -552,7 +553,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( } let geo_faceted_docids = ctx.index.geo_faceted_documents_ids(ctx.txn)?; ranking_rules.push(Box::new(GeoSort::new( - geo_strategy, + geo_param, geo_faceted_docids, point, false, @@ -584,7 +585,7 @@ pub fn execute_vector_search( universe: RoaringBitmap, sort_criteria: &Option>, distinct: &Option, - geo_strategy: geo_sort::Strategy, + geo_param: geo_sort::Parameter, from: usize, length: usize, embedder_name: &str, @@ -600,7 +601,7 @@ pub fn execute_vector_search( let ranking_rules = get_ranking_rules_for_vector( ctx, sort_criteria, - geo_strategy, + geo_param, from + length, vector, embedder_name, @@ -647,7 +648,7 @@ pub fn execute_search( mut universe: RoaringBitmap, sort_criteria: &Option>, distinct: &Option, - geo_strategy: geo_sort::Strategy, + geo_param: geo_sort::Parameter, from: usize, length: usize, words_limit: Option, @@ -761,7 +762,7 @@ pub fn execute_search( let ranking_rules = get_ranking_rules_for_query_graph_search( ctx, sort_criteria, - geo_strategy, + geo_param, terms_matching_strategy, )?; @@ -783,7 +784,7 @@ pub fn execute_search( )? } else { let ranking_rules = - get_ranking_rules_for_placeholder_search(ctx, sort_criteria, geo_strategy)?; + get_ranking_rules_for_placeholder_search(ctx, sort_criteria, geo_param)?; bucket_sort( ctx, ranking_rules, diff --git a/crates/milli/src/search/new/tests/geo_sort.rs b/crates/milli/src/search/new/tests/geo_sort.rs index 6d3df5262..f67d3ffdf 100644 --- a/crates/milli/src/search/new/tests/geo_sort.rs +++ b/crates/milli/src/search/new/tests/geo_sort.rs @@ -157,10 +157,10 @@ fn test_geo_sort_reached_max_bucket_size() { ])) .unwrap(); - crate::search::new::geo_sort::set_default_max_bucket_size(2); let rtxn = index.read_txn().unwrap(); let mut s = Search::new(&rtxn, &index); + s.geo_max_bucket_size(2); s.scoring_strategy(crate::score_details::ScoringStrategy::Detailed); s.sort_criteria(vec![ AscDesc::Asc(Member::Geo([0., 0.])), @@ -200,9 +200,6 @@ fn test_geo_sort_reached_max_bucket_size() { } let no_geo_ids = rtree_ids[10..].iter().collect_vec(); insta::assert_snapshot!(format!("{no_geo_ids:?}"), @r#"["1", "4", "3", "2", "5"]"#); - - // recover settings - crate::search::new::geo_sort::set_default_max_bucket_size(1000); } #[test]