Emit better error messages when parsing vectors

This commit is contained in:
Louis Dureuil 2024-11-12 22:49:22 +01:00
parent 8a6e61c77f
commit c4e9f761e9
No known key found for this signature in database
4 changed files with 242 additions and 37 deletions

View file

@ -19,10 +19,54 @@ pub enum RawVectors<'doc> {
ImplicitlyUserProvided(#[serde(borrow)] Option<&'doc RawValue>),
}
pub enum RawVectorsError {
DeserializeSeq { index: usize, error: String },
DeserializeKey { error: String },
DeserializeRegenerate { error: String },
DeserializeEmbeddings { error: String },
UnknownField { field: String },
MissingRegenerate,
WrongKind { kind: &'static str, value: String },
Parsing(serde_json::Error),
}
impl RawVectorsError {
pub fn msg(self, embedder_name: &str) -> String {
match self {
RawVectorsError::DeserializeSeq { index, error } => format!(
"Could not parse `._vectors.{embedder_name}[{index}]`: {error}"
),
RawVectorsError::DeserializeKey { error } => format!(
"Could not parse a field at `._vectors.{embedder_name}`: {error}"
),
RawVectorsError::DeserializeRegenerate { error } => format!(
"Could not parse `._vectors.{embedder_name}.regenerate`: {error}"
),
RawVectorsError::DeserializeEmbeddings { error } => format!(
"Could not parse `._vectors.{embedder_name}.embeddings`: {error}"
),
RawVectorsError::UnknownField { field } => format!(
"Unexpected field `._vectors.{embedder_name}.{field}`\n \
\t - note: the allowed fields are `regenerate` and `embeddings`"
),
RawVectorsError::MissingRegenerate => format!(
"Missing field `._vectors.{embedder_name}.regenerate`\n \
\t - note: `._vectors.{embedder_name}` must be an array of floats, an array of arrays of floats, or an object with field `regenerate`"
),
RawVectorsError::WrongKind { kind, value } => format!(
"Expected `._vectors.{embedder_name}` to be an array of floats, an array of arrays of floats, or an object with at least the field `regenerate`, but got the {kind} `{value}`"
),
RawVectorsError::Parsing(error) => format!(
"Could not parse `._vectors.{embedder_name}`: {error}"
),
}
}
}
impl<'doc> RawVectors<'doc> {
pub fn from_raw_value(raw: &'doc RawValue) -> Result<Self, serde_json::Error> {
pub fn from_raw_value(raw: &'doc RawValue) -> Result<Self, RawVectorsError> {
use serde::de::Deserializer as _;
Ok(match raw.deserialize_any(RawVectorsVisitor)? {
Ok(match raw.deserialize_any(RawVectorsVisitor).map_err(RawVectorsError::Parsing)?? {
RawVectorsVisitorValue::ImplicitNone => RawVectors::ImplicitlyUserProvided(None),
RawVectorsVisitorValue::Implicit => RawVectors::ImplicitlyUserProvided(Some(raw)),
RawVectorsVisitorValue::Explicit { regenerate, embeddings } => {
@ -41,7 +85,7 @@ enum RawVectorsVisitorValue<'doc> {
}
impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor {
type Value = RawVectorsVisitorValue<'doc>;
type Value = std::result::Result<RawVectorsVisitorValue<'doc>, RawVectorsError>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "a map containing at least `regenerate`, or an array of floats`")
@ -51,7 +95,7 @@ impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor {
where
E: serde::de::Error,
{
Ok(RawVectorsVisitorValue::ImplicitNone)
Ok(Ok(RawVectorsVisitorValue::ImplicitNone))
}
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
@ -65,42 +109,150 @@ impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor {
where
E: serde::de::Error,
{
Ok(RawVectorsVisitorValue::ImplicitNone)
Ok(Ok(RawVectorsVisitorValue::ImplicitNone))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'doc>,
{
let mut index = 0;
// must consume all elements or parsing fails
while let Some(_) = seq.next_element::<&RawValue>()? {}
Ok(RawVectorsVisitorValue::Implicit)
loop {
match seq.next_element::<&RawValue>() {
Ok(Some(_)) => index += 1,
Err(error) => {
return Ok(Err(RawVectorsError::DeserializeSeq {
index,
error: error.to_string(),
}))
}
Ok(None) => break,
};
}
Ok(Ok(RawVectorsVisitorValue::Implicit))
}
fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'doc>,
{
use serde::de::Error as _;
let mut regenerate = None;
let mut embeddings = None;
while let Some(s) = map.next_key()? {
match s {
"regenerate" => {
let value: bool = map.next_value()?;
loop {
match map.next_key::<&str>() {
Ok(Some("regenerate")) => {
let value: bool = match map.next_value() {
Ok(value) => value,
Err(error) => {
return Ok(Err(RawVectorsError::DeserializeRegenerate {
error: error.to_string(),
}))
}
};
regenerate = Some(value);
}
"embeddings" => {
let value: &RawValue = map.next_value()?;
Ok(Some("embeddings")) => {
let value: &RawValue = match map.next_value() {
Ok(value) => value,
Err(error) => {
return Ok(Err(RawVectorsError::DeserializeEmbeddings {
error: error.to_string(),
}))
}
};
embeddings = Some(value);
}
other => return Err(A::Error::unknown_field(other, &["regenerate", "embeddings"])),
Ok(Some(other)) => {
return Ok(Err(RawVectorsError::UnknownField { field: other.to_string() }))
}
Ok(None) => break,
Err(error) => {
return Ok(Err(RawVectorsError::DeserializeKey { error: error.to_string() }))
}
}
}
let Some(regenerate) = regenerate else {
return Err(A::Error::missing_field("regenerate"));
return Ok(Err(RawVectorsError::MissingRegenerate));
};
Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings })
Ok(Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings }))
}
fn visit_bool<E>(self, v: bool) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "boolean", value: v.to_string() }))
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() }))
}
fn visit_i128<E>(self, v: i128) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() }))
}
fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() }))
}
fn visit_u128<E>(self, v: u128) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() }))
}
fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "number", value: v.to_string() }))
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "string", value: v.to_string() }))
}
fn visit_string<E>(self, v: String) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "string", value: v }))
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(Err(RawVectorsError::WrongKind { kind: "bytes", value: format!("{v:?}") }))
}
fn visit_newtype_struct<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: serde::Deserializer<'doc>,
{
deserializer.deserialize_any(self)
}
fn visit_enum<A>(self, _data: A) -> Result<Self::Value, A::Error>
where
A: serde::de::EnumAccess<'doc>,
{
Ok(Err(RawVectorsError::WrongKind { kind: "enum", value: "a variant".to_string() }))
}
}