mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-22 21:04:27 +01:00
WIP vector extraction
This commit is contained in:
parent
5efd70c251
commit
9cbb2b066a
315
milli/src/update/new/extract/vectors/mod.rs
Normal file
315
milli/src/update/new/extract/vectors/mod.rs
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
use crate::error::FaultSource;
|
||||||
|
use crate::prompt::Prompt;
|
||||||
|
use crate::update::new::channel::EmbeddingSender;
|
||||||
|
use crate::update::new::indexer::document_changes::{Extractor, FullySend};
|
||||||
|
use crate::update::new::vector_document::VectorDocument;
|
||||||
|
use crate::update::new::DocumentChange;
|
||||||
|
use crate::vector::error::EmbedErrorKind;
|
||||||
|
use crate::vector::Embedder;
|
||||||
|
use crate::{DocumentId, Result, ThreadPoolNoAbort, UserError};
|
||||||
|
|
||||||
|
pub struct EmbeddingExtractor<'a> {
|
||||||
|
embedder: &'a Embedder,
|
||||||
|
prompt: &'a Prompt,
|
||||||
|
embedder_id: u8,
|
||||||
|
embedder_name: &'a str,
|
||||||
|
sender: &'a EmbeddingSender<'a>,
|
||||||
|
threads: &'a ThreadPoolNoAbort,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a> {
|
||||||
|
type Data = FullySend<()>;
|
||||||
|
|
||||||
|
fn init_data<'doc>(
|
||||||
|
&'doc self,
|
||||||
|
_extractor_alloc: raw_collections::alloc::RefBump<'extractor>,
|
||||||
|
) -> crate::Result<Self::Data> {
|
||||||
|
Ok(FullySend(()))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process<'doc>(
|
||||||
|
&'doc self,
|
||||||
|
changes: impl Iterator<Item = crate::Result<DocumentChange<'doc>>>,
|
||||||
|
context: &'doc crate::update::new::indexer::document_changes::DocumentChangeContext<
|
||||||
|
Self::Data,
|
||||||
|
>,
|
||||||
|
) -> crate::Result<()> {
|
||||||
|
let embedder_name: &str = self.embedder_name;
|
||||||
|
let embedder: &Embedder = self.embedder;
|
||||||
|
let prompt: &Prompt = self.prompt;
|
||||||
|
|
||||||
|
let mut chunks = Chunks::new(
|
||||||
|
embedder,
|
||||||
|
self.embedder_id,
|
||||||
|
embedder_name,
|
||||||
|
self.threads,
|
||||||
|
self.sender,
|
||||||
|
&context.doc_alloc,
|
||||||
|
);
|
||||||
|
|
||||||
|
for change in changes {
|
||||||
|
let change = change?;
|
||||||
|
match change {
|
||||||
|
DocumentChange::Deletion(deletion) => {
|
||||||
|
self.sender.delete(deletion.docid(), self.embedder_id).unwrap();
|
||||||
|
}
|
||||||
|
DocumentChange::Update(update) => {
|
||||||
|
/// FIXME: this will force the parsing/retrieval of VectorDocument once per embedder
|
||||||
|
/// consider doing all embedders at once?
|
||||||
|
let old_vectors = update.current_vectors(
|
||||||
|
&context.txn,
|
||||||
|
context.index,
|
||||||
|
context.db_fields_ids_map,
|
||||||
|
&context.doc_alloc,
|
||||||
|
)?;
|
||||||
|
let old_vectors = old_vectors.vectors_for_key(embedder_name)?.unwrap();
|
||||||
|
let new_vectors = update.updated_vectors(&context.doc_alloc)?;
|
||||||
|
if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| {
|
||||||
|
new_vectors.vectors_for_key(embedder_name).transpose()
|
||||||
|
}) {
|
||||||
|
let new_vectors = new_vectors?;
|
||||||
|
match (old_vectors.regenerate, new_vectors.regenerate) {
|
||||||
|
(true, true) | (false, false) => todo!(),
|
||||||
|
_ => {
|
||||||
|
self.sender
|
||||||
|
.set_user_provided(
|
||||||
|
update.docid(),
|
||||||
|
self.embedder_id,
|
||||||
|
!new_vectors.regenerate,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// do we have set embeddings?
|
||||||
|
if let Some(embeddings) = new_vectors.embeddings {
|
||||||
|
self.sender
|
||||||
|
.set_vectors(
|
||||||
|
update.docid(),
|
||||||
|
self.embedder_id,
|
||||||
|
embeddings.into_vec().map_err(UserError::SerdeJson)?,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
} else if new_vectors.regenerate {
|
||||||
|
let new_rendered = prompt.render_document(
|
||||||
|
update.current(
|
||||||
|
&context.txn,
|
||||||
|
context.index,
|
||||||
|
context.db_fields_ids_map,
|
||||||
|
)?,
|
||||||
|
context.new_fields_ids_map,
|
||||||
|
&context.doc_alloc,
|
||||||
|
)?;
|
||||||
|
let old_rendered = prompt.render_document(
|
||||||
|
update.new(
|
||||||
|
&context.txn,
|
||||||
|
context.index,
|
||||||
|
context.db_fields_ids_map,
|
||||||
|
)?,
|
||||||
|
context.new_fields_ids_map,
|
||||||
|
&context.doc_alloc,
|
||||||
|
)?;
|
||||||
|
if new_rendered != old_rendered {
|
||||||
|
chunks.push(update.docid(), new_rendered)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if old_vectors.regenerate {
|
||||||
|
let old_rendered = prompt.render_document(
|
||||||
|
update.current(
|
||||||
|
&context.txn,
|
||||||
|
context.index,
|
||||||
|
context.db_fields_ids_map,
|
||||||
|
)?,
|
||||||
|
context.new_fields_ids_map,
|
||||||
|
&context.doc_alloc,
|
||||||
|
)?;
|
||||||
|
let new_rendered = prompt.render_document(
|
||||||
|
update.new(&context.txn, context.index, context.db_fields_ids_map)?,
|
||||||
|
context.new_fields_ids_map,
|
||||||
|
&context.doc_alloc,
|
||||||
|
)?;
|
||||||
|
if new_rendered != old_rendered {
|
||||||
|
chunks.push(update.docid(), new_rendered)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
DocumentChange::Insertion(insertion) => {
|
||||||
|
// if no inserted vectors, then regenerate: true + no embeddings => autogenerate
|
||||||
|
let new_vectors = insertion.inserted_vectors(&context.doc_alloc)?;
|
||||||
|
if let Some(new_vectors) = new_vectors.as_ref().and_then(|new_vectors| {
|
||||||
|
new_vectors.vectors_for_key(embedder_name).transpose()
|
||||||
|
}) {
|
||||||
|
let new_vectors = new_vectors?;
|
||||||
|
self.sender
|
||||||
|
.set_user_provided(
|
||||||
|
insertion.docid(),
|
||||||
|
self.embedder_id,
|
||||||
|
!new_vectors.regenerate,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
if let Some(embeddings) = new_vectors.embeddings {
|
||||||
|
self.sender
|
||||||
|
.set_vectors(
|
||||||
|
insertion.docid(),
|
||||||
|
self.embedder_id,
|
||||||
|
embeddings.into_vec().map_err(UserError::SerdeJson)?,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
} else if new_vectors.regenerate {
|
||||||
|
let rendered = prompt.render_document(
|
||||||
|
insertion.new(),
|
||||||
|
context.new_fields_ids_map,
|
||||||
|
&context.doc_alloc,
|
||||||
|
)?;
|
||||||
|
chunks.push(insertion.docid(), rendered)?;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let rendered = prompt.render_document(
|
||||||
|
insertion.new(),
|
||||||
|
context.new_fields_ids_map,
|
||||||
|
&context.doc_alloc,
|
||||||
|
)?;
|
||||||
|
chunks.push(insertion.docid(), rendered)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks.drain()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use bumpalo::collections::Vec as BVec;
|
||||||
|
use bumpalo::Bump;
|
||||||
|
|
||||||
|
// **Warning**: the destructor of this struct is not normally run, make sure that all its fields:
|
||||||
|
// 1. don't have side effects tied to they destructors
|
||||||
|
// 2. if allocated, are allocated inside of the bumpalo
|
||||||
|
//
|
||||||
|
// Currently this is the case as:
|
||||||
|
// 1. BVec are inside of the bumaplo
|
||||||
|
// 2. All other fields are either trivial (u8) or references.
|
||||||
|
struct Chunks<'a> {
|
||||||
|
texts: BVec<'a, &'a str>,
|
||||||
|
ids: BVec<'a, DocumentId>,
|
||||||
|
|
||||||
|
embedder: &'a Embedder,
|
||||||
|
embedder_id: u8,
|
||||||
|
embedder_name: &'a str,
|
||||||
|
threads: &'a ThreadPoolNoAbort,
|
||||||
|
sender: &'a EmbeddingSender<'a>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Chunks<'a> {
|
||||||
|
pub fn new(
|
||||||
|
embedder: &'a Embedder,
|
||||||
|
embedder_id: u8,
|
||||||
|
embedder_name: &'a str,
|
||||||
|
threads: &'a ThreadPoolNoAbort,
|
||||||
|
sender: &'a EmbeddingSender<'a>,
|
||||||
|
doc_alloc: &'a Bump,
|
||||||
|
) -> Self {
|
||||||
|
let capacity = embedder.prompt_count_in_chunk_hint() * embedder.chunk_count_hint();
|
||||||
|
let texts = BVec::with_capacity_in(capacity, doc_alloc);
|
||||||
|
let ids = BVec::with_capacity_in(capacity, doc_alloc);
|
||||||
|
Self { texts, ids, embedder, threads, sender, embedder_id, embedder_name }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn push(&mut self, docid: DocumentId, rendered: &'a str) -> Result<()> {
|
||||||
|
if self.texts.len() < self.texts.capacity() {
|
||||||
|
self.texts.push(rendered);
|
||||||
|
self.ids.push(docid);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
Self::embed_chunks(
|
||||||
|
&mut self.texts,
|
||||||
|
&mut self.ids,
|
||||||
|
self.embedder,
|
||||||
|
self.embedder_id,
|
||||||
|
self.embedder_name,
|
||||||
|
self.threads,
|
||||||
|
self.sender,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn drain(mut self) -> Result<()> {
|
||||||
|
let res = Self::embed_chunks(
|
||||||
|
&mut self.texts,
|
||||||
|
&mut self.ids,
|
||||||
|
self.embedder,
|
||||||
|
self.embedder_id,
|
||||||
|
self.embedder_name,
|
||||||
|
self.threads,
|
||||||
|
self.sender,
|
||||||
|
);
|
||||||
|
// optimization: don't run bvec dtors as they only contain bumpalo allocated stuff
|
||||||
|
std::mem::forget(self);
|
||||||
|
res
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed_chunks(
|
||||||
|
texts: &mut BVec<'a, &'a str>,
|
||||||
|
ids: &mut BVec<'a, DocumentId>,
|
||||||
|
embedder: &'a Embedder,
|
||||||
|
embedder_id: u8,
|
||||||
|
embedder_name: &str,
|
||||||
|
threads: &'a ThreadPoolNoAbort,
|
||||||
|
sender: &'a EmbeddingSender<'a>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let res = match embedder.embed_chunks_ref(texts.as_slice(), threads) {
|
||||||
|
Ok(embeddings) => {
|
||||||
|
for (docid, embedding) in ids.into_iter().zip(embeddings) {
|
||||||
|
sender.set_vector(*docid, embedder_id, embedding).unwrap();
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
Err(error) => {
|
||||||
|
if let FaultSource::Bug = error.fault {
|
||||||
|
Err(crate::Error::InternalError(crate::InternalError::VectorEmbeddingError(
|
||||||
|
error.into(),
|
||||||
|
)))
|
||||||
|
} else {
|
||||||
|
let mut msg = format!(
|
||||||
|
r"While embedding documents for embedder `{embedder_name}`: {error}"
|
||||||
|
);
|
||||||
|
|
||||||
|
if let EmbedErrorKind::ManualEmbed(_) = &error.kind {
|
||||||
|
msg += &format!("\n- Note: `{embedder_name}` has `source: userProvided`, so documents must provide embeddings as an array in `_vectors.{embedder_name}`.");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// FIXME: reintroduce possible_embedding_mistakes and possible_embedding_mistakes
|
||||||
|
let mut hint_count = 0;
|
||||||
|
|
||||||
|
/*
|
||||||
|
for (vector_misspelling, count) in
|
||||||
|
possible_embedding_mistakes.vector_mistakes().take(2)
|
||||||
|
{
|
||||||
|
msg += &format!("\n- Hint: try replacing `{vector_misspelling}` by `_vectors` in {count} document(s).");
|
||||||
|
hint_count += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (embedder_misspelling, count) in possible_embedding_mistakes
|
||||||
|
.embedder_mistakes(embedder_name, unused_vectors_distribution)
|
||||||
|
.take(2)
|
||||||
|
{
|
||||||
|
msg += &format!("\n- Hint: try replacing `_vectors.{embedder_misspelling}` by `_vectors.{embedder_name}` in {count} document(s).");
|
||||||
|
hint_count += 1;
|
||||||
|
}
|
||||||
|
*/
|
||||||
|
if hint_count == 0 {
|
||||||
|
if let EmbedErrorKind::ManualEmbed(_) = &error.kind {
|
||||||
|
msg += &format!(
|
||||||
|
"\n- Hint: opt-out for a document with `_vectors.{embedder_name}: null`"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(crate::Error::UserError(crate::UserError::DocumentEmbeddingError(msg)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
texts.clear();
|
||||||
|
ids.clear();
|
||||||
|
res
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user