hybrid search uses semantic ratio, error handling

This commit is contained in:
Louis Dureuil 2023-12-14 12:42:37 +01:00
parent 1b7c164a55
commit 217105b7da
No known key found for this signature in database
10 changed files with 89 additions and 316 deletions

View file

@ -692,7 +692,7 @@ impl SearchAggregator {
ret.max_terms_number = q.split_whitespace().count();
}
if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector {
if let Some(ref vector) = vector {
ret.max_vector_size = vector.len();
}

View file

@ -51,6 +51,8 @@ pub enum MeilisearchHttpError {
DocumentFormat(#[from] DocumentFormatError),
#[error(transparent)]
Join(#[from] JoinError),
#[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")]
MissingSearchHybrid,
}
impl ErrorCode for MeilisearchHttpError {
@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError {
MeilisearchHttpError::FileStore(_) => Code::Internal,
MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
MeilisearchHttpError::Join(_) => Code::Internal,
MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid,
}
}
}

View file

@ -7,7 +7,6 @@ use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::VectorQuery;
use serde_json::Value;
use crate::analytics::{Analytics, FacetSearchAggregator};
@ -121,7 +120,7 @@ impl From<FacetSearchQuery> for SearchQuery {
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
crop_marker: DEFAULT_CROP_MARKER(),
matching_strategy,
vector: vector.map(VectorQuery::Vector),
vector,
attributes_to_search_on,
hybrid,
}

View file

@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::error::ResponseError;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::{self, VectorQuery};
use meilisearch_types::milli;
use meilisearch_types::serde_cs::vec::CS;
use serde_json::Value;
@ -128,7 +128,7 @@ impl From<SearchQueryGet> for SearchQuery {
Self {
q: other.q,
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector),
vector: other.vector.map(CS::into_inner),
offset: other.offset.0,
limit: other.limit.0,
page: other.page.as_deref().copied(),
@ -258,49 +258,37 @@ pub async fn embed(
index_scheduler: &IndexScheduler,
index: &milli::Index,
) -> Result<(), ResponseError> {
match query.vector.take() {
Some(VectorQuery::String(prompt)) => {
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) =
(&query.vector, &query.q, &query.hybrid)
{
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
let embedders = index_scheduler.embedders(embedder_configs)?;
let embedder_name =
if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
&query.hybrid
{
Some(embedder)
} else {
None
};
let embedder = if let Some(embedder_name) = embedder {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = if let Some(embedder_name) = embedder_name {
embedders.get(embedder_name)
} else {
embedders.get_default()
};
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
let embeddings = embedder
.embed(vec![q.to_owned()])
.await
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?
.pop()
.expect("No vector returned from embedding");
let embedder = embedder
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
.map_err(milli::Error::from)?
.0;
let embeddings = embedder
.embed(vec![prompt])
.await
.map_err(milli::vector::Error::from)
.map_err(milli::Error::from)?
.pop()
.expect("No vector returned from embedding");
if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query");
query.vector =
Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
} else {
query.vector = Some(VectorQuery::Vector(embeddings.into_inner()));
}
if embeddings.iter().nth(1).is_some() {
warn!("Ignoring embeddings past the first one in long search query");
query.vector = Some(embeddings.iter().next().unwrap().to_vec());
} else {
query.vector = Some(embeddings.into_inner());
}
Some(vector) => query.vector = Some(vector),
None => {}
};
}
Ok(())
}

View file

@ -7,14 +7,13 @@ use deserr::Deserr;
use either::Either;
use index_scheduler::RoFeatures;
use indexmap::IndexMap;
use log::warn;
use meilisearch_auth::IndexSearchRules;
use meilisearch_types::deserr::DeserrJsonError;
use meilisearch_types::error::deserr_codes::*;
use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery};
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
use meilisearch_types::{milli, Document};
use milli::tokenizer::TokenizerBuilder;
@ -44,7 +43,7 @@ pub struct SearchQuery {
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
pub vector: Option<milli::VectorQuery>,
pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
@ -105,6 +104,8 @@ impl std::convert::TryFrom<f32> for SemanticRatio {
type Error = InvalidSearchSemanticRatio;
fn try_from(f: f32) -> Result<Self, Self::Error> {
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
#[allow(clippy::manual_range_contains)]
if f > 1.0 || f < 0.0 {
Err(InvalidSearchSemanticRatio)
} else {
@ -139,7 +140,7 @@ pub struct SearchQueryWithIndex {
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub q: Option<String>,
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
pub vector: Option<VectorQuery>,
pub vector: Option<Vec<f32>>,
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
pub hybrid: Option<HybridQuery>,
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
@ -376,8 +377,16 @@ fn prepare_search<'t>(
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
let mut search = index.search(rtxn);
if query.vector.is_some() && query.q.is_some() {
warn!("Attempting hybrid search");
if query.vector.is_some() {
features.check_vector("Passing `vector` as a query parameter")?;
}
if query.hybrid.is_some() {
features.check_vector("Passing `hybrid` as a query parameter")?;
}
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() {
return Err(MeilisearchHttpError::MissingSearchHybrid);
}
if let Some(ref vector) = query.vector {
@ -385,14 +394,9 @@ fn prepare_search<'t>(
// If semantic ratio is 0.0, only the query search will impact the search results,
// skip the vector
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
_otherwise => match vector {
VectorQuery::Vector(vector) => {
search.vector(vector.clone());
}
VectorQuery::String(_) => {
panic!("Failed while preparing search; caller did not generate embedding for query")
}
},
_otherwise => {
search.vector(vector.clone());
}
}
}
@ -431,10 +435,6 @@ fn prepare_search<'t>(
features.check_score_details()?;
}
if query.vector.is_some() {
features.check_vector("Passing `vector` as a query parameter")?;
}
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
search.embedder_name(embedder);
}
@ -492,7 +492,7 @@ pub fn perform_search(
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
match &query.hybrid {
Some(hybrid) => match *hybrid.semantic_ratio {
0.0 | 1.0 => search.execute()?,
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
ratio => search.execute_hybrid(ratio)?,
},
None => search.execute()?,
@ -700,10 +700,7 @@ pub fn perform_search(
hits: documents,
hits_info,
query: query.q.unwrap_or_default(),
vector: match query.vector {
Some(VectorQuery::Vector(vector)) => Some(vector),
_ => None,
},
vector: query.vector,
processing_time_ms: before_search.elapsed().as_millis(),
facet_distribution,
facet_stats,