Various changes

- fixed seed for arroy
- check vector dimensions as soon as it is provided to search
- don't embed whitespace
This commit is contained in:
Louis Dureuil 2023-12-14 16:01:35 +01:00
parent 217105b7da
commit 87bba98bd8
No known key found for this signature in database
9 changed files with 148 additions and 51 deletions

View File

@ -9,6 +9,7 @@ use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli; use meilisearch_types::milli;
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::serde_cs::vec::CS; use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value; use serde_json::Value;
@ -200,10 +201,11 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?; let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features(); let features = index_scheduler.features();
embed(&mut query, index_scheduler.get_ref(), &index).await?; let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
.await?;
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result); aggregate.succeed(search_result);
} }
@ -238,10 +240,11 @@ pub async fn search_with_post(
let features = index_scheduler.features(); let features = index_scheduler.features();
embed(&mut query, index_scheduler.get_ref(), &index).await?; let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
let search_result = let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
.await?;
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result); aggregate.succeed(search_result);
} }
@ -257,9 +260,10 @@ pub async fn embed(
query: &mut SearchQuery, query: &mut SearchQuery,
index_scheduler: &IndexScheduler, index_scheduler: &IndexScheduler,
index: &milli::Index, index: &milli::Index,
) -> Result<(), ResponseError> { ) -> Result<Option<DistributionShift>, ResponseError> {
if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) = match (&query.hybrid, &query.vector, &query.q) {
(&query.vector, &query.q, &query.hybrid) (Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q))
if !q.trim().is_empty() =>
{ {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?; let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?; let embedders = index_scheduler.embedders(embedder_configs)?;
@ -274,6 +278,9 @@ pub async fn embed(
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) .ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)? .map_err(milli::Error::from)?
.0; .0;
let distribution = embedder.distribution();
let embeddings = embedder let embeddings = embedder
.embed(vec![q.to_owned()]) .embed(vec![q.to_owned()])
.await .await
@ -288,8 +295,39 @@ pub async fn embed(
} else { } else {
query.vector = Some(embeddings.into_inner()); query.vector = Some(embeddings.into_inner());
} }
Ok(distribution)
}
(Some(hybrid), vector, _) => {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder = if let Some(embedder_name) = &hybrid.embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
if let Some(vector) = vector {
if vector.len() != embedder.dimensions() {
return Err(meilisearch_types::milli::Error::UserError(
meilisearch_types::milli::UserError::InvalidVectorDimensions {
expected: embedder.dimensions(),
found: vector.len(),
},
)
.into());
}
}
Ok(embedder.distribution())
}
_ => Ok(None),
} }
Ok(())
} }
#[cfg(test)] #[cfg(test)]

View File

@ -75,10 +75,13 @@ pub async fn multi_search_with_post(
}) })
.with_index(query_index)?; .with_index(query_index)?;
embed(&mut query, index_scheduler.get_ref(), &index).await.with_index(query_index)?; let distribution = embed(&mut query, index_scheduler.get_ref(), &index)
.await
.with_index(query_index)?;
let search_result = let search_result = tokio::task::spawn_blocking(move || {
tokio::task::spawn_blocking(move || perform_search(&index, query, features)) perform_search(&index, query, features, distribution)
})
.await .await
.with_index(query_index)?; .with_index(query_index)?;

View File

