WIP multi embedders

fixed template bugs
This commit is contained in:
Louis Dureuil 2023-12-12 21:19:48 +01:00
parent abbe131084
commit 922a640188
No known key found for this signature in database
20 changed files with 438 additions and 158 deletions

View file

@ -71,8 +71,8 @@ impl VectorStateDelta {
pub fn extract_vector_points<R: io::Read + io::Seek>(
obkv_documents: grenad::Reader<R>,
indexer: GrenadParameters,
field_id_map: FieldsIdsMap,
prompt: Option<&Prompt>,
field_id_map: &FieldsIdsMap,
prompt: &Prompt,
) -> Result<ExtractedVectorPoints> {
puffin::profile_function!();
@ -142,14 +142,11 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
// becomes autogenerated
match prompt {
Some(prompt) => VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
&field_id_map,
)?),
None => VectorStateDelta::NowRemoved,
}
VectorStateDelta::NowGenerated(prompt.render(
obkv,
DelAdd::Addition,
field_id_map,
)?)
} else {
VectorStateDelta::NowRemoved
}
@ -162,26 +159,18 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
match prompt {
Some(prompt) => {
// Don't give up if the old prompt was failing
let old_prompt = prompt
.render(obkv, DelAdd::Deletion, &field_id_map)
.unwrap_or_default();
let new_prompt =
prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
}
// We no longer have a prompt, so we need to remove any existing vector
None => VectorStateDelta::NowRemoved,
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
@ -196,24 +185,16 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
if document_is_kept {
match prompt {
Some(prompt) => {
// Don't give up if the old prompt was failing
let old_prompt = prompt
.render(obkv, DelAdd::Deletion, &field_id_map)
.unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
if old_prompt != new_prompt {
log::trace!(
"🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"
);
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
}
None => VectorStateDelta::NowRemoved,
// Don't give up if the old prompt was failing
let old_prompt =
prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default();
let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?;
if old_prompt != new_prompt {
log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}");
VectorStateDelta::NowGenerated(new_prompt)
} else {
log::trace!("⏭️ Prompt unmodified, skipping");
VectorStateDelta::NoChange
}
} else {
VectorStateDelta::NowRemoved
@ -322,7 +303,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
prompt_reader: grenad::Reader<R>,
indexer: GrenadParameters,
embedder: Arc<Embedder>,
) -> Result<(grenad::Reader<BufReader<File>>, Option<usize>)> {
) -> Result<grenad::Reader<BufReader<File>>> {
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
@ -341,8 +322,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let mut chunks_ids = Vec::with_capacity(n_chunks);
let mut cursor = prompt_reader.into_cursor()?;
let mut expected_dimension = None;
while let Some((key, value)) = cursor.move_on_next()? {
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
// SAFETY: precondition, the grenad value was saved from a string
@ -367,7 +346,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
)
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
@ -376,7 +354,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
}
chunks_ids.clear();
}
@ -387,7 +364,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let chunked_embeds = rt
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in chunks_ids
.iter()
@ -395,7 +371,6 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
{
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
}
}
@ -403,14 +378,12 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
let embeds = rt
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?;
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
expected_dimension = Some(embeddings.dimension());
}
}
Ok((writer_into_reader(state_writer)?, expected_dimension))
writer_into_reader(state_writer)
}

View file

@ -292,43 +292,42 @@ fn send_original_documents_data(
let documents_chunk_cloned = original_documents_chunk.clone();
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
rayon::spawn(move || {
let (embedder, prompt) = embedders.get("default").cloned().unzip();
let result =
extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref());
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
/// FIXME: support multiple embedders
let results = embedder.and_then(|embedder| {
match extract_embeddings(prompts, indexer, embedder.clone()) {
for (name, (embedder, prompt)) in embedders {
let result = extract_vector_points(
documents_chunk_cloned.clone(),
indexer,
&field_id_map,
&prompt,
);
match result {
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
Ok(results) => Some(results),
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
None
}
}
});
let (embeddings, expected_dimension) = results.unzip();
let expected_dimension = expected_dimension.flatten();
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
/// FIXME FIXME FIXME
if expected_dimension.is_some() {
};
if !(remove_vectors.is_empty()
&& manual_vectors.is_empty()
&& embeddings.as_ref().map_or(true, |e| e.is_empty()))
{
let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
remove_vectors,
embeddings,
/// FIXME: compute an expected dimension from the manual vectors if any
expected_dimension: expected_dimension.unwrap(),
expected_dimension: embedder.dimensions(),
manual_vectors,
embedder_name: name,
}));
}
}
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
}
Err(error) => {
let _ = lmdb_writer_sx_cloned.send(Err(error));
}
};
}
});
// TODO: create a custom internal error

