MeiliSearch/milli/src/update/index_documents/extract/extract_vector_points.rs

428 lines
16 KiB
Rust

use std::cmp::Ordering;
use std::convert::{TryFrom, TryInto};
use std::fs::File;
use std::io::{self, BufReader, BufWriter};
use std::mem::size_of;
use std::str::from_utf8;
use std::sync::Arc;
use bytemuck::cast_slice;
use grenad::Writer;
use itertools::EitherOrBoth;
use ordered_float::OrderedFloat;
use serde_json::{from_slice, Value};
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::error::UserError;
use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at;
use crate::vector::Embedder;
use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors};
/// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
pub struct ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
pub manual_vectors: grenad::Reader<BufReader<File>>,
// docid -> ()
pub remove_vectors: grenad::Reader<BufReader<File>>,
// docid -> prompt
pub prompts: grenad::Reader<BufReader<File>>,
}
enum VectorStateDelta {
NoChange,
// Remove all vectors, generated or manual, from this document
NowRemoved,
// Add the manually specified vectors, passed in the other grenad
// Remove any previously generated vectors
// Note: changing the value of the manually specified vector **should not record** this delta
WasGeneratedNowManual(Vec<Vec<f32>>),
ManualDelta(Vec<Vec<f32>>, Vec<Vec<f32>>),
// Add the vector computed from the specified prompt
// Remove any previous vector
// Note: changing the value of the prompt **does require** recording this delta
NowGenerated(String),
}
impl VectorStateDelta {
fn into_values(self) -> (bool, String, (Vec<Vec<f32>>, Vec<Vec<f32>>)) {
match self {
VectorStateDelta::NoChange => Default::default(),
VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()),
VectorStateDelta::WasGeneratedNowManual(add) => {
(true, Default::default(), (Default::default(), add))
}
VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)),
VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()),
}
}
}
/// Extracts the embedding vector contained in each document under the `_vectors` field.
///
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
#[tracing::instrument(level = "trace", skip_all, target = "indexing::extract")]
pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters,
field_id_map: &FieldsIdsMap,
prompt: &Prompt,
embedder_name: &str,
) -> Result<ExtractedVectorPoints> {
puffin::profile_function!();
// (docid, _index) -> KvWriterDelAdd -> Vector
let mut manual_vectors_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
// (docid) -> (prompt)
let mut prompts_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
// (docid) -> ()
let mut remove_vectors_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let vectors_fid = field_id_map.id("_vectors");
let mut key_buffer = Vec::new();
let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
// this must always be serialized as (docid, external_docid);
let (docid_bytes, external_id_bytes) =
try_split_at(key, std::mem::size_of::<DocumentId>()).unwrap();
debug_assert!(from_utf8(external_id_bytes).is_ok());
let obkv = obkv::KvReader::new(value);
key_buffer.clear();
key_buffer.extend_from_slice(docid_bytes);
// since we only needs the primary key when we throw an error we create this getter to
// lazily get it when needed
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
let vectors_field = vectors_fid
.and_then(|vectors_fid| obkv.get(vectors_fid))
.map(KvReaderDelAdd::new)
.map(|obkv| to_vector_maps(obkv, document_id))
.transpose()?;
let (del_map, add_map) = vectors_field.unzip();
let del_map = del_map.flatten();
let add_map = add_map.flatten();
let del_value = del_map.and_then(|mut map| map.remove(embedder_name));
let add_value = add_map.and_then(|mut map| map.remove(embedder_name));
let delta = match (del_value, add_value) {
(Some(old), Some(new)) => {
// no autogeneration
let del_vectors = extract_vectors(old, document_id, embedder_name)?;
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > usize::from(u8::MAX) {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::ManualDelta(del_vectors, add_vectors)
}
(Some(_old), None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
}
}
(None, Some(new)) => {
// was possibly autogenerated, remove all vectors for that document
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > usize::from(u8::MAX) {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::WasGeneratedNowManual(add_vectors)
}
(None, None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
tracing::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
tracing::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
}
}
};
// and we finally push the unique vectors into the writer
push_vectors_diff(
&mut remove_vectors_writer,
&mut prompts_writer,
&mut manual_vectors_writer,
&mut key_buffer,
delta,
)?;
}
Ok(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt
prompts: writer_into_reader(prompts_writer)?,
})
}
fn to_vector_maps(
obkv: KvReaderDelAdd,
document_id: impl Fn() -> Value,
) -> Result<(Option<serde_json::Map<String, Value>>, Option<serde_json::Map<String, Value>>)> {
let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?;
let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?;
Ok((del, add))
}
fn to_vector_map(
obkv: KvReaderDelAdd,
side: DelAdd,
document_id: &impl Fn() -> Value,
) -> Result<Option<serde_json::Map<String, Value>>> {
Ok(if let Some(value) = obkv.get(side) {
let Ok(value) = from_slice(value) else {
let value = from_slice(value).map_err(InternalError::SerdeJson)?;
return Err(crate::Error::UserError(UserError::InvalidVectorsMapType {
document_id: document_id(),
value,
}));
};
Some(value)
} else {
None
})
}
/// Computes the diff between both Del and Add numbers and
/// only inserts the parts that differ in the sorter.
fn push_vectors_diff(
remove_vectors_writer: &mut Writer<BufWriter<File>>,
prompts_writer: &mut Writer<BufWriter<File>>,
manual_vectors_writer: &mut Writer<BufWriter<File>>,
key_buffer: &mut Vec<u8>,
delta: VectorStateDelta,
) -> Result<()> {
puffin::profile_function!();
let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values();
if must_remove {
key_buffer.truncate(TRUNCATE_SIZE);
remove_vectors_writer.insert(&key_buffer, [])?;
}
if !prompt.is_empty() {
key_buffer.truncate(TRUNCATE_SIZE);
prompts_writer.insert(&key_buffer, prompt.as_bytes())?;
}
// We sort and dedup the vectors
del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
del_vectors.dedup_by(|a, b| compare_vectors(a, b).is_eq());
add_vectors.dedup_by(|a, b| compare_vectors(a, b).is_eq());
let merged_vectors_iter =
itertools::merge_join_by(del_vectors, add_vectors, |del, add| compare_vectors(del, add));
// insert vectors into the writer
for (i, eob) in merged_vectors_iter.into_iter().enumerate().take(u16::MAX as usize) {
// Generate the key by extending the unique index to it.
key_buffer.truncate(TRUNCATE_SIZE);
let index = u16::try_from(i).unwrap();
key_buffer.extend_from_slice(&index.to_be_bytes());
match eob {
EitherOrBoth::Both(_, _) => (), // no need to touch anything
EitherOrBoth::Left(vector) => {
// We insert only the Del part of the Obkv to inform
// that we only want to remove all those vectors.
let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Deletion, cast_slice(&vector))?;
let bytes = obkv.into_inner()?;
manual_vectors_writer.insert(&key_buffer, bytes)?;
}
EitherOrBoth::Right(vector) => {
// We insert only the Add part of the Obkv to inform
// that we only want to remove all those vectors.
let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Addition, cast_slice(&vector))?;
let bytes = obkv.into_inner()?;
manual_vectors_writer.insert(&key_buffer, bytes)?;
}
}
}
Ok(())
}
/// Compares two vectors by using the OrderingFloat helper.
fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
a.iter().copied().map(OrderedFloat).cmp(b.iter().copied().map(OrderedFloat))
}
/// Extracts the vectors from a JSON value.
fn extract_vectors(
value: Value,
document_id: impl Fn() -> Value,
name: &str,
) -> Result<Vec<Vec<f32>>> {
// FIXME: ugly clone of the vectors here
match serde_json::from_value(value.clone()) {
Ok(vectors) => {
Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default())
}
Err(_) => Err(UserError::InvalidVectorsType {
document_id: document_id(),
value,
subfield: name.to_owned(),
}
.into()),
}
}
#[tracing::instrument(level = "trace", skip_all, target = "indexing::extract")]
pub fn extract_embeddings<R: io::Read + io::Seek>(
// docid, prompt
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
request_threads: &rayon::ThreadPool,
) -> Result<grenad::Reader<BufReader<File>>> {
puffin::profile_function!();
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
// docid, state with embedding
let mut state_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let mut chunks = Vec::with_capacity(n_chunks);
let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk);
let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk);
let mut chunks_ids = Vec::with_capacity(n_chunks);
let mut cursor = prompt_reader.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
// SAFETY: precondition, the grenad value was saved from a string
let prompt = unsafe { std::str::from_utf8_unchecked(value) };
if current_chunk.len() == current_chunk.capacity() {
chunks.push(std::mem::replace(
&mut current_chunk,
Vec::with_capacity(n_vectors_per_chunk),
));
chunks_ids.push(std::mem::replace(
&mut current_chunk_ids,
Vec::with_capacity(n_vectors_per_chunk),
));
};
current_chunk.push(prompt.to_owned());
current_chunk_ids.push(docid);
if chunks.len() == chunks.capacity() {
let chunked_embeds = embedder
.embed_chunks(
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
request_threads,
)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
chunks_ids.clear();
}
}
// send last chunk
if !chunks.is_empty() {
let chunked_embeds = embedder
.embed_chunks(std::mem::take(&mut chunks), request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
if !current_chunk.is_empty() {
let embeds = embedder
.embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
if let Some(embeds) = embeds.first() {
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
}
writer_into_reader(state_writer)
}