From d4aeff92d054c7f183d24bf813d4b73988cb15b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 24 Apr 2024 16:40:12 +0200 Subject: [PATCH] Introduce the ThreadPoolNoAbort wrapper --- meilisearch/src/option.rs | 12 +--- milli/src/error.rs | 5 +- milli/src/lib.rs | 2 + milli/src/thread_pool_no_abort.rs | 69 +++++++++++++++++++ .../extract/extract_vector_points.rs | 4 +- .../src/update/index_documents/extract/mod.rs | 4 +- milli/src/update/index_documents/mod.rs | 25 ++----- milli/src/update/indexer_config.rs | 12 +--- milli/src/vector/error.rs | 3 + milli/src/vector/mod.rs | 3 +- milli/src/vector/ollama.rs | 15 ++-- milli/src/vector/openai.rs | 15 ++-- milli/src/vector/rest.rs | 16 +++-- .../src/processor/firefox_profiler.rs | 4 +- 14 files changed, 129 insertions(+), 60 deletions(-) create mode 100644 milli/src/thread_pool_no_abort.rs diff --git a/meilisearch/src/option.rs b/meilisearch/src/option.rs index d3d9150d6..fed824079 100644 --- a/meilisearch/src/option.rs +++ b/meilisearch/src/option.rs @@ -6,7 +6,6 @@ use std::num::ParseIntError; use std::ops::Deref; use std::path::PathBuf; use std::str::FromStr; -use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::{env, fmt, fs}; @@ -14,6 +13,7 @@ use byte_unit::{Byte, ByteError}; use clap::Parser; use meilisearch_types::features::InstanceTogglableFeatures; use meilisearch_types::milli::update::IndexerConfig; +use meilisearch_types::milli::ThreadPoolNoAbortBuilder; use rustls::server::{ AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ServerSessionMemoryCache, }; @@ -667,23 +667,15 @@ impl TryFrom<&IndexerOpts> for IndexerConfig { type Error = anyhow::Error; fn try_from(other: &IndexerOpts) -> Result { - let pool_panic_catched = Arc::new(AtomicBool::new(false)); - let thread_pool = rayon::ThreadPoolBuilder::new() + let thread_pool = ThreadPoolNoAbortBuilder::new() .thread_name(|index| format!("indexing-thread:{index}")) .num_threads(*other.max_indexing_threads) - .panic_handler({ - // TODO What should we do with this Box. - // So, let's just set a value to true to cancel the task with a message for now. - let panic_cathed = pool_panic_catched.clone(); - move |_result| panic_cathed.store(true, Ordering::SeqCst) - }) .build()?; Ok(Self { log_every_n: Some(DEFAULT_LOG_EVERY_N), max_memory: other.max_indexing_memory.map(|b| b.get_bytes() as usize), thread_pool: Some(thread_pool), - pool_panic_catched, max_positions_per_attributes: None, skip_index_budget: other.skip_index_budget, ..Default::default() diff --git a/milli/src/error.rs b/milli/src/error.rs index 02691a5cf..e4550de1f 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -9,6 +9,7 @@ use serde_json::Value; use thiserror::Error; use crate::documents::{self, DocumentsBatchCursorError}; +use crate::thread_pool_no_abort::PanicCatched; use crate::{CriterionError, DocumentId, FieldId, Object, SortError}; pub fn is_reserved_keyword(keyword: &str) -> bool { @@ -49,8 +50,8 @@ pub enum InternalError { InvalidDatabaseTyping, #[error(transparent)] RayonThreadPool(#[from] ThreadPoolBuildError), - #[error("A panic occured. Read the logs to find more information about it")] - PanicInThreadPool, + #[error(transparent)] + PanicInThreadPool(#[from] PanicCatched), #[error(transparent)] SerdeJson(#[from] serde_json::Error), #[error(transparent)] diff --git a/milli/src/lib.rs b/milli/src/lib.rs index cd44f2f2e..a1e240464 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -21,6 +21,7 @@ pub mod prompt; pub mod proximity; pub mod score_details; mod search; +mod thread_pool_no_abort; pub mod update; pub mod vector; @@ -42,6 +43,7 @@ pub use search::new::{ SearchLogger, VisualSearchLogger, }; use serde_json::Value; +pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; pub use {charabia as tokenizer, heed}; pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError}; diff --git a/milli/src/thread_pool_no_abort.rs b/milli/src/thread_pool_no_abort.rs new file mode 100644 index 000000000..14e5b0491 --- /dev/null +++ b/milli/src/thread_pool_no_abort.rs @@ -0,0 +1,69 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use rayon::{ThreadPool, ThreadPoolBuilder}; +use thiserror::Error; + +/// A rayon ThreadPool wrapper that can catch panics in the pool +/// and modifies the install function accordingly. +#[derive(Debug)] +pub struct ThreadPoolNoAbort { + thread_pool: ThreadPool, + /// Set to true if the thread pool catched a panic. + pool_catched_panic: Arc, +} + +impl ThreadPoolNoAbort { + pub fn install(&self, op: OP) -> Result + where + OP: FnOnce() -> R + Send, + R: Send, + { + let output = self.thread_pool.install(op); + // While reseting the pool panic catcher we return an error if we catched one. + if self.pool_catched_panic.swap(false, Ordering::SeqCst) { + Err(PanicCatched) + } else { + Ok(output) + } + } + + pub fn current_num_threads(&self) -> usize { + self.thread_pool.current_num_threads() + } +} + +#[derive(Error, Debug)] +#[error("A panic occured. Read the logs to find more information about it")] +pub struct PanicCatched; + +#[derive(Default)] +pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder); + +impl ThreadPoolNoAbortBuilder { + pub fn new() -> ThreadPoolNoAbortBuilder { + ThreadPoolNoAbortBuilder::default() + } + + pub fn thread_name(mut self, closure: F) -> Self + where + F: FnMut(usize) -> String + 'static, + { + self.0 = self.0.thread_name(closure); + self + } + + pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder { + self.0 = self.0.num_threads(num_threads); + self + } + + pub fn build(mut self) -> Result { + let pool_catched_panic = Arc::new(AtomicBool::new(false)); + self.0 = self.0.panic_handler({ + let catched_panic = pool_catched_panic.clone(); + move |_result| catched_panic.store(true, Ordering::SeqCst) + }); + Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic }) + } +} diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 23f945c7a..9d0e7e360 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -19,7 +19,7 @@ use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::try_split_at; use crate::update::settings::InnerIndexSettingsDiff; use crate::vector::Embedder; -use crate::{DocumentId, InternalError, Result, VectorOrArrayOfVectors}; +use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, VectorOrArrayOfVectors}; /// The length of the elements that are always in the buffer when inserting new values. const TRUNCATE_SIZE: usize = size_of::(); @@ -347,7 +347,7 @@ pub fn extract_embeddings( prompt_reader: grenad::Reader, indexer: GrenadParameters, embedder: Arc, - request_threads: &rayon::ThreadPool, + request_threads: &ThreadPoolNoAbort, ) -> Result>> { puffin::profile_function!(); let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index bc6fe2aff..573e0898a 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -31,7 +31,7 @@ use self::extract_word_position_docids::extract_word_position_docids; use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters}; use super::{helpers, TypedChunk}; use crate::update::settings::InnerIndexSettingsDiff; -use crate::{FieldId, Result}; +use crate::{FieldId, Result, ThreadPoolNoAbortBuilder}; /// Extract data for each databases from obkv documents in parallel. /// Send data in grenad file over provided Sender. @@ -229,7 +229,7 @@ fn send_original_documents_data( let documents_chunk_cloned = original_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); - let request_threads = rayon::ThreadPoolBuilder::new() + let request_threads = ThreadPoolNoAbortBuilder::new() .num_threads(crate::vector::REQUEST_PARALLELISM) .thread_name(|index| format!("embedding-request-{index}")) .build()?; diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index c61c83757..bb180a7ee 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -8,7 +8,6 @@ use std::collections::{HashMap, HashSet}; use std::io::{Read, Seek}; use std::num::NonZeroU32; use std::result::Result as StdResult; -use std::sync::atomic::Ordering; use std::sync::Arc; use crossbeam_channel::{Receiver, Sender}; @@ -34,6 +33,7 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; pub use self::transform::{Transform, TransformOutput}; use crate::documents::{obkv_to_object, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; +use crate::thread_pool_no_abort::ThreadPoolNoAbortBuilder; pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, @@ -297,17 +297,13 @@ where let settings_diff = Arc::new(settings_diff); let backup_pool; - let pool_catched_panic = self.indexer_config.pool_panic_catched.clone(); let pool = match self.indexer_config.thread_pool { Some(ref pool) => pool, None => { // We initialize a backup pool with the default // settings if none have already been set. - let mut pool_builder = rayon::ThreadPoolBuilder::new(); - pool_builder = pool_builder.panic_handler({ - let catched_panic = pool_catched_panic.clone(); - move |_result| catched_panic.store(true, Ordering::SeqCst) - }); + #[allow(unused_mut)] + let mut pool_builder = ThreadPoolNoAbortBuilder::new(); #[cfg(test)] { @@ -538,12 +534,7 @@ where } Ok(()) - })?; - - // While reseting the pool panic catcher we return an error if we catched one. - if pool_catched_panic.swap(false, Ordering::SeqCst) { - return Err(InternalError::PanicInThreadPool.into()); - } + }).map_err(InternalError::from)??; // We write the field distribution into the main database self.index.put_field_distribution(self.wtxn, &field_distribution)?; @@ -572,12 +563,8 @@ where writer.build(wtxn, &mut rng, None)?; } Result::Ok(()) - })?; - - // While reseting the pool panic catcher we return an error if we catched one. - if pool_catched_panic.swap(false, Ordering::SeqCst) { - return Err(InternalError::PanicInThreadPool.into()); - } + }) + .map_err(InternalError::from)??; } self.execute_prefix_databases( diff --git a/milli/src/update/indexer_config.rs b/milli/src/update/indexer_config.rs index b23d8e700..115059a1d 100644 --- a/milli/src/update/indexer_config.rs +++ b/milli/src/update/indexer_config.rs @@ -1,8 +1,6 @@ -use std::sync::atomic::AtomicBool; -use std::sync::Arc; - use grenad::CompressionType; -use rayon::ThreadPool; + +use crate::thread_pool_no_abort::ThreadPoolNoAbort; #[derive(Debug)] pub struct IndexerConfig { @@ -12,10 +10,7 @@ pub struct IndexerConfig { pub max_memory: Option, pub chunk_compression_type: CompressionType, pub chunk_compression_level: Option, - pub thread_pool: Option, - /// Set to true if the thread pool catched a panic - /// and we must abort the task - pub pool_panic_catched: Arc, + pub thread_pool: Option, pub max_positions_per_attributes: Option, pub skip_index_budget: bool, } @@ -30,7 +25,6 @@ impl Default for IndexerConfig { chunk_compression_type: CompressionType::None, chunk_compression_level: None, thread_pool: None, - pool_panic_catched: Arc::default(), max_positions_per_attributes: None, skip_index_budget: false, } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index d3369ef3d..650e1171e 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use hf_hub::api::sync::ApiError; use crate::error::FaultSource; +use crate::PanicCatched; #[derive(Debug, thiserror::Error)] #[error("Error while generating embeddings: {inner}")] @@ -80,6 +81,8 @@ pub enum EmbedErrorKind { OpenAiUnexpectedDimension(usize, usize), #[error("no embedding was produced")] MissingEmbedding, + #[error(transparent)] + PanicInThreadPool(#[from] PanicCatched), } impl EmbedError { diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 58f7ba5e1..306c1c1e9 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use self::error::{EmbedError, NewEmbedderError}; use crate::prompt::{Prompt, PromptData}; +use crate::ThreadPoolNoAbort; pub mod error; pub mod hf; @@ -254,7 +255,7 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index cf5030fb4..2c29cc816 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -3,6 +3,8 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::{DistributionShift, Embeddings}; +use crate::error::FaultSource; +use crate::ThreadPoolNoAbort; #[derive(Debug)] pub struct Embedder { @@ -71,11 +73,16 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> Result>>, EmbedError> { - threads.install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() - }) + threads + .install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? } pub fn chunk_count_hint(&self) -> usize { diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 141de486b..e180aedaa 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -4,7 +4,9 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use super::error::{EmbedError, NewEmbedderError}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::{DistributionShift, Embeddings}; +use crate::error::FaultSource; use crate::vector::error::EmbedErrorKind; +use crate::ThreadPoolNoAbort; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { @@ -241,11 +243,16 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> Result>>, EmbedError> { - threads.install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() - }) + threads + .install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? } pub fn chunk_count_hint(&self) -> usize { diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index b0ea07f82..60f54782e 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -2,9 +2,12 @@ use deserr::Deserr; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use serde::{Deserialize, Serialize}; +use super::error::EmbedErrorKind; use super::{ DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, }; +use crate::error::FaultSource; +use crate::ThreadPoolNoAbort; // retrying in case of failure @@ -158,11 +161,16 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> Result>>, EmbedError> { - threads.install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() - }) + threads + .install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? } pub fn chunk_count_hint(&self) -> usize { diff --git a/tracing-trace/src/processor/firefox_profiler.rs b/tracing-trace/src/processor/firefox_profiler.rs index bae8ea44a..da3380e5c 100644 --- a/tracing-trace/src/processor/firefox_profiler.rs +++ b/tracing-trace/src/processor/firefox_profiler.rs @@ -217,9 +217,7 @@ fn add_memory_samples( memory_counters: &mut Option, last_memory: &mut MemoryStats, ) -> Option { - let Some(stats) = memory else { - return None; - }; + let stats = memory?; let memory_counters = memory_counters.get_or_insert_with(|| MemoryCounterHandles::new(profile, main));