Allow overriding pooling method

This commit is contained in:
Louis Dureuil 2025-02-18 17:12:23 +01:00
parent 11759c4be4
commit 7b4ce468a6
No known key found for this signature in database
4 changed files with 78 additions and 1 deletions

View file

@ -34,6 +34,30 @@ pub struct EmbedderOptions {
pub model: String,
pub revision: Option<String>,
pub distribution: Option<DistributionShift>,
#[serde(default)]
pub pooling: OverridePooling,
}
#[derive(
Debug,
Clone,
Copy,
Default,
Hash,
PartialEq,
Eq,
serde::Deserialize,
serde::Serialize,
utoipa::ToSchema,
deserr::Deserr,
)]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
#[serde(rename_all = "camelCase")]
pub enum OverridePooling {
UseModel,
ForceCls,
#[default]
ForceMean,
}
impl EmbedderOptions {
@ -42,6 +66,7 @@ impl EmbedderOptions {
model: "BAAI/bge-base-en-v1.5".to_string(),
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
distribution: None,
pooling: OverridePooling::UseModel,
}
}
}
@ -95,6 +120,15 @@ pub enum Pooling {
MeanSqrtLen,
LastToken,
}
impl Pooling {
fn override_with(&mut self, pooling: OverridePooling) {
match pooling {
OverridePooling::UseModel => {}
OverridePooling::ForceCls => *self = Pooling::Cls,
OverridePooling::ForceMean => *self = Pooling::Mean,
}
}
}
impl From<PoolingConfig> for Pooling {
fn from(value: PoolingConfig) -> Self {
@ -151,7 +185,7 @@ impl Embedder {
}
Err(error) => return Err(NewEmbedderError::api_get(error)),
};
let pooling: Pooling = match pooling {
let mut pooling: Pooling = match pooling {
Some(pooling_filename) => {
let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
@ -170,6 +204,8 @@ impl Embedder {
None => Pooling::default(),
};
pooling.override_with(options.pooling);
(config, tokenizer, weights, source, pooling)
};