From 93dcbf598d67a37879b0e2ae7ecb7fc63145c30f Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Thu, 14 Dec 2023 10:21:10 +0100 Subject: [PATCH] Deserialize semantic ratio --- meilisearch-types/src/deserr/mod.rs | 1 + meilisearch-types/src/error.rs | 11 +++++++++- meilisearch/src/routes/indexes/search.rs | 8 +++---- meilisearch/src/search.rs | 28 ++++++++++++++++++------ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/meilisearch-types/src/deserr/mod.rs b/meilisearch-types/src/deserr/mod.rs index df304cc2f..537b24574 100644 --- a/meilisearch-types/src/deserr/mod.rs +++ b/meilisearch-types/src/deserr/mod.rs @@ -188,3 +188,4 @@ merge_with_error_impl_take_error_message!(ParseOffsetDateTimeError); merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(IndexUidFormatError); +merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 9df41b68f..1dc33b140 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -235,7 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; -InvalidSemanticRatio , InvalidRequest , BAD_REQUEST ; +InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; @@ -459,6 +459,15 @@ impl fmt::Display for DeserrParseIntError { } } +impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`." + ) + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 7a9a14687..ad7f0dc89 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -17,7 +17,7 @@ use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, + add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; @@ -75,10 +75,10 @@ pub struct SearchQueryGet { matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrQueryParamError)] pub attributes_to_search_on: Option>, - #[deserr(default, error = DeserrQueryParamError)] + #[deserr(default, error = DeserrQueryParamError)] pub hybrid_embedder: Option, - #[deserr(default, error = DeserrQueryParamError)] - pub hybrid_semantic_ratio: Option, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_semantic_ratio: Option, } impl From for SearchQuery { diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 7bf8ea160..674b6e25e 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -36,7 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10; pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); -pub const DEFAULT_SEMANTIC_RATIO: fn() -> f32 = || 0.5; +pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5); #[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] @@ -91,12 +91,27 @@ pub struct SearchQuery { #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct HybridQuery { /// TODO validate that sementic ratio is between 0.0 and 1,0 - #[deserr(default, error = DeserrJsonError, default = DEFAULT_SEMANTIC_RATIO())] - pub semantic_ratio: f32, + #[deserr(default, error = DeserrJsonError)] + pub semantic_ratio: SemanticRatio, #[deserr(default, error = DeserrJsonError, default)] pub embedder: Option, } +#[derive(Debug, Clone, Copy, Default, PartialEq, Deserr)] +#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] +pub struct SemanticRatio(f32); +impl std::convert::TryFrom for SemanticRatio { + type Error = InvalidSearchSemanticRatio; + + fn try_from(f: f32) -> Result { + if f > 1.0 || f < 0.0 { + Err(InvalidSearchSemanticRatio) + } else { + Ok(SemanticRatio(f)) + } + } +} + impl SearchQuery { pub fn is_finite_pagination(&self) -> bool { self.page.or(self.hits_per_page).is_some() @@ -457,10 +472,9 @@ pub fn perform_search( /// + < 1.0 or remove q /// + > 0.0 or remove vector let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = - if query.q.is_some() && query.vector.is_some() { - search.execute_hybrid()? - } else { - search.execute()? + match query.hybrid { + Some(_) => search.execute_hybrid()?, + None => search.execute()?, }; let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();