diff --git a/Cargo.lock b/Cargo.lock index 156e3d146..4c1124f8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3323,6 +3323,7 @@ dependencies = [ "rayon", "regex", "reqwest", + "roaring", "rustls 0.21.12", "rustls-pemfile", "segment", diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index f529238e4..166e412e3 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -243,6 +243,8 @@ InvalidSearchAttributesToHighlight , InvalidRequest , BAD_REQUEST ; InvalidSimilarAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSimilarRetrieveVectors , InvalidRequest , BAD_REQUEST ; InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; +InvalidSearchWeight , InvalidRequest , BAD_REQUEST ; +InvalidSearchFederated , InvalidRequest , BAD_REQUEST ; InvalidSearchRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; InvalidSimilarRankingScoreThreshold , InvalidRequest , BAD_REQUEST ; InvalidSearchRetrieveVectors , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch/Cargo.toml b/meilisearch/Cargo.toml index ce73ebdcf..fcd2330f7 100644 --- a/meilisearch/Cargo.toml +++ b/meilisearch/Cargo.toml @@ -106,6 +106,7 @@ tracing-subscriber = { version = "0.3.18", features = ["json"] } tracing-trace = { version = "0.1.0", path = "../tracing-trace" } tracing-actix-web = "0.7.10" build-info = { version = "1.7.0", path = "../build-info" } +roaring = "0.10.2" [dev-dependencies] actix-rt = "2.9.0" diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index 94e4684d5..980dcb83f 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -35,8 +35,8 @@ use crate::routes::indexes::documents::UpdateDocumentsQuery; use crate::routes::indexes::facet_search::FacetSearchQuery; use crate::routes::{create_all_stats, Stats}; use crate::search::{ - FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, - SimilarQuery, SimilarResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, + FacetSearchResult, FederatedSearch, MatchingStrategy, SearchQuery, SearchQueryWithIndex, + SearchResult, SimilarQuery, SimilarResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO, }; @@ -1075,22 +1075,33 @@ pub struct MultiSearchAggregator { show_ranking_score: bool, show_ranking_score_details: bool, + // federation + use_federation: bool, + // context user_agents: HashSet, } impl MultiSearchAggregator { - pub fn from_queries(query: &[SearchQueryWithIndex], request: &HttpRequest) -> Self { + pub fn from_federated_search( + federated_search: &FederatedSearch, + request: &HttpRequest, + ) -> Self { let timestamp = Some(OffsetDateTime::now_utc()); let user_agents = extract_user_agents(request).into_iter().collect(); - let distinct_indexes: HashSet<_> = query + let use_federation = federated_search.federation.is_some(); + + let distinct_indexes: HashSet<_> = federated_search + .queries .iter() .map(|query| { + let query = &query; // make sure we get a compilation error if a field gets added to / removed from SearchQueryWithIndex let SearchQueryWithIndex { index_uid, + federated: _, q: _, vector: _, offset: _, @@ -1122,8 +1133,10 @@ impl MultiSearchAggregator { }) .collect(); - let show_ranking_score = query.iter().any(|query| query.show_ranking_score); - let show_ranking_score_details = query.iter().any(|query| query.show_ranking_score_details); + let show_ranking_score = + federated_search.queries.iter().any(|query| query.show_ranking_score); + let show_ranking_score_details = + federated_search.queries.iter().any(|query| query.show_ranking_score_details); Self { timestamp, @@ -1131,10 +1144,11 @@ impl MultiSearchAggregator { total_succeeded: 0, total_distinct_index_count: distinct_indexes.len(), total_single_index: if distinct_indexes.len() == 1 { 1 } else { 0 }, - total_search_count: query.len(), + total_search_count: federated_search.queries.len(), show_ranking_score, show_ranking_score_details, user_agents, + use_federation, } } @@ -1160,6 +1174,7 @@ impl MultiSearchAggregator { let show_ranking_score_details = this.show_ranking_score_details || other.show_ranking_score_details; let mut user_agents = this.user_agents; + let use_federation = this.use_federation || other.use_federation; for user_agent in other.user_agents.into_iter() { user_agents.insert(user_agent); @@ -1176,6 +1191,7 @@ impl MultiSearchAggregator { user_agents, show_ranking_score, show_ranking_score_details, + use_federation, // do not add _ or ..Default::default() here }; @@ -1194,6 +1210,7 @@ impl MultiSearchAggregator { user_agents, show_ranking_score, show_ranking_score_details, + use_federation, } = self; if total_received == 0 { @@ -1218,6 +1235,9 @@ impl MultiSearchAggregator { "scoring": { "show_ranking_score": show_ranking_score, "show_ranking_score_details": show_ranking_score_details, + }, + "federation": { + "use_federation": use_federation, } }); diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index 1d697dac6..92ce59e54 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -15,7 +15,8 @@ use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::sequential_extractor::SeqHandler; use crate::routes::indexes::search::search_kind; use crate::search::{ - add_search_rules, perform_search, RetrieveVectors, SearchQueryWithIndex, SearchResultWithIndex, + add_search_rules, perform_federated_search, perform_search, FederatedSearch, RetrieveVectors, + SearchQueryWithIndex, SearchResultWithIndex, }; use crate::search_queue::SearchQueue; @@ -28,85 +29,44 @@ struct SearchResults { results: Vec, } -#[derive(Debug, deserr::Deserr)] -#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] -pub struct SearchQueries { - queries: Vec, -} - pub async fn multi_search_with_post( index_scheduler: GuardedData, Data>, search_queue: Data, - params: AwebJson, + params: AwebJson, req: HttpRequest, analytics: web::Data, ) -> Result { - let queries = params.into_inner().queries; - - let mut multi_aggregate = MultiSearchAggregator::from_queries(&queries, &req); - let features = index_scheduler.features(); - // Since we don't want to process half of the search requests and then get a permit refused // we're going to get one permit for the whole duration of the multi-search request. let _permit = search_queue.try_get_search_permit().await?; - // Explicitly expect a `(ResponseError, usize)` for the error type rather than `ResponseError` only, - // so that `?` doesn't work if it doesn't use `with_index`, ensuring that it is not forgotten in case of code - // changes. - let search_results: Result<_, (ResponseError, usize)> = async { - let mut search_results = Vec::with_capacity(queries.len()); - for (query_index, (index_uid, mut query)) in - queries.into_iter().map(SearchQueryWithIndex::into_index_query).enumerate() - { - debug!(on_index = query_index, parameters = ?query, "Multi-search"); + let federated_search = params.into_inner(); + let mut multi_aggregate = MultiSearchAggregator::from_federated_search(&federated_search, &req); + + let FederatedSearch { mut queries, federation } = federated_search; + + let features = index_scheduler.features(); + + // regardless of federation, check authorization and apply search rules + let auth = 'check_authorization: { + for (query_index, federated_query) in queries.iter_mut().enumerate() { + let index_uid = federated_query.index_uid.as_str(); // Check index from API key - if !index_scheduler.filters().is_index_authorized(&index_uid) { - return Err(AuthenticationError::InvalidToken).with_index(query_index); + if !index_scheduler.filters().is_index_authorized(index_uid) { + break 'check_authorization Err(AuthenticationError::InvalidToken) + .with_index(query_index); } // Apply search rules from tenant token - if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) + if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(index_uid) { - add_search_rules(&mut query.filter, search_rules); + add_search_rules(&mut federated_query.filter, search_rules); } - - let index = index_scheduler - .index(&index_uid) - .map_err(|err| { - let mut err = ResponseError::from(err); - // Patch the HTTP status code to 400 as it defaults to 404 for `index_not_found`, but - // here the resource not found is not part of the URL. - err.code = StatusCode::BAD_REQUEST; - err - }) - .with_index(query_index)?; - - let search_kind = search_kind(&query, index_scheduler.get_ref(), &index, features) - .with_index(query_index)?; - let retrieve_vector = - RetrieveVectors::new(query.retrieve_vectors, features).with_index(query_index)?; - - let search_result = tokio::task::spawn_blocking(move || { - perform_search(&index, query, search_kind, retrieve_vector) - }) - .await - .with_index(query_index)?; - - search_results.push(SearchResultWithIndex { - index_uid: index_uid.into_inner(), - result: search_result.with_index(query_index)?, - }); } - Ok(search_results) - } - .await; + Ok(()) + }; - if search_results.is_ok() { - multi_aggregate.succeed(); - } - analytics.post_multi_search(multi_aggregate); - - let search_results = search_results.map_err(|(mut err, query_index)| { + auth.map_err(|(mut err, query_index)| { // Add the query index that failed as context for the error message. // We're doing it only here and not directly in the `WithIndex` trait so that the `with_index` function returns a different type // of result and we can benefit from static typing. @@ -114,9 +74,90 @@ pub async fn multi_search_with_post( err })?; - debug!(returns = ?search_results, "Multi-search"); + let response = match federation { + Some(federation) => { + let search_result = tokio::task::spawn_blocking(move || { + perform_federated_search(&index_scheduler, queries, federation, features) + }) + .await; - Ok(HttpResponse::Ok().json(SearchResults { results: search_results })) + if let Ok(Ok(_)) = search_result { + multi_aggregate.succeed(); + } + + analytics.post_multi_search(multi_aggregate); + HttpResponse::Ok().json(search_result??) + } + None => { + // Explicitly expect a `(ResponseError, usize)` for the error type rather than `ResponseError` only, + // so that `?` doesn't work if it doesn't use `with_index`, ensuring that it is not forgotten in case of code + // changes. + let search_results: Result<_, (ResponseError, usize)> = async { + let mut search_results = Vec::with_capacity(queries.len()); + for (query_index, (index_uid, query, federated)) in queries + .into_iter() + .map(SearchQueryWithIndex::into_index_query_federated) + .enumerate() + { + debug!(on_index = query_index, parameters = ?query, "Multi-search"); + + if federated.is_some() { + /// FIXME: add error case + panic!("federated is some in a non-federated query") + } + + let index = index_scheduler + .index(&index_uid) + .map_err(|err| { + let mut err = ResponseError::from(err); + // Patch the HTTP status code to 400 as it defaults to 404 for `index_not_found`, but + // here the resource not found is not part of the URL. + err.code = StatusCode::BAD_REQUEST; + err + }) + .with_index(query_index)?; + + let search_kind = + search_kind(&query, index_scheduler.get_ref(), &index, features) + .with_index(query_index)?; + let retrieve_vector = RetrieveVectors::new(query.retrieve_vectors, features) + .with_index(query_index)?; + + let search_result = tokio::task::spawn_blocking(move || { + perform_search(&index, query, search_kind, retrieve_vector) + }) + .await + .with_index(query_index)?; + + search_results.push(SearchResultWithIndex { + index_uid: index_uid.into_inner(), + result: search_result.with_index(query_index)?, + }); + } + Ok(search_results) + } + .await; + + if search_results.is_ok() { + multi_aggregate.succeed(); + } + analytics.post_multi_search(multi_aggregate); + + let search_results = search_results.map_err(|(mut err, query_index)| { + // Add the query index that failed as context for the error message. + // We're doing it only here and not directly in the `WithIndex` trait so that the `with_index` function returns a different type + // of result and we can benefit from static typing. + err.message = format!("Inside `.queries[{query_index}]`: {}", err.message); + err + })?; + + debug!(returns = ?search_results, "Multi-search"); + + HttpResponse::Ok().json(SearchResults { results: search_results }) + } + }; + + Ok(response) } /// Local `Result` extension trait to avoid `map_err` boilerplate. diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 375060889..491d28dd1 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -1,6 +1,6 @@ use core::fmt; use std::cmp::min; -use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -31,6 +31,9 @@ use serde_json::{json, Value}; use crate::error::MeilisearchHttpError; +mod federated; +pub use federated::{perform_federated_search, Federated, FederatedSearch, Federation}; + type MatchesPosition = BTreeMap>; pub const DEFAULT_SEARCH_OFFSET: fn() -> usize = || 0; @@ -257,11 +260,13 @@ pub struct HybridQuery { pub embedder: Option, } +#[derive(Clone)] pub enum SearchKind { KeywordOnly, SemanticOnly { embedder_name: String, embedder: Arc }, Hybrid { embedder_name: String, embedder: Arc, semantic_ratio: f32 }, } + impl SearchKind { pub(crate) fn semantic( index_scheduler: &index_scheduler::IndexScheduler, @@ -358,7 +363,7 @@ impl SearchQuery { } } -/// A `SearchQuery` + an index UID. +/// A `SearchQuery` + an index UID and an optional Federated option. // This struct contains the fields of `SearchQuery` inline. // This is because neither deserr nor serde support `flatten` when using `deny_unknown_fields. // The `From` implementation ensures both structs remain up to date. @@ -373,10 +378,10 @@ pub struct SearchQueryWithIndex { pub vector: Option>, #[deserr(default, error = DeserrJsonError)] pub hybrid: Option, - #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] - pub offset: usize, - #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] - pub limit: usize, + #[deserr(default, error = DeserrJsonError)] + pub offset: Option, + #[deserr(default, error = DeserrJsonError)] + pub limit: Option, #[deserr(default, error = DeserrJsonError)] pub page: Option, #[deserr(default, error = DeserrJsonError)] @@ -417,12 +422,34 @@ pub struct SearchQueryWithIndex { pub attributes_to_search_on: Option>, #[deserr(default, error = DeserrJsonError, default)] pub ranking_score_threshold: Option, + + #[deserr(default)] + pub federated: Option, } impl SearchQueryWithIndex { - pub fn into_index_query(self) -> (IndexUid, SearchQuery) { + pub fn is_federated(&self) -> bool { + self.federated.is_some() + } + + pub fn has_pagination(&self) -> Option<&'static str> { + if self.offset.is_some() { + Some("offset") + } else if self.limit.is_some() { + Some("limit") + } else if self.page.is_some() { + Some("page") + } else if self.hits_per_page.is_some() { + Some("hitsPerPage") + } else { + None + } + } + + pub fn into_index_query_federated(self) -> (IndexUid, SearchQuery, Option) { let SearchQueryWithIndex { index_uid, + federated, q, vector, offset, @@ -454,8 +481,8 @@ impl SearchQueryWithIndex { SearchQuery { q, vector, - offset, - limit, + offset: offset.unwrap_or(DEFAULT_SEARCH_OFFSET()), + limit: limit.unwrap_or(DEFAULT_SEARCH_LIMIT()), page, hits_per_page, attributes_to_retrieve, @@ -480,6 +507,7 @@ impl SearchQueryWithIndex { // do not use ..Default::default() here, // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` }, + federated, ) } } @@ -864,15 +892,7 @@ pub fn perform_search( used_negative_operator, }, semantic_hit_count, - ) = match &search_kind { - SearchKind::KeywordOnly => (search.execute()?, None), - SearchKind::SemanticOnly { .. } => { - let results = search.execute()?; - let semantic_hit_count = results.document_scores.len() as u32; - (results, Some(semantic_hit_count)) - } - SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?, - }; + ) = search_from_kind(search_kind, search)?; let SearchQuery { q, @@ -919,8 +939,13 @@ pub fn perform_search( show_ranking_score_details, }; - let documents = - make_hits(index, &rtxn, format, matching_words, documents_ids, document_scores)?; + let documents = make_hits( + index, + &rtxn, + format, + matching_words, + documents_ids.iter().copied().zip(document_scores.iter()), + )?; let number_of_hits = min(candidates.len() as usize, max_total_hits); let hits_info = if is_finite_pagination { @@ -988,6 +1013,22 @@ pub fn perform_search( Ok(result) } +pub fn search_from_kind( + search_kind: SearchKind, + search: milli::Search<'_>, +) -> Result<(milli::SearchResult, Option), MeilisearchHttpError> { + let (milli_result, semantic_hit_count) = match &search_kind { + SearchKind::KeywordOnly => (search.execute()?, None), + SearchKind::SemanticOnly { .. } => { + let results = search.execute()?; + let semantic_hit_count = results.document_scores.len() as u32; + (results, Some(semantic_hit_count)) + } + SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?, + }; + Ok((milli_result, semantic_hit_count)) +} + struct AttributesFormat { attributes_to_retrieve: Option>, retrieve_vectors: RetrieveVectors, @@ -1033,129 +1074,189 @@ impl RetrieveVectors { } } -fn make_hits( - index: &Index, - rtxn: &RoTxn<'_>, - format: AttributesFormat, - matching_words: milli::MatchingWords, - documents_ids: Vec, - document_scores: Vec>, -) -> Result, MeilisearchHttpError> { - let fields_ids_map = index.fields_ids_map(rtxn).unwrap(); - let displayed_ids = - index.displayed_fields_ids(rtxn)?.map(|fields| fields.into_iter().collect::>()); +struct HitMaker<'a> { + index: &'a Index, + rtxn: &'a RoTxn<'a>, + fields_ids_map: FieldsIdsMap, + displayed_ids: BTreeSet, + vectors_fid: Option, + retrieve_vectors: RetrieveVectors, + to_retrieve_ids: BTreeSet, + embedding_configs: Vec, + formatter_builder: MatcherBuilder<'a>, + formatted_options: BTreeMap, + show_ranking_score: bool, + show_ranking_score_details: bool, + sort: Option>, + show_matches_position: bool, +} - let vectors_fid = fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME); +impl<'a> HitMaker<'a> { + pub fn tokenizer<'b>( + script_lang_map: &'b HashMap>, + dictionary: Option<&'b [&'b str]>, + separators: Option<&'b [&'b str]>, + ) -> milli::tokenizer::Tokenizer<'b> { + let mut tokenizer_builder = TokenizerBuilder::default(); + tokenizer_builder.create_char_map(true); + if !script_lang_map.is_empty() { + tokenizer_builder.allow_list(script_lang_map); + } - let vectors_is_hidden = match (&displayed_ids, vectors_fid) { - // displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid - (None, _) => false, - // displayed_ids is a finite list, and `_vectors` cannot be part of it because it is not an existing field - (Some(_), None) => true, - // displayed_ids is a finit list, so hide if `_vectors` is not part of it - (Some(map), Some(vectors_fid)) => map.contains(&vectors_fid), - }; + if let Some(separators) = separators { + tokenizer_builder.separators(separators); + } - let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors { - if vectors_is_hidden { - RetrieveVectors::Hide + if let Some(dictionary) = dictionary { + tokenizer_builder.words_dict(dictionary); + } + + tokenizer_builder.into_tokenizer() + } + + pub fn formatter_builder( + matching_words: milli::MatchingWords, + tokenizer: milli::tokenizer::Tokenizer<'_>, + ) -> MatcherBuilder<'_> { + let formatter_builder = MatcherBuilder::new(matching_words, tokenizer); + + formatter_builder + } + + pub fn new( + index: &'a Index, + rtxn: &'a RoTxn<'a>, + format: AttributesFormat, + mut formatter_builder: MatcherBuilder<'a>, + ) -> Result { + formatter_builder.crop_marker(format.crop_marker); + formatter_builder.highlight_prefix(format.highlight_pre_tag); + formatter_builder.highlight_suffix(format.highlight_post_tag); + + let fields_ids_map = index.fields_ids_map(rtxn)?; + let displayed_ids = index + .displayed_fields_ids(rtxn)? + .map(|fields| fields.into_iter().collect::>()); + + let vectors_fid = + fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME); + + let vectors_is_hidden = match (&displayed_ids, vectors_fid) { + // displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid + (None, _) => false, + // displayed_ids is a finite list, and `_vectors` cannot be part of it because it is not an existing field + (Some(_), None) => true, + // displayed_ids is a finit list, so hide if `_vectors` is not part of it + (Some(map), Some(vectors_fid)) => map.contains(&vectors_fid), + }; + + let displayed_ids = + displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); + + let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors { + if vectors_is_hidden { + RetrieveVectors::Hide + } else { + RetrieveVectors::Retrieve + } } else { - RetrieveVectors::Retrieve - } - } else { - format.retrieve_vectors - }; + format.retrieve_vectors + }; - let displayed_ids = - displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); - let fids = |attrs: &BTreeSet| { - let mut ids = BTreeSet::new(); - for attr in attrs { - if attr == "*" { - ids.clone_from(&displayed_ids); - break; + let fids = |attrs: &BTreeSet| { + let mut ids = BTreeSet::new(); + for attr in attrs { + if attr == "*" { + ids.clone_from(&displayed_ids); + break; + } + + if let Some(id) = fields_ids_map.id(attr) { + ids.insert(id); + } } + ids + }; + let to_retrieve_ids: BTreeSet<_> = format + .attributes_to_retrieve + .as_ref() + .map(fids) + .unwrap_or_else(|| displayed_ids.clone()) + .intersection(&displayed_ids) + .cloned() + .collect(); - if let Some(id) = fields_ids_map.id(attr) { - ids.insert(id); - } - } - ids - }; - let to_retrieve_ids: BTreeSet<_> = format - .attributes_to_retrieve - .as_ref() - .map(fids) - .unwrap_or_else(|| displayed_ids.clone()) - .intersection(&displayed_ids) - .cloned() - .collect(); + let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); + let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); + let formatted_options = compute_formatted_options( + &attr_to_highlight, + &attr_to_crop, + format.crop_length, + &to_retrieve_ids, + &fields_ids_map, + &displayed_ids, + ); - let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); - let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); - let formatted_options = compute_formatted_options( - &attr_to_highlight, - &attr_to_crop, - format.crop_length, - &to_retrieve_ids, - &fields_ids_map, - &displayed_ids, - ); - let mut tokenizer_builder = TokenizerBuilder::default(); - tokenizer_builder.create_char_map(true); - let script_lang_map = index.script_language(rtxn)?; - if !script_lang_map.is_empty() { - tokenizer_builder.allow_list(&script_lang_map); + let embedding_configs = index.embedding_configs(rtxn)?; + + Ok(Self { + index, + rtxn, + fields_ids_map, + displayed_ids, + vectors_fid, + retrieve_vectors, + to_retrieve_ids, + embedding_configs, + formatter_builder, + formatted_options, + show_ranking_score: format.show_ranking_score, + show_ranking_score_details: format.show_ranking_score_details, + show_matches_position: format.show_matches_position, + sort: format.sort, + }) } - let separators = index.allowed_separators(rtxn)?; - let separators: Option> = - separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref separators) = separators { - tokenizer_builder.separators(separators); - } - let dictionary = index.dictionary(rtxn)?; - let dictionary: Option> = - dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref dictionary) = dictionary { - tokenizer_builder.words_dict(dictionary); - } - let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build()); - formatter_builder.crop_marker(format.crop_marker); - formatter_builder.highlight_prefix(format.highlight_pre_tag); - formatter_builder.highlight_suffix(format.highlight_post_tag); - let mut documents = Vec::new(); - let embedding_configs = index.embedding_configs(rtxn)?; - let documents_iter = index.documents(rtxn, documents_ids)?; - for ((id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { + + pub fn make_hit( + &self, + id: u32, + score: &[ScoreDetails], + ) -> Result { + let (_, obkv) = + self.index.iter_documents(self.rtxn, std::iter::once(id))?.next().unwrap()?; + // First generate a document with all the displayed fields - let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; + let displayed_document = make_document(&self.displayed_ids, &self.fields_ids_map, obkv)?; let add_vectors_fid = - vectors_fid.filter(|_fid| retrieve_vectors == RetrieveVectors::Retrieve); + self.vectors_fid.filter(|_fid| self.retrieve_vectors == RetrieveVectors::Retrieve); // select the attributes to retrieve - let attributes_to_retrieve = to_retrieve_ids + let attributes_to_retrieve = self + .to_retrieve_ids .iter() // skip the vectors_fid if RetrieveVectors::Hide - .filter(|fid| match vectors_fid { + .filter(|fid| match self.vectors_fid { Some(vectors_fid) => { - !(retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid) + !(self.retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid) } None => true, }) // need to retrieve the existing `_vectors` field if the `RetrieveVectors::Retrieve` .chain(add_vectors_fid.iter()) - .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); + .map(|&fid| self.fields_ids_map.name(fid).expect("Missing field name")); + let mut document = permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); - if retrieve_vectors == RetrieveVectors::Retrieve { + if self.retrieve_vectors == RetrieveVectors::Retrieve { let mut vectors = match document.remove("_vectors") { Some(Value::Object(map)) => map, _ => Default::default(), }; - for (name, vector) in index.embeddings(rtxn, id)? { - let user_provided = embedding_configs + for (name, vector) in self.index.embeddings(self.rtxn, id)? { + let user_provided = self + .embedding_configs .iter() .find(|conf| conf.name == name) .is_some_and(|conf| conf.user_provided.contains(id)); @@ -1168,21 +1269,21 @@ fn make_hits( let (matches_position, formatted) = format_fields( &displayed_document, - &fields_ids_map, - &formatter_builder, - &formatted_options, - format.show_matches_position, - &displayed_ids, + &self.fields_ids_map, + &self.formatter_builder, + &self.formatted_options, + self.show_matches_position, + &self.displayed_ids, )?; - if let Some(sort) = format.sort.as_ref() { + if let Some(sort) = self.sort.as_ref() { insert_geo_distance(sort, &mut document); } let ranking_score = - format.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); + self.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); let ranking_score_details = - format.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); + self.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); let hit = SearchHit { document, @@ -1191,7 +1292,38 @@ fn make_hits( ranking_score_details, ranking_score, }; - documents.push(hit); + + Ok(hit) + } +} + +fn make_hits<'a>( + index: &Index, + rtxn: &RoTxn<'_>, + format: AttributesFormat, + matching_words: milli::MatchingWords, + documents_ids_scores: impl Iterator)> + 'a, +) -> Result, MeilisearchHttpError> { + let mut documents = Vec::new(); + + let script_lang_map = index.script_language(rtxn)?; + + let dictionary = index.dictionary(rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); + let separators = index.allowed_separators(rtxn)?; + let separators: Option> = + separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); + + let tokenizer = + HitMaker::tokenizer(&script_lang_map, dictionary.as_deref(), separators.as_deref()); + + let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer); + + let hit_maker = HitMaker::new(index, rtxn, format, formatter_builder)?; + + for (id, score) in documents_ids_scores { + documents.push(hit_maker.make_hit(id, score)?); } Ok(documents) } @@ -1307,7 +1439,13 @@ pub fn perform_similar( show_ranking_score_details, }; - let hits = make_hits(index, &rtxn, format, Default::default(), documents_ids, document_scores)?; + let hits = make_hits( + index, + &rtxn, + format, + Default::default(), + documents_ids.iter().copied().zip(document_scores.iter()), + )?; let max_total_hits = index .pagination_max_total_hits(&rtxn) @@ -1480,10 +1618,10 @@ fn make_document( Ok(document) } -fn format_fields<'a>( +fn format_fields( document: &Document, field_ids_map: &FieldsIdsMap, - builder: &'a MatcherBuilder<'a>, + builder: &MatcherBuilder<'_>, formatted_options: &BTreeMap, compute_matches: bool, displayable_ids: &BTreeSet, @@ -1538,9 +1676,9 @@ fn format_fields<'a>( Ok((matches_position, document)) } -fn format_value<'a>( +fn format_value( value: Value, - builder: &'a MatcherBuilder<'a>, + builder: &MatcherBuilder<'_>, format_options: Option, infos: &mut Vec, compute_matches: bool, diff --git a/meilisearch/src/search/federated.rs b/meilisearch/src/search/federated.rs new file mode 100644 index 000000000..5547cf7ac --- /dev/null +++ b/meilisearch/src/search/federated.rs @@ -0,0 +1,521 @@ +use std::cmp::Ordering; +use std::collections::BTreeMap; +use std::fmt; +use std::iter::Zip; +use std::rc::Rc; +use std::time::Duration; +use std::vec::{IntoIter, Vec}; + +use http::StatusCode; +use index_scheduler::{IndexScheduler, RoFeatures}; +use meilisearch_types::deserr::DeserrJsonError; +use meilisearch_types::error::deserr_codes::{ + InvalidSearchLimit, InvalidSearchOffset, InvalidSearchWeight, +}; +use meilisearch_types::error::ResponseError; +use meilisearch_types::milli; +use meilisearch_types::milli::score_details::{ScoreDetails, ScoreValue}; +use meilisearch_types::milli::{DocumentId, TimeBudget}; +use roaring::RoaringBitmap; +use serde::Serialize; + +use super::{ + prepare_search, AttributesFormat, HitMaker, HitsInfo, RetrieveVectors, SearchHit, SearchKind, + SearchQuery, SearchQueryWithIndex, +}; +use crate::routes::indexes::search::search_kind; + +pub const DEFAULT_FEDERATED_WEIGHT: fn() -> f64 = || 1.0; + +#[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct Federated { + #[deserr(default = DEFAULT_FEDERATED_WEIGHT(), error = DeserrJsonError)] + pub weight: f64, +} + +impl Default for Federated { + fn default() -> Self { + Self { weight: DEFAULT_FEDERATED_WEIGHT() } + } +} + +#[derive(Debug, deserr::Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct Federation { + #[deserr(default = super::DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] + pub limit: usize, + #[deserr(default = super::DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] + pub offset: usize, +} + +#[derive(Debug, deserr::Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct FederatedSearch { + pub queries: Vec, + #[deserr(default)] + pub federation: Option, +} + +#[derive(Serialize, Clone, PartialEq)] +#[serde(rename_all = "camelCase")] +pub struct FederatedSearchResult { + pub hits: Vec, + pub processing_time_ms: u128, + #[serde(flatten)] + pub hits_info: HitsInfo, + + #[serde(skip_serializing_if = "Option::is_none")] + pub semantic_hit_count: Option, + + // These fields are only used for analytics purposes + #[serde(skip)] + pub degraded: bool, + #[serde(skip)] + pub used_negative_operator: bool, +} + +impl fmt::Debug for FederatedSearchResult { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let FederatedSearchResult { + hits, + processing_time_ms, + hits_info, + semantic_hit_count, + degraded, + used_negative_operator, + } = self; + + let mut debug = f.debug_struct("SearchResult"); + // The most important thing when looking at a search result is the time it took to process + debug.field("processing_time_ms", &processing_time_ms); + debug.field("hits", &format!("[{} hits returned]", hits.len())); + debug.field("hits_info", &hits_info); + if *used_negative_operator { + debug.field("used_negative_operator", used_negative_operator); + } + if *degraded { + debug.field("degraded", degraded); + } + if let Some(semantic_hit_count) = semantic_hit_count { + debug.field("semantic_hit_count", &semantic_hit_count); + } + + debug.finish() + } +} + +struct WeightedScore<'a> { + details: &'a [ScoreDetails], + weight: f64, +} + +impl<'a> WeightedScore<'a> { + pub fn new(details: &'a [ScoreDetails], weight: f64) -> Self { + Self { details, weight } + } + + pub fn weighted_global_score(&self) -> f64 { + ScoreDetails::global_score(self.details.iter()) * self.weight + } + + pub fn compare_weighted_global_scores(&self, other: &Self) -> Ordering { + self.weighted_global_score() + .partial_cmp(&other.weighted_global_score()) + // both are numbers, possibly infinite + .unwrap() + } + + pub fn compare(&self, other: &Self) -> Ordering { + let mut left_it = ScoreDetails::score_values(self.details.iter()); + let mut right_it = ScoreDetails::score_values(other.details.iter()); + + loop { + let left = left_it.next(); + let right = right_it.next(); + + match (left, right) { + (None, None) => return Ordering::Equal, + (None, Some(_)) => return Ordering::Less, + (Some(_), None) => return Ordering::Greater, + (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { + let left = left * self.weight; + let right = right * other.weight; + if (left - right).abs() <= f64::EPSILON { + continue; + } + return left.partial_cmp(&right).unwrap(); + } + (Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => { + match left.partial_cmp(right) { + Some(Ordering::Equal) => continue, + Some(order) => return order, + None => return self.compare_weighted_global_scores(other), + } + } + (Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => { + match left.partial_cmp(right) { + Some(Ordering::Equal) => continue, + Some(order) => return order, + None => { + return self.compare_weighted_global_scores(other); + } + } + } + // not comparable details, use global + (Some(ScoreValue::Score(_)), Some(_)) + | (Some(_), Some(ScoreValue::Score(_))) + | (Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_))) + | (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => { + return self.compare_weighted_global_scores(other); + } + } + } + } +} + +struct QueryByIndex { + query: SearchQuery, + federated: Federated, + query_index: usize, +} + +struct SearchResultByQuery<'a> { + documents_ids: Vec, + document_scores: Vec>, + federated: Federated, + hit_maker: HitMaker<'a>, + query_index: usize, +} + +struct SearchResultByQueryIter<'a> { + it: Zip, IntoIter>>, + federated: Federated, + hit_maker: Rc>, + query_index: usize, +} + +impl<'a> SearchResultByQueryIter<'a> { + fn new( + SearchResultByQuery { documents_ids, document_scores, federated, hit_maker, query_index }: SearchResultByQuery<'a>, + ) -> Self { + let it = documents_ids.into_iter().zip(document_scores); + Self { it, federated, hit_maker: Rc::new(hit_maker), query_index } + } +} + +struct SearchResultByQueryIterItem<'a> { + docid: DocumentId, + score: Vec, + federated: Federated, + hit_maker: Rc>, + query_index: usize, +} + +fn merge_index_local_results( + results_by_query: Vec>, +) -> impl Iterator + '_ { + itertools::kmerge_by( + results_by_query.into_iter().map(SearchResultByQueryIter::new), + |left: &SearchResultByQueryIterItem, right: &SearchResultByQueryIterItem| { + let left_score = WeightedScore::new(&left.score, left.federated.weight); + let right_score = WeightedScore::new(&right.score, right.federated.weight); + + match left_score.compare(&right_score) { + // the biggest score goes first + Ordering::Greater => true, + // break ties using query index + Ordering::Equal => left.query_index < right.query_index, + Ordering::Less => false, + } + }, + ) +} + +fn merge_index_global_results( + results_by_index: Vec, +) -> impl Iterator { + itertools::kmerge_by( + results_by_index.into_iter().map(|result_by_index| result_by_index.hits.into_iter()), + |left: &SearchHitByIndex, right: &SearchHitByIndex| { + let left_score = WeightedScore::new(&left.score, left.federated.weight); + let right_score = WeightedScore::new(&right.score, right.federated.weight); + + match left_score.compare(&right_score) { + // the biggest score goes first + Ordering::Greater => true, + // break ties using query index + Ordering::Equal => left.query_index < right.query_index, + Ordering::Less => false, + } + }, + ) +} + +impl<'a> Iterator for SearchResultByQueryIter<'a> { + type Item = SearchResultByQueryIterItem<'a>; + + fn next(&mut self) -> Option { + let (docid, score) = self.it.next()?; + Some(SearchResultByQueryIterItem { + docid, + score, + federated: self.federated, + hit_maker: Rc::clone(&self.hit_maker), + query_index: self.query_index, + }) + } +} + +struct SearchHitByIndex { + hit: SearchHit, + score: Vec, + federated: Federated, + query_index: usize, +} + +struct SearchResultByIndex { + hits: Vec, + candidates: RoaringBitmap, + degraded: bool, + used_negative_operator: bool, +} + +pub fn perform_federated_search( + index_scheduler: &IndexScheduler, + queries: Vec, + federation: Federation, + features: RoFeatures, +) -> Result { + let before_search = std::time::Instant::now(); + + // this implementation partition the queries by index to guarantee an important property: + // - all the queries to a particular index use the same read transaction. + // This is an important property, otherwise we cannot guarantee the self-consistency of the results. + + // 1. partition queries by index + let mut queries_by_index: BTreeMap> = Default::default(); + for (query_index, federated_query) in queries.into_iter().enumerate() { + if let Some(pagination_field) = federated_query.has_pagination() { + /// FIXME: proper error + panic!("using pagination with a federated query") + } + + let (index_uid, query, federated) = federated_query.into_index_query_federated(); + + queries_by_index.entry(index_uid.into_inner()).or_default().push(QueryByIndex { + query, + federated: federated.unwrap_or_default(), + query_index, + }) + } + + // 2. perform queries, merge and make hits index by index + let required_hit_count = federation.limit + federation.offset; + // In step (2), semantic_hit_count will be set to Some(0) if any search kind uses semantic + // Then in step (3), we'll update its value if there is any semantic search + let mut semantic_hit_count = None; + let mut results_by_index = Vec::with_capacity(queries_by_index.len()); + for (index_uid, queries) in queries_by_index { + let index = index_scheduler.index(&index_uid).map_err(|err| { + let mut err = ResponseError::from(err); + // Patch the HTTP status code to 400 as it defaults to 404 for `index_not_found`, but + // here the resource not found is not part of the URL. + err.code = StatusCode::BAD_REQUEST; + err + })?; + + // Important: this is the only transaction we'll use for this index during this federated search + let rtxn = index.read_txn()?; + + // stuff we need for the hitmaker + let script_lang_map = index.script_language(&rtxn)?; + + let dictionary = index.dictionary(&rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); + let separators = index.allowed_separators(&rtxn)?; + let separators: Option> = + separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); + + // each query gets its individual cutoff + let cutoff = index.search_cutoff(&rtxn)?; + + let mut degraded = false; + let mut used_negative_operator = false; + let mut candidates = RoaringBitmap::new(); + + // 2.1. Compute all candidates for each query in the index + let mut results_by_query = Vec::with_capacity(queries.len()); + for QueryByIndex { query, federated, query_index } in queries { + // use an immediately invoked lambda to capture the result without returning from the function + + let res: Result<(), ResponseError> = (|| { + let search_kind = search_kind(&query, index_scheduler, &index, features)?; + + match search_kind { + SearchKind::KeywordOnly => {} + _ => semantic_hit_count = Some(0), + } + + let retrieve_vectors = RetrieveVectors::new(query.retrieve_vectors, features)?; + + let time_budget = match cutoff { + Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)), + None => TimeBudget::default(), + }; + + let (mut search, _is_finite_pagination, _max_total_hits, _offset) = + prepare_search(&index, &rtxn, &query, &search_kind, time_budget)?; + + search.scoring_strategy(milli::score_details::ScoringStrategy::Detailed); + search.offset(0); + search.limit(required_hit_count); + + let (result, _semantic_hit_count) = super::search_from_kind(search_kind, search)?; + let format = AttributesFormat { + attributes_to_retrieve: query.attributes_to_retrieve, + retrieve_vectors, + attributes_to_highlight: query.attributes_to_highlight, + attributes_to_crop: query.attributes_to_crop, + crop_length: query.crop_length, + crop_marker: query.crop_marker, + highlight_pre_tag: query.highlight_pre_tag, + highlight_post_tag: query.highlight_post_tag, + show_matches_position: query.show_matches_position, + sort: query.sort, + show_ranking_score: query.show_ranking_score, + show_ranking_score_details: query.show_ranking_score_details, + }; + + let milli::SearchResult { + matching_words, + candidates: query_candidates, + documents_ids, + document_scores, + degraded: query_degraded, + used_negative_operator: query_used_negative_operator, + } = result; + + candidates |= query_candidates; + degraded |= query_degraded; + used_negative_operator |= query_used_negative_operator; + + let tokenizer = HitMaker::tokenizer( + &script_lang_map, + dictionary.as_deref(), + separators.as_deref(), + ); + + let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer); + + let hit_maker = HitMaker::new(&index, &rtxn, format, formatter_builder)?; + + results_by_query.push(SearchResultByQuery { + federated, + hit_maker, + query_index, + documents_ids, + document_scores, + }); + Ok(()) + })(); + + if let Err(mut error) = res { + error.message = format!("Inside `.queries[{query_index}]`: {}", error.message); + return Err(error); + } + } + // 2.2. merge inside index + let mut documents_seen = RoaringBitmap::new(); + let merged_result: Result, ResponseError> = + merge_index_local_results(results_by_query) + // skip documents we've already seen & mark that we saw the current document + .filter(|SearchResultByQueryIterItem { docid, .. }| documents_seen.insert(*docid)) + .take(required_hit_count) + // 2.3 make hits + .map( + |SearchResultByQueryIterItem { + docid, + score, + federated, + hit_maker, + query_index, + }| { + let mut hit = hit_maker.make_hit(docid, &score)?; + let weighted_score = + ScoreDetails::global_score(score.iter()) * federated.weight; + + let _federation = serde_json::json!( + { + "indexUid": index_uid, + "sourceQuery": query_index, + "weightedRankingScore": weighted_score, + } + ); + hit.document.insert("_federation".to_string(), _federation); + Ok(SearchHitByIndex { hit, score, federated, query_index }) + }, + ) + .collect(); + + let merged_result = merged_result?; + results_by_index.push(SearchResultByIndex { + hits: merged_result, + candidates, + degraded, + used_negative_operator, + }); + } + + // 3. merge hits and metadata across indexes + // 3.1 merge metadata + let (estimated_total_hits, degraded, used_negative_operator) = { + let mut estimated_total_hits = 0; + let mut degraded = false; + let mut used_negative_operator = false; + + for SearchResultByIndex { + hits: _, + candidates, + degraded: degraded_by_index, + used_negative_operator: used_negative_operator_by_index, + } in &results_by_index + { + estimated_total_hits += candidates.len() as usize; + degraded |= *degraded_by_index; + used_negative_operator |= *used_negative_operator_by_index; + } + + (estimated_total_hits, degraded, used_negative_operator) + }; + + // 3.2 merge hits + let merged_hits: Vec<_> = merge_index_global_results(results_by_index) + .skip(federation.offset) + .take(federation.limit) + .inspect(|hit| { + if let Some(semantic_hit_count) = &mut semantic_hit_count { + if hit.score.iter().any(|score| matches!(&score, ScoreDetails::Vector(_))) { + *semantic_hit_count += 1; + } + } + }) + .map(|hit| hit.hit) + .collect(); + + let search_result = FederatedSearchResult { + hits: merged_hits, + processing_time_ms: before_search.elapsed().as_millis(), + hits_info: HitsInfo::OffsetLimit { + limit: federation.limit, + offset: federation.offset, + estimated_total_hits, + }, + semantic_hit_count, + degraded, + used_negative_operator, + }; + + Ok(search_result) +} diff --git a/milli/src/score_details.rs b/milli/src/score_details.rs index 0a9b77e2b..1efa3b8e6 100644 --- a/milli/src/score_details.rs +++ b/milli/src/score_details.rs @@ -425,9 +425,6 @@ pub struct Sort { impl PartialOrd for Sort { fn partial_cmp(&self, other: &Self) -> Option { - if self.field_name != other.field_name { - return None; - } if self.ascending != other.ascending { return None; } @@ -466,9 +463,6 @@ pub struct GeoSort { impl PartialOrd for GeoSort { fn partial_cmp(&self, other: &Self) -> Option { - if self.target_point != other.target_point { - return None; - } if self.ascending != other.ascending { return None; } diff --git a/milli/src/search/federated.rs b/milli/src/search/federated.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/milli/src/search/federated.rs @@ -0,0 +1 @@ + diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index 2102bf479..dce4329b7 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -174,6 +174,7 @@ impl<'a> Search<'a> { semantic: self.semantic.clone(), time_budget: self.time_budget.clone(), ranking_score_threshold: self.ranking_score_threshold, + input_candidates: self.input_candidates, }; let semantic = search.semantic.take(); diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 8ae1ebb0f..b50243d81 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -21,6 +21,7 @@ static LEVDIST1: Lazy = Lazy::new(|| LevBuilder::new(1, true)); static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); pub mod facet; +pub mod federated; mod fst_utils; pub mod hybrid; pub mod new; @@ -52,6 +53,7 @@ pub struct Search<'a> { semantic: Option, time_budget: TimeBudget, ranking_score_threshold: Option, + input_candidates: Option<&'a RoaringBitmap>, } impl<'a> Search<'a> { @@ -74,6 +76,7 @@ impl<'a> Search<'a> { semantic: None, time_budget: TimeBudget::max(), ranking_score_threshold: None, + input_candidates: None, } } @@ -137,6 +140,11 @@ impl<'a> Search<'a> { self } + pub fn input_candidates(&mut self, input_candidates: &'a RoaringBitmap) -> &mut Search<'a> { + self.input_candidates = Some(input_candidates); + self + } + #[cfg(test)] pub fn geo_sort_strategy(&mut self, strategy: new::GeoSortStrategy) -> &mut Search<'a> { self.geo_strategy = strategy; @@ -163,7 +171,11 @@ impl<'a> Search<'a> { pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { if has_vector_search { let ctx = SearchContext::new(self.index, self.rtxn)?; - filtered_universe(ctx.index, ctx.txn, &self.filter) + let filtered_universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; + Ok(match self.input_candidates { + Some(input_candidates) => filtered_universe & input_candidates, + None => filtered_universe, + }) } else { Ok(self.execute()?.candidates) } @@ -189,7 +201,10 @@ impl<'a> Search<'a> { } } - let universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; + let mut universe = filtered_universe(ctx.index, ctx.txn, &self.filter)?; + if let Some(input_candidates) = self.input_candidates { + universe &= input_candidates; + } let PartialSearchResult { located_query_terms, candidates, @@ -272,6 +287,7 @@ impl fmt::Debug for Search<'_> { semantic, time_budget, ranking_score_threshold, + input_candidates: _, } = self; f.debug_struct("Search") .field("query", query) diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 77ae5fcd5..88f1a10f1 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -46,7 +46,7 @@ impl<'m> MatcherBuilder<'m> { self } - pub fn build<'t>(&'m self, text: &'t str) -> Matcher<'t, 'm> { + pub fn build<'t>(&self, text: &'t str) -> Matcher<'t, 'm, '_> { let crop_marker = match &self.crop_marker { Some(marker) => marker.as_str(), None => DEFAULT_CROP_MARKER, @@ -105,19 +105,19 @@ pub struct MatchBounds { pub length: usize, } -/// Structure used to analize a string, compute words that match, +/// Structure used to analyze a string, compute words that match, /// and format the source string, returning a highlighted and cropped sub-string. -pub struct Matcher<'t, 'm> { +pub struct Matcher<'t, 'tokenizer, 'b> { text: &'t str, - matching_words: &'m MatchingWords, - tokenizer: &'m Tokenizer<'m>, - crop_marker: &'m str, - highlight_prefix: &'m str, - highlight_suffix: &'m str, + matching_words: &'b MatchingWords, + tokenizer: &'b Tokenizer<'tokenizer>, + crop_marker: &'b str, + highlight_prefix: &'b str, + highlight_suffix: &'b str, matches: Option<(Vec>, Vec)>, } -impl<'t> Matcher<'t, '_> { +impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_> { /// Iterates over tokens and save any of them that matches the query. fn compute_matches(&mut self) -> &mut Self { /// some words are counted as matches only if they are close together and in the good order,