mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-30 23:27:36 +01:00
462 lines
15 KiB
Rust
462 lines
15 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::time::Instant;
|
|
|
|
use deserr::Deserr;
|
|
use rand::Rng;
|
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
|
use rayon::slice::ParallelSlice as _;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use super::error::EmbedErrorKind;
|
|
use super::json_template::ValueTemplate;
|
|
use super::{DistributionShift, EmbedError, Embedding, NewEmbedderError, REQUEST_PARALLELISM};
|
|
use crate::error::FaultSource;
|
|
use crate::ThreadPoolNoAbort;
|
|
|
|
// retrying in case of failure
|
|
pub struct Retry {
|
|
pub error: EmbedError,
|
|
strategy: RetryStrategy,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum ConfigurationSource {
|
|
OpenAi,
|
|
Ollama,
|
|
User,
|
|
}
|
|
|
|
pub enum RetryStrategy {
|
|
GiveUp,
|
|
Retry,
|
|
RetryTokenized,
|
|
RetryAfterRateLimit,
|
|
}
|
|
|
|
impl Retry {
|
|
pub fn give_up(error: EmbedError) -> Self {
|
|
Self { error, strategy: RetryStrategy::GiveUp }
|
|
}
|
|
|
|
pub fn retry_later(error: EmbedError) -> Self {
|
|
Self { error, strategy: RetryStrategy::Retry }
|
|
}
|
|
|
|
pub fn retry_tokenized(error: EmbedError) -> Self {
|
|
Self { error, strategy: RetryStrategy::RetryTokenized }
|
|
}
|
|
|
|
pub fn rate_limited(error: EmbedError) -> Self {
|
|
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
|
}
|
|
|
|
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
|
|
match self.strategy {
|
|
RetryStrategy::GiveUp => Err(self.error),
|
|
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
|
|
RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
|
|
RetryStrategy::RetryAfterRateLimit => {
|
|
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn must_tokenize(&self) -> bool {
|
|
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
|
}
|
|
|
|
pub fn into_error(self) -> EmbedError {
|
|
self.error
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Embedder {
|
|
data: EmbedderData,
|
|
dimensions: usize,
|
|
distribution: Option<DistributionShift>,
|
|
}
|
|
|
|
/// All data needed to perform requests and parse responses
|
|
#[derive(Debug)]
|
|
struct EmbedderData {
|
|
client: ureq::Agent,
|
|
bearer: Option<String>,
|
|
headers: BTreeMap<String, String>,
|
|
url: String,
|
|
request: Request,
|
|
response: Response,
|
|
configuration_source: ConfigurationSource,
|
|
}
|
|
|
|
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
|
pub struct EmbedderOptions {
|
|
pub api_key: Option<String>,
|
|
pub distribution: Option<DistributionShift>,
|
|
pub dimensions: Option<usize>,
|
|
pub url: String,
|
|
pub request: serde_json::Value,
|
|
pub response: serde_json::Value,
|
|
pub headers: BTreeMap<String, String>,
|
|
}
|
|
|
|
impl std::hash::Hash for EmbedderOptions {
|
|
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
|
self.api_key.hash(state);
|
|
self.distribution.hash(state);
|
|
self.dimensions.hash(state);
|
|
self.url.hash(state);
|
|
// skip hashing the request and response
|
|
// collisions in regular usage should be minimal,
|
|
// and the list is limited to 256 values anyway
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
|
enum InputType {
|
|
Text,
|
|
TextArray,
|
|
}
|
|
|
|
impl Embedder {
|
|
pub fn new(
|
|
options: EmbedderOptions,
|
|
configuration_source: ConfigurationSource,
|
|
) -> Result<Self, NewEmbedderError> {
|
|
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
|
|
|
|
let client = ureq::AgentBuilder::new()
|
|
.max_idle_connections(REQUEST_PARALLELISM * 2)
|
|
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
|
|
.build();
|
|
|
|
let request = Request::new(options.request)?;
|
|
let response = Response::new(options.response, &request)?;
|
|
|
|
let data = EmbedderData {
|
|
client,
|
|
bearer,
|
|
url: options.url,
|
|
request,
|
|
response,
|
|
configuration_source,
|
|
headers: options.headers,
|
|
};
|
|
|
|
let dimensions = if let Some(dimensions) = options.dimensions {
|
|
dimensions
|
|
} else {
|
|
infer_dimensions(&data)?
|
|
};
|
|
|
|
Ok(Self { data, dimensions, distribution: options.distribution })
|
|
}
|
|
|
|
pub fn embed(
|
|
&self,
|
|
texts: Vec<String>,
|
|
deadline: Option<Instant>,
|
|
) -> Result<Vec<Embedding>, EmbedError> {
|
|
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions), deadline)
|
|
}
|
|
|
|
pub fn embed_ref<S>(
|
|
&self,
|
|
texts: &[S],
|
|
deadline: Option<Instant>,
|
|
) -> Result<Vec<Embedding>, EmbedError>
|
|
where
|
|
S: AsRef<str> + Serialize,
|
|
{
|
|
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline)
|
|
}
|
|
|
|
pub fn embed_tokens(
|
|
&self,
|
|
tokens: &[usize],
|
|
deadline: Option<Instant>,
|
|
) -> Result<Embedding, EmbedError> {
|
|
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions), deadline)?;
|
|
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
|
Ok(embeddings.pop().unwrap())
|
|
}
|
|
|
|
pub fn embed_chunks(
|
|
&self,
|
|
text_chunks: Vec<Vec<String>>,
|
|
threads: &ThreadPoolNoAbort,
|
|
) -> Result<Vec<Vec<Embedding>>, EmbedError> {
|
|
threads
|
|
.install(move || {
|
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk, None)).collect()
|
|
})
|
|
.map_err(|error| EmbedError {
|
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
|
fault: FaultSource::Bug,
|
|
})?
|
|
}
|
|
|
|
pub(crate) fn embed_chunks_ref(
|
|
&self,
|
|
texts: &[&str],
|
|
threads: &ThreadPoolNoAbort,
|
|
) -> Result<Vec<Embedding>, EmbedError> {
|
|
threads
|
|
.install(move || {
|
|
let embeddings: Result<Vec<Vec<Embedding>>, _> = texts
|
|
.par_chunks(self.prompt_count_in_chunk_hint())
|
|
.map(move |chunk| self.embed_ref(chunk, None))
|
|
.collect();
|
|
|
|
let embeddings = embeddings?;
|
|
Ok(embeddings.into_iter().flatten().collect())
|
|
})
|
|
.map_err(|error| EmbedError {
|
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
|
fault: FaultSource::Bug,
|
|
})?
|
|
}
|
|
|
|
pub fn chunk_count_hint(&self) -> usize {
|
|
super::REQUEST_PARALLELISM
|
|
}
|
|
|
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
|
match self.data.request.input_type() {
|
|
InputType::Text => 1,
|
|
InputType::TextArray => 10,
|
|
}
|
|
}
|
|
|
|
pub fn dimensions(&self) -> usize {
|
|
self.dimensions
|
|
}
|
|
|
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
|
self.distribution
|
|
}
|
|
}
|
|
|
|
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
|
let v = embed(data, ["test"].as_slice(), 1, None, None)
|
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
|
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
|
Ok(v.first().unwrap().len())
|
|
}
|
|
|
|
fn embed<S>(
|
|
data: &EmbedderData,
|
|
inputs: &[S],
|
|
expected_count: usize,
|
|
expected_dimension: Option<usize>,
|
|
deadline: Option<Instant>,
|
|
) -> Result<Vec<Embedding>, EmbedError>
|
|
where
|
|
S: Serialize,
|
|
{
|
|
let request = data.client.post(&data.url);
|
|
let request = if let Some(bearer) = &data.bearer {
|
|
request.set("Authorization", bearer)
|
|
} else {
|
|
request
|
|
};
|
|
let mut request = request.set("Content-Type", "application/json");
|
|
for (header, value) in &data.headers {
|
|
request = request.set(header.as_str(), value.as_str());
|
|
}
|
|
|
|
let body = data.request.inject_texts(inputs);
|
|
|
|
for attempt in 0..10 {
|
|
let response = request.clone().send_json(&body);
|
|
let result = check_response(response, data.configuration_source).and_then(|response| {
|
|
response_to_embedding(response, data, expected_count, expected_dimension)
|
|
});
|
|
|
|
let retry_duration = match result {
|
|
Ok(response) => return Ok(response),
|
|
Err(retry) => {
|
|
tracing::warn!("Failed: {}", retry.error);
|
|
if let Some(deadline) = deadline {
|
|
let now = std::time::Instant::now();
|
|
if now > deadline {
|
|
tracing::warn!("Could not embed due to deadline");
|
|
return Err(retry.into_error());
|
|
}
|
|
|
|
let duration_to_deadline = deadline - now;
|
|
retry.into_duration(attempt).map(|duration| duration.min(duration_to_deadline))
|
|
} else {
|
|
retry.into_duration(attempt)
|
|
}
|
|
}
|
|
}?;
|
|
|
|
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
|
|
|
// randomly up to double the retry duration
|
|
let retry_duration = retry_duration
|
|
+ rand::thread_rng().gen_range(std::time::Duration::ZERO..retry_duration);
|
|
|
|
tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
|
|
std::thread::sleep(retry_duration);
|
|
}
|
|
|
|
let response = request.send_json(&body);
|
|
let result = check_response(response, data.configuration_source);
|
|
result.map_err(Retry::into_error).and_then(|response| {
|
|
response_to_embedding(response, data, expected_count, expected_dimension)
|
|
.map_err(Retry::into_error)
|
|
})
|
|
}
|
|
|
|
fn check_response(
|
|
response: Result<ureq::Response, ureq::Error>,
|
|
configuration_source: ConfigurationSource,
|
|
) -> Result<ureq::Response, Retry> {
|
|
match response {
|
|
Ok(response) => Ok(response),
|
|
Err(ureq::Error::Status(code, response)) => {
|
|
let error_response: Option<String> = response.into_string().ok();
|
|
Err(match code {
|
|
401 => Retry::give_up(EmbedError::rest_unauthorized(
|
|
error_response,
|
|
configuration_source,
|
|
)),
|
|
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
|
|
400 => Retry::give_up(EmbedError::rest_bad_request(
|
|
error_response,
|
|
configuration_source,
|
|
)),
|
|
500..=599 => {
|
|
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
|
}
|
|
402..=499 => {
|
|
Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
|
|
}
|
|
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
|
|
})
|
|
}
|
|
Err(ureq::Error::Transport(transport)) => {
|
|
Err(Retry::retry_later(EmbedError::rest_network(transport)))
|
|
}
|
|
}
|
|
}
|
|
|
|
fn response_to_embedding(
|
|
response: ureq::Response,
|
|
data: &EmbedderData,
|
|
expected_count: usize,
|
|
expected_dimensions: Option<usize>,
|
|
) -> Result<Vec<Embedding>, Retry> {
|
|
let response: serde_json::Value = response
|
|
.into_json()
|
|
.map_err(EmbedError::rest_response_deserialization)
|
|
.map_err(Retry::retry_later)?;
|
|
|
|
let embeddings = data.response.extract_embeddings(response).map_err(Retry::give_up)?;
|
|
|
|
if embeddings.len() != expected_count {
|
|
return Err(Retry::give_up(EmbedError::rest_response_embedding_count(
|
|
expected_count,
|
|
embeddings.len(),
|
|
)));
|
|
}
|
|
|
|
if let Some(dimensions) = expected_dimensions {
|
|
for embedding in &embeddings {
|
|
if embedding.len() != dimensions {
|
|
return Err(Retry::give_up(EmbedError::rest_unexpected_dimension(
|
|
dimensions,
|
|
embedding.len(),
|
|
)));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(embeddings)
|
|
}
|
|
|
|
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
|
|
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
|
|
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
|
|
|
|
#[derive(Debug)]
|
|
pub struct Request {
|
|
template: ValueTemplate,
|
|
}
|
|
|
|
impl Request {
|
|
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
|
|
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
|
Ok(template) => template,
|
|
Err(error) => {
|
|
let message =
|
|
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
|
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
|
}
|
|
};
|
|
|
|
Ok(Self { template })
|
|
}
|
|
|
|
fn input_type(&self) -> InputType {
|
|
if self.template.has_array_value() {
|
|
InputType::TextArray
|
|
} else {
|
|
InputType::Text
|
|
}
|
|
}
|
|
|
|
pub fn inject_texts<S: Serialize>(
|
|
&self,
|
|
texts: impl IntoIterator<Item = S>,
|
|
) -> serde_json::Value {
|
|
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Response {
|
|
template: ValueTemplate,
|
|
}
|
|
|
|
impl Response {
|
|
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
|
|
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
|
|
{
|
|
Ok(template) => template,
|
|
Err(error) => {
|
|
let message =
|
|
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
|
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
|
}
|
|
};
|
|
|
|
match (template.has_array_value(), request.template.has_array_value()) {
|
|
(true, true) | (false, false) => Ok(Self {template}),
|
|
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
|
|
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
|
|
}
|
|
}
|
|
|
|
pub fn extract_embeddings(
|
|
&self,
|
|
response: serde_json::Value,
|
|
) -> Result<Vec<Embedding>, EmbedError> {
|
|
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
|
Ok(extracted_values) => extracted_values,
|
|
Err(error) => {
|
|
let error_message =
|
|
error.error_message("response", "{{embedding}}", "an array of numbers");
|
|
return Err(EmbedError::rest_extraction_error(error_message));
|
|
}
|
|
};
|
|
let embeddings: Vec<Embedding> = extracted_values.into_iter().collect();
|
|
|
|
Ok(embeddings)
|
|
}
|
|
}
|