Revert "Revert "Merge remote-tracking branch 'origin/main' into release-v1.7.1""

This commit is contained in:
Tamo 2024-03-20 10:08:28 +01:00 committed by GitHub
parent c495c8eb33
commit c5322df519
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 1784 additions and 610 deletions

View file

@ -10,6 +10,8 @@ pub mod manual;
pub mod openai;
pub mod settings;
pub mod ollama;
pub use self::error::Error;
pub type Embedding = Vec<f32>;
@ -76,6 +78,7 @@ pub enum Embedder {
HuggingFace(hf::Embedder),
OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
Ollama(ollama::Embedder),
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
@ -127,6 +130,7 @@ impl IntoIterator for EmbeddingConfigs {
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
Ollama(ollama::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
}
@ -144,6 +148,10 @@ impl EmbedderOptions {
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}
pub fn ollama() -> Self {
Self::Ollama(ollama::EmbedderOptions::with_default_model())
}
}
impl Embedder {
@ -151,6 +159,7 @@ impl Embedder {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
@ -167,6 +176,10 @@ impl Embedder {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::Ollama(embedder) => {
let client = embedder.new_client()?;
embedder.embed(texts, &client).await
}
Embedder::UserProvided(embedder) => embedder.embed(texts),
}
}
@ -181,6 +194,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks),
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
}
}
@ -189,6 +203,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
}
}
@ -197,6 +212,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
@ -205,6 +221,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::Ollama(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
}
}
@ -213,6 +230,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.distribution(),
Embedder::OpenAi(embedder) => embedder.distribution(),
Embedder::Ollama(embedder) => embedder.distribution(),
Embedder::UserProvided(_embedder) => None,
}
}