Return a user error when the _vectors type is invalid

This commit is contained in:
Kerollmops 2023-06-20 16:18:24 +02:00 committed by Clément Renault
parent 7aa1275337
commit 531748c536
No known key found for this signature in database
GPG Key ID: 92ADA4E935E71FA4
4 changed files with 27 additions and 7 deletions

View File

@ -218,6 +218,7 @@ MissingDocumentFilter , InvalidRequest , BAD_REQUEST ;
InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ;
InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ;
InvalidVectorDimensions , InvalidRequest , BAD_REQUEST ;
InvalidVectorsType , InvalidRequest , BAD_REQUEST ;
InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
@ -337,6 +338,7 @@ impl ErrorCode for milli::Error {
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType,
UserError::SortError(_) => Code::InvalidSearchSort,
UserError::InvalidMinTypoWordLenSetting(_, _) => {
Code::InvalidSettingsTypoTolerance

View File

@ -112,6 +112,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
InvalidGeoField(#[from] GeoError),
#[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)]
InvalidVectorDimensions { expected: usize, found: usize },
#[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
InvalidVectorsType { document_id: Value, value: Value },
#[error("{0}")]
InvalidFilter(String),
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]

View File

@ -3,9 +3,10 @@ use std::fs::File;
use std::io;
use bytemuck::cast_slice;
use serde_json::from_slice;
use serde_json::{from_slice, Value};
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::error::UserError;
use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors};
/// Extracts the embedding vector contained in each document under the `_vectors` field.
@ -15,6 +16,7 @@ use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors};
pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters,
primary_key_id: FieldId,
vectors_fid: FieldId,
) -> Result<grenad::Reader<File>> {
let mut writer = create_writer(
@ -27,14 +29,23 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
while let Some((docid_bytes, value)) = cursor.move_on_next()? {
let obkv = obkv::KvReader::new(value);
// since we only needs the primary key when we throw an error we create this getter to
// lazily get it when needed
let document_id = || -> Value {
let document_id = obkv.get(primary_key_id).unwrap();
serde_json::from_slice(document_id).unwrap()
};
// first we retrieve the _vectors field
if let Some(vectors) = obkv.get(vectors_fid) {
// extract the vectors
// TODO return a user error before unwrapping
let vectors = from_slice(vectors)
.map_err(InternalError::SerdeJson)
.map(VectorOrArrayOfVectors::into_array_of_vectors)
.unwrap();
let vectors = match from_slice(vectors) {
Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors),
Err(_) => return Err(UserError::InvalidVectorsType {
document_id: document_id(),
value: from_slice(vectors).map_err(InternalError::SerdeJson)?,
}.into()),
};
for (i, vector) in vectors.into_iter().enumerate() {
match u16::try_from(i) {

View File

@ -316,7 +316,12 @@ fn send_and_extract_flattened_documents_data(
let documents_chunk_cloned = flattened_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || {
let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id);
let result = extract_vector_points(
documents_chunk_cloned,
indexer,
primary_key_id,
vectors_field_id,
);
let _ = match result {
Ok(vector_points) => {
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))