search: introduce hitmaker

This commit is contained in:
Louis Dureuil 2024-07-11 16:35:59 +02:00
parent 2123d76089
commit d3a6d2a6fa
No known key found for this signature in database

View File

@ -1,6 +1,6 @@
use core::fmt; use core::fmt;
use std::cmp::min; use std::cmp::min;
use std::collections::{BTreeMap, BTreeSet, HashSet}; use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::str::FromStr; use std::str::FromStr;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
@ -913,8 +913,13 @@ pub fn perform_search(
show_ranking_score_details, show_ranking_score_details,
}; };
let documents = let documents = make_hits(
make_hits(index, &rtxn, format, matching_words, documents_ids, document_scores)?; index,
&rtxn,
format,
matching_words,
documents_ids.iter().copied().zip(document_scores.iter()),
)?;
let number_of_hits = min(candidates.len() as usize, max_total_hits); let number_of_hits = min(candidates.len() as usize, max_total_hits);
let hits_info = if is_finite_pagination { let hits_info = if is_finite_pagination {
@ -1043,19 +1048,72 @@ impl RetrieveVectors {
} }
} }
fn make_hits( struct HitMaker<'a> {
index: &Index, index: &'a Index,
rtxn: &RoTxn<'_>, rtxn: &'a RoTxn<'a>,
format: AttributesFormat, fields_ids_map: FieldsIdsMap,
matching_words: milli::MatchingWords, displayed_ids: BTreeSet<FieldId>,
documents_ids: Vec<u32>, vectors_fid: Option<FieldId>,
document_scores: Vec<Vec<ScoreDetails>>, retrieve_vectors: RetrieveVectors,
) -> Result<Vec<SearchHit>, MeilisearchHttpError> { to_retrieve_ids: BTreeSet<FieldId>,
let fields_ids_map = index.fields_ids_map(rtxn).unwrap(); embedding_configs: Vec<milli::index::IndexEmbeddingConfig>,
let displayed_ids = formatter_builder: MatcherBuilder<'a>,
index.displayed_fields_ids(rtxn)?.map(|fields| fields.into_iter().collect::<BTreeSet<_>>()); formatted_options: BTreeMap<FieldId, FormatOptions>,
show_ranking_score: bool,
show_ranking_score_details: bool,
sort: Option<Vec<String>>,
show_matches_position: bool,
}
let vectors_fid = fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME); impl<'a> HitMaker<'a> {
pub fn tokenizer<'b>(
script_lang_map: &'b HashMap<milli::tokenizer::Script, Vec<milli::tokenizer::Language>>,
dictionary: Option<&'b [&'b str]>,
separators: Option<&'b [&'b str]>,
) -> milli::tokenizer::Tokenizer<'b> {
let mut tokenizer_builder = TokenizerBuilder::default();
tokenizer_builder.create_char_map(true);
if !script_lang_map.is_empty() {
tokenizer_builder.allow_list(script_lang_map);
}
if let Some(separators) = separators {
tokenizer_builder.separators(separators);
}
if let Some(dictionary) = dictionary {
tokenizer_builder.words_dict(dictionary);
}
tokenizer_builder.into_tokenizer()
}
pub fn formatter_builder(
matching_words: milli::MatchingWords,
tokenizer: milli::tokenizer::Tokenizer<'_>,
) -> MatcherBuilder<'_> {
let formatter_builder = MatcherBuilder::new(matching_words, tokenizer);
formatter_builder
}
pub fn new(
index: &'a Index,
rtxn: &'a RoTxn<'a>,
format: AttributesFormat,
mut formatter_builder: MatcherBuilder<'a>,
) -> Result<Self, MeilisearchHttpError> {
formatter_builder.crop_marker(format.crop_marker);
formatter_builder.highlight_prefix(format.highlight_pre_tag);
formatter_builder.highlight_suffix(format.highlight_post_tag);
let fields_ids_map = index.fields_ids_map(rtxn)?;
let displayed_ids = index
.displayed_fields_ids(rtxn)?
.map(|fields| fields.into_iter().collect::<BTreeSet<_>>());
let vectors_fid =
fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME);
let vectors_is_hidden = match (&displayed_ids, vectors_fid) { let vectors_is_hidden = match (&displayed_ids, vectors_fid) {
// displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid // displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid
@ -1066,6 +1124,9 @@ fn make_hits(
(Some(map), Some(vectors_fid)) => map.contains(&vectors_fid), (Some(map), Some(vectors_fid)) => map.contains(&vectors_fid),
}; };
let displayed_ids =
displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect());
let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors { let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors {
if vectors_is_hidden { if vectors_is_hidden {
RetrieveVectors::Hide RetrieveVectors::Hide
@ -1076,8 +1137,6 @@ fn make_hits(
format.retrieve_vectors format.retrieve_vectors
}; };
let displayed_ids =
displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect());
let fids = |attrs: &BTreeSet<String>| { let fids = |attrs: &BTreeSet<String>| {
let mut ids = BTreeSet::new(); let mut ids = BTreeSet::new();
for attr in attrs { for attr in attrs {
@ -1111,63 +1170,69 @@ fn make_hits(
&fields_ids_map, &fields_ids_map,
&displayed_ids, &displayed_ids,
); );
let mut tokenizer_builder = TokenizerBuilder::default();
tokenizer_builder.create_char_map(true);
let script_lang_map = index.script_language(rtxn)?;
if !script_lang_map.is_empty() {
tokenizer_builder.allow_list(&script_lang_map);
}
let separators = index.allowed_separators(rtxn)?;
let separators: Option<Vec<_>> =
separators.as_ref().map(|x| x.iter().map(String::as_str).collect());
if let Some(ref separators) = separators {
tokenizer_builder.separators(separators);
}
let dictionary = index.dictionary(rtxn)?;
let dictionary: Option<Vec<_>> =
dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect());
if let Some(ref dictionary) = dictionary {
tokenizer_builder.words_dict(dictionary);
}
let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build());
formatter_builder.crop_marker(format.crop_marker);
formatter_builder.highlight_prefix(format.highlight_pre_tag);
formatter_builder.highlight_suffix(format.highlight_post_tag);
let mut documents = Vec::new();
let embedding_configs = index.embedding_configs(rtxn)?; let embedding_configs = index.embedding_configs(rtxn)?;
let documents_iter = index.documents(rtxn, documents_ids)?;
for ((id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { Ok(Self {
index,
rtxn,
fields_ids_map,
displayed_ids,
vectors_fid,
retrieve_vectors,
to_retrieve_ids,
embedding_configs,
formatter_builder,
formatted_options,
show_ranking_score: format.show_ranking_score,
show_ranking_score_details: format.show_ranking_score_details,
show_matches_position: format.show_matches_position,
sort: format.sort,
})
}
pub fn make_hit(
&self,
id: u32,
score: &[ScoreDetails],
) -> Result<SearchHit, MeilisearchHttpError> {
let (_, obkv) =
self.index.iter_documents(self.rtxn, std::iter::once(id))?.next().unwrap()?;
// First generate a document with all the displayed fields // First generate a document with all the displayed fields
let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?; let displayed_document = make_document(&self.displayed_ids, &self.fields_ids_map, obkv)?;
let add_vectors_fid = let add_vectors_fid =
vectors_fid.filter(|_fid| retrieve_vectors == RetrieveVectors::Retrieve); self.vectors_fid.filter(|_fid| self.retrieve_vectors == RetrieveVectors::Retrieve);
// select the attributes to retrieve // select the attributes to retrieve
let attributes_to_retrieve = to_retrieve_ids let attributes_to_retrieve = self
.to_retrieve_ids
.iter() .iter()
// skip the vectors_fid if RetrieveVectors::Hide // skip the vectors_fid if RetrieveVectors::Hide
.filter(|fid| match vectors_fid { .filter(|fid| match self.vectors_fid {
Some(vectors_fid) => { Some(vectors_fid) => {
!(retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid) !(self.retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid)
} }
None => true, None => true,
}) })
// need to retrieve the existing `_vectors` field if the `RetrieveVectors::Retrieve` // need to retrieve the existing `_vectors` field if the `RetrieveVectors::Retrieve`
.chain(add_vectors_fid.iter()) .chain(add_vectors_fid.iter())
.map(|&fid| fields_ids_map.name(fid).expect("Missing field name")); .map(|&fid| self.fields_ids_map.name(fid).expect("Missing field name"));
let mut document = let mut document =
permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve); permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve);
if retrieve_vectors == RetrieveVectors::Retrieve { if self.retrieve_vectors == RetrieveVectors::Retrieve {
// Clippy is wrong // Clippy is wrong
#[allow(clippy::manual_unwrap_or_default)] #[allow(clippy::manual_unwrap_or_default)]
let mut vectors = match document.remove("_vectors") { let mut vectors = match document.remove("_vectors") {
Some(Value::Object(map)) => map, Some(Value::Object(map)) => map,
_ => Default::default(), _ => Default::default(),
}; };
for (name, vector) in index.embeddings(rtxn, id)? { for (name, vector) in self.index.embeddings(self.rtxn, id)? {
let user_provided = embedding_configs let user_provided = self
.embedding_configs
.iter() .iter()
.find(|conf| conf.name == name) .find(|conf| conf.name == name)
.is_some_and(|conf| conf.user_provided.contains(id)); .is_some_and(|conf| conf.user_provided.contains(id));
@ -1180,21 +1245,21 @@ fn make_hits(
let (matches_position, formatted) = format_fields( let (matches_position, formatted) = format_fields(
&displayed_document, &displayed_document,
&fields_ids_map, &self.fields_ids_map,
&formatter_builder, &self.formatter_builder,
&formatted_options, &self.formatted_options,
format.show_matches_position, self.show_matches_position,
&displayed_ids, &self.displayed_ids,
)?; )?;
if let Some(sort) = format.sort.as_ref() { if let Some(sort) = self.sort.as_ref() {
insert_geo_distance(sort, &mut document); insert_geo_distance(sort, &mut document);
} }
let ranking_score = let ranking_score =
format.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); self.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
let ranking_score_details = let ranking_score_details =
format.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter())); self.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter()));
let hit = SearchHit { let hit = SearchHit {
document, document,
@ -1203,7 +1268,38 @@ fn make_hits(
ranking_score_details, ranking_score_details,
ranking_score, ranking_score,
}; };
documents.push(hit);
Ok(hit)
}
}
fn make_hits<'a>(
index: &Index,
rtxn: &RoTxn<'_>,
format: AttributesFormat,
matching_words: milli::MatchingWords,
documents_ids_scores: impl Iterator<Item = (u32, &'a Vec<ScoreDetails>)> + 'a,
) -> Result<Vec<SearchHit>, MeilisearchHttpError> {
let mut documents = Vec::new();
let script_lang_map = index.script_language(rtxn)?;
let dictionary = index.dictionary(rtxn)?;
let dictionary: Option<Vec<_>> =
dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect());
let separators = index.allowed_separators(rtxn)?;
let separators: Option<Vec<_>> =
separators.as_ref().map(|x| x.iter().map(String::as_str).collect());
let tokenizer =
HitMaker::tokenizer(&script_lang_map, dictionary.as_deref(), separators.as_deref());
let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer);
let hit_maker = HitMaker::new(index, rtxn, format, formatter_builder)?;
for (id, score) in documents_ids_scores {
documents.push(hit_maker.make_hit(id, score)?);
} }
Ok(documents) Ok(documents)
} }
@ -1319,7 +1415,13 @@ pub fn perform_similar(
show_ranking_score_details, show_ranking_score_details,
}; };
let hits = make_hits(index, &rtxn, format, Default::default(), documents_ids, document_scores)?; let hits = make_hits(
index,
&rtxn,
format,
Default::default(),
documents_ids.iter().copied().zip(document_scores.iter()),
)?;
let max_total_hits = index let max_total_hits = index
.pagination_max_total_hits(&rtxn) .pagination_max_total_hits(&rtxn)
@ -1492,10 +1594,10 @@ fn make_document(
Ok(document) Ok(document)
} }
fn format_fields<'a>( fn format_fields(
document: &Document, document: &Document,
field_ids_map: &FieldsIdsMap, field_ids_map: &FieldsIdsMap,
builder: &'a MatcherBuilder<'a>, builder: &MatcherBuilder<'_>,
formatted_options: &BTreeMap<FieldId, FormatOptions>, formatted_options: &BTreeMap<FieldId, FormatOptions>,
compute_matches: bool, compute_matches: bool,
displayable_ids: &BTreeSet<FieldId>, displayable_ids: &BTreeSet<FieldId>,
@ -1550,9 +1652,9 @@ fn format_fields<'a>(
Ok((matches_position, document)) Ok((matches_position, document))
} }
fn format_value<'a>( fn format_value(
value: Value, value: Value,
builder: &'a MatcherBuilder<'a>, builder: &MatcherBuilder<'_>,
format_options: Option<FormatOptions>, format_options: Option<FormatOptions>,
infos: &mut Vec<MatchBounds>, infos: &mut Vec<MatchBounds>,
compute_matches: bool, compute_matches: bool,