mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 20:37:15 +02:00
Make _vectors.:embedding.regenerate mandatory + tests + error messages
This commit is contained in:
parent
298c7b0c93
commit
1daaed163a
4 changed files with 336 additions and 16 deletions
|
@ -1,5 +1,6 @@
|
|||
use std::collections::{BTreeMap, BTreeSet};
|
||||
|
||||
use deserr::{take_cf_content, DeserializeError, Deserr, Sequence};
|
||||
use obkv::KvReader;
|
||||
use serde_json::{from_slice, Value};
|
||||
|
||||
|
@ -10,13 +11,44 @@ use crate::{DocumentId, FieldId, InternalError, UserError};
|
|||
|
||||
pub const RESERVED_VECTORS_FIELD_NAME: &str = "_vectors";
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
#[serde(untagged)]
|
||||
pub enum Vectors {
|
||||
ImplicitlyUserProvided(VectorOrArrayOfVectors),
|
||||
Explicit(ExplicitVectors),
|
||||
}
|
||||
|
||||
impl<E: DeserializeError> Deserr<E> for Vectors {
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef,
|
||||
) -> Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::Sequence(_) | deserr::Value::Null => {
|
||||
Ok(Vectors::ImplicitlyUserProvided(VectorOrArrayOfVectors::deserialize_from_value(
|
||||
value, location,
|
||||
)?))
|
||||
}
|
||||
deserr::Value::Map(_) => {
|
||||
Ok(Vectors::Explicit(ExplicitVectors::deserialize_from_value(value, location)?))
|
||||
}
|
||||
|
||||
value => Err(take_cf_content(E::error(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[
|
||||
deserr::ValueKind::Sequence,
|
||||
deserr::ValueKind::Map,
|
||||
deserr::ValueKind::Null,
|
||||
],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Vectors {
|
||||
pub fn must_regenerate(&self) -> bool {
|
||||
match self {
|
||||
|
@ -37,9 +69,11 @@ impl Vectors {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
#[derive(serde::Serialize, Deserr, Debug)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ExplicitVectors {
|
||||
#[serde(default)]
|
||||
#[deserr(default)]
|
||||
pub embeddings: Option<VectorOrArrayOfVectors>,
|
||||
pub regenerate: bool,
|
||||
}
|
||||
|
@ -149,13 +183,20 @@ impl ParsedVectorsDiff {
|
|||
|
||||
pub struct ParsedVectors(pub BTreeMap<String, Vectors>);
|
||||
|
||||
impl<E: DeserializeError> Deserr<E> for ParsedVectors {
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef,
|
||||
) -> Result<Self, E> {
|
||||
let value = <BTreeMap<String, Vectors>>::deserialize_from_value(value, location)?;
|
||||
Ok(ParsedVectors(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl ParsedVectors {
|
||||
pub fn from_bytes(value: &[u8]) -> Result<Self, Error> {
|
||||
let Ok(value) = from_slice(value) else {
|
||||
let value = from_slice(value).map_err(Error::InternalSerdeJson)?;
|
||||
return Err(Error::InvalidMap(value));
|
||||
};
|
||||
Ok(ParsedVectors(value))
|
||||
let value: serde_json::Value = from_slice(value).map_err(Error::InternalSerdeJson)?;
|
||||
deserr::deserialize(value).map_err(|error| Error::InvalidEmbedderConf { error })
|
||||
}
|
||||
|
||||
pub fn retain_not_embedded_vectors(&mut self, embedders: &BTreeSet<String>) {
|
||||
|
@ -165,6 +206,7 @@ impl ParsedVectors {
|
|||
|
||||
pub enum Error {
|
||||
InvalidMap(Value),
|
||||
InvalidEmbedderConf { error: deserr::errors::JsonError },
|
||||
InternalSerdeJson(serde_json::Error),
|
||||
}
|
||||
|
||||
|
@ -174,6 +216,12 @@ impl Error {
|
|||
Error::InvalidMap(value) => {
|
||||
crate::Error::UserError(UserError::InvalidVectorsMapType { document_id, value })
|
||||
}
|
||||
Error::InvalidEmbedderConf { error } => {
|
||||
crate::Error::UserError(UserError::InvalidVectorsEmbedderConf {
|
||||
document_id,
|
||||
error,
|
||||
})
|
||||
}
|
||||
Error::InternalSerdeJson(error) => {
|
||||
crate::Error::InternalError(InternalError::SerdeJson(error))
|
||||
}
|
||||
|
@ -194,13 +242,73 @@ fn to_vector_map(
|
|||
}
|
||||
|
||||
/// Represents either a vector or an array of multiple vectors.
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||
#[derive(serde::Serialize, Debug)]
|
||||
#[serde(transparent)]
|
||||
pub struct VectorOrArrayOfVectors {
|
||||
#[serde(with = "either::serde_untagged_optional")]
|
||||
inner: Option<either::Either<Vec<Embedding>, Embedding>>,
|
||||
}
|
||||
|
||||
impl<E: DeserializeError> Deserr<E> for VectorOrArrayOfVectors {
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef,
|
||||
) -> Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::Null => Ok(VectorOrArrayOfVectors { inner: None }),
|
||||
deserr::Value::Sequence(seq) => {
|
||||
let mut iter = seq.into_iter();
|
||||
let location = location.push_index(0);
|
||||
match iter.next().map(|v| v.into_value()) {
|
||||
None => {
|
||||
// With the strange way serde serialize the `Either`, we must send the left part
|
||||
// otherwise it'll consider we returned [[]]
|
||||
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(Vec::new())) })
|
||||
}
|
||||
Some(val @ deserr::Value::Sequence(_)) => {
|
||||
let first = Embedding::deserialize_from_value(val, location)?;
|
||||
let mut collect = vec![first];
|
||||
let mut tail = iter
|
||||
.map(|v| Embedding::deserialize_from_value(v.into_value(), location))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
collect.append(&mut tail);
|
||||
|
||||
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Left(collect)) })
|
||||
}
|
||||
Some(
|
||||
val @ deserr::Value::Integer(_)
|
||||
| val @ deserr::Value::NegativeInteger(_)
|
||||
| val @ deserr::Value::Float(_),
|
||||
) => {
|
||||
let first = <f32>::deserialize_from_value(val, location)?;
|
||||
let mut embedding = iter
|
||||
.map(|v| <f32>::deserialize_from_value(v.into_value(), location))
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
embedding.insert(0, first);
|
||||
Ok(VectorOrArrayOfVectors { inner: Some(either::Either::Right(embedding)) })
|
||||
}
|
||||
Some(value) => Err(take_cf_content(E::error(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Float],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
value => Err(take_cf_content(E::error(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::Sequence, deserr::ValueKind::Null],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl VectorOrArrayOfVectors {
|
||||
pub fn into_array_of_vectors(self) -> Option<Vec<Embedding>> {
|
||||
match self.inner? {
|
||||
|
@ -234,15 +342,19 @@ impl From<Vec<Embedding>> for VectorOrArrayOfVectors {
|
|||
mod test {
|
||||
use super::VectorOrArrayOfVectors;
|
||||
|
||||
fn embedding_from_str(s: &str) -> Result<VectorOrArrayOfVectors, deserr::errors::JsonError> {
|
||||
let value: serde_json::Value = serde_json::from_str(s).unwrap();
|
||||
deserr::deserialize(value)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn array_of_vectors() {
|
||||
let null: VectorOrArrayOfVectors = serde_json::from_str("null").unwrap();
|
||||
let empty: VectorOrArrayOfVectors = serde_json::from_str("[]").unwrap();
|
||||
let one: VectorOrArrayOfVectors = serde_json::from_str("[0.1]").unwrap();
|
||||
let two: VectorOrArrayOfVectors = serde_json::from_str("[0.1, 0.2]").unwrap();
|
||||
let one_vec: VectorOrArrayOfVectors = serde_json::from_str("[[0.1, 0.2]]").unwrap();
|
||||
let two_vecs: VectorOrArrayOfVectors =
|
||||
serde_json::from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
|
||||
let null = embedding_from_str("null").unwrap();
|
||||
let empty = embedding_from_str("[]").unwrap();
|
||||
let one = embedding_from_str("[0.1]").unwrap();
|
||||
let two = embedding_from_str("[0.1, 0.2]").unwrap();
|
||||
let one_vec = embedding_from_str("[[0.1, 0.2]]").unwrap();
|
||||
let two_vecs = embedding_from_str("[[0.1, 0.2], [0.3, 0.4]]").unwrap();
|
||||
|
||||
insta::assert_json_snapshot!(null.into_array_of_vectors(), @"null");
|
||||
insta::assert_json_snapshot!(empty.into_array_of_vectors(), @"[]");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue