Add distribution to all embedders

This commit is contained in:
Louis Dureuil 2024-03-27 11:50:22 +01:00
parent 9a95ed619d
commit afd1da5642
No known key found for this signature in database
7 changed files with 41 additions and 18 deletions

View File

@ -2652,6 +2652,7 @@ mod tests {
path_to_embeddings: Setting::NotSet, path_to_embeddings: Setting::NotSet,
embedding_object: Setting::NotSet, embedding_object: Setting::NotSet,
input_type: Setting::NotSet, input_type: Setting::NotSet,
distribution: Setting::NotSet,
}), }),
); );
settings.set_embedder_settings(embedders); settings.set_embedder_settings(embedders);

View File

@ -1146,6 +1146,7 @@ fn validate_prompt(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
}) => { }) => {
// validate // validate
let template = crate::prompt::Prompt::new(template) let template = crate::prompt::Prompt::new(template)
@ -1165,6 +1166,7 @@ fn validate_prompt(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
})) }))
} }
new => Ok(new), new => Ok(new),
@ -1190,6 +1192,7 @@ pub fn validate_embedding_settings(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
} = settings; } = settings;
if let Some(0) = dimensions.set() { if let Some(0) = dimensions.set() {
@ -1221,6 +1224,7 @@ pub fn validate_embedding_settings(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
})); }));
}; };
match inferred_source { match inferred_source {
@ -1365,6 +1369,7 @@ pub fn validate_embedding_settings(
path_to_embeddings, path_to_embeddings,
embedding_object, embedding_object,
input_type, input_type,
distribution,
})) }))
} }

View File

@ -33,6 +33,7 @@ enum WeightSource {
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub model: String, pub model: String,
pub revision: Option<String>, pub revision: Option<String>,
pub distribution: Option<DistributionShift>,
} }
impl EmbedderOptions { impl EmbedderOptions {
@ -40,6 +41,7 @@ impl EmbedderOptions {
Self { Self {
model: "BAAI/bge-base-en-v1.5".to_string(), model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
distribution: None,
} }
} }
} }
@ -193,6 +195,7 @@ impl Embedder {
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.distribution.or_else(|| {
if self.options.model == "BAAI/bge-base-en-v1.5" { if self.options.model == "BAAI/bge-base-en-v1.5" {
Some(DistributionShift { Some(DistributionShift {
current_mean: ordered_float::OrderedFloat(0.85), current_mean: ordered_float::OrderedFloat(0.85),
@ -201,5 +204,6 @@ impl Embedder {
} else { } else {
None None
} }
})
} }
} }

View File

@ -1,19 +1,21 @@
use super::error::EmbedError; use super::error::EmbedError;
use super::Embeddings; use super::{DistributionShift, Embeddings};
#[derive(Debug, Clone, Copy)] #[derive(Debug, Clone, Copy)]
pub struct Embedder { pub struct Embedder {
dimensions: usize, dimensions: usize,
distribution: Option<DistributionShift>,
} }
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions { pub struct EmbedderOptions {
pub dimensions: usize, pub dimensions: usize,
pub distribution: Option<DistributionShift>,
} }
impl Embedder { impl Embedder {
pub fn new(options: EmbedderOptions) -> Self { pub fn new(options: EmbedderOptions) -> Self {
Self { dimensions: options.dimensions } Self { dimensions: options.dimensions, distribution: options.distribution }
} }
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
@ -31,4 +33,8 @@ impl Embedder {
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> { ) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
} }
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution
}
} }

View File

@ -1,6 +1,7 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use deserr::{DeserializeError, Deserr};
use ordered_float::OrderedFloat; use ordered_float::OrderedFloat;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -292,7 +293,7 @@ impl Embedder {
Embedder::HuggingFace(embedder) => embedder.distribution(), Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(), Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(), Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None, Embedder::UserProvided(embedder) => embedder.distribution(),
Embedder::Rest(embedder) => embedder.distribution(), Embedder::Rest(embedder) => embedder.distribution(),
} }
} }

View File

@ -14,11 +14,12 @@ pub struct EmbedderOptions {
pub embedding_model: String, pub embedding_model: String,
pub url: Option<String>, pub url: Option<String>,
pub api_key: Option<String>, pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
} }
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>, url: Option<String>) -> Self { pub fn with_default_model(api_key: Option<String>, url: Option<String>) -> Self {
Self { embedding_model: "nomic-embed-text".into(), api_key, url } Self { embedding_model: "nomic-embed-text".into(), api_key, url, distribution: None }
} }
} }
@ -27,8 +28,8 @@ impl Embedder {
let model = options.embedding_model.as_str(); let model = options.embedding_model.as_str();
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
api_key: options.api_key, api_key: options.api_key,
distribution: None,
dimensions: None, dimensions: None,
distribution: options.distribution,
url: options.url.unwrap_or_else(get_ollama_path), url: options.url.unwrap_or_else(get_ollama_path),
query: serde_json::json!({ query: serde_json::json!({
"model": model, "model": model,
@ -90,7 +91,7 @@ impl Embedder {
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
None self.rest_embedder.distribution()
} }
} }

View File

@ -11,6 +11,7 @@ pub struct EmbedderOptions {
pub api_key: Option<String>, pub api_key: Option<String>,
pub embedding_model: EmbeddingModel, pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>, pub dimensions: Option<usize>,
pub distribution: Option<DistributionShift>,
} }
impl EmbedderOptions { impl EmbedderOptions {
@ -37,6 +38,10 @@ impl EmbedderOptions {
query query
} }
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution.or(self.embedding_model.distribution())
}
} }
#[derive( #[derive(
@ -139,11 +144,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions { impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self { pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default(), dimensions: None } Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None }
} }
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self { pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model, dimensions: None } Self { api_key, embedding_model, dimensions: None, distribution: None }
} }
} }
@ -170,7 +175,7 @@ impl Embedder {
let rest_embedder = RestEmbedder::new(RestEmbedderOptions { let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
api_key: Some(api_key.clone()), api_key: Some(api_key.clone()),
distribution: options.embedding_model.distribution(), distribution: None,
dimensions: Some(options.dimensions()), dimensions: Some(options.dimensions()),
url: OPENAI_EMBEDDINGS_URL.to_owned(), url: OPENAI_EMBEDDINGS_URL.to_owned(),
query: options.query(), query: options.query(),
@ -256,6 +261,6 @@ impl Embedder {
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution() self.options.distribution()
} }
} }