diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 54dc6b0b7..2a684817a 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -1,6 +1,6 @@ use core::fmt; use std::cmp::min; -use std::collections::{BTreeMap, BTreeSet, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -913,8 +913,13 @@ pub fn perform_search( show_ranking_score_details, }; - let documents = - make_hits(index, &rtxn, format, matching_words, documents_ids, document_scores)?; + let documents = make_hits( + index, + &rtxn, + format, + matching_words, + documents_ids.iter().copied().zip(document_scores.iter()), + )?; let number_of_hits = min(candidates.len() as usize, max_total_hits); let hits_info = if is_finite_pagination { @@ -1043,131 +1048,191 @@ impl RetrieveVectors { } } -fn make_hits( - index: &Index, - rtxn: &RoTxn<'_>, - format: AttributesFormat, - matching_words: milli::MatchingWords, - documents_ids: Vec, - document_scores: Vec>, -) -> Result, MeilisearchHttpError> { - let fields_ids_map = index.fields_ids_map(rtxn).unwrap(); - let displayed_ids = - index.displayed_fields_ids(rtxn)?.map(|fields| fields.into_iter().collect::>()); +struct HitMaker<'a> { + index: &'a Index, + rtxn: &'a RoTxn<'a>, + fields_ids_map: FieldsIdsMap, + displayed_ids: BTreeSet, + vectors_fid: Option, + retrieve_vectors: RetrieveVectors, + to_retrieve_ids: BTreeSet, + embedding_configs: Vec, + formatter_builder: MatcherBuilder<'a>, + formatted_options: BTreeMap, + show_ranking_score: bool, + show_ranking_score_details: bool, + sort: Option>, + show_matches_position: bool, +} - let vectors_fid = fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME); +impl<'a> HitMaker<'a> { + pub fn tokenizer<'b>( + script_lang_map: &'b HashMap>, + dictionary: Option<&'b [&'b str]>, + separators: Option<&'b [&'b str]>, + ) -> milli::tokenizer::Tokenizer<'b> { + let mut tokenizer_builder = TokenizerBuilder::default(); + tokenizer_builder.create_char_map(true); + if !script_lang_map.is_empty() { + tokenizer_builder.allow_list(script_lang_map); + } - let vectors_is_hidden = match (&displayed_ids, vectors_fid) { - // displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid - (None, _) => false, - // displayed_ids is a finite list, and `_vectors` cannot be part of it because it is not an existing field - (Some(_), None) => true, - // displayed_ids is a finit list, so hide if `_vectors` is not part of it - (Some(map), Some(vectors_fid)) => map.contains(&vectors_fid), - }; + if let Some(separators) = separators { + tokenizer_builder.separators(separators); + } - let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors { - if vectors_is_hidden { - RetrieveVectors::Hide + if let Some(dictionary) = dictionary { + tokenizer_builder.words_dict(dictionary); + } + + tokenizer_builder.into_tokenizer() + } + + pub fn formatter_builder( + matching_words: milli::MatchingWords, + tokenizer: milli::tokenizer::Tokenizer<'_>, + ) -> MatcherBuilder<'_> { + let formatter_builder = MatcherBuilder::new(matching_words, tokenizer); + + formatter_builder + } + + pub fn new( + index: &'a Index, + rtxn: &'a RoTxn<'a>, + format: AttributesFormat, + mut formatter_builder: MatcherBuilder<'a>, + ) -> Result { + formatter_builder.crop_marker(format.crop_marker); + formatter_builder.highlight_prefix(format.highlight_pre_tag); + formatter_builder.highlight_suffix(format.highlight_post_tag); + + let fields_ids_map = index.fields_ids_map(rtxn)?; + let displayed_ids = index + .displayed_fields_ids(rtxn)? + .map(|fields| fields.into_iter().collect::>()); + + let vectors_fid = + fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME); + + let vectors_is_hidden = match (&displayed_ids, vectors_fid) { + // displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid + (None, _) => false, + // displayed_ids is a finite list, and `_vectors` cannot be part of it because it is not an existing field + (Some(_), None) => true, + // displayed_ids is a finit list, so hide if `_vectors` is not part of it + (Some(map), Some(vectors_fid)) => map.contains(&vectors_fid), + }; + + let displayed_ids = + displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); + + let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors { + if vectors_is_hidden { + RetrieveVectors::Hide + } else { + RetrieveVectors::Retrieve + } } else { - RetrieveVectors::Retrieve - } - } else { - format.retrieve_vectors - }; + format.retrieve_vectors + }; - let displayed_ids = - displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect()); - let fids = |attrs: &BTreeSet| { - let mut ids = BTreeSet::new(); - for attr in attrs { - if attr == "*" { - ids.clone_from(&displayed_ids); - break; + let fids = |attrs: &BTreeSet| { + let mut ids = BTreeSet::new(); + for attr in attrs { + if attr == "*" { + ids.clone_from(&displayed_ids); + break; + } + + if let Some(id) = fields_ids_map.id(attr) { + ids.insert(id); + } } + ids + }; + let to_retrieve_ids: BTreeSet<_> = format + .attributes_to_retrieve + .as_ref() + .map(fids) + .unwrap_or_else(|| displayed_ids.clone()) + .intersection(&displayed_ids) + .cloned() + .collect(); - if let Some(id) = fields_ids_map.id(attr) { - ids.insert(id); - } - } - ids - }; - let to_retrieve_ids: BTreeSet<_> = format - .attributes_to_retrieve - .as_ref() - .map(fids) - .unwrap_or_else(|| displayed_ids.clone()) - .intersection(&displayed_ids) - .cloned() - .collect(); + let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); + let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); + let formatted_options = compute_formatted_options( + &attr_to_highlight, + &attr_to_crop, + format.crop_length, + &to_retrieve_ids, + &fields_ids_map, + &displayed_ids, + ); - let attr_to_highlight = format.attributes_to_highlight.unwrap_or_default(); - let attr_to_crop = format.attributes_to_crop.unwrap_or_default(); - let formatted_options = compute_formatted_options( - &attr_to_highlight, - &attr_to_crop, - format.crop_length, - &to_retrieve_ids, - &fields_ids_map, - &displayed_ids, - ); - let mut tokenizer_builder = TokenizerBuilder::default(); - tokenizer_builder.create_char_map(true); - let script_lang_map = index.script_language(rtxn)?; - if !script_lang_map.is_empty() { - tokenizer_builder.allow_list(&script_lang_map); + let embedding_configs = index.embedding_configs(rtxn)?; + + Ok(Self { + index, + rtxn, + fields_ids_map, + displayed_ids, + vectors_fid, + retrieve_vectors, + to_retrieve_ids, + embedding_configs, + formatter_builder, + formatted_options, + show_ranking_score: format.show_ranking_score, + show_ranking_score_details: format.show_ranking_score_details, + show_matches_position: format.show_matches_position, + sort: format.sort, + }) } - let separators = index.allowed_separators(rtxn)?; - let separators: Option> = - separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref separators) = separators { - tokenizer_builder.separators(separators); - } - let dictionary = index.dictionary(rtxn)?; - let dictionary: Option> = - dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); - if let Some(ref dictionary) = dictionary { - tokenizer_builder.words_dict(dictionary); - } - let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build()); - formatter_builder.crop_marker(format.crop_marker); - formatter_builder.highlight_prefix(format.highlight_pre_tag); - formatter_builder.highlight_suffix(format.highlight_post_tag); - let mut documents = Vec::new(); - let embedding_configs = index.embedding_configs(rtxn)?; - let documents_iter = index.documents(rtxn, documents_ids)?; - for ((id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { + + pub fn make_hit( + &self, + id: u32, + score: &[ScoreDetails], + ) -> Result { + let (_, obkv) = + self.index.iter_documents(self.rtxn, std::iter::once(id))?.next().unwrap()?; + // First generate a document with all the displayed fields - let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; + let displayed_document = make_document(&self.displayed_ids, &self.fields_ids_map, obkv)?; let add_vectors_fid = - vectors_fid.filter(|_fid| retrieve_vectors == RetrieveVectors::Retrieve); + self.vectors_fid.filter(|_fid| self.retrieve_vectors == RetrieveVectors::Retrieve); // select the attributes to retrieve - let attributes_to_retrieve = to_retrieve_ids + let attributes_to_retrieve = self + .to_retrieve_ids .iter() // skip the vectors_fid if RetrieveVectors::Hide - .filter(|fid| match vectors_fid { + .filter(|fid| match self.vectors_fid { Some(vectors_fid) => { - !(retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid) + !(self.retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid) } None => true, }) // need to retrieve the existing `_vectors` field if the `RetrieveVectors::Retrieve` .chain(add_vectors_fid.iter()) - .map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); + .map(|&fid| self.fields_ids_map.name(fid).expect("Missing field name")); + let mut document = permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); - if retrieve_vectors == RetrieveVectors::Retrieve { + if self.retrieve_vectors == RetrieveVectors::Retrieve { // Clippy is wrong #[allow(clippy::manual_unwrap_or_default)] let mut vectors = match document.remove("_vectors") { Some(Value::Object(map)) => map, _ => Default::default(), }; - for (name, vector) in index.embeddings(rtxn, id)? { - let user_provided = embedding_configs + for (name, vector) in self.index.embeddings(self.rtxn, id)? { + let user_provided = self + .embedding_configs .iter() .find(|conf| conf.name == name) .is_some_and(|conf| conf.user_provided.contains(id)); @@ -1180,21 +1245,21 @@ fn make_hits( let (matches_position, formatted) = format_fields( &displayed_document, - &fields_ids_map, - &formatter_builder, - &formatted_options, - format.show_matches_position, - &displayed_ids, + &self.fields_ids_map, + &self.formatter_builder, + &self.formatted_options, + self.show_matches_position, + &self.displayed_ids, )?; - if let Some(sort) = format.sort.as_ref() { + if let Some(sort) = self.sort.as_ref() { insert_geo_distance(sort, &mut document); } let ranking_score = - format.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); + self.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); let ranking_score_details = - format.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); + self.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); let hit = SearchHit { document, @@ -1203,7 +1268,38 @@ fn make_hits( ranking_score_details, ranking_score, }; - documents.push(hit); + + Ok(hit) + } +} + +fn make_hits<'a>( + index: &Index, + rtxn: &RoTxn<'_>, + format: AttributesFormat, + matching_words: milli::MatchingWords, + documents_ids_scores: impl Iterator)> + 'a, +) -> Result, MeilisearchHttpError> { + let mut documents = Vec::new(); + + let script_lang_map = index.script_language(rtxn)?; + + let dictionary = index.dictionary(rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); + let separators = index.allowed_separators(rtxn)?; + let separators: Option> = + separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); + + let tokenizer = + HitMaker::tokenizer(&script_lang_map, dictionary.as_deref(), separators.as_deref()); + + let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer); + + let hit_maker = HitMaker::new(index, rtxn, format, formatter_builder)?; + + for (id, score) in documents_ids_scores { + documents.push(hit_maker.make_hit(id, score)?); } Ok(documents) } @@ -1319,7 +1415,13 @@ pub fn perform_similar( show_ranking_score_details, }; - let hits = make_hits(index, &rtxn, format, Default::default(), documents_ids, document_scores)?; + let hits = make_hits( + index, + &rtxn, + format, + Default::default(), + documents_ids.iter().copied().zip(document_scores.iter()), + )?; let max_total_hits = index .pagination_max_total_hits(&rtxn) @@ -1492,10 +1594,10 @@ fn make_document( Ok(document) } -fn format_fields<'a>( +fn format_fields( document: &Document, field_ids_map: &FieldsIdsMap, - builder: &'a MatcherBuilder<'a>, + builder: &MatcherBuilder<'_>, formatted_options: &BTreeMap, compute_matches: bool, displayable_ids: &BTreeSet, @@ -1550,9 +1652,9 @@ fn format_fields<'a>( Ok((matches_position, document)) } -fn format_value<'a>( +fn format_value( value: Value, - builder: &'a MatcherBuilder<'a>, + builder: &MatcherBuilder<'_>, format_options: Option, infos: &mut Vec, compute_matches: bool,