Merge branch 'main' into change-proximity-precision-settings

This commit is contained in:
Many the fish 2023-12-18 09:08:47 +01:00 committed by GitHub
commit 9e1b458010
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
55 changed files with 5801 additions and 723 deletions

View file

@ -42,7 +42,8 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
facet_id_is_empty_docids,
field_id_docid_facet_f64s,
field_id_docid_facet_strings,
vector_id_docid,
vector_arroy,
embedder_category_id: _,
documents,
} = self.index;
@ -58,7 +59,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?;
self.index.delete_geo_rtree(self.wtxn)?;
self.index.delete_geo_faceted_documents_ids(self.wtxn)?;
self.index.delete_vector_hnsw(self.wtxn)?;
// Clear the other databases.
external_documents_ids.clear(self.wtxn)?;
@ -82,7 +82,9 @@ impl<'t, 'i> ClearDocuments<'t, 'i> {
facet_id_string_docids.clear(self.wtxn)?;
field_id_docid_facet_f64s.clear(self.wtxn)?;
field_id_docid_facet_strings.clear(self.wtxn)?;
vector_id_docid.clear(self.wtxn)?;
// vector
vector_arroy.clear(self.wtxn)?;
documents.clear(self.wtxn)?;
Ok(number_of_documents)

View file

@ -1,9 +1,10 @@
use std::cmp::Ordering;
use std::convert::TryFrom;
use std::convert::{TryFrom, TryInto};
use std::fs::File;
use std::io::{self, BufReader, BufWriter};
use std::mem::size_of;
use std::str::from_utf8;
use std::sync::Arc;
use bytemuck::cast_slice;
use grenad::Writer;
@ -13,13 +14,56 @@ use serde_json::{from_slice, Value};
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
use crate::error::UserError;
use crate::prompt::Prompt;
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
use crate::update::index_documents::helpers::try_split_at;
use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors};
use crate::vector::Embedder;
use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors};
/// The length of the elements that are always in the buffer when inserting new values.
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
pub struct ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
pub manual_vectors: grenad::Reader<BufReader<File>>,
// docid -> ()
pub remove_vectors: grenad::Reader<BufReader<File>>,
// docid -> prompt
pub prompts: grenad::Reader<BufReader<File>>,
}
enum VectorStateDelta {
NoChange,
// Remove all vectors, generated or manual, from this document
NowRemoved,
// Add the manually specified vectors, passed in the other grenad
// Remove any previously generated vectors
// Note: changing the value of the manually specified vector **should not record** this delta
WasGeneratedNowManual(Vec<Vec<f32>>),
ManualDelta(Vec<Vec<f32>>, Vec<Vec<f32>>),
// Add the vector computed from the specified prompt
// Remove any previous vector
// Note: changing the value of the prompt **does require** recording this delta
NowGenerated(String),
}
impl VectorStateDelta {
fn into_values(self) -> (bool, String, (Vec<Vec<f32>>, Vec<Vec<f32>>)) {
match self {
VectorStateDelta::NoChange => Default::default(),
VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()),
VectorStateDelta::WasGeneratedNowManual(add) => {
(true, Default::default(), (Default::default(), add))
}
VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)),
VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()),
}
}
}
/// Extracts the embedding vector contained in each document under the `_vectors` field.
///
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
@ -27,16 +71,35 @@ const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters,
vectors_fid: FieldId,
) -> Result<grenad::Reader<BufReader<File>>> {
field_id_map: &FieldsIdsMap,
prompt: &Prompt,
embedder_name: &str,
) -> Result<ExtractedVectorPoints> {
puffin::profile_function!();
let mut writer = create_writer(
// (docid, _index) -> KvWriterDelAdd -> Vector
let mut manual_vectors_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
// (docid) -> (prompt)
let mut prompts_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
// (docid) -> ()
let mut remove_vectors_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let vectors_fid = field_id_map.id("_vectors");
let mut key_buffer = Vec::new();
let mut cursor = obkv_documents.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
@ -53,43 +116,157 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
// lazily get it when needed
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
// first we retrieve the _vectors field
if let Some(value) = obkv.get(vectors_fid) {
let vectors_obkv = KvReaderDelAdd::new(value);
let vectors_field = vectors_fid
.and_then(|vectors_fid| obkv.get(vectors_fid))
.map(KvReaderDelAdd::new)
.map(|obkv| to_vector_maps(obkv, document_id))
.transpose()?;
// then we extract the values
let del_vectors = vectors_obkv
.get(DelAdd::Deletion)
.map(|vectors| extract_vectors(vectors, document_id))
.transpose()?
.flatten();
let add_vectors = vectors_obkv
.get(DelAdd::Addition)
.map(|vectors| extract_vectors(vectors, document_id))
.transpose()?
.flatten();
let (del_map, add_map) = vectors_field.unzip();
let del_map = del_map.flatten();
let add_map = add_map.flatten();
// and we finally push the unique vectors into the writer
push_vectors_diff(
&mut writer,
&mut key_buffer,
del_vectors.unwrap_or_default(),
add_vectors.unwrap_or_default(),
)?;
}
let del_value = del_map.and_then(|mut map| map.remove(embedder_name));
let add_value = add_map.and_then(|mut map| map.remove(embedder_name));
let delta = match (del_value, add_value) {
(Some(old), Some(new)) => {
// no autogeneration
let del_vectors = extract_vectors(old, document_id, embedder_name)?;
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > u8::MAX.into() {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::ManualDelta(del_vectors, add_vectors)
}
(Some(_old), None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
}
}
(None, Some(new)) => {
// was possibly autogenerated, remove all vectors for that document
let add_vectors = extract_vectors(new, document_id, embedder_name)?;
if add_vectors.len() > u8::MAX.into() {
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
document_id().to_string(),
add_vectors.len(),
)));
}
VectorStateDelta::WasGeneratedNowManual(add_vectors)
}
(None, None) => {
// Do we keep this document?
let document_is_kept = obkv
.iter()
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
}
}
};
// and we finally push the unique vectors into the writer
push_vectors_diff(
&mut remove_vectors_writer,
&mut prompts_writer,
&mut manual_vectors_writer,
&mut key_buffer,
delta,
)?;
}
writer_into_reader(writer)
Ok(ExtractedVectorPoints {
// docid, _index -> KvWriterDelAdd -> Vector
manual_vectors: writer_into_reader(manual_vectors_writer)?,
// docid -> ()
remove_vectors: writer_into_reader(remove_vectors_writer)?,
// docid -> prompt
prompts: writer_into_reader(prompts_writer)?,
})
}
fn to_vector_maps(
obkv: KvReaderDelAdd,
document_id: impl Fn() -> Value,
) -> Result<(Option<serde_json::Map<String, Value>>, Option<serde_json::Map<String, Value>>)> {
let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?;
let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?;
Ok((del, add))
}
fn to_vector_map(
obkv: KvReaderDelAdd,
side: DelAdd,
document_id: &impl Fn() -> Value,
) -> Result<Option<serde_json::Map<String, Value>>> {
Ok(if let Some(value) = obkv.get(side) {
let Ok(value) = from_slice(value) else {
let value = from_slice(value).map_err(InternalError::SerdeJson)?;
return Err(crate::Error::UserError(UserError::InvalidVectorsMapType {
document_id: document_id(),
value,
}));
};
Some(value)
} else {
None
})
}
/// Computes the diff between both Del and Add numbers and
/// only inserts the parts that differ in the sorter.
fn push_vectors_diff(
writer: &mut Writer<BufWriter<File>>,
remove_vectors_writer: &mut Writer<BufWriter<File>>,
prompts_writer: &mut Writer<BufWriter<File>>,
manual_vectors_writer: &mut Writer<BufWriter<File>>,
key_buffer: &mut Vec<u8>,
mut del_vectors: Vec<Vec<f32>>,
mut add_vectors: Vec<Vec<f32>>,
delta: VectorStateDelta,
) -> Result<()> {
let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values();
if must_remove {
key_buffer.truncate(TRUNCATE_SIZE);
remove_vectors_writer.insert(&key_buffer, [])?;
}
if !prompt.is_empty() {
key_buffer.truncate(TRUNCATE_SIZE);
prompts_writer.insert(&key_buffer, prompt.as_bytes())?;
}
// We sort and dedup the vectors
del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
@ -114,7 +291,7 @@ fn push_vectors_diff(
let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Deletion, cast_slice(&vector))?;
let bytes = obkv.into_inner()?;
writer.insert(&key_buffer, bytes)?;
manual_vectors_writer.insert(&key_buffer, bytes)?;
}
EitherOrBoth::Right(vector) => {
// We insert only the Add part of the Obkv to inform
@ -122,7 +299,7 @@ fn push_vectors_diff(
let mut obkv = KvWriterDelAdd::memory();
obkv.insert(DelAdd::Addition, cast_slice(&vector))?;
let bytes = obkv.into_inner()?;
writer.insert(&key_buffer, bytes)?;
manual_vectors_writer.insert(&key_buffer, bytes)?;
}
}
}
@ -136,13 +313,112 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering {
}
/// Extracts the vectors from a JSON value.
fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Option<Vec<Vec<f32>>>> {
match from_slice(value) {
Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)),
fn extract_vectors(
value: Value,
document_id: impl Fn() -> Value,
name: &str,
) -> Result<Vec<Vec<f32>>> {
// FIXME: ugly clone of the vectors here
match serde_json::from_value(value.clone()) {
Ok(vectors) => {
Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default())
}
Err(_) => Err(UserError::InvalidVectorsType {
document_id: document_id(),
value: from_slice(value).map_err(InternalError::SerdeJson)?,
value,
subfield: name.to_owned(),
}
.into()),
}
}
#[logging_timer::time]
pub fn extract_embeddings<R: io::Read + io::Seek>(
// docid, prompt
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
) -> Result<grenad::Reader<BufReader<File>>> {
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
// docid, state with embedding
let mut state_writer = create_writer(
indexer.chunk_compression_type,
indexer.chunk_compression_level,
tempfile::tempfile()?,
);
let mut chunks = Vec::with_capacity(n_chunks);
let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk);
let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk);
let mut chunks_ids = Vec::with_capacity(n_chunks);
let mut cursor = prompt_reader.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
// SAFETY: precondition, the grenad value was saved from a string
let prompt = unsafe { std::str::from_utf8_unchecked(value) };
if current_chunk.len() == current_chunk.capacity() {
chunks.push(std::mem::replace(
&mut current_chunk,
Vec::with_capacity(n_vectors_per_chunk),
));
chunks_ids.push(std::mem::replace(
&mut current_chunk_ids,
Vec::with_capacity(n_vectors_per_chunk),
));
};
current_chunk.push(prompt.to_owned());
current_chunk_ids.push(docid);
if chunks.len() == chunks.capacity() {
let chunked_embeds = rt
.block_on(
embedder
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
)
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
chunks_ids.clear();
}
}
// send last chunk
if !chunks.is_empty() {
let chunked_embeds = rt
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
.flat_map(|docids| docids.iter())
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
if !current_chunk.is_empty() {
let embeds = rt
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
}
}
writer_into_reader(state_writer)
}

