diff --git a/milli/src/search/similar.rs b/milli/src/search/similar.rs index 49b7c876f..bf5cc323f 100644 --- a/milli/src/search/similar.rs +++ b/milli/src/search/similar.rs @@ -17,6 +17,7 @@ pub struct Similar<'a> { index: &'a Index, embedder_name: String, embedder: Arc, + ranking_score_threshold: Option, } impl<'a> Similar<'a> { @@ -29,7 +30,17 @@ impl<'a> Similar<'a> { embedder_name: String, embedder: Arc, ) -> Self { - Self { id, filter: None, offset, limit, rtxn, index, embedder_name, embedder } + Self { + id, + filter: None, + offset, + limit, + rtxn, + index, + embedder_name, + embedder, + ranking_score_threshold: None, + } } pub fn filter(&mut self, filter: Filter<'a>) -> &mut Self { @@ -37,8 +48,18 @@ impl<'a> Similar<'a> { self } + pub fn ranking_score_threshold(&mut self, ranking_score_threshold: f64) -> &mut Self { + self.ranking_score_threshold = Some(ranking_score_threshold); + self + } + pub fn execute(&self) -> Result { - let universe = filtered_universe(self.index, self.rtxn, &self.filter)?; + let mut universe = filtered_universe(self.index, self.rtxn, &self.filter)?; + + // we never want to receive the docid + universe.remove(self.id); + + let universe = universe; let embedder_index = self.index @@ -77,6 +98,8 @@ impl<'a> Similar<'a> { let mut documents_seen = RoaringBitmap::new(); documents_seen.insert(self.id); + let mut candidates = universe; + for (docid, distance) in results .into_iter() // skip documents we've already seen & mark that we saw the current document @@ -85,8 +108,6 @@ impl<'a> Similar<'a> { // take **after** filter and skip so that we get exactly limit elements if available .take(self.limit) { - documents_ids.push(docid); - let score = 1.0 - distance; let score = self .embedder @@ -94,14 +115,28 @@ impl<'a> Similar<'a> { .map(|distribution| distribution.shift(score)) .unwrap_or(score); - let score = ScoreDetails::Vector(score_details::Vector { similarity: Some(score) }); + let score_details = + vec![ScoreDetails::Vector(score_details::Vector { similarity: Some(score) })]; - document_scores.push(vec![score]); + let score = ScoreDetails::global_score(score_details.iter()); + + if let Some(ranking_score_threshold) = &self.ranking_score_threshold { + if score < *ranking_score_threshold { + // this document is no longer a candidate + candidates.remove(docid); + // any document after this one is no longer a candidate either, so restrict the set to documents already seen. + candidates &= documents_seen; + break; + } + } + + documents_ids.push(docid); + document_scores.push(score_details); } Ok(SearchResult { matching_words: Default::default(), - candidates: universe, + candidates, documents_ids, document_scores, degraded: false,