From f72986446668e9ea504b79d55e7e8505b00c0685 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 31 Mar 2025 13:43:57 +0200 Subject: [PATCH] Check dimension mismatch at insertion time --- .../src/update/new/extract/vectors/mod.rs | 29 +++++++++++++++++-- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/crates/milli/src/update/new/extract/vectors/mod.rs b/crates/milli/src/update/new/extract/vectors/mod.rs index 6820ee67b..696864e7f 100644 --- a/crates/milli/src/update/new/extract/vectors/mod.rs +++ b/crates/milli/src/update/new/extract/vectors/mod.rs @@ -121,6 +121,7 @@ impl<'a, 'b, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a, 'b> { // do we have set embeddings? if let Some(embeddings) = new_vectors.embeddings { chunks.set_vectors( + update.external_document_id(), update.docid(), embeddings .into_vec(&context.doc_alloc, embedder_name) @@ -128,7 +129,7 @@ impl<'a, 'b, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a, 'b> { document_id: update.external_document_id().to_string(), error: error.to_string(), })?, - ); + )?; } else if new_vectors.regenerate { let new_rendered = prompt.render_document( update.external_document_id(), @@ -209,6 +210,7 @@ impl<'a, 'b, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a, 'b> { chunks.set_regenerate(insertion.docid(), new_vectors.regenerate); if let Some(embeddings) = new_vectors.embeddings { chunks.set_vectors( + insertion.external_document_id(), insertion.docid(), embeddings .into_vec(&context.doc_alloc, embedder_name) @@ -218,7 +220,7 @@ impl<'a, 'b, 'extractor> Extractor<'extractor> for EmbeddingExtractor<'a, 'b> { .to_string(), error: error.to_string(), })?, - ); + )?; } else if new_vectors.regenerate { let rendered = prompt.render_document( insertion.external_document_id(), @@ -273,6 +275,7 @@ struct Chunks<'a, 'b, 'extractor> { embedder: &'a Embedder, embedder_id: u8, embedder_name: &'a str, + dimensions: usize, prompt: &'a Prompt, possible_embedding_mistakes: &'a PossibleEmbeddingMistakes, user_provided: &'a RefCell>, @@ -297,6 +300,7 @@ impl<'a, 'b, 'extractor> Chunks<'a, 'b, 'extractor> { let capacity = embedder.prompt_count_in_chunk_hint() * embedder.chunk_count_hint(); let texts = BVec::with_capacity_in(capacity, doc_alloc); let ids = BVec::with_capacity_in(capacity, doc_alloc); + let dimensions = embedder.dimensions(); Self { texts, ids, @@ -309,6 +313,7 @@ impl<'a, 'b, 'extractor> Chunks<'a, 'b, 'extractor> { embedder_name, user_provided, has_manual_generation: None, + dimensions, } } @@ -490,7 +495,25 @@ impl<'a, 'b, 'extractor> Chunks<'a, 'b, 'extractor> { } } - fn set_vectors(&self, docid: DocumentId, embeddings: Vec) { + fn set_vectors( + &self, + external_docid: &'a str, + docid: DocumentId, + embeddings: Vec, + ) -> Result<()> { + for (embedding_index, embedding) in embeddings.iter().enumerate() { + if embedding.len() != self.dimensions { + return Err(UserError::InvalidIndexingVectorDimensions { + expected: self.dimensions, + found: embedding.len(), + embedder_name: self.embedder_name.to_string(), + document_id: external_docid.to_string(), + embedding_index, + } + .into()); + } + } self.sender.set_vectors(docid, self.embedder_id, embeddings).unwrap(); + Ok(()) } }