diff --git a/crates/meilisearch/src/routes/multi_search.rs b/crates/meilisearch/src/routes/multi_search.rs index fcc3cd700..46c931b10 100644 --- a/crates/meilisearch/src/routes/multi_search.rs +++ b/crates/meilisearch/src/routes/multi_search.rs @@ -48,6 +48,7 @@ pub struct SearchResults { /// Bundle multiple search queries in a single API request. Use this endpoint to search through multiple indexes at once. #[utoipa::path( post, + request_body = FederatedSearch, path = "", tag = "Multi-search", security(("Bearer" = ["search", "*"])), diff --git a/crates/meilisearch/src/search/federated.rs b/crates/meilisearch/src/search/federated.rs deleted file mode 100644 index 1b3fa3b26..000000000 --- a/crates/meilisearch/src/search/federated.rs +++ /dev/null @@ -1,923 +0,0 @@ -use std::cmp::Ordering; -use std::collections::BTreeMap; -use std::fmt; -use std::iter::Zip; -use std::rc::Rc; -use std::str::FromStr as _; -use std::time::Duration; -use std::vec::{IntoIter, Vec}; - -use actix_http::StatusCode; -use index_scheduler::{IndexScheduler, RoFeatures}; -use indexmap::IndexMap; -use meilisearch_types::deserr::DeserrJsonError; -use meilisearch_types::error::deserr_codes::{ - InvalidMultiSearchFacetsByIndex, InvalidMultiSearchMaxValuesPerFacet, - InvalidMultiSearchMergeFacets, InvalidMultiSearchWeight, InvalidSearchLimit, - InvalidSearchOffset, -}; -use meilisearch_types::error::ResponseError; -use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::score_details::{ScoreDetails, ScoreValue}; -use meilisearch_types::milli::{self, DocumentId, OrderBy, TimeBudget}; -use roaring::RoaringBitmap; -use serde::Serialize; -use utoipa::ToSchema; - -use super::ranking_rules::{self, RankingRules}; -use super::{ - compute_facet_distribution_stats, prepare_search, AttributesFormat, ComputedFacets, FacetStats, - HitMaker, HitsInfo, RetrieveVectors, SearchHit, SearchKind, SearchQuery, SearchQueryWithIndex, -}; -use crate::error::MeilisearchHttpError; -use crate::routes::indexes::search::search_kind; - -pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0; - -#[derive(Debug, Default, Clone, Copy, PartialEq, deserr::Deserr, ToSchema)] -#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] -pub struct FederationOptions { - #[deserr(default, error = DeserrJsonError)] - #[schema(value_type = f64)] - pub weight: Weight, -} - -#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] -#[deserr(try_from(f64) = TryFrom::try_from -> InvalidMultiSearchWeight)] -pub struct Weight(f64); - -impl Default for Weight { - fn default() -> Self { - Weight(DEFAULT_FEDERATED_WEIGHT) - } -} - -impl std::convert::TryFrom for Weight { - type Error = InvalidMultiSearchWeight; - - fn try_from(f: f64) -> Result { - if f < 0.0 { - Err(InvalidMultiSearchWeight) - } else { - Ok(Weight(f)) - } - } -} - -impl std::ops::Deref for Weight { - type Target = f64; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - -#[derive(Debug, deserr::Deserr, ToSchema)] -#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] -#[schema(rename_all = "camelCase")] -pub struct Federation { - #[deserr(default = super::DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] - pub limit: usize, - #[deserr(default = super::DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] - pub offset: usize, - #[deserr(default, error = DeserrJsonError)] - pub facets_by_index: BTreeMap>>, - #[deserr(default, error = DeserrJsonError)] - pub merge_facets: Option, -} - -#[derive(Copy, Clone, Debug, deserr::Deserr, Default, ToSchema)] -#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] -#[schema(rename_all = "camelCase")] -pub struct MergeFacets { - #[deserr(default, error = DeserrJsonError)] - pub max_values_per_facet: Option, -} - -#[derive(Debug, deserr::Deserr, ToSchema)] -#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] -#[schema(rename_all = "camelCase")] -pub struct FederatedSearch { - pub queries: Vec, - #[deserr(default)] - pub federation: Option, -} - -#[derive(Serialize, Clone, ToSchema)] -#[serde(rename_all = "camelCase")] -#[schema(rename_all = "camelCase")] -pub struct FederatedSearchResult { - pub hits: Vec, - pub processing_time_ms: u128, - #[serde(flatten)] - pub hits_info: HitsInfo, - - #[serde(skip_serializing_if = "Option::is_none")] - pub semantic_hit_count: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - #[schema(value_type = Option>>)] - pub facet_distribution: Option>>, - #[serde(skip_serializing_if = "Option::is_none")] - pub facet_stats: Option>, - #[serde(skip_serializing_if = "FederatedFacets::is_empty")] - pub facets_by_index: FederatedFacets, - - // These fields are only used for analytics purposes - #[serde(skip)] - pub degraded: bool, - #[serde(skip)] - pub used_negative_operator: bool, -} - -impl fmt::Debug for FederatedSearchResult { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let FederatedSearchResult { - hits, - processing_time_ms, - hits_info, - semantic_hit_count, - degraded, - used_negative_operator, - facet_distribution, - facet_stats, - facets_by_index, - } = self; - - let mut debug = f.debug_struct("SearchResult"); - // The most important thing when looking at a search result is the time it took to process - debug.field("processing_time_ms", &processing_time_ms); - debug.field("hits", &format!("[{} hits returned]", hits.len())); - debug.field("hits_info", &hits_info); - if *used_negative_operator { - debug.field("used_negative_operator", used_negative_operator); - } - if *degraded { - debug.field("degraded", degraded); - } - if let Some(facet_distribution) = facet_distribution { - debug.field("facet_distribution", &facet_distribution); - } - if let Some(facet_stats) = facet_stats { - debug.field("facet_stats", &facet_stats); - } - if let Some(semantic_hit_count) = semantic_hit_count { - debug.field("semantic_hit_count", &semantic_hit_count); - } - if !facets_by_index.is_empty() { - debug.field("facets_by_index", &facets_by_index); - } - - debug.finish() - } -} - -struct WeightedScore<'a> { - details: &'a [ScoreDetails], - weight: f64, -} - -impl<'a> WeightedScore<'a> { - pub fn new(details: &'a [ScoreDetails], weight: f64) -> Self { - Self { details, weight } - } - - pub fn weighted_global_score(&self) -> f64 { - ScoreDetails::global_score(self.details.iter()) * self.weight - } - - pub fn compare_weighted_global_scores(&self, other: &Self) -> Ordering { - self.weighted_global_score() - .partial_cmp(&other.weighted_global_score()) - // both are numbers, possibly infinite - .unwrap() - } - - pub fn compare(&self, other: &Self) -> Ordering { - let mut left_it = ScoreDetails::score_values(self.details.iter()); - let mut right_it = ScoreDetails::score_values(other.details.iter()); - - loop { - let left = left_it.next(); - let right = right_it.next(); - - match (left, right) { - (None, None) => return Ordering::Equal, - (None, Some(_)) => return Ordering::Less, - (Some(_), None) => return Ordering::Greater, - (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { - let left = left * self.weight; - let right = right * other.weight; - if (left - right).abs() <= f64::EPSILON { - continue; - } - return left.partial_cmp(&right).unwrap(); - } - (Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => { - match left.partial_cmp(right) { - Some(Ordering::Equal) => continue, - Some(order) => return order, - None => return self.compare_weighted_global_scores(other), - } - } - (Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => { - match left.partial_cmp(right) { - Some(Ordering::Equal) => continue, - Some(order) => return order, - None => { - return self.compare_weighted_global_scores(other); - } - } - } - // not comparable details, use global - (Some(ScoreValue::Score(_)), Some(_)) - | (Some(_), Some(ScoreValue::Score(_))) - | (Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_))) - | (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => { - let left_count = left_it.count(); - let right_count = right_it.count(); - // compare how many remaining groups of rules each side has. - // the group with the most remaining groups wins. - return left_count - .cmp(&right_count) - // breaks ties with the global ranking score - .then_with(|| self.compare_weighted_global_scores(other)); - } - } - } - } -} - -struct QueryByIndex { - query: SearchQuery, - federation_options: FederationOptions, - query_index: usize, -} - -struct SearchResultByQuery<'a> { - documents_ids: Vec, - document_scores: Vec>, - federation_options: FederationOptions, - hit_maker: HitMaker<'a>, - query_index: usize, -} - -struct SearchResultByQueryIter<'a> { - it: Zip, IntoIter>>, - federation_options: FederationOptions, - hit_maker: Rc>, - query_index: usize, -} - -impl<'a> SearchResultByQueryIter<'a> { - fn new( - SearchResultByQuery { - documents_ids, - document_scores, - federation_options, - hit_maker, - query_index, - }: SearchResultByQuery<'a>, - ) -> Self { - let it = documents_ids.into_iter().zip(document_scores); - Self { it, federation_options, hit_maker: Rc::new(hit_maker), query_index } - } -} - -struct SearchResultByQueryIterItem<'a> { - docid: DocumentId, - score: Vec, - federation_options: FederationOptions, - hit_maker: Rc>, - query_index: usize, -} - -fn merge_index_local_results( - results_by_query: Vec>, -) -> impl Iterator + '_ { - itertools::kmerge_by( - results_by_query.into_iter().map(SearchResultByQueryIter::new), - |left: &SearchResultByQueryIterItem, right: &SearchResultByQueryIterItem| { - let left_score = WeightedScore::new(&left.score, *left.federation_options.weight); - let right_score = WeightedScore::new(&right.score, *right.federation_options.weight); - - match left_score.compare(&right_score) { - // the biggest score goes first - Ordering::Greater => true, - // break ties using query index - Ordering::Equal => left.query_index < right.query_index, - Ordering::Less => false, - } - }, - ) -} - -fn merge_index_global_results( - results_by_index: Vec, -) -> impl Iterator { - itertools::kmerge_by( - results_by_index.into_iter().map(|result_by_index| result_by_index.hits.into_iter()), - |left: &SearchHitByIndex, right: &SearchHitByIndex| { - let left_score = WeightedScore::new(&left.score, *left.federation_options.weight); - let right_score = WeightedScore::new(&right.score, *right.federation_options.weight); - - match left_score.compare(&right_score) { - // the biggest score goes first - Ordering::Greater => true, - // break ties using query index - Ordering::Equal => left.query_index < right.query_index, - Ordering::Less => false, - } - }, - ) -} - -impl<'a> Iterator for SearchResultByQueryIter<'a> { - type Item = SearchResultByQueryIterItem<'a>; - - fn next(&mut self) -> Option { - let (docid, score) = self.it.next()?; - Some(SearchResultByQueryIterItem { - docid, - score, - federation_options: self.federation_options, - hit_maker: Rc::clone(&self.hit_maker), - query_index: self.query_index, - }) - } -} - -struct SearchHitByIndex { - hit: SearchHit, - score: Vec, - federation_options: FederationOptions, - query_index: usize, -} - -struct SearchResultByIndex { - index: String, - hits: Vec, - estimated_total_hits: usize, - degraded: bool, - used_negative_operator: bool, - facets: Option, -} - -#[derive(Debug, Clone, Default, Serialize, ToSchema)] -pub struct FederatedFacets(pub BTreeMap); - -impl FederatedFacets { - pub fn insert(&mut self, index: String, facets: Option) { - if let Some(facets) = facets { - self.0.insert(index, facets); - } - } - - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - - pub fn merge( - self, - MergeFacets { max_values_per_facet }: MergeFacets, - facet_order: BTreeMap, - ) -> Option { - if self.is_empty() { - return None; - } - - let mut distribution: BTreeMap = Default::default(); - let mut stats: BTreeMap = Default::default(); - - for facets_by_index in self.0.into_values() { - for (facet, index_distribution) in facets_by_index.distribution { - match distribution.entry(facet) { - std::collections::btree_map::Entry::Vacant(entry) => { - entry.insert(index_distribution); - } - std::collections::btree_map::Entry::Occupied(mut entry) => { - let distribution = entry.get_mut(); - - for (value, index_count) in index_distribution { - distribution - .entry(value) - .and_modify(|count| *count += index_count) - .or_insert(index_count); - } - } - } - } - - for (facet, index_stats) in facets_by_index.stats { - match stats.entry(facet) { - std::collections::btree_map::Entry::Vacant(entry) => { - entry.insert(index_stats); - } - std::collections::btree_map::Entry::Occupied(mut entry) => { - let stats = entry.get_mut(); - - stats.min = f64::min(stats.min, index_stats.min); - stats.max = f64::max(stats.max, index_stats.max); - } - } - } - } - - // fixup order - for (facet, values) in &mut distribution { - let order_by = facet_order.get(facet).map(|(_, order)| *order).unwrap_or_default(); - - match order_by { - OrderBy::Lexicographic => { - values.sort_unstable_by(|left, _, right, _| left.cmp(right)) - } - OrderBy::Count => { - values.sort_unstable_by(|_, left, _, right| { - left.cmp(right) - // biggest first - .reverse() - }) - } - } - - if let Some(max_values_per_facet) = max_values_per_facet { - values.truncate(max_values_per_facet) - }; - } - - Some(ComputedFacets { distribution, stats }) - } -} - -pub fn perform_federated_search( - index_scheduler: &IndexScheduler, - queries: Vec, - mut federation: Federation, - features: RoFeatures, -) -> Result { - let before_search = std::time::Instant::now(); - - // this implementation partition the queries by index to guarantee an important property: - // - all the queries to a particular index use the same read transaction. - // This is an important property, otherwise we cannot guarantee the self-consistency of the results. - - // 1. partition queries by index - let mut queries_by_index: BTreeMap> = Default::default(); - for (query_index, federated_query) in queries.into_iter().enumerate() { - if let Some(pagination_field) = federated_query.has_pagination() { - return Err(MeilisearchHttpError::PaginationInFederatedQuery( - query_index, - pagination_field, - ) - .into()); - } - - if let Some(facets) = federated_query.has_facets() { - let facets = facets.to_owned(); - return Err(MeilisearchHttpError::FacetsInFederatedQuery( - query_index, - federated_query.index_uid.into_inner(), - facets, - ) - .into()); - } - - let (index_uid, query, federation_options) = federated_query.into_index_query_federation(); - - queries_by_index.entry(index_uid.into_inner()).or_default().push(QueryByIndex { - query, - federation_options: federation_options.unwrap_or_default(), - query_index, - }) - } - - // 2. perform queries, merge and make hits index by index - let required_hit_count = federation.limit + federation.offset; - - // In step (2), semantic_hit_count will be set to Some(0) if any search kind uses semantic - // Then in step (3), we'll update its value if there is any semantic search - let mut semantic_hit_count = None; - let mut results_by_index = Vec::with_capacity(queries_by_index.len()); - let mut previous_query_data: Option<(RankingRules, usize, String)> = None; - - // remember the order and name of first index for each facet when merging with index settings - // to detect if the order is inconsistent for a facet. - let mut facet_order: Option> = match federation.merge_facets - { - Some(MergeFacets { .. }) => Some(Default::default()), - _ => None, - }; - - for (index_uid, queries) in queries_by_index { - let first_query_index = queries.first().map(|query| query.query_index); - - let index = match index_scheduler.index(&index_uid) { - Ok(index) => index, - Err(err) => { - let mut err = ResponseError::from(err); - // Patch the HTTP status code to 400 as it defaults to 404 for `index_not_found`, but - // here the resource not found is not part of the URL. - err.code = StatusCode::BAD_REQUEST; - if let Some(query_index) = first_query_index { - err.message = format!("Inside `.queries[{}]`: {}", query_index, err.message); - } - return Err(err); - } - }; - - // Important: this is the only transaction we'll use for this index during this federated search - let rtxn = index.read_txn()?; - - let criteria = index.criteria(&rtxn)?; - - let dictionary = index.dictionary(&rtxn)?; - let dictionary: Option> = - dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); - let separators = index.allowed_separators(&rtxn)?; - let separators: Option> = - separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); - - // each query gets its individual cutoff - let cutoff = index.search_cutoff(&rtxn)?; - - let mut degraded = false; - let mut used_negative_operator = false; - let mut candidates = RoaringBitmap::new(); - - let facets_by_index = federation.facets_by_index.remove(&index_uid).flatten(); - - // TODO: recover the max size + facets_by_index as return value of this function so as not to ask it for all queries - if let Err(mut error) = - check_facet_order(&mut facet_order, &index_uid, &facets_by_index, &index, &rtxn) - { - error.message = format!( - "Inside `.federation.facetsByIndex.{index_uid}`: {error}{}", - if let Some(query_index) = first_query_index { - format!("\n - Note: index `{index_uid}` used in `.queries[{query_index}]`") - } else { - Default::default() - } - ); - return Err(error); - } - - // 2.1. Compute all candidates for each query in the index - let mut results_by_query = Vec::with_capacity(queries.len()); - - for QueryByIndex { query, federation_options, query_index } in queries { - // use an immediately invoked lambda to capture the result without returning from the function - - let res: Result<(), ResponseError> = (|| { - let search_kind = - search_kind(&query, index_scheduler, index_uid.to_string(), &index)?; - - let canonicalization_kind = match (&search_kind, &query.q) { - (SearchKind::SemanticOnly { .. }, _) => { - ranking_rules::CanonicalizationKind::Vector - } - (_, Some(q)) if !q.is_empty() => ranking_rules::CanonicalizationKind::Keyword, - _ => ranking_rules::CanonicalizationKind::Placeholder, - }; - - let sort = if let Some(sort) = &query.sort { - let sorts: Vec<_> = - match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() { - Ok(sorts) => sorts, - Err(asc_desc_error) => { - return Err(milli::Error::from(milli::SortError::from( - asc_desc_error, - )) - .into()) - } - }; - Some(sorts) - } else { - None - }; - - let ranking_rules = ranking_rules::RankingRules::new( - criteria.clone(), - sort, - query.matching_strategy.into(), - canonicalization_kind, - ); - - if let Some((previous_ranking_rules, previous_query_index, previous_index_uid)) = - previous_query_data.take() - { - if let Err(error) = ranking_rules.is_compatible_with(&previous_ranking_rules) { - return Err(error.to_response_error( - &ranking_rules, - &previous_ranking_rules, - query_index, - previous_query_index, - &index_uid, - &previous_index_uid, - )); - } - previous_query_data = if previous_ranking_rules.constraint_count() - > ranking_rules.constraint_count() - { - Some((previous_ranking_rules, previous_query_index, previous_index_uid)) - } else { - Some((ranking_rules, query_index, index_uid.clone())) - }; - } else { - previous_query_data = Some((ranking_rules, query_index, index_uid.clone())); - } - - match search_kind { - SearchKind::KeywordOnly => {} - _ => semantic_hit_count = Some(0), - } - - let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors); - - let time_budget = match cutoff { - Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)), - None => TimeBudget::default(), - }; - - let (mut search, _is_finite_pagination, _max_total_hits, _offset) = - prepare_search(&index, &rtxn, &query, &search_kind, time_budget, features)?; - - search.scoring_strategy(milli::score_details::ScoringStrategy::Detailed); - search.offset(0); - search.limit(required_hit_count); - - let (result, _semantic_hit_count) = - super::search_from_kind(index_uid.to_string(), search_kind, search)?; - let format = AttributesFormat { - attributes_to_retrieve: query.attributes_to_retrieve, - retrieve_vectors, - attributes_to_highlight: query.attributes_to_highlight, - attributes_to_crop: query.attributes_to_crop, - crop_length: query.crop_length, - crop_marker: query.crop_marker, - highlight_pre_tag: query.highlight_pre_tag, - highlight_post_tag: query.highlight_post_tag, - show_matches_position: query.show_matches_position, - sort: query.sort, - show_ranking_score: query.show_ranking_score, - show_ranking_score_details: query.show_ranking_score_details, - locales: query.locales.map(|l| l.iter().copied().map(Into::into).collect()), - }; - - let milli::SearchResult { - matching_words, - candidates: query_candidates, - documents_ids, - document_scores, - degraded: query_degraded, - used_negative_operator: query_used_negative_operator, - } = result; - - candidates |= query_candidates; - degraded |= query_degraded; - used_negative_operator |= query_used_negative_operator; - - let tokenizer = HitMaker::tokenizer(dictionary.as_deref(), separators.as_deref()); - - let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer); - - let hit_maker = - HitMaker::new(&index, &rtxn, format, formatter_builder).map_err(|e| { - MeilisearchHttpError::from_milli(e, Some(index_uid.to_string())) - })?; - - results_by_query.push(SearchResultByQuery { - federation_options, - hit_maker, - query_index, - documents_ids, - document_scores, - }); - Ok(()) - })(); - - if let Err(mut error) = res { - error.message = format!("Inside `.queries[{query_index}]`: {}", error.message); - return Err(error); - } - } - // 2.2. merge inside index - let mut documents_seen = RoaringBitmap::new(); - let merged_result: Result, ResponseError> = - merge_index_local_results(results_by_query) - // skip documents we've already seen & mark that we saw the current document - .filter(|SearchResultByQueryIterItem { docid, .. }| documents_seen.insert(*docid)) - .take(required_hit_count) - // 2.3 make hits - .map( - |SearchResultByQueryIterItem { - docid, - score, - federation_options, - hit_maker, - query_index, - }| { - let mut hit = hit_maker.make_hit(docid, &score)?; - let weighted_score = - ScoreDetails::global_score(score.iter()) * (*federation_options.weight); - - let _federation = serde_json::json!( - { - "indexUid": index_uid, - "queriesPosition": query_index, - "weightedRankingScore": weighted_score, - } - ); - hit.document.insert("_federation".to_string(), _federation); - Ok(SearchHitByIndex { hit, score, federation_options, query_index }) - }, - ) - .collect(); - - let merged_result = merged_result?; - - let estimated_total_hits = candidates.len() as usize; - - let facets = facets_by_index - .map(|facets_by_index| { - compute_facet_distribution_stats( - &facets_by_index, - &index, - &rtxn, - candidates, - super::Route::MultiSearch, - ) - }) - .transpose() - .map_err(|mut error| { - error.message = format!( - "Inside `.federation.facetsByIndex.{index_uid}`: {}{}", - error.message, - if let Some(query_index) = first_query_index { - format!("\n - Note: index `{index_uid}` used in `.queries[{query_index}]`") - } else { - Default::default() - } - ); - error - })?; - - results_by_index.push(SearchResultByIndex { - index: index_uid, - hits: merged_result, - estimated_total_hits, - degraded, - used_negative_operator, - facets, - }); - } - - // bonus step, make sure to return an error if an index wants a non-faceted field, even if no query actually uses that index. - for (index_uid, facets) in federation.facets_by_index { - let index = match index_scheduler.index(&index_uid) { - Ok(index) => index, - Err(err) => { - let mut err = ResponseError::from(err); - // Patch the HTTP status code to 400 as it defaults to 404 for `index_not_found`, but - // here the resource not found is not part of the URL. - err.code = StatusCode::BAD_REQUEST; - err.message = format!( - "Inside `.federation.facetsByIndex.{index_uid}`: {}\n - Note: index `{index_uid}` is not used in queries", - err.message - ); - return Err(err); - } - }; - - // Important: this is the only transaction we'll use for this index during this federated search - let rtxn = index.read_txn()?; - - if let Err(mut error) = - check_facet_order(&mut facet_order, &index_uid, &facets, &index, &rtxn) - { - error.message = format!( - "Inside `.federation.facetsByIndex.{index_uid}`: {error}\n - Note: index `{index_uid}` is not used in queries", - ); - return Err(error); - } - - if let Some(facets) = facets { - if let Err(mut error) = compute_facet_distribution_stats( - &facets, - &index, - &rtxn, - Default::default(), - super::Route::MultiSearch, - ) { - error.message = - format!("Inside `.federation.facetsByIndex.{index_uid}`: {}\n - Note: index `{index_uid}` is not used in queries", error.message); - return Err(error); - } - } - } - - // 3. merge hits and metadata across indexes - // 3.1 merge metadata - let (estimated_total_hits, degraded, used_negative_operator, facets) = { - let mut estimated_total_hits = 0; - let mut degraded = false; - let mut used_negative_operator = false; - - let mut facets: FederatedFacets = FederatedFacets::default(); - - for SearchResultByIndex { - index, - hits: _, - estimated_total_hits: estimated_total_hits_by_index, - facets: facets_by_index, - degraded: degraded_by_index, - used_negative_operator: used_negative_operator_by_index, - } in &mut results_by_index - { - estimated_total_hits += *estimated_total_hits_by_index; - degraded |= *degraded_by_index; - used_negative_operator |= *used_negative_operator_by_index; - - let facets_by_index = std::mem::take(facets_by_index); - let index = std::mem::take(index); - - facets.insert(index, facets_by_index); - } - - (estimated_total_hits, degraded, used_negative_operator, facets) - }; - - // 3.2 merge hits - let merged_hits: Vec<_> = merge_index_global_results(results_by_index) - .skip(federation.offset) - .take(federation.limit) - .inspect(|hit| { - if let Some(semantic_hit_count) = &mut semantic_hit_count { - if hit.score.iter().any(|score| matches!(&score, ScoreDetails::Vector(_))) { - *semantic_hit_count += 1; - } - } - }) - .map(|hit| hit.hit) - .collect(); - - let (facet_distribution, facet_stats, facets_by_index) = - match federation.merge_facets.zip(facet_order) { - Some((merge_facets, facet_order)) => { - let facets = facets.merge(merge_facets, facet_order); - - let (facet_distribution, facet_stats) = facets - .map(|ComputedFacets { distribution, stats }| (distribution, stats)) - .unzip(); - - (facet_distribution, facet_stats, FederatedFacets::default()) - } - None => (None, None, facets), - }; - - let search_result = FederatedSearchResult { - hits: merged_hits, - processing_time_ms: before_search.elapsed().as_millis(), - hits_info: HitsInfo::OffsetLimit { - limit: federation.limit, - offset: federation.offset, - estimated_total_hits, - }, - semantic_hit_count, - degraded, - used_negative_operator, - facet_distribution, - facet_stats, - facets_by_index, - }; - - Ok(search_result) -} - -fn check_facet_order( - facet_order: &mut Option>, - current_index: &str, - facets_by_index: &Option>, - index: &milli::Index, - rtxn: &milli::heed::RoTxn<'_>, -) -> Result<(), ResponseError> { - if let (Some(facet_order), Some(facets_by_index)) = (facet_order, facets_by_index) { - let index_facet_order = index.sort_facet_values_by(rtxn)?; - for facet in facets_by_index { - let index_facet_order = index_facet_order.get(facet); - let (previous_index, previous_facet_order) = facet_order - .entry(facet.to_owned()) - .or_insert_with(|| (current_index.to_owned(), index_facet_order)); - if previous_facet_order != &index_facet_order { - return Err(MeilisearchHttpError::InconsistentFacetOrder { - facet: facet.clone(), - previous_facet_order: *previous_facet_order, - previous_uid: previous_index.clone(), - current_uid: current_index.to_owned(), - index_facet_order, - } - .into()); - } - } - }; - Ok(()) -} diff --git a/crates/meilisearch/src/search/federated/mod.rs b/crates/meilisearch/src/search/federated/mod.rs new file mode 100644 index 000000000..40204c591 --- /dev/null +++ b/crates/meilisearch/src/search/federated/mod.rs @@ -0,0 +1,10 @@ +mod perform; +mod proxy; +mod types; +mod weighted_scores; + +pub use perform::perform_federated_search; +pub use proxy::{PROXY_SEARCH_HEADER, PROXY_SEARCH_HEADER_VALUE}; +pub use types::{ + FederatedSearch, FederatedSearchResult, Federation, FederationOptions, MergeFacets, +}; diff --git a/crates/meilisearch/src/search/federated/perform.rs b/crates/meilisearch/src/search/federated/perform.rs new file mode 100644 index 000000000..9092b3dbf --- /dev/null +++ b/crates/meilisearch/src/search/federated/perform.rs @@ -0,0 +1,1068 @@ +use std::cmp::Ordering; +use std::collections::BTreeMap; +use std::iter::Zip; +use std::rc::Rc; +use std::str::FromStr as _; +use std::time::{Duration, Instant}; +use std::vec::{IntoIter, Vec}; + +use actix_http::StatusCode; +use index_scheduler::{IndexScheduler, RoFeatures}; +use itertools::Itertools; +use meilisearch_types::error::ResponseError; +use meilisearch_types::features::{Network, Remote}; +use meilisearch_types::milli::order_by_map::OrderByMap; +use meilisearch_types::milli::score_details::{ScoreDetails, WeightedScoreValue}; +use meilisearch_types::milli::{self, DocumentId, OrderBy, TimeBudget, DEFAULT_VALUES_PER_FACET}; +use roaring::RoaringBitmap; +use tokio::task::JoinHandle; + +use super::super::ranking_rules::{self, RankingRules}; +use super::super::{ + compute_facet_distribution_stats, prepare_search, AttributesFormat, ComputedFacets, HitMaker, + HitsInfo, RetrieveVectors, SearchHit, SearchKind, SearchQuery, SearchQueryWithIndex, +}; +use super::proxy::{proxy_search, ProxySearchError, ProxySearchParams}; +use super::types::{ + FederatedFacets, FederatedSearchResult, Federation, FederationOptions, MergeFacets, Weight, + FEDERATION_HIT, FEDERATION_REMOTE, WEIGHTED_SCORE_VALUES, +}; +use super::weighted_scores; +use crate::error::MeilisearchHttpError; +use crate::routes::indexes::search::search_kind; +use crate::search::federated::types::{INDEX_UID, QUERIES_POSITION, WEIGHTED_RANKING_SCORE}; + +pub async fn perform_federated_search( + index_scheduler: &IndexScheduler, + queries: Vec, + federation: Federation, + features: RoFeatures, + is_proxy: bool, +) -> Result { + if is_proxy { + features.check_proxy_search("Performing a proxy search")?; + } + let before_search = std::time::Instant::now(); + let deadline = before_search + std::time::Duration::from_secs(9); + + let required_hit_count = federation.limit + federation.offset; + + let network = index_scheduler.network(); + + // this implementation partition the queries by index to guarantee an important property: + // - all the queries to a particular index use the same read transaction. + // This is an important property, otherwise we cannot guarantee the self-consistency of the results. + + // 1. partition queries by host and index + let mut partitioned_queries = PartitionedQueries::new(); + for (query_index, federated_query) in queries.into_iter().enumerate() { + partitioned_queries.partition(federated_query, query_index, &network)? + } + partitioned_queries.check_features(features)?; + + // 2. perform queries, merge and make hits index by index + // 2.1. start remote queries + let remote_search = + RemoteSearch::start(partitioned_queries.remote_queries_by_host, &federation, deadline); + + // 2.2. concurrently execute local queries + let params = SearchByIndexParams { + index_scheduler, + features, + is_proxy, + network: &network, + has_remote: partitioned_queries.has_remote, + required_hit_count, + }; + let mut search_by_index = SearchByIndex::new( + federation, + partitioned_queries.local_queries_by_index.len(), + params.has_remote, + ); + + for (index_uid, queries) in partitioned_queries.local_queries_by_index { + // note: this is the only place we open `index_uid` + search_by_index.execute(index_uid, queries, ¶ms)?; + } + + // bonus step, make sure to return an error if an index wants a non-faceted field, even if no query actually uses that index. + search_by_index.check_unused_facets(index_scheduler)?; + + let SearchByIndex { + federation, + mut semantic_hit_count, + mut results_by_index, + previous_query_data: _, + facet_order, + } = search_by_index; + + // 2.3. Wait for proxy search requests to complete + let (mut remote_results, remote_errors) = remote_search.finish().await; + + // 3. merge hits and metadata across indexes and hosts + // 3.1. merge metadata + let (estimated_total_hits, degraded, used_negative_operator, facets) = + merge_metadata(&mut results_by_index, &remote_results); + + // 3.2. merge hits + let merged_hits: Vec<_> = merge_index_global_results(results_by_index, &mut remote_results) + .skip(federation.offset) + .take(federation.limit) + .inspect(|hit| { + if let Some(semantic_hit_count) = &mut semantic_hit_count { + if hit.to_score().0.any(|score| matches!(&score, WeightedScoreValue::VectorSort(_))) + { + *semantic_hit_count += 1; + } + } + }) + .map(|hit| hit.hit()) + .collect(); + + // 3.3. merge facets + let (facet_distribution, facet_stats, facets_by_index) = + facet_order.merge(federation.merge_facets, remote_results, facets); + + Ok(FederatedSearchResult { + hits: merged_hits, + processing_time_ms: before_search.elapsed().as_millis(), + hits_info: HitsInfo::OffsetLimit { + limit: federation.limit, + offset: federation.offset, + estimated_total_hits, + }, + semantic_hit_count, + degraded, + used_negative_operator, + facet_distribution, + facet_stats, + facets_by_index, + remote_errors: partitioned_queries.has_remote.then_some(remote_errors), + }) +} + +struct QueryByIndex { + query: SearchQuery, + weight: Weight, + query_index: usize, +} + +struct SearchResultByQuery<'a> { + documents_ids: Vec, + document_scores: Vec>, + weight: Weight, + hit_maker: HitMaker<'a>, + query_index: usize, +} + +struct SearchResultByQueryIter<'a> { + it: Zip, IntoIter>>, + weight: Weight, + hit_maker: Rc>, + query_index: usize, +} + +impl<'a> SearchResultByQueryIter<'a> { + fn new( + SearchResultByQuery { + documents_ids, + document_scores, + weight, + hit_maker, + query_index, + }: SearchResultByQuery<'a>, + ) -> Self { + let it = documents_ids.into_iter().zip(document_scores); + Self { it, weight, hit_maker: Rc::new(hit_maker), query_index } + } +} + +struct SearchResultByQueryIterItem<'a> { + docid: DocumentId, + score: Vec, + weight: Weight, + hit_maker: Rc>, + query_index: usize, +} + +fn merge_index_local_results( + results_by_query: Vec>, +) -> impl Iterator + '_ { + itertools::kmerge_by( + results_by_query.into_iter().map(SearchResultByQueryIter::new), + |left: &SearchResultByQueryIterItem, right: &SearchResultByQueryIterItem| { + match weighted_scores::compare( + ScoreDetails::weighted_score_values(left.score.iter(), *left.weight), + ScoreDetails::global_score(left.score.iter()) * *left.weight, + ScoreDetails::weighted_score_values(right.score.iter(), *right.weight), + ScoreDetails::global_score(right.score.iter()) * *right.weight, + ) { + // the biggest score goes first + Ordering::Greater => true, + // break ties using query index + Ordering::Equal => left.query_index < right.query_index, + Ordering::Less => false, + } + }, + ) +} + +fn merge_index_global_results( + results_by_index: Vec, + remote_results: &mut [FederatedSearchResult], +) -> impl Iterator + '_ { + itertools::kmerge_by( + // local results + results_by_index + .into_iter() + .map(|result_by_index| { + either::Either::Left(result_by_index.hits.into_iter().map(MergedSearchHit::Local)) + }) + // remote results + .chain(remote_results.iter_mut().map(|x| either::Either::Right(iter_remote_hits(x)))), + |left: &MergedSearchHit, right: &MergedSearchHit| { + let (left_it, left_weighted_global_score, left_query_index) = left.to_score(); + let (right_it, right_weighted_global_score, right_query_index) = right.to_score(); + + match weighted_scores::compare( + left_it, + left_weighted_global_score, + right_it, + right_weighted_global_score, + ) { + // the biggest score goes first + Ordering::Greater => true, + // break ties using query index + Ordering::Equal => left_query_index < right_query_index, + Ordering::Less => false, + } + }, + ) +} + +enum MergedSearchHit { + Local(SearchHitByIndex), + Remote { + hit: SearchHit, + score: Vec, + global_weighted_score: f64, + query_index: usize, + }, +} + +impl MergedSearchHit { + fn remote(mut hit: SearchHit) -> Result { + let federation = hit + .document + .get_mut(FEDERATION_HIT) + .ok_or(ProxySearchError::MissingPathInResponse("._federation"))?; + let federation = match federation.as_object_mut() { + Some(federation) => federation, + None => { + return Err(ProxySearchError::UnexpectedValueInPath { + path: "._federation", + expected_type: "map", + received_value: federation.to_string(), + }); + } + }; + + let global_weighted_score = federation + .get(WEIGHTED_RANKING_SCORE) + .ok_or(ProxySearchError::MissingPathInResponse("._federation.weightedRankingScore"))?; + let global_weighted_score = global_weighted_score.as_f64().ok_or_else(|| { + ProxySearchError::UnexpectedValueInPath { + path: "._federation.weightedRankingScore", + expected_type: "number", + received_value: global_weighted_score.to_string(), + } + })?; + + let score: Vec = + serde_json::from_value(federation.remove(WEIGHTED_SCORE_VALUES).ok_or( + ProxySearchError::MissingPathInResponse("._federation.weightedScoreValues"), + )?) + .map_err(ProxySearchError::CouldNotParseWeightedScoreValues)?; + + let query_index = federation + .get(QUERIES_POSITION) + .ok_or(ProxySearchError::MissingPathInResponse("._federation.queriesPosition"))?; + let query_index = + query_index.as_u64().ok_or_else(|| ProxySearchError::UnexpectedValueInPath { + path: "._federation.queriesPosition", + expected_type: "integer", + received_value: query_index.to_string(), + })? as usize; + + Ok(Self::Remote { hit, score, global_weighted_score, query_index }) + } + + fn hit(self) -> SearchHit { + match self { + MergedSearchHit::Local(search_hit_by_index) => search_hit_by_index.hit, + MergedSearchHit::Remote { hit, .. } => hit, + } + } + + fn to_score(&self) -> (impl Iterator + '_, f64, usize) { + match self { + MergedSearchHit::Local(search_hit_by_index) => ( + either::Left(ScoreDetails::weighted_score_values( + search_hit_by_index.score.iter(), + *search_hit_by_index.weight, + )), + ScoreDetails::global_score(search_hit_by_index.score.iter()) + * *search_hit_by_index.weight, + search_hit_by_index.query_index, + ), + MergedSearchHit::Remote { hit: _, score, global_weighted_score, query_index } => { + let global_weighted_score = *global_weighted_score; + let query_index = *query_index; + (either::Right(score.iter().cloned()), global_weighted_score, query_index) + } + } + } +} + +fn iter_remote_hits( + results_by_host: &mut FederatedSearchResult, +) -> impl Iterator + '_ { + // have a per node registry of failed hits + results_by_host.hits.drain(..).filter_map(|hit| match MergedSearchHit::remote(hit) { + Ok(hit) => Some(hit), + Err(err) => { + tracing::warn!("skipping remote hit due to error: {err}"); + None + } + }) +} + +impl<'a> Iterator for SearchResultByQueryIter<'a> { + type Item = SearchResultByQueryIterItem<'a>; + + fn next(&mut self) -> Option { + let (docid, score) = self.it.next()?; + Some(SearchResultByQueryIterItem { + docid, + score, + weight: self.weight, + hit_maker: Rc::clone(&self.hit_maker), + query_index: self.query_index, + }) + } +} + +struct SearchHitByIndex { + hit: SearchHit, + score: Vec, + weight: Weight, + query_index: usize, +} + +struct SearchResultByIndex { + index: String, + hits: Vec, + estimated_total_hits: usize, + degraded: bool, + used_negative_operator: bool, + facets: Option, +} + +fn merge_metadata( + results_by_index: &mut Vec, + remote_results: &Vec, +) -> (usize, bool, bool, FederatedFacets) { + let mut estimated_total_hits = 0; + let mut degraded = false; + let mut used_negative_operator = false; + let mut facets: FederatedFacets = FederatedFacets::default(); + for SearchResultByIndex { + index, + hits: _, + estimated_total_hits: estimated_total_hits_by_index, + facets: facets_by_index, + degraded: degraded_by_index, + used_negative_operator: used_negative_operator_by_index, + } in results_by_index + { + estimated_total_hits += *estimated_total_hits_by_index; + degraded |= *degraded_by_index; + used_negative_operator |= *used_negative_operator_by_index; + + let facets_by_index = std::mem::take(facets_by_index); + let index = std::mem::take(index); + + facets.insert(index, facets_by_index); + } + for FederatedSearchResult { + hits: _, + processing_time_ms: _, + hits_info, + semantic_hit_count: _, + facet_distribution: _, + facet_stats: _, + facets_by_index: _, + degraded: degraded_for_host, + used_negative_operator: host_used_negative_operator, + remote_errors: _, + } in remote_results + { + estimated_total_hits += match hits_info { + HitsInfo::Pagination { total_hits: estimated_total_hits, .. } + | HitsInfo::OffsetLimit { estimated_total_hits, .. } => estimated_total_hits, + }; + degraded |= degraded_for_host; + used_negative_operator |= host_used_negative_operator; + } + (estimated_total_hits, degraded, used_negative_operator, facets) +} + +type LocalQueriesByIndex = BTreeMap>; +type RemoteQueriesByHost = BTreeMap)>; + +struct PartitionedQueries { + local_queries_by_index: LocalQueriesByIndex, + remote_queries_by_host: RemoteQueriesByHost, + has_remote: bool, + has_query_position: bool, +} + +impl PartitionedQueries { + fn new() -> PartitionedQueries { + PartitionedQueries { + local_queries_by_index: Default::default(), + remote_queries_by_host: Default::default(), + has_remote: false, + has_query_position: false, + } + } + + fn partition( + &mut self, + federated_query: SearchQueryWithIndex, + query_index: usize, + network: &Network, + ) -> Result<(), ResponseError> { + if let Some(pagination_field) = federated_query.has_pagination() { + return Err(MeilisearchHttpError::PaginationInFederatedQuery( + query_index, + pagination_field, + ) + .into()); + } + + if let Some(facets) = federated_query.has_facets() { + let facets = facets.to_owned(); + return Err(MeilisearchHttpError::FacetsInFederatedQuery( + query_index, + federated_query.index_uid.into_inner(), + facets, + ) + .into()); + } + + let (index_uid, query, federation_options) = federated_query.into_index_query_federation(); + + let federation_options = federation_options.unwrap_or_default(); + + // local or remote node? + 'local_query: { + let queries_by_index = match federation_options.remote { + None => self.local_queries_by_index.entry(index_uid.into_inner()).or_default(), + Some(remote_name) => { + self.has_remote = true; + match &network.local { + Some(local) if local == &remote_name => { + self.local_queries_by_index.entry(index_uid.into_inner()).or_default() + } + _ => { + // node from the network + let Some(remote) = network.remotes.get(&remote_name) else { + return Err(ResponseError::from_msg(format!("Invalid `queries[{query_index}].federation_options.remote`: remote `{remote_name}` is not registered"), + meilisearch_types::error::Code::InvalidMultiSearchRemote)); + }; + let query = SearchQueryWithIndex::from_index_query_federation( + index_uid, + query, + Some(FederationOptions { + weight: federation_options.weight, + // do not pass the `remote` to not require the remote instance to have itself has a local node + remote: None, + // pass an explicit query index + query_position: Some(query_index), + }), + ); + + self.remote_queries_by_host + .entry(remote_name) + .or_insert_with(|| (remote.clone(), Default::default())) + .1 + .push(query); + break 'local_query; + } + } + } + }; + + queries_by_index.push(QueryByIndex { + query, + weight: federation_options.weight, + // override query index here with the one in federation. + // this will fix-up error messages to refer to the global query index of the original request. + query_index: if let Some(query_index) = federation_options.query_position { + self.has_query_position = true; + query_index + } else { + query_index + }, + }) + } + Ok(()) + } + + fn check_features(&self, features: RoFeatures) -> Result<(), ResponseError> { + if self.has_remote { + features.check_proxy_search("Performing a proxy search")?; + } + + if self.has_query_position { + features.check_proxy_search("Using `federationOptions.queryPosition`")?; + } + Ok(()) + } +} + +struct RemoteSearch { + in_flight_remote_queries: + BTreeMap>>, +} + +impl RemoteSearch { + fn start(queries: RemoteQueriesByHost, federation: &Federation, deadline: Instant) -> Self { + let mut in_flight_remote_queries = BTreeMap::new(); + let client = reqwest::ClientBuilder::new() + .connect_timeout(std::time::Duration::from_millis(200)) + .build() + .unwrap(); + let params = + ProxySearchParams { deadline: Some(deadline), try_count: 3, client: client.clone() }; + for (node_name, (node, queries)) in queries { + // spawn one task per host + in_flight_remote_queries.insert( + node_name, + tokio::spawn({ + let mut federation = federation.clone(); + // never merge distant facets + federation.merge_facets = None; + let params = params.clone(); + async move { proxy_search(&node, queries, federation, ¶ms).await } + }), + ); + } + Self { in_flight_remote_queries } + } + + async fn finish(self) -> (Vec, BTreeMap) { + let mut remote_results = Vec::with_capacity(self.in_flight_remote_queries.len()); + let mut remote_errors: BTreeMap = BTreeMap::new(); + 'remote_queries: for (node_name, handle) in self.in_flight_remote_queries { + match handle.await { + Ok(Ok(mut res)) => { + for hit in &mut res.hits { + let Some(federation) = hit.document.get_mut(FEDERATION_HIT) else { + let error = ProxySearchError::MissingPathInResponse("._federation"); + remote_errors.insert(node_name, error.as_response_error()); + continue 'remote_queries; + }; + let Some(federation) = federation.as_object_mut() else { + let error = ProxySearchError::UnexpectedValueInPath { + path: "._federation", + expected_type: "map", + received_value: federation.to_string(), + }; + remote_errors.insert(node_name, error.as_response_error()); + continue 'remote_queries; + }; + federation.insert( + FEDERATION_REMOTE.to_string(), + serde_json::Value::String(node_name.clone()), + ); + } + + remote_results.push(res); + } + Ok(Err(error)) => { + remote_errors.insert(node_name, error.as_response_error()); + } + Err(panic) => match panic.try_into_panic() { + Ok(panic) => { + let msg = match panic.downcast_ref::<&'static str>() { + Some(s) => *s, + None => match panic.downcast_ref::() { + Some(s) => &s[..], + None => "Box", + }, + }; + remote_errors.insert( + node_name, + ResponseError::from_msg( + msg.to_string(), + meilisearch_types::error::Code::Internal, + ), + ); + } + Err(_) => tracing::error!("proxy search task was unexpectedly cancelled"), + }, + } + } + (remote_results, remote_errors) + } +} + +struct SearchByIndexParams<'a> { + index_scheduler: &'a IndexScheduler, + required_hit_count: usize, + features: RoFeatures, + is_proxy: bool, + has_remote: bool, + network: &'a Network, +} + +struct SearchByIndex { + federation: Federation, + // During search by index, semantic_hit_count will be set to Some(0) if any search kind uses semantic + // Then when merging, we'll update its value if there is any semantic hit + semantic_hit_count: Option, + results_by_index: Vec, + previous_query_data: Option<(RankingRules, usize, String)>, + // remember the order and name of first index for each facet when merging with index settings + // to detect if the order is inconsistent for a facet. + facet_order: FacetOrder, +} + +impl SearchByIndex { + fn new(federation: Federation, index_count: usize, has_remote: bool) -> Self { + SearchByIndex { + facet_order: match (federation.merge_facets, has_remote) { + (None, true) => FacetOrder::ByIndex(Default::default()), + (None, false) => FacetOrder::None, + (Some(_), _) => FacetOrder::ByFacet(Default::default()), + }, + federation, + semantic_hit_count: None, + results_by_index: Vec::with_capacity(index_count), + previous_query_data: None, + } + } + + fn execute( + &mut self, + index_uid: String, + queries: Vec, + params: &SearchByIndexParams<'_>, + ) -> Result<(), ResponseError> { + let first_query_index = queries.first().map(|query| query.query_index); + let index = match params.index_scheduler.index(&index_uid) { + Ok(index) => index, + Err(err) => { + let mut err = ResponseError::from(err); + // Patch the HTTP status code to 400 as it defaults to 404 for `index_not_found`, but + // here the resource not found is not part of the URL. + err.code = StatusCode::BAD_REQUEST; + if let Some(query_index) = first_query_index { + err.message = format!("Inside `.queries[{}]`: {}", query_index, err.message); + } + return Err(err); + } + }; + let rtxn = index.read_txn()?; + let criteria = index.criteria(&rtxn)?; + let dictionary = index.dictionary(&rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); + let separators = index.allowed_separators(&rtxn)?; + let separators: Option> = + separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); + let cutoff = index.search_cutoff(&rtxn)?; + let mut degraded = false; + let mut used_negative_operator = false; + let mut candidates = RoaringBitmap::new(); + let facets_by_index = self.federation.facets_by_index.remove(&index_uid).flatten(); + if let Err(mut error) = + self.facet_order.check_facet_order(&index_uid, &facets_by_index, &index, &rtxn) + { + error.message = format!( + "Inside `.federation.facetsByIndex.{index_uid}`: {error}{}", + if let Some(query_index) = first_query_index { + format!("\n - Note: index `{index_uid}` used in `.queries[{query_index}]`") + } else { + Default::default() + } + ); + return Err(error); + } + let mut results_by_query = Vec::with_capacity(queries.len()); + for QueryByIndex { query, weight, query_index } in queries { + // use an immediately invoked lambda to capture the result without returning from the function + + let res: Result<(), ResponseError> = (|| { + let search_kind = + search_kind(&query, params.index_scheduler, index_uid.to_string(), &index)?; + + let canonicalization_kind = match (&search_kind, &query.q) { + (SearchKind::SemanticOnly { .. }, _) => { + ranking_rules::CanonicalizationKind::Vector + } + (_, Some(q)) if !q.is_empty() => ranking_rules::CanonicalizationKind::Keyword, + _ => ranking_rules::CanonicalizationKind::Placeholder, + }; + + let sort = if let Some(sort) = &query.sort { + let sorts: Vec<_> = + match sort.iter().map(|s| milli::AscDesc::from_str(s)).collect() { + Ok(sorts) => sorts, + Err(asc_desc_error) => { + return Err(milli::Error::from(milli::SortError::from( + asc_desc_error, + )) + .into()) + } + }; + Some(sorts) + } else { + None + }; + + let ranking_rules = ranking_rules::RankingRules::new( + criteria.clone(), + sort, + query.matching_strategy.into(), + canonicalization_kind, + ); + + if let Some((previous_ranking_rules, previous_query_index, previous_index_uid)) = + self.previous_query_data.take() + { + if let Err(error) = ranking_rules.is_compatible_with(&previous_ranking_rules) { + return Err(error.to_response_error( + &ranking_rules, + &previous_ranking_rules, + query_index, + previous_query_index, + &index_uid, + &previous_index_uid, + )); + } + self.previous_query_data = if previous_ranking_rules.constraint_count() + > ranking_rules.constraint_count() + { + Some((previous_ranking_rules, previous_query_index, previous_index_uid)) + } else { + Some((ranking_rules, query_index, index_uid.clone())) + }; + } else { + self.previous_query_data = + Some((ranking_rules, query_index, index_uid.clone())); + } + + match search_kind { + SearchKind::KeywordOnly => {} + _ => self.semantic_hit_count = Some(0), + } + + let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors); + + let time_budget = match cutoff { + Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)), + None => TimeBudget::default(), + }; + + let (mut search, _is_finite_pagination, _max_total_hits, _offset) = prepare_search( + &index, + &rtxn, + &query, + &search_kind, + time_budget, + params.features, + )?; + + search.scoring_strategy(milli::score_details::ScoringStrategy::Detailed); + search.offset(0); + search.limit(params.required_hit_count); + + let (result, _semantic_hit_count) = + super::super::search_from_kind(index_uid.to_string(), search_kind, search)?; + let format = AttributesFormat { + attributes_to_retrieve: query.attributes_to_retrieve, + retrieve_vectors, + attributes_to_highlight: query.attributes_to_highlight, + attributes_to_crop: query.attributes_to_crop, + crop_length: query.crop_length, + crop_marker: query.crop_marker, + highlight_pre_tag: query.highlight_pre_tag, + highlight_post_tag: query.highlight_post_tag, + show_matches_position: query.show_matches_position, + sort: query.sort, + show_ranking_score: query.show_ranking_score, + show_ranking_score_details: query.show_ranking_score_details, + locales: query.locales.map(|l| l.iter().copied().map(Into::into).collect()), + }; + + let milli::SearchResult { + matching_words, + candidates: query_candidates, + documents_ids, + document_scores, + degraded: query_degraded, + used_negative_operator: query_used_negative_operator, + } = result; + + candidates |= query_candidates; + degraded |= query_degraded; + used_negative_operator |= query_used_negative_operator; + + let tokenizer = HitMaker::tokenizer(dictionary.as_deref(), separators.as_deref()); + + let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer); + + let hit_maker = + HitMaker::new(&index, &rtxn, format, formatter_builder).map_err(|e| { + MeilisearchHttpError::from_milli(e, Some(index_uid.to_string())) + })?; + + results_by_query.push(SearchResultByQuery { + weight, + hit_maker, + query_index, + documents_ids, + document_scores, + }); + Ok(()) + })(); + + if let Err(mut error) = res { + error.message = format!("Inside `.queries[{query_index}]`: {}", error.message); + return Err(error); + } + } + let mut documents_seen = RoaringBitmap::new(); + let merged_result: Result, ResponseError> = + merge_index_local_results(results_by_query) + // skip documents we've already seen & mark that we saw the current document + .filter(|SearchResultByQueryIterItem { docid, .. }| documents_seen.insert(*docid)) + .take(params.required_hit_count) + // 2.3 make hits + .map( + |SearchResultByQueryIterItem { + docid, + score, + weight, + hit_maker, + query_index, + }| { + let mut hit = hit_maker.make_hit(docid, &score)?; + let weighted_score = ScoreDetails::global_score(score.iter()) * (*weight); + + let mut _federation = serde_json::json!( + { + INDEX_UID: index_uid, + QUERIES_POSITION: query_index, + WEIGHTED_RANKING_SCORE: weighted_score, + } + ); + if params.has_remote && !params.is_proxy { + _federation.as_object_mut().unwrap().insert( + FEDERATION_REMOTE.to_string(), + params.network.local.clone().into(), + ); + } + if params.is_proxy { + _federation.as_object_mut().unwrap().insert( + WEIGHTED_SCORE_VALUES.to_string(), + serde_json::json!(ScoreDetails::weighted_score_values( + score.iter(), + *weight + ) + .collect_vec()), + ); + } + hit.document.insert(FEDERATION_HIT.to_string(), _federation); + Ok(SearchHitByIndex { hit, score, weight, query_index }) + }, + ) + .collect(); + let merged_result = merged_result?; + let estimated_total_hits = candidates.len() as usize; + let facets = facets_by_index + .map(|facets_by_index| { + compute_facet_distribution_stats( + &facets_by_index, + &index, + &rtxn, + candidates, + super::super::Route::MultiSearch, + ) + }) + .transpose() + .map_err(|mut error| { + error.message = format!( + "Inside `.federation.facetsByIndex.{index_uid}`: {}{}", + error.message, + if let Some(query_index) = first_query_index { + format!("\n - Note: index `{index_uid}` used in `.queries[{query_index}]`") + } else { + Default::default() + } + ); + error + })?; + self.results_by_index.push(SearchResultByIndex { + index: index_uid, + hits: merged_result, + estimated_total_hits, + degraded, + used_negative_operator, + facets, + }); + Ok(()) + } + + fn check_unused_facets( + &mut self, + index_scheduler: &IndexScheduler, + ) -> Result<(), ResponseError> { + for (index_uid, facets) in std::mem::take(&mut self.federation.facets_by_index) { + let index = match index_scheduler.index(&index_uid) { + Ok(index) => index, + Err(err) => { + let mut err = ResponseError::from(err); + // Patch the HTTP status code to 400 as it defaults to 404 for `index_not_found`, but + // here the resource not found is not part of the URL. + err.code = StatusCode::BAD_REQUEST; + err.message = format!( + "Inside `.federation.facetsByIndex.{index_uid}`: {}\n - Note: index `{index_uid}` is not used in queries", + err.message + ); + return Err(err); + } + }; + + // Important: this is the only transaction we'll use for this index during this federated search + let rtxn = index.read_txn()?; + + if let Err(mut error) = + self.facet_order.check_facet_order(&index_uid, &facets, &index, &rtxn) + { + error.message = format!( + "Inside `.federation.facetsByIndex.{index_uid}`: {error}\n - Note: index `{index_uid}` is not used in queries", + ); + return Err(error); + } + + if let Some(facets) = facets { + if let Err(mut error) = compute_facet_distribution_stats( + &facets, + &index, + &rtxn, + Default::default(), + super::super::Route::MultiSearch, + ) { + error.message = + format!("Inside `.federation.facetsByIndex.{index_uid}`: {}\n - Note: index `{index_uid}` is not used in queries", error.message); + return Err(error); + } + } + } + Ok(()) + } +} + +enum FacetOrder { + ByFacet(BTreeMap), + ByIndex(BTreeMap), + None, +} + +type FacetDistributions = BTreeMap>; +type FacetStats = BTreeMap; + +impl FacetOrder { + fn check_facet_order( + &mut self, + current_index: &str, + facets_by_index: &Option>, + index: &milli::Index, + rtxn: &milli::heed::RoTxn<'_>, + ) -> Result<(), ResponseError> { + match self { + FacetOrder::ByFacet(facet_order) => { + if let Some(facets_by_index) = facets_by_index { + let index_facet_order = index.sort_facet_values_by(rtxn)?; + for facet in facets_by_index { + let index_facet_order = index_facet_order.get(facet); + let (previous_index, previous_facet_order) = facet_order + .entry(facet.to_owned()) + .or_insert_with(|| (current_index.to_owned(), index_facet_order)); + if previous_facet_order != &index_facet_order { + return Err(MeilisearchHttpError::InconsistentFacetOrder { + facet: facet.clone(), + previous_facet_order: *previous_facet_order, + previous_uid: previous_index.clone(), + current_uid: current_index.to_owned(), + index_facet_order, + } + .into()); + } + } + } + } + FacetOrder::ByIndex(order_by_index) => { + let max_values_per_facet = index + .max_values_per_facet(rtxn)? + .map(|x| x as usize) + .unwrap_or(DEFAULT_VALUES_PER_FACET); + order_by_index.insert( + current_index.to_owned(), + (index.sort_facet_values_by(rtxn)?, max_values_per_facet), + ); + } + FacetOrder::None => {} + } + Ok(()) + } + + fn merge( + self, + merge_facets: Option, + remote_results: Vec, + mut facets: FederatedFacets, + ) -> (Option, Option, FederatedFacets) { + let (facet_distribution, facet_stats, facets_by_index) = match (self, merge_facets) { + (FacetOrder::ByFacet(facet_order), Some(merge_facets)) => { + for remote_facets_by_index in + remote_results.into_iter().map(|result| result.facets_by_index) + { + facets.append(remote_facets_by_index); + } + let facets = facets.merge(merge_facets, facet_order); + + let (facet_distribution, facet_stats) = facets + .map(|ComputedFacets { distribution, stats }| (distribution, stats)) + .unzip(); + + (facet_distribution, facet_stats, FederatedFacets::default()) + } + (FacetOrder::ByIndex(facet_order), _) => { + for remote_facets_by_index in + remote_results.into_iter().map(|result| result.facets_by_index) + { + facets.append(remote_facets_by_index); + } + facets.sort_and_truncate(facet_order); + (None, None, facets) + } + _ => (None, None, facets), + }; + (facet_distribution, facet_stats, facets_by_index) + } +} diff --git a/crates/meilisearch/src/search/federated/proxy.rs b/crates/meilisearch/src/search/federated/proxy.rs new file mode 100644 index 000000000..a2d9bb96c --- /dev/null +++ b/crates/meilisearch/src/search/federated/proxy.rs @@ -0,0 +1,268 @@ +pub use error::ProxySearchError; +use error::ReqwestErrorWithoutUrl; +use meilisearch_types::features::Remote; +use rand::Rng as _; +use reqwest::{Client, Response, StatusCode}; +use serde::de::DeserializeOwned; +use serde_json::Value; + +use super::types::{FederatedSearch, FederatedSearchResult, Federation}; +use crate::search::SearchQueryWithIndex; + +pub const PROXY_SEARCH_HEADER: &str = "Meili-Proxy-Search"; +pub const PROXY_SEARCH_HEADER_VALUE: &str = "true"; + +mod error { + use meilisearch_types::error::ResponseError; + use reqwest::StatusCode; + + #[derive(Debug, thiserror::Error)] + pub enum ProxySearchError { + #[error("{0}")] + CouldNotSendRequest(ReqwestErrorWithoutUrl), + #[error("could not authenticate against the remote host")] + AuthenticationError, + #[error( + "could not parse response from the remote host as a federated search response{}", + response_from_remote(response) + )] + CouldNotParseResponse { response: Result }, + #[error("remote host responded with code {}{}", status_code.as_u16(), response_from_remote(response))] + BadRequest { status_code: StatusCode, response: Result }, + #[error("remote host did not answer before the deadline")] + Timeout, + #[error("remote hit does not contain `{0}`")] + MissingPathInResponse(&'static str), + #[error("remote host responded with code {}{}", status_code.as_u16(), response_from_remote(response))] + RemoteError { status_code: StatusCode, response: Result }, + #[error("remote hit contains an unexpected value at path `{path}`: expected {expected_type}, received `{received_value}`")] + UnexpectedValueInPath { + path: &'static str, + expected_type: &'static str, + received_value: String, + }, + #[error("could not parse weighted score values in the remote hit: {0}")] + CouldNotParseWeightedScoreValues(serde_json::Error), + } + + impl ProxySearchError { + pub fn as_response_error(&self) -> ResponseError { + use meilisearch_types::error::Code; + let message = self.to_string(); + let code = match self { + ProxySearchError::CouldNotSendRequest(_) => Code::ProxyCouldNotSendRequest, + ProxySearchError::AuthenticationError => Code::ProxyInvalidApiKey, + ProxySearchError::BadRequest { .. } => Code::ProxyBadRequest, + ProxySearchError::Timeout => Code::ProxyTimeout, + ProxySearchError::RemoteError { .. } => Code::ProxyRemoteError, + ProxySearchError::CouldNotParseResponse { .. } + | ProxySearchError::MissingPathInResponse(_) + | ProxySearchError::UnexpectedValueInPath { .. } + | ProxySearchError::CouldNotParseWeightedScoreValues(_) => Code::ProxyBadResponse, + }; + ResponseError::from_msg(message, code) + } + } + + #[derive(Debug, thiserror::Error)] + #[error(transparent)] + pub struct ReqwestErrorWithoutUrl(reqwest::Error); + impl ReqwestErrorWithoutUrl { + pub fn new(inner: reqwest::Error) -> Self { + Self(inner.without_url()) + } + } + + fn response_from_remote(response: &Result) -> String { + match response { + Ok(response) => { + // unwrap: to_string of a value should not fail + format!(":\n - response from remote: {}", response) + } + Err(error) => { + format!(":\n - additionally, could not retrieve response from remote: {error}") + } + } + } +} + +#[derive(Clone)] +pub struct ProxySearchParams { + pub deadline: Option, + pub try_count: u32, + pub client: reqwest::Client, +} + +/// Performs a federated search on a remote host and returns the results +pub async fn proxy_search( + node: &Remote, + queries: Vec, + federation: Federation, + params: &ProxySearchParams, +) -> Result { + let url = format!("{}/multi-search", node.url); + + let federated = FederatedSearch { queries, federation: Some(federation) }; + + let search_api_key = node.search_api_key.as_deref(); + + let max_deadline = std::time::Instant::now() + std::time::Duration::from_secs(5); + + let deadline = if let Some(deadline) = params.deadline { + std::time::Instant::min(deadline, max_deadline) + } else { + max_deadline + }; + + for i in 0..=params.try_count { + match try_proxy_search(&url, search_api_key, &federated, ¶ms.client, deadline).await { + Ok(response) => return Ok(response), + Err(retry) => { + let duration = retry.into_duration(i)?; + tokio::time::sleep(duration).await; + } + } + } + try_proxy_search(&url, search_api_key, &federated, ¶ms.client, deadline) + .await + .map_err(Retry::into_error) +} + +async fn try_proxy_search( + url: &str, + search_api_key: Option<&str>, + federated: &FederatedSearch, + client: &Client, + deadline: std::time::Instant, +) -> Result { + let timeout = deadline.saturating_duration_since(std::time::Instant::now()); + + let request = client.post(url).json(&federated).timeout(timeout); + let request = if let Some(search_api_key) = search_api_key { + request.bearer_auth(search_api_key) + } else { + request + }; + let request = request.header(PROXY_SEARCH_HEADER, PROXY_SEARCH_HEADER_VALUE); + + let response = request.send().await; + let response = match response { + Ok(response) => response, + Err(error) if error.is_timeout() => return Err(Retry::give_up(ProxySearchError::Timeout)), + Err(error) => { + return Err(Retry::retry_later(ProxySearchError::CouldNotSendRequest( + ReqwestErrorWithoutUrl::new(error), + ))) + } + }; + + match response.status() { + status_code if status_code.is_success() => (), + StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { + return Err(Retry::give_up(ProxySearchError::AuthenticationError)) + } + status_code if status_code.is_client_error() => { + let response = parse_error(response).await; + return Err(Retry::give_up(ProxySearchError::BadRequest { status_code, response })); + } + status_code if status_code.is_server_error() => { + let response = parse_error(response).await; + return Err(Retry::retry_later(ProxySearchError::RemoteError { + status_code, + response, + })); + } + status_code => { + tracing::debug!( + status_code = status_code.as_u16(), + "remote replied with unexpected status code" + ); + } + } + + let response = match parse_response(response).await { + Ok(response) => response, + Err(response) => { + return Err(Retry::retry_later(ProxySearchError::CouldNotParseResponse { response })) + } + }; + + Ok(response) +} + +/// Always parse the body of the response of a failed request as JSON. +async fn parse_error(response: Response) -> Result { + let bytes = match response.bytes().await { + Ok(bytes) => bytes, + Err(error) => return Err(ReqwestErrorWithoutUrl::new(error)), + }; + + Ok(parse_bytes_as_error(&bytes)) +} + +fn parse_bytes_as_error(bytes: &[u8]) -> String { + match serde_json::from_slice::(bytes) { + Ok(value) => value.to_string(), + Err(_) => String::from_utf8_lossy(bytes).into_owned(), + } +} + +async fn parse_response( + response: Response, +) -> Result> { + let bytes = match response.bytes().await { + Ok(bytes) => bytes, + Err(error) => return Err(Err(ReqwestErrorWithoutUrl::new(error))), + }; + + match serde_json::from_slice::(&bytes) { + Ok(value) => Ok(value), + Err(_) => Err(Ok(parse_bytes_as_error(&bytes))), + } +} + +pub struct Retry { + error: ProxySearchError, + strategy: RetryStrategy, +} + +pub enum RetryStrategy { + GiveUp, + Retry, +} + +impl Retry { + pub fn give_up(error: ProxySearchError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + pub fn retry_later(error: ProxySearchError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + pub fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => { + let retry_duration = std::time::Duration::from_millis((10u64).pow(attempt)); + let retry_duration = retry_duration.min(std::time::Duration::from_secs(1)); // don't wait more than a minute + + // randomly up to double the retry duration + let retry_duration = retry_duration + + rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration); + + tracing::warn!( + "Attempt #{}, failed with {}, retrying after {}ms.", + attempt, + self.error, + retry_duration.as_millis() + ); + Ok(retry_duration) + } + } + } + + pub fn into_error(self) -> ProxySearchError { + self.error + } +} diff --git a/crates/meilisearch/src/search/federated/types.rs b/crates/meilisearch/src/search/federated/types.rs new file mode 100644 index 000000000..d08f5a0b4 --- /dev/null +++ b/crates/meilisearch/src/search/federated/types.rs @@ -0,0 +1,322 @@ +use std::collections::btree_map::Entry; +use std::collections::BTreeMap; +use std::fmt; +use std::vec::Vec; + +use indexmap::IndexMap; +use meilisearch_types::deserr::DeserrJsonError; +use meilisearch_types::error::deserr_codes::{ + InvalidMultiSearchFacetsByIndex, InvalidMultiSearchMaxValuesPerFacet, + InvalidMultiSearchMergeFacets, InvalidMultiSearchQueryPosition, InvalidMultiSearchRemote, + InvalidMultiSearchWeight, InvalidSearchLimit, InvalidSearchOffset, +}; +use meilisearch_types::error::ResponseError; +use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::order_by_map::OrderByMap; +use meilisearch_types::milli::OrderBy; +use serde::{Deserialize, Serialize}; +use utoipa::ToSchema; + +use super::super::{ComputedFacets, FacetStats, HitsInfo, SearchHit, SearchQueryWithIndex}; + +pub const DEFAULT_FEDERATED_WEIGHT: f64 = 1.0; + +// fields in the response +pub const FEDERATION_HIT: &str = "_federation"; +pub const INDEX_UID: &str = "indexUid"; +pub const QUERIES_POSITION: &str = "queriesPosition"; +pub const WEIGHTED_RANKING_SCORE: &str = "weightedRankingScore"; +pub const WEIGHTED_SCORE_VALUES: &str = "weightedScoreValues"; +pub const FEDERATION_REMOTE: &str = "remote"; + +#[derive(Debug, Default, Clone, PartialEq, Serialize, deserr::Deserr, ToSchema)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +#[serde(rename_all = "camelCase")] + +pub struct FederationOptions { + #[deserr(default, error = DeserrJsonError)] + #[schema(value_type = f64)] + pub weight: Weight, + + #[deserr(default, error = DeserrJsonError)] + pub remote: Option, + + #[deserr(default, error = DeserrJsonError)] + pub query_position: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Serialize, deserr::Deserr)] +#[deserr(try_from(f64) = TryFrom::try_from -> InvalidMultiSearchWeight)] +pub struct Weight(f64); + +impl Default for Weight { + fn default() -> Self { + Weight(DEFAULT_FEDERATED_WEIGHT) + } +} + +impl std::convert::TryFrom for Weight { + type Error = InvalidMultiSearchWeight; + + fn try_from(f: f64) -> Result { + if f < 0.0 { + Err(InvalidMultiSearchWeight) + } else { + Ok(Weight(f)) + } + } +} + +impl std::ops::Deref for Weight { + type Target = f64; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Clone, deserr::Deserr, Serialize, ToSchema)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +#[schema(rename_all = "camelCase")] +#[serde(rename_all = "camelCase")] +pub struct Federation { + #[deserr(default = super::super::DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] + pub limit: usize, + #[deserr(default = super::super::DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] + pub offset: usize, + #[deserr(default, error = DeserrJsonError)] + pub facets_by_index: BTreeMap>>, + #[deserr(default, error = DeserrJsonError)] + pub merge_facets: Option, +} + +#[derive(Copy, Clone, Debug, deserr::Deserr, Serialize, Default, ToSchema)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +#[schema(rename_all = "camelCase")] +#[serde(rename_all = "camelCase")] +pub struct MergeFacets { + #[deserr(default, error = DeserrJsonError)] + pub max_values_per_facet: Option, +} + +#[derive(Debug, deserr::Deserr, Serialize, ToSchema)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +#[schema(rename_all = "camelCase")] +#[serde(rename_all = "camelCase")] +pub struct FederatedSearch { + pub queries: Vec, + #[deserr(default)] + pub federation: Option, +} + +#[derive(Serialize, Deserialize, Clone, ToSchema)] +#[serde(rename_all = "camelCase")] +#[schema(rename_all = "camelCase")] +pub struct FederatedSearchResult { + pub hits: Vec, + pub processing_time_ms: u128, + #[serde(flatten)] + pub hits_info: HitsInfo, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub semantic_hit_count: Option, + + #[serde(default, skip_serializing_if = "Option::is_none")] + #[schema(value_type = Option>>)] + pub facet_distribution: Option>>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub facet_stats: Option>, + #[serde(default, skip_serializing_if = "FederatedFacets::is_empty")] + pub facets_by_index: FederatedFacets, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub remote_errors: Option>, + + // These fields are only used for analytics purposes + #[serde(skip)] + pub degraded: bool, + #[serde(skip)] + pub used_negative_operator: bool, +} + +impl fmt::Debug for FederatedSearchResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let FederatedSearchResult { + hits, + processing_time_ms, + hits_info, + semantic_hit_count, + degraded, + used_negative_operator, + facet_distribution, + facet_stats, + facets_by_index, + remote_errors, + } = self; + + let mut debug = f.debug_struct("SearchResult"); + // The most important thing when looking at a search result is the time it took to process + debug.field("processing_time_ms", &processing_time_ms); + debug.field("hits", &format!("[{} hits returned]", hits.len())); + debug.field("hits_info", &hits_info); + if *used_negative_operator { + debug.field("used_negative_operator", used_negative_operator); + } + if *degraded { + debug.field("degraded", degraded); + } + if let Some(facet_distribution) = facet_distribution { + debug.field("facet_distribution", &facet_distribution); + } + if let Some(facet_stats) = facet_stats { + debug.field("facet_stats", &facet_stats); + } + if let Some(semantic_hit_count) = semantic_hit_count { + debug.field("semantic_hit_count", &semantic_hit_count); + } + if !facets_by_index.is_empty() { + debug.field("facets_by_index", &facets_by_index); + } + if !remote_errors.is_none() { + debug.field("remote_errors", &remote_errors); + } + + debug.finish() + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)] +pub struct FederatedFacets(pub BTreeMap); + +impl FederatedFacets { + pub fn insert(&mut self, index: String, facets: Option) { + if let Some(facets) = facets { + self.0.insert(index, facets); + } + } + + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + + pub fn merge( + self, + MergeFacets { max_values_per_facet }: MergeFacets, + facet_order: BTreeMap, + ) -> Option { + if self.is_empty() { + return None; + } + + let mut distribution: BTreeMap = Default::default(); + let mut stats: BTreeMap = Default::default(); + + for facets_by_index in self.0.into_values() { + for (facet, index_distribution) in facets_by_index.distribution { + match distribution.entry(facet) { + Entry::Vacant(entry) => { + entry.insert(index_distribution); + } + Entry::Occupied(mut entry) => { + let distribution = entry.get_mut(); + + for (value, index_count) in index_distribution { + distribution + .entry(value) + .and_modify(|count| *count += index_count) + .or_insert(index_count); + } + } + } + } + + for (facet, index_stats) in facets_by_index.stats { + match stats.entry(facet) { + Entry::Vacant(entry) => { + entry.insert(index_stats); + } + Entry::Occupied(mut entry) => { + let stats = entry.get_mut(); + + stats.min = f64::min(stats.min, index_stats.min); + stats.max = f64::max(stats.max, index_stats.max); + } + } + } + } + + // fixup order + for (facet, values) in &mut distribution { + let order_by = facet_order.get(facet).map(|(_, order)| *order).unwrap_or_default(); + + match order_by { + OrderBy::Lexicographic => { + values.sort_unstable_by(|left, _, right, _| left.cmp(right)) + } + OrderBy::Count => { + values.sort_unstable_by(|_, left, _, right| { + left.cmp(right) + // biggest first + .reverse() + }) + } + } + + if let Some(max_values_per_facet) = max_values_per_facet { + values.truncate(max_values_per_facet) + }; + } + + Some(ComputedFacets { distribution, stats }) + } + + pub(crate) fn append(&mut self, FederatedFacets(remote_facets_by_index): FederatedFacets) { + for (index, remote_facets) in remote_facets_by_index { + let merged_facets = self.0.entry(index).or_default(); + + for (remote_facet, remote_stats) in remote_facets.stats { + match merged_facets.stats.entry(remote_facet) { + Entry::Vacant(vacant_entry) => { + vacant_entry.insert(remote_stats); + } + Entry::Occupied(mut occupied_entry) => { + let stats = occupied_entry.get_mut(); + stats.min = f64::min(stats.min, remote_stats.min); + stats.max = f64::max(stats.max, remote_stats.max); + } + } + } + + for (remote_facet, remote_values) in remote_facets.distribution { + let merged_facet = merged_facets.distribution.entry(remote_facet).or_default(); + for (remote_value, remote_count) in remote_values { + let count = merged_facet.entry(remote_value).or_default(); + *count += remote_count; + } + } + } + } + + pub fn sort_and_truncate(&mut self, facet_order: BTreeMap) { + for (index, facets) in &mut self.0 { + let Some((order_by, max_values_per_facet)) = facet_order.get(index) else { + continue; + }; + for (facet, values) in &mut facets.distribution { + match order_by.get(facet) { + OrderBy::Lexicographic => { + values.sort_unstable_by(|left, _, right, _| left.cmp(right)) + } + OrderBy::Count => { + values.sort_unstable_by(|_, left, _, right| { + left.cmp(right) + // biggest first + .reverse() + }) + } + } + values.truncate(*max_values_per_facet); + } + } + } +} diff --git a/crates/meilisearch/src/search/federated/weighted_scores.rs b/crates/meilisearch/src/search/federated/weighted_scores.rs new file mode 100644 index 000000000..899940a31 --- /dev/null +++ b/crates/meilisearch/src/search/federated/weighted_scores.rs @@ -0,0 +1,88 @@ +use std::cmp::Ordering; + +use meilisearch_types::milli::score_details::{self, WeightedScoreValue}; + +pub fn compare( + mut left_it: impl Iterator, + left_weighted_global_score: f64, + mut right_it: impl Iterator, + right_weighted_global_score: f64, +) -> Ordering { + loop { + let left = left_it.next(); + let right = right_it.next(); + + match (left, right) { + (None, None) => return Ordering::Equal, + (None, Some(_)) => return Ordering::Less, + (Some(_), None) => return Ordering::Greater, + ( + Some( + WeightedScoreValue::WeightedScore(left) | WeightedScoreValue::VectorSort(left), + ), + Some( + WeightedScoreValue::WeightedScore(right) + | WeightedScoreValue::VectorSort(right), + ), + ) => { + if (left - right).abs() <= f64::EPSILON { + continue; + } + return left.partial_cmp(&right).unwrap(); + } + ( + Some(WeightedScoreValue::Sort { asc: left_asc, value: left }), + Some(WeightedScoreValue::Sort { asc: right_asc, value: right }), + ) => { + if left_asc != right_asc { + return left_weighted_global_score + .partial_cmp(&right_weighted_global_score) + .unwrap(); + } + match score_details::compare_sort_values(left_asc, &left, &right) { + Ordering::Equal => continue, + order => return order, + } + } + ( + Some(WeightedScoreValue::GeoSort { asc: left_asc, distance: left }), + Some(WeightedScoreValue::GeoSort { asc: right_asc, distance: right }), + ) => { + if left_asc != right_asc { + continue; + } + match (left, right) { + (None, None) => continue, + (None, Some(_)) => return Ordering::Less, + (Some(_), None) => return Ordering::Greater, + (Some(left), Some(right)) => { + if (left - right).abs() <= f64::EPSILON { + continue; + } + return left.partial_cmp(&right).unwrap(); + } + } + } + // not comparable details, use global + (Some(WeightedScoreValue::WeightedScore(_)), Some(_)) + | (Some(_), Some(WeightedScoreValue::WeightedScore(_))) + | (Some(WeightedScoreValue::VectorSort(_)), Some(_)) + | (Some(_), Some(WeightedScoreValue::VectorSort(_))) + | (Some(WeightedScoreValue::GeoSort { .. }), Some(WeightedScoreValue::Sort { .. })) + | (Some(WeightedScoreValue::Sort { .. }), Some(WeightedScoreValue::GeoSort { .. })) => { + let left_count = left_it.count(); + let right_count = right_it.count(); + // compare how many remaining groups of rules each side has. + // the group with the most remaining groups wins. + return left_count + .cmp(&right_count) + // breaks ties with the global ranking score + .then_with(|| { + left_weighted_global_score + .partial_cmp(&right_weighted_global_score) + .unwrap() + }); + } + } + } +}