rest embedder: use json_template

This commit is contained in:
Louis Dureuil 2024-07-16 15:04:40 +02:00
parent 4109182ca4
commit a1beddd5d9
No known key found for this signature in database
2 changed files with 174 additions and 159 deletions

View File

@ -228,7 +228,9 @@ impl Embedder {
EmbedderOptions::UserProvided(options) => { EmbedderOptions::UserProvided(options) => {
Self::UserProvided(manual::Embedder::new(options)) Self::UserProvided(manual::Embedder::new(options))
} }
EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?), EmbedderOptions::Rest(options) => {
Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?)
}
}) })
} }

View File

@ -4,6 +4,7 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use super::error::EmbedErrorKind; use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate;
use super::{ use super::{
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
}; };
@ -11,12 +12,18 @@ use crate::error::FaultSource;
use crate::ThreadPoolNoAbort; use crate::ThreadPoolNoAbort;
// retrying in case of failure // retrying in case of failure
pub struct Retry { pub struct Retry {
pub error: EmbedError, pub error: EmbedError,
strategy: RetryStrategy, strategy: RetryStrategy,
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfigurationSource {
OpenAi,
Ollama,
User,
}
pub enum RetryStrategy { pub enum RetryStrategy {
GiveUp, GiveUp,
Retry, Retry,
@ -63,10 +70,20 @@ impl Retry {
#[derive(Debug)] #[derive(Debug)]
pub struct Embedder { pub struct Embedder {
client: ureq::Agent, data: EmbedderData,
options: EmbedderOptions,
bearer: Option<String>,
dimensions: usize, dimensions: usize,
distribution: Option<DistributionShift>,
}
/// All data needed to perform requests and parse responses
#[derive(Debug)]
struct EmbedderData {
client: ureq::Agent,
bearer: Option<String>,
url: String,
request: Request,
response: Response,
configuration_source: ConfigurationSource,
} }
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
@ -75,29 +92,8 @@ pub struct EmbedderOptions {
pub distribution: Option<DistributionShift>, pub distribution: Option<DistributionShift>,
pub dimensions: Option<usize>, pub dimensions: Option<usize>,
pub url: String, pub url: String,
pub query: serde_json::Value, pub request: serde_json::Value,
pub input_field: Vec<String>, pub response: serde_json::Value,
// path to the array of embeddings
pub path_to_embeddings: Vec<String>,
// shape of a single embedding
pub embedding_object: Vec<String>,
pub input_type: InputType,
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self {
url: Default::default(),
query: Default::default(),
input_field: vec!["input".into()],
path_to_embeddings: vec!["data".into()],
embedding_object: vec!["embedding".into()],
input_type: InputType::Text,
api_key: None,
distribution: None,
dimensions: None,
}
}
} }
impl std::hash::Hash for EmbedderOptions { impl std::hash::Hash for EmbedderOptions {
@ -106,26 +102,25 @@ impl std::hash::Hash for EmbedderOptions {
self.distribution.hash(state); self.distribution.hash(state);
self.dimensions.hash(state); self.dimensions.hash(state);
self.url.hash(state); self.url.hash(state);
// skip hashing the query // skip hashing the request and response
// collisions in regular usage should be minimal, // collisions in regular usage should be minimal,
// and the list is limited to 256 values anyway // and the list is limited to 256 values anyway
self.input_field.hash(state);
self.path_to_embeddings.hash(state);
self.embedding_object.hash(state);
self.input_type.hash(state);
} }
} }
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)] #[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
#[serde(rename_all = "camelCase")] #[serde(rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)] #[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum InputType { enum InputType {
Text, Text,
TextArray, TextArray,
} }
impl Embedder { impl Embedder {
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> { pub fn new(
options: EmbedderOptions,
configuration_source: ConfigurationSource,
) -> Result<Self, NewEmbedderError> {
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}")); let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
let client = ureq::AgentBuilder::new() let client = ureq::AgentBuilder::new()
@ -133,28 +128,40 @@ impl Embedder {
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2) .max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
.build(); .build();
let request = Request::new(options.request)?;
let response = Response::new(options.response, &request)?;
let data = EmbedderData {
client,
bearer,
url: options.url,
request,
response,
configuration_source,
};
let dimensions = if let Some(dimensions) = options.dimensions { let dimensions = if let Some(dimensions) = options.dimensions {
dimensions dimensions
} else { } else {
infer_dimensions(&client, &options, bearer.as_deref())? infer_dimensions(&data)?
}; };
Ok(Self { client, dimensions, options, bearer }) Ok(Self { data, dimensions, distribution: options.distribution })
} }
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> { pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len()) embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
} }
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError> pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
where where
S: AsRef<str> + Serialize, S: AsRef<str> + Serialize,
{ {
embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len()) embed(&self.data, texts, texts.len(), Some(self.dimensions))
} }
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> { pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?; let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
Ok(embeddings.pop().unwrap()) Ok(embeddings.pop().unwrap())
} }
@ -179,7 +186,7 @@ impl Embedder {
} }
pub fn prompt_count_in_chunk_hint(&self) -> usize { pub fn prompt_count_in_chunk_hint(&self) -> usize {
match self.options.input_type { match self.data.request.input_type() {
InputType::Text => 1, InputType::Text => 1,
InputType::TextArray => 10, InputType::TextArray => 10,
} }
@ -190,87 +197,44 @@ impl Embedder {
} }
pub fn distribution(&self) -> Option<DistributionShift> { pub fn distribution(&self) -> Option<DistributionShift> {
self.options.distribution self.distribution
} }
} }
fn infer_dimensions( fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
client: &ureq::Agent, let v = embed(data, ["test"].as_slice(), 1, None)
options: &EmbedderOptions,
bearer: Option<&str>,
) -> Result<usize, NewEmbedderError> {
let v = embed(client, options, bearer, ["test"].as_slice(), 1)
.map_err(NewEmbedderError::could_not_determine_dimension)?; .map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
Ok(v.first().unwrap().dimension()) Ok(v.first().unwrap().dimension())
} }
fn embed<S>( fn embed<S>(
client: &ureq::Agent, data: &EmbedderData,
options: &EmbedderOptions,
bearer: Option<&str>,
inputs: &[S], inputs: &[S],
expected_count: usize, expected_count: usize,
expected_dimension: Option<usize>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> ) -> Result<Vec<Embeddings<f32>>, EmbedError>
where where
S: Serialize, S: Serialize,
{ {
let request = client.post(&options.url); let request = data.client.post(&data.url);
let request = let request = if let Some(bearer) = &data.bearer {
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request }; request.set("Authorization", bearer)
} else {
request
};
let request = request.set("Content-Type", "application/json"); let request = request.set("Content-Type", "application/json");
let input_value = match options.input_type { let body = data.request.inject_texts(inputs);
InputType::Text => serde_json::json!(inputs.first()),
InputType::TextArray => serde_json::json!(inputs),
};
let body = match options.input_field.as_slice() {
[] => {
// inject input in body
input_value
}
[input] => {
let mut body = options.query.clone();
body.as_object_mut()
.ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.insert(input.clone(), input_value);
body
}
[path @ .., input] => {
let mut body = options.query.clone();
let mut current_value = &mut body;
for component in path {
current_value = current_value
.as_object_mut()
.ok_or_else(|| {
EmbedError::rest_not_an_object(
options.query.clone(),
options.input_field.clone(),
)
})?
.entry(component.clone())
.or_insert(serde_json::json!({}));
}
current_value.as_object_mut().unwrap().insert(input.clone(), input_value);
body
}
};
for attempt in 0..10 { for attempt in 0..10 {
let response = request.clone().send_json(&body); let response = request.clone().send_json(&body);
let result = check_response(response); let result = check_response(response, data.configuration_source);
let retry_duration = match result { let retry_duration = match result {
Ok(response) => return response_to_embedding(response, options, expected_count), Ok(response) => {
return response_to_embedding(response, data, expected_count, expected_dimension)
}
Err(retry) => { Err(retry) => {
tracing::warn!("Failed: {}", retry.error); tracing::warn!("Failed: {}", retry.error);
retry.into_duration(attempt) retry.into_duration(attempt)
@ -288,13 +252,16 @@ where
} }
let response = request.send_json(&body); let response = request.send_json(&body);
let result = check_response(response); let result = check_response(response, data.configuration_source);
result result.map_err(Retry::into_error).and_then(|response| {
.map_err(Retry::into_error) response_to_embedding(response, data, expected_count, expected_dimension)
.and_then(|response| response_to_embedding(response, options, expected_count)) })
} }
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> { fn check_response(
response: Result<ureq::Response, ureq::Error>,
configuration_source: ConfigurationSource,
) -> Result<ureq::Response, Retry> {
match response { match response {
Ok(response) => Ok(response), Ok(response) => Ok(response),
Err(ureq::Error::Status(code, response)) => { Err(ureq::Error::Status(code, response)) => {
@ -302,7 +269,10 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq:
Err(match code { Err(match code {
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)), 401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)), 429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
400 => Retry::give_up(EmbedError::rest_bad_request(error_response)), 400 => Retry::give_up(EmbedError::rest_bad_request(
error_response,
configuration_source,
)),
500..=599 => { 500..=599 => {
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response)) Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
} }
@ -320,68 +290,111 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq:
fn response_to_embedding( fn response_to_embedding(
response: ureq::Response, response: ureq::Response,
options: &EmbedderOptions, data: &EmbedderData,
expected_count: usize, expected_count: usize,
expected_dimensions: Option<usize>,
) -> Result<Vec<Embeddings<f32>>, EmbedError> { ) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let response: serde_json::Value = let response: serde_json::Value =
response.into_json().map_err(EmbedError::rest_response_deserialization)?; response.into_json().map_err(EmbedError::rest_response_deserialization)?;
let mut current_value = &response; let embeddings = data.response.extract_embeddings(response)?;
for component in &options.path_to_embeddings {
let component = component.as_ref();
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.path_to_embeddings,
)
})?;
}
let embeddings = match options.input_type {
InputType::Text => {
for component in &options.embedding_object {
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.embedding_object,
)
})?;
}
let embeddings = current_value.to_owned();
let embeddings: Embedding =
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
vec![Embeddings::from_single_embedding(embeddings)]
}
InputType::TextArray => {
let empty = vec![];
let values = current_value.as_array().unwrap_or(&empty);
let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count);
for value in values {
let mut current_value = value;
for component in &options.embedding_object {
current_value = current_value.get(component).ok_or_else(|| {
EmbedError::rest_response_missing_embeddings(
response.clone(),
component,
&options.embedding_object,
)
})?;
}
let embedding = current_value.to_owned();
let embedding: Embedding =
serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?;
embeddings.push(Embeddings::from_single_embedding(embedding));
}
embeddings
}
};
if embeddings.len() != expected_count { if embeddings.len() != expected_count {
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len())); return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
} }
if let Some(dimensions) = expected_dimensions {
for embedding in &embeddings {
if embedding.dimension() != dimensions {
return Err(EmbedError::rest_unexpected_dimension(
dimensions,
embedding.dimension(),
));
}
}
}
Ok(embeddings) Ok(embeddings)
} }
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
#[derive(Debug)]
pub struct Request {
template: ValueTemplate,
}
impl Request {
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
Ok(template) => template,
Err(error) => {
let message =
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
Ok(Self { template })
}
fn input_type(&self) -> InputType {
if self.template.has_array_value() {
InputType::TextArray
} else {
InputType::Text
}
}
pub fn inject_texts<S: Serialize>(
&self,
texts: impl IntoIterator<Item = S>,
) -> serde_json::Value {
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
}
}
#[derive(Debug)]
pub struct Response {
template: ValueTemplate,
}
impl Response {
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
{
Ok(template) => template,
Err(error) => {
let message =
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
match (template.has_array_value(), request.template.has_array_value()) {
(true, true) | (false, false) => Ok(Self {template}),
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
}
}
pub fn extract_embeddings(
&self,
response: serde_json::Value,
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
Ok(extracted_values) => extracted_values,
Err(error) => {
let error_message =
error.error_message("response", "{{embedding}}", "an array of numbers");
return Err(EmbedError::rest_extraction_error(error_message));
}
};
let embeddings: Vec<Embeddings<f32>> =
extracted_values.into_iter().map(Embeddings::from_single_embedding).collect();
Ok(embeddings)
}
}