diff --git a/crates/milli/src/vector/error.rs b/crates/milli/src/vector/error.rs
index d1b2516f5..650249bff 100644
--- a/crates/milli/src/vector/error.rs
+++ b/crates/milli/src/vector/error.rs
@@ -262,6 +262,31 @@ impl NewEmbedderError {
         }
     }
 
+    pub fn open_pooling_config(
+        pooling_config_filename: PathBuf,
+        inner: std::io::Error,
+    ) -> NewEmbedderError {
+        let open_config = OpenPoolingConfig { filename: pooling_config_filename, inner };
+
+        Self {
+            kind: NewEmbedderErrorKind::OpenPoolingConfig(open_config),
+            fault: FaultSource::Runtime,
+        }
+    }
+
+    pub fn deserialize_pooling_config(
+        model_name: String,
+        pooling_config_filename: PathBuf,
+        inner: serde_json::Error,
+    ) -> NewEmbedderError {
+        let deserialize_pooling_config =
+            DeserializePoolingConfig { model_name, filename: pooling_config_filename, inner };
+        Self {
+            kind: NewEmbedderErrorKind::DeserializePoolingConfig(deserialize_pooling_config),
+            fault: FaultSource::Runtime,
+        }
+    }
+
     pub fn open_tokenizer(
         tokenizer_filename: PathBuf,
         inner: Box<dyn std::error::Error + Send + Sync>,
@@ -319,6 +344,13 @@ pub struct OpenConfig {
     pub inner: std::io::Error,
 }
 
+#[derive(Debug, thiserror::Error)]
+#[error("could not open pooling config at {filename}: {inner}")]
+pub struct OpenPoolingConfig {
+    pub filename: PathBuf,
+    pub inner: std::io::Error,
+}
+
 #[derive(Debug, thiserror::Error)]
 #[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
 pub struct DeserializeConfig {
@@ -327,6 +359,14 @@ pub struct DeserializeConfig {
     pub inner: serde_json::Error,
 }
 
+#[derive(Debug, thiserror::Error)]
+#[error("for model '{model_name}', could not deserialize file at `{filename}` as a pooling config: {inner}")]
+pub struct DeserializePoolingConfig {
+    pub model_name: String,
+    pub filename: PathBuf,
+    pub inner: serde_json::Error,
+}
+
 #[derive(Debug, thiserror::Error)]
 #[error("model `{model_name}` appears to be unsupported{}\n  - inner error: {inner}",
 if architectures.is_empty() {
@@ -354,8 +394,12 @@ pub enum NewEmbedderErrorKind {
     #[error(transparent)]
     OpenConfig(OpenConfig),
     #[error(transparent)]
+    OpenPoolingConfig(OpenPoolingConfig),
+    #[error(transparent)]
     DeserializeConfig(DeserializeConfig),
     #[error(transparent)]
+    DeserializePoolingConfig(DeserializePoolingConfig),
+    #[error(transparent)]
     UnsupportedModel(UnsupportedModel),
     #[error(transparent)]
     OpenTokenizer(OpenTokenizer),
diff --git a/crates/milli/src/vector/hf.rs b/crates/milli/src/vector/hf.rs
index 447a88f5d..9ec34daef 100644
--- a/crates/milli/src/vector/hf.rs
+++ b/crates/milli/src/vector/hf.rs
@@ -58,6 +58,7 @@ pub struct Embedder {
     tokenizer: Tokenizer,
     options: EmbedderOptions,
     dimensions: usize,
+    pooling: Pooling,
 }
 
 impl std::fmt::Debug for Embedder {
@@ -66,10 +67,53 @@ impl std::fmt::Debug for Embedder {
             .field("model", &self.options.model)
             .field("tokenizer", &self.tokenizer)
             .field("options", &self.options)
+            .field("pooling", &self.pooling)
             .finish()
     }
 }
 
+#[derive(Clone, Copy, serde::Deserialize)]
+struct PoolingConfig {
+    #[serde(default)]
+    pub pooling_mode_cls_token: bool,
+    #[serde(default)]
+    pub pooling_mode_mean_tokens: bool,
+    #[serde(default)]
+    pub pooling_mode_max_tokens: bool,
+    #[serde(default)]
+    pub pooling_mode_mean_sqrt_len_tokens: bool,
+    #[serde(default)]
+    pub pooling_mode_lasttoken: bool,
+}
+
+#[derive(Debug, Clone, Copy, Default)]
+pub enum Pooling {
+    #[default]
+    Mean,
+    Cls,
+    Max,
+    MeanSqrtLen,
+    LastToken,
+}
+
+impl From<PoolingConfig> for Pooling {
+    fn from(value: PoolingConfig) -> Self {
+        if value.pooling_mode_cls_token {
+            Self::Cls
+        } else if value.pooling_mode_mean_tokens {
+            Self::Mean
+        } else if value.pooling_mode_lasttoken {
+            Self::LastToken
+        } else if value.pooling_mode_mean_sqrt_len_tokens {
+            Self::MeanSqrtLen
+        } else if value.pooling_mode_max_tokens {
+            Self::Max
+        } else {
+            Self::default()
+        }
+    }
+}
+
 impl Embedder {
     pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
         let device = match candle_core::Device::cuda_if_available(0) {
@@ -83,7 +127,7 @@ impl Embedder {
             Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
             None => Repo::model(options.model.clone()),
         };
-        let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
+        let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = {
             let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
             let api = api.repo(repo);
             let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
@@ -97,7 +141,36 @@ impl Embedder {
                     })
                     .map_err(NewEmbedderError::api_get)?
             };
-            (config, tokenizer, weights, source)
+            let pooling = match api.get("1_Pooling/config.json") {
+                Ok(pooling) => Some(pooling),
+                Err(hf_hub::api::sync::ApiError::RequestError(error))
+                    if matches!(*error, ureq::Error::Status(404, _,)) =>
+                {
+                    // ignore the error if the file simply doesn't exist
+                    None
+                }
+                Err(error) => return Err(NewEmbedderError::api_get(error)),
+            };
+            let pooling: Pooling = match pooling {
+                Some(pooling_filename) => {
+                    let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| {
+                        NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner)
+                    })?;
+
+                    let pooling: PoolingConfig =
+                        serde_json::from_str(&pooling).map_err(|inner| {
+                            NewEmbedderError::deserialize_pooling_config(
+                                options.model.clone(),
+                                pooling_filename,
+                                inner,
+                            )
+                        })?;
+                    pooling.into()
+                }
+                None => Pooling::default(),
+            };
+
+            (config, tokenizer, weights, source, pooling)
         };
 
         let config = std::fs::read_to_string(&config_filename)
@@ -122,6 +195,8 @@ impl Embedder {
             },
         };
 
+        tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config");
+
         let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
 
         if let Some(pp) = tokenizer.get_padding_mut() {
@@ -134,7 +209,7 @@ impl Embedder {
             tokenizer.with_padding(Some(pp));
         }
 
-        let mut this = Self { model, tokenizer, options, dimensions: 0 };
+        let mut this = Self { model, tokenizer, options, dimensions: 0, pooling };
 
         let embeddings = this
             .embed(vec!["test".into()])
@@ -168,17 +243,53 @@ impl Embedder {
             .forward(&token_ids, &token_type_ids, None)
             .map_err(EmbedError::model_forward)?;
 
-        // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
-        let (_n_sentence, n_tokens, _hidden_size) =
-            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
-
-        let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
-            .map_err(EmbedError::tensor_shape)?;
+        let embeddings = Self::pooling(embeddings, self.pooling)?;
 
         let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
         Ok(embeddings)
     }
 
+    fn pooling(embeddings: Tensor, pooling: Pooling) -> Result<Tensor, EmbedError> {
+        match pooling {
+            Pooling::Mean => Self::mean_pooling(embeddings),
+            Pooling::Cls => Self::cls_pooling(embeddings),
+            Pooling::Max => Self::max_pooling(embeddings),
+            Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings),
+            Pooling::LastToken => Self::last_token_pooling(embeddings),
+        }
+    }
+
+    fn cls_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
+        embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value)
+    }
+
+    fn mean_sqrt_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
+        let (_n_sentence, n_tokens, _hidden_size) =
+            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
+
+        (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt())
+            .map_err(EmbedError::tensor_shape)
+    }
+
+    fn mean_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
+        let (_n_sentence, n_tokens, _hidden_size) =
+            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
+
+        (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
+            .map_err(EmbedError::tensor_shape)
+    }
+
+    fn max_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
+        embeddings.max(1).map_err(EmbedError::tensor_shape)
+    }
+
+    fn last_token_pooling(embeddings: Tensor) -> Result<Tensor, EmbedError> {
+        let (_n_sentence, n_tokens, _hidden_size) =
+            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
+
+        embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value)
+    }
+
     pub fn embed_one(&self, text: &str) -> std::result::Result<Embedding, EmbedError> {
         let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?;
         let token_ids = tokens.get_ids();
@@ -192,11 +303,8 @@ impl Embedder {
             .forward(&token_ids, &token_type_ids, None)
             .map_err(EmbedError::model_forward)?;
 
-        // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
-        let (_n_sentence, n_tokens, _hidden_size) =
-            embeddings.dims3().map_err(EmbedError::tensor_shape)?;
-        let embedding = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
-            .map_err(EmbedError::tensor_shape)?;
+        let embedding = Self::pooling(embeddings, self.pooling)?;
+
         let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?;
         let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?;
         Ok(embedding)