mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-22 20:50:04 +01:00
hybrid search uses semantic ratio, error handling
This commit is contained in:
parent
1b7c164a55
commit
217105b7da
@ -235,7 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchFacets , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchFilter , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ;
|
||||
@ -299,6 +299,7 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ;
|
||||
MissingIndexUid , InvalidRequest , BAD_REQUEST ;
|
||||
MissingMasterKey , Auth , UNAUTHORIZED ;
|
||||
MissingPayload , InvalidRequest , BAD_REQUEST ;
|
||||
MissingSearchHybrid , InvalidRequest , BAD_REQUEST ;
|
||||
MissingSwapIndexes , InvalidRequest , BAD_REQUEST ;
|
||||
MissingTaskFilters , InvalidRequest , BAD_REQUEST ;
|
||||
NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY;
|
||||
|
@ -692,7 +692,7 @@ impl SearchAggregator {
|
||||
ret.max_terms_number = q.split_whitespace().count();
|
||||
}
|
||||
|
||||
if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector {
|
||||
if let Some(ref vector) = vector {
|
||||
ret.max_vector_size = vector.len();
|
||||
}
|
||||
|
||||
|
@ -51,6 +51,8 @@ pub enum MeilisearchHttpError {
|
||||
DocumentFormat(#[from] DocumentFormatError),
|
||||
#[error(transparent)]
|
||||
Join(#[from] JoinError),
|
||||
#[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")]
|
||||
MissingSearchHybrid,
|
||||
}
|
||||
|
||||
impl ErrorCode for MeilisearchHttpError {
|
||||
@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError {
|
||||
MeilisearchHttpError::FileStore(_) => Code::Internal,
|
||||
MeilisearchHttpError::DocumentFormat(e) => e.error_code(),
|
||||
MeilisearchHttpError::Join(_) => Code::Internal,
|
||||
MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::VectorQuery;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::analytics::{Analytics, FacetSearchAggregator};
|
||||
@ -121,7 +120,7 @@ impl From<FacetSearchQuery> for SearchQuery {
|
||||
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
|
||||
crop_marker: DEFAULT_CROP_MARKER(),
|
||||
matching_strategy,
|
||||
vector: vector.map(VectorQuery::Vector),
|
||||
vector,
|
||||
attributes_to_search_on,
|
||||
hybrid,
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::{self, VectorQuery};
|
||||
use meilisearch_types::milli;
|
||||
use meilisearch_types::serde_cs::vec::CS;
|
||||
use serde_json::Value;
|
||||
|
||||
@ -128,7 +128,7 @@ impl From<SearchQueryGet> for SearchQuery {
|
||||
|
||||
Self {
|
||||
q: other.q,
|
||||
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector),
|
||||
vector: other.vector.map(CS::into_inner),
|
||||
offset: other.offset.0,
|
||||
limit: other.limit.0,
|
||||
page: other.page.as_deref().copied(),
|
||||
@ -258,49 +258,37 @@ pub async fn embed(
|
||||
index_scheduler: &IndexScheduler,
|
||||
index: &milli::Index,
|
||||
) -> Result<(), ResponseError> {
|
||||
match query.vector.take() {
|
||||
Some(VectorQuery::String(prompt)) => {
|
||||
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
||||
let embedders = index_scheduler.embedders(embedder_configs)?;
|
||||
if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) =
|
||||
(&query.vector, &query.q, &query.hybrid)
|
||||
{
|
||||
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
||||
let embedders = index_scheduler.embedders(embedder_configs)?;
|
||||
|
||||
let embedder_name =
|
||||
if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) =
|
||||
&query.hybrid
|
||||
{
|
||||
Some(embedder)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let embedder = if let Some(embedder_name) = embedder {
|
||||
embedders.get(embedder_name)
|
||||
} else {
|
||||
embedders.get_default()
|
||||
};
|
||||
|
||||
let embedder = if let Some(embedder_name) = embedder_name {
|
||||
embedders.get(embedder_name)
|
||||
} else {
|
||||
embedders.get_default()
|
||||
};
|
||||
let embedder = embedder
|
||||
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
|
||||
.map_err(milli::Error::from)?
|
||||
.0;
|
||||
let embeddings = embedder
|
||||
.embed(vec![q.to_owned()])
|
||||
.await
|
||||
.map_err(milli::vector::Error::from)
|
||||
.map_err(milli::Error::from)?
|
||||
.pop()
|
||||
.expect("No vector returned from embedding");
|
||||
|
||||
let embedder = embedder
|
||||
.ok_or(milli::UserError::InvalidEmbedder("default".to_owned()))
|
||||
.map_err(milli::Error::from)?
|
||||
.0;
|
||||
let embeddings = embedder
|
||||
.embed(vec![prompt])
|
||||
.await
|
||||
.map_err(milli::vector::Error::from)
|
||||
.map_err(milli::Error::from)?
|
||||
.pop()
|
||||
.expect("No vector returned from embedding");
|
||||
|
||||
if embeddings.iter().nth(1).is_some() {
|
||||
warn!("Ignoring embeddings past the first one in long search query");
|
||||
query.vector =
|
||||
Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
|
||||
} else {
|
||||
query.vector = Some(VectorQuery::Vector(embeddings.into_inner()));
|
||||
}
|
||||
if embeddings.iter().nth(1).is_some() {
|
||||
warn!("Ignoring embeddings past the first one in long search query");
|
||||
query.vector = Some(embeddings.iter().next().unwrap().to_vec());
|
||||
} else {
|
||||
query.vector = Some(embeddings.into_inner());
|
||||
}
|
||||
Some(vector) => query.vector = Some(vector),
|
||||
None => {}
|
||||
};
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -7,14 +7,13 @@ use deserr::Deserr;
|
||||
use either::Either;
|
||||
use index_scheduler::RoFeatures;
|
||||
use indexmap::IndexMap;
|
||||
use log::warn;
|
||||
use meilisearch_auth::IndexSearchRules;
|
||||
use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
|
||||
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery};
|
||||
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
|
||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||
use meilisearch_types::{milli, Document};
|
||||
use milli::tokenizer::TokenizerBuilder;
|
||||
@ -44,7 +43,7 @@ pub struct SearchQuery {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub q: Option<String>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
||||
pub vector: Option<milli::VectorQuery>,
|
||||
pub vector: Option<Vec<f32>>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
||||
pub hybrid: Option<HybridQuery>,
|
||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||
@ -105,6 +104,8 @@ impl std::convert::TryFrom<f32> for SemanticRatio {
|
||||
type Error = InvalidSearchSemanticRatio;
|
||||
|
||||
fn try_from(f: f32) -> Result<Self, Self::Error> {
|
||||
// the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable
|
||||
#[allow(clippy::manual_range_contains)]
|
||||
if f > 1.0 || f < 0.0 {
|
||||
Err(InvalidSearchSemanticRatio)
|
||||
} else {
|
||||
@ -139,7 +140,7 @@ pub struct SearchQueryWithIndex {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub q: Option<String>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub vector: Option<VectorQuery>,
|
||||
pub vector: Option<Vec<f32>>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidHybridQuery>)]
|
||||
pub hybrid: Option<HybridQuery>,
|
||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||
@ -376,8 +377,16 @@ fn prepare_search<'t>(
|
||||
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
|
||||
let mut search = index.search(rtxn);
|
||||
|
||||
if query.vector.is_some() && query.q.is_some() {
|
||||
warn!("Attempting hybrid search");
|
||||
if query.vector.is_some() {
|
||||
features.check_vector("Passing `vector` as a query parameter")?;
|
||||
}
|
||||
|
||||
if query.hybrid.is_some() {
|
||||
features.check_vector("Passing `hybrid` as a query parameter")?;
|
||||
}
|
||||
|
||||
if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() {
|
||||
return Err(MeilisearchHttpError::MissingSearchHybrid);
|
||||
}
|
||||
|
||||
if let Some(ref vector) = query.vector {
|
||||
@ -385,14 +394,9 @@ fn prepare_search<'t>(
|
||||
// If semantic ratio is 0.0, only the query search will impact the search results,
|
||||
// skip the vector
|
||||
Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (),
|
||||
_otherwise => match vector {
|
||||
VectorQuery::Vector(vector) => {
|
||||
search.vector(vector.clone());
|
||||
}
|
||||
VectorQuery::String(_) => {
|
||||
panic!("Failed while preparing search; caller did not generate embedding for query")
|
||||
}
|
||||
},
|
||||
_otherwise => {
|
||||
search.vector(vector.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -431,10 +435,6 @@ fn prepare_search<'t>(
|
||||
features.check_score_details()?;
|
||||
}
|
||||
|
||||
if query.vector.is_some() {
|
||||
features.check_vector("Passing `vector` as a query parameter")?;
|
||||
}
|
||||
|
||||
if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid {
|
||||
search.embedder_name(embedder);
|
||||
}
|
||||
@ -492,7 +492,7 @@ pub fn perform_search(
|
||||
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
|
||||
match &query.hybrid {
|
||||
Some(hybrid) => match *hybrid.semantic_ratio {
|
||||
0.0 | 1.0 => search.execute()?,
|
||||
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
|
||||
ratio => search.execute_hybrid(ratio)?,
|
||||
},
|
||||
None => search.execute()?,
|
||||
@ -700,10 +700,7 @@ pub fn perform_search(
|
||||
hits: documents,
|
||||
hits_info,
|
||||
query: query.q.unwrap_or_default(),
|
||||
vector: match query.vector {
|
||||
Some(VectorQuery::Vector(vector)) => Some(vector),
|
||||
_ => None,
|
||||
},
|
||||
vector: query.vector,
|
||||
processing_time_ms: before_search.elapsed().as_millis(),
|
||||
facet_distribution,
|
||||
facet_stats,
|
||||
|
@ -1,4 +1,4 @@
|
||||
use meili_snap::{json_string, snapshot};
|
||||
use meili_snap::snapshot;
|
||||
use once_cell::sync::Lazy;
|
||||
|
||||
use crate::common::index::Index;
|
||||
|
@ -59,7 +59,7 @@ pub use self::index::Index;
|
||||
pub use self::search::{
|
||||
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
|
||||
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
|
||||
VectorQuery, DEFAULT_VALUES_PER_FACET,
|
||||
DEFAULT_VALUES_PER_FACET,
|
||||
};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, error::Error>;
|
||||
|
@ -1,49 +1,37 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use itertools::Itertools;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::new::{execute_vector_search, PartialSearchResult};
|
||||
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
|
||||
use crate::{
|
||||
execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult,
|
||||
};
|
||||
use crate::{MatchingWords, Result, Search, SearchResult};
|
||||
|
||||
struct CombinedSearchResult {
|
||||
struct ScoreWithRatioResult {
|
||||
matching_words: MatchingWords,
|
||||
candidates: RoaringBitmap,
|
||||
document_scores: Vec<(u32, CombinedScore)>,
|
||||
document_scores: Vec<(u32, ScoreWithRatio)>,
|
||||
}
|
||||
|
||||
type CombinedScore = (Vec<ScoreDetails>, Option<Vec<ScoreDetails>>);
|
||||
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
|
||||
|
||||
fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering {
|
||||
let mut left_main_it = ScoreDetails::score_values(left.0.iter());
|
||||
let mut left_sub_it =
|
||||
ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten());
|
||||
|
||||
let mut right_main_it = ScoreDetails::score_values(right.0.iter());
|
||||
let mut right_sub_it =
|
||||
ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten());
|
||||
|
||||
let mut left_main = left_main_it.next();
|
||||
let mut left_sub = left_sub_it.next();
|
||||
let mut right_main = right_main_it.next();
|
||||
let mut right_sub = right_sub_it.next();
|
||||
fn compare_scores(
|
||||
&(ref left_scores, left_ratio): &ScoreWithRatio,
|
||||
&(ref right_scores, right_ratio): &ScoreWithRatio,
|
||||
) -> Ordering {
|
||||
let mut left_it = ScoreDetails::score_values(left_scores.iter());
|
||||
let mut right_it = ScoreDetails::score_values(right_scores.iter());
|
||||
|
||||
loop {
|
||||
let left =
|
||||
take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it);
|
||||
|
||||
let right =
|
||||
take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it);
|
||||
let left = left_it.next();
|
||||
let right = right_it.next();
|
||||
|
||||
match (left, right) {
|
||||
(None, None) => return Ordering::Equal,
|
||||
(None, Some(_)) => return Ordering::Less,
|
||||
(Some(_), None) => return Ordering::Greater,
|
||||
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
|
||||
let left = left * left_ratio as f64;
|
||||
let right = right * right_ratio as f64;
|
||||
if (left - right).abs() <= f64::EPSILON {
|
||||
continue;
|
||||
}
|
||||
@ -72,94 +60,17 @@ fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering {
|
||||
}
|
||||
}
|
||||
|
||||
fn take_best_score<'a>(
|
||||
main_score: &mut Option<ScoreValue<'a>>,
|
||||
sub_score: &mut Option<ScoreValue<'a>>,
|
||||
main_it: &mut impl Iterator<Item = ScoreValue<'a>>,
|
||||
sub_it: &mut impl Iterator<Item = ScoreValue<'a>>,
|
||||
) -> Option<ScoreValue<'a>> {
|
||||
match (*main_score, *sub_score) {
|
||||
(Some(main), None) => {
|
||||
*main_score = main_it.next();
|
||||
Some(main)
|
||||
}
|
||||
(None, Some(sub)) => {
|
||||
*sub_score = sub_it.next();
|
||||
Some(sub)
|
||||
}
|
||||
(main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => {
|
||||
// take max, both advance
|
||||
*main_score = main_it.next();
|
||||
*sub_score = sub_it.next();
|
||||
if main_f >= sub_v {
|
||||
main
|
||||
} else {
|
||||
sub
|
||||
}
|
||||
}
|
||||
(main @ Some(ScoreValue::Score(_)), _) => {
|
||||
*main_score = main_it.next();
|
||||
main
|
||||
}
|
||||
(_, sub @ Some(ScoreValue::Score(_))) => {
|
||||
*sub_score = sub_it.next();
|
||||
sub
|
||||
}
|
||||
(main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => {
|
||||
// take best advance both
|
||||
*main_score = main_it.next();
|
||||
*sub_score = sub_it.next();
|
||||
if main_geo >= sub_geo {
|
||||
main
|
||||
} else {
|
||||
sub
|
||||
}
|
||||
}
|
||||
(main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => {
|
||||
// take best advance both
|
||||
*main_score = main_it.next();
|
||||
*sub_score = sub_it.next();
|
||||
if main_sort >= sub_sort {
|
||||
main
|
||||
} else {
|
||||
sub
|
||||
}
|
||||
}
|
||||
(
|
||||
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
|
||||
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
|
||||
) => None,
|
||||
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl CombinedSearchResult {
|
||||
fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self {
|
||||
let mut docid_scores = HashMap::new();
|
||||
for (docid, score) in
|
||||
main_results.documents_ids.iter().zip(main_results.document_scores.into_iter())
|
||||
{
|
||||
docid_scores.insert(*docid, (score, None));
|
||||
}
|
||||
|
||||
for (docid, score) in ancillary_results
|
||||
impl ScoreWithRatioResult {
|
||||
fn new(results: SearchResult, ratio: f32) -> Self {
|
||||
let document_scores = results
|
||||
.documents_ids
|
||||
.iter()
|
||||
.zip(ancillary_results.document_scores.into_iter())
|
||||
{
|
||||
docid_scores
|
||||
.entry(*docid)
|
||||
.and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score));
|
||||
}
|
||||
|
||||
let mut document_scores: Vec<_> = docid_scores.into_iter().collect();
|
||||
|
||||
document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse());
|
||||
.into_iter()
|
||||
.zip(results.document_scores.into_iter().map(|scores| (scores, ratio)))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
matching_words: main_results.matching_words,
|
||||
candidates: main_results.candidates,
|
||||
matching_words: results.matching_words,
|
||||
candidates: results.candidates,
|
||||
document_scores,
|
||||
}
|
||||
}
|
||||
@ -200,7 +111,7 @@ impl CombinedSearchResult {
|
||||
}
|
||||
|
||||
impl<'a> Search<'a> {
|
||||
pub fn execute_hybrid(&self) -> Result<SearchResult> {
|
||||
pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result<SearchResult> {
|
||||
// TODO: find classier way to achieve that than to reset vector and query params
|
||||
// create separate keyword and semantic searches
|
||||
let mut search = Search {
|
||||
@ -223,8 +134,6 @@ impl<'a> Search<'a> {
|
||||
};
|
||||
|
||||
let vector_query = search.vector.take();
|
||||
let keyword_query = self.query.as_deref();
|
||||
|
||||
let keyword_results = search.execute()?;
|
||||
|
||||
// skip semantic search if we don't have a vector query (placeholder search)
|
||||
@ -233,7 +142,7 @@ impl<'a> Search<'a> {
|
||||
};
|
||||
|
||||
// completely skip semantic search if the results of the keyword search are good enough
|
||||
if self.results_good_enough(&keyword_results) {
|
||||
if self.results_good_enough(&keyword_results, semantic_ratio) {
|
||||
return Ok(keyword_results);
|
||||
}
|
||||
|
||||
@ -243,94 +152,18 @@ impl<'a> Search<'a> {
|
||||
// TODO: would be better to have two distinct functions at this point
|
||||
let vector_results = search.execute()?;
|
||||
|
||||
// Compute keyword scores for vector_results
|
||||
let keyword_results_for_vector =
|
||||
self.keyword_results_for_vector(keyword_query, &vector_results)?;
|
||||
|
||||
// compute vector scores for keyword_results
|
||||
let vector_results_for_keyword =
|
||||
// can unwrap because we returned already if there was no vector query
|
||||
self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?;
|
||||
|
||||
/// TODO apply sementic ratio
|
||||
let keyword_results =
|
||||
CombinedSearchResult::new(keyword_results, vector_results_for_keyword);
|
||||
let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector);
|
||||
let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio);
|
||||
let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio);
|
||||
|
||||
let merge_results =
|
||||
CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit);
|
||||
ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit);
|
||||
assert!(merge_results.documents_ids.len() <= self.limit);
|
||||
Ok(merge_results)
|
||||
}
|
||||
|
||||
fn vector_results_for_keyword(
|
||||
&self,
|
||||
vector: &[f32],
|
||||
keyword_results: &SearchResult,
|
||||
) -> Result<PartialSearchResult> {
|
||||
let embedder_name;
|
||||
let embedder_name = match &self.embedder_name {
|
||||
Some(embedder_name) => embedder_name,
|
||||
None => {
|
||||
embedder_name = self.index.default_embedding_name(self.rtxn)?;
|
||||
&embedder_name
|
||||
}
|
||||
};
|
||||
|
||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
||||
|
||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
||||
ctx.searchable_attributes(searchable_attributes)?;
|
||||
}
|
||||
|
||||
let universe = keyword_results.documents_ids.iter().collect();
|
||||
|
||||
execute_vector_search(
|
||||
&mut ctx,
|
||||
vector,
|
||||
ScoringStrategy::Detailed,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
0,
|
||||
self.limit + self.offset,
|
||||
self.distribution_shift,
|
||||
embedder_name,
|
||||
)
|
||||
}
|
||||
|
||||
fn keyword_results_for_vector(
|
||||
&self,
|
||||
query: Option<&str>,
|
||||
vector_results: &SearchResult,
|
||||
) -> Result<PartialSearchResult> {
|
||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
||||
|
||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
||||
ctx.searchable_attributes(searchable_attributes)?;
|
||||
}
|
||||
|
||||
let universe = vector_results.documents_ids.iter().collect();
|
||||
|
||||
execute_search(
|
||||
&mut ctx,
|
||||
query,
|
||||
self.terms_matching_strategy,
|
||||
ScoringStrategy::Detailed,
|
||||
self.exhaustive_number_hits,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
0,
|
||||
self.limit + self.offset,
|
||||
Some(self.words_limit),
|
||||
&mut DefaultSearchLogger,
|
||||
&mut DefaultSearchLogger,
|
||||
)
|
||||
}
|
||||
|
||||
fn results_good_enough(&self, keyword_results: &SearchResult) -> bool {
|
||||
const GOOD_ENOUGH_SCORE: f64 = 0.9;
|
||||
fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool {
|
||||
// A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5
|
||||
const GOOD_ENOUGH_SCORE: f64 = 0.45;
|
||||
|
||||
// 1. we check that we got a sufficient number of results
|
||||
if keyword_results.document_scores.len() < self.limit + self.offset {
|
||||
@ -341,7 +174,7 @@ impl<'a> Search<'a> {
|
||||
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
|
||||
for score in &keyword_results.document_scores {
|
||||
let score = ScoreDetails::global_score(score.iter());
|
||||
if score < GOOD_ENOUGH_SCORE {
|
||||
if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ use std::ops::ControlFlow;
|
||||
|
||||
use charabia::normalizer::NormalizerOption;
|
||||
use charabia::Normalize;
|
||||
use deserr::{DeserializeError, Deserr, Sequence};
|
||||
use fst::automaton::{Automaton, Str};
|
||||
use fst::{IntoStreamer, Streamer};
|
||||
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
|
||||
@ -57,53 +56,6 @@ pub struct Search<'a> {
|
||||
embedder_name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum VectorQuery {
|
||||
Vector(Vec<f32>),
|
||||
String(String),
|
||||
}
|
||||
|
||||
impl<E> Deserr<E> for VectorQuery
|
||||
where
|
||||
E: DeserializeError,
|
||||
{
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef,
|
||||
) -> std::result::Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::String(s) => Ok(VectorQuery::String(s)),
|
||||
deserr::Value::Sequence(seq) => {
|
||||
let v: std::result::Result<Vec<f32>, _> = seq
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, v)| match v.into_value() {
|
||||
deserr::Value::Float(f) => Ok(f as f32),
|
||||
deserr::Value::Integer(i) => Ok(i as f32),
|
||||
v => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: v,
|
||||
accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer],
|
||||
},
|
||||
location.push_index(index),
|
||||
))),
|
||||
})
|
||||
.collect();
|
||||
Ok(VectorQuery::Vector(v?))
|
||||
}
|
||||
_ => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Search<'a> {
|
||||
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
|
||||
Search {
|
||||
|
Loading…
x
Reference in New Issue
Block a user