Refactor word prefix pair proximity indexation further

This commit is contained in:
Loïc Lecrenier 2022-07-13 19:35:17 +02:00
parent 306593144d
commit 86807ca848

View File

@ -5,6 +5,7 @@ use heed::BytesDecode;
use log::debug; use log::debug;
use std::borrow::Cow; use std::borrow::Cow;
use std::cmp::Ordering;
use std::collections::HashSet; use std::collections::HashSet;
use std::io::BufReader; use std::io::BufReader;
use std::time::Instant; use std::time::Instant;
@ -72,6 +73,84 @@ impl<'t, 'u, 'i> WordPrefixPairProximityDocids<'t, 'u, 'i> {
del_prefix_fst_words: &HashSet<Vec<u8>>, del_prefix_fst_words: &HashSet<Vec<u8>>,
) -> Result<()> { ) -> Result<()> {
debug!("Computing and writing the word prefix pair proximity docids into LMDB on disk..."); debug!("Computing and writing the word prefix pair proximity docids into LMDB on disk...");
let mut allocations = Allocations::default();
let mut count = 0;
let prefixes = PrefixTrieNode::from_sorted_prefixes(
common_prefix_fst_words
.into_iter()
.map(|s| s.into_iter())
.flatten()
.map(|s| s.as_str())
.filter(|s| s.len() <= self.max_prefix_length),
);
if !prefixes.is_empty() {
let mut cursor = new_word_pair_proximity_docids.into_cursor()?;
Self::execute_on_word_pairs_and_prefixes(
&mut cursor,
|cursor| {
if let Some((key, value)) = cursor.move_on_next()? {
let (word1, word2, proximity) = UncheckedStrStrU8Codec::bytes_decode(key)
.ok_or(heed::Error::Decoding)?;
Ok(Some(((word1, word2, proximity), value)))
} else {
Ok(None)
}
},
&prefixes,
&mut allocations,
self.max_proximity,
|key, value| {
count += 1;
insert_into_database(
&mut self.wtxn,
*self.index.word_prefix_pair_proximity_docids.as_polymorph(),
key,
value,
)
},
)?;
}
dbg!(count);
let prefixes = PrefixTrieNode::from_sorted_prefixes(
new_prefix_fst_words
.into_iter()
.map(|s| s.as_str())
.filter(|s| s.len() <= self.max_prefix_length),
);
if !prefixes.is_empty() {
let mut db_iter = self
.index
.word_pair_proximity_docids
.remap_key_type::<UncheckedStrStrU8Codec>()
.remap_data_type::<ByteSlice>()
.iter(self.wtxn)?;
let mut writer = create_writer(
self.chunk_compression_type,
self.chunk_compression_level,
tempfile::tempfile()?,
);
Self::execute_on_word_pairs_and_prefixes(
&mut db_iter,
|db_iter| db_iter.next().transpose().map_err(|e| e.into()),
&prefixes,
&mut allocations,
self.max_proximity,
|key, value| writer.insert(key, value).map_err(|e| e.into()),
)?;
drop(db_iter);
writer_into_lmdb_database(
self.wtxn,
*self.index.word_prefix_pair_proximity_docids.as_polymorph(),
writer,
)?;
}
// All of the word prefix pairs in the database that have a w2 // All of the word prefix pairs in the database that have a w2
// that is contained in the `suppr_pw` set must be removed as well. // that is contained in the `suppr_pw` set must be removed as well.
@ -89,131 +168,71 @@ impl<'t, 'u, 'i> WordPrefixPairProximityDocids<'t, 'u, 'i> {
} }
} }
// We construct a Trie of all the prefixes that are smaller than the max prefix length
// This is an optimisation that allows us to iterate over all prefixes of a word quickly.
let new_prefix_fst_words = PrefixTrieNode::from_sorted_prefixes(
new_prefix_fst_words
.into_iter()
.map(|s| s.as_str())
.filter(|s| s.len() <= self.max_prefix_length),
);
let common_prefix_fst_words = PrefixTrieNode::from_sorted_prefixes(
common_prefix_fst_words
.into_iter()
.map(|s| s.into_iter())
.flatten()
.map(|s| s.as_str())
.filter(|s| s.len() <= self.max_prefix_length),
);
let mut allocations = Allocations::default();
let mut batch = PrefixAndProximityBatch::default();
if !common_prefix_fst_words.is_empty() {
let mut cursor = new_word_pair_proximity_docids.into_cursor()?;
while let Some((key, data)) = cursor.move_on_next()? {
let (word1, word2, proximity) =
UncheckedStrStrU8Codec::bytes_decode(key).ok_or(heed::Error::Decoding)?;
if proximity <= self.max_proximity {
batch.flush_if_necessary(
word1,
word2,
&mut allocations,
&mut |key, value| {
insert_into_database(
&mut self.wtxn,
*self.index.word_prefix_pair_proximity_docids.as_polymorph(),
key,
value,
)
},
)?;
self.insert_word_prefix_pair_proximity_docids_into_batch(
word2,
proximity,
data,
&common_prefix_fst_words,
&mut batch,
&mut allocations,
)?;
}
}
batch.flush(&mut allocations, &mut |key, value| {
insert_into_database(
&mut self.wtxn,
*self.index.word_prefix_pair_proximity_docids.as_polymorph(),
key,
value,
)
})?;
}
if !new_prefix_fst_words.is_empty() {
let mut db_iter = self
.index
.word_pair_proximity_docids
.remap_key_type::<UncheckedStrStrU8Codec>()
.remap_data_type::<ByteSlice>()
.iter(self.wtxn)?;
let mut writer = create_writer(
self.chunk_compression_type,
self.chunk_compression_level,
tempfile::tempfile()?,
);
while let Some(((word1, word2, proximity), data)) = db_iter.next().transpose()? {
if proximity <= self.max_proximity {
batch.flush_if_necessary(
word1,
word2,
&mut allocations,
&mut |key, value| writer.insert(key, value).map_err(|e| e.into()),
)?;
self.insert_word_prefix_pair_proximity_docids_into_batch(
word2,
proximity,
data,
&new_prefix_fst_words,
&mut batch,
&mut allocations,
)?;
}
}
batch.flush(&mut allocations, &mut |key, value| {
writer.insert(key, value).map_err(|e| e.into())
})?;
drop(db_iter);
writer_into_lmdb_database(
self.wtxn,
*self.index.word_prefix_pair_proximity_docids.as_polymorph(),
writer,
)?;
}
Ok(()) Ok(())
} }
fn insert_word_prefix_pair_proximity_docids_into_batch<'b, 'c>( fn execute_on_word_pairs_and_prefixes<Iter>(
&self, iter: &mut Iter,
word2: &[u8], mut next_word_pair_proximity: impl for<'a> FnMut(
proximity: u8, &'a mut Iter,
data: &'b [u8], ) -> Result<
prefixes: &'c PrefixTrieNode, Option<((&'a [u8], &'a [u8], u8), &'a [u8])>,
writer: &'b mut PrefixAndProximityBatch, >,
prefixes: &PrefixTrieNode,
allocations: &mut Allocations, allocations: &mut Allocations,
max_proximity: u8,
mut insert: impl for<'a> FnMut(&'a [u8], &'a [u8]) -> Result<()>,
) -> Result<()> { ) -> Result<()> {
let mut batch = PrefixAndProximityBatch::default();
let mut prev_word2_start = 0;
let mut prefix_search_start = PrefixTrieNodeSearchStart(0);
let mut empty_prefixes = false;
let mut prefix_buffer = allocations.take_byte_vector(); let mut prefix_buffer = allocations.take_byte_vector();
prefixes.for_each_prefix_of(word2, &mut prefix_buffer, |prefix| {
let mut value = allocations.take_byte_vector(); while let Some(((word1, word2, proximity), data)) = next_word_pair_proximity(iter)? {
value.extend_from_slice(&data); if proximity > max_proximity {
writer.insert(prefix, proximity, value, allocations); continue;
}); };
allocations.reclaim_byte_vector(prefix_buffer); let word2_start_different_than_prev = word2[0] != prev_word2_start;
if empty_prefixes && !word2_start_different_than_prev {
continue;
}
let word1_different_than_prev = word1 != batch.word1;
if word1_different_than_prev || word2_start_different_than_prev {
batch.flush(allocations, &mut insert)?;
if word1_different_than_prev {
prefix_search_start.0 = 0;
batch.word1.clear();
batch.word1.extend_from_slice(word1);
}
if word2_start_different_than_prev {
// word2_start_different_than_prev == true
prev_word2_start = word2[0];
}
empty_prefixes = !prefixes.set_search_start(word2, &mut prefix_search_start);
}
if !empty_prefixes {
prefixes.for_each_prefix_of(
word2,
&mut prefix_buffer,
&prefix_search_start,
|prefix_buffer| {
let mut value = allocations.take_byte_vector();
value.extend_from_slice(&data);
let prefix_len = prefix_buffer.len();
prefix_buffer.push(0);
prefix_buffer.push(proximity);
batch.insert(&prefix_buffer, value, allocations);
prefix_buffer.truncate(prefix_len);
},
);
prefix_buffer.clear();
}
}
batch.flush(allocations, &mut insert)?;
Ok(()) Ok(())
} }
} }
@ -224,10 +243,6 @@ The keys are sorted and conflicts are resolved by merging the vectors of bitstri
It is used to ensure that all ((word1, prefix, proximity), docids) are inserted into the database in sorted order and efficiently. It is used to ensure that all ((word1, prefix, proximity), docids) are inserted into the database in sorted order and efficiently.
A batch is valid only for a specific `word1`. Also, all prefixes stored in the batch start with the same letter. Make sure to
call [`self.flush_if_necessary`](Self::flush_if_necessary) before inserting a list of sorted `(prefix, proximity)` (and where each
`prefix` starts with the same letter) in order to uphold these invariants.
The batch is flushed as often as possible, when we are sure that every (word1, prefix, proximity) key derived from its content The batch is flushed as often as possible, when we are sure that every (word1, prefix, proximity) key derived from its content
can be inserted into the database in sorted order. When it is flushed, it calls a user-provided closure with the following arguments: can be inserted into the database in sorted order. When it is flushed, it calls a user-provided closure with the following arguments:
- key : (word1, prefix, proximity) as bytes - key : (word1, prefix, proximity) as bytes
@ -235,91 +250,95 @@ can be inserted into the database in sorted order. When it is flushed, it calls
*/ */
#[derive(Default)] #[derive(Default)]
struct PrefixAndProximityBatch { struct PrefixAndProximityBatch {
batch: Vec<(Vec<u8>, Vec<Cow<'static, [u8]>>)>,
word1: Vec<u8>, word1: Vec<u8>,
word2_start: u8, batch: Vec<(Vec<u8>, Vec<Cow<'static, [u8]>>)>,
} }
impl PrefixAndProximityBatch { impl PrefixAndProximityBatch {
fn insert( fn insert(&mut self, new_key: &[u8], new_value: Vec<u8>, allocations: &mut Allocations) {
&mut self, // this is a macro instead of a closure because the borrow checker will complain
new_prefix: &[u8], // about the closure moving `new_value`
new_proximity: u8, macro_rules! insert_new_key_value {
new_value: Vec<u8>, () => {
allocations: &mut Allocations, let mut key = allocations.take_byte_vector();
) { key.extend_from_slice(new_key);
let mut key = allocations.take_byte_vector();
key.extend_from_slice(new_prefix);
key.push(0);
key.push(new_proximity);
if let Some(position) = self.batch.iter().position(|(k, _)| k >= &key) {
let (existing_key, existing_data) = &mut self.batch[position];
if existing_key == &key {
existing_data.push(Cow::Owned(new_value));
} else {
let mut mergeable_data = allocations.take_mergeable_data_vector(); let mut mergeable_data = allocations.take_mergeable_data_vector();
mergeable_data.push(Cow::Owned(new_value)); mergeable_data.push(Cow::Owned(new_value));
self.batch.insert(position, (key, mergeable_data)); self.batch.push((key, mergeable_data));
};
($idx:expr) => {
let mut key = allocations.take_byte_vector();
key.extend_from_slice(new_key);
let mut mergeable_data = allocations.take_mergeable_data_vector();
mergeable_data.push(Cow::Owned(new_value));
self.batch.insert($idx, (key, mergeable_data));
};
}
if self.batch.is_empty() {
insert_new_key_value!();
} else if self.batch.len() == 1 {
let (existing_key, existing_data) = &mut self.batch[0];
match new_key.cmp(&existing_key) {
Ordering::Less => {
insert_new_key_value!(0);
}
Ordering::Equal => {
existing_data.push(Cow::Owned(new_value));
}
Ordering::Greater => {
insert_new_key_value!();
}
} }
} else { } else {
let mut mergeable_data = allocations.take_mergeable_data_vector(); match self.batch.binary_search_by_key(&new_key, |(k, _)| k.as_slice()) {
mergeable_data.push(Cow::Owned(new_value)); Ok(position) => {
self.batch.push((key, mergeable_data)); self.batch[position].1.push(Cow::Owned(new_value));
} }
} Err(position) => {
insert_new_key_value!(position);
/// Call [`self.flush`](Self::flush) if `word1` changed or if `word2` begins with a different letter than the }
/// previous word2. Update `prev_word1` and `prev_word2_start` with the new values from `word1` and `word2`.
fn flush_if_necessary(
&mut self,
word1: &[u8],
word2: &[u8],
allocations: &mut Allocations,
insert: &mut impl for<'buffer> FnMut(&'buffer [u8], &'buffer [u8]) -> Result<()>,
) -> Result<()> {
let word2_start = word2[0];
if word1 != self.word1 {
self.flush(allocations, insert)?;
self.word1.clear();
self.word1.extend_from_slice(word1);
if word2_start != self.word2_start {
self.word2_start = word2_start;
} }
} }
if word2_start != self.word2_start {
self.flush(allocations, insert)?;
self.word2_start = word2_start;
}
Ok(())
} }
/// Empties the batch, calling `insert` on each element. /// Empties the batch, calling `insert` on each element.
/// ///
/// The key given to insert is `(word1, prefix, proximity)` and the value is the associated merged roaring bitmap. /// The key given to `insert` is `(word1, prefix, proximity)` and the value is the associated merged roaring bitmap.
fn flush( fn flush(
&mut self, &mut self,
allocations: &mut Allocations, allocations: &mut Allocations,
insert: &mut impl for<'buffer> FnMut(&'buffer [u8], &'buffer [u8]) -> Result<()>, insert: &mut impl for<'buffer> FnMut(&'buffer [u8], &'buffer [u8]) -> Result<()>,
) -> Result<()> { ) -> Result<()> {
let PrefixAndProximityBatch { batch, word1: prev_word1, word2_start: _ } = self; let PrefixAndProximityBatch { word1, batch } = self;
if batch.is_empty() {
return Ok(());
}
let mut buffer = allocations.take_byte_vector(); let mut buffer = allocations.take_byte_vector();
buffer.extend_from_slice(prev_word1.as_slice()); buffer.extend_from_slice(word1);
buffer.push(0); buffer.push(0);
for (key, mergeable_data) in batch.drain(..) { for (key, mergeable_data) in batch.drain(..) {
buffer.truncate(prev_word1.len() + 1); buffer.truncate(word1.len() + 1);
buffer.extend_from_slice(key.as_slice()); buffer.extend_from_slice(key.as_slice());
let data = merge_cbo_roaring_bitmaps(&buffer, &mergeable_data)?; let merged;
insert(buffer.as_slice(), &data)?; let data = if mergeable_data.len() > 1 {
merged = merge_cbo_roaring_bitmaps(&buffer, &mergeable_data)?;
&merged
} else {
&mergeable_data[0]
};
insert(buffer.as_slice(), data)?;
allocations.reclaim_byte_vector(key); allocations.reclaim_byte_vector(key);
allocations.reclaim_mergeable_data_vector(mergeable_data); allocations.reclaim_mergeable_data_vector(mergeable_data);
} }
Ok(()) Ok(())
} }
} }
// This is adapted from `sorter_into_lmdb_database`
fn insert_into_database( fn insert_into_database(
wtxn: &mut heed::RwTxn, wtxn: &mut heed::RwTxn,
database: heed::PolyDatabase, database: heed::PolyDatabase,
@ -356,7 +375,8 @@ pub fn writer_into_lmdb_database(
) -> Result<()> { ) -> Result<()> {
let file = writer.into_inner()?; let file = writer.into_inner()?;
let reader = grenad::Reader::new(BufReader::new(file))?; let reader = grenad::Reader::new(BufReader::new(file))?;
let len = reader.len();
dbg!(len);
let before = Instant::now(); let before = Instant::now();
if database.is_empty(wtxn)? { if database.is_empty(wtxn)? {
@ -413,10 +433,44 @@ struct PrefixTrieNode {
is_end_node: bool, is_end_node: bool,
} }
#[derive(Debug)]
struct PrefixTrieNodeSearchStart(usize);
impl PrefixTrieNode { impl PrefixTrieNode {
fn is_empty(&self) -> bool { fn is_empty(&self) -> bool {
self.children.is_empty() self.children.is_empty()
} }
/// Returns false if the trie does not contain a prefix of the given word.
/// Returns true if the trie *may* contain a prefix of the given word.
///
/// Moves the search start to the first node equal to the first letter of the word,
/// or to 0 otherwise.
fn set_search_start(&self, word: &[u8], search_start: &mut PrefixTrieNodeSearchStart) -> bool {
let byte = word[0];
if self.children[search_start.0].1 == byte {
return true;
} else if let Some(position) =
self.children[search_start.0..].iter().position(|(_, c)| *c >= byte)
{
let (_, c) = self.children[search_start.0 + position];
// dbg!(position, c, byte);
if c == byte {
// dbg!();
search_start.0 += position;
true
} else {
// dbg!();
search_start.0 = 0;
false
}
} else {
// dbg!();
search_start.0 = 0;
false
}
}
fn from_sorted_prefixes<'a>(prefixes: impl Iterator<Item = &'a str>) -> Self { fn from_sorted_prefixes<'a>(prefixes: impl Iterator<Item = &'a str>) -> Self {
let mut node = PrefixTrieNode::default(); let mut node = PrefixTrieNode::default();
for prefix in prefixes { for prefix in prefixes {
@ -439,17 +493,41 @@ impl PrefixTrieNode {
self.is_end_node = true; self.is_end_node = true;
} }
} }
fn for_each_prefix_of(&self, word: &[u8], buffer: &mut Vec<u8>, mut do_fn: impl FnMut(&[u8])) { fn for_each_prefix_of(
&self,
word: &[u8],
buffer: &mut Vec<u8>,
search_start: &PrefixTrieNodeSearchStart,
mut do_fn: impl FnMut(&mut Vec<u8>),
) {
let first_byte = word[0];
let mut cur_node = self; let mut cur_node = self;
for &byte in word { buffer.push(first_byte);
buffer.push(byte); if let Some((child_node, c)) =
if let Some((child_node, _)) = cur_node.children.iter().find(|(_, c)| *c == byte) { cur_node.children[search_start.0..].iter().find(|(_, c)| *c >= first_byte)
{
if *c == first_byte {
cur_node = child_node; cur_node = child_node;
if cur_node.is_end_node { if cur_node.is_end_node {
do_fn(buffer.as_slice()); do_fn(buffer);
}
for &byte in &word[1..] {
buffer.push(byte);
if let Some((child_node, c)) =
cur_node.children.iter().find(|(_, c)| *c >= byte)
{
if *c == byte {
cur_node = child_node;
if cur_node.is_end_node {
do_fn(buffer);
}
} else {
break;
}
} else {
break;
}
} }
} else {
break;
} }
} }
} }
@ -466,3 +544,66 @@ impl PrefixTrieNode {
// } // }
// } // }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_trie() {
let trie = PrefixTrieNode::from_sorted_prefixes(IntoIterator::into_iter([
"1", "19", "2", "a", "ab", "ac", "ad", "al", "am", "an", "ap", "ar", "as", "at", "au",
"b", "ba", "bar", "be", "bi", "bl", "bla", "bo", "br", "bra", "bri", "bro", "bu", "c",
"ca", "car", "ce", "ch", "cha", "che", "chi", "ci", "cl", "cla", "co", "col", "com",
"comp", "con", "cons", "cont", "cor", "cou", "cr", "cu", "d", "da", "de", "dec", "des",
"di", "dis", "do", "dr", "du", "e", "el", "em", "en", "es", "ev", "ex", "exp", "f",
"fa", "fe", "fi", "fl", "fo", "for", "fr", "fra", "fre", "fu", "g", "ga", "ge", "gi",
"gl", "go", "gr", "gra", "gu", "h", "ha", "har", "he", "hea", "hi", "ho", "hu", "i",
"im", "imp", "in", "ind", "ins", "int", "inte", "j", "ja", "je", "jo", "ju", "k", "ka",
"ke", "ki", "ko", "l", "la", "le", "li", "lo", "lu", "m", "ma", "mal", "man", "mar",
"mat", "mc", "me", "mi", "min", "mis", "mo", "mon", "mor", "mu", "n", "na", "ne", "ni",
"no", "o", "or", "ou", "ov", "ove", "over", "p", "pa", "par", "pe", "per", "ph", "pi",
"pl", "po", "pr", "pre", "pro", "pu", "q", "qu", "r", "ra", "re", "rec", "rep", "res",
"ri", "ro", "ru", "s", "sa", "san", "sc", "sch", "se", "sh", "sha", "shi", "sho", "si",
"sk", "sl", "sn", "so", "sp", "st", "sta", "ste", "sto", "str", "su", "sup", "sw", "t",
"ta", "te", "th", "ti", "to", "tr", "tra", "tri", "tu", "u", "un", "v", "va", "ve",
"vi", "vo", "w", "wa", "we", "wh", "wi", "wo", "y", "yo", "z",
]));
// let mut buffer = String::new();
// trie.print(&mut buffer, 0);
// buffer.clear();
let mut search_start = PrefixTrieNodeSearchStart(0);
let mut buffer = vec![];
let is_empty = !trie.set_search_start("affair".as_bytes(), &mut search_start);
println!("{search_start:?}");
println!("is empty: {is_empty}");
trie.for_each_prefix_of("affair".as_bytes(), &mut buffer, &search_start, |x| {
let s = std::str::from_utf8(x).unwrap();
println!("{s}");
});
buffer.clear();
trie.for_each_prefix_of("trans".as_bytes(), &mut buffer, &search_start, |x| {
let s = std::str::from_utf8(x).unwrap();
println!("{s}");
});
buffer.clear();
trie.for_each_prefix_of("affair".as_bytes(), &mut buffer, &search_start, |x| {
let s = std::str::from_utf8(x).unwrap();
println!("{s}");
});
buffer.clear();
// trie.for_each_prefix_of("1", |x| {
// println!("{x}");
// });
// trie.for_each_prefix_of("19", |x| {
// println!("{x}");
// });
// trie.for_each_prefix_of("21", |x| {
// println!("{x}");
// });
// let mut buffer = vec![];
// trie.for_each_prefix_of("integ", &mut buffer, |x| {
// println!("{x}");
// });
}
}