From ac52c857e8f5ecf85a42e32abe7a14450fdfdd66 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 19 Mar 2024 15:41:37 +0100 Subject: [PATCH] Update ollama and openai impls to use the rest embedder internally --- .../extract/extract_vector_points.rs | 10 +- .../src/update/index_documents/extract/mod.rs | 15 +- milli/src/vector/error.rs | 116 +---- milli/src/vector/mod.rs | 22 +- milli/src/vector/ollama.rs | 307 ++---------- milli/src/vector/openai.rs | 452 +++++------------- milli/src/vector/rest.rs | 247 ++++++++-- milli/src/vector/settings.rs | 4 +- 8 files changed, 394 insertions(+), 779 deletions(-) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index ece841659..40b32bf9c 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -339,6 +339,7 @@ pub fn extract_embeddings( prompt_reader: grenad::Reader, indexer: GrenadParameters, embedder: Arc, + request_threads: &rayon::ThreadPool, ) -> Result>> { puffin::profile_function!(); let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism @@ -376,7 +377,10 @@ pub fn extract_embeddings( if chunks.len() == chunks.capacity() { let chunked_embeds = embedder - .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) + .embed_chunks( + std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)), + request_threads, + ) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; @@ -394,7 +398,7 @@ pub fn extract_embeddings( // send last chunk if !chunks.is_empty() { let chunked_embeds = embedder - .embed_chunks(std::mem::take(&mut chunks)) + .embed_chunks(std::mem::take(&mut chunks), request_threads) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids @@ -408,7 +412,7 @@ pub fn extract_embeddings( if !current_chunk.is_empty() { let embeds = embedder - .embed_chunks(vec![std::mem::take(&mut current_chunk)]) + .embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 43f3f4947..5689bb04f 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -238,7 +238,15 @@ fn send_original_documents_data( let documents_chunk_cloned = original_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + + let request_threads = rayon::ThreadPoolBuilder::new() + .num_threads(crate::vector::REQUEST_PARALLELISM) + .thread_name(|index| format!("embedding-request-{index}")) + .build() + .unwrap(); + rayon::spawn(move || { + /// FIXME: unwrap for (name, (embedder, prompt)) in embedders { let result = extract_vector_points( documents_chunk_cloned.clone(), @@ -249,7 +257,12 @@ fn send_original_documents_data( ); match result { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { - let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { + let embeddings = match extract_embeddings( + prompts, + indexer, + embedder.clone(), + &request_threads, + ) { Ok(results) => Some(results), Err(error) => { let _ = lmdb_writer_sx_cloned.send(Err(error)); diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index b2eb37e81..92f077924 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -2,9 +2,7 @@ use std::path::PathBuf; use hf_hub::api::sync::ApiError; -use super::ollama::OllamaError; use crate::error::FaultSource; -use crate::vector::openai::OpenAiError; #[derive(Debug, thiserror::Error)] #[error("Error while generating embeddings: {inner}")] @@ -52,43 +50,12 @@ pub enum EmbedErrorKind { TensorValue(candle_core::Error), #[error("could not run model: {0}")] ModelForward(candle_core::Error), - #[error("could not reach OpenAI: {0}")] - OpenAiNetwork(ureq::Transport), - #[error("unexpected response from OpenAI: {0}")] - OpenAiUnexpected(ureq::Error), - #[error("could not authenticate against OpenAI: {0:?}")] - OpenAiAuth(Option), - #[error("sent too many requests to OpenAI: {0:?}")] - OpenAiTooManyRequests(Option), - #[error("received internal error from OpenAI: {0:?}")] - OpenAiInternalServerError(Option), - #[error("sent too many tokens in a request to OpenAI: {0:?}")] - OpenAiTooManyTokens(Option), - #[error("received unhandled HTTP status code {0} from OpenAI")] - OpenAiUnhandledStatusCode(u16), #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] ManualEmbed(String), #[error("could not initialize asynchronous runtime: {0}")] OpenAiRuntimeInit(std::io::Error), - #[error("initializing web client for sending embedding requests failed: {0}")] - InitWebClient(reqwest::Error), - // Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends. - #[error("unexpected response from Ollama: {0}")] - OllamaUnexpected(reqwest::Error), - #[error("sent too many requests to Ollama: {0}")] - OllamaTooManyRequests(OllamaError), - #[error("received internal error from Ollama: {0}")] - OllamaInternalServerError(OllamaError), - #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] - OllamaModelNotFoundError(OllamaError), - #[error("received unhandled HTTP status code {0} from Ollama")] - OllamaUnhandledStatusCode(u16), - #[error("error serializing template context: {0}")] - RestTemplateContextSerialization(liquid::Error), - #[error( - "error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'" - )] - RestTemplateError(liquid::Error), + #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")] + OllamaModelNotFoundError(Option), #[error("error deserialization the response body as JSON: {0}")] RestResponseDeserialization(std::io::Error), #[error("component `{0}` not found in path `{1}` in response: `{2}`")] @@ -128,77 +95,14 @@ impl EmbedError { Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } } - pub fn openai_network(inner: ureq::Transport) -> Self { - Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } - } - - pub fn openai_unexpected(inner: ureq::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } - } - - pub(crate) fn openai_auth_error(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } - } - - pub(crate) fn openai_too_many_requests(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn openai_internal_server_error(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn openai_too_many_tokens(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } - } - - pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } - } - pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } } - pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime } - } - - pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { - Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug } - } - - pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError { + pub(crate) fn ollama_model_not_found(inner: Option) -> EmbedError { Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User } } - pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } - } - - pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError { - Self { - kind: EmbedErrorKind::RestTemplateContextSerialization(error), - fault: FaultSource::Bug, - } - } - - pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError { - Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User } - } - pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError { Self { kind: EmbedErrorKind::RestResponseDeserialization(error), @@ -335,17 +239,6 @@ impl NewEmbedderError { fault: FaultSource::Runtime, } } - - pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { - Self { - kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), - fault: FaultSource::User, - } - } - - pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { - Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } - } } #[derive(Debug, thiserror::Error)] @@ -392,7 +285,4 @@ pub enum NewEmbedderErrorKind { CouldNotDetermineDimension(EmbedError), #[error("loading model failed: {0}")] LoadModel(candle_core::Error), - // openai - #[error("The API key passed to Authorization error was in an invalid format: {0}")] - InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 7eef3d442..39232e387 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -17,6 +17,8 @@ pub use self::error::Error; pub type Embedding = Vec; +pub const REQUEST_PARALLELISM: usize = 40; + /// One or multiple embeddings stored consecutively in a flat vector. pub struct Embeddings { data: Vec, @@ -99,7 +101,7 @@ pub enum Embedder { /// An embedder based on running local models, fetched from the Hugging Face Hub. HuggingFace(hf::Embedder), /// An embedder based on making embedding queries against the OpenAI API. - OpenAi(openai::sync::Embedder), + OpenAi(openai::Embedder), /// An embedder based on the user providing the embeddings in the documents and queries. UserProvided(manual::Embedder), Ollama(ollama::Embedder), @@ -202,7 +204,7 @@ impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), - EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::sync::Embedder::new(options)?), + EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?), EmbedderOptions::UserProvided(options) => { Self::UserProvided(manual::Embedder::new(options)) @@ -213,17 +215,14 @@ impl Embedder { /// Embed one or multiple texts. /// /// Each text can be embedded as one or multiple embeddings. - pub async fn embed( + pub fn embed( &self, texts: Vec, ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::OpenAi(embedder) => embedder.embed(texts), - Embedder::Ollama(embedder) => { - let client = embedder.new_client()?; - embedder.embed(texts, &client).await - } + Embedder::Ollama(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts), } } @@ -231,18 +230,15 @@ impl Embedder { /// Embed multiple chunks of texts. /// /// Each chunk is composed of one or multiple texts. - /// - /// # Panics - /// - /// - if called from an asynchronous context pub fn embed_chunks( &self, text_chunks: Vec>, + threads: &rayon::ThreadPool, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), - Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), - Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks), + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), + Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), } } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index 76988f70b..9c44e8052 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -1,293 +1,94 @@ -// Copied from "openai.rs" with the sections I actually understand changed for Ollama. -// The common components of the Ollama and OpenAI interfaces might need to be extracted. +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; -use std::fmt::Display; - -use reqwest::StatusCode; - -use super::error::{EmbedError, NewEmbedderError}; -use super::openai::Retry; -use super::{DistributionShift, Embedding, Embeddings}; +use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; +use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; +use super::{DistributionShift, Embeddings}; #[derive(Debug)] pub struct Embedder { - headers: reqwest::header::HeaderMap, - options: EmbedderOptions, + rest_embedder: RestEmbedder, } #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { - pub embedding_model: EmbeddingModel, -} - -#[derive( - Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr, -)] -#[deserr(deny_unknown_fields)] -pub struct EmbeddingModel { - name: String, - dimensions: usize, -} - -#[derive(Debug, serde::Serialize)] -struct OllamaRequest<'a> { - model: &'a str, - prompt: &'a str, -} - -#[derive(Debug, serde::Deserialize)] -struct OllamaResponse { - embedding: Embedding, -} - -#[derive(Debug, serde::Deserialize)] -pub struct OllamaError { - error: String, -} - -impl EmbeddingModel { - pub fn max_token(&self) -> usize { - // this might not be the same for all models - 8192 - } - - pub fn default_dimensions(&self) -> usize { - // Dimensions for nomic-embed-text - 768 - } - - pub fn name(&self) -> String { - self.name.clone() - } - - pub fn from_name(name: &str) -> Self { - Self { name: name.to_string(), dimensions: 0 } - } - - pub fn supports_overriding_dimensions(&self) -> bool { - false - } -} - -impl Default for EmbeddingModel { - fn default() -> Self { - Self { name: "nomic-embed-text".to_string(), dimensions: 0 } - } + pub embedding_model: String, } impl EmbedderOptions { pub fn with_default_model() -> Self { - Self { embedding_model: Default::default() } + Self { embedding_model: "nomic-embed-text".into() } } - pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { + pub fn with_embedding_model(embedding_model: String) -> Self { Self { embedding_model } } } impl Embedder { - pub fn new_client(&self) -> Result { - reqwest::ClientBuilder::new() - .default_headers(self.headers.clone()) - .build() - .map_err(EmbedError::openai_initialize_web_client) - } - pub fn new(options: EmbedderOptions) -> Result { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_static("application/json"), - ); - - let mut embedder = Self { options, headers }; - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - // Get dimensions from Ollama - let request = - OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; - // TODO: Refactor into shared error type - let client = embedder - .new_client() - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - rt.block_on(async move { - let response = client - .post(get_ollama_path()) - .json(&request) - .send() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - // Process error in case model not found - let response = Self::check_response(response).await.map_err(|_err| { - let e = EmbedError::ollama_model_not_found(OllamaError { - error: format!("model: {}", embedder.options.embedding_model.name()), - }); - NewEmbedderError::ollama_could_not_determine_dimension(e) - })?; - - let response: OllamaResponse = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - let embedding = Embeddings::from_single_embedding(response.embedding); - - embedder.options.embedding_model.dimensions = embedding.dimension(); - - tracing::info!( - "ollama model {} with dimensionality {} added", - embedder.options.embedding_model.name(), - embedding.dimension() - ); - - Ok(embedder) - }) - } - - async fn check_response(response: reqwest::Response) -> Result { - if !response.status().is_success() { - // Not the same number of possible error cases covered as with OpenAI. - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::retry_later)?; - - return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( - OllamaError { error: error_response.error }, - ))); - } - StatusCode::SERVICE_UNAVAILABLE => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::retry_later)?; - return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( - OllamaError { error: error_response.error }, - ))); - } - StatusCode::NOT_FOUND => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::give_up)?; - - return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError { - error: error_response.error, - }))); - } - code => { - return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( - code.as_u16(), - ))); - } + let model = options.embedding_model.as_str(); + let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { + api_key: None, + distribution: None, + dimensions: None, + url: get_ollama_path(), + query: serde_json::json!({ + "model": model, + }), + input_field: vec!["prompt".to_owned()], + path_to_embeddings: Default::default(), + embedding_object: vec!["embedding".to_owned()], + input_type: super::rest::InputType::Text, + }) { + Ok(embedder) => embedder, + Err(NewEmbedderError { + kind: + NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError { + kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error), + fault: _, + }), + fault: _, + }) => { + return Err(NewEmbedderError::could_not_determine_dimension( + EmbedError::ollama_model_not_found(error), + )) } - } - Ok(response) + Err(error) => return Err(error), + }; + + Ok(Self { rest_embedder }) } - pub async fn embed( - &self, - texts: Vec, - client: &reqwest::Client, - ) -> Result>, EmbedError> { - // Ollama only embedds one document at a time. - let mut results = Vec::with_capacity(texts.len()); - - // The retry loop is inside the texts loop, might have to switch that around - for text in texts { - // Retries copied from openai.rs - for attempt in 0..7 { - let retry_duration = match self.try_embed(&text, client).await { - Ok(result) => { - results.push(result); - break; - } - Err(retry) => { - tracing::warn!("Failed: {}", retry.error); - retry.into_duration(attempt) - } - }?; - tracing::warn!( - "Attempt #{}, retrying after {}ms.", - attempt, - retry_duration.as_millis() - ); - tokio::time::sleep(retry_duration).await; + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + match self.rest_embedder.embed(texts) { + Ok(embeddings) => Ok(embeddings), + Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { + Err(EmbedError::ollama_model_not_found(error)) } + Err(error) => Err(error), } - - Ok(results) - } - - async fn try_embed( - &self, - text: &str, - client: &reqwest::Client, - ) -> Result, Retry> { - let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text }; - let response = client - .post(get_ollama_path()) - .json(&request) - .send() - .await - .map_err(EmbedError::openai_network) - .map_err(Retry::retry_later)?; - - let response = Self::check_response(response).await?; - - let response: OllamaResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::trace!("response: {:?}", response.embedding); - - let embedding = Embeddings::from_single_embedding(response.embedding); - Ok(embedding) } pub fn embed_chunks( &self, text_chunks: Vec>, + threads: &rayon::ThreadPool, ) -> Result>>, EmbedError> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init)?; - let client = self.new_client()?; - rt.block_on(futures::future::try_join_all( - text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), - )) + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) } - // Defaults copied from openai.rs pub fn chunk_count_hint(&self) -> usize { - 10 + self.rest_embedder.chunk_count_hint() } pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 + self.rest_embedder.prompt_count_in_chunk_hint() } pub fn dimensions(&self) -> usize { - self.options.embedding_model.dimensions + self.rest_embedder.dimensions() } pub fn distribution(&self) -> Option { @@ -295,12 +96,6 @@ impl Embedder { } } -impl Display for OllamaError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error) - } -} - fn get_ollama_path() -> String { // Important: Hostname not enough, has to be entire path to embeddings endpoint std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string()) diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 5d13d5ee2..b2638966e 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -1,9 +1,9 @@ -use std::fmt::Display; - -use serde::{Deserialize, Serialize}; +use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use super::error::{EmbedError, NewEmbedderError}; -use super::{DistributionShift, Embedding, Embeddings}; +use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; +use super::{DistributionShift, Embeddings}; +use crate::vector::error::EmbedErrorKind; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { @@ -12,6 +12,32 @@ pub struct EmbedderOptions { pub dimensions: Option, } +impl EmbedderOptions { + pub fn dimensions(&self) -> usize { + if self.embedding_model.supports_overriding_dimensions() { + self.dimensions.unwrap_or(self.embedding_model.default_dimensions()) + } else { + self.embedding_model.default_dimensions() + } + } + + pub fn query(&self) -> serde_json::Value { + let model = self.embedding_model.name(); + + let mut query = serde_json::json!({ + "model": model, + }); + + if self.embedding_model.supports_overriding_dimensions() { + if let Some(dimensions) = self.dimensions { + query["dimensions"] = dimensions.into(); + } + } + + query + } +} + #[derive( Debug, Clone, @@ -117,364 +143,112 @@ impl EmbedderOptions { } } -// retrying in case of failure - -pub struct Retry { - pub error: EmbedError, - strategy: RetryStrategy, -} - -pub enum RetryStrategy { - GiveUp, - Retry, - RetryTokenized, - RetryAfterRateLimit, -} - -impl Retry { - pub fn give_up(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::GiveUp } - } - - pub fn retry_later(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::Retry } - } - - pub fn retry_tokenized(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::RetryTokenized } - } - - pub fn rate_limited(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::RetryAfterRateLimit } - } - - pub fn into_duration(self, attempt: u32) -> Result { - match self.strategy { - RetryStrategy::GiveUp => Err(self.error), - RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), - RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), - RetryStrategy::RetryAfterRateLimit => { - Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) - } - } - } - - pub fn must_tokenize(&self) -> bool { - matches!(self.strategy, RetryStrategy::RetryTokenized) - } - - pub fn into_error(self) -> EmbedError { - self.error - } -} - -// openai api structs - -#[derive(Debug, Serialize)] -struct OpenAiRequest<'a, S: AsRef + serde::Serialize> { - model: &'a str, - input: &'a [S], - #[serde(skip_serializing_if = "Option::is_none")] - dimensions: Option, -} - -#[derive(Debug, Serialize)] -struct OpenAiTokensRequest<'a> { - model: &'a str, - input: &'a [usize], - #[serde(skip_serializing_if = "Option::is_none")] - dimensions: Option, -} - -#[derive(Debug, Deserialize)] -struct OpenAiResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct OpenAiErrorResponse { - error: OpenAiError, -} - -#[derive(Debug, Deserialize)] -pub struct OpenAiError { - message: String, - // type: String, - code: Option, -} - -impl Display for OpenAiError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.code { - Some(code) => write!(f, "{} ({})", self.message, code), - None => write!(f, "{}", self.message), - } - } -} - -#[derive(Debug, Deserialize)] -struct OpenAiEmbedding { - embedding: Embedding, - // object: String, - // index: usize, -} - fn infer_api_key() -> String { std::env::var("MEILI_OPENAI_API_KEY") .or_else(|_| std::env::var("OPENAI_API_KEY")) .unwrap_or_default() } -pub mod sync { - use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; +#[derive(Debug)] +pub struct Embedder { + tokenizer: tiktoken_rs::CoreBPE, + rest_embedder: RestEmbedder, + options: EmbedderOptions, +} - use super::{ - EmbedError, Embedding, Embeddings, NewEmbedderError, OpenAiErrorResponse, OpenAiRequest, - OpenAiResponse, OpenAiTokensRequest, Retry, OPENAI_EMBEDDINGS_URL, - }; - use crate::vector::DistributionShift; +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let mut inferred_api_key = Default::default(); + let api_key = options.api_key.as_ref().unwrap_or_else(|| { + inferred_api_key = infer_api_key(); + &inferred_api_key + }); - const REQUEST_PARALLELISM: usize = 10; + let rest_embedder = RestEmbedder::new(RestEmbedderOptions { + api_key: Some(api_key.clone()), + distribution: options.embedding_model.distribution(), + dimensions: Some(options.dimensions()), + url: OPENAI_EMBEDDINGS_URL.to_owned(), + query: options.query(), + input_field: vec!["input".to_owned()], + input_type: crate::vector::rest::InputType::TextArray, + path_to_embeddings: vec!["data".to_owned()], + embedding_object: vec!["embedding".to_owned()], + })?; - #[derive(Debug)] - pub struct Embedder { - tokenizer: tiktoken_rs::CoreBPE, - options: super::EmbedderOptions, - bearer: String, - threads: rayon::ThreadPool, + // looking at the code it is very unclear that this can actually fail. + let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + + Ok(Self { options, rest_embedder, tokenizer }) } - impl Embedder { - pub fn new(options: super::EmbedderOptions) -> Result { - let mut inferred_api_key = Default::default(); - let api_key = options.api_key.as_ref().unwrap_or_else(|| { - inferred_api_key = super::infer_api_key(); - &inferred_api_key - }); - let bearer = format!("Bearer {api_key}"); - - // looking at the code it is very unclear that this can actually fail. - let tokenizer = tiktoken_rs::cl100k_base().unwrap(); - - // FIXME: unwrap - let threads = rayon::ThreadPoolBuilder::new() - .num_threads(REQUEST_PARALLELISM) - .thread_name(|index| format!("embedder-chunk-{index}")) - .build() - .unwrap(); - - Ok(Self { options, bearer, tokenizer, threads }) + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + match self.rest_embedder.embed_ref(&texts) { + Ok(embeddings) => Ok(embeddings), + Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => { + tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); + self.try_embed_tokenized(&texts) + } + Err(error) => Err(error), } + } - pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - let mut tokenized = false; - - let client = ureq::agent(); - - for attempt in 0..7 { - let result = if tokenized { - self.try_embed_tokenized(&texts, &client) - } else { - self.try_embed(&texts, &client) - }; - - let retry_duration = match result { - Ok(embeddings) => return Ok(embeddings), - Err(retry) => { - tracing::warn!("Failed: {}", retry.error); - tokenized |= retry.must_tokenize(); - retry.into_duration(attempt) - } - }?; - - let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute - tracing::warn!( - "Attempt #{}, retrying after {}ms.", - attempt, - retry_duration.as_millis() - ); - std::thread::sleep(retry_duration); + fn try_embed_tokenized(&self, text: &[String]) -> Result>, EmbedError> { + pub const OVERLAP_SIZE: usize = 200; + let mut all_embeddings = Vec::with_capacity(text.len()); + for text in text { + let max_token_count = self.options.embedding_model.max_token(); + let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let len = encoded.len(); + if len < max_token_count { + all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); + continue; } - let result = if tokenized { - self.try_embed_tokenized(&texts, &client) - } else { - self.try_embed(&texts, &client) - }; + let mut tokens = encoded.as_slice(); + let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); + while tokens.len() > max_token_count { + let window = &tokens[..max_token_count]; + let embedding = self.rest_embedder.embed_tokens(window)?; + /// FIXME: unwrap + embeddings_for_prompt.append(embedding.into_inner()).unwrap(); - result.map_err(Retry::into_error) - } - - fn check_response( - response: Result, - ) -> Result { - match response { - Ok(response) => Ok(response), - Err(ureq::Error::Status(code, response)) => { - let error_response: Option = response.into_json().ok(); - let error = error_response.map(|response| response.error); - Err(match code { - 401 => Retry::give_up(EmbedError::openai_auth_error(error)), - 429 => Retry::rate_limited(EmbedError::openai_too_many_requests(error)), - 400 => { - tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); - - Retry::retry_tokenized(EmbedError::openai_too_many_tokens(error)) - } - 500..=599 => { - Retry::retry_later(EmbedError::openai_internal_server_error(error)) - } - x => Retry::retry_later(EmbedError::openai_unhandled_status_code(code)), - }) - } - Err(ureq::Error::Transport(transport)) => { - Err(Retry::retry_later(EmbedError::openai_network(transport))) - } - } - } - - fn try_embed + serde::Serialize>( - &self, - texts: &[S], - client: &ureq::Agent, - ) -> Result>, Retry> { - for text in texts { - tracing::trace!("Received prompt: {}", text.as_ref()) - } - let request = OpenAiRequest { - model: self.options.embedding_model.name(), - input: texts, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .set("Authorization", &self.bearer) - .send_json(&request); - - let response = Self::check_response(response)?; - - let response: OpenAiResponse = response - .into_json() - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::trace!("response: {:?}", response.data); - - Ok(response - .data - .into_iter() - .map(|data| Embeddings::from_single_embedding(data.embedding)) - .collect()) - } - - fn try_embed_tokenized( - &self, - text: &[String], - client: &ureq::Agent, - ) -> Result>, Retry> { - pub const OVERLAP_SIZE: usize = 200; - let mut all_embeddings = Vec::with_capacity(text.len()); - for text in text { - let max_token_count = self.options.embedding_model.max_token(); - let encoded = self.tokenizer.encode_ordinary(text.as_str()); - let len = encoded.len(); - if len < max_token_count { - all_embeddings.append(&mut self.try_embed(&[text], client)?); - continue; - } - - let mut tokens = encoded.as_slice(); - let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); - while tokens.len() > max_token_count { - let window = &tokens[..max_token_count]; - embeddings_for_prompt.push(self.embed_tokens(window, client)?).unwrap(); - - tokens = &tokens[max_token_count - OVERLAP_SIZE..]; - } - - // end of text - embeddings_for_prompt.push(self.embed_tokens(tokens, client)?).unwrap(); - - all_embeddings.push(embeddings_for_prompt); - } - Ok(all_embeddings) - } - - fn embed_tokens(&self, tokens: &[usize], client: &ureq::Agent) -> Result { - for attempt in 0..9 { - let duration = match self.try_embed_tokens(tokens, client) { - Ok(embedding) => return Ok(embedding), - Err(retry) => retry.into_duration(attempt), - } - .map_err(Retry::retry_later)?; - - std::thread::sleep(duration); + tokens = &tokens[max_token_count - OVERLAP_SIZE..]; } - self.try_embed_tokens(tokens, client) - .map_err(|retry| Retry::give_up(retry.into_error())) + // end of text + let embedding = self.rest_embedder.embed_tokens(tokens)?; + /// FIXME: unwrap + embeddings_for_prompt.append(embedding.into_inner()).unwrap(); + + all_embeddings.push(embeddings_for_prompt); } + Ok(all_embeddings) + } - fn try_embed_tokens( - &self, - tokens: &[usize], - client: &ureq::Agent, - ) -> Result { - let request = OpenAiTokensRequest { - model: self.options.embedding_model.name(), - input: tokens, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .set("Authorization", &self.bearer) - .send_json(&request); + pub fn embed_chunks( + &self, + text_chunks: Vec>, + threads: &rayon::ThreadPool, + ) -> Result>>, EmbedError> { + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + } - let response = Self::check_response(response)?; + pub fn chunk_count_hint(&self) -> usize { + self.rest_embedder.chunk_count_hint() + } - let mut response: OpenAiResponse = response - .into_json() - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; + pub fn prompt_count_in_chunk_hint(&self) -> usize { + self.rest_embedder.prompt_count_in_chunk_hint() + } - Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) - } + pub fn dimensions(&self) -> usize { + self.options.dimensions() + } - pub fn embed_chunks( - &self, - text_chunks: Vec>, - ) -> Result>>, EmbedError> { - self.threads - .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) - .collect() - } - - pub fn chunk_count_hint(&self) -> usize { - 10 - } - - pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 - } - - pub fn dimensions(&self) -> usize { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions()) - } else { - self.options.embedding_model.default_dimensions() - } - } - - pub fn distribution(&self) -> Option { - self.options.embedding_model.distribution() - } - - fn overriden_dimensions(&self) -> Option { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions - } else { - None - } - } + pub fn distribution(&self) -> Option { + self.options.embedding_model.distribution() } } diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index 975bd3790..6fd47d882 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -1,9 +1,62 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use serde::Serialize; -use super::openai::Retry; -use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError}; -use crate::VectorOrArrayOfVectors; +use super::{ + DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, +}; +// retrying in case of failure + +pub struct Retry { + pub error: EmbedError, + strategy: RetryStrategy, +} + +pub enum RetryStrategy { + GiveUp, + Retry, + RetryTokenized, + RetryAfterRateLimit, +} + +impl Retry { + pub fn give_up(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + pub fn retry_later(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + pub fn retry_tokenized(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryTokenized } + } + + pub fn rate_limited(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryAfterRateLimit } + } + + pub fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))), + RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)), + RetryStrategy::RetryAfterRateLimit => { + Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt))) + } + } + } + + pub fn must_tokenize(&self) -> bool { + matches!(self.strategy, RetryStrategy::RetryTokenized) + } + + pub fn into_error(self) -> EmbedError { + self.error + } +} + +#[derive(Debug)] pub struct Embedder { client: ureq::Agent, options: EmbedderOptions, @@ -11,20 +64,35 @@ pub struct Embedder { dimensions: usize, } +#[derive(Debug)] pub struct EmbedderOptions { - api_key: Option, - distribution: Option, - dimensions: Option, - url: String, - query: liquid::Template, - response_field: Vec, + pub api_key: Option, + pub distribution: Option, + pub dimensions: Option, + pub url: String, + pub query: serde_json::Value, + pub input_field: Vec, + // path to the array of embeddings + pub path_to_embeddings: Vec, + // shape of a single embedding + pub embedding_object: Vec, + pub input_type: InputType, +} + +#[derive(Debug)] +pub enum InputType { + Text, + TextArray, } impl Embedder { pub fn new(options: EmbedderOptions) -> Result { - let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}")); + let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}")); - let client = ureq::agent(); + let client = ureq::AgentBuilder::new() + .max_idle_connections(REQUEST_PARALLELISM * 2) + .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) + .build(); let dimensions = if let Some(dimensions) = options.dimensions { dimensions @@ -36,7 +104,20 @@ impl Embedder { } pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice()) + embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len()) + } + + pub fn embed_ref(&self, texts: &[S]) -> Result>, EmbedError> + where + S: AsRef + Serialize, + { + embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len()) + } + + pub fn embed_tokens(&self, tokens: &[usize]) -> Result, EmbedError> { + let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?; + // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error + Ok(embeddings.pop().unwrap()) } pub fn embed_chunks( @@ -44,17 +125,20 @@ impl Embedder { text_chunks: Vec>, threads: &rayon::ThreadPool, ) -> Result>>, EmbedError> { - threads - .install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk))) - .collect() + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) } pub fn chunk_count_hint(&self) -> usize { - 10 + super::REQUEST_PARALLELISM } pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 + match self.options.input_type { + InputType::Text => 1, + InputType::TextArray => 10, + } } pub fn dimensions(&self) -> usize { @@ -71,9 +155,9 @@ fn infer_dimensions( options: &EmbedderOptions, bearer: Option<&str>, ) -> Result { - let v = embed(client, options, bearer, ["test"].as_slice()) + let v = embed(client, options, bearer, ["test"].as_slice(), 1) .map_err(NewEmbedderError::could_not_determine_dimension)?; - // unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error + // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error Ok(v.first().unwrap().dimension()) } @@ -82,33 +166,57 @@ fn embed( options: &EmbedderOptions, bearer: Option<&str>, inputs: &[S], + expected_count: usize, ) -> Result>, EmbedError> where - S: serde::Serialize, + S: Serialize, { let request = client.post(&options.url); let request = if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; let request = request.set("Content-Type", "application/json"); - let body = options - .query - .render( - &liquid::to_object(&serde_json::json!({ - "input": inputs, - })) - .map_err(EmbedError::rest_template_context_serialization)?, - ) - .map_err(EmbedError::rest_template_render)?; + let input_value = match options.input_type { + InputType::Text => serde_json::json!(inputs.first()), + InputType::TextArray => serde_json::json!(inputs), + }; + + let body = match options.input_field.as_slice() { + [] => { + // inject input in body + input_value + } + [input] => { + let mut body = options.query.clone(); + + /// FIXME unwrap + body.as_object_mut().unwrap().insert(input.clone(), input_value); + body + } + [path @ .., input] => { + let mut body = options.query.clone(); + + /// FIXME unwrap + let mut current_value = &mut body; + for component in path { + current_value = current_value + .as_object_mut() + .unwrap() + .entry(component.clone()) + .or_insert(serde_json::json!({})); + } + + current_value.as_object_mut().unwrap().insert(input.clone(), input_value); + body + } + }; for attempt in 0..7 { - let response = request.send_string(&body); + let response = request.clone().send_json(&body); let result = check_response(response); let retry_duration = match result { - Ok(response) => { - return response_to_embedding(response, &options.response_field, inputs.len()) - } + Ok(response) => return response_to_embedding(response, options, expected_count), Err(retry) => { tracing::warn!("Failed: {}", retry.error); retry.into_duration(attempt) @@ -120,11 +228,11 @@ where std::thread::sleep(retry_duration); } - let response = request.send_string(&body); + let response = request.send_json(&body); let result = check_response(response); result .map_err(Retry::into_error) - .and_then(|response| response_to_embedding(response, &options.response_field, inputs.len())) + .and_then(|response| response_to_embedding(response, options, expected_count)) } fn check_response(response: Result) -> Result { @@ -139,7 +247,10 @@ fn check_response(response: Result) -> Result { Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) } - x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), + 402..=499 => { + Retry::give_up(EmbedError::rest_other_status_code(code, error_response)) + } + _ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), }) } Err(ureq::Error::Transport(transport)) => { @@ -148,34 +259,66 @@ fn check_response(response: Result) -> Result>( +fn response_to_embedding( response: ureq::Response, - response_field: &[S], + options: &EmbedderOptions, expected_count: usize, ) -> Result>, EmbedError> { let response: serde_json::Value = response.into_json().map_err(EmbedError::rest_response_deserialization)?; let mut current_value = &response; - for component in response_field { + for component in &options.path_to_embeddings { let component = component.as_ref(); - let current_value = current_value.get(component).ok_or_else(|| { - EmbedError::rest_response_missing_embeddings(response, component, response_field) + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.path_to_embeddings, + ) })?; } - let embeddings = current_value.to_owned(); + let embeddings = match options.input_type { + InputType::Text => { + for component in &options.embedding_object { + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.embedding_object, + ) + })?; + } + let embeddings = current_value.to_owned(); + let embeddings: Embedding = + serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; - let embeddings: VectorOrArrayOfVectors = - serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; - - let embeddings = embeddings.into_array_of_vectors(); - - let embeddings: Vec> = embeddings - .into_iter() - .flatten() - .map(|embedding| Embeddings::from_single_embedding(embedding)) - .collect(); + vec![Embeddings::from_single_embedding(embeddings)] + } + InputType::TextArray => { + let empty = vec![]; + let values = current_value.as_array().unwrap_or(&empty); + let mut embeddings: Vec> = Vec::with_capacity(expected_count); + for value in values { + let mut current_value = value; + for component in &options.embedding_object { + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.embedding_object, + ) + })?; + } + let embedding = current_value.to_owned(); + let embedding: Embedding = + serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?; + embeddings.push(Embeddings::from_single_embedding(embedding)); + } + embeddings + } + }; if embeddings.len() != expected_count { return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 89571e98a..540693d44 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -204,7 +204,7 @@ impl From for EmbeddingSettings { }, super::EmbedderOptions::Ollama(options) => Self { source: Setting::Set(EmbedderSource::Ollama), - model: Setting::Set(options.embedding_model.name().to_owned()), + model: Setting::Set(options.embedding_model.to_owned()), revision: Setting::NotSet, api_key: Setting::NotSet, dimensions: Setting::NotSet, @@ -248,7 +248,7 @@ impl From for EmbeddingConfig { let mut options: ollama::EmbedderOptions = super::ollama::EmbedderOptions::with_default_model(); if let Some(model) = model.set() { - options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); + options.embedding_model = model; } this.embedder_options = super::EmbedderOptions::Ollama(options); }