From 642c01d0dc4f0f55ce6ffaa1e1184da11c698cac Mon Sep 17 00:00:00 2001 From: mpostma Date: Thu, 20 Jan 2022 18:34:54 +0100 Subject: [PATCH 1/6] set max typos on ngram to 1 --- milli/src/search/query_tree.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/milli/src/search/query_tree.rs b/milli/src/search/query_tree.rs index 0744231ae..d4cc338c8 100644 --- a/milli/src/search/query_tree.rs +++ b/milli/src/search/query_tree.rs @@ -260,12 +260,12 @@ fn split_best_frequency(ctx: &impl Context, word: &str) -> heed::Result QueryKind { +fn typos(word: String, authorize_typos: bool, max_typos: u8) -> QueryKind { if authorize_typos { match word.chars().count() { 0..=4 => QueryKind::exact(word), - 5..=8 => QueryKind::tolerant(1, word), - _ => QueryKind::tolerant(2, word), + 5..=8 => QueryKind::tolerant(1.min(max_typos), word), + _ => QueryKind::tolerant(2.min(max_typos), word), } } else { QueryKind::exact(word) @@ -316,8 +316,10 @@ fn create_query_tree( if let Some(child) = split_best_frequency(ctx, &word)? { children.push(child); } - children - .push(Operation::Query(Query { prefix, kind: typos(word, authorize_typos) })); + children.push(Operation::Query(Query { + prefix, + kind: typos(word, authorize_typos, 2), + })); Ok(Operation::or(false, children)) } // create a CONSECUTIVE operation wrapping all word in the phrase @@ -363,8 +365,9 @@ fn create_query_tree( .collect(); let mut operations = synonyms(ctx, &words)?.unwrap_or_default(); let concat = words.concat(); - let query = - Query { prefix: is_prefix, kind: typos(concat, authorize_typos) }; + let query = Query { prefix: is_prefix, kind: typos(concat, true, 1) }; + // let query = + // Query { prefix: is_prefix, kind: typos(concat, authorize_typos) }; operations.push(Operation::Query(query)); and_op_children.push(Operation::or(false, operations)); } From 55e6cb9c7b179181e1e131265b0a66da76a76250 Mon Sep 17 00:00:00 2001 From: mpostma Date: Thu, 20 Jan 2022 18:35:11 +0100 Subject: [PATCH 2/6] typos on first letter counts as 2 --- Cargo.toml | 3 +++ milli/src/search/mod.rs | 37 +++++++++++++++++++++++++++++-------- 2 files changed, 32 insertions(+), 8 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6b3e12f07..9b97dee88 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,6 @@ opt-level = 3 opt-level = 3 [profile.test.build-override] opt-level = 3 + +[patch.crates-io] +fst = { path = "/Users/mpostma/Documents/code/rust/fst/" } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 7c8722187..6b2e50c94 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -7,7 +7,8 @@ use std::str::Utf8Error; use std::time::Instant; use distinct::{Distinct, DocIter, FacetDistinct, NoopDistinct}; -use fst::{IntoStreamer, Streamer}; +use fst::automaton::Str; +use fst::{Automaton, IntoStreamer, Streamer}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; use log::debug; use meilisearch_tokenizer::{Analyzer, AnalyzerConfig}; @@ -285,19 +286,39 @@ pub fn word_derivations<'c>( Entry::Vacant(entry) => { let mut derived_words = Vec::new(); let dfa = build_dfa(word, max_typo, is_prefix); - let mut stream = fst.search_with_state(&dfa).into_stream(); + if max_typo == 1 { + let starts = Str::new(get_first(word)); + let mut stream = fst.search_with_state(starts.intersection(&dfa)).into_stream(); - while let Some((word, state)) = stream.next() { - let word = std::str::from_utf8(word)?; - let distance = dfa.distance(state); - derived_words.push((word.to_string(), distance.to_u8())); + while let Some((word, state)) = stream.next() { + let word = std::str::from_utf8(word)?; + let distance = dfa.distance(state.1); + derived_words.push((word.to_string(), distance.to_u8())); + } + + Ok(entry.insert(derived_words)) + } else { + let mut stream = fst.search_with_state(&dfa).into_stream(); + + while let Some((word, state)) = stream.next() { + let word = std::str::from_utf8(word)?; + let distance = dfa.distance(state); + derived_words.push((word.to_string(), distance.to_u8())); + } + + Ok(entry.insert(derived_words)) } - - Ok(entry.insert(derived_words)) } } } +fn get_first(s: &str) -> &str { + match s.chars().next() { + Some(c) => &s[..c.len_utf8()], + None => s, + } +} + pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA { let lev = match typos { 0 => &LEVDIST0, From d0aabde502f3450b8c26a8cd8e6ee0240bd7cf1a Mon Sep 17 00:00:00 2001 From: mpostma Date: Thu, 20 Jan 2022 23:23:07 +0100 Subject: [PATCH 3/6] optimize 2 typos case --- milli/src/search/mod.rs | 54 ++++++++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 17 deletions(-) diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 6b2e50c94..cf596fa7a 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -285,29 +285,49 @@ pub fn word_derivations<'c>( Entry::Occupied(entry) => Ok(entry.into_mut()), Entry::Vacant(entry) => { let mut derived_words = Vec::new(); - let dfa = build_dfa(word, max_typo, is_prefix); - if max_typo == 1 { - let starts = Str::new(get_first(word)); - let mut stream = fst.search_with_state(starts.intersection(&dfa)).into_stream(); + if max_typo == 0 { + if is_prefix { + let prefix = Str::new(word).starts_with(); + let mut stream = fst.search(prefix).into_stream(); - while let Some((word, state)) = stream.next() { - let word = std::str::from_utf8(word)?; - let distance = dfa.distance(state.1); - derived_words.push((word.to_string(), distance.to_u8())); + while let Some(word) = stream.next() { + let word = std::str::from_utf8(word)?; + derived_words.push((word.to_string(), 0)); + } + } else { + let automaton = Str::new(word); + let mut stream = fst.search(automaton).into_stream(); + while let Some(word) = stream.next() { + let word = std::str::from_utf8(word)?; + derived_words.push((word.to_string(), 0)); + } } - - Ok(entry.insert(derived_words)) } else { - let mut stream = fst.search_with_state(&dfa).into_stream(); + if max_typo == 1 { + let dfa = build_dfa(word, 1, is_prefix); + let starts = Str::new(get_first(word)).starts_with(); + let mut stream = fst.search_with_state(starts.intersection(&dfa)).into_stream(); - while let Some((word, state)) = stream.next() { - let word = std::str::from_utf8(word)?; - let distance = dfa.distance(state); - derived_words.push((word.to_string(), distance.to_u8())); + while let Some((word, state)) = stream.next() { + let word = std::str::from_utf8(word)?; + let distance = dfa.distance(state.1); + derived_words.push((word.to_string(), distance.to_u8())); + } + } else { + let starts = Str::new(get_first(word)).starts_with(); + let first = build_dfa(word, 1, is_prefix).intersection((&starts).complement()); + let second = build_dfa(word, 2, is_prefix).intersection(&starts); + let automaton = first.union(second); + + let mut stream = fst.search(automaton).into_stream(); + + while let Some(word) = stream.next() { + let word = std::str::from_utf8(word)?; + derived_words.push((word.to_string(), 2)); + } } - - Ok(entry.insert(derived_words)) } + Ok(entry.insert(derived_words)) } } } From 7541ab99cdcc0f60fd92895f697a7e628d983b58 Mon Sep 17 00:00:00 2001 From: mpostma Date: Tue, 25 Jan 2022 10:06:27 +0100 Subject: [PATCH 4/6] review changes --- Cargo.toml | 3 --- milli/src/search/mod.rs | 18 ++++++------------ milli/src/search/query_tree.rs | 2 -- 3 files changed, 6 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 9b97dee88..6b3e12f07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,3 @@ opt-level = 3 opt-level = 3 [profile.test.build-override] opt-level = 3 - -[patch.crates-io] -fst = { path = "/Users/mpostma/Documents/code/rust/fst/" } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index cf596fa7a..67b86d6bf 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -294,24 +294,18 @@ pub fn word_derivations<'c>( let word = std::str::from_utf8(word)?; derived_words.push((word.to_string(), 0)); } - } else { - let automaton = Str::new(word); - let mut stream = fst.search(automaton).into_stream(); - while let Some(word) = stream.next() { - let word = std::str::from_utf8(word)?; - derived_words.push((word.to_string(), 0)); - } + } else if fst.contains(word) { + derived_words.push((word.to_string(), 0)); } } else { if max_typo == 1 { let dfa = build_dfa(word, 1, is_prefix); let starts = Str::new(get_first(word)).starts_with(); - let mut stream = fst.search_with_state(starts.intersection(&dfa)).into_stream(); + let mut stream = fst.search(starts.intersection(&dfa)).into_stream(); - while let Some((word, state)) = stream.next() { + while let Some(word) = stream.next() { let word = std::str::from_utf8(word)?; - let distance = dfa.distance(state.1); - derived_words.push((word.to_string(), distance.to_u8())); + derived_words.push((word.to_string(), 1)); } } else { let starts = Str::new(get_first(word)).starts_with(); @@ -335,7 +329,7 @@ pub fn word_derivations<'c>( fn get_first(s: &str) -> &str { match s.chars().next() { Some(c) => &s[..c.len_utf8()], - None => s, + None => panic!("unexpected empty query"), } } diff --git a/milli/src/search/query_tree.rs b/milli/src/search/query_tree.rs index d4cc338c8..355e42663 100644 --- a/milli/src/search/query_tree.rs +++ b/milli/src/search/query_tree.rs @@ -366,8 +366,6 @@ fn create_query_tree( let mut operations = synonyms(ctx, &words)?.unwrap_or_default(); let concat = words.concat(); let query = Query { prefix: is_prefix, kind: typos(concat, true, 1) }; - // let query = - // Query { prefix: is_prefix, kind: typos(concat, authorize_typos) }; operations.push(Operation::Query(query)); and_op_children.push(Operation::or(false, operations)); } From 628c835a220c4b29f12bc23b40ff90ca5292a620 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 2 Feb 2022 18:45:11 +0100 Subject: [PATCH 5/6] fix tests --- Cargo.toml | 3 +++ milli/src/search/mod.rs | 30 +++++++++++++++++++--------- milli/src/search/query_tree.rs | 17 +++++++++------- milli/tests/assets/test_set.ndjson | 2 +- milli/tests/search/query_criteria.rs | 1 + 5 files changed, 36 insertions(+), 17 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 6b3e12f07..52599b1bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,3 +18,6 @@ opt-level = 3 opt-level = 3 [profile.test.build-override] opt-level = 3 + +[patch.crates-io] +fst = { git = "https://github.com/MarinPostma/fst.git", rev = "e6c606b7507e8cb5e502d1609f9b909b8690bac5" } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 67b86d6bf..bfe5e023c 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -70,6 +70,7 @@ impl<'a> Search<'a> { pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { self.offset = offset; + self } @@ -301,23 +302,34 @@ pub fn word_derivations<'c>( if max_typo == 1 { let dfa = build_dfa(word, 1, is_prefix); let starts = Str::new(get_first(word)).starts_with(); - let mut stream = fst.search(starts.intersection(&dfa)).into_stream(); + let mut stream = fst.search_with_state(starts.intersection(&dfa)).into_stream(); - while let Some(word) = stream.next() { + while let Some((word, state)) = stream.next() { let word = std::str::from_utf8(word)?; - derived_words.push((word.to_string(), 1)); + let d = dfa.distance(state.1); + derived_words.push((word.to_string(), d.to_u8())); } } else { let starts = Str::new(get_first(word)).starts_with(); let first = build_dfa(word, 1, is_prefix).intersection((&starts).complement()); - let second = build_dfa(word, 2, is_prefix).intersection(&starts); - let automaton = first.union(second); + let second_dfa = build_dfa(word, 2, is_prefix); + let second = (&second_dfa).intersection(&starts); + let automaton = first.union(&second); - let mut stream = fst.search(automaton).into_stream(); + let mut stream = fst.search_with_state(automaton).into_stream(); - while let Some(word) = stream.next() { - let word = std::str::from_utf8(word)?; - derived_words.push((word.to_string(), 2)); + while let Some((found_word, state)) = stream.next() { + let found_word = std::str::from_utf8(found_word)?; + // in the case the typo is on the first letter, we know the number of typo + // is two + if get_first(found_word) != get_first(word) { + derived_words.push((word.to_string(), 2)); + } else { + // Else, we know that it is the second dfa that matched and compute the + // correct distance + let d = second_dfa.distance((state.1).0); + derived_words.push((word.to_string(), d.to_u8())); + } } } } diff --git a/milli/src/search/query_tree.rs b/milli/src/search/query_tree.rs index 355e42663..a7285ccaa 100644 --- a/milli/src/search/query_tree.rs +++ b/milli/src/search/query_tree.rs @@ -365,7 +365,10 @@ fn create_query_tree( .collect(); let mut operations = synonyms(ctx, &words)?.unwrap_or_default(); let concat = words.concat(); - let query = Query { prefix: is_prefix, kind: typos(concat, true, 1) }; + let query = Query { + prefix: is_prefix, + kind: typos(concat, authorize_typos, 1), + }; operations.push(Operation::Query(query)); and_op_children.push(Operation::or(false, operations)); } @@ -657,7 +660,7 @@ mod test { ]), Operation::Query(Query { prefix: true, - kind: QueryKind::tolerant(2, "heyfriends".to_string()), + kind: QueryKind::tolerant(1, "heyfriends".to_string()), }), ], ); @@ -690,7 +693,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "heyfriends".to_string()), + kind: QueryKind::tolerant(1, "heyfriends".to_string()), }), ], ); @@ -755,7 +758,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "helloworld".to_string()), + kind: QueryKind::tolerant(1, "helloworld".to_string()), }), ], ); @@ -853,7 +856,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "newyorkcity".to_string()), + kind: QueryKind::tolerant(1, "newyorkcity".to_string()), }), ], ), @@ -927,7 +930,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "wordsplitfish".to_string()), + kind: QueryKind::tolerant(1, "wordsplitfish".to_string()), }), ], ); @@ -1047,7 +1050,7 @@ mod test { ]), Operation::Query(Query { prefix: false, - kind: QueryKind::tolerant(2, "heymyfriend".to_string()), + kind: QueryKind::tolerant(1, "heymyfriend".to_string()), }), ], ), diff --git a/milli/tests/assets/test_set.ndjson b/milli/tests/assets/test_set.ndjson index 9a0fe5b0a..6383d274e 100644 --- a/milli/tests/assets/test_set.ndjson +++ b/milli/tests/assets/test_set.ndjson @@ -8,7 +8,7 @@ {"id":"H","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":0,"exact_rank":3,"asc_desc_rank":4,"sort_by_rank":1,"geo_rank":202182,"title":"world hello day","description":"holiday observed on november 21 to express that conflicts should be resolved through communication rather than the use of force","tag":"green","_geo": { "lat": 48.875617484531965, "lng": 2.346747821504194 },"":""} {"id":"I","word_rank":0,"typo_rank":0,"proximity_rank":8,"attribute_rank":338,"exact_rank":3,"asc_desc_rank":3,"sort_by_rank":0,"geo_rank":740667,"title":"hello world song","description":"hello world is a song written by tom douglas tony lane and david lee and recorded by american country music group lady antebellum","tag":"blue","_geo": { "lat": 43.973998070351065, "lng": 3.4661837318345032 },"":""} {"id":"J","word_rank":1,"typo_rank":0,"proximity_rank":1,"attribute_rank":1,"exact_rank":3,"asc_desc_rank":2,"sort_by_rank":1,"geo_rank":739020,"title":"hello cruel world","description":"hello cruel world is an album by new zealand band tall dwarfs","tag":"green","_geo": { "lat": 43.98920130353838, "lng": 3.480519311627928 },"":""} -{"id":"K","word_rank":0,"typo_rank":2,"proximity_rank":9,"attribute_rank":670,"exact_rank":5,"asc_desc_rank":1,"sort_by_rank":2,"geo_rank":738830,"title":"ello creation system","description":"in few word ello was a construction toy created by the american company mattel to engage girls in construction play","tag":"red","_geo": { "lat": 43.99155030238669, "lng": 3.503453528249425 },"":""} +{"id":"K","word_rank":0,"typo_rank":2,"proximity_rank":9,"attribute_rank":670,"exact_rank":5,"asc_desc_rank":1,"sort_by_rank":2,"geo_rank":738830,"title":"hallo creation system","description":"in few word hallo was a construction toy created by the american company mattel to engage girls in construction play","tag":"red","_geo": { "lat": 43.99155030238669, "lng": 3.503453528249425 },"":""} {"id":"L","word_rank":0,"typo_rank":0,"proximity_rank":2,"attribute_rank":250,"exact_rank":4,"asc_desc_rank":0,"sort_by_rank":0,"geo_rank":737861,"title":"good morning world","description":"good morning world is an american sitcom broadcast on cbs tv during the 1967 1968 season","tag":"blue","_geo": { "lat": 44.000507750283695, "lng": 3.5116812040621572 },"":""} {"id":"M","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":0,"asc_desc_rank":0,"sort_by_rank":2,"geo_rank":739203,"title":"hello world america","description":"a perfect match for a perfect engine using the query hello world america","tag":"red","_geo": { "lat": 43.99150729038736, "lng": 3.606143957295055 },"":""} {"id":"N","word_rank":0,"typo_rank":0,"proximity_rank":0,"attribute_rank":0,"exact_rank":1,"asc_desc_rank":4,"sort_by_rank":1,"geo_rank":9499586,"title":"hello world america unleashed","description":"a very good match for a very good engine using the query hello world america","tag":"green","_geo": { "lat": 35.511540843367115, "lng": 138.764368875787 },"":""} diff --git a/milli/tests/search/query_criteria.rs b/milli/tests/search/query_criteria.rs index 0dcbf660e..ef080db9f 100644 --- a/milli/tests/search/query_criteria.rs +++ b/milli/tests/search/query_criteria.rs @@ -61,6 +61,7 @@ test_criterion!( vec![Attribute], vec![] ); +test_criterion!(typo, DISALLOW_OPTIONAL_WORDS, ALLOW_TYPOS, vec![Typo], vec![]); test_criterion!( attribute_disallow_typo, DISALLOW_OPTIONAL_WORDS, From 3f24555c3d16b3078ef0182980341e2fbdc3ea43 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 15 Mar 2022 17:28:57 +0100 Subject: [PATCH 6/6] custom fst automatons --- Cargo.toml | 3 - milli/src/search/fst_utils.rs | 187 ++++++++++++++++++++++++++++++++++ milli/src/search/mod.rs | 16 +-- 3 files changed, 196 insertions(+), 10 deletions(-) create mode 100644 milli/src/search/fst_utils.rs diff --git a/Cargo.toml b/Cargo.toml index 52599b1bd..6b3e12f07 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,6 +18,3 @@ opt-level = 3 opt-level = 3 [profile.test.build-override] opt-level = 3 - -[patch.crates-io] -fst = { git = "https://github.com/MarinPostma/fst.git", rev = "e6c606b7507e8cb5e502d1609f9b909b8690bac5" } diff --git a/milli/src/search/fst_utils.rs b/milli/src/search/fst_utils.rs new file mode 100644 index 000000000..b488e6c19 --- /dev/null +++ b/milli/src/search/fst_utils.rs @@ -0,0 +1,187 @@ +/// This mod is necessary until https://github.com/BurntSushi/fst/pull/137 gets merged. +/// All credits for this code go to BurntSushi. +use fst::Automaton; + +pub struct StartsWith(pub A); + +/// The `Automaton` state for `StartsWith`. +pub struct StartsWithState(pub StartsWithStateKind); + +impl Clone for StartsWithState +where + A::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +/// The inner state of a `StartsWithState`. +pub enum StartsWithStateKind { + /// Sink state that is reached when the automaton has matched the prefix. + Done, + /// State in which the automaton is while it hasn't matched the prefix. + Running(A::State), +} + +impl Clone for StartsWithStateKind +where + A::State: Clone, +{ + fn clone(&self) -> Self { + match self { + StartsWithStateKind::Done => StartsWithStateKind::Done, + StartsWithStateKind::Running(inner) => StartsWithStateKind::Running(inner.clone()), + } + } +} + +impl Automaton for StartsWith { + type State = StartsWithState; + + fn start(&self) -> StartsWithState { + StartsWithState({ + let inner = self.0.start(); + if self.0.is_match(&inner) { + StartsWithStateKind::Done + } else { + StartsWithStateKind::Running(inner) + } + }) + } + fn is_match(&self, state: &StartsWithState) -> bool { + match state.0 { + StartsWithStateKind::Done => true, + StartsWithStateKind::Running(_) => false, + } + } + fn can_match(&self, state: &StartsWithState) -> bool { + match state.0 { + StartsWithStateKind::Done => true, + StartsWithStateKind::Running(ref inner) => self.0.can_match(inner), + } + } + fn will_always_match(&self, state: &StartsWithState) -> bool { + match state.0 { + StartsWithStateKind::Done => true, + StartsWithStateKind::Running(_) => false, + } + } + fn accept(&self, state: &StartsWithState, byte: u8) -> StartsWithState { + StartsWithState(match state.0 { + StartsWithStateKind::Done => StartsWithStateKind::Done, + StartsWithStateKind::Running(ref inner) => { + let next_inner = self.0.accept(inner, byte); + if self.0.is_match(&next_inner) { + StartsWithStateKind::Done + } else { + StartsWithStateKind::Running(next_inner) + } + } + }) + } +} +/// An automaton that matches when one of its component automata match. +#[derive(Clone, Debug)] +pub struct Union(pub A, pub B); + +/// The `Automaton` state for `Union`. +pub struct UnionState(pub A::State, pub B::State); + +impl Clone for UnionState +where + A::State: Clone, + B::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone()) + } +} + +impl Automaton for Union { + type State = UnionState; + fn start(&self) -> UnionState { + UnionState(self.0.start(), self.1.start()) + } + fn is_match(&self, state: &UnionState) -> bool { + self.0.is_match(&state.0) || self.1.is_match(&state.1) + } + fn can_match(&self, state: &UnionState) -> bool { + self.0.can_match(&state.0) || self.1.can_match(&state.1) + } + fn will_always_match(&self, state: &UnionState) -> bool { + self.0.will_always_match(&state.0) || self.1.will_always_match(&state.1) + } + fn accept(&self, state: &UnionState, byte: u8) -> UnionState { + UnionState(self.0.accept(&state.0, byte), self.1.accept(&state.1, byte)) + } +} +/// An automaton that matches when both of its component automata match. +#[derive(Clone, Debug)] +pub struct Intersection(pub A, pub B); + +/// The `Automaton` state for `Intersection`. +pub struct IntersectionState(pub A::State, pub B::State); + +impl Clone for IntersectionState +where + A::State: Clone, + B::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone(), self.1.clone()) + } +} + +impl Automaton for Intersection { + type State = IntersectionState; + fn start(&self) -> IntersectionState { + IntersectionState(self.0.start(), self.1.start()) + } + fn is_match(&self, state: &IntersectionState) -> bool { + self.0.is_match(&state.0) && self.1.is_match(&state.1) + } + fn can_match(&self, state: &IntersectionState) -> bool { + self.0.can_match(&state.0) && self.1.can_match(&state.1) + } + fn will_always_match(&self, state: &IntersectionState) -> bool { + self.0.will_always_match(&state.0) && self.1.will_always_match(&state.1) + } + fn accept(&self, state: &IntersectionState, byte: u8) -> IntersectionState { + IntersectionState(self.0.accept(&state.0, byte), self.1.accept(&state.1, byte)) + } +} +/// An automaton that matches exactly when the automaton it wraps does not. +#[derive(Clone, Debug)] +pub struct Complement(pub A); + +/// The `Automaton` state for `Complement`. +pub struct ComplementState(pub A::State); + +impl Clone for ComplementState +where + A::State: Clone, +{ + fn clone(&self) -> Self { + Self(self.0.clone()) + } +} + +impl Automaton for Complement { + type State = ComplementState; + fn start(&self) -> ComplementState { + ComplementState(self.0.start()) + } + fn is_match(&self, state: &ComplementState) -> bool { + !self.0.is_match(&state.0) + } + fn can_match(&self, state: &ComplementState) -> bool { + !self.0.will_always_match(&state.0) + } + fn will_always_match(&self, state: &ComplementState) -> bool { + !self.0.can_match(&state.0) + } + fn accept(&self, state: &ComplementState, byte: u8) -> ComplementState { + ComplementState(self.0.accept(&state.0, byte)) + } +} diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index bfe5e023c..40e4bca24 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -16,6 +16,7 @@ use once_cell::sync::Lazy; use roaring::bitmap::RoaringBitmap; pub use self::facet::{FacetDistribution, FacetNumberIter, Filter}; +use self::fst_utils::{Complement, Intersection, StartsWith, Union}; pub use self::matching_words::MatchingWords; use self::query_tree::QueryTreeBuilder; use crate::error::UserError; @@ -30,6 +31,7 @@ static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); mod criteria; mod distinct; mod facet; +mod fst_utils; mod matching_words; mod query_tree; @@ -70,7 +72,6 @@ impl<'a> Search<'a> { pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { self.offset = offset; - self } @@ -301,8 +302,9 @@ pub fn word_derivations<'c>( } else { if max_typo == 1 { let dfa = build_dfa(word, 1, is_prefix); - let starts = Str::new(get_first(word)).starts_with(); - let mut stream = fst.search_with_state(starts.intersection(&dfa)).into_stream(); + let starts = StartsWith(Str::new(get_first(word))); + let mut stream = + fst.search_with_state(Intersection(starts, &dfa)).into_stream(); while let Some((word, state)) = stream.next() { let word = std::str::from_utf8(word)?; @@ -310,11 +312,11 @@ pub fn word_derivations<'c>( derived_words.push((word.to_string(), d.to_u8())); } } else { - let starts = Str::new(get_first(word)).starts_with(); - let first = build_dfa(word, 1, is_prefix).intersection((&starts).complement()); + let starts = StartsWith(Str::new(get_first(word))); + let first = Intersection(build_dfa(word, 1, is_prefix), Complement(&starts)); let second_dfa = build_dfa(word, 2, is_prefix); - let second = (&second_dfa).intersection(&starts); - let automaton = first.union(&second); + let second = Intersection(&second_dfa, &starts); + let automaton = Union(first, &second); let mut stream = fst.search_with_state(automaton).into_stream();