Support pooling

This commit is contained in:
Louis Dureuil 2025-02-18 14:16:41 +01:00
parent 0f1aeb8eaa
commit 11759c4be4
No known key found for this signature in database
2 changed files with 166 additions and 14 deletions

View file

@ -262,6 +262,31 @@ impl NewEmbedderError {
}
}
pub fn open_pooling_config(
pooling_config_filename: PathBuf,
inner: std::io::Error,
) -> NewEmbedderError {
let open_config = OpenPoolingConfig { filename: pooling_config_filename, inner };
Self {
kind: NewEmbedderErrorKind::OpenPoolingConfig(open_config),
fault: FaultSource::Runtime,
}
}
pub fn deserialize_pooling_config(
model_name: String,
pooling_config_filename: PathBuf,
inner: serde_json::Error,
) -> NewEmbedderError {
let deserialize_pooling_config =
DeserializePoolingConfig { model_name, filename: pooling_config_filename, inner };
Self {
kind: NewEmbedderErrorKind::DeserializePoolingConfig(deserialize_pooling_config),
fault: FaultSource::Runtime,
}
}
pub fn open_tokenizer(
tokenizer_filename: PathBuf,
inner: Box<dyn std::error::Error + Send + Sync>,
@ -319,6 +344,13 @@ pub struct OpenConfig {
pub inner: std::io::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("could not open pooling config at {filename}: {inner}")]
pub struct OpenPoolingConfig {
pub filename: PathBuf,
pub inner: std::io::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
pub struct DeserializeConfig {
@ -327,6 +359,14 @@ pub struct DeserializeConfig {
pub inner: serde_json::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("for model '{model_name}', could not deserialize file at `{filename}` as a pooling config: {inner}")]
pub struct DeserializePoolingConfig {
pub model_name: String,
pub filename: PathBuf,
pub inner: serde_json::Error,
}
#[derive(Debug, thiserror::Error)]
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
if architectures.is_empty() {
@ -354,8 +394,12 @@ pub enum NewEmbedderErrorKind {
#[error(transparent)]
OpenConfig(OpenConfig),
#[error(transparent)]
OpenPoolingConfig(OpenPoolingConfig),
#[error(transparent)]
DeserializeConfig(DeserializeConfig),
#[error(transparent)]
DeserializePoolingConfig(DeserializePoolingConfig),
#[error(transparent)]
UnsupportedModel(UnsupportedModel),
#[error(transparent)]
OpenTokenizer(OpenTokenizer),