View file

@ -23,7 +23,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids;
use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues};
use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
use self::extract_geo_points::extract_geo_points;
use self::extract_vector_points::extract_vector_points;
use self::extract_vector_points::{
extract_embeddings, extract_vector_points, ExtractedVectorPoints,
};
use self::extract_word_docids::extract_word_docids;
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
use self::extract_word_position_docids::extract_word_position_docids;
@ -33,7 +35,8 @@ use super::helpers::{
};
use super::{helpers, TypedChunk};
use crate::proximity::ProximityPrecision;
use crate::{FieldId, Result};
use crate::vector::EmbeddingConfigs;
use crate::{FieldId, FieldsIdsMap, Result};
/// Extract data for each databases from obkv documents in parallel.
/// Send data in grenad file over provided Sender.
@ -47,13 +50,14 @@ pub(crate) fn data_from_obkv_documents(
faceted_fields: HashSet<FieldId>,
primary_key_id: FieldId,
geo_fields_ids: Option<(FieldId, FieldId)>,
vectors_field_id: Option<FieldId>,
field_id_map: FieldsIdsMap,
stop_words: Option<fst::Set<&[u8]>>,
allowed_separators: Option<&[&str]>,
dictionary: Option<&[&str]>,
max_positions_per_attributes: Option<u32>,
exact_attributes: HashSet<FieldId>,
proximity_precision: ProximityPrecision,
embedders: EmbeddingConfigs,
) -> Result<()> {
puffin::profile_function!();
@ -64,7 +68,8 @@ pub(crate) fn data_from_obkv_documents(
original_documents_chunk,
indexer,
lmdb_writer_sx.clone(),
vectors_field_id,
field_id_map.clone(),
embedders.clone(),
)
})
.collect::<Result<()>>()?;
@ -276,24 +281,53 @@ fn send_original_documents_data(
original_documents_chunk: Result<grenad::Reader<BufReader<File>>>,
indexer: GrenadParameters,
lmdb_writer_sx: Sender<Result<TypedChunk>>,
vectors_field_id: Option<FieldId>,
field_id_map: FieldsIdsMap,
embedders: EmbeddingConfigs,
) -> Result<()> {
let original_documents_chunk =
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
if let Some(vectors_field_id) = vectors_field_id {
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || {
let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id);
let _ = match result {
Ok(vector_points) => {
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || {
for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points(
documents_chunk_cloned.clone(),
indexer,
&field_id_map,
&prompt,
&name,
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
None
}
};
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name: name,
}));
}
}
Err(error) => lmdb_writer_sx_cloned.send(Err(error)),
};
});
}
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
}
}
});
// TODO: create a custom internal error
lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap();

