Allow actually passing dimensions for OpenAI source

-> make sure the settings change is rejected or the settings task fails when the specified model doesn't support
overriding `dimensions` and the passed `dimensions` differs from the model's default dimensions.
This commit is contained in:
Louis Dureuil 2024-02-07 10:39:19 +01:00
parent 9ac5750096
commit 517f5332d6
No known key found for this signature in database
4 changed files with 47 additions and 9 deletions

View file

@ -1,6 +1,7 @@
use deserr::Deserr;
use serde::{Deserialize, Serialize};
use super::openai;
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
@ -82,7 +83,7 @@ impl EmbeddingSettings {
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
Self::REVISION => &[EmbedderSource::HuggingFace],
Self::API_KEY => &[EmbedderSource::OpenAi],
Self::DIMENSIONS => &[EmbedderSource::UserProvided],
Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided],
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
_other => unreachable!("unknown field"),
}
@ -90,9 +91,13 @@ impl EmbeddingSettings {
pub fn allowed_fields_for_source(source: EmbedderSource) -> &'static [&'static str] {
match source {
EmbedderSource::OpenAi => {
&[Self::SOURCE, Self::MODEL, Self::API_KEY, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::OpenAi => &[
Self::SOURCE,
Self::MODEL,
Self::API_KEY,
Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS,
],
EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
}
@ -109,6 +114,17 @@ impl EmbeddingSettings {
*source = Setting::Set(EmbedderSource::default())
}
}
pub(crate) fn apply_default_openai_model(setting: &mut Setting<EmbeddingSettings>) {
if let Setting::Set(EmbeddingSettings {
source: Setting::Set(EmbedderSource::OpenAi),
model: model @ (Setting::NotSet | Setting::Reset),
..
}) = setting
{
*model = Setting::Set(openai::EmbeddingModel::default().name().to_owned())
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]