diff --git a/meilisearch/src/option.rs b/meilisearch/src/option.rs index 651af7336..fed824079 100644 --- a/meilisearch/src/option.rs +++ b/meilisearch/src/option.rs @@ -13,6 +13,7 @@ use byte_unit::{Byte, ByteError}; use clap::Parser; use meilisearch_types::features::InstanceTogglableFeatures; use meilisearch_types::milli::update::IndexerConfig; +use meilisearch_types::milli::ThreadPoolNoAbortBuilder; use rustls::server::{ AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ServerSessionMemoryCache, }; @@ -666,7 +667,7 @@ impl TryFrom<&IndexerOpts> for IndexerConfig { type Error = anyhow::Error; fn try_from(other: &IndexerOpts) -> Result { - let thread_pool = rayon::ThreadPoolBuilder::new() + let thread_pool = ThreadPoolNoAbortBuilder::new() .thread_name(|index| format!("indexing-thread:{index}")) .num_threads(*other.max_indexing_threads) .build()?; diff --git a/milli/src/error.rs b/milli/src/error.rs index 1d61bef63..e4550de1f 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -9,6 +9,7 @@ use serde_json::Value; use thiserror::Error; use crate::documents::{self, DocumentsBatchCursorError}; +use crate::thread_pool_no_abort::PanicCatched; use crate::{CriterionError, DocumentId, FieldId, Object, SortError}; pub fn is_reserved_keyword(keyword: &str) -> bool { @@ -39,17 +40,19 @@ pub enum InternalError { Fst(#[from] fst::Error), #[error(transparent)] DocumentsError(#[from] documents::Error), - #[error("Invalid compression type have been specified to grenad.")] + #[error("Invalid compression type have been specified to grenad")] GrenadInvalidCompressionType, - #[error("Invalid grenad file with an invalid version format.")] + #[error("Invalid grenad file with an invalid version format")] GrenadInvalidFormatVersion, - #[error("Invalid merge while processing {process}.")] + #[error("Invalid merge while processing {process}")] IndexingMergingKeys { process: &'static str }, #[error("{}", HeedError::InvalidDatabaseTyping)] InvalidDatabaseTyping, #[error(transparent)] RayonThreadPool(#[from] ThreadPoolBuildError), #[error(transparent)] + PanicInThreadPool(#[from] PanicCatched), + #[error(transparent)] SerdeJson(#[from] serde_json::Error), #[error(transparent)] Serialization(#[from] SerializationError), @@ -57,9 +60,9 @@ pub enum InternalError { Store(#[from] MdbError), #[error(transparent)] Utf8(#[from] str::Utf8Error), - #[error("An indexation process was explicitly aborted.")] + #[error("An indexation process was explicitly aborted")] AbortedIndexation, - #[error("The matching words list contains at least one invalid member.")] + #[error("The matching words list contains at least one invalid member")] InvalidMatchingWords, #[error(transparent)] ArroyError(#[from] arroy::Error), diff --git a/milli/src/lib.rs b/milli/src/lib.rs index cd44f2f2e..a1e240464 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -21,6 +21,7 @@ pub mod prompt; pub mod proximity; pub mod score_details; mod search; +mod thread_pool_no_abort; pub mod update; pub mod vector; @@ -42,6 +43,7 @@ pub use search::new::{ SearchLogger, VisualSearchLogger, }; use serde_json::Value; +pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder}; pub use {charabia as tokenizer, heed}; pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError}; diff --git a/milli/src/thread_pool_no_abort.rs b/milli/src/thread_pool_no_abort.rs new file mode 100644 index 000000000..14e5b0491 --- /dev/null +++ b/milli/src/thread_pool_no_abort.rs @@ -0,0 +1,69 @@ +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; + +use rayon::{ThreadPool, ThreadPoolBuilder}; +use thiserror::Error; + +/// A rayon ThreadPool wrapper that can catch panics in the pool +/// and modifies the install function accordingly. +#[derive(Debug)] +pub struct ThreadPoolNoAbort { + thread_pool: ThreadPool, + /// Set to true if the thread pool catched a panic. + pool_catched_panic: Arc, +} + +impl ThreadPoolNoAbort { + pub fn install(&self, op: OP) -> Result + where + OP: FnOnce() -> R + Send, + R: Send, + { + let output = self.thread_pool.install(op); + // While reseting the pool panic catcher we return an error if we catched one. + if self.pool_catched_panic.swap(false, Ordering::SeqCst) { + Err(PanicCatched) + } else { + Ok(output) + } + } + + pub fn current_num_threads(&self) -> usize { + self.thread_pool.current_num_threads() + } +} + +#[derive(Error, Debug)] +#[error("A panic occured. Read the logs to find more information about it")] +pub struct PanicCatched; + +#[derive(Default)] +pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder); + +impl ThreadPoolNoAbortBuilder { + pub fn new() -> ThreadPoolNoAbortBuilder { + ThreadPoolNoAbortBuilder::default() + } + + pub fn thread_name(mut self, closure: F) -> Self + where + F: FnMut(usize) -> String + 'static, + { + self.0 = self.0.thread_name(closure); + self + } + + pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder { + self.0 = self.0.num_threads(num_threads); + self + } + + pub fn build(mut self) -> Result { + let pool_catched_panic = Arc::new(AtomicBool::new(false)); + self.0 = self.0.panic_handler({ + let catched_panic = pool_catched_panic.clone(); + move |_result| catched_panic.store(true, Ordering::SeqCst) + }); + Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic }) + } +} diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 23f945c7a..9d0e7e360 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -19,7 +19,7 @@ use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::try_split_at; use crate::update::settings::InnerIndexSettingsDiff; use crate::vector::Embedder; -use crate::{DocumentId, InternalError, Result, VectorOrArrayOfVectors}; +use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, VectorOrArrayOfVectors}; /// The length of the elements that are always in the buffer when inserting new values. const TRUNCATE_SIZE: usize = size_of::(); @@ -347,7 +347,7 @@ pub fn extract_embeddings( prompt_reader: grenad::Reader, indexer: GrenadParameters, embedder: Arc, - request_threads: &rayon::ThreadPool, + request_threads: &ThreadPoolNoAbort, ) -> Result>> { puffin::profile_function!(); let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index bc6fe2aff..573e0898a 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -31,7 +31,7 @@ use self::extract_word_position_docids::extract_word_position_docids; use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters}; use super::{helpers, TypedChunk}; use crate::update::settings::InnerIndexSettingsDiff; -use crate::{FieldId, Result}; +use crate::{FieldId, Result, ThreadPoolNoAbortBuilder}; /// Extract data for each databases from obkv documents in parallel. /// Send data in grenad file over provided Sender. @@ -229,7 +229,7 @@ fn send_original_documents_data( let documents_chunk_cloned = original_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); - let request_threads = rayon::ThreadPoolBuilder::new() + let request_threads = ThreadPoolNoAbortBuilder::new() .num_threads(crate::vector::REQUEST_PARALLELISM) .thread_name(|index| format!("embedding-request-{index}")) .build()?; diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index aa9789a1a..bb180a7ee 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -33,6 +33,7 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; pub use self::transform::{Transform, TransformOutput}; use crate::documents::{obkv_to_object, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; +use crate::thread_pool_no_abort::ThreadPoolNoAbortBuilder; pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, @@ -298,18 +299,18 @@ where let backup_pool; let pool = match self.indexer_config.thread_pool { Some(ref pool) => pool, - #[cfg(not(test))] None => { - // We initialize a bakcup pool with the default + // We initialize a backup pool with the default // settings if none have already been set. - backup_pool = rayon::ThreadPoolBuilder::new().build()?; - &backup_pool - } - #[cfg(test)] - None => { - // We initialize a bakcup pool with the default - // settings if none have already been set. - backup_pool = rayon::ThreadPoolBuilder::new().num_threads(1).build()?; + #[allow(unused_mut)] + let mut pool_builder = ThreadPoolNoAbortBuilder::new(); + + #[cfg(test)] + { + pool_builder = pool_builder.num_threads(1); + } + + backup_pool = pool_builder.build()?; &backup_pool } }; @@ -533,7 +534,7 @@ where } Ok(()) - })?; + }).map_err(InternalError::from)??; // We write the field distribution into the main database self.index.put_field_distribution(self.wtxn, &field_distribution)?; @@ -562,7 +563,8 @@ where writer.build(wtxn, &mut rng, None)?; } Result::Ok(()) - })?; + }) + .map_err(InternalError::from)??; } self.execute_prefix_databases( diff --git a/milli/src/update/indexer_config.rs b/milli/src/update/indexer_config.rs index ff7942fdb..115059a1d 100644 --- a/milli/src/update/indexer_config.rs +++ b/milli/src/update/indexer_config.rs @@ -1,5 +1,6 @@ use grenad::CompressionType; -use rayon::ThreadPool; + +use crate::thread_pool_no_abort::ThreadPoolNoAbort; #[derive(Debug)] pub struct IndexerConfig { @@ -9,7 +10,7 @@ pub struct IndexerConfig { pub max_memory: Option, pub chunk_compression_type: CompressionType, pub chunk_compression_level: Option, - pub thread_pool: Option, + pub thread_pool: Option, pub max_positions_per_attributes: Option, pub skip_index_budget: bool, } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index d3369ef3d..650e1171e 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -3,6 +3,7 @@ use std::path::PathBuf; use hf_hub::api::sync::ApiError; use crate::error::FaultSource; +use crate::PanicCatched; #[derive(Debug, thiserror::Error)] #[error("Error while generating embeddings: {inner}")] @@ -80,6 +81,8 @@ pub enum EmbedErrorKind { OpenAiUnexpectedDimension(usize, usize), #[error("no embedding was produced")] MissingEmbedding, + #[error(transparent)] + PanicInThreadPool(#[from] PanicCatched), } impl EmbedError { diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 58f7ba5e1..306c1c1e9 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize}; use self::error::{EmbedError, NewEmbedderError}; use crate::prompt::{Prompt, PromptData}; +use crate::ThreadPoolNoAbort; pub mod error; pub mod hf; @@ -254,7 +255,7 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> std::result::Result>>, EmbedError> { match self { Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), diff --git a/milli/src/vector/ollama.rs b/milli/src/vector/ollama.rs index cf5030fb4..2c29cc816 100644 --- a/milli/src/vector/ollama.rs +++ b/milli/src/vector/ollama.rs @@ -3,6 +3,8 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::{DistributionShift, Embeddings}; +use crate::error::FaultSource; +use crate::ThreadPoolNoAbort; #[derive(Debug)] pub struct Embedder { @@ -71,11 +73,16 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> Result>>, EmbedError> { - threads.install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() - }) + threads + .install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? } pub fn chunk_count_hint(&self) -> usize { diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 141de486b..e180aedaa 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -4,7 +4,9 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator as _}; use super::error::{EmbedError, NewEmbedderError}; use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions}; use super::{DistributionShift, Embeddings}; +use crate::error::FaultSource; use crate::vector::error::EmbedErrorKind; +use crate::ThreadPoolNoAbort; #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { @@ -241,11 +243,16 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> Result>>, EmbedError> { - threads.install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() - }) + threads + .install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? } pub fn chunk_count_hint(&self) -> usize { diff --git a/milli/src/vector/rest.rs b/milli/src/vector/rest.rs index b0ea07f82..60f54782e 100644 --- a/milli/src/vector/rest.rs +++ b/milli/src/vector/rest.rs @@ -2,9 +2,12 @@ use deserr::Deserr; use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _}; use serde::{Deserialize, Serialize}; +use super::error::EmbedErrorKind; use super::{ DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM, }; +use crate::error::FaultSource; +use crate::ThreadPoolNoAbort; // retrying in case of failure @@ -158,11 +161,16 @@ impl Embedder { pub fn embed_chunks( &self, text_chunks: Vec>, - threads: &rayon::ThreadPool, + threads: &ThreadPoolNoAbort, ) -> Result>>, EmbedError> { - threads.install(move || { - text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() - }) + threads + .install(move || { + text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect() + }) + .map_err(|error| EmbedError { + kind: EmbedErrorKind::PanicInThreadPool(error), + fault: FaultSource::Bug, + })? } pub fn chunk_count_hint(&self) -> usize { diff --git a/tracing-trace/src/processor/firefox_profiler.rs b/tracing-trace/src/processor/firefox_profiler.rs index bae8ea44a..da3380e5c 100644 --- a/tracing-trace/src/processor/firefox_profiler.rs +++ b/tracing-trace/src/processor/firefox_profiler.rs @@ -217,9 +217,7 @@ fn add_memory_samples( memory_counters: &mut Option, last_memory: &mut MemoryStats, ) -> Option { - let Some(stats) = memory else { - return None; - }; + let stats = memory?; let memory_counters = memory_counters.get_or_insert_with(|| MemoryCounterHandles::new(profile, main));