Implement useful conversion strategies and clean up the code

This commit is contained in:
Kerollmops 2025-05-30 10:54:32 +02:00 committed by Clément Renault
parent 2821163b95
commit 50fafbbc8b
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
8 changed files with 224 additions and 203 deletions

View file

@ -16,6 +16,7 @@ use meilisearch_types::error::{Code, ResponseError};
use meilisearch_types::heed::RoTxn;
use meilisearch_types::index_uid::IndexUid;
use meilisearch_types::locales::Locale;
use meilisearch_types::milli::index::{self, SearchParameters};
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
use meilisearch_types::milli::vector::parsed_vectors::ExplicitVectors;
use meilisearch_types::milli::vector::Embedder;
@ -56,7 +57,7 @@ pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5);
#[derive(Clone, Default, PartialEq, Deserr, ToSchema)]
#[derive(Clone, PartialEq, Deserr, ToSchema)]
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
pub struct SearchQuery {
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
@ -119,6 +120,69 @@ pub struct SearchQuery {
pub locales: Option<Vec<Locale>>,
}
impl From<SearchParameters> for SearchQuery {
fn from(parameters: SearchParameters) -> Self {
let SearchParameters {
hybrid,
limit,
sort,
distinct,
matching_strategy,
attributes_to_search_on,
ranking_score_threshold,
} = parameters;
SearchQuery {
hybrid: hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| HybridQuery {
semantic_ratio: SemanticRatio::try_from(semantic_ratio)
.ok()
.unwrap_or_else(DEFAULT_SEMANTIC_RATIO),
embedder,
}),
limit: limit.unwrap_or_else(DEFAULT_SEARCH_LIMIT),
sort,
distinct,
matching_strategy: matching_strategy.map(MatchingStrategy::from).unwrap_or_default(),
attributes_to_search_on,
ranking_score_threshold: ranking_score_threshold.map(RankingScoreThreshold::from),
..Default::default()
}
}
}
impl Default for SearchQuery {
fn default() -> Self {
SearchQuery {
q: None,
vector: None,
hybrid: None,
offset: DEFAULT_SEARCH_OFFSET(),
limit: DEFAULT_SEARCH_LIMIT(),
page: None,
hits_per_page: None,
attributes_to_retrieve: None,
retrieve_vectors: false,
attributes_to_crop: None,
crop_length: DEFAULT_CROP_LENGTH(),
attributes_to_highlight: None,
show_matches_position: false,
show_ranking_score: false,
show_ranking_score_details: false,
filter: None,
sort: None,
distinct: None,
facets: None,
highlight_pre_tag: DEFAULT_HIGHLIGHT_PRE_TAG(),
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
crop_marker: DEFAULT_CROP_MARKER(),
matching_strategy: Default::default(),
attributes_to_search_on: None,
ranking_score_threshold: None,
locales: None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)]
pub struct RankingScoreThreshold(f64);
@ -137,6 +201,14 @@ impl std::convert::TryFrom<f64> for RankingScoreThreshold {
}
}
impl From<index::RankingScoreThreshold> for RankingScoreThreshold {
fn from(threshold: index::RankingScoreThreshold) -> Self {
let threshold = threshold.as_f64();
assert!(threshold >= 0.0 && threshold <= 1.0);
RankingScoreThreshold(threshold)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Deserr)]
#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSimilarRankingScoreThreshold)]
pub struct RankingScoreThresholdSimilar(f64);
@ -718,6 +790,16 @@ impl From<MatchingStrategy> for TermsMatchingStrategy {
}
}
impl From<index::MatchingStrategy> for MatchingStrategy {
fn from(other: index::MatchingStrategy) -> Self {
match other {
index::MatchingStrategy::Last => Self::Last,
index::MatchingStrategy::All => Self::All,
index::MatchingStrategy::Frequency => Self::Frequency,
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Deserr)]
#[deserr(rename_all = camelCase)]
pub enum FacetValuesSort {
@ -1261,7 +1343,7 @@ struct HitMaker<'a> {
vectors_fid: Option<FieldId>,
retrieve_vectors: RetrieveVectors,
to_retrieve_ids: BTreeSet<FieldId>,
embedding_configs: Vec<milli::index::IndexEmbeddingConfig>,
embedding_configs: Vec<index::IndexEmbeddingConfig>,
formatter_builder: MatcherBuilder<'a>,
formatted_options: BTreeMap<FieldId, FormatOptions>,
show_ranking_score: bool,