From 045a1b1e756d9911332e49cc1e38394dcde671c0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 22 May 2025 15:34:49 +0200 Subject: [PATCH] Introduce a lot of search parameters and make Deserr happy --- crates/meilisearch-auth/src/lib.rs | 1 + crates/meilisearch-types/src/deserr/mod.rs | 18 +- crates/meilisearch-types/src/settings.rs | 2 +- crates/meilisearch-types/src/task_view.rs | 4 +- crates/meilisearch-types/src/tasks.rs | 2 +- crates/meilisearch/src/routes/chat.rs | 65 ++++-- crates/meilisearch/src/search/mod.rs | 5 +- crates/milli/src/index.rs | 43 +++- crates/milli/src/update/chat.rs | 222 ++++++++++++++++++++- crates/milli/src/update/settings.rs | 86 +++++++- 10 files changed, 411 insertions(+), 37 deletions(-) diff --git a/crates/meilisearch-auth/src/lib.rs b/crates/meilisearch-auth/src/lib.rs index d72ba386c..a19ad7b8c 100644 --- a/crates/meilisearch-auth/src/lib.rs +++ b/crates/meilisearch-auth/src/lib.rs @@ -165,6 +165,7 @@ impl AuthController { } } +#[derive(Debug)] pub struct AuthFilter { search_rules: Option, key_authorized_indexes: SearchRules, diff --git a/crates/meilisearch-types/src/deserr/mod.rs b/crates/meilisearch-types/src/deserr/mod.rs index f5ad18d5c..f1470c201 100644 --- a/crates/meilisearch-types/src/deserr/mod.rs +++ b/crates/meilisearch-types/src/deserr/mod.rs @@ -4,9 +4,12 @@ use std::marker::PhantomData; use std::ops::ControlFlow; use deserr::errors::{JsonError, QueryParamError}; -use deserr::{take_cf_content, DeserializeError, IntoValue, MergeWithError, ValuePointerRef}; +use deserr::{ + take_cf_content, DeserializeError, Deserr, IntoValue, MergeWithError, ValuePointerRef, +}; +use milli::update::ChatSettings; -use crate::error::deserr_codes::*; +use crate::error::deserr_codes::{self, *}; use crate::error::{ Code, DeserrParseBoolError, DeserrParseIntError, ErrorCode, InvalidTaskDateError, ParseOffsetDateTimeError, @@ -33,6 +36,7 @@ pub struct DeserrError { pub code: Code, _phantom: PhantomData<(Format, C)>, } + impl DeserrError { pub fn new(msg: String, code: Code) -> Self { Self { msg, code, _phantom: PhantomData } @@ -117,6 +121,16 @@ impl DeserializeError for DeserrQueryParamError { } } +impl Deserr> for ChatSettings { + fn deserialize_from_value( + value: deserr::Value, + location: ValuePointerRef, + ) -> Result> { + Deserr::::deserialize_from_value(value, location) + .map_err(|e| DeserrError::new(e.to_string(), InvalidSettingsIndexChat.error_code())) + } +} + pub fn immutable_field_error(field: &str, accepted: &[&str], code: Code) -> DeserrJsonError { let msg = format!( "Immutable field `{field}`: expected one of {}", diff --git a/crates/meilisearch-types/src/settings.rs b/crates/meilisearch-types/src/settings.rs index 3f92c1c52..48c8c9769 100644 --- a/crates/meilisearch-types/src/settings.rs +++ b/crates/meilisearch-types/src/settings.rs @@ -186,7 +186,7 @@ impl Deserr for SettingEmbeddingSettings { /// Holds all the settings for an index. `T` can either be `Checked` if they represents settings /// whose validity is guaranteed, or `Unchecked` if they need to be validated. In the later case, a /// call to `check` will return a `Settings` from a `Settings`. -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)] +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)] #[serde( deny_unknown_fields, rename_all = "camelCase", diff --git a/crates/meilisearch-types/src/task_view.rs b/crates/meilisearch-types/src/task_view.rs index 7a6faee39..86a00426b 100644 --- a/crates/meilisearch-types/src/task_view.rs +++ b/crates/meilisearch-types/src/task_view.rs @@ -8,7 +8,7 @@ use crate::error::ResponseError; use crate::settings::{Settings, Unchecked}; use crate::tasks::{serialize_duration, Details, IndexSwap, Kind, Status, Task, TaskId}; -#[derive(Debug, Clone, PartialEq, Eq, Serialize, ToSchema)] +#[derive(Debug, Clone, PartialEq, Serialize, ToSchema)] #[serde(rename_all = "camelCase")] #[schema(rename_all = "camelCase")] pub struct TaskView { @@ -67,7 +67,7 @@ impl TaskView { } } -#[derive(Default, Debug, PartialEq, Eq, Clone, Serialize, Deserialize, ToSchema)] +#[derive(Default, Debug, PartialEq, Clone, Serialize, Deserialize, ToSchema)] #[serde(rename_all = "camelCase")] #[schema(rename_all = "camelCase")] pub struct DetailsView { diff --git a/crates/meilisearch-types/src/tasks.rs b/crates/meilisearch-types/src/tasks.rs index 6e10f2606..102f6c3e1 100644 --- a/crates/meilisearch-types/src/tasks.rs +++ b/crates/meilisearch-types/src/tasks.rs @@ -597,7 +597,7 @@ impl fmt::Display for ParseTaskKindError { } impl std::error::Error for ParseTaskKindError {} -#[derive(Debug, PartialEq, Eq, Clone, Serialize, Deserialize)] +#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] pub enum Details { DocumentAdditionOrUpdate { received_documents: u64, diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 0f50aafac..b8c3eee29 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -26,7 +26,7 @@ use meilisearch_auth::AuthController; use meilisearch_types::error::ResponseError; use meilisearch_types::heed::RoTxn; use meilisearch_types::keys::actions; -use meilisearch_types::milli::index::ChatConfig; +use meilisearch_types::milli::index::{self, ChatConfig, SearchParameters}; use meilisearch_types::milli::prompt::{Prompt, PromptData}; use meilisearch_types::milli::update::new::document::DocumentFromDb; use meilisearch_types::milli::update::Setting; @@ -46,12 +46,12 @@ use crate::extractors::authentication::{extract_token_from_request, GuardedData, use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::routes::indexes::search::search_kind; use crate::search::{ - add_search_rules, prepare_search, search_from_kind, HybridQuery, MatchingStrategy, SearchQuery, - SemanticRatio, + add_search_rules, prepare_search, search_from_kind, HybridQuery, MatchingStrategy, + RankingScoreThreshold, SearchQuery, SemanticRatio, DEFAULT_SEARCH_LIMIT, + DEFAULT_SEMANTIC_RATIO, }; use crate::search_queue::SearchQueue; -const EMBEDDER_NAME: &str = "openai"; const SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -168,14 +168,43 @@ async fn process_search_request( index_uid: String, q: Option, ) -> Result<(Index, String), ResponseError> { + // TBD + // let mut aggregate = SearchAggregator::::from_query(&query); + + let index = index_scheduler.index(&index_uid)?; + let rtxn = index.static_read_txn()?; + let ChatConfig { description: _, prompt: _, search_parameters } = index.chat_config(&rtxn)?; + let SearchParameters { + hybrid, + limit, + sort, + distinct, + matching_strategy, + attributes_to_search_on, + ranking_score_threshold, + } = search_parameters; + let mut query = SearchQuery { q, - hybrid: Some(HybridQuery { - semantic_ratio: SemanticRatio::default(), - embedder: EMBEDDER_NAME.to_string(), + hybrid: hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| HybridQuery { + semantic_ratio: SemanticRatio::try_from(semantic_ratio) + .ok() + .unwrap_or_else(DEFAULT_SEMANTIC_RATIO), + embedder, }), - limit: 20, - matching_strategy: MatchingStrategy::Frequency, + limit: limit.unwrap_or_else(DEFAULT_SEARCH_LIMIT), + sort: sort, + distinct: distinct, + matching_strategy: matching_strategy + .map(|ms| match ms { + index::MatchingStrategy::Last => MatchingStrategy::Last, + index::MatchingStrategy::All => MatchingStrategy::All, + index::MatchingStrategy::Frequency => MatchingStrategy::Frequency, + }) + .unwrap_or(MatchingStrategy::Frequency), + attributes_to_search_on: attributes_to_search_on, + ranking_score_threshold: ranking_score_threshold + .and_then(|rst| RankingScoreThreshold::try_from(rst).ok()), ..Default::default() }; @@ -189,19 +218,13 @@ async fn process_search_request( if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) { add_search_rules(&mut query.filter, search_rules); } - - // TBD - // let mut aggregate = SearchAggregator::::from_query(&query); - - let index = index_scheduler.index(&index_uid)?; let search_kind = search_kind(&query, index_scheduler.get_ref(), index_uid.to_string(), &index)?; let permit = search_queue.try_get_search_permit().await?; let features = index_scheduler.features(); let index_cloned = index.clone(); - let search_result = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> { - let rtxn = index_cloned.read_txn()?; + let output = tokio::task::spawn_blocking(move || -> Result<_, ResponseError> { let time_budget = match index_cloned .search_cutoff(&rtxn) .map_err(|e| MeilisearchHttpError::from_milli(e, Some(index_uid.clone())))? @@ -214,14 +237,14 @@ async fn process_search_request( prepare_search(&index_cloned, &rtxn, &query, &search_kind, time_budget, features)?; search_from_kind(index_uid, search_kind, search) - .map(|(search_results, _)| search_results) + .map(|(search_results, _)| (rtxn, search_results)) .map_err(ResponseError::from) }) .await; permit.drop().await; - let search_result = search_result?; - if let Ok(ref search_result) = search_result { + let output = output?; + if let Ok((_, ref search_result)) = output { // aggregate.succeed(search_result); if search_result.degraded { MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); @@ -229,8 +252,8 @@ async fn process_search_request( } // analytics.publish(aggregate, &req); - let search_result = search_result?; - let rtxn = index.read_txn()?; + let (rtxn, search_result) = output?; + // let rtxn = index.read_txn()?; let render_alloc = Bump::new(); let formatted = format_documents(&rtxn, &index, &render_alloc, search_result.documents_ids)?; let text = formatted.join("\n"); diff --git a/crates/meilisearch/src/search/mod.rs b/crates/meilisearch/src/search/mod.rs index 16d04cd58..848591a4f 100644 --- a/crates/meilisearch/src/search/mod.rs +++ b/crates/meilisearch/src/search/mod.rs @@ -122,6 +122,7 @@ pub struct SearchQuery { #[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize)] #[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)] pub struct RankingScoreThreshold(f64); + impl std::convert::TryFrom for RankingScoreThreshold { type Error = InvalidSearchRankingScoreThreshold; @@ -279,8 +280,8 @@ impl fmt::Debug for SearchQuery { #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] #[serde(rename_all = "camelCase")] pub struct HybridQuery { - #[deserr(default, error = DeserrJsonError, default)] - #[schema(value_type = f32, default)] + #[deserr(default, error = DeserrJsonError)] + #[schema(default, value_type = f32)] #[serde(default)] pub semantic_ratio: SemanticRatio, #[deserr(error = DeserrJsonError)] diff --git a/crates/milli/src/index.rs b/crates/milli/src/index.rs index a5145cb0b..e6f28d02e 100644 --- a/crates/milli/src/index.rs +++ b/crates/milli/src/index.rs @@ -1695,7 +1695,7 @@ impl Index { pub fn chat_config(&self, txn: &RoTxn<'_>) -> heed::Result { self.main - .remap_types::>() + .remap_types::>() .get(txn, main_key::CHAT) .map(|o| o.unwrap_or_default()) } @@ -1705,7 +1705,7 @@ impl Index { txn: &mut RwTxn<'_>, val: &ChatConfig, ) -> heed::Result<()> { - self.main.remap_types::>().put(txn, main_key::CHAT, &val) + self.main.remap_types::>().put(txn, main_key::CHAT, &val) } pub(crate) fn delete_chat_config(&self, txn: &mut RwTxn<'_>) -> heed::Result { @@ -1943,15 +1943,54 @@ pub struct ChatConfig { pub description: String, /// Contains the document template and max template length. pub prompt: PromptData, + pub search_parameters: SearchParameters, +} + +#[derive(Debug, Default, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct SearchParameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub hybrid: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub limit: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub sort: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub distinct: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub matching_strategy: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub attributes_to_search_on: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub ranking_score_threshold: Option, +} + +#[derive(Debug, Default, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct HybridQuery { + pub semantic_ratio: f32, + pub embedder: String, } #[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] pub struct PrefixSettings { pub prefix_count_threshold: usize, pub max_prefix_length: usize, pub compute_prefixes: PrefixSearch, } +#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub enum MatchingStrategy { + /// Remove query words from last to first + Last, + /// All query words are mandatory + All, + /// Remove query words from the most frequent to the least + Frequency, +} + #[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)] #[serde(rename_all = "camelCase")] pub enum PrefixSearch { diff --git a/crates/milli/src/update/chat.rs b/crates/milli/src/update/chat.rs index 44e646f6d..b8fbc582d 100644 --- a/crates/milli/src/update/chat.rs +++ b/crates/milli/src/update/chat.rs @@ -1,14 +1,19 @@ +use std::error::Error; +use std::fmt; + +use deserr::errors::JsonError; use deserr::Deserr; use serde::{Deserialize, Serialize}; use utoipa::ToSchema; -use crate::index::ChatConfig; +use crate::index::{self, ChatConfig, SearchParameters}; use crate::prompt::{default_max_bytes, PromptData}; use crate::update::Setting; +use crate::TermsMatchingStrategy; -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr, ToSchema)] +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)] #[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(deny_unknown_fields, rename_all = camelCase)] +#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)] pub struct ChatSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] @@ -29,17 +34,226 @@ pub struct ChatSettings { #[deserr(default)] #[schema(value_type = Option)] pub document_template_max_bytes: Setting, + + /// The search parameters to use for the LLM. + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + #[schema(value_type = Option)] + pub search_parameters: Setting, } impl From for ChatSettings { fn from(config: ChatConfig) -> Self { - let ChatConfig { description, prompt: PromptData { template, max_bytes } } = config; + let ChatConfig { + description, + prompt: PromptData { template, max_bytes }, + search_parameters, + } = config; ChatSettings { description: Setting::Set(description), document_template: Setting::Set(template), document_template_max_bytes: Setting::Set( max_bytes.unwrap_or(default_max_bytes()).get(), ), + search_parameters: Setting::Set({ + let SearchParameters { + hybrid, + limit, + sort, + distinct, + matching_strategy, + attributes_to_search_on, + ranking_score_threshold, + } = search_parameters; + + let hybrid = hybrid.map(|index::HybridQuery { semantic_ratio, embedder }| { + HybridQuery { semantic_ratio: SemanticRatio(semantic_ratio), embedder } + }); + + let matching_strategy = matching_strategy.map(|ms| match ms { + index::MatchingStrategy::Last => MatchingStrategy::Last, + index::MatchingStrategy::All => MatchingStrategy::All, + index::MatchingStrategy::Frequency => MatchingStrategy::Frequency, + }); + + let ranking_score_threshold = ranking_score_threshold.map(RankingScoreThreshold); + + ChatSearchParams { + hybrid: Setting::some_or_not_set(hybrid), + limit: Setting::some_or_not_set(limit), + sort: Setting::some_or_not_set(sort), + distinct: Setting::some_or_not_set(distinct), + matching_strategy: Setting::some_or_not_set(matching_strategy), + attributes_to_search_on: Setting::some_or_not_set(attributes_to_search_on), + ranking_score_threshold: Setting::some_or_not_set(ranking_score_threshold), + } + }), } } } + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Deserr, ToSchema)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(error = JsonError, deny_unknown_fields, rename_all = camelCase)] +pub struct ChatSearchParams { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + #[schema(value_type = Option)] + pub hybrid: Setting, + + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default = Setting::Set(20))] + #[schema(value_type = Option)] + pub limit: Setting, + + // #[serde(default, skip_serializing_if = "Setting::is_not_set")] + // #[deserr(default)] + // pub attributes_to_retrieve: Option>, + + // #[serde(default, skip_serializing_if = "Setting::is_not_set")] + // #[deserr(default)] + // pub filter: Option, + // + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + #[schema(value_type = Option>)] + pub sort: Setting>, + + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + #[schema(value_type = Option)] + pub distinct: Setting, + + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + #[schema(value_type = Option)] + pub matching_strategy: Setting, + + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + #[schema(value_type = Option>)] + pub attributes_to_search_on: Setting>, + + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + #[schema(value_type = Option)] + pub ranking_score_threshold: Setting, +} + +#[derive(Debug, Clone, Default, Deserr, ToSchema, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +#[deserr(error = JsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct HybridQuery { + #[deserr(default)] + #[serde(default)] + #[schema(default, value_type = f32)] + pub semantic_ratio: SemanticRatio, + #[schema(value_type = String)] + pub embedder: String, +} + +#[derive(Debug, Clone, Copy, Deserr, ToSchema, PartialEq, Serialize, Deserialize)] +#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] +pub struct SemanticRatio(f32); + +impl Default for SemanticRatio { + fn default() -> Self { + SemanticRatio(0.5) + } +} + +impl std::convert::TryFrom for SemanticRatio { + type Error = InvalidSearchSemanticRatio; + + fn try_from(f: f32) -> Result { + // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable + #[allow(clippy::manual_range_contains)] + if f > 1.0 || f < 0.0 { + Err(InvalidSearchSemanticRatio) + } else { + Ok(SemanticRatio(f)) + } + } +} + +#[derive(Debug)] +pub struct InvalidSearchSemanticRatio; + +impl Error for InvalidSearchSemanticRatio {} + +impl fmt::Display for InvalidSearchSemanticRatio { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`." + ) + } +} + +impl std::ops::Deref for SemanticRatio { + type Target = f32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Deserr, ToSchema, Serialize, Deserialize)] +#[deserr(rename_all = camelCase)] +#[serde(rename_all = "camelCase")] +pub enum MatchingStrategy { + /// Remove query words from last to first + Last, + /// All query words are mandatory + All, + /// Remove query words from the most frequent to the least + Frequency, +} + +impl Default for MatchingStrategy { + fn default() -> Self { + Self::Last + } +} + +impl From for TermsMatchingStrategy { + fn from(other: MatchingStrategy) -> Self { + match other { + MatchingStrategy::Last => Self::Last, + MatchingStrategy::All => Self::All, + MatchingStrategy::Frequency => Self::Frequency, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Deserr, ToSchema, Serialize, Deserialize)] +#[deserr(try_from(f64) = TryFrom::try_from -> InvalidSearchRankingScoreThreshold)] +pub struct RankingScoreThreshold(pub f64); + +impl std::convert::TryFrom for RankingScoreThreshold { + type Error = InvalidSearchRankingScoreThreshold; + + fn try_from(f: f64) -> Result { + // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable + #[allow(clippy::manual_range_contains)] + if f > 1.0 || f < 0.0 { + Err(InvalidSearchRankingScoreThreshold) + } else { + Ok(RankingScoreThreshold(f)) + } + } +} + +#[derive(Debug)] +pub struct InvalidSearchRankingScoreThreshold; + +impl Error for InvalidSearchRankingScoreThreshold {} + +impl fmt::Display for InvalidSearchRankingScoreThreshold { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `rankingScoreThreshold` is invalid, expected a float between `0.0` and `1.0`." + ) + } +} diff --git a/crates/milli/src/update/settings.rs b/crates/milli/src/update/settings.rs index 50220bd27..c9cb72eaa 100644 --- a/crates/milli/src/update/settings.rs +++ b/crates/milli/src/update/settings.rs @@ -11,6 +11,7 @@ use roaring::RoaringBitmap; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use time::OffsetDateTime; +use super::chat::{ChatSearchParams, RankingScoreThreshold}; use super::del_add::{DelAdd, DelAddOperation}; use super::index_documents::{IndexDocumentsConfig, Transform}; use super::{ChatSettings, IndexerConfig}; @@ -22,8 +23,8 @@ use crate::error::UserError; use crate::fields_ids_map::metadata::{FieldIdMapWithMetadata, MetadataBuilder}; use crate::filterable_attributes_rules::match_faceted_field; use crate::index::{ - ChatConfig, IndexEmbeddingConfig, PrefixSearch, DEFAULT_MIN_WORD_LEN_ONE_TYPO, - DEFAULT_MIN_WORD_LEN_TWO_TYPOS, + ChatConfig, IndexEmbeddingConfig, MatchingStrategy, PrefixSearch, + DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS, }; use crate::order_by_map::OrderByMap; use crate::prompt::{default_max_bytes, PromptData}; @@ -1264,11 +1265,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { description: new_description, document_template: new_document_template, document_template_max_bytes: new_document_template_max_bytes, + search_parameters: new_search_parameters, }) => { let mut old = self.index.chat_config(self.wtxn)?; let ChatConfig { ref mut description, prompt: PromptData { ref mut template, ref mut max_bytes }, + ref mut search_parameters, } = old; match new_description { @@ -1289,6 +1292,85 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Setting::NotSet => (), } + match new_search_parameters { + Setting::Set(sp) => { + let ChatSearchParams { + hybrid, + limit, + sort, + distinct, + matching_strategy, + attributes_to_search_on, + ranking_score_threshold, + } = sp; + + match hybrid { + Setting::Set(hybrid) => { + search_parameters.hybrid = Some(crate::index::HybridQuery { + semantic_ratio: *hybrid.semantic_ratio, + embedder: hybrid.embedder.clone(), + }) + } + Setting::Reset => search_parameters.hybrid = None, + Setting::NotSet => (), + } + + match limit { + Setting::Set(limit) => search_parameters.limit = Some(*limit), + Setting::Reset => search_parameters.limit = None, + Setting::NotSet => (), + } + + match sort { + Setting::Set(sort) => search_parameters.sort = Some(sort.clone()), + Setting::Reset => search_parameters.sort = None, + Setting::NotSet => (), + } + + match distinct { + Setting::Set(distinct) => { + search_parameters.distinct = Some(distinct.clone()) + } + Setting::Reset => search_parameters.distinct = None, + Setting::NotSet => (), + } + + match matching_strategy { + Setting::Set(matching_strategy) => { + let strategy = match matching_strategy { + super::chat::MatchingStrategy::Last => MatchingStrategy::Last, + super::chat::MatchingStrategy::All => MatchingStrategy::All, + super::chat::MatchingStrategy::Frequency => { + MatchingStrategy::Frequency + } + }; + search_parameters.matching_strategy = Some(strategy) + } + Setting::Reset => search_parameters.matching_strategy = None, + Setting::NotSet => (), + } + + match attributes_to_search_on { + Setting::Set(attributes_to_search_on) => { + search_parameters.attributes_to_search_on = + Some(attributes_to_search_on.clone()) + } + Setting::Reset => search_parameters.attributes_to_search_on = None, + Setting::NotSet => (), + } + + match ranking_score_threshold { + Setting::Set(RankingScoreThreshold(score)) => { + search_parameters.ranking_score_threshold = Some(*score) + } + Setting::Reset => search_parameters.ranking_score_threshold = None, + Setting::NotSet => (), + } + } + Setting::Reset => *search_parameters = Default::default(), + Setting::NotSet => (), + } + self.index.put_chat_config(self.wtxn, &old)?; Ok(true) }