Rerun extraction when language detection might have failed

This commit is contained in:
ManyTheFish 2023-03-07 18:35:26 +01:00
parent 370d88f626
commit da48506f15

View File

@ -3,12 +3,14 @@ use std::convert::TryInto;
use std::fs::File;
use std::{io, mem, str};
use charabia::{Language, Script, SeparatorKind, Token, TokenKind, TokenizerBuilder};
use charabia::{Language, Script, SeparatorKind, Token, TokenKind, Tokenizer, TokenizerBuilder};
use obkv::KvReader;
use roaring::RoaringBitmap;
use serde_json::Value;
use super::helpers::{concat_u32s_array, create_sorter, sorter_into_reader, GrenadParameters};
use crate::error::{InternalError, SerializationError};
use crate::update::index_documents::MergeFn;
use crate::{
absolute_from_relative_position, FieldId, Result, MAX_POSITION_PER_ATTRIBUTE, MAX_WORD_LENGTH,
};
@ -33,7 +35,7 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
let max_memory = indexer.max_memory_by_thread();
let mut documents_ids = RoaringBitmap::new();
let mut script_language_pair = HashMap::new();
let mut script_language_docids = HashMap::new();
let mut docid_word_positions_sorter = create_sorter(
grenad::SortAlgorithm::Stable,
concat_u32s_array,
@ -45,11 +47,11 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
let mut key_buffer = Vec::new();
let mut field_buffer = String::new();
let mut builder = TokenizerBuilder::new();
let mut tokenizer_builder = TokenizerBuilder::new();
if let Some(stop_words) = stop_words {
builder.stop_words(stop_words);
tokenizer_builder.stop_words(stop_words);
}
let tokenizer = builder.build();
let tokenizer = tokenizer_builder.build();
let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
@ -57,49 +59,120 @@ pub fn extract_docid_word_positions<R: io::Read + io::Seek>(
.try_into()
.map(u32::from_be_bytes)
.map_err(|_| SerializationError::InvalidNumberSerialization)?;
let obkv = obkv::KvReader::<FieldId>::new(value);
let obkv = KvReader::<FieldId>::new(value);
documents_ids.push(document_id);
key_buffer.clear();
key_buffer.extend_from_slice(&document_id.to_be_bytes());
for (field_id, field_bytes) in obkv.iter() {
if searchable_fields.as_ref().map_or(true, |sf| sf.contains(&field_id)) {
let value =
serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?;
field_buffer.clear();
if let Some(field) = json_to_string(&value, &mut field_buffer) {
let tokens = process_tokens(tokenizer.tokenize(field))
.take_while(|(p, _)| (*p as u32) < max_positions_per_attributes);
let mut script_language_word_count = HashMap::new();
for (index, token) in tokens {
if let Some(language) = token.language {
let script = token.script;
let entry = script_language_pair
.entry((script, language))
.or_insert_with(RoaringBitmap::new);
entry.push(document_id);
}
let token = token.lemma().trim();
if !token.is_empty() && token.len() <= MAX_WORD_LENGTH {
key_buffer.truncate(mem::size_of::<u32>());
key_buffer.extend_from_slice(token.as_bytes());
extract_tokens_from_document(
&obkv,
searchable_fields,
&tokenizer,
max_positions_per_attributes,
&mut key_buffer,
&mut field_buffer,
&mut script_language_word_count,
&mut docid_word_positions_sorter,
)?;
let position: u16 = index
.try_into()
.map_err(|_| SerializationError::InvalidNumberSerialization)?;
let position = absolute_from_relative_position(field_id, position);
docid_word_positions_sorter
.insert(&key_buffer, position.to_ne_bytes())?;
// if we detect a potetial mistake in the language detection,
// we rerun the extraction forcing the tokenizer to detect the most frequently detected Languages.
// context: https://github.com/meilisearch/meilisearch/issues/3565
if script_language_word_count.values().any(potential_language_detection_error) {
// build an allow list with the most frequent detected languages in the document.
let script_language: HashMap<_, _> =
script_language_word_count.iter().filter_map(most_frequent_languages).collect();
// if the allow list is empty, meaning that no Language is considered frequent,
// then we don't rerun the extraction.
if !script_language.is_empty() {
// build a new temporar tokenizer including the allow list.
let mut tokenizer_builder = TokenizerBuilder::new();
if let Some(stop_words) = stop_words {
tokenizer_builder.stop_words(stop_words);
}
tokenizer_builder.allow_list(&script_language);
let tokenizer = tokenizer_builder.build();
script_language_word_count.clear();
// rerun the extraction.
extract_tokens_from_document(
&obkv,
searchable_fields,
&tokenizer,
max_positions_per_attributes,
&mut key_buffer,
&mut field_buffer,
&mut script_language_word_count,
&mut docid_word_positions_sorter,
)?;
}
}
for (script, languages_frequency) in script_language_word_count {
for (language, _) in languages_frequency {
let entry = script_language_docids
.entry((script, language))
.or_insert_with(RoaringBitmap::new);
entry.push(document_id);
}
}
}
sorter_into_reader(docid_word_positions_sorter, indexer)
.map(|reader| (documents_ids, reader, script_language_docids))
}
fn extract_tokens_from_document<T: AsRef<[u8]>>(
obkv: &KvReader<FieldId>,
searchable_fields: &Option<HashSet<FieldId>>,
tokenizer: &Tokenizer<T>,
max_positions_per_attributes: u32,
key_buffer: &mut Vec<u8>,
field_buffer: &mut String,
script_language_word_count: &mut HashMap<Script, Vec<(Language, usize)>>,
docid_word_positions_sorter: &mut grenad::Sorter<MergeFn>,
) -> Result<()> {
for (field_id, field_bytes) in obkv.iter() {
if searchable_fields.as_ref().map_or(true, |sf| sf.contains(&field_id)) {
let value = serde_json::from_slice(field_bytes).map_err(InternalError::SerdeJson)?;
field_buffer.clear();
if let Some(field) = json_to_string(&value, field_buffer) {
let tokens = process_tokens(tokenizer.tokenize(field))
.take_while(|(p, _)| (*p as u32) < max_positions_per_attributes);
for (index, token) in tokens {
// if a language has been detected for the token, we update the counter.
if let Some(language) = token.language {
let script = token.script;
let entry =
script_language_word_count.entry(script).or_insert_with(Vec::new);
match entry.iter_mut().find(|(l, _)| *l == language) {
Some((_, n)) => *n += 1,
None => entry.push((language, 1)),
}
}
let token = token.lemma().trim();
if !token.is_empty() && token.len() <= MAX_WORD_LENGTH {
key_buffer.truncate(mem::size_of::<u32>());
key_buffer.extend_from_slice(token.as_bytes());
let position: u16 = index
.try_into()
.map_err(|_| SerializationError::InvalidNumberSerialization)?;
let position = absolute_from_relative_position(field_id, position);
docid_word_positions_sorter.insert(&key_buffer, position.to_ne_bytes())?;
}
}
}
}
}
sorter_into_reader(docid_word_positions_sorter, indexer)
.map(|reader| (documents_ids, reader, script_language_pair))
Ok(())
}
/// Transform a JSON value into a string that can be indexed.
@ -183,3 +256,39 @@ fn process_tokens<'a>(
})
.filter(|(_, t)| t.is_word())
}
fn potential_language_detection_error(languages_frequency: &Vec<(Language, usize)>) -> bool {
if languages_frequency.len() > 1 {
let threshold = compute_laguage_frequency_threshold(languages_frequency);
languages_frequency.iter().any(|(_, c)| *c <= threshold)
} else {
false
}
}
fn most_frequent_languages(
(script, languages_frequency): (&Script, &Vec<(Language, usize)>),
) -> Option<(Script, Vec<Language>)> {
if languages_frequency.len() > 1 {
let threshold = compute_laguage_frequency_threshold(languages_frequency);
let languages: Vec<_> = languages_frequency
.iter()
.filter(|(_, c)| *c > threshold)
.map(|(l, _)| l.clone())
.collect();
if languages.is_empty() {
None
} else {
Some((script.clone(), languages))
}
} else {
None
}
}
fn compute_laguage_frequency_threshold(languages_frequency: &Vec<(Language, usize)>) -> usize {
let total: usize = languages_frequency.iter().map(|(_, c)| c).sum();
total / 20 // 5% is a completely arbitrar value.
}