diff --git a/crates/meilisearch/tests/search/hybrid.rs b/crates/meilisearch/tests/search/hybrid.rs index 3282a357a..3a8fb2b4c 100644 --- a/crates/meilisearch/tests/search/hybrid.rs +++ b/crates/meilisearch/tests/search/hybrid.rs @@ -76,6 +76,48 @@ static SINGLE_DOCUMENT_VEC: Lazy = Lazy::new(|| { }]) }); +static TEST_DISTINCT_DOCUMENTS: Lazy = Lazy::new(|| { + // for query "Captain Marvel" and vector [1.0, 1.0] + json!([ + { + "id": 0, + "search": "Captain Planet", + "desc": "#2 for keyword search, #3 for hybrid search", + "_vectors": { + "default": [-1.0, 0.0], + }, + "distinct": 0 + }, + { + "id": 1, + "search": "Captain Marvel", + "desc": "#1 for keyword search, #4 for hybrid search", + "_vectors": { + "default": [-1.0, -1.0], + }, + "distinct": 1 + }, + { + "id": 2, + "search": "Some Captain at least", + "desc": "#3 for keyword search, #1 for hybrid search", + "_vectors": { + "default": [1.0, 1.0], + }, + "distinct": 0 + }, + { + "id": 3, + "search": "Irrelevant Capitaine", + "desc": "#4 for keyword search, #2 for hybrid search", + "_vectors": { + "default": [1.0, 0.0], + }, + "distinct": 1 + }, + ]) +}); + static SIMPLE_SEARCH_DOCUMENTS: Lazy = Lazy::new(|| { json!([ { @@ -493,6 +535,50 @@ async fn query_combination() { snapshot!(response["semanticHitCount"], @"0"); } +// see +#[actix_rt::test] +async fn distinct_is_applied() { + let server = Server::new().await; + let index = index_with_documents_user_provided(&server, &TEST_DISTINCT_DOCUMENTS).await; + + let (response, code) = index.update_settings(json!({ "distinctAttribute": "distinct" } )).await; + assert_eq!(202, code, "{:?}", response); + index.wait_task(response.uid()).await.succeeded(); + + // pure keyword + let (response, code) = index + .search_post( + json!({"q": "Captain Marvel", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.0, "embedder": "default"}}), + ) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"id":1,"search":"Captain Marvel","desc":"#1 for keyword search, #4 for hybrid search","distinct":1},{"id":0,"search":"Captain Planet","desc":"#2 for keyword search, #3 for hybrid search","distinct":0}]"###); + snapshot!(response["semanticHitCount"], @"null"); + snapshot!(response["estimatedTotalHits"], @"2"); + + // pure semantic + let (response, code) = index + .search_post( + json!({"q": "Captain Marvel", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 1.0, "embedder": "default"}}), + ) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"id":2,"search":"Some Captain at least","desc":"#3 for keyword search, #1 for hybrid search","distinct":0},{"id":3,"search":"Irrelevant Capitaine","desc":"#4 for keyword search, #2 for hybrid search","distinct":1}]"###); + snapshot!(response["semanticHitCount"], @"2"); + snapshot!(response["estimatedTotalHits"], @"2"); + + // hybrid + let (response, code) = index + .search_post( + json!({"q": "Captain Marvel", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.5, "embedder": "default"}}), + ) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"id":2,"search":"Some Captain at least","desc":"#3 for keyword search, #1 for hybrid search","distinct":0},{"id":1,"search":"Captain Marvel","desc":"#1 for keyword search, #4 for hybrid search","distinct":1}]"###); + snapshot!(response["semanticHitCount"], @"1"); + snapshot!(response["estimatedTotalHits"], @"2"); +} + #[actix_rt::test] async fn retrieve_vectors() { let server = Server::new().await; diff --git a/crates/milli/src/search/hybrid.rs b/crates/milli/src/search/hybrid.rs index e07f886c9..b63f6288f 100644 --- a/crates/milli/src/search/hybrid.rs +++ b/crates/milli/src/search/hybrid.rs @@ -1,11 +1,13 @@ use std::cmp::Ordering; +use heed::RoTxn; use itertools::Itertools; use roaring::RoaringBitmap; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; +use crate::search::new::{distinct_fid, distinct_single_docid}; use crate::search::SemanticSearch; -use crate::{MatchingWords, Result, Search, SearchResult}; +use crate::{Index, MatchingWords, Result, Search, SearchResult}; struct ScoreWithRatioResult { matching_words: MatchingWords, @@ -91,7 +93,10 @@ impl ScoreWithRatioResult { keyword_results: Self, from: usize, length: usize, - ) -> (SearchResult, u32) { + distinct: Option<&str>, + index: &Index, + rtxn: &RoTxn<'_>, + ) -> Result<(SearchResult, u32)> { #[derive(Clone, Copy)] enum ResultSource { Semantic, @@ -106,8 +111,9 @@ impl ScoreWithRatioResult { vector_results.document_scores.len() + keyword_results.document_scores.len(), ); - let mut documents_seen = RoaringBitmap::new(); - for ((docid, (main_score, _sub_score)), source) in vector_results + let distinct_fid = distinct_fid(distinct, index, rtxn)?; + let mut excluded_documents = RoaringBitmap::new(); + for res in vector_results .document_scores .into_iter() .zip(std::iter::repeat(ResultSource::Semantic)) @@ -121,13 +127,33 @@ impl ScoreWithRatioResult { compare_scores(left, right).is_ge() }, ) - // remove documents we already saw - .filter(|((docid, _), _)| documents_seen.insert(*docid)) + // remove documents we already saw and apply distinct rule + .filter_map(|item @ ((docid, _), _)| { + if !excluded_documents.insert(docid) { + // the document was already added, or is indistinct from an already-added document. + return None; + } + + if let Some(distinct_fid) = distinct_fid { + if let Err(error) = distinct_single_docid( + index, + rtxn, + distinct_fid, + docid, + &mut excluded_documents, + ) { + return Some(Err(error)); + } + } + + Some(Ok(item)) + }) // start skipping **after** the filter .skip(from) // take **after** skipping .take(length) { + let ((docid, (main_score, _sub_score)), source) = res?; if let ResultSource::Semantic = source { semantic_hit_count += 1; } @@ -136,10 +162,24 @@ impl ScoreWithRatioResult { document_scores.push(main_score); } - ( + // compute the set of candidates from both sets + let candidates = vector_results.candidates | keyword_results.candidates; + let must_remove_redundant_candidates = distinct_fid.is_some(); + let candidates = if must_remove_redundant_candidates { + // patch-up the candidates to remove the indistinct documents, then add back the actual hits + let mut candidates = candidates - excluded_documents; + for docid in &documents_ids { + candidates.insert(*docid); + } + candidates + } else { + candidates + }; + + Ok(( SearchResult { matching_words: keyword_results.matching_words, - candidates: vector_results.candidates | keyword_results.candidates, + candidates, documents_ids, document_scores, degraded: vector_results.degraded | keyword_results.degraded, @@ -147,7 +187,7 @@ impl ScoreWithRatioResult { | keyword_results.used_negative_operator, }, semantic_hit_count, - ) + )) } } @@ -226,8 +266,15 @@ impl Search<'_> { let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); - let (merge_results, semantic_hit_count) = - ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); + let (merge_results, semantic_hit_count) = ScoreWithRatioResult::merge( + vector_results, + keyword_results, + self.offset, + self.limit, + search.distinct.as_deref(), + search.index, + search.rtxn, + )?; assert!(merge_results.documents_ids.len() <= self.limit); Ok((merge_results, Some(semantic_hit_count))) } diff --git a/crates/milli/src/search/new/bucket_sort.rs b/crates/milli/src/search/new/bucket_sort.rs index ca7a4a986..3c26cad5c 100644 --- a/crates/milli/src/search/new/bucket_sort.rs +++ b/crates/milli/src/search/new/bucket_sort.rs @@ -4,7 +4,9 @@ use super::logger::SearchLogger; use super::ranking_rules::{BoxRankingRule, RankingRuleQueryTrait}; use super::SearchContext; use crate::score_details::{ScoreDetails, ScoringStrategy}; -use crate::search::new::distinct::{apply_distinct_rule, distinct_single_docid, DistinctOutput}; +use crate::search::new::distinct::{ + apply_distinct_rule, distinct_fid, distinct_single_docid, DistinctOutput, +}; use crate::{Result, TimeBudget}; pub struct BucketSortOutput { @@ -35,16 +37,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>( logger.ranking_rules(&ranking_rules); logger.initial_universe(universe); - let distinct_field = match distinct { - Some(distinct) => Some(distinct), - None => ctx.index.distinct_field(ctx.txn)?, - }; - - let distinct_fid = if let Some(field) = distinct_field { - ctx.index.fields_ids_map(ctx.txn)?.id(field) - } else { - None - }; + let distinct_fid = distinct_fid(distinct, ctx.index, ctx.txn)?; if universe.len() < from as u64 { return Ok(BucketSortOutput { diff --git a/crates/milli/src/search/new/distinct.rs b/crates/milli/src/search/new/distinct.rs index 17859b6f8..36172302a 100644 --- a/crates/milli/src/search/new/distinct.rs +++ b/crates/milli/src/search/new/distinct.rs @@ -9,7 +9,7 @@ use crate::heed_codec::facet::{ FacetGroupKey, FacetGroupKeyCodec, FacetGroupValueCodec, FieldDocIdFacetCodec, }; use crate::heed_codec::BytesRefCodec; -use crate::{Index, Result, SearchContext}; +use crate::{FieldId, Index, Result, SearchContext}; pub struct DistinctOutput { pub remaining: RoaringBitmap, @@ -121,3 +121,18 @@ pub fn facet_string_values<'a>( fn facet_values_prefix_key(distinct: u16, id: u32) -> [u8; FID_SIZE + DOCID_SIZE] { concat_arrays::concat_arrays!(distinct.to_be_bytes(), id.to_be_bytes()) } + +pub fn distinct_fid( + query_distinct_field: Option<&str>, + index: &Index, + rtxn: &RoTxn<'_>, +) -> Result> { + let distinct_field = match query_distinct_field { + Some(distinct) => Some(distinct), + None => index.distinct_field(rtxn)?, + }; + + let distinct_fid = + if let Some(field) = distinct_field { index.fields_ids_map(rtxn)?.id(field) } else { None }; + Ok(distinct_fid) +} diff --git a/crates/milli/src/search/new/mod.rs b/crates/milli/src/search/new/mod.rs index 0a3bc1b04..a65b4076b 100644 --- a/crates/milli/src/search/new/mod.rs +++ b/crates/milli/src/search/new/mod.rs @@ -28,6 +28,7 @@ use std::time::Duration; use bucket_sort::{bucket_sort, BucketSortOutput}; use charabia::{Language, TokenizerBuilder}; use db_cache::DatabaseCache; +pub use distinct::{distinct_fid, distinct_single_docid}; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; @@ -47,8 +48,7 @@ use sort::Sort; use self::distinct::facet_string_values; use self::geo_sort::GeoSort; -pub use self::geo_sort::Parameter as GeoSortParameter; -pub use self::geo_sort::Strategy as GeoSortStrategy; +pub use self::geo_sort::{Parameter as GeoSortParameter, Strategy as GeoSortStrategy}; use self::graph_based_ranking_rule::Words; use self::interner::Interned; use self::vector_sort::VectorSort;