From d3004d8040a9ddadbea6be97e8059c2d68ab50b6 Mon Sep 17 00:00:00 2001 From: Jakob Klemm Date: Sun, 3 Mar 2024 01:11:25 +0100 Subject: [PATCH 1/3] Implemented Ollama as an embeddings provider Initial prototype of Ollama embeddings actually working, error handlign / retries still missing. Allow model to be any String and require dimensions parameter Fixed rustfmt formatting issues There were some formatting issues in the initial PR and this should not make the changes comply with the Rust style guidelines Because I accidentally didn't follow the style guide for commits in my commit messages I squashed them into one to comply --- meilisearch/src/routes/indexes/settings.rs | 1 + milli/src/update/settings.rs | 7 + milli/src/vector/error.rs | 26 +++ milli/src/vector/mod.rs | 18 ++ milli/src/vector/ollama.rs | 255 +++++++++++++++++++++ milli/src/vector/openai.rs | 20 +- milli/src/vector/settings.rs | 38 ++- 7 files changed, 350 insertions(+), 15 deletions(-) create mode 100644 milli/src/vector/ollama.rs diff --git a/meilisearch/src/routes/indexes/settings.rs b/meilisearch/src/routes/indexes/settings.rs index c71d83279..c782e78cb 100644 --- a/meilisearch/src/routes/indexes/settings.rs +++ b/meilisearch/src/routes/indexes/settings.rs @@ -604,6 +604,7 @@ fn embedder_analytics( EmbedderSource::OpenAi => sources.insert("openAi"), EmbedderSource::HuggingFace => sources.insert("huggingFace"), EmbedderSource::UserProvided => sources.insert("userProvided"), + EmbedderSource::Ollama => sources.insert("ollama"), }; } }; diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 3cad79467..df273b023 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1178,6 +1178,13 @@ pub fn validate_embedding_settings( } } } + EmbedderSource::Ollama => { + // Existence & corrent dimensions of models cannot easily be checked here. + check_set(&dimensions, "dimensions", inferred_source, name)?; + check_set(&model, "model", inferred_source, name)?; + check_unset(&api_key, "apiKey", inferred_source, name)?; + check_unset(&revision, "revision", inferred_source, name)?; + } EmbedderSource::HuggingFace => { check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&dimensions, "dimensions", inferred_source, name)?; diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 3673c85e3..ffdda42ca 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -2,6 +2,7 @@ use std::path::PathBuf; use hf_hub::api::sync::ApiError; +use super::ollama::OllamaError; use crate::error::FaultSource; use crate::vector::openai::OpenAiError; @@ -71,6 +72,15 @@ pub enum EmbedErrorKind { OpenAiRuntimeInit(std::io::Error), #[error("initializing web client for sending embedding requests failed: {0}")] InitWebClient(reqwest::Error), + // Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends. + #[error("unexpected response from Ollama: {0}")] + OllamaUnexpected(reqwest::Error), + #[error("sent too many requests to Ollama: {0}")] + OllamaTooManyRequests(OllamaError), + #[error("received internal error from Ollama: {0}")] + OllamaInternalServerError(OllamaError), + #[error("received unhandled HTTP status code {0} from Ollama")] + OllamaUnhandledStatusCode(u16), } impl EmbedError { @@ -129,6 +139,22 @@ impl EmbedError { pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } } + + pub fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { + Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug } + } + + pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { + Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError { + Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { + Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } + } } #[derive(Debug, thiserror::Error)] diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 99b7bff7e..635c72878 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -10,6 +10,8 @@ pub mod manual; pub mod openai; pub mod settings; +pub mod ollama; + pub use self::error::Error; pub type Embedding = Vec; @@ -76,6 +78,7 @@ pub enum Embedder { HuggingFace(hf::Embedder), OpenAi(openai::Embedder), UserProvided(manual::Embedder), + Ollama(ollama::Embedder), } #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] @@ -127,6 +130,7 @@ impl IntoIterator for EmbeddingConfigs { pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions), OpenAi(openai::EmbedderOptions), + Ollama(ollama::EmbedderOptions), UserProvided(manual::EmbedderOptions), } @@ -144,6 +148,10 @@ impl EmbedderOptions { pub fn openai(api_key: Option) -> Self { Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) } + + pub fn ollama() -> Self { + Self::Ollama(ollama::EmbedderOptions::with_default_model()) + } } impl Embedder { @@ -151,6 +159,7 @@ impl Embedder { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), + EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?), EmbedderOptions::UserProvided(options) => { Self::UserProvided(manual::Embedder::new(options)) } @@ -167,6 +176,10 @@ impl Embedder { let client = embedder.new_client()?; embedder.embed(texts, &client).await } + Embedder::Ollama(embedder) => { + let client = embedder.new_client()?; + embedder.embed(texts, &client).await + } Embedder::UserProvided(embedder) => embedder.embed(texts), } } @@ -181,6 +194,7 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), + Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), } } @@ -189,6 +203,7 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), + Embedder::Ollama(embedder) => embedder.chunk_count_hint(), Embedder::UserProvided(_) => 1, } } @@ -197,6 +212,7 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), + Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::UserProvided(_) => 1, } } @@ -205,6 +221,7 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.dimensions(), Embedder::OpenAi(embedder) => embedder.dimensions(), + Embedder::Ollama(embedder) => embedder.dimensions(), Embedder::UserProvided(embedder) => embedder.dimensions(), } } @@ -213,6 +230,7 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.distribution(), Embedder::OpenAi(embedder) => embedder.distribution(), + Embedder::Ollama(embedder) => embedder.distribution(), Embedder::UserProvided(_embedder) => None, } } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs new file mode 100644 index 000000000..a83022dbd --- /dev/null +++ b/milli/src/vector/ollama.rs @@ -0,0 +1,255 @@ +// Copied from "openai.rs" with the sections I actually understand changed for Ollama. +// The common components of the Ollama and OpenAI interfaces might need to be extracted. + +use std::fmt::Display; + +use reqwest::StatusCode; + +use super::error::{EmbedError, NewEmbedderError}; +use super::openai::Retry; +use super::{DistributionShift, Embedding, Embeddings}; + +#[derive(Debug)] +pub struct Embedder { + headers: reqwest::header::HeaderMap, + options: EmbedderOptions, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub embedding_model: EmbeddingModel, + pub dimensions: usize, +} + +#[derive( + Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr, +)] +#[deserr(deny_unknown_fields)] +pub struct EmbeddingModel { + name: String, +} + +#[derive(Debug, serde::Serialize)] +struct OllamaRequest<'a> { + model: &'a str, + prompt: &'a str, +} + +#[derive(Debug, serde::Deserialize)] +struct OllamaResponse { + embedding: Embedding, +} + +#[derive(Debug, serde::Deserialize)] +struct OllamaErrorResponse { + error: OllamaError, +} + +#[derive(Debug, serde::Deserialize)] +pub struct OllamaError { + message: String, + // type: String, + code: Option, +} + +impl EmbeddingModel { + pub fn max_token(&self) -> usize { + // this might not be the same for all models + 8192 + } + + pub fn default_dimensions(&self) -> usize { + // Dimensions for nomic-embed-text + 768 + } + + pub fn name(&self) -> String { + self.name.clone() + } + + pub fn from_name(name: &str) -> Self { + Self { name: name.to_string() } + } + + pub fn supports_overriding_dimensions(&self) -> bool { + false + } +} + +impl Default for EmbeddingModel { + fn default() -> Self { + Self { name: "nomic-embed-text".to_string() } + } +} + +impl EmbedderOptions { + pub fn with_default_model() -> Self { + Self { embedding_model: Default::default(), dimensions: 768 } + } + + pub fn with_embedding_model(embedding_model: EmbeddingModel, dimensions: usize) -> Self { + Self { embedding_model, dimensions } + } +} + +impl Embedder { + pub fn new_client(&self) -> Result { + reqwest::ClientBuilder::new() + .default_headers(self.headers.clone()) + .build() + .map_err(EmbedError::openai_initialize_web_client) + } + + pub fn new(options: EmbedderOptions) -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::CONTENT_TYPE, + reqwest::header::HeaderValue::from_static("application/json"), + ); + + Ok(Self { options, headers }) + } + + async fn check_response(response: reqwest::Response) -> Result { + if !response.status().is_success() { + // Not the same number of possible error cases covered as with OpenAI. + match response.status() { + StatusCode::TOO_MANY_REQUESTS => { + let error_response: OllamaErrorResponse = response + .json() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(Retry::retry_later)?; + + return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( + error_response.error, + ))); + } + StatusCode::SERVICE_UNAVAILABLE => { + let error_response: OllamaErrorResponse = response + .json() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(Retry::retry_later)?; + return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( + error_response.error, + ))); + } + code => { + return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( + code.as_u16(), + ))); + } + } + } + Ok(response) + } + + pub async fn embed( + &self, + texts: Vec, + client: &reqwest::Client, + ) -> Result>, EmbedError> { + // Ollama only embedds one document at a time. + let mut results = Vec::with_capacity(texts.len()); + + // The retry loop is inside the texts loop, might have to switch that around + for text in texts { + // Retries copied from openai.rs + for attempt in 0..7 { + let retry_duration = match self.try_embed(&text, client).await { + Ok(result) => { + results.push(result); + break; + } + Err(retry) => { + tracing::warn!("Failed: {}", retry.error); + retry.into_duration(attempt) + } + }?; + tracing::warn!( + "Attempt #{}, retrying after {}ms.", + attempt, + retry_duration.as_millis() + ); + tokio::time::sleep(retry_duration).await; + } + } + + Ok(results) + } + + async fn try_embed( + &self, + text: &str, + client: &reqwest::Client, + ) -> Result, Retry> { + let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text }; + let response = client + .post(get_ollama_path()) + .json(&request) + .send() + .await + .map_err(EmbedError::openai_network) + .map_err(Retry::retry_later)?; + + let response = Self::check_response(response).await?; + + let response: OllamaResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + tracing::trace!("response: {:?}", response.embedding); + + let embedding = Embeddings::from_single_embedding(response.embedding); + Ok(embedding) + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .map_err(EmbedError::openai_runtime_init)?; + let client = self.new_client()?; + rt.block_on(futures::future::try_join_all( + text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), + )) + } + + // Defaults copied from openai.rs + pub fn chunk_count_hint(&self) -> usize { + 10 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + 10 + } + + pub fn dimensions(&self) -> usize { + self.options.dimensions + } + + pub fn distribution(&self) -> Option { + None + } +} + +impl Display for OllamaError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.code { + Some(code) => write!(f, "{} ({})", self.message, code), + None => write!(f, "{}", self.message), + } + } +} + +fn get_ollama_path() -> String { + // Important: Hostname not enough, has to be entire path to embeddings endpoint + std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string()) +} diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index cbddddfb7..ae0fdddf8 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -429,12 +429,12 @@ impl Embedder { // retrying in case of failure -struct Retry { - error: EmbedError, +pub struct Retry { + pub error: EmbedError, strategy: RetryStrategy, } -enum RetryStrategy { +pub enum RetryStrategy { GiveUp, Retry, RetryTokenized, @@ -442,23 +442,23 @@ enum RetryStrategy { } impl Retry { - fn give_up(error: EmbedError) -> Self { + pub fn give_up(error: EmbedError) -> Self { Self { error, strategy: RetryStrategy::GiveUp } } - fn retry_later(error: EmbedError) -> Self { + pub fn retry_later(error: EmbedError) -> Self { Self { error, strategy: RetryStrategy::Retry } } - fn retry_tokenized(error: EmbedError) -> Self { + pub fn retry_tokenized(error: EmbedError) -> Self { Self { error, strategy: RetryStrategy::RetryTokenized } } - fn rate_limited(error: EmbedError) -> Self { + pub fn rate_limited(error: EmbedError) -> Self { Self { error, strategy: RetryStrategy::RetryAfterRateLimit } } - fn into_duration(self, attempt: u32) -> Result { + pub fn into_duration(self, attempt: u32) -> Result { match self.strategy { RetryStrategy::GiveUp => Err(self.error), RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), @@ -469,11 +469,11 @@ impl Retry { } } - fn must_tokenize(&self) -> bool { + pub fn must_tokenize(&self) -> bool { matches!(self.strategy, RetryStrategy::RetryTokenized) } - fn into_error(self) -> EmbedError { + pub fn into_error(self) -> EmbedError { self.error } } diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 834a1c81d..5595f60e3 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -1,7 +1,7 @@ use deserr::Deserr; use serde::{Deserialize, Serialize}; -use super::openai; +use super::{ollama, openai}; use crate::prompt::PromptData; use crate::update::Setting; use crate::vector::EmbeddingConfig; @@ -80,11 +80,17 @@ impl EmbeddingSettings { Self::SOURCE => { &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided] } - Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], + Self::MODEL => { + &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] + } Self::REVISION => &[EmbedderSource::HuggingFace], Self::API_KEY => &[EmbedderSource::OpenAi], - Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided], - Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi], + Self::DIMENSIONS => { + &[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Ollama] + } + Self::DOCUMENT_TEMPLATE => { + &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] + } _other => unreachable!("unknown field"), } } @@ -101,6 +107,9 @@ impl EmbeddingSettings { EmbedderSource::HuggingFace => { &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] } + EmbedderSource::Ollama => { + &[Self::SOURCE, Self::MODEL, Self::DIMENSIONS, Self::DOCUMENT_TEMPLATE] + } EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], } } @@ -134,6 +143,7 @@ pub enum EmbedderSource { #[default] OpenAi, HuggingFace, + Ollama, UserProvided, } @@ -143,6 +153,7 @@ impl std::fmt::Display for EmbedderSource { EmbedderSource::OpenAi => "openAi", EmbedderSource::HuggingFace => "huggingFace", EmbedderSource::UserProvided => "userProvided", + EmbedderSource::Ollama => "ollama", }; f.write_str(s) } @@ -192,7 +203,15 @@ impl From for EmbeddingSettings { model: Setting::Set(options.embedding_model.name().to_owned()), revision: Setting::NotSet, api_key: options.api_key.map(Setting::Set).unwrap_or_default(), - dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(), + dimensions: Setting::Set(options.dimensions.unwrap_or_default()), + document_template: Setting::Set(prompt.template), + }, + super::EmbedderOptions::Ollama(options) => Self { + source: Setting::Set(EmbedderSource::Ollama), + model: Setting::Set(options.embedding_model.name().to_owned()), + revision: Setting::NotSet, + api_key: Setting::NotSet, + dimensions: Setting::Set(options.dimensions), document_template: Setting::Set(prompt.template), }, super::EmbedderOptions::UserProvided(options) => Self { @@ -229,6 +248,15 @@ impl From for EmbeddingConfig { } this.embedder_options = super::EmbedderOptions::OpenAi(options); } + EmbedderSource::Ollama => { + let mut options: ollama::EmbedderOptions = + super::ollama::EmbedderOptions::with_default_model(); + if let (Some(model), Some(dim)) = (model.set(), dimensions.set()) { + options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); + options.dimensions = dim; + } + this.embedder_options = super::EmbedderOptions::Ollama(options); + } EmbedderSource::HuggingFace => { let mut options = super::hf::EmbedderOptions::default(); if let Some(model) = model.set() { From 88bc9556a93009d780f48a37fed88e2187b96ffc Mon Sep 17 00:00:00 2001 From: Jakob Klemm Date: Tue, 12 Mar 2024 19:59:11 +0100 Subject: [PATCH 2/3] Add Ollama dimension inference and add clearer errors Instead of the user manually specifying the model dimensions it will now automatically get determined Just like with hf.rs the word "test" gets embedded to determine the dimensions of the output Add a dedicated error type for if the model doesn't exist (don't automatically pull it though) and set the fault of that error to be the user --- milli/src/update/settings.rs | 4 +- milli/src/vector/error.rs | 15 +++++- milli/src/vector/ollama.rs | 100 ++++++++++++++++++++++++++--------- milli/src/vector/settings.rs | 13 ++--- 4 files changed, 96 insertions(+), 36 deletions(-) diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index df273b023..ee2f58a01 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1179,8 +1179,8 @@ pub fn validate_embedding_settings( } } EmbedderSource::Ollama => { - // Existence & corrent dimensions of models cannot easily be checked here. - check_set(&dimensions, "dimensions", inferred_source, name)?; + // Dimensions get inferred, only model name is required + check_unset(&dimensions, "dimensions", inferred_source, name)?; check_set(&model, "model", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?; diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index ffdda42ca..3f4d5eb51 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -79,6 +79,8 @@ pub enum EmbedErrorKind { OllamaTooManyRequests(OllamaError), #[error("received internal error from Ollama: {0}")] OllamaInternalServerError(OllamaError), + #[error("model not found. MeiliSearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] + OllamaModelNotFoundError(OllamaError), #[error("received unhandled HTTP status code {0} from Ollama")] OllamaUnhandledStatusCode(u16), } @@ -140,10 +142,14 @@ impl EmbedError { Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } } - pub fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { + pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug } } + pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError { + Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User } + } + pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } } @@ -221,6 +227,13 @@ impl NewEmbedderError { } } + pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { + Self { + kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), + fault: FaultSource::User, + } + } + pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index a83022dbd..76988f70b 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -18,7 +18,6 @@ pub struct Embedder { #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { pub embedding_model: EmbeddingModel, - pub dimensions: usize, } #[derive( @@ -27,6 +26,7 @@ pub struct EmbedderOptions { #[deserr(deny_unknown_fields)] pub struct EmbeddingModel { name: String, + dimensions: usize, } #[derive(Debug, serde::Serialize)] @@ -40,16 +40,9 @@ struct OllamaResponse { embedding: Embedding, } -#[derive(Debug, serde::Deserialize)] -struct OllamaErrorResponse { - error: OllamaError, -} - #[derive(Debug, serde::Deserialize)] pub struct OllamaError { - message: String, - // type: String, - code: Option, + error: String, } impl EmbeddingModel { @@ -68,7 +61,7 @@ impl EmbeddingModel { } pub fn from_name(name: &str) -> Self { - Self { name: name.to_string() } + Self { name: name.to_string(), dimensions: 0 } } pub fn supports_overriding_dimensions(&self) -> bool { @@ -78,17 +71,17 @@ impl EmbeddingModel { impl Default for EmbeddingModel { fn default() -> Self { - Self { name: "nomic-embed-text".to_string() } + Self { name: "nomic-embed-text".to_string(), dimensions: 0 } } } impl EmbedderOptions { pub fn with_default_model() -> Self { - Self { embedding_model: Default::default(), dimensions: 768 } + Self { embedding_model: Default::default() } } - pub fn with_embedding_model(embedding_model: EmbeddingModel, dimensions: usize) -> Self { - Self { embedding_model, dimensions } + pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { + Self { embedding_model } } } @@ -107,7 +100,58 @@ impl Embedder { reqwest::header::HeaderValue::from_static("application/json"), ); - Ok(Self { options, headers }) + let mut embedder = Self { options, headers }; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .map_err(EmbedError::openai_runtime_init) + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + // Get dimensions from Ollama + let request = + OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; + // TODO: Refactor into shared error type + let client = embedder + .new_client() + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + rt.block_on(async move { + let response = client + .post(get_ollama_path()) + .json(&request) + .send() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + // Process error in case model not found + let response = Self::check_response(response).await.map_err(|_err| { + let e = EmbedError::ollama_model_not_found(OllamaError { + error: format!("model: {}", embedder.options.embedding_model.name()), + }); + NewEmbedderError::ollama_could_not_determine_dimension(e) + })?; + + let response: OllamaResponse = response + .json() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + let embedding = Embeddings::from_single_embedding(response.embedding); + + embedder.options.embedding_model.dimensions = embedding.dimension(); + + tracing::info!( + "ollama model {} with dimensionality {} added", + embedder.options.embedding_model.name(), + embedding.dimension() + ); + + Ok(embedder) + }) } async fn check_response(response: reqwest::Response) -> Result { @@ -115,26 +159,37 @@ impl Embedder { // Not the same number of possible error cases covered as with OpenAI. match response.status() { StatusCode::TOO_MANY_REQUESTS => { - let error_response: OllamaErrorResponse = response + let error_response: OllamaError = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(Retry::retry_later)?; return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( - error_response.error, + OllamaError { error: error_response.error }, ))); } StatusCode::SERVICE_UNAVAILABLE => { - let error_response: OllamaErrorResponse = response + let error_response: OllamaError = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(Retry::retry_later)?; return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( - error_response.error, + OllamaError { error: error_response.error }, ))); } + StatusCode::NOT_FOUND => { + let error_response: OllamaError = response + .json() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(Retry::give_up)?; + + return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError { + error: error_response.error, + }))); + } code => { return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( code.as_u16(), @@ -232,7 +287,7 @@ impl Embedder { } pub fn dimensions(&self) -> usize { - self.options.dimensions + self.options.embedding_model.dimensions } pub fn distribution(&self) -> Option { @@ -242,10 +297,7 @@ impl Embedder { impl Display for OllamaError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.code { - Some(code) => write!(f, "{} ({})", self.message, code), - None => write!(f, "{}", self.message), - } + write!(f, "{}", self.error) } } diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 5595f60e3..84d58a996 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -85,9 +85,7 @@ impl EmbeddingSettings { } Self::REVISION => &[EmbedderSource::HuggingFace], Self::API_KEY => &[EmbedderSource::OpenAi], - Self::DIMENSIONS => { - &[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Ollama] - } + Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided], Self::DOCUMENT_TEMPLATE => { &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] } @@ -107,9 +105,7 @@ impl EmbeddingSettings { EmbedderSource::HuggingFace => { &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] } - EmbedderSource::Ollama => { - &[Self::SOURCE, Self::MODEL, Self::DIMENSIONS, Self::DOCUMENT_TEMPLATE] - } + EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE], EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], } } @@ -211,7 +207,7 @@ impl From for EmbeddingSettings { model: Setting::Set(options.embedding_model.name().to_owned()), revision: Setting::NotSet, api_key: Setting::NotSet, - dimensions: Setting::Set(options.dimensions), + dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), }, super::EmbedderOptions::UserProvided(options) => Self { @@ -251,9 +247,8 @@ impl From for EmbeddingConfig { EmbedderSource::Ollama => { let mut options: ollama::EmbedderOptions = super::ollama::EmbedderOptions::with_default_model(); - if let (Some(model), Some(dim)) = (model.set(), dimensions.set()) { + if let Some(model) = model.set() { options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); - options.dimensions = dim; } this.embedder_options = super::EmbedderOptions::Ollama(options); } From ae67d5eef097b1e722b4dc101bc8b696149d3096 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Mar 2024 09:45:04 +0100 Subject: [PATCH 3/3] Update milli/src/vector/error.rs Fix Meilisearch capitalization --- milli/src/vector/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 3f4d5eb51..20bf49a6b 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -79,7 +79,7 @@ pub enum EmbedErrorKind { OllamaTooManyRequests(OllamaError), #[error("received internal error from Ollama: {0}")] OllamaInternalServerError(OllamaError), - #[error("model not found. MeiliSearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] + #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] OllamaModelNotFoundError(OllamaError), #[error("received unhandled HTTP status code {0} from Ollama")] OllamaUnhandledStatusCode(u16),