mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 12:27:13 +02:00
REST embedder supports fragments
This commit is contained in:
parent
e7b9b8f002
commit
4235a82dcf
1 changed files with 197 additions and 34 deletions
|
@ -6,11 +6,13 @@ use rand::Rng;
|
||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
use rayon::slice::ParallelSlice as _;
|
use rayon::slice::ParallelSlice as _;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
use super::error::EmbedErrorKind;
|
use super::error::EmbedErrorKind;
|
||||||
use super::json_template::ValueTemplate;
|
use super::json_template::{InjectableValue, JsonTemplate};
|
||||||
use super::{
|
use super::{
|
||||||
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM,
|
DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, SearchQuery,
|
||||||
|
REQUEST_PARALLELISM,
|
||||||
};
|
};
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
use crate::progress::EmbedderStats;
|
use crate::progress::EmbedderStats;
|
||||||
|
@ -88,19 +90,54 @@ struct EmbedderData {
|
||||||
bearer: Option<String>,
|
bearer: Option<String>,
|
||||||
headers: BTreeMap<String, String>,
|
headers: BTreeMap<String, String>,
|
||||||
url: String,
|
url: String,
|
||||||
request: Request,
|
request: RequestData,
|
||||||
response: Response,
|
response: Response,
|
||||||
configuration_source: ConfigurationSource,
|
configuration_source: ConfigurationSource,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum RequestData {
|
||||||
|
Single(Request),
|
||||||
|
FromFragments(RequestFromFragments),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestData {
|
||||||
|
pub fn new(
|
||||||
|
request: Value,
|
||||||
|
indexing_fragments: BTreeMap<String, Value>,
|
||||||
|
search_fragments: BTreeMap<String, Value>,
|
||||||
|
) -> Result<Self, NewEmbedderError> {
|
||||||
|
Ok(if indexing_fragments.is_empty() && search_fragments.is_empty() {
|
||||||
|
RequestData::Single(Request::new(request)?)
|
||||||
|
} else {
|
||||||
|
RequestData::FromFragments(RequestFromFragments::new(request, search_fragments)?)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_type(&self) -> InputType {
|
||||||
|
match self {
|
||||||
|
RequestData::Single(request) => request.input_type(),
|
||||||
|
RequestData::FromFragments(request_from_fragments) => {
|
||||||
|
request_from_fragments.input_type()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_fragments(&self) -> bool {
|
||||||
|
matches!(self, RequestData::FromFragments(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
pub distribution: Option<DistributionShift>,
|
pub distribution: Option<DistributionShift>,
|
||||||
pub dimensions: Option<usize>,
|
pub dimensions: Option<usize>,
|
||||||
pub url: String,
|
pub url: String,
|
||||||
pub request: serde_json::Value,
|
pub request: Value,
|
||||||
pub response: serde_json::Value,
|
pub search_fragments: BTreeMap<String, Value>,
|
||||||
|
pub indexing_fragments: BTreeMap<String, Value>,
|
||||||
|
pub response: Value,
|
||||||
pub headers: BTreeMap<String, String>,
|
pub headers: BTreeMap<String, String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -138,7 +175,12 @@ impl Embedder {
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
.timeout(std::time::Duration::from_secs(30))
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
let request = Request::new(options.request)?;
|
let request = RequestData::new(
|
||||||
|
options.request,
|
||||||
|
options.indexing_fragments,
|
||||||
|
options.search_fragments,
|
||||||
|
)?;
|
||||||
|
|
||||||
let response = Response::new(options.response, &request)?;
|
let response = Response::new(options.response, &request)?;
|
||||||
|
|
||||||
let data = EmbedderData {
|
let data = EmbedderData {
|
||||||
|
@ -188,7 +230,7 @@ impl Embedder {
|
||||||
embedder_stats: Option<&EmbedderStats>,
|
embedder_stats: Option<&EmbedderStats>,
|
||||||
) -> Result<Vec<Embedding>, EmbedError>
|
) -> Result<Vec<Embedding>, EmbedError>
|
||||||
where
|
where
|
||||||
S: AsRef<str> + Serialize,
|
S: Serialize,
|
||||||
{
|
{
|
||||||
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats)
|
embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats)
|
||||||
}
|
}
|
||||||
|
@ -231,9 +273,9 @@ impl Embedder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn embed_index_ref(
|
pub(crate) fn embed_index_ref<S: Serialize + Sync>(
|
||||||
&self,
|
&self,
|
||||||
texts: &[&str],
|
texts: &[S],
|
||||||
threads: &ThreadPoolNoAbort,
|
threads: &ThreadPoolNoAbort,
|
||||||
embedder_stats: &EmbedderStats,
|
embedder_stats: &EmbedderStats,
|
||||||
) -> Result<Vec<Embedding>, EmbedError> {
|
) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
|
@ -287,9 +329,44 @@ impl Embedder {
|
||||||
pub(super) fn cache(&self) -> &EmbeddingCache {
|
pub(super) fn cache(&self) -> &EmbeddingCache {
|
||||||
&self.cache
|
&self.cache
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn embed_one(
|
||||||
|
&self,
|
||||||
|
query: SearchQuery,
|
||||||
|
deadline: Option<Instant>,
|
||||||
|
embedder_stats: Option<&EmbedderStats>,
|
||||||
|
) -> Result<Embedding, EmbedError> {
|
||||||
|
let mut embeddings = match (&self.data.request, query) {
|
||||||
|
(RequestData::Single(_), SearchQuery::Text(text)) => {
|
||||||
|
embed(&self.data, &[text], 1, Some(self.dimensions), deadline, embedder_stats)
|
||||||
|
}
|
||||||
|
(RequestData::Single(_), SearchQuery::Media { q: _, media: _ }) => {
|
||||||
|
return Err(EmbedError::rest_media_not_a_fragment())
|
||||||
|
}
|
||||||
|
(RequestData::FromFragments(request_from_fragments), SearchQuery::Text(q)) => {
|
||||||
|
let fragment = request_from_fragments.render_search_fragment(Some(q), None)?;
|
||||||
|
|
||||||
|
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
|
||||||
|
}
|
||||||
|
(
|
||||||
|
RequestData::FromFragments(request_from_fragments),
|
||||||
|
SearchQuery::Media { q, media },
|
||||||
|
) => {
|
||||||
|
let fragment = request_from_fragments.render_search_fragment(q, media)?;
|
||||||
|
|
||||||
|
embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats)
|
||||||
|
}
|
||||||
|
}?;
|
||||||
|
|
||||||
|
// unwrap: checked by `expected_count`
|
||||||
|
Ok(embeddings.pop().unwrap())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
fn infer_dimensions(data: &EmbedderData) -> Result<usize, NewEmbedderError> {
|
||||||
|
if data.request.has_fragments() {
|
||||||
|
return Err(NewEmbedderError::rest_cannot_infer_dimensions_for_fragment());
|
||||||
|
}
|
||||||
let v = embed(data, ["test"].as_slice(), 1, None, None, None)
|
let v = embed(data, ["test"].as_slice(), 1, None, None, None)
|
||||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||||
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||||
|
@ -307,6 +384,13 @@ fn embed<S>(
|
||||||
where
|
where
|
||||||
S: Serialize,
|
S: Serialize,
|
||||||
{
|
{
|
||||||
|
if inputs.is_empty() {
|
||||||
|
if expected_count != 0 {
|
||||||
|
return Err(EmbedError::rest_response_embedding_count(expected_count, 0));
|
||||||
|
}
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
let request = data.client.post(&data.url);
|
let request = data.client.post(&data.url);
|
||||||
let request = if let Some(bearer) = &data.bearer {
|
let request = if let Some(bearer) = &data.bearer {
|
||||||
request.set("Authorization", bearer)
|
request.set("Authorization", bearer)
|
||||||
|
@ -318,7 +402,12 @@ where
|
||||||
request = request.set(header.as_str(), value.as_str());
|
request = request.set(header.as_str(), value.as_str());
|
||||||
}
|
}
|
||||||
|
|
||||||
let body = data.request.inject_texts(inputs);
|
let body = match &data.request {
|
||||||
|
RequestData::Single(request) => request.inject_texts(inputs),
|
||||||
|
RequestData::FromFragments(request_from_fragments) => {
|
||||||
|
request_from_fragments.request_from_fragments(inputs).expect("inputs was empty")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
for attempt in 0..10 {
|
for attempt in 0..10 {
|
||||||
if let Some(embedder_stats) = &embedder_stats {
|
if let Some(embedder_stats) = &embedder_stats {
|
||||||
|
@ -426,7 +515,7 @@ fn response_to_embedding(
|
||||||
expected_count: usize,
|
expected_count: usize,
|
||||||
expected_dimensions: Option<usize>,
|
expected_dimensions: Option<usize>,
|
||||||
) -> Result<Vec<Embedding>, Retry> {
|
) -> Result<Vec<Embedding>, Retry> {
|
||||||
let response: serde_json::Value = response
|
let response: Value = response
|
||||||
.into_json()
|
.into_json()
|
||||||
.map_err(EmbedError::rest_response_deserialization)
|
.map_err(EmbedError::rest_response_deserialization)
|
||||||
.map_err(Retry::retry_later)?;
|
.map_err(Retry::retry_later)?;
|
||||||
|
@ -455,17 +544,19 @@ fn response_to_embedding(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
|
pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}";
|
||||||
|
pub(super) const REQUEST_FRAGMENT_PLACEHOLDER: &str = "{{fragment}}";
|
||||||
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
|
pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}";
|
||||||
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
|
pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}";
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Request {
|
pub struct Request {
|
||||||
template: ValueTemplate,
|
template: InjectableValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Request {
|
impl Request {
|
||||||
pub fn new(template: serde_json::Value) -> Result<Self, NewEmbedderError> {
|
pub fn new(template: Value) -> Result<Self, NewEmbedderError> {
|
||||||
let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
let template = match InjectableValue::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER)
|
||||||
|
{
|
||||||
Ok(template) => template,
|
Ok(template) => template,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
let message =
|
let message =
|
||||||
|
@ -485,42 +576,114 @@ impl Request {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn inject_texts<S: Serialize>(
|
pub fn inject_texts<S: Serialize>(&self, texts: impl IntoIterator<Item = S>) -> Value {
|
||||||
&self,
|
|
||||||
texts: impl IntoIterator<Item = S>,
|
|
||||||
) -> serde_json::Value {
|
|
||||||
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
|
self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RequestFromFragments {
|
||||||
|
search_fragments: BTreeMap<String, JsonTemplate>,
|
||||||
|
request: InjectableValue,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RequestFromFragments {
|
||||||
|
pub fn new(
|
||||||
|
request: Value,
|
||||||
|
search_fragments: impl IntoIterator<Item = (String, Value)>,
|
||||||
|
) -> Result<Self, NewEmbedderError> {
|
||||||
|
let request =
|
||||||
|
match InjectableValue::new(request, REQUEST_FRAGMENT_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
||||||
|
Ok(template) => template,
|
||||||
|
Err(error) => {
|
||||||
|
let message =
|
||||||
|
error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||||
|
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let search_fragments: Result<_, NewEmbedderError> = search_fragments
|
||||||
|
.into_iter()
|
||||||
|
.map(|(name, value)| {
|
||||||
|
Ok((
|
||||||
|
name,
|
||||||
|
JsonTemplate::new(value).map_err(|error| {
|
||||||
|
NewEmbedderError::rest_could_not_parse_template(
|
||||||
|
error.parsing("searchFragments"),
|
||||||
|
)
|
||||||
|
})?,
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(Self { request, search_fragments: search_fragments? })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn input_type(&self) -> InputType {
|
||||||
|
if self.request.has_array_value() {
|
||||||
|
InputType::TextArray
|
||||||
|
} else {
|
||||||
|
InputType::Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn render_search_fragment(
|
||||||
|
&self,
|
||||||
|
q: Option<&str>,
|
||||||
|
media: Option<&Value>,
|
||||||
|
) -> Result<Value, EmbedError> {
|
||||||
|
let mut it = self.search_fragments.iter().filter_map(|(name, template)| {
|
||||||
|
let render = template.render_search(q, media).ok()?;
|
||||||
|
Some((name, render))
|
||||||
|
});
|
||||||
|
let Some((name, fragment)) = it.next() else {
|
||||||
|
return Err(EmbedError::rest_search_matches_no_fragment(q, media));
|
||||||
|
};
|
||||||
|
if let Some((second_name, _)) = it.next() {
|
||||||
|
return Err(EmbedError::rest_search_matches_multiple_fragments(
|
||||||
|
name,
|
||||||
|
second_name,
|
||||||
|
q,
|
||||||
|
media,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(fragment)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn request_from_fragments<'a, S: Serialize + 'a>(
|
||||||
|
&self,
|
||||||
|
fragments: impl IntoIterator<Item = &'a S>,
|
||||||
|
) -> Option<Value> {
|
||||||
|
self.request.inject(fragments.into_iter().map(|fragment| serde_json::json!(fragment))).ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Response {
|
pub struct Response {
|
||||||
template: ValueTemplate,
|
template: InjectableValue,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Response {
|
impl Response {
|
||||||
pub fn new(template: serde_json::Value, request: &Request) -> Result<Self, NewEmbedderError> {
|
pub fn new(template: Value, request: &RequestData) -> Result<Self, NewEmbedderError> {
|
||||||
let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER)
|
let template =
|
||||||
{
|
match InjectableValue::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) {
|
||||||
Ok(template) => template,
|
Ok(template) => template,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
let message =
|
let message =
|
||||||
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER);
|
||||||
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
return Err(NewEmbedderError::rest_could_not_parse_template(message));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
match (template.has_array_value(), request.template.has_array_value()) {
|
match (template.has_array_value(), request.input_type() == InputType::TextArray) {
|
||||||
(true, true) | (false, false) => Ok(Self {template}),
|
(true, true) | (false, false) => Ok(Self {template}),
|
||||||
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
|
(true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())),
|
||||||
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
|
(false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn extract_embeddings(
|
pub fn extract_embeddings(&self, response: Value) -> Result<Vec<Embedding>, EmbedError> {
|
||||||
&self,
|
|
||||||
response: serde_json::Value,
|
|
||||||
) -> Result<Vec<Embedding>, EmbedError> {
|
|
||||||
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
let extracted_values: Vec<Embedding> = match self.template.extract(response) {
|
||||||
Ok(extracted_values) => extracted_values,
|
Ok(extracted_values) => extracted_values,
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue