Lazily embed, don't fail hybrid search on embedding failure

This commit is contained in:
Louis Dureuil 2024-03-28 11:50:53 +01:00
parent fabc9cf14a
commit 6ebb6b55a6
No known key found for this signature in database
11 changed files with 237 additions and 203 deletions

View file

@ -12,6 +12,7 @@ use tracing::debug;
use crate::analytics::{Analytics, FacetSearchAggregator};
use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData;
use crate::routes::indexes::search::search_kind;
use crate::search::{
add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
@ -73,9 +74,10 @@ pub async fn search(
let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features();
let search_kind = search_kind(&search_query, &index_scheduler, &index)?;
let _permit = search_queue.try_get_search_permit().await?;
let search_result = tokio::task::spawn_blocking(move || {
perform_facet_search(&index, search_query, facet_query, facet_name, features)
perform_facet_search(&index, search_query, facet_query, facet_name, features, search_kind)
})
.await?;

View file

@ -8,19 +8,19 @@ use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli;
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value;
use tracing::{debug, warn};
use tracing::debug;
use crate::analytics::{Analytics, SearchAggregator};
use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::policies::*;
use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler;
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::search::{
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio,
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchKind, SearchQuery,
SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO,
};
use crate::search_queue::SearchQueue;
@ -204,11 +204,11 @@ pub async fn search_with_url_query(
let index = index_scheduler.index(&index_uid)?;
let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;
let search_kind = search_kind(&query, index_scheduler.get_ref(), &index)?;
let _permit = search_queue.try_get_search_permit().await?;
let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
tokio::task::spawn_blocking(move || perform_search(&index, query, features, search_kind))
.await?;
if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result);
@ -245,11 +245,11 @@ pub async fn search_with_post(
let features = index_scheduler.features();
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;
let search_kind = search_kind(&query, index_scheduler.get_ref(), &index)?;
let _permit = search_queue.try_get_search_permit().await?;
let search_result =
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
tokio::task::spawn_blocking(move || perform_search(&index, query, features, search_kind))
.await?;
if let Ok(ref search_result) = search_result {
aggregate.succeed(search_result);
@ -265,76 +265,49 @@ pub async fn search_with_post(
Ok(HttpResponse::Ok().json(search_result))
}
pub fn embed(
query: &mut SearchQuery,
pub fn search_kind(
query: &SearchQuery,
index_scheduler: &IndexScheduler,
index: &milli::Index,
) -> Result<Option<DistributionShift>, ResponseError> {
match (&query.hybrid, &query.vector, &query.q) {
(Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q))
if !q.trim().is_empty() =>
{
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder = if let Some(embedder_name) = embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
let distribution = embedder.distribution();
let embeddings = embedder
.embed(vec![q.to_owned()])
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?
.pop()
.expect("No vector returned from embedding");
if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query");
query.vector = Some(embeddings.iter().next().unwrap().to_vec());
} else {
query.vector = Some(embeddings.into_inner());
}
Ok(distribution)
) -> Result<SearchKind, ResponseError> {
// regardless of anything, always do a semantic search when we don't have a vector and the query is whitespace or missing
if query.vector.is_none() {
match &query.q {
Some(q) if q.trim().is_empty() => return Ok(SearchKind::KeywordOnly),
None => return Ok(SearchKind::KeywordOnly),
_ => {}
}
(Some(hybrid), vector, _) => {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
}
let embedder = if let Some(embedder_name) = &hybrid.embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
if let Some(vector) = vector {
if vector.len() != embedder.dimensions() {
return Err(meilisearch_types::milli::Error::UserError(
meilisearch_types::milli::UserError::InvalidVectorDimensions {
expected: embedder.dimensions(),
found: vector.len(),
},
)
.into());
}
}
Ok(embedder.distribution())
match &query.hybrid {
Some(HybridQuery { semantic_ratio, embedder }) if **semantic_ratio == 1.0 => {
Ok(SearchKind::semantic(
index_scheduler,
index,
embedder.as_deref(),
query.vector.as_ref().map(Vec::len),
)?)
}
_ => Ok(None),
Some(HybridQuery { semantic_ratio, embedder: _ }) if **semantic_ratio == 0.0 => {
Ok(SearchKind::KeywordOnly)
}
Some(HybridQuery { semantic_ratio, embedder }) => Ok(SearchKind::hybrid(
index_scheduler,
index,
embedder.as_deref(),
**semantic_ratio,
query.vector.as_ref().map(Vec::len),
)?),
None => match (query.q.as_deref(), query.vector.as_deref()) {
(_query, None) => Ok(SearchKind::KeywordOnly),
(None, Some(_vector)) => Ok(SearchKind::semantic(
index_scheduler,
index,
None,
query.vector.as_ref().map(Vec::len),
)?),
(Some(_), Some(_)) => Err(MeilisearchHttpError::MissingSearchHybrid.into()),
},
}
}

View file

@ -13,7 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator};
use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{AuthenticationError, GuardedData};
use crate::extractors::sequential_extractor::SeqHandler;
use crate::routes::indexes::search::embed;
use crate::routes::indexes::search::search_kind;
use crate::search::{
add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex,
};
@ -81,11 +81,11 @@ pub async fn multi_search_with_post(
})
.with_index(query_index)?;
let distribution =
embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?;
let search_kind =
search_kind(&query, index_scheduler.get_ref(), &index).with_index(query_index)?;
let search_result = tokio::task::spawn_blocking(move || {
perform_search(&index, query, features, distribution)
perform_search(&index, query, features, search_kind)
})
.await
.with_index(query_index)?;

View file

@ -1,6 +1,7 @@
use std::cmp::min;
use std::collections::{BTreeMap, BTreeSet, HashSet};
use std::str::FromStr;
use std::sync::Arc;
use std::time::{Duration, Instant};
use deserr::Deserr;
@ -10,10 +11,11 @@ use indexmap::IndexMap;
use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError;
use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::vector::DistributionShift;
use meilisearch_types::milli::vector::Embedder;
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, TimeBudget};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document};
@ -90,13 +92,75 @@ pub struct SearchQuery {
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
#[deserr(error = DeserrJsonError<InvalidHybridQuery>, rename_all = camelCase, deny_unknown_fields)]
pub struct HybridQuery {
/// TODO validate that sementic ratio is between 0.0 and 1,0
#[deserr(default, error = DeserrJsonError<InvalidSearchSemanticRatio>, default)]
pub semantic_ratio: SemanticRatio,
#[deserr(default, error = DeserrJsonError<InvalidEmbedder>, default)]
pub embedder: Option<String>,
}
pub enum SearchKind {
KeywordOnly,
SemanticOnly { embedder_name: String, embedder: Arc<Embedder> },
Hybrid { embedder_name: String, embedder: Arc<Embedder>, semantic_ratio: f32 },
}
impl SearchKind {
pub(crate) fn semantic(
index_scheduler: &index_scheduler::IndexScheduler,
index: &Index,
embedder_name: Option<&str>,
vector_len: Option<usize>,
) -> Result<Self, ResponseError> {
let (embedder_name, embedder) =
Self::embedder(index_scheduler, index, embedder_name, vector_len)?;
Ok(Self::SemanticOnly { embedder_name, embedder })
}
pub(crate) fn hybrid(
index_scheduler: &index_scheduler::IndexScheduler,
index: &Index,
embedder_name: Option<&str>,
semantic_ratio: f32,
vector_len: Option<usize>,
) -> Result<Self, ResponseError> {
let (embedder_name, embedder) =
Self::embedder(index_scheduler, index, embedder_name, vector_len)?;
Ok(Self::Hybrid { embedder_name, embedder, semantic_ratio })
}
fn embedder(
index_scheduler: &index_scheduler::IndexScheduler,
index: &Index,
embedder_name: Option<&str>,
vector_len: Option<usize>,
) -> Result<(String, Arc<Embedder>), ResponseError> {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder_name = embedder_name.unwrap_or_else(|| embedders.get_default_embedder_name());
let embedder = embedders.get(embedder_name);
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned()))
.map_err(milli::Error::from)?
.0;
if let Some(vector_len) = vector_len {
if vector_len != embedder.dimensions() {
return Err(meilisearch_types::milli::Error::UserError(
meilisearch_types::milli::UserError::InvalidVectorDimensions {
expected: embedder.dimensions(),
found: vector_len,
},
)
.into());
}
}
Ok((embedder_name.to_owned(), embedder))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)]
pub struct SemanticRatio(f32);
@ -385,7 +449,7 @@ fn prepare_search<'t>(
rtxn: &'t RoTxn,
query: &'t SearchQuery,
features: RoFeatures,
distribution: Option<DistributionShift>,
search_kind: &SearchKind,
time_budget: TimeBudget,
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn);
@ -399,32 +463,30 @@ fn prepare_search<'t>(
features.check_vector("Passing `hybrid` as a query parameter")?;
}
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() {
return Err(MeilisearchHttpError::MissingSearchHybrid);
}
search.distribution_shift(distribution);
if let Some(ref vector) = query.vector {
match &query.hybrid {
// If semantic ratio is 0.0, only the query search will impact the search results,
// skip the vector
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
_otherwise => {
search.vector(vector.clone());
}
}
}
if let Some(ref q) = query.q {
match &query.hybrid {
// If semantic ratio is 1.0, only the vector search will impact the search results,
// skip the query
Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (),
_otherwise => {
match search_kind {
SearchKind::KeywordOnly => {
if let Some(q) = &query.q {
search.query(q);
}
}
SearchKind::SemanticOnly { embedder_name, embedder } => {
let vector = match query.vector.clone() {
Some(vector) => vector,
None => embedder
.embed_one(query.q.clone().unwrap())
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?,
};
search.semantic(embedder_name.clone(), embedder.clone(), Some(vector));
}
SearchKind::Hybrid { embedder_name, embedder, semantic_ratio: _ } => {
if let Some(q) = &query.q {
search.query(q);
}
// will be embedded in hybrid search if necessary
search.semantic(embedder_name.clone(), embedder.clone(), query.vector.clone());
}
}
if let Some(ref searchable) = query.attributes_to_search_on {
@ -447,10 +509,6 @@ fn prepare_search<'t>(
ScoringStrategy::Skip
});
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
search.embedder_name(embedder);
}
// compute the offset on the limit depending on the pagination mode.
let (offset, limit) = if is_finite_pagination {
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
@ -494,7 +552,7 @@ pub fn perform_search(
index: &Index,
query: SearchQuery,
features: RoFeatures,
distribution: Option<DistributionShift>,
search_kind: SearchKind,
) -> Result<SearchResult, MeilisearchHttpError> {
let before_search = Instant::now();
let rtxn = index.read_txn()?;
@ -504,7 +562,7 @@ pub fn perform_search(
};
let (search, is_finite_pagination, max_total_hits, offset) =
prepare_search(index, &rtxn, &query, features, distribution, time_budget)?;
prepare_search(index, &rtxn, &query, features, &search_kind, time_budget)?;
let milli::SearchResult {
documents_ids,
@ -514,12 +572,9 @@ pub fn perform_search(
degraded,
used_negative_operator,
..
} = match &query.hybrid {
Some(hybrid) => match *hybrid.semantic_ratio {
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
ratio => search.execute_hybrid(ratio)?,
},
None => search.execute()?,
} = match &search_kind {
SearchKind::KeywordOnly | SearchKind::SemanticOnly { .. } => search.execute()?,
SearchKind::Hybrid { semantic_ratio, .. } => search.execute_hybrid(*semantic_ratio)?,
};
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
@ -726,6 +781,7 @@ pub fn perform_facet_search(
facet_query: Option<String>,
facet_name: String,
features: RoFeatures,
search_kind: SearchKind,
) -> Result<FacetSearchResult, MeilisearchHttpError> {
let before_search = Instant::now();
let rtxn = index.read_txn()?;
@ -735,9 +791,12 @@ pub fn perform_facet_search(
};
let (search, _, _, _) =
prepare_search(index, &rtxn, &search_query, features, None, time_budget)?;
let mut facet_search =
SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some());
prepare_search(index, &rtxn, &search_query, features, &search_kind, time_budget)?;
let mut facet_search = SearchForFacetValues::new(
facet_name,
search,
matches!(search_kind, SearchKind::Hybrid { .. }),
);
if let Some(facet_query) = &facet_query {
facet_search.query(facet_query);
}