diff --git a/Cargo.lock b/Cargo.lock index 20259b1d1..9ca569b86 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -120,7 +120,7 @@ dependencies = [ "futures-util", "mio", "num_cpus", - "socket2", + "socket2 0.4.9", "tokio", "tracing", ] @@ -201,7 +201,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "smallvec", - "socket2", + "socket2 0.4.9", "time", "url", ] @@ -1690,9 +1690,9 @@ checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -1705,9 +1705,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -1715,15 +1715,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -1732,15 +1732,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", @@ -1749,21 +1749,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -2207,7 +2207,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -2980,9 +2980,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libgit2-sys" @@ -3589,6 +3589,7 @@ dependencies = [ "filter-parser", "flatten-serde-json", "fst", + "futures", "fxhash", "geoutils", "grenad", @@ -3626,6 +3627,7 @@ dependencies = [ "thiserror", "time", "tokenizers", + "tokio", "uuid 1.5.0", ] @@ -3671,9 +3673,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "log", @@ -4977,6 +4979,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.5.2" @@ -5258,11 +5270,10 @@ dependencies = [ [[package]] name = "tokio" -version = "1.29.1" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ - "autocfg", "backtrace", "bytes", "libc", @@ -5271,16 +5282,16 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.5", "tokio-macros", "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 653cb108b..8d4b0327d 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -302,7 +302,8 @@ TaskNotFound , InvalidRequest , NOT_FOUND ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; -UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE +UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; +VectorEmbeddingError , InvalidRequest , BAD_REQUEST } impl ErrorCode for JoinError { @@ -357,6 +358,7 @@ impl ErrorCode for milli::Error { UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance } + UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, } } } diff --git a/milli/Cargo.toml b/milli/Cargo.toml index acf658ff6..12a2fdcb8 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -77,6 +77,8 @@ candle-transformers = { git = "https://github.com/huggingface/candle.git", versi candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } hf-hub = "0.3.2" +tokio = { version = "1.34.0", features = ["rt"] } +futures = "0.3.29" [dev-dependencies] mimalloc = { version = "0.1.37", default-features = false } diff --git a/milli/src/error.rs b/milli/src/error.rs index cbbd8a3e5..06a2aa1bb 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -180,6 +180,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco UnknownInternalDocumentId { document_id: DocumentId }, #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] InvalidMinTypoWordLenSetting(u8, u8), + #[error(transparent)] + VectorEmbeddingError(#[from] crate::vector::Error), } #[derive(Error, Debug)] diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 317a9aec3..40593260d 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -1,13 +1,15 @@ use std::cmp::Ordering; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::fs::File; use std::io::{self, BufReader, BufWriter}; use std::mem::size_of; use std::str::from_utf8; +use std::sync::{Arc, OnceLock}; use bytemuck::cast_slice; use grenad::Writer; use itertools::EitherOrBoth; +use obkv::KvReader; use ordered_float::OrderedFloat; use serde_json::{from_slice, Value}; @@ -15,11 +17,53 @@ use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::error::UserError; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::try_split_at; +use crate::vector::{Embedder, EmbedderOptions}; use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; /// The length of the elements that are always in the buffer when inserting new values. const TRUNCATE_SIZE: usize = size_of::(); +pub struct ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + pub manual_vectors: grenad::Reader>, + // docid -> () + pub remove_vectors: grenad::Reader>, + // docid -> prompt + pub prompts: grenad::Reader>, +} + +enum VectorStateDelta { + NoChange, + // Remove all vectors, generated or manual, from this document + NowRemoved, + + // Add the manually specified vectors, passed in the other grenad + // Remove any previously generated vectors + // Note: changing the value of the manually specified vector **should not record** this delta + WasGeneratedNowManual(Vec>), + + ManualDelta(Vec>, Vec>), + + // Add the vector computed from the specified prompt + // Remove any previous vector + // Note: changing the value of the prompt **does require** recording this delta + NowGenerated(String), +} + +impl VectorStateDelta { + fn into_values(self) -> (bool, String, (Vec>, Vec>)) { + match self { + VectorStateDelta::NoChange => Default::default(), + VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()), + VectorStateDelta::WasGeneratedNowManual(add) => { + (true, Default::default(), (Default::default(), add)) + } + VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)), + VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()), + } + } +} + /// Extracts the embedding vector contained in each document under the `_vectors` field. /// /// Returns the generated grenad reader containing the docid as key associated to the Vec @@ -28,10 +72,25 @@ pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, vectors_fid: FieldId, -) -> Result>> { +) -> Result { puffin::profile_function!(); - let mut writer = create_writer( + // (docid, _index) -> KvWriterDelAdd -> Vector + let mut manual_vectors_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + // (docid) -> (prompt) + let mut prompts_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + // (docid) -> () + let mut remove_vectors_writer = create_writer( indexer.chunk_compression_type, indexer.chunk_compression_level, tempfile::tempfile()?, @@ -53,43 +112,119 @@ pub fn extract_vector_points( // lazily get it when needed let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; - // first we retrieve the _vectors field - if let Some(value) = obkv.get(vectors_fid) { + let delta = if let Some(value) = obkv.get(vectors_fid) { let vectors_obkv = KvReaderDelAdd::new(value); + match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) { + (Some(old), Some(new)) => { + // no autogeneration + let del_vectors = extract_vectors(old, document_id)?; + let add_vectors = extract_vectors(new, document_id)?; - // then we extract the values - let del_vectors = vectors_obkv - .get(DelAdd::Deletion) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); - let add_vectors = vectors_obkv - .get(DelAdd::Addition) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); + VectorStateDelta::ManualDelta( + del_vectors.unwrap_or_default(), + add_vectors.unwrap_or_default(), + ) + } + (None, Some(new)) => { + // was possibly autogenerated, remove all vectors for that document + let add_vectors = extract_vectors(new, document_id)?; - // and we finally push the unique vectors into the writer - push_vectors_diff( - &mut writer, - &mut key_buffer, - del_vectors.unwrap_or_default(), - add_vectors.unwrap_or_default(), - )?; - } + VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default()) + } + (Some(_old), None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + if document_is_kept { + // becomes autogenerated + VectorStateDelta::NowGenerated(prompt_for(obkv, DelAdd::Addition)) + } else { + VectorStateDelta::NowRemoved + } + } + (None, None) => { + // no change + VectorStateDelta::NoChange + } + } + } else { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + // fixme only if obkv changed + let old_prompt = prompt_for(obkv, DelAdd::Deletion); + let new_prompt = prompt_for(obkv, DelAdd::Addition); + if old_prompt != new_prompt { + VectorStateDelta::NowGenerated(new_prompt) + } else { + VectorStateDelta::NoChange + } + } else { + VectorStateDelta::NowRemoved + } + }; + + // and we finally push the unique vectors into the writer + push_vectors_diff( + &mut remove_vectors_writer, + &mut prompts_writer, + &mut manual_vectors_writer, + &mut key_buffer, + delta, + )?; } - writer_into_reader(writer) + Ok(ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + manual_vectors: writer_into_reader(manual_vectors_writer)?, + // docid -> () + remove_vectors: writer_into_reader(remove_vectors_writer)?, + // docid -> prompt + prompts: writer_into_reader(prompts_writer)?, + }) +} + +fn prompt_for(obkv: KvReader<'_, FieldId>, side: DelAdd) -> String { + let mut texts = String::new(); + for (_fid, value) in obkv.iter() { + let deladd = KvReaderDelAdd::new(value); + let Some(value) = deladd.get(side) else { + continue; + }; + let Ok(value) = from_slice(value) else { + continue; + }; + + texts += value; + } + texts } /// Computes the diff between both Del and Add numbers and /// only inserts the parts that differ in the sorter. fn push_vectors_diff( - writer: &mut Writer>, + remove_vectors_writer: &mut Writer>, + prompts_writer: &mut Writer>, + manual_vectors_writer: &mut Writer>, key_buffer: &mut Vec, - mut del_vectors: Vec>, - mut add_vectors: Vec>, + delta: VectorStateDelta, ) -> Result<()> { + let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values(); + if must_remove { + key_buffer.truncate(TRUNCATE_SIZE); + remove_vectors_writer.insert(&key_buffer, [])?; + } + if !prompt.is_empty() { + key_buffer.truncate(TRUNCATE_SIZE); + prompts_writer.insert(&key_buffer, prompt.as_bytes())?; + } + // We sort and dedup the vectors del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); @@ -114,7 +249,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } EitherOrBoth::Right(vector) => { // We insert only the Add part of the Obkv to inform @@ -122,7 +257,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Addition, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } } } @@ -146,3 +281,76 @@ fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result( + // docid, prompt + prompt_reader: grenad::Reader, + indexer: GrenadParameters, + embedder: Arc>, +) -> Result>> { + let rt = tokio::runtime::Builder::new_current_thread().build()?; + let embedder = embedder.get_or_init(|| Embedder::new(EmbedderOptions::new()).unwrap()); + + let n_chunks = 1; // chunk level parellelism + let n_vectors_per_chunk = 2000; // number of vectors in a single chunk + + // docid, state with embedding + let mut state_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let mut chunks = Vec::with_capacity(n_chunks); + let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk); + let mut all_ids = Vec::with_capacity(n_chunks * n_vectors_per_chunk); + let mut cursor = prompt_reader.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + // SAFETY: precondition, the grenad value was saved from a string + let prompt = unsafe { std::str::from_utf8_unchecked(value) }; + all_ids.push(docid); + current_chunk = if current_chunk.len() == current_chunk.capacity() { + chunks.push(std::mem::take(&mut current_chunk)); + Vec::with_capacity(n_vectors_per_chunk) + } else { + current_chunk + }; + current_chunk.push(prompt.to_owned()); + + if chunks.len() == chunks.capacity() { + let chunked_embeds = rt + .block_on( + embedder + .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), + ) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + for (docid, embedding) in + all_ids.iter().zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_ne_bytes(), cast_slice(embedding))? + } + } + } + + // send last chunk + if !chunks.is_empty() { + let chunked_embeds = rt + .block_on( + embedder.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), + ) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + for (docid, embedding) in + all_ids.iter().zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_ne_bytes(), cast_slice(embedding))? + } + } + + writer_into_reader(state_writer) +} diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 91f3e1c62..495378999 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -12,6 +12,7 @@ mod extract_word_position_docids; use std::collections::HashSet; use std::fs::File; use std::io::BufReader; +use std::sync::{Arc, OnceLock}; use crossbeam_channel::Sender; use log::debug; @@ -23,7 +24,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_geo_points::extract_geo_points; -use self::extract_vector_points::extract_vector_points; +use self::extract_vector_points::{ + extract_embeddings, extract_vector_points, ExtractedVectorPoints, +}; use self::extract_word_docids::extract_word_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; use self::extract_word_position_docids::extract_word_position_docids; @@ -52,6 +55,7 @@ pub(crate) fn data_from_obkv_documents( dictionary: Option<&[&str]>, max_positions_per_attributes: Option, exact_attributes: HashSet, + embedder: Arc>, ) -> Result<()> { puffin::profile_function!(); @@ -63,6 +67,7 @@ pub(crate) fn data_from_obkv_documents( indexer, lmdb_writer_sx.clone(), vectors_field_id, + embedder.clone(), ) }) .collect::>()?; @@ -273,6 +278,7 @@ fn send_original_documents_data( indexer: GrenadParameters, lmdb_writer_sx: Sender>, vectors_field_id: Option, + embedder: Arc>, ) -> Result<()> { let original_documents_chunk = original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; @@ -283,8 +289,17 @@ fn send_original_documents_data( rayon::spawn(move || { let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); let _ = match result { - Ok(vector_points) => { - lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) + Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { + match extract_embeddings(prompts, indexer, embedder) { + Ok(embeddings) => { + lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { + remove_vectors, + embeddings, + manual_vectors, + })) + } + Err(error) => lmdb_writer_sx_cloned.send(Err(error)), + } } Err(error) => lmdb_writer_sx_cloned.send(Err(error)), }; diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index be2fbb25e..c4fa0b02a 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -363,6 +363,8 @@ where self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; + let cloned_embedder = self.indexer_config.embedder.clone(); + // Run extraction pipeline in parallel. pool.install(|| { puffin::profile_scope!("extract_and_send_grenad_chunks"); @@ -392,6 +394,7 @@ where dictionary.as_deref(), max_positions_per_attributes, exact_attributes, + cloned_embedder, ) }); diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 49e36b87e..e1e3f2381 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -47,7 +47,11 @@ pub(crate) enum TypedChunk { FieldIdFacetIsNullDocids(grenad::Reader>), FieldIdFacetIsEmptyDocids(grenad::Reader>), GeoPoints(grenad::Reader>), - VectorPoints(grenad::Reader>), + VectorPoints { + remove_vectors: grenad::Reader>, + embeddings: grenad::Reader>, + manual_vectors: grenad::Reader>, + }, ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), } @@ -100,8 +104,8 @@ impl TypedChunk { TypedChunk::GeoPoints(grenad) => { format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) } - TypedChunk::VectorPoints(grenad) => { - format!("VectorPoints {{ number_of_entries: {} }}", grenad.len()) + TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings } => { + format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.len()) } TypedChunk::ScriptLanguageDocids(sl_map) => { format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) @@ -355,19 +359,41 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_rtree(wtxn, &rtree)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } - TypedChunk::VectorPoints(vector_points) => { - let mut vectors_set = HashSet::new(); + TypedChunk::VectorPoints { remove_vectors, manual_vectors, embeddings } => { + let mut docid_vectors_map: HashMap>>> = + HashMap::new(); + // We extract and store the previous vectors if let Some(hnsw) = index.vector_hnsw(wtxn)? { for (pid, point) in hnsw.iter() { let pid_key = pid.into_inner(); let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); + docid_vectors_map.entry(docid).or_default().insert(vector); } } - let mut cursor = vector_points.into_cursor()?; + // remove vectors for docids we want them removed + let mut cursor = remove_vectors.into_cursor()?; + while let Some((key, _)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + + docid_vectors_map.remove(&docid); + } + + // add generated embeddings + let mut cursor = embeddings.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + let vector: Vec> = + pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + let mut set = HashSet::new(); + set.insert(vector); + docid_vectors_map.insert(docid, set); + } + + // perform the manual diff + let mut cursor = manual_vectors.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let (left, _index) = try_split_array_at(key).unwrap(); @@ -376,23 +402,30 @@ pub(crate) fn write_typed_chunk_into_index( let vector_deladd_obkv = KvReaderDelAdd::new(value); if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { // convert the vector back to a Vec - let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - let key = (docid, vector); - if !vectors_set.remove(&key) { - error!("Unable to delete the vector: {:?}", key.1); - } + let vector: Vec> = + pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + docid_vectors_map.entry(docid).and_modify(|v| { + if !v.remove(&vector) { + error!("Unable to delete the vector: {:?}", vector); + } + }); } if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { // convert the vector back to a Vec let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); + docid_vectors_map.entry(docid).and_modify(|v| { + v.insert(vector); + }); } } // Extract the most common vector dimension let expected_dimension_size = { let mut dims = HashMap::new(); - vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); + docid_vectors_map + .values() + .flat_map(|v| v.iter()) + .for_each(|v| *dims.entry(v.len()).or_insert(0) += 1); dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) }; @@ -400,7 +433,10 @@ pub(crate) fn write_typed_chunk_into_index( // prepare the vectors before inserting them in the HNSW. let mut points = Vec::new(); let mut docids = Vec::new(); - for (docid, vector) in vectors_set { + for (docid, vector) in docid_vectors_map + .into_iter() + .flat_map(|(docid, vectors)| std::iter::repeat(docid).zip(vectors)) + { if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { return Err(UserError::InvalidVectorDimensions { expected: expected_dimension_size.unwrap_or(vector.len()), diff --git a/milli/src/update/indexer_config.rs b/milli/src/update/indexer_config.rs index ff7942fdb..6821f88cd 100644 --- a/milli/src/update/indexer_config.rs +++ b/milli/src/update/indexer_config.rs @@ -1,3 +1,5 @@ +use std::sync::{Arc, OnceLock}; + use grenad::CompressionType; use rayon::ThreadPool; @@ -12,6 +14,7 @@ pub struct IndexerConfig { pub thread_pool: Option, pub max_positions_per_attributes: Option, pub skip_index_budget: bool, + pub embedder: Arc>, } impl Default for IndexerConfig { @@ -26,6 +29,7 @@ impl Default for IndexerConfig { thread_pool: None, max_positions_per_attributes: None, skip_index_budget: false, + embedder: Arc::new(OnceLock::new()), } } } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index ff0f0711b..b838671a9 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -117,7 +117,10 @@ impl Embedder { Ok(Self { model, tokenizer, options }) } - pub fn embed(&self, texts: Vec) -> std::result::Result>, EmbedError> { + pub async fn embed( + &self, + texts: Vec, + ) -> std::result::Result>, EmbedError> { let tokens = self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?; let token_ids = tokens .iter() @@ -147,6 +150,14 @@ impl Embedder { let embeddings = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; Ok(embeddings) } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> std::result::Result>>, EmbedError> { + futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) + .await + } } fn normalize_l2(v: &Tensor) -> Result {