mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-23 21:20:24 +01:00
Fix the vector extractions for the diff indexing
This commit is contained in:
parent
1c39459cf4
commit
ff522c919d
@ -1,15 +1,25 @@
|
|||||||
|
use std::cmp::Ordering;
|
||||||
use std::convert::TryFrom;
|
use std::convert::TryFrom;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufReader};
|
use std::io::{self, BufReader, BufWriter};
|
||||||
|
use std::mem::size_of;
|
||||||
|
use std::str::from_utf8;
|
||||||
|
|
||||||
use bytemuck::cast_slice;
|
use bytemuck::cast_slice;
|
||||||
|
use grenad::Writer;
|
||||||
|
use itertools::EitherOrBoth;
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
use serde_json::{from_slice, Value};
|
use serde_json::{from_slice, Value};
|
||||||
|
|
||||||
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
||||||
use crate::error::UserError;
|
use crate::error::UserError;
|
||||||
|
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
|
||||||
use crate::update::index_documents::helpers::try_split_at;
|
use crate::update::index_documents::helpers::try_split_at;
|
||||||
use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors};
|
use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors};
|
||||||
|
|
||||||
|
/// The length of the elements that are always in the buffer when inserting new values.
|
||||||
|
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
|
||||||
|
|
||||||
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
||||||
///
|
///
|
||||||
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
|
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
|
||||||
@ -27,45 +37,112 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
|||||||
tempfile::tempfile()?,
|
tempfile::tempfile()?,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let mut key_buffer = Vec::new();
|
||||||
let mut cursor = obkv_documents.into_cursor()?;
|
let mut cursor = obkv_documents.into_cursor()?;
|
||||||
while let Some((key, value)) = cursor.move_on_next()? {
|
while let Some((key, value)) = cursor.move_on_next()? {
|
||||||
// this must always be serialized as (docid, external_docid);
|
// this must always be serialized as (docid, external_docid);
|
||||||
let (docid_bytes, external_id_bytes) =
|
let (docid_bytes, external_id_bytes) =
|
||||||
try_split_at(key, std::mem::size_of::<DocumentId>()).unwrap();
|
try_split_at(key, std::mem::size_of::<DocumentId>()).unwrap();
|
||||||
debug_assert!(std::str::from_utf8(external_id_bytes).is_ok());
|
debug_assert!(from_utf8(external_id_bytes).is_ok());
|
||||||
|
|
||||||
let obkv = obkv::KvReader::new(value);
|
let obkv = obkv::KvReader::new(value);
|
||||||
|
key_buffer.clear();
|
||||||
|
key_buffer.extend_from_slice(docid_bytes);
|
||||||
|
|
||||||
// since we only needs the primary key when we throw an error we create this getter to
|
// since we only needs the primary key when we throw an error we create this getter to
|
||||||
// lazily get it when needed
|
// lazily get it when needed
|
||||||
let document_id = || -> Value { std::str::from_utf8(external_id_bytes).unwrap().into() };
|
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
|
||||||
|
|
||||||
// first we retrieve the _vectors field
|
// first we retrieve the _vectors field
|
||||||
if let Some(vectors) = obkv.get(vectors_fid) {
|
if let Some(value) = obkv.get(vectors_fid) {
|
||||||
// extract the vectors
|
let vectors_obkv = KvReaderDelAdd::new(value);
|
||||||
let vectors = match from_slice(vectors) {
|
|
||||||
Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors),
|
|
||||||
Err(_) => {
|
|
||||||
return Err(UserError::InvalidVectorsType {
|
|
||||||
document_id: document_id(),
|
|
||||||
value: from_slice(vectors).map_err(InternalError::SerdeJson)?,
|
|
||||||
}
|
|
||||||
.into())
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
if let Some(vectors) = vectors {
|
// then we extract the values
|
||||||
for (i, vector) in vectors.into_iter().enumerate().take(u16::MAX as usize) {
|
let del_vectors = vectors_obkv
|
||||||
let index = u16::try_from(i).unwrap();
|
.get(DelAdd::Deletion)
|
||||||
let mut key = docid_bytes.to_vec();
|
.map(|vectors| extract_vectors(vectors, document_id))
|
||||||
key.extend_from_slice(&index.to_be_bytes());
|
.transpose()?
|
||||||
let bytes = cast_slice(&vector);
|
.flatten();
|
||||||
writer.insert(key, bytes)?;
|
let add_vectors = vectors_obkv
|
||||||
|
.get(DelAdd::Addition)
|
||||||
|
.map(|vectors| extract_vectors(vectors, document_id))
|
||||||
|
.transpose()?
|
||||||
|
.flatten();
|
||||||
|
|
||||||
|
// and we finally push the unique vectors into the writer
|
||||||
|
push_vectors_diff(
|
||||||
|
&mut writer,
|
||||||
|
&mut key_buffer,
|
||||||
|
del_vectors.unwrap_or_default(),
|
||||||
|
add_vectors.unwrap_or_default(),
|
||||||
|
)?;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// else => the `_vectors` object was `null`, there is nothing to do
|
|
||||||
}
|
|
||||||
|
|
||||||
writer_into_reader(writer)
|
writer_into_reader(writer)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Computes the diff between both Del and Add numbers and
|
||||||
|
/// only inserts the parts that differ in the sorter.
|
||||||
|
fn push_vectors_diff(
|
||||||
|
writer: &mut Writer<BufWriter<File>>,
|
||||||
|
key_buffer: &mut Vec<u8>,
|
||||||
|
mut del_vectors: Vec<Vec<f32>>,
|
||||||
|
mut add_vectors: Vec<Vec<f32>>,
|
||||||
|
) -> Result<()> {
|
||||||
|
// We sort and dedup the vectors
|
||||||
|
del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
|
||||||
|
add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
|
||||||
|
del_vectors.dedup_by(|a, b| compare_vectors(a, b).is_eq());
|
||||||
|
add_vectors.dedup_by(|a, b| compare_vectors(a, b).is_eq());
|
||||||
|
|
||||||
|
let merged_vectors_iter =
|
||||||
|
itertools::merge_join_by(del_vectors, add_vectors, |del, add| compare_vectors(del, add));
|
||||||
|
|
||||||
|
// insert vectors into the writer
|
||||||
|
for (i, eob) in merged_vectors_iter.into_iter().enumerate().take(u16::MAX as usize) {
|
||||||
|
// Generate the key by extending the unique index to it.
|
||||||
|
key_buffer.truncate(TRUNCATE_SIZE);
|
||||||
|
let index = u16::try_from(i).unwrap();
|
||||||
|
key_buffer.extend_from_slice(&index.to_be_bytes());
|
||||||
|
|
||||||
|
match eob {
|
||||||
|
EitherOrBoth::Both(_, _) => (), // no need to touch anything
|
||||||
|
EitherOrBoth::Left(vector) => {
|
||||||
|
// We insert only the Del part of the Obkv to inform
|
||||||
|
// that we only want to remove all those vectors.
|
||||||
|
let mut obkv = KvWriterDelAdd::memory();
|
||||||
|
obkv.insert(DelAdd::Deletion, cast_slice(&vector))?;
|
||||||
|
let bytes = obkv.into_inner()?;
|
||||||
|
writer.insert(&key_buffer, bytes)?;
|
||||||
|
}
|
||||||
|
EitherOrBoth::Right(vector) => {
|
||||||
|
// We insert only the Add part of the Obkv to inform
|
||||||
|
// that we only want to remove all those vectors.
|
||||||
|
let mut obkv = KvWriterDelAdd::memory();
|
||||||
|
obkv.insert(DelAdd::Addition, cast_slice(&vector))?;
|
||||||
|
let bytes = obkv.into_inner()?;
|
||||||
|
writer.insert(&key_buffer, bytes)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Compares two vectors by using the OrderingFloat helper.
|
||||||
|
fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
|
||||||
|
a.iter().copied().map(OrderedFloat).cmp(b.iter().copied().map(OrderedFloat))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extracts the vectors from a JSON value.
|
||||||
|
fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Option<Vec<Vec<f32>>>> {
|
||||||
|
match from_slice(value) {
|
||||||
|
Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)),
|
||||||
|
Err(_) => Err(UserError::InvalidVectorsType {
|
||||||
|
document_id: document_id(),
|
||||||
|
value: from_slice(value).map_err(InternalError::SerdeJson)?,
|
||||||
|
}
|
||||||
|
.into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use std::collections::HashMap;
|
use std::collections::{HashMap, HashSet};
|
||||||
use std::convert::TryInto;
|
use std::convert::TryInto;
|
||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io::{self, BufReader};
|
use std::io::{self, BufReader};
|
||||||
@ -8,7 +8,9 @@ use charabia::{Language, Script};
|
|||||||
use grenad::MergerBuilder;
|
use grenad::MergerBuilder;
|
||||||
use heed::types::ByteSlice;
|
use heed::types::ByteSlice;
|
||||||
use heed::RwTxn;
|
use heed::RwTxn;
|
||||||
|
use log::error;
|
||||||
use obkv::{KvReader, KvWriter};
|
use obkv::{KvReader, KvWriter};
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
|
|
||||||
use super::helpers::{self, merge_ignore_values, valid_lmdb_key, CursorClonableMmap};
|
use super::helpers::{self, merge_ignore_values, valid_lmdb_key, CursorClonableMmap};
|
||||||
@ -22,10 +24,9 @@ use crate::index::Hnsw;
|
|||||||
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
|
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
|
||||||
use crate::update::facet::FacetsUpdate;
|
use crate::update::facet::FacetsUpdate;
|
||||||
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
|
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
|
||||||
use crate::update::index_documents::validate_document_id_value;
|
|
||||||
use crate::{
|
use crate::{
|
||||||
lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, FieldId, GeoPoint, Index, InternalError,
|
lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, FieldId, GeoPoint, Index, Result,
|
||||||
Result, SerializationError, BEU32,
|
SerializationError, BEU32,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) enum TypedChunk {
|
pub(crate) enum TypedChunk {
|
||||||
@ -366,44 +367,70 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||||||
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
|
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
|
||||||
}
|
}
|
||||||
TypedChunk::VectorPoints(vector_points) => {
|
TypedChunk::VectorPoints(vector_points) => {
|
||||||
let (pids, mut points): (Vec<_>, Vec<_>) = match index.vector_hnsw(wtxn)? {
|
let mut vectors_set = HashSet::new();
|
||||||
Some(hnsw) => hnsw.iter().map(|(pid, point)| (pid, point.clone())).unzip(),
|
// We extract and store the previous vectors
|
||||||
None => Default::default(),
|
if let Some(hnsw) = index.vector_hnsw(wtxn)? {
|
||||||
};
|
for (pid, point) in hnsw.iter() {
|
||||||
|
let pid_key = BEU32::new(pid.into_inner());
|
||||||
// Convert the PointIds into DocumentIds
|
let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap().get();
|
||||||
let mut docids = Vec::new();
|
let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect();
|
||||||
for pid in pids {
|
vectors_set.insert((docid, vector));
|
||||||
let docid =
|
}
|
||||||
index.vector_id_docid.get(wtxn, &BEU32::new(pid.into_inner()))?.unwrap();
|
|
||||||
docids.push(docid.get());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut expected_dimensions = points.get(0).map(|p| p.len());
|
|
||||||
let mut cursor = vector_points.into_cursor()?;
|
let mut cursor = vector_points.into_cursor()?;
|
||||||
while let Some((key, value)) = cursor.move_on_next()? {
|
while let Some((key, value)) = cursor.move_on_next()? {
|
||||||
// convert the key back to a u32 (4 bytes)
|
// convert the key back to a u32 (4 bytes)
|
||||||
let (left, _index) = try_split_array_at(key).unwrap();
|
let (left, _index) = try_split_array_at(key).unwrap();
|
||||||
let docid = DocumentId::from_be_bytes(left);
|
let docid = DocumentId::from_be_bytes(left);
|
||||||
// convert the vector back to a Vec<f32>
|
|
||||||
let vector: Vec<f32> = pod_collect_to_vec(value);
|
|
||||||
|
|
||||||
// TODO Inform the user about the document that has a wrong `_vectors`
|
let vector_deladd_obkv = KvReaderDelAdd::new(value);
|
||||||
let found = vector.len();
|
if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) {
|
||||||
let expected = *expected_dimensions.get_or_insert(found);
|
// convert the vector back to a Vec<f32>
|
||||||
if expected != found {
|
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
|
||||||
return Err(UserError::InvalidVectorDimensions { expected, found }.into());
|
let key = (docid, vector);
|
||||||
|
if !vectors_set.remove(&key) {
|
||||||
|
error!("Unable to delete the vector: {:?}", key.1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
|
||||||
|
// convert the vector back to a Vec<f32>
|
||||||
|
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
|
||||||
|
vectors_set.insert((docid, vector));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract the most common vector dimension
|
||||||
|
let expected_dimension_size = {
|
||||||
|
let mut dims = HashMap::new();
|
||||||
|
vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1);
|
||||||
|
dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Ensure that the vector lenghts are correct and
|
||||||
|
// prepare the vectors before inserting them in the HNSW.
|
||||||
|
let mut points = Vec::new();
|
||||||
|
let mut docids = Vec::new();
|
||||||
|
for (docid, vector) in vectors_set {
|
||||||
|
if expected_dimension_size.map_or(false, |expected| expected != vector.len()) {
|
||||||
|
return Err(UserError::InvalidVectorDimensions {
|
||||||
|
expected: expected_dimension_size.unwrap_or(vector.len()),
|
||||||
|
found: vector.len(),
|
||||||
|
}
|
||||||
|
.into());
|
||||||
|
} else {
|
||||||
|
let vector = vector.into_iter().map(OrderedFloat::into_inner).collect();
|
||||||
points.push(NDotProductPoint::new(vector));
|
points.push(NDotProductPoint::new(vector));
|
||||||
docids.push(docid);
|
docids.push(docid);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
assert_eq!(docids.len(), points.len());
|
|
||||||
|
|
||||||
let hnsw_length = points.len();
|
let hnsw_length = points.len();
|
||||||
let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points);
|
let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points);
|
||||||
|
|
||||||
|
assert_eq!(docids.len(), pids.len());
|
||||||
|
|
||||||
|
// Store the vectors in the point-docid relation database
|
||||||
index.vector_id_docid.clear(wtxn)?;
|
index.vector_id_docid.clear(wtxn)?;
|
||||||
for (docid, pid) in docids.into_iter().zip(pids) {
|
for (docid, pid) in docids.into_iter().zip(pids) {
|
||||||
index.vector_id_docid.put(
|
index.vector_id_docid.put(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user