diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 913fbc881..dbacb4002 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -2652,6 +2652,7 @@ mod tests { path_to_embeddings: Setting::NotSet, embedding_object: Setting::NotSet, input_type: Setting::NotSet, + distribution: Setting::NotSet, }), ); settings.set_embedder_settings(embedders); diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 2b1be9453..9f47768c1 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1146,6 +1146,7 @@ fn validate_prompt( path_to_embeddings, embedding_object, input_type, + distribution, }) => { // validate let template = crate::prompt::Prompt::new(template) @@ -1165,6 +1166,7 @@ fn validate_prompt( path_to_embeddings, embedding_object, input_type, + distribution, })) } new => Ok(new), @@ -1190,6 +1192,7 @@ pub fn validate_embedding_settings( path_to_embeddings, embedding_object, input_type, + distribution, } = settings; if let Some(0) = dimensions.set() { @@ -1221,6 +1224,7 @@ pub fn validate_embedding_settings( path_to_embeddings, embedding_object, input_type, + distribution, })); }; match inferred_source { @@ -1365,6 +1369,7 @@ pub fn validate_embedding_settings( path_to_embeddings, embedding_object, input_type, + distribution, })) } diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index e341a553e..725d702ec 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -33,6 +33,7 @@ enum WeightSource { pub struct EmbedderOptions { pub model: String, pub revision: Option, + pub distribution: Option, } impl EmbedderOptions { @@ -40,6 +41,7 @@ impl EmbedderOptions { Self { model: "BAAI/bge-base-en-v1.5".to_string(), revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), + distribution: None, } } } @@ -193,13 +195,15 @@ impl Embedder { } pub fn distribution(&self) -> Option { - if self.options.model == "BAAI/bge-base-en-v1.5" { - Some(DistributionShift { - current_mean: ordered_float::OrderedFloat(0.85), - current_sigma: ordered_float::OrderedFloat(0.1), - }) - } else { - None - } + self.options.distribution.or_else(|| { + if self.options.model == "BAAI/bge-base-en-v1.5" { + Some(DistributionShift { + current_mean: ordered_float::OrderedFloat(0.85), + current_sigma: ordered_float::OrderedFloat(0.1), + }) + } else { + None + } + }) } } diff --git a/milli/src/vector/manual.rs b/milli/src/vector/manual.rs index 7ed48a251..e5d3689c0 100644 --- a/milli/src/vector/manual.rs +++ b/milli/src/vector/manual.rs @@ -1,19 +1,21 @@ use super::error::EmbedError; -use super::Embeddings; +use super::{DistributionShift, Embeddings}; #[derive(Debug, Clone, Copy)] pub struct Embedder { dimensions: usize, + distribution: Option, } #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { pub dimensions: usize, + pub distribution: Option, } impl Embedder { pub fn new(options: EmbedderOptions) -> Self { - Self { dimensions: options.dimensions } + Self { dimensions: options.dimensions, distribution: options.distribution } } pub fn embed(&self, mut texts: Vec) -> Result>, EmbedError> { @@ -31,4 +33,8 @@ impl Embedder { ) -> Result>>, EmbedError> { text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() } + + pub fn distribution(&self) -> Option { + self.distribution + } } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 8b25de56d..4a3a9920e 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -1,6 +1,7 @@ use std::collections::HashMap; use std::sync::Arc; +use deserr::{DeserializeError, Deserr}; use ordered_float::OrderedFloat; use serde::{Deserialize, Serialize}; @@ -292,7 +293,7 @@ impl Embedder { Embedder::HuggingFace(embedder) => embedder.distribution(), Embedder::OpenAi(embedder) => embedder.distribution(), Embedder::Ollama(embedder) => embedder.distribution(), - Embedder::UserProvided(_embedder) => None, + Embedder::UserProvided(embedder) => embedder.distribution(), Embedder::Rest(embedder) => embedder.distribution(), } } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index 578b6c8e2..cf5030fb4 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -14,11 +14,12 @@ pub struct EmbedderOptions { pub embedding_model: String, pub url: Option, pub api_key: Option, + pub distribution: Option, } impl EmbedderOptions { pub fn with_default_model(api_key: Option, url: Option) -> Self { - Self { embedding_model: "nomic-embed-text".into(), api_key, url } + Self { embedding_model: "nomic-embed-text".into(), api_key, url, distribution: None } } } @@ -27,8 +28,8 @@ impl Embedder { let model = options.embedding_model.as_str(); let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { api_key: options.api_key, - distribution: None, dimensions: None, + distribution: options.distribution, url: options.url.unwrap_or_else(get_ollama_path), query: serde_json::json!({ "model": model, @@ -90,7 +91,7 @@ impl Embedder { } pub fn distribution(&self) -> Option { - None + self.rest_embedder.distribution() } } diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 24e94a9f7..141de486b 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -11,6 +11,7 @@ pub struct EmbedderOptions { pub api_key: Option, pub embedding_model: EmbeddingModel, pub dimensions: Option, + pub distribution: Option, } impl EmbedderOptions { @@ -37,6 +38,10 @@ impl EmbedderOptions { query } + + pub fn distribution(&self) -> Option { + self.distribution.or(self.embedding_model.distribution()) + } } #[derive( @@ -139,11 +144,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; impl EmbedderOptions { pub fn with_default_model(api_key: Option) -> Self { - Self { api_key, embedding_model: Default::default(), dimensions: None } + Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None } } pub fn with_embedding_model(api_key: Option, embedding_model: EmbeddingModel) -> Self { - Self { api_key, embedding_model, dimensions: None } + Self { api_key, embedding_model, dimensions: None, distribution: None } } } @@ -170,7 +175,7 @@ impl Embedder { let rest_embedder = RestEmbedder::new(RestEmbedderOptions { api_key: Some(api_key.clone()), - distribution: options.embedding_model.distribution(), + distribution: None, dimensions: Some(options.dimensions()), url: OPENAI_EMBEDDINGS_URL.to_owned(), query: options.query(), @@ -256,6 +261,6 @@ impl Embedder { } pub fn distribution(&self) -> Option { - self.options.embedding_model.distribution() + self.options.distribution() } }