WIP multi embedders

fixed template bugs
This commit is contained in:
Louis Dureuil 2023-12-12 21:19:48 +01:00
parent abbe131084
commit 922a640188
No known key found for this signature in database
20 changed files with 438 additions and 158 deletions

View file

@ -15,7 +15,7 @@ pub struct Embedder {
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub api_key: String,
pub api_key: Option<String>,
pub embedding_model: EmbeddingModel,
}
@ -68,11 +68,11 @@ impl EmbeddingModel {
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions {
pub fn with_default_model(api_key: String) -> Self {
pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default() }
}
pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self {
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model }
}
}
@ -80,9 +80,14 @@ impl EmbedderOptions {
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key))
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
);
headers.insert(
@ -315,6 +320,10 @@ impl Embedder {
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
}
pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions()
}
}
// retrying in case of failure
@ -414,3 +423,9 @@ struct OpenAiEmbedding {
// object: String,
// index: usize,
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}