diff --git a/crates/milli/src/update/index_documents/extract/extract_vector_points.rs b/crates/milli/src/update/index_documents/extract/extract_vector_points.rs index e940e743b..e6d874a69 100644 --- a/crates/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/crates/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -684,12 +684,10 @@ pub fn extract_embeddings( embedder: Arc, embedder_name: &str, possible_embedding_mistakes: &PossibleEmbeddingMistakes, - embedder_stats: Option>, + embedder_stats: Arc, unused_vectors_distribution: &UnusedVectorsDistribution, request_threads: &ThreadPoolNoAbort, ) -> Result>> { - println!("Extract embedder stats {}:", embedder_stats.is_some()); - let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk @@ -791,7 +789,7 @@ fn embed_chunks( text_chunks: Vec>, embedder_name: &str, possible_embedding_mistakes: &PossibleEmbeddingMistakes, - embedder_stats: Option>, + embedder_stats: Arc, unused_vectors_distribution: &UnusedVectorsDistribution, request_threads: &ThreadPoolNoAbort, ) -> Result>> { diff --git a/crates/milli/src/update/index_documents/extract/mod.rs b/crates/milli/src/update/index_documents/extract/mod.rs index f4f3ad22e..1eeddcccb 100644 --- a/crates/milli/src/update/index_documents/extract/mod.rs +++ b/crates/milli/src/update/index_documents/extract/mod.rs @@ -274,7 +274,7 @@ fn send_original_documents_data( embedder.clone(), &embedder_name, &possible_embedding_mistakes, - Some(embedder_stats.clone()), + embedder_stats.clone(), &unused_vectors_distribution, request_threads(), ) { diff --git a/crates/milli/src/update/new/extract/vectors/mod.rs b/crates/milli/src/update/new/extract/vectors/mod.rs index 946fb00b5..c21dabf74 100644 --- a/crates/milli/src/update/new/extract/vectors/mod.rs +++ b/crates/milli/src/update/new/extract/vectors/mod.rs @@ -23,7 +23,7 @@ pub struct EmbeddingExtractor<'a, 'b> { embedders: &'a EmbeddingConfigs, sender: EmbeddingSender<'a, 'b>, possible_embedding_mistakes: PossibleEmbeddingMistakes, - embedder_stats: Option>, + embedder_stats: Arc, threads: &'a ThreadPoolNoAbort, } @@ -32,7 +32,7 @@ impl<'a, 'b> EmbeddingExtractor<'a, 'b> { embedders: &'a EmbeddingConfigs, sender: EmbeddingSender<'a, 'b>, field_distribution: &'a FieldDistribution, - embedder_stats: Option>, + embedder_stats: Arc, threads: &'a ThreadPoolNoAbort, ) -> Self { let possible_embedding_mistakes = PossibleEmbeddingMistakes::new(field_distribution); @@ -311,7 +311,7 @@ struct Chunks<'a, 'b, 'extractor> { dimensions: usize, prompt: &'a Prompt, possible_embedding_mistakes: &'a PossibleEmbeddingMistakes, - embedder_stats: Option>, + embedder_stats: Arc, user_provided: &'a RefCell>, threads: &'a ThreadPoolNoAbort, sender: EmbeddingSender<'a, 'b>, @@ -327,7 +327,7 @@ impl<'a, 'b, 'extractor> Chunks<'a, 'b, 'extractor> { prompt: &'a Prompt, user_provided: &'a RefCell>, possible_embedding_mistakes: &'a PossibleEmbeddingMistakes, - embedder_stats: Option>, + embedder_stats: Arc, threads: &'a ThreadPoolNoAbort, sender: EmbeddingSender<'a, 'b>, doc_alloc: &'a Bump, @@ -416,7 +416,7 @@ impl<'a, 'b, 'extractor> Chunks<'a, 'b, 'extractor> { embedder_id: u8, embedder_name: &str, possible_embedding_mistakes: &PossibleEmbeddingMistakes, - embedder_stats: Option>, + embedder_stats: Arc, unused_vectors_distribution: &UnusedVectorsDistributionBump, threads: &ThreadPoolNoAbort, sender: EmbeddingSender<'a, 'b>, diff --git a/crates/milli/src/update/new/indexer/extract.rs b/crates/milli/src/update/new/indexer/extract.rs index 040886236..c721a2563 100644 --- a/crates/milli/src/update/new/indexer/extract.rs +++ b/crates/milli/src/update/new/indexer/extract.rs @@ -248,7 +248,7 @@ where embedders, embedding_sender, field_distribution, - Some(embedder_stats), + embedder_stats, request_threads(), ); let mut datastore = ThreadLocal::with_capacity(rayon::current_num_threads()); diff --git a/crates/milli/src/vector/composite.rs b/crates/milli/src/vector/composite.rs index 7d9497165..87f05d4fe 100644 --- a/crates/milli/src/vector/composite.rs +++ b/crates/milli/src/vector/composite.rs @@ -196,7 +196,7 @@ impl SubEmbedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> std::result::Result>, EmbedError> { match self { SubEmbedder::HuggingFace(embedder) => embedder.embed_index(text_chunks), @@ -218,7 +218,7 @@ impl SubEmbedder { &self, texts: &[&str], threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> std::result::Result, EmbedError> { match self { SubEmbedder::HuggingFace(embedder) => embedder.embed_index_ref(texts), diff --git a/crates/milli/src/vector/mod.rs b/crates/milli/src/vector/mod.rs index efa981694..481eb6c99 100644 --- a/crates/milli/src/vector/mod.rs +++ b/crates/milli/src/vector/mod.rs @@ -749,7 +749,7 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> std::result::Result>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_index(text_chunks), @@ -772,7 +772,7 @@ impl Embedder { &self, texts: &[&str], threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> std::result::Result, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_index_ref(texts), diff --git a/crates/milli/src/vector/ollama.rs b/crates/milli/src/vector/ollama.rs index e26b7e1ea..045b65b72 100644 --- a/crates/milli/src/vector/ollama.rs +++ b/crates/milli/src/vector/ollama.rs @@ -121,21 +121,21 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> Result>, EmbedError> { // This condition helps reduce the number of active rayon jobs // so that we avoid consuming all the LMDB rtxns and avoid stack overflows. if threads.active_operations() >= REQUEST_PARALLELISM { text_chunks .into_iter() - .map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(&chunk, None, Some(embedder_stats.clone()))) .collect() } else { threads .install(move || { text_chunks .into_par_iter() - .map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(&chunk, None, Some(embedder_stats.clone()))) .collect() }) .map_err(|error| EmbedError { @@ -149,14 +149,14 @@ impl Embedder { &self, texts: &[&str], threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> Result>, EmbedError> { // This condition helps reduce the number of active rayon jobs // so that we avoid consuming all the LMDB rtxns and avoid stack overflows. if threads.active_operations() >= REQUEST_PARALLELISM { let embeddings: Result>, _> = texts .chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(chunk, None, Some(embedder_stats.clone()))) .collect(); let embeddings = embeddings?; @@ -166,7 +166,7 @@ impl Embedder { .install(move || { let embeddings: Result>, _> = texts .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(chunk, None, Some(embedder_stats.clone()))) .collect(); let embeddings = embeddings?; diff --git a/crates/milli/src/vector/openai.rs b/crates/milli/src/vector/openai.rs index ca072d6e5..b64e3d467 100644 --- a/crates/milli/src/vector/openai.rs +++ b/crates/milli/src/vector/openai.rs @@ -262,21 +262,21 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> Result>, EmbedError> { // This condition helps reduce the number of active rayon jobs // so that we avoid consuming all the LMDB rtxns and avoid stack overflows. if threads.active_operations() >= REQUEST_PARALLELISM { text_chunks .into_iter() - .map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(&chunk, None, Some(embedder_stats.clone()))) .collect() } else { threads .install(move || { text_chunks .into_par_iter() - .map(move |chunk| self.embed(&chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(&chunk, None, Some(embedder_stats.clone()))) .collect() }) .map_err(|error| EmbedError { @@ -290,14 +290,14 @@ impl Embedder { &self, texts: &[&str], threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> Result>, EmbedError> { // This condition helps reduce the number of active rayon jobs // so that we avoid consuming all the LMDB rtxns and avoid stack overflows. if threads.active_operations() >= REQUEST_PARALLELISM { let embeddings: Result>, _> = texts .chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(chunk, None, Some(embedder_stats.clone()))) .collect(); let embeddings = embeddings?; Ok(embeddings.into_iter().flatten().collect()) @@ -306,7 +306,7 @@ impl Embedder { .install(move || { let embeddings: Result>, _> = texts .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(chunk, None, Some(embedder_stats.clone()))) .collect(); let embeddings = embeddings?; diff --git a/crates/milli/src/vector/rest.rs b/crates/milli/src/vector/rest.rs index 294b0ceda..409284b65 100644 --- a/crates/milli/src/vector/rest.rs +++ b/crates/milli/src/vector/rest.rs @@ -208,21 +208,21 @@ impl Embedder { &self, text_chunks: Vec>, threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> Result>, EmbedError> { // This condition helps reduce the number of active rayon jobs // so that we avoid consuming all the LMDB rtxns and avoid stack overflows. if threads.active_operations() >= REQUEST_PARALLELISM { text_chunks .into_iter() - .map(move |chunk| self.embed(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(chunk, None, Some(embedder_stats.clone()))) .collect() } else { threads .install(move || { text_chunks .into_par_iter() - .map(move |chunk| self.embed(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed(chunk, None, Some(embedder_stats.clone()))) .collect() }) .map_err(|error| EmbedError { @@ -236,14 +236,14 @@ impl Embedder { &self, texts: &[&str], threads: &ThreadPoolNoAbort, - embedder_stats: Option>, + embedder_stats: Arc, ) -> Result, EmbedError> { // This condition helps reduce the number of active rayon jobs // so that we avoid consuming all the LMDB rtxns and avoid stack overflows. if threads.active_operations() >= REQUEST_PARALLELISM { let embeddings: Result>, _> = texts .chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed_ref(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed_ref(chunk, None, Some(embedder_stats.clone()))) .collect(); let embeddings = embeddings?; @@ -253,7 +253,7 @@ impl Embedder { .install(move || { let embeddings: Result>, _> = texts .par_chunks(self.prompt_count_in_chunk_hint()) - .map(move |chunk| self.embed_ref(chunk, None, embedder_stats.clone())) + .map(move |chunk| self.embed_ref(chunk, None, Some(embedder_stats.clone()))) .collect(); let embeddings = embeddings?;