REST embedder supports fragments

This commit is contained in:
Louis Dureuil 2025-06-29 23:54:06 +02:00
parent e7b9b8f002
commit 4235a82dcf
No known key found for this signature in database

View file

@ -6,11 +6,13 @@ use rand::Rng;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use rayon::slice::ParallelSlice as _;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use super::error::EmbedErrorKind;
use super::json_template::ValueTemplate;
use super::json_template::{InjectableValue, JsonTemplate};
use super::{
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, SearchQuery,
REQUEST_PARALLELISM,
};
use crate::error::FaultSource;
use crate::progress::EmbedderStats;
@ -88,19 +90,54 @@ struct EmbedderData {
bearer: Option<String>,
headers: BTreeMap<String, String>,
url: String,
request: Request,
request: RequestData,
response: Response,
configuration_source: ConfigurationSource,
}
#[derive(Debug)]
pub enum RequestData {
Single(Request),
FromFragments(RequestFromFragments),
}
impl RequestData {
pub fn new(
request: Value,
indexing_fragments: BTreeMap<String, Value>,
search_fragments: BTreeMap<String, Value>,
) -> Result<Self, NewEmbedderError> {
Ok(if indexing_fragments.is_empty() && search_fragments.is_empty() {
RequestData::Single(Request::new(request)?)
} else {
RequestData::FromFragments(RequestFromFragments::new(request, search_fragments)?)
})
}
fn input_type(&self) -> InputType {
match self {
RequestData::Single(request) => request.input_type(),
RequestData::FromFragments(request_from_fragments) => {
request_from_fragments.input_type()
}
}
}
fn has_fragments(&self) -> bool {
matches!(self, RequestData::FromFragments(_))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
pub struct EmbedderOptions {
pub api_key: Option<String>,
pub distribution: Option<DistributionShift>,
pub dimensions: Option<usize>,
pub url: String,
pub request: serde_json::Value,
pub response: serde_json::Value,
pub request: Value,
pub search_fragments: BTreeMap<String, Value>,
pub indexing_fragments: BTreeMap<String, Value>,
pub response: Value,
pub headers: BTreeMap<String, String>,
}
@ -138,7 +175,12 @@ impl Embedder {
.timeout(std::time::Duration::from_secs(30))
.build();
let request = Request::new(options.request)?;
let request = RequestData::new(
options.request,
options.indexing_fragments,
options.search_fragments,
)?;
let response = Response::new(options.response, &request)?;
let data = EmbedderData {
@ -188,7 +230,7 @@ impl Embedder {
embedder_stats: Option<&EmbedderStats>,
) -> Result<Vec<Embedding>, EmbedError>
where
S: AsRef<str> + Serialize,
S: Serialize,
{
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats)
}
@ -231,9 +273,9 @@ impl Embedder {
}
}
pub(crate) fn embed_index_ref(
pub(crate) fn embed_index_ref<S: Serialize + Sync>(
&self,
texts: &[&str],
texts: &[S],
threads: &ThreadPoolNoAbort,
embedder_stats: &EmbedderStats,
) -> Result<Vec<Embedding>, EmbedError> {
@ -287,9 +329,44 @@ impl Embedder {
pub(super) fn cache(&self) -> &EmbeddingCache {
&self.cache
}
pub(crate) fn embed_one(
&self,
query: SearchQuery,
deadline: Option<Instant>,
embedder_stats: Option<&EmbedderStats>,
) -> Result<Embedding, EmbedError> {
let mut embeddings = match (&self.data.request, query) {
(RequestData::Single(_), SearchQuery::Text(text)) => {
embed(&self.data, &[text], 1, Some(self.dimensions), deadline, embedder_stats)
}
(RequestData::Single(_), SearchQuery::Media { q: _, media: _ }) => {
return Err(EmbedError::rest_media_not_a_fragment())
}
(RequestData::FromFragments(request_from_fragments), SearchQuery::Text(q)) => {
let fragment = request_from_fragments.render_search_fragment(Some(q), None)?;
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
}
(
RequestData::FromFragments(request_from_fragments),
SearchQuery::Media { q, media },
) => {
let fragment = request_from_fragments.render_search_fragment(q, media)?;
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
}
}?;
// unwrap: checked by `expected_count`
Ok(embeddings.pop().unwrap())
}
}
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
if data.request.has_fragments() {
return Err(NewEmbedderError::rest_cannot_infer_dimensions_for_fragment());
}
let v = embed(data, ["test"].as_slice(), 1, None, None, None)
.map_err(NewEmbedderError::could_not_determine_dimension)?;
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
@ -307,6 +384,13 @@ fn embed<S>(
where
S: Serialize,
{
if inputs.is_empty() {
if expected_count != 0 {
return Err(EmbedError::rest_response_embedding_count(expected_count, 0));
}
return Ok(Vec::new());
}
let request = data.client.post(&data.url);
let request = if let Some(bearer) = &data.bearer {
request.set("Authorization", bearer)
@ -318,7 +402,12 @@ where
request = request.set(header.as_str(), value.as_str());
}
let body = data.request.inject_texts(inputs);
let body = match &data.request {
RequestData::Single(request) => request.inject_texts(inputs),
RequestData::FromFragments(request_from_fragments) => {
request_from_fragments.request_from_fragments(inputs).expect("inputs was empty")
}
};
for attempt in 0..10 {
if let Some(embedder_stats) = &embedder_stats {
@ -426,7 +515,7 @@ fn response_to_embedding(
expected_count: usize,
expected_dimensions: Option<usize>,
) -> Result<Vec<Embedding>, Retry> {
let response: serde_json::Value = response
let response: Value = response
.into_json()
.map_err(EmbedError::rest_response_deserialization)
.map_err(Retry::retry_later)?;
@ -455,17 +544,19 @@ fn response_to_embedding(
}
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
pub(super) const REQUEST_FRAGMENT_PLACEHOLDER: &str = "{{fragment}}";
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
#[derive(Debug)]
pub struct Request {
template: ValueTemplate,
template: InjectableValue,
}
impl Request {
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
pub fn new(template: Value) -> Result<Self, NewEmbedderError> {
let template = match InjectableValue::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER)
{
Ok(template) => template,
Err(error) => {
let message =
@ -485,42 +576,114 @@ impl Request {
}
}
pub fn inject_texts<S: Serialize>(
&self,
texts: impl IntoIterator<Item = S>,
) -> serde_json::Value {
pub fn inject_texts<S: Serialize>(&self, texts: impl IntoIterator<Item = S>) -> Value {
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
}
}
#[derive(Debug)]
pub struct RequestFromFragments {
search_fragments: BTreeMap<String, JsonTemplate>,
request: InjectableValue,
}
impl RequestFromFragments {
pub fn new(
request: Value,
search_fragments: impl IntoIterator<Item = (String, Value)>,
) -> Result<Self, NewEmbedderError> {
let request =
match InjectableValue::new(request, REQUEST_FRAGMENT_PLACEHOLDER, REPEAT_PLACEHOLDER) {
Ok(template) => template,
Err(error) => {
let message =
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
let search_fragments: Result<_, NewEmbedderError> = search_fragments
.into_iter()
.map(|(name, value)| {
Ok((
name,
JsonTemplate::new(value).map_err(|error| {
NewEmbedderError::rest_could_not_parse_template(
error.parsing("searchFragments"),
)
})?,
))
})
.collect();
Ok(Self { request, search_fragments: search_fragments? })
}
fn input_type(&self) -> InputType {
if self.request.has_array_value() {
InputType::TextArray
} else {
InputType::Text
}
}
pub fn render_search_fragment(
&self,
q: Option<&str>,
media: Option<&Value>,
) -> Result<Value, EmbedError> {
let mut it = self.search_fragments.iter().filter_map(|(name, template)| {
let render = template.render_search(q, media).ok()?;
Some((name, render))
});
let Some((name, fragment)) = it.next() else {
return Err(EmbedError::rest_search_matches_no_fragment(q, media));
};
if let Some((second_name, _)) = it.next() {
return Err(EmbedError::rest_search_matches_multiple_fragments(
name,
second_name,
q,
media,
));
}
Ok(fragment)
}
pub fn request_from_fragments<'a, S: Serialize + 'a>(
&self,
fragments: impl IntoIterator<Item = &'a S>,
) -> Option<Value> {
self.request.inject(fragments.into_iter().map(|fragment| serde_json::json!(fragment))).ok()
}
}
#[derive(Debug)]
pub struct Response {
template: ValueTemplate,
template: InjectableValue,
}
impl Response {
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
{
Ok(template) => template,
Err(error) => {
let message =
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
pub fn new(template: Value, request: &RequestData) -> Result<Self, NewEmbedderError> {
let template =
match InjectableValue::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) {
Ok(template) => template,
Err(error) => {
let message =
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
return Err(NewEmbedderError::rest_could_not_parse_template(message));
}
};
match (template.has_array_value(), request.template.has_array_value()) {
match (template.has_array_value(), request.input_type() == InputType::TextArray) {
(true, true) | (false, false) => Ok(Self {template}),
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
}
}
pub fn extract_embeddings(
&self,
response: serde_json::Value,
) -> Result<Vec<Embedding>, EmbedError> {
pub fn extract_embeddings(&self, response: Value) -> Result<Vec<Embedding>, EmbedError> {
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
Ok(extracted_values) => extracted_values,
Err(error) => {