diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index c2b6ca3fc..c474d285e 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -9,6 +9,7 @@ use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli; +use meilisearch_types::milli::vector::DistributionShift; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -200,10 +201,11 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); - embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; + tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); } @@ -238,10 +240,11 @@ pub async fn search_with_post( let features = index_scheduler.features(); - embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; + tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); } @@ -257,39 +260,74 @@ pub async fn embed( query: &mut SearchQuery, index_scheduler: &IndexScheduler, index: &milli::Index, -) -> Result<(), ResponseError> { - if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) = - (&query.vector, &query.q, &query.hybrid) - { - let embedder_configs = index.embedding_configs(&index.read_txn()?)?; - let embedders = index_scheduler.embedders(embedder_configs)?; +) -> Result, ResponseError> { + match (&query.hybrid, &query.vector, &query.q) { + (Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q)) + if !q.trim().is_empty() => + { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedders = index_scheduler.embedders(embedder_configs)?; - let embedder = if let Some(embedder_name) = embedder { - embedders.get(embedder_name) - } else { - embedders.get_default() - }; + let embedder = if let Some(embedder_name) = embedder { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; - let embedder = embedder - .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) - .map_err(milli::Error::from)? - .0; - let embeddings = embedder - .embed(vec![q.to_owned()]) - .await - .map_err(milli::vector::Error::from) - .map_err(milli::Error::from)? - .pop() - .expect("No vector returned from embedding"); + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) + .map_err(milli::Error::from)? + .0; - if embeddings.iter().nth(1).is_some() { - warn!("Ignoring embeddings past the first one in long search query"); - query.vector = Some(embeddings.iter().next().unwrap().to_vec()); - } else { - query.vector = Some(embeddings.into_inner()); + let distribution = embedder.distribution(); + + let embeddings = embedder + .embed(vec![q.to_owned()]) + .await + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); + + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = Some(embeddings.iter().next().unwrap().to_vec()); + } else { + query.vector = Some(embeddings.into_inner()); + } + Ok(distribution) } + (Some(hybrid), vector, _) => { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedders = index_scheduler.embedders(embedder_configs)?; + + let embedder = if let Some(embedder_name) = &hybrid.embedder { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; + + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) + .map_err(milli::Error::from)? + .0; + + if let Some(vector) = vector { + if vector.len() != embedder.dimensions() { + return Err(meilisearch_types::milli::Error::UserError( + meilisearch_types::milli::UserError::InvalidVectorDimensions { + expected: embedder.dimensions(), + found: vector.len(), + }, + ) + .into()); + } + } + + Ok(embedder.distribution()) + } + _ => Ok(None), } - Ok(()) } #[cfg(test)] diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index 4e578572d..8e81688e6 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -75,12 +75,15 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; - embed(&mut query, index_scheduler.get_ref(), &index).await.with_index(query_index)?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index) + .await + .with_index(query_index)?; - let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)) - .await - .with_index(query_index)?; + let search_result = tokio::task::spawn_blocking(move || { + perform_search(&index, query, features, distribution) + }) + .await + .with_index(query_index)?; search_results.push(SearchResultWithIndex { index_uid: index_uid.into_inner(), diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 267a404c0..b5dba8a58 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -13,6 +13,7 @@ use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; +use meilisearch_types::milli::vector::DistributionShift; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -90,16 +91,22 @@ pub struct SearchQuery { #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct HybridQuery { /// TODO validate that sementic ratio is between 0.0 and 1,0 - #[deserr(default, error = DeserrJsonError)] + #[deserr(default, error = DeserrJsonError, default)] pub semantic_ratio: SemanticRatio, #[deserr(default, error = DeserrJsonError, default)] pub embedder: Option, } -#[derive(Debug, Clone, Copy, Default, PartialEq, Deserr)] +#[derive(Debug, Clone, Copy, PartialEq, Deserr)] #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] pub struct SemanticRatio(f32); +impl Default for SemanticRatio { + fn default() -> Self { + DEFAULT_SEMANTIC_RATIO() + } +} + impl std::convert::TryFrom for SemanticRatio { type Error = InvalidSearchSemanticRatio; @@ -374,6 +381,7 @@ fn prepare_search<'t>( rtxn: &'t RoTxn, query: &'t SearchQuery, features: RoFeatures, + distribution: Option, ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { let mut search = index.search(rtxn); @@ -389,6 +397,8 @@ fn prepare_search<'t>( return Err(MeilisearchHttpError::MissingSearchHybrid); } + search.distribution_shift(distribution); + if let Some(ref vector) = query.vector { match &query.hybrid { // If semantic ratio is 0.0, only the query search will impact the search results, @@ -482,12 +492,13 @@ pub fn perform_search( index: &Index, query: SearchQuery, features: RoFeatures, + distribution: Option, ) -> Result { let before_search = Instant::now(); let rtxn = index.read_txn()?; let (search, is_finite_pagination, max_total_hits, offset) = - prepare_search(index, &rtxn, &query, features)?; + prepare_search(index, &rtxn, &query, features, distribution)?; let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = match &query.hybrid { @@ -718,8 +729,9 @@ pub fn perform_facet_search( let before_search = Instant::now(); let rtxn = index.read_txn()?; - let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features)?; - let mut facet_search = SearchForFacetValues::new(facet_name, search); + let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?; + let mut facet_search = + SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some()); if let Some(facet_query) = &facet_query { facet_search.query(facet_query); } diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs index 7986091b0..c3534c110 100644 --- a/meilisearch/tests/search/hybrid.rs +++ b/meilisearch/tests/search/hybrid.rs @@ -27,11 +27,11 @@ async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Inde ) .await; assert_eq!(202, code, "{:?}", response); - index.wait_task(0).await; + index.wait_task(response.uid()).await; let (response, code) = index.add_documents(documents.clone(), None).await; assert_eq!(202, code, "{:?}", response); - index.wait_task(1).await; + index.wait_task(response.uid()).await; index } @@ -68,7 +68,7 @@ async fn simple_search() { ) .await; snapshot!(code, @"200 OK"); - snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); + snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###); let (response, code) = index .search_post( diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 0fb24be84..3e4849578 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -154,6 +154,15 @@ impl<'a> Search<'a> { self } + pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { + if has_vector_search { + let ctx = SearchContext::new(self.index, self.rtxn); + filtered_universe(&ctx, &self.filter) + } else { + Ok(self.execute()?.candidates) + } + } + pub fn execute(&self) -> Result { let embedder_name; let embedder_name = match &self.embedder_name { @@ -297,11 +306,16 @@ pub struct SearchForFacetValues<'a> { query: Option, facet: String, search_query: Search<'a>, + is_hybrid: bool, } impl<'a> SearchForFacetValues<'a> { - pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> { - SearchForFacetValues { query: None, facet, search_query } + pub fn new( + facet: String, + search_query: Search<'a>, + is_hybrid: bool, + ) -> SearchForFacetValues<'a> { + SearchForFacetValues { query: None, facet, search_query, is_hybrid } } pub fn query(&mut self, query: impl Into) -> &mut Self { @@ -351,7 +365,9 @@ impl<'a> SearchForFacetValues<'a> { None => return Ok(vec![]), }; - let search_candidates = self.search_query.execute()?.candidates; + let search_candidates = self + .search_query + .execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?; match self.query.as_ref() { Some(query) => { diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 6906bbcd3..ffc3f6b3a 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -509,7 +509,7 @@ where // We write the primary key field id into the main database self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; - let mut rng = rand::rngs::StdRng::from_entropy(); + let mut rng = rand::rngs::StdRng::seed_from_u64(42); for (embedder_name, dimension) in dimension { let wtxn = &mut *self.wtxn; diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 3162dadec..0a6bcbe93 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; pub use super::error::{EmbedError, Error, NewEmbedderError}; -use super::{Embedding, Embeddings}; +use super::{DistributionShift, Embedding, Embeddings}; #[derive( Debug, @@ -184,4 +184,12 @@ impl Embedder { pub fn dimensions(&self) -> usize { self.dimensions } + + pub fn distribution(&self) -> Option { + if self.options.model == "BAAI/bge-base-en-v1.5" { + Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) + } else { + None + } + } } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index df5750e77..81c4cf4a1 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -202,6 +202,14 @@ impl Embedder { Embedder::UserProvided(embedder) => embedder.dimensions(), } } + + pub fn distribution(&self) -> Option { + match self { + Embedder::HuggingFace(embedder) => embedder.distribution(), + Embedder::OpenAi(embedder) => embedder.distribution(), + Embedder::UserProvided(_embedder) => None, + } + } } #[derive(Debug, Clone, Copy)] diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 7ae626494..c11e6ddc6 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -4,7 +4,7 @@ use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use super::error::{EmbedError, NewEmbedderError}; -use super::{Embedding, Embeddings}; +use super::{DistributionShift, Embedding, Embeddings}; #[derive(Debug)] pub struct Embedder { @@ -65,6 +65,14 @@ impl EmbeddingModel { _ => None, } } + + fn distribution(&self) -> Option { + match self { + EmbeddingModel::TextEmbeddingAda002 => { + Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) + } + } + } } pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; @@ -326,6 +334,10 @@ impl Embedder { pub fn dimensions(&self) -> usize { self.options.embedding_model.dimensions() } + + pub fn distribution(&self) -> Option { + self.options.embedding_model.distribution() + } } // retrying in case of failure