WIP multi embedders

fixed template bugs
This commit is contained in:
Louis Dureuil 2023-12-12 21:19:48 +01:00
parent abbe131084
commit 922a640188
No known key found for this signature in database
20 changed files with 438 additions and 158 deletions

View file

@ -65,6 +65,8 @@ pub enum EmbedErrorKind {
OpenAiTooManyTokens(OpenAiError),
#[error("received unhandled HTTP status code {0} from OpenAI")]
OpenAiUnhandledStatusCode(u16),
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
ManualEmbed(String),
}
impl EmbedError {
@ -111,6 +113,10 @@ impl EmbedError {
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
}
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
}
}
#[derive(Debug, thiserror::Error)]
@ -170,6 +176,13 @@ impl NewEmbedderError {
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
}
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
Self {
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
fault: FaultSource::Runtime,
}
}
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
}
@ -219,6 +232,8 @@ pub enum NewEmbedderErrorKind {
NewApiFail(ApiError),
#[error("fetching file from HG_HUB failed: {0}")]
ApiGet(ApiError),
#[error("could not determine model dimensions: test embedding failed with {0}")]
CouldNotDetermineDimension(EmbedError),
#[error("loading model failed: {0}")]
LoadModel(candle_core::Error),
// openai

View file

@ -62,6 +62,7 @@ pub struct Embedder {
model: BertModel,
tokenizer: Tokenizer,
options: EmbedderOptions,
dimensions: usize,
}
impl std::fmt::Debug for Embedder {
@ -126,10 +127,17 @@ impl Embedder {
tokenizer.with_padding(Some(pp));
}
Ok(Self { model, tokenizer, options })
let mut this = Self { model, tokenizer, options, dimensions: 0 };
let embeddings = this
.embed(vec!["test".into()])
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
this.dimensions = embeddings.first().unwrap().dimension();
Ok(this)
}
pub async fn embed(
pub fn embed(
&self,
mut texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
@ -170,12 +178,11 @@ impl Embedder {
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
}
pub async fn embed_chunks(
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
.await
text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect()
}
pub fn chunk_count_hint(&self) -> usize {
@ -185,6 +192,10 @@ impl Embedder {
pub fn prompt_count_in_chunk_hint(&self) -> usize {
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
}
pub fn dimensions(&self) -> usize {
self.dimensions
}
}
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {

View file

@ -3,6 +3,7 @@ use crate::prompt::PromptData;
pub mod error;
pub mod hf;
pub mod manual;
pub mod openai;
pub mod settings;
@ -67,6 +68,7 @@ impl<F> Embeddings<F> {
pub enum Embedder {
HuggingFace(hf::Embedder),
OpenAi(openai::Embedder),
UserProvided(manual::Embedder),
}
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
@ -80,6 +82,7 @@ pub struct EmbeddingConfig {
pub enum EmbedderOptions {
HuggingFace(hf::EmbedderOptions),
OpenAi(openai::EmbedderOptions),
UserProvided(manual::EmbedderOptions),
}
impl Default for EmbedderOptions {
@ -93,7 +96,7 @@ impl EmbedderOptions {
Self::HuggingFace(hf::EmbedderOptions::new())
}
pub fn openai(api_key: String) -> Self {
pub fn openai(api_key: Option<String>) -> Self {
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
}
}
@ -103,6 +106,9 @@ impl Embedder {
Ok(match options {
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options))
}
})
}
@ -111,8 +117,9 @@ impl Embedder {
texts: Vec<String>,
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed(texts).await,
Embedder::HuggingFace(embedder) => embedder.embed(texts),
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
Embedder::UserProvided(embedder) => embedder.embed(texts),
}
}
@ -121,8 +128,9 @@ impl Embedder {
text_chunks: Vec<Vec<String>>,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await,
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
}
}
@ -130,6 +138,7 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
Embedder::UserProvided(_) => 1,
}
}
@ -137,6 +146,15 @@ impl Embedder {
match self {
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
Embedder::UserProvided(_) => 1,
}
}
pub fn dimensions(&self) -> usize {
match self {
Embedder::HuggingFace(embedder) => embedder.dimensions(),
Embedder::OpenAi(embedder) => embedder.dimensions(),
Embedder::UserProvided(embedder) => embedder.dimensions(),
}
}
}

View file

@ -15,7 +15,7 @@ pub struct Embedder {
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
pub api_key: String,
pub api_key: Option<String>,
pub embedding_model: EmbeddingModel,
}
@ -68,11 +68,11 @@ impl EmbeddingModel {
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
impl EmbedderOptions {
pub fn with_default_model(api_key: String) -> Self {
pub fn with_default_model(api_key: Option<String>) -> Self {
Self { api_key, embedding_model: Default::default() }
}
pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self {
pub fn with_embedding_model(api_key: Option<String>, embedding_model: EmbeddingModel) -> Self {
Self { api_key, embedding_model }
}
}
@ -80,9 +80,14 @@ impl EmbedderOptions {
impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
let mut headers = reqwest::header::HeaderMap::new();
let mut inferred_api_key = Default::default();
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
inferred_api_key = infer_api_key();
&inferred_api_key
});
headers.insert(
reqwest::header::AUTHORIZATION,
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key))
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
);
headers.insert(
@ -315,6 +320,10 @@ impl Embedder {
pub fn prompt_count_in_chunk_hint(&self) -> usize {
10
}
pub fn dimensions(&self) -> usize {
self.options.embedding_model.dimensions()
}
}
// retrying in case of failure
@ -414,3 +423,9 @@ struct OpenAiEmbedding {
// object: String,
// index: usize,
}
fn infer_api_key() -> String {
std::env::var("MEILI_OPENAI_API_KEY")
.or_else(|_| std::env::var("OPENAI_API_KEY"))
.unwrap_or_default()
}

View file

@ -15,14 +15,14 @@ pub struct EmbeddingSettings {
pub embedder_options: Setting<EmbedderSettings>,
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
#[deserr(default)]
pub prompt: Setting<PromptSettings>,
pub document_template: Setting<PromptSettings>,
}
impl EmbeddingSettings {
pub fn apply(&mut self, new: Self) {
let EmbeddingSettings { embedder_options, prompt } = new;
let EmbeddingSettings { embedder_options, document_template: prompt } = new;
self.embedder_options.apply(embedder_options);
self.prompt.apply(prompt);
self.document_template.apply(prompt);
}
}
@ -30,7 +30,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
fn from(value: EmbeddingConfig) -> Self {
Self {
embedder_options: Setting::Set(value.embedder_options.into()),
prompt: Setting::Set(value.prompt.into()),
document_template: Setting::Set(value.prompt.into()),
}
}
}
@ -38,7 +38,7 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
impl From<EmbeddingSettings> for EmbeddingConfig {
fn from(value: EmbeddingSettings) -> Self {
let mut this = Self::default();
let EmbeddingSettings { embedder_options, prompt } = value;
let EmbeddingSettings { embedder_options, document_template: prompt } = value;
if let Some(embedder_options) = embedder_options.set() {
this.embedder_options = embedder_options.into();
}
@ -105,6 +105,7 @@ impl From<PromptSettings> for PromptData {
pub enum EmbedderSettings {
HuggingFace(Setting<HfEmbedderSettings>),
OpenAi(Setting<OpenAiEmbedderSettings>),
UserProvided(UserProvidedSettings),
}
impl<E> Deserr<E> for EmbedderSettings
@ -145,11 +146,17 @@ where
location.push_key(&k),
)?,
))),
"userProvided" => Ok(EmbedderSettings::UserProvided(
UserProvidedSettings::deserialize_from_value(
v.into_value(),
location.push_key(&k),
)?,
)),
other => Err(deserr::take_cf_content(E::error::<V>(
None,
deserr::ErrorKind::UnknownKey {
key: other,
accepted: &["huggingFace", "openAi"],
accepted: &["huggingFace", "openAi", "userProvided"],
},
location,
))),
@ -182,6 +189,9 @@ impl From<crate::vector::EmbedderOptions> for EmbedderSettings {
crate::vector::EmbedderOptions::OpenAi(openai) => {
Self::OpenAi(Setting::Set(openai.into()))
}
crate::vector::EmbedderOptions::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
}
}
}
@ -192,9 +202,12 @@ impl From<EmbedderSettings> for crate::vector::EmbedderOptions {
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
EmbedderSettings::OpenAi(_setting) => Self::OpenAi(
crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()),
),
EmbedderSettings::OpenAi(_setting) => {
Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None))
}
EmbedderSettings::UserProvided(user_provided) => {
Self::UserProvided(user_provided.into())
}
}
}
}
@ -286,7 +299,7 @@ impl OpenAiEmbedderSettings {
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
Self {
api_key: Setting::Set(value.api_key),
api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset),
embedding_model: Setting::Set(value.embedding_model),
}
}
@ -295,14 +308,25 @@ impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
fn from(value: OpenAiEmbedderSettings) -> Self {
let OpenAiEmbedderSettings { api_key, embedding_model } = value;
Self {
api_key: api_key.set().unwrap_or_else(infer_api_key),
embedding_model: embedding_model.set().unwrap_or_default(),
}
Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() }
}
}
fn infer_api_key() -> String {
/// FIXME: get key from instance options?
std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default()
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub struct UserProvidedSettings {
pub dimensions: usize,
}
impl From<UserProvidedSettings> for crate::vector::manual::EmbedderOptions {
fn from(value: UserProvidedSettings) -> Self {
Self { dimensions: value.dimensions }
}
}
impl From<crate::vector::manual::EmbedderOptions> for UserProvidedSettings {
fn from(value: crate::vector::manual::EmbedderOptions) -> Self {
Self { dimensions: value.dimensions }
}
}