diff --git a/milli/src/update/index_documents/extract/extract_docid_word_positions.rs b/milli/src/update/index_documents/extract/extract_docid_word_positions.rs index 2d51fcc1a..5a103f1e0 100644 --- a/milli/src/update/index_documents/extract/extract_docid_word_positions.rs +++ b/milli/src/update/index_documents/extract/extract_docid_word_positions.rs @@ -3,12 +3,14 @@ use std::convert::TryInto; use std::fs::File; use std::{io, mem, str}; -use charabia::{Language, Script, SeparatorKind, Token, TokenKind, TokenizerBuilder}; +use charabia::{Language, Script, SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder}; +use obkv::KvReader; use roaring::RoaringBitmap; use serde_json::Value; use super::helpers::{concat_u32s_array, create_sorter, sorter_into_reader, GrenadParameters}; use crate::error::{InternalError, SerializationError}; +use crate::update::index_documents::MergeFn; use crate::{ absolute_from_relative_position, FieldId, Result, MAX_POSITION_PER_ATTRIBUTE, MAX_WORD_LENGTH, }; @@ -33,7 +35,7 @@ pub fn extract_docid_word_positions( let max_memory = indexer.max_memory_by_thread(); let mut documents_ids = RoaringBitmap::new(); - let mut script_language_pair = HashMap::new(); + let mut script_language_docids = HashMap::new(); let mut docid_word_positions_sorter = create_sorter( grenad::SortAlgorithm::Stable, concat_u32s_array, @@ -45,11 +47,11 @@ pub fn extract_docid_word_positions( let mut key_buffer = Vec::new(); let mut field_buffer = String::new(); - let mut builder = TokenizerBuilder::new(); + let mut tokenizer_builder = TokenizerBuilder::new(); if let Some(stop_words) = stop_words { - builder.stop_words(stop_words); + tokenizer_builder.stop_words(stop_words); } - let tokenizer = builder.build(); + let tokenizer = tokenizer_builder.build(); let mut cursor = obkv_documents.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { @@ -57,49 +59,120 @@ pub fn extract_docid_word_positions( .try_into() .map(u32::from_be_bytes) .map_err(|_| SerializationError::InvalidNumberSerialization)?; - let obkv = obkv::KvReader::::new(value); + let obkv = KvReader::::new(value); documents_ids.push(document_id); key_buffer.clear(); key_buffer.extend_from_slice(&document_id.to_be_bytes()); - for (field_id, field_bytes) in obkv.iter() { - if searchable_fields.as_ref().map_or(true, |sf| sf.contains(&field_id)) { - let value = - serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?; - field_buffer.clear(); - if let Some(field) = json_to_string(&value, &mut field_buffer) { - let tokens = process_tokens(tokenizer.tokenize(field)) - .take_while(|(p, _)| (*p as u32) < max_positions_per_attributes); + let mut script_language_word_count = HashMap::new(); - for (index, token) in tokens { - if let Some(language) = token.language { - let script = token.script; - let entry = script_language_pair - .entry((script, language)) - .or_insert_with(RoaringBitmap::new); - entry.push(document_id); - } - let token = token.lemma().trim(); - if !token.is_empty() && token.len() <= MAX_WORD_LENGTH { - key_buffer.truncate(mem::size_of::()); - key_buffer.extend_from_slice(token.as_bytes()); + extract_tokens_from_document( + &obkv, + searchable_fields, + &tokenizer, + max_positions_per_attributes, + &mut key_buffer, + &mut field_buffer, + &mut script_language_word_count, + &mut docid_word_positions_sorter, + )?; - let position: u16 = index - .try_into() - .map_err(|_| SerializationError::InvalidNumberSerialization)?; - let position = absolute_from_relative_position(field_id, position); - docid_word_positions_sorter - .insert(&key_buffer, position.to_ne_bytes())?; + // if we detect a potetial mistake in the language detection, + // we rerun the extraction forcing the tokenizer to detect the most frequently detected Languages. + // context: https://github.com/meilisearch/meilisearch/issues/3565 + if script_language_word_count.values().any(potential_language_detection_error) { + // build an allow list with the most frequent detected languages in the document. + let script_language: HashMap<_, _> = + script_language_word_count.iter().filter_map(most_frequent_languages).collect(); + + // if the allow list is empty, meaning that no Language is considered frequent, + // then we don't rerun the extraction. + if !script_language.is_empty() { + // build a new temporar tokenizer including the allow list. + let mut tokenizer_builder = TokenizerBuilder::new(); + if let Some(stop_words) = stop_words { + tokenizer_builder.stop_words(stop_words); + } + tokenizer_builder.allow_list(&script_language); + let tokenizer = tokenizer_builder.build(); + + script_language_word_count.clear(); + + // rerun the extraction. + extract_tokens_from_document( + &obkv, + searchable_fields, + &tokenizer, + max_positions_per_attributes, + &mut key_buffer, + &mut field_buffer, + &mut script_language_word_count, + &mut docid_word_positions_sorter, + )?; + } + } + + for (script, languages_frequency) in script_language_word_count { + for (language, _) in languages_frequency { + let entry = script_language_docids + .entry((script, language)) + .or_insert_with(RoaringBitmap::new); + entry.push(document_id); + } + } + } + + sorter_into_reader(docid_word_positions_sorter, indexer) + .map(|reader| (documents_ids, reader, script_language_docids)) +} + +fn extract_tokens_from_document>( + obkv: &KvReader, + searchable_fields: &Option>, + tokenizer: &Tokenizer, + max_positions_per_attributes: u32, + key_buffer: &mut Vec, + field_buffer: &mut String, + script_language_word_count: &mut HashMap>, + docid_word_positions_sorter: &mut grenad::Sorter, +) -> Result<()> { + for (field_id, field_bytes) in obkv.iter() { + if searchable_fields.as_ref().map_or(true, |sf| sf.contains(&field_id)) { + let value = serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?; + field_buffer.clear(); + if let Some(field) = json_to_string(&value, field_buffer) { + let tokens = process_tokens(tokenizer.tokenize(field)) + .take_while(|(p, _)| (*p as u32) < max_positions_per_attributes); + + for (index, token) in tokens { + // if a language has been detected for the token, we update the counter. + if let Some(language) = token.language { + let script = token.script; + let entry = + script_language_word_count.entry(script).or_insert_with(Vec::new); + match entry.iter_mut().find(|(l, _)| *l == language) { + Some((_, n)) => *n += 1, + None => entry.push((language, 1)), } } + let token = token.lemma().trim(); + if !token.is_empty() && token.len() <= MAX_WORD_LENGTH { + key_buffer.truncate(mem::size_of::()); + key_buffer.extend_from_slice(token.as_bytes()); + + let position: u16 = index + .try_into() + .map_err(|_| SerializationError::InvalidNumberSerialization)?; + let position = absolute_from_relative_position(field_id, position); + docid_word_positions_sorter.insert(&key_buffer, position.to_ne_bytes())?; + } } } } } - sorter_into_reader(docid_word_positions_sorter, indexer) - .map(|reader| (documents_ids, reader, script_language_pair)) + Ok(()) } /// Transform a JSON value into a string that can be indexed. @@ -183,3 +256,39 @@ fn process_tokens<'a>( }) .filter(|(_, t)| t.is_word()) } + +fn potential_language_detection_error(languages_frequency: &Vec<(Language, usize)>) -> bool { + if languages_frequency.len() > 1 { + let threshold = compute_laguage_frequency_threshold(languages_frequency); + languages_frequency.iter().any(|(_, c)| *c <= threshold) + } else { + false + } +} + +fn most_frequent_languages( + (script, languages_frequency): (&Script, &Vec<(Language, usize)>), +) -> Option<(Script, Vec)> { + if languages_frequency.len() > 1 { + let threshold = compute_laguage_frequency_threshold(languages_frequency); + + let languages: Vec<_> = languages_frequency + .iter() + .filter(|(_, c)| *c > threshold) + .map(|(l, _)| l.clone()) + .collect(); + + if languages.is_empty() { + None + } else { + Some((script.clone(), languages)) + } + } else { + None + } +} + +fn compute_laguage_frequency_threshold(languages_frequency: &Vec<(Language, usize)>) -> usize { + let total: usize = languages_frequency.iter().map(|(_, c)| c).sum(); + total / 20 // 5% is a completely arbitrar value. +}