Allow overriding OpenAI's url

This commit is contained in:
Louis Dureuil 2024-07-15 16:20:19 +02:00
parent 82647bcded
commit 65d0c32aa7
No known key found for this signature in database
3 changed files with 29 additions and 11 deletions

View File

@ -1574,7 +1574,6 @@ pub fn validate_embedding_settings(
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => {
check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?; check_unset(&revision, EmbeddingSettings::REVISION, inferred_source, name)?;
check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?;
check_unset(&query, EmbeddingSettings::QUERY, inferred_source, name)?; check_unset(&query, EmbeddingSettings::QUERY, inferred_source, name)?;
check_unset(&input_field, EmbeddingSettings::INPUT_FIELD, inferred_source, name)?; check_unset(&input_field, EmbeddingSettings::INPUT_FIELD, inferred_source, name)?;
check_unset( check_unset(

View File

@ -10,6 +10,7 @@ use crate::ThreadPoolNoAbort;
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub url: Option<String>,
pub api_key: Option<String>, pub api_key: Option<String>,
pub embedding_model: EmbeddingModel, pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>, pub dimensions: Option<usize>,
@ -146,11 +147,13 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self { pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None } Self {
api_key,
embedding_model: Default::default(),
dimensions: None,
distribution: None,
url: None,
} }
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model, dimensions: None, distribution: None }
} }
} }
@ -175,11 +178,13 @@ impl Embedder {
&inferred_api_key &inferred_api_key
}); });
let url = options.url.as_deref().unwrap_or(OPENAI_EMBEDDINGS_URL).to_owned();
let rest_embedder = RestEmbedder::new(RestEmbedderOptions { let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
api_key: Some(api_key.clone()), api_key: Some(api_key.clone()),
distribution: None, distribution: None,
dimensions: Some(options.dimensions()), dimensions: Some(options.dimensions()),
url: OPENAI_EMBEDDINGS_URL.to_owned(), url,
query: options.query(), query: options.query(),
input_field: vec!["input".to_owned()], input_field: vec!["input".to_owned()],
input_type: crate::vector::rest::InputType::TextArray, input_type: crate::vector::rest::InputType::TextArray,

View File

@ -166,7 +166,16 @@ impl SettingsDiff {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
} }
if url.apply(new_url) { if url.apply(new_url) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); match source {
// do not regenerate on an url change in OpenAI
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {}
_ => {
ReindexAction::push_action(
&mut reindex_action,
ReindexAction::FullReindex,
);
}
}
} }
if query.apply(new_query) { if query.apply(new_query) {
ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex); ReindexAction::push_action(&mut reindex_action, ReindexAction::FullReindex);
@ -271,7 +280,7 @@ fn apply_default_for_source(
*model = Setting::Reset; *model = Setting::Reset;
*revision = Setting::NotSet; *revision = Setting::NotSet;
*dimensions = Setting::NotSet; *dimensions = Setting::NotSet;
*url = Setting::NotSet; *url = Setting::Reset;
*query = Setting::NotSet; *query = Setting::NotSet;
*input_field = Setting::NotSet; *input_field = Setting::NotSet;
*path_to_embeddings = Setting::NotSet; *path_to_embeddings = Setting::NotSet;
@ -364,7 +373,7 @@ impl EmbeddingSettings {
EmbedderSource::Ollama, EmbedderSource::Ollama,
EmbedderSource::Rest, EmbedderSource::Rest,
], ],
Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest], Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest, EmbedderSource::OpenAi],
Self::QUERY => &[EmbedderSource::Rest], Self::QUERY => &[EmbedderSource::Rest],
Self::INPUT_FIELD => &[EmbedderSource::Rest], Self::INPUT_FIELD => &[EmbedderSource::Rest],
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
@ -390,6 +399,7 @@ impl EmbeddingSettings {
Self::DOCUMENT_TEMPLATE, Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS, Self::DIMENSIONS,
Self::DISTRIBUTION, Self::DISTRIBUTION,
Self::URL,
], ],
EmbedderSource::HuggingFace => &[ EmbedderSource::HuggingFace => &[
Self::SOURCE, Self::SOURCE,
@ -494,6 +504,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
distribution: distribution.map(Setting::Set).unwrap_or_default(), distribution: distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions { super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions {
url,
api_key, api_key,
embedding_model, embedding_model,
dimensions, dimensions,
@ -505,7 +516,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
api_key: api_key.map(Setting::Set).unwrap_or_default(), api_key: api_key.map(Setting::Set).unwrap_or_default(),
dimensions: dimensions.map(Setting::Set).unwrap_or_default(), dimensions: dimensions.map(Setting::Set).unwrap_or_default(),
document_template: Setting::Set(prompt.template), document_template: Setting::Set(prompt.template),
url: Setting::NotSet, url: url.map(Setting::Set).unwrap_or_default(),
query: Setting::NotSet, query: Setting::NotSet,
input_field: Setting::NotSet, input_field: Setting::NotSet,
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
@ -608,6 +619,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
options.embedding_model = model; options.embedding_model = model;
} }
} }
if let Some(url) = url.set() {
options.url = Some(url);
}
if let Some(api_key) = api_key.set() { if let Some(api_key) = api_key.set() {
options.api_key = Some(api_key); options.api_key = Some(api_key);
} }