4537: Expose distribution shift in settings r=ManyTheFish a=dureuill

See [usage page](https://meilisearch.notion.site/v1-8-AI-search-API-usage-135552d6e85a4a52bc7109be82aeca42#d652adc0890445658aaf36352dbc8802)

# Changes

- Distribution shift added to all embedders.
- Exposed in settings
- Changed the reindexing logic to not trigger a reindex operation when only the distribution shift or API key change

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2024-04-03 09:08:58 +00:00 committed by GitHub
commit 56bf8503db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 231 additions and 95 deletions

View File

@ -6,7 +6,7 @@ source: index-scheduler/src/lib.rs
[] []
---------------------------------------------------------------------- ----------------------------------------------------------------------
### All Tasks: ### All Tasks:
0 {uid: 0, status: enqueued, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }} 0 {uid: 0, status: enqueued, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }}
---------------------------------------------------------------------- ----------------------------------------------------------------------
### Status: ### Status:
enqueued [0,] enqueued [0,]

View File

@ -6,7 +6,7 @@ source: index-scheduler/src/lib.rs
[] []
---------------------------------------------------------------------- ----------------------------------------------------------------------
### All Tasks: ### All Tasks:
0 {uid: 0, status: succeeded, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }} 0 {uid: 0, status: succeeded, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }}
---------------------------------------------------------------------- ----------------------------------------------------------------------
### Status: ### Status:
enqueued [] enqueued []

View File

@ -87,6 +87,38 @@ async fn simple_search() {
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###);
} }
#[actix_rt::test]
async fn distribution_shift() {
let server = Server::new().await;
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
let search = json!({"q": "Captain", "vector": [1.0, 1.0], "showRankingScore": true, "hybrid": {"semanticRatio": 1.0}});
let (response, code) = index.search_post(search.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.990290343761444,"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.974341630935669,"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.9472135901451112,"_semanticScore":0.9472136}]"###);
let (response, code) = index
.update_settings(json!({
"embedders": {
"default": {
"distribution": {
"mean": 0.998,
"sigma": 0.01
}
}
}
}))
.await;
snapshot!(code, @"202 Accepted");
let response = server.wait_task(response.uid()).await;
snapshot!(response["details"], @r###"{"embedders":{"default":{"distribution":{"mean":0.998,"sigma":0.01}}}}"###);
let (response, code) = index.search_post(search).await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.19161224365234375,"_semanticScore":0.19161224},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":1.1920928955078125e-7,"_semanticScore":1.1920929e-7},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.1920928955078125e-7,"_semanticScore":1.1920929e-7}]"###);
}
#[actix_rt::test] #[actix_rt::test]
async fn highlighter() { async fn highlighter() {
let server = Server::new().await; let server = Server::new().await;

View File

@ -2652,6 +2652,7 @@ mod tests {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: Setting::NotSet,
}), }),
); );
settings.set_embedder_settings(embedders); settings.set_embedder_settings(embedders);

View File

@ -976,7 +976,12 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
match joined { match joined {
// updated config // updated config
EitherOrBoth::Both((name, mut old), (_, new)) => { EitherOrBoth::Both((name, mut old), (_, new)) => {
changed |= old.apply(new); changed |= EmbeddingSettings::apply_and_need_reindex(&mut old, new);
if changed {
tracing::debug!(embedder = name, "need reindex");
} else {
tracing::debug!(embedder = name, "skip reindex");
}
let new = validate_embedding_settings(old, &name)?; let new = validate_embedding_settings(old, &name)?;
new_configs.insert(name, new); new_configs.insert(name, new);
} }
@ -1169,6 +1174,7 @@ fn validate_prompt(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
}) => { }) => {
// validate // validate
let template = crate::prompt::Prompt::new(template) let template = crate::prompt::Prompt::new(template)
@ -1188,6 +1194,7 @@ fn validate_prompt(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
})) }))
} }
new => Ok(new), new => Ok(new),
@ -1213,6 +1220,7 @@ pub fn validate_embedding_settings(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
} = settings; } = settings;
if let Some(0) = dimensions.set() { if let Some(0) = dimensions.set() {
@ -1244,6 +1252,7 @@ pub fn validate_embedding_settings(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
})); }));
}; };
match inferred_source { match inferred_source {
@ -1388,6 +1397,7 @@ pub fn validate_embedding_settings(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
})) }))
} }

View File

@ -33,6 +33,7 @@ enum WeightSource {
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub model: String, pub model: String,
pub revision: Option<String>, pub revision: Option<String>,
pub distribution: Option<DistributionShift>,
} }
impl EmbedderOptions { impl EmbedderOptions {
@ -40,6 +41,7 @@ impl EmbedderOptions {
Self { Self {
model: "BAAI/bge-base-en-v1.5".to_string(), model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
distribution: None,
} }
} }
} }
@ -193,13 +195,15 @@ impl Embedder {
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
if self.options.model == "BAAI/bge-base-en-v1.5" { self.options.distribution.or_else(|| {
Some(DistributionShift { if self.options.model == "BAAI/bge-base-en-v1.5" {
current_mean: ordered_float::OrderedFloat(0.85), Some(DistributionShift {
current_sigma: ordered_float::OrderedFloat(0.1), current_mean: ordered_float::OrderedFloat(0.85),
}) current_sigma: ordered_float::OrderedFloat(0.1),
} else { })
None } else {
} None
}
})
} }
} }

View File

@ -1,19 +1,21 @@
use super::error::EmbedError; use super::error::EmbedError;
use super::Embeddings; use super::{DistributionShift, Embeddings};
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Embedder { pub struct Embedder {
dimensions: usize, dimensions: usize,
distribution: Option<DistributionShift>,
} }
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub dimensions: usize, pub dimensions: usize,
pub distribution: Option<DistributionShift>,
} }
impl Embedder { impl Embedder {
pub fn new(options: EmbedderOptions) -> Self { pub fn new(options: EmbedderOptions) -> Self {
Self { dimensions: options.dimensions } Self { dimensions: options.dimensions, distribution: options.distribution }
} }
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
@ -31,4 +33,8 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
} }
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution
}
} }

View File

@ -1,6 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use deserr::{DeserializeError, Deserr};
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -292,7 +293,7 @@ impl Embedder {
Embedder::HuggingFace(embedder) => embedder.distribution(), Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(), Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(), Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None, Embedder::UserProvided(embedder) => embedder.distribution(),
Embedder::Rest(embedder) => embedder.distribution(), Embedder::Rest(embedder) => embedder.distribution(),
} }
} }
@ -317,10 +318,50 @@ pub struct DistributionShift {
pub current_sigma: OrderedFloat<f32>, pub current_sigma: OrderedFloat<f32>,
} }
#[derive(Serialize, Deserialize)] impl<E> Deserr<E> for DistributionShift
where
E: DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
if value.mean < 0. || value.mean > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution mean must be in the range [0, 1], got {}",
value.mean
),
},
location,
)));
}
if value.sigma <= 0. || value.sigma > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution sigma must be in the range ]0, 1], got {}",
value.sigma
),
},
location,
)));
}
Ok(value.into())
}
}
#[derive(Serialize, Deserialize, Deserr)]
#[serde(deny_unknown_fields)]
#[deserr(deny_unknown_fields)]
struct DistributionShiftSerializable { struct DistributionShiftSerializable {
current_mean: f32, mean: f32,
current_sigma: f32, sigma: f32,
} }
impl From<DistributionShift> for DistributionShiftSerializable { impl From<DistributionShift> for DistributionShiftSerializable {
@ -330,18 +371,13 @@ impl From<DistributionShift> for DistributionShiftSerializable {
current_sigma: OrderedFloat(current_sigma), current_sigma: OrderedFloat(current_sigma),
}: DistributionShift, }: DistributionShift,
) -> Self { ) -> Self {
Self { current_mean, current_sigma } Self { mean: current_mean, sigma: current_sigma }
} }
} }
impl From<DistributionShiftSerializable> for DistributionShift { impl From<DistributionShiftSerializable> for DistributionShift {
fn from( fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable, Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
) -> Self {
Self {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}
} }
} }

View File

@ -14,11 +14,12 @@ pub struct EmbedderOptions {
pub embedding_model: String, pub embedding_model: String,
pub url: Option<String>, pub url: Option<String>,
pub api_key: Option<String>, pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
} }
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>, url: Option<String>) -> Self { pub fn with_default_model(api_key: Option<String>, url: Option<String>) -> Self {
Self { embedding_model: "nomic-embed-text".into(), api_key, url } Self { embedding_model: "nomic-embed-text".into(), api_key, url, distribution: None }
} }
} }
@ -27,8 +28,8 @@ impl Embedder {
let model = options.embedding_model.as_str(); let model = options.embedding_model.as_str();
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
api_key: options.api_key, api_key: options.api_key,
distribution: None,
dimensions: None, dimensions: None,
distribution: options.distribution,
url: options.url.unwrap_or_else(get_ollama_path), url: options.url.unwrap_or_else(get_ollama_path),
query: serde_json::json!({ query: serde_json::json!({
"model": model, "model": model,
@ -90,7 +91,7 @@ impl Embedder {
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
None self.rest_embedder.distribution()
} }
} }

View File

@ -11,6 +11,7 @@ pub struct EmbedderOptions {
pub api_key: Option<String>, pub api_key: Option<String>,
pub embedding_model: EmbeddingModel, pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>, pub dimensions: Option<usize>,
pub distribution: Option<DistributionShift>,
} }
impl EmbedderOptions { impl EmbedderOptions {
@ -37,6 +38,10 @@ impl EmbedderOptions {
query query
} }
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution.or(self.embedding_model.distribution())
}
} }
#[derive( #[derive(
@ -139,11 +144,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self { pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default(), dimensions: None } Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None }
} }
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self { pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model, dimensions: None } Self { api_key, embedding_model, dimensions: None, distribution: None }
} }
} }
@ -170,7 +175,7 @@ impl Embedder {
let rest_embedder = RestEmbedder::new(RestEmbedderOptions { let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
api_key: Some(api_key.clone()), api_key: Some(api_key.clone()),
distribution: options.embedding_model.distribution(), distribution: None,
dimensions: Some(options.dimensions()), dimensions: Some(options.dimensions()),
url: OPENAI_EMBEDDINGS_URL.to_owned(), url: OPENAI_EMBEDDINGS_URL.to_owned(),
query: options.query(), query: options.query(),
@ -256,6 +261,6 @@ impl Embedder {
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution() self.options.distribution()
} }
} }

View File

@ -2,7 +2,7 @@ use deserr::Deserr;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::rest::InputType; use super::rest::InputType;
use super::{ollama, openai}; use super::{ollama, openai, DistributionShift};
use crate::prompt::PromptData; use crate::prompt::PromptData;
use crate::update::Setting; use crate::update::Setting;
use crate::vector::EmbeddingConfig; use crate::vector::EmbeddingConfig;
@ -48,6 +48,9 @@ pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")] #[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)] #[deserr(default)]
pub input_type: Setting<InputType>, pub input_type: Setting<InputType>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub distribution: Setting<DistributionShift>,
} }
pub fn check_unset<T>( pub fn check_unset<T>(
@ -101,6 +104,8 @@ impl EmbeddingSettings {
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject"; pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
pub const INPUT_TYPE: &'static str = "inputType"; pub const INPUT_TYPE: &'static str = "inputType";
pub const DISTRIBUTION: &'static str = "distribution";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field { match field {
Self::SOURCE => &[ Self::SOURCE => &[
@ -132,6 +137,13 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest], Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
Self::INPUT_TYPE => &[EmbedderSource::Rest], Self::INPUT_TYPE => &[EmbedderSource::Rest],
Self::DISTRIBUTION => &[
EmbedderSource::HuggingFace,
EmbedderSource::Ollama,
EmbedderSource::OpenAi,
EmbedderSource::Rest,
EmbedderSource::UserProvided,
],
_other => unreachable!("unknown field"), _other => unreachable!("unknown field"),
} }
} }
@ -144,14 +156,24 @@ impl EmbeddingSettings {
Self::API_KEY, Self::API_KEY,
Self::DOCUMENT_TEMPLATE, Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS, Self::DIMENSIONS,
Self::DISTRIBUTION,
], ],
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => &[
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] Self::SOURCE,
} Self::MODEL,
EmbedderSource::Ollama => { Self::REVISION,
&[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE, Self::URL, Self::API_KEY] Self::DOCUMENT_TEMPLATE,
} Self::DISTRIBUTION,
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], ],
EmbedderSource::Ollama => &[
Self::SOURCE,
Self::MODEL,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::API_KEY,
Self::DISTRIBUTION,
],
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS, Self::DISTRIBUTION],
EmbedderSource::Rest => &[ EmbedderSource::Rest => &[
Self::SOURCE, Self::SOURCE,
Self::API_KEY, Self::API_KEY,
@ -163,6 +185,7 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS, Self::PATH_TO_EMBEDDINGS,
Self::EMBEDDING_OBJECT, Self::EMBEDDING_OBJECT,
Self::INPUT_TYPE, Self::INPUT_TYPE,
Self::DISTRIBUTION,
], ],
} }
} }
@ -187,6 +210,66 @@ impl EmbeddingSettings {
*model = Setting::Set(openai::EmbeddingModel::default().name().to_owned()) *model = Setting::Set(openai::EmbeddingModel::default().name().to_owned())
} }
} }
pub(crate) fn apply_and_need_reindex(
old: &mut Setting<EmbeddingSettings>,
new: Setting<EmbeddingSettings>,
) -> bool {
match (old, new) {
(
Setting::Set(EmbeddingSettings {
source: old_source,
model: old_model,
revision: old_revision,
api_key: old_api_key,
dimensions: old_dimensions,
document_template: old_document_template,
url: old_url,
query: old_query,
input_field: old_input_field,
path_to_embeddings: old_path_to_embeddings,
embedding_object: old_embedding_object,
input_type: old_input_type,
distribution: old_distribution,
}),
Setting::Set(EmbeddingSettings {
source: new_source,
model: new_model,
revision: new_revision,
api_key: new_api_key,
dimensions: new_dimensions,
document_template: new_document_template,
url: new_url,
query: new_query,
input_field: new_input_field,
path_to_embeddings: new_path_to_embeddings,
embedding_object: new_embedding_object,
input_type: new_input_type,
distribution: new_distribution,
}),
) => {
let mut needs_reindex = false;
needs_reindex |= old_source.apply(new_source);
needs_reindex |= old_model.apply(new_model);
needs_reindex |= old_revision.apply(new_revision);
needs_reindex |= old_dimensions.apply(new_dimensions);
needs_reindex |= old_document_template.apply(new_document_template);
needs_reindex |= old_url.apply(new_url);
needs_reindex |= old_query.apply(new_query);
needs_reindex |= old_input_field.apply(new_input_field);
needs_reindex |= old_path_to_embeddings.apply(new_path_to_embeddings);
needs_reindex |= old_embedding_object.apply(new_embedding_object);
needs_reindex |= old_input_type.apply(new_input_type);
old_distribution.apply(new_distribution);
old_api_key.apply(new_api_key);
needs_reindex
}
(Setting::Reset, Setting::Reset) | (_, Setting::NotSet) => false,
_ => true,
}
}
} }
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] #[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
@ -214,58 +297,6 @@ impl std::fmt::Display for EmbedderSource {
} }
} }
impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) {
let EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = new;
let old_source = self.source;
self.source.apply(source);
// Reinitialize the whole setting object on a source change
if old_source != self.source {
*self = EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
};
return;
}
self.model.apply(model);
self.revision.apply(revision);
self.api_key.apply(api_key);
self.dimensions.apply(dimensions);
self.document_template.apply(document_template);
self.url.apply(url);
self.query.apply(query);
self.input_field.apply(input_field);
self.path_to_embeddings.apply(path_to_embeddings);
self.embedding_object.apply(embedding_object);
self.input_type.apply(input_type);
}
}
impl From<EmbeddingConfig> for EmbeddingSettings { impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self { fn from(value: EmbeddingConfig) -> Self {
let EmbeddingConfig { embedder_options, prompt } = value; let EmbeddingConfig { embedder_options, prompt } = value;
@ -283,6 +314,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::OpenAi(options) => Self { super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi), source: Setting::Set(EmbedderSource::OpenAi),
@ -297,6 +329,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::Ollama(options) => Self { super::EmbedderOptions::Ollama(options) => Self {
source: Setting::Set(EmbedderSource::Ollama), source: Setting::Set(EmbedderSource::Ollama),
@ -311,6 +344,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::UserProvided(options) => Self { super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided), source: Setting::Set(EmbedderSource::UserProvided),
@ -325,11 +359,10 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
}, },
super::EmbedderOptions::Rest(super::rest::EmbedderOptions { super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key, api_key,
// TODO: support distribution
distribution: _,
dimensions, dimensions,
url, url,
query, query,
@ -337,6 +370,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
}) => Self { }) => Self {
source: Setting::Set(EmbedderSource::Rest), source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet, model: Setting::NotSet,
@ -350,6 +384,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::Set(path_to_embeddings), path_to_embeddings: Setting::Set(path_to_embeddings),
embedding_object: Setting::Set(embedding_object), embedding_object: Setting::Set(embedding_object),
input_type: Setting::Set(input_type), input_type: Setting::Set(input_type),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
}, },
} }
} }
@ -371,7 +406,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
} = value; } = value;
if let Some(source) = source.set() { if let Some(source) = source.set() {
match source { match source {
EmbedderSource::OpenAi => { EmbedderSource::OpenAi => {
@ -387,6 +424,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(dimensions) = dimensions.set() { if let Some(dimensions) = dimensions.set() {
options.dimensions = Some(dimensions); options.dimensions = Some(dimensions);
} }
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::OpenAi(options); this.embedder_options = super::EmbedderOptions::OpenAi(options);
} }
EmbedderSource::Ollama => { EmbedderSource::Ollama => {
@ -399,6 +437,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
options.embedding_model = model; options.embedding_model = model;
} }
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::Ollama(options); this.embedder_options = super::EmbedderOptions::Ollama(options);
} }
EmbedderSource::HuggingFace => { EmbedderSource::HuggingFace => {
@ -415,12 +454,14 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(revision) = revision.set() { if let Some(revision) = revision.set() {
options.revision = Some(revision); options.revision = Some(revision);
} }
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::HuggingFace(options); this.embedder_options = super::EmbedderOptions::HuggingFace(options);
} }
EmbedderSource::UserProvided => { EmbedderSource::UserProvided => {
this.embedder_options = this.embedder_options =
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions { super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions: dimensions.set().unwrap(), dimensions: dimensions.set().unwrap(),
distribution: distribution.set(),
}); });
} }
EmbedderSource::Rest => { EmbedderSource::Rest => {
@ -429,7 +470,6 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
this.embedder_options = this.embedder_options =
super::EmbedderOptions::Rest(super::rest::EmbedderOptions { super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key: api_key.set(), api_key: api_key.set(),
distribution: None,
dimensions: dimensions.set(), dimensions: dimensions.set(),
url: url.set().unwrap(), url: url.set().unwrap(),
query: query.set().unwrap_or(embedder_options.query), query: query.set().unwrap_or(embedder_options.query),
@ -441,6 +481,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
.set() .set()
.unwrap_or(embedder_options.embedding_object), .unwrap_or(embedder_options.embedding_object),
input_type: input_type.set().unwrap_or(embedder_options.input_type), input_type: input_type.set().unwrap_or(embedder_options.input_type),
distribution: distribution.set(),
}) })
} }
} }