Add distribution to all embedders

This commit is contained in:
Louis Dureuil 2024-03-27 11:50:22 +01:00
parent 9a95ed619d
commit afd1da5642
No known key found for this signature in database
7 changed files with 41 additions and 18 deletions

View file

@ -11,6 +11,7 @@ pub struct EmbedderOptions {
pub api_key: Option<String>,
pub embedding_model: EmbeddingModel,
pub dimensions: Option<usize>,
pub distribution: Option<DistributionShift>,
}
impl EmbedderOptions {
@ -37,6 +38,10 @@ impl EmbedderOptions {
query
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.distribution.or(self.embedding_model.distribution())
}
}
#[derive(
@ -139,11 +144,11 @@ pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions {
pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default(), dimensions: None }
Self { api_key, embedding_model: Default::default(), dimensions: None, distribution: None }
}
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model, dimensions: None }
Self { api_key, embedding_model, dimensions: None, distribution: None }
}
}
@ -170,7 +175,7 @@ impl Embedder {
let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
api_key: Some(api_key.clone()),
distribution: options.embedding_model.distribution(),
distribution: None,
dimensions: Some(options.dimensions()),
url: OPENAI_EMBEDDINGS_URL.to_owned(),
query: options.query(),
@ -256,6 +261,6 @@ impl Embedder {
}
pub fn distribution(&self) -> Option<DistributionShift> {
self.options.embedding_model.distribution()
self.options.distribution()
}
}