From 65d0c32aa7bed61597cb2ffd78836267a8ed2f22 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 15 Jul 2024 16:20:19 +0200 Subject: [PATCH] Allow overriding OpenAI's url --- milli/src/update/settings.rs | 1 - milli/src/vector/openai.rs | 17 +++++++++++------ milli/src/vector/settings.rs | 22 ++++++++++++++++++---- 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 3ad6e658c..54a25abd5 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1574,7 +1574,6 @@ pub fn validate_embedding_settings( EmbedderSource::OpenAi => { check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; - check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?; check_unset(&query, EmbeddingSettings::QUERY, inferred_source, name)?; check_unset(&input_field, EmbeddingSettings::INPUT_FIELD, inferred_source, name)?; check_unset( diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index e180aedaa..ea7ea97f6 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -10,6 +10,7 @@ use crate::ThreadPoolNoAbort; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { + pub url: Option, pub api_key: Option, pub embedding_model: EmbeddingModel, pub dimensions: Option, @@ -146,11 +147,13 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; impl EmbedderOptions { pub fn with_default_model(api_key: Option) -> Self { - Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None } - } - - pub fn with_embedding_model(api_key: Option, embedding_model: EmbeddingModel) -> Self { - Self { api_key, embedding_model, dimensions: None, distribution: None } + Self { + api_key, + embedding_model: Default::default(), + dimensions: None, + distribution: None, + url: None, + } } } @@ -175,11 +178,13 @@ impl Embedder { &inferred_api_key }); + let url = options.url.as_deref().unwrap_or(OPENAI_EMBEDDINGS_URL).to_owned(); + let rest_embedder = RestEmbedder::new(RestEmbedderOptions { api_key: Some(api_key.clone()), distribution: None, dimensions: Some(options.dimensions()), - url: OPENAI_EMBEDDINGS_URL.to_owned(), + url, query: options.query(), input_field: vec!["input".to_owned()], input_type: crate::vector::rest::InputType::TextArray, diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 9c7fb09b1..4b04e3370 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -166,7 +166,16 @@ impl SettingsDiff { ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); } if url.apply(new_url) { - ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); + match source { + // do not regenerate on an url change in OpenAI + Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {} + _ => { + ReindexAction::push_action( + &mut reindex_action, + ReindexAction::FullReindex, + ); + } + } } if query.apply(new_query) { ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); @@ -271,7 +280,7 @@ fn apply_default_for_source( *model = Setting::Reset; *revision = Setting::NotSet; *dimensions = Setting::NotSet; - *url = Setting::NotSet; + *url = Setting::Reset; *query = Setting::NotSet; *input_field = Setting::NotSet; *path_to_embeddings = Setting::NotSet; @@ -364,7 +373,7 @@ impl EmbeddingSettings { EmbedderSource::Ollama, EmbedderSource::Rest, ], - Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest], + Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest, EmbedderSource::OpenAi], Self::QUERY => &[EmbedderSource::Rest], Self::INPUT_FIELD => &[EmbedderSource::Rest], Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], @@ -390,6 +399,7 @@ impl EmbeddingSettings { Self::DOCUMENT_TEMPLATE, Self::DIMENSIONS, Self::DISTRIBUTION, + Self::URL, ], EmbedderSource::HuggingFace => &[ Self::SOURCE, @@ -494,6 +504,7 @@ impl From for EmbeddingSettings { distribution: distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions { + url, api_key, embedding_model, dimensions, @@ -505,7 +516,7 @@ impl From for EmbeddingSettings { api_key: api_key.map(Setting::Set).unwrap_or_default(), dimensions: dimensions.map(Setting::Set).unwrap_or_default(), document_template: Setting::Set(prompt.template), - url: Setting::NotSet, + url: url.map(Setting::Set).unwrap_or_default(), query: Setting::NotSet, input_field: Setting::NotSet, path_to_embeddings: Setting::NotSet, @@ -608,6 +619,9 @@ impl From for EmbeddingConfig { options.embedding_model = model; } } + if let Some(url) = url.set() { + options.url = Some(url); + } if let Some(api_key) = api_key.set() { options.api_key = Some(api_key); }