Accept multiple vectors by documents using the _vectors field

This commit is contained in:
Kerollmops 2023-06-20 11:17:20 +02:00 committed by Clément Renault
parent 1b2923f7c0
commit 321ec5f3fa
No known key found for this signature in database
GPG key ID: 92ADA4E935E71FA4
4 changed files with 31 additions and 16 deletions

View file

@ -1,20 +1,22 @@
use std::convert::TryFrom;
use std::fs::File;
use std::io;
use bytemuck::cast_slice;
use either::Either;
use serde_json::from_slice;
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::{FieldId, InternalError, Result};
/// Extracts the embedding vector contained in each document under the `_vector` field.
/// Extracts the embedding vector contained in each document under the `_vectors` field.
///
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
#[logging_timer::time]
pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters,
vector_fid: FieldId,
vectors_fid: FieldId,
) -> Result<grenad::Reader<File>> {
let mut writer = create_writer(
indexer.chunk_compression_type,
@ -26,14 +28,26 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
while let Some((docid_bytes, value)) = cursor.move_on_next()? {
let obkv = obkv::KvReader::new(value);
// first we get the _vector field
if let Some(vector) = obkv.get(vector_fid) {
// try to extract the vector
let vector: Vec<f32> = from_slice(vector).map_err(InternalError::SerdeJson).unwrap();
let bytes = cast_slice(&vector);
writer.insert(docid_bytes, bytes)?;
// first we retrieve the _vectors field
if let Some(vectors) = obkv.get(vectors_fid) {
// extract the vectors
let vectors: Either<Vec<Vec<f32>>, Vec<f32>> =
from_slice(vectors).map_err(InternalError::SerdeJson).unwrap();
let vectors = vectors.map_right(|v| vec![v]).into_inner();
for (i, vector) in vectors.into_iter().enumerate() {
match u16::try_from(i) {
Ok(i) => {
let mut key = docid_bytes.to_vec();
key.extend_from_slice(&i.to_ne_bytes());
let bytes = cast_slice(&vector);
writer.insert(key, bytes)?;
}
Err(_) => continue,
}
}
}
// else => the _vector object was `null`, there is nothing to do
// else => the `_vectors` object was `null`, there is nothing to do
}
writer_into_reader(writer)

View file

@ -47,7 +47,7 @@ pub(crate) fn data_from_obkv_documents(
faceted_fields: HashSet<FieldId>,
primary_key_id: FieldId,
geo_fields_ids: Option<(FieldId, FieldId)>,
vector_field_id: Option<FieldId>,
vectors_field_id: Option<FieldId>,
stop_words: Option<fst::Set<&[u8]>>,
max_positions_per_attributes: Option<u32>,
exact_attributes: HashSet<FieldId>,
@ -72,7 +72,7 @@ pub(crate) fn data_from_obkv_documents(
&faceted_fields,
primary_key_id,
geo_fields_ids,
vector_field_id,
vectors_field_id,
&stop_words,
max_positions_per_attributes,
)