diff --git a/benchmarks/benches/utils.rs b/benchmarks/benches/utils.rs index 8c556b383..a240ce299 100644 --- a/benchmarks/benches/utils.rs +++ b/benchmarks/benches/utils.rs @@ -11,7 +11,7 @@ use milli::heed::EnvOpenOptions; use milli::update::{ IndexDocuments, IndexDocumentsConfig, IndexDocumentsMethod, IndexerConfig, Settings, }; -use milli::{Filter, Index, Object}; +use milli::{Filter, Index, Object, TermsMatchingStrategy}; use serde_json::Value; pub struct Conf<'a> { @@ -119,7 +119,7 @@ pub fn run_benches(c: &mut criterion::Criterion, confs: &[Conf]) { b.iter(|| { let rtxn = index.read_txn().unwrap(); let mut search = index.search(&rtxn); - search.query(query).optional_words(conf.optional_words); + search.query(query).terms_matching_strategy(TermsMatchingStrategy::default()); if let Some(filter) = conf.filter { let filter = Filter::from_str(filter).unwrap().unwrap(); search.filter(filter); diff --git a/milli/src/lib.rs b/milli/src/lib.rs index ac88ebdab..517d28ccc 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -42,7 +42,7 @@ pub use self::heed_codec::{ pub use self::index::Index; pub use self::search::{ FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWord, - MatchingWords, Search, SearchResult, DEFAULT_VALUES_PER_FACET, + MatchingWords, Search, SearchResult, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; pub type Result = std::result::Result; diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 1930091ef..7145c1445 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -44,7 +44,7 @@ pub struct Search<'a> { offset: usize, limit: usize, sort_criteria: Option>, - optional_words: bool, + terms_matching_strategy: TermsMatchingStrategy, authorize_typos: bool, words_limit: usize, rtxn: &'a heed::RoTxn<'a>, @@ -59,7 +59,7 @@ impl<'a> Search<'a> { offset: 0, limit: 20, sort_criteria: None, - optional_words: true, + terms_matching_strategy: TermsMatchingStrategy::default(), authorize_typos: true, words_limit: 10, rtxn, @@ -87,8 +87,8 @@ impl<'a> Search<'a> { self } - pub fn optional_words(&mut self, value: bool) -> &mut Search<'a> { - self.optional_words = value; + pub fn terms_matching_strategy(&mut self, value: TermsMatchingStrategy) -> &mut Search<'a> { + self.terms_matching_strategy = value; self } @@ -119,7 +119,7 @@ impl<'a> Search<'a> { let (query_tree, primitive_query, matching_words) = match self.query.as_ref() { Some(query) => { let mut builder = QueryTreeBuilder::new(self.rtxn, self.index)?; - builder.optional_words(self.optional_words); + builder.terms_matching_strategy(self.terms_matching_strategy); builder.authorize_typos(self.is_typo_authorized()?); @@ -259,7 +259,7 @@ impl fmt::Debug for Search<'_> { offset, limit, sort_criteria, - optional_words, + terms_matching_strategy, authorize_typos, words_limit, rtxn: _, @@ -271,7 +271,7 @@ impl fmt::Debug for Search<'_> { .field("offset", offset) .field("limit", limit) .field("sort_criteria", sort_criteria) - .field("optional_words", optional_words) + .field("terms_matching_strategy", terms_matching_strategy) .field("authorize_typos", authorize_typos) .field("words_limit", words_limit) .finish() @@ -286,6 +286,28 @@ pub struct SearchResult { pub documents_ids: Vec, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum TermsMatchingStrategy { + // remove last word first + Last, + // remove first word first + First, + // remove more frequent word first + Frequency, + // remove smallest word first + Size, + // only one of the word is mandatory + Any, + // all words are mandatory + All, +} + +impl Default for TermsMatchingStrategy { + fn default() -> Self { + Self::Last + } +} + pub type WordDerivationsCache = HashMap<(String, bool, u8), Vec<(String, u8)>>; pub fn word_derivations<'c>( diff --git a/milli/src/search/query_tree.rs b/milli/src/search/query_tree.rs index 617d9e4d9..51774d8b4 100644 --- a/milli/src/search/query_tree.rs +++ b/milli/src/search/query_tree.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::cmp::min; use std::{cmp, fmt, mem}; use charabia::classifier::ClassifiedTokenIter; @@ -8,6 +9,7 @@ use roaring::RoaringBitmap; use slice_group_by::GroupBy; use crate::search::matches::matching_words::{MatchingWord, PrimitiveWordId}; +use crate::search::TermsMatchingStrategy; use crate::{Index, MatchingWords, Result}; type IsOptionalWord = bool; @@ -62,6 +64,13 @@ impl Operation { if ops.len() == 1 { ops.pop().unwrap() } else { + let ops = ops + .into_iter() + .flat_map(|o| match o { + Operation::Or(wb, children) if wb == word_branch => children, + op => vec![op], + }) + .collect(); Self::Or(word_branch, ops) } } @@ -153,7 +162,7 @@ trait Context { pub struct QueryTreeBuilder<'a> { rtxn: &'a heed::RoTxn<'a>, index: &'a Index, - optional_words: bool, + terms_matching_strategy: TermsMatchingStrategy, authorize_typos: bool, words_limit: Option, exact_words: Option>>, @@ -190,19 +199,22 @@ impl<'a> QueryTreeBuilder<'a> { Ok(Self { rtxn, index, - optional_words: true, + terms_matching_strategy: TermsMatchingStrategy::default(), authorize_typos: true, words_limit: None, exact_words: index.exact_words(rtxn)?, }) } - /// if `optional_words` is set to `false` the query tree will be + /// if `terms_matching_strategy` is set to `All` the query tree will be /// generated forcing all query words to be present in each matching documents /// (the criterion `words` will be ignored). - /// default value if not called: `true` - pub fn optional_words(&mut self, optional_words: bool) -> &mut Self { - self.optional_words = optional_words; + /// default value if not called: `Last` + pub fn terms_matching_strategy( + &mut self, + terms_matching_strategy: TermsMatchingStrategy, + ) -> &mut Self { + self.terms_matching_strategy = terms_matching_strategy; self } @@ -223,7 +235,7 @@ impl<'a> QueryTreeBuilder<'a> { } /// Build the query tree: - /// - if `optional_words` is set to `false` the query tree will be + /// - if `terms_matching_strategy` is set to `All` the query tree will be /// generated forcing all query words to be present in each matching documents /// (the criterion `words` will be ignored) /// - if `authorize_typos` is set to `false` the query tree will be generated @@ -238,7 +250,7 @@ impl<'a> QueryTreeBuilder<'a> { if !primitive_query.is_empty() { let qt = create_query_tree( self, - self.optional_words, + self.terms_matching_strategy, self.authorize_typos, &primitive_query, )?; @@ -323,7 +335,7 @@ fn synonyms(ctx: &impl Context, word: &[&str]) -> heed::Result Result { @@ -363,6 +375,7 @@ fn create_query_tree( ctx: &impl Context, authorize_typos: bool, query: &[PrimitiveQueryPart], + any_words: bool, ) -> Result { const MAX_NGRAM: usize = 3; let mut op_children = Vec::new(); @@ -415,57 +428,93 @@ fn create_query_tree( } if !is_last { - let ngrams = ngrams(ctx, authorize_typos, tail)?; + let ngrams = ngrams(ctx, authorize_typos, tail, any_words)?; and_op_children.push(ngrams); } - or_op_children.push(Operation::and(and_op_children)); + + if any_words { + or_op_children.push(Operation::or(false, and_op_children)); + } else { + or_op_children.push(Operation::and(and_op_children)); + } } } op_children.push(Operation::or(false, or_op_children)); } - Ok(Operation::and(op_children)) - } - - /// Create a new branch removing the last non-phrase query parts. - fn optional_word( - ctx: &impl Context, - authorize_typos: bool, - query: PrimitiveQuery, - ) -> Result { - let number_phrases = query.iter().filter(|p| p.is_phrase()).count(); - let mut operation_children = Vec::new(); - - let start = number_phrases + (number_phrases == 0) as usize; - for len in start..=query.len() { - let mut word_count = len - number_phrases; - let query: Vec<_> = query - .iter() - .filter(|p| { - if p.is_phrase() { - true - } else if word_count != 0 { - word_count -= 1; - true - } else { - false - } - }) - .cloned() - .collect(); - - let ngrams = ngrams(ctx, authorize_typos, &query)?; - operation_children.push(ngrams); + if any_words { + Ok(Operation::or(false, op_children)) + } else { + Ok(Operation::and(op_children)) } - - Ok(Operation::or(true, operation_children)) } - if optional_words { - optional_word(ctx, authorize_typos, query.to_vec()) - } else { - ngrams(ctx, authorize_typos, query) + let number_phrases = query.iter().filter(|p| p.is_phrase()).count(); + let remove_count = query.len() - min(number_phrases, 1); + if remove_count == 0 { + return ngrams(ctx, authorize_typos, query, false); } + + let mut operation_children = Vec::new(); + let mut query = query.to_vec(); + for _ in 0..remove_count { + let pos = match terms_matching_strategy { + TermsMatchingStrategy::All => return ngrams(ctx, authorize_typos, &query, false), + TermsMatchingStrategy::Any => { + let operation = Operation::Or( + true, + vec![ + // branch allowing matching documents to contains any query word. + ngrams(ctx, authorize_typos, &query, true)?, + // branch forcing matching documents to contains all the query words, + // keeping this documents of the top of the resulted list. + ngrams(ctx, authorize_typos, &query, false)?, + ], + ); + + return Ok(operation); + } + TermsMatchingStrategy::Last => query + .iter() + .enumerate() + .filter(|(_, part)| !part.is_phrase()) + .last() + .map(|(pos, _)| pos), + TermsMatchingStrategy::First => { + query.iter().enumerate().find(|(_, part)| !part.is_phrase()).map(|(pos, _)| pos) + } + TermsMatchingStrategy::Size => query + .iter() + .enumerate() + .filter(|(_, part)| !part.is_phrase()) + .min_by_key(|(_, part)| match part { + PrimitiveQueryPart::Word(s, _) => s.len(), + _ => unreachable!(), + }) + .map(|(pos, _)| pos), + TermsMatchingStrategy::Frequency => query + .iter() + .enumerate() + .filter(|(_, part)| !part.is_phrase()) + .max_by_key(|(_, part)| match part { + PrimitiveQueryPart::Word(s, _) => { + ctx.word_documents_count(s).unwrap_or_default().unwrap_or(u64::max_value()) + } + _ => unreachable!(), + }) + .map(|(pos, _)| pos), + }; + + // compute and push the current branch on the front + operation_children.insert(0, ngrams(ctx, authorize_typos, &query, false)?); + // remove word from query before creating an new branch + match pos { + Some(pos) => query.remove(pos), + None => break, + }; + } + + Ok(Operation::or(true, operation_children)) } /// Main function that matchings words used for crop and highlight. @@ -750,15 +799,19 @@ mod test { impl TestContext { fn build>( &self, - optional_words: bool, + terms_matching_strategy: TermsMatchingStrategy, authorize_typos: bool, words_limit: Option, query: ClassifiedTokenIter, ) -> Result> { let primitive_query = create_primitive_query(query, None, words_limit); if !primitive_query.is_empty() { - let qt = - create_query_tree(self, optional_words, authorize_typos, &primitive_query)?; + let qt = create_query_tree( + self, + terms_matching_strategy, + authorize_typos, + &primitive_query, + )?; Ok(Some((qt, primitive_query))) } else { Ok(None) @@ -852,8 +905,10 @@ mod test { let query = "hey friends"; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR @@ -869,8 +924,10 @@ mod test { let query = "hey friends "; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR @@ -886,8 +943,10 @@ mod test { let query = "hello world "; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR @@ -911,8 +970,10 @@ mod test { let query = "new york city "; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR @@ -932,12 +993,11 @@ mod test { Exact { word: "city" } Tolerant { word: "newyork", max typo: 1 } Exact { word: "city" } - OR - Exact { word: "nyc" } - AND - Exact { word: "new" } - Exact { word: "york" } - Tolerant { word: "newyorkcity", max typo: 1 } + Exact { word: "nyc" } + AND + Exact { word: "new" } + Exact { word: "york" } + Tolerant { word: "newyorkcity", max typo: 1 } "###); } @@ -946,8 +1006,10 @@ mod test { let query = "n grams "; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR @@ -963,8 +1025,10 @@ mod test { let query = "wordsplit fish "; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR @@ -982,8 +1046,10 @@ mod test { let query = "\"hey friends\" \" \" \"wooop"; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" AND @@ -997,8 +1063,10 @@ mod test { let query = "\"hey friends. wooop wooop\""; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" AND @@ -1012,8 +1080,10 @@ mod test { let query = "hey my friend "; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(true, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::default(), true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR(WORD) @@ -1043,8 +1113,10 @@ mod test { let query = "\"hey my\""; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(true, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::default(), true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" PHRASE ["hey", "my"] @@ -1056,8 +1128,10 @@ mod test { let query = r#""hey" my good "friend""#; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(true, true, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::default(), true, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR(WORD) @@ -1084,8 +1158,10 @@ mod test { let query = "hey friends "; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, false, None, tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, false, None, tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" OR @@ -1101,8 +1177,10 @@ mod test { let query = "\"hey my\" good friend"; let tokens = query.tokenize(); - let (query_tree, _) = - TestContext::default().build(false, false, Some(2), tokens).unwrap().unwrap(); + let (query_tree, _) = TestContext::default() + .build(TermsMatchingStrategy::All, false, Some(2), tokens) + .unwrap() + .unwrap(); insta::assert_debug_snapshot!(query_tree, @r###" AND @@ -1145,7 +1223,8 @@ mod test { let exact_words = fst::Set::from_iter(Some("goodbye")).unwrap().into_fst().into_inner(); let exact_words = Some(fst::Set::new(exact_words).unwrap().map_data(Cow::Owned).unwrap()); let context = TestContext { exact_words, ..Default::default() }; - let (query_tree, _) = context.build(false, true, Some(2), tokens).unwrap().unwrap(); + let (query_tree, _) = + context.build(TermsMatchingStrategy::All, true, Some(2), tokens).unwrap().unwrap(); assert!(matches!( query_tree, diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index f5e04435d..23618b478 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -613,6 +613,7 @@ mod tests { use super::*; use crate::documents::documents_batch_reader_from_objects; use crate::index::tests::TempIndex; + use crate::search::TermsMatchingStrategy; use crate::update::DeleteDocuments; use crate::BEU16; @@ -1207,7 +1208,7 @@ mod tests { let mut search = crate::Search::new(&rtxn, &index); search.query("document"); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); // all documents should be returned let crate::SearchResult { documents_ids, .. } = search.execute().unwrap(); assert_eq!(documents_ids.len(), 4); @@ -1313,7 +1314,7 @@ mod tests { let mut search = crate::Search::new(&rtxn, &index); search.query("document"); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); // all documents should be returned let crate::SearchResult { documents_ids, .. } = search.execute().unwrap(); assert_eq!(documents_ids.len(), 4); @@ -1512,7 +1513,7 @@ mod tests { let mut search = crate::Search::new(&rtxn, &index); search.query("化妆包"); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); // only 1 document should be returned let crate::SearchResult { documents_ids, .. } = search.execute().unwrap(); diff --git a/milli/tests/search/distinct.rs b/milli/tests/search/distinct.rs index 022724fde..64dd16f09 100644 --- a/milli/tests/search/distinct.rs +++ b/milli/tests/search/distinct.rs @@ -2,7 +2,7 @@ use std::collections::HashSet; use big_s::S; use milli::update::Settings; -use milli::{Criterion, Search, SearchResult}; +use milli::{Criterion, Search, SearchResult, TermsMatchingStrategy}; use Criterion::*; use crate::search::{self, EXTERNAL_DOCUMENTS_IDS}; @@ -28,24 +28,25 @@ macro_rules! test_distinct { search.query(search::TEST_QUERY); search.limit(EXTERNAL_DOCUMENTS_IDS.len()); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let SearchResult { documents_ids, candidates, .. } = search.execute().unwrap(); assert_eq!(candidates.len(), $n_res); let mut distinct_values = HashSet::new(); - let expected_external_ids: Vec<_> = search::expected_order(&criteria, true, true, &[]) - .into_iter() - .filter_map(|d| { - if distinct_values.contains(&d.$distinct) { - None - } else { - distinct_values.insert(d.$distinct.to_owned()); - Some(d.id) - } - }) - .collect(); + let expected_external_ids: Vec<_> = + search::expected_order(&criteria, true, TermsMatchingStrategy::default(), &[]) + .into_iter() + .filter_map(|d| { + if distinct_values.contains(&d.$distinct) { + None + } else { + distinct_values.insert(d.$distinct.to_owned()); + Some(d.id) + } + }) + .collect(); let documents_ids = search::internal_to_external_ids(&index, &documents_ids); assert_eq!(documents_ids, expected_external_ids); diff --git a/milli/tests/search/filters.rs b/milli/tests/search/filters.rs index 5451a9076..18de24ac3 100644 --- a/milli/tests/search/filters.rs +++ b/milli/tests/search/filters.rs @@ -1,5 +1,5 @@ use either::{Either, Left, Right}; -use milli::{Criterion, Filter, Search, SearchResult}; +use milli::{Criterion, Filter, Search, SearchResult, TermsMatchingStrategy}; use Criterion::*; use crate::search::{self, EXTERNAL_DOCUMENTS_IDS}; @@ -19,16 +19,17 @@ macro_rules! test_filter { search.query(search::TEST_QUERY); search.limit(EXTERNAL_DOCUMENTS_IDS.len()); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); search.filter(filter_conditions); let SearchResult { documents_ids, .. } = search.execute().unwrap(); let filtered_ids = search::expected_filtered_ids($filter); - let expected_external_ids: Vec<_> = search::expected_order(&criteria, true, true, &[]) - .into_iter() - .filter_map(|d| if filtered_ids.contains(&d.id) { Some(d.id) } else { None }) - .collect(); + let expected_external_ids: Vec<_> = + search::expected_order(&criteria, true, TermsMatchingStrategy::default(), &[]) + .into_iter() + .filter_map(|d| if filtered_ids.contains(&d.id) { Some(d.id) } else { None }) + .collect(); let documents_ids = search::internal_to_external_ids(&index, &documents_ids); assert_eq!(documents_ids, expected_external_ids); diff --git a/milli/tests/search/mod.rs b/milli/tests/search/mod.rs index 0e1d43d2a..4ec1aeb83 100644 --- a/milli/tests/search/mod.rs +++ b/milli/tests/search/mod.rs @@ -8,7 +8,7 @@ use heed::EnvOpenOptions; use maplit::{hashmap, hashset}; use milli::documents::{DocumentsBatchBuilder, DocumentsBatchReader}; use milli::update::{IndexDocuments, IndexDocumentsConfig, IndexerConfig, Settings}; -use milli::{AscDesc, Criterion, DocumentId, Index, Member, Object}; +use milli::{AscDesc, Criterion, DocumentId, Index, Member, Object, TermsMatchingStrategy}; use serde::{Deserialize, Deserializer}; use slice_group_by::GroupBy; @@ -96,7 +96,7 @@ pub fn internal_to_external_ids(index: &Index, internal_ids: &[DocumentId]) -> V pub fn expected_order( criteria: &[Criterion], authorize_typo: bool, - optional_words: bool, + optional_words: TermsMatchingStrategy, sort_by: &[AscDesc], ) -> Vec { let dataset = @@ -155,9 +155,9 @@ pub fn expected_order( groups = std::mem::take(&mut new_groups); } - if authorize_typo && optional_words { + if authorize_typo && optional_words == TermsMatchingStrategy::default() { groups.into_iter().flatten().collect() - } else if optional_words { + } else if optional_words == TermsMatchingStrategy::default() { groups.into_iter().flatten().filter(|d| d.typo_rank == 0).collect() } else if authorize_typo { groups.into_iter().flatten().filter(|d| d.word_rank == 0).collect() diff --git a/milli/tests/search/query_criteria.rs b/milli/tests/search/query_criteria.rs index a96366f5e..8b72c8420 100644 --- a/milli/tests/search/query_criteria.rs +++ b/milli/tests/search/query_criteria.rs @@ -7,7 +7,7 @@ use itertools::Itertools; use maplit::hashset; use milli::documents::{DocumentsBatchBuilder, DocumentsBatchReader}; use milli::update::{IndexDocuments, IndexDocumentsConfig, IndexerConfig, Settings}; -use milli::{AscDesc, Criterion, Index, Member, Search, SearchResult}; +use milli::{AscDesc, Criterion, Index, Member, Search, SearchResult, TermsMatchingStrategy}; use rand::Rng; use Criterion::*; @@ -15,8 +15,8 @@ use crate::search::{self, EXTERNAL_DOCUMENTS_IDS}; const ALLOW_TYPOS: bool = true; const DISALLOW_TYPOS: bool = false; -const ALLOW_OPTIONAL_WORDS: bool = true; -const DISALLOW_OPTIONAL_WORDS: bool = false; +const ALLOW_OPTIONAL_WORDS: TermsMatchingStrategy = TermsMatchingStrategy::Last; +const DISALLOW_OPTIONAL_WORDS: TermsMatchingStrategy = TermsMatchingStrategy::All; const ASC_DESC_CANDIDATES_THRESHOLD: usize = 1000; macro_rules! test_criterion { @@ -31,7 +31,7 @@ macro_rules! test_criterion { search.query(search::TEST_QUERY); search.limit(EXTERNAL_DOCUMENTS_IDS.len()); search.authorize_typos($authorize_typos); - search.optional_words($optional_word); + search.terms_matching_strategy($optional_word); search.sort_criteria($sort_criteria); let SearchResult { documents_ids, .. } = search.execute().unwrap(); @@ -353,13 +353,13 @@ fn criteria_mixup() { let mut search = Search::new(&mut rtxn, &index); search.query(search::TEST_QUERY); search.limit(EXTERNAL_DOCUMENTS_IDS.len()); - search.optional_words(ALLOW_OPTIONAL_WORDS); + search.terms_matching_strategy(ALLOW_OPTIONAL_WORDS); search.authorize_typos(ALLOW_TYPOS); let SearchResult { documents_ids, .. } = search.execute().unwrap(); let expected_external_ids: Vec<_> = - search::expected_order(&criteria, ALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, &[]) + search::expected_order(&criteria, ALLOW_TYPOS, ALLOW_OPTIONAL_WORDS, &[]) .into_iter() .map(|d| d.id) .collect(); diff --git a/milli/tests/search/sort.rs b/milli/tests/search/sort.rs index 86404bb99..16d21eac8 100644 --- a/milli/tests/search/sort.rs +++ b/milli/tests/search/sort.rs @@ -1,6 +1,6 @@ use big_s::S; use milli::Criterion::{Attribute, Exactness, Proximity, Typo, Words}; -use milli::{AscDesc, Error, Member, Search, UserError}; +use milli::{AscDesc, Error, Member, Search, TermsMatchingStrategy, UserError}; use crate::search::{self, EXTERNAL_DOCUMENTS_IDS}; @@ -15,7 +15,7 @@ fn sort_ranking_rule_missing() { search.query(search::TEST_QUERY); search.limit(EXTERNAL_DOCUMENTS_IDS.len()); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); search.sort_criteria(vec![AscDesc::Asc(Member::Field(S("tag")))]); let result = search.execute(); diff --git a/milli/tests/search/typo_tolerance.rs b/milli/tests/search/typo_tolerance.rs index 7c4cf8971..7dc6b0c4f 100644 --- a/milli/tests/search/typo_tolerance.rs +++ b/milli/tests/search/typo_tolerance.rs @@ -2,7 +2,7 @@ use std::collections::BTreeSet; use heed::EnvOpenOptions; use milli::update::{IndexDocuments, IndexDocumentsConfig, IndexerConfig, Settings}; -use milli::{Criterion, Index, Search}; +use milli::{Criterion, Index, Search, TermsMatchingStrategy}; use serde_json::json; use tempfile::tempdir; use Criterion::*; @@ -20,7 +20,7 @@ fn test_typo_tolerance_one_typo() { search.query("zeal"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 1); @@ -29,7 +29,7 @@ fn test_typo_tolerance_one_typo() { search.query("zean"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 0); @@ -47,7 +47,7 @@ fn test_typo_tolerance_one_typo() { search.query("zean"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 1); @@ -66,7 +66,7 @@ fn test_typo_tolerance_two_typo() { search.query("zealand"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 1); @@ -75,7 +75,7 @@ fn test_typo_tolerance_two_typo() { search.query("zealemd"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 0); @@ -93,7 +93,7 @@ fn test_typo_tolerance_two_typo() { search.query("zealemd"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 1); @@ -142,7 +142,7 @@ fn test_typo_disabled_on_word() { search.query("zealand"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 2); @@ -162,7 +162,7 @@ fn test_typo_disabled_on_word() { search.query("zealand"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 1); @@ -182,7 +182,7 @@ fn test_disable_typo_on_attribute() { search.query("antebelum"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 1); @@ -200,7 +200,7 @@ fn test_disable_typo_on_attribute() { search.query("antebelum"); search.limit(10); search.authorize_typos(true); - search.optional_words(true); + search.terms_matching_strategy(TermsMatchingStrategy::default()); let result = search.execute().unwrap(); assert_eq!(result.documents_ids.len(), 0);