2023-06-20 11:17:20 +02:00
|
|
|
use std::convert::TryFrom;
|
2023-06-08 11:35:36 +02:00
|
|
|
use std::fs::File;
|
|
|
|
use std::io;
|
|
|
|
|
|
|
|
use bytemuck::cast_slice;
|
2023-06-20 16:18:24 +02:00
|
|
|
use serde_json::{from_slice, Value};
|
2023-06-08 11:35:36 +02:00
|
|
|
|
|
|
|
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
2023-06-20 16:18:24 +02:00
|
|
|
use crate::error::UserError;
|
2023-06-20 15:54:28 +02:00
|
|
|
use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors};
|
2023-06-08 11:35:36 +02:00
|
|
|
|
2023-06-20 11:17:20 +02:00
|
|
|
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
2023-06-08 11:35:36 +02:00
|
|
|
///
|
|
|
|
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
|
|
|
|
#[logging_timer::time]
|
|
|
|
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|
|
|
obkv_documents: grenad::Reader<R>,
|
|
|
|
indexer: GrenadParameters,
|
2023-06-20 16:18:24 +02:00
|
|
|
primary_key_id: FieldId,
|
2023-06-20 11:17:20 +02:00
|
|
|
vectors_fid: FieldId,
|
2023-06-08 11:35:36 +02:00
|
|
|
) -> Result<grenad::Reader<File>> {
|
|
|
|
let mut writer = create_writer(
|
|
|
|
indexer.chunk_compression_type,
|
|
|
|
indexer.chunk_compression_level,
|
|
|
|
tempfile::tempfile()?,
|
|
|
|
);
|
|
|
|
|
|
|
|
let mut cursor = obkv_documents.into_cursor()?;
|
|
|
|
while let Some((docid_bytes, value)) = cursor.move_on_next()? {
|
|
|
|
let obkv = obkv::KvReader::new(value);
|
|
|
|
|
2023-06-20 16:18:24 +02:00
|
|
|
// since we only needs the primary key when we throw an error we create this getter to
|
|
|
|
// lazily get it when needed
|
|
|
|
let document_id = || -> Value {
|
|
|
|
let document_id = obkv.get(primary_key_id).unwrap();
|
|
|
|
serde_json::from_slice(document_id).unwrap()
|
|
|
|
};
|
|
|
|
|
2023-06-20 11:17:20 +02:00
|
|
|
// first we retrieve the _vectors field
|
|
|
|
if let Some(vectors) = obkv.get(vectors_fid) {
|
|
|
|
// extract the vectors
|
2023-06-20 16:18:24 +02:00
|
|
|
let vectors = match from_slice(vectors) {
|
|
|
|
Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors),
|
2023-06-20 16:26:00 +02:00
|
|
|
Err(_) => {
|
|
|
|
return Err(UserError::InvalidVectorsType {
|
|
|
|
document_id: document_id(),
|
|
|
|
value: from_slice(vectors).map_err(InternalError::SerdeJson)?,
|
|
|
|
}
|
|
|
|
.into())
|
|
|
|
}
|
2023-06-20 16:18:24 +02:00
|
|
|
};
|
2023-06-20 11:17:20 +02:00
|
|
|
|
|
|
|
for (i, vector) in vectors.into_iter().enumerate() {
|
|
|
|
match u16::try_from(i) {
|
|
|
|
Ok(i) => {
|
|
|
|
let mut key = docid_bytes.to_vec();
|
|
|
|
key.extend_from_slice(&i.to_ne_bytes());
|
|
|
|
let bytes = cast_slice(&vector);
|
|
|
|
writer.insert(key, bytes)?;
|
|
|
|
}
|
|
|
|
Err(_) => continue,
|
|
|
|
}
|
|
|
|
}
|
2023-06-08 11:35:36 +02:00
|
|
|
}
|
2023-06-20 11:17:20 +02:00
|
|
|
// else => the `_vectors` object was `null`, there is nothing to do
|
2023-06-08 11:35:36 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
writer_into_reader(writer)
|
|
|
|
}
|