mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-22 21:04:27 +01:00
Add custom headers for REST embedder
This commit is contained in:
parent
22ef2d877f
commit
4654d51e05
@ -2744,6 +2744,7 @@ mod tests {
|
|||||||
request: Setting::NotSet,
|
request: Setting::NotSet,
|
||||||
response: Setting::NotSet,
|
response: Setting::NotSet,
|
||||||
distribution: Setting::NotSet,
|
distribution: Setting::NotSet,
|
||||||
|
headers: Setting::NotSet,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
settings.set_embedder_settings(embedders);
|
settings.set_embedder_settings(embedders);
|
||||||
|
@ -1544,6 +1544,7 @@ fn validate_prompt(
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
}) => {
|
}) => {
|
||||||
// validate
|
// validate
|
||||||
let template = crate::prompt::Prompt::new(template)
|
let template = crate::prompt::Prompt::new(template)
|
||||||
@ -1561,6 +1562,7 @@ fn validate_prompt(
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
new => Ok(new),
|
new => Ok(new),
|
||||||
@ -1584,6 +1586,7 @@ pub fn validate_embedding_settings(
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
} = settings;
|
} = settings;
|
||||||
|
|
||||||
if let Some(0) = dimensions.set() {
|
if let Some(0) = dimensions.set() {
|
||||||
@ -1622,6 +1625,7 @@ pub fn validate_embedding_settings(
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
}));
|
}));
|
||||||
};
|
};
|
||||||
match inferred_source {
|
match inferred_source {
|
||||||
@ -1630,6 +1634,7 @@ pub fn validate_embedding_settings(
|
|||||||
|
|
||||||
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
||||||
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
||||||
|
check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?;
|
||||||
|
|
||||||
if let Setting::Set(model) = &model {
|
if let Setting::Set(model) = &model {
|
||||||
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
|
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
|
||||||
@ -1669,6 +1674,7 @@ pub fn validate_embedding_settings(
|
|||||||
|
|
||||||
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
||||||
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
||||||
|
check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?;
|
||||||
}
|
}
|
||||||
EmbedderSource::HuggingFace => {
|
EmbedderSource::HuggingFace => {
|
||||||
check_unset(&api_key, EmbeddingSettings::API_KEY, inferred_source, name)?;
|
check_unset(&api_key, EmbeddingSettings::API_KEY, inferred_source, name)?;
|
||||||
@ -1677,6 +1683,7 @@ pub fn validate_embedding_settings(
|
|||||||
check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?;
|
check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?;
|
||||||
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
||||||
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
||||||
|
check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?;
|
||||||
}
|
}
|
||||||
EmbedderSource::UserProvided => {
|
EmbedderSource::UserProvided => {
|
||||||
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
|
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
|
||||||
@ -1693,6 +1700,7 @@ pub fn validate_embedding_settings(
|
|||||||
check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?;
|
check_unset(&url, EmbeddingSettings::URL, inferred_source, name)?;
|
||||||
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
check_unset(&request, EmbeddingSettings::REQUEST, inferred_source, name)?;
|
||||||
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
check_unset(&response, EmbeddingSettings::RESPONSE, inferred_source, name)?;
|
||||||
|
check_unset(&headers, EmbeddingSettings::HEADERS, inferred_source, name)?;
|
||||||
}
|
}
|
||||||
EmbedderSource::Rest => {
|
EmbedderSource::Rest => {
|
||||||
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
|
check_unset(&model, EmbeddingSettings::MODEL, inferred_source, name)?;
|
||||||
@ -1713,6 +1721,7 @@ pub fn validate_embedding_settings(
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -41,6 +41,7 @@ impl Embedder {
|
|||||||
response: serde_json::json!({
|
response: serde_json::json!({
|
||||||
"embedding": super::rest::RESPONSE_PLACEHOLDER,
|
"embedding": super::rest::RESPONSE_PLACEHOLDER,
|
||||||
}),
|
}),
|
||||||
|
headers: Default::default(),
|
||||||
},
|
},
|
||||||
super::rest::ConfigurationSource::Ollama,
|
super::rest::ConfigurationSource::Ollama,
|
||||||
) {
|
) {
|
||||||
|
@ -195,6 +195,7 @@ impl Embedder {
|
|||||||
super::rest::REPEAT_PLACEHOLDER
|
super::rest::REPEAT_PLACEHOLDER
|
||||||
]
|
]
|
||||||
}),
|
}),
|
||||||
|
headers: Default::default(),
|
||||||
},
|
},
|
||||||
super::rest::ConfigurationSource::OpenAi,
|
super::rest::ConfigurationSource::OpenAi,
|
||||||
)?;
|
)?;
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use deserr::Deserr;
|
use deserr::Deserr;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
@ -80,6 +82,7 @@ pub struct Embedder {
|
|||||||
struct EmbedderData {
|
struct EmbedderData {
|
||||||
client: ureq::Agent,
|
client: ureq::Agent,
|
||||||
bearer: Option<String>,
|
bearer: Option<String>,
|
||||||
|
headers: BTreeMap<String, String>,
|
||||||
url: String,
|
url: String,
|
||||||
request: Request,
|
request: Request,
|
||||||
response: Response,
|
response: Response,
|
||||||
@ -94,6 +97,7 @@ pub struct EmbedderOptions {
|
|||||||
pub url: String,
|
pub url: String,
|
||||||
pub request: serde_json::Value,
|
pub request: serde_json::Value,
|
||||||
pub response: serde_json::Value,
|
pub response: serde_json::Value,
|
||||||
|
pub headers: BTreeMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::hash::Hash for EmbedderOptions {
|
impl std::hash::Hash for EmbedderOptions {
|
||||||
@ -138,6 +142,7 @@ impl Embedder {
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
configuration_source,
|
configuration_source,
|
||||||
|
headers: options.headers,
|
||||||
};
|
};
|
||||||
|
|
||||||
let dimensions = if let Some(dimensions) = options.dimensions {
|
let dimensions = if let Some(dimensions) = options.dimensions {
|
||||||
@ -223,7 +228,10 @@ where
|
|||||||
} else {
|
} else {
|
||||||
request
|
request
|
||||||
};
|
};
|
||||||
let request = request.set("Content-Type", "application/json");
|
let mut request = request.set("Content-Type", "application/json");
|
||||||
|
for (header, value) in &data.headers {
|
||||||
|
request = request.set(header.as_str(), value.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
let body = data.request.inject_texts(inputs);
|
let body = data.request.inject_texts(inputs);
|
||||||
|
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use deserr::Deserr;
|
use deserr::Deserr;
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
@ -41,6 +43,9 @@ pub struct EmbeddingSettings {
|
|||||||
pub response: Setting<serde_json::Value>,
|
pub response: Setting<serde_json::Value>,
|
||||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
#[deserr(default)]
|
#[deserr(default)]
|
||||||
|
pub headers: Setting<BTreeMap<String, String>>,
|
||||||
|
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||||
|
#[deserr(default)]
|
||||||
pub distribution: Setting<DistributionShift>,
|
pub distribution: Setting<DistributionShift>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -105,6 +110,7 @@ impl SettingsDiff {
|
|||||||
mut request,
|
mut request,
|
||||||
mut response,
|
mut response,
|
||||||
mut distribution,
|
mut distribution,
|
||||||
|
mut headers,
|
||||||
} = old;
|
} = old;
|
||||||
|
|
||||||
let EmbeddingSettings {
|
let EmbeddingSettings {
|
||||||
@ -118,6 +124,7 @@ impl SettingsDiff {
|
|||||||
request: new_request,
|
request: new_request,
|
||||||
response: new_response,
|
response: new_response,
|
||||||
distribution: new_distribution,
|
distribution: new_distribution,
|
||||||
|
headers: new_headers,
|
||||||
} = new;
|
} = new;
|
||||||
|
|
||||||
let mut reindex_action = None;
|
let mut reindex_action = None;
|
||||||
@ -135,6 +142,7 @@ impl SettingsDiff {
|
|||||||
&mut request,
|
&mut request,
|
||||||
&mut response,
|
&mut response,
|
||||||
&mut document_template,
|
&mut document_template,
|
||||||
|
&mut headers,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if model.apply(new_model) {
|
if model.apply(new_model) {
|
||||||
@ -173,6 +181,7 @@ impl SettingsDiff {
|
|||||||
|
|
||||||
distribution.apply(new_distribution);
|
distribution.apply(new_distribution);
|
||||||
api_key.apply(new_api_key);
|
api_key.apply(new_api_key);
|
||||||
|
headers.apply(new_headers);
|
||||||
|
|
||||||
let updated_settings = EmbeddingSettings {
|
let updated_settings = EmbeddingSettings {
|
||||||
source,
|
source,
|
||||||
@ -185,6 +194,7 @@ impl SettingsDiff {
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
};
|
};
|
||||||
|
|
||||||
match reindex_action {
|
match reindex_action {
|
||||||
@ -218,6 +228,7 @@ fn apply_default_for_source(
|
|||||||
request: &mut Setting<serde_json::Value>,
|
request: &mut Setting<serde_json::Value>,
|
||||||
response: &mut Setting<serde_json::Value>,
|
response: &mut Setting<serde_json::Value>,
|
||||||
document_template: &mut Setting<String>,
|
document_template: &mut Setting<String>,
|
||||||
|
headers: &mut Setting<BTreeMap<String, String>>,
|
||||||
) {
|
) {
|
||||||
match source {
|
match source {
|
||||||
Setting::Set(EmbedderSource::HuggingFace) => {
|
Setting::Set(EmbedderSource::HuggingFace) => {
|
||||||
@ -227,6 +238,7 @@ fn apply_default_for_source(
|
|||||||
*url = Setting::NotSet;
|
*url = Setting::NotSet;
|
||||||
*request = Setting::NotSet;
|
*request = Setting::NotSet;
|
||||||
*response = Setting::NotSet;
|
*response = Setting::NotSet;
|
||||||
|
*headers = Setting::NotSet;
|
||||||
}
|
}
|
||||||
Setting::Set(EmbedderSource::Ollama) => {
|
Setting::Set(EmbedderSource::Ollama) => {
|
||||||
*model = Setting::Reset;
|
*model = Setting::Reset;
|
||||||
@ -235,6 +247,7 @@ fn apply_default_for_source(
|
|||||||
*url = Setting::NotSet;
|
*url = Setting::NotSet;
|
||||||
*request = Setting::NotSet;
|
*request = Setting::NotSet;
|
||||||
*response = Setting::NotSet;
|
*response = Setting::NotSet;
|
||||||
|
*headers = Setting::NotSet;
|
||||||
}
|
}
|
||||||
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
|
Setting::Set(EmbedderSource::OpenAi) | Setting::Reset => {
|
||||||
*model = Setting::Reset;
|
*model = Setting::Reset;
|
||||||
@ -243,6 +256,7 @@ fn apply_default_for_source(
|
|||||||
*url = Setting::Reset;
|
*url = Setting::Reset;
|
||||||
*request = Setting::NotSet;
|
*request = Setting::NotSet;
|
||||||
*response = Setting::NotSet;
|
*response = Setting::NotSet;
|
||||||
|
*headers = Setting::NotSet;
|
||||||
}
|
}
|
||||||
Setting::Set(EmbedderSource::Rest) => {
|
Setting::Set(EmbedderSource::Rest) => {
|
||||||
*model = Setting::NotSet;
|
*model = Setting::NotSet;
|
||||||
@ -251,6 +265,7 @@ fn apply_default_for_source(
|
|||||||
*url = Setting::Reset;
|
*url = Setting::Reset;
|
||||||
*request = Setting::Reset;
|
*request = Setting::Reset;
|
||||||
*response = Setting::Reset;
|
*response = Setting::Reset;
|
||||||
|
*headers = Setting::Reset;
|
||||||
}
|
}
|
||||||
Setting::Set(EmbedderSource::UserProvided) => {
|
Setting::Set(EmbedderSource::UserProvided) => {
|
||||||
*model = Setting::NotSet;
|
*model = Setting::NotSet;
|
||||||
@ -260,6 +275,7 @@ fn apply_default_for_source(
|
|||||||
*request = Setting::NotSet;
|
*request = Setting::NotSet;
|
||||||
*response = Setting::NotSet;
|
*response = Setting::NotSet;
|
||||||
*document_template = Setting::NotSet;
|
*document_template = Setting::NotSet;
|
||||||
|
*headers = Setting::NotSet;
|
||||||
}
|
}
|
||||||
Setting::NotSet => {}
|
Setting::NotSet => {}
|
||||||
}
|
}
|
||||||
@ -293,6 +309,7 @@ impl EmbeddingSettings {
|
|||||||
pub const URL: &'static str = "url";
|
pub const URL: &'static str = "url";
|
||||||
pub const REQUEST: &'static str = "request";
|
pub const REQUEST: &'static str = "request";
|
||||||
pub const RESPONSE: &'static str = "response";
|
pub const RESPONSE: &'static str = "response";
|
||||||
|
pub const HEADERS: &'static str = "headers";
|
||||||
|
|
||||||
pub const DISTRIBUTION: &'static str = "distribution";
|
pub const DISTRIBUTION: &'static str = "distribution";
|
||||||
|
|
||||||
@ -324,6 +341,7 @@ impl EmbeddingSettings {
|
|||||||
Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest, EmbedderSource::OpenAi],
|
Self::URL => &[EmbedderSource::Ollama, EmbedderSource::Rest, EmbedderSource::OpenAi],
|
||||||
Self::REQUEST => &[EmbedderSource::Rest],
|
Self::REQUEST => &[EmbedderSource::Rest],
|
||||||
Self::RESPONSE => &[EmbedderSource::Rest],
|
Self::RESPONSE => &[EmbedderSource::Rest],
|
||||||
|
Self::HEADERS => &[EmbedderSource::Rest],
|
||||||
Self::DISTRIBUTION => &[
|
Self::DISTRIBUTION => &[
|
||||||
EmbedderSource::HuggingFace,
|
EmbedderSource::HuggingFace,
|
||||||
EmbedderSource::Ollama,
|
EmbedderSource::Ollama,
|
||||||
@ -370,6 +388,7 @@ impl EmbeddingSettings {
|
|||||||
Self::URL,
|
Self::URL,
|
||||||
Self::REQUEST,
|
Self::REQUEST,
|
||||||
Self::RESPONSE,
|
Self::RESPONSE,
|
||||||
|
Self::HEADERS,
|
||||||
Self::DISTRIBUTION,
|
Self::DISTRIBUTION,
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@ -440,6 +459,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
url: Setting::NotSet,
|
url: Setting::NotSet,
|
||||||
request: Setting::NotSet,
|
request: Setting::NotSet,
|
||||||
response: Setting::NotSet,
|
response: Setting::NotSet,
|
||||||
|
headers: Setting::NotSet,
|
||||||
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
||||||
},
|
},
|
||||||
super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions {
|
super::EmbedderOptions::OpenAi(super::openai::EmbedderOptions {
|
||||||
@ -458,6 +478,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
url: url.map(Setting::Set).unwrap_or_default(),
|
url: url.map(Setting::Set).unwrap_or_default(),
|
||||||
request: Setting::NotSet,
|
request: Setting::NotSet,
|
||||||
response: Setting::NotSet,
|
response: Setting::NotSet,
|
||||||
|
headers: Setting::NotSet,
|
||||||
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
||||||
},
|
},
|
||||||
super::EmbedderOptions::Ollama(super::ollama::EmbedderOptions {
|
super::EmbedderOptions::Ollama(super::ollama::EmbedderOptions {
|
||||||
@ -475,6 +496,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
url: url.map(Setting::Set).unwrap_or_default(),
|
url: url.map(Setting::Set).unwrap_or_default(),
|
||||||
request: Setting::NotSet,
|
request: Setting::NotSet,
|
||||||
response: Setting::NotSet,
|
response: Setting::NotSet,
|
||||||
|
headers: Setting::NotSet,
|
||||||
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
||||||
},
|
},
|
||||||
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
|
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
|
||||||
@ -490,6 +512,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
url: Setting::NotSet,
|
url: Setting::NotSet,
|
||||||
request: Setting::NotSet,
|
request: Setting::NotSet,
|
||||||
response: Setting::NotSet,
|
response: Setting::NotSet,
|
||||||
|
headers: Setting::NotSet,
|
||||||
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
||||||
},
|
},
|
||||||
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
||||||
@ -499,6 +522,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
}) => Self {
|
}) => Self {
|
||||||
source: Setting::Set(EmbedderSource::Rest),
|
source: Setting::Set(EmbedderSource::Rest),
|
||||||
model: Setting::NotSet,
|
model: Setting::NotSet,
|
||||||
@ -510,6 +534,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
|||||||
request: Setting::Set(request),
|
request: Setting::Set(request),
|
||||||
response: Setting::Set(response),
|
response: Setting::Set(response),
|
||||||
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
distribution: distribution.map(Setting::Set).unwrap_or_default(),
|
||||||
|
headers: Setting::Set(headers),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -529,6 +554,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
|||||||
request,
|
request,
|
||||||
response,
|
response,
|
||||||
distribution,
|
distribution,
|
||||||
|
headers,
|
||||||
} = value;
|
} = value;
|
||||||
|
|
||||||
if let Some(source) = source.set() {
|
if let Some(source) = source.set() {
|
||||||
@ -598,6 +624,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
|||||||
request: request.set().unwrap(),
|
request: request.set().unwrap(),
|
||||||
response: response.set().unwrap(),
|
response: response.set().unwrap(),
|
||||||
distribution: distribution.set(),
|
distribution: distribution.set(),
|
||||||
|
headers: headers.set().unwrap_or_default(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user