Make attribute criterion typo/prefix tolerant

This commit is contained in:
many 2021-04-01 14:42:23 +02:00
parent 59f58c15f7
commit 1eee0029a8
No known key found for this signature in database
GPG Key ID: 2CEF23B75189EACA

View File

@ -1,4 +1,4 @@
use std::{cmp::{self, Ordering}, collections::BinaryHeap};
use std::{borrow::Cow, cmp::{self, Ordering}, collections::BinaryHeap};
use std::collections::{BTreeMap, HashMap, btree_map};
use std::mem::take;
@ -7,7 +7,7 @@ use roaring::RoaringBitmap;
use crate::{TreeLevel, search::build_dfa};
use crate::search::criteria::Query;
use crate::search::query_tree::{Operation, QueryKind};
use crate::search::WordDerivationsCache;
use crate::search::{word_derivations, WordDerivationsCache};
use super::{Criterion, CriterionResult, Context, resolve_query_tree};
pub struct Attribute<'t> {
@ -71,7 +71,7 @@ impl<'t> Criterion for Attribute<'t> {
},
}
} else {
set_compute_candidates(self.ctx, flattened_query_tree, candidates)?
set_compute_candidates(self.ctx, flattened_query_tree, candidates, wdcache)?
};
candidates.difference_with(&found_candidates);
@ -122,21 +122,18 @@ struct WordLevelIterator<'t, 'q> {
inner: Box<dyn Iterator<Item =heed::Result<((&'t str, TreeLevel, u32, u32), RoaringBitmap)>> + 't>,
level: TreeLevel,
interval_size: u32,
word: &'q str,
word: Cow<'q, str>,
in_prefix_cache: bool,
inner_next: Option<(u32, u32, RoaringBitmap)>,
current_interval: Option<(u32, u32)>,
}
impl<'t, 'q> WordLevelIterator<'t, 'q> {
fn new(ctx: &'t dyn Context<'t>, query: &'q Query) -> heed::Result<Option<Self>> {
// TODO make it typo/prefix tolerant
let word = query.kind.word();
let in_prefix_cache = query.prefix && ctx.in_prefix_cache(word);
match ctx.word_position_last_level(word, in_prefix_cache)? {
fn new(ctx: &'t dyn Context<'t>, word: Cow<'q, str>, in_prefix_cache: bool) -> heed::Result<Option<Self>> {
match ctx.word_position_last_level(&word, in_prefix_cache)? {
Some(level) => {
let interval_size = 4u32.pow(Into::<u8>::into(level.clone()) as u32);
let inner = ctx.word_position_iterator(word, level, in_prefix_cache, None, None)?;
let inner = ctx.word_position_iterator(&word, level, in_prefix_cache, None, None)?;
Ok(Some(Self { inner, level, interval_size, word, in_prefix_cache, inner_next: None, current_interval: None }))
},
None => Ok(None),
@ -146,11 +143,11 @@ impl<'t, 'q> WordLevelIterator<'t, 'q> {
fn dig(&self, ctx: &'t dyn Context<'t>, level: &TreeLevel) -> heed::Result<Self> {
let level = level.min(&self.level).clone();
let interval_size = 4u32.pow(Into::<u8>::into(level.clone()) as u32);
let word = self.word;
let word = self.word.clone();
let in_prefix_cache = self.in_prefix_cache;
// TODO try to dig starting from the current interval
// let left = self.current_interval.map(|(left, _)| left);
let inner = ctx.word_position_iterator(word, level, in_prefix_cache, None, None)?;
let inner = ctx.word_position_iterator(&word, level, in_prefix_cache, None, None)?;
Ok(Self {inner, level, interval_size, word, in_prefix_cache, inner_next: None, current_interval: None})
}
@ -193,12 +190,34 @@ struct QueryLevelIterator<'t, 'q> {
}
impl<'t, 'q> QueryLevelIterator<'t, 'q> {
fn new(ctx: &'t dyn Context<'t>, queries: &'q Vec<Query>) -> heed::Result<Option<Self>> {
fn new(ctx: &'t dyn Context<'t>, queries: &'q Vec<Query>, wdcache: &mut WordDerivationsCache) -> anyhow::Result<Option<Self>> {
let mut inner = Vec::with_capacity(queries.len());
for query in queries {
if let Some(word_level_iterator) = WordLevelIterator::new(ctx, query)? {
match &query.kind {
QueryKind::Exact { word, .. } => {
if !query.prefix || ctx.in_prefix_cache(&word) {
let word = Cow::Borrowed(query.kind.word());
if let Some(word_level_iterator) = WordLevelIterator::new(ctx, word, query.prefix)? {
inner.push(word_level_iterator);
}
} else {
for (word, _) in word_derivations(&word, true, 0, ctx.words_fst(), wdcache)? {
let word = Cow::Owned(word.to_owned());
if let Some(word_level_iterator) = WordLevelIterator::new(ctx, word, false)? {
inner.push(word_level_iterator);
}
}
}
},
QueryKind::Tolerant { typo, word } => {
for (word, _) in word_derivations(&word, query.prefix, *typo, ctx.words_fst(), wdcache)? {
let word = Cow::Owned(word.to_owned());
if let Some(word_level_iterator) = WordLevelIterator::new(ctx, word, false)? {
inner.push(word_level_iterator);
}
}
}
}
}
let highest = inner.iter().max_by_key(|wli| wli.level).map(|wli| wli.level.clone());
@ -346,13 +365,14 @@ impl<'t, 'q> Eq for Branch<'t, 'q> {}
fn initialize_query_level_iterators<'t, 'q>(
ctx: &'t dyn Context<'t>,
branches: &'q Vec<Vec<Vec<Query>>>,
) -> heed::Result<BinaryHeap<Branch<'t, 'q>>> {
wdcache: &mut WordDerivationsCache,
) -> anyhow::Result<BinaryHeap<Branch<'t, 'q>>> {
let mut positions = BinaryHeap::with_capacity(branches.len());
for branch in branches {
let mut branch_positions = Vec::with_capacity(branch.len());
for query in branch {
match QueryLevelIterator::new(ctx, query)? {
match QueryLevelIterator::new(ctx, query, wdcache)? {
Some(qli) => branch_positions.push(qli),
None => {
// the branch seems to be invalid, so we skip it.
@ -393,9 +413,10 @@ fn set_compute_candidates<'t>(
ctx: &'t dyn Context<'t>,
branches: &Vec<Vec<Vec<Query>>>,
allowed_candidates: &RoaringBitmap,
wdcache: &mut WordDerivationsCache,
) -> anyhow::Result<RoaringBitmap>
{
let mut branches_heap = initialize_query_level_iterators(ctx, branches)?;
let mut branches_heap = initialize_query_level_iterators(ctx, branches, wdcache)?;
let lowest_level = TreeLevel::min_value();
while let Some(mut branch) = branches_heap.peek_mut() {