mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 04:17:10 +02:00
REST embedder supports fragments
This commit is contained in:
parent
e7b9b8f002
commit
4235a82dcf
1 changed files with 197 additions and 34 deletions
|
@ -6,11 +6,13 @@ use rand::Rng;
|
|||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||
use rayon::slice::ParallelSlice as _;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
use super::error::EmbedErrorKind;
|
||||
use super::json_template::ValueTemplate;
|
||||
use super::json_template::{InjectableValue, JsonTemplate};
|
||||
use super::{
|
||||
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
|
||||
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, SearchQuery,
|
||||
REQUEST_PARALLELISM,
|
||||
};
|
||||
use crate::error::FaultSource;
|
||||
use crate::progress::EmbedderStats;
|
||||
|
@ -88,19 +90,54 @@ struct EmbedderData {
|
|||
bearer: Option<String>,
|
||||
headers: BTreeMap<String, String>,
|
||||
url: String,
|
||||
request: Request,
|
||||
request: RequestData,
|
||||
response: Response,
|
||||
configuration_source: ConfigurationSource,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RequestData {
|
||||
Single(Request),
|
||||
FromFragments(RequestFromFragments),
|
||||
}
|
||||
|
||||
impl RequestData {
|
||||
pub fn new(
|
||||
request: Value,
|
||||
indexing_fragments: BTreeMap<String, Value>,
|
||||
search_fragments: BTreeMap<String, Value>,
|
||||
) -> Result<Self, NewEmbedderError> {
|
||||
Ok(if indexing_fragments.is_empty() && search_fragments.is_empty() {
|
||||
RequestData::Single(Request::new(request)?)
|
||||
} else {
|
||||
RequestData::FromFragments(RequestFromFragments::new(request, search_fragments)?)
|
||||
})
|
||||
}
|
||||
|
||||
fn input_type(&self) -> InputType {
|
||||
match self {
|
||||
RequestData::Single(request) => request.input_type(),
|
||||
RequestData::FromFragments(request_from_fragments) => {
|
||||
request_from_fragments.input_type()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn has_fragments(&self) -> bool {
|
||||
matches!(self, RequestData::FromFragments(_))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub api_key: Option<String>,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
pub dimensions: Option<usize>,
|
||||
pub url: String,
|
||||
pub request: serde_json::Value,
|
||||
pub response: serde_json::Value,
|
||||
pub request: Value,
|
||||
pub search_fragments: BTreeMap<String, Value>,
|
||||
pub indexing_fragments: BTreeMap<String, Value>,
|
||||
pub response: Value,
|
||||
pub headers: BTreeMap<String, String>,
|
||||
}
|
||||
|
||||
|
@ -138,7 +175,12 @@ impl Embedder {
|
|||
.timeout(std::time::Duration::from_secs(30))
|
||||
.build();
|
||||
|
||||
let request = Request::new(options.request)?;
|
||||
let request = RequestData::new(
|
||||
options.request,
|
||||
options.indexing_fragments,
|
||||
options.search_fragments,
|
||||
)?;
|
||||
|
||||
let response = Response::new(options.response, &request)?;
|
||||
|
||||
let data = EmbedderData {
|
||||
|
@ -188,7 +230,7 @@ impl Embedder {
|
|||
embedder_stats: Option<&EmbedderStats>,
|
||||
) -> Result<Vec<Embedding>, EmbedError>
|
||||
where
|
||||
S: AsRef<str> + Serialize,
|
||||
S: Serialize,
|
||||
{
|
||||
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats)
|
||||
}
|
||||
|
@ -231,9 +273,9 @@ impl Embedder {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn embed_index_ref(
|
||||
pub(crate) fn embed_index_ref<S: Serialize + Sync>(
|
||||
&self,
|
||||
texts: &[&str],
|
||||
texts: &[S],
|
||||
threads: &ThreadPoolNoAbort,
|
||||
embedder_stats: &EmbedderStats,
|
||||
) -> Result<Vec<Embedding>, EmbedError> {
|
||||
|
@ -287,9 +329,44 @@ impl Embedder {
|
|||
pub(super) fn cache(&self) -> &EmbeddingCache {
|
||||
&self.cache
|
||||
}
|
||||
|
||||
pub(crate) fn embed_one(
|
||||
&self,
|
||||
query: SearchQuery,
|
||||
deadline: Option<Instant>,
|
||||
embedder_stats: Option<&EmbedderStats>,
|
||||
) -> Result<Embedding, EmbedError> {
|
||||
let mut embeddings = match (&self.data.request, query) {
|
||||
(RequestData::Single(_), SearchQuery::Text(text)) => {
|
||||
embed(&self.data, &[text], 1, Some(self.dimensions), deadline, embedder_stats)
|
||||
}
|
||||
(RequestData::Single(_), SearchQuery::Media { q: _, media: _ }) => {
|
||||
return Err(EmbedError::rest_media_not_a_fragment())
|
||||
}
|
||||
(RequestData::FromFragments(request_from_fragments), SearchQuery::Text(q)) => {
|
||||
let fragment = request_from_fragments.render_search_fragment(Some(q), None)?;
|
||||
|
||||
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
|
||||
}
|
||||
(
|
||||
RequestData::FromFragments(request_from_fragments),
|
||||
SearchQuery::Media { q, media },
|
||||
) => {
|
||||
let fragment = request_from_fragments.render_search_fragment(q, media)?;
|
||||
|
||||
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
|
||||
}
|
||||
}?;
|
||||
|
||||
// unwrap: checked by `expected_count`
|
||||
Ok(embeddings.pop().unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
||||
if data.request.has_fragments() {
|
||||
return Err(NewEmbedderError::rest_cannot_infer_dimensions_for_fragment());
|
||||
}
|
||||
let v = embed(data, ["test"].as_slice(), 1, None, None, None)
|
||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||
|
@ -307,6 +384,13 @@ fn embed<S>(
|
|||
where
|
||||
S: Serialize,
|
||||
{
|
||||
if inputs.is_empty() {
|
||||
if expected_count != 0 {
|
||||
return Err(EmbedError::rest_response_embedding_count(expected_count, 0));
|
||||
}
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let request = data.client.post(&data.url);
|
||||
let request = if let Some(bearer) = &data.bearer {
|
||||
request.set("Authorization", bearer)
|
||||
|
@ -318,7 +402,12 @@ where
|
|||
request = request.set(header.as_str(), value.as_str());
|
||||
}
|
||||
|
||||
let body = data.request.inject_texts(inputs);
|
||||
let body = match &data.request {
|
||||
RequestData::Single(request) => request.inject_texts(inputs),
|
||||
RequestData::FromFragments(request_from_fragments) => {
|
||||
request_from_fragments.request_from_fragments(inputs).expect("inputs was empty")
|
||||
}
|
||||
};
|
||||
|
||||
for attempt in 0..10 {
|
||||
if let Some(embedder_stats) = &embedder_stats {
|
||||
|
@ -426,7 +515,7 @@ fn response_to_embedding(
|
|||
expected_count: usize,
|
||||
expected_dimensions: Option<usize>,
|
||||
) -> Result<Vec<Embedding>, Retry> {
|
||||
let response: serde_json::Value = response
|
||||
let response: Value = response
|
||||
.into_json()
|
||||
.map_err(EmbedError::rest_response_deserialization)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
@ -455,17 +544,19 @@ fn response_to_embedding(
|
|||
}
|
||||
|
||||
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
|
||||
pub(super) const REQUEST_FRAGMENT_PLACEHOLDER: &str = "{{fragment}}";
|
||||
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
|
||||
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Request {
|
||||
template: ValueTemplate,
|
||||
template: InjectableValue,
|
||||
}
|
||||
|
||||
impl Request {
|
||||
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
|
||||
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
||||
pub fn new(template: Value) -> Result<Self, NewEmbedderError> {
|
||||
let template = match InjectableValue::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER)
|
||||
{
|
||||
Ok(template) => template,
|
||||
Err(error) => {
|
||||
let message =
|
||||
|
@ -485,42 +576,114 @@ impl Request {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn inject_texts<S: Serialize>(
|
||||
&self,
|
||||
texts: impl IntoIterator<Item = S>,
|
||||
) -> serde_json::Value {
|
||||
pub fn inject_texts<S: Serialize>(&self, texts: impl IntoIterator<Item = S>) -> Value {
|
||||
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RequestFromFragments {
|
||||
search_fragments: BTreeMap<String, JsonTemplate>,
|
||||
request: InjectableValue,
|
||||
}
|
||||
|
||||
impl RequestFromFragments {
|
||||
pub fn new(
|
||||
request: Value,
|
||||
search_fragments: impl IntoIterator<Item = (String, Value)>,
|
||||
) -> Result<Self, NewEmbedderError> {
|
||||
let request =
|
||||
match InjectableValue::new(request, REQUEST_FRAGMENT_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
||||
Ok(template) => template,
|
||||
Err(error) => {
|
||||
let message =
|
||||
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||
}
|
||||
};
|
||||
|
||||
let search_fragments: Result<_, NewEmbedderError> = search_fragments
|
||||
.into_iter()
|
||||
.map(|(name, value)| {
|
||||
Ok((
|
||||
name,
|
||||
JsonTemplate::new(value).map_err(|error| {
|
||||
NewEmbedderError::rest_could_not_parse_template(
|
||||
error.parsing("searchFragments"),
|
||||
)
|
||||
})?,
|
||||
))
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Self { request, search_fragments: search_fragments? })
|
||||
}
|
||||
|
||||
fn input_type(&self) -> InputType {
|
||||
if self.request.has_array_value() {
|
||||
InputType::TextArray
|
||||
} else {
|
||||
InputType::Text
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render_search_fragment(
|
||||
&self,
|
||||
q: Option<&str>,
|
||||
media: Option<&Value>,
|
||||
) -> Result<Value, EmbedError> {
|
||||
let mut it = self.search_fragments.iter().filter_map(|(name, template)| {
|
||||
let render = template.render_search(q, media).ok()?;
|
||||
Some((name, render))
|
||||
});
|
||||
let Some((name, fragment)) = it.next() else {
|
||||
return Err(EmbedError::rest_search_matches_no_fragment(q, media));
|
||||
};
|
||||
if let Some((second_name, _)) = it.next() {
|
||||
return Err(EmbedError::rest_search_matches_multiple_fragments(
|
||||
name,
|
||||
second_name,
|
||||
q,
|
||||
media,
|
||||
));
|
||||
}
|
||||
|
||||
Ok(fragment)
|
||||
}
|
||||
|
||||
pub fn request_from_fragments<'a, S: Serialize + 'a>(
|
||||
&self,
|
||||
fragments: impl IntoIterator<Item = &'a S>,
|
||||
) -> Option<Value> {
|
||||
self.request.inject(fragments.into_iter().map(|fragment| serde_json::json!(fragment))).ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Response {
|
||||
template: ValueTemplate,
|
||||
template: InjectableValue,
|
||||
}
|
||||
|
||||
impl Response {
|
||||
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
|
||||
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
|
||||
{
|
||||
Ok(template) => template,
|
||||
Err(error) => {
|
||||
let message =
|
||||
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||
}
|
||||
};
|
||||
pub fn new(template: Value, request: &RequestData) -> Result<Self, NewEmbedderError> {
|
||||
let template =
|
||||
match InjectableValue::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
||||
Ok(template) => template,
|
||||
Err(error) => {
|
||||
let message =
|
||||
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||
}
|
||||
};
|
||||
|
||||
match (template.has_array_value(), request.template.has_array_value()) {
|
||||
match (template.has_array_value(), request.input_type() == InputType::TextArray) {
|
||||
(true, true) | (false, false) => Ok(Self {template}),
|
||||
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
|
||||
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extract_embeddings(
|
||||
&self,
|
||||
response: serde_json::Value,
|
||||
) -> Result<Vec<Embedding>, EmbedError> {
|
||||
pub fn extract_embeddings(&self, response: Value) -> Result<Vec<Embedding>, EmbedError> {
|
||||
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
||||
Ok(extracted_values) => extracted_values,
|
||||
Err(error) => {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue