diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index fbe3c1129..9477959ad 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -6,11 +6,13 @@ use rand::Rng; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use rayon::slice::ParallelSlice as _; use serde::{Deserialize, Serialize}; +use serde_json::Value; use super::error::EmbedErrorKind; -use super::json_template::ValueTemplate; +use super::json_template::{InjectableValue, JsonTemplate}; use super::{ - DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, REQUEST_PARALLELISM, + DistributionShift, EmbedError, Embedding, EmbeddingCache, NewEmbedderError, SearchQuery, + REQUEST_PARALLELISM, }; use crate::error::FaultSource; use crate::progress::EmbedderStats; @@ -88,19 +90,54 @@ struct EmbedderData { bearer: Option, headers: BTreeMap, url: String, - request: Request, + request: RequestData, response: Response, configuration_source: ConfigurationSource, } +#[derive(Debug)] +pub enum RequestData { + Single(Request), + FromFragments(RequestFromFragments), +} + +impl RequestData { + pub fn new( + request: Value, + indexing_fragments: BTreeMap, + search_fragments: BTreeMap, + ) -> Result { + Ok(if indexing_fragments.is_empty() && search_fragments.is_empty() { + RequestData::Single(Request::new(request)?) + } else { + RequestData::FromFragments(RequestFromFragments::new(request, search_fragments)?) + }) + } + + fn input_type(&self) -> InputType { + match self { + RequestData::Single(request) => request.input_type(), + RequestData::FromFragments(request_from_fragments) => { + request_from_fragments.input_type() + } + } + } + + fn has_fragments(&self) -> bool { + matches!(self, RequestData::FromFragments(_)) + } +} + #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] pub struct EmbedderOptions { pub api_key: Option, pub distribution: Option, pub dimensions: Option, pub url: String, - pub request: serde_json::Value, - pub response: serde_json::Value, + pub request: Value, + pub search_fragments: BTreeMap, + pub indexing_fragments: BTreeMap, + pub response: Value, pub headers: BTreeMap, } @@ -138,7 +175,12 @@ impl Embedder { .timeout(std::time::Duration::from_secs(30)) .build(); - let request = Request::new(options.request)?; + let request = RequestData::new( + options.request, + options.indexing_fragments, + options.search_fragments, + )?; + let response = Response::new(options.response, &request)?; let data = EmbedderData { @@ -188,7 +230,7 @@ impl Embedder { embedder_stats: Option<&EmbedderStats>, ) -> Result, EmbedError> where - S: AsRef + Serialize, + S: Serialize, { embed(&self.data, texts, texts.len(), Some(self.dimensions), deadline, embedder_stats) } @@ -231,9 +273,9 @@ impl Embedder { } } - pub(crate) fn embed_index_ref( + pub(crate) fn embed_index_ref( &self, - texts: &[&str], + texts: &[S], threads: &ThreadPoolNoAbort, embedder_stats: &EmbedderStats, ) -> Result, EmbedError> { @@ -287,9 +329,44 @@ impl Embedder { pub(super) fn cache(&self) -> &EmbeddingCache { &self.cache } + + pub(crate) fn embed_one( + &self, + query: SearchQuery, + deadline: Option, + embedder_stats: Option<&EmbedderStats>, + ) -> Result { + let mut embeddings = match (&self.data.request, query) { + (RequestData::Single(_), SearchQuery::Text(text)) => { + embed(&self.data, &[text], 1, Some(self.dimensions), deadline, embedder_stats) + } + (RequestData::Single(_), SearchQuery::Media { q: _, media: _ }) => { + return Err(EmbedError::rest_media_not_a_fragment()) + } + (RequestData::FromFragments(request_from_fragments), SearchQuery::Text(q)) => { + let fragment = request_from_fragments.render_search_fragment(Some(q), None)?; + + embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats) + } + ( + RequestData::FromFragments(request_from_fragments), + SearchQuery::Media { q, media }, + ) => { + let fragment = request_from_fragments.render_search_fragment(q, media)?; + + embed(&self.data, &[fragment], 1, Some(self.dimensions), deadline, embedder_stats) + } + }?; + + // unwrap: checked by `expected_count` + Ok(embeddings.pop().unwrap()) + } } fn infer_dimensions(data: &EmbedderData) -> Result { + if data.request.has_fragments() { + return Err(NewEmbedderError::rest_cannot_infer_dimensions_for_fragment()); + } let v = embed(data, ["test"].as_slice(), 1, None, None, None) .map_err(NewEmbedderError::could_not_determine_dimension)?; // unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error @@ -307,6 +384,13 @@ fn embed( where S: Serialize, { + if inputs.is_empty() { + if expected_count != 0 { + return Err(EmbedError::rest_response_embedding_count(expected_count, 0)); + } + return Ok(Vec::new()); + } + let request = data.client.post(&data.url); let request = if let Some(bearer) = &data.bearer { request.set("Authorization", bearer) @@ -318,7 +402,12 @@ where request = request.set(header.as_str(), value.as_str()); } - let body = data.request.inject_texts(inputs); + let body = match &data.request { + RequestData::Single(request) => request.inject_texts(inputs), + RequestData::FromFragments(request_from_fragments) => { + request_from_fragments.request_from_fragments(inputs).expect("inputs was empty") + } + }; for attempt in 0..10 { if let Some(embedder_stats) = &embedder_stats { @@ -426,7 +515,7 @@ fn response_to_embedding( expected_count: usize, expected_dimensions: Option, ) -> Result, Retry> { - let response: serde_json::Value = response + let response: Value = response .into_json() .map_err(EmbedError::rest_response_deserialization) .map_err(Retry::retry_later)?; @@ -455,17 +544,19 @@ fn response_to_embedding( } pub(super) const REQUEST_PLACEHOLDER: &str = "{{text}}"; +pub(super) const REQUEST_FRAGMENT_PLACEHOLDER: &str = "{{fragment}}"; pub(super) const RESPONSE_PLACEHOLDER: &str = "{{embedding}}"; pub(super) const REPEAT_PLACEHOLDER: &str = "{{..}}"; #[derive(Debug)] pub struct Request { - template: ValueTemplate, + template: InjectableValue, } impl Request { - pub fn new(template: serde_json::Value) -> Result { - let template = match ValueTemplate::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) { + pub fn new(template: Value) -> Result { + let template = match InjectableValue::new(template, REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER) + { Ok(template) => template, Err(error) => { let message = @@ -485,42 +576,114 @@ impl Request { } } - pub fn inject_texts( - &self, - texts: impl IntoIterator, - ) -> serde_json::Value { + pub fn inject_texts(&self, texts: impl IntoIterator) -> Value { self.template.inject(texts.into_iter().map(|s| serde_json::json!(s))).unwrap() } } +#[derive(Debug)] +pub struct RequestFromFragments { + search_fragments: BTreeMap, + request: InjectableValue, +} + +impl RequestFromFragments { + pub fn new( + request: Value, + search_fragments: impl IntoIterator, + ) -> Result { + let request = + match InjectableValue::new(request, REQUEST_FRAGMENT_PLACEHOLDER, REPEAT_PLACEHOLDER) { + Ok(template) => template, + Err(error) => { + let message = + error.error_message("request", REQUEST_PLACEHOLDER, REPEAT_PLACEHOLDER); + return Err(NewEmbedderError::rest_could_not_parse_template(message)); + } + }; + + let search_fragments: Result<_, NewEmbedderError> = search_fragments + .into_iter() + .map(|(name, value)| { + Ok(( + name, + JsonTemplate::new(value).map_err(|error| { + NewEmbedderError::rest_could_not_parse_template( + error.parsing("searchFragments"), + ) + })?, + )) + }) + .collect(); + + Ok(Self { request, search_fragments: search_fragments? }) + } + + fn input_type(&self) -> InputType { + if self.request.has_array_value() { + InputType::TextArray + } else { + InputType::Text + } + } + + pub fn render_search_fragment( + &self, + q: Option<&str>, + media: Option<&Value>, + ) -> Result { + let mut it = self.search_fragments.iter().filter_map(|(name, template)| { + let render = template.render_search(q, media).ok()?; + Some((name, render)) + }); + let Some((name, fragment)) = it.next() else { + return Err(EmbedError::rest_search_matches_no_fragment(q, media)); + }; + if let Some((second_name, _)) = it.next() { + return Err(EmbedError::rest_search_matches_multiple_fragments( + name, + second_name, + q, + media, + )); + } + + Ok(fragment) + } + + pub fn request_from_fragments<'a, S: Serialize + 'a>( + &self, + fragments: impl IntoIterator, + ) -> Option { + self.request.inject(fragments.into_iter().map(|fragment| serde_json::json!(fragment))).ok() + } +} + #[derive(Debug)] pub struct Response { - template: ValueTemplate, + template: InjectableValue, } impl Response { - pub fn new(template: serde_json::Value, request: &Request) -> Result { - let template = match ValueTemplate::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) - { - Ok(template) => template, - Err(error) => { - let message = - error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER); - return Err(NewEmbedderError::rest_could_not_parse_template(message)); - } - }; + pub fn new(template: Value, request: &RequestData) -> Result { + let template = + match InjectableValue::new(template, RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER) { + Ok(template) => template, + Err(error) => { + let message = + error.error_message("response", RESPONSE_PLACEHOLDER, REPEAT_PLACEHOLDER); + return Err(NewEmbedderError::rest_could_not_parse_template(message)); + } + }; - match (template.has_array_value(), request.template.has_array_value()) { + match (template.has_array_value(), request.input_type() == InputType::TextArray) { (true, true) | (false, false) => Ok(Self {template}), (true, false) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has multiple embeddings, but `request` has only one text to embed".to_string())), (false, true) => Err(NewEmbedderError::rest_could_not_parse_template("in `response`: `response` has a single embedding, but `request` has multiple texts to embed".to_string())), } } - pub fn extract_embeddings( - &self, - response: serde_json::Value, - ) -> Result, EmbedError> { + pub fn extract_embeddings(&self, response: Value) -> Result, EmbedError> { let extracted_values: Vec = match self.template.extract(response) { Ok(extracted_values) => extracted_values, Err(error) => {