View file

@ -435,7 +435,7 @@ where
let mut word_docids = None;
let mut exact_word_docids = None;
let mut dimension = None;
let mut dimension = HashMap::new();
for result in lmdb_writer_rx {
if (self.should_abort)() {
@ -471,13 +471,15 @@ where
remove_vectors,
embeddings,
manual_vectors,
embedder_name,
} => {
dimension = Some(expected_dimension);
dimension.insert(embedder_name.clone(), expected_dimension);
TypedChunk::VectorPoints {
remove_vectors,
embeddings,
expected_dimension,
manual_vectors,
embedder_name,
}
}
otherwise => otherwise,
@ -513,14 +515,22 @@ where
self.index.put_primary_key(self.wtxn, &primary_key)?;
let number_of_documents = self.index.number_of_documents(self.wtxn)?;
if let Some(dimension) = dimension {
for (embedder_name, dimension) in dimension {
let wtxn = &mut *self.wtxn;
let vector_arroy = self.index.vector_arroy;
/// FIXME: unwrap
let embedder_index =
self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
pool.install(|| {
/// FIXME: do for each embedder
let writer_index = (embedder_index as u16) << 8;
let mut rng = rand::rngs::StdRng::from_entropy();
for k in 0..=u8::MAX {
let writer = arroy::Writer::prepare(wtxn, vector_arroy, k.into(), dimension)?;
let writer = arroy::Writer::prepare(
wtxn,
vector_arroy,
writer_index | (k as u16),
dimension,
)?;
if writer.is_empty(wtxn)? {
break;
}

View file

@ -47,6 +47,7 @@ pub(crate) enum TypedChunk {
embeddings: Option<grenad::Reader<BufReader<File>>>,
expected_dimension: usize,
manual_vectors: grenad::Reader<BufReader<File>>,
embedder_name: String,
},
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
}
@ -100,8 +101,8 @@ impl TypedChunk {
TypedChunk::GeoPoints(grenad) => {
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
}
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => {
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension)
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => {
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name)
}
TypedChunk::ScriptLanguageDocids(sl_map) => {
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
@ -360,12 +361,20 @@ pub(crate) fn write_typed_chunk_into_index(
manual_vectors,
embeddings,
expected_dimension,
embedder_name,
} => {
/// FIXME: allow customizing distance
/// FIXME: unwrap
let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap();
let writer_index = (embedder_index as u16) << 8;
// FIXME: allow customizing distance
let writers: std::result::Result<Vec<_>, _> = (0..=u8::MAX)
.map(|k| {
/// FIXME: allow customizing index and then do index << 8 + k
arroy::Writer::prepare(wtxn, index.vector_arroy, k.into(), expected_dimension)
arroy::Writer::prepare(
wtxn,
index.vector_arroy,
writer_index | (k as u16),
expected_dimension,
)
})
.collect();
let writers = writers?;
@ -456,7 +465,7 @@ pub(crate) fn write_typed_chunk_into_index(
}
}
log::debug!("There are 🤷‍♀️ entries in the arroy so far");
log::debug!("Finished vector chunk for {}", embedder_name);
}
TypedChunk::ScriptLanguageDocids(sl_map) => {
for (key, (deletion, addition)) in sl_map {

View file

@ -431,7 +431,6 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
let embedder = Arc::new(
Embedder::new(embedder_options.clone())
.map_err(crate::vector::Error::from)
.map_err(crate::UserError::from)
.map_err(crate::Error::from)?,
);
Ok((name, (embedder, prompt)))
@ -976,6 +975,19 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
})
.collect();
self.index.embedder_category_id.clear(self.wtxn)?;
for (index, (embedder_name, _)) in new_configs.iter().enumerate() {
self.index.embedder_category_id.put_with_flags(
self.wtxn,
heed::PutFlags::APPEND,
embedder_name,
&index
.try_into()
.map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?,
)?;
}
if new_configs.is_empty() {
self.index.delete_embedding_configs(self.wtxn)?;
} else {
@ -1062,7 +1074,7 @@ fn validate_prompt(
match new {
Setting::Set(EmbeddingSettings {
embedder_options,
prompt:
document_template:
Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }),
}) => {
// validate
@ -1072,7 +1084,7 @@ fn validate_prompt(
Ok(Setting::Set(EmbeddingSettings {
embedder_options,
prompt: Setting::Set(PromptSettings {
document_template: Setting::Set(PromptSettings {
template: Setting::Set(template),
strategy,
fallback,