diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 1df31fff2..39919d94a 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -2744,6 +2744,7 @@ mod tests { request: Setting::NotSet, response: Setting::NotSet, distribution: Setting::NotSet, + headers: Setting::NotSet, }), ); settings.set_embedder_settings(embedders); diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index e423852f1..2836f4bc9 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1544,6 +1544,7 @@ fn validate_prompt( request, response, distribution, + headers, }) => { // validate let template = crate::prompt::Prompt::new(template) @@ -1561,6 +1562,7 @@ fn validate_prompt( request, response, distribution, + headers, })) } new => Ok(new), @@ -1584,6 +1586,7 @@ pub fn validate_embedding_settings( request, response, distribution, + headers, } = settings; if let Some(0) = dimensions.set() { @@ -1622,6 +1625,7 @@ pub fn validate_embedding_settings( request, response, distribution, + headers, })); }; match inferred_source { @@ -1630,6 +1634,7 @@ pub fn validate_embedding_settings( check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; + check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?; if let Setting::Set(model) = &model { let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str()) @@ -1669,6 +1674,7 @@ pub fn validate_embedding_settings( check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; + check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?; } EmbedderSource::HuggingFace => { check_unset(&api_key, EmbeddingSettings::API_KEY, inferred_source, name)?; @@ -1677,6 +1683,7 @@ pub fn validate_embedding_settings( check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?; check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; + check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?; } EmbedderSource::UserProvided => { check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?; @@ -1693,6 +1700,7 @@ pub fn validate_embedding_settings( check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?; check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?; check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?; + check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?; } EmbedderSource::Rest => { check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?; @@ -1713,6 +1721,7 @@ pub fn validate_embedding_settings( request, response, distribution, + headers, })) } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index 84baac1ba..d8b75342b 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -41,6 +41,7 @@ impl Embedder { response: serde_json::json!({ "embedding": super::rest::RESPONSE_PLACEHOLDER, }), + headers: Default::default(), }, super::rest::ConfigurationSource::Ollama, ) { diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 514ad4a3b..ce63e69d7 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -195,6 +195,7 @@ impl Embedder { super::rest::REPEAT_PLACEHOLDER ] }), + headers: Default::default(), }, super::rest::ConfigurationSource::OpenAi, )?; diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 35a7ebc41..593d2b509 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use deserr::Deserr; use rand::Rng; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; @@ -80,6 +82,7 @@ pub struct Embedder { struct EmbedderData { client: ureq::Agent, bearer: Option, + headers: BTreeMap, url: String, request: Request, response: Response, @@ -94,6 +97,7 @@ pub struct EmbedderOptions { pub url: String, pub request: serde_json::Value, pub response: serde_json::Value, + pub headers: BTreeMap, } impl std::hash::Hash for EmbedderOptions { @@ -138,6 +142,7 @@ impl Embedder { request, response, configuration_source, + headers: options.headers, }; let dimensions = if let Some(dimensions) = options.dimensions { @@ -223,7 +228,10 @@ where } else { request }; - let request = request.set("Content-Type", "application/json"); + let mut request = request.set("Content-Type", "application/json"); + for (header, value) in &data.headers { + request = request.set(header.as_str(), value.as_str()); + } let body = data.request.inject_texts(inputs); diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index e15999d4f..ef0c8f7ff 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -1,3 +1,5 @@ +use std::collections::BTreeMap; + use deserr::Deserr; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; @@ -41,6 +43,9 @@ pub struct EmbeddingSettings { pub response: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] + pub headers: Setting>, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] pub distribution: Setting, } @@ -105,6 +110,7 @@ impl SettingsDiff { mut request, mut response, mut distribution, + mut headers, } = old; let EmbeddingSettings { @@ -118,6 +124,7 @@ impl SettingsDiff { request: new_request, response: new_response, distribution: new_distribution, + headers: new_headers, } = new; let mut reindex_action = None; @@ -135,6 +142,7 @@ impl SettingsDiff { &mut request, &mut response, &mut document_template, + &mut headers, ) } if model.apply(new_model) { @@ -173,6 +181,7 @@ impl SettingsDiff { distribution.apply(new_distribution); api_key.apply(new_api_key); + headers.apply(new_headers); let updated_settings = EmbeddingSettings { source, @@ -185,6 +194,7 @@ impl SettingsDiff { request, response, distribution, + headers, }; match reindex_action { @@ -218,6 +228,7 @@ fn apply_default_for_source( request: &mut Setting, response: &mut Setting, document_template: &mut Setting, + headers: &mut Setting>, ) { match source { Setting::Set(EmbedderSource::HuggingFace) => { @@ -227,6 +238,7 @@ fn apply_default_for_source( *url = Setting::NotSet; *request = Setting::NotSet; *response = Setting::NotSet; + *headers = Setting::NotSet; } Setting::Set(EmbedderSource::Ollama) => { *model = Setting::Reset; @@ -235,6 +247,7 @@ fn apply_default_for_source( *url = Setting::NotSet; *request = Setting::NotSet; *response = Setting::NotSet; + *headers = Setting::NotSet; } Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => { *model = Setting::Reset; @@ -243,6 +256,7 @@ fn apply_default_for_source( *url = Setting::Reset; *request = Setting::NotSet; *response = Setting::NotSet; + *headers = Setting::NotSet; } Setting::Set(EmbedderSource::Rest) => { *model = Setting::NotSet; @@ -251,6 +265,7 @@ fn apply_default_for_source( *url = Setting::Reset; *request = Setting::Reset; *response = Setting::Reset; + *headers = Setting::Reset; } Setting::Set(EmbedderSource::UserProvided) => { *model = Setting::NotSet; @@ -260,6 +275,7 @@ fn apply_default_for_source( *request = Setting::NotSet; *response = Setting::NotSet; *document_template = Setting::NotSet; + *headers = Setting::NotSet; } Setting::NotSet => {} } @@ -293,6 +309,7 @@ impl EmbeddingSettings { pub const URL: &'static str = "url"; pub const REQUEST: &'static str = "request"; pub const RESPONSE: &'static str = "response"; + pub const HEADERS: &'static str = "headers"; pub const DISTRIBUTION: &'static str = "distribution"; @@ -324,6 +341,7 @@ impl EmbeddingSettings { Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest, EmbedderSource::OpenAi], Self::REQUEST => &[EmbedderSource::Rest], Self::RESPONSE => &[EmbedderSource::Rest], + Self::HEADERS => &[EmbedderSource::Rest], Self::DISTRIBUTION => &[ EmbedderSource::HuggingFace, EmbedderSource::Ollama, @@ -370,6 +388,7 @@ impl EmbeddingSettings { Self::URL, Self::REQUEST, Self::RESPONSE, + Self::HEADERS, Self::DISTRIBUTION, ], } @@ -440,6 +459,7 @@ impl From for EmbeddingSettings { url: Setting::NotSet, request: Setting::NotSet, response: Setting::NotSet, + headers: Setting::NotSet, distribution: distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions { @@ -458,6 +478,7 @@ impl From for EmbeddingSettings { url: url.map(Setting::Set).unwrap_or_default(), request: Setting::NotSet, response: Setting::NotSet, + headers: Setting::NotSet, distribution: distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::Ollama(super::ollama::EmbedderOptions { @@ -475,6 +496,7 @@ impl From for EmbeddingSettings { url: url.map(Setting::Set).unwrap_or_default(), request: Setting::NotSet, response: Setting::NotSet, + headers: Setting::NotSet, distribution: distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions { @@ -490,6 +512,7 @@ impl From for EmbeddingSettings { url: Setting::NotSet, request: Setting::NotSet, response: Setting::NotSet, + headers: Setting::NotSet, distribution: distribution.map(Setting::Set).unwrap_or_default(), }, super::EmbedderOptions::Rest(super::rest::EmbedderOptions { @@ -499,6 +522,7 @@ impl From for EmbeddingSettings { request, response, distribution, + headers, }) => Self { source: Setting::Set(EmbedderSource::Rest), model: Setting::NotSet, @@ -510,6 +534,7 @@ impl From for EmbeddingSettings { request: Setting::Set(request), response: Setting::Set(response), distribution: distribution.map(Setting::Set).unwrap_or_default(), + headers: Setting::Set(headers), }, } } @@ -529,6 +554,7 @@ impl From for EmbeddingConfig { request, response, distribution, + headers, } = value; if let Some(source) = source.set() { @@ -598,6 +624,7 @@ impl From for EmbeddingConfig { request: request.set().unwrap(), response: response.set().unwrap(), distribution: distribution.set(), + headers: headers.set().unwrap_or_default(), }) } }