Make _vectors.:embedding.regenerate mandatory + tests + error messages

This commit is contained in:
Tamo 2024-06-27 11:01:52 +02:00
parent 298c7b0c93
commit 1daaed163a
4 changed files with 336 additions and 16 deletions

View file

@ -119,6 +119,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
InvalidVectorDimensions { expected: usize, found: usize },
#[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")]
InvalidVectorsMapType { document_id: String, value: Value },
#[error("Bad embedder configuration in the document with id: `{document_id}`. {error}")]
InvalidVectorsEmbedderConf { document_id: String, error: deserr::errors::JsonError },
#[error("{0}")]
InvalidFilter(String),
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]

View file

@ -1,5 +1,6 @@
use std::collections::{BTreeMap, BTreeSet};
use deserr::{take_cf_content, DeserializeError, Deserr, Sequence};
use obkv::KvReader;
use serde_json::{from_slice, Value};
@ -10,13 +11,44 @@ use crate::{DocumentId, FieldId, InternalError, UserError};
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
#[derive(serde::Serialize, serde::Deserialize, Debug)]
#[derive(serde::Serialize, Debug)]
#[serde(untagged)]
pub enum Vectors {
ImplicitlyUserProvided(VectorOrArrayOfVectors),
Explicit(ExplicitVectors),
}
impl<E: DeserializeError> Deserr<E> for Vectors {
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
match value {
deserr::Value::Sequence(_) | deserr::Value::Null => {
Ok(Vectors::ImplicitlyUserProvided(VectorOrArrayOfVectors::deserialize_from_value(
value, location,
)?))
}
deserr::Value::Map(_) => {
Ok(Vectors::Explicit(ExplicitVectors::deserialize_from_value(value, location)?))
}
value => Err(take_cf_content(E::error(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[
deserr::ValueKind::Sequence,
deserr::ValueKind::Map,
deserr::ValueKind::Null,
],
},
location,
))),
}
}
}
impl Vectors {
pub fn must_regenerate(&self) -> bool {
match self {
@ -37,9 +69,11 @@ impl Vectors {
}
}
#[derive(serde::Serialize, serde::Deserialize, Debug)]
#[derive(serde::Serialize, Deserr, Debug)]
#[serde(rename_all = "camelCase")]
pub struct ExplicitVectors {
#[serde(default)]
#[deserr(default)]
pub embeddings: Option<VectorOrArrayOfVectors>,
pub regenerate: bool,
}
@ -149,13 +183,20 @@ impl ParsedVectorsDiff {
pub struct ParsedVectors(pub BTreeMap<String, Vectors>);
impl<E: DeserializeError> Deserr<E> for ParsedVectors {
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
let value = <BTreeMap<String, Vectors>>::deserialize_from_value(value, location)?;
Ok(ParsedVectors(value))
}
}
impl ParsedVectors {
pub fn from_bytes(value: &[u8]) -> Result<Self, Error> {
let Ok(value) = from_slice(value) else {
let value = from_slice(value).map_err(Error::InternalSerdeJson)?;
return Err(Error::InvalidMap(value));
};
Ok(ParsedVectors(value))
let value: serde_json::Value = from_slice(value).map_err(Error::InternalSerdeJson)?;
deserr::deserialize(value).map_err(|error| Error::InvalidEmbedderConf { error })
}
pub fn retain_not_embedded_vectors(&mut self, embedders: &BTreeSet<String>) {
@ -165,6 +206,7 @@ impl ParsedVectors {
pub enum Error {
InvalidMap(Value),
InvalidEmbedderConf { error: deserr::errors::JsonError },
InternalSerdeJson(serde_json::Error),
}
@ -174,6 +216,12 @@ impl Error {
Error::InvalidMap(value) => {
crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value })
}
Error::InvalidEmbedderConf { error } => {
crate::Error::UserError(UserError::InvalidVectorsEmbedderConf {
document_id,
error,
})
}
Error::InternalSerdeJson(error) => {
crate::Error::InternalError(InternalError::SerdeJson(error))
}
@ -194,13 +242,73 @@ fn to_vector_map(
}
/// Represents either a vector or an array of multiple vectors.
#[derive(serde::Serialize, serde::Deserialize, Debug)]
#[derive(serde::Serialize, Debug)]
#[serde(transparent)]
pub struct VectorOrArrayOfVectors {
#[serde(with = "either::serde_untagged_optional")]
inner: Option<either::Either<Vec<Embedding>, Embedding>>,
}
impl<E: DeserializeError> Deserr<E> for VectorOrArrayOfVectors {
fn deserialize_from_value<V: deserr::IntoValue>(
value: deserr::Value<V>,
location: deserr::ValuePointerRef,
) -> Result<Self, E> {
match value {
deserr::Value::Null => Ok(VectorOrArrayOfVectors { inner: None }),
deserr::Value::Sequence(seq) => {
let mut iter = seq.into_iter();
let location = location.push_index(0);
match iter.next().map(|v| v.into_value()) {
None => {
// With the strange way serde serialize the `Either`, we must send the left part
// otherwise it'll consider we returned [[]]
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(Vec::new())) })
}
Some(val @ deserr::Value::Sequence(_)) => {
let first = Embedding::deserialize_from_value(val, location)?;
let mut collect = vec![first];
let mut tail = iter
.map(|v| Embedding::deserialize_from_value(v.into_value(), location))
.collect::<Result<Vec<_>, _>>()?;
collect.append(&mut tail);
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(collect)) })
}
Some(
val @ deserr::Value::Integer(_)
| val @ deserr::Value::NegativeInteger(_)
| val @ deserr::Value::Float(_),
) => {
let first = <f32>::deserialize_from_value(val, location)?;
let mut embedding = iter
.map(|v| <f32>::deserialize_from_value(v.into_value(), location))
.collect::<Result<Vec<_>, _>>()?;
embedding.insert(0, first);
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Right(embedding)) })
}
Some(value) => Err(take_cf_content(E::error(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Float],
},
location,
))),
}
}
value => Err(take_cf_content(E::error(
None,
deserr::ErrorKind::IncorrectValueKind {
actual: value,
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Null],
},
location,
))),
}
}
}
impl VectorOrArrayOfVectors {
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
match self.inner? {
@ -234,15 +342,19 @@ impl From<Vec<Embedding>> for VectorOrArrayOfVectors {
mod test {
use super::VectorOrArrayOfVectors;
fn embedding_from_str(s: &str) -> Result<VectorOrArrayOfVectors, deserr::errors::JsonError> {
let value: serde_json::Value = serde_json::from_str(s).unwrap();
deserr::deserialize(value)
}
#[test]
fn array_of_vectors() {
let null: VectorOrArrayOfVectors = serde_json::from_str("null").unwrap();
let empty: VectorOrArrayOfVectors = serde_json::from_str("[]").unwrap();
let one: VectorOrArrayOfVectors = serde_json::from_str("[0.1]").unwrap();
let two: VectorOrArrayOfVectors = serde_json::from_str("[0.1, 0.2]").unwrap();
let one_vec: VectorOrArrayOfVectors = serde_json::from_str("[[0.1, 0.2]]").unwrap();
let two_vecs: VectorOrArrayOfVectors =
serde_json::from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
let null = embedding_from_str("null").unwrap();
let empty = embedding_from_str("[]").unwrap();
let one = embedding_from_str("[0.1]").unwrap();
let two = embedding_from_str("[0.1, 0.2]").unwrap();
let one_vec = embedding_from_str("[[0.1, 0.2]]").unwrap();
let two_vecs = embedding_from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
insta::assert_json_snapshot!(null.into_array_of_vectors(), @"null");
insta::assert_json_snapshot!(empty.into_array_of_vectors(), @"[]");