mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-24 05:30:16 +01:00
Add distribution to all embedders
This commit is contained in:
parent
9a95ed619d
commit
afd1da5642
@ -2652,6 +2652,7 @@ mod tests {
|
|||||||
path_to_embeddings: Setting::NotSet,
|
path_to_embeddings: Setting::NotSet,
|
||||||
embedding_object: Setting::NotSet,
|
embedding_object: Setting::NotSet,
|
||||||
input_type: Setting::NotSet,
|
input_type: Setting::NotSet,
|
||||||
|
distribution: Setting::NotSet,
|
||||||
}),
|
}),
|
||||||
);
|
);
|
||||||
settings.set_embedder_settings(embedders);
|
settings.set_embedder_settings(embedders);
|
||||||
|
@ -1146,6 +1146,7 @@ fn validate_prompt(
|
|||||||
path_to_embeddings,
|
path_to_embeddings,
|
||||||
embedding_object,
|
embedding_object,
|
||||||
input_type,
|
input_type,
|
||||||
|
distribution,
|
||||||
}) => {
|
}) => {
|
||||||
// validate
|
// validate
|
||||||
let template = crate::prompt::Prompt::new(template)
|
let template = crate::prompt::Prompt::new(template)
|
||||||
@ -1165,6 +1166,7 @@ fn validate_prompt(
|
|||||||
path_to_embeddings,
|
path_to_embeddings,
|
||||||
embedding_object,
|
embedding_object,
|
||||||
input_type,
|
input_type,
|
||||||
|
distribution,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
new => Ok(new),
|
new => Ok(new),
|
||||||
@ -1190,6 +1192,7 @@ pub fn validate_embedding_settings(
|
|||||||
path_to_embeddings,
|
path_to_embeddings,
|
||||||
embedding_object,
|
embedding_object,
|
||||||
input_type,
|
input_type,
|
||||||
|
distribution,
|
||||||
} = settings;
|
} = settings;
|
||||||
|
|
||||||
if let Some(0) = dimensions.set() {
|
if let Some(0) = dimensions.set() {
|
||||||
@ -1221,6 +1224,7 @@ pub fn validate_embedding_settings(
|
|||||||
path_to_embeddings,
|
path_to_embeddings,
|
||||||
embedding_object,
|
embedding_object,
|
||||||
input_type,
|
input_type,
|
||||||
|
distribution,
|
||||||
}));
|
}));
|
||||||
};
|
};
|
||||||
match inferred_source {
|
match inferred_source {
|
||||||
@ -1365,6 +1369,7 @@ pub fn validate_embedding_settings(
|
|||||||
path_to_embeddings,
|
path_to_embeddings,
|
||||||
embedding_object,
|
embedding_object,
|
||||||
input_type,
|
input_type,
|
||||||
|
distribution,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ enum WeightSource {
|
|||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub revision: Option<String>,
|
pub revision: Option<String>,
|
||||||
|
pub distribution: Option<DistributionShift>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
@ -40,6 +41,7 @@ impl EmbedderOptions {
|
|||||||
Self {
|
Self {
|
||||||
model: "BAAI/bge-base-en-v1.5".to_string(),
|
model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||||
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
|
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
|
||||||
|
distribution: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -193,6 +195,7 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
|
self.options.distribution.or_else(|| {
|
||||||
if self.options.model == "BAAI/bge-base-en-v1.5" {
|
if self.options.model == "BAAI/bge-base-en-v1.5" {
|
||||||
Some(DistributionShift {
|
Some(DistributionShift {
|
||||||
current_mean: ordered_float::OrderedFloat(0.85),
|
current_mean: ordered_float::OrderedFloat(0.85),
|
||||||
@ -201,5 +204,6 @@ impl Embedder {
|
|||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,19 +1,21 @@
|
|||||||
use super::error::EmbedError;
|
use super::error::EmbedError;
|
||||||
use super::Embeddings;
|
use super::{DistributionShift, Embeddings};
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
|
distribution: Option<DistributionShift>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub dimensions: usize,
|
pub dimensions: usize,
|
||||||
|
pub distribution: Option<DistributionShift>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedder {
|
impl Embedder {
|
||||||
pub fn new(options: EmbedderOptions) -> Self {
|
pub fn new(options: EmbedderOptions) -> Self {
|
||||||
Self { dimensions: options.dimensions }
|
Self { dimensions: options.dimensions, distribution: options.distribution }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed(&self, mut texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
@ -31,4 +33,8 @@ impl Embedder {
|
|||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
|
self.distribution
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use deserr::{DeserializeError, Deserr};
|
||||||
use ordered_float::OrderedFloat;
|
use ordered_float::OrderedFloat;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
@ -292,7 +293,7 @@ impl Embedder {
|
|||||||
Embedder::HuggingFace(embedder) => embedder.distribution(),
|
Embedder::HuggingFace(embedder) => embedder.distribution(),
|
||||||
Embedder::OpenAi(embedder) => embedder.distribution(),
|
Embedder::OpenAi(embedder) => embedder.distribution(),
|
||||||
Embedder::Ollama(embedder) => embedder.distribution(),
|
Embedder::Ollama(embedder) => embedder.distribution(),
|
||||||
Embedder::UserProvided(_embedder) => None,
|
Embedder::UserProvided(embedder) => embedder.distribution(),
|
||||||
Embedder::Rest(embedder) => embedder.distribution(),
|
Embedder::Rest(embedder) => embedder.distribution(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14,11 +14,12 @@ pub struct EmbedderOptions {
|
|||||||
pub embedding_model: String,
|
pub embedding_model: String,
|
||||||
pub url: Option<String>,
|
pub url: Option<String>,
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
|
pub distribution: Option<DistributionShift>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
pub fn with_default_model(api_key: Option<String>, url: Option<String>) -> Self {
|
pub fn with_default_model(api_key: Option<String>, url: Option<String>) -> Self {
|
||||||
Self { embedding_model: "nomic-embed-text".into(), api_key, url }
|
Self { embedding_model: "nomic-embed-text".into(), api_key, url, distribution: None }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -27,8 +28,8 @@ impl Embedder {
|
|||||||
let model = options.embedding_model.as_str();
|
let model = options.embedding_model.as_str();
|
||||||
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
|
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
|
||||||
api_key: options.api_key,
|
api_key: options.api_key,
|
||||||
distribution: None,
|
|
||||||
dimensions: None,
|
dimensions: None,
|
||||||
|
distribution: options.distribution,
|
||||||
url: options.url.unwrap_or_else(get_ollama_path),
|
url: options.url.unwrap_or_else(get_ollama_path),
|
||||||
query: serde_json::json!({
|
query: serde_json::json!({
|
||||||
"model": model,
|
"model": model,
|
||||||
@ -90,7 +91,7 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
None
|
self.rest_embedder.distribution()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ pub struct EmbedderOptions {
|
|||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
pub embedding_model: EmbeddingModel,
|
pub embedding_model: EmbeddingModel,
|
||||||
pub dimensions: Option<usize>,
|
pub dimensions: Option<usize>,
|
||||||
|
pub distribution: Option<DistributionShift>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
@ -37,6 +38,10 @@ impl EmbedderOptions {
|
|||||||
|
|
||||||
query
|
query
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
|
self.distribution.or(self.embedding_model.distribution())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(
|
#[derive(
|
||||||
@ -139,11 +144,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
|
|||||||
|
|
||||||
impl EmbedderOptions {
|
impl EmbedderOptions {
|
||||||
pub fn with_default_model(api_key: Option<String>) -> Self {
|
pub fn with_default_model(api_key: Option<String>) -> Self {
|
||||||
Self { api_key, embedding_model: Default::default(), dimensions: None }
|
Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
|
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
|
||||||
Self { api_key, embedding_model, dimensions: None }
|
Self { api_key, embedding_model, dimensions: None, distribution: None }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,7 +175,7 @@ impl Embedder {
|
|||||||
|
|
||||||
let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
|
let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
|
||||||
api_key: Some(api_key.clone()),
|
api_key: Some(api_key.clone()),
|
||||||
distribution: options.embedding_model.distribution(),
|
distribution: None,
|
||||||
dimensions: Some(options.dimensions()),
|
dimensions: Some(options.dimensions()),
|
||||||
url: OPENAI_EMBEDDINGS_URL.to_owned(),
|
url: OPENAI_EMBEDDINGS_URL.to_owned(),
|
||||||
query: options.query(),
|
query: options.query(),
|
||||||
@ -256,6 +261,6 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
self.options.embedding_model.distribution()
|
self.options.distribution()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user