From c4e9f761e95f900a8d6648c34cdea7668b4acb6b Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Nov 2024 22:49:22 +0100 Subject: [PATCH] Emit better error messages when parsing vectors --- .../milli/src/update/new/document_change.rs | 20 +- .../src/update/new/extract/vectors/mod.rs | 4 +- .../milli/src/update/new/vector_document.rs | 69 +++++-- crates/milli/src/vector/parsed_vectors.rs | 186 ++++++++++++++++-- 4 files changed, 242 insertions(+), 37 deletions(-) diff --git a/crates/milli/src/update/new/document_change.rs b/crates/milli/src/update/new/document_change.rs index 4a61c110d..899655db1 100644 --- a/crates/milli/src/update/new/document_change.rs +++ b/crates/milli/src/update/new/document_change.rs @@ -97,7 +97,7 @@ impl<'doc> Insertion<'doc> { doc_alloc: &'doc Bump, embedders: &'doc EmbeddingConfigs, ) -> Result>> { - VectorDocumentFromVersions::new(&self.new, doc_alloc, embedders) + VectorDocumentFromVersions::new(self.external_document_id, &self.new, doc_alloc, embedders) } } @@ -169,7 +169,7 @@ impl<'doc> Update<'doc> { doc_alloc: &'doc Bump, embedders: &'doc EmbeddingConfigs, ) -> Result>> { - VectorDocumentFromVersions::new(&self.new, doc_alloc, embedders) + VectorDocumentFromVersions::new(self.external_document_id, &self.new, doc_alloc, embedders) } pub fn merged_vectors( @@ -181,10 +181,22 @@ impl<'doc> Update<'doc> { embedders: &'doc EmbeddingConfigs, ) -> Result>> { if self.has_deletion { - MergedVectorDocument::without_db(&self.new, doc_alloc, embedders) + MergedVectorDocument::without_db( + self.external_document_id, + &self.new, + doc_alloc, + embedders, + ) } else { MergedVectorDocument::with_db( - self.docid, index, rtxn, mapper, &self.new, doc_alloc, embedders, + self.docid, + self.external_document_id, + index, + rtxn, + mapper, + &self.new, + doc_alloc, + embedders, ) } } diff --git a/crates/milli/src/update/new/extract/vectors/mod.rs b/crates/milli/src/update/new/extract/vectors/mod.rs index 514791a65..efb02b2ab 100644 --- a/crates/milli/src/update/new/extract/vectors/mod.rs +++ b/crates/milli/src/update/new/extract/vectors/mod.rs @@ -126,7 +126,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { .into_vec(&context.doc_alloc, embedder_name) .map_err(|error| UserError::InvalidVectorsEmbedderConf { document_id: update.external_document_id().to_string(), - error, + error: error.to_string(), })?, ); } else if new_vectors.regenerate { @@ -210,7 +210,7 @@ impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> { document_id: insertion .external_document_id() .to_string(), - error, + error: error.to_string(), })?, ); } else if new_vectors.regenerate { diff --git a/crates/milli/src/update/new/vector_document.rs b/crates/milli/src/update/new/vector_document.rs index 381c4dab6..736456f0f 100644 --- a/crates/milli/src/update/new/vector_document.rs +++ b/crates/milli/src/update/new/vector_document.rs @@ -12,7 +12,7 @@ use super::indexer::de::DeserrRawValue; use crate::documents::FieldIdMapper; use crate::index::IndexEmbeddingConfig; use crate::vector::parsed_vectors::{ - RawVectors, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME, + RawVectors, RawVectorsError, VectorOrArrayOfVectors, RESERVED_VECTORS_FIELD_NAME, }; use crate::vector::{ArroyWrapper, Embedding, EmbeddingConfigs}; use crate::{DocumentId, Index, InternalError, Result, UserError}; @@ -143,7 +143,14 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { Ok((&*config_name, entry)) }) .chain(self.vectors_field.iter().flat_map(|map| map.iter()).map(|(name, value)| { - Ok((name, entry_from_raw_value(value, false).map_err(InternalError::SerdeJson)?)) + Ok(( + name, + entry_from_raw_value(value, false).map_err(|_| { + InternalError::Serialization(crate::SerializationError::Decoding { + db_name: Some(crate::index::db_name::VECTOR_ARROY), + }) + })?, + )) })) } @@ -155,20 +162,38 @@ impl<'t> VectorDocument<'t> for VectorDocumentFromDb<'t> { Some(self.entry_from_db(embedder_id, config)?) } None => match self.vectors_field.as_ref().and_then(|obkv| obkv.get(key)) { - Some(embedding_from_doc) => Some( - entry_from_raw_value(embedding_from_doc, false) - .map_err(InternalError::SerdeJson)?, - ), + Some(embedding_from_doc) => { + Some(entry_from_raw_value(embedding_from_doc, false).map_err(|_| { + InternalError::Serialization(crate::SerializationError::Decoding { + db_name: Some(crate::index::db_name::VECTOR_ARROY), + }) + })?) + } None => None, }, }) } } +fn entry_from_raw_value_user<'doc>( + external_docid: &str, + embedder_name: &str, + value: &'doc RawValue, + has_configured_embedder: bool, +) -> Result> { + entry_from_raw_value(value, has_configured_embedder).map_err(|error| { + UserError::InvalidVectorsEmbedderConf { + document_id: external_docid.to_string(), + error: error.msg(embedder_name), + } + .into() + }) +} + fn entry_from_raw_value( value: &RawValue, has_configured_embedder: bool, -) -> std::result::Result, serde_json::Error> { +) -> std::result::Result, RawVectorsError> { let value: RawVectors = RawVectors::from_raw_value(value)?; Ok(match value { @@ -194,12 +219,14 @@ fn entry_from_raw_value( } pub struct VectorDocumentFromVersions<'doc> { + external_document_id: &'doc str, vectors: RawMap<'doc>, embedders: &'doc EmbeddingConfigs, } impl<'doc> VectorDocumentFromVersions<'doc> { pub fn new( + external_document_id: &'doc str, versions: &Versions<'doc>, bump: &'doc Bump, embedders: &'doc EmbeddingConfigs, @@ -208,7 +235,7 @@ impl<'doc> VectorDocumentFromVersions<'doc> { if let Some(vectors_field) = document.vectors_field()? { let vectors = RawMap::from_raw_value(vectors_field, bump).map_err(UserError::SerdeJson)?; - Ok(Some(Self { vectors, embedders })) + Ok(Some(Self { external_document_id, vectors, embedders })) } else { Ok(None) } @@ -218,16 +245,24 @@ impl<'doc> VectorDocumentFromVersions<'doc> { impl<'doc> VectorDocument<'doc> for VectorDocumentFromVersions<'doc> { fn iter_vectors(&self) -> impl Iterator)>> { self.vectors.iter().map(|(embedder, vectors)| { - let vectors = entry_from_raw_value(vectors, self.embedders.contains(embedder)) - .map_err(UserError::SerdeJson)?; + let vectors = entry_from_raw_value_user( + self.external_document_id, + embedder, + vectors, + self.embedders.contains(embedder), + )?; Ok((embedder, vectors)) }) } fn vectors_for_key(&self, key: &str) -> Result>> { let Some(vectors) = self.vectors.get(key) else { return Ok(None) }; - let vectors = entry_from_raw_value(vectors, self.embedders.contains(key)) - .map_err(UserError::SerdeJson)?; + let vectors = entry_from_raw_value_user( + self.external_document_id, + key, + vectors, + self.embedders.contains(key), + )?; Ok(Some(vectors)) } } @@ -238,8 +273,10 @@ pub struct MergedVectorDocument<'doc> { } impl<'doc> MergedVectorDocument<'doc> { + #[allow(clippy::too_many_arguments)] pub fn with_db( docid: DocumentId, + external_document_id: &'doc str, index: &'doc Index, rtxn: &'doc RoTxn, db_fields_ids_map: &'doc Mapper, @@ -248,16 +285,20 @@ impl<'doc> MergedVectorDocument<'doc> { embedders: &'doc EmbeddingConfigs, ) -> Result> { let db = VectorDocumentFromDb::new(docid, index, rtxn, db_fields_ids_map, doc_alloc)?; - let new_doc = VectorDocumentFromVersions::new(versions, doc_alloc, embedders)?; + let new_doc = + VectorDocumentFromVersions::new(&external_document_id, versions, doc_alloc, embedders)?; Ok(if db.is_none() && new_doc.is_none() { None } else { Some(Self { new_doc, db }) }) } pub fn without_db( + external_document_id: &'doc str, versions: &Versions<'doc>, doc_alloc: &'doc Bump, embedders: &'doc EmbeddingConfigs, ) -> Result> { - let Some(new_doc) = VectorDocumentFromVersions::new(versions, doc_alloc, embedders)? else { + let Some(new_doc) = + VectorDocumentFromVersions::new(external_document_id, versions, doc_alloc, embedders)? + else { return Ok(None); }; Ok(Some(Self { new_doc: Some(new_doc), db: None })) diff --git a/crates/milli/src/vector/parsed_vectors.rs b/crates/milli/src/vector/parsed_vectors.rs index 5f8b30f1f..a45729abd 100644 --- a/crates/milli/src/vector/parsed_vectors.rs +++ b/crates/milli/src/vector/parsed_vectors.rs @@ -19,10 +19,54 @@ pub enum RawVectors<'doc> { ImplicitlyUserProvided(#[serde(borrow)] Option<&'doc RawValue>), } +pub enum RawVectorsError { + DeserializeSeq { index: usize, error: String }, + DeserializeKey { error: String }, + DeserializeRegenerate { error: String }, + DeserializeEmbeddings { error: String }, + UnknownField { field: String }, + MissingRegenerate, + WrongKind { kind: &'static str, value: String }, + Parsing(serde_json::Error), +} + +impl RawVectorsError { + pub fn msg(self, embedder_name: &str) -> String { + match self { + RawVectorsError::DeserializeSeq { index, error } => format!( + "Could not parse `._vectors.{embedder_name}[{index}]`: {error}" + ), + RawVectorsError::DeserializeKey { error } => format!( + "Could not parse a field at `._vectors.{embedder_name}`: {error}" + ), + RawVectorsError::DeserializeRegenerate { error } => format!( + "Could not parse `._vectors.{embedder_name}.regenerate`: {error}" + ), + RawVectorsError::DeserializeEmbeddings { error } => format!( + "Could not parse `._vectors.{embedder_name}.embeddings`: {error}" + ), + RawVectorsError::UnknownField { field } => format!( + "Unexpected field `._vectors.{embedder_name}.{field}`\n \ + \t - note: the allowed fields are `regenerate` and `embeddings`" + ), + RawVectorsError::MissingRegenerate => format!( + "Missing field `._vectors.{embedder_name}.regenerate`\n \ + \t - note: `._vectors.{embedder_name}` must be an array of floats, an array of arrays of floats, or an object with field `regenerate`" + ), + RawVectorsError::WrongKind { kind, value } => format!( + "Expected `._vectors.{embedder_name}` to be an array of floats, an array of arrays of floats, or an object with at least the field `regenerate`, but got the {kind} `{value}`" + ), + RawVectorsError::Parsing(error) => format!( + "Could not parse `._vectors.{embedder_name}`: {error}" + ), + } + } +} + impl<'doc> RawVectors<'doc> { - pub fn from_raw_value(raw: &'doc RawValue) -> Result { + pub fn from_raw_value(raw: &'doc RawValue) -> Result { use serde::de::Deserializer as _; - Ok(match raw.deserialize_any(RawVectorsVisitor)? { + Ok(match raw.deserialize_any(RawVectorsVisitor).map_err(RawVectorsError::Parsing)?? { RawVectorsVisitorValue::ImplicitNone => RawVectors::ImplicitlyUserProvided(None), RawVectorsVisitorValue::Implicit => RawVectors::ImplicitlyUserProvided(Some(raw)), RawVectorsVisitorValue::Explicit { regenerate, embeddings } => { @@ -41,7 +85,7 @@ enum RawVectorsVisitorValue<'doc> { } impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor { - type Value = RawVectorsVisitorValue<'doc>; + type Value = std::result::Result, RawVectorsError>; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { write!(formatter, "a map containing at least `regenerate`, or an array of floats`") @@ -51,7 +95,7 @@ impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor { where E: serde::de::Error, { - Ok(RawVectorsVisitorValue::ImplicitNone) + Ok(Ok(RawVectorsVisitorValue::ImplicitNone)) } fn visit_some(self, deserializer: D) -> Result @@ -65,42 +109,150 @@ impl<'doc> serde::de::Visitor<'doc> for RawVectorsVisitor { where E: serde::de::Error, { - Ok(RawVectorsVisitorValue::ImplicitNone) + Ok(Ok(RawVectorsVisitorValue::ImplicitNone)) } fn visit_seq(self, mut seq: A) -> Result where A: serde::de::SeqAccess<'doc>, { + let mut index = 0; // must consume all elements or parsing fails - while let Some(_) = seq.next_element::<&RawValue>()? {} - Ok(RawVectorsVisitorValue::Implicit) + loop { + match seq.next_element::<&RawValue>() { + Ok(Some(_)) => index += 1, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeSeq { + index, + error: error.to_string(), + })) + } + Ok(None) => break, + }; + } + Ok(Ok(RawVectorsVisitorValue::Implicit)) } fn visit_map(self, mut map: A) -> Result where A: serde::de::MapAccess<'doc>, { - use serde::de::Error as _; let mut regenerate = None; let mut embeddings = None; - while let Some(s) = map.next_key()? { - match s { - "regenerate" => { - let value: bool = map.next_value()?; + loop { + match map.next_key::<&str>() { + Ok(Some("regenerate")) => { + let value: bool = match map.next_value() { + Ok(value) => value, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeRegenerate { + error: error.to_string(), + })) + } + }; regenerate = Some(value); } - "embeddings" => { - let value: &RawValue = map.next_value()?; + Ok(Some("embeddings")) => { + let value: &RawValue = match map.next_value() { + Ok(value) => value, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeEmbeddings { + error: error.to_string(), + })) + } + }; embeddings = Some(value); } - other => return Err(A::Error::unknown_field(other, &["regenerate", "embeddings"])), + Ok(Some(other)) => { + return Ok(Err(RawVectorsError::UnknownField { field: other.to_string() })) + } + Ok(None) => break, + Err(error) => { + return Ok(Err(RawVectorsError::DeserializeKey { error: error.to_string() })) + } } } let Some(regenerate) = regenerate else { - return Err(A::Error::missing_field("regenerate")); + return Ok(Err(RawVectorsError::MissingRegenerate)); }; - Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings }) + Ok(Ok(RawVectorsVisitorValue::Explicit { regenerate, embeddings })) + } + + fn visit_bool(self, v: bool) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "boolean", value: v.to_string() })) + } + + fn visit_i64(self, v: i64) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_i128(self, v: i128) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_u64(self, v: u64) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_u128(self, v: u128) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "integer", value: v.to_string() })) + } + + fn visit_f64(self, v: f64) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "number", value: v.to_string() })) + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "string", value: v.to_string() })) + } + + fn visit_string(self, v: String) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "string", value: v })) + } + + fn visit_bytes(self, v: &[u8]) -> Result + where + E: serde::de::Error, + { + Ok(Err(RawVectorsError::WrongKind { kind: "bytes", value: format!("{v:?}") })) + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::Deserializer<'doc>, + { + deserializer.deserialize_any(self) + } + + fn visit_enum(self, _data: A) -> Result + where + A: serde::de::EnumAccess<'doc>, + { + Ok(Err(RawVectorsError::WrongKind { kind: "enum", value: "a variant".to_string() })) } }