@ -13,6 +13,7 @@ use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::heed::RoTxn; use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid; use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues}; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document}; use meilisearch_types::{milli, Document};
@ -90,16 +91,22 @@ pub struct SearchQuery {
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)] #[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery { pub struct HybridQuery {
/// TODO validate that sementic ratio is between 0.0 and 1,0 /// TODO validate that sementic ratio is between 0.0 and 1,0
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>)] #[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
pub semantic_ratio: SemanticRatio, pub semantic_ratio: SemanticRatio,
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)] #[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
pub embedder: Option<String>, pub embedder: Option<String>,
} }
#[derive(Debug, Clone, Copy, Default, PartialEq, Deserr)] #[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32); pub struct SemanticRatio(f32);
impl Default for SemanticRatio {
fn default() -> Self {
DEFAULT_SEMANTIC_RATIO()
}
}
impl std::convert::TryFrom<f32> for SemanticRatio { impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio; type Error = InvalidSearchSemanticRatio;
@ -374,6 +381,7 @@ fn prepare_search<'t>(
rtxn: &'t RoTxn, rtxn: &'t RoTxn,
query: &'t SearchQuery, query: &'t SearchQuery,
features: RoFeatures, features: RoFeatures,
distribution: Option<DistributionShift>,
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn); let mut search = index.search(rtxn);
@ -389,6 +397,8 @@ fn prepare_search<'t>(
return Err(MeilisearchHttpError::MissingSearchHybrid); return Err(MeilisearchHttpError::MissingSearchHybrid);
} }
search.distribution_shift(distribution);
if let Some(ref vector) = query.vector { if let Some(ref vector) = query.vector {
match &query.hybrid { match &query.hybrid {
// If semantic ratio is 0.0, only the query search will impact the search results, // If semantic ratio is 0.0, only the query search will impact the search results,
@ -482,12 +492,13 @@ pub fn perform_search(
index: &Index, index: &Index,
query: SearchQuery, query: SearchQuery,
features: RoFeatures, features: RoFeatures,
distribution: Option<DistributionShift>,
) -> Result<SearchResult, MeilisearchHttpError> { ) -> Result<SearchResult, MeilisearchHttpError> {
let before_search = Instant::now(); let before_search = Instant::now();
let rtxn = index.read_txn()?; let rtxn = index.read_txn()?;
let (search, is_finite_pagination, max_total_hits, offset) = let (search, is_finite_pagination, max_total_hits, offset) =
prepare_search(index, &rtxn, &query, features)?; prepare_search(index, &rtxn, &query, features, distribution)?;
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
match &query.hybrid { match &query.hybrid {
@ -718,8 +729,9 @@ pub fn perform_facet_search(
let before_search = Instant::now(); let before_search = Instant::now();
let rtxn = index.read_txn()?; let rtxn = index.read_txn()?;
let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features)?; let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?;
let mut facet_search = SearchForFacetValues::new(facet_name, search); let mut facet_search =
SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some());
if let Some(facet_query) = &facet_query { if let Some(facet_query) = &facet_query {
facet_search.query(facet_query); facet_search.query(facet_query);
} }

View File

@ -27,11 +27,11 @@ async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Inde
) )
.await; .await;
assert_eq!(202, code, "{:?}", response); assert_eq!(202, code, "{:?}", response);
index.wait_task(0).await; index.wait_task(response.uid()).await;
let (response, code) = index.add_documents(documents.clone(), None).await; let (response, code) = index.add_documents(documents.clone(), None).await;
assert_eq!(202, code, "{:?}", response); assert_eq!(202, code, "{:?}", response);
index.wait_task(1).await; index.wait_task(response.uid()).await;
index index
} }
@ -68,7 +68,7 @@ async fn simple_search() {
) )
.await; .await;
snapshot!(code, @"200 OK"); snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###);
let (response, code) = index let (response, code) = index
.search_post( .search_post(

View File

@ -154,6 +154,15 @@ impl<'a> Search<'a> {
self self
} }
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
if has_vector_search {
let ctx = SearchContext::new(self.index, self.rtxn);
filtered_universe(&ctx, &self.filter)
} else {
Ok(self.execute()?.candidates)
}
}
pub fn execute(&self) -> Result<SearchResult> { pub fn execute(&self) -> Result<SearchResult> {
let embedder_name; let embedder_name;
let embedder_name = match &self.embedder_name { let embedder_name = match &self.embedder_name {
@ -297,11 +306,16 @@ pub struct SearchForFacetValues<'a> {
query: Option<String>, query: Option<String>,
facet: String, facet: String,
search_query: Search<'a>, search_query: Search<'a>,
is_hybrid: bool,
} }
impl<'a> SearchForFacetValues<'a> { impl<'a> SearchForFacetValues<'a> {
pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> { pub fn new(
SearchForFacetValues { query: None, facet, search_query } facet: String,
search_query: Search<'a>,
is_hybrid: bool,
) -> SearchForFacetValues<'a> {
SearchForFacetValues { query: None, facet, search_query, is_hybrid }
} }
pub fn query(&mut self, query: impl Into<String>) -> &mut Self { pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
@ -351,7 +365,9 @@ impl<'a> SearchForFacetValues<'a> {
None => return Ok(vec![]), None => return Ok(vec![]),
}; };
let search_candidates = self.search_query.execute()?.candidates; let search_candidates = self
.search_query
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
match self.query.as_ref() { match self.query.as_ref() {
Some(query) => { Some(query) => {

View File

@ -509,7 +509,7 @@ where
// We write the primary key field id into the main database // We write the primary key field id into the main database
self.index.put_primary_key(self.wtxn, &primary_key)?; self.index.put_primary_key(self.wtxn, &primary_key)?;
let number_of_documents = self.index.number_of_documents(self.wtxn)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?;
let mut rng = rand::rngs::StdRng::from_entropy(); let mut rng = rand::rngs::StdRng::seed_from_u64(42);
for (embedder_name, dimension) in dimension { for (embedder_name, dimension) in dimension {
let wtxn = &mut *self.wtxn; let wtxn = &mut *self.wtxn;

View File

@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer}; use tokenizers::{PaddingParams, Tokenizer};
pub use super::error::{EmbedError, Error, NewEmbedderError}; pub use super::error::{EmbedError, Error, NewEmbedderError};
use super::{Embedding, Embeddings}; use super::{DistributionShift, Embedding, Embeddings};
#[derive( #[derive(
Debug, Debug,
@ -184,4 +184,12 @@ impl Embedder {
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
self.dimensions self.dimensions
} }
pub fn distribution(&self) -> Option<DistributionShift> {
if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 })
} else {
None
}
}
} }

View File

@ -202,6 +202,14 @@ impl Embedder {
Embedder::UserProvided(embedder) => embedder.dimensions(), Embedder::UserProvided(embedder) => embedder.dimensions(),
} }
} }
pub fn distribution(&self) -> Option<DistributionShift> {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
}
}
} }
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]

View File

@ -4,7 +4,7 @@ use reqwest::StatusCode;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::error::{EmbedError, NewEmbedderError}; use super::error::{EmbedError, NewEmbedderError};
use super::{Embedding, Embeddings}; use super::{DistributionShift, Embedding, Embeddings};
#[derive(Debug)] #[derive(Debug)]
pub struct Embedder { pub struct Embedder {
@ -65,6 +65,14 @@ impl EmbeddingModel {
_ => None, _ => None,
} }
} }
fn distribution(&self) -> Option<DistributionShift> {
match self {
EmbeddingModel::TextEmbeddingAda002 => {
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
}
}
}
} }
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
@ -326,6 +334,10 @@ impl Embedder {
pub fn dimensions(&self) -> usize { pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions() self.options.embedding_model.dimensions()
} }
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
}
} }
// retrying in case of failure // retrying in case of failure