From 88bc9556a93009d780f48a37fed88e2187b96ffc Mon Sep 17 00:00:00 2001 From: Jakob Klemm Date: Tue, 12 Mar 2024 19:59:11 +0100 Subject: [PATCH] Add Ollama dimension inference and add clearer errors Instead of the user manually specifying the model dimensions it will now automatically get determined Just like with hf.rs the word "test" gets embedded to determine the dimensions of the output Add a dedicated error type for if the model doesn't exist (don't automatically pull it though) and set the fault of that error to be the user --- milli/src/update/settings.rs | 4 +- milli/src/vector/error.rs | 15 +++++- milli/src/vector/ollama.rs | 100 ++++++++++++++++++++++++++--------- milli/src/vector/settings.rs | 13 ++--- 4 files changed, 96 insertions(+), 36 deletions(-) diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index df273b023..ee2f58a01 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1179,8 +1179,8 @@ pub fn validate_embedding_settings( } } EmbedderSource::Ollama => { - // Existence & corrent dimensions of models cannot easily be checked here. - check_set(&dimensions, "dimensions", inferred_source, name)?; + // Dimensions get inferred, only model name is required + check_unset(&dimensions, "dimensions", inferred_source, name)?; check_set(&model, "model", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?; diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index ffdda42ca..3f4d5eb51 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -79,6 +79,8 @@ pub enum EmbedErrorKind { OllamaTooManyRequests(OllamaError), #[error("received internal error from Ollama: {0}")] OllamaInternalServerError(OllamaError), + #[error("model not found. MeiliSearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] + OllamaModelNotFoundError(OllamaError), #[error("received unhandled HTTP status code {0} from Ollama")] OllamaUnhandledStatusCode(u16), } @@ -140,10 +142,14 @@ impl EmbedError { Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } } - pub fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { + pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug } } + pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError { + Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User } + } + pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } } @@ -221,6 +227,13 @@ impl NewEmbedderError { } } + pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { + Self { + kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), + fault: FaultSource::User, + } + } + pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index a83022dbd..76988f70b 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -18,7 +18,6 @@ pub struct Embedder { #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { pub embedding_model: EmbeddingModel, - pub dimensions: usize, } #[derive( @@ -27,6 +26,7 @@ pub struct EmbedderOptions { #[deserr(deny_unknown_fields)] pub struct EmbeddingModel { name: String, + dimensions: usize, } #[derive(Debug, serde::Serialize)] @@ -40,16 +40,9 @@ struct OllamaResponse { embedding: Embedding, } -#[derive(Debug, serde::Deserialize)] -struct OllamaErrorResponse { - error: OllamaError, -} - #[derive(Debug, serde::Deserialize)] pub struct OllamaError { - message: String, - // type: String, - code: Option, + error: String, } impl EmbeddingModel { @@ -68,7 +61,7 @@ impl EmbeddingModel { } pub fn from_name(name: &str) -> Self { - Self { name: name.to_string() } + Self { name: name.to_string(), dimensions: 0 } } pub fn supports_overriding_dimensions(&self) -> bool { @@ -78,17 +71,17 @@ impl EmbeddingModel { impl Default for EmbeddingModel { fn default() -> Self { - Self { name: "nomic-embed-text".to_string() } + Self { name: "nomic-embed-text".to_string(), dimensions: 0 } } } impl EmbedderOptions { pub fn with_default_model() -> Self { - Self { embedding_model: Default::default(), dimensions: 768 } + Self { embedding_model: Default::default() } } - pub fn with_embedding_model(embedding_model: EmbeddingModel, dimensions: usize) -> Self { - Self { embedding_model, dimensions } + pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { + Self { embedding_model } } } @@ -107,7 +100,58 @@ impl Embedder { reqwest::header::HeaderValue::from_static("application/json"), ); - Ok(Self { options, headers }) + let mut embedder = Self { options, headers }; + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .map_err(EmbedError::openai_runtime_init) + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + // Get dimensions from Ollama + let request = + OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; + // TODO: Refactor into shared error type + let client = embedder + .new_client() + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + rt.block_on(async move { + let response = client + .post(get_ollama_path()) + .json(&request) + .send() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + // Process error in case model not found + let response = Self::check_response(response).await.map_err(|_err| { + let e = EmbedError::ollama_model_not_found(OllamaError { + error: format!("model: {}", embedder.options.embedding_model.name()), + }); + NewEmbedderError::ollama_could_not_determine_dimension(e) + })?; + + let response: OllamaResponse = response + .json() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; + + let embedding = Embeddings::from_single_embedding(response.embedding); + + embedder.options.embedding_model.dimensions = embedding.dimension(); + + tracing::info!( + "ollama model {} with dimensionality {} added", + embedder.options.embedding_model.name(), + embedding.dimension() + ); + + Ok(embedder) + }) } async fn check_response(response: reqwest::Response) -> Result { @@ -115,26 +159,37 @@ impl Embedder { // Not the same number of possible error cases covered as with OpenAI. match response.status() { StatusCode::TOO_MANY_REQUESTS => { - let error_response: OllamaErrorResponse = response + let error_response: OllamaError = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(Retry::retry_later)?; return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( - error_response.error, + OllamaError { error: error_response.error }, ))); } StatusCode::SERVICE_UNAVAILABLE => { - let error_response: OllamaErrorResponse = response + let error_response: OllamaError = response .json() .await .map_err(EmbedError::ollama_unexpected) .map_err(Retry::retry_later)?; return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( - error_response.error, + OllamaError { error: error_response.error }, ))); } + StatusCode::NOT_FOUND => { + let error_response: OllamaError = response + .json() + .await + .map_err(EmbedError::ollama_unexpected) + .map_err(Retry::give_up)?; + + return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError { + error: error_response.error, + }))); + } code => { return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( code.as_u16(), @@ -232,7 +287,7 @@ impl Embedder { } pub fn dimensions(&self) -> usize { - self.options.dimensions + self.options.embedding_model.dimensions } pub fn distribution(&self) -> Option { @@ -242,10 +297,7 @@ impl Embedder { impl Display for OllamaError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.code { - Some(code) => write!(f, "{} ({})", self.message, code), - None => write!(f, "{}", self.message), - } + write!(f, "{}", self.error) } } diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 5595f60e3..84d58a996 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -85,9 +85,7 @@ impl EmbeddingSettings { } Self::REVISION => &[EmbedderSource::HuggingFace], Self::API_KEY => &[EmbedderSource::OpenAi], - Self::DIMENSIONS => { - &[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Ollama] - } + Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided], Self::DOCUMENT_TEMPLATE => { &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] } @@ -107,9 +105,7 @@ impl EmbeddingSettings { EmbedderSource::HuggingFace => { &[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE] } - EmbedderSource::Ollama => { - &[Self::SOURCE, Self::MODEL, Self::DIMENSIONS, Self::DOCUMENT_TEMPLATE] - } + EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE], EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], } } @@ -211,7 +207,7 @@ impl From for EmbeddingSettings { model: Setting::Set(options.embedding_model.name().to_owned()), revision: Setting::NotSet, api_key: Setting::NotSet, - dimensions: Setting::Set(options.dimensions), + dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), }, super::EmbedderOptions::UserProvided(options) => Self { @@ -251,9 +247,8 @@ impl From for EmbeddingConfig { EmbedderSource::Ollama => { let mut options: ollama::EmbedderOptions = super::ollama::EmbedderOptions::with_default_model(); - if let (Some(model), Some(dim)) = (model.set(), dimensions.set()) { + if let Some(model) = model.set() { options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); - options.dimensions = dim; } this.embedder_options = super::EmbedderOptions::Ollama(options); }