From 6d04a285dca479a29b4c1761462ebd6246127562 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 11 Aug 2020 15:18:02 +0200 Subject: [PATCH] Retrieve and display the distances of the words found --- src/lib.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 3ec1cbcec..6ec02433e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,6 @@ mod iter_shortest_paths; mod query_tokens; mod transitive_arc; -use std::borrow::Cow; use std::collections::{HashSet, HashMap}; use std::fs::{File, OpenOptions}; use std::hash::BuildHasherDefault; @@ -181,13 +180,14 @@ impl Index { let mut count = 0; let mut union_positions = RoaringBitmap::default(); let mut derived_words = Vec::new(); - // TODO re-enable the prefixes system - let mut stream = fst.search(&dfa).into_stream(); - while let Some(word) = stream.next() { + let mut stream = fst.search_with_state(&dfa).into_stream(); + while let Some((word, state)) = stream.next() { let word = std::str::from_utf8(word)?; + let distance = dfa.distance(state); + debug!("found {:?} at distance of {}", word, distance.to_u8()); if let Some(positions) = self.word_positions.get(rtxn, word)? { union_positions.union_with(&positions); - derived_words.push((word.as_bytes().to_vec(), positions)); + derived_words.push((word.as_bytes().to_vec(), distance.to_u8(), positions)); count += 1; } } @@ -198,7 +198,7 @@ impl Index { positions.push(union_positions.iter().collect()); } - // We compute the docids candiate for these words (and derived words). + // We compute the docids candidates for these words (and derived words). // We do a union between all the docids of each of the words and derived words, // we got N unions (where N is the number of query words), we then intersect them. // TODO we must store the words documents ids to avoid these unions. @@ -206,7 +206,7 @@ impl Index { let number_of_attributes = self.number_of_attributes(rtxn)?.map_or(0, |n| n as u32); for (i, derived_words) in words.iter().enumerate() { let mut union_docids = RoaringBitmap::new(); - for (word, _positions) in derived_words { + for (word, _distance, _positions) in derived_words { for attr in 0..number_of_attributes { let mut key = word.to_vec(); key.extend_from_slice(&attr.to_be_bytes()); @@ -228,7 +228,7 @@ impl Index { // Returns the union of the same position for all the derived words. let unions_word_pos = |word: usize, pos: u32| { let mut union_docids = RoaringBitmap::new(); - for (word, attrs) in &words[word] { + for (word, _distance, attrs) in &words[word] { if attrs.contains(pos) { let mut key = word.clone(); key.extend_from_slice(&pos.to_be_bytes()); @@ -243,7 +243,7 @@ impl Index { // Returns the union of the same attribute for all the derived words. let unions_word_attr = |word: usize, attr: u32| { let mut union_docids = RoaringBitmap::new(); - for (word, _) in &words[word] { + for (word, _distance, _) in &words[word] { let mut key = word.clone(); key.extend_from_slice(&attr.to_be_bytes()); if let Some(right) = self.word_attribute_docids.get(rtxn, &key).unwrap() { @@ -385,7 +385,7 @@ impl Index { } debug!("{} final candidates", documents.iter().map(RoaringBitmap::len).sum::()); - let words = words.into_iter().flatten().map(|(w, _)| String::from_utf8(w).unwrap()).collect(); + let words = words.into_iter().flatten().map(|(w, _distance, _)| String::from_utf8(w).unwrap()).collect(); let documents = documents.iter().flatten().take(20).collect(); Ok((words, documents))