mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-11-22 21:04:27 +01:00
Introduce the ThreadPoolNoAbort wrapper
This commit is contained in:
parent
b3173d0423
commit
d4aeff92d0
@ -6,7 +6,6 @@ use std::num::ParseIntError;
|
|||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
use std::sync::atomic::{AtomicBool, Ordering};
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::{env, fmt, fs};
|
use std::{env, fmt, fs};
|
||||||
|
|
||||||
@ -14,6 +13,7 @@ use byte_unit::{Byte, ByteError};
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use meilisearch_types::features::InstanceTogglableFeatures;
|
use meilisearch_types::features::InstanceTogglableFeatures;
|
||||||
use meilisearch_types::milli::update::IndexerConfig;
|
use meilisearch_types::milli::update::IndexerConfig;
|
||||||
|
use meilisearch_types::milli::ThreadPoolNoAbortBuilder;
|
||||||
use rustls::server::{
|
use rustls::server::{
|
||||||
AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ServerSessionMemoryCache,
|
AllowAnyAnonymousOrAuthenticatedClient, AllowAnyAuthenticatedClient, ServerSessionMemoryCache,
|
||||||
};
|
};
|
||||||
@ -667,23 +667,15 @@ impl TryFrom<&IndexerOpts> for IndexerConfig {
|
|||||||
type Error = anyhow::Error;
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
fn try_from(other: &IndexerOpts) -> Result<Self, Self::Error> {
|
fn try_from(other: &IndexerOpts) -> Result<Self, Self::Error> {
|
||||||
let pool_panic_catched = Arc::new(AtomicBool::new(false));
|
let thread_pool = ThreadPoolNoAbortBuilder::new()
|
||||||
let thread_pool = rayon::ThreadPoolBuilder::new()
|
|
||||||
.thread_name(|index| format!("indexing-thread:{index}"))
|
.thread_name(|index| format!("indexing-thread:{index}"))
|
||||||
.num_threads(*other.max_indexing_threads)
|
.num_threads(*other.max_indexing_threads)
|
||||||
.panic_handler({
|
|
||||||
// TODO What should we do with this Box<dyn Any + Send>.
|
|
||||||
// So, let's just set a value to true to cancel the task with a message for now.
|
|
||||||
let panic_cathed = pool_panic_catched.clone();
|
|
||||||
move |_result| panic_cathed.store(true, Ordering::SeqCst)
|
|
||||||
})
|
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
log_every_n: Some(DEFAULT_LOG_EVERY_N),
|
log_every_n: Some(DEFAULT_LOG_EVERY_N),
|
||||||
max_memory: other.max_indexing_memory.map(|b| b.get_bytes() as usize),
|
max_memory: other.max_indexing_memory.map(|b| b.get_bytes() as usize),
|
||||||
thread_pool: Some(thread_pool),
|
thread_pool: Some(thread_pool),
|
||||||
pool_panic_catched,
|
|
||||||
max_positions_per_attributes: None,
|
max_positions_per_attributes: None,
|
||||||
skip_index_budget: other.skip_index_budget,
|
skip_index_budget: other.skip_index_budget,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
|
@ -9,6 +9,7 @@ use serde_json::Value;
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
use crate::documents::{self, DocumentsBatchCursorError};
|
use crate::documents::{self, DocumentsBatchCursorError};
|
||||||
|
use crate::thread_pool_no_abort::PanicCatched;
|
||||||
use crate::{CriterionError, DocumentId, FieldId, Object, SortError};
|
use crate::{CriterionError, DocumentId, FieldId, Object, SortError};
|
||||||
|
|
||||||
pub fn is_reserved_keyword(keyword: &str) -> bool {
|
pub fn is_reserved_keyword(keyword: &str) -> bool {
|
||||||
@ -49,8 +50,8 @@ pub enum InternalError {
|
|||||||
InvalidDatabaseTyping,
|
InvalidDatabaseTyping,
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
RayonThreadPool(#[from] ThreadPoolBuildError),
|
RayonThreadPool(#[from] ThreadPoolBuildError),
|
||||||
#[error("A panic occured. Read the logs to find more information about it")]
|
#[error(transparent)]
|
||||||
PanicInThreadPool,
|
PanicInThreadPool(#[from] PanicCatched),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
SerdeJson(#[from] serde_json::Error),
|
SerdeJson(#[from] serde_json::Error),
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
|
@ -21,6 +21,7 @@ pub mod prompt;
|
|||||||
pub mod proximity;
|
pub mod proximity;
|
||||||
pub mod score_details;
|
pub mod score_details;
|
||||||
mod search;
|
mod search;
|
||||||
|
mod thread_pool_no_abort;
|
||||||
pub mod update;
|
pub mod update;
|
||||||
pub mod vector;
|
pub mod vector;
|
||||||
|
|
||||||
@ -42,6 +43,7 @@ pub use search::new::{
|
|||||||
SearchLogger, VisualSearchLogger,
|
SearchLogger, VisualSearchLogger,
|
||||||
};
|
};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
|
pub use thread_pool_no_abort::{PanicCatched, ThreadPoolNoAbort, ThreadPoolNoAbortBuilder};
|
||||||
pub use {charabia as tokenizer, heed};
|
pub use {charabia as tokenizer, heed};
|
||||||
|
|
||||||
pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError};
|
pub use self::asc_desc::{AscDesc, AscDescError, Member, SortError};
|
||||||
|
69
milli/src/thread_pool_no_abort.rs
Normal file
69
milli/src/thread_pool_no_abort.rs
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
use std::sync::atomic::{AtomicBool, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use rayon::{ThreadPool, ThreadPoolBuilder};
|
||||||
|
use thiserror::Error;
|
||||||
|
|
||||||
|
/// A rayon ThreadPool wrapper that can catch panics in the pool
|
||||||
|
/// and modifies the install function accordingly.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ThreadPoolNoAbort {
|
||||||
|
thread_pool: ThreadPool,
|
||||||
|
/// Set to true if the thread pool catched a panic.
|
||||||
|
pool_catched_panic: Arc<AtomicBool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadPoolNoAbort {
|
||||||
|
pub fn install<OP, R>(&self, op: OP) -> Result<R, PanicCatched>
|
||||||
|
where
|
||||||
|
OP: FnOnce() -> R + Send,
|
||||||
|
R: Send,
|
||||||
|
{
|
||||||
|
let output = self.thread_pool.install(op);
|
||||||
|
// While reseting the pool panic catcher we return an error if we catched one.
|
||||||
|
if self.pool_catched_panic.swap(false, Ordering::SeqCst) {
|
||||||
|
Err(PanicCatched)
|
||||||
|
} else {
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn current_num_threads(&self) -> usize {
|
||||||
|
self.thread_pool.current_num_threads()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Error, Debug)]
|
||||||
|
#[error("A panic occured. Read the logs to find more information about it")]
|
||||||
|
pub struct PanicCatched;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ThreadPoolNoAbortBuilder(ThreadPoolBuilder);
|
||||||
|
|
||||||
|
impl ThreadPoolNoAbortBuilder {
|
||||||
|
pub fn new() -> ThreadPoolNoAbortBuilder {
|
||||||
|
ThreadPoolNoAbortBuilder::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn thread_name<F>(mut self, closure: F) -> Self
|
||||||
|
where
|
||||||
|
F: FnMut(usize) -> String + 'static,
|
||||||
|
{
|
||||||
|
self.0 = self.0.thread_name(closure);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_threads(mut self, num_threads: usize) -> ThreadPoolNoAbortBuilder {
|
||||||
|
self.0 = self.0.num_threads(num_threads);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build(mut self) -> Result<ThreadPoolNoAbort, rayon::ThreadPoolBuildError> {
|
||||||
|
let pool_catched_panic = Arc::new(AtomicBool::new(false));
|
||||||
|
self.0 = self.0.panic_handler({
|
||||||
|
let catched_panic = pool_catched_panic.clone();
|
||||||
|
move |_result| catched_panic.store(true, Ordering::SeqCst)
|
||||||
|
});
|
||||||
|
Ok(ThreadPoolNoAbort { thread_pool: self.0.build()?, pool_catched_panic })
|
||||||
|
}
|
||||||
|
}
|
@ -19,7 +19,7 @@ use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
|
|||||||
use crate::update::index_documents::helpers::try_split_at;
|
use crate::update::index_documents::helpers::try_split_at;
|
||||||
use crate::update::settings::InnerIndexSettingsDiff;
|
use crate::update::settings::InnerIndexSettingsDiff;
|
||||||
use crate::vector::Embedder;
|
use crate::vector::Embedder;
|
||||||
use crate::{DocumentId, InternalError, Result, VectorOrArrayOfVectors};
|
use crate::{DocumentId, InternalError, Result, ThreadPoolNoAbort, VectorOrArrayOfVectors};
|
||||||
|
|
||||||
/// The length of the elements that are always in the buffer when inserting new values.
|
/// The length of the elements that are always in the buffer when inserting new values.
|
||||||
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
|
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
|
||||||
@ -347,7 +347,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
|||||||
prompt_reader: grenad::Reader<R>,
|
prompt_reader: grenad::Reader<R>,
|
||||||
indexer: GrenadParameters,
|
indexer: GrenadParameters,
|
||||||
embedder: Arc<Embedder>,
|
embedder: Arc<Embedder>,
|
||||||
request_threads: &rayon::ThreadPool,
|
request_threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||||
puffin::profile_function!();
|
puffin::profile_function!();
|
||||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
||||||
|
@ -31,7 +31,7 @@ use self::extract_word_position_docids::extract_word_position_docids;
|
|||||||
use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters};
|
use super::helpers::{as_cloneable_grenad, CursorClonableMmap, GrenadParameters};
|
||||||
use super::{helpers, TypedChunk};
|
use super::{helpers, TypedChunk};
|
||||||
use crate::update::settings::InnerIndexSettingsDiff;
|
use crate::update::settings::InnerIndexSettingsDiff;
|
||||||
use crate::{FieldId, Result};
|
use crate::{FieldId, Result, ThreadPoolNoAbortBuilder};
|
||||||
|
|
||||||
/// Extract data for each databases from obkv documents in parallel.
|
/// Extract data for each databases from obkv documents in parallel.
|
||||||
/// Send data in grenad file over provided Sender.
|
/// Send data in grenad file over provided Sender.
|
||||||
@ -229,7 +229,7 @@ fn send_original_documents_data(
|
|||||||
let documents_chunk_cloned = original_documents_chunk.clone();
|
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||||
|
|
||||||
let request_threads = rayon::ThreadPoolBuilder::new()
|
let request_threads = ThreadPoolNoAbortBuilder::new()
|
||||||
.num_threads(crate::vector::REQUEST_PARALLELISM)
|
.num_threads(crate::vector::REQUEST_PARALLELISM)
|
||||||
.thread_name(|index| format!("embedding-request-{index}"))
|
.thread_name(|index| format!("embedding-request-{index}"))
|
||||||
.build()?;
|
.build()?;
|
||||||
|
@ -8,7 +8,6 @@ use std::collections::{HashMap, HashSet};
|
|||||||
use std::io::{Read, Seek};
|
use std::io::{Read, Seek};
|
||||||
use std::num::NonZeroU32;
|
use std::num::NonZeroU32;
|
||||||
use std::result::Result as StdResult;
|
use std::result::Result as StdResult;
|
||||||
use std::sync::atomic::Ordering;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crossbeam_channel::{Receiver, Sender};
|
use crossbeam_channel::{Receiver, Sender};
|
||||||
@ -34,6 +33,7 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters};
|
|||||||
pub use self::transform::{Transform, TransformOutput};
|
pub use self::transform::{Transform, TransformOutput};
|
||||||
use crate::documents::{obkv_to_object, DocumentsBatchReader};
|
use crate::documents::{obkv_to_object, DocumentsBatchReader};
|
||||||
use crate::error::{Error, InternalError, UserError};
|
use crate::error::{Error, InternalError, UserError};
|
||||||
|
use crate::thread_pool_no_abort::ThreadPoolNoAbortBuilder;
|
||||||
pub use crate::update::index_documents::helpers::CursorClonableMmap;
|
pub use crate::update::index_documents::helpers::CursorClonableMmap;
|
||||||
use crate::update::{
|
use crate::update::{
|
||||||
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
|
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
|
||||||
@ -297,17 +297,13 @@ where
|
|||||||
let settings_diff = Arc::new(settings_diff);
|
let settings_diff = Arc::new(settings_diff);
|
||||||
|
|
||||||
let backup_pool;
|
let backup_pool;
|
||||||
let pool_catched_panic = self.indexer_config.pool_panic_catched.clone();
|
|
||||||
let pool = match self.indexer_config.thread_pool {
|
let pool = match self.indexer_config.thread_pool {
|
||||||
Some(ref pool) => pool,
|
Some(ref pool) => pool,
|
||||||
None => {
|
None => {
|
||||||
// We initialize a backup pool with the default
|
// We initialize a backup pool with the default
|
||||||
// settings if none have already been set.
|
// settings if none have already been set.
|
||||||
let mut pool_builder = rayon::ThreadPoolBuilder::new();
|
#[allow(unused_mut)]
|
||||||
pool_builder = pool_builder.panic_handler({
|
let mut pool_builder = ThreadPoolNoAbortBuilder::new();
|
||||||
let catched_panic = pool_catched_panic.clone();
|
|
||||||
move |_result| catched_panic.store(true, Ordering::SeqCst)
|
|
||||||
});
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
{
|
{
|
||||||
@ -538,12 +534,7 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
})?;
|
}).map_err(InternalError::from)??;
|
||||||
|
|
||||||
// While reseting the pool panic catcher we return an error if we catched one.
|
|
||||||
if pool_catched_panic.swap(false, Ordering::SeqCst) {
|
|
||||||
return Err(InternalError::PanicInThreadPool.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
// We write the field distribution into the main database
|
// We write the field distribution into the main database
|
||||||
self.index.put_field_distribution(self.wtxn, &field_distribution)?;
|
self.index.put_field_distribution(self.wtxn, &field_distribution)?;
|
||||||
@ -572,12 +563,8 @@ where
|
|||||||
writer.build(wtxn, &mut rng, None)?;
|
writer.build(wtxn, &mut rng, None)?;
|
||||||
}
|
}
|
||||||
Result::Ok(())
|
Result::Ok(())
|
||||||
})?;
|
})
|
||||||
|
.map_err(InternalError::from)??;
|
||||||
// While reseting the pool panic catcher we return an error if we catched one.
|
|
||||||
if pool_catched_panic.swap(false, Ordering::SeqCst) {
|
|
||||||
return Err(InternalError::PanicInThreadPool.into());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.execute_prefix_databases(
|
self.execute_prefix_databases(
|
||||||
|
@ -1,8 +1,6 @@
|
|||||||
use std::sync::atomic::AtomicBool;
|
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use grenad::CompressionType;
|
use grenad::CompressionType;
|
||||||
use rayon::ThreadPool;
|
|
||||||
|
use crate::thread_pool_no_abort::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct IndexerConfig {
|
pub struct IndexerConfig {
|
||||||
@ -12,10 +10,7 @@ pub struct IndexerConfig {
|
|||||||
pub max_memory: Option<usize>,
|
pub max_memory: Option<usize>,
|
||||||
pub chunk_compression_type: CompressionType,
|
pub chunk_compression_type: CompressionType,
|
||||||
pub chunk_compression_level: Option<u32>,
|
pub chunk_compression_level: Option<u32>,
|
||||||
pub thread_pool: Option<ThreadPool>,
|
pub thread_pool: Option<ThreadPoolNoAbort>,
|
||||||
/// Set to true if the thread pool catched a panic
|
|
||||||
/// and we must abort the task
|
|
||||||
pub pool_panic_catched: Arc<AtomicBool>,
|
|
||||||
pub max_positions_per_attributes: Option<u32>,
|
pub max_positions_per_attributes: Option<u32>,
|
||||||
pub skip_index_budget: bool,
|
pub skip_index_budget: bool,
|
||||||
}
|
}
|
||||||
@ -30,7 +25,6 @@ impl Default for IndexerConfig {
|
|||||||
chunk_compression_type: CompressionType::None,
|
chunk_compression_type: CompressionType::None,
|
||||||
chunk_compression_level: None,
|
chunk_compression_level: None,
|
||||||
thread_pool: None,
|
thread_pool: None,
|
||||||
pool_panic_catched: Arc::default(),
|
|
||||||
max_positions_per_attributes: None,
|
max_positions_per_attributes: None,
|
||||||
skip_index_budget: false,
|
skip_index_budget: false,
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,7 @@ use std::path::PathBuf;
|
|||||||
use hf_hub::api::sync::ApiError;
|
use hf_hub::api::sync::ApiError;
|
||||||
|
|
||||||
use crate::error::FaultSource;
|
use crate::error::FaultSource;
|
||||||
|
use crate::PanicCatched;
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
#[error("Error while generating embeddings: {inner}")]
|
#[error("Error while generating embeddings: {inner}")]
|
||||||
@ -80,6 +81,8 @@ pub enum EmbedErrorKind {
|
|||||||
OpenAiUnexpectedDimension(usize, usize),
|
OpenAiUnexpectedDimension(usize, usize),
|
||||||
#[error("no embedding was produced")]
|
#[error("no embedding was produced")]
|
||||||
MissingEmbedding,
|
MissingEmbedding,
|
||||||
|
#[error(transparent)]
|
||||||
|
PanicInThreadPool(#[from] PanicCatched),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EmbedError {
|
impl EmbedError {
|
||||||
|
@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
|
|
||||||
use self::error::{EmbedError, NewEmbedderError};
|
use self::error::{EmbedError, NewEmbedderError};
|
||||||
use crate::prompt::{Prompt, PromptData};
|
use crate::prompt::{Prompt, PromptData};
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod hf;
|
pub mod hf;
|
||||||
@ -254,7 +255,7 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
match self {
|
match self {
|
||||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||||
|
@ -3,6 +3,8 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
|||||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, Embeddings};
|
use super::{DistributionShift, Embeddings};
|
||||||
|
use crate::error::FaultSource;
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct Embedder {
|
pub struct Embedder {
|
||||||
@ -71,11 +73,16 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads.install(move || {
|
threads
|
||||||
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
})
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
@ -4,7 +4,9 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
|||||||
use super::error::{EmbedError, NewEmbedderError};
|
use super::error::{EmbedError, NewEmbedderError};
|
||||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||||
use super::{DistributionShift, Embeddings};
|
use super::{DistributionShift, Embeddings};
|
||||||
|
use crate::error::FaultSource;
|
||||||
use crate::vector::error::EmbedErrorKind;
|
use crate::vector::error::EmbedErrorKind;
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||||
pub struct EmbedderOptions {
|
pub struct EmbedderOptions {
|
||||||
@ -241,11 +243,16 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads.install(move || {
|
threads
|
||||||
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
})
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
@ -2,9 +2,12 @@ use deserr::Deserr;
|
|||||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use super::error::EmbedErrorKind;
|
||||||
use super::{
|
use super::{
|
||||||
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
||||||
};
|
};
|
||||||
|
use crate::error::FaultSource;
|
||||||
|
use crate::ThreadPoolNoAbort;
|
||||||
|
|
||||||
// retrying in case of failure
|
// retrying in case of failure
|
||||||
|
|
||||||
@ -158,11 +161,16 @@ impl Embedder {
|
|||||||
pub fn embed_chunks(
|
pub fn embed_chunks(
|
||||||
&self,
|
&self,
|
||||||
text_chunks: Vec<Vec<String>>,
|
text_chunks: Vec<Vec<String>>,
|
||||||
threads: &rayon::ThreadPool,
|
threads: &ThreadPoolNoAbort,
|
||||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||||
threads.install(move || {
|
threads
|
||||||
|
.install(move || {
|
||||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||||
})
|
})
|
||||||
|
.map_err(|error| EmbedError {
|
||||||
|
kind: EmbedErrorKind::PanicInThreadPool(error),
|
||||||
|
fault: FaultSource::Bug,
|
||||||
|
})?
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn chunk_count_hint(&self) -> usize {
|
pub fn chunk_count_hint(&self) -> usize {
|
||||||
|
@ -217,9 +217,7 @@ fn add_memory_samples(
|
|||||||
memory_counters: &mut Option<MemoryCounterHandles>,
|
memory_counters: &mut Option<MemoryCounterHandles>,
|
||||||
last_memory: &mut MemoryStats,
|
last_memory: &mut MemoryStats,
|
||||||
) -> Option<MemoryStats> {
|
) -> Option<MemoryStats> {
|
||||||
let Some(stats) = memory else {
|
let stats = memory?;
|
||||||
return None;
|
|
||||||
};
|
|
||||||
|
|
||||||
let memory_counters =
|
let memory_counters =
|
||||||
memory_counters.get_or_insert_with(|| MemoryCounterHandles::new(profile, main));
|
memory_counters.get_or_insert_with(|| MemoryCounterHandles::new(profile, main));
|
||||||
|
Loading…
Reference in New Issue
Block a user