diff --git a/crates/milli/src/update/new/extract/mod.rs b/crates/milli/src/update/new/extract/mod.rs index 2abefb098..05c90d8f8 100644 --- a/crates/milli/src/update/new/extract/mod.rs +++ b/crates/milli/src/update/new/extract/mod.rs @@ -12,7 +12,7 @@ pub use documents::*; pub use faceted::*; pub use geo::*; pub use searchable::*; -pub use vectors::EmbeddingExtractor; +pub use vectors::{EmbeddingExtractor, SettingsChangeEmbeddingExtractor}; /// TODO move in permissive json pointer pub mod perm_json_p { diff --git a/crates/milli/src/update/new/extract/vectors/mod.rs b/crates/milli/src/update/new/extract/vectors/mod.rs index 43647e786..ac00e9811 100644 --- a/crates/milli/src/update/new/extract/vectors/mod.rs +++ b/crates/milli/src/update/new/extract/vectors/mod.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::BTreeMap; use bumpalo::collections::Vec as BVec; use bumpalo::Bump; @@ -8,13 +9,16 @@ use super::cache::DelAddRoaringBitmap; use crate::error::FaultSource; use crate::prompt::Prompt; use crate::update::new::channel::EmbeddingSender; +use crate::update::new::document_change::DatabaseDocument; use crate::update::new::indexer::document_changes::{DocumentChangeContext, Extractor}; +use crate::update::new::indexer::settings_changes::SettingsChangeExtractor; use crate::update::new::thread_local::MostlySend; use crate::update::new::vector_document::VectorDocument; use crate::update::new::DocumentChange; use crate::vector::error::{ EmbedErrorKind, PossibleEmbeddingMistakes, UnusedVectorsDistributionBump, }; +use crate::vector::settings::{EmbedderAction, ReindexAction}; use crate::vector::{Embedder, Embedding, EmbeddingConfigs}; use crate::{DocumentId, FieldDistribution, InternalError, Result, ThreadPoolNoAbort, UserError}; @@ -290,6 +294,200 @@ impl<'extractor> Extractor<'extractor> for EmbeddingExtractor<'_, '_> { } } +pub struct SettingsChangeEmbeddingExtractor<'a, 'b> { + embedders: &'a EmbeddingConfigs, + old_embedders: &'a EmbeddingConfigs, + embedder_actions: &'a BTreeMap, + embedder_category_id: &'a std::collections::HashMap, + sender: EmbeddingSender<'a, 'b>, + possible_embedding_mistakes: PossibleEmbeddingMistakes, + threads: &'a ThreadPoolNoAbort, +} + +impl<'a, 'b> SettingsChangeEmbeddingExtractor<'a, 'b> { + pub fn new( + embedders: &'a EmbeddingConfigs, + old_embedders: &'a EmbeddingConfigs, + embedder_actions: &'a BTreeMap, + embedder_category_id: &'a std::collections::HashMap, + sender: EmbeddingSender<'a, 'b>, + field_distribution: &'a FieldDistribution, + threads: &'a ThreadPoolNoAbort, + ) -> Self { + let possible_embedding_mistakes = PossibleEmbeddingMistakes::new(field_distribution); + Self { + embedders, + old_embedders, + embedder_actions, + embedder_category_id, + sender, + threads, + possible_embedding_mistakes, + } + } +} + +impl<'extractor> SettingsChangeExtractor<'extractor> for SettingsChangeEmbeddingExtractor<'_, '_> { + type Data = RefCell>; + + fn init_data<'doc>(&'doc self, extractor_alloc: &'extractor Bump) -> crate::Result { + Ok(RefCell::new(EmbeddingExtractorData(HashMap::new_in(extractor_alloc)))) + } + + fn process<'doc>( + &'doc self, + documents: impl Iterator>>, + context: &'doc DocumentChangeContext, + ) -> crate::Result<()> { + let embedders = self.embedders.inner_as_ref(); + let old_embedders = self.old_embedders.inner_as_ref(); + let unused_vectors_distribution = UnusedVectorsDistributionBump::new_in(&context.doc_alloc); + + let mut all_chunks = BVec::with_capacity_in(embedders.len(), &context.doc_alloc); + for (embedder_name, (embedder, prompt, _is_quantized)) in embedders { + // if the embedder is not in the embedder_actions, we don't need to reindex. + if let Some((embedder_id, reindex_action)) = + self.embedder_actions.get(embedder_name).and_then(|action| { + let embedder_id = self + .embedder_category_id + .get(embedder_name) + .expect("embedder_category_id should be present"); + action.reindex().map(|reindex| (*embedder_id, reindex)) + }) + { + all_chunks.push(( + Chunks::new( + embedder, + embedder_id, + embedder_name, + prompt, + context.data, + &self.possible_embedding_mistakes, + self.threads, + self.sender, + &context.doc_alloc, + ), + reindex_action, + )) + } + } + + for document in documents { + let document = document?; + + let current_vectors = document.current_vectors( + &context.rtxn, + context.index, + context.db_fields_ids_map, + &context.doc_alloc, + )?; + + for (chunks, reindex_action) in &mut all_chunks { + let embedder_name = chunks.embedder_name(); + let current_vectors = current_vectors.vectors_for_key(embedder_name)?; + + // if the vectors for this document have been already provided, we don't need to reindex. + let (is_new_embedder, must_regenerate) = + current_vectors.as_ref().map_or((true, true), |vectors| { + (!vectors.has_configured_embedder, vectors.regenerate) + }); + + match reindex_action { + ReindexAction::RegeneratePrompts => { + if !must_regenerate { + continue; + } + // we need to regenerate the prompts for the document + + // Get the old prompt and render the document with it + let Some((_, old_prompt, _)) = old_embedders.get(embedder_name) else { + unreachable!("ReindexAction::RegeneratePrompts implies that the embedder {embedder_name} is in the old_embedders") + }; + let old_rendered = old_prompt.render_document( + document.external_document_id(), + document.current( + &context.rtxn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + + // Get the new prompt and render the document with it + let new_prompt = chunks.prompt(); + let new_rendered = new_prompt.render_document( + document.external_document_id(), + document.current( + &context.rtxn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + + // Compare the rendered documents + // if they are different, regenerate the vectors + if new_rendered != old_rendered { + chunks.set_autogenerated( + document.docid(), + document.external_document_id(), + new_rendered, + &unused_vectors_distribution, + )?; + } + } + ReindexAction::FullReindex => { + let prompt = chunks.prompt(); + // if no inserted vectors, then regenerate: true + no embeddings => autogenerate + if let Some(embeddings) = current_vectors + .and_then(|vectors| vectors.embeddings) + // insert the embeddings only for new embedders + .filter(|_| is_new_embedder) + { + chunks.set_regenerate(document.docid(), must_regenerate); + chunks.set_vectors( + document.external_document_id(), + document.docid(), + embeddings.into_vec(&context.doc_alloc, embedder_name).map_err( + |error| UserError::InvalidVectorsEmbedderConf { + document_id: document.external_document_id().to_string(), + error: error.to_string(), + }, + )?, + )?; + } else if must_regenerate { + let rendered = prompt.render_document( + document.external_document_id(), + document.current( + &context.rtxn, + context.index, + context.db_fields_ids_map, + )?, + context.new_fields_ids_map, + &context.doc_alloc, + )?; + chunks.set_autogenerated( + document.docid(), + document.external_document_id(), + rendered, + &unused_vectors_distribution, + )?; + } + } + } + } + } + + for (chunk, _) in all_chunks { + chunk.drain(&unused_vectors_distribution)?; + } + + Ok(()) + } +} + // **Warning**: the destructor of this struct is not normally run, make sure that all its fields: // 1. don't have side effects tied to they destructors // 2. if allocated, are allocated inside of the bumpalo diff --git a/crates/milli/src/update/new/indexer/extract.rs b/crates/milli/src/update/new/indexer/extract.rs index 246416503..6b8115d42 100644 --- a/crates/milli/src/update/new/indexer/extract.rs +++ b/crates/milli/src/update/new/indexer/extract.rs @@ -12,6 +12,7 @@ use super::super::steps::IndexingStep; use super::super::thread_local::{FullySend, ThreadLocal}; use super::super::FacetFieldIdsDelta; use super::document_changes::{extract, DocumentChanges, IndexingContext}; +use super::settings_changes::settings_change_extract; use crate::documents::FieldIdMapper; use crate::documents::PrimaryKey; use crate::index::IndexEmbeddingConfig; @@ -353,6 +354,53 @@ where extractor_allocs, )?; + 'vectors: { + if settings_delta.embedder_actions().is_empty() { + break 'vectors; + } + + let embedding_sender = extractor_sender.embeddings(); + + // extract the remaining embedders + let extractor = SettingsChangeEmbeddingExtractor::new( + settings_delta.new_embedders(), + settings_delta.old_embedders(), + settings_delta.embedder_actions(), + settings_delta.new_embedder_category_id(), + embedding_sender, + field_distribution, + request_threads(), + ); + let mut datastore = ThreadLocal::with_capacity(rayon::current_num_threads()); + { + let span = tracing::debug_span!(target: "indexing::documents::extract", "vectors"); + let _entered = span.enter(); + + settings_change_extract( + &documents, + &extractor, + indexing_context, + extractor_allocs, + &datastore, + IndexingStep::ExtractingEmbeddings, + )?; + } + { + let span = tracing::debug_span!(target: "indexing::documents::merge", "vectors"); + let _entered = span.enter(); + + for config in &mut index_embeddings { + 'data: for data in datastore.iter_mut() { + let data = &mut data.get_mut().0; + let Some(deladd) = data.remove(&config.name) else { + continue 'data; + }; + deladd.apply_to(&mut config.user_provided, modified_docids); + } + } + } + } + indexing_context.progress.update_progress(IndexingStep::WaitingForDatabaseWrites); finished_extraction.store(true, std::sync::atomic::Ordering::Relaxed); diff --git a/crates/milli/src/update/new/indexer/mod.rs b/crates/milli/src/update/new/indexer/mod.rs index 0848cc39d..9626940b0 100644 --- a/crates/milli/src/update/new/indexer/mod.rs +++ b/crates/milli/src/update/new/indexer/mod.rs @@ -168,6 +168,7 @@ where index_embeddings, arroy_memory, &mut arroy_writers, + None, &indexing_context.must_stop_processing, ) }) diff --git a/crates/milli/src/update/new/indexer/write.rs b/crates/milli/src/update/new/indexer/write.rs index 5a600eeb3..19696f169 100644 --- a/crates/milli/src/update/new/indexer/write.rs +++ b/crates/milli/src/update/new/indexer/write.rs @@ -1,3 +1,4 @@ +use std::collections::BTreeMap; use std::sync::atomic::AtomicBool; use bstr::ByteSlice as _; @@ -13,6 +14,7 @@ use crate::fields_ids_map::metadata::FieldIdMapWithMetadata; use crate::index::IndexEmbeddingConfig; use crate::progress::Progress; use crate::update::settings::InnerIndexSettings; +use crate::vector::settings::EmbedderAction; use crate::vector::{ArroyWrapper, Embedder, EmbeddingConfigs, Embeddings}; use crate::{Error, Index, InternalError, Result, UserError}; @@ -106,6 +108,7 @@ pub fn build_vectors( index_embeddings: Vec, arroy_memory: Option, arroy_writers: &mut HashMap, + embeder_actions: Option<&BTreeMap>, must_stop_processing: &MSP, ) -> Result<()> where @@ -117,14 +120,17 @@ where let seed = rand::random(); let mut rng = rand::rngs::StdRng::seed_from_u64(seed); - for (_index, (_embedder_name, _embedder, writer, dimensions)) in arroy_writers { + for (_index, (embedder_name, _embedder, writer, dimensions)) in arroy_writers { let dimensions = *dimensions; + let is_being_quantized = embeder_actions + .and_then(|actions| actions.get(*embedder_name).map(|action| action.is_being_quantized)) + .unwrap_or(false); writer.build_and_quantize( wtxn, progress, &mut rng, dimensions, - false, + is_being_quantized, arroy_memory, must_stop_processing, )?;