diff --git a/milli/src/update/mod.rs b/milli/src/update/mod.rs index eb2b6e69a..66c52a52f 100644 --- a/milli/src/update/mod.rs +++ b/milli/src/update/mod.rs @@ -8,7 +8,7 @@ pub use self::index_documents::{ MergeFn, }; pub use self::indexer_config::IndexerConfig; -pub use self::settings::{Setting, Settings}; +pub use self::settings::{validate_embedding_settings, Setting, Settings}; pub use self::update_step::UpdateIndexingStep; pub use self::word_prefix_docids::WordPrefixDocids; pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids; diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index e47a5ad52..d770bcd74 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -17,7 +17,7 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; -use crate::vector::settings::{EmbeddingSettings, PromptSettings}; +use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings}; use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; use crate::{FieldsIdsMap, Index, OrderBy, Result}; @@ -958,17 +958,23 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) { match joined { + // updated config EitherOrBoth::Both((name, mut old), (_, new)) => { - old.apply(new); - let new = validate_prompt(&name, old)?; - changed = true; + changed |= old.apply(new); + let new = validate_embedding_settings(old, &name)?; new_configs.insert(name, new); } + // unchanged config EitherOrBoth::Left((name, setting)) => { new_configs.insert(name, setting); } - EitherOrBoth::Right((name, setting)) => { - let setting = validate_prompt(&name, setting)?; + // new config + EitherOrBoth::Right((name, mut setting)) => { + // apply the default source in case the source was not set so that it gets validated + crate::vector::settings::EmbeddingSettings::apply_default_source( + &mut setting, + ); + let setting = validate_embedding_settings(setting, &name)?; changed = true; new_configs.insert(name, setting); } @@ -1080,8 +1086,12 @@ fn validate_prompt( ) -> Result> { match new { Setting::Set(EmbeddingSettings { - embedder_options, - document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }), + source, + model, + revision, + api_key, + dimensions, + document_template: Setting::Set(template), }) => { // validate let template = crate::prompt::Prompt::new(template) @@ -1089,16 +1099,71 @@ fn validate_prompt( .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; Ok(Setting::Set(EmbeddingSettings { - embedder_options, - document_template: Setting::Set(PromptSettings { - template: Setting::Set(template), - }), + source, + model, + revision, + api_key, + dimensions, + document_template: Setting::Set(template), })) } new => Ok(new), } } +pub fn validate_embedding_settings( + settings: Setting, + name: &str, +) -> Result> { + let settings = validate_prompt(name, settings)?; + let Setting::Set(settings) = settings else { return Ok(settings) }; + let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = + settings; + let Some(inferred_source) = source.set() else { + return Ok(Setting::Set(EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + })); + }; + match inferred_source { + EmbedderSource::OpenAi => { + check_unset(&revision, "revision", inferred_source, name)?; + check_unset(&dimensions, "dimensions", inferred_source, name)?; + if let Setting::Set(model) = &model { + crate::vector::openai::EmbeddingModel::from_name(model.as_str()).ok_or( + crate::error::UserError::InvalidOpenAiModel { + embedder_name: name.to_owned(), + model: model.clone(), + }, + )?; + } + } + EmbedderSource::HuggingFace => { + check_unset(&api_key, "apiKey", inferred_source, name)?; + check_unset(&dimensions, "dimensions", inferred_source, name)?; + } + EmbedderSource::UserProvided => { + check_unset(&model, "model", inferred_source, name)?; + check_unset(&revision, "revision", inferred_source, name)?; + check_unset(&api_key, "apiKey", inferred_source, name)?; + check_unset(&document_template, "documentTemplate", inferred_source, name)?; + check_set(&dimensions, "dimensions", inferred_source, name)?; + } + } + Ok(Setting::Set(EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + })) +} + #[cfg(test)] mod tests { use big_s::S;