diff --git a/src/bin/search.rs b/src/bin/search.rs index 86abe752c..c3fd7cd66 100644 --- a/src/bin/search.rs +++ b/src/bin/search.rs @@ -58,7 +58,7 @@ fn main() -> anyhow::Result<()> { let before = Instant::now(); let query = result?; - let documents_ids = index.search(&rtxn, &query)?; + let (_, documents_ids) = index.search(&rtxn, &query)?; let headers = match index.headers(&rtxn)? { Some(headers) => headers, None => return Ok(()), diff --git a/src/bin/serve.rs b/src/bin/serve.rs index f8e84555d..f9e18b315 100644 --- a/src/bin/serve.rs +++ b/src/bin/serve.rs @@ -152,14 +152,21 @@ async fn main() -> anyhow::Result<()> { let before_search = Instant::now(); let rtxn = env_cloned.read_txn().unwrap(); - let documents_ids = index.search(&rtxn, &query.query).unwrap(); + let (words, documents_ids) = index.search(&rtxn, &query.query).unwrap(); let mut body = Vec::new(); if let Some(headers) = index.headers(&rtxn).unwrap() { // We write the headers body.extend_from_slice(headers); - let re = Regex::new(r"(?i)(hello)").unwrap(); + let mut regex = format!(r"(?i)\b("); + let number_of_words = words.len(); + words.into_iter().enumerate().for_each(|(i, w)| { + regex.push_str(&w); + if i != number_of_words - 1 { regex.push('|') } + }); + regex.push_str(r")\b"); + let re = Regex::new(®ex).unwrap(); for id in documents_ids { let content = index.documents.get(&rtxn, &BEU32::new(id)).unwrap(); diff --git a/src/lib.rs b/src/lib.rs index 6af75e875..9384155b7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,7 @@ mod iter_shortest_paths; mod query_tokens; use std::borrow::Cow; -use std::collections::HashMap; +use std::collections::{HashSet, HashMap}; use std::hash::BuildHasherDefault; use std::time::Instant; @@ -96,10 +96,10 @@ impl Index { } } - pub fn search(&self, rtxn: &heed::RoTxn, query: &str) -> anyhow::Result> { + pub fn search(&self, rtxn: &heed::RoTxn, query: &str) -> anyhow::Result<(HashSet, Vec)> { let fst = match self.fst(rtxn)? { Some(fst) => fst, - None => return Ok(vec![]), + None => return Ok(Default::default()), }; let (lev0, lev1, lev2) = (&LEVDIST0, &LEVDIST1, &LEVDIST2); @@ -342,7 +342,10 @@ impl Index { } } - debug!("{} candidates", documents.iter().map(RoaringBitmap::len).sum::()); - Ok(documents.iter().flatten().take(20).collect()) + debug!("{} final candidates", documents.iter().map(RoaringBitmap::len).sum::()); + let words = words.into_iter().flatten().map(|(w, _)| String::from_utf8(w).unwrap()).collect(); + let documents = documents.iter().flatten().take(20).collect(); + + Ok((words, documents)) } }