diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 1e0bcc7fb..d3369ef3d 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -58,7 +58,7 @@ pub enum EmbedErrorKind { RestResponseDeserialization(std::io::Error), #[error("component `{0}` not found in path `{1}` in response: `{2}`")] RestResponseMissingEmbeddings(String, String, String), - #[error("expected a response parseable as a vector or an array of vectors: {0}")] + #[error("unexpected format of the embedding response: {0}")] RestResponseFormat(serde_json::Error), #[error("expected a response containing {0} embeddings, got only {1}")] RestResponseEmbeddingCount(usize, usize), @@ -78,6 +78,8 @@ pub enum EmbedErrorKind { RestNotAnObject(serde_json::Value, Vec), #[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")] OpenAiUnexpectedDimension(usize, usize), + #[error("no embedding was produced")] + MissingEmbedding, } impl EmbedError { @@ -190,6 +192,9 @@ impl EmbedError { fault: FaultSource::Runtime, } } + pub(crate) fn missing_embedding() -> EmbedError { + Self { kind: EmbedErrorKind::MissingEmbedding, fault: FaultSource::Undecided } + } } #[derive(Debug, thiserror::Error)] diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 5aa58da5d..58f7ba5e1 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -237,6 +237,17 @@ impl Embedder { } } + pub fn embed_one(&self, text: String) -> std::result::Result { + let mut embeddings = self.embed(vec![text])?; + let embeddings = embeddings.pop().ok_or_else(EmbedError::missing_embedding)?; + Ok(if embeddings.iter().nth(1).is_some() { + tracing::warn!("Ignoring embeddings past the first one in long search query"); + embeddings.iter().next().unwrap().to_vec() + } else { + embeddings.into_inner() + }) + } + /// Embed multiple chunks of texts. /// /// Each chunk is composed of one or multiple texts.