Introduce the ThreadPoolNoAbort wrapper

This commit is contained in:
Clément Renault 2024-04-24 16:40:12 +02:00
parent b3173d0423
commit d4aeff92d0
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
14 changed files with 129 additions and 60 deletions

View file

@ -3,6 +3,7 @@ use std::path::PathBuf;
use hf_hub::api::sync::ApiError;
use crate::error::FaultSource;
use crate::PanicCatched;
#[derive(Debug, thiserror::Error)]
#[error("Error while generating embeddings: {inner}")]
@ -80,6 +81,8 @@ pub enum EmbedErrorKind {
OpenAiUnexpectedDimension(usize, usize),
#[error("no embedding was produced")]
MissingEmbedding,
#[error(transparent)]
PanicInThreadPool(#[from] PanicCatched),
}
impl EmbedError {

View file

@ -7,6 +7,7 @@ use serde::{Deserialize, Serialize};
use self::error::{EmbedError, NewEmbedderError};
use crate::prompt::{Prompt, PromptData};
use crate::ThreadPoolNoAbort;
pub mod error;
pub mod hf;
@ -254,7 +255,7 @@ impl Embedder {
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
threads: &ThreadPoolNoAbort,
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
match self {
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),

View file

@ -3,6 +3,8 @@ use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, Embeddings};
use crate::error::FaultSource;
use crate::ThreadPoolNoAbort;
#[derive(Debug)]
pub struct Embedder {
@ -71,11 +73,16 @@ impl Embedder {
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
pub fn chunk_count_hint(&self) -> usize {

View file

@ -4,7 +4,9 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
use super::error::{EmbedError, NewEmbedderError};
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
use super::{DistributionShift, Embeddings};
use crate::error::FaultSource;
use crate::vector::error::EmbedErrorKind;
use crate::ThreadPoolNoAbort;
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct EmbedderOptions {
@ -241,11 +243,16 @@ impl Embedder {
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
pub fn chunk_count_hint(&self) -> usize {

View file

@ -2,9 +2,12 @@ use deserr::Deserr;
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
use serde::{Deserialize, Serialize};
use super::error::EmbedErrorKind;
use super::{
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
};
use crate::error::FaultSource;
use crate::ThreadPoolNoAbort;
// retrying in case of failure
@ -158,11 +161,16 @@ impl Embedder {
pub fn embed_chunks(
&self,
text_chunks: Vec<Vec<String>>,
threads: &rayon::ThreadPool,
threads: &ThreadPoolNoAbort,
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
threads.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
threads
.install(move || {
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
})
.map_err(|error| EmbedError {
kind: EmbedErrorKind::PanicInThreadPool(error),
fault: FaultSource::Bug,
})?
}
pub fn chunk_count_hint(&self) -> usize {