diff --git a/meilidb-core/src/levenshtein.rs b/meilidb-core/src/levenshtein.rs new file mode 100644 index 000000000..6e781b550 --- /dev/null +++ b/meilidb-core/src/levenshtein.rs @@ -0,0 +1,134 @@ +use std::cmp::min; +use std::collections::BTreeMap; +use std::ops::{Index, IndexMut}; + +// A simple wrapper around vec so we can get contiguous but index it like it's 2D array. +struct N2Array { + y_size: usize, + buf: Vec, +} + +impl N2Array { + fn new(x: usize, y: usize, value: T) -> N2Array { + N2Array { + y_size: y, + buf: vec![value; x * y], + } + } +} + +impl Index<(usize, usize)> for N2Array { + type Output = T; + + #[inline] + fn index(&self, (x, y): (usize, usize)) -> &T { + &self.buf[(x * self.y_size) + y] + } +} + +impl IndexMut<(usize, usize)> for N2Array { + #[inline] + fn index_mut(&mut self, (x, y): (usize, usize)) -> &mut T { + &mut self.buf[(x * self.y_size) + y] + } +} + +pub fn prefix_damerau_levenshtein(source: &[u8], target: &[u8]) -> (u32, usize) { + let (n, m) = (source.len(), target.len()); + + assert!( + n <= m, + "the source string must be shorter than the target one" + ); + + if n == 0 { + return (m as u32, 0); + } + if m == 0 { + return (n as u32, 0); + } + + if n == m && source == target { + return (0, m); + } + + let inf = n + m; + let mut matrix = N2Array::new(n + 2, m + 2, 0); + + matrix[(0, 0)] = inf; + for i in 0..n + 1 { + matrix[(i + 1, 0)] = inf; + matrix[(i + 1, 1)] = i; + } + for j in 0..m + 1 { + matrix[(0, j + 1)] = inf; + matrix[(1, j + 1)] = j; + } + + let mut last_row = BTreeMap::new(); + + for (row, char_s) in source.iter().enumerate() { + let mut last_match_col = 0; + let row = row + 1; + + for (col, char_t) in target.iter().enumerate() { + let col = col + 1; + let last_match_row = *last_row.get(&char_t).unwrap_or(&0); + let cost = if char_s == char_t { 0 } else { 1 }; + + let dist_add = matrix[(row, col + 1)] + 1; + let dist_del = matrix[(row + 1, col)] + 1; + let dist_sub = matrix[(row, col)] + cost; + let dist_trans = matrix[(last_match_row, last_match_col)] + + (row - last_match_row - 1) + + 1 + + (col - last_match_col - 1); + + let dist = min(min(dist_add, dist_del), min(dist_sub, dist_trans)); + + matrix[(row + 1, col + 1)] = dist; + + if cost == 0 { + last_match_col = col; + } + } + + last_row.insert(char_s, row); + } + + let mut minimum = (u32::max_value(), 0); + + for x in n..=m { + let dist = matrix[(n + 1, x + 1)] as u32; + if dist < minimum.0 { + minimum = (dist, x) + } + } + + minimum +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn matched_length() { + let query = "Levenste"; + let text = "Levenshtein"; + + let (dist, length) = prefix_damerau_levenshtein(query.as_bytes(), text.as_bytes()); + assert_eq!(dist, 1); + assert_eq!(&text[..length], "Levenshte"); + } + + #[test] + #[should_panic] + fn matched_length_panic() { + let query = "Levenshtein"; + let text = "Levenste"; + + // this function will panic if source if longer than target + prefix_damerau_levenshtein(query.as_bytes(), text.as_bytes()); + } +} diff --git a/meilidb-core/src/lib.rs b/meilidb-core/src/lib.rs index 6beef461e..9a9c3ed80 100644 --- a/meilidb-core/src/lib.rs +++ b/meilidb-core/src/lib.rs @@ -7,6 +7,7 @@ pub mod criterion; mod database; mod distinct_map; mod error; +mod levenshtein; mod number; mod query_builder; mod ranked_map; diff --git a/meilidb-core/src/query_builder.rs b/meilidb-core/src/query_builder.rs index 2328a4844..236c6b699 100644 --- a/meilidb-core/src/query_builder.rs +++ b/meilidb-core/src/query_builder.rs @@ -11,6 +11,7 @@ use slice_group_by::{GroupBy, GroupByMut}; use crate::automaton::{Automaton, AutomatonGroup, AutomatonProducer, QueryEnhancer}; use crate::distinct_map::{BufferedDistinctMap, DistinctMap}; +use crate::levenshtein::prefix_damerau_levenshtein; use crate::raw_document::{raw_documents_from, RawDocument}; use crate::{criterion::Criteria, Document, DocumentId, Highlight, TmpMatch}; use crate::{reordered_attrs::ReorderedAttrs, store, MResult}; @@ -162,6 +163,7 @@ fn fetch_raw_documents( index, is_exact, query_len, + query, .. } = automaton; let dfa = automaton.dfa(); @@ -176,6 +178,12 @@ fn fetch_raw_documents( let distance = dfa.eval(input).to_u8(); let is_exact = *is_exact && distance == 0 && input.len() == *query_len; + let covered_area = if query.len() > input.len() { + query.len() + } else { + prefix_damerau_levenshtein(query.as_bytes(), input).1 + }; + let doc_indexes = match postings_lists_store.postings_list(reader, input)? { Some(doc_indexes) => doc_indexes, None => continue, @@ -197,7 +205,7 @@ fn fetch_raw_documents( let highlight = Highlight { attribute: di.attribute, char_index: di.char_index, - char_length: u16::try_from(*query_len).unwrap_or(u16::max_value()), + char_length: u16::try_from(covered_area).unwrap_or(u16::max_value()), }; tmp_matches.push((di.document_id, id, match_, highlight));