mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-23 19:57:30 +01:00
Add RestEmbedder
This commit is contained in:
parent
c3d02f092d
commit
8708cbef25
@ -83,6 +83,32 @@ pub enum EmbedErrorKind {
|
|||||||
OllamaModelNotFoundError(OllamaError),
|
OllamaModelNotFoundError(OllamaError),
|
||||||
#[error("received unhandled HTTP status code {0} from Ollama")]
|
#[error("received unhandled HTTP status code {0} from Ollama")]
|
||||||
OllamaUnhandledStatusCode(u16),
|
OllamaUnhandledStatusCode(u16),
|
||||||
|
#[error("error serializing template context: {0}")]
|
||||||
|
RestTemplateContextSerialization(liquid::Error),
|
||||||
|
#[error(
|
||||||
|
"error rendering request template: {0}. Hint: available variable in the context: {{{{input}}}}'"
|
||||||
|
)]
|
||||||
|
RestTemplateError(liquid::Error),
|
||||||
|
#[error("error deserialization the response body as JSON: {0}")]
|
||||||
|
RestResponseDeserialization(std::io::Error),
|
||||||
|
#[error("component `{0}` not found in path `{1}` in response: `{2}`")]
|
||||||
|
RestResponseMissingEmbeddings(String, String, String),
|
||||||
|
#[error("expected a response parseable as a vector or an array of vectors: {0}")]
|
||||||
|
RestResponseFormat(serde_json::Error),
|
||||||
|
#[error("expected a response containing {0} embeddings, got only {1}")]
|
||||||
|
RestResponseEmbeddingCount(usize, usize),
|
||||||
|
#[error("could not authenticate against embedding server: {0:?}")]
|
||||||
|
RestUnauthorized(Option<String>),
|
||||||
|
#[error("sent too many requests to embedding server: {0:?}")]
|
||||||
|
RestTooManyRequests(Option<String>),
|
||||||
|
#[error("sent a bad request to embedding server: {0:?}")]
|
||||||
|
RestBadRequest(Option<String>),
|
||||||
|
#[error("received internal error from embedding server: {0:?}")]
|
||||||
|
RestInternalServerError(u16, Option<String>),
|
||||||
|
#[error("received HTTP {0} from embedding server: {0:?}")]
|
||||||
|
RestOtherStatusCode(u16, Option<String>),
|
||||||
|
#[error("could not reach embedding server: {0}")]
|
||||||
|
RestNetwork(ureq::Transport),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedError {
|
impl EmbedError {
|
||||||
@ -161,6 +187,89 @@ impl EmbedError {
|
|||||||
pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError {
|
pub(crate) fn ollama_unhandled_status_code(code: u16) -> EmbedError {
|
||||||
Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug }
|
Self { kind: EmbedErrorKind::OllamaUnhandledStatusCode(code), fault: FaultSource::Bug }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_template_context_serialization(error: liquid::Error) -> EmbedError {
|
||||||
|
Self {
|
||||||
|
kind: EmbedErrorKind::RestTemplateContextSerialization(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_template_render(error: liquid::Error) -> EmbedError {
|
||||||
|
Self { kind: EmbedErrorKind::RestTemplateError(error), fault: FaultSource::User }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
|
||||||
|
Self {
|
||||||
|
kind: EmbedErrorKind::RestResponseDeserialization(error),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_response_missing_embeddings<S: AsRef<str>>(
|
||||||
|
response: serde_json::Value,
|
||||||
|
component: &str,
|
||||||
|
response_field: &[S],
|
||||||
|
) -> EmbedError {
|
||||||
|
let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect();
|
||||||
|
let response_field = response_field.join(".");
|
||||||
|
|
||||||
|
Self {
|
||||||
|
kind: EmbedErrorKind::RestResponseMissingEmbeddings(
|
||||||
|
component.to_owned(),
|
||||||
|
response_field,
|
||||||
|
serde_json::to_string_pretty(&response).unwrap_or_default(),
|
||||||
|
),
|
||||||
|
fault: FaultSource::Undecided,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError {
|
||||||
|
Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError {
|
||||||
|
Self {
|
||||||
|
kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_unauthorized(error_response: Option<String>) -> EmbedError {
|
||||||
|
Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError {
|
||||||
|
Self {
|
||||||
|
kind: EmbedErrorKind::RestTooManyRequests(error_response),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_bad_request(error_response: Option<String>) -> EmbedError {
|
||||||
|
Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_internal_server_error(
|
||||||
|
code: u16,
|
||||||
|
error_response: Option<String>,
|
||||||
|
) -> EmbedError {
|
||||||
|
Self {
|
||||||
|
kind: EmbedErrorKind::RestInternalServerError(code, error_response),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError {
|
||||||
|
Self {
|
||||||
|
kind: EmbedErrorKind::RestOtherStatusCode(code, error_response),
|
||||||
|
fault: FaultSource::Undecided,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
|
||||||
|
Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
@ -11,6 +11,7 @@ pub mod openai;
|
|||||||
pub mod settings;
|
pub mod settings;
|
||||||
|
|
||||||
pub mod ollama;
|
pub mod ollama;
|
||||||
|
pub mod rest;
|
||||||
|
|
||||||
pub use self::error::Error;
|
pub use self::error::Error;
|
||||||
|
|
||||||
|
185
milli/src/vector/rest.rs
Normal file
185
milli/src/vector/rest.rs
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
|
|
||||||
|
use super::openai::Retry;
|
||||||
|
use super::{DistributionShift, EmbedError, Embeddings, NewEmbedderError};
|
||||||
|
use crate::VectorOrArrayOfVectors;
|
||||||
|
|
||||||
|
pub struct Embedder {
|
||||||
|
client: ureq::Agent,
|
||||||
|
options: EmbedderOptions,
|
||||||
|
bearer: Option<String>,
|
||||||
|
dimensions: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EmbedderOptions {
|
||||||
|
api_key: Option<String>,
|
||||||
|
distribution: Option<DistributionShift>,
|
||||||
|
dimensions: Option<usize>,
|
||||||
|
url: String,
|
||||||
|
query: liquid::Template,
|
||||||
|
response_field: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedder {
|
||||||
|
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||||
|
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer: {api_key}"));
|
||||||
|
|
||||||
|
let client = ureq::agent();
|
||||||
|
|
||||||
|
let dimensions = if let Some(dimensions) = options.dimensions {
|
||||||
|
dimensions
|
||||||
|
} else {
|
||||||
|
infer_dimensions(&client, &options, bearer.as_deref())?
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self { client, dimensions, options, bearer })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
|
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_chunks(
|
||||||
|
&self,
|
||||||
|
text_chunks: Vec<Vec<String>>,
|
||||||
|
threads: &rayon::ThreadPool,
|
||||||
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
|
threads
|
||||||
|
.install(move || text_chunks.into_par_iter().map(|chunk| self.embed(chunk)))
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dimensions(&self) -> usize {
|
||||||
|
self.dimensions
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
|
self.options.distribution
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_dimensions(
|
||||||
|
client: &ureq::Agent,
|
||||||
|
options: &EmbedderOptions,
|
||||||
|
bearer: Option<&str>,
|
||||||
|
) -> Result<usize, NewEmbedderError> {
|
||||||
|
let v = embed(client, options, bearer, ["test"].as_slice())
|
||||||
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
|
// unwrap: guaranteed that v.len() == ["test"].len() == 1, otherwise the previous line terminated in error
|
||||||
|
Ok(v.first().unwrap().dimension())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embed<S>(
|
||||||
|
client: &ureq::Agent,
|
||||||
|
options: &EmbedderOptions,
|
||||||
|
bearer: Option<&str>,
|
||||||
|
inputs: &[S],
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||||
|
where
|
||||||
|
S: serde::Serialize,
|
||||||
|
{
|
||||||
|
let request = client.post(&options.url);
|
||||||
|
let request =
|
||||||
|
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request };
|
||||||
|
let request = request.set("Content-Type", "application/json");
|
||||||
|
|
||||||
|
let body = options
|
||||||
|
.query
|
||||||
|
.render(
|
||||||
|
&liquid::to_object(&serde_json::json!({
|
||||||
|
"input": inputs,
|
||||||
|
}))
|
||||||
|
.map_err(EmbedError::rest_template_context_serialization)?,
|
||||||
|
)
|
||||||
|
.map_err(EmbedError::rest_template_render)?;
|
||||||
|
|
||||||
|
for attempt in 0..7 {
|
||||||
|
let response = request.send_string(&body);
|
||||||
|
let result = check_response(response);
|
||||||
|
|
||||||
|
let retry_duration = match result {
|
||||||
|
Ok(response) => {
|
||||||
|
return response_to_embedding(response, &options.response_field, inputs.len())
|
||||||
|
}
|
||||||
|
Err(retry) => {
|
||||||
|
tracing::warn!("Failed: {}", retry.error);
|
||||||
|
retry.into_duration(attempt)
|
||||||
|
}
|
||||||
|
}?;
|
||||||
|
|
||||||
|
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||||
|
tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
|
||||||
|
std::thread::sleep(retry_duration);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = request.send_string(&body);
|
||||||
|
let result = check_response(response);
|
||||||
|
result
|
||||||
|
.map_err(Retry::into_error)
|
||||||
|
.and_then(|response| response_to_embedding(response, &options.response_field, inputs.len()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> {
|
||||||
|
match response {
|
||||||
|
Ok(response) => Ok(response),
|
||||||
|
Err(ureq::Error::Status(code, response)) => {
|
||||||
|
let error_response: Option<String> = response.into_string().ok();
|
||||||
|
Err(match code {
|
||||||
|
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
|
||||||
|
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
|
||||||
|
400 => Retry::give_up(EmbedError::rest_bad_request(error_response)),
|
||||||
|
500..=599 => {
|
||||||
|
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
||||||
|
}
|
||||||
|
x => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
Err(ureq::Error::Transport(transport)) => {
|
||||||
|
Err(Retry::retry_later(EmbedError::rest_network(transport)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn response_to_embedding<S: AsRef<str>>(
|
||||||
|
response: ureq::Response,
|
||||||
|
response_field: &[S],
|
||||||
|
expected_count: usize,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
|
let response: serde_json::Value =
|
||||||
|
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
||||||
|
|
||||||
|
let mut current_value = &response;
|
||||||
|
for component in response_field {
|
||||||
|
let component = component.as_ref();
|
||||||
|
let current_value = current_value.get(component).ok_or_else(|| {
|
||||||
|
EmbedError::rest_response_missing_embeddings(response, component, response_field)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let embeddings = current_value.to_owned();
|
||||||
|
|
||||||
|
let embeddings: VectorOrArrayOfVectors =
|
||||||
|
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
|
||||||
|
|
||||||
|
let embeddings = embeddings.into_array_of_vectors();
|
||||||
|
|
||||||
|
let embeddings: Vec<Embeddings<f32>> = embeddings
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.map(|embedding| Embeddings::from_single_embedding(embedding))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if embeddings.len() != expected_count {
|
||||||
|
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user