diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 62591e991..2182b1836 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -344,7 +344,10 @@ impl ErrorCode for milli::Error { Code::InvalidDocumentId } UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, - UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, + UserError::InvalidFieldForSource { .. } + | UserError::MissingFieldForSource { .. } + | UserError::InvalidOpenAiModel { .. } + | UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, diff --git a/meilisearch-types/src/settings.rs b/meilisearch-types/src/settings.rs index b0dee69a3..244fbffa2 100644 --- a/meilisearch-types/src/settings.rs +++ b/meilisearch-types/src/settings.rs @@ -318,6 +318,21 @@ impl Settings { _kind: PhantomData, } } + + pub fn validate(self) -> Result { + self.validate_embedding_settings() + } + + fn validate_embedding_settings(mut self) -> Result { + let Setting::Set(mut configs) = self.embedders else { return Ok(self) }; + for (name, config) in configs.iter_mut() { + let config_to_check = std::mem::take(config); + let checked_config = milli::update::validate_embedding_settings(config_to_check, name)?; + *config = checked_config + } + self.embedders = Setting::Set(configs); + Ok(self) + } } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/meilisearch/src/routes/indexes/settings.rs b/meilisearch/src/routes/indexes/settings.rs index 290cab2e0..feb174a1b 100644 --- a/meilisearch/src/routes/indexes/settings.rs +++ b/meilisearch/src/routes/indexes/settings.rs @@ -90,6 +90,8 @@ macro_rules! make_setting_route { ..Default::default() }; + let new_settings = new_settings.validate()?; + let allow_index_creation = index_scheduler.filters().allow_index_creation(&index_uid); @@ -582,13 +584,13 @@ fn embedder_analytics( for source in s .values() .filter_map(|config| config.clone().set()) - .filter_map(|config| config.embedder_options.set()) + .filter_map(|config| config.source.set()) { - use meilisearch_types::milli::vector::settings::EmbedderSettings; + use meilisearch_types::milli::vector::settings::EmbedderSource; match source { - EmbedderSettings::OpenAi(_) => sources.insert("openAi"), - EmbedderSettings::HuggingFace(_) => sources.insert("huggingFace"), - EmbedderSettings::UserProvided(_) => sources.insert("userProvided"), + EmbedderSource::OpenAi => sources.insert("openAi"), + EmbedderSource::HuggingFace => sources.insert("huggingFace"), + EmbedderSource::UserProvided => sources.insert("userProvided"), }; } }; @@ -651,6 +653,7 @@ pub async fn update_all( let index_uid = IndexUid::try_from(index_uid.into_inner())?; let new_settings = body.into_inner(); + let new_settings = new_settings.validate()?; analytics.publish( "Settings Updated".to_string(), diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs index fb6fe297f..77a29d4a3 100644 --- a/meilisearch/tests/search/hybrid.rs +++ b/meilisearch/tests/search/hybrid.rs @@ -21,9 +21,9 @@ async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Inde "###); let (response, code) = index - .update_settings( - json!({ "embedders": {"default": {"source": {"userProvided": {"dimensions": 2}}}} }), - ) + .update_settings(json!({ "embedders": {"default": { + "source": "userProvided", + "dimensions": 2}}} )) .await; assert_eq!(202, code, "{:?}", response); index.wait_task(response.uid()).await; diff --git a/meilisearch/tests/search/mod.rs b/meilisearch/tests/search/mod.rs index 133a143fd..9b7b01029 100644 --- a/meilisearch/tests/search/mod.rs +++ b/meilisearch/tests/search/mod.rs @@ -890,13 +890,21 @@ async fn experimental_feature_vector_store() { let (response, code) = index .update_settings(json!({"embedders": { "manual": { - "source": { - "userProvided": {"dimensions": 3} - } + "source": "userProvided", + "dimensions": 3, } }})) .await; + meili_snap::snapshot!(response, @r###" + { + "taskUid": 1, + "indexUid": "test", + "status": "enqueued", + "type": "settingsUpdate", + "enqueuedAt": "[date]" + } + "###); meili_snap::snapshot!(code, @"202 Accepted"); let response = index.wait_task(response.uid()).await; diff --git a/milli/src/error.rs b/milli/src/error.rs index 9c5d8f416..539861e73 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -192,7 +192,7 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), #[error(transparent)] InvalidPrompt(#[from] crate::prompt::error::NewPromptError), - #[error("Invalid prompt in for embeddings with name '{0}': {1}.")] + #[error("`.embedders.{0}.documentTemplate`: Invalid template: {1}.")] InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] TooManyEmbedders(usize), @@ -200,6 +200,33 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco InvalidEmbedder(String), #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] TooManyVectors(String, usize), + #[error("`.embedders.{embedder_name}`: Field `{field}` unavailable for source `{source_}` (only available for sources: {}). Available fields: {}", + allowed_sources_for_field + .iter() + .map(|accepted| format!("`{}`", accepted)) + .collect::>() + .join(", "), + allowed_fields_for_source + .iter() + .map(|accepted| format!("`{}`", accepted)) + .collect::>() + .join(", ") + )] + InvalidFieldForSource { + embedder_name: String, + source_: crate::vector::settings::EmbedderSource, + field: &'static str, + allowed_fields_for_source: &'static [&'static str], + allowed_sources_for_field: &'static [crate::vector::settings::EmbedderSource], + }, + #[error("`.embedders.{embedder_name}.model`: Invalid model `{model}` for OpenAI. Supported models: {:?}", crate::vector::openai::EmbeddingModel::supported_models())] + InvalidOpenAiModel { embedder_name: String, model: String }, + #[error("`.embedders.{embedder_name}`: Missing field `{field}` (note: this field is mandatory for source {source_})")] + MissingFieldForSource { + field: &'static str, + source_: crate::vector::settings::EmbedderSource, + embedder_name: String, + }, } impl From for Error { diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index ffc3f6b3a..738cfeb38 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -2553,7 +2553,7 @@ mod tests { /// Vectors must be of the same length. #[test] fn test_multiple_vectors() { - use crate::vector::settings::{EmbedderSettings, EmbeddingSettings}; + use crate::vector::settings::EmbeddingSettings; let index = TempIndex::new(); index @@ -2562,9 +2562,11 @@ mod tests { embedders.insert( "manual".to_string(), Setting::Set(EmbeddingSettings { - embedder_options: Setting::Set(EmbedderSettings::UserProvided( - crate::vector::settings::UserProvidedSettings { dimensions: 3 }, - )), + source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided), + model: Setting::NotSet, + revision: Setting::NotSet, + api_key: Setting::NotSet, + dimensions: Setting::Set(3), document_template: Setting::NotSet, }), ); @@ -2579,10 +2581,10 @@ mod tests { .unwrap(); index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap(); index - .add_documents( - documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]), - ) - .unwrap(); + .add_documents( + documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]), + ) + .unwrap(); let rtxn = index.read_txn().unwrap(); let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap(); diff --git a/milli/src/update/mod.rs b/milli/src/update/mod.rs index eb2b6e69a..66c52a52f 100644 --- a/milli/src/update/mod.rs +++ b/milli/src/update/mod.rs @@ -8,7 +8,7 @@ pub use self::index_documents::{ MergeFn, }; pub use self::indexer_config::IndexerConfig; -pub use self::settings::{Setting, Settings}; +pub use self::settings::{validate_embedding_settings, Setting, Settings}; pub use self::update_step::UpdateIndexingStep; pub use self::word_prefix_docids::WordPrefixDocids; pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids; diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index d406c121c..d770bcd74 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -17,7 +17,7 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; -use crate::vector::settings::{EmbeddingSettings, PromptSettings}; +use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings}; use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; use crate::{FieldsIdsMap, Index, OrderBy, Result}; @@ -78,11 +78,19 @@ impl Setting { } } - pub fn apply(&mut self, new: Self) { + /// Returns `true` if applying the new setting changed this setting + pub fn apply(&mut self, new: Self) -> bool + where + T: PartialEq + Eq, + { if let Setting::NotSet = new { - return; + return false; + } + if self == &new { + return false; } *self = new; + true } } @@ -950,17 +958,23 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) { match joined { + // updated config EitherOrBoth::Both((name, mut old), (_, new)) => { - old.apply(new); - let new = validate_prompt(&name, old)?; - changed = true; + changed |= old.apply(new); + let new = validate_embedding_settings(old, &name)?; new_configs.insert(name, new); } + // unchanged config EitherOrBoth::Left((name, setting)) => { new_configs.insert(name, setting); } - EitherOrBoth::Right((name, setting)) => { - let setting = validate_prompt(&name, setting)?; + // new config + EitherOrBoth::Right((name, mut setting)) => { + // apply the default source in case the source was not set so that it gets validated + crate::vector::settings::EmbeddingSettings::apply_default_source( + &mut setting, + ); + let setting = validate_embedding_settings(setting, &name)?; changed = true; new_configs.insert(name, setting); } @@ -1072,8 +1086,12 @@ fn validate_prompt( ) -> Result> { match new { Setting::Set(EmbeddingSettings { - embedder_options, - document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }), + source, + model, + revision, + api_key, + dimensions, + document_template: Setting::Set(template), }) => { // validate let template = crate::prompt::Prompt::new(template) @@ -1081,16 +1099,71 @@ fn validate_prompt( .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; Ok(Setting::Set(EmbeddingSettings { - embedder_options, - document_template: Setting::Set(PromptSettings { - template: Setting::Set(template), - }), + source, + model, + revision, + api_key, + dimensions, + document_template: Setting::Set(template), })) } new => Ok(new), } } +pub fn validate_embedding_settings( + settings: Setting, + name: &str, +) -> Result> { + let settings = validate_prompt(name, settings)?; + let Setting::Set(settings) = settings else { return Ok(settings) }; + let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = + settings; + let Some(inferred_source) = source.set() else { + return Ok(Setting::Set(EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + })); + }; + match inferred_source { + EmbedderSource::OpenAi => { + check_unset(&revision, "revision", inferred_source, name)?; + check_unset(&dimensions, "dimensions", inferred_source, name)?; + if let Setting::Set(model) = &model { + crate::vector::openai::EmbeddingModel::from_name(model.as_str()).ok_or( + crate::error::UserError::InvalidOpenAiModel { + embedder_name: name.to_owned(), + model: model.clone(), + }, + )?; + } + } + EmbedderSource::HuggingFace => { + check_unset(&api_key, "apiKey", inferred_source, name)?; + check_unset(&dimensions, "dimensions", inferred_source, name)?; + } + EmbedderSource::UserProvided => { + check_unset(&model, "model", inferred_source, name)?; + check_unset(&revision, "revision", inferred_source, name)?; + check_unset(&api_key, "apiKey", inferred_source, name)?; + check_unset(&document_template, "documentTemplate", inferred_source, name)?; + check_set(&dimensions, "dimensions", inferred_source, name)?; + } + } + Ok(Setting::Set(EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + })) +} + #[cfg(test)] mod tests { use big_s::S; diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index c11e6ddc6..53e8a041b 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -34,6 +34,9 @@ pub struct EmbedderOptions { #[serde(deny_unknown_fields, rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] pub enum EmbeddingModel { + // # WARNING + // + // If ever adding a model, make sure to add it to the list of supported models below. #[default] #[serde(rename = "text-embedding-ada-002")] #[deserr(rename = "text-embedding-ada-002")] @@ -41,6 +44,10 @@ pub enum EmbeddingModel { } impl EmbeddingModel { + pub fn supported_models() -> &'static [&'static str] { + &["text-embedding-ada-002"] + } + pub fn max_token(&self) -> usize { match self { EmbeddingModel::TextEmbeddingAda002 => 8191, @@ -59,7 +66,7 @@ impl EmbeddingModel { } } - pub fn from_name(name: &'static str) -> Option { + pub fn from_name(name: &str) -> Option { match name { "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), _ => None, diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 1826c040d..37fb80452 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -4,32 +4,189 @@ use serde::{Deserialize, Serialize}; use crate::prompt::PromptData; use crate::update::Setting; use crate::vector::EmbeddingConfig; +use crate::UserError; #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] #[serde(deny_unknown_fields, rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] pub struct EmbeddingSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")] - #[deserr(default, rename = "source")] - pub embedder_options: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] - pub document_template: Setting, + pub source: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub model: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub revision: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub api_key: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub dimensions: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub document_template: Setting, +} + +pub fn check_unset( + key: &Setting, + field: &'static str, + source: EmbedderSource, + embedder_name: &str, +) -> Result<(), UserError> { + if matches!(key, Setting::NotSet) { + Ok(()) + } else { + Err(UserError::InvalidFieldForSource { + embedder_name: embedder_name.to_owned(), + source_: source, + field, + allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source), + allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field), + }) + } +} + +pub fn check_set( + key: &Setting, + field: &'static str, + source: EmbedderSource, + embedder_name: &str, +) -> Result<(), UserError> { + if matches!(key, Setting::Set(_)) { + Ok(()) + } else { + Err(UserError::MissingFieldForSource { + field, + source_: source, + embedder_name: embedder_name.to_owned(), + }) + } +} + +impl EmbeddingSettings { + pub const SOURCE: &'static str = "source"; + pub const MODEL: &'static str = "model"; + pub const REVISION: &'static str = "revision"; + pub const API_KEY: &'static str = "apiKey"; + pub const DIMENSIONS: &'static str = "dimensions"; + pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; + + pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { + match field { + Self::SOURCE => { + &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided] + } + Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], + Self::REVISION => &[EmbedderSource::HuggingFace], + Self::API_KEY => &[EmbedderSource::OpenAi], + Self::DIMENSIONS => &[EmbedderSource::UserProvided], + Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], + _other => unreachable!("unknown field"), + } + } + + pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] { + match source { + EmbedderSource::OpenAi => { + &[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE] + } + EmbedderSource::HuggingFace => { + &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] + } + EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], + } + } + + pub(crate) fn apply_default_source(setting: &mut Setting) { + if let Setting::Set(EmbeddingSettings { + source: source @ (Setting::NotSet | Setting::Reset), + .. + }) = setting + { + *source = Setting::Set(EmbedderSource::default()) + } + } +} + +#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum EmbedderSource { + #[default] + OpenAi, + HuggingFace, + UserProvided, +} + +impl std::fmt::Display for EmbedderSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + EmbedderSource::OpenAi => "openAi", + EmbedderSource::HuggingFace => "huggingFace", + EmbedderSource::UserProvided => "userProvided", + }; + f.write_str(s) + } } impl EmbeddingSettings { pub fn apply(&mut self, new: Self) { - let EmbeddingSettings { embedder_options, document_template: prompt } = new; - self.embedder_options.apply(embedder_options); - self.document_template.apply(prompt); + let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = + new; + let old_source = self.source; + self.source.apply(source); + // Reinitialize the whole setting object on a source change + if old_source != self.source { + *self = EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + }; + return; + } + + self.model.apply(model); + self.revision.apply(revision); + self.api_key.apply(api_key); + self.dimensions.apply(dimensions); + self.document_template.apply(document_template); } } impl From for EmbeddingSettings { fn from(value: EmbeddingConfig) -> Self { - Self { - embedder_options: Setting::Set(value.embedder_options.into()), - document_template: Setting::Set(value.prompt.into()), + let EmbeddingConfig { embedder_options, prompt } = value; + match embedder_options { + super::EmbedderOptions::HuggingFace(options) => Self { + source: Setting::Set(EmbedderSource::HuggingFace), + model: Setting::Set(options.model), + revision: options.revision.map(Setting::Set).unwrap_or_default(), + api_key: Setting::NotSet, + dimensions: Setting::NotSet, + document_template: Setting::Set(prompt.template), + }, + super::EmbedderOptions::OpenAi(options) => Self { + source: Setting::Set(EmbedderSource::OpenAi), + model: Setting::Set(options.embedding_model.name().to_owned()), + revision: Setting::NotSet, + api_key: options.api_key.map(Setting::Set).unwrap_or_default(), + dimensions: Setting::NotSet, + document_template: Setting::Set(prompt.template), + }, + super::EmbedderOptions::UserProvided(options) => Self { + source: Setting::Set(EmbedderSource::UserProvided), + model: Setting::NotSet, + revision: Setting::NotSet, + api_key: Setting::NotSet, + dimensions: Setting::Set(options.dimensions), + document_template: Setting::NotSet, + }, } } } @@ -37,262 +194,51 @@ impl From for EmbeddingSettings { impl From for EmbeddingConfig { fn from(value: EmbeddingSettings) -> Self { let mut this = Self::default(); - let EmbeddingSettings { embedder_options, document_template: prompt } = value; - if let Some(embedder_options) = embedder_options.set() { - this.embedder_options = embedder_options.into(); - } - if let Some(prompt) = prompt.set() { - this.prompt = prompt.into(); - } - this - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct PromptSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub template: Setting, -} - -impl PromptSettings { - pub fn apply(&mut self, new: Self) { - let PromptSettings { template } = new; - self.template.apply(template); - } -} - -impl From for PromptSettings { - fn from(value: PromptData) -> Self { - Self { template: Setting::Set(value.template) } - } -} - -impl From for PromptData { - fn from(value: PromptSettings) -> Self { - let mut this = PromptData::default(); - let PromptSettings { template } = value; - if let Some(template) = template.set() { - this.template = template; - } - this - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -pub enum EmbedderSettings { - HuggingFace(Setting), - OpenAi(Setting), - UserProvided(UserProvidedSettings), -} - -impl Deserr for EmbedderSettings -where - E: deserr::DeserializeError, -{ - fn deserialize_from_value( - value: deserr::Value, - location: deserr::ValuePointerRef, - ) -> Result { - match value { - deserr::Value::Map(map) => { - if deserr::Map::len(&map) != 1 { - return Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::Unexpected { - msg: format!( - "Expected a single field, got {} fields", - deserr::Map::len(&map) - ), - }, - location, - ))); + let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = + value; + if let Some(source) = source.set() { + match source { + EmbedderSource::OpenAi => { + let mut options = super::openai::EmbedderOptions::with_default_model(None); + if let Some(model) = model.set() { + if let Some(model) = super::openai::EmbeddingModel::from_name(&model) { + options.embedding_model = model; + } + } + if let Some(api_key) = api_key.set() { + options.api_key = Some(api_key); + } + this.embedder_options = super::EmbedderOptions::OpenAi(options); } - let mut it = deserr::Map::into_iter(map); - let (k, v) = it.next().unwrap(); - - match k.as_str() { - "huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set( - HfEmbedderSettings::deserialize_from_value( - v.into_value(), - location.push_key(&k), - )?, - ))), - "openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set( - OpenAiEmbedderSettings::deserialize_from_value( - v.into_value(), - location.push_key(&k), - )?, - ))), - "userProvided" => Ok(EmbedderSettings::UserProvided( - UserProvidedSettings::deserialize_from_value( - v.into_value(), - location.push_key(&k), - )?, - )), - other => Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::UnknownKey { - key: other, - accepted: &["huggingFace", "openAi", "userProvided"], - }, - location, - ))), + EmbedderSource::HuggingFace => { + let mut options = super::hf::EmbedderOptions::default(); + if let Some(model) = model.set() { + options.model = model; + // Reset the revision if we are setting the model. + // This allows the following: + // "huggingFace": {} -> default model with default revision + // "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision + // "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision + options.revision = None; + } + if let Some(revision) = revision.set() { + options.revision = Some(revision); + } + this.embedder_options = super::EmbedderOptions::HuggingFace(options); + } + EmbedderSource::UserProvided => { + this.embedder_options = + super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions { + dimensions: dimensions.set().unwrap(), + }); } } - _ => Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::IncorrectValueKind { - actual: value, - accepted: &[deserr::ValueKind::Map], - }, - location, - ))), } - } -} -impl Default for EmbedderSettings { - fn default() -> Self { - Self::OpenAi(Default::default()) - } -} - -impl From for EmbedderSettings { - fn from(value: crate::vector::EmbedderOptions) -> Self { - match value { - crate::vector::EmbedderOptions::HuggingFace(hf) => { - Self::HuggingFace(Setting::Set(hf.into())) - } - crate::vector::EmbedderOptions::OpenAi(openai) => { - Self::OpenAi(Setting::Set(openai.into())) - } - crate::vector::EmbedderOptions::UserProvided(user_provided) => { - Self::UserProvided(user_provided.into()) - } + if let Setting::Set(template) = document_template { + this.prompt = PromptData { template } } - } -} -impl From for crate::vector::EmbedderOptions { - fn from(value: EmbedderSettings) -> Self { - match value { - EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), - EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), - EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), - EmbedderSettings::OpenAi(_setting) => { - Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None)) - } - EmbedderSettings::UserProvided(user_provided) => { - Self::UserProvided(user_provided.into()) - } - } - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct HfEmbedderSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub model: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub revision: Setting, -} - -impl HfEmbedderSettings { - pub fn apply(&mut self, new: Self) { - let HfEmbedderSettings { model, revision } = new; - self.model.apply(model); - self.revision.apply(revision); - } -} - -impl From for HfEmbedderSettings { - fn from(value: crate::vector::hf::EmbedderOptions) -> Self { - Self { - model: Setting::Set(value.model), - revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), - } - } -} - -impl From for crate::vector::hf::EmbedderOptions { - fn from(value: HfEmbedderSettings) -> Self { - let HfEmbedderSettings { model, revision } = value; - let mut this = Self::default(); - if let Some(model) = model.set() { - this.model = model; - // Reset the revision if we are setting the model. - // This allows the following: - // "huggingFace": {} -> default model with default revision - // "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision - // "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision - this.revision = None; - } - if let Some(revision) = revision.set() { - this.revision = Some(revision); - } this } } - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct OpenAiEmbedderSettings { - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub api_key: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")] - #[deserr(default, rename = "model")] - pub embedding_model: Setting, -} - -impl OpenAiEmbedderSettings { - pub fn apply(&mut self, new: Self) { - let Self { api_key, embedding_model: embedding_mode } = new; - self.api_key.apply(api_key); - self.embedding_model.apply(embedding_mode); - } -} - -impl From for OpenAiEmbedderSettings { - fn from(value: crate::vector::openai::EmbedderOptions) -> Self { - Self { - api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset), - embedding_model: Setting::Set(value.embedding_model), - } - } -} - -impl From for crate::vector::openai::EmbedderOptions { - fn from(value: OpenAiEmbedderSettings) -> Self { - let OpenAiEmbedderSettings { api_key, embedding_model } = value; - Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() } - } -} - -#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub struct UserProvidedSettings { - pub dimensions: usize, -} - -impl From for crate::vector::manual::EmbedderOptions { - fn from(value: UserProvidedSettings) -> Self { - Self { dimensions: value.dimensions } - } -} - -impl From for UserProvidedSettings { - fn from(value: crate::vector::manual::EmbedderOptions) -> Self { - Self { dimensions: value.dimensions } - } -}