From e7b9b8f00230429831fe5467f3a1d5161465e6e2 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Sun, 29 Jun 2025 23:53:06 +0200 Subject: [PATCH] Change embedder API --- crates/milli/src/vector/mod.rs | 75 ++++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index 37ade8f81..87ecd2414 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -797,6 +797,27 @@ pub enum EmbedderOptions { Composite(composite::EmbedderOptions), } +impl EmbedderOptions { + pub fn fragment(&self, name: &str) -> Option<&serde_json::Value> { + match &self { + EmbedderOptions::HuggingFace(_) + | EmbedderOptions::OpenAi(_) + | EmbedderOptions::Ollama(_) + | EmbedderOptions::UserProvided(_) => None, + EmbedderOptions::Rest(embedder_options) => { + embedder_options.indexing_fragments.get(name) + } + EmbedderOptions::Composite(embedder_options) => { + if let SubEmbedderOptions::Rest(embedder_options) = &embedder_options.index { + embedder_options.indexing_fragments.get(name) + } else { + None + } + } + } + } +} + impl Default for EmbedderOptions { fn default() -> Self { Self::HuggingFace(Default::default()) @@ -837,6 +858,17 @@ impl Embedder { #[tracing::instrument(level = "debug", skip_all, target = "search")] pub fn embed_search( + &self, + query: SearchQuery<'_>, + deadline: Option, + ) -> std::result::Result { + match query { + SearchQuery::Text(text) => self.embed_search_text(text, deadline), + SearchQuery::Media { q, media } => self.embed_search_media(q, media, deadline), + } + } + + pub fn embed_search_text( &self, text: &str, deadline: Option, @@ -858,10 +890,7 @@ impl Embedder { .pop() .ok_or_else(EmbedError::missing_embedding), Embedder::UserProvided(embedder) => embedder.embed_one(text), - Embedder::Rest(embedder) => embedder - .embed_ref(&[text], deadline, None)? - .pop() - .ok_or_else(EmbedError::missing_embedding), + Embedder::Rest(embedder) => embedder.embed_one(SearchQuery::Text(text), deadline, None), Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline, None), }?; @@ -872,6 +901,18 @@ impl Embedder { Ok(embedding) } + pub fn embed_search_media( + &self, + q: Option<&str>, + media: Option<&serde_json::Value>, + deadline: Option, + ) -> std::result::Result { + let Embedder::Rest(embedder) = self else { + return Err(EmbedError::rest_media_not_a_rest()); + }; + embedder.embed_one(SearchQuery::Media { q, media }, deadline, None) + } + /// Embed multiple chunks of texts. /// /// Each chunk is composed of one or multiple texts. @@ -916,6 +957,26 @@ impl Embedder { } } + pub fn embed_index_ref_fragments( + &self, + fragments: &[serde_json::Value], + threads: &ThreadPoolNoAbort, + embedder_stats: &EmbedderStats, + ) -> std::result::Result, EmbedError> { + if let Embedder::Rest(embedder) = self { + embedder.embed_index_ref(fragments, threads, embedder_stats) + } else { + let Embedder::Composite(embedder) = self else { + unimplemented!("embedding fragments is only available for rest embedders") + }; + let crate::vector::composite::SubEmbedder::Rest(embedder) = &embedder.index else { + unimplemented!("embedding fragments is only available for rest embedders") + }; + + embedder.embed_index_ref(fragments, threads, embedder_stats) + } + } + /// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`] pub fn chunk_count_hint(&self) -> usize { match self { @@ -987,6 +1048,12 @@ impl Embedder { } } +#[derive(Clone, Copy)] +pub enum SearchQuery<'a> { + Text(&'a str), + Media { q: Option<&'a str>, media: Option<&'a serde_json::Value> }, +} + /// Describes the mean and sigma of distribution of embedding similarity in the embedding space. /// /// The intended use is to make the similarity score more comparable to the regular ranking score.