diff --git a/Cargo.lock b/Cargo.lock index b44b151d1..214ba368f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3338,7 +3338,6 @@ dependencies = [ "filter-parser", "flatten-serde-json", "fst", - "futures", "fxhash", "geoutils", "grenad", @@ -3362,7 +3361,6 @@ dependencies = [ "rand", "rand_pcg", "rayon", - "reqwest", "roaring", "rstar", "serde", @@ -3376,8 +3374,9 @@ dependencies = [ "tiktoken-rs", "time", "tokenizers", - "tokio", "tracing", + "ureq", + "url", "uuid", ] diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index aed77411a..1b94201f2 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -353,6 +353,7 @@ impl ErrorCode for milli::Error { | UserError::InvalidOpenAiModelDimensions { .. } | UserError::InvalidOpenAiModelDimensionsMax { .. } | UserError::InvalidSettingsDimensions { .. } + | UserError::InvalidUrl { .. } | UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 6a430b6a3..8de2be13f 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -202,7 +202,7 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); - let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) @@ -241,7 +241,7 @@ pub async fn search_with_post( let features = index_scheduler.features(); - let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?; let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) @@ -260,7 +260,7 @@ pub async fn search_with_post( Ok(HttpResponse::Ok().json(search_result)) } -pub async fn embed( +pub fn embed( query: &mut SearchQuery, index_scheduler: &IndexScheduler, index: &milli::Index, @@ -287,7 +287,6 @@ pub async fn embed( let embeddings = embedder .embed(vec![q.to_owned()]) - .await .map_err(milli::vector::Error::from) .map_err(milli::Error::from)? .pop() diff --git a/meilisearch/src/routes/indexes/settings.rs b/meilisearch/src/routes/indexes/settings.rs index 5dabd7b0d..99c3d0fbb 100644 --- a/meilisearch/src/routes/indexes/settings.rs +++ b/meilisearch/src/routes/indexes/settings.rs @@ -605,6 +605,7 @@ fn embedder_analytics( EmbedderSource::HuggingFace => sources.insert("huggingFace"), EmbedderSource::UserProvided => sources.insert("userProvided"), EmbedderSource::Ollama => sources.insert("ollama"), + EmbedderSource::Rest => sources.insert("rest"), }; } }; diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index 86aa58e70..f54b8ae8f 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -75,9 +75,8 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; - let distribution = embed(&mut query, index_scheduler.get_ref(), &index) - .await - .with_index(query_index)?; + let distribution = + embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?; let search_result = tokio::task::spawn_blocking(move || { perform_search(&index, query, features, distribution) diff --git a/milli/Cargo.toml b/milli/Cargo.toml index fa4215404..9f5803f4e 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -80,17 +80,13 @@ tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0. hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ "online", ] } -tokio = { version = "1.35.1", features = ["rt"] } -futures = "0.3.30" -reqwest = { version = "0.11.23", features = [ - "rustls-tls", - "json", -], default-features = false } tiktoken-rs = "0.5.8" liquid = "0.26.4" arroy = "0.2.0" rand = "0.8.5" tracing = "0.1.40" +ureq = { version = "2.9.6", features = ["json"] } +url = "2.5.0" [dev-dependencies] mimalloc = { version = "0.1.39", default-features = false } diff --git a/milli/src/error.rs b/milli/src/error.rs index 1147085dd..aba80b475 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -243,6 +243,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco }, #[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")] InvalidSettingsDimensions { embedder_name: String }, + #[error("`.embedders.{embedder_name}.url`: could not parse `{url}`: {inner_error}")] + InvalidUrl { embedder_name: String, inner_error: url::ParseError, url: String }, } impl From for Error { diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index ece841659..40b32bf9c 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -339,6 +339,7 @@ pub fn extract_embeddings( prompt_reader: grenad::Reader, indexer: GrenadParameters, embedder: Arc, + request_threads: &rayon::ThreadPool, ) -> Result>> { puffin::profile_function!(); let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism @@ -376,7 +377,10 @@ pub fn extract_embeddings( if chunks.len() == chunks.capacity() { let chunked_embeds = embedder - .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))) + .embed_chunks( + std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)), + request_threads, + ) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; @@ -394,7 +398,7 @@ pub fn extract_embeddings( // send last chunk if !chunks.is_empty() { let chunked_embeds = embedder - .embed_chunks(std::mem::take(&mut chunks)) + .embed_chunks(std::mem::take(&mut chunks), request_threads) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids @@ -408,7 +412,7 @@ pub fn extract_embeddings( if !current_chunk.is_empty() { let embeds = embedder - .embed_chunks(vec![std::mem::take(&mut current_chunk)]) + .embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads) .map_err(crate::vector::Error::from) .map_err(crate::Error::from)?; diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 43f3f4947..82486f3a8 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -238,6 +238,12 @@ fn send_original_documents_data( let documents_chunk_cloned = original_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + + let request_threads = rayon::ThreadPoolBuilder::new() + .num_threads(crate::vector::REQUEST_PARALLELISM) + .thread_name(|index| format!("embedding-request-{index}")) + .build()?; + rayon::spawn(move || { for (name, (embedder, prompt)) in embedders { let result = extract_vector_points( @@ -249,7 +255,12 @@ fn send_original_documents_data( ); match result { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { - let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { + let embeddings = match extract_embeddings( + prompts, + indexer, + embedder.clone(), + &request_threads, + ) { Ok(results) => Some(results), Err(error) => { let _ = lmdb_writer_sx_cloned.send(Err(error)); diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 7499b68e5..913fbc881 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -2646,6 +2646,12 @@ mod tests { api_key: Setting::NotSet, dimensions: Setting::Set(3), document_template: Setting::NotSet, + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }), ); settings.set_embedder_settings(embedders); diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index f54f45e1e..e902badc0 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1140,6 +1140,12 @@ fn validate_prompt( api_key, dimensions, document_template: Setting::Set(template), + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, }) => { // validate let template = crate::prompt::Prompt::new(template) @@ -1153,6 +1159,12 @@ fn validate_prompt( api_key, dimensions, document_template: Setting::Set(template), + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, })) } new => Ok(new), @@ -1165,8 +1177,20 @@ pub fn validate_embedding_settings( ) -> Result> { let settings = validate_prompt(name, settings)?; let Setting::Set(settings) = settings else { return Ok(settings) }; - let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = - settings; + let EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + } = settings; if let Some(0) = dimensions.set() { return Err(crate::error::UserError::InvalidSettingsDimensions { @@ -1175,6 +1199,14 @@ pub fn validate_embedding_settings( .into()); } + if let Some(url) = url.as_ref().set() { + url::Url::parse(url).map_err(|error| crate::error::UserError::InvalidUrl { + embedder_name: name.to_owned(), + inner_error: error, + url: url.to_owned(), + })?; + } + let Some(inferred_source) = source.set() else { return Ok(Setting::Set(EmbeddingSettings { source, @@ -1183,11 +1215,25 @@ pub fn validate_embedding_settings( api_key, dimensions, document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, })); }; match inferred_source { EmbedderSource::OpenAi => { check_unset(&revision, "revision", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; + if let Setting::Set(model) = &model { let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str()) .ok_or(crate::error::UserError::InvalidOpenAiModel { @@ -1224,10 +1270,24 @@ pub fn validate_embedding_settings( check_set(&model, "model", inferred_source, name)?; check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&revision, "revision", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; } EmbedderSource::HuggingFace => { check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&dimensions, "dimensions", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; } EmbedderSource::UserProvided => { check_unset(&model, "model", inferred_source, name)?; @@ -1235,6 +1295,18 @@ pub fn validate_embedding_settings( check_unset(&api_key, "apiKey", inferred_source, name)?; check_unset(&document_template, "documentTemplate", inferred_source, name)?; check_set(&dimensions, "dimensions", inferred_source, name)?; + + check_unset(&url, "url", inferred_source, name)?; + check_unset(&query, "query", inferred_source, name)?; + check_unset(&input_field, "inputField", inferred_source, name)?; + check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?; + check_unset(&embedding_object, "embeddingObject", inferred_source, name)?; + check_unset(&input_type, "inputType", inferred_source, name)?; + } + EmbedderSource::Rest => { + check_unset(&model, "model", inferred_source, name)?; + check_unset(&revision, "revision", inferred_source, name)?; + check_set(&url, "url", inferred_source, name)?; } } Ok(Setting::Set(EmbeddingSettings { @@ -1244,6 +1316,12 @@ pub fn validate_embedding_settings( api_key, dimensions, document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, })) } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 9bbdeaa90..1e0bcc7fb 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -2,9 +2,7 @@ use std::path::PathBuf; use hf_hub::api::sync::ApiError; -use super::ollama::OllamaError; use crate::error::FaultSource; -use crate::vector::openai::OpenAiError; #[derive(Debug, thiserror::Error)] #[error("Error while generating embeddings: {inner}")] @@ -52,37 +50,34 @@ pub enum EmbedErrorKind { TensorValue(candle_core::Error), #[error("could not run model: {0}")] ModelForward(candle_core::Error), - #[error("could not reach OpenAI: {0}")] - OpenAiNetwork(reqwest::Error), - #[error("unexpected response from OpenAI: {0}")] - OpenAiUnexpected(reqwest::Error), - #[error("could not authenticate against OpenAI: {0}")] - OpenAiAuth(OpenAiError), - #[error("sent too many requests to OpenAI: {0}")] - OpenAiTooManyRequests(OpenAiError), - #[error("received internal error from OpenAI: {0:?}")] - OpenAiInternalServerError(Option), - #[error("sent too many tokens in a request to OpenAI: {0}")] - OpenAiTooManyTokens(OpenAiError), - #[error("received unhandled HTTP status code {0} from OpenAI")] - OpenAiUnhandledStatusCode(u16), #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] ManualEmbed(String), - #[error("could not initialize asynchronous runtime: {0}")] - OpenAiRuntimeInit(std::io::Error), - #[error("initializing web client for sending embedding requests failed: {0}")] - InitWebClient(reqwest::Error), - // Dedicated Ollama error kinds, might have to merge them into one cohesive error type for all backends. - #[error("unexpected response from Ollama: {0}")] - OllamaUnexpected(reqwest::Error), - #[error("sent too many requests to Ollama: {0}")] - OllamaTooManyRequests(OllamaError), - #[error("received internal error from Ollama: {0}")] - OllamaInternalServerError(OllamaError), - #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0}")] - OllamaModelNotFoundError(OllamaError), - #[error("received unhandled HTTP status code {0} from Ollama")] - OllamaUnhandledStatusCode(u16), + #[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")] + OllamaModelNotFoundError(Option), + #[error("error deserialization the response body as JSON: {0}")] + RestResponseDeserialization(std::io::Error), + #[error("component `{0}` not found in path `{1}` in response: `{2}`")] + RestResponseMissingEmbeddings(String, String, String), + #[error("expected a response parseable as a vector or an array of vectors: {0}")] + RestResponseFormat(serde_json::Error), + #[error("expected a response containing {0} embeddings, got only {1}")] + RestResponseEmbeddingCount(usize, usize), + #[error("could not authenticate against embedding server: {0:?}")] + RestUnauthorized(Option), + #[error("sent too many requests to embedding server: {0:?}")] + RestTooManyRequests(Option), + #[error("sent a bad request to embedding server: {0:?}")] + RestBadRequest(Option), + #[error("received internal error from embedding server: {0:?}")] + RestInternalServerError(u16, Option), + #[error("received HTTP {0} from embedding server: {0:?}")] + RestOtherStatusCode(u16, Option), + #[error("could not reach embedding server: {0}")] + RestNetwork(ureq::Transport), + #[error("was expected '{}' to be an object in query '{0}'", .1.join("."))] + RestNotAnObject(serde_json::Value, Vec), + #[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")] + OpenAiUnexpectedDimension(usize, usize), } impl EmbedError { @@ -102,64 +97,98 @@ impl EmbedError { Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } } - pub fn openai_network(inner: reqwest::Error) -> Self { - Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } - } - - pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } - } - - pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } - } - - pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn openai_internal_server_error(inner: Option) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } - } - - pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } - } - pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } } - pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime } - } - - pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { - Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } - } - - pub(crate) fn ollama_unexpected(inner: reqwest::Error) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaUnexpected(inner), fault: FaultSource::Bug } - } - - pub(crate) fn ollama_model_not_found(inner: OllamaError) -> EmbedError { + pub(crate) fn ollama_model_not_found(inner: Option) -> EmbedError { Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User } } - pub(crate) fn ollama_too_many_requests(inner: OllamaError) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaTooManyRequests(inner), fault: FaultSource::Runtime } + pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError { + Self { + kind: EmbedErrorKind::RestResponseDeserialization(error), + fault: FaultSource::Runtime, + } } - pub(crate) fn ollama_internal_server_error(inner: OllamaError) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaInternalServerError(inner), fault: FaultSource::Runtime } + pub(crate) fn rest_response_missing_embeddings>( + response: serde_json::Value, + component: &str, + response_field: &[S], + ) -> EmbedError { + let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect(); + let response_field = response_field.join("."); + + Self { + kind: EmbedErrorKind::RestResponseMissingEmbeddings( + component.to_owned(), + response_field, + serde_json::to_string_pretty(&response).unwrap_or_default(), + ), + fault: FaultSource::Undecided, + } } - pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError { - Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug } + pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError { + Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided } + } + + pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError { + Self { + kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_unauthorized(error_response: Option) -> EmbedError { + Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User } + } + + pub(crate) fn rest_too_many_requests(error_response: Option) -> EmbedError { + Self { + kind: EmbedErrorKind::RestTooManyRequests(error_response), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_bad_request(error_response: Option) -> EmbedError { + Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User } + } + + pub(crate) fn rest_internal_server_error( + code: u16, + error_response: Option, + ) -> EmbedError { + Self { + kind: EmbedErrorKind::RestInternalServerError(code, error_response), + fault: FaultSource::Runtime, + } + } + + pub(crate) fn rest_other_status_code(code: u16, error_response: Option) -> EmbedError { + Self { + kind: EmbedErrorKind::RestOtherStatusCode(code, error_response), + fault: FaultSource::Undecided, + } + } + + pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError { + Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime } + } + + pub(crate) fn rest_not_an_object( + query: serde_json::Value, + input_path: Vec, + ) -> EmbedError { + Self { kind: EmbedErrorKind::RestNotAnObject(query, input_path), fault: FaultSource::User } + } + + pub(crate) fn openai_unexpected_dimension(expected: usize, got: usize) -> EmbedError { + Self { + kind: EmbedErrorKind::OpenAiUnexpectedDimension(expected, got), + fault: FaultSource::Runtime, + } } } @@ -220,23 +249,12 @@ impl NewEmbedderError { Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } } - pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { + pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { Self { kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), fault: FaultSource::Runtime, } } - - pub fn ollama_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { - Self { - kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), - fault: FaultSource::User, - } - } - - pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { - Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } - } } #[derive(Debug, thiserror::Error)] @@ -283,7 +301,4 @@ pub enum NewEmbedderErrorKind { CouldNotDetermineDimension(EmbedError), #[error("loading model failed: {0}")] LoadModel(candle_core::Error), - // openai - #[error("The API key passed to Authorization error was in an invalid format: {0}")] - InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), } diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 04e169c71..e341a553e 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -131,7 +131,7 @@ impl Embedder { let embeddings = this .embed(vec!["test".into()]) - .map_err(NewEmbedderError::hf_could_not_determine_dimension)?; + .map_err(NewEmbedderError::could_not_determine_dimension)?; this.dimensions = embeddings.first().unwrap().dimension(); Ok(this) @@ -194,7 +194,10 @@ impl Embedder { pub fn distribution(&self) -> Option { if self.options.model == "BAAI/bge-base-en-v1.5" { - Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) + Some(DistributionShift { + current_mean: ordered_float::OrderedFloat(0.85), + current_sigma: ordered_float::OrderedFloat(0.1), + }) } else { None } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 035ac555e..65654af4a 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -1,6 +1,9 @@ use std::collections::HashMap; use std::sync::Arc; +use ordered_float::OrderedFloat; +use serde::{Deserialize, Serialize}; + use self::error::{EmbedError, NewEmbedderError}; use crate::prompt::{Prompt, PromptData}; @@ -11,51 +14,70 @@ pub mod openai; pub mod settings; pub mod ollama; +pub mod rest; pub use self::error::Error; pub type Embedding = Vec; +pub const REQUEST_PARALLELISM: usize = 40; + +/// One or multiple embeddings stored consecutively in a flat vector. pub struct Embeddings { data: Vec, dimension: usize, } impl Embeddings { + /// Declares an empty vector of embeddings of the specified dimensions. pub fn new(dimension: usize) -> Self { Self { data: Default::default(), dimension } } + /// Declares a vector of embeddings containing a single element. + /// + /// The dimension is inferred from the length of the passed embedding. pub fn from_single_embedding(embedding: Vec) -> Self { Self { dimension: embedding.len(), data: embedding } } + /// Declares a vector of embeddings from its components. + /// + /// `data.len()` must be a multiple of `dimension`, otherwise an error is returned. pub fn from_inner(data: Vec, dimension: usize) -> Result> { let mut this = Self::new(dimension); this.append(data)?; Ok(this) } + /// Returns the number of embeddings in this vector of embeddings. pub fn embedding_count(&self) -> usize { self.data.len() / self.dimension } + /// Dimension of a single embedding. pub fn dimension(&self) -> usize { self.dimension } + /// Deconstructs self into the inner flat vector. pub fn into_inner(self) -> Vec { self.data } + /// A reference to the inner flat vector. pub fn as_inner(&self) -> &[F] { &self.data } + /// Iterates over the embeddings contained in the flat vector. pub fn iter(&self) -> impl Iterator + '_ { self.data.as_slice().chunks_exact(self.dimension) } + /// Push an embedding at the end of the embeddings. + /// + /// If `embedding.len() != self.dimension`, then the push operation fails. pub fn push(&mut self, mut embedding: Vec) -> Result<(), Vec> { if embedding.len() != self.dimension { return Err(embedding); @@ -64,6 +86,9 @@ impl Embeddings { Ok(()) } + /// Append a flat vector of embeddings a the end of the embeddings. + /// + /// If `embeddings.len() % self.dimension != 0`, then the append operation fails. pub fn append(&mut self, mut embeddings: Vec) -> Result<(), Vec> { if embeddings.len() % self.dimension != 0 { return Err(embeddings); @@ -73,37 +98,60 @@ impl Embeddings { } } +/// An embedder can be used to transform text into embeddings. #[derive(Debug)] pub enum Embedder { + /// An embedder based on running local models, fetched from the Hugging Face Hub. HuggingFace(hf::Embedder), + /// An embedder based on making embedding queries against the OpenAI API. OpenAi(openai::Embedder), + /// An embedder based on the user providing the embeddings in the documents and queries. UserProvided(manual::Embedder), + /// An embedder based on making embedding queries against an embedding server. Ollama(ollama::Embedder), + /// An embedder based on making embedding queries against a generic JSON/REST embedding server. + Rest(rest::Embedder), } +/// Configuration for an embedder. #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] pub struct EmbeddingConfig { + /// Options of the embedder, specific to each kind of embedder pub embedder_options: EmbedderOptions, + /// Document template pub prompt: PromptData, // TODO: add metrics and anything needed } +/// Map of embedder configurations. +/// +/// Each configuration is mapped to a name. #[derive(Clone, Default)] pub struct EmbeddingConfigs(HashMap, Arc)>); impl EmbeddingConfigs { + /// Create the map from its internal component.s pub fn new(data: HashMap, Arc)>) -> Self { Self(data) } + /// Get an embedder configuration and template from its name. pub fn get(&self, name: &str) -> Option<(Arc, Arc)> { self.0.get(name).cloned() } + /// Get the default embedder configuration, if any. pub fn get_default(&self) -> Option<(Arc, Arc)> { self.get_default_embedder_name().and_then(|default| self.get(&default)) } + /// Get the name of the default embedder configuration. + /// + /// The default embedder is determined as follows: + /// + /// - If there is only one embedder, it is always the default. + /// - If there are multiple embedders and one of them is called `default`, then that one is the default embedder. + /// - In all other cases, there is no default embedder. pub fn get_default_embedder_name(&self) -> Option { let mut it = self.0.keys(); let first_name = it.next(); @@ -126,12 +174,14 @@ impl IntoIterator for EmbeddingConfigs { } } +/// Options of an embedder, specific to each kind of embedder. #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions), OpenAi(openai::EmbedderOptions), Ollama(ollama::EmbedderOptions), UserProvided(manual::EmbedderOptions), + Rest(rest::EmbedderOptions), } impl Default for EmbedderOptions { @@ -141,10 +191,12 @@ impl Default for EmbedderOptions { } impl EmbedderOptions { + /// Default options for the Hugging Face embedder pub fn huggingface() -> Self { Self::HuggingFace(hf::EmbedderOptions::new()) } + /// Default options for the OpenAI embedder pub fn openai(api_key: Option) -> Self { Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) } @@ -155,6 +207,7 @@ impl EmbedderOptions { } impl Embedder { + /// Spawns a new embedder built from its options. pub fn new(options: EmbedderOptions) -> std::result::Result { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), @@ -163,83 +216,133 @@ impl Embedder { EmbedderOptions::UserProvided(options) => { Self::UserProvided(manual::Embedder::new(options)) } + EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?), }) } - pub async fn embed( + /// Embed one or multiple texts. + /// + /// Each text can be embedded as one or multiple embeddings. + pub fn embed( &self, texts: Vec, ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed(texts), - Embedder::OpenAi(embedder) => { - let client = embedder.new_client()?; - embedder.embed(texts, &client).await - } - Embedder::Ollama(embedder) => { - let client = embedder.new_client()?; - embedder.embed(texts, &client).await - } + Embedder::OpenAi(embedder) => embedder.embed(texts), + Embedder::Ollama(embedder) => embedder.embed(texts), Embedder::UserProvided(embedder) => embedder.embed(texts), + Embedder::Rest(embedder) => embedder.embed(texts), } } - /// # Panics + /// Embed multiple chunks of texts. /// - /// - if called from an asynchronous context + /// Each chunk is composed of one or multiple texts. pub fn embed_chunks( &self, text_chunks: Vec>, + threads: &rayon::ThreadPool, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), - Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks), - Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks), + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads), + Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads), Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), + Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads), } } + /// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`] pub fn chunk_count_hint(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), Embedder::Ollama(embedder) => embedder.chunk_count_hint(), Embedder::UserProvided(_) => 1, + Embedder::Rest(embedder) => embedder.chunk_count_hint(), } } + /// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`] pub fn prompt_count_in_chunk_hint(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::UserProvided(_) => 1, + Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(), } } + /// Indicates the dimensions of a single embedding produced by the embedder. pub fn dimensions(&self) -> usize { match self { Embedder::HuggingFace(embedder) => embedder.dimensions(), Embedder::OpenAi(embedder) => embedder.dimensions(), Embedder::Ollama(embedder) => embedder.dimensions(), Embedder::UserProvided(embedder) => embedder.dimensions(), + Embedder::Rest(embedder) => embedder.dimensions(), } } + /// An optional distribution used to apply an affine transformation to the similarity score of a document. pub fn distribution(&self) -> Option { match self { Embedder::HuggingFace(embedder) => embedder.distribution(), Embedder::OpenAi(embedder) => embedder.distribution(), Embedder::Ollama(embedder) => embedder.distribution(), Embedder::UserProvided(_embedder) => None, + Embedder::Rest(embedder) => embedder.distribution(), } } } -#[derive(Debug, Clone, Copy)] +/// Describes the mean and sigma of distribution of embedding similarity in the embedding space. +/// +/// The intended use is to make the similarity score more comparable to the regular ranking score. +/// This allows to correct effects where results are too "packed" around a certain value. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)] +#[serde(from = "DistributionShiftSerializable")] +#[serde(into = "DistributionShiftSerializable")] pub struct DistributionShift { - pub current_mean: f32, - pub current_sigma: f32, + /// Value where the results are "packed". + /// + /// Similarity scores are translated so that they are packed around 0.5 instead + pub current_mean: OrderedFloat, + + /// standard deviation of a similarity score. + /// + /// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed. + pub current_sigma: OrderedFloat, +} + +#[derive(Serialize, Deserialize)] +struct DistributionShiftSerializable { + current_mean: f32, + current_sigma: f32, +} + +impl From for DistributionShiftSerializable { + fn from( + DistributionShift { + current_mean: OrderedFloat(current_mean), + current_sigma: OrderedFloat(current_sigma), + }: DistributionShift, + ) -> Self { + Self { current_mean, current_sigma } + } +} + +impl From for DistributionShift { + fn from( + DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable, + ) -> Self { + Self { + current_mean: OrderedFloat(current_mean), + current_sigma: OrderedFloat(current_sigma), + } + } } impl DistributionShift { @@ -248,11 +351,13 @@ impl DistributionShift { if sigma <= 0.0 { None } else { - Some(Self { current_mean: mean, current_sigma: sigma }) + Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) }) } } pub fn shift(&self, score: f32) -> f32 { + let current_mean = self.current_mean.0; + let current_sigma = self.current_sigma.0; // // We're somewhat abusively mapping the distribution of distances to a gaussian. // The parameters we're given is the mean and sigma of the native result distribution. @@ -262,9 +367,9 @@ impl DistributionShift { let target_sigma = 0.4; // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. - let factor = target_sigma / self.current_sigma; + let factor = target_sigma / current_sigma; // a*mu1 + b = mu2 => b = mu2 - a*mu1 - let offset = target_mean - (factor * self.current_mean); + let offset = target_mean - (factor * current_mean); let mut score = factor * score + offset; @@ -280,6 +385,7 @@ impl DistributionShift { } } +/// Whether CUDA is supported in this version of Meilisearch. pub const fn is_cuda_enabled() -> bool { cfg!(feature = "cuda") } diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index 76988f70b..9c44e8052 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -1,293 +1,94 @@ -// Copied from "openai.rs" with the sections I actually understand changed for Ollama. -// The common components of the Ollama and OpenAI interfaces might need to be extracted. +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; -use std::fmt::Display; - -use reqwest::StatusCode; - -use super::error::{EmbedError, NewEmbedderError}; -use super::openai::Retry; -use super::{DistributionShift, Embedding, Embeddings}; +use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; +use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; +use super::{DistributionShift, Embeddings}; #[derive(Debug)] pub struct Embedder { - headers: reqwest::header::HeaderMap, - options: EmbedderOptions, + rest_embedder: RestEmbedder, } #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { - pub embedding_model: EmbeddingModel, -} - -#[derive( - Debug, Clone, Hash, PartialEq, Eq, serde::Serialize, serde::Deserialize, deserr::Deserr, -)] -#[deserr(deny_unknown_fields)] -pub struct EmbeddingModel { - name: String, - dimensions: usize, -} - -#[derive(Debug, serde::Serialize)] -struct OllamaRequest<'a> { - model: &'a str, - prompt: &'a str, -} - -#[derive(Debug, serde::Deserialize)] -struct OllamaResponse { - embedding: Embedding, -} - -#[derive(Debug, serde::Deserialize)] -pub struct OllamaError { - error: String, -} - -impl EmbeddingModel { - pub fn max_token(&self) -> usize { - // this might not be the same for all models - 8192 - } - - pub fn default_dimensions(&self) -> usize { - // Dimensions for nomic-embed-text - 768 - } - - pub fn name(&self) -> String { - self.name.clone() - } - - pub fn from_name(name: &str) -> Self { - Self { name: name.to_string(), dimensions: 0 } - } - - pub fn supports_overriding_dimensions(&self) -> bool { - false - } -} - -impl Default for EmbeddingModel { - fn default() -> Self { - Self { name: "nomic-embed-text".to_string(), dimensions: 0 } - } + pub embedding_model: String, } impl EmbedderOptions { pub fn with_default_model() -> Self { - Self { embedding_model: Default::default() } + Self { embedding_model: "nomic-embed-text".into() } } - pub fn with_embedding_model(embedding_model: EmbeddingModel) -> Self { + pub fn with_embedding_model(embedding_model: String) -> Self { Self { embedding_model } } } impl Embedder { - pub fn new_client(&self) -> Result { - reqwest::ClientBuilder::new() - .default_headers(self.headers.clone()) - .build() - .map_err(EmbedError::openai_initialize_web_client) - } - pub fn new(options: EmbedderOptions) -> Result { - let mut headers = reqwest::header::HeaderMap::new(); - headers.insert( - reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_static("application/json"), - ); - - let mut embedder = Self { options, headers }; - - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - // Get dimensions from Ollama - let request = - OllamaRequest { model: &embedder.options.embedding_model.name(), prompt: "test" }; - // TODO: Refactor into shared error type - let client = embedder - .new_client() - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - rt.block_on(async move { - let response = client - .post(get_ollama_path()) - .json(&request) - .send() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - // Process error in case model not found - let response = Self::check_response(response).await.map_err(|_err| { - let e = EmbedError::ollama_model_not_found(OllamaError { - error: format!("model: {}", embedder.options.embedding_model.name()), - }); - NewEmbedderError::ollama_could_not_determine_dimension(e) - })?; - - let response: OllamaResponse = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(NewEmbedderError::ollama_could_not_determine_dimension)?; - - let embedding = Embeddings::from_single_embedding(response.embedding); - - embedder.options.embedding_model.dimensions = embedding.dimension(); - - tracing::info!( - "ollama model {} with dimensionality {} added", - embedder.options.embedding_model.name(), - embedding.dimension() - ); - - Ok(embedder) - }) - } - - async fn check_response(response: reqwest::Response) -> Result { - if !response.status().is_success() { - // Not the same number of possible error cases covered as with OpenAI. - match response.status() { - StatusCode::TOO_MANY_REQUESTS => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::retry_later)?; - - return Err(Retry::rate_limited(EmbedError::ollama_too_many_requests( - OllamaError { error: error_response.error }, - ))); - } - StatusCode::SERVICE_UNAVAILABLE => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::retry_later)?; - return Err(Retry::retry_later(EmbedError::ollama_internal_server_error( - OllamaError { error: error_response.error }, - ))); - } - StatusCode::NOT_FOUND => { - let error_response: OllamaError = response - .json() - .await - .map_err(EmbedError::ollama_unexpected) - .map_err(Retry::give_up)?; - - return Err(Retry::give_up(EmbedError::ollama_model_not_found(OllamaError { - error: error_response.error, - }))); - } - code => { - return Err(Retry::give_up(EmbedError::ollama_unhandled_status_code( - code.as_u16(), - ))); - } + let model = options.embedding_model.as_str(); + let rest_embedder = match RestEmbedder::new(RestEmbedderOptions { + api_key: None, + distribution: None, + dimensions: None, + url: get_ollama_path(), + query: serde_json::json!({ + "model": model, + }), + input_field: vec!["prompt".to_owned()], + path_to_embeddings: Default::default(), + embedding_object: vec!["embedding".to_owned()], + input_type: super::rest::InputType::Text, + }) { + Ok(embedder) => embedder, + Err(NewEmbedderError { + kind: + NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError { + kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error), + fault: _, + }), + fault: _, + }) => { + return Err(NewEmbedderError::could_not_determine_dimension( + EmbedError::ollama_model_not_found(error), + )) } - } - Ok(response) + Err(error) => return Err(error), + }; + + Ok(Self { rest_embedder }) } - pub async fn embed( - &self, - texts: Vec, - client: &reqwest::Client, - ) -> Result>, EmbedError> { - // Ollama only embedds one document at a time. - let mut results = Vec::with_capacity(texts.len()); - - // The retry loop is inside the texts loop, might have to switch that around - for text in texts { - // Retries copied from openai.rs - for attempt in 0..7 { - let retry_duration = match self.try_embed(&text, client).await { - Ok(result) => { - results.push(result); - break; - } - Err(retry) => { - tracing::warn!("Failed: {}", retry.error); - retry.into_duration(attempt) - } - }?; - tracing::warn!( - "Attempt #{}, retrying after {}ms.", - attempt, - retry_duration.as_millis() - ); - tokio::time::sleep(retry_duration).await; + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + match self.rest_embedder.embed(texts) { + Ok(embeddings) => Ok(embeddings), + Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => { + Err(EmbedError::ollama_model_not_found(error)) } + Err(error) => Err(error), } - - Ok(results) - } - - async fn try_embed( - &self, - text: &str, - client: &reqwest::Client, - ) -> Result, Retry> { - let request = OllamaRequest { model: &self.options.embedding_model.name(), prompt: text }; - let response = client - .post(get_ollama_path()) - .json(&request) - .send() - .await - .map_err(EmbedError::openai_network) - .map_err(Retry::retry_later)?; - - let response = Self::check_response(response).await?; - - let response: OllamaResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::trace!("response: {:?}", response.embedding); - - let embedding = Embeddings::from_single_embedding(response.embedding); - Ok(embedding) } pub fn embed_chunks( &self, text_chunks: Vec>, + threads: &rayon::ThreadPool, ) -> Result>>, EmbedError> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init)?; - let client = self.new_client()?; - rt.block_on(futures::future::try_join_all( - text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), - )) + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) } - // Defaults copied from openai.rs pub fn chunk_count_hint(&self) -> usize { - 10 + self.rest_embedder.chunk_count_hint() } pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 + self.rest_embedder.prompt_count_in_chunk_hint() } pub fn dimensions(&self) -> usize { - self.options.embedding_model.dimensions + self.rest_embedder.dimensions() } pub fn distribution(&self) -> Option { @@ -295,12 +96,6 @@ impl Embedder { } } -impl Display for OllamaError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.error) - } -} - fn get_ollama_path() -> String { // Important: Hostname not enough, has to be entire path to embeddings endpoint std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string()) diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index dcf3f4c89..24e94a9f7 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -1,17 +1,10 @@ -use std::fmt::Display; - -use reqwest::StatusCode; -use serde::{Deserialize, Serialize}; +use ordered_float::OrderedFloat; +use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use super::error::{EmbedError, NewEmbedderError}; -use super::{DistributionShift, Embedding, Embeddings}; - -#[derive(Debug)] -pub struct Embedder { - headers: reqwest::header::HeaderMap, - tokenizer: tiktoken_rs::CoreBPE, - options: EmbedderOptions, -} +use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; +use super::{DistributionShift, Embeddings}; +use crate::vector::error::EmbedErrorKind; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { @@ -20,6 +13,32 @@ pub struct EmbedderOptions { pub dimensions: Option, } +impl EmbedderOptions { + pub fn dimensions(&self) -> usize { + if self.embedding_model.supports_overriding_dimensions() { + self.dimensions.unwrap_or(self.embedding_model.default_dimensions()) + } else { + self.embedding_model.default_dimensions() + } + } + + pub fn query(&self) -> serde_json::Value { + let model = self.embedding_model.name(); + + let mut query = serde_json::json!({ + "model": model, + }); + + if self.embedding_model.supports_overriding_dimensions() { + if let Some(dimensions) = self.dimensions { + query["dimensions"] = dimensions.into(); + } + } + + query + } +} + #[derive( Debug, Clone, @@ -92,15 +111,18 @@ impl EmbeddingModel { fn distribution(&self) -> Option { match self { - EmbeddingModel::TextEmbeddingAda002 => { - Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) - } - EmbeddingModel::TextEmbedding3Large => { - Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 }) - } - EmbeddingModel::TextEmbedding3Small => { - Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 }) - } + EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift { + current_mean: OrderedFloat(0.90), + current_sigma: OrderedFloat(0.08), + }), + EmbeddingModel::TextEmbedding3Large => Some(DistributionShift { + current_mean: OrderedFloat(0.70), + current_sigma: OrderedFloat(0.1), + }), + EmbeddingModel::TextEmbedding3Small => Some(DistributionShift { + current_mean: OrderedFloat(0.75), + current_sigma: OrderedFloat(0.1), + }), } } @@ -125,178 +147,57 @@ impl EmbedderOptions { } } -impl Embedder { - pub fn new_client(&self) -> Result { - reqwest::ClientBuilder::new() - .default_headers(self.headers.clone()) - .build() - .map_err(EmbedError::openai_initialize_web_client) - } +fn infer_api_key() -> String { + std::env::var("MEILI_OPENAI_API_KEY") + .or_else(|_| std::env::var("OPENAI_API_KEY")) + .unwrap_or_default() +} +#[derive(Debug)] +pub struct Embedder { + tokenizer: tiktoken_rs::CoreBPE, + rest_embedder: RestEmbedder, + options: EmbedderOptions, +} + +impl Embedder { pub fn new(options: EmbedderOptions) -> Result { - let mut headers = reqwest::header::HeaderMap::new(); let mut inferred_api_key = Default::default(); let api_key = options.api_key.as_ref().unwrap_or_else(|| { inferred_api_key = infer_api_key(); &inferred_api_key }); - headers.insert( - reqwest::header::AUTHORIZATION, - reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) - .map_err(NewEmbedderError::openai_invalid_api_key_format)?, - ); - headers.insert( - reqwest::header::CONTENT_TYPE, - reqwest::header::HeaderValue::from_static("application/json"), - ); + + let rest_embedder = RestEmbedder::new(RestEmbedderOptions { + api_key: Some(api_key.clone()), + distribution: options.embedding_model.distribution(), + dimensions: Some(options.dimensions()), + url: OPENAI_EMBEDDINGS_URL.to_owned(), + query: options.query(), + input_field: vec!["input".to_owned()], + input_type: crate::vector::rest::InputType::TextArray, + path_to_embeddings: vec!["data".to_owned()], + embedding_object: vec!["embedding".to_owned()], + })?; // looking at the code it is very unclear that this can actually fail. let tokenizer = tiktoken_rs::cl100k_base().unwrap(); - Ok(Self { options, headers, tokenizer }) + Ok(Self { options, rest_embedder, tokenizer }) } - pub async fn embed( - &self, - texts: Vec, - client: &reqwest::Client, - ) -> Result>, EmbedError> { - let mut tokenized = false; - - for attempt in 0..7 { - let result = if tokenized { - self.try_embed_tokenized(&texts, client).await - } else { - self.try_embed(&texts, client).await - }; - - let retry_duration = match result { - Ok(embeddings) => return Ok(embeddings), - Err(retry) => { - tracing::warn!("Failed: {}", retry.error); - tokenized |= retry.must_tokenize(); - retry.into_duration(attempt) - } - }?; - - let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute - tracing::warn!( - "Attempt #{}, retrying after {}ms.", - attempt, - retry_duration.as_millis() - ); - tokio::time::sleep(retry_duration).await; - } - - let result = if tokenized { - self.try_embed_tokenized(&texts, client).await - } else { - self.try_embed(&texts, client).await - }; - - result.map_err(Retry::into_error) - } - - async fn check_response(response: reqwest::Response) -> Result { - if !response.status().is_success() { - match response.status() { - StatusCode::UNAUTHORIZED => { - let error_response: OpenAiErrorResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - return Err(Retry::give_up(EmbedError::openai_auth_error( - error_response.error, - ))); - } - StatusCode::TOO_MANY_REQUESTS => { - let error_response: OpenAiErrorResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - return Err(Retry::rate_limited(EmbedError::openai_too_many_requests( - error_response.error, - ))); - } - StatusCode::INTERNAL_SERVER_ERROR - | StatusCode::BAD_GATEWAY - | StatusCode::SERVICE_UNAVAILABLE => { - let error_response: Result = response.json().await; - return Err(Retry::retry_later(EmbedError::openai_internal_server_error( - error_response.ok().map(|error_response| error_response.error), - ))); - } - StatusCode::BAD_REQUEST => { - // Most probably, one text contained too many tokens - let error_response: OpenAiErrorResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt."); - - return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens( - error_response.error, - ))); - } - code => { - return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code( - code.as_u16(), - ))); - } + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + match self.rest_embedder.embed_ref(&texts) { + Ok(embeddings) => Ok(embeddings), + Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => { + tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template."); + self.try_embed_tokenized(&texts) } + Err(error) => Err(error), } - Ok(response) } - async fn try_embed + serde::Serialize>( - &self, - texts: &[S], - client: &reqwest::Client, - ) -> Result>, Retry> { - for text in texts { - tracing::trace!("Received prompt: {}", text.as_ref()) - } - let request = OpenAiRequest { - model: self.options.embedding_model.name(), - input: texts, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .json(&request) - .send() - .await - .map_err(EmbedError::openai_network) - .map_err(Retry::retry_later)?; - - let response = Self::check_response(response).await?; - - let response: OpenAiResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - - tracing::trace!("response: {:?}", response.data); - - Ok(response - .data - .into_iter() - .map(|data| Embeddings::from_single_embedding(data.embedding)) - .collect()) - } - - async fn try_embed_tokenized( - &self, - text: &[String], - client: &reqwest::Client, - ) -> Result>, Retry> { + fn try_embed_tokenized(&self, text: &[String]) -> Result>, EmbedError> { pub const OVERLAP_SIZE: usize = 200; let mut all_embeddings = Vec::with_capacity(text.len()); for text in text { @@ -304,7 +205,7 @@ impl Embedder { let encoded = self.tokenizer.encode_ordinary(text.as_str()); let len = encoded.len(); if len < max_token_count { - all_embeddings.append(&mut self.try_embed(&[text], client).await?); + all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?); continue; } @@ -312,215 +213,49 @@ impl Embedder { let mut embeddings_for_prompt = Embeddings::new(self.dimensions()); while tokens.len() > max_token_count { let window = &tokens[..max_token_count]; - embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap(); + let embedding = self.rest_embedder.embed_tokens(window)?; + embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { + EmbedError::openai_unexpected_dimension(self.dimensions(), got.len()) + })?; tokens = &tokens[max_token_count - OVERLAP_SIZE..]; } // end of text - embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap(); + let embedding = self.rest_embedder.embed_tokens(tokens)?; + + embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| { + EmbedError::openai_unexpected_dimension(self.dimensions(), got.len()) + })?; all_embeddings.push(embeddings_for_prompt); } Ok(all_embeddings) } - async fn embed_tokens( - &self, - tokens: &[usize], - client: &reqwest::Client, - ) -> Result { - for attempt in 0..9 { - let duration = match self.try_embed_tokens(tokens, client).await { - Ok(embedding) => return Ok(embedding), - Err(retry) => retry.into_duration(attempt), - } - .map_err(Retry::retry_later)?; - - tokio::time::sleep(duration).await; - } - - self.try_embed_tokens(tokens, client) - .await - .map_err(|retry| Retry::give_up(retry.into_error())) - } - - async fn try_embed_tokens( - &self, - tokens: &[usize], - client: &reqwest::Client, - ) -> Result { - let request = OpenAiTokensRequest { - model: self.options.embedding_model.name(), - input: tokens, - dimensions: self.overriden_dimensions(), - }; - let response = client - .post(OPENAI_EMBEDDINGS_URL) - .json(&request) - .send() - .await - .map_err(EmbedError::openai_network) - .map_err(Retry::retry_later)?; - - let response = Self::check_response(response).await?; - - let mut response: OpenAiResponse = response - .json() - .await - .map_err(EmbedError::openai_unexpected) - .map_err(Retry::retry_later)?; - Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) - } - pub fn embed_chunks( &self, text_chunks: Vec>, + threads: &rayon::ThreadPool, ) -> Result>>, EmbedError> { - let rt = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .map_err(EmbedError::openai_runtime_init)?; - let client = self.new_client()?; - rt.block_on(futures::future::try_join_all( - text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)), - )) + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) } pub fn chunk_count_hint(&self) -> usize { - 10 + self.rest_embedder.chunk_count_hint() } pub fn prompt_count_in_chunk_hint(&self) -> usize { - 10 + self.rest_embedder.prompt_count_in_chunk_hint() } pub fn dimensions(&self) -> usize { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions()) - } else { - self.options.embedding_model.default_dimensions() - } + self.options.dimensions() } pub fn distribution(&self) -> Option { self.options.embedding_model.distribution() } - - fn overriden_dimensions(&self) -> Option { - if self.options.embedding_model.supports_overriding_dimensions() { - self.options.dimensions - } else { - None - } - } -} - -// retrying in case of failure - -pub struct Retry { - pub error: EmbedError, - strategy: RetryStrategy, -} - -pub enum RetryStrategy { - GiveUp, - Retry, - RetryTokenized, - RetryAfterRateLimit, -} - -impl Retry { - pub fn give_up(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::GiveUp } - } - - pub fn retry_later(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::Retry } - } - - pub fn retry_tokenized(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::RetryTokenized } - } - - pub fn rate_limited(error: EmbedError) -> Self { - Self { error, strategy: RetryStrategy::RetryAfterRateLimit } - } - - pub fn into_duration(self, attempt: u32) -> Result { - match self.strategy { - RetryStrategy::GiveUp => Err(self.error), - RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), - RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), - RetryStrategy::RetryAfterRateLimit => { - Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) - } - } - } - - pub fn must_tokenize(&self) -> bool { - matches!(self.strategy, RetryStrategy::RetryTokenized) - } - - pub fn into_error(self) -> EmbedError { - self.error - } -} - -// openai api structs - -#[derive(Debug, Serialize)] -struct OpenAiRequest<'a, S: AsRef + serde::Serialize> { - model: &'a str, - input: &'a [S], - #[serde(skip_serializing_if = "Option::is_none")] - dimensions: Option, -} - -#[derive(Debug, Serialize)] -struct OpenAiTokensRequest<'a> { - model: &'a str, - input: &'a [usize], - #[serde(skip_serializing_if = "Option::is_none")] - dimensions: Option, -} - -#[derive(Debug, Deserialize)] -struct OpenAiResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct OpenAiErrorResponse { - error: OpenAiError, -} - -#[derive(Debug, Deserialize)] -pub struct OpenAiError { - message: String, - // type: String, - code: Option, -} - -impl Display for OpenAiError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match &self.code { - Some(code) => write!(f, "{} ({})", self.message, code), - None => write!(f, "{}", self.message), - } - } -} - -#[derive(Debug, Deserialize)] -struct OpenAiEmbedding { - embedding: Embedding, - // object: String, - // index: usize, -} - -fn infer_api_key() -> String { - std::env::var("MEILI_OPENAI_API_KEY") - .or_else(|_| std::env::var("OPENAI_API_KEY")) - .unwrap_or_default() } diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs new file mode 100644 index 000000000..b0ea07f82 --- /dev/null +++ b/milli/src/vector/rest.rs @@ -0,0 +1,373 @@ +use deserr::Deserr; +use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; +use serde::{Deserialize, Serialize}; + +use super::{ + DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, +}; + +// retrying in case of failure + +pub struct Retry { + pub error: EmbedError, + strategy: RetryStrategy, +} + +pub enum RetryStrategy { + GiveUp, + Retry, + RetryTokenized, + RetryAfterRateLimit, +} + +impl Retry { + pub fn give_up(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + pub fn retry_later(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + pub fn retry_tokenized(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryTokenized } + } + + pub fn rate_limited(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryAfterRateLimit } + } + + pub fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))), + RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)), + RetryStrategy::RetryAfterRateLimit => { + Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt))) + } + } + } + + pub fn must_tokenize(&self) -> bool { + matches!(self.strategy, RetryStrategy::RetryTokenized) + } + + pub fn into_error(self) -> EmbedError { + self.error + } +} + +#[derive(Debug)] +pub struct Embedder { + client: ureq::Agent, + options: EmbedderOptions, + bearer: Option, + dimensions: usize, +} + +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +pub struct EmbedderOptions { + pub api_key: Option, + pub distribution: Option, + pub dimensions: Option, + pub url: String, + pub query: serde_json::Value, + pub input_field: Vec, + // path to the array of embeddings + pub path_to_embeddings: Vec, + // shape of a single embedding + pub embedding_object: Vec, + pub input_type: InputType, +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self { + url: Default::default(), + query: Default::default(), + input_field: vec!["input".into()], + path_to_embeddings: vec!["data".into()], + embedding_object: vec!["embedding".into()], + input_type: InputType::Text, + api_key: None, + distribution: None, + dimensions: None, + } + } +} + +impl std::hash::Hash for EmbedderOptions { + fn hash(&self, state: &mut H) { + self.api_key.hash(state); + self.distribution.hash(state); + self.dimensions.hash(state); + self.url.hash(state); + // skip hashing the query + // collisions in regular usage should be minimal, + // and the list is limited to 256 values anyway + self.input_field.hash(state); + self.path_to_embeddings.hash(state); + self.embedding_object.hash(state); + self.input_type.hash(state); + } +} + +#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)] +#[serde(rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum InputType { + Text, + TextArray, +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}")); + + let client = ureq::AgentBuilder::new() + .max_idle_connections(REQUEST_PARALLELISM * 2) + .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) + .build(); + + let dimensions = if let Some(dimensions) = options.dimensions { + dimensions + } else { + infer_dimensions(&client, &options, bearer.as_deref())? + }; + + Ok(Self { client, dimensions, options, bearer }) + } + + pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { + embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len()) + } + + pub fn embed_ref(&self, texts: &[S]) -> Result>, EmbedError> + where + S: AsRef + Serialize, + { + embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len()) + } + + pub fn embed_tokens(&self, tokens: &[usize]) -> Result, EmbedError> { + let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?; + // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error + Ok(embeddings.pop().unwrap()) + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + threads: &rayon::ThreadPool, + ) -> Result>>, EmbedError> { + threads.install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + } + + pub fn chunk_count_hint(&self) -> usize { + super::REQUEST_PARALLELISM + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + match self.options.input_type { + InputType::Text => 1, + InputType::TextArray => 10, + } + } + + pub fn dimensions(&self) -> usize { + self.dimensions + } + + pub fn distribution(&self) -> Option { + self.options.distribution + } +} + +fn infer_dimensions( + client: &ureq::Agent, + options: &EmbedderOptions, + bearer: Option<&str>, +) -> Result { + let v = embed(client, options, bearer, ["test"].as_slice(), 1) + .map_err(NewEmbedderError::could_not_determine_dimension)?; + // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error + Ok(v.first().unwrap().dimension()) +} + +fn embed( + client: &ureq::Agent, + options: &EmbedderOptions, + bearer: Option<&str>, + inputs: &[S], + expected_count: usize, +) -> Result>, EmbedError> +where + S: Serialize, +{ + let request = client.post(&options.url); + let request = + if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; + let request = request.set("Content-Type", "application/json"); + + let input_value = match options.input_type { + InputType::Text => serde_json::json!(inputs.first()), + InputType::TextArray => serde_json::json!(inputs), + }; + + let body = match options.input_field.as_slice() { + [] => { + // inject input in body + input_value + } + [input] => { + let mut body = options.query.clone(); + + body.as_object_mut() + .ok_or_else(|| { + EmbedError::rest_not_an_object( + options.query.clone(), + options.input_field.clone(), + ) + })? + .insert(input.clone(), input_value); + body + } + [path @ .., input] => { + let mut body = options.query.clone(); + + let mut current_value = &mut body; + for component in path { + current_value = current_value + .as_object_mut() + .ok_or_else(|| { + EmbedError::rest_not_an_object( + options.query.clone(), + options.input_field.clone(), + ) + })? + .entry(component.clone()) + .or_insert(serde_json::json!({})); + } + + current_value.as_object_mut().unwrap().insert(input.clone(), input_value); + body + } + }; + + for attempt in 0..7 { + let response = request.clone().send_json(&body); + let result = check_response(response); + + let retry_duration = match result { + Ok(response) => return response_to_embedding(response, options, expected_count), + Err(retry) => { + tracing::warn!("Failed: {}", retry.error); + retry.into_duration(attempt) + } + }?; + + let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute + tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); + std::thread::sleep(retry_duration); + } + + let response = request.send_json(&body); + let result = check_response(response); + result + .map_err(Retry::into_error) + .and_then(|response| response_to_embedding(response, options, expected_count)) +} + +fn check_response(response: Result) -> Result { + match response { + Ok(response) => Ok(response), + Err(ureq::Error::Status(code, response)) => { + let error_response: Option = response.into_string().ok(); + Err(match code { + 401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)), + 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), + 400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), + 500..=599 => { + Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) + } + 402..=499 => { + Retry::give_up(EmbedError::rest_other_status_code(code, error_response)) + } + _ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)), + }) + } + Err(ureq::Error::Transport(transport)) => { + Err(Retry::retry_later(EmbedError::rest_network(transport))) + } + } +} + +fn response_to_embedding( + response: ureq::Response, + options: &EmbedderOptions, + expected_count: usize, +) -> Result>, EmbedError> { + let response: serde_json::Value = + response.into_json().map_err(EmbedError::rest_response_deserialization)?; + + let mut current_value = &response; + for component in &options.path_to_embeddings { + let component = component.as_ref(); + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.path_to_embeddings, + ) + })?; + } + + let embeddings = match options.input_type { + InputType::Text => { + for component in &options.embedding_object { + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.embedding_object, + ) + })?; + } + let embeddings = current_value.to_owned(); + let embeddings: Embedding = + serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; + + vec![Embeddings::from_single_embedding(embeddings)] + } + InputType::TextArray => { + let empty = vec![]; + let values = current_value.as_array().unwrap_or(&empty); + let mut embeddings: Vec> = Vec::with_capacity(expected_count); + for value in values { + let mut current_value = value; + for component in &options.embedding_object { + current_value = current_value.get(component).ok_or_else(|| { + EmbedError::rest_response_missing_embeddings( + response.clone(), + component, + &options.embedding_object, + ) + })?; + } + let embedding = current_value.to_owned(); + let embedding: Embedding = + serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?; + embeddings.push(Embeddings::from_single_embedding(embedding)); + } + embeddings + } + }; + + if embeddings.len() != expected_count { + return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); + } + + Ok(embeddings) +} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 89571e98a..c5b0d0326 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -1,6 +1,7 @@ use deserr::Deserr; use serde::{Deserialize, Serialize}; +use super::rest::InputType; use super::{ollama, openai}; use crate::prompt::PromptData; use crate::update::Setting; @@ -29,6 +30,24 @@ pub struct EmbeddingSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub document_template: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub url: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub query: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub input_field: Setting>, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub path_to_embeddings: Setting>, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub embedding_object: Setting>, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub input_type: Setting, } pub fn check_unset( @@ -75,20 +94,42 @@ impl EmbeddingSettings { pub const DIMENSIONS: &'static str = "dimensions"; pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate"; + pub const URL: &'static str = "url"; + pub const QUERY: &'static str = "query"; + pub const INPUT_FIELD: &'static str = "inputField"; + pub const PATH_TO_EMBEDDINGS: &'static str = "pathToEmbeddings"; + pub const EMBEDDING_OBJECT: &'static str = "embeddingObject"; + pub const INPUT_TYPE: &'static str = "inputType"; + pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] { match field { - Self::SOURCE => { - &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided] - } + Self::SOURCE => &[ + EmbedderSource::HuggingFace, + EmbedderSource::OpenAi, + EmbedderSource::UserProvided, + EmbedderSource::Rest, + EmbedderSource::Ollama, + ], Self::MODEL => { &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] } Self::REVISION => &[EmbedderSource::HuggingFace], - Self::API_KEY => &[EmbedderSource::OpenAi], - Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided], - Self::DOCUMENT_TEMPLATE => { - &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama] + Self::API_KEY => &[EmbedderSource::OpenAi, EmbedderSource::Rest], + Self::DIMENSIONS => { + &[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Rest] } + Self::DOCUMENT_TEMPLATE => &[ + EmbedderSource::HuggingFace, + EmbedderSource::OpenAi, + EmbedderSource::Ollama, + EmbedderSource::Rest, + ], + Self::URL => &[EmbedderSource::Rest], + Self::QUERY => &[EmbedderSource::Rest], + Self::INPUT_FIELD => &[EmbedderSource::Rest], + Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest], + Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest], + Self::INPUT_TYPE => &[EmbedderSource::Rest], _other => unreachable!("unknown field"), } } @@ -107,6 +148,18 @@ impl EmbeddingSettings { } EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE], EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS], + EmbedderSource::Rest => &[ + Self::SOURCE, + Self::API_KEY, + Self::DIMENSIONS, + Self::DOCUMENT_TEMPLATE, + Self::URL, + Self::QUERY, + Self::INPUT_FIELD, + Self::PATH_TO_EMBEDDINGS, + Self::EMBEDDING_OBJECT, + Self::INPUT_TYPE, + ], } } @@ -141,6 +194,7 @@ pub enum EmbedderSource { HuggingFace, Ollama, UserProvided, + Rest, } impl std::fmt::Display for EmbedderSource { @@ -150,6 +204,7 @@ impl std::fmt::Display for EmbedderSource { EmbedderSource::HuggingFace => "huggingFace", EmbedderSource::UserProvided => "userProvided", EmbedderSource::Ollama => "ollama", + EmbedderSource::Rest => "rest", }; f.write_str(s) } @@ -157,8 +212,20 @@ impl std::fmt::Display for EmbedderSource { impl EmbeddingSettings { pub fn apply(&mut self, new: Self) { - let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = - new; + let EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + } = new; let old_source = self.source; self.source.apply(source); // Reinitialize the whole setting object on a source change @@ -170,6 +237,12 @@ impl EmbeddingSettings { api_key, dimensions, document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, }; return; } @@ -179,6 +252,13 @@ impl EmbeddingSettings { self.api_key.apply(api_key); self.dimensions.apply(dimensions); self.document_template.apply(document_template); + + self.url.apply(url); + self.query.apply(query); + self.input_field.apply(input_field); + self.path_to_embeddings.apply(path_to_embeddings); + self.embedding_object.apply(embedding_object); + self.input_type.apply(input_type); } } @@ -193,6 +273,12 @@ impl From for EmbeddingSettings { api_key: Setting::NotSet, dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }, super::EmbedderOptions::OpenAi(options) => Self { source: Setting::Set(EmbedderSource::OpenAi), @@ -201,14 +287,26 @@ impl From for EmbeddingSettings { api_key: options.api_key.map(Setting::Set).unwrap_or_default(), dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(), document_template: Setting::Set(prompt.template), + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }, super::EmbedderOptions::Ollama(options) => Self { source: Setting::Set(EmbedderSource::Ollama), - model: Setting::Set(options.embedding_model.name().to_owned()), + model: Setting::Set(options.embedding_model.to_owned()), revision: Setting::NotSet, api_key: Setting::NotSet, dimensions: Setting::NotSet, document_template: Setting::Set(prompt.template), + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, }, super::EmbedderOptions::UserProvided(options) => Self { source: Setting::Set(EmbedderSource::UserProvided), @@ -217,6 +315,37 @@ impl From for EmbeddingSettings { api_key: Setting::NotSet, dimensions: Setting::Set(options.dimensions), document_template: Setting::NotSet, + url: Setting::NotSet, + query: Setting::NotSet, + input_field: Setting::NotSet, + path_to_embeddings: Setting::NotSet, + embedding_object: Setting::NotSet, + input_type: Setting::NotSet, + }, + super::EmbedderOptions::Rest(super::rest::EmbedderOptions { + api_key, + // TODO: support distribution + distribution: _, + dimensions, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + }) => Self { + source: Setting::Set(EmbedderSource::Rest), + model: Setting::NotSet, + revision: Setting::NotSet, + api_key: api_key.map(Setting::Set).unwrap_or_default(), + dimensions: dimensions.map(Setting::Set).unwrap_or_default(), + document_template: Setting::Set(prompt.template), + url: Setting::Set(url), + query: Setting::Set(query), + input_field: Setting::Set(input_field), + path_to_embeddings: Setting::Set(path_to_embeddings), + embedding_object: Setting::Set(embedding_object), + input_type: Setting::Set(input_type), }, } } @@ -225,8 +354,20 @@ impl From for EmbeddingSettings { impl From for EmbeddingConfig { fn from(value: EmbeddingSettings) -> Self { let mut this = Self::default(); - let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } = - value; + let EmbeddingSettings { + source, + model, + revision, + api_key, + dimensions, + document_template, + url, + query, + input_field, + path_to_embeddings, + embedding_object, + input_type, + } = value; if let Some(source) = source.set() { match source { EmbedderSource::OpenAi => { @@ -248,7 +389,7 @@ impl From for EmbeddingConfig { let mut options: ollama::EmbedderOptions = super::ollama::EmbedderOptions::with_default_model(); if let Some(model) = model.set() { - options.embedding_model = super::ollama::EmbeddingModel::from_name(&model); + options.embedding_model = model; } this.embedder_options = super::EmbedderOptions::Ollama(options); } @@ -274,6 +415,26 @@ impl From for EmbeddingConfig { dimensions: dimensions.set().unwrap(), }); } + EmbedderSource::Rest => { + let embedder_options = super::rest::EmbedderOptions::default(); + + this.embedder_options = + super::EmbedderOptions::Rest(super::rest::EmbedderOptions { + api_key: api_key.set(), + distribution: None, + dimensions: dimensions.set(), + url: url.set().unwrap(), + query: query.set().unwrap_or(embedder_options.query), + input_field: input_field.set().unwrap_or(embedder_options.input_field), + path_to_embeddings: path_to_embeddings + .set() + .unwrap_or(embedder_options.path_to_embeddings), + embedding_object: embedding_object + .set() + .unwrap_or(embedder_options.embedding_object), + input_type: input_type.set().unwrap_or(embedder_options.input_type), + }) + } } }