From 90c0a6db7ddf5b0d8535e1a6a20987776f7b3e82 Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Tue, 23 Jul 2024 14:09:27 +0200 Subject: [PATCH] Implement localized search --- meilisearch-types/src/error.rs | 1 + meilisearch-types/src/lib.rs | 1 + meilisearch-types/src/locales.rs | 132 ++++++++++++++++++ .../src/analytics/segment_analytics.rs | 22 ++- meilisearch/src/routes/indexes/search.rs | 4 + meilisearch/src/search/federated.rs | 10 +- meilisearch/src/search/mod.rs | 61 ++++++-- milli/examples/search.rs | 1 + milli/src/search/facet/search.rs | 24 +++- milli/src/search/hybrid.rs | 1 + milli/src/search/mod.rs | 11 ++ milli/src/search/new/matches/mod.rs | 84 ++++++----- milli/src/search/new/mod.rs | 8 +- .../src/search/new/query_term/parse_query.rs | 2 +- 14 files changed, 292 insertions(+), 70 deletions(-) create mode 100644 meilisearch-types/src/locales.rs diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index d27d6cd3d..e56949b57 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -256,6 +256,7 @@ InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; +InvalidSearchLocales , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidSimilarId , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch-types/src/lib.rs b/meilisearch-types/src/lib.rs index e4f5cbeb4..d6049e667 100644 --- a/meilisearch-types/src/lib.rs +++ b/meilisearch-types/src/lib.rs @@ -7,6 +7,7 @@ pub mod features; pub mod index_uid; pub mod index_uid_pattern; pub mod keys; +pub mod locales; pub mod settings; pub mod star_or; pub mod task_view; diff --git a/meilisearch-types/src/locales.rs b/meilisearch-types/src/locales.rs new file mode 100644 index 000000000..14972fc33 --- /dev/null +++ b/meilisearch-types/src/locales.rs @@ -0,0 +1,132 @@ +use deserr::Deserr; +use serde::{Deserialize, Serialize}; +use serde_json::json; + +use milli::LocalizedAttributesRule; + +/// Generate a Locale enum and its From and Into implementations for milli::tokenizer::Language. +/// +/// this enum implements `Deserr` in order to be used in the API. +macro_rules! make_locale { + + ($($language:tt), +) => { + #[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr, Serialize, Deserialize, Ord, PartialOrd)] + #[deserr(rename_all = camelCase)] + #[serde(rename_all = "camelCase")] + pub enum Locale { + $($language),+, + } + + impl From for Locale { + fn from(other: milli::tokenizer::Language) -> Locale { + match other { + $(milli::tokenizer::Language::$language => Locale::$language), + + } + } + } + + impl From for milli::tokenizer::Language { + fn from(other: Locale) -> milli::tokenizer::Language { + match other { + $(Locale::$language => milli::tokenizer::Language::$language), +, + } + } + } + + #[derive(Debug)] + pub struct LocaleFormatError { + pub invalid_locale: String, + } + + impl std::fmt::Display for LocaleFormatError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let valid_locales = [$(Locale::$language),+].iter().map(|l| format!("`{}`", json!(l).as_str().unwrap())).collect::>().join(", "); + write!(f, "Unknown value `{}`, expected one of {}", self.invalid_locale, valid_locales) + } + } + + impl std::error::Error for LocaleFormatError {} + + impl std::str::FromStr for Locale { + type Err = LocaleFormatError; + + fn from_str(s: &str) -> Result { + milli::tokenizer::Language::from_code(s).map(Self::from).ok_or(LocaleFormatError { + invalid_locale: s.to_string(), + }) + } + } + }; +} + +make_locale! { + Epo, + Eng, + Rus, + Cmn, + Spa, + Por, + Ita, + Ben, + Fra, + Deu, + Ukr, + Kat, + Ara, + Hin, + Jpn, + Heb, + Yid, + Pol, + Amh, + Jav, + Kor, + Nob, + Dan, + Swe, + Fin, + Tur, + Nld, + Hun, + Ces, + Ell, + Bul, + Bel, + Mar, + Kan, + Ron, + Slv, + Hrv, + Srp, + Mkd, + Lit, + Lav, + Est, + Tam, + Vie, + Urd, + Tha, + Guj, + Uzb, + Pan, + Aze, + Ind, + Tel, + Pes, + Mal, + Ori, + Mya, + Nep, + Sin, + Khm, + Tuk, + Aka, + Zul, + Sna, + Afr, + Lat, + Slk, + Cat, + Tgl, + Hye +} diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index 487eaf003..407b90658 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -1,4 +1,4 @@ -use std::collections::{BinaryHeap, HashMap, HashSet}; +use std::collections::{BTreeSet, BinaryHeap, HashMap, HashSet}; use std::fs; use std::mem::take; use std::path::{Path, PathBuf}; @@ -10,6 +10,7 @@ use actix_web::HttpRequest; use byte_unit::Byte; use index_scheduler::IndexScheduler; use meilisearch_auth::{AuthController, AuthFilter}; +use meilisearch_types::locales::Locale; use meilisearch_types::InstanceUid; use once_cell::sync::Lazy; use regex::Regex; @@ -653,6 +654,9 @@ pub struct SearchAggregator { // every time a search is done, we increment the counter linked to the used settings matching_strategy: HashMap, + // List of the unique Locales passed as parameter + locales: BTreeSet, + // pagination max_limit: usize, max_offset: usize, @@ -707,6 +711,7 @@ impl SearchAggregator { attributes_to_search_on, hybrid, ranking_score_threshold, + locales, } = query; let mut ret = Self::default(); @@ -774,6 +779,10 @@ impl SearchAggregator { ret.matching_strategy.insert(format!("{:?}", matching_strategy), 1); + if let Some(locales) = locales { + ret.locales = locales.into_iter().copied().collect(); + } + ret.highlight_pre_tag = *highlight_pre_tag != DEFAULT_HIGHLIGHT_PRE_TAG(); ret.highlight_post_tag = *highlight_post_tag != DEFAULT_HIGHLIGHT_POST_TAG(); ret.crop_marker = *crop_marker != DEFAULT_CROP_MARKER(); @@ -859,6 +868,7 @@ impl SearchAggregator { total_degraded, total_used_negative_operator, ranking_score_threshold, + ref mut locales, } = other; if self.timestamp.is_none() { @@ -947,6 +957,9 @@ impl SearchAggregator { self.show_ranking_score |= show_ranking_score; self.show_ranking_score_details |= show_ranking_score_details; self.ranking_score_threshold |= ranking_score_threshold; + + // locales + self.locales.append(locales); } pub fn into_event(self, user: &User, event_name: &str) -> Option { @@ -991,6 +1004,7 @@ impl SearchAggregator { total_degraded, total_used_negative_operator, ranking_score_threshold, + locales, } = self; if total_received == 0 { @@ -1060,6 +1074,7 @@ impl SearchAggregator { "matching_strategy": { "most_used_strategy": matching_strategy.iter().max_by_key(|(_, v)| *v).map(|(k, _)| json!(k)).unwrap_or_else(|| json!(null)), }, + "locales": locales, "scoring": { "show_ranking_score": show_ranking_score, "show_ranking_score_details": show_ranking_score_details, @@ -1150,6 +1165,7 @@ impl MultiSearchAggregator { attributes_to_search_on: _, hybrid: _, ranking_score_threshold: _, + locales: _, } = query; index_uid.as_str() @@ -1307,6 +1323,7 @@ impl FacetSearchAggregator { attributes_to_search_on, hybrid, ranking_score_threshold, + locales, } = query; let mut ret = Self::default(); @@ -1322,7 +1339,8 @@ impl FacetSearchAggregator { || *matching_strategy != MatchingStrategy::default() || attributes_to_search_on.is_some() || hybrid.is_some() - || ranking_score_threshold.is_some(); + || ranking_score_threshold.is_some() + || locales.is_some(); ret } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 836b96147..e60f95948 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -7,6 +7,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::locales::Locale; use meilisearch_types::milli; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -89,6 +90,8 @@ pub struct SearchQueryGet { pub hybrid_semantic_ratio: Option, #[deserr(default, error = DeserrQueryParamError)] pub ranking_score_threshold: Option, + #[deserr(default, error = DeserrQueryParamError)] + pub locales: Option>, } #[derive(Debug, Clone, Copy, PartialEq, deserr::Deserr)] @@ -175,6 +178,7 @@ impl From for SearchQuery { attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), hybrid, ranking_score_threshold: other.ranking_score_threshold.map(|o| o.0), + locales: other.locales.map(|o| o.into_iter().collect()), } } } diff --git a/meilisearch/src/search/federated.rs b/meilisearch/src/search/federated.rs index 0c623d9cb..58005ec53 100644 --- a/meilisearch/src/search/federated.rs +++ b/meilisearch/src/search/federated.rs @@ -380,9 +380,6 @@ pub fn perform_federated_search( let criteria = index.criteria(&rtxn)?; - // stuff we need for the hitmaker - let script_lang_map = index.script_language(&rtxn)?; - let dictionary = index.dictionary(&rtxn)?; let dictionary: Option> = dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); @@ -494,6 +491,7 @@ pub fn perform_federated_search( sort: query.sort, show_ranking_score: query.show_ranking_score, show_ranking_score_details: query.show_ranking_score_details, + locales: query.locales.map(|l| l.iter().copied().map(Into::into).collect()), }; let milli::SearchResult { @@ -509,11 +507,7 @@ pub fn perform_federated_search( degraded |= query_degraded; used_negative_operator |= query_used_negative_operator; - let tokenizer = HitMaker::tokenizer( - &script_lang_map, - dictionary.as_deref(), - separators.as_deref(), - ); + let tokenizer = HitMaker::tokenizer(dictionary.as_deref(), separators.as_deref()); let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer); diff --git a/meilisearch/src/search/mod.rs b/meilisearch/src/search/mod.rs index 6624188ce..d28d888aa 100644 --- a/meilisearch/src/search/mod.rs +++ b/meilisearch/src/search/mod.rs @@ -1,6 +1,6 @@ use core::fmt; use std::cmp::min; -use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashSet}; use std::str::FromStr; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -15,16 +15,17 @@ use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::locales::Locale; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors; use meilisearch_types::milli::vector::Embedder; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, TimeBudget}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; -use milli::tokenizer::TokenizerBuilder; +use milli::tokenizer::{Language, TokenizerBuilder}; use milli::{ - AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, - SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, + AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, LocalizedAttributesRule, + MatchBounds, MatcherBuilder, SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; use regex::Regex; use serde::Serialize; @@ -100,6 +101,8 @@ pub struct SearchQuery { pub attributes_to_search_on: Option>, #[deserr(default, error = DeserrJsonError, default)] pub ranking_score_threshold: Option, + #[deserr(default, error = DeserrJsonError, default)] + pub locales: Option>, } #[derive(Debug, Clone, Copy, PartialEq, Deserr)] @@ -169,6 +172,7 @@ impl fmt::Debug for SearchQuery { matching_strategy, attributes_to_search_on, ranking_score_threshold, + locales, } = self; let mut debug = f.debug_struct("SearchQuery"); @@ -250,6 +254,10 @@ impl fmt::Debug for SearchQuery { debug.field("ranking_score_threshold", &ranking_score_threshold); } + if let Some(locales) = locales { + debug.field("locales", &locales); + } + debug.finish() } } @@ -425,6 +433,8 @@ pub struct SearchQueryWithIndex { pub attributes_to_search_on: Option>, #[deserr(default, error = DeserrJsonError, default)] pub ranking_score_threshold: Option, + #[deserr(default, error = DeserrJsonError, default)] + pub locales: Option>, #[deserr(default)] pub federation_options: Option, @@ -477,6 +487,7 @@ impl SearchQueryWithIndex { attributes_to_search_on, hybrid, ranking_score_threshold, + locales, } = self; ( index_uid, @@ -506,6 +517,7 @@ impl SearchQueryWithIndex { attributes_to_search_on, hybrid, ranking_score_threshold, + locales, // do not use ..Default::default() here, // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` }, @@ -866,6 +878,10 @@ fn prepare_search<'t>( search.sort_criteria(sort); } + if let Some(ref locales) = query.locales { + search.locales(locales.iter().copied().map(Into::into).collect()); + } + Ok((search, is_finite_pagination, max_total_hits, offset)) } @@ -917,6 +933,7 @@ pub fn perform_search( highlight_pre_tag, highlight_post_tag, crop_marker, + locales, // already used in prepare_search vector: _, hybrid: _, @@ -941,6 +958,7 @@ pub fn perform_search( sort, show_ranking_score, show_ranking_score_details, + locales: locales.map(|l| l.iter().copied().map(Into::into).collect()), }; let documents = make_hits( @@ -1046,6 +1064,7 @@ struct AttributesFormat { sort: Option>, show_ranking_score: bool, show_ranking_score_details: bool, + locales: Option>, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -1093,19 +1112,16 @@ struct HitMaker<'a> { show_ranking_score_details: bool, sort: Option>, show_matches_position: bool, + locales: Option>, } impl<'a> HitMaker<'a> { pub fn tokenizer<'b>( - script_lang_map: &'b HashMap>, dictionary: Option<&'b [&'b str]>, separators: Option<&'b [&'b str]>, ) -> milli::tokenizer::Tokenizer<'b> { let mut tokenizer_builder = TokenizerBuilder::default(); tokenizer_builder.create_char_map(true); - if !script_lang_map.is_empty() { - tokenizer_builder.allow_list(script_lang_map); - } if let Some(separators) = separators { tokenizer_builder.separators(separators); @@ -1218,6 +1234,7 @@ impl<'a> HitMaker<'a> { show_ranking_score_details: format.show_ranking_score_details, show_matches_position: format.show_matches_position, sort: format.sort, + locales: format.locales, }) } @@ -1280,6 +1297,7 @@ impl<'a> HitMaker<'a> { &self.formatted_options, self.show_matches_position, &self.displayed_ids, + self.locales.as_deref(), )?; if let Some(sort) = self.sort.as_ref() { @@ -1312,8 +1330,6 @@ fn make_hits<'a>( ) -> Result, MeilisearchHttpError> { let mut documents = Vec::new(); - let script_lang_map = index.script_language(rtxn)?; - let dictionary = index.dictionary(rtxn)?; let dictionary: Option> = dictionary.as_ref().map(|x| x.iter().map(String::as_str).collect()); @@ -1321,8 +1337,7 @@ fn make_hits<'a>( let separators: Option> = separators.as_ref().map(|x| x.iter().map(String::as_str).collect()); - let tokenizer = - HitMaker::tokenizer(&script_lang_map, dictionary.as_deref(), separators.as_deref()); + let tokenizer = HitMaker::tokenizer(dictionary.as_deref(), separators.as_deref()); let formatter_builder = HitMaker::formatter_builder(matching_words, tokenizer); @@ -1341,6 +1356,7 @@ pub fn perform_facet_search( facet_name: String, search_kind: SearchKind, features: RoFeatures, + locales: Option>, ) -> Result { let before_search = Instant::now(); let rtxn = index.read_txn()?; @@ -1363,6 +1379,10 @@ pub fn perform_facet_search( facet_search.max_values(max_facets as usize); } + if let Some(locales) = locales { + facet_search.locales(locales); + } + Ok(FacetSearchResult { facet_hits: facet_search.execute()?, facet_query, @@ -1443,6 +1463,7 @@ pub fn perform_similar( sort: None, show_ranking_score, show_ranking_score_details, + locales: None, }; let hits = make_hits( @@ -1631,6 +1652,7 @@ fn format_fields( formatted_options: &BTreeMap, compute_matches: bool, displayable_ids: &BTreeSet, + locales: Option<&[Language]>, ) -> Result<(Option, Document), MeilisearchHttpError> { let mut matches_position = compute_matches.then(BTreeMap::new); let mut document = document.clone(); @@ -1664,6 +1686,14 @@ fn format_fields( let mut infos = Vec::new(); *value = format_value(std::mem::take(value), builder, format, &mut infos, compute_matches); + *value = format_value( + std::mem::take(value), + builder, + format, + &mut infos, + compute_matches, + locales, + ); if let Some(matches) = matches_position.as_mut() { if !infos.is_empty() { @@ -1688,10 +1718,11 @@ fn format_value( format_options: Option, infos: &mut Vec, compute_matches: bool, + locales: Option<&[Language]>, ) -> Value { match value { Value::String(old_string) => { - let mut matcher = builder.build(&old_string); + let mut matcher = builder.build(&old_string, locales); if compute_matches { let matches = matcher.matches(); infos.extend_from_slice(&matches[..]); @@ -1718,6 +1749,7 @@ fn format_value( }), infos, compute_matches, + locales, ) }) .collect(), @@ -1737,6 +1769,7 @@ fn format_value( }), infos, compute_matches, + locales, ), ) }) @@ -1745,7 +1778,7 @@ fn format_value( Value::Number(number) => { let s = number.to_string(); - let mut matcher = builder.build(&s); + let mut matcher = builder.build(&s, locales); if compute_matches { let matches = matcher.matches(); infos.extend_from_slice(&matches[..]); diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 87020994a..bb374f629 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -68,6 +68,7 @@ fn main() -> Result<(), Box> { logger, TimeBudget::max(), None, + None, )?; if let Some((logger, dir)) = detailed_logger { logger.finish(&mut ctx, Path::new(dir))?; diff --git a/milli/src/search/facet/search.rs b/milli/src/search/facet/search.rs index a6756a7af..6ef62e39a 100644 --- a/milli/src/search/facet/search.rs +++ b/milli/src/search/facet/search.rs @@ -3,7 +3,7 @@ use std::collections::BinaryHeap; use std::ops::ControlFlow; use charabia::normalizer::NormalizerOption; -use charabia::Normalize; +use charabia::{Language, Normalize, StrDetection, Token}; use fst::automaton::{Automaton, Str}; use fst::{IntoStreamer, Streamer}; use roaring::RoaringBitmap; @@ -23,6 +23,7 @@ pub struct SearchForFacetValues<'a> { search_query: Search<'a>, max_values: usize, is_hybrid: bool, + locales: Option>, } impl<'a> SearchForFacetValues<'a> { @@ -37,6 +38,7 @@ impl<'a> SearchForFacetValues<'a> { search_query, max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET, is_hybrid, + locales: None, } } @@ -50,6 +52,11 @@ impl<'a> SearchForFacetValues<'a> { self } + pub fn locales(&mut self, locales: Vec) -> &mut Self { + self.locales = Some(locales); + self + } + fn one_original_value_of( &self, field_id: FieldId, @@ -109,8 +116,7 @@ impl<'a> SearchForFacetValues<'a> { match self.query.as_ref() { Some(query) => { - let options = NormalizerOption { lossy: true, ..Default::default() }; - let query = query.normalize(&options); + let query = normalize_facet_string(query, self.locales.as_deref()); let query = query.as_ref(); let authorize_typos = self.search_query.index.authorize_typos(rtxn)?; @@ -330,3 +336,15 @@ impl ValuesCollection { } } } +fn normalize_facet_string(facet_string: &str, locales: Option<&[Language]>) -> String { + let options = NormalizerOption { lossy: true, ..Default::default() }; + let mut detection = StrDetection::new(facet_string, locales); + let token = Token { + lemma: std::borrow::Cow::Borrowed(facet_string), + script: detection.script(), + language: detection.language(), + ..Default::default() + }; + + token.normalize(&options).lemma.to_string() +} diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index 2102bf479..e08111473 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -174,6 +174,7 @@ impl<'a> Search<'a> { semantic: self.semantic.clone(), time_budget: self.time_budget.clone(), ranking_score_threshold: self.ranking_score_threshold, + locales: self.locales.clone(), }; let semantic = search.semantic.take(); diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 2b2afa607..0f5eb23e1 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -1,6 +1,7 @@ use std::fmt; use std::sync::Arc; +use charabia::Language; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; @@ -52,6 +53,7 @@ pub struct Search<'a> { semantic: Option, time_budget: TimeBudget, ranking_score_threshold: Option, + locales: Option>, } impl<'a> Search<'a> { @@ -72,6 +74,7 @@ impl<'a> Search<'a> { rtxn, index, semantic: None, + locales: None, time_budget: TimeBudget::max(), ranking_score_threshold: None, } @@ -160,6 +163,11 @@ impl<'a> Search<'a> { self } + pub fn locales(&mut self, locales: Vec) -> &mut Search<'a> { + self.locales = Some(locales); + self + } + pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { if has_vector_search { let ctx = SearchContext::new(self.index, self.rtxn)?; @@ -232,6 +240,7 @@ impl<'a> Search<'a> { &mut DefaultSearchLogger, self.time_budget.clone(), self.ranking_score_threshold, + self.locales.as_ref(), )?, }; @@ -272,6 +281,7 @@ impl fmt::Debug for Search<'_> { semantic, time_budget, ranking_score_threshold, + locales, } = self; f.debug_struct("Search") .field("query", query) @@ -292,6 +302,7 @@ impl fmt::Debug for Search<'_> { ) .field("time_budget", time_budget) .field("ranking_score_threshold", ranking_score_threshold) + .field("locales", locales) .finish() } } diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 7bc4d9c5d..4688b8f32 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -1,6 +1,6 @@ use std::borrow::Cow; -use charabia::{SeparatorKind, Token, Tokenizer}; +use charabia::{Language, SeparatorKind, Token, Tokenizer}; pub use matching_words::MatchingWords; use matching_words::{MatchType, PartialMatch, WordId}; use serde::Serialize; @@ -46,7 +46,11 @@ impl<'m> MatcherBuilder<'m> { self } - pub fn build<'t>(&self, text: &'t str) -> Matcher<'t, 'm, '_> { + pub fn build<'t, 'lang>( + &self, + text: &'t str, + locales: Option<&'lang [Language]>, + ) -> Matcher<'t, 'm, '_, 'lang> { let crop_marker = match &self.crop_marker { Some(marker) => marker.as_str(), None => DEFAULT_CROP_MARKER, @@ -68,6 +72,7 @@ impl<'m> MatcherBuilder<'m> { highlight_prefix, highlight_suffix, matches: None, + locales, } } } @@ -107,17 +112,18 @@ pub struct MatchBounds { /// Structure used to analyze a string, compute words that match, /// and format the source string, returning a highlighted and cropped sub-string. -pub struct Matcher<'t, 'tokenizer, 'b> { +pub struct Matcher<'t, 'tokenizer, 'b, 'lang> { text: &'t str, matching_words: &'b MatchingWords, tokenizer: &'b Tokenizer<'tokenizer>, + locales: Option<&'lang [Language]>, crop_marker: &'b str, highlight_prefix: &'b str, highlight_suffix: &'b str, matches: Option<(Vec>, Vec)>, } -impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_> { +impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_, '_> { /// Iterates over tokens and save any of them that matches the query. fn compute_matches(&mut self) -> &mut Self { /// some words are counted as matches only if they are close together and in the good order, @@ -173,7 +179,8 @@ impl<'t, 'tokenizer> Matcher<'t, 'tokenizer, '_> { false } - let tokens: Vec<_> = self.tokenizer.tokenize(self.text).collect(); + let tokens: Vec<_> = + self.tokenizer.tokenize_with_allow_list(self.text, self.locales).collect(); let mut matches = Vec::new(); let mut words_positions = tokens @@ -530,6 +537,7 @@ mod tests { &mut crate::DefaultSearchLogger, TimeBudget::max(), None, + None, ) .unwrap(); @@ -553,19 +561,19 @@ mod tests { // Text without any match. let text = "A quick brown fox can not jump 32 feet, right? Brr, it is cold!"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop and no highlight should return complete text. assert_eq!(&matcher.format(format_options), &text); // Text containing all matches. let text = "Natalie risk her future to build a world with the boy she loves. Emily Henry: The Love That Split The World."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop and no highlight should return complete text. assert_eq!(&matcher.format(format_options), &text); // Text containing some matches. let text = "Natalie risk her future to build a world with the boy she loves."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop and no highlight should return complete text. assert_eq!(&matcher.format(format_options), &text); } @@ -580,23 +588,23 @@ mod tests { // empty text. let text = ""; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); assert_eq!(&matcher.format(format_options), ""); // text containing only separators. let text = ":-)"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); assert_eq!(&matcher.format(format_options), ":-)"); // Text without any match. let text = "A quick brown fox can not jump 32 feet, right? Brr, it is cold!"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop should return complete text, because there is no matches. assert_eq!(&matcher.format(format_options), &text); // Text containing all matches. let text = "Natalie risk her future to build a world with the boy she loves. Emily Henry: The Love That Split The World."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop should return complete text with highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -605,7 +613,7 @@ mod tests { // Text containing some matches. let text = "Natalie risk her future to build a world with the boy she loves."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop should return complete text with highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -622,7 +630,7 @@ mod tests { // Text containing prefix match. let text = "Ŵôřlḑôle"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop should return complete text with highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -631,7 +639,7 @@ mod tests { // Text containing unicode match. let text = "Ŵôřlḑ"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop should return complete text with highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -643,7 +651,7 @@ mod tests { // Text containing unicode match. let text = "Westfália"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no crop should return complete text with highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -661,7 +669,7 @@ mod tests { // empty text. let text = ""; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); insta::assert_snapshot!( matcher.format(format_options), @"" @@ -669,7 +677,7 @@ mod tests { // text containing only separators. let text = ":-)"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); insta::assert_snapshot!( matcher.format(format_options), @":-)" @@ -677,7 +685,7 @@ mod tests { // Text without any match. let text = "A quick brown fox can not jump 32 feet, right? Brr, it is cold!"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no highlight should return 10 first words with a marker at the end. insta::assert_snapshot!( matcher.format(format_options), @@ -686,7 +694,7 @@ mod tests { // Text without any match starting by a separator. let text = "(A quick brown fox can not jump 32 feet, right? Brr, it is cold!)"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no highlight should return 10 first words with a marker at the end. insta::assert_snapshot!( matcher.format(format_options), @@ -695,7 +703,7 @@ mod tests { // Test phrase propagation let text = "Natalie risk her future. Split The World is a book written by Emily Henry. I never read it."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // should crop the phrase instead of croping around the match. insta::assert_snapshot!( matcher.format(format_options), @@ -704,7 +712,7 @@ mod tests { // Text containing some matches. let text = "Natalie risk her future to build a world with the boy she loves."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no highlight should return 10 last words with a marker at the start. insta::assert_snapshot!( matcher.format(format_options), @@ -713,7 +721,7 @@ mod tests { // Text containing all matches. let text = "Natalie risk her future to build a world with the boy she loves. Emily Henry: The Love That Split The World."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // no highlight should return 10 last words with a marker at the start. insta::assert_snapshot!( matcher.format(format_options), @@ -722,7 +730,7 @@ mod tests { // Text containing a match unordered and a match ordered. let text = "The world split void void void void void void void void void split the world void void"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // crop should return 10 last words with a marker at the start. insta::assert_snapshot!( matcher.format(format_options), @@ -731,7 +739,7 @@ mod tests { // Text containing matches with different density. let text = "split void the void void world void void void void void void void void void void split the world void void"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // crop should return 10 last words with a marker at the start. insta::assert_snapshot!( matcher.format(format_options), @@ -740,7 +748,7 @@ mod tests { // Text containing matches with same word. let text = "split split split split split split void void void void void void void void void void split the world void void"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // crop should return 10 last words with a marker at the start. insta::assert_snapshot!( matcher.format(format_options), @@ -758,7 +766,7 @@ mod tests { // empty text. let text = ""; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); insta::assert_snapshot!( matcher.format(format_options), @"" @@ -766,7 +774,7 @@ mod tests { // text containing only separators. let text = ":-)"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); insta::assert_snapshot!( matcher.format(format_options), @":-)" @@ -774,7 +782,7 @@ mod tests { // Text without any match. let text = "A quick brown fox can not jump 32 feet, right? Brr, it is cold!"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // both should return 10 first words with a marker at the end. insta::assert_snapshot!( matcher.format(format_options), @@ -783,7 +791,7 @@ mod tests { // Text containing some matches. let text = "Natalie risk her future to build a world with the boy she loves."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // both should return 10 last words with a marker at the start and highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -792,7 +800,7 @@ mod tests { // Text containing all matches. let text = "Natalie risk her future to build a world with the boy she loves. Emily Henry: The Love That Split The World."; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // both should return 10 last words with a marker at the start and highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -801,7 +809,7 @@ mod tests { // Text containing a match unordered and a match ordered. let text = "The world split void void void void void void void void void split the world void void"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // crop should return 10 last words with a marker at the start. insta::assert_snapshot!( matcher.format(format_options), @@ -824,7 +832,7 @@ mod tests { let text = "The groundbreaking invention had the power to split the world between those who embraced progress and those who resisted change!"; let builder = MatcherBuilder::new_test(&rtxn, &temp_index, "\"the world\""); - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // should return 10 words with a marker at the start as well the end, and the highlighted matches. insta::assert_snapshot!( matcher.format(format_options), @@ -832,7 +840,7 @@ mod tests { ); let builder = MatcherBuilder::new_test(&rtxn, &temp_index, "those \"and those\""); - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // should highlight "those" and the phrase "and those". insta::assert_snapshot!( matcher.format(format_options), @@ -851,7 +859,7 @@ mod tests { // set a smaller crop size let format_options = FormatOptions { highlight: false, crop: Some(2) }; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // because crop size < query size, partially format matches. insta::assert_snapshot!( matcher.format(format_options), @@ -860,7 +868,7 @@ mod tests { // set a smaller crop size let format_options = FormatOptions { highlight: false, crop: Some(1) }; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // because crop size < query size, partially format matches. insta::assert_snapshot!( matcher.format(format_options), @@ -869,7 +877,7 @@ mod tests { // set crop size to 0 let format_options = FormatOptions { highlight: false, crop: Some(0) }; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); // because crop size is 0, crop is ignored. insta::assert_snapshot!( matcher.format(format_options), @@ -889,7 +897,7 @@ mod tests { let format_options = FormatOptions { highlight: true, crop: None }; let text = "the do or die can't be he do and or isn't he"; - let mut matcher = builder.build(text); + let mut matcher = builder.build(text, None); insta::assert_snapshot!( matcher.format(format_options), @"_the_ _do_ _or_ die can't be he do and or isn'_t_ _he_" diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 78b7a0446..577e12a39 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -24,7 +24,7 @@ mod tests; use std::collections::HashSet; use bucket_sort::{bucket_sort, BucketSortOutput}; -use charabia::TokenizerBuilder; +use charabia::{Language, TokenizerBuilder}; use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; @@ -639,6 +639,7 @@ pub fn execute_search( query_graph_logger: &mut dyn SearchLogger, time_budget: TimeBudget, ranking_score_threshold: Option, + locales: Option<&Vec>, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -670,9 +671,8 @@ pub fn execute_search( tokbuilder.words_dict(dictionary); } - let languages = ctx.index.languages(ctx.txn)?; - if !languages.is_empty() { - tokbuilder.allow_list(&languages); + if let Some(locales) = locales { + tokbuilder.allow_list(locales); } let tokenizer = tokbuilder.build(); diff --git a/milli/src/search/new/query_term/parse_query.rs b/milli/src/search/new/query_term/parse_query.rs index d4c1c2f95..bb98f19ce 100644 --- a/milli/src/search/new/query_term/parse_query.rs +++ b/milli/src/search/new/query_term/parse_query.rs @@ -24,7 +24,7 @@ pub struct ExtractedTokens { #[tracing::instrument(level = "trace", skip_all, target = "search::query")] pub fn located_query_terms_from_tokens( ctx: &mut SearchContext<'_>, - query: NormalizedTokenIter<'_, '_>, + query: NormalizedTokenIter<'_, '_, '_, '_>, words_limit: Option, ) -> Result { let nbr_typos = number_of_typos_allowed(ctx)?;