Remove the vectors from the documents database

This commit is contained in:
Tamo 2024-05-22 15:27:09 +02:00
parent 7a84697570
commit 84e498299b
14 changed files with 407 additions and 51 deletions

View file

@ -10,16 +10,16 @@ use bytemuck::cast_slice;
use grenad::Writer;
use itertools::EitherOrBoth;
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use serde_json::Value;
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at;
use crate::update::settings::InnerIndexSettingsDiff;
use crate::vector::parsed_vectors::{ParsedVectorsDiff, RESERVED_VECTORS_FIELD_NAME};
use crate::vector::Embedder;
use crate::{DocumentId, Result, ThreadPoolNoAbort};
use crate::{try_split_array_at, DocumentId, Result, ThreadPoolNoAbort};
/// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
@ -35,6 +35,8 @@ pub struct ExtractedVectorPoints {
// embedder
pub embedder_name: String,
pub embedder: Arc<Embedder>,
pub user_defined: RoaringBitmap,
pub remove_from_user_defined: RoaringBitmap,
}
enum VectorStateDelta {
@ -80,6 +82,11 @@ struct EmbedderVectorExtractor {
prompts_writer: Writer<BufWriter<File>>,
// (docid) -> ()
remove_vectors_writer: Writer<BufWriter<File>>,
// The docids of the documents that contains a user defined embedding
user_defined: RoaringBitmap,
// The docids of the documents that contains an auto-generated embedding
remove_from_user_defined: RoaringBitmap,
}
/// Extracts the embedding vector contained in each document under the `_vectors` field.
@ -134,6 +141,8 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined: RoaringBitmap::new(),
remove_from_user_defined: RoaringBitmap::new(),
});
}
@ -141,13 +150,15 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
// this must always be serialized as (docid, external_docid);
const SIZE_OF_DOCUMENTID: usize = std::mem::size_of::<DocumentId>();
let (docid_bytes, external_id_bytes) =
try_split_at(key, std::mem::size_of::<DocumentId>()).unwrap();
try_split_array_at::<u8, SIZE_OF_DOCUMENTID>(key).unwrap();
debug_assert!(from_utf8(external_id_bytes).is_ok());
let docid = DocumentId::from_be_bytes(docid_bytes);
let obkv = obkv::KvReader::new(value);
key_buffer.clear();
key_buffer.extend_from_slice(docid_bytes);
key_buffer.extend_from_slice(docid_bytes.as_slice());
// since we only need the primary key when we throw an error we create this getter to
// lazily get it when needed
@ -163,10 +174,22 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined,
remove_from_user_defined,
} in extractors.iter_mut()
{
let delta = match parsed_vectors.remove(embedder_name) {
(Some(old), Some(new)) => {
match (old.is_user_provided(), new.is_user_provided()) {
(true, true) | (false, false) => (),
(true, false) => {
remove_from_user_defined.insert(docid);
}
(false, true) => {
user_defined.insert(docid);
}
}
// no autogeneration
let del_vectors = old.into_array_of_vectors();
let add_vectors = new.into_array_of_vectors();
@ -187,6 +210,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
remove_from_user_defined.insert(docid);
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
@ -198,6 +222,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
}
}
(None, Some(new)) => {
if new.is_user_provided() {
user_defined.insert(docid);
} else {
remove_from_user_defined.insert(docid);
}
// was possibly autogenerated, remove all vectors for that document
let add_vectors = new.into_array_of_vectors();
if add_vectors.len() > usize::from(u8::MAX) {
@ -239,6 +268,7 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
VectorStateDelta::NoChange
}
} else {
remove_from_user_defined.remove(docid);
VectorStateDelta::NowRemoved
}
}
@ -265,18 +295,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
manual_vectors_writer,
prompts_writer,
remove_vectors_writer,
user_defined,
remove_from_user_defined,
} in extractors
{
results.push(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt
prompts: writer_into_reader(prompts_writer)?,
embedder,
embedder_name,
user_defined,
remove_from_user_defined,
})
}

View file

@ -238,6 +238,8 @@ fn send_original_documents_data(
prompts,
embedder_name,
embedder,
user_defined,
remove_from_user_defined: auto_generated,
} in extracted_vectors
{
let embeddings = match extract_embeddings(
@ -262,6 +264,8 @@ fn send_original_documents_data(
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name,
user_defined,
remove_from_user_defined: auto_generated,
}));
}
}

View file

@ -501,6 +501,8 @@ where
embeddings,
manual_vectors,
embedder_name,
user_defined,
remove_from_user_defined,
} => {
dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints {
@ -509,6 +511,8 @@ where
expected_dimension,
manual_vectors,
embedder_name,
user_defined,
remove_from_user_defined,
}
}
otherwise => otherwise,
@ -2616,10 +2620,11 @@ mod tests {
let rtxn = index.read_txn().unwrap();
let mut embedding_configs = index.embedding_configs(&rtxn).unwrap();
let (embedder_name, embedder) = embedding_configs.pop().unwrap();
let (embedder_name, embedder, user_defined) = embedding_configs.pop().unwrap();
insta::assert_snapshot!(embedder_name, @"manual");
insta::assert_debug_snapshot!(user_defined, @"RoaringBitmap<[0, 1, 2]>");
let embedder =
std::sync::Arc::new(crate::vector::Embedder::new(embedder.embedder_options).unwrap());
assert_eq!("manual", embedder_name);
let res = index
.search(&rtxn)
.semantic(embedder_name, embedder, Some([0.0, 1.0, 2.0].to_vec()))

View file

@ -90,6 +90,8 @@ pub(crate) enum TypedChunk {
expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
user_defined: RoaringBitmap,
remove_from_user_defined: RoaringBitmap,
},
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
}
@ -155,7 +157,7 @@ pub(crate) fn write_typed_chunk_into_index(
let mut iter = merger.into_stream_merger_iter()?;
let embedders: BTreeSet<_> =
index.embedding_configs(wtxn)?.into_iter().map(|(k, _v)| k).collect();
index.embedding_configs(wtxn)?.into_iter().map(|(name, _, _)| name).collect();
let mut vectors_buffer = Vec::new();
while let Some((key, reader)) = iter.next()? {
let mut writer: KvWriter<_, FieldId> = KvWriter::memory();
@ -181,7 +183,7 @@ pub(crate) fn write_typed_chunk_into_index(
// if the `_vectors` field cannot be parsed as map of vectors, just write it as-is
break 'vectors Some(addition);
};
vectors.retain_user_provided_vectors(&embedders);
vectors.retain_not_embedded_vectors(&embedders);
let crate::vector::parsed_vectors::ParsedVectors(vectors) = vectors;
if vectors.is_empty() {
// skip writing empty `_vectors` map
@ -619,6 +621,8 @@ pub(crate) fn write_typed_chunk_into_index(
let mut remove_vectors_builder = MergerBuilder::new(keep_first as MergeFn);
let mut manual_vectors_builder = MergerBuilder::new(keep_first as MergeFn);
let mut embeddings_builder = MergerBuilder::new(keep_first as MergeFn);
let mut user_defined = RoaringBitmap::new();
let mut remove_from_user_defined = RoaringBitmap::new();
let mut params = None;
for typed_chunk in typed_chunks {
let TypedChunk::VectorPoints {
@ -627,6 +631,8 @@ pub(crate) fn write_typed_chunk_into_index(
embeddings,
expected_dimension,
embedder_name,
user_defined: ud,
remove_from_user_defined: rud,
} = typed_chunk
else {
unreachable!();
@ -639,11 +645,21 @@ pub(crate) fn write_typed_chunk_into_index(
if let Some(embeddings) = embeddings {
embeddings_builder.push(embeddings.into_cursor()?);
}
user_defined |= ud;
remove_from_user_defined |= rud;
}
// typed chunks has always at least 1 chunk.
let Some((expected_dimension, embedder_name)) = params else { unreachable!() };
let mut embedding_configs = index.embedding_configs(&wtxn)?;
let (_name, _conf, ud) =
embedding_configs.iter_mut().find(|config| config.0 == embedder_name).unwrap();
*ud -= remove_from_user_defined;
*ud |= user_defined;
index.put_embedding_configs(wtxn, embedding_configs)?;
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;

View file

@ -6,6 +6,7 @@ use std::sync::Arc;
use charabia::{Normalize, Tokenizer, TokenizerBuilder};
use deserr::{DeserializeError, Deserr};
use itertools::{EitherOrBoth, Itertools};
use roaring::RoaringBitmap;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
@ -926,8 +927,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Setting::Set(configs) => {
let mut changed = false;
let old_configs = self.index.embedding_configs(self.wtxn)?;
let old_configs: BTreeMap<String, Setting<EmbeddingSettings>> =
old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect();
let old_configs: BTreeMap<String, (Setting<EmbeddingSettings>, RoaringBitmap)> =
old_configs
.into_iter()
.map(|(name, setting, user_defined)| {
(name, (Setting::Set(setting.into()), user_defined))
})
.collect();
let mut new_configs = BTreeMap::new();
for joined in old_configs
@ -936,15 +942,19 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
{
match joined {
// updated config
EitherOrBoth::Both((name, mut old), (_, new)) => {
EitherOrBoth::Both((name, (mut old, user_defined)), (_, new)) => {
changed |= EmbeddingSettings::apply_and_need_reindex(&mut old, new);
if changed {
tracing::debug!(embedder = name, "need reindex");
tracing::debug!(
embedder = name,
documents = user_defined.len(),
"need reindex"
);
} else {
tracing::debug!(embedder = name, "skip reindex");
}
let new = validate_embedding_settings(old, &name)?;
new_configs.insert(name, new);
new_configs.insert(name, (new, user_defined));
}
// unchanged config
EitherOrBoth::Left((name, setting)) => {
@ -961,21 +971,23 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
);
let setting = validate_embedding_settings(setting, &name)?;
changed = true;
new_configs.insert(name, setting);
new_configs.insert(name, (setting, RoaringBitmap::new()));
}
}
}
let new_configs: Vec<(String, EmbeddingConfig)> = new_configs
let new_configs: Vec<(String, EmbeddingConfig, RoaringBitmap)> = new_configs
.into_iter()
.filter_map(|(name, setting)| match setting {
Setting::Set(value) => Some((name, value.into())),
.filter_map(|(name, (setting, user_defined))| match setting {
Setting::Set(settings) => Some((name, settings.into(), user_defined)),
Setting::Reset => None,
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
Setting::NotSet => {
Some((name, EmbeddingSettings::default().into(), user_defined))
}
})
.collect();
self.index.embedder_category_id.clear(self.wtxn)?;
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
for (index, (embedder_name, _, _)) in new_configs.iter().enumerate() {
self.index.embedder_category_id.put_with_flags(
self.wtxn,
heed::PutFlags::APPEND,
@ -1359,10 +1371,12 @@ impl InnerIndexSettings {
}
}
fn embedders(embedding_configs: Vec<(String, EmbeddingConfig)>) -> Result<EmbeddingConfigs> {
fn embedders(
embedding_configs: Vec<(String, EmbeddingConfig, RoaringBitmap)>,
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
.map(|(name, EmbeddingConfig { embedder_options, prompt }, _)| {
let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?);
let embedder = Arc::new(