diff --git a/milli/src/update/word_prefix_pair_proximity_docids/mod.rs b/milli/src/update/word_prefix_pair_proximity_docids/mod.rs index 119c0c53e..5b073bb95 100644 --- a/milli/src/update/word_prefix_pair_proximity_docids/mod.rs +++ b/milli/src/update/word_prefix_pair_proximity_docids/mod.rs @@ -5,6 +5,7 @@ use heed::BytesDecode; use log::debug; use std::borrow::Cow; +use std::cmp::Ordering; use std::collections::HashSet; use std::io::BufReader; use std::time::Instant; @@ -72,6 +73,84 @@ impl<'t, 'u, 'i> WordPrefixPairProximityDocids<'t, 'u, 'i> { del_prefix_fst_words: &HashSet>, ) -> Result<()> { debug!("Computing and writing the word prefix pair proximity docids into LMDB on disk..."); + let mut allocations = Allocations::default(); + + let mut count = 0; + + let prefixes = PrefixTrieNode::from_sorted_prefixes( + common_prefix_fst_words + .into_iter() + .map(|s| s.into_iter()) + .flatten() + .map(|s| s.as_str()) + .filter(|s| s.len() <= self.max_prefix_length), + ); + + if !prefixes.is_empty() { + let mut cursor = new_word_pair_proximity_docids.into_cursor()?; + Self::execute_on_word_pairs_and_prefixes( + &mut cursor, + |cursor| { + if let Some((key, value)) = cursor.move_on_next()? { + let (word1, word2, proximity) = UncheckedStrStrU8Codec::bytes_decode(key) + .ok_or(heed::Error::Decoding)?; + Ok(Some(((word1, word2, proximity), value))) + } else { + Ok(None) + } + }, + &prefixes, + &mut allocations, + self.max_proximity, + |key, value| { + count += 1; + insert_into_database( + &mut self.wtxn, + *self.index.word_prefix_pair_proximity_docids.as_polymorph(), + key, + value, + ) + }, + )?; + } + dbg!(count); + + let prefixes = PrefixTrieNode::from_sorted_prefixes( + new_prefix_fst_words + .into_iter() + .map(|s| s.as_str()) + .filter(|s| s.len() <= self.max_prefix_length), + ); + + if !prefixes.is_empty() { + let mut db_iter = self + .index + .word_pair_proximity_docids + .remap_key_type::() + .remap_data_type::() + .iter(self.wtxn)?; + + let mut writer = create_writer( + self.chunk_compression_type, + self.chunk_compression_level, + tempfile::tempfile()?, + ); + + Self::execute_on_word_pairs_and_prefixes( + &mut db_iter, + |db_iter| db_iter.next().transpose().map_err(|e| e.into()), + &prefixes, + &mut allocations, + self.max_proximity, + |key, value| writer.insert(key, value).map_err(|e| e.into()), + )?; + drop(db_iter); + writer_into_lmdb_database( + self.wtxn, + *self.index.word_prefix_pair_proximity_docids.as_polymorph(), + writer, + )?; + } // All of the word prefix pairs in the database that have a w2 // that is contained in the `suppr_pw` set must be removed as well. @@ -89,131 +168,71 @@ impl<'t, 'u, 'i> WordPrefixPairProximityDocids<'t, 'u, 'i> { } } - // We construct a Trie of all the prefixes that are smaller than the max prefix length - // This is an optimisation that allows us to iterate over all prefixes of a word quickly. - let new_prefix_fst_words = PrefixTrieNode::from_sorted_prefixes( - new_prefix_fst_words - .into_iter() - .map(|s| s.as_str()) - .filter(|s| s.len() <= self.max_prefix_length), - ); - - let common_prefix_fst_words = PrefixTrieNode::from_sorted_prefixes( - common_prefix_fst_words - .into_iter() - .map(|s| s.into_iter()) - .flatten() - .map(|s| s.as_str()) - .filter(|s| s.len() <= self.max_prefix_length), - ); - - let mut allocations = Allocations::default(); - let mut batch = PrefixAndProximityBatch::default(); - - if !common_prefix_fst_words.is_empty() { - let mut cursor = new_word_pair_proximity_docids.into_cursor()?; - - while let Some((key, data)) = cursor.move_on_next()? { - let (word1, word2, proximity) = - UncheckedStrStrU8Codec::bytes_decode(key).ok_or(heed::Error::Decoding)?; - - if proximity <= self.max_proximity { - batch.flush_if_necessary( - word1, - word2, - &mut allocations, - &mut |key, value| { - insert_into_database( - &mut self.wtxn, - *self.index.word_prefix_pair_proximity_docids.as_polymorph(), - key, - value, - ) - }, - )?; - self.insert_word_prefix_pair_proximity_docids_into_batch( - word2, - proximity, - data, - &common_prefix_fst_words, - &mut batch, - &mut allocations, - )?; - } - } - batch.flush(&mut allocations, &mut |key, value| { - insert_into_database( - &mut self.wtxn, - *self.index.word_prefix_pair_proximity_docids.as_polymorph(), - key, - value, - ) - })?; - } - - if !new_prefix_fst_words.is_empty() { - let mut db_iter = self - .index - .word_pair_proximity_docids - .remap_key_type::() - .remap_data_type::() - .iter(self.wtxn)?; - - let mut writer = create_writer( - self.chunk_compression_type, - self.chunk_compression_level, - tempfile::tempfile()?, - ); - - while let Some(((word1, word2, proximity), data)) = db_iter.next().transpose()? { - if proximity <= self.max_proximity { - batch.flush_if_necessary( - word1, - word2, - &mut allocations, - &mut |key, value| writer.insert(key, value).map_err(|e| e.into()), - )?; - self.insert_word_prefix_pair_proximity_docids_into_batch( - word2, - proximity, - data, - &new_prefix_fst_words, - &mut batch, - &mut allocations, - )?; - } - } - batch.flush(&mut allocations, &mut |key, value| { - writer.insert(key, value).map_err(|e| e.into()) - })?; - - drop(db_iter); - writer_into_lmdb_database( - self.wtxn, - *self.index.word_prefix_pair_proximity_docids.as_polymorph(), - writer, - )?; - } - Ok(()) } - fn insert_word_prefix_pair_proximity_docids_into_batch<'b, 'c>( - &self, - word2: &[u8], - proximity: u8, - data: &'b [u8], - prefixes: &'c PrefixTrieNode, - writer: &'b mut PrefixAndProximityBatch, + fn execute_on_word_pairs_and_prefixes( + iter: &mut Iter, + mut next_word_pair_proximity: impl for<'a> FnMut( + &'a mut Iter, + ) -> Result< + Option<((&'a [u8], &'a [u8], u8), &'a [u8])>, + >, + prefixes: &PrefixTrieNode, allocations: &mut Allocations, + max_proximity: u8, + mut insert: impl for<'a> FnMut(&'a [u8], &'a [u8]) -> Result<()>, ) -> Result<()> { + let mut batch = PrefixAndProximityBatch::default(); + let mut prev_word2_start = 0; + + let mut prefix_search_start = PrefixTrieNodeSearchStart(0); + let mut empty_prefixes = false; + let mut prefix_buffer = allocations.take_byte_vector(); - prefixes.for_each_prefix_of(word2, &mut prefix_buffer, |prefix| { - let mut value = allocations.take_byte_vector(); - value.extend_from_slice(&data); - writer.insert(prefix, proximity, value, allocations); - }); - allocations.reclaim_byte_vector(prefix_buffer); + + while let Some(((word1, word2, proximity), data)) = next_word_pair_proximity(iter)? { + if proximity > max_proximity { + continue; + }; + let word2_start_different_than_prev = word2[0] != prev_word2_start; + if empty_prefixes && !word2_start_different_than_prev { + continue; + } + let word1_different_than_prev = word1 != batch.word1; + if word1_different_than_prev || word2_start_different_than_prev { + batch.flush(allocations, &mut insert)?; + if word1_different_than_prev { + prefix_search_start.0 = 0; + batch.word1.clear(); + batch.word1.extend_from_slice(word1); + } + if word2_start_different_than_prev { + // word2_start_different_than_prev == true + prev_word2_start = word2[0]; + } + empty_prefixes = !prefixes.set_search_start(word2, &mut prefix_search_start); + } + + if !empty_prefixes { + prefixes.for_each_prefix_of( + word2, + &mut prefix_buffer, + &prefix_search_start, + |prefix_buffer| { + let mut value = allocations.take_byte_vector(); + value.extend_from_slice(&data); + let prefix_len = prefix_buffer.len(); + prefix_buffer.push(0); + prefix_buffer.push(proximity); + batch.insert(&prefix_buffer, value, allocations); + prefix_buffer.truncate(prefix_len); + }, + ); + prefix_buffer.clear(); + } + } + batch.flush(allocations, &mut insert)?; Ok(()) } } @@ -224,10 +243,6 @@ The keys are sorted and conflicts are resolved by merging the vectors of bitstri It is used to ensure that all ((word1, prefix, proximity), docids) are inserted into the database in sorted order and efficiently. -A batch is valid only for a specific `word1`. Also, all prefixes stored in the batch start with the same letter. Make sure to -call [`self.flush_if_necessary`](Self::flush_if_necessary) before inserting a list of sorted `(prefix, proximity)` (and where each -`prefix` starts with the same letter) in order to uphold these invariants. - The batch is flushed as often as possible, when we are sure that every (word1, prefix, proximity) key derived from its content can be inserted into the database in sorted order. When it is flushed, it calls a user-provided closure with the following arguments: - key : (word1, prefix, proximity) as bytes @@ -235,91 +250,95 @@ can be inserted into the database in sorted order. When it is flushed, it calls */ #[derive(Default)] struct PrefixAndProximityBatch { - batch: Vec<(Vec, Vec>)>, word1: Vec, - word2_start: u8, + batch: Vec<(Vec, Vec>)>, } impl PrefixAndProximityBatch { - fn insert( - &mut self, - new_prefix: &[u8], - new_proximity: u8, - new_value: Vec, - allocations: &mut Allocations, - ) { - let mut key = allocations.take_byte_vector(); - key.extend_from_slice(new_prefix); - key.push(0); - key.push(new_proximity); - - if let Some(position) = self.batch.iter().position(|(k, _)| k >= &key) { - let (existing_key, existing_data) = &mut self.batch[position]; - if existing_key == &key { - existing_data.push(Cow::Owned(new_value)); - } else { + fn insert(&mut self, new_key: &[u8], new_value: Vec, allocations: &mut Allocations) { + // this is a macro instead of a closure because the borrow checker will complain + // about the closure moving `new_value` + macro_rules! insert_new_key_value { + () => { + let mut key = allocations.take_byte_vector(); + key.extend_from_slice(new_key); let mut mergeable_data = allocations.take_mergeable_data_vector(); mergeable_data.push(Cow::Owned(new_value)); - self.batch.insert(position, (key, mergeable_data)); + self.batch.push((key, mergeable_data)); + }; + ($idx:expr) => { + let mut key = allocations.take_byte_vector(); + key.extend_from_slice(new_key); + let mut mergeable_data = allocations.take_mergeable_data_vector(); + mergeable_data.push(Cow::Owned(new_value)); + self.batch.insert($idx, (key, mergeable_data)); + }; + } + + if self.batch.is_empty() { + insert_new_key_value!(); + } else if self.batch.len() == 1 { + let (existing_key, existing_data) = &mut self.batch[0]; + match new_key.cmp(&existing_key) { + Ordering::Less => { + insert_new_key_value!(0); + } + Ordering::Equal => { + existing_data.push(Cow::Owned(new_value)); + } + Ordering::Greater => { + insert_new_key_value!(); + } } } else { - let mut mergeable_data = allocations.take_mergeable_data_vector(); - mergeable_data.push(Cow::Owned(new_value)); - self.batch.push((key, mergeable_data)); - } - } - - /// Call [`self.flush`](Self::flush) if `word1` changed or if `word2` begins with a different letter than the - /// previous word2. Update `prev_word1` and `prev_word2_start` with the new values from `word1` and `word2`. - fn flush_if_necessary( - &mut self, - word1: &[u8], - word2: &[u8], - allocations: &mut Allocations, - insert: &mut impl for<'buffer> FnMut(&'buffer [u8], &'buffer [u8]) -> Result<()>, - ) -> Result<()> { - let word2_start = word2[0]; - if word1 != self.word1 { - self.flush(allocations, insert)?; - self.word1.clear(); - self.word1.extend_from_slice(word1); - if word2_start != self.word2_start { - self.word2_start = word2_start; + match self.batch.binary_search_by_key(&new_key, |(k, _)| k.as_slice()) { + Ok(position) => { + self.batch[position].1.push(Cow::Owned(new_value)); + } + Err(position) => { + insert_new_key_value!(position); + } } } - if word2_start != self.word2_start { - self.flush(allocations, insert)?; - self.word2_start = word2_start; - } - Ok(()) } /// Empties the batch, calling `insert` on each element. /// - /// The key given to insert is `(word1, prefix, proximity)` and the value is the associated merged roaring bitmap. + /// The key given to `insert` is `(word1, prefix, proximity)` and the value is the associated merged roaring bitmap. fn flush( &mut self, allocations: &mut Allocations, insert: &mut impl for<'buffer> FnMut(&'buffer [u8], &'buffer [u8]) -> Result<()>, ) -> Result<()> { - let PrefixAndProximityBatch { batch, word1: prev_word1, word2_start: _ } = self; + let PrefixAndProximityBatch { word1, batch } = self; + if batch.is_empty() { + return Ok(()); + } + let mut buffer = allocations.take_byte_vector(); - buffer.extend_from_slice(prev_word1.as_slice()); + buffer.extend_from_slice(word1); buffer.push(0); for (key, mergeable_data) in batch.drain(..) { - buffer.truncate(prev_word1.len() + 1); + buffer.truncate(word1.len() + 1); buffer.extend_from_slice(key.as_slice()); - let data = merge_cbo_roaring_bitmaps(&buffer, &mergeable_data)?; - insert(buffer.as_slice(), &data)?; - + let merged; + let data = if mergeable_data.len() > 1 { + merged = merge_cbo_roaring_bitmaps(&buffer, &mergeable_data)?; + &merged + } else { + &mergeable_data[0] + }; + insert(buffer.as_slice(), data)?; allocations.reclaim_byte_vector(key); allocations.reclaim_mergeable_data_vector(mergeable_data); } + Ok(()) } } +// This is adapted from `sorter_into_lmdb_database` fn insert_into_database( wtxn: &mut heed::RwTxn, database: heed::PolyDatabase, @@ -356,7 +375,8 @@ pub fn writer_into_lmdb_database( ) -> Result<()> { let file = writer.into_inner()?; let reader = grenad::Reader::new(BufReader::new(file))?; - + let len = reader.len(); + dbg!(len); let before = Instant::now(); if database.is_empty(wtxn)? { @@ -413,10 +433,44 @@ struct PrefixTrieNode { is_end_node: bool, } +#[derive(Debug)] +struct PrefixTrieNodeSearchStart(usize); + impl PrefixTrieNode { fn is_empty(&self) -> bool { self.children.is_empty() } + + /// Returns false if the trie does not contain a prefix of the given word. + /// Returns true if the trie *may* contain a prefix of the given word. + /// + /// Moves the search start to the first node equal to the first letter of the word, + /// or to 0 otherwise. + fn set_search_start(&self, word: &[u8], search_start: &mut PrefixTrieNodeSearchStart) -> bool { + let byte = word[0]; + if self.children[search_start.0].1 == byte { + return true; + } else if let Some(position) = + self.children[search_start.0..].iter().position(|(_, c)| *c >= byte) + { + let (_, c) = self.children[search_start.0 + position]; + // dbg!(position, c, byte); + if c == byte { + // dbg!(); + search_start.0 += position; + true + } else { + // dbg!(); + search_start.0 = 0; + false + } + } else { + // dbg!(); + search_start.0 = 0; + false + } + } + fn from_sorted_prefixes<'a>(prefixes: impl Iterator) -> Self { let mut node = PrefixTrieNode::default(); for prefix in prefixes { @@ -439,17 +493,41 @@ impl PrefixTrieNode { self.is_end_node = true; } } - fn for_each_prefix_of(&self, word: &[u8], buffer: &mut Vec, mut do_fn: impl FnMut(&[u8])) { + fn for_each_prefix_of( + &self, + word: &[u8], + buffer: &mut Vec, + search_start: &PrefixTrieNodeSearchStart, + mut do_fn: impl FnMut(&mut Vec), + ) { + let first_byte = word[0]; let mut cur_node = self; - for &byte in word { - buffer.push(byte); - if let Some((child_node, _)) = cur_node.children.iter().find(|(_, c)| *c == byte) { + buffer.push(first_byte); + if let Some((child_node, c)) = + cur_node.children[search_start.0..].iter().find(|(_, c)| *c >= first_byte) + { + if *c == first_byte { cur_node = child_node; if cur_node.is_end_node { - do_fn(buffer.as_slice()); + do_fn(buffer); + } + for &byte in &word[1..] { + buffer.push(byte); + if let Some((child_node, c)) = + cur_node.children.iter().find(|(_, c)| *c >= byte) + { + if *c == byte { + cur_node = child_node; + if cur_node.is_end_node { + do_fn(buffer); + } + } else { + break; + } + } else { + break; + } } - } else { - break; } } } @@ -466,3 +544,66 @@ impl PrefixTrieNode { // } // } } +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_trie() { + let trie = PrefixTrieNode::from_sorted_prefixes(IntoIterator::into_iter([ + "1", "19", "2", "a", "ab", "ac", "ad", "al", "am", "an", "ap", "ar", "as", "at", "au", + "b", "ba", "bar", "be", "bi", "bl", "bla", "bo", "br", "bra", "bri", "bro", "bu", "c", + "ca", "car", "ce", "ch", "cha", "che", "chi", "ci", "cl", "cla", "co", "col", "com", + "comp", "con", "cons", "cont", "cor", "cou", "cr", "cu", "d", "da", "de", "dec", "des", + "di", "dis", "do", "dr", "du", "e", "el", "em", "en", "es", "ev", "ex", "exp", "f", + "fa", "fe", "fi", "fl", "fo", "for", "fr", "fra", "fre", "fu", "g", "ga", "ge", "gi", + "gl", "go", "gr", "gra", "gu", "h", "ha", "har", "he", "hea", "hi", "ho", "hu", "i", + "im", "imp", "in", "ind", "ins", "int", "inte", "j", "ja", "je", "jo", "ju", "k", "ka", + "ke", "ki", "ko", "l", "la", "le", "li", "lo", "lu", "m", "ma", "mal", "man", "mar", + "mat", "mc", "me", "mi", "min", "mis", "mo", "mon", "mor", "mu", "n", "na", "ne", "ni", + "no", "o", "or", "ou", "ov", "ove", "over", "p", "pa", "par", "pe", "per", "ph", "pi", + "pl", "po", "pr", "pre", "pro", "pu", "q", "qu", "r", "ra", "re", "rec", "rep", "res", + "ri", "ro", "ru", "s", "sa", "san", "sc", "sch", "se", "sh", "sha", "shi", "sho", "si", + "sk", "sl", "sn", "so", "sp", "st", "sta", "ste", "sto", "str", "su", "sup", "sw", "t", + "ta", "te", "th", "ti", "to", "tr", "tra", "tri", "tu", "u", "un", "v", "va", "ve", + "vi", "vo", "w", "wa", "we", "wh", "wi", "wo", "y", "yo", "z", + ])); + // let mut buffer = String::new(); + // trie.print(&mut buffer, 0); + // buffer.clear(); + let mut search_start = PrefixTrieNodeSearchStart(0); + let mut buffer = vec![]; + + let is_empty = !trie.set_search_start("affair".as_bytes(), &mut search_start); + println!("{search_start:?}"); + println!("is empty: {is_empty}"); + trie.for_each_prefix_of("affair".as_bytes(), &mut buffer, &search_start, |x| { + let s = std::str::from_utf8(x).unwrap(); + println!("{s}"); + }); + buffer.clear(); + trie.for_each_prefix_of("trans".as_bytes(), &mut buffer, &search_start, |x| { + let s = std::str::from_utf8(x).unwrap(); + println!("{s}"); + }); + buffer.clear(); + + trie.for_each_prefix_of("affair".as_bytes(), &mut buffer, &search_start, |x| { + let s = std::str::from_utf8(x).unwrap(); + println!("{s}"); + }); + buffer.clear(); + // trie.for_each_prefix_of("1", |x| { + // println!("{x}"); + // }); + // trie.for_each_prefix_of("19", |x| { + // println!("{x}"); + // }); + // trie.for_each_prefix_of("21", |x| { + // println!("{x}"); + // }); + // let mut buffer = vec![]; + // trie.for_each_prefix_of("integ", &mut buffer, |x| { + // println!("{x}"); + // }); + } +}