Change embedder API

This commit is contained in:
Louis Dureuil 2025-06-29 23:53:06 +02:00
parent 5716ab70f3
commit e7b9b8f002
No known key found for this signature in database

View file

@ -797,6 +797,27 @@ pub enum EmbedderOptions {
Composite(composite::EmbedderOptions),
}
impl EmbedderOptions {
pub fn fragment(&self, name: &str) -> Option<&serde_json::Value> {
match &self {
EmbedderOptions::HuggingFace(_)
| EmbedderOptions::OpenAi(_)
| EmbedderOptions::Ollama(_)
| EmbedderOptions::UserProvided(_) => None,
EmbedderOptions::Rest(embedder_options) => {
embedder_options.indexing_fragments.get(name)
}
EmbedderOptions::Composite(embedder_options) => {
if let SubEmbedderOptions::Rest(embedder_options) = &embedder_options.index {
embedder_options.indexing_fragments.get(name)
} else {
None
}
}
}
}
}
impl Default for EmbedderOptions {
fn default() -> Self {
Self::HuggingFace(Default::default())
@ -837,6 +858,17 @@ impl Embedder {
#[tracing::instrument(level = "debug", skip_all, target = "search")]
pub fn embed_search(
&self,
query: SearchQuery<'_>,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
match query {
SearchQuery::Text(text) => self.embed_search_text(text, deadline),
SearchQuery::Media { q, media } => self.embed_search_media(q, media, deadline),
}
}
pub fn embed_search_text(
&self,
text: &str,
deadline: Option<Instant>,
@ -858,10 +890,7 @@ impl Embedder {
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::UserProvided(embedder) => embedder.embed_one(text),
Embedder::Rest(embedder) => embedder
.embed_ref(&[text], deadline, None)?
.pop()
.ok_or_else(EmbedError::missing_embedding),
Embedder::Rest(embedder) => embedder.embed_one(SearchQuery::Text(text), deadline, None),
Embedder::Composite(embedder) => embedder.search.embed_one(text, deadline, None),
}?;
@ -872,6 +901,18 @@ impl Embedder {
Ok(embedding)
}
pub fn embed_search_media(
&self,
q: Option<&str>,
media: Option<&serde_json::Value>,
deadline: Option<Instant>,
) -> std::result::Result<Embedding, EmbedError> {
let Embedder::Rest(embedder) = self else {
return Err(EmbedError::rest_media_not_a_rest());
};
embedder.embed_one(SearchQuery::Media { q, media }, deadline, None)
}
/// Embed multiple chunks of texts.
///
/// Each chunk is composed of one or multiple texts.
@ -916,6 +957,26 @@ impl Embedder {
}
}
pub fn embed_index_ref_fragments(
&self,
fragments: &[serde_json::Value],
threads: &ThreadPoolNoAbort,
embedder_stats: &EmbedderStats,
) -> std::result::Result<Vec<Embedding>, EmbedError> {
if let Embedder::Rest(embedder) = self {
embedder.embed_index_ref(fragments, threads, embedder_stats)
} else {
let Embedder::Composite(embedder) = self else {
unimplemented!("embedding fragments is only available for rest embedders")
};
let crate::vector::composite::SubEmbedder::Rest(embedder) = &embedder.index else {
unimplemented!("embedding fragments is only available for rest embedders")
};
embedder.embed_index_ref(fragments, threads, embedder_stats)
}
}
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
pub fn chunk_count_hint(&self) -> usize {
match self {
@ -987,6 +1048,12 @@ impl Embedder {
}
}
#[derive(Clone, Copy)]
pub enum SearchQuery<'a> {
Text(&'a str),
Media { q: Option<&'a str>, media: Option<&'a serde_json::Value> },
}
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
///
/// The intended use is to make the similarity score more comparable to the regular ranking score.