From bf3c2c372554829e0ef314cd55eb09fb0378c07d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Fri, 2 Aug 2019 12:07:23 +0200 Subject: [PATCH] feat: Move the multi-word rewriting algorithm into its own function --- meilidb-core/Cargo.toml | 2 +- meilidb-core/src/criterion/sum_of_typos.rs | 2 +- meilidb-core/src/query_builder.rs | 222 ++++++++++-------- meilidb-core/src/query_enhancer.rs | 15 +- meilidb-core/src/raw_document.rs | 4 +- .../src/database/synonyms_addition.rs | 4 +- meilidb/examples/create-database.rs | 72 +++++- 7 files changed, 204 insertions(+), 117 deletions(-) diff --git a/meilidb-core/Cargo.toml b/meilidb-core/Cargo.toml index 037a7788c..25fb57119 100644 --- a/meilidb-core/Cargo.toml +++ b/meilidb-core/Cargo.toml @@ -14,7 +14,7 @@ meilidb-tokenizer = { path = "../meilidb-tokenizer", version = "0.1.0" } rayon = "1.0.3" sdset = "0.3.2" serde = { version = "1.0.88", features = ["derive"] } -slice-group-by = "0.2.4" +slice-group-by = "0.2.6" zerocopy = "0.2.2" [dependencies.fst] diff --git a/meilidb-core/src/criterion/sum_of_typos.rs b/meilidb-core/src/criterion/sum_of_typos.rs index d5cd75f08..6736e6caa 100644 --- a/meilidb-core/src/criterion/sum_of_typos.rs +++ b/meilidb-core/src/criterion/sum_of_typos.rs @@ -21,7 +21,7 @@ fn custom_log10(n: u8) -> f32 { #[inline] fn sum_matches_typos(query_index: &[u32], distance: &[u8]) -> usize { - let mut number_words = 0; + let mut number_words: usize = 0; let mut sum_typos = 0.0; let mut index = 0; diff --git a/meilidb-core/src/query_builder.rs b/meilidb-core/src/query_builder.rs index c5a0ac847..7c3183ff4 100644 --- a/meilidb-core/src/query_builder.rs +++ b/meilidb-core/src/query_builder.rs @@ -197,6 +197,110 @@ impl<'c, S, FI> QueryBuilder<'c, S, FI> } } +fn multiword_rewrite_matches( + mut matches: Vec<(DocumentId, TmpMatch)>, + query_enhancer: &QueryEnhancer, +) -> SetBuf<(DocumentId, TmpMatch)> +{ + let mut padded_matches = Vec::with_capacity(matches.len()); + + // we sort the matches by word index to make them rewritable + let start = Instant::now(); + matches.par_sort_unstable_by_key(|(id, match_)| (*id, match_.attribute, match_.word_index)); + info!("rewrite sort by word_index took {:.2?}", start.elapsed()); + + let start = Instant::now(); + // for each attribute of each document + for same_document_attribute in matches.linear_group_by_key(|(id, m)| (*id, m.attribute)) { + + // padding will only be applied + // to word indices in the same attribute + let mut padding = 0; + let mut iter = same_document_attribute.linear_group_by_key(|(_, m)| m.word_index); + + // for each match at the same position + // in this document attribute + while let Some(same_word_index) = iter.next() { + + // find the biggest padding + let mut biggest = 0; + for (id, match_) in same_word_index { + + let mut replacement = query_enhancer.replacement(match_.query_index); + let replacement_len = replacement.len(); + let nexts = iter.remainder().linear_group_by_key(|(_, m)| m.word_index); + + if let Some(query_index) = replacement.next() { + let word_index = match_.word_index + padding as u16; + let match_ = TmpMatch { query_index, word_index, ..match_.clone() }; + padded_matches.push((*id, match_)); + } + + let mut found = false; + + // look ahead and if there already is a match + // corresponding to this padding word, abort the padding + 'padding: for (x, next_group) in nexts.enumerate() { + + for (i, query_index) in replacement.clone().enumerate().skip(x) { + let word_index = match_.word_index + padding as u16 + (i + 1) as u16; + let padmatch = TmpMatch { query_index, word_index, ..match_.clone() }; + + for (_, nmatch_) in next_group { + let mut rep = query_enhancer.replacement(nmatch_.query_index); + let query_index = rep.next().unwrap(); + if query_index == padmatch.query_index { + + if !found { + // if we find a corresponding padding for the + // first time we must push preceding paddings + for (i, query_index) in replacement.clone().enumerate().take(i) { + let word_index = match_.word_index + padding as u16 + (i + 1) as u16; + let match_ = TmpMatch { query_index, word_index, ..match_.clone() }; + padded_matches.push((*id, match_)); + biggest = biggest.max(i + 1); + } + } + + padded_matches.push((*id, padmatch)); + found = true; + continue 'padding; + } + } + } + + // if we do not find a corresponding padding in the + // next groups so stop here and pad what was found + break + } + + if !found { + // if no padding was found in the following matches + // we must insert the entire padding + for (i, query_index) in replacement.enumerate() { + let word_index = match_.word_index + padding as u16 + (i + 1) as u16; + let match_ = TmpMatch { query_index, word_index, ..match_.clone() }; + padded_matches.push((*id, match_)); + } + + biggest = biggest.max(replacement_len - 1); + } + } + + padding += biggest; + } + } + info!("main multiword rewrite took {:.2?}", start.elapsed()); + + let start = Instant::now(); + for document_matches in padded_matches.linear_group_by_key_mut(|(id, _)| *id) { + document_matches.sort_unstable(); + } + info!("final rewrite sort took {:.2?}", start.elapsed()); + + SetBuf::new_unchecked(padded_matches) +} + impl<'c, S, FI> QueryBuilder<'c, S, FI> where S: Store, { @@ -217,22 +321,26 @@ where S: Store, let mut matches = Vec::new(); let mut highlights = Vec::new(); + let mut query_db = std::time::Duration::default(); + + let start = Instant::now(); while let Some((input, indexed_values)) = stream.next() { for iv in indexed_values { let Automaton { is_exact, query_len, ref dfa } = automatons[iv.index]; let distance = dfa.eval(input).to_u8(); let is_exact = is_exact && distance == 0 && input.len() == query_len; + let start = Instant::now(); let doc_indexes = self.store.word_indexes(input)?; let doc_indexes = match doc_indexes { Some(doc_indexes) => doc_indexes, None => continue, }; + query_db += start.elapsed(); for di in doc_indexes.as_slice() { let attribute = searchables.map_or(Some(di.attribute), |r| r.get(di.attribute)); if let Some(attribute) = attribute { - let match_ = TmpMatch { query_index: iv.index as u32, distance, @@ -253,118 +361,28 @@ where S: Store, } } } + info!("main query all took {:.2?} (get indexes {:.2?})", start.elapsed(), query_db); - // we sort the matches to make them rewritable - matches.par_sort_unstable_by_key(|(id, match_)| (*id, match_.attribute, match_.word_index)); + info!("{} total matches to rewrite", matches.len()); - let mut padded_matches = Vec::with_capacity(matches.len()); - for same_document in matches.linear_group_by(|a, b| a.0 == b.0) { - - for same_attribute in same_document.linear_group_by(|a, b| a.1.attribute == b.1.attribute) { - - let mut padding = 0; - let mut iter = same_attribute.linear_group_by(|a, b| a.1.word_index == b.1.word_index); - while let Some(same_word_index) = iter.next() { - - let mut biggest = 0; - for (id, match_) in same_word_index { - - let mut replacement = query_enhancer.replacement(match_.query_index); - let replacement_len = replacement.len() - 1; - let nexts = iter.remainder().linear_group_by(|a, b| a.1.word_index == b.1.word_index); - - if let Some(query_index) = replacement.next() { - let match_ = TmpMatch { - query_index, - word_index: match_.word_index + padding as u16, - ..match_.clone() - }; - padded_matches.push((*id, match_)); - } - - let mut found = false; - - // look ahead and if there already is a match - // corresponding to this padding word, abort the padding - 'padding: for (x, next_group) in nexts.enumerate() { - - for (i, query_index) in replacement.clone().enumerate().skip(x) { - let padmatch_ = TmpMatch { - query_index, - word_index: match_.word_index + padding as u16 + (i + 1) as u16, - ..match_.clone() - }; - - for (_, nmatch_) in next_group { - let mut rep = query_enhancer.replacement(nmatch_.query_index); - let query_index = rep.next().unwrap(); - let nmatch_ = TmpMatch { query_index, ..nmatch_.clone() }; - if nmatch_.query_index == padmatch_.query_index { - - if !found { - // if we find a corresponding padding for the - // first time we must push preceding paddings - for (i, query_index) in replacement.clone().enumerate().take(i) { - let match_ = TmpMatch { - query_index, - word_index: match_.word_index + padding as u16 + (i + 1) as u16, - ..match_.clone() - }; - padded_matches.push((*id, match_)); - biggest = biggest.max(i + 1); - } - } - - padded_matches.push((*id, padmatch_)); - found = true; - continue 'padding; - } - } - } - - // if we do not find a corresponding padding in the - // next groups so stop here and pad what was found - break - } - - if !found { - // if no padding was found in the following matches - // we must insert the entire padding - for (i, query_index) in replacement.enumerate() { - let match_ = TmpMatch { - query_index, - word_index: match_.word_index + padding as u16 + (i + 1) as u16, - ..match_.clone() - }; - padded_matches.push((*id, match_)); - } - - biggest = biggest.max(replacement_len); - } - } - - padding += biggest; - } - } - - } - - - let matches = { - padded_matches.par_sort_unstable(); - SetBuf::new_unchecked(padded_matches) - }; + let start = Instant::now(); + let matches = multiword_rewrite_matches(matches, &query_enhancer); + info!("multiword rewrite took {:.2?}", start.elapsed()); + let start = Instant::now(); let highlights = { highlights.par_sort_unstable_by_key(|(id, _)| *id); SetBuf::new_unchecked(highlights) }; + info!("sorting highlights took {:.2?}", start.elapsed()); - let total_matches = matches.len(); + info!("{} total matches to classify", matches.len()); + + let start = Instant::now(); let raw_documents = raw_documents_from(matches, highlights); + info!("making raw documents took {:.2?}", start.elapsed()); info!("{} total documents to classify", raw_documents.len()); - info!("{} total matches to classify", total_matches); Ok(raw_documents) } diff --git a/meilidb-core/src/query_enhancer.rs b/meilidb-core/src/query_enhancer.rs index 6280ae11e..165c1b094 100644 --- a/meilidb-core/src/query_enhancer.rs +++ b/meilidb-core/src/query_enhancer.rs @@ -52,17 +52,20 @@ where S: AsRef, !original.map(AsRef::as_ref).eq(words.iter().map(AsRef::as_ref)) } +type Origin = usize; +type RealLength = usize; + struct FakeIntervalTree { - intervals: Vec<(Range, (usize, usize))>, // origin, real_length + intervals: Vec<(Range, (Origin, RealLength))>, } impl FakeIntervalTree { - fn new(mut intervals: Vec<(Range, (usize, usize))>) -> FakeIntervalTree { + fn new(mut intervals: Vec<(Range, (Origin, RealLength))>) -> FakeIntervalTree { intervals.sort_unstable_by_key(|(r, _)| (r.start, r.end)); FakeIntervalTree { intervals } } - fn query(&self, point: usize) -> Option<(Range, (usize, usize))> { + fn query(&self, point: usize) -> Option<(Range, (Origin, RealLength))> { let element = self.intervals.binary_search_by(|(r, _)| { if point >= r.start { if point < r.end { Equal } else { Less } @@ -81,7 +84,7 @@ impl FakeIntervalTree { pub struct QueryEnhancerBuilder<'a, S> { query: &'a [S], origins: Vec, - real_to_origin: Vec<(Range, (usize, usize))>, + real_to_origin: Vec<(Range, (Origin, RealLength))>, } impl> QueryEnhancerBuilder<'_, S> { @@ -147,8 +150,8 @@ impl QueryEnhancer { // query the fake interval tree with the real query index let (range, (origin, real_length)) = self.real_to_origin - .query(real) - .expect("real has never been declared"); + .query(real) + .expect("real has never been declared"); // if `real` is the end bound of the range if (range.start + real_length - 1) == real { diff --git a/meilidb-core/src/raw_document.rs b/meilidb-core/src/raw_document.rs index 5d449a74a..3567c3fd1 100644 --- a/meilidb-core/src/raw_document.rs +++ b/meilidb-core/src/raw_document.rs @@ -74,8 +74,8 @@ pub fn raw_documents_from( let mut docs_ranges: Vec<(_, Range, _)> = Vec::new(); let mut matches2 = Matches::with_capacity(matches.len()); - let matches = matches.linear_group_by(|(a, _), (b, _)| a == b); - let highlights = highlights.linear_group_by(|(a, _), (b, _)| a == b); + let matches = matches.linear_group_by_key(|(id, _)| *id); + let highlights = highlights.linear_group_by_key(|(id, _)| *id); for (mgroup, hgroup) in matches.zip(highlights) { debug_assert_eq!(mgroup[0].0, hgroup[0].0); diff --git a/meilidb-data/src/database/synonyms_addition.rs b/meilidb-data/src/database/synonyms_addition.rs index 6e16ab97b..c37f0475a 100644 --- a/meilidb-data/src/database/synonyms_addition.rs +++ b/meilidb-data/src/database/synonyms_addition.rs @@ -21,10 +21,10 @@ impl<'a> SynonymsAddition<'a> { pub fn add_synonym(&mut self, synonym: S, alternatives: I) where S: AsRef, T: AsRef, - I: Iterator, + I: IntoIterator, { let synonym = normalize_str(synonym.as_ref()); - let alternatives = alternatives.map(|s| s.as_ref().to_lowercase()); + let alternatives = alternatives.into_iter().map(|s| s.as_ref().to_lowercase()); self.synonyms.entry(synonym).or_insert_with(Vec::new).extend(alternatives); } diff --git a/meilidb/examples/create-database.rs b/meilidb/examples/create-database.rs index ed07e3742..d8e553ed3 100644 --- a/meilidb/examples/create-database.rs +++ b/meilidb/examples/create-database.rs @@ -31,9 +31,13 @@ pub struct Opt { #[structopt(long = "schema", parse(from_os_str))] pub schema_path: PathBuf, + /// The file with the synonyms. + #[structopt(long = "synonyms", parse(from_os_str))] + pub synonyms: Option, + /// The path to the list of stop words (one by line). #[structopt(long = "stop-words", parse(from_os_str))] - pub stop_words_path: Option, + pub stop_words: Option, #[structopt(long = "update-group-size")] pub update_group_size: Option, @@ -45,12 +49,40 @@ struct Document<'a> ( HashMap, Cow<'a, str>> ); +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Synonym { + OneWay(SynonymOneWay), + MultiWay { synonyms: Vec }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SynonymOneWay { + pub search_terms: String, + pub synonyms: Synonyms, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Synonyms { + Multiple(Vec), + Single(String), +} + +fn read_synomys(path: &Path) -> Result, Box> { + let file = File::open(path)?; + let synonyms = serde_json::from_reader(file)?; + Ok(synonyms) +} + fn index( schema: Schema, database_path: &Path, csv_data_path: &Path, update_group_size: Option, stop_words: &HashSet, + synonyms: Vec, ) -> Result> { let database = Database::start_default(database_path)?; @@ -62,6 +94,28 @@ fn index( let index = database.create_index("test", schema.clone())?; + let mut synonyms_adder = index.synonyms_addition(); + for synonym in synonyms { + match synonym { + Synonym::OneWay(SynonymOneWay { search_terms, synonyms }) => { + let alternatives = match synonyms { + Synonyms::Multiple(alternatives) => alternatives, + Synonyms::Single(alternative) => vec![alternative], + }; + synonyms_adder.add_synonym(search_terms, alternatives); + }, + Synonym::MultiWay { mut synonyms } => { + for _ in 0..synonyms.len() { + if let Some((synonym, alternatives)) = synonyms.split_first() { + synonyms_adder.add_synonym(synonym, alternatives); + } + synonyms.rotate_left(1); + } + }, + } + } + synonyms_adder.finalize()?; + let mut rdr = csv::Reader::from_path(csv_data_path)?; let mut raw_record = csv::StringRecord::new(); let headers = rdr.headers()?.clone(); @@ -133,13 +187,25 @@ fn main() -> Result<(), Box> { Schema::from_toml(file)? }; - let stop_words = match opt.stop_words_path { + let stop_words = match opt.stop_words { Some(ref path) => retrieve_stop_words(path)?, None => HashSet::new(), }; + let synonyms = match opt.synonyms { + Some(ref path) => read_synomys(path)?, + None => Vec::new(), + }; + let start = Instant::now(); - let result = index(schema, &opt.database_path, &opt.csv_data_path, opt.update_group_size, &stop_words); + let result = index( + schema, + &opt.database_path, + &opt.csv_data_path, + opt.update_group_size, + &stop_words, + synonyms, + ); if let Err(e) = result { return Err(e.into())