use candle_core::Tensor; use candle_nn::VarBuilder; use candle_transformers::models::bert::{BertModel, Config, DTYPE}; // FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself use hf_hub::api::sync::Api; use hf_hub::{Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; pub use super::error::{EmbedError, Error, NewEmbedderError}; use super::{DistributionShift, Embedding}; #[derive( Debug, Clone, Copy, Default, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize, deserr::Deserr, )] #[serde(deny_unknown_fields, rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] enum WeightSource { #[default] Safetensors, Pytorch, } #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { pub model: String, pub revision: Option, pub distribution: Option, #[serde(default)] pub pooling: OverridePooling, } #[derive( Debug, Clone, Copy, Default, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize, utoipa::ToSchema, deserr::Deserr, )] #[deserr(rename_all = camelCase, deny_unknown_fields)] #[serde(rename_all = "camelCase")] pub enum OverridePooling { UseModel, ForceCls, #[default] ForceMean, } impl EmbedderOptions { pub fn new() -> Self { Self { model: "BAAI/bge-base-en-v1.5".to_string(), revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), distribution: None, pooling: OverridePooling::UseModel, } } } impl Default for EmbedderOptions { fn default() -> Self { Self::new() } } /// Perform embedding of documents and queries pub struct Embedder { model: BertModel, tokenizer: Tokenizer, options: EmbedderOptions, dimensions: usize, pooling: Pooling, } impl std::fmt::Debug for Embedder { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Embedder") .field("model", &self.options.model) .field("tokenizer", &self.tokenizer) .field("options", &self.options) .field("pooling", &self.pooling) .finish() } } #[derive(Clone, Copy, serde::Deserialize)] struct PoolingConfig { #[serde(default)] pub pooling_mode_cls_token: bool, #[serde(default)] pub pooling_mode_mean_tokens: bool, #[serde(default)] pub pooling_mode_max_tokens: bool, #[serde(default)] pub pooling_mode_mean_sqrt_len_tokens: bool, #[serde(default)] pub pooling_mode_lasttoken: bool, } #[derive(Debug, Clone, Copy, Default)] pub enum Pooling { #[default] Mean, Cls, Max, MeanSqrtLen, LastToken, } impl Pooling { fn override_with(&mut self, pooling: OverridePooling) { match pooling { OverridePooling::UseModel => {} OverridePooling::ForceCls => *self = Pooling::Cls, OverridePooling::ForceMean => *self = Pooling::Mean, } } } impl From for Pooling { fn from(value: PoolingConfig) -> Self { if value.pooling_mode_cls_token { Self::Cls } else if value.pooling_mode_mean_tokens { Self::Mean } else if value.pooling_mode_lasttoken { Self::LastToken } else if value.pooling_mode_mean_sqrt_len_tokens { Self::MeanSqrtLen } else if value.pooling_mode_max_tokens { Self::Max } else { Self::default() } } } impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { let device = match candle_core::Device::cuda_if_available(0) { Ok(device) => device, Err(error) => { tracing::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error); candle_core::Device::Cpu } }; let repo = match options.revision.clone() { Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), None => Repo::model(options.model.clone()), }; let (config_filename, tokenizer_filename, weights_filename, weight_source, pooling) = { let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; let api = api.repo(repo); let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; let (weights, source) = { api.get("model.safetensors") .map(|filename| (filename, WeightSource::Safetensors)) .or_else(|_| { api.get("pytorch_model.bin") .map(|filename| (filename, WeightSource::Pytorch)) }) .map_err(NewEmbedderError::api_get)? }; let pooling = match api.get("1_Pooling/config.json") { Ok(pooling) => Some(pooling), Err(hf_hub::api::sync::ApiError::RequestError(error)) if matches!(*error, ureq::Error::Status(404, _,)) => { // ignore the error if the file simply doesn't exist None } Err(error) => return Err(NewEmbedderError::api_get(error)), }; let mut pooling: Pooling = match pooling { Some(pooling_filename) => { let pooling = std::fs::read_to_string(&pooling_filename).map_err(|inner| { NewEmbedderError::open_pooling_config(pooling_filename.clone(), inner) })?; let pooling: PoolingConfig = serde_json::from_str(&pooling).map_err(|inner| { NewEmbedderError::deserialize_pooling_config( options.model.clone(), pooling_filename, inner, ) })?; pooling.into() } None => Pooling::default(), }; pooling.override_with(options.pooling); (config, tokenizer, weights, source, pooling) }; let config = std::fs::read_to_string(&config_filename) .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; let config: Config = serde_json::from_str(&config).map_err(|inner| { NewEmbedderError::deserialize_config( options.model.clone(), config, config_filename, inner, ) })?; let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; let vb = match weight_source { WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) .map_err(NewEmbedderError::pytorch_weight)?, WeightSource::Safetensors => unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) .map_err(NewEmbedderError::safetensor_weight)? }, }; tracing::debug!(model = options.model, weight=?weight_source, pooling=?pooling, "model config"); let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; if let Some(pp) = tokenizer.get_padding_mut() { pp.strategy = tokenizers::PaddingStrategy::BatchLongest } else { let pp = PaddingParams { strategy: tokenizers::PaddingStrategy::BatchLongest, ..Default::default() }; tokenizer.with_padding(Some(pp)); } let mut this = Self { model, tokenizer, options, dimensions: 0, pooling }; let embeddings = this .embed(vec!["test".into()]) .map_err(NewEmbedderError::could_not_determine_dimension)?; this.dimensions = embeddings.first().unwrap().len(); Ok(this) } pub fn embed(&self, texts: Vec) -> std::result::Result, EmbedError> { texts.into_iter().map(|text| self.embed_one(&text)).collect() } fn pooling(embeddings: Tensor, pooling: Pooling) -> Result { match pooling { Pooling::Mean => Self::mean_pooling(embeddings), Pooling::Cls => Self::cls_pooling(embeddings), Pooling::Max => Self::max_pooling(embeddings), Pooling::MeanSqrtLen => Self::mean_sqrt_pooling(embeddings), Pooling::LastToken => Self::last_token_pooling(embeddings), } } fn cls_pooling(embeddings: Tensor) -> Result { embeddings.get_on_dim(1, 0).map_err(EmbedError::tensor_value) } fn mean_sqrt_pooling(embeddings: Tensor) -> Result { let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3().map_err(EmbedError::tensor_shape)?; (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64).sqrt()) .map_err(EmbedError::tensor_shape) } fn mean_pooling(embeddings: Tensor) -> Result { let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3().map_err(EmbedError::tensor_shape)?; (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) .map_err(EmbedError::tensor_shape) } fn max_pooling(embeddings: Tensor) -> Result { embeddings.max(1).map_err(EmbedError::tensor_shape) } fn last_token_pooling(embeddings: Tensor) -> Result { let (_n_sentence, n_tokens, _hidden_size) = embeddings.dims3().map_err(EmbedError::tensor_shape)?; embeddings.get_on_dim(1, n_tokens - 1).map_err(EmbedError::tensor_value) } pub fn embed_one(&self, text: &str) -> std::result::Result { let tokens = self.tokenizer.encode(text, true).map_err(EmbedError::tokenize)?; let token_ids = tokens.get_ids(); let token_ids = if token_ids.len() > 512 { &token_ids[..512] } else { token_ids }; let token_ids = Tensor::new(token_ids, &self.model.device).map_err(EmbedError::tensor_shape)?; let token_ids = Tensor::stack(&[token_ids], 0).map_err(EmbedError::tensor_shape)?; let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; let embeddings = self .model .forward(&token_ids, &token_type_ids, None) .map_err(EmbedError::model_forward)?; let embedding = Self::pooling(embeddings, self.pooling)?; let embedding = embedding.squeeze(0).map_err(EmbedError::tensor_shape)?; let embedding: Embedding = embedding.to_vec1().map_err(EmbedError::tensor_shape)?; Ok(embedding) } pub fn embed_index( &self, text_chunks: Vec>, ) -> std::result::Result>, EmbedError> { text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() } pub fn chunk_count_hint(&self) -> usize { 1 } pub fn prompt_count_in_chunk_hint(&self) -> usize { std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) } pub fn dimensions(&self) -> usize { self.dimensions } pub fn distribution(&self) -> Option { self.options.distribution.or_else(|| { if self.options.model == "BAAI/bge-base-en-v1.5" { Some(DistributionShift { current_mean: ordered_float::OrderedFloat(0.85), current_sigma: ordered_float::OrderedFloat(0.1), }) } else { None } }) } pub(crate) fn embed_index_ref(&self, texts: &[&str]) -> Result, EmbedError> { texts.iter().map(|text| self.embed_one(text)).collect() } }