diff --git a/src/best_proximity.rs b/src/best_proximity.rs index 6f822ee6d..1eb60f167 100644 --- a/src/best_proximity.rs +++ b/src/best_proximity.rs @@ -1,8 +1,9 @@ use std::cmp; use std::time::Instant; -use log::debug; use crate::iter_shortest_paths::astar_bag; +use log::debug; +use roaring::RoaringBitmap; const ONE_ATTRIBUTE: u32 = 1000; const MAX_DISTANCE: u32 = 8; @@ -47,21 +48,21 @@ impl Node { // TODO we must skip the successors that have already been seen // TODO we must skip the successors that doesn't return any documents // this way we are able to skip entire paths - fn successors(&self, positions: &[Vec], best_proximity: u32) -> Vec<(Node, u32)> { + fn successors(&self, positions: &[RoaringBitmap], best_proximity: u32) -> Vec<(Node, u32)> { match self { Node::Uninit => { - positions[0].iter().map(|p| { - (Node::Init { layer: 0, position: *p, acc_proximity: 0, parent_position: 0 }, 0) + positions[0].iter().map(|position| { + (Node::Init { layer: 0, position, acc_proximity: 0, parent_position: 0 }, 0) }).collect() }, // We reached the highest layer n @ Node::Init { .. } if n.is_complete(positions) => vec![], Node::Init { layer, position, acc_proximity, .. } => { positions[layer + 1].iter().filter_map(|p| { - let proximity = positions_proximity(*position, *p); + let proximity = positions_proximity(*position, p); let node = Node::Init { layer: layer + 1, - position: *p, + position: p, acc_proximity: acc_proximity + proximity, parent_position: *position, }; @@ -76,7 +77,7 @@ impl Node { } } - fn is_complete(&self, positions: &[Vec]) -> bool { + fn is_complete(&self, positions: &[RoaringBitmap]) -> bool { match self { Node::Uninit => false, Node::Init { layer, .. } => *layer == positions.len() - 1, @@ -121,19 +122,19 @@ impl Node { } pub struct BestProximity { - positions: Vec>, + positions: Vec, best_proximity: u32, } impl BestProximity { - pub fn new(positions: Vec>) -> BestProximity { + pub fn new(positions: Vec) -> BestProximity { let best_proximity = (positions.len() as u32).saturating_sub(1); BestProximity { positions, best_proximity } } } impl BestProximity { - pub fn next(&mut self, mut contains_documents: F) -> Option<(u32, Vec>)> + pub fn next(&mut self, mut contains_documents: F) -> Option<(u32, Vec)> where F: FnMut((usize, u32), (usize, u32)) -> bool, { let before = Instant::now(); @@ -176,6 +177,7 @@ impl BestProximity { #[cfg(test)] mod tests { use super::*; + use std::iter::FromIterator; fn sort(mut val: (u32, Vec)) -> (u32, Vec) { val.1.sort_unstable(); @@ -185,37 +187,37 @@ mod tests { #[test] fn same_attribute() { let positions = vec![ - vec![0, 2, 3, 4 ], - vec![ 1, ], - vec![ 3, 6], + RoaringBitmap::from_iter(vec![0, 2, 3, 4 ]), + RoaringBitmap::from_iter(vec![ 1, ]), + RoaringBitmap::from_iter(vec![ 3, 6]), ]; let mut iter = BestProximity::new(positions); let f = |_, _| true; - assert_eq!(iter.next(f), Some((1+2, vec![vec![0, 1, 3]]))); // 3 - assert_eq!(iter.next(f), Some((2+2, vec![vec![2, 1, 3]]))); // 4 - assert_eq!(iter.next(f), Some((3+2, vec![vec![3, 1, 3]]))); // 5 - assert_eq!(iter.next(f).map(sort), Some((1+5, vec![vec![0, 1, 6], vec![4, 1, 3]]))); // 6 - assert_eq!(iter.next(f), Some((2+5, vec![vec![2, 1, 6]]))); // 7 - assert_eq!(iter.next(f), Some((3+5, vec![vec![3, 1, 6]]))); // 8 - assert_eq!(iter.next(f), Some((4+5, vec![vec![4, 1, 6]]))); // 9 + assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3 + assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4 + assert_eq!(iter.next(f), Some((3+2, vec![RoaringBitmap::from_iter(vec![3, 1, 3])]))); // 5 + assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6]), RoaringBitmap::from_iter(vec![4, 1, 3])]))); // 6 + assert_eq!(iter.next(f), Some((2+5, vec![RoaringBitmap::from_iter(vec![2, 1, 6])]))); // 7 + assert_eq!(iter.next(f), Some((3+5, vec![RoaringBitmap::from_iter(vec![3, 1, 6])]))); // 8 + assert_eq!(iter.next(f), Some((4+5, vec![RoaringBitmap::from_iter(vec![4, 1, 6])]))); // 9 assert_eq!(iter.next(f), None); } #[test] fn different_attributes() { let positions = vec![ - vec![0, 2, 1000, 1001, 2000 ], - vec![ 1, 1000, 2001 ], - vec![ 3, 6, 2002, 3000], + RoaringBitmap::from_iter(vec![0, 2, 1000, 1001, 2000 ]), + RoaringBitmap::from_iter(vec![ 1, 1000, 2001 ]), + RoaringBitmap::from_iter(vec![ 3, 6, 2002, 3000]), ]; let mut iter = BestProximity::new(positions); let f = |_, _| true; - assert_eq!(iter.next(f), Some((1+1, vec![vec![2000, 2001, 2002]]))); // 2 - assert_eq!(iter.next(f), Some((1+2, vec![vec![0, 1, 3]]))); // 3 - assert_eq!(iter.next(f), Some((2+2, vec![vec![2, 1, 3]]))); // 4 - assert_eq!(iter.next(f), Some((1+5, vec![vec![0, 1, 6]]))); // 6 + assert_eq!(iter.next(f), Some((1+1, vec![RoaringBitmap::from_iter(vec![2000, 2001, 2002])]))); // 2 + assert_eq!(iter.next(f), Some((1+2, vec![RoaringBitmap::from_iter(vec![0, 1, 3])]))); // 3 + assert_eq!(iter.next(f), Some((2+2, vec![RoaringBitmap::from_iter(vec![2, 1, 3])]))); // 4 + assert_eq!(iter.next(f), Some((1+5, vec![RoaringBitmap::from_iter(vec![0, 1, 6])]))); // 6 // We ignore others here... } diff --git a/src/bin/search.rs b/src/bin/search.rs index bd1adfd65..2618c376f 100644 --- a/src/bin/search.rs +++ b/src/bin/search.rs @@ -62,12 +62,13 @@ fn main() -> anyhow::Result<()> { let before = Instant::now(); let query = result?; - let (_, documents_ids) = index.search(&rtxn, &query)?; + let result = index.search(&rtxn).query(query).execute().unwrap(); + let headers = match index.headers(&rtxn)? { Some(headers) => headers, None => return Ok(()), }; - let documents = index.documents(documents_ids.iter().cloned())?; + let documents = index.documents(result.documents_ids.iter().cloned())?; let mut stdout = io::stdout(); stdout.write_all(&headers)?; @@ -76,7 +77,7 @@ fn main() -> anyhow::Result<()> { stdout.write_all(&content)?; } - debug!("Took {:.02?} to find {} documents", before.elapsed(), documents_ids.len()); + debug!("Took {:.02?} to find {} documents", before.elapsed(), result.documents_ids.len()); } Ok(()) diff --git a/src/bin/serve.rs b/src/bin/serve.rs index 7f20bf8e2..3b91aa2c2 100644 --- a/src/bin/serve.rs +++ b/src/bin/serve.rs @@ -13,7 +13,7 @@ use slice_group_by::StrGroupBy; use structopt::StructOpt; use warp::{Filter, http::Response}; -use milli::Index; +use milli::{Index, SearchResult}; #[cfg(target_os = "linux")] #[global_allocator] @@ -183,7 +183,10 @@ async fn main() -> anyhow::Result<()> { let before_search = Instant::now(); let rtxn = env_cloned.read_txn().unwrap(); - let (words, documents_ids) = index.search(&rtxn, &query.query).unwrap(); + let SearchResult { found_words, documents_ids } = index.search(&rtxn) + .query(query.query) + .execute() + .unwrap(); let mut body = Vec::new(); if let Some(headers) = index.headers(&rtxn).unwrap() { @@ -196,7 +199,7 @@ async fn main() -> anyhow::Result<()> { let content = if disable_highlighting { Cow::from(content) } else { - Cow::from(highlight_string(content, &words)) + Cow::from(highlight_string(content, &found_words)) }; body.extend_from_slice(content.as_bytes()); diff --git a/src/lib.rs b/src/lib.rs index 63d0ed6c6..d24f5f64e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -3,38 +3,27 @@ mod criterion; mod heed_codec; mod iter_shortest_paths; mod query_tokens; +mod search; mod transitive_arc; -use std::collections::{HashSet, HashMap}; +use std::collections::HashMap; use std::fs::{File, OpenOptions}; use std::hash::BuildHasherDefault; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::Instant; use anyhow::Context; -use cow_utils::CowUtils; -use fst::{IntoStreamer, Streamer}; use fxhash::{FxHasher32, FxHasher64}; use heed::types::*; use heed::{PolyDatabase, Database}; -use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; -use log::debug; use memmap::Mmap; -use once_cell::sync::Lazy; use oxidized_mtbl as omtbl; -use roaring::RoaringBitmap; -use self::best_proximity::BestProximity; +pub use self::search::{Search, SearchResult}; +pub use self::criterion::{Criterion, default_criteria}; use self::heed_codec::RoaringBitmapCodec; -use self::query_tokens::{QueryTokens, QueryToken}; use self::transitive_arc::TransitiveArc; -// Building these factories is not free. -static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); -static LEVDIST1: Lazy = Lazy::new(|| LevBuilder::new(1, true)); -static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); - pub type FastMap4 = HashMap>; pub type FastMap8 = HashMap>; pub type SmallString32 = smallstr::SmallString<[u8; 32]>; @@ -138,257 +127,7 @@ impl Index { self.documents.metadata().count_entries as usize } - pub fn search(&self, rtxn: &heed::RoTxn, query: &str) -> anyhow::Result<(HashSet, Vec)> { - let fst = match self.fst(rtxn)? { - Some(fst) => fst, - None => return Ok(Default::default()), - }; - - let (lev0, lev1, lev2) = (&LEVDIST0, &LEVDIST1, &LEVDIST2); - - let words: Vec<_> = QueryTokens::new(query).collect(); - let ends_with_whitespace = query.chars().last().map_or(false, char::is_whitespace); - let number_of_words = words.len(); - let dfas = words.into_iter().enumerate().map(|(i, word)| { - let (word, quoted) = match word { - QueryToken::Free(word) => (word.cow_to_lowercase(), word.len() <= 3), - QueryToken::Quoted(word) => (word.cow_to_lowercase(), true), - }; - let is_last = i + 1 == number_of_words; - let is_prefix = is_last && !ends_with_whitespace && !quoted; - let lev = match word.len() { - 0..=4 => if quoted { lev0 } else { lev0 }, - 5..=8 => if quoted { lev0 } else { lev1 }, - _ => if quoted { lev0 } else { lev2 }, - }; - - let dfa = if is_prefix { - lev.build_prefix_dfa(&word) - } else { - lev.build_dfa(&word) - }; - - (word, is_prefix, dfa) - }); - - let mut words = Vec::new(); - let mut positions = Vec::new(); - let before = Instant::now(); - - for (word, _is_prefix, dfa) in dfas { - let before = Instant::now(); - - let mut count = 0; - let mut union_positions = RoaringBitmap::default(); - let mut derived_words = Vec::new(); - let mut stream = fst.search_with_state(&dfa).into_stream(); - while let Some((word, state)) = stream.next() { - let word = std::str::from_utf8(word)?; - let distance = dfa.distance(state); - debug!("found {:?} at distance of {}", word, distance.to_u8()); - if let Some(positions) = self.word_positions.get(rtxn, word)? { - union_positions.union_with(&positions); - derived_words.push((word.as_bytes().to_vec(), distance.to_u8(), positions)); - count += 1; - } - } - - debug!("{} words for {:?} we have found positions {:?} in {:.02?}", - count, word, union_positions, before.elapsed()); - words.push(derived_words); - positions.push(union_positions.iter().collect()); - } - - // We compute the docids candidates for these words (and derived words). - // We do a union between all the docids of each of the words and derived words, - // we got N unions (where N is the number of query words), we then intersect them. - // TODO we must store the words documents ids to avoid these unions. - let mut candidates = RoaringBitmap::new(); - let number_of_attributes = self.number_of_attributes(rtxn)?.map_or(0, |n| n as u32); - for (i, derived_words) in words.iter().enumerate() { - let mut union_docids = RoaringBitmap::new(); - for (word, _distance, _positions) in derived_words { - for attr in 0..number_of_attributes { - let mut key = word.to_vec(); - key.extend_from_slice(&attr.to_be_bytes()); - if let Some(right) = self.word_attribute_docids.get(rtxn, &key)? { - union_docids.union_with(&right); - } - } - } - if i == 0 { - candidates = union_docids; - } else { - candidates.intersect_with(&union_docids); - } - } - - debug!("The candidates are {:?}", candidates); - debug!("Retrieving words positions took {:.02?}", before.elapsed()); - - // Returns the union of the same position for all the derived words. - let unions_word_pos = |word: usize, pos: u32| { - let mut union_docids = RoaringBitmap::new(); - for (word, _distance, attrs) in &words[word] { - if attrs.contains(pos) { - let mut key = word.clone(); - key.extend_from_slice(&pos.to_be_bytes()); - if let Some(right) = self.word_position_docids.get(rtxn, &key).unwrap() { - union_docids.union_with(&right); - } - } - } - union_docids - }; - - // Returns the union of the same attribute for all the derived words. - let unions_word_attr = |word: usize, attr: u32| { - let mut union_docids = RoaringBitmap::new(); - for (word, _distance, _) in &words[word] { - let mut key = word.clone(); - key.extend_from_slice(&attr.to_be_bytes()); - if let Some(right) = self.word_attribute_docids.get(rtxn, &key).unwrap() { - union_docids.union_with(&right); - } - } - union_docids - }; - - let mut union_cache = HashMap::new(); - let mut intersect_cache = HashMap::new(); - - let mut attribute_union_cache = HashMap::new(); - let mut attribute_intersect_cache = HashMap::new(); - - // Returns `true` if there is documents in common between the two words and positions given. - let mut contains_documents = |(lword, lpos), (rword, rpos), union_cache: &mut HashMap<_, _>, candidates: &RoaringBitmap| { - if lpos == rpos { return false } - - let (lattr, _) = best_proximity::extract_position(lpos); - let (rattr, _) = best_proximity::extract_position(rpos); - - if lattr == rattr { - // We retrieve or compute the intersection between the two given words and positions. - *intersect_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and positions. - union_cache.entry((lword, lpos)).or_insert_with(|| unions_word_pos(lword, lpos)); - union_cache.entry((rword, rpos)).or_insert_with(|| unions_word_pos(rword, rpos)); - - // TODO is there a way to avoid this double gets? - let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); - let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); - - // We first check that the docids of these unions are part of the candidates. - if lunion_docids.is_disjoint(candidates) { return false } - if runion_docids.is_disjoint(candidates) { return false } - - !lunion_docids.is_disjoint(&runion_docids) - }) - } else { - *attribute_intersect_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| { - // We retrieve or compute the unions for the two words and positions. - attribute_union_cache.entry((lword, lattr)).or_insert_with(|| unions_word_attr(lword, lattr)); - attribute_union_cache.entry((rword, rattr)).or_insert_with(|| unions_word_attr(rword, rattr)); - - // TODO is there a way to avoid this double gets? - let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap(); - let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap(); - - // We first check that the docids of these unions are part of the candidates. - if lunion_docids.is_disjoint(candidates) { return false } - if runion_docids.is_disjoint(candidates) { return false } - - !lunion_docids.is_disjoint(&runion_docids) - }) - } - }; - - let mut documents = Vec::new(); - let mut iter = BestProximity::new(positions); - while let Some((proximity, mut positions)) = iter.next(|l, r| contains_documents(l, r, &mut union_cache, &candidates)) { - positions.sort_unstable(); - - let same_prox_before = Instant::now(); - let mut same_proximity_union = RoaringBitmap::default(); - - for positions in positions { - let before = Instant::now(); - - // Precompute the potentially missing unions - positions.iter().enumerate().for_each(|(word, pos)| { - union_cache.entry((word, *pos)).or_insert_with(|| unions_word_pos(word, *pos)); - }); - - // Retrieve the unions along with the popularity of it. - let mut to_intersect: Vec<_> = positions.iter() - .enumerate() - .map(|(word, pos)| { - let docids = union_cache.get(&(word, *pos)).unwrap(); - (docids.len(), docids) - }) - .collect(); - - // Sort the unions by popuarity to help reduce - // the number of documents as soon as possible. - to_intersect.sort_unstable_by_key(|(l, _)| *l); - let elapsed_retrieving = before.elapsed(); - - let before_intersect = Instant::now(); - let intersect_docids: Option = to_intersect.into_iter() - .fold(None, |acc, (_, union_docids)| { - match acc { - Some(mut left) => { - left.intersect_with(&union_docids); - Some(left) - }, - None => Some(union_docids.clone()), - } - }); - - debug!("retrieving words took {:.02?} and took {:.02?} to intersect", - elapsed_retrieving, before_intersect.elapsed()); - - debug!("for proximity {:?} {:?} we took {:.02?} to find {} documents", - proximity, positions, before.elapsed(), - intersect_docids.as_ref().map_or(0, |rb| rb.len())); - - if let Some(intersect_docids) = intersect_docids { - same_proximity_union.union_with(&intersect_docids); - } - - // We found enough documents we can stop here - if documents.iter().map(RoaringBitmap::len).sum::() + same_proximity_union.len() >= 20 { - debug!("proximity {} took a total of {:.02?}", proximity, same_prox_before.elapsed()); - break; - } - } - - // We achieve to find valid documents ids so we remove them from the candidates list. - candidates.difference_with(&same_proximity_union); - - documents.push(same_proximity_union); - - // We remove the double occurences of documents. - for i in 0..documents.len() { - if let Some((docs, others)) = documents[..=i].split_last_mut() { - others.iter().for_each(|other| docs.difference_with(other)); - } - } - documents.retain(|rb| !rb.is_empty()); - - debug!("documents: {:?}", documents); - debug!("proximity {} took a total of {:.02?}", proximity, same_prox_before.elapsed()); - - // We found enough documents we can stop here. - if documents.iter().map(RoaringBitmap::len).sum::() >= 20 { - break; - } - } - - debug!("{} final candidates", documents.iter().map(RoaringBitmap::len).sum::()); - let words = words.into_iter().flatten().map(|(w, _distance, _)| String::from_utf8(w).unwrap()).collect(); - let documents = documents.iter().flatten().take(20).collect(); - - Ok((words, documents)) + pub fn search<'a>(&'a self, rtxn: &'a heed::RoTxn) -> Search<'a> { + Search::new(rtxn, self) } } diff --git a/src/search.rs b/src/search.rs new file mode 100644 index 000000000..91be3776d --- /dev/null +++ b/src/search.rs @@ -0,0 +1,361 @@ +use std::collections::{HashMap, HashSet}; + +use fst::{IntoStreamer, Streamer}; +use levenshtein_automata::DFA; +use levenshtein_automata::LevenshteinAutomatonBuilder as LevBuilder; +use once_cell::sync::Lazy; +use roaring::RoaringBitmap; + +use crate::query_tokens::{QueryTokens, QueryToken}; +use crate::{Index, DocumentId, Position, Attribute}; +use crate::best_proximity::{self, BestProximity}; + +// Building these factories is not free. +static LEVDIST0: Lazy = Lazy::new(|| LevBuilder::new(0, true)); +static LEVDIST1: Lazy = Lazy::new(|| LevBuilder::new(1, true)); +static LEVDIST2: Lazy = Lazy::new(|| LevBuilder::new(2, true)); + +pub struct Search<'a> { + query: Option, + offset: usize, + limit: usize, + rtxn: &'a heed::RoTxn, + index: &'a Index, +} + +impl<'a> Search<'a> { + pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { + Search { + query: None, + offset: 0, + limit: 20, + rtxn, + index, + } + } + + pub fn query(&mut self, query: impl Into) -> &mut Search<'a> { + self.query = Some(query.into()); + self + } + + pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { + self.offset = offset; + self + } + + pub fn limit(&mut self, limit: usize) -> &mut Search<'a> { + self.limit = limit; + self + } + + /// Extracts the query words from the query string and returns the DFAs accordingly. + /// TODO introduce settings for the number of typos regarding the words lengths. + fn generate_query_dfas(query: &str) -> Vec<(String, bool, DFA)> { + let (lev0, lev1, lev2) = (&LEVDIST0, &LEVDIST1, &LEVDIST2); + + let words: Vec<_> = QueryTokens::new(query).collect(); + let ends_with_whitespace = query.chars().last().map_or(false, char::is_whitespace); + let number_of_words = words.len(); + + words.into_iter().enumerate().map(|(i, word)| { + let (word, quoted) = match word { + QueryToken::Free(word) => (word.to_lowercase(), word.len() <= 3), + QueryToken::Quoted(word) => (word.to_lowercase(), true), + }; + let is_last = i + 1 == number_of_words; + let is_prefix = is_last && !ends_with_whitespace && !quoted; + let lev = match word.len() { + 0..=4 => if quoted { lev0 } else { lev0 }, + 5..=8 => if quoted { lev0 } else { lev1 }, + _ => if quoted { lev0 } else { lev2 }, + }; + + let dfa = if is_prefix { + lev.build_prefix_dfa(&word) + } else { + lev.build_dfa(&word) + }; + + (word, is_prefix, dfa) + }) + .collect() + } + + /// Fetch the words from the given FST related to the given DFAs along with the associated + /// positions and the unions of those positions where the words found appears in the documents. + fn fetch_words_positions( + rtxn: &heed::RoTxn, + index: &Index, + fst: &fst::Set<&[u8]>, + dfas: Vec<(String, bool, DFA)>, + ) -> anyhow::Result<(Vec>, Vec)> + { + // A Vec storing all the derived words from the original query words, associated + // with the distance from the original word and the positions it appears at. + // The index the derived words appears in the Vec corresponds to the original query + // word position. + let mut derived_words = Vec::>::with_capacity(dfas.len()); + // A Vec storing the unions of all of each of the derived words positions. The index + // the union appears in the Vec corresponds to the original query word position. + let mut union_positions = Vec::::with_capacity(dfas.len()); + + for (_word, _is_prefix, dfa) in dfas { + + let mut acc_derived_words = Vec::new(); + let mut acc_union_positions = RoaringBitmap::new(); + let mut stream = fst.search_with_state(&dfa).into_stream(); + while let Some((word, state)) = stream.next() { + + let word = std::str::from_utf8(word)?; + let positions = index.word_positions.get(rtxn, word)?.unwrap(); + let distance = dfa.distance(state); + acc_union_positions.union_with(&positions); + acc_derived_words.push((word.to_string(), distance.to_u8(), positions)); + } + derived_words.push(acc_derived_words); + union_positions.push(acc_union_positions); + } + + Ok((derived_words, union_positions)) + } + + /// Returns the set of docids that contains all of the query words. + fn compute_candidates( + rtxn: &heed::RoTxn, + index: &Index, + derived_words: &[Vec<(String, u8, RoaringBitmap)>], + ) -> anyhow::Result + { + // we do a union between all the docids of each of the derived words, + // we got N unions (the number of original query words), we then intersect them. + // TODO we must store the words documents ids to avoid these unions. + let mut candidates = RoaringBitmap::new(); + let number_of_attributes = index.number_of_attributes(rtxn)?.map_or(0, |n| n as u32); + + for (i, derived_words) in derived_words.iter().enumerate() { + + let mut union_docids = RoaringBitmap::new(); + for (word, _distance, _positions) in derived_words { + for attr in 0..number_of_attributes { + + let mut key = word.clone().into_bytes(); + key.extend_from_slice(&attr.to_be_bytes()); + if let Some(docids) = index.word_attribute_docids.get(rtxn, &key)? { + union_docids.union_with(&docids); + } + } + } + + if i == 0 { + candidates = union_docids; + } else { + candidates.intersect_with(&union_docids); + } + } + + Ok(candidates) + } + + /// Returns the union of the same position for all the given words. + fn union_word_position( + rtxn: &heed::RoTxn, + index: &Index, + words: &[(String, u8, RoaringBitmap)], + position: Position, + ) -> anyhow::Result + { + let mut union_docids = RoaringBitmap::new(); + for (word, _distance, positions) in words { + if positions.contains(position) { + let mut key = word.clone().into_bytes(); + key.extend_from_slice(&position.to_be_bytes()); + if let Some(docids) = index.word_position_docids.get(rtxn, &key)? { + union_docids.union_with(&docids); + } + } + } + Ok(union_docids) + } + + /// Returns the union of the same attribute for all the given words. + fn union_word_attribute( + rtxn: &heed::RoTxn, + index: &Index, + words: &[(String, u8, RoaringBitmap)], + attribute: Attribute, + ) -> anyhow::Result + { + let mut union_docids = RoaringBitmap::new(); + for (word, _distance, _positions) in words { + let mut key = word.clone().into_bytes(); + key.extend_from_slice(&attribute.to_be_bytes()); + if let Some(docids) = index.word_attribute_docids.get(rtxn, &key)? { + union_docids.union_with(&docids); + } + } + Ok(union_docids) + } + + pub fn execute(&self) -> anyhow::Result { + let rtxn = self.rtxn; + let index = self.index; + + let fst = match index.fst(rtxn)? { + Some(fst) => fst, + None => return Ok(Default::default()), + }; + + // Construct the DFAs related to the query words. + // TODO do a placeholder search when query string isn't present. + let dfas = match &self.query { + Some(q) => Self::generate_query_dfas(q), + None => return Ok(Default::default()), + }; + + let (derived_words, union_positions) = Self::fetch_words_positions(rtxn, index, &fst, dfas)?; + let mut candidates = Self::compute_candidates(rtxn, index, &derived_words)?; + + let mut union_cache = HashMap::new(); + let mut intersect_cache = HashMap::new(); + + let mut attribute_union_cache = HashMap::new(); + let mut attribute_intersect_cache = HashMap::new(); + + // Returns `true` if there is documents in common between the two words and positions given. + let mut contains_documents = |(lword, lpos), (rword, rpos), union_cache: &mut HashMap<_, _>, candidates: &RoaringBitmap| { + if lpos == rpos { return false } + + let (lattr, _) = best_proximity::extract_position(lpos); + let (rattr, _) = best_proximity::extract_position(rpos); + + if lattr == rattr { + // We retrieve or compute the intersection between the two given words and positions. + *intersect_cache.entry(((lword, lpos), (rword, rpos))).or_insert_with(|| { + // We retrieve or compute the unions for the two words and positions. + union_cache.entry((lword, lpos)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[lword]; + Self::union_word_position(rtxn, index, words, lpos).unwrap() + }); + union_cache.entry((rword, rpos)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[rword]; + Self::union_word_position(rtxn, index, words, rpos).unwrap() + }); + + // TODO is there a way to avoid this double gets? + let lunion_docids = union_cache.get(&(lword, lpos)).unwrap(); + let runion_docids = union_cache.get(&(rword, rpos)).unwrap(); + + // We first check that the docids of these unions are part of the candidates. + if lunion_docids.is_disjoint(candidates) { return false } + if runion_docids.is_disjoint(candidates) { return false } + + !lunion_docids.is_disjoint(&runion_docids) + }) + } else { + *attribute_intersect_cache.entry(((lword, lattr), (rword, rattr))).or_insert_with(|| { + // We retrieve or compute the unions for the two words and positions. + attribute_union_cache.entry((lword, lattr)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[lword]; + Self::union_word_attribute(rtxn, index, words, lattr).unwrap() + }); + attribute_union_cache.entry((rword, rattr)).or_insert_with(|| { + let words: &Vec<_> = &derived_words[rword]; + Self::union_word_attribute(rtxn, index, words, rattr).unwrap() + }); + + // TODO is there a way to avoid this double gets? + let lunion_docids = attribute_union_cache.get(&(lword, lattr)).unwrap(); + let runion_docids = attribute_union_cache.get(&(rword, rattr)).unwrap(); + + // We first check that the docids of these unions are part of the candidates. + if lunion_docids.is_disjoint(candidates) { return false } + if runion_docids.is_disjoint(candidates) { return false } + + !lunion_docids.is_disjoint(&runion_docids) + }) + } + }; + + let mut documents = Vec::new(); + let mut iter = BestProximity::new(union_positions); + while let Some((_proximity, mut positions)) = iter.next(|l, r| contains_documents(l, r, &mut union_cache, &candidates)) { + positions.sort_unstable_by(|a, b| a.iter().cmp(b.iter())); + + let mut same_proximity_union = RoaringBitmap::default(); + for positions in positions { + + // Precompute the potentially missing unions + positions.iter().enumerate().for_each(|(word, pos)| { + union_cache.entry((word, pos)).or_insert_with(|| { + let words = &derived_words[word]; + Self::union_word_position(rtxn, index, words, pos).unwrap() + }); + }); + + // Retrieve the unions along with the popularity of it. + let mut to_intersect: Vec<_> = positions.iter() + .enumerate() + .map(|(word, pos)| { + let docids = union_cache.get(&(word, pos)).unwrap(); + (docids.len(), docids) + }) + .collect(); + + // Sort the unions by popularity to help reduce + // the number of documents as soon as possible. + to_intersect.sort_unstable_by_key(|(l, _)| *l); + + let intersect_docids: Option = to_intersect.into_iter() + .fold(None, |acc, (_, union_docids)| { + match acc { + Some(mut left) => { + left.intersect_with(&union_docids); + Some(left) + }, + None => Some(union_docids.clone()), + } + }); + + if let Some(intersect_docids) = intersect_docids { + same_proximity_union.union_with(&intersect_docids); + } + + // We found enough documents we can stop here + if documents.iter().map(RoaringBitmap::len).sum::() + same_proximity_union.len() >= 20 { + break; + } + } + + // We achieve to find valid documents ids so we remove them from the candidates list. + candidates.difference_with(&same_proximity_union); + + documents.push(same_proximity_union); + + // We remove the double occurences of documents. + for i in 0..documents.len() { + if let Some((docs, others)) = documents[..=i].split_last_mut() { + others.iter().for_each(|other| docs.difference_with(other)); + } + } + documents.retain(|rb| !rb.is_empty()); + + // We found enough documents we can stop here. + if documents.iter().map(RoaringBitmap::len).sum::() >= 20 { + break; + } + } + + let found_words = derived_words.into_iter().flatten().map(|(w, _, _)| w).collect(); + let documents_ids = documents.iter().flatten().take(20).collect(); + + Ok(SearchResult { found_words, documents_ids }) + } +} + +#[derive(Default)] +pub struct SearchResult { + pub found_words: HashSet, + // TODO those documents ids should be associated with their criteria scores. + pub documents_ids: Vec, +}