search: introduce hitmaker

This commit is contained in:
Louis Dureuil 2024-07-11 16:35:59 +02:00
parent 2123d76089
commit d3a6d2a6fa
No known key found for this signature in database

View File

@ -1,6 +1,6 @@
use core::fmt;
use std::cmp::min;
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
@ -913,8 +913,13 @@ pub fn perform_search(
show_ranking_score_details,
};
let documents =
make_hits(index, &rtxn, format, matching_words, documents_ids, document_scores)?;
let documents = make_hits(
index,
&rtxn,
format,
matching_words,
documents_ids.iter().copied().zip(document_scores.iter()),
)?;
let number_of_hits = min(candidates.len() as usize, max_total_hits);
let hits_info = if is_finite_pagination {
@ -1043,19 +1048,72 @@ impl RetrieveVectors {
}
}
fn make_hits(
index: &Index,
rtxn: &RoTxn<'_>,
format: AttributesFormat,
matching_words: milli::MatchingWords,
documents_ids: Vec<u32>,
document_scores: Vec<Vec<ScoreDetails>>,
) -> Result<Vec<SearchHit>, MeilisearchHttpError> {
let fields_ids_map = index.fields_ids_map(rtxn).unwrap();
let displayed_ids =
index.displayed_fields_ids(rtxn)?.map(|fields| fields.into_iter().collect::<BTreeSet<_>>());
struct HitMaker<'a> {
index: &'a Index,
rtxn: &'a RoTxn<'a>,
fields_ids_map: FieldsIdsMap,
displayed_ids: BTreeSet<FieldId>,
vectors_fid: Option<FieldId>,
retrieve_vectors: RetrieveVectors,
to_retrieve_ids: BTreeSet<FieldId>,
embedding_configs: Vec<milli::index::IndexEmbeddingConfig>,
formatter_builder: MatcherBuilder<'a>,
formatted_options: BTreeMap<FieldId, FormatOptions>,
show_ranking_score: bool,
show_ranking_score_details: bool,
sort: Option<Vec<String>>,
show_matches_position: bool,
}
let vectors_fid = fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME);
impl<'a> HitMaker<'a> {
pub fn tokenizer<'b>(
script_lang_map: &'b HashMap<milli::tokenizer::Script, Vec<milli::tokenizer::Language>>,
dictionary: Option<&'b [&'b str]>,
separators: Option<&'b [&'b str]>,
) -> milli::tokenizer::Tokenizer<'b> {
let mut tokenizer_builder = TokenizerBuilder::default();
tokenizer_builder.create_char_map(true);
if !script_lang_map.is_empty() {
tokenizer_builder.allow_list(script_lang_map);
}
if let Some(separators) = separators {
tokenizer_builder.separators(separators);
}
if let Some(dictionary) = dictionary {
tokenizer_builder.words_dict(dictionary);
}
tokenizer_builder.into_tokenizer()
}
pub fn formatter_builder(
matching_words: milli::MatchingWords,
tokenizer: milli::tokenizer::Tokenizer<'_>,
) -> MatcherBuilder<'_> {
let formatter_builder = MatcherBuilder::new(matching_words, tokenizer);
formatter_builder
}
pub fn new(
index: &'a Index,
rtxn: &'a RoTxn<'a>,
format: AttributesFormat,
mut formatter_builder: MatcherBuilder<'a>,
) -> Result<Self, MeilisearchHttpError> {
formatter_builder.crop_marker(format.crop_marker);
formatter_builder.highlight_prefix(format.highlight_pre_tag);
formatter_builder.highlight_suffix(format.highlight_post_tag);
let fields_ids_map = index.fields_ids_map(rtxn)?;
let displayed_ids = index
.displayed_fields_ids(rtxn)?
.map(|fields| fields.into_iter().collect::<BTreeSet<_>>());
let vectors_fid =
fields_ids_map.id(milli::vector::parsed_vectors::RESERVED_VECTORS_FIELD_NAME);
let vectors_is_hidden = match (&displayed_ids, vectors_fid) {
// displayed_ids is a wildcard, so `_vectors` can be displayed regardless of its fid
@ -1066,6 +1124,9 @@ fn make_hits(
(Some(map), Some(vectors_fid)) => map.contains(&vectors_fid),
};
let displayed_ids =
displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect());
let retrieve_vectors = if let RetrieveVectors::Retrieve = format.retrieve_vectors {
if vectors_is_hidden {
RetrieveVectors::Hide
@ -1076,8 +1137,6 @@ fn make_hits(
format.retrieve_vectors
};
let displayed_ids =
displayed_ids.unwrap_or_else(|| fields_ids_map.iter().map(|(id, _)| id).collect());
let fids = |attrs: &BTreeSet<String>| {
let mut ids = BTreeSet::new();
for attr in attrs {
@ -1111,63 +1170,69 @@ fn make_hits(
&fields_ids_map,
&displayed_ids,
);
let mut tokenizer_builder = TokenizerBuilder::default();
tokenizer_builder.create_char_map(true);
let script_lang_map = index.script_language(rtxn)?;
if !script_lang_map.is_empty() {
tokenizer_builder.allow_list(&script_lang_map);
}
let separators = index.allowed_separators(rtxn)?;
let separators: Option<Vec<_>> =
separators.as_ref().map(|x| x.iter().map(String::as_str).collect());
if let Some(ref separators) = separators {
tokenizer_builder.separators(separators);
}
let dictionary = index.dictionary(rtxn)?;
let dictionary: Option<Vec<_>> =
dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect());
if let Some(ref dictionary) = dictionary {
tokenizer_builder.words_dict(dictionary);
}
let mut formatter_builder = MatcherBuilder::new(matching_words, tokenizer_builder.build());
formatter_builder.crop_marker(format.crop_marker);
formatter_builder.highlight_prefix(format.highlight_pre_tag);
formatter_builder.highlight_suffix(format.highlight_post_tag);
let mut documents = Vec::new();
let embedding_configs = index.embedding_configs(rtxn)?;
let documents_iter = index.documents(rtxn, documents_ids)?;
for ((id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
Ok(Self {
index,
rtxn,
fields_ids_map,
displayed_ids,
vectors_fid,
retrieve_vectors,
to_retrieve_ids,
embedding_configs,
formatter_builder,
formatted_options,
show_ranking_score: format.show_ranking_score,
show_ranking_score_details: format.show_ranking_score_details,
show_matches_position: format.show_matches_position,
sort: format.sort,
})
}
pub fn make_hit(
&self,
id: u32,
score: &[ScoreDetails],
) -> Result<SearchHit, MeilisearchHttpError> {
let (_, obkv) =
self.index.iter_documents(self.rtxn, std::iter::once(id))?.next().unwrap()?;
// First generate a document with all the displayed fields
let displayed_document = make_document(&displayed_ids, &fields_ids_map, obkv)?;
let displayed_document = make_document(&self.displayed_ids, &self.fields_ids_map, obkv)?;
let add_vectors_fid =
vectors_fid.filter(|_fid| retrieve_vectors == RetrieveVectors::Retrieve);
self.vectors_fid.filter(|_fid| self.retrieve_vectors == RetrieveVectors::Retrieve);
// select the attributes to retrieve
let attributes_to_retrieve = to_retrieve_ids
let attributes_to_retrieve = self
.to_retrieve_ids
.iter()
// skip the vectors_fid if RetrieveVectors::Hide
.filter(|fid| match vectors_fid {
.filter(|fid| match self.vectors_fid {
Some(vectors_fid) => {
!(retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid)
!(self.retrieve_vectors == RetrieveVectors::Hide && **fid == vectors_fid)
}
None => true,
})
// need to retrieve the existing `_vectors` field if the `RetrieveVectors::Retrieve`
.chain(add_vectors_fid.iter())
.map(|&fid| fields_ids_map.name(fid).expect("Missing field name"));
.map(|&fid| self.fields_ids_map.name(fid).expect("Missing field name"));
let mut document =
permissive_json_pointer::select_values(&displayed_document, attributes_to_retrieve);
if retrieve_vectors == RetrieveVectors::Retrieve {
if self.retrieve_vectors == RetrieveVectors::Retrieve {
// Clippy is wrong
#[allow(clippy::manual_unwrap_or_default)]
let mut vectors = match document.remove("_vectors") {
Some(Value::Object(map)) => map,
_ => Default::default(),
};
for (name, vector) in index.embeddings(rtxn, id)? {
let user_provided = embedding_configs
for (name, vector) in self.index.embeddings(self.rtxn, id)? {
let user_provided = self
.embedding_configs
.iter()
.find(|conf| conf.name == name)
.is_some_and(|conf| conf.user_provided.contains(id));
@ -1180,21 +1245,21 @@ fn make_hits(
let (matches_position, formatted) = format_fields(
&displayed_document,
&fields_ids_map,
&formatter_builder,
&formatted_options,
format.show_matches_position,
&displayed_ids,
&self.fields_ids_map,
&self.formatter_builder,
&self.formatted_options,
self.show_matches_position,
&self.displayed_ids,
)?;
if let Some(sort) = format.sort.as_ref() {
if let Some(sort) = self.sort.as_ref() {
insert_geo_distance(sort, &mut document);
}
let ranking_score =
format.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
self.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
let ranking_score_details =
format.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter()));
self.show_ranking_score_details.then(|| ScoreDetails::to_json_map(score.iter()));
let hit = SearchHit {
document,
@ -1203,7 +1268,38 @@ fn make_hits(
ranking_score_details,
ranking_score,
};
documents.push(hit);
Ok(hit)
}
}
fn make_hits<'a>(
index: &Index,
rtxn: &RoTxn<'_>,
format: AttributesFormat,
matching_words: milli::MatchingWords,
documents_ids_scores: impl Iterator<Item = (u32, &'a Vec<ScoreDetails>)> + 'a,
) -> Result<Vec<SearchHit>, MeilisearchHttpError> {
let mut documents = Vec::new();
let script_lang_map = index.script_language(rtxn)?;
let dictionary = index.dictionary(rtxn)?;
let dictionary: Option<Vec<_>> =
dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect());
let separators = index.allowed_separators(rtxn)?;
let separators: Option<Vec<_>> =
separators.as_ref().map(|x| x.iter().map(String::as_str).collect());
let tokenizer =
HitMaker::tokenizer(&script_lang_map, dictionary.as_deref(), separators.as_deref());
let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer);
let hit_maker = HitMaker::new(index, rtxn, format, formatter_builder)?;
for (id, score) in documents_ids_scores {
documents.push(hit_maker.make_hit(id, score)?);
}
Ok(documents)
}
@ -1319,7 +1415,13 @@ pub fn perform_similar(
show_ranking_score_details,
};
let hits = make_hits(index, &rtxn, format, Default::default(), documents_ids, document_scores)?;
let hits = make_hits(
index,
&rtxn,
format,
Default::default(),
documents_ids.iter().copied().zip(document_scores.iter()),
)?;
let max_total_hits = index
.pagination_max_total_hits(&rtxn)
@ -1492,10 +1594,10 @@ fn make_document(
Ok(document)
}
fn format_fields<'a>(
fn format_fields(
document: &Document,
field_ids_map: &FieldsIdsMap,
builder: &'a MatcherBuilder<'a>,
builder: &MatcherBuilder<'_>,
formatted_options: &BTreeMap<FieldId, FormatOptions>,
compute_matches: bool,
displayable_ids: &BTreeSet<FieldId>,
@ -1550,9 +1652,9 @@ fn format_fields<'a>(
Ok((matches_position, document))
}
fn format_value<'a>(
fn format_value(
value: Value,
builder: &'a MatcherBuilder<'a>,
builder: &MatcherBuilder<'_>,
format_options: Option<FormatOptions>,
infos: &mut Vec<MatchBounds>,
compute_matches: bool,