2023-06-20 11:17:20 +02:00
|
|
|
use std::convert::TryFrom;
|
2023-06-08 11:35:36 +02:00
|
|
|
use std::fs::File;
|
2023-09-28 16:26:01 +02:00
|
|
|
use std::io::{self, BufReader};
|
2023-06-08 11:35:36 +02:00
|
|
|
|
|
|
|
use bytemuck::cast_slice;
|
2023-06-20 16:18:24 +02:00
|
|
|
use serde_json::{from_slice, Value};
|
2023-06-08 11:35:36 +02:00
|
|
|
|
|
|
|
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
2023-06-20 16:18:24 +02:00
|
|
|
use crate::error::UserError;
|
2023-10-31 17:44:42 +01:00
|
|
|
use crate::update::index_documents::helpers::try_split_at;
|
|
|
|
use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors};
|
2023-06-08 11:35:36 +02:00
|
|
|
|
2023-06-20 11:17:20 +02:00
|
|
|
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
2023-06-08 11:35:36 +02:00
|
|
|
///
|
|
|
|
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
|
|
|
|
#[logging_timer::time]
|
|
|
|
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|
|
|
obkv_documents: grenad::Reader<R>,
|
|
|
|
indexer: GrenadParameters,
|
2023-06-20 11:17:20 +02:00
|
|
|
vectors_fid: FieldId,
|
2023-09-28 16:26:01 +02:00
|
|
|
) -> Result<grenad::Reader<BufReader<File>>> {
|
2023-07-10 18:41:54 +02:00
|
|
|
puffin::profile_function!();
|
|
|
|
|
2023-06-08 11:35:36 +02:00
|
|
|
let mut writer = create_writer(
|
|
|
|
indexer.chunk_compression_type,
|
|
|
|
indexer.chunk_compression_level,
|
|
|
|
tempfile::tempfile()?,
|
|
|
|
);
|
|
|
|
|
|
|
|
let mut cursor = obkv_documents.into_cursor()?;
|
2023-10-31 17:44:42 +01:00
|
|
|
while let Some((key, value)) = cursor.move_on_next()? {
|
|
|
|
// this must always be serialized as (docid, external_docid);
|
|
|
|
let (docid_bytes, external_id_bytes) =
|
|
|
|
try_split_at(key, std::mem::size_of::<DocumentId>()).unwrap();
|
|
|
|
debug_assert!(std::str::from_utf8(external_id_bytes).is_ok());
|
|
|
|
|
2023-06-08 11:35:36 +02:00
|
|
|
let obkv = obkv::KvReader::new(value);
|
|
|
|
|
2023-06-20 16:18:24 +02:00
|
|
|
// since we only needs the primary key when we throw an error we create this getter to
|
|
|
|
// lazily get it when needed
|
2023-10-31 17:44:42 +01:00
|
|
|
let document_id = || -> Value { std::str::from_utf8(external_id_bytes).unwrap().into() };
|
2023-06-20 16:18:24 +02:00
|
|
|
|
2023-06-20 11:17:20 +02:00
|
|
|
// first we retrieve the _vectors field
|
|
|
|
if let Some(vectors) = obkv.get(vectors_fid) {
|
|
|
|
// extract the vectors
|
2023-06-20 16:18:24 +02:00
|
|
|
let vectors = match from_slice(vectors) {
|
|
|
|
Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors),
|
2023-06-20 16:26:00 +02:00
|
|
|
Err(_) => {
|
|
|
|
return Err(UserError::InvalidVectorsType {
|
|
|
|
document_id: document_id(),
|
|
|
|
value: from_slice(vectors).map_err(InternalError::SerdeJson)?,
|
|
|
|
}
|
|
|
|
.into())
|
|
|
|
}
|
2023-06-20 16:18:24 +02:00
|
|
|
};
|
2023-06-20 11:17:20 +02:00
|
|
|
|
2023-08-14 16:03:55 +02:00
|
|
|
if let Some(vectors) = vectors {
|
|
|
|
for (i, vector) in vectors.into_iter().enumerate().take(u16::MAX as usize) {
|
|
|
|
let index = u16::try_from(i).unwrap();
|
|
|
|
let mut key = docid_bytes.to_vec();
|
|
|
|
key.extend_from_slice(&index.to_be_bytes());
|
|
|
|
let bytes = cast_slice(&vector);
|
|
|
|
writer.insert(key, bytes)?;
|
|
|
|
}
|
2023-06-20 11:17:20 +02:00
|
|
|
}
|
2023-06-08 11:35:36 +02:00
|
|
|
}
|
2023-06-20 11:17:20 +02:00
|
|
|
// else => the `_vectors` object was `null`, there is nothing to do
|
2023-06-08 11:35:36 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
writer_into_reader(writer)
|
|
|
|
}
|