View file

@ -4,7 +4,7 @@ mod helpers;
mod transform;
mod typed_chunk;
use std::collections::HashSet;
use std::collections::{HashMap, HashSet};
use std::io::{Cursor, Read, Seek};
use std::iter::FromIterator;
use std::num::NonZeroU32;
@ -14,6 +14,7 @@ use crossbeam_channel::{Receiver, Sender};
use heed::types::Str;
use heed::Database;
use log::debug;
use rand::SeedableRng;
use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize};
use slice_group_by::GroupBy;
@ -36,6 +37,7 @@ pub use crate::update::index_documents::helpers::CursorClonableMmap;
use crate::update::{
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
};
use crate::vector::EmbeddingConfigs;
use crate::{CboRoaringBitmapCodec, Index, Result};
static MERGED_DATABASE_COUNT: usize = 7;
@ -78,6 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
should_abort: FA,
added_documents: u64,
deleted_documents: u64,
embedders: EmbeddingConfigs,
}
#[derive(Default, Debug, Clone)]
@ -121,6 +124,7 @@ where
index,
added_documents: 0,
deleted_documents: 0,
embedders: Default::default(),
})
}
@ -167,6 +171,11 @@ where
Ok((self, Ok(indexed_documents)))
}
pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self {
self.embedders = embedders;
self
}
/// Remove a batch of documents from the current builder.
///
/// Returns the number of documents deleted from the builder.
@ -322,17 +331,18 @@ where
// get filterable fields for facet databases
let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?;
// get the fid of the `_geo.lat` and `_geo.lng` fields.
let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") {
let mut field_id_map = self.index.fields_ids_map(self.wtxn)?;
// self.index.fields_ids_map($a)? ==>> field_id_map
let geo_fields_ids = match field_id_map.id("_geo") {
Some(gfid) => {
let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid);
let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid);
// if `_geo` is faceted then we get the `lat` and `lng`
if is_sortable || is_filterable {
let field_ids = self
.index
.fields_ids_map(self.wtxn)?
let field_ids = field_id_map
.insert("_geo.lat")
.zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng"))
.zip(field_id_map.insert("_geo.lng"))
.ok_or(UserError::AttributeLimitReached)?;
Some(field_ids)
} else {
@ -341,8 +351,6 @@ where
}
None => None,
};
// get the fid of the `_vectors` field.
let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors");
let stop_words = self.index.stop_words(self.wtxn)?;
let separators = self.index.allowed_separators(self.wtxn)?;
@ -364,6 +372,8 @@ where
self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB
let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes;
let cloned_embedder = self.embedders.clone();
// Run extraction pipeline in parallel.
pool.install(|| {
puffin::profile_scope!("extract_and_send_grenad_chunks");
@ -387,13 +397,14 @@ where
faceted_fields,
primary_key_id,
geo_fields_ids,
vectors_field_id,
field_id_map,
stop_words,
separators.as_deref(),
dictionary.as_deref(),
max_positions_per_attributes,
exact_attributes,
proximity_precision,
cloned_embedder,
)
});
@ -402,7 +413,7 @@ where
}
// needs to be dropped to avoid channel waiting lock.
drop(lmdb_writer_sx)
drop(lmdb_writer_sx);
});
let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0;
@ -419,6 +430,8 @@ where
let mut word_docids = None;
let mut exact_word_docids = None;
let mut dimension = HashMap::new();
for result in lmdb_writer_rx {
if (self.should_abort)() {
return Err(Error::InternalError(InternalError::AbortedIndexation));
@ -448,6 +461,22 @@ where
word_position_docids = Some(cloneable_chunk);
TypedChunk::WordPositionDocids(chunk)
}
TypedChunk::VectorPoints {
expected_dimension,
remove_vectors,
embeddings,
manual_vectors,
embedder_name,
} => {
dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension,
manual_vectors,
embedder_name,
}
}
otherwise => otherwise,
};
@ -480,6 +509,33 @@ where
// We write the primary key field id into the main database
self.index.put_primary_key(self.wtxn, &primary_key)?;
let number_of_documents = self.index.number_of_documents(self.wtxn)?;
let mut rng = rand::rngs::StdRng::seed_from_u64(42);
for (embedder_name, dimension) in dimension {
let wtxn = &mut *self.wtxn;
let vector_arroy = self.index.vector_arroy;
let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
pool.install(|| {
let writer_index = (embedder_index as u16) << 8;
for k in 0..=u8::MAX {
let writer = arroy::Writer::prepare(
wtxn,
vector_arroy,
writer_index | (k as u16),
dimension,
)?;
if writer.is_empty(wtxn)? {
break;
}
writer.build(wtxn, &mut rng, None)?;
}
Result::Ok(())
})?;
}
self.execute_prefix_databases(
word_docids,
@ -694,6 +750,8 @@ fn execute_word_prefix_docids(
#[cfg(test)]
mod tests {
use std::collections::BTreeMap;
use big_s::S;
use fst::IntoStreamer;
use heed::RwTxn;
@ -703,6 +761,7 @@ mod tests {
use crate::documents::documents_batch_reader_from_objects;
use crate::index::tests::TempIndex;
use crate::search::TermsMatchingStrategy;
use crate::update::Setting;
use crate::{db_snap, Filter, Search};
#[test]
@ -2494,18 +2553,39 @@ mod tests {
/// Vectors must be of the same length.
#[test]
fn test_multiple_vectors() {
use crate::vector::settings::{EmbedderSettings, EmbeddingSettings};
let index = TempIndex::new();
index.add_documents(documents!([{"id": 0, "_vectors": [[0, 1, 2], [3, 4, 5]] }])).unwrap();
index.add_documents(documents!([{"id": 1, "_vectors": [6, 7, 8] }])).unwrap();
index
.update_settings(|settings| {
let mut embedders = BTreeMap::default();
embedders.insert(
"manual".to_string(),
Setting::Set(EmbeddingSettings {
embedder_options: Setting::Set(EmbedderSettings::UserProvided(
crate::vector::settings::UserProvidedSettings { dimensions: 3 },
)),
document_template: Setting::NotSet,
}),
);
settings.set_embedder_settings(embedders);
})
.unwrap();
index
.add_documents(
documents!([{"id": 2, "_vectors": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }]),
documents!([{"id": 0, "_vectors": { "manual": [[0, 1, 2], [3, 4, 5]] } }]),
)
.unwrap();
index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap();
index
.add_documents(
documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]),
)
.unwrap();
let rtxn = index.read_txn().unwrap();
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap();
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap();
assert_eq!(res.documents_ids.len(), 3);
}

View file

@ -1,4 +1,4 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashMap;
use std::convert::TryInto;
use std::fs::File;
use std::io::{self, BufReader};
@ -8,9 +8,7 @@ use charabia::{Language, Script};
use grenad::MergerBuilder;
use heed::types::Bytes;
use heed::{PutFlags, RwTxn};
use log::error;
use obkv::{KvReader, KvWriter};
use ordered_float::OrderedFloat;
use roaring::RoaringBitmap;
use super::helpers::{
@ -18,16 +16,15 @@ use super::helpers::{
valid_lmdb_key, CursorClonableMmap,
};
use super::{ClonableMmap, MergeFn};
use crate::distance::NDotProductPoint;
use crate::error::UserError;
use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind};
use crate::facet::FacetType;
use crate::index::db_name::DOCUMENTS;
use crate::index::Hnsw;
use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd};
use crate::update::facet::FacetsUpdate;
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError};
use crate::{
lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError,
};
pub(crate) enum TypedChunk {
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
@ -47,7 +44,13 @@ pub(crate) enum TypedChunk {
FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>),
FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>),
GeoPoints(grenad::Reader<BufReader<File>>),
VectorPoints(grenad::Reader<BufReader<File>>),
VectorPoints {
remove_vectors: grenad::Reader<BufReader<File>>,
embeddings: Option<grenad::Reader<BufReader<File>>>,
expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
},
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
}
@ -100,8 +103,8 @@ impl TypedChunk {
TypedChunk::GeoPoints(grenad) => {
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
}
TypedChunk::VectorPoints(grenad) => {
format!("VectorPoints {{ number_of_entries: {} }}", grenad.len())
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => {
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name)
}
TypedChunk::ScriptLanguageDocids(sl_map) => {
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
@ -355,19 +358,77 @@ pub(crate) fn write_typed_chunk_into_index(
index.put_geo_rtree(wtxn, &rtree)?;
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
}
TypedChunk::VectorPoints(vector_points) => {
let mut vectors_set = HashSet::new();
// We extract and store the previous vectors
if let Some(hnsw) = index.vector_hnsw(wtxn)? {
for (pid, point) in hnsw.iter() {
let pid_key = pid.into_inner();
let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap();
let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect();
vectors_set.insert((docid, vector));
TypedChunk::VectorPoints {
remove_vectors,
manual_vectors,
embeddings,
expected_dimension,
embedder_name,
} => {
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or(
InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None },
)?;
let writer_index = (embedder_index as u16) << 8;
// FIXME: allow customizing distance
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map(|k| {
arroy::Writer::prepare(
wtxn,
index.vector_arroy,
writer_index | (k as u16),
expected_dimension,
)
})
.collect();
let writers = writers?;
// remove vectors for docids we want them removed
let mut cursor = remove_vectors.into_cursor()?;
while let Some((key, _)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
for writer in &writers {
// Uses invariant: vectors are packed in the first writers.
if !writer.del_item(wtxn, docid)? {
break;
}
}
}
let mut cursor = vector_points.into_cursor()?;
// add generated embeddings
if let Some(embeddings) = embeddings {
let mut cursor = embeddings.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
let data = pod_collect_to_vec(value);
// it is a code error to have embeddings and not expected_dimension
let embeddings =
crate::vector::Embeddings::from_inner(data, expected_dimension)
// code error if we somehow got the wrong dimension
.unwrap();
if embeddings.embedding_count() > u8::MAX.into() {
let external_docid = if let Ok(Some(Ok(index))) = index
.external_id_of(wtxn, std::iter::once(docid))
.map(|it| it.into_iter().next())
{
index
} else {
format!("internal docid={docid}")
};
return Err(crate::Error::UserError(crate::UserError::TooManyVectors(
external_docid,
embeddings.embedding_count(),
)));
}
for (embedding, writer) in embeddings.iter().zip(&writers) {
writer.add_item(wtxn, docid, embedding)?;
}
}
}
// perform the manual diff
let mut cursor = manual_vectors.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? {
// convert the key back to a u32 (4 bytes)
let (left, _index) = try_split_array_at(key).unwrap();
@ -375,58 +436,52 @@ pub(crate) fn write_typed_chunk_into_index(
let vector_deladd_obkv = KvReaderDelAdd::new(value);
if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) {
// convert the vector back to a Vec<f32>
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
let key = (docid, vector);
if !vectors_set.remove(&key) {
error!("Unable to delete the vector: {:?}", key.1);
let vector: Vec<f32> = pod_collect_to_vec(value);
let mut deleted_index = None;
for (index, writer) in writers.iter().enumerate() {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
// uses invariant: vectors are packed in the first writers.
break;
};
if candidate == vector {
writer.del_item(wtxn, docid)?;
deleted_index = Some(index);
}
}
// 🥲 enforce invariant: vectors are packed in the first writers.
if let Some(deleted_index) = deleted_index {
let mut last_index_with_a_vector = None;
for (index, writer) in writers.iter().enumerate().skip(deleted_index) {
let Some(candidate) = writer.item_vector(wtxn, docid)? else {
break;
};
last_index_with_a_vector = Some((index, candidate));
}
if let Some((last_index, vector)) = last_index_with_a_vector {
// unwrap: computed the index from the list of writers
let writer = writers.get(last_index).unwrap();
writer.del_item(wtxn, docid)?;
writers.get(deleted_index).unwrap().add_item(wtxn, docid, &vector)?;
}
}
}
if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
// convert the vector back to a Vec<f32>
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
vectors_set.insert((docid, vector));
}
}
let vector = pod_collect_to_vec(value);
// Extract the most common vector dimension
let expected_dimension_size = {
let mut dims = HashMap::new();
vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1);
dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len)
};
// Ensure that the vector lengths are correct and
// prepare the vectors before inserting them in the HNSW.
let mut points = Vec::new();
let mut docids = Vec::new();
for (docid, vector) in vectors_set {
if expected_dimension_size.map_or(false, |expected| expected != vector.len()) {
return Err(UserError::InvalidVectorDimensions {
expected: expected_dimension_size.unwrap_or(vector.len()),
found: vector.len(),
// overflow was detected during vector extraction.
for writer in &writers {
if !writer.contains_item(wtxn, docid)? {
writer.add_item(wtxn, docid, &vector)?;
break;
}
}
.into());
} else {
let vector = vector.into_iter().map(OrderedFloat::into_inner).collect();
points.push(NDotProductPoint::new(vector));
docids.push(docid);
}
}
let hnsw_length = points.len();
let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points);
assert_eq!(docids.len(), pids.len());
// Store the vectors in the point-docid relation database
index.vector_id_docid.clear(wtxn)?;
for (docid, pid) in docids.into_iter().zip(pids) {
index.vector_id_docid.put(wtxn, &pid.into_inner(), &docid)?;
}
log::debug!("There are {} entries in the HNSW so far", hnsw_length);
index.put_vector_hnsw(wtxn, &new_hnsw)?;
log::debug!("Finished vector chunk for {}", embedder_name);
}
TypedChunk::ScriptLanguageDocids(sl_map) => {
for (key, (deletion, addition)) in sl_map {

View file

@ -1,9 +1,11 @@
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::convert::TryInto;
use std::result::Result as StdResult;
use std::sync::Arc;
use charabia::{Normalize, Tokenizer, TokenizerBuilder};
use deserr::{DeserializeError, Deserr};
use itertools::Itertools;
use itertools::{EitherOrBoth, Itertools};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
@ -15,6 +17,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS
use crate::proximity::ProximityPrecision;
use crate::update::index_documents::IndexDocumentsMethod;
use crate::update::{IndexDocuments, UpdateIndexingStep};
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
use crate::{FieldsIdsMap, Index, OrderBy, Result};
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -73,6 +77,13 @@ impl<T> Setting<T> {
otherwise => otherwise,
}
}
pub fn apply(&mut self, new: Self) {
if let Setting::NotSet = new {
return;
}
*self = new;
}
}
impl<T: Serialize> Serialize for Setting<T> {
@ -129,6 +140,7 @@ pub struct Settings<'a, 't, 'i> {
sort_facet_values_by: Setting<HashMap<String, OrderBy>>,
pagination_max_total_hits: Setting<usize>,
proximity_precision: Setting<ProximityPrecision>,
embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>,
}
impl<'a, 't, 'i> Settings<'a, 't, 'i> {
@ -161,6 +173,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
sort_facet_values_by: Setting::NotSet,
pagination_max_total_hits: Setting::NotSet,
proximity_precision: Setting::NotSet,
embedder_settings: Setting::NotSet,
indexer_config,
}
}
@ -343,6 +356,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
self.proximity_precision = Setting::Reset;
}
pub fn set_embedder_settings(&mut self, value: BTreeMap<String, Setting<EmbeddingSettings>>) {
self.embedder_settings = Setting::Set(value);
}
pub fn reset_embedder_settings(&mut self) {
self.embedder_settings = Setting::Reset;
}
fn reindex<FP, FA>(
&mut self,
progress_callback: &FP,
@ -377,6 +398,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
fields_ids_map,
)?;
let embedder_configs = self.index.embedding_configs(self.wtxn)?;
let embedders = self.embedders(embedder_configs)?;
// We index the generated `TransformOutput` which must contain
// all the documents with fields in the newly defined searchable order.
let indexing_builder = IndexDocuments::new(
@ -387,11 +411,33 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
&progress_callback,
&should_abort,
)?;
let indexing_builder = indexing_builder.with_embedders(embedders);
indexing_builder.execute_raw(output)?;
Ok(())
}
fn embedders(
&self,
embedding_configs: Vec<(String, EmbeddingConfig)>,
) -> Result<EmbeddingConfigs> {
let res: Result<_> = embedding_configs
.into_iter()
.map(|(name, EmbeddingConfig { embedder_options, prompt })| {
let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?);
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(crate::vector::Error::from)
.map_err(crate::Error::from)?,
);
Ok((name, (embedder, prompt)))
})
.collect();
res.map(EmbeddingConfigs::new)
}
fn update_displayed(&mut self) -> Result<bool> {
match self.displayed_fields {
Setting::Set(ref fields) => {
@ -890,6 +936,73 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Ok(changed)
}
fn update_embedding_configs(&mut self) -> Result<bool> {
let update = match std::mem::take(&mut self.embedder_settings) {
Setting::Set(configs) => {
let mut changed = false;
let old_configs = self.index.embedding_configs(self.wtxn)?;
let old_configs: BTreeMap<String, Setting<EmbeddingSettings>> =
old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect();
let mut new_configs = BTreeMap::new();
for joined in old_configs
.into_iter()
.merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right))
{
match joined {
EitherOrBoth::Both((name, mut old), (_, new)) => {
old.apply(new);
let new = validate_prompt(&name, old)?;
changed = true;
new_configs.insert(name, new);
}
EitherOrBoth::Left((name, setting)) => {
new_configs.insert(name, setting);
}
EitherOrBoth::Right((name, setting)) => {
let setting = validate_prompt(&name, setting)?;
changed = true;
new_configs.insert(name, setting);
}
}
}
let new_configs: Vec<(String, EmbeddingConfig)> = new_configs
.into_iter()
.filter_map(|(name, setting)| match setting {
Setting::Set(value) => Some((name, value.into())),
Setting::Reset => None,
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
})
.collect();
self.index.embedder_category_id.clear(self.wtxn)?;
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
self.index.embedder_category_id.put_with_flags(
self.wtxn,
heed::PutFlags::APPEND,
embedder_name,
&index
.try_into()
.map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?,
)?;
}
if new_configs.is_empty() {
self.index.delete_embedding_configs(self.wtxn)?;
} else {
self.index.put_embedding_configs(self.wtxn, new_configs)?;
}
changed
}
Setting::Reset => {
self.index.delete_embedding_configs(self.wtxn)?;
true
}
Setting::NotSet => false,
};
Ok(update)
}
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()>
where
FP: Fn(UpdateIndexingStep) + Sync,
@ -927,6 +1040,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
let searchable_updated = self.update_searchable()?;
let exact_attributes_updated = self.update_exact_attributes()?;
let proximity_precision = self.update_proximity_precision()?;
// TODO: very rough approximation of the needs for reindexing where any change will result in
// a full reindexing.
// What can be done instead:
// 1. Only change the distance on a distance change
// 2. Only change the name -> embedder mapping on a name change
// 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage
let embedding_configs_updated = self.update_embedding_configs()?;
if stop_words_updated
|| non_separator_tokens_updated
@ -937,6 +1057,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|| searchable_updated
|| exact_attributes_updated
|| proximity_precision
|| embedding_configs_updated
{
self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?;
}
@ -945,6 +1066,31 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
}
}
fn validate_prompt(
name: &str,
new: Setting<EmbeddingSettings>,
) -> Result<Setting<EmbeddingSettings>> {
match new {
Setting::Set(EmbeddingSettings {
embedder_options,
document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }),
}) => {
// validate
let template = crate::prompt::Prompt::new(template)
.map(|prompt| crate::prompt::PromptData::from(prompt).template)
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?;
Ok(Setting::Set(EmbeddingSettings {
embedder_options,
document_template: Setting::Set(PromptSettings {
template: Setting::Set(template),
}),
}))
}
new => Ok(new),
}
}
#[cfg(test)]
mod tests {
use big_s::S;
@ -1763,6 +1909,7 @@ mod tests {
sort_facet_values_by,
pagination_max_total_hits,
proximity_precision,
embedder_settings,
} = settings;
assert!(matches!(searchable_fields, Setting::NotSet));
assert!(matches!(displayed_fields, Setting::NotSet));
@ -1785,6 +1932,7 @@ mod tests {
assert!(matches!(sort_facet_values_by, Setting::NotSet));
assert!(matches!(pagination_max_total_hits, Setting::NotSet));
assert!(matches!(proximity_precision, Setting::NotSet));
assert!(matches!(embedder_settings, Setting::NotSet));
})
.unwrap();
}