4537: Expose distribution shift in settings r=ManyTheFish a=dureuill

See [usage page](https://meilisearch.notion.site/v1-8-AI-search-API-usage-135552d6e85a4a52bc7109be82aeca42#d652adc0890445658aaf36352dbc8802)

# Changes

- Distribution shift added to all embedders.
- Exposed in settings
- Changed the reindexing logic to not trigger a reindex operation when only the distribution shift or API key change

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2024-04-03 09:08:58 +00:00 committed by GitHub
commit 56bf8503db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 231 additions and 95 deletions

View File

@ -6,7 +6,7 @@ source: index-scheduler/src/lib.rs
[]
----------------------------------------------------------------------
### All Tasks:
0 {uid: 0, status: enqueued, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }}
0 {uid: 0, status: enqueued, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }}
----------------------------------------------------------------------
### Status:
enqueued [0,]

View File

@ -6,7 +6,7 @@ source: index-scheduler/src/lib.rs
[]
----------------------------------------------------------------------
### All Tasks:
0 {uid: 0, status: succeeded, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }}
0 {uid: 0, status: succeeded, details: { settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> } }, kind: SettingsUpdate { index_uid: "doggos", new_settings: Settings { displayed_attributes: NotSet, searchable_attributes: NotSet, filterable_attributes: NotSet, sortable_attributes: NotSet, ranking_rules: NotSet, stop_words: NotSet, non_separator_tokens: NotSet, separator_tokens: NotSet, dictionary: NotSet, synonyms: NotSet, distinct_attribute: NotSet, proximity_precision: NotSet, typo_tolerance: NotSet, faceting: NotSet, pagination: NotSet, embedders: Set({"default": Set(EmbeddingSettings { source: Set(Rest), model: NotSet, revision: NotSet, api_key: Set("My super secret"), dimensions: NotSet, document_template: NotSet, url: Set("http://localhost:7777"), query: NotSet, input_field: NotSet, path_to_embeddings: NotSet, embedding_object: NotSet, input_type: NotSet, distribution: NotSet })}), search_cutoff_ms: NotSet, _kind: PhantomData<meilisearch_types::settings::Unchecked> }, is_deletion: false, allow_index_creation: true }}
----------------------------------------------------------------------
### Status:
enqueued []

View File

@ -87,6 +87,38 @@ async fn simple_search() {
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###);
}
#[actix_rt::test]
async fn distribution_shift() {
let server = Server::new().await;
let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await;
let search = json!({"q": "Captain", "vector": [1.0, 1.0], "showRankingScore": true, "hybrid": {"semanticRatio": 1.0}});
let (response, code) = index.search_post(search.clone()).await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.990290343761444,"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":0.974341630935669,"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":0.9472135901451112,"_semanticScore":0.9472136}]"###);
let (response, code) = index
.update_settings(json!({
"embedders": {
"default": {
"distribution": {
"mean": 0.998,
"sigma": 0.01
}
}
}
}))
.await;
snapshot!(code, @"202 Accepted");
let response = server.wait_task(response.uid()).await;
snapshot!(response["details"], @r###"{"embedders":{"default":{"distribution":{"mean":0.998,"sigma":0.01}}}}"###);
let (response, code) = index.search_post(search).await;
snapshot!(code, @"200 OK");
snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_rankingScore":0.19161224365234375,"_semanticScore":0.19161224},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_rankingScore":1.1920928955078125e-7,"_semanticScore":1.1920929e-7},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_rankingScore":1.1920928955078125e-7,"_semanticScore":1.1920929e-7}]"###);
}
#[actix_rt::test]
async fn highlighter() {
let server = Server::new().await;

View File

@ -2652,6 +2652,7 @@ mod tests {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: Setting::NotSet,
}),
);
settings.set_embedder_settings(embedders);

View File

@ -976,7 +976,12 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
match joined {
// updated config
EitherOrBoth::Both((name, mut old), (_, new)) => {
changed |= old.apply(new);
changed |= EmbeddingSettings::apply_and_need_reindex(&mut old, new);
if changed {
tracing::debug!(embedder = name, "need reindex");
} else {
tracing::debug!(embedder = name, "skip reindex");
}
let new = validate_embedding_settings(old, &name)?;
new_configs.insert(name, new);
}
@ -1169,6 +1174,7 @@ fn validate_prompt(
path_to_embeddings,
embedding_object,
input_type,
distribution,
}) => {
// validate
let template = crate::prompt::Prompt::new(template)
@ -1188,6 +1194,7 @@ fn validate_prompt(
path_to_embeddings,
embedding_object,
input_type,
distribution,
}))
}
new => Ok(new),
@ -1213,6 +1220,7 @@ pub fn validate_embedding_settings(
path_to_embeddings,
embedding_object,
input_type,
distribution,
} = settings;
if let Some(0) = dimensions.set() {
@ -1244,6 +1252,7 @@ pub fn validate_embedding_settings(
path_to_embeddings,
embedding_object,
input_type,
distribution,
}));
};
match inferred_source {
@ -1388,6 +1397,7 @@ pub fn validate_embedding_settings(
path_to_embeddings,
embedding_object,
input_type,
distribution,
}))
}

View File

@ -33,6 +33,7 @@ enum WeightSource {
pub struct EmbedderOptions {
pub model: String,
pub revision: Option<String>,
pub distribution: Option<DistributionShift>,
}
impl EmbedderOptions {
@ -40,6 +41,7 @@ impl EmbedderOptions {
Self {
model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
distribution: None,
}
}
}
@ -193,13 +195,15 @@ impl Embedder {
}
pub fn distribution(&self) -> Option<DistributionShift> {
if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift {
current_mean: ordered_float::OrderedFloat(0.85),
current_sigma: ordered_float::OrderedFloat(0.1),
})
} else {
None
}
self.options.distribution.or_else(|| {
if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift {
current_mean: ordered_float::OrderedFloat(0.85),
current_sigma: ordered_float::OrderedFloat(0.1),
})
} else {
None
}
})
}
}

View File

@ -1,19 +1,21 @@
use super::error::EmbedError;
use super::Embeddings;
use super::{DistributionShift, Embeddings};
#[derive(Debug, Clone, Copy)]
pub struct Embedder {
dimensions: usize,
distribution: Option<DistributionShift>,
}
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub dimensions: usize,
pub distribution: Option<DistributionShift>,
}
impl Embedder {
pub fn new(options: EmbedderOptions) -> Self {
Self { dimensions: options.dimensions }
Self { dimensions: options.dimensions, distribution: options.distribution }
}
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
@ -31,4 +33,8 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution
}
}

View File

@ -1,6 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;
use deserr::{DeserializeError, Deserr};
use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize};
@ -292,7 +293,7 @@ impl Embedder {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
Embedder::UserProvided(embedder) => embedder.distribution(),
Embedder::Rest(embedder) => embedder.distribution(),
}
}
@ -317,10 +318,50 @@ pub struct DistributionShift {
pub current_sigma: OrderedFloat<f32>,
}
#[derive(Serialize, Deserialize)]
impl<E> Deserr<E> for DistributionShift
where
E: DeserializeError,
{
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
let value = DistributionShiftSerializable::deserialize_from_value(value, location)?;
if value.mean < 0. || value.mean > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution mean must be in the range [0, 1], got {}",
value.mean
),
},
location,
)));
}
if value.sigma <= 0. || value.sigma > 1. {
return Err(deserr::take_cf_content(E::error::<std::convert::Infallible>(
None,
deserr::ErrorKind::Unexpected {
msg: format!(
"the distribution sigma must be in the range ]0, 1], got {}",
value.sigma
),
},
location,
)));
}
Ok(value.into())
}
}
#[derive(Serialize, Deserialize, Deserr)]
#[serde(deny_unknown_fields)]
#[deserr(deny_unknown_fields)]
struct DistributionShiftSerializable {
current_mean: f32,
current_sigma: f32,
mean: f32,
sigma: f32,
}
impl From<DistributionShift> for DistributionShiftSerializable {
@ -330,18 +371,13 @@ impl From<DistributionShift> for DistributionShiftSerializable {
current_sigma: OrderedFloat(current_sigma),
}: DistributionShift,
) -> Self {
Self { current_mean, current_sigma }
Self { mean: current_mean, sigma: current_sigma }
}
}
impl From<DistributionShiftSerializable> for DistributionShift {
fn from(
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable,
) -> Self {
Self {
current_mean: OrderedFloat(current_mean),
current_sigma: OrderedFloat(current_sigma),
}
fn from(DistributionShiftSerializable { mean, sigma }: DistributionShiftSerializable) -> Self {
Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }
}
}

View File

@ -14,11 +14,12 @@ pub struct EmbedderOptions {
pub embedding_model: String,
pub url: Option<String>,
pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
}
impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>, url: Option<String>) -> Self {
Self { embedding_model: "nomic-embed-text".into(), api_key, url }
Self { embedding_model: "nomic-embed-text".into(), api_key, url, distribution: None }
}
}
@ -27,8 +28,8 @@ impl Embedder {
let model = options.embedding_model.as_str();
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
api_key: options.api_key,
distribution: None,
dimensions: None,
distribution: options.distribution,
url: options.url.unwrap_or_else(get_ollama_path),
query: serde_json::json!({
"model": model,
@ -90,7 +91,7 @@ impl Embedder {
}
pub fn distribution(&self) -> Option<DistributionShift> {
None
self.rest_embedder.distribution()
}
}

View File

@ -11,6 +11,7 @@ pub struct EmbedderOptions {
pub api_key: Option<String>,
pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>,
pub distribution: Option<DistributionShift>,
}
impl EmbedderOptions {
@ -37,6 +38,10 @@ impl EmbedderOptions {
query
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution.or(self.embedding_model.distribution())
}
}
#[derive(
@ -139,11 +144,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default(), dimensions: None }
Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None }
}
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model, dimensions: None }
Self { api_key, embedding_model, dimensions: None, distribution: None }
}
}
@ -170,7 +175,7 @@ impl Embedder {
let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
api_key: Some(api_key.clone()),
distribution: options.embedding_model.distribution(),
distribution: None,
dimensions: Some(options.dimensions()),
url: OPENAI_EMBEDDINGS_URL.to_owned(),
query: options.query(),
@ -256,6 +261,6 @@ impl Embedder {
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
self.options.distribution()
}
}

View File

@ -2,7 +2,7 @@ use deserr::Deserr;
use serde::{Deserialize, Serialize};
use super::rest::InputType;
use super::{ollama, openai};
use super::{ollama, openai, DistributionShift};
use crate::prompt::PromptData;
use crate::update::Setting;
use crate::vector::EmbeddingConfig;
@ -48,6 +48,9 @@ pub struct EmbeddingSettings {
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub input_type: Setting<InputType>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub distribution: Setting<DistributionShift>,
}
pub fn check_unset<T>(
@ -101,6 +104,8 @@ impl EmbeddingSettings {
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
pub const INPUT_TYPE: &'static str = "inputType";
pub const DISTRIBUTION: &'static str = "distribution";
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
match field {
Self::SOURCE => &[
@ -132,6 +137,13 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
Self::INPUT_TYPE => &[EmbedderSource::Rest],
Self::DISTRIBUTION => &[
EmbedderSource::HuggingFace,
EmbedderSource::Ollama,
EmbedderSource::OpenAi,
EmbedderSource::Rest,
EmbedderSource::UserProvided,
],
_other => unreachable!("unknown field"),
}
}
@ -144,14 +156,24 @@ impl EmbeddingSettings {
Self::API_KEY,
Self::DOCUMENT_TEMPLATE,
Self::DIMENSIONS,
Self::DISTRIBUTION,
],
EmbedderSource::HuggingFace => {
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
}
EmbedderSource::Ollama => {
&[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE, Self::URL, Self::API_KEY]
}
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
EmbedderSource::HuggingFace => &[
Self::SOURCE,
Self::MODEL,
Self::REVISION,
Self::DOCUMENT_TEMPLATE,
Self::DISTRIBUTION,
],
EmbedderSource::Ollama => &[
Self::SOURCE,
Self::MODEL,
Self::DOCUMENT_TEMPLATE,
Self::URL,
Self::API_KEY,
Self::DISTRIBUTION,
],
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS, Self::DISTRIBUTION],
EmbedderSource::Rest => &[
Self::SOURCE,
Self::API_KEY,
@ -163,6 +185,7 @@ impl EmbeddingSettings {
Self::PATH_TO_EMBEDDINGS,
Self::EMBEDDING_OBJECT,
Self::INPUT_TYPE,
Self::DISTRIBUTION,
],
}
}
@ -187,6 +210,66 @@ impl EmbeddingSettings {
*model = Setting::Set(openai::EmbeddingModel::default().name().to_owned())
}
}
pub(crate) fn apply_and_need_reindex(
old: &mut Setting<EmbeddingSettings>,
new: Setting<EmbeddingSettings>,
) -> bool {
match (old, new) {
(
Setting::Set(EmbeddingSettings {
source: old_source,
model: old_model,
revision: old_revision,
api_key: old_api_key,
dimensions: old_dimensions,
document_template: old_document_template,
url: old_url,
query: old_query,
input_field: old_input_field,
path_to_embeddings: old_path_to_embeddings,
embedding_object: old_embedding_object,
input_type: old_input_type,
distribution: old_distribution,
}),
Setting::Set(EmbeddingSettings {
source: new_source,
model: new_model,
revision: new_revision,
api_key: new_api_key,
dimensions: new_dimensions,
document_template: new_document_template,
url: new_url,
query: new_query,
input_field: new_input_field,
path_to_embeddings: new_path_to_embeddings,
embedding_object: new_embedding_object,
input_type: new_input_type,
distribution: new_distribution,
}),
) => {
let mut needs_reindex = false;
needs_reindex |= old_source.apply(new_source);
needs_reindex |= old_model.apply(new_model);
needs_reindex |= old_revision.apply(new_revision);
needs_reindex |= old_dimensions.apply(new_dimensions);
needs_reindex |= old_document_template.apply(new_document_template);
needs_reindex |= old_url.apply(new_url);
needs_reindex |= old_query.apply(new_query);
needs_reindex |= old_input_field.apply(new_input_field);
needs_reindex |= old_path_to_embeddings.apply(new_path_to_embeddings);
needs_reindex |= old_embedding_object.apply(new_embedding_object);
needs_reindex |= old_input_type.apply(new_input_type);
old_distribution.apply(new_distribution);
old_api_key.apply(new_api_key);
needs_reindex
}
(Setting::Reset, Setting::Reset) | (_, Setting::NotSet) => false,
_ => true,
}
}
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
@ -214,58 +297,6 @@ impl std::fmt::Display for EmbedderSource {
}
}
impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) {
let EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
} = new;
let old_source = self.source;
self.source.apply(source);
// Reinitialize the whole setting object on a source change
if old_source != self.source {
*self = EmbeddingSettings {
source,
model,
revision,
api_key,
dimensions,
document_template,
url,
query,
input_field,
path_to_embeddings,
embedding_object,
input_type,
};
return;
}
self.model.apply(model);
self.revision.apply(revision);
self.api_key.apply(api_key);
self.dimensions.apply(dimensions);
self.document_template.apply(document_template);
self.url.apply(url);
self.query.apply(query);
self.input_field.apply(input_field);
self.path_to_embeddings.apply(path_to_embeddings);
self.embedding_object.apply(embedding_object);
self.input_type.apply(input_type);
}
}
impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self {
let EmbeddingConfig { embedder_options, prompt } = value;
@ -283,6 +314,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::OpenAi(options) => Self {
source: Setting::Set(EmbedderSource::OpenAi),
@ -297,6 +329,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::Ollama(options) => Self {
source: Setting::Set(EmbedderSource::Ollama),
@ -311,6 +344,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::UserProvided(options) => Self {
source: Setting::Set(EmbedderSource::UserProvided),
@ -325,11 +359,10 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet,
input_type: Setting::NotSet,
distribution: options.distribution.map(Setting::Set).unwrap_or_default(),
},
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key,
// TODO: support distribution
distribution: _,
dimensions,
url,
query,
@ -337,6 +370,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings,
embedding_object,
input_type,
distribution,
}) => Self {
source: Setting::Set(EmbedderSource::Rest),
model: Setting::NotSet,
@ -350,6 +384,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
path_to_embeddings: Setting::Set(path_to_embeddings),
embedding_object: Setting::Set(embedding_object),
input_type: Setting::Set(input_type),
distribution: distribution.map(Setting::Set).unwrap_or_default(),
},
}
}
@ -371,7 +406,9 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
path_to_embeddings,
embedding_object,
input_type,
distribution,
} = value;
if let Some(source) = source.set() {
match source {
EmbedderSource::OpenAi => {
@ -387,6 +424,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(dimensions) = dimensions.set() {
options.dimensions = Some(dimensions);
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::OpenAi(options);
}
EmbedderSource::Ollama => {
@ -399,6 +437,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
options.embedding_model = model;
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::Ollama(options);
}
EmbedderSource::HuggingFace => {
@ -415,12 +454,14 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
if let Some(revision) = revision.set() {
options.revision = Some(revision);
}
options.distribution = distribution.set();
this.embedder_options = super::EmbedderOptions::HuggingFace(options);
}
EmbedderSource::UserProvided => {
this.embedder_options =
super::EmbedderOptions::UserProvided(super::manual::EmbedderOptions {
dimensions: dimensions.set().unwrap(),
distribution: distribution.set(),
});
}
EmbedderSource::Rest => {
@ -429,7 +470,6 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
this.embedder_options =
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
api_key: api_key.set(),
distribution: None,
dimensions: dimensions.set(),
url: url.set().unwrap(),
query: query.set().unwrap_or(embedder_options.query),
@ -441,6 +481,7 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
.set()
.unwrap_or(embedder_options.embedding_object),
input_type: input_type.set().unwrap_or(embedder_options.input_type),
distribution: distribution.set(),
})
}
}