mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-09 22:48:54 +01:00
rest embedder: use json_template
This commit is contained in:
parent
4109182ca4
commit
a1beddd5d9
@ -228,7 +228,9 @@ impl Embedder {
|
|||||||
EmbedderOptions::UserProvided(options) => {
|
EmbedderOptions::UserProvided(options) => {
|
||||||
Self::UserProvided(manual::Embedder::new(options))
|
Self::UserProvided(manual::Embedder::new(options))
|
||||||
}
|
}
|
||||||
EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?),
|
EmbedderOptions::Rest(options) => {
|
||||||
|
Self::Rest(rest::Embedder::new(options, rest::ConfigurationSource::User)?)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use super::error::EmbedErrorKind;
|
use super::error::EmbedErrorKind;
|
||||||
|
use super::json_template::ValueTemplate;
|
||||||
use super::{
|
use super::{
|
||||||
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
||||||
};
|
};
|
||||||
@ -11,12 +12,18 @@ use crate::error::FaultSource;
|
|||||||
use crate::ThreadPoolNoAbort;
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
// retrying in case of failure
|
// retrying in case of failure
|
||||||
|
|
||||||
pub struct Retry {
|
pub struct Retry {
|
||||||
pub error: EmbedError,
|
pub error: EmbedError,
|
||||||
strategy: RetryStrategy,
|
strategy: RetryStrategy,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum ConfigurationSource {
|
||||||
|
OpenAi,
|
||||||
|
Ollama,
|
||||||
|
User,
|
||||||
|
}
|
||||||
|
|
||||||
pub enum RetryStrategy {
|
pub enum RetryStrategy {
|
||||||
GiveUp,
|
GiveUp,
|
||||||
Retry,
|
Retry,
|
||||||
@ -63,10 +70,20 @@ impl Retry {
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
client: ureq::Agent,
|
data: EmbedderData,
|
||||||
options: EmbedderOptions,
|
|
||||||
bearer: Option<String>,
|
|
||||||
dimensions: usize,
|
dimensions: usize,
|
||||||
|
distribution: Option<DistributionShift>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// All data needed to perform requests and parse responses
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct EmbedderData {
|
||||||
|
client: ureq::Agent,
|
||||||
|
bearer: Option<String>,
|
||||||
|
url: String,
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
configuration_source: ConfigurationSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
@ -75,29 +92,8 @@ pub struct EmbedderOptions {
|
|||||||
pub distribution: Option<DistributionShift>,
|
pub distribution: Option<DistributionShift>,
|
||||||
pub dimensions: Option<usize>,
|
pub dimensions: Option<usize>,
|
||||||
pub url: String,
|
pub url: String,
|
||||||
pub query: serde_json::Value,
|
pub request: serde_json::Value,
|
||||||
pub input_field: Vec<String>,
|
pub response: serde_json::Value,
|
||||||
// path to the array of embeddings
|
|
||||||
pub path_to_embeddings: Vec<String>,
|
|
||||||
// shape of a single embedding
|
|
||||||
pub embedding_object: Vec<String>,
|
|
||||||
pub input_type: InputType,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Default for EmbedderOptions {
|
|
||||||
fn default() -> Self {
|
|
||||||
Self {
|
|
||||||
url: Default::default(),
|
|
||||||
query: Default::default(),
|
|
||||||
input_field: vec!["input".into()],
|
|
||||||
path_to_embeddings: vec!["data".into()],
|
|
||||||
embedding_object: vec!["embedding".into()],
|
|
||||||
input_type: InputType::Text,
|
|
||||||
api_key: None,
|
|
||||||
distribution: None,
|
|
||||||
dimensions: None,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl std::hash::Hash for EmbedderOptions {
|
impl std::hash::Hash for EmbedderOptions {
|
||||||
@ -106,26 +102,25 @@ impl std::hash::Hash for EmbedderOptions {
|
|||||||
self.distribution.hash(state);
|
self.distribution.hash(state);
|
||||||
self.dimensions.hash(state);
|
self.dimensions.hash(state);
|
||||||
self.url.hash(state);
|
self.url.hash(state);
|
||||||
// skip hashing the query
|
// skip hashing the request and response
|
||||||
// collisions in regular usage should be minimal,
|
// collisions in regular usage should be minimal,
|
||||||
// and the list is limited to 256 values anyway
|
// and the list is limited to 256 values anyway
|
||||||
self.input_field.hash(state);
|
|
||||||
self.path_to_embeddings.hash(state);
|
|
||||||
self.embedding_object.hash(state);
|
|
||||||
self.input_type.hash(state);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
|
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
|
||||||
#[serde(rename_all = "camelCase")]
|
#[serde(rename_all = "camelCase")]
|
||||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||||
pub enum InputType {
|
enum InputType {
|
||||||
Text,
|
Text,
|
||||||
TextArray,
|
TextArray,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Embedder {
|
impl Embedder {
|
||||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
pub fn new(
|
||||||
|
options: EmbedderOptions,
|
||||||
|
configuration_source: ConfigurationSource,
|
||||||
|
) -> Result<Self, NewEmbedderError> {
|
||||||
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
|
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
|
||||||
|
|
||||||
let client = ureq::AgentBuilder::new()
|
let client = ureq::AgentBuilder::new()
|
||||||
@ -133,28 +128,40 @@ impl Embedder {
|
|||||||
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
|
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
|
let request = Request::new(options.request)?;
|
||||||
|
let response = Response::new(options.response, &request)?;
|
||||||
|
|
||||||
|
let data = EmbedderData {
|
||||||
|
client,
|
||||||
|
bearer,
|
||||||
|
url: options.url,
|
||||||
|
request,
|
||||||
|
response,
|
||||||
|
configuration_source,
|
||||||
|
};
|
||||||
|
|
||||||
let dimensions = if let Some(dimensions) = options.dimensions {
|
let dimensions = if let Some(dimensions) = options.dimensions {
|
||||||
dimensions
|
dimensions
|
||||||
} else {
|
} else {
|
||||||
infer_dimensions(&client, &options, bearer.as_deref())?
|
infer_dimensions(&data)?
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self { client, dimensions, options, bearer })
|
Ok(Self { data, dimensions, distribution: options.distribution })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len())
|
embed(&self.data, texts.as_slice(), texts.len(), Some(self.dimensions))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||||
where
|
where
|
||||||
S: AsRef<str> + Serialize,
|
S: AsRef<str> + Serialize,
|
||||||
{
|
{
|
||||||
embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len())
|
embed(&self.data, texts, texts.len(), Some(self.dimensions))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
|
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
|
||||||
let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?;
|
let mut embeddings = embed(&self.data, tokens, 1, Some(self.dimensions))?;
|
||||||
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
||||||
Ok(embeddings.pop().unwrap())
|
Ok(embeddings.pop().unwrap())
|
||||||
}
|
}
|
||||||
@ -179,7 +186,7 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||||
match self.options.input_type {
|
match self.data.request.input_type() {
|
||||||
InputType::Text => 1,
|
InputType::Text => 1,
|
||||||
InputType::TextArray => 10,
|
InputType::TextArray => 10,
|
||||||
}
|
}
|
||||||
@ -190,87 +197,44 @@ impl Embedder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||||
self.options.distribution
|
self.distribution
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_dimensions(
|
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
||||||
client: &ureq::Agent,
|
let v = embed(data, ["test"].as_slice(), 1, None)
|
||||||
options: &EmbedderOptions,
|
|
||||||
bearer: Option<&str>,
|
|
||||||
) -> Result<usize, NewEmbedderError> {
|
|
||||||
let v = embed(client, options, bearer, ["test"].as_slice(), 1)
|
|
||||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||||
Ok(v.first().unwrap().dimension())
|
Ok(v.first().unwrap().dimension())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn embed<S>(
|
fn embed<S>(
|
||||||
client: &ureq::Agent,
|
data: &EmbedderData,
|
||||||
options: &EmbedderOptions,
|
|
||||||
bearer: Option<&str>,
|
|
||||||
inputs: &[S],
|
inputs: &[S],
|
||||||
expected_count: usize,
|
expected_count: usize,
|
||||||
|
expected_dimension: Option<usize>,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||||
where
|
where
|
||||||
S: Serialize,
|
S: Serialize,
|
||||||
{
|
{
|
||||||
let request = client.post(&options.url);
|
let request = data.client.post(&data.url);
|
||||||
let request =
|
let request = if let Some(bearer) = &data.bearer {
|
||||||
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request };
|
request.set("Authorization", bearer)
|
||||||
|
} else {
|
||||||
|
request
|
||||||
|
};
|
||||||
let request = request.set("Content-Type", "application/json");
|
let request = request.set("Content-Type", "application/json");
|
||||||
|
|
||||||
let input_value = match options.input_type {
|
let body = data.request.inject_texts(inputs);
|
||||||
InputType::Text => serde_json::json!(inputs.first()),
|
|
||||||
InputType::TextArray => serde_json::json!(inputs),
|
|
||||||
};
|
|
||||||
|
|
||||||
let body = match options.input_field.as_slice() {
|
|
||||||
[] => {
|
|
||||||
// inject input in body
|
|
||||||
input_value
|
|
||||||
}
|
|
||||||
[input] => {
|
|
||||||
let mut body = options.query.clone();
|
|
||||||
|
|
||||||
body.as_object_mut()
|
|
||||||
.ok_or_else(|| {
|
|
||||||
EmbedError::rest_not_an_object(
|
|
||||||
options.query.clone(),
|
|
||||||
options.input_field.clone(),
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.insert(input.clone(), input_value);
|
|
||||||
body
|
|
||||||
}
|
|
||||||
[path @ .., input] => {
|
|
||||||
let mut body = options.query.clone();
|
|
||||||
|
|
||||||
let mut current_value = &mut body;
|
|
||||||
for component in path {
|
|
||||||
current_value = current_value
|
|
||||||
.as_object_mut()
|
|
||||||
.ok_or_else(|| {
|
|
||||||
EmbedError::rest_not_an_object(
|
|
||||||
options.query.clone(),
|
|
||||||
options.input_field.clone(),
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.entry(component.clone())
|
|
||||||
.or_insert(serde_json::json!({}));
|
|
||||||
}
|
|
||||||
|
|
||||||
current_value.as_object_mut().unwrap().insert(input.clone(), input_value);
|
|
||||||
body
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
for attempt in 0..10 {
|
for attempt in 0..10 {
|
||||||
let response = request.clone().send_json(&body);
|
let response = request.clone().send_json(&body);
|
||||||
let result = check_response(response);
|
let result = check_response(response, data.configuration_source);
|
||||||
|
|
||||||
let retry_duration = match result {
|
let retry_duration = match result {
|
||||||
Ok(response) => return response_to_embedding(response, options, expected_count),
|
Ok(response) => {
|
||||||
|
return response_to_embedding(response, data, expected_count, expected_dimension)
|
||||||
|
}
|
||||||
Err(retry) => {
|
Err(retry) => {
|
||||||
tracing::warn!("Failed: {}", retry.error);
|
tracing::warn!("Failed: {}", retry.error);
|
||||||
retry.into_duration(attempt)
|
retry.into_duration(attempt)
|
||||||
@ -288,13 +252,16 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
let response = request.send_json(&body);
|
let response = request.send_json(&body);
|
||||||
let result = check_response(response);
|
let result = check_response(response, data.configuration_source);
|
||||||
result
|
result.map_err(Retry::into_error).and_then(|response| {
|
||||||
.map_err(Retry::into_error)
|
response_to_embedding(response, data, expected_count, expected_dimension)
|
||||||
.and_then(|response| response_to_embedding(response, options, expected_count))
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> {
|
fn check_response(
|
||||||
|
response: Result<ureq::Response, ureq::Error>,
|
||||||
|
configuration_source: ConfigurationSource,
|
||||||
|
) -> Result<ureq::Response, Retry> {
|
||||||
match response {
|
match response {
|
||||||
Ok(response) => Ok(response),
|
Ok(response) => Ok(response),
|
||||||
Err(ureq::Error::Status(code, response)) => {
|
Err(ureq::Error::Status(code, response)) => {
|
||||||
@ -302,7 +269,10 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq:
|
|||||||
Err(match code {
|
Err(match code {
|
||||||
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
|
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
|
||||||
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
|
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
|
||||||
400 => Retry::give_up(EmbedError::rest_bad_request(error_response)),
|
400 => Retry::give_up(EmbedError::rest_bad_request(
|
||||||
|
error_response,
|
||||||
|
configuration_source,
|
||||||
|
)),
|
||||||
500..=599 => {
|
500..=599 => {
|
||||||
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
||||||
}
|
}
|
||||||
@ -320,68 +290,111 @@ fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq:
|
|||||||
|
|
||||||
fn response_to_embedding(
|
fn response_to_embedding(
|
||||||
response: ureq::Response,
|
response: ureq::Response,
|
||||||
options: &EmbedderOptions,
|
data: &EmbedderData,
|
||||||
expected_count: usize,
|
expected_count: usize,
|
||||||
|
expected_dimensions: Option<usize>,
|
||||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
let response: serde_json::Value =
|
let response: serde_json::Value =
|
||||||
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
||||||
|
|
||||||
let mut current_value = &response;
|
let embeddings = data.response.extract_embeddings(response)?;
|
||||||
for component in &options.path_to_embeddings {
|
|
||||||
let component = component.as_ref();
|
|
||||||
current_value = current_value.get(component).ok_or_else(|| {
|
|
||||||
EmbedError::rest_response_missing_embeddings(
|
|
||||||
response.clone(),
|
|
||||||
component,
|
|
||||||
&options.path_to_embeddings,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let embeddings = match options.input_type {
|
|
||||||
InputType::Text => {
|
|
||||||
for component in &options.embedding_object {
|
|
||||||
current_value = current_value.get(component).ok_or_else(|| {
|
|
||||||
EmbedError::rest_response_missing_embeddings(
|
|
||||||
response.clone(),
|
|
||||||
component,
|
|
||||||
&options.embedding_object,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
let embeddings = current_value.to_owned();
|
|
||||||
let embeddings: Embedding =
|
|
||||||
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
|
|
||||||
|
|
||||||
vec![Embeddings::from_single_embedding(embeddings)]
|
|
||||||
}
|
|
||||||
InputType::TextArray => {
|
|
||||||
let empty = vec![];
|
|
||||||
let values = current_value.as_array().unwrap_or(&empty);
|
|
||||||
let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count);
|
|
||||||
for value in values {
|
|
||||||
let mut current_value = value;
|
|
||||||
for component in &options.embedding_object {
|
|
||||||
current_value = current_value.get(component).ok_or_else(|| {
|
|
||||||
EmbedError::rest_response_missing_embeddings(
|
|
||||||
response.clone(),
|
|
||||||
component,
|
|
||||||
&options.embedding_object,
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
}
|
|
||||||
let embedding = current_value.to_owned();
|
|
||||||
let embedding: Embedding =
|
|
||||||
serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?;
|
|
||||||
embeddings.push(Embeddings::from_single_embedding(embedding));
|
|
||||||
}
|
|
||||||
embeddings
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if embeddings.len() != expected_count {
|
if embeddings.len() != expected_count {
|
||||||
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
|
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(dimensions) = expected_dimensions {
|
||||||
|
for embedding in &embeddings {
|
||||||
|
if embedding.dimension() != dimensions {
|
||||||
|
return Err(EmbedError::rest_unexpected_dimension(
|
||||||
|
dimensions,
|
||||||
|
embedding.dimension(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(embeddings)
|
Ok(embeddings)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
|
||||||
|
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
|
||||||
|
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Request {
|
||||||
|
template: ValueTemplate,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Request {
|
||||||
|
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
|
||||||
|
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
||||||
|
Ok(template) => template,
|
||||||
|
Err(error) => {
|
||||||
|
let message =
|
||||||
|
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||||
|
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self { template })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_type(&self) -> InputType {
|
||||||
|
if self.template.has_array_value() {
|
||||||
|
InputType::TextArray
|
||||||
|
} else {
|
||||||
|
InputType::Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inject_texts<S: Serialize>(
|
||||||
|
&self,
|
||||||
|
texts: impl IntoIterator<Item = S>,
|
||||||
|
) -> serde_json::Value {
|
||||||
|
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Response {
|
||||||
|
template: ValueTemplate,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Response {
|
||||||
|
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
|
||||||
|
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
|
||||||
|
{
|
||||||
|
Ok(template) => template,
|
||||||
|
Err(error) => {
|
||||||
|
let message =
|
||||||
|
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||||
|
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
match (template.has_array_value(), request.template.has_array_value()) {
|
||||||
|
(true, true) | (false, false) => Ok(Self {template}),
|
||||||
|
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
|
||||||
|
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn extract_embeddings(
|
||||||
|
&self,
|
||||||
|
response: serde_json::Value,
|
||||||
|
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||||
|
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
||||||
|
Ok(extracted_values) => extracted_values,
|
||||||
|
Err(error) => {
|
||||||
|
let error_message =
|
||||||
|
error.error_message("response", "{{embedding}}", "an array of numbers");
|
||||||
|
return Err(EmbedError::rest_extraction_error(error_message));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let embeddings: Vec<Embeddings<f32>> =
|
||||||
|
extracted_values.into_iter().map(Embeddings::from_single_embedding).collect();
|
||||||
|
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user