diff --git a/src/bin/infos.rs b/src/bin/infos.rs index 270892e90..e68f52357 100644 --- a/src/bin/infos.rs +++ b/src/bin/infos.rs @@ -118,10 +118,7 @@ fn most_common_words(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyhow: for result in index.word_position_docids.iter(rtxn)? { if limit == 0 { break } - let (bytes, postings) = result?; - let (word, _position) = bytes.split_at(bytes.len() - 4); - let word = str::from_utf8(word)?; - + let ((word, _position), postings) = result?; match prev.as_mut() { Some((prev_word, freq, docids)) if prev_word == word => { *freq += postings.len(); @@ -153,6 +150,9 @@ fn most_common_words(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyhow: } fn words_frequencies(index: &Index, rtxn: &heed::RoTxn, words: Vec) -> anyhow::Result<()> { + use heed::BytesDecode; + use heed::types::ByteSlice; + use milli::heed_codec::{RoaringBitmapCodec, StrBEU32Codec}; use roaring::RoaringBitmap; let stdout = io::stdout(); @@ -162,13 +162,14 @@ fn words_frequencies(index: &Index, rtxn: &heed::RoTxn, words: Vec) -> a for word in words { let mut document_frequency = RoaringBitmap::new(); let mut frequency = 0; - for result in index.word_position_docids.prefix_iter(rtxn, word.as_bytes())? { + let db = index.word_position_docids.as_polymorph(); + for result in db.prefix_iter::<_, ByteSlice, RoaringBitmapCodec>(rtxn, word.as_bytes())? { let (bytes, postings) = result?; - let (w, _position) = bytes.split_at(bytes.len() - 4); + let (w, _position) = StrBEU32Codec::bytes_decode(bytes).unwrap(); // if the word is not exactly the word we requested then it means // we found a word that *starts with* the requested word and we must stop. - if word.as_bytes() != w { break } + if word != w { break } document_frequency.union_with(&postings); frequency += postings.len(); @@ -182,8 +183,9 @@ fn words_frequencies(index: &Index, rtxn: &heed::RoTxn, words: Vec) -> a fn biggest_value_sizes(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyhow::Result<()> { use std::cmp::Reverse; use std::collections::BinaryHeap; - use std::convert::TryInto; + use heed::BytesDecode; use heed::types::{Str, ByteSlice}; + use milli::heed_codec::StrBEU32Codec; let main_name = "main"; let word_positions_name = "word_positions"; @@ -206,10 +208,7 @@ fn biggest_value_sizes(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyho for result in index.word_position_docids.as_polymorph().iter::<_, ByteSlice, ByteSlice>(rtxn)? { let (key_bytes, value) = result?; - let (word, position) = key_bytes.split_at(key_bytes.len() - 4); - let word = str::from_utf8(word)?; - let position = position.try_into().map(u32::from_be_bytes)?; - + let (word, position) = StrBEU32Codec::bytes_decode(key_bytes).unwrap(); let key = format!("{} {}", word, position); heap.push(Reverse((value.len(), key, word_position_docids_name))); if heap.len() > limit { heap.pop(); } @@ -217,10 +216,7 @@ fn biggest_value_sizes(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyho for result in index.word_attribute_docids.as_polymorph().iter::<_, ByteSlice, ByteSlice>(rtxn)? { let (key_bytes, value) = result?; - let (word, attribute) = key_bytes.split_at(key_bytes.len() - 4); - let word = str::from_utf8(word)?; - let attribute = attribute.try_into().map(u32::from_be_bytes)?; - + let (word, attribute) = StrBEU32Codec::bytes_decode(key_bytes).unwrap(); let key = format!("{} {}", word, attribute); heap.push(Reverse((value.len(), key, word_attribute_docids_name))); if heap.len() > limit { heap.pop(); } @@ -239,7 +235,9 @@ fn biggest_value_sizes(index: &Index, rtxn: &heed::RoTxn, limit: usize) -> anyho } fn word_position_doc_ids(index: &Index, rtxn: &heed::RoTxn, debug: bool, words: Vec) -> anyhow::Result<()> { - use std::convert::TryInto; + use heed::BytesDecode; + use heed::types::ByteSlice; + use milli::heed_codec::{RoaringBitmapCodec, StrBEU32Codec}; let stdout = io::stdout(); let mut wtr = csv::Writer::from_writer(stdout.lock()); @@ -247,14 +245,14 @@ fn word_position_doc_ids(index: &Index, rtxn: &heed::RoTxn, debug: bool, words: let mut non_debug = Vec::new(); for word in words { - for result in index.word_position_docids.prefix_iter(rtxn, word.as_bytes())? { + let db = index.word_position_docids.as_polymorph(); + for result in db.prefix_iter::<_, ByteSlice, RoaringBitmapCodec>(rtxn, word.as_bytes())? { let (bytes, postings) = result?; - let (w, position) = bytes.split_at(bytes.len() - 4); - let position = position.try_into().map(u32::from_be_bytes)?; + let (w, position) = StrBEU32Codec::bytes_decode(bytes).unwrap(); // if the word is not exactly the word we requested then it means // we found a word that *starts with* the requested word and we must stop. - if word.as_bytes() != w { break } + if word != w { break } let postings_string = if debug { format!("{:?}", postings) diff --git a/src/heed_codec/mod.rs b/src/heed_codec/mod.rs index 559f71633..bb75cdc15 100644 --- a/src/heed_codec/mod.rs +++ b/src/heed_codec/mod.rs @@ -1,3 +1,5 @@ mod roaring_bitmap_codec; +mod str_beu32_codec; pub use self::roaring_bitmap_codec::RoaringBitmapCodec; +pub use self::str_beu32_codec::StrBEU32Codec; diff --git a/src/heed_codec/str_beu32_codec.rs b/src/heed_codec/str_beu32_codec.rs new file mode 100644 index 000000000..95836ec4e --- /dev/null +++ b/src/heed_codec/str_beu32_codec.rs @@ -0,0 +1,28 @@ +use std::borrow::Cow; +use std::convert::TryInto; +use std::str; + +pub struct StrBEU32Codec; + +impl<'a> heed::BytesDecode<'a> for StrBEU32Codec { + type DItem = (&'a str, u32); + + fn bytes_decode(bytes: &'a [u8]) -> Option { + let str_len = bytes.len().checked_sub(4)?; + let (str_bytes, n_bytes) = bytes.split_at(str_len); + let s = str::from_utf8(str_bytes).ok()?; + let n = n_bytes.try_into().map(u32::from_be_bytes).ok()?; + Some((s, n)) + } +} + +impl<'a> heed::BytesEncode<'a> for StrBEU32Codec { + type EItem = (&'a str, u32); + + fn bytes_encode((s, n): &Self::EItem) -> Option> { + let mut bytes = Vec::with_capacity(s.len() + 4); + bytes.extend_from_slice(s.as_bytes()); + bytes.extend_from_slice(&n.to_be_bytes()); + Some(Cow::Owned(bytes)) + } +} diff --git a/src/lib.rs b/src/lib.rs index 1e21331dc..c45638c6a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,9 +1,9 @@ mod criterion; -mod heed_codec; mod node; mod query_tokens; mod search; mod transitive_arc; +pub mod heed_codec; pub mod lexer; use std::collections::HashMap; @@ -21,7 +21,7 @@ use oxidized_mtbl as omtbl; pub use self::search::{Search, SearchResult}; pub use self::criterion::{Criterion, default_criteria}; -use self::heed_codec::RoaringBitmapCodec; +use self::heed_codec::{RoaringBitmapCodec, StrBEU32Codec}; use self::transitive_arc::TransitiveArc; pub type FastMap4 = HashMap>; @@ -44,10 +44,10 @@ pub struct Index { pub word_positions: Database, pub prefix_word_positions: Database, /// Maps a word at a position (u32) and all the documents ids where it appears. - pub word_position_docids: Database, - pub prefix_word_position_docids: Database, + pub word_position_docids: Database, + pub prefix_word_position_docids: Database, /// Maps a word and an attribute (u32) to all the documents ids that it appears in. - pub word_attribute_docids: Database, + pub word_attribute_docids: Database, /// The MTBL store that contains the documents content. documents: omtbl::Reader>, } diff --git a/src/search.rs b/src/search.rs index e5b8c5fc7..a9d2610ff 100644 --- a/src/search.rs +++ b/src/search.rs @@ -138,14 +138,10 @@ impl<'a> Search<'a> { let number_of_attributes = index.number_of_attributes(rtxn)?.map_or(0, |n| n as u32); for (i, derived_words) in derived_words.iter().enumerate() { - let mut union_docids = RoaringBitmap::new(); for (word, _distance, _positions) in derived_words { for attr in 0..number_of_attributes { - - let mut key = word.clone().into_bytes(); - key.extend_from_slice(&attr.to_be_bytes()); - if let Some(docids) = index.word_attribute_docids.get(rtxn, &key)? { + if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attr))? { union_docids.union_with(&docids); } } @@ -172,9 +168,7 @@ impl<'a> Search<'a> { let mut union_docids = RoaringBitmap::new(); for (word, _distance, positions) in words { if positions.contains(position) { - let mut key = word.clone().into_bytes(); - key.extend_from_slice(&position.to_be_bytes()); - if let Some(docids) = index.word_position_docids.get(rtxn, &key)? { + if let Some(docids) = index.word_position_docids.get(rtxn, &(word, position))? { union_docids.union_with(&docids); } } @@ -192,9 +186,7 @@ impl<'a> Search<'a> { { let mut union_docids = RoaringBitmap::new(); for (word, _distance, _positions) in words { - let mut key = word.clone().into_bytes(); - key.extend_from_slice(&attribute.to_be_bytes()); - if let Some(docids) = index.word_attribute_docids.get(rtxn, &key)? { + if let Some(docids) = index.word_attribute_docids.get(rtxn, &(word, attribute))? { union_docids.union_with(&docids); } }