From bc58e8a310aa0774265c1e51ec38162cff526da2 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Mar 2024 15:00:26 +0100 Subject: [PATCH 1/9] Documentation for the vector module --- milli/src/vector/mod.rs | 66 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 035ac555e..aeb0be1ca 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -16,46 +16,62 @@ pub use self::error::Error; pub type Embedding = Vec; +/// One or multiple embeddings stored consecutively in a flat vector. pub struct Embeddings { data: Vec, dimension: usize, } impl Embeddings { + /// Declares an empty vector of embeddings of the specified dimensions. pub fn new(dimension: usize) -> Self { Self { data: Default::default(), dimension } } + /// Declares a vector of embeddings containing a single element. + /// + /// The dimension is inferred from the length of the passed embedding. pub fn from_single_embedding(embedding: Vec) -> Self { Self { dimension: embedding.len(), data: embedding } } + /// Declares a vector of embeddings from its components. + /// + /// `data.len()` must be a multiple of `dimension`, otherwise an error is returned. pub fn from_inner(data: Vec, dimension: usize) -> Result> { let mut this = Self::new(dimension); this.append(data)?; Ok(this) } + /// Returns the number of embeddings in this vector of embeddings. pub fn embedding_count(&self) -> usize { self.data.len() / self.dimension } + /// Dimension of a single embedding. pub fn dimension(&self) -> usize { self.dimension } + /// Deconstructs self into the inner flat vector. pub fn into_inner(self) -> Vec { self.data } + /// A reference to the inner flat vector. pub fn as_inner(&self) -> &[F] { &self.data } + /// Iterates over the embeddings contained in the flat vector. pub fn iter(&self) -> impl Iterator + '_ { self.data.as_slice().chunks_exact(self.dimension) } + /// Push an embedding at the end of the embeddings. + /// + /// If `embedding.len() != self.dimension`, then the push operation fails. pub fn push(&mut self, mut embedding: Vec) -> Result<(), Vec> { if embedding.len() != self.dimension { return Err(embedding); @@ -64,6 +80,9 @@ impl Embeddings { Ok(()) } + /// Append a flat vector of embeddings a the end of the embeddings. + /// + /// If `embeddings.len() % self.dimension != 0`, then the append operation fails. pub fn append(&mut self, mut embeddings: Vec) -> Result<(), Vec> { if embeddings.len() % self.dimension != 0 { return Err(embeddings); @@ -73,37 +92,57 @@ impl Embeddings { } } +/// An embedder can be used to transform text into embeddings. #[derive(Debug)] pub enum Embedder { + /// An embedder based on running local models, fetched from the Hugging Face Hub. HuggingFace(hf::Embedder), + /// An embedder based on making embedding queries against the OpenAI API. OpenAi(openai::Embedder), + /// An embedder based on the user providing the embeddings in the documents and queries. UserProvided(manual::Embedder), Ollama(ollama::Embedder), } +/// Configuration for an embedder. #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] pub struct EmbeddingConfig { + /// Options of the embedder, specific to each kind of embedder pub embedder_options: EmbedderOptions, + /// Document template pub prompt: PromptData, // TODO: add metrics and anything needed } +/// Map of embedder configurations. +/// +/// Each configuration is mapped to a name. #[derive(Clone, Default)] pub struct EmbeddingConfigs(HashMap, Arc)>); impl EmbeddingConfigs { + /// Create the map from its internal component.s pub fn new(data: HashMap, Arc)>) -> Self { Self(data) } + /// Get an embedder configuration and template from its name. pub fn get(&self, name: &str) -> Option<(Arc, Arc)> { self.0.get(name).cloned() } + /// Get the default embedder configuration, if any. pub fn get_default(&self) -> Option<(Arc, Arc)> { self.get_default_embedder_name().and_then(|default| self.get(&default)) } + /// Get the name of the default embedder configuration. + /// + /// The default embedder is determined as follows: + /// + /// - If there is only one embedder, it is always the default. + /// - If there are multiple embedders and one of them is called `default`, then that one is the default embedder. + /// - In all other cases, there is no default embedder. pub fn get_default_embedder_name(&self) -> Option { let mut it = self.0.keys(); let first_name = it.next(); @@ -126,6 +165,7 @@ impl IntoIterator for EmbeddingConfigs { } } +/// Options of an embedder, specific to each kind of embedder. #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions), @@ -141,10 +181,12 @@ impl Default for EmbedderOptions { } impl EmbedderOptions { + /// Default options for the Hugging Face embedder pub fn huggingface() -> Self { Self::HuggingFace(hf::EmbedderOptions::new()) } + /// Default options for the OpenAI embedder pub fn openai(api_key: Option) -> Self { Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) } @@ -155,6 +197,7 @@ impl EmbedderOptions { } impl Embedder { + /// Spawns a new embedder built from its options. pub fn new(options: EmbedderOptions) -> std::result::Result { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), @@ -166,6 +209,9 @@ impl Embedder { }) } + /// Embed one or multiple texts. + /// + /// Each text can be embedded as one or multiple embeddings. pub async fn embed( &self, texts: Vec, @@ -184,6 +230,10 @@ impl Embedder { } } + /// Embed multiple chunks of texts. + /// + /// Each chunk is composed of one or multiple texts. + /// /// # Panics /// /// - if called from an asynchronous context @@ -199,6 +249,7 @@ impl Embedder { } } + /// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`] pub fn chunk_count_hint(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), @@ -208,6 +259,7 @@ impl Embedder { } } + /// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`] pub fn prompt_count_in_chunk_hint(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), @@ -217,6 +269,7 @@ impl Embedder { } } + /// Indicates the dimensions of a single embedding produced by the embedder. pub fn dimensions(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.dimensions(), @@ -226,6 +279,7 @@ impl Embedder { } } + /// An optional distribution used to apply an affine transformation to the similarity score of a document. pub fn distribution(&self) -> Option { match self { Embedder::HuggingFace(embedder) => embedder.distribution(), @@ -236,9 +290,20 @@ impl Embedder { } } +/// Describes the mean and sigma of distribution of embedding similarity in the embedding space. +/// +/// The intended use is to make the similarity score more comparable to the regular ranking score. +/// This allows to correct effects where results are too "packed" around a certain value. #[derive(Debug, Clone, Copy)] pub struct DistributionShift { + /// Value where the results are "packed". + /// + /// Similarity scores are translated so that they are packed around 0.5 instead pub current_mean: f32, + + /// standard deviation of a similarity score. + /// + /// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed. pub current_sigma: f32, } @@ -280,6 +345,7 @@ impl DistributionShift { } } +/// Whether CUDA is supported in this version of Meilisearch. pub const fn is_cuda_enabled() -> bool { cfg!(feature = "cuda") } From c3d02f092dddf5f3e0f336f7774cca72eb8ed0bb Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 14 Mar 2024 11:14:31 +0100 Subject: [PATCH 2/9] OpenAI sync --- Cargo.lock | 1 + milli/Cargo.toml | 1 + milli/src/vector/error.rs | 28 +- milli/src/vector/hf.rs | 2 +- milli/src/vector/mod.rs | 9 +- milli/src/vector/openai.rs | 554 +++++++++++++++++-------------------- 6 files changed, 274 insertions(+), 321 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b44b151d1..60d0e4c0e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3378,6 +3378,7 @@ dependencies = [ "tokenizers", "tokio", "tracing", + "ureq", "uuid", ] diff --git a/milli/Cargo.toml b/milli/Cargo.toml index fa4215404..59b3699cc 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -91,6 +91,7 @@ liquid = "0.26.4" arroy = "0.2.0" rand = "0.8.5" tracing = "0.1.40" +ureq = { version = "2.9.6", features = ["json"] } [dev-dependencies] mimalloc = { version = "0.1.39", default-features = false } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 9bbdeaa90..1def4f7a9 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -53,17 +53,17 @@ pub enum EmbedErrorKind { #[error("could not run model: {0}")] ModelForward(candle_core::Error), #[error("could not reach OpenAI: {0}")] - OpenAiNetwork(reqwest::Error), + OpenAiNetwork(ureq::Transport), #[error("unexpected response from OpenAI: {0}")] - OpenAiUnexpected(reqwest::Error), - #[error("could not authenticate against OpenAI: {0}")] - OpenAiAuth(OpenAiError), - #[error("sent too many requests to OpenAI: {0}")] - OpenAiTooManyRequests(OpenAiError), + OpenAiUnexpected(ureq::Error), + #[error("could not authenticate against OpenAI: {0:?}")] + OpenAiAuth(Option), + #[error("sent too many requests to OpenAI: {0:?}")] + OpenAiTooManyRequests(Option), #[error("received internal error from OpenAI: {0:?}")] OpenAiInternalServerError(Option), - #[error("sent too many tokens in a request to OpenAI: {0}")] - OpenAiTooManyTokens(OpenAiError), + #[error("sent too many tokens in a request to OpenAI: {0:?}")] + OpenAiTooManyTokens(Option), #[error("received unhandled HTTP status code {0} from OpenAI")] OpenAiUnhandledStatusCode(u16), #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] @@ -102,19 +102,19 @@ impl EmbedError { Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } } - pub fn openai_network(inner: reqwest::Error) -> Self { + pub fn openai_network(inner: ureq::Transport) -> Self { Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } } - pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { + pub fn openai_unexpected(inner: ureq::Error) -> EmbedError { Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } } - pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { + pub(crate) fn openai_auth_error(inner: Option) -> EmbedError { Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } } - pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { + pub(crate) fn openai_too_many_requests(inner: Option) -> EmbedError { Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } } @@ -122,7 +122,7 @@ impl EmbedError { Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } } - pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { + pub(crate) fn openai_too_many_tokens(inner: Option) -> EmbedError { Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } } @@ -220,7 +220,7 @@ impl NewEmbedderError { Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } } - pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { + pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { Self { kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), fault: FaultSource::Runtime, diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 04e169c71..939b6210a 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -131,7 +131,7 @@ impl Embedder { let embeddings = this .embed(vec!["test".into()]) - .map_err(NewEmbedderError::hf_could_not_determine_dimension)?; + .map_err(NewEmbedderError::could_not_determine_dimension)?; this.dimensions = embeddings.first().unwrap().dimension(); Ok(this) diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index aeb0be1ca..86dde8ad4 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -98,7 +98,7 @@ pub enum Embedder { /// An embedder based on running local models, fetched from the Hugging Face Hub. HuggingFace(hf::Embedder), /// An embedder based on making embedding queries against the OpenAI API. - OpenAi(openai::Embedder), + OpenAi(openai::sync::Embedder), /// An embedder based on the user providing the embeddings in the documents and queries. UserProvided(manual::Embedder), Ollama(ollama::Embedder), @@ -201,7 +201,7 @@ impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), - EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), + EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?), EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?), EmbedderOptions::UserProvided(options) => { Self::UserProvided(manual::Embedder::new(options)) @@ -218,10 +218,7 @@ impl Embedder { ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => { - let client = embedder.new_client()?; - embedder.embed(texts, &client).await - } + Embedder::OpenAi(embedder) => embedder.embed(texts), Embedder::Ollama(embedder) => { let client = embedder.new_client()?; embedder.embed(texts, &client).await diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index dcf3f4c89..5d13d5ee2 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -1,18 +1,10 @@ use std::fmt::Display; -use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use super::error::{EmbedError, NewEmbedderError}; use super::{DistributionShift, Embedding, Embeddings}; -#[derive(Debug)] -pub struct Embedder { - headers: reqwest::header::HeaderMap, - tokenizer: tiktoken_rs::CoreBPE, - options: EmbedderOptions, -} - #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { pub api_key: Option, @@ -125,298 +117,6 @@ impl EmbedderOptions { } } -impl Embedder { - pub fn new_client(&self) -> Result { - reqwest::ClientBuilder::new() - .default_headers(self.headers.clone()) - .build() - .map_err(EmbedError::openai_initialize_web_client) - } - - pub fn new(options: EmbedderOptions) -> Result { - let mut headers = reqwest::header::HeaderMap::new(); - let mut inferred_api_key = Default::default(); - let api_key = options.api_key.as_ref().unwrap_or_else(|| { - inferred_api_key = infer_api_key(); - &inferred_api_key - }); - headers.insert( - reqwest::header::AUTHORIZATION, - reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) - .map_err(NewEmbedderError::openai_invalid_api_key_format)?, - ); - headers.insert( - reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_static("application/json"), - ); - - // looking at the code it is very unclear that this can actually fail. - let tokenizer = tiktoken_rs::cl100k_base().unwrap(); - - Ok(Self { options, headers, tokenizer }) - } - - pub async fn embed( - &self, - texts: Vec, - client: &reqwest::Client, - ) -> Result>, EmbedError> { - let mut tokenized = false; - - for attempt in 0..7 { - let result = if tokenized { - self.try_embed_tokenized(&texts, client).await - } else { - self.try_embed(&texts, client).await - }; - - let retry_duration = match result { - Ok(embeddings) => return Ok(embeddings), - Err(retry) => { - tracing::warn!("Failed: {}", retry.error); - tokenized |= retry.must_tokenize(); - retry.into_duration(attempt) - } - }?; - - let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute - tracing::warn!( - "Attempt #{}, retrying after {}ms.", - attempt, - retry_duration.as_millis() - ); - tokio::time::sleep(retry_duration).await; - } - - let result = if tokenized { - self.try_embed_tokenized(&texts, client).await - } else { - self.try_embed(&texts, client).await - }; - - result.map_err(Retry::into_error) - } - - async fn check_response(response: reqwest::Response) -> Result { - if !response.status().is_success() { - match response.status() { - StatusCode::UNAUTHORIZED => { - let error_response: OpenAiErrorResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - return Err(Retry::give_up(EmbedError::openai_auth_error( - error_response.error, - ))); - } - StatusCode::TOO_MANY_REQUESTS => { - let error_response: OpenAiErrorResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - return Err(Retry::rate_limited(EmbedError::openai_too_many_requests( - error_response.error, - ))); - } - StatusCode::INTERNAL_SERVER_ERROR - | StatusCode::BAD_GATEWAY - | StatusCode::SERVICE_UNAVAILABLE => { - let error_response: Result = response.json().await; - return Err(Retry::retry_later(EmbedError::openai_internal_server_error( - error_response.ok().map(|error_response| error_response.error), - ))); - } - StatusCode::BAD_REQUEST => { - // Most probably, one text contained too many tokens - let error_response: OpenAiErrorResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt."); - - return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens( - error_response.error, - ))); - } - code => { - return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code( - code.as_u16(), - ))); - } - } - } - Ok(response) - } - - async fn try_embed + serde::Serialize>( - &self, - texts: &[S], - client: &reqwest::Client, - ) -> Result>, Retry> { - for text in texts { - tracing::trace!("Received prompt: {}", text.as_ref()) - } - let request = OpenAiRequest { - model: self.options.embedding_model.name(), - input: texts, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .json(&request) - .send() - .await - .map_err(EmbedError::openai_network) - .map_err(Retry::retry_later)?; - - let response = Self::check_response(response).await?; - - let response: OpenAiResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::trace!("response: {:?}", response.data); - - Ok(response - .data - .into_iter() - .map(|data| Embeddings::from_single_embedding(data.embedding)) - .collect()) - } - - async fn try_embed_tokenized( - &self, - text: &[String], - client: &reqwest::Client, - ) -> Result>, Retry> { - pub const OVERLAP_SIZE: usize = 200; - let mut all_embeddings = Vec::with_capacity(text.len()); - for text in text { - let max_token_count = self.options.embedding_model.max_token(); - let encoded = self.tokenizer.encode_ordinary(text.as_str()); - let len = encoded.len(); - if len < max_token_count { - all_embeddings.append(&mut self.try_embed(&[text], client).await?); - continue; - } - - let mut tokens = encoded.as_slice(); - let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); - while tokens.len() > max_token_count { - let window = &tokens[..max_token_count]; - embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); - - tokens = &tokens[max_token_count - OVERLAP_SIZE..]; - } - - // end of text - embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap(); - - all_embeddings.push(embeddings_for_prompt); - } - Ok(all_embeddings) - } - - async fn embed_tokens( - &self, - tokens: &[usize], - client: &reqwest::Client, - ) -> Result { - for attempt in 0..9 { - let duration = match self.try_embed_tokens(tokens, client).await { - Ok(embedding) => return Ok(embedding), - Err(retry) => retry.into_duration(attempt), - } - .map_err(Retry::retry_later)?; - - tokio::time::sleep(duration).await; - } - - self.try_embed_tokens(tokens, client) - .await - .map_err(|retry| Retry::give_up(retry.into_error())) - } - - async fn try_embed_tokens( - &self, - tokens: &[usize], - client: &reqwest::Client, - ) -> Result { - let request = OpenAiTokensRequest { - model: self.options.embedding_model.name(), - input: tokens, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .json(&request) - .send() - .await - .map_err(EmbedError::openai_network) - .map_err(Retry::retry_later)?; - - let response = Self::check_response(response).await?; - - let mut response: OpenAiResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) - } - - pub fn embed_chunks( - &self, - text_chunks: Vec>, - ) -> Result>>, EmbedError> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init)?; - let client = self.new_client()?; - rt.block_on(futures::future::try_join_all( - text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), - )) - } - - pub fn chunk_count_hint(&self) -> usize { - 10 - } - - pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 - } - - pub fn dimensions(&self) -> usize { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions()) - } else { - self.options.embedding_model.default_dimensions() - } - } - - pub fn distribution(&self) -> Option { - self.options.embedding_model.distribution() - } - - fn overriden_dimensions(&self) -> Option { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions - } else { - None - } - } -} - // retrying in case of failure pub struct Retry { @@ -524,3 +224,257 @@ fn infer_api_key() -> String { .or_else(|_| std::env::var("OPENAI_API_KEY")) .unwrap_or_default() } + +pub mod sync { + use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; + + use super::{ + EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest, + OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL, + }; + use crate::vector::DistributionShift; + + const REQUEST_PARALLELISM: usize = 10; + + #[derive(Debug)] + pub struct Embedder { + tokenizer: tiktoken_rs::CoreBPE, + options: super::EmbedderOptions, + bearer: String, + threads: rayon::ThreadPool, + } + + impl Embedder { + pub fn new(options: super::EmbedderOptions) -> Result { + let mut inferred_api_key = Default::default(); + let api_key = options.api_key.as_ref().unwrap_or_else(|| { + inferred_api_key = super::infer_api_key(); + &inferred_api_key + }); + let bearer = format!("Bearer {api_key}"); + + // looking at the code it is very unclear that this can actually fail. + let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + + // FIXME: unwrap + let threads = rayon::ThreadPoolBuilder::new() + .num_threads(REQUEST_PARALLELISM) + .thread_name(|index| format!("embedder-chunk-{index}")) + .build() + .unwrap(); + + Ok(Self { options, bearer, tokenizer, threads }) + } + + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + let mut tokenized = false; + + let client = ureq::agent(); + + for attempt in 0..7 { + let result = if tokenized { + self.try_embed_tokenized(&texts, &client) + } else { + self.try_embed(&texts, &client) + }; + + let retry_duration = match result { + Ok(embeddings) => return Ok(embeddings), + Err(retry) => { + tracing::warn!("Failed: {}", retry.error); + tokenized |= retry.must_tokenize(); + retry.into_duration(attempt) + } + }?; + + let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute + tracing::warn!( + "Attempt #{}, retrying after {}ms.", + attempt, + retry_duration.as_millis() + ); + std::thread::sleep(retry_duration); + } + + let result = if tokenized { + self.try_embed_tokenized(&texts, &client) + } else { + self.try_embed(&texts, &client) + }; + + result.map_err(Retry::into_error) + } + + fn check_response( + response: Result, + ) -> Result { + match response { + Ok(response) => Ok(response), + Err(ureq::Error::Status(code, response)) => { + let error_response: Option = response.into_json().ok(); + let error = error_response.map(|response| response.error); + Err(match code { + 401 => Retry::give_up(EmbedError::openai_auth_error(error)), + 429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)), + 400 => { + tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); + + Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error)) + } + 500..=599 => { + Retry::retry_later(EmbedError::openai_internal_server_error(error)) + } + x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)), + }) + } + Err(ureq::Error::Transport(transport)) => { + Err(Retry::retry_later(EmbedError::openai_network(transport))) + } + } + } + + fn try_embed + serde::Serialize>( + &self, + texts: &[S], + client: &ureq::Agent, + ) -> Result>, Retry> { + for text in texts { + tracing::trace!("Received prompt: {}", text.as_ref()) + } + let request = OpenAiRequest { + model: self.options.embedding_model.name(), + input: texts, + dimensions: self.overriden_dimensions(), + }; + let response = client + .post(OPENAI_EMBEDDINGS_URL) + .set("Authorization", &self.bearer) + .send_json(&request); + + let response = Self::check_response(response)?; + + let response: OpenAiResponse = response + .into_json() + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + tracing::trace!("response: {:?}", response.data); + + Ok(response + .data + .into_iter() + .map(|data| Embeddings::from_single_embedding(data.embedding)) + .collect()) + } + + fn try_embed_tokenized( + &self, + text: &[String], + client: &ureq::Agent, + ) -> Result>, Retry> { + pub const OVERLAP_SIZE: usize = 200; + let mut all_embeddings = Vec::with_capacity(text.len()); + for text in text { + let max_token_count = self.options.embedding_model.max_token(); + let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let len = encoded.len(); + if len < max_token_count { + all_embeddings.append(&mut self.try_embed(&[text], client)?); + continue; + } + + let mut tokens = encoded.as_slice(); + let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); + while tokens.len() > max_token_count { + let window = &tokens[..max_token_count]; + embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap(); + + tokens = &tokens[max_token_count - OVERLAP_SIZE..]; + } + + // end of text + embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap(); + + all_embeddings.push(embeddings_for_prompt); + } + Ok(all_embeddings) + } + + fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result { + for attempt in 0..9 { + let duration = match self.try_embed_tokens(tokens, client) { + Ok(embedding) => return Ok(embedding), + Err(retry) => retry.into_duration(attempt), + } + .map_err(Retry::retry_later)?; + + std::thread::sleep(duration); + } + + self.try_embed_tokens(tokens, client) + .map_err(|retry| Retry::give_up(retry.into_error())) + } + + fn try_embed_tokens( + &self, + tokens: &[usize], + client: &ureq::Agent, + ) -> Result { + let request = OpenAiTokensRequest { + model: self.options.embedding_model.name(), + input: tokens, + dimensions: self.overriden_dimensions(), + }; + let response = client + .post(OPENAI_EMBEDDINGS_URL) + .set("Authorization", &self.bearer) + .send_json(&request); + + let response = Self::check_response(response)?; + + let mut response: OpenAiResponse = response + .into_json() + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + self.threads + .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) + .collect() + } + + pub fn chunk_count_hint(&self) -> usize { + 10 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + 10 + } + + pub fn dimensions(&self) -> usize { + if self.options.embedding_model.supports_overriding_dimensions() { + self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions()) + } else { + self.options.embedding_model.default_dimensions() + } + } + + pub fn distribution(&self) -> Option { + self.options.embedding_model.distribution() + } + + fn overriden_dimensions(&self) -> Option { + if self.options.embedding_model.supports_overriding_dimensions() { + self.options.dimensions + } else { + None + } + } + } +} From 8708cbef2538d28c65b7511e9706b9c1a093762a Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 14 Mar 2024 14:44:43 +0100 Subject: [PATCH 3/9] Add RestEmbedder --- milli/src/vector/error.rs | 109 ++++++++++++++++++++++ milli/src/vector/mod.rs | 1 + milli/src/vector/rest.rs | 185 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 295 insertions(+) create mode 100644 milli/src/vector/rest.rs diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 1def4f7a9..b2eb37e81 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -83,6 +83,32 @@ pub enum EmbedErrorKind { OllamaModelNotFoundError(OllamaError), #[error("received unhandled HTTP status code {0} from Ollama")] OllamaUnhandledStatusCode(u16), + #[error("error serializing template context: {0}")] + RestTemplateContextSerialization(liquid::Error), + #[error( + "error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'" + )] + RestTemplateError(liquid::Error), + #[error("error deserialization the response body as JSON: {0}")] + RestResponseDeserialization(std::io::Error), + #[error("component `{0}` not found in path `{1}` in response: `{2}`")] + RestResponseMissingEmbeddings(String, String, String), + #[error("expected a response parseable as a vector or an array of vectors: {0}")] + RestResponseFormat(serde_json::Error), + #[error("expected a response containing {0} embeddings, got only {1}")] + RestResponseEmbeddingCount(usize, usize), + #[error("could not authenticate against embedding server: {0:?}")] + RestUnauthorized(Option), + #[error("sent too many requests to embedding server: {0:?}")] + RestTooManyRequests(Option), + #[error("sent a bad request to embedding server: {0:?}")] + RestBadRequest(Option), + #[error("received internal error from embedding server: {0:?}")] + RestInternalServerError(u16, Option), + #[error("received HTTP {0} from embedding server: {0:?}")] + RestOtherStatusCode(u16, Option), + #[error("could not reach embedding server: {0}")] + RestNetwork(ureq::Transport), } impl EmbedError { @@ -161,6 +187,89 @@ impl EmbedError { pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } } + + pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError { + Self { + kind: EmbedErrorKind::RestTemplateContextSerialization(error), + fault: FaultSource::Bug, + } + } + + pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError { + Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User } + } + + pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError { + Self { + kind: EmbedErrorKind::RestResponseDeserialization(error), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_response_missing_embeddings>( + response: serde_json::Value, + component: &str, + response_field: &[S], + ) -> EmbedError { + let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect(); + let response_field = response_field.join("."); + + Self { + kind: EmbedErrorKind::RestResponseMissingEmbeddings( + component.to_owned(), + response_field, + serde_json::to_string_pretty(&response).unwrap_or_default(), + ), + fault: FaultSource::Undecided, + } + } + + pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError { + Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided } + } + + pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError { + Self { + kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_unauthorized(error_response: Option) -> EmbedError { + Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User } + } + + pub(crate) fn rest_too_many_requests(error_response: Option) -> EmbedError { + Self { + kind: EmbedErrorKind::RestTooManyRequests(error_response), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_bad_request(error_response: Option) -> EmbedError { + Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User } + } + + pub(crate) fn rest_internal_server_error( + code: u16, + error_response: Option, + ) -> EmbedError { + Self { + kind: EmbedErrorKind::RestInternalServerError(code, error_response), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_other_status_code(code: u16, error_response: Option) -> EmbedError { + Self { + kind: EmbedErrorKind::RestOtherStatusCode(code, error_response), + fault: FaultSource::Undecided, + } + } + + pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError { + Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime } + } } #[derive(Debug, thiserror::Error)] diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 86dde8ad4..7eef3d442 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -11,6 +11,7 @@ pub mod openai; pub mod settings; pub mod ollama; +pub mod rest; pub use self::error::Error; diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs new file mode 100644 index 000000000..975bd3790 --- /dev/null +++ b/milli/src/vector/rest.rs @@ -0,0 +1,185 @@ +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; + +use super::openai::Retry; +use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError}; +use crate::VectorOrArrayOfVectors; + +pub struct Embedder { + client: ureq::Agent, + options: EmbedderOptions, + bearer: Option, + dimensions: usize, +} + +pub struct EmbedderOptions { + api_key: Option, + distribution: Option, + dimensions: Option, + url: String, + query: liquid::Template, + response_field: Vec, +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}")); + + let client = ureq::agent(); + + let dimensions = if let Some(dimensions) = options.dimensions { + dimensions + } else { + infer_dimensions(&client, &options, bearer.as_deref())? + }; + + Ok(Self { client, dimensions, options, bearer }) + } + + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice()) + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + threads: &rayon::ThreadPool, + ) -> Result>>, EmbedError> { + threads + .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) + .collect() + } + + pub fn chunk_count_hint(&self) -> usize { + 10 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + 10 + } + + pub fn dimensions(&self) -> usize { + self.dimensions + } + + pub fn distribution(&self) -> Option { + self.options.distribution + } +} + +fn infer_dimensions( + client: &ureq::Agent, + options: &EmbedderOptions, + bearer: Option<&str>, +) -> Result { + let v = embed(client, options, bearer, ["test"].as_slice()) + .map_err(NewEmbedderError::could_not_determine_dimension)?; + // unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error + Ok(v.first().unwrap().dimension()) +} + +fn embed( + client: &ureq::Agent, + options: &EmbedderOptions, + bearer: Option<&str>, + inputs: &[S], +) -> Result>, EmbedError> +where + S: serde::Serialize, +{ + let request = client.post(&options.url); + let request = + if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; + let request = request.set("Content-Type", "application/json"); + + let body = options + .query + .render( + &liquid::to_object(&serde_json::json!({ + "input": inputs, + })) + .map_err(EmbedError::rest_template_context_serialization)?, + ) + .map_err(EmbedError::rest_template_render)?; + + for attempt in 0..7 { + let response = request.send_string(&body); + let result = check_response(response); + + let retry_duration = match result { + Ok(response) => { + return response_to_embedding(response, &options.response_field, inputs.len()) + } + Err(retry) => { + tracing::warn!("Failed: {}", retry.error); + retry.into_duration(attempt) + } + }?; + + let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute + tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); + std::thread::sleep(retry_duration); + } + + let response = request.send_string(&body); + let result = check_response(response); + result + .map_err(Retry::into_error) + .and_then(|response| response_to_embedding(response, &options.response_field, inputs.len())) +} + +fn check_response(response: Result) -> Result { + match response { + Ok(response) => Ok(response), + Err(ureq::Error::Status(code, response)) => { + let error_response: Option = response.into_string().ok(); + Err(match code { + 401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)), + 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), + 400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), + 500..=599 => { + Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) + } + x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), + }) + } + Err(ureq::Error::Transport(transport)) => { + Err(Retry::retry_later(EmbedError::rest_network(transport))) + } + } +} + +fn response_to_embedding>( + response: ureq::Response, + response_field: &[S], + expected_count: usize, +) -> Result>, EmbedError> { + let response: serde_json::Value = + response.into_json().map_err(EmbedError::rest_response_deserialization)?; + + let mut current_value = &response; + for component in response_field { + let component = component.as_ref(); + let current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings(response, component, response_field) + })?; + } + + let embeddings = current_value.to_owned(); + + let embeddings: VectorOrArrayOfVectors = + serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; + + let embeddings = embeddings.into_array_of_vectors(); + + let embeddings: Vec> = embeddings + .into_iter() + .flatten() + .map(|embedding| Embeddings::from_single_embedding(embedding)) + .collect(); + + if embeddings.len() != expected_count { + return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); + } + + Ok(embeddings) +} From ac52c857e8f5ecf85a42e32abe7a14450fdfdd66 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 19 Mar 2024 15:41:37 +0100 Subject: [PATCH 4/9] Update ollama and openai impls to use the rest embedder internally --- .../extract/extract_vector_points.rs | 10 +- .../src/update/index_documents/extract/mod.rs | 15 +- milli/src/vector/error.rs | 116 +---- milli/src/vector/mod.rs | 22 +- milli/src/vector/ollama.rs | 307 ++---------- milli/src/vector/openai.rs | 452 +++++------------- milli/src/vector/rest.rs | 247 ++++++++-- milli/src/vector/settings.rs | 4 +- 8 files changed, 394 insertions(+), 779 deletions(-) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index ece841659..40b32bf9c 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -339,6 +339,7 @@ pub fn extract_embeddings( prompt_reader: grenad::Reader, indexer: GrenadParameters, embedder: Arc, + request_threads: &rayon::ThreadPool, ) -> Result>> { puffin::profile_function!(); let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism @@ -376,7 +377,10 @@ pub fn extract_embeddings( if chunks.len() == chunks.capacity() { let chunked_embeds = embedder - .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) + .embed_chunks( + std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)), + request_threads, + ) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; @@ -394,7 +398,7 @@ pub fn extract_embeddings( // send last chunk if !chunks.is_empty() { let chunked_embeds = embedder - .embed_chunks(std::mem::take(&mut chunks)) + .embed_chunks(std::mem::take(&mut chunks), request_threads) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids @@ -408,7 +412,7 @@ pub fn extract_embeddings( if !current_chunk.is_empty() { let embeds = embedder - .embed_chunks(vec![std::mem::take(&mut current_chunk)]) + .embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 43f3f4947..5689bb04f 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -238,7 +238,15 @@ fn send_original_documents_data( let documents_chunk_cloned = original_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + + let request_threads = rayon::ThreadPoolBuilder::new() + .num_threads(crate::vector::REQUEST_PARALLELISM) + .thread_name(|index| format!("embedding-request-{index}")) + .build() + .unwrap(); + rayon::spawn(move || { + /// FIXME: unwrap for (name, (embedder, prompt)) in embedders { let result = extract_vector_points( documents_chunk_cloned.clone(), @@ -249,7 +257,12 @@ fn send_original_documents_data( ); match result { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { - let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { + let embeddings = match extract_embeddings( + prompts, + indexer, + embedder.clone(), + &request_threads, + ) { Ok(results) => Some(results), Err(error) => { let _ = lmdb_writer_sx_cloned.send(Err(error)); diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index b2eb37e81..92f077924 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -2,9 +2,7 @@ use std::path::PathBuf; use hf_hub::api::sync::ApiError; -use super::ollama::OllamaError; use crate::error::FaultSource; -use crate::vector::openai::OpenAiError; #[derive(Debug, thiserror::Error)] #[error("Error while generating embeddings: {inner}")] @@ -52,43 +50,12 @@ pub enum EmbedErrorKind { TensorValue(candle_core::Error), #[error("could not run model: {0}")] ModelForward(candle_core::Error), - #[error("could not reach OpenAI: {0}")] - OpenAiNetwork(ureq::Transport), - #[error("unexpected response from OpenAI: {0}")] - OpenAiUnexpected(ureq::Error), - #[error("could not authenticate against OpenAI: {0:?}")] - OpenAiAuth(Option), - #[error("sent too many requests to OpenAI: {0:?}")] - OpenAiTooManyRequests(Option), - #[error("received internal error from OpenAI: {0:?}")] - OpenAiInternalServerError(Option), - #[error("sent too many tokens in a request to OpenAI: {0:?}")] - OpenAiTooManyTokens(Option), - #[error("received unhandled HTTP status code {0} from OpenAI")] - OpenAiUnhandledStatusCode(u16), #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] ManualEmbed(String), #[error("could not initialize asynchronous runtime: {0}")] OpenAiRuntimeInit(std::io::Error), - #[error("initializing web client for sending embedding requests failed: {0}")] - InitWebClient(reqwest::Error), - // Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends. - #[error("unexpected response from Ollama: {0}")] - OllamaUnexpected(reqwest::Error), - #[error("sent too many requests to Ollama: {0}")] - OllamaTooManyRequests(OllamaError), - #[error("received internal error from Ollama: {0}")] - OllamaInternalServerError(OllamaError), - #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] - OllamaModelNotFoundError(OllamaError), - #[error("received unhandled HTTP status code {0} from Ollama")] - OllamaUnhandledStatusCode(u16), - #[error("error serializing template context: {0}")] - RestTemplateContextSerialization(liquid::Error), - #[error( - "error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'" - )] - RestTemplateError(liquid::Error), + #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")] + OllamaModelNotFoundError(Option), #[error("error deserialization the response body as JSON: {0}")] RestResponseDeserialization(std::io::Error), #[error("component `{0}` not found in path `{1}` in response: `{2}`")] @@ -128,77 +95,14 @@ impl EmbedError { Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } } - pub fn openai_network(inner: ureq::Transport) -> Self { - Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } - } - - pub fn openai_unexpected(inner: ureq::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } - } - - pub(crate) fn openai_auth_error(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } - } - - pub(crate) fn openai_too_many_requests(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn openai_internal_server_error(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn openai_too_many_tokens(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } - } - - pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } - } - pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } } - pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime } - } - - pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { - Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug } - } - - pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError { + pub(crate) fn ollama_model_not_found(inner: Option) -> EmbedError { Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User } } - pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } - } - - pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError { - Self { - kind: EmbedErrorKind::RestTemplateContextSerialization(error), - fault: FaultSource::Bug, - } - } - - pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError { - Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User } - } - pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError { Self { kind: EmbedErrorKind::RestResponseDeserialization(error), @@ -335,17 +239,6 @@ impl NewEmbedderError { fault: FaultSource::Runtime, } } - - pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { - Self { - kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), - fault: FaultSource::User, - } - } - - pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { - Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } - } } #[derive(Debug, thiserror::Error)] @@ -392,7 +285,4 @@ pub enum NewEmbedderErrorKind { CouldNotDetermineDimension(EmbedError), #[error("loading model failed: {0}")] LoadModel(candle_core::Error), - // openai - #[error("The API key passed to Authorization error was in an invalid format: {0}")] - InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 7eef3d442..39232e387 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -17,6 +17,8 @@ pub use self::error::Error; pub type Embedding = Vec; +pub const REQUEST_PARALLELISM: usize = 40; + /// One or multiple embeddings stored consecutively in a flat vector. pub struct Embeddings { data: Vec, @@ -99,7 +101,7 @@ pub enum Embedder { /// An embedder based on running local models, fetched from the Hugging Face Hub. HuggingFace(hf::Embedder), /// An embedder based on making embedding queries against the OpenAI API. - OpenAi(openai::sync::Embedder), + OpenAi(openai::Embedder), /// An embedder based on the user providing the embeddings in the documents and queries. UserProvided(manual::Embedder), Ollama(ollama::Embedder), @@ -202,7 +204,7 @@ impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), - EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?), + EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?), EmbedderOptions::UserProvided(options) => { Self::UserProvided(manual::Embedder::new(options)) @@ -213,17 +215,14 @@ impl Embedder { /// Embed one or multiple texts. /// /// Each text can be embedded as one or multiple embeddings. - pub async fn embed( + pub fn embed( &self, texts: Vec, ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::OpenAi(embedder) => embedder.embed(texts), - Embedder::Ollama(embedder) => { - let client = embedder.new_client()?; - embedder.embed(texts, &client).await - } + Embedder::Ollama(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts), } } @@ -231,18 +230,15 @@ impl Embedder { /// Embed multiple chunks of texts. /// /// Each chunk is composed of one or multiple texts. - /// - /// # Panics - /// - /// - if called from an asynchronous context pub fn embed_chunks( &self, text_chunks: Vec>, + threads: &rayon::ThreadPool, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), - Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), - Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks), + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), + Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), } } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index 76988f70b..9c44e8052 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -1,293 +1,94 @@ -// Copied from "openai.rs" with the sections I actually understand changed for Ollama. -// The common components of the Ollama and OpenAI interfaces might need to be extracted. +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; -use std::fmt::Display; - -use reqwest::StatusCode; - -use super::error::{EmbedError, NewEmbedderError}; -use super::openai::Retry; -use super::{DistributionShift, Embedding, Embeddings}; +use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; +use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; +use super::{DistributionShift, Embeddings}; #[derive(Debug)] pub struct Embedder { - headers: reqwest::header::HeaderMap, - options: EmbedderOptions, + rest_embedder: RestEmbedder, } #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { - pub embedding_model: EmbeddingModel, -} - -#[derive( - Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr, -)] -#[deserr(deny_unknown_fields)] -pub struct EmbeddingModel { - name: String, - dimensions: usize, -} - -#[derive(Debug, serde::Serialize)] -struct OllamaRequest<'a> { - model: &'a str, - prompt: &'a str, -} - -#[derive(Debug, serde::Deserialize)] -struct OllamaResponse { - embedding: Embedding, -} - -#[derive(Debug, serde::Deserialize)] -pub struct OllamaError { - error: String, -} - -impl EmbeddingModel { - pub fn max_token(&self) -> usize { - // this might not be the same for all models - 8192 - } - - pub fn default_dimensions(&self) -> usize { - // Dimensions for nomic-embed-text - 768 - } - - pub fn name(&self) -> String { - self.name.clone() - } - - pub fn from_name(name: &str) -> Self { - Self { name: name.to_string(), dimensions: 0 } - } - - pub fn supports_overriding_dimensions(&self) -> bool { - false - } -} - -impl Default for EmbeddingModel { - fn default() -> Self { - Self { name: "nomic-embed-text".to_string(), dimensions: 0 } - } + pub embedding_model: String, } impl EmbedderOptions { pub fn with_default_model() -> Self { - Self { embedding_model: Default::default() } + Self { embedding_model: "nomic-embed-text".into() } } - pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { + pub fn with_embedding_model(embedding_model: String) -> Self { Self { embedding_model } } } impl Embedder { - pub fn new_client(&self) -> Result { - reqwest::ClientBuilder::new() - .default_headers(self.headers.clone()) - .build() - .map_err(EmbedError::openai_initialize_web_client) - } - pub fn new(options: EmbedderOptions) -> Result { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_static("application/json"), - ); - - let mut embedder = Self { options, headers }; - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - // Get dimensions from Ollama - let request = - OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; - // TODO: Refactor into shared error type - let client = embedder - .new_client() - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - rt.block_on(async move { - let response = client - .post(get_ollama_path()) - .json(&request) - .send() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - // Process error in case model not found - let response = Self::check_response(response).await.map_err(|_err| { - let e = EmbedError::ollama_model_not_found(OllamaError { - error: format!("model: {}", embedder.options.embedding_model.name()), - }); - NewEmbedderError::ollama_could_not_determine_dimension(e) - })?; - - let response: OllamaResponse = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - let embedding = Embeddings::from_single_embedding(response.embedding); - - embedder.options.embedding_model.dimensions = embedding.dimension(); - - tracing::info!( - "ollama model {} with dimensionality {} added", - embedder.options.embedding_model.name(), - embedding.dimension() - ); - - Ok(embedder) - }) - } - - async fn check_response(response: reqwest::Response) -> Result { - if !response.status().is_success() { - // Not the same number of possible error cases covered as with OpenAI. - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::retry_later)?; - - return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( - OllamaError { error: error_response.error }, - ))); - } - StatusCode::SERVICE_UNAVAILABLE => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::retry_later)?; - return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( - OllamaError { error: error_response.error }, - ))); - } - StatusCode::NOT_FOUND => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::give_up)?; - - return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError { - error: error_response.error, - }))); - } - code => { - return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( - code.as_u16(), - ))); - } + let model = options.embedding_model.as_str(); + let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { + api_key: None, + distribution: None, + dimensions: None, + url: get_ollama_path(), + query: serde_json::json!({ + "model": model, + }), + input_field: vec!["prompt".to_owned()], + path_to_embeddings: Default::default(), + embedding_object: vec!["embedding".to_owned()], + input_type: super::rest::InputType::Text, + }) { + Ok(embedder) => embedder, + Err(NewEmbedderError { + kind: + NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError { + kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error), + fault: _, + }), + fault: _, + }) => { + return Err(NewEmbedderError::could_not_determine_dimension( + EmbedError::ollama_model_not_found(error), + )) } - } - Ok(response) + Err(error) => return Err(error), + }; + + Ok(Self { rest_embedder }) } - pub async fn embed( - &self, - texts: Vec, - client: &reqwest::Client, - ) -> Result>, EmbedError> { - // Ollama only embedds one document at a time. - let mut results = Vec::with_capacity(texts.len()); - - // The retry loop is inside the texts loop, might have to switch that around - for text in texts { - // Retries copied from openai.rs - for attempt in 0..7 { - let retry_duration = match self.try_embed(&text, client).await { - Ok(result) => { - results.push(result); - break; - } - Err(retry) => { - tracing::warn!("Failed: {}", retry.error); - retry.into_duration(attempt) - } - }?; - tracing::warn!( - "Attempt #{}, retrying after {}ms.", - attempt, - retry_duration.as_millis() - ); - tokio::time::sleep(retry_duration).await; + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + match self.rest_embedder.embed(texts) { + Ok(embeddings) => Ok(embeddings), + Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { + Err(EmbedError::ollama_model_not_found(error)) } + Err(error) => Err(error), } - - Ok(results) - } - - async fn try_embed( - &self, - text: &str, - client: &reqwest::Client, - ) -> Result, Retry> { - let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text }; - let response = client - .post(get_ollama_path()) - .json(&request) - .send() - .await - .map_err(EmbedError::openai_network) - .map_err(Retry::retry_later)?; - - let response = Self::check_response(response).await?; - - let response: OllamaResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::trace!("response: {:?}", response.embedding); - - let embedding = Embeddings::from_single_embedding(response.embedding); - Ok(embedding) } pub fn embed_chunks( &self, text_chunks: Vec>, + threads: &rayon::ThreadPool, ) -> Result>>, EmbedError> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init)?; - let client = self.new_client()?; - rt.block_on(futures::future::try_join_all( - text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), - )) + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) } - // Defaults copied from openai.rs pub fn chunk_count_hint(&self) -> usize { - 10 + self.rest_embedder.chunk_count_hint() } pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 + self.rest_embedder.prompt_count_in_chunk_hint() } pub fn dimensions(&self) -> usize { - self.options.embedding_model.dimensions + self.rest_embedder.dimensions() } pub fn distribution(&self) -> Option { @@ -295,12 +96,6 @@ impl Embedder { } } -impl Display for OllamaError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error) - } -} - fn get_ollama_path() -> String { // Important: Hostname not enough, has to be entire path to embeddings endpoint std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string()) diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 5d13d5ee2..b2638966e 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -1,9 +1,9 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; +use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use super::error::{EmbedError, NewEmbedderError}; -use super::{DistributionShift, Embedding, Embeddings}; +use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; +use super::{DistributionShift, Embeddings}; +use crate::vector::error::EmbedErrorKind; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { @@ -12,6 +12,32 @@ pub struct EmbedderOptions { pub dimensions: Option, } +impl EmbedderOptions { + pub fn dimensions(&self) -> usize { + if self.embedding_model.supports_overriding_dimensions() { + self.dimensions.unwrap_or(self.embedding_model.default_dimensions()) + } else { + self.embedding_model.default_dimensions() + } + } + + pub fn query(&self) -> serde_json::Value { + let model = self.embedding_model.name(); + + let mut query = serde_json::json!({ + "model": model, + }); + + if self.embedding_model.supports_overriding_dimensions() { + if let Some(dimensions) = self.dimensions { + query["dimensions"] = dimensions.into(); + } + } + + query + } +} + #[derive( Debug, Clone, @@ -117,364 +143,112 @@ impl EmbedderOptions { } } -// retrying in case of failure - -pub struct Retry { - pub error: EmbedError, - strategy: RetryStrategy, -} - -pub enum RetryStrategy { - GiveUp, - Retry, - RetryTokenized, - RetryAfterRateLimit, -} - -impl Retry { - pub fn give_up(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::GiveUp } - } - - pub fn retry_later(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::Retry } - } - - pub fn retry_tokenized(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::RetryTokenized } - } - - pub fn rate_limited(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::RetryAfterRateLimit } - } - - pub fn into_duration(self, attempt: u32) -> Result { - match self.strategy { - RetryStrategy::GiveUp => Err(self.error), - RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), - RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), - RetryStrategy::RetryAfterRateLimit => { - Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) - } - } - } - - pub fn must_tokenize(&self) -> bool { - matches!(self.strategy, RetryStrategy::RetryTokenized) - } - - pub fn into_error(self) -> EmbedError { - self.error - } -} - -// openai api structs - -#[derive(Debug, Serialize)] -struct OpenAiRequest<'a, S: AsRef + serde::Serialize> { - model: &'a str, - input: &'a [S], - #[serde(skip_serializing_if = "Option::is_none")] - dimensions: Option, -} - -#[derive(Debug, Serialize)] -struct OpenAiTokensRequest<'a> { - model: &'a str, - input: &'a [usize], - #[serde(skip_serializing_if = "Option::is_none")] - dimensions: Option, -} - -#[derive(Debug, Deserialize)] -struct OpenAiResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct OpenAiErrorResponse { - error: OpenAiError, -} - -#[derive(Debug, Deserialize)] -pub struct OpenAiError { - message: String, - // type: String, - code: Option, -} - -impl Display for OpenAiError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.code { - Some(code) => write!(f, "{} ({})", self.message, code), - None => write!(f, "{}", self.message), - } - } -} - -#[derive(Debug, Deserialize)] -struct OpenAiEmbedding { - embedding: Embedding, - // object: String, - // index: usize, -} - fn infer_api_key() -> String { std::env::var("MEILI_OPENAI_API_KEY") .or_else(|_| std::env::var("OPENAI_API_KEY")) .unwrap_or_default() } -pub mod sync { - use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; +#[derive(Debug)] +pub struct Embedder { + tokenizer: tiktoken_rs::CoreBPE, + rest_embedder: RestEmbedder, + options: EmbedderOptions, +} - use super::{ - EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest, - OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL, - }; - use crate::vector::DistributionShift; +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let mut inferred_api_key = Default::default(); + let api_key = options.api_key.as_ref().unwrap_or_else(|| { + inferred_api_key = infer_api_key(); + &inferred_api_key + }); - const REQUEST_PARALLELISM: usize = 10; + let rest_embedder = RestEmbedder::new(RestEmbedderOptions { + api_key: Some(api_key.clone()), + distribution: options.embedding_model.distribution(), + dimensions: Some(options.dimensions()), + url: OPENAI_EMBEDDINGS_URL.to_owned(), + query: options.query(), + input_field: vec!["input".to_owned()], + input_type: crate::vector::rest::InputType::TextArray, + path_to_embeddings: vec!["data".to_owned()], + embedding_object: vec!["embedding".to_owned()], + })?; - #[derive(Debug)] - pub struct Embedder { - tokenizer: tiktoken_rs::CoreBPE, - options: super::EmbedderOptions, - bearer: String, - threads: rayon::ThreadPool, + // looking at the code it is very unclear that this can actually fail. + let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + + Ok(Self { options, rest_embedder, tokenizer }) } - impl Embedder { - pub fn new(options: super::EmbedderOptions) -> Result { - let mut inferred_api_key = Default::default(); - let api_key = options.api_key.as_ref().unwrap_or_else(|| { - inferred_api_key = super::infer_api_key(); - &inferred_api_key - }); - let bearer = format!("Bearer {api_key}"); - - // looking at the code it is very unclear that this can actually fail. - let tokenizer = tiktoken_rs::cl100k_base().unwrap(); - - // FIXME: unwrap - let threads = rayon::ThreadPoolBuilder::new() - .num_threads(REQUEST_PARALLELISM) - .thread_name(|index| format!("embedder-chunk-{index}")) - .build() - .unwrap(); - - Ok(Self { options, bearer, tokenizer, threads }) + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + match self.rest_embedder.embed_ref(&texts) { + Ok(embeddings) => Ok(embeddings), + Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => { + tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); + self.try_embed_tokenized(&texts) + } + Err(error) => Err(error), } + } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - let mut tokenized = false; - - let client = ureq::agent(); - - for attempt in 0..7 { - let result = if tokenized { - self.try_embed_tokenized(&texts, &client) - } else { - self.try_embed(&texts, &client) - }; - - let retry_duration = match result { - Ok(embeddings) => return Ok(embeddings), - Err(retry) => { - tracing::warn!("Failed: {}", retry.error); - tokenized |= retry.must_tokenize(); - retry.into_duration(attempt) - } - }?; - - let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute - tracing::warn!( - "Attempt #{}, retrying after {}ms.", - attempt, - retry_duration.as_millis() - ); - std::thread::sleep(retry_duration); + fn try_embed_tokenized(&self, text: &[String]) -> Result>, EmbedError> { + pub const OVERLAP_SIZE: usize = 200; + let mut all_embeddings = Vec::with_capacity(text.len()); + for text in text { + let max_token_count = self.options.embedding_model.max_token(); + let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let len = encoded.len(); + if len < max_token_count { + all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); + continue; } - let result = if tokenized { - self.try_embed_tokenized(&texts, &client) - } else { - self.try_embed(&texts, &client) - }; + let mut tokens = encoded.as_slice(); + let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); + while tokens.len() > max_token_count { + let window = &tokens[..max_token_count]; + let embedding = self.rest_embedder.embed_tokens(window)?; + /// FIXME: unwrap + embeddings_for_prompt.append(embedding.into_inner()).unwrap(); - result.map_err(Retry::into_error) - } - - fn check_response( - response: Result, - ) -> Result { - match response { - Ok(response) => Ok(response), - Err(ureq::Error::Status(code, response)) => { - let error_response: Option = response.into_json().ok(); - let error = error_response.map(|response| response.error); - Err(match code { - 401 => Retry::give_up(EmbedError::openai_auth_error(error)), - 429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)), - 400 => { - tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); - - Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error)) - } - 500..=599 => { - Retry::retry_later(EmbedError::openai_internal_server_error(error)) - } - x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)), - }) - } - Err(ureq::Error::Transport(transport)) => { - Err(Retry::retry_later(EmbedError::openai_network(transport))) - } - } - } - - fn try_embed + serde::Serialize>( - &self, - texts: &[S], - client: &ureq::Agent, - ) -> Result>, Retry> { - for text in texts { - tracing::trace!("Received prompt: {}", text.as_ref()) - } - let request = OpenAiRequest { - model: self.options.embedding_model.name(), - input: texts, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .set("Authorization", &self.bearer) - .send_json(&request); - - let response = Self::check_response(response)?; - - let response: OpenAiResponse = response - .into_json() - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::trace!("response: {:?}", response.data); - - Ok(response - .data - .into_iter() - .map(|data| Embeddings::from_single_embedding(data.embedding)) - .collect()) - } - - fn try_embed_tokenized( - &self, - text: &[String], - client: &ureq::Agent, - ) -> Result>, Retry> { - pub const OVERLAP_SIZE: usize = 200; - let mut all_embeddings = Vec::with_capacity(text.len()); - for text in text { - let max_token_count = self.options.embedding_model.max_token(); - let encoded = self.tokenizer.encode_ordinary(text.as_str()); - let len = encoded.len(); - if len < max_token_count { - all_embeddings.append(&mut self.try_embed(&[text], client)?); - continue; - } - - let mut tokens = encoded.as_slice(); - let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); - while tokens.len() > max_token_count { - let window = &tokens[..max_token_count]; - embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap(); - - tokens = &tokens[max_token_count - OVERLAP_SIZE..]; - } - - // end of text - embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap(); - - all_embeddings.push(embeddings_for_prompt); - } - Ok(all_embeddings) - } - - fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result { - for attempt in 0..9 { - let duration = match self.try_embed_tokens(tokens, client) { - Ok(embedding) => return Ok(embedding), - Err(retry) => retry.into_duration(attempt), - } - .map_err(Retry::retry_later)?; - - std::thread::sleep(duration); + tokens = &tokens[max_token_count - OVERLAP_SIZE..]; } - self.try_embed_tokens(tokens, client) - .map_err(|retry| Retry::give_up(retry.into_error())) + // end of text + let embedding = self.rest_embedder.embed_tokens(tokens)?; + /// FIXME: unwrap + embeddings_for_prompt.append(embedding.into_inner()).unwrap(); + + all_embeddings.push(embeddings_for_prompt); } + Ok(all_embeddings) + } - fn try_embed_tokens( - &self, - tokens: &[usize], - client: &ureq::Agent, - ) -> Result { - let request = OpenAiTokensRequest { - model: self.options.embedding_model.name(), - input: tokens, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .set("Authorization", &self.bearer) - .send_json(&request); + pub fn embed_chunks( + &self, + text_chunks: Vec>, + threads: &rayon::ThreadPool, + ) -> Result>>, EmbedError> { + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + } - let response = Self::check_response(response)?; + pub fn chunk_count_hint(&self) -> usize { + self.rest_embedder.chunk_count_hint() + } - let mut response: OpenAiResponse = response - .into_json() - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; + pub fn prompt_count_in_chunk_hint(&self) -> usize { + self.rest_embedder.prompt_count_in_chunk_hint() + } - Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) - } + pub fn dimensions(&self) -> usize { + self.options.dimensions() + } - pub fn embed_chunks( - &self, - text_chunks: Vec>, - ) -> Result>>, EmbedError> { - self.threads - .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) - .collect() - } - - pub fn chunk_count_hint(&self) -> usize { - 10 - } - - pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 - } - - pub fn dimensions(&self) -> usize { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions()) - } else { - self.options.embedding_model.default_dimensions() - } - } - - pub fn distribution(&self) -> Option { - self.options.embedding_model.distribution() - } - - fn overriden_dimensions(&self) -> Option { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions - } else { - None - } - } + pub fn distribution(&self) -> Option { + self.options.embedding_model.distribution() } } diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 975bd3790..6fd47d882 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -1,9 +1,62 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use serde::Serialize; -use super::openai::Retry; -use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError}; -use crate::VectorOrArrayOfVectors; +use super::{ + DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, +}; +// retrying in case of failure + +pub struct Retry { + pub error: EmbedError, + strategy: RetryStrategy, +} + +pub enum RetryStrategy { + GiveUp, + Retry, + RetryTokenized, + RetryAfterRateLimit, +} + +impl Retry { + pub fn give_up(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + pub fn retry_later(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + pub fn retry_tokenized(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryTokenized } + } + + pub fn rate_limited(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryAfterRateLimit } + } + + pub fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))), + RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)), + RetryStrategy::RetryAfterRateLimit => { + Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt))) + } + } + } + + pub fn must_tokenize(&self) -> bool { + matches!(self.strategy, RetryStrategy::RetryTokenized) + } + + pub fn into_error(self) -> EmbedError { + self.error + } +} + +#[derive(Debug)] pub struct Embedder { client: ureq::Agent, options: EmbedderOptions, @@ -11,20 +64,35 @@ pub struct Embedder { dimensions: usize, } +#[derive(Debug)] pub struct EmbedderOptions { - api_key: Option, - distribution: Option, - dimensions: Option, - url: String, - query: liquid::Template, - response_field: Vec, + pub api_key: Option, + pub distribution: Option, + pub dimensions: Option, + pub url: String, + pub query: serde_json::Value, + pub input_field: Vec, + // path to the array of embeddings + pub path_to_embeddings: Vec, + // shape of a single embedding + pub embedding_object: Vec, + pub input_type: InputType, +} + +#[derive(Debug)] +pub enum InputType { + Text, + TextArray, } impl Embedder { pub fn new(options: EmbedderOptions) -> Result { - let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}")); + let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}")); - let client = ureq::agent(); + let client = ureq::AgentBuilder::new() + .max_idle_connections(REQUEST_PARALLELISM * 2) + .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) + .build(); let dimensions = if let Some(dimensions) = options.dimensions { dimensions @@ -36,7 +104,20 @@ impl Embedder { } pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice()) + embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len()) + } + + pub fn embed_ref(&self, texts: &[S]) -> Result>, EmbedError> + where + S: AsRef + Serialize, + { + embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len()) + } + + pub fn embed_tokens(&self, tokens: &[usize]) -> Result, EmbedError> { + let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?; + // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error + Ok(embeddings.pop().unwrap()) } pub fn embed_chunks( @@ -44,17 +125,20 @@ impl Embedder { text_chunks: Vec>, threads: &rayon::ThreadPool, ) -> Result>>, EmbedError> { - threads - .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) - .collect() + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) } pub fn chunk_count_hint(&self) -> usize { - 10 + super::REQUEST_PARALLELISM } pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 + match self.options.input_type { + InputType::Text => 1, + InputType::TextArray => 10, + } } pub fn dimensions(&self) -> usize { @@ -71,9 +155,9 @@ fn infer_dimensions( options: &EmbedderOptions, bearer: Option<&str>, ) -> Result { - let v = embed(client, options, bearer, ["test"].as_slice()) + let v = embed(client, options, bearer, ["test"].as_slice(), 1) .map_err(NewEmbedderError::could_not_determine_dimension)?; - // unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error + // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error Ok(v.first().unwrap().dimension()) } @@ -82,33 +166,57 @@ fn embed( options: &EmbedderOptions, bearer: Option<&str>, inputs: &[S], + expected_count: usize, ) -> Result>, EmbedError> where - S: serde::Serialize, + S: Serialize, { let request = client.post(&options.url); let request = if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; let request = request.set("Content-Type", "application/json"); - let body = options - .query - .render( - &liquid::to_object(&serde_json::json!({ - "input": inputs, - })) - .map_err(EmbedError::rest_template_context_serialization)?, - ) - .map_err(EmbedError::rest_template_render)?; + let input_value = match options.input_type { + InputType::Text => serde_json::json!(inputs.first()), + InputType::TextArray => serde_json::json!(inputs), + }; + + let body = match options.input_field.as_slice() { + [] => { + // inject input in body + input_value + } + [input] => { + let mut body = options.query.clone(); + + /// FIXME unwrap + body.as_object_mut().unwrap().insert(input.clone(), input_value); + body + } + [path @ .., input] => { + let mut body = options.query.clone(); + + /// FIXME unwrap + let mut current_value = &mut body; + for component in path { + current_value = current_value + .as_object_mut() + .unwrap() + .entry(component.clone()) + .or_insert(serde_json::json!({})); + } + + current_value.as_object_mut().unwrap().insert(input.clone(), input_value); + body + } + }; for attempt in 0..7 { - let response = request.send_string(&body); + let response = request.clone().send_json(&body); let result = check_response(response); let retry_duration = match result { - Ok(response) => { - return response_to_embedding(response, &options.response_field, inputs.len()) - } + Ok(response) => return response_to_embedding(response, options, expected_count), Err(retry) => { tracing::warn!("Failed: {}", retry.error); retry.into_duration(attempt) @@ -120,11 +228,11 @@ where std::thread::sleep(retry_duration); } - let response = request.send_string(&body); + let response = request.send_json(&body); let result = check_response(response); result .map_err(Retry::into_error) - .and_then(|response| response_to_embedding(response, &options.response_field, inputs.len())) + .and_then(|response| response_to_embedding(response, options, expected_count)) } fn check_response(response: Result) -> Result { @@ -139,7 +247,10 @@ fn check_response(response: Result) -> Result { Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) } - x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), + 402..=499 => { + Retry::give_up(EmbedError::rest_other_status_code(code, error_response)) + } + _ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), }) } Err(ureq::Error::Transport(transport)) => { @@ -148,34 +259,66 @@ fn check_response(response: Result) -> Result>( +fn response_to_embedding( response: ureq::Response, - response_field: &[S], + options: &EmbedderOptions, expected_count: usize, ) -> Result>, EmbedError> { let response: serde_json::Value = response.into_json().map_err(EmbedError::rest_response_deserialization)?; let mut current_value = &response; - for component in response_field { + for component in &options.path_to_embeddings { let component = component.as_ref(); - let current_value = current_value.get(component).ok_or_else(|| { - EmbedError::rest_response_missing_embeddings(response, component, response_field) + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.path_to_embeddings, + ) })?; } - let embeddings = current_value.to_owned(); + let embeddings = match options.input_type { + InputType::Text => { + for component in &options.embedding_object { + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.embedding_object, + ) + })?; + } + let embeddings = current_value.to_owned(); + let embeddings: Embedding = + serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; - let embeddings: VectorOrArrayOfVectors = - serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; - - let embeddings = embeddings.into_array_of_vectors(); - - let embeddings: Vec> = embeddings - .into_iter() - .flatten() - .map(|embedding| Embeddings::from_single_embedding(embedding)) - .collect(); + vec![Embeddings::from_single_embedding(embeddings)] + } + InputType::TextArray => { + let empty = vec![]; + let values = current_value.as_array().unwrap_or(&empty); + let mut embeddings: Vec> = Vec::with_capacity(expected_count); + for value in values { + let mut current_value = value; + for component in &options.embedding_object { + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.embedding_object, + ) + })?; + } + let embedding = current_value.to_owned(); + let embedding: Embedding = + serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?; + embeddings.push(Embeddings::from_single_embedding(embedding)); + } + embeddings + } + }; if embeddings.len() != expected_count { return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 89571e98a..540693d44 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -204,7 +204,7 @@ impl From for EmbeddingSettings { }, super::EmbedderOptions::Ollama(options) => Self { source: Setting::Set(EmbedderSource::Ollama), - model: Setting::Set(options.embedding_model.name().to_owned()), + model: Setting::Set(options.embedding_model.to_owned()), revision: Setting::NotSet, api_key: Setting::NotSet, dimensions: Setting::NotSet, @@ -248,7 +248,7 @@ impl From for EmbeddingConfig { let mut options: ollama::EmbedderOptions = super::ollama::EmbedderOptions::with_default_model(); if let Some(model) = model.set() { - options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); + options.embedding_model = model; } this.embedder_options = super::EmbedderOptions::Ollama(options); } From f649f58013c969908a2a2753ab91e1689e766b2d Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 19 Mar 2024 15:42:53 +0100 Subject: [PATCH 5/9] embed no longer async --- meilisearch/src/routes/indexes/search.rs | 7 +++---- meilisearch/src/routes/multi_search.rs | 5 ++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 6a430b6a3..8de2be13f 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -202,7 +202,7 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); - let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) @@ -241,7 +241,7 @@ pub async fn search_with_post( let features = index_scheduler.features(); - let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) @@ -260,7 +260,7 @@ pub async fn search_with_post( Ok(HttpResponse::Ok().json(search_result)) } -pub async fn embed( +pub fn embed( query: &mut SearchQuery, index_scheduler: &IndexScheduler, index: &milli::Index, @@ -287,7 +287,6 @@ pub async fn embed( let embeddings = embedder .embed(vec![q.to_owned()]) - .await .map_err(milli::vector::Error::from) .map_err(milli::Error::from)? .pop() diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index 86aa58e70..f54b8ae8f 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -75,9 +75,8 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; - let distribution = embed(&mut query, index_scheduler.get_ref(), &index) - .await - .with_index(query_index)?; + let distribution = + embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?; let search_result = tokio::task::spawn_blocking(move || { perform_search(&index, query, features, distribution) From b6b4b6bab7563180475d0f306886870086804a32 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 19 Mar 2024 15:43:12 +0100 Subject: [PATCH 6/9] Remove the tokio and the reqwests --- Cargo.lock | 3 --- milli/Cargo.toml | 6 ------ 2 files changed, 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 60d0e4c0e..6a8c20f12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3338,7 +3338,6 @@ dependencies = [ "filter-parser", "flatten-serde-json", "fst", - "futures", "fxhash", "geoutils", "grenad", @@ -3362,7 +3361,6 @@ dependencies = [ "rand", "rand_pcg", "rayon", - "reqwest", "roaring", "rstar", "serde", @@ -3376,7 +3374,6 @@ dependencies = [ "tiktoken-rs", "time", "tokenizers", - "tokio", "tracing", "ureq", "uuid", diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 59b3699cc..4833ad00b 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -80,12 +80,6 @@ tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0. hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ "online", ] } -tokio = { version = "1.35.1", features = ["rt"] } -futures = "0.3.30" -reqwest = { version = "0.11.23", features = [ - "rustls-tls", - "json", -], default-features = false } tiktoken-rs = "0.5.8" liquid = "0.26.4" arroy = "0.2.0" From f87747f4d32af96a152776d021075782ce058c0f Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 20 Mar 2024 13:25:10 +0100 Subject: [PATCH 7/9] Remove unwraps --- .../src/update/index_documents/extract/mod.rs | 4 +--- milli/src/vector/error.rs | 20 +++++++++++++++++-- milli/src/vector/openai.rs | 11 ++++++---- milli/src/vector/rest.rs | 18 +++++++++++++---- 4 files changed, 40 insertions(+), 13 deletions(-) diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 5689bb04f..82486f3a8 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -242,11 +242,9 @@ fn send_original_documents_data( let request_threads = rayon::ThreadPoolBuilder::new() .num_threads(crate::vector::REQUEST_PARALLELISM) .thread_name(|index| format!("embedding-request-{index}")) - .build() - .unwrap(); + .build()?; rayon::spawn(move || { - /// FIXME: unwrap for (name, (embedder, prompt)) in embedders { let result = extract_vector_points( documents_chunk_cloned.clone(), diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 92f077924..1e0bcc7fb 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -52,8 +52,6 @@ pub enum EmbedErrorKind { ModelForward(candle_core::Error), #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] ManualEmbed(String), - #[error("could not initialize asynchronous runtime: {0}")] - OpenAiRuntimeInit(std::io::Error), #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")] OllamaModelNotFoundError(Option), #[error("error deserialization the response body as JSON: {0}")] @@ -76,6 +74,10 @@ pub enum EmbedErrorKind { RestOtherStatusCode(u16, Option), #[error("could not reach embedding server: {0}")] RestNetwork(ureq::Transport), + #[error("was expected '{}' to be an object in query '{0}'", .1.join("."))] + RestNotAnObject(serde_json::Value, Vec), + #[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")] + OpenAiUnexpectedDimension(usize, usize), } impl EmbedError { @@ -174,6 +176,20 @@ impl EmbedError { pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError { Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime } } + + pub(crate) fn rest_not_an_object( + query: serde_json::Value, + input_path: Vec, + ) -> EmbedError { + Self { kind: EmbedErrorKind::RestNotAnObject(query, input_path), fault: FaultSource::User } + } + + pub(crate) fn openai_unexpected_dimension(expected: usize, got: usize) -> EmbedError { + Self { + kind: EmbedErrorKind::OpenAiUnexpectedDimension(expected, got), + fault: FaultSource::Runtime, + } + } } #[derive(Debug, thiserror::Error)] diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index b2638966e..737878a1a 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -210,16 +210,19 @@ impl Embedder { while tokens.len() > max_token_count { let window = &tokens[..max_token_count]; let embedding = self.rest_embedder.embed_tokens(window)?; - /// FIXME: unwrap - embeddings_for_prompt.append(embedding.into_inner()).unwrap(); + embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { + EmbedError::openai_unexpected_dimension(self.dimensions(), got.len()) + })?; tokens = &tokens[max_token_count - OVERLAP_SIZE..]; } // end of text let embedding = self.rest_embedder.embed_tokens(tokens)?; - /// FIXME: unwrap - embeddings_for_prompt.append(embedding.into_inner()).unwrap(); + + embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { + EmbedError::openai_unexpected_dimension(self.dimensions(), got.len()) + })?; all_embeddings.push(embeddings_for_prompt); } diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 6fd47d882..8650bb68d 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -189,19 +189,29 @@ where [input] => { let mut body = options.query.clone(); - /// FIXME unwrap - body.as_object_mut().unwrap().insert(input.clone(), input_value); + body.as_object_mut() + .ok_or_else(|| { + EmbedError::rest_not_an_object( + options.query.clone(), + options.input_field.clone(), + ) + })? + .insert(input.clone(), input_value); body } [path @ .., input] => { let mut body = options.query.clone(); - /// FIXME unwrap let mut current_value = &mut body; for component in path { current_value = current_value .as_object_mut() - .unwrap() + .ok_or_else(|| { + EmbedError::rest_not_an_object( + options.query.clone(), + options.input_field.clone(), + ) + })? .entry(component.clone()) .or_insert(serde_json::json!({})); } From a1db342f01076a6ccebb2e76a8fe43fbf2ac22a0 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 25 Mar 2024 10:05:38 +0100 Subject: [PATCH 8/9] Expose REST embedder to the API --- milli/src/update/index_documents/mod.rs | 6 + milli/src/update/settings.rs | 74 +++++++++- milli/src/vector/hf.rs | 5 +- milli/src/vector/mod.rs | 58 +++++++- milli/src/vector/openai.rs | 22 +-- milli/src/vector/rest.rs | 41 +++++- milli/src/vector/settings.rs | 183 ++++++++++++++++++++++-- 7 files changed, 357 insertions(+), 32 deletions(-) diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 7499b68e5..913fbc881 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -2646,6 +2646,12 @@ mod tests { api_key: Setting::NotSet, dimensions: Setting::Set(3), document_template: Setting::NotSet, + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }), ); settings.set_embedder_settings(embedders); diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index f54f45e1e..4c7289eb7 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1140,6 +1140,12 @@ fn validate_prompt( api_key, dimensions, document_template: Setting::Set(template), + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, }) => { // validate let template = crate::prompt::Prompt::new(template) @@ -1153,6 +1159,12 @@ fn validate_prompt( api_key, dimensions, document_template: Setting::Set(template), + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, })) } new => Ok(new), @@ -1165,8 +1177,20 @@ pub fn validate_embedding_settings( ) -> Result> { let settings = validate_prompt(name, settings)?; let Setting::Set(settings) = settings else { return Ok(settings) }; - let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = - settings; + let EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + } = settings; if let Some(0) = dimensions.set() { return Err(crate::error::UserError::InvalidSettingsDimensions { @@ -1183,11 +1207,25 @@ pub fn validate_embedding_settings( api_key, dimensions, document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, })); }; match inferred_source { EmbedderSource::OpenAi => { check_unset(&revision, "revision", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; + if let Setting::Set(model) = &model { let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str()) .ok_or(crate::error::UserError::InvalidOpenAiModel { @@ -1224,10 +1262,24 @@ pub fn validate_embedding_settings( check_set(&model, "model", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; } EmbedderSource::HuggingFace => { check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&dimensions, "dimensions", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; } EmbedderSource::UserProvided => { check_unset(&model, "model", inferred_source, name)?; @@ -1235,6 +1287,18 @@ pub fn validate_embedding_settings( check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&document_template, "documentTemplate", inferred_source, name)?; check_set(&dimensions, "dimensions", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; + } + EmbedderSource::Rest => { + check_unset(&model, "model", inferred_source, name)?; + check_unset(&revision, "revision", inferred_source, name)?; + check_set(&url, "url", inferred_source, name)?; } } Ok(Setting::Set(EmbeddingSettings { @@ -1244,6 +1308,12 @@ pub fn validate_embedding_settings( api_key, dimensions, document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, })) } diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 939b6210a..e341a553e 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -194,7 +194,10 @@ impl Embedder { pub fn distribution(&self) -> Option { if self.options.model == "BAAI/bge-base-en-v1.5" { - Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) + Some(DistributionShift { + current_mean: ordered_float::OrderedFloat(0.85), + current_sigma: ordered_float::OrderedFloat(0.1), + }) } else { None } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 39232e387..65654af4a 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -1,6 +1,9 @@ use std::collections::HashMap; use std::sync::Arc; +use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; + use self::error::{EmbedError, NewEmbedderError}; use crate::prompt::{Prompt, PromptData}; @@ -104,7 +107,10 @@ pub enum Embedder { OpenAi(openai::Embedder), /// An embedder based on the user providing the embeddings in the documents and queries. UserProvided(manual::Embedder), + /// An embedder based on making embedding queries against an embedding server. Ollama(ollama::Embedder), + /// An embedder based on making embedding queries against a generic JSON/REST embedding server. + Rest(rest::Embedder), } /// Configuration for an embedder. @@ -175,6 +181,7 @@ pub enum EmbedderOptions { OpenAi(openai::EmbedderOptions), Ollama(ollama::EmbedderOptions), UserProvided(manual::EmbedderOptions), + Rest(rest::EmbedderOptions), } impl Default for EmbedderOptions { @@ -209,6 +216,7 @@ impl Embedder { EmbedderOptions::UserProvided(options) => { Self::UserProvided(manual::Embedder::new(options)) } + EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?), }) } @@ -224,6 +232,7 @@ impl Embedder { Embedder::OpenAi(embedder) => embedder.embed(texts), Embedder::Ollama(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts), + Embedder::Rest(embedder) => embedder.embed(texts), } } @@ -240,6 +249,7 @@ impl Embedder { Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), + Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads), } } @@ -250,6 +260,7 @@ impl Embedder { Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), Embedder::Ollama(embedder) => embedder.chunk_count_hint(), Embedder::UserProvided(_) => 1, + Embedder::Rest(embedder) => embedder.chunk_count_hint(), } } @@ -260,6 +271,7 @@ impl Embedder { Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::UserProvided(_) => 1, + Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(), } } @@ -270,6 +282,7 @@ impl Embedder { Embedder::OpenAi(embedder) => embedder.dimensions(), Embedder::Ollama(embedder) => embedder.dimensions(), Embedder::UserProvided(embedder) => embedder.dimensions(), + Embedder::Rest(embedder) => embedder.dimensions(), } } @@ -280,6 +293,7 @@ impl Embedder { Embedder::OpenAi(embedder) => embedder.distribution(), Embedder::Ollama(embedder) => embedder.distribution(), Embedder::UserProvided(_embedder) => None, + Embedder::Rest(embedder) => embedder.distribution(), } } } @@ -288,17 +302,47 @@ impl Embedder { /// /// The intended use is to make the similarity score more comparable to the regular ranking score. /// This allows to correct effects where results are too "packed" around a certain value. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)] +#[serde(from = "DistributionShiftSerializable")] +#[serde(into = "DistributionShiftSerializable")] pub struct DistributionShift { /// Value where the results are "packed". /// /// Similarity scores are translated so that they are packed around 0.5 instead - pub current_mean: f32, + pub current_mean: OrderedFloat, /// standard deviation of a similarity score. /// /// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed. - pub current_sigma: f32, + pub current_sigma: OrderedFloat, +} + +#[derive(Serialize, Deserialize)] +struct DistributionShiftSerializable { + current_mean: f32, + current_sigma: f32, +} + +impl From for DistributionShiftSerializable { + fn from( + DistributionShift { + current_mean: OrderedFloat(current_mean), + current_sigma: OrderedFloat(current_sigma), + }: DistributionShift, + ) -> Self { + Self { current_mean, current_sigma } + } +} + +impl From for DistributionShift { + fn from( + DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable, + ) -> Self { + Self { + current_mean: OrderedFloat(current_mean), + current_sigma: OrderedFloat(current_sigma), + } + } } impl DistributionShift { @@ -307,11 +351,13 @@ impl DistributionShift { if sigma <= 0.0 { None } else { - Some(Self { current_mean: mean, current_sigma: sigma }) + Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }) } } pub fn shift(&self, score: f32) -> f32 { + let current_mean = self.current_mean.0; + let current_sigma = self.current_sigma.0; // // We're somewhat abusively mapping the distribution of distances to a gaussian. // The parameters we're given is the mean and sigma of the native result distribution. @@ -321,9 +367,9 @@ impl DistributionShift { let target_sigma = 0.4; // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. - let factor = target_sigma / self.current_sigma; + let factor = target_sigma / current_sigma; // a*mu1 + b = mu2 => b = mu2 - a*mu1 - let offset = target_mean - (factor * self.current_mean); + let offset = target_mean - (factor * current_mean); let mut score = factor * score + offset; diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 737878a1a..24e94a9f7 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -1,3 +1,4 @@ +use ordered_float::OrderedFloat; use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use super::error::{EmbedError, NewEmbedderError}; @@ -110,15 +111,18 @@ impl EmbeddingModel { fn distribution(&self) -> Option { match self { - EmbeddingModel::TextEmbeddingAda002 => { - Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) - } - EmbeddingModel::TextEmbedding3Large => { - Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 }) - } - EmbeddingModel::TextEmbedding3Small => { - Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 }) - } + EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift { + current_mean: OrderedFloat(0.90), + current_sigma: OrderedFloat(0.08), + }), + EmbeddingModel::TextEmbedding3Large => Some(DistributionShift { + current_mean: OrderedFloat(0.70), + current_sigma: OrderedFloat(0.1), + }), + EmbeddingModel::TextEmbedding3Small => Some(DistributionShift { + current_mean: OrderedFloat(0.75), + current_sigma: OrderedFloat(0.1), + }), } } diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 8650bb68d..b0ea07f82 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -1,5 +1,6 @@ +use deserr::Deserr; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; use super::{ DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, @@ -64,7 +65,7 @@ pub struct Embedder { dimensions: usize, } -#[derive(Debug)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct EmbedderOptions { pub api_key: Option, pub distribution: Option, @@ -79,7 +80,41 @@ pub struct EmbedderOptions { pub input_type: InputType, } -#[derive(Debug)] +impl Default for EmbedderOptions { + fn default() -> Self { + Self { + url: Default::default(), + query: Default::default(), + input_field: vec!["input".into()], + path_to_embeddings: vec!["data".into()], + embedding_object: vec!["embedding".into()], + input_type: InputType::Text, + api_key: None, + distribution: None, + dimensions: None, + } + } +} + +impl std::hash::Hash for EmbedderOptions { + fn hash(&self, state: &mut H) { + self.api_key.hash(state); + self.distribution.hash(state); + self.dimensions.hash(state); + self.url.hash(state); + // skip hashing the query + // collisions in regular usage should be minimal, + // and the list is limited to 256 values anyway + self.input_field.hash(state); + self.path_to_embeddings.hash(state); + self.embedding_object.hash(state); + self.input_type.hash(state); + } +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)] +#[serde(rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] pub enum InputType { Text, TextArray, diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 540693d44..c5b0d0326 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -1,6 +1,7 @@ use deserr::Deserr; use serde::{Deserialize, Serialize}; +use super::rest::InputType; use super::{ollama, openai}; use crate::prompt::PromptData; use crate::update::Setting; @@ -29,6 +30,24 @@ pub struct EmbeddingSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub document_template: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub url: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub query: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub input_field: Setting>, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub path_to_embeddings: Setting>, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub embedding_object: Setting>, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub input_type: Setting, } pub fn check_unset( @@ -75,20 +94,42 @@ impl EmbeddingSettings { pub const DIMENSIONS: &'static str = "dimensions"; pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; + pub const URL: &'static str = "url"; + pub const QUERY: &'static str = "query"; + pub const INPUT_FIELD: &'static str = "inputField"; + pub const PATH_TO_EMBEDDINGS: &'static str = "pathToEmbeddings"; + pub const EMBEDDING_OBJECT: &'static str = "embeddingObject"; + pub const INPUT_TYPE: &'static str = "inputType"; + pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { match field { - Self::SOURCE => { - &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided] - } + Self::SOURCE => &[ + EmbedderSource::HuggingFace, + EmbedderSource::OpenAi, + EmbedderSource::UserProvided, + EmbedderSource::Rest, + EmbedderSource::Ollama, + ], Self::MODEL => { &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] } Self::REVISION => &[EmbedderSource::HuggingFace], - Self::API_KEY => &[EmbedderSource::OpenAi], - Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided], - Self::DOCUMENT_TEMPLATE => { - &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] + Self::API_KEY => &[EmbedderSource::OpenAi, EmbedderSource::Rest], + Self::DIMENSIONS => { + &[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Rest] } + Self::DOCUMENT_TEMPLATE => &[ + EmbedderSource::HuggingFace, + EmbedderSource::OpenAi, + EmbedderSource::Ollama, + EmbedderSource::Rest, + ], + Self::URL => &[EmbedderSource::Rest], + Self::QUERY => &[EmbedderSource::Rest], + Self::INPUT_FIELD => &[EmbedderSource::Rest], + Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], + Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest], + Self::INPUT_TYPE => &[EmbedderSource::Rest], _other => unreachable!("unknown field"), } } @@ -107,6 +148,18 @@ impl EmbeddingSettings { } EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE], EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], + EmbedderSource::Rest => &[ + Self::SOURCE, + Self::API_KEY, + Self::DIMENSIONS, + Self::DOCUMENT_TEMPLATE, + Self::URL, + Self::QUERY, + Self::INPUT_FIELD, + Self::PATH_TO_EMBEDDINGS, + Self::EMBEDDING_OBJECT, + Self::INPUT_TYPE, + ], } } @@ -141,6 +194,7 @@ pub enum EmbedderSource { HuggingFace, Ollama, UserProvided, + Rest, } impl std::fmt::Display for EmbedderSource { @@ -150,6 +204,7 @@ impl std::fmt::Display for EmbedderSource { EmbedderSource::HuggingFace => "huggingFace", EmbedderSource::UserProvided => "userProvided", EmbedderSource::Ollama => "ollama", + EmbedderSource::Rest => "rest", }; f.write_str(s) } @@ -157,8 +212,20 @@ impl std::fmt::Display for EmbedderSource { impl EmbeddingSettings { pub fn apply(&mut self, new: Self) { - let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = - new; + let EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + } = new; let old_source = self.source; self.source.apply(source); // Reinitialize the whole setting object on a source change @@ -170,6 +237,12 @@ impl EmbeddingSettings { api_key, dimensions, document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, }; return; } @@ -179,6 +252,13 @@ impl EmbeddingSettings { self.api_key.apply(api_key); self.dimensions.apply(dimensions); self.document_template.apply(document_template); + + self.url.apply(url); + self.query.apply(query); + self.input_field.apply(input_field); + self.path_to_embeddings.apply(path_to_embeddings); + self.embedding_object.apply(embedding_object); + self.input_type.apply(input_type); } } @@ -193,6 +273,12 @@ impl From for EmbeddingSettings { api_key: Setting::NotSet, dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }, super::EmbedderOptions::OpenAi(options) => Self { source: Setting::Set(EmbedderSource::OpenAi), @@ -201,6 +287,12 @@ impl From for EmbeddingSettings { api_key: options.api_key.map(Setting::Set).unwrap_or_default(), dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(), document_template: Setting::Set(prompt.template), + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }, super::EmbedderOptions::Ollama(options) => Self { source: Setting::Set(EmbedderSource::Ollama), @@ -209,6 +301,12 @@ impl From for EmbeddingSettings { api_key: Setting::NotSet, dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }, super::EmbedderOptions::UserProvided(options) => Self { source: Setting::Set(EmbedderSource::UserProvided), @@ -217,6 +315,37 @@ impl From for EmbeddingSettings { api_key: Setting::NotSet, dimensions: Setting::Set(options.dimensions), document_template: Setting::NotSet, + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, + }, + super::EmbedderOptions::Rest(super::rest::EmbedderOptions { + api_key, + // TODO: support distribution + distribution: _, + dimensions, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + }) => Self { + source: Setting::Set(EmbedderSource::Rest), + model: Setting::NotSet, + revision: Setting::NotSet, + api_key: api_key.map(Setting::Set).unwrap_or_default(), + dimensions: dimensions.map(Setting::Set).unwrap_or_default(), + document_template: Setting::Set(prompt.template), + url: Setting::Set(url), + query: Setting::Set(query), + input_field: Setting::Set(input_field), + path_to_embeddings: Setting::Set(path_to_embeddings), + embedding_object: Setting::Set(embedding_object), + input_type: Setting::Set(input_type), }, } } @@ -225,8 +354,20 @@ impl From for EmbeddingSettings { impl From for EmbeddingConfig { fn from(value: EmbeddingSettings) -> Self { let mut this = Self::default(); - let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = - value; + let EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + } = value; if let Some(source) = source.set() { match source { EmbedderSource::OpenAi => { @@ -274,6 +415,26 @@ impl From for EmbeddingConfig { dimensions: dimensions.set().unwrap(), }); } + EmbedderSource::Rest => { + let embedder_options = super::rest::EmbedderOptions::default(); + + this.embedder_options = + super::EmbedderOptions::Rest(super::rest::EmbedderOptions { + api_key: api_key.set(), + distribution: None, + dimensions: dimensions.set(), + url: url.set().unwrap(), + query: query.set().unwrap_or(embedder_options.query), + input_field: input_field.set().unwrap_or(embedder_options.input_field), + path_to_embeddings: path_to_embeddings + .set() + .unwrap_or(embedder_options.path_to_embeddings), + embedding_object: embedding_object + .set() + .unwrap_or(embedder_options.embedding_object), + input_type: input_type.set().unwrap_or(embedder_options.input_type), + }) + } } } From dfa5e41ea6fa37f1d5df6710da543b00050417d9 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 25 Mar 2024 10:05:58 +0100 Subject: [PATCH 9/9] Check validity of the URL setting --- Cargo.lock | 1 + meilisearch-types/src/error.rs | 1 + meilisearch/src/routes/indexes/settings.rs | 1 + milli/Cargo.toml | 1 + milli/src/error.rs | 2 ++ milli/src/update/settings.rs | 8 ++++++++ 6 files changed, 14 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 6a8c20f12..214ba368f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3376,6 +3376,7 @@ dependencies = [ "tokenizers", "tracing", "ureq", + "url", "uuid", ] diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index aed77411a..1b94201f2 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -353,6 +353,7 @@ impl ErrorCode for milli::Error { | UserError::InvalidOpenAiModelDimensions { .. } | UserError::InvalidOpenAiModelDimensionsMax { .. } | UserError::InvalidSettingsDimensions { .. } + | UserError::InvalidUrl { .. } | UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, diff --git a/meilisearch/src/routes/indexes/settings.rs b/meilisearch/src/routes/indexes/settings.rs index 5dabd7b0d..99c3d0fbb 100644 --- a/meilisearch/src/routes/indexes/settings.rs +++ b/meilisearch/src/routes/indexes/settings.rs @@ -605,6 +605,7 @@ fn embedder_analytics( EmbedderSource::HuggingFace => sources.insert("huggingFace"), EmbedderSource::UserProvided => sources.insert("userProvided"), EmbedderSource::Ollama => sources.insert("ollama"), + EmbedderSource::Rest => sources.insert("rest"), }; } }; diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 4833ad00b..9f5803f4e 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -86,6 +86,7 @@ arroy = "0.2.0" rand = "0.8.5" tracing = "0.1.40" ureq = { version = "2.9.6", features = ["json"] } +url = "2.5.0" [dev-dependencies] mimalloc = { version = "0.1.39", default-features = false } diff --git a/milli/src/error.rs b/milli/src/error.rs index 1147085dd..aba80b475 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -243,6 +243,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco }, #[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")] InvalidSettingsDimensions { embedder_name: String }, + #[error("`.embedders.{embedder_name}.url`: could not parse `{url}`: {inner_error}")] + InvalidUrl { embedder_name: String, inner_error: url::ParseError, url: String }, } impl From for Error { diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 4c7289eb7..e902badc0 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1199,6 +1199,14 @@ pub fn validate_embedding_settings( .into()); } + if let Some(url) = url.as_ref().set() { + url::Url::parse(url).map_err(|error| crate::error::UserError::InvalidUrl { + embedder_name: name.to_owned(), + inner_error: error, + url: url.to_owned(), + })?; + } + let Some(inferred_source) = source.set() else { return Ok(Setting::Set(EmbeddingSettings { source,