From 6a399556b5581d146bec5a03e0b0cd87b9ab5fc6 Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Wed, 4 Sep 2024 10:20:18 +0200 Subject: [PATCH] Implement more searchable extractor --- milli/src/update/new/extract/mod.rs | 7 +- .../extract/searchable/extract_word_docids.rs | 100 +++++ .../src/update/new/extract/searchable/mod.rs | 156 ++++++++ .../extract/searchable/tokenize_document.rs | 364 ++++++++++++++++++ 4 files changed, 623 insertions(+), 4 deletions(-) create mode 100644 milli/src/update/new/extract/searchable/extract_word_docids.rs create mode 100644 milli/src/update/new/extract/searchable/mod.rs create mode 100644 milli/src/update/new/extract/searchable/tokenize_document.rs diff --git a/milli/src/update/new/extract/mod.rs b/milli/src/update/new/extract/mod.rs index 3124068d9..5e6c02c65 100644 --- a/milli/src/update/new/extract/mod.rs +++ b/milli/src/update/new/extract/mod.rs @@ -1,6 +1,5 @@ mod cache; -mod extract_word_docids; -mod tokenize_document; +mod searchable; -pub use extract_word_docids::SearchableExtractor; -pub use extract_word_docids::WordDocidsExtractor; +pub use searchable::SearchableExtractor; +pub use searchable::WordDocidsExtractor; diff --git a/milli/src/update/new/extract/searchable/extract_word_docids.rs b/milli/src/update/new/extract/searchable/extract_word_docids.rs new file mode 100644 index 000000000..f8b495538 --- /dev/null +++ b/milli/src/update/new/extract/searchable/extract_word_docids.rs @@ -0,0 +1,100 @@ +use std::borrow::Cow; + +use heed::RoTxn; + +use super::SearchableExtractor; +use crate::{bucketed_position, FieldId, Index, Result}; + +pub struct WordDocidsExtractor; +impl SearchableExtractor for WordDocidsExtractor { + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + index.user_defined_searchable_fields(rtxn).map_err(Into::into) + } + + fn attributes_to_skip<'a>(rtxn: &'a RoTxn, index: &'a Index) -> Result> { + // exact attributes must be skipped and stored in a separate DB, see `ExactWordDocidsExtractor`. + index.exact_attributes(rtxn).map_err(Into::into) + } + + fn build_key<'a>(_field_id: FieldId, _position: u16, word: &'a str) -> Cow<'a, [u8]> { + Cow::Borrowed(word.as_bytes()) + } +} + +pub struct ExactWordDocidsExtractor; +impl SearchableExtractor for ExactWordDocidsExtractor { + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + let exact_attributes = index.exact_attributes(rtxn)?; + // If there are no user-defined searchable fields, we return all exact attributes. + // Otherwise, we return the intersection of exact attributes and user-defined searchable fields. + if let Some(searchable_attributes) = index.user_defined_searchable_fields(rtxn)? { + let attributes = exact_attributes + .into_iter() + .filter(|attr| searchable_attributes.contains(attr)) + .collect(); + Ok(Some(attributes)) + } else { + Ok(Some(exact_attributes)) + } + } + + fn attributes_to_skip<'a>(_rtxn: &'a RoTxn, _index: &'a Index) -> Result> { + Ok(vec![]) + } + + fn build_key<'a>(_field_id: FieldId, _position: u16, word: &'a str) -> Cow<'a, [u8]> { + Cow::Borrowed(word.as_bytes()) + } +} + +pub struct WordFidDocidsExtractor; +impl SearchableExtractor for WordFidDocidsExtractor { + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + index.user_defined_searchable_fields(rtxn).map_err(Into::into) + } + + fn attributes_to_skip<'a>(_rtxn: &'a RoTxn, _index: &'a Index) -> Result> { + Ok(vec![]) + } + + fn build_key<'a>(field_id: FieldId, _position: u16, word: &'a str) -> Cow<'a, [u8]> { + let mut key = Vec::new(); + key.extend_from_slice(word.as_bytes()); + key.push(0); + key.extend_from_slice(&field_id.to_be_bytes()); + Cow::Owned(key) + } +} + +pub struct WordPositionDocidsExtractor; +impl SearchableExtractor for WordPositionDocidsExtractor { + fn attributes_to_extract<'a>( + rtxn: &'a RoTxn, + index: &'a Index, + ) -> Result>> { + index.user_defined_searchable_fields(rtxn).map_err(Into::into) + } + + fn attributes_to_skip<'a>(_rtxn: &'a RoTxn, _index: &'a Index) -> Result> { + Ok(vec![]) + } + + fn build_key<'a>(_field_id: FieldId, position: u16, word: &'a str) -> Cow<'a, [u8]> { + // position must be bucketed to reduce the number of keys in the DB. + let position = bucketed_position(position); + let mut key = Vec::new(); + key.extend_from_slice(word.as_bytes()); + key.push(0); + key.extend_from_slice(&position.to_be_bytes()); + Cow::Owned(key) + } +} diff --git a/milli/src/update/new/extract/searchable/mod.rs b/milli/src/update/new/extract/searchable/mod.rs new file mode 100644 index 000000000..106455a7b --- /dev/null +++ b/milli/src/update/new/extract/searchable/mod.rs @@ -0,0 +1,156 @@ +mod extract_word_docids; +mod tokenize_document; + +pub use extract_word_docids::{ + ExactWordDocidsExtractor, WordDocidsExtractor, WordFidDocidsExtractor, + WordPositionDocidsExtractor, +}; +use std::borrow::Cow; +use std::fs::File; + +use grenad::Merger; +use heed::RoTxn; +use rayon::iter::{IntoParallelIterator, ParallelIterator}; + +use super::cache::CachedSorter; +use crate::update::new::{DocumentChange, ItemsPool}; +use crate::update::{create_sorter, GrenadParameters, MergeDeladdCboRoaringBitmaps}; +use crate::{FieldId, GlobalFieldsIdsMap, Index, Result, MAX_POSITION_PER_ATTRIBUTE}; +use tokenize_document::{tokenizer_builder, DocumentTokenizer}; + +pub trait SearchableExtractor { + fn run_extraction( + index: &Index, + fields_ids_map: &GlobalFieldsIdsMap, + indexer: GrenadParameters, + document_changes: impl IntoParallelIterator>, + ) -> Result> { + let max_memory = indexer.max_memory_by_thread(); + + let rtxn = index.read_txn()?; + let stop_words = index.stop_words(&rtxn)?; + let allowed_separators = index.allowed_separators(&rtxn)?; + let allowed_separators: Option> = + allowed_separators.as_ref().map(|s| s.iter().map(String::as_str).collect()); + let dictionary = index.dictionary(&rtxn)?; + let dictionary: Option> = + dictionary.as_ref().map(|s| s.iter().map(String::as_str).collect()); + let builder = tokenizer_builder( + stop_words.as_ref(), + allowed_separators.as_deref(), + dictionary.as_deref(), + ); + let tokenizer = builder.into_tokenizer(); + + let attributes_to_extract = Self::attributes_to_extract(&rtxn, index)?; + let attributes_to_skip = Self::attributes_to_skip(&rtxn, index)?; + let localized_attributes_rules = + index.localized_attributes_rules(&rtxn)?.unwrap_or_default(); + + let document_tokenizer = DocumentTokenizer { + tokenizer: &tokenizer, + attribute_to_extract: attributes_to_extract.as_deref(), + attribute_to_skip: attributes_to_skip.as_slice(), + localized_attributes_rules: &localized_attributes_rules, + max_positions_per_attributes: MAX_POSITION_PER_ATTRIBUTE, + }; + + let context_pool = ItemsPool::new(|| { + Ok(( + index.read_txn()?, + &document_tokenizer, + fields_ids_map.clone(), + CachedSorter::new( + // TODO use a better value + 100.try_into().unwrap(), + create_sorter( + grenad::SortAlgorithm::Stable, + MergeDeladdCboRoaringBitmaps, + indexer.chunk_compression_type, + indexer.chunk_compression_level, + indexer.max_nb_chunks, + max_memory, + ), + ), + )) + }); + + document_changes.into_par_iter().try_for_each(|document_change| { + context_pool.with(|(rtxn, document_tokenizer, fields_ids_map, cached_sorter)| { + Self::extract_document_change( + &*rtxn, + index, + document_tokenizer, + fields_ids_map, + cached_sorter, + document_change?, + ) + }) + })?; + + let mut builder = grenad::MergerBuilder::new(MergeDeladdCboRoaringBitmaps); + for (_rtxn, _tokenizer, _fields_ids_map, cache) in context_pool.into_items() { + let sorter = cache.into_sorter()?; + let readers = sorter.into_reader_cursors()?; + builder.extend(readers); + } + + Ok(builder.build()) + } + + fn extract_document_change( + rtxn: &RoTxn, + index: &Index, + document_tokenizer: &DocumentTokenizer, + fields_ids_map: &mut GlobalFieldsIdsMap, + cached_sorter: &mut CachedSorter, + document_change: DocumentChange, + ) -> Result<()> { + match document_change { + DocumentChange::Deletion(inner) => { + let mut token_fn = |fid, pos: u16, word: &str| { + let key = Self::build_key(fid, pos, word); + cached_sorter.insert_del_u32(&key, inner.docid()).unwrap(); + }; + document_tokenizer.tokenize_document( + inner.current(rtxn, index)?.unwrap(), + fields_ids_map, + &mut token_fn, + )?; + } + DocumentChange::Update(inner) => { + let mut token_fn = |fid, pos, word: &str| { + let key = Self::build_key(fid, pos, word); + cached_sorter.insert_del_u32(&key, inner.docid()).unwrap(); + }; + document_tokenizer.tokenize_document( + inner.current(rtxn, index)?.unwrap(), + fields_ids_map, + &mut token_fn, + )?; + + let mut token_fn = |fid, pos, word: &str| { + let key = Self::build_key(fid, pos, word); + cached_sorter.insert_add_u32(&key, inner.docid()).unwrap(); + }; + document_tokenizer.tokenize_document(inner.new(), fields_ids_map, &mut token_fn)?; + } + DocumentChange::Insertion(inner) => { + let mut token_fn = |fid, pos, word: &str| { + let key = Self::build_key(fid, pos, word); + cached_sorter.insert_add_u32(&key, inner.docid()).unwrap(); + }; + document_tokenizer.tokenize_document(inner.new(), fields_ids_map, &mut token_fn)?; + } + } + + Ok(()) + } + + fn attributes_to_extract<'a>(rtxn: &'a RoTxn, index: &'a Index) + -> Result>>; + + fn attributes_to_skip<'a>(rtxn: &'a RoTxn, index: &'a Index) -> Result>; + + fn build_key<'a>(field_id: FieldId, position: u16, word: &'a str) -> Cow<'a, [u8]>; +} diff --git a/milli/src/update/new/extract/searchable/tokenize_document.rs b/milli/src/update/new/extract/searchable/tokenize_document.rs new file mode 100644 index 000000000..e20e52406 --- /dev/null +++ b/milli/src/update/new/extract/searchable/tokenize_document.rs @@ -0,0 +1,364 @@ +use std::collections::HashMap; + +use charabia::{SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder}; +use heed::RoTxn; +use serde_json::Value; + +use crate::update::new::KvReaderFieldId; +use crate::{ + FieldId, FieldsIdsMap, GlobalFieldsIdsMap, Index, InternalError, LocalizedAttributesRule, + Result, MAX_POSITION_PER_ATTRIBUTE, MAX_WORD_LENGTH, +}; + +pub struct DocumentTokenizer<'a> { + pub tokenizer: &'a Tokenizer<'a>, + pub attribute_to_extract: Option<&'a [&'a str]>, + pub attribute_to_skip: &'a [&'a str], + pub localized_attributes_rules: &'a [LocalizedAttributesRule], + pub max_positions_per_attributes: u32, +} + +impl<'a> DocumentTokenizer<'a> { + pub fn tokenize_document( + &self, + obkv: &KvReaderFieldId, + field_id_map: &mut GlobalFieldsIdsMap, + token_fn: &mut impl FnMut(FieldId, u16, &str), + ) -> Result<()> { + let mut field_position = HashMap::new(); + let mut field_name = String::new(); + for (field_id, field_bytes) in obkv { + let Some(field_name) = field_id_map.name(field_id).map(|s| { + field_name.clear(); + field_name.push_str(s); + &field_name + }) else { + unreachable!("field id not found in field id map"); + }; + + let mut tokenize_field = |name: &str, value: &Value| { + let Some(field_id) = field_id_map.id_or_insert(name) else { + /// TODO: better error + panic!("it's over 9000"); + }; + + let position = + field_position.entry(field_id).and_modify(|counter| *counter += 8).or_insert(0); + if *position as u32 >= self.max_positions_per_attributes { + return; + } + + match value { + Value::Number(n) => { + let token = n.to_string(); + if let Ok(position) = (*position).try_into() { + token_fn(field_id, position, token.as_str()); + } + } + Value::String(text) => { + // create an iterator of token with their positions. + let locales = self + .localized_attributes_rules + .iter() + .find(|rule| rule.match_str(field_name)) + .map(|rule| rule.locales()); + let tokens = process_tokens( + *position, + self.tokenizer.tokenize_with_allow_list(text.as_str(), locales), + ) + .take_while(|(p, _)| (*p as u32) < self.max_positions_per_attributes); + + for (index, token) in tokens { + // keep a word only if it is not empty and fit in a LMDB key. + let token = token.lemma().trim(); + if !token.is_empty() && token.len() <= MAX_WORD_LENGTH { + *position = index; + if let Ok(position) = (*position).try_into() { + token_fn(field_id, position, token); + } + } + } + } + _ => (), + } + }; + + // if the current field is searchable or contains a searchable attribute + if perm_json_p::select_field( + &field_name, + self.attribute_to_extract.as_deref(), + self.attribute_to_skip, + ) { + // parse json. + match serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)? { + Value::Object(object) => perm_json_p::seek_leaf_values_in_object( + &object, + self.attribute_to_extract.as_deref(), + self.attribute_to_skip, + &field_name, + &mut tokenize_field, + ), + Value::Array(array) => perm_json_p::seek_leaf_values_in_array( + &array, + self.attribute_to_extract.as_deref(), + self.attribute_to_skip, + &field_name, + &mut tokenize_field, + ), + value => tokenize_field(&field_name, &value), + } + } + } + Ok(()) + } +} + +/// take an iterator on tokens and compute their relative position depending on separator kinds +/// if it's an `Hard` separator we add an additional relative proximity of 8 between words, +/// else we keep the standard proximity of 1 between words. +fn process_tokens<'a>( + start_offset: usize, + tokens: impl Iterator>, +) -> impl Iterator)> { + tokens + .skip_while(|token| token.is_separator()) + .scan((start_offset, None), |(offset, prev_kind), mut token| { + match token.kind { + TokenKind::Word | TokenKind::StopWord if !token.lemma().is_empty() => { + *offset += match *prev_kind { + Some(TokenKind::Separator(SeparatorKind::Hard)) => 8, + Some(_) => 1, + None => 0, + }; + *prev_kind = Some(token.kind) + } + TokenKind::Separator(SeparatorKind::Hard) => { + *prev_kind = Some(token.kind); + } + TokenKind::Separator(SeparatorKind::Soft) + if *prev_kind != Some(TokenKind::Separator(SeparatorKind::Hard)) => + { + *prev_kind = Some(token.kind); + } + _ => token.kind = TokenKind::Unknown, + } + Some((*offset, token)) + }) + .filter(|(_, t)| t.is_word()) +} + +/// Factorize tokenizer building. +pub fn tokenizer_builder<'a>( + stop_words: Option<&'a fst::Set<&'a [u8]>>, + allowed_separators: Option<&'a [&str]>, + dictionary: Option<&'a [&str]>, +) -> TokenizerBuilder<'a, &'a [u8]> { + let mut tokenizer_builder = TokenizerBuilder::new(); + if let Some(stop_words) = stop_words { + tokenizer_builder.stop_words(stop_words); + } + if let Some(dictionary) = dictionary { + tokenizer_builder.words_dict(dictionary); + } + if let Some(separators) = allowed_separators { + tokenizer_builder.separators(separators); + } + + tokenizer_builder +} + +/// TODO move in permissive json pointer +mod perm_json_p { + use serde_json::{Map, Value}; + const SPLIT_SYMBOL: char = '.'; + + /// Returns `true` if the `selector` match the `key`. + /// + /// ```text + /// Example: + /// `animaux` match `animaux` + /// `animaux.chien` match `animaux` + /// `animaux.chien` match `animaux` + /// `animaux.chien.nom` match `animaux` + /// `animaux.chien.nom` match `animaux.chien` + /// ----------------------------------------- + /// `animaux` doesn't match `animaux.chien` + /// `animaux.` doesn't match `animaux` + /// `animaux.ch` doesn't match `animaux.chien` + /// `animau` doesn't match `animaux` + /// ``` + pub fn contained_in(selector: &str, key: &str) -> bool { + selector.starts_with(key) + && selector[key.len()..].chars().next().map(|c| c == SPLIT_SYMBOL).unwrap_or(true) + } + + pub fn seek_leaf_values_in_object( + value: &Map, + selectors: Option<&[&str]>, + skip_selectors: &[&str], + base_key: &str, + seeker: &mut impl FnMut(&str, &Value), + ) { + for (key, value) in value.iter() { + let base_key = if base_key.is_empty() { + key.to_string() + } else { + format!("{}{}{}", base_key, SPLIT_SYMBOL, key) + }; + + // here if the user only specified `doggo` we need to iterate in all the fields of `doggo` + // so we check the contained_in on both side + let should_continue = select_field(&base_key, selectors, skip_selectors); + if should_continue { + match value { + Value::Object(object) => seek_leaf_values_in_object( + object, + selectors, + skip_selectors, + &base_key, + seeker, + ), + Value::Array(array) => seek_leaf_values_in_array( + array, + selectors, + skip_selectors, + &base_key, + seeker, + ), + value => seeker(&base_key, value), + } + } + } + } + + pub fn seek_leaf_values_in_array( + values: &[Value], + selectors: Option<&[&str]>, + skip_selectors: &[&str], + base_key: &str, + seeker: &mut impl FnMut(&str, &Value), + ) { + for value in values { + match value { + Value::Object(object) => { + seek_leaf_values_in_object(object, selectors, skip_selectors, base_key, seeker) + } + Value::Array(array) => { + seek_leaf_values_in_array(array, selectors, skip_selectors, base_key, seeker) + } + value => seeker(base_key, value), + } + } + } + + pub fn select_field( + field_name: &str, + selectors: Option<&[&str]>, + skip_selectors: &[&str], + ) -> bool { + selectors.map_or(true, |selectors| { + selectors.iter().any(|selector| { + contained_in(selector, &field_name) || contained_in(&field_name, selector) + }) + }) && !skip_selectors.iter().any(|skip_selector| { + contained_in(skip_selector, &field_name) || contained_in(&field_name, skip_selector) + }) + } +} + +#[cfg(test)] +mod test { + use charabia::TokenizerBuilder; + use meili_snap::snapshot; + use obkv::KvReader; + use serde_json::json; + + use super::*; + #[test] + fn test_tokenize_document() { + let mut fields_ids_map = FieldsIdsMap::new(); + + let field_1 = json!({ + "name": "doggo", + "age": 10, + }); + + let field_2 = json!({ + "catto": { + "name": "pesti", + "age": 23, + } + }); + + let field_3 = json!(["doggo", "catto"]); + let field_4 = json!("UNSEARCHABLE"); + let field_5 = json!({"nope": "unsearchable"}); + + let mut obkv = obkv::KvWriter::memory(); + let field_1_id = fields_ids_map.insert("doggo").unwrap(); + let field_1 = serde_json::to_string(&field_1).unwrap(); + obkv.insert(field_1_id, field_1.as_bytes()).unwrap(); + let field_2_id = fields_ids_map.insert("catto").unwrap(); + let field_2 = serde_json::to_string(&field_2).unwrap(); + obkv.insert(field_2_id, field_2.as_bytes()).unwrap(); + let field_3_id = fields_ids_map.insert("doggo.name").unwrap(); + let field_3 = serde_json::to_string(&field_3).unwrap(); + obkv.insert(field_3_id, field_3.as_bytes()).unwrap(); + let field_4_id = fields_ids_map.insert("not-me").unwrap(); + let field_4 = serde_json::to_string(&field_4).unwrap(); + obkv.insert(field_4_id, field_4.as_bytes()).unwrap(); + let field_5_id = fields_ids_map.insert("me-nether").unwrap(); + let field_5 = serde_json::to_string(&field_5).unwrap(); + obkv.insert(field_5_id, field_5.as_bytes()).unwrap(); + let value = obkv.into_inner().unwrap(); + let obkv = KvReader::from_slice(value.as_slice()); + + let mut tb = TokenizerBuilder::default(); + let document_tokenizer = DocumentTokenizer { + tokenizer: &tb.build(), + attribute_to_extract: None, + attribute_to_skip: &["not-me", "me-nether.nope"], + localized_attributes_rules: &[], + max_positions_per_attributes: 1000, + }; + + let fields_ids_map_lock = std::sync::RwLock::new(fields_ids_map); + let mut global_fields_ids_map = GlobalFieldsIdsMap::new(&fields_ids_map_lock); + + let mut words = std::collections::BTreeMap::new(); + document_tokenizer + .tokenize_document(obkv, &mut global_fields_ids_map, &mut |fid, pos, word| { + words.insert([fid, pos], word.to_string()); + }) + .unwrap(); + + snapshot!(format!("{:#?}", words), @r###" + { + [ + 2, + 0, + ]: "doggo", + [ + 2, + 8, + ]: "doggo", + [ + 2, + 16, + ]: "catto", + [ + 3, + 0, + ]: "10", + [ + 4, + 0, + ]: "pesti", + [ + 5, + 0, + ]: "23", + } + "###); + } +}