From a1beddd5d988f1dbefc0bbe704b5f246318bd4bb Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 16 Jul 2024 15:04:40 +0200 Subject: [PATCH] rest embedder: use json_template --- milli/src/vector/mod.rs | 4 +- milli/src/vector/rest.rs | 329 ++++++++++++++++++++------------------- 2 files changed, 174 insertions(+), 159 deletions(-) diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index a3e9f7c2b..a1c937d24 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -228,7 +228,9 @@ impl Embedder { EmbedderOptions::UserProvided(options) => { Self::UserProvided(manual::Embedder::new(options)) } - EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?), + EmbedderOptions::Rest(options) => { + Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?) + } }) } diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index b651cba63..35a7ebc41 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -4,6 +4,7 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use serde::{Deserialize, Serialize}; use super::error::EmbedErrorKind; +use super::json_template::ValueTemplate; use super::{ DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, }; @@ -11,12 +12,18 @@ use crate::error::FaultSource; use crate::ThreadPoolNoAbort; // retrying in case of failure - pub struct Retry { pub error: EmbedError, strategy: RetryStrategy, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConfigurationSource { + OpenAi, + Ollama, + User, +} + pub enum RetryStrategy { GiveUp, Retry, @@ -63,10 +70,20 @@ impl Retry { #[derive(Debug)] pub struct Embedder { - client: ureq::Agent, - options: EmbedderOptions, - bearer: Option, + data: EmbedderData, dimensions: usize, + distribution: Option, +} + +/// All data needed to perform requests and parse responses +#[derive(Debug)] +struct EmbedderData { + client: ureq::Agent, + bearer: Option, + url: String, + request: Request, + response: Response, + configuration_source: ConfigurationSource, } #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] @@ -75,29 +92,8 @@ pub struct EmbedderOptions { pub distribution: Option, pub dimensions: Option, pub url: String, - pub query: serde_json::Value, - pub input_field: Vec, - // path to the array of embeddings - pub path_to_embeddings: Vec, - // shape of a single embedding - pub embedding_object: Vec, - pub input_type: InputType, -} - -impl Default for EmbedderOptions { - fn default() -> Self { - Self { - url: Default::default(), - query: Default::default(), - input_field: vec!["input".into()], - path_to_embeddings: vec!["data".into()], - embedding_object: vec!["embedding".into()], - input_type: InputType::Text, - api_key: None, - distribution: None, - dimensions: None, - } - } + pub request: serde_json::Value, + pub response: serde_json::Value, } impl std::hash::Hash for EmbedderOptions { @@ -106,26 +102,25 @@ impl std::hash::Hash for EmbedderOptions { self.distribution.hash(state); self.dimensions.hash(state); self.url.hash(state); - // skip hashing the query + // skip hashing the request and response // collisions in regular usage should be minimal, // and the list is limited to 256 values anyway - self.input_field.hash(state); - self.path_to_embeddings.hash(state); - self.embedding_object.hash(state); - self.input_type.hash(state); } } #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)] #[serde(rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] -pub enum InputType { +enum InputType { Text, TextArray, } impl Embedder { - pub fn new(options: EmbedderOptions) -> Result { + pub fn new( + options: EmbedderOptions, + configuration_source: ConfigurationSource, + ) -> Result { let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}")); let client = ureq::AgentBuilder::new() @@ -133,28 +128,40 @@ impl Embedder { .max_idle_connections_per_host(REQUEST_PARALLELISM * 2) .build(); + let request = Request::new(options.request)?; + let response = Response::new(options.response, &request)?; + + let data = EmbedderData { + client, + bearer, + url: options.url, + request, + response, + configuration_source, + }; + let dimensions = if let Some(dimensions) = options.dimensions { dimensions } else { - infer_dimensions(&client, &options, bearer.as_deref())? + infer_dimensions(&data)? }; - Ok(Self { client, dimensions, options, bearer }) + Ok(Self { data, dimensions, distribution: options.distribution }) } pub fn embed(&self, texts: Vec) -> Result>, EmbedError> { - embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len()) + embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions)) } pub fn embed_ref(&self, texts: &[S]) -> Result>, EmbedError> where S: AsRef + Serialize, { - embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len()) + embed(&self.data, texts, texts.len(), Some(self.dimensions)) } pub fn embed_tokens(&self, tokens: &[usize]) -> Result, EmbedError> { - let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?; + let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?; // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error Ok(embeddings.pop().unwrap()) } @@ -179,7 +186,7 @@ impl Embedder { } pub fn prompt_count_in_chunk_hint(&self) -> usize { - match self.options.input_type { + match self.data.request.input_type() { InputType::Text => 1, InputType::TextArray => 10, } @@ -190,87 +197,44 @@ impl Embedder { } pub fn distribution(&self) -> Option { - self.options.distribution + self.distribution } } -fn infer_dimensions( - client: &ureq::Agent, - options: &EmbedderOptions, - bearer: Option<&str>, -) -> Result { - let v = embed(client, options, bearer, ["test"].as_slice(), 1) +fn infer_dimensions(data: &EmbedderData) -> Result { + let v = embed(data, ["test"].as_slice(), 1, None) .map_err(NewEmbedderError::could_not_determine_dimension)?; // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error Ok(v.first().unwrap().dimension()) } fn embed( - client: &ureq::Agent, - options: &EmbedderOptions, - bearer: Option<&str>, + data: &EmbedderData, inputs: &[S], expected_count: usize, + expected_dimension: Option, ) -> Result>, EmbedError> where S: Serialize, { - let request = client.post(&options.url); - let request = - if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; + let request = data.client.post(&data.url); + let request = if let Some(bearer) = &data.bearer { + request.set("Authorization", bearer) + } else { + request + }; let request = request.set("Content-Type", "application/json"); - let input_value = match options.input_type { - InputType::Text => serde_json::json!(inputs.first()), - InputType::TextArray => serde_json::json!(inputs), - }; - - let body = match options.input_field.as_slice() { - [] => { - // inject input in body - input_value - } - [input] => { - let mut body = options.query.clone(); - - body.as_object_mut() - .ok_or_else(|| { - EmbedError::rest_not_an_object( - options.query.clone(), - options.input_field.clone(), - ) - })? - .insert(input.clone(), input_value); - body - } - [path @ .., input] => { - let mut body = options.query.clone(); - - let mut current_value = &mut body; - for component in path { - current_value = current_value - .as_object_mut() - .ok_or_else(|| { - EmbedError::rest_not_an_object( - options.query.clone(), - options.input_field.clone(), - ) - })? - .entry(component.clone()) - .or_insert(serde_json::json!({})); - } - - current_value.as_object_mut().unwrap().insert(input.clone(), input_value); - body - } - }; + let body = data.request.inject_texts(inputs); for attempt in 0..10 { let response = request.clone().send_json(&body); - let result = check_response(response); + let result = check_response(response, data.configuration_source); let retry_duration = match result { - Ok(response) => return response_to_embedding(response, options, expected_count), + Ok(response) => { + return response_to_embedding(response, data, expected_count, expected_dimension) + } Err(retry) => { tracing::warn!("Failed: {}", retry.error); retry.into_duration(attempt) @@ -288,13 +252,16 @@ where } let response = request.send_json(&body); - let result = check_response(response); - result - .map_err(Retry::into_error) - .and_then(|response| response_to_embedding(response, options, expected_count)) + let result = check_response(response, data.configuration_source); + result.map_err(Retry::into_error).and_then(|response| { + response_to_embedding(response, data, expected_count, expected_dimension) + }) } -fn check_response(response: Result) -> Result { +fn check_response( + response: Result, + configuration_source: ConfigurationSource, +) -> Result { match response { Ok(response) => Ok(response), Err(ureq::Error::Status(code, response)) => { @@ -302,7 +269,10 @@ fn check_response(response: Result) -> Result Retry::give_up(EmbedError::rest_unauthorized(error_response)), 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), - 400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), + 400 => Retry::give_up(EmbedError::rest_bad_request( + error_response, + configuration_source, + )), 500..=599 => { Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) } @@ -320,68 +290,111 @@ fn check_response(response: Result) -> Result, ) -> Result>, EmbedError> { let response: serde_json::Value = response.into_json().map_err(EmbedError::rest_response_deserialization)?; - let mut current_value = &response; - for component in &options.path_to_embeddings { - let component = component.as_ref(); - current_value = current_value.get(component).ok_or_else(|| { - EmbedError::rest_response_missing_embeddings( - response.clone(), - component, - &options.path_to_embeddings, - ) - })?; - } - - let embeddings = match options.input_type { - InputType::Text => { - for component in &options.embedding_object { - current_value = current_value.get(component).ok_or_else(|| { - EmbedError::rest_response_missing_embeddings( - response.clone(), - component, - &options.embedding_object, - ) - })?; - } - let embeddings = current_value.to_owned(); - let embeddings: Embedding = - serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?; - - vec![Embeddings::from_single_embedding(embeddings)] - } - InputType::TextArray => { - let empty = vec![]; - let values = current_value.as_array().unwrap_or(&empty); - let mut embeddings: Vec> = Vec::with_capacity(expected_count); - for value in values { - let mut current_value = value; - for component in &options.embedding_object { - current_value = current_value.get(component).ok_or_else(|| { - EmbedError::rest_response_missing_embeddings( - response.clone(), - component, - &options.embedding_object, - ) - })?; - } - let embedding = current_value.to_owned(); - let embedding: Embedding = - serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?; - embeddings.push(Embeddings::from_single_embedding(embedding)); - } - embeddings - } - }; + let embeddings = data.response.extract_embeddings(response)?; if embeddings.len() != expected_count { return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); } + if let Some(dimensions) = expected_dimensions { + for embedding in &embeddings { + if embedding.dimension() != dimensions { + return Err(EmbedError::rest_unexpected_dimension( + dimensions, + embedding.dimension(), + )); + } + } + } + Ok(embeddings) } + +pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}"; +pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}"; +pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}"; + +#[derive(Debug)] +pub struct Request { + template: ValueTemplate, +} + +impl Request { + pub fn new(template: serde_json::Value) -> Result { + let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) { + Ok(template) => template, + Err(error) => { + let message = + error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER); + return Err(NewEmbedderError::rest_could_not_parse_template(message)); + } + }; + + Ok(Self { template }) + } + + fn input_type(&self) -> InputType { + if self.template.has_array_value() { + InputType::TextArray + } else { + InputType::Text + } + } + + pub fn inject_texts( + &self, + texts: impl IntoIterator, + ) -> serde_json::Value { + self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap() + } +} + +#[derive(Debug)] +pub struct Response { + template: ValueTemplate, +} + +impl Response { + pub fn new(template: serde_json::Value, request: &Request) -> Result { + let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) + { + Ok(template) => template, + Err(error) => { + let message = + error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER); + return Err(NewEmbedderError::rest_could_not_parse_template(message)); + } + }; + + match (template.has_array_value(), request.template.has_array_value()) { + (true, true) | (false, false) => Ok(Self {template}), + (true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())), + (false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())), + } + } + + pub fn extract_embeddings( + &self, + response: serde_json::Value, + ) -> Result>, EmbedError> { + let extracted_values: Vec = match self.template.extract(response) { + Ok(extracted_values) => extracted_values, + Err(error) => { + let error_message = + error.error_message("response", "{{embedding}}", "an array of numbers"); + return Err(EmbedError::rest_extraction_error(error_message)); + } + }; + let embeddings: Vec> = + extracted_values.into_iter().map(Embeddings::from_single_embedding).collect(); + + Ok(embeddings) + } +}