WIP multi embedders

fixed template bugs
This commit is contained in:
Louis Dureuil 2023-12-12 21:19:48 +01:00
parent abbe131084
commit 922a640188
No known key found for this signature in database
20 changed files with 438 additions and 158 deletions

View file

@ -3,6 +3,7 @@ use crate::prompt::PromptData;
pub mod error;
pub mod hf;
pub mod manual;
pub mod openai;
pub mod settings;
@ -67,6 +68,7 @@ impl<F> Embeddings<F> {
pub enum Embedder {
HuggingFace(hf::Embedder),
OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
@ -80,6 +82,7 @@ pub struct EmbeddingConfig {
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
}
impl Default for EmbedderOptions {
@ -93,7 +96,7 @@ impl EmbedderOptions {
Self::HuggingFace(hf::EmbedderOptions::new())
}
pub fn openai(api_key: String) -> Self {
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}
}
@ -103,6 +106,9 @@ impl Embedder {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
})
}
@ -111,8 +117,9 @@ impl Embedder {
texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts).await,
Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
Embedder::UserProvided(embedder) => embedder.embed(texts),
}
}
@ -121,8 +128,9 @@ impl Embedder {
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await,
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
}
}
@ -130,6 +138,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
}
}
@ -137,6 +146,15 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
}
}
}