4275: Flatten settings r=dureuill a=dureuill

# Pull Request

## Related issue
Initial internal feedback seems to indicate that the current shape of the `embedders` setting is undesirable: it has too much depth.

This PR changes this by flattening the structure of the embedders to the following:

```json5
// NEW structure
"embedders": {
  // still starts with the embedder name
  "default": {
    "source": "huggingFace", // now a string
    // properties of the source are all at the same level as the source
    "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    "revision": "a9c555277f9bcf24f28fa5e092e665fc6f7c49cd",
    "documentTemplate": "A product titled '{{doc.title}}'" // now a string
  }
}
```

By comparison, the old structure was:

```json5
// PREVIOUS version, no longer working with this PR
"embedders": {
  // still starts with the embedder name
  "default": {
    "source": {
      "huggingFace": {
        "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
        "revision": "a9c555277f9bcf24f28fa5e092e665fc6f7c49cd"
      },
    "documentTemplate": { 
      "template": "A product titled '{{doc.title}}'" // now a string
    }
  }
}
```

The fields that are accepted in the new version of the `embedders` setting are depending on the value of the `source` field:

```json5
// huggingFace
"embedders": {
   "default": {
    "source": "huggingFace",
    "model": "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
    "revision": "a9c555277f9bcf24f28fa5e092e665fc6f7c49cd",
    "documentTemplate": "A product titled '{{doc.title}}'"
  }
}

// openAi
"embedders": {
   "default": {
    "source": "openAi",
    "model": "text-embedding-ada-002",
    "apiKey": "open_ai_api_key",
    "documentTemplate": "A product titled '{{doc.title}}'"
  }
}

// userProvided
"embedders": {
   "default": {
    "source": "userProvided",
    "dimensions": 42, // mandatory
  }
}
```

## What does this PR do?
- Flatten the settings structure
- Validate the prompt earlier to return a synchronous error on setting change rather than in the failing task
- Make it an error to pass a field for the wrong source (see above for allowed fields for each source)
- Not changed: It is still an error not to pass `dimensions` to the `userProvided` embedder
- If `source` was specified in the settings, validate the setting early to return a synchronous error in case of a missing mandatory field for the userProvided source (dimensions) or a forbidden field for the specified source.
- If `source` was not specified in the settings, still validate the setting, but only at indexing time, by using the source stored in the DB.
- Resets all values if the source changes, even if the user did not reset them explicitly.

## PR checklist
Please check if your PR fulfills the following requirements:
- [ ] Change the public facing guide for using the API
- [ ] Change examples of use in the changelog


Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2023-12-21 09:58:01 +00:00 committed by GitHub
commit d4cb0a885b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 380 additions and 296 deletions

View File

@ -344,7 +344,10 @@ impl ErrorCode for milli::Error {
Code::InvalidDocumentId Code::InvalidDocumentId
} }
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, UserError::InvalidFieldForSource { .. }
| UserError::MissingFieldForSource { .. }
| UserError::InvalidOpenAiModel { .. }
| UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,

View File

@ -318,6 +318,21 @@ impl Settings<Unchecked> {
_kind: PhantomData, _kind: PhantomData,
} }
} }
pub fn validate(self) -> Result<Self, milli::Error> {
self.validate_embedding_settings()
}
fn validate_embedding_settings(mut self) -> Result<Self, milli::Error> {
let Setting::Set(mut configs) = self.embedders else { return Ok(self) };
for (name, config) in configs.iter_mut() {
let config_to_check = std::mem::take(config);
let checked_config = milli::update::validate_embedding_settings(config_to_check, name)?;
*config = checked_config
}
self.embedders = Setting::Set(configs);
Ok(self)
}
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]

View File

@ -90,6 +90,8 @@ macro_rules! make_setting_route {
..Default::default() ..Default::default()
}; };
let new_settings = new_settings.validate()?;
let allow_index_creation = let allow_index_creation =
index_scheduler.filters().allow_index_creation(&index_uid); index_scheduler.filters().allow_index_creation(&index_uid);
@ -582,13 +584,13 @@ fn embedder_analytics(
for source in s for source in s
.values() .values()
.filter_map(|config| config.clone().set()) .filter_map(|config| config.clone().set())
.filter_map(|config| config.embedder_options.set()) .filter_map(|config| config.source.set())
{ {
use meilisearch_types::milli::vector::settings::EmbedderSettings; use meilisearch_types::milli::vector::settings::EmbedderSource;
match source { match source {
EmbedderSettings::OpenAi(_) => sources.insert("openAi"), EmbedderSource::OpenAi => sources.insert("openAi"),
EmbedderSettings::HuggingFace(_) => sources.insert("huggingFace"), EmbedderSource::HuggingFace => sources.insert("huggingFace"),
EmbedderSettings::UserProvided(_) => sources.insert("userProvided"), EmbedderSource::UserProvided => sources.insert("userProvided"),
}; };
} }
}; };
@ -651,6 +653,7 @@ pub async fn update_all(
let index_uid = IndexUid::try_from(index_uid.into_inner())?; let index_uid = IndexUid::try_from(index_uid.into_inner())?;
let new_settings = body.into_inner(); let new_settings = body.into_inner();
let new_settings = new_settings.validate()?;
analytics.publish( analytics.publish(
"Settings Updated".to_string(), "Settings Updated".to_string(),

View File

@ -21,9 +21,9 @@ async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Inde
"###); "###);
let (response, code) = index let (response, code) = index
.update_settings( .update_settings(json!({ "embedders": {"default": {
json!({ "embedders": {"default": {"source": {"userProvided": {"dimensions": 2}}}} }), "source": "userProvided",
) "dimensions": 2}}} ))
.await; .await;
assert_eq!(202, code, "{:?}", response); assert_eq!(202, code, "{:?}", response);
index.wait_task(response.uid()).await; index.wait_task(response.uid()).await;

View File

@ -890,13 +890,21 @@ async fn experimental_feature_vector_store() {
let (response, code) = index let (response, code) = index
.update_settings(json!({"embedders": { .update_settings(json!({"embedders": {
"manual": { "manual": {
"source": { "source": "userProvided",
"userProvided": {"dimensions": 3} "dimensions": 3,
}
} }
}})) }}))
.await; .await;
meili_snap::snapshot!(response, @r###"
{
"taskUid": 1,
"indexUid": "test",
"status": "enqueued",
"type": "settingsUpdate",
"enqueuedAt": "[date]"
}
"###);
meili_snap::snapshot!(code, @"202 Accepted"); meili_snap::snapshot!(code, @"202 Accepted");
let response = index.wait_task(response.uid()).await; let response = index.wait_task(response.uid()).await;

View File

@ -192,7 +192,7 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
#[error(transparent)] #[error(transparent)]
InvalidPrompt(#[from] crate::prompt::error::NewPromptError), InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
#[error("Invalid prompt in for embeddings with name '{0}': {1}.")] #[error("`.embedders.{0}.documentTemplate`: Invalid template: {1}.")]
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
#[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")]
TooManyEmbedders(usize), TooManyEmbedders(usize),
@ -200,6 +200,33 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
InvalidEmbedder(String), InvalidEmbedder(String),
#[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")]
TooManyVectors(String, usize), TooManyVectors(String, usize),
#[error("`.embedders.{embedder_name}`: Field `{field}` unavailable for source `{source_}` (only available for sources: {}). Available fields: {}",
allowed_sources_for_field
.iter()
.map(|accepted| format!("`{}`", accepted))
.collect::<Vec<String>>()
.join(", "),
allowed_fields_for_source
.iter()
.map(|accepted| format!("`{}`", accepted))
.collect::<Vec<String>>()
.join(", ")
)]
InvalidFieldForSource {
embedder_name: String,
source_: crate::vector::settings::EmbedderSource,
field: &'static str,
allowed_fields_for_source: &'static [&'static str],
allowed_sources_for_field: &'static [crate::vector::settings::EmbedderSource],
},
#[error("`.embedders.{embedder_name}.model`: Invalid model `{model}` for OpenAI. Supported models: {:?}", crate::vector::openai::EmbeddingModel::supported_models())]
InvalidOpenAiModel { embedder_name: String, model: String },
#[error("`.embedders.{embedder_name}`: Missing field `{field}` (note: this field is mandatory for source {source_})")]
MissingFieldForSource {
field: &'static str,
source_: crate::vector::settings::EmbedderSource,
embedder_name: String,
},
} }
impl From<crate::vector::Error> for Error { impl From<crate::vector::Error> for Error {

View File

@ -2553,7 +2553,7 @@ mod tests {
/// Vectors must be of the same length. /// Vectors must be of the same length.
#[test] #[test]
fn test_multiple_vectors() { fn test_multiple_vectors() {
use crate::vector::settings::{EmbedderSettings, EmbeddingSettings}; use crate::vector::settings::EmbeddingSettings;
let index = TempIndex::new(); let index = TempIndex::new();
index index
@ -2562,9 +2562,11 @@ mod tests {
embedders.insert( embedders.insert(
"manual".to_string(), "manual".to_string(),
Setting::Set(EmbeddingSettings { Setting::Set(EmbeddingSettings {
embedder_options: Setting::Set(EmbedderSettings::UserProvided( source: Setting::Set(crate::vector::settings::EmbedderSource::UserProvided),
crate::vector::settings::UserProvidedSettings { dimensions: 3 }, model: Setting::NotSet,
)), revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::Set(3),
document_template: Setting::NotSet, document_template: Setting::NotSet,
}), }),
); );

View File

@ -8,7 +8,7 @@ pub use self::index_documents::{
MergeFn, MergeFn,
}; };
pub use self::indexer_config::IndexerConfig; pub use self::indexer_config::IndexerConfig;
pub use self::settings::{Setting, Settings}; pub use self::settings::{validate_embedding_settings, Setting, Settings};
pub use self::update_step::UpdateIndexingStep; pub use self::update_step::UpdateIndexingStep;
pub use self::word_prefix_docids::WordPrefixDocids; pub use self::word_prefix_docids::WordPrefixDocids;
pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids; pub use self::words_prefix_integer_docids::WordPrefixIntegerDocids;

View File

@ -17,7 +17,7 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS
use crate::proximity::ProximityPrecision; use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod; use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep}; use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::settings::{EmbeddingSettings, PromptSettings}; use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings};
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
use crate::{FieldsIdsMap, Index, OrderBy, Result}; use crate::{FieldsIdsMap, Index, OrderBy, Result};
@ -78,11 +78,19 @@ impl<T> Setting<T> {
} }
} }
pub fn apply(&mut self, new: Self) { /// Returns `true` if applying the new setting changed this setting
pub fn apply(&mut self, new: Self) -> bool
where
T: PartialEq + Eq,
{
if let Setting::NotSet = new { if let Setting::NotSet = new {
return; return false;
}
if self == &new {
return false;
} }
*self = new; *self = new;
true
} }
} }
@ -950,17 +958,23 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
.merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right))
{ {
match joined { match joined {
// updated config
EitherOrBoth::Both((name, mut old), (_, new)) => { EitherOrBoth::Both((name, mut old), (_, new)) => {
old.apply(new); changed |= old.apply(new);
let new = validate_prompt(&name, old)?; let new = validate_embedding_settings(old, &name)?;
changed = true;
new_configs.insert(name, new); new_configs.insert(name, new);
} }
// unchanged config
EitherOrBoth::Left((name, setting)) => { EitherOrBoth::Left((name, setting)) => {
new_configs.insert(name, setting); new_configs.insert(name, setting);
} }
EitherOrBoth::Right((name, setting)) => { // new config
let setting = validate_prompt(&name, setting)?; EitherOrBoth::Right((name, mut setting)) => {
// apply the default source in case the source was not set so that it gets validated
crate::vector::settings::EmbeddingSettings::apply_default_source(
&mut setting,
);
let setting = validate_embedding_settings(setting, &name)?;
changed = true; changed = true;
new_configs.insert(name, setting); new_configs.insert(name, setting);
} }
@ -1072,8 +1086,12 @@ fn validate_prompt(
) -> Result<Setting<EmbeddingSettings>> { ) -> Result<Setting<EmbeddingSettings>> {
match new { match new {
Setting::Set(EmbeddingSettings { Setting::Set(EmbeddingSettings {
embedder_options, source,
document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }), model,
revision,
api_key,
dimensions,
document_template: Setting::Set(template),
}) => { }) => {
// validate // validate
let template = crate::prompt::Prompt::new(template) let template = crate::prompt::Prompt::new(template)
@ -1081,16 +1099,71 @@ fn validate_prompt(
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?;
Ok(Setting::Set(EmbeddingSettings { Ok(Setting::Set(EmbeddingSettings {
embedder_options, source,
document_template: Setting::Set(PromptSettings { model,
template: Setting::Set(template), revision,
}), api_key,
dimensions,
document_template: Setting::Set(template),
})) }))
} }
new => Ok(new), new => Ok(new),
} }
} }
pub fn validate_embedding_settings(
settings: Setting<EmbeddingSettings>,
name: &str,
) -> Result<Setting<EmbeddingSettings>> {
let settings = validate_prompt(name, settings)?;
let Setting::Set(settings) = settings else { return Ok(settings) };
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
settings;
let Some(inferred_source) = source.set() else {
return Ok(Setting::Set(EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
}));
};
match inferred_source {
EmbedderSource::OpenAi => {
check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;
if let Setting::Set(model) = &model {
crate::vector::openai::EmbeddingModel::from_name(model.as_str()).ok_or(
crate::error::UserError::InvalidOpenAiModel {
embedder_name: name.to_owned(),
model: model.clone(),
},
)?;
}
}
EmbedderSource::HuggingFace => {
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&dimensions, "dimensions", inferred_source, name)?;
}
EmbedderSource::UserProvided => {
check_unset(&model, "model", inferred_source, name)?;
check_unset(&revision, "revision", inferred_source, name)?;
check_unset(&api_key, "apiKey", inferred_source, name)?;
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
check_set(&dimensions, "dimensions", inferred_source, name)?;
}
}
Ok(Setting::Set(EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
}))
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use big_s::S; use big_s::S;

View File

@ -34,6 +34,9 @@ pub struct EmbedderOptions {
#[serde(deny_unknown_fields, rename_all = "camelCase")] #[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)] #[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum EmbeddingModel { pub enum EmbeddingModel {
// # WARNING
//
// If ever adding a model, make sure to add it to the list of supported models below.
#[default] #[default]
#[serde(rename = "text-embedding-ada-002")] #[serde(rename = "text-embedding-ada-002")]
#[deserr(rename = "text-embedding-ada-002")] #[deserr(rename = "text-embedding-ada-002")]
@ -41,6 +44,10 @@ pub enum EmbeddingModel {
} }
impl EmbeddingModel { impl EmbeddingModel {
pub fn supported_models() -> &'static [&'static str] {
&["text-embedding-ada-002"]
}
pub fn max_token(&self) -> usize { pub fn max_token(&self) -> usize {
match self { match self {
EmbeddingModel::TextEmbeddingAda002 => 8191, EmbeddingModel::TextEmbeddingAda002 => 8191,
@ -59,7 +66,7 @@ impl EmbeddingModel {
} }
} }
pub fn from_name(name: &'static str) -> Option<Self> { pub fn from_name(name: &str) -> Option<Self> {
match name { match name {
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
_ => None, _ => None,

View File

@ -4,32 +4,189 @@ use serde::{Deserialize, Serialize};
use crate::prompt::PromptData; use crate::prompt::PromptData;
use crate::update::Setting; use crate::update::Setting;
use crate::vector::EmbeddingConfig; use crate::vector::EmbeddingConfig;
use crate::UserError;
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")] #[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)] #[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct EmbeddingSettings { pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")]
#[deserr(default, rename = "source")]
pub embedder_options: Setting<EmbedderSettings>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
pub document_template: Setting<PromptSettings>, pub source: Setting<EmbedderSource>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub model: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub revision: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub dimensions: Setting<usize>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub document_template: Setting<String>,
}
pub fn check_unset<T>(
key: &Setting<T>,
field: &'static str,
source: EmbedderSource,
embedder_name: &str,
) -> Result<(), UserError> {
if matches!(key, Setting::NotSet) {
Ok(())
} else {
Err(UserError::InvalidFieldForSource {
embedder_name: embedder_name.to_owned(),
source_: source,
field,
allowed_fields_for_source: EmbeddingSettings::allowed_fields_for_source(source),
allowed_sources_for_field: EmbeddingSettings::allowed_sources_for_field(field),
})
}
}
pub fn check_set<T>(
key: &Setting<T>,
field: &'static str,
source: EmbedderSource,
embedder_name: &str,
) -> Result<(), UserError> {
if matches!(key, Setting::Set(_)) {
Ok(())
} else {
Err(UserError::MissingFieldForSource {
field,
source_: source,
embedder_name: embedder_name.to_owned(),
})
}
}
impl EmbeddingSettings {
pub const SOURCE: &'static str = "source";
pub const MODEL: &'static str = "model";
pub const REVISION: &'static str = "revision";
pub const API_KEY: &'static str = "apiKey";
pub const DIMENSIONS: &'static str = "dimensions";
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field {
Self::SOURCE => {
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
}
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => &[EmbedderSource::OpenAi],
Self::DIMENSIONS => &[EmbedderSource::UserProvided],
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
_other => unreachable!("unknown field"),
}
}
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
match source {
EmbedderSource::OpenAi => {
&[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
}
}
pub(crate) fn apply_default_source(setting: &mut Setting<EmbeddingSettings>) {
if let Setting::Set(EmbeddingSettings {
source: source @ (Setting::NotSet | Setting::Reset),
..
}) = setting
{
*source = Setting::Set(EmbedderSource::default())
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum EmbedderSource {
#[default]
OpenAi,
HuggingFace,
UserProvided,
}
impl std::fmt::Display for EmbedderSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
EmbedderSource::OpenAi => "openAi",
EmbedderSource::HuggingFace => "huggingFace",
EmbedderSource::UserProvided => "userProvided",
};
f.write_str(s)
}
} }
impl EmbeddingSettings { impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) { pub fn apply(&mut self, new: Self) {
let EmbeddingSettings { embedder_options, document_template: prompt } = new; let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
self.embedder_options.apply(embedder_options); new;
self.document_template.apply(prompt); let old_source = self.source;
self.source.apply(source);
// Reinitialize the whole setting object on a source change
if old_source != self.source {
*self = EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
};
return;
}
self.model.apply(model);
self.revision.apply(revision);
self.api_key.apply(api_key);
self.dimensions.apply(dimensions);
self.document_template.apply(document_template);
} }
} }
impl From<EmbeddingConfig> for EmbeddingSettings { impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self { fn from(value: EmbeddingConfig) -> Self {
Self { let EmbeddingConfig { embedder_options, prompt } = value;
embedder_options: Setting::Set(value.embedder_options.into()), match embedder_options {
document_template: Setting::Set(value.prompt.into()), super::EmbedderOptions::HuggingFace(options) => Self {
source: Setting::Set(EmbedderSource::HuggingFace),
model: Setting::Set(options.model),
revision: options.revision.map(Setting::Set).unwrap_or_default(),
api_key: Setting::NotSet,
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
},
super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi),
model: Setting::Set(options.embedding_model.name().to_owned()),
revision: Setting::NotSet,
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
dimensions: Setting::NotSet,
document_template: Setting::Set(prompt.template),
},
super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided),
model: Setting::NotSet,
revision: Setting::NotSet,
api_key: Setting::NotSet,
dimensions: Setting::Set(options.dimensions),
document_template: Setting::NotSet,
},
} }
} }
} }
@ -37,262 +194,51 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
impl From<EmbeddingSettings> for EmbeddingConfig { impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self { fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default(); let mut this = Self::default();
let EmbeddingSettings { embedder_options, document_template: prompt } = value; let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
if let Some(embedder_options) = embedder_options.set() { value;
this.embedder_options = embedder_options.into(); if let Some(source) = source.set() {
} match source {
if let Some(prompt) = prompt.set() { EmbedderSource::OpenAi => {
this.prompt = prompt.into(); let mut options = super::openai::EmbedderOptions::with_default_model(None);
}
this
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct PromptSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub template: Setting<String>,
}
impl PromptSettings {
pub fn apply(&mut self, new: Self) {
let PromptSettings { template } = new;
self.template.apply(template);
}
}
impl From<PromptData> for PromptSettings {
fn from(value: PromptData) -> Self {
Self { template: Setting::Set(value.template) }
}
}
impl From<PromptSettings> for PromptData {
fn from(value: PromptSettings) -> Self {
let mut this = PromptData::default();
let PromptSettings { template } = value;
if let Some(template) = template.set() {
this.template = template;
}
this
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub enum EmbedderSettings {
HuggingFace(Setting<HfEmbedderSettings>),
OpenAi(Setting<OpenAiEmbedderSettings>),
UserProvided(UserProvidedSettings),
}
impl<E> Deserr<E> for EmbedderSettings
where
E: deserr::DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
match value {
deserr::Value::Map(map) => {
if deserr::Map::len(&map) != 1 {
return Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"Expected a single field, got {} fields",
deserr::Map::len(&map)
),
},
location,
)));
}
let mut it = deserr::Map::into_iter(map);
let (k, v) = it.next().unwrap();
match k.as_str() {
"huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set(
HfEmbedderSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
))),
"openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set(
OpenAiEmbedderSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
))),
"userProvided" => Ok(EmbedderSettings::UserProvided(
UserProvidedSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
)),
other => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::UnknownKey {
key: other,
accepted: &["huggingFace", "openAi", "userProvided"],
},
location,
))),
}
}
_ => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::Map],
},
location,
))),
}
}
}
impl Default for EmbedderSettings {
fn default() -> Self {
Self::OpenAi(Default::default())
}
}
impl From<crate::vector::EmbedderOptions> for EmbedderSettings {
fn from(value: crate::vector::EmbedderOptions) -> Self {
match value {
crate::vector::EmbedderOptions::HuggingFace(hf) => {
Self::HuggingFace(Setting::Set(hf.into()))
}
crate::vector::EmbedderOptions::OpenAi(openai) => {
Self::OpenAi(Setting::Set(openai.into()))
}
crate::vector::EmbedderOptions::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
}
}
}
impl From<EmbedderSettings> for crate::vector::EmbedderOptions {
fn from(value: EmbedderSettings) -> Self {
match value {
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
EmbedderSettings::OpenAi(_setting) => {
Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None))
}
EmbedderSettings::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct HfEmbedderSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub model: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub revision: Setting<String>,
}
impl HfEmbedderSettings {
pub fn apply(&mut self, new: Self) {
let HfEmbedderSettings { model, revision } = new;
self.model.apply(model);
self.revision.apply(revision);
}
}
impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
fn from(value: crate::vector::hf::EmbedderOptions) -> Self {
Self {
model: Setting::Set(value.model),
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
}
}
}
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
fn from(value: HfEmbedderSettings) -> Self {
let HfEmbedderSettings { model, revision } = value;
let mut this = Self::default();
if let Some(model) = model.set() { if let Some(model) = model.set() {
this.model = model; if let Some(model) = super::openai::EmbeddingModel::from_name(&model) {
options.embedding_model = model;
}
}
if let Some(api_key) = api_key.set() {
options.api_key = Some(api_key);
}
this.embedder_options = super::EmbedderOptions::OpenAi(options);
}
EmbedderSource::HuggingFace => {
let mut options = super::hf::EmbedderOptions::default();
if let Some(model) = model.set() {
options.model = model;
// Reset the revision if we are setting the model. // Reset the revision if we are setting the model.
// This allows the following: // This allows the following:
// "huggingFace": {} -> default model with default revision // "huggingFace": {} -> default model with default revision
// "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision // "huggingFace": { "model": "name-of-the-default-model" } -> default model without a revision
// "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision // "huggingFace": { "model": "some-other-model" } -> most importantly, other model without a revision
this.revision = None; options.revision = None;
} }
if let Some(revision) = revision.set() { if let Some(revision) = revision.set() {
this.revision = Some(revision); options.revision = Some(revision);
} }
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
}
EmbedderSource::UserProvided => {
this.embedder_options =
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions: dimensions.set().unwrap(),
});
}
}
}
if let Setting::Set(template) = document_template {
this.prompt = PromptData { template }
}
this this
} }
} }
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct OpenAiEmbedderSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub api_key: Setting<String>,
#[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")]
#[deserr(default, rename = "model")]
pub embedding_model: Setting<crate::vector::openai::EmbeddingModel>,
}
impl OpenAiEmbedderSettings {
pub fn apply(&mut self, new: Self) {
let Self { api_key, embedding_model: embedding_mode } = new;
self.api_key.apply(api_key);
self.embedding_model.apply(embedding_mode);
}
}
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
Self {
api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset),
embedding_model: Setting::Set(value.embedding_model),
}
}
}
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
fn from(value: OpenAiEmbedderSettings) -> Self {
let OpenAiEmbedderSettings { api_key, embedding_model } = value;
Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() }
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct UserProvidedSettings {
pub dimensions: usize,
}
impl From<UserProvidedSettings> for crate::vector::manual::EmbedderOptions {
fn from(value: UserProvidedSettings) -> Self {
Self { dimensions: value.dimensions }
}
}
impl From<crate::vector::manual::EmbedderOptions> for UserProvidedSettings {
fn from(value: crate::vector::manual::EmbedderOptions) -> Self {
Self { dimensions: value.dimensions }
}
}