diff --git a/milli/src/lib.rs b/milli/src/lib.rs index acea72c41..1624118b7 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -22,6 +22,7 @@ mod readable_slices; pub mod score_details; mod search; pub mod update; +pub mod vector; #[cfg(test)] #[macro_use] diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs new file mode 100644 index 000000000..a6599386f --- /dev/null +++ b/milli/src/vector/error.rs @@ -0,0 +1,191 @@ +use std::fmt::Display; +use std::path::PathBuf; + +use hf_hub::api::sync::ApiError; + +#[derive(Debug, Clone, Copy)] +pub enum FaultSource { + User, + Runtime, + Bug, + Undecided, +} + +impl Display for FaultSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + FaultSource::User => "user error", + FaultSource::Runtime => "runtime error", + FaultSource::Bug => "coding error", + FaultSource::Undecided => "error", + }; + f.write_str(s) + } +} + +#[derive(Debug, thiserror::Error)] +#[error("Error while generating embeddings: {inner}")] +pub struct Error { + pub inner: Box, +} + +impl> From for Error { + fn from(value: I) -> Self { + Self { inner: Box::new(value.into()) } + } +} + +impl Error { + pub fn fault(&self) -> FaultSource { + match &*self.inner { + ErrorKind::NewEmbedderError(inner) => inner.fault, + ErrorKind::EmbedError(inner) => inner.fault, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ErrorKind { + #[error(transparent)] + NewEmbedderError(#[from] NewEmbedderError), + #[error(transparent)] + EmbedError(#[from] EmbedError), +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct EmbedError { + pub kind: EmbedErrorKind, + pub fault: FaultSource, +} + +#[derive(Debug, thiserror::Error)] +pub enum EmbedErrorKind { + #[error("could not tokenize: {0}")] + Tokenize(Box), + #[error("unexpected tensor shape: {0}")] + TensorShape(candle_core::Error), + #[error("unexpected tensor value: {0}")] + TensorValue(candle_core::Error), + #[error("could not run model: {0}")] + ModelForward(candle_core::Error), +} + +impl EmbedError { + pub fn tokenize(inner: Box) -> Self { + Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime } + } + + pub fn tensor_shape(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug } + } + + pub fn tensor_value(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug } + } + + pub fn model_forward(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct NewEmbedderError { + pub kind: NewEmbedderErrorKind, + pub fault: FaultSource, +} + +impl NewEmbedderError { + pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError { + let open_config = OpenConfig { filename: config_filename, inner }; + + Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime } + } + + pub fn deserialize_config( + config: String, + config_filename: PathBuf, + inner: serde_json::Error, + ) -> NewEmbedderError { + let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; + Self { + kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), + fault: FaultSource::Runtime, + } + } + + pub fn open_tokenizer( + tokenizer_filename: PathBuf, + inner: Box, + ) -> NewEmbedderError { + let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner }; + Self { + kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer), + fault: FaultSource::Runtime, + } + } + + pub fn new_api_fail(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug } + } + + pub fn api_get(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided } + } + + pub fn pytorch_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn safetensor_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn load_model(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open config at {filename:?}: {inner}")] +pub struct OpenConfig { + pub filename: PathBuf, + pub inner: std::io::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] +pub struct DeserializeConfig { + pub config: String, + pub filename: PathBuf, + pub inner: serde_json::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open tokenizer at {filename}: {inner}")] +pub struct OpenTokenizer { + pub filename: PathBuf, + #[source] + pub inner: Box, +} + +#[derive(Debug, thiserror::Error)] +pub enum NewEmbedderErrorKind { + #[error(transparent)] + OpenConfig(OpenConfig), + #[error(transparent)] + DeserializeConfig(DeserializeConfig), + #[error(transparent)] + OpenTokenizer(OpenTokenizer), + #[error("could not build weights from Pytorch weights: {0}")] + PytorchWeight(candle_core::Error), + #[error("could not build weights from Safetensor weights: {0}")] + SafetensorWeight(candle_core::Error), + #[error("could not spawn HG_HUB API client: {0}")] + NewApiFail(ApiError), + #[error("fetching file from HG_HUB failed: {0}")] + ApiGet(ApiError), + #[error("loading model failed: {0}")] + LoadModel(candle_core::Error), +} diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs new file mode 100644 index 000000000..ff0f0711b --- /dev/null +++ b/milli/src/vector/mod.rs @@ -0,0 +1,154 @@ +use candle_core::Tensor; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself +use hf_hub::api::sync::Api; +use hf_hub::{Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +pub use self::error::{EmbedError, Error, NewEmbedderError}; + +mod error; + +#[derive(Debug, Default)] +pub enum WeightSource { + #[default] + Safetensors, + Pytorch, +} + +#[derive(Debug)] +pub struct EmbedderOptions { + pub model: String, + pub revision: Option, + pub weight_source: WeightSource, + pub normalize_embeddings: bool, +} + +impl EmbedderOptions { + pub fn new() -> Self { + Self { + model: "sentence-transformers/all-MiniLM-L6-v2".to_string(), + //model: "BAAI/bge-base-en-v1.5".to_string(), + revision: Some("refs/pr/21".to_string()), + //revision: None, + weight_source: Default::default(), + //weight_source: WeightSource::Pytorch, + normalize_embeddings: true, + } + } +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self::new() + } +} + +/// Perform embedding of documents and queries +pub struct Embedder { + model: BertModel, + tokenizer: Tokenizer, + options: EmbedderOptions, +} + +impl std::fmt::Debug for Embedder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Embedder") + .field("model", &self.options.model) + .field("tokenizer", &self.tokenizer) + .field("options", &self.options) + .finish() + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> std::result::Result { + let device = candle_core::Device::Cpu; + let repo = match options.revision.clone() { + Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), + None => Repo::model(options.model.clone()), + }; + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; + let api = api.repo(repo); + let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; + let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; + let weights = match options.weight_source { + WeightSource::Pytorch => { + api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? + } + WeightSource::Safetensors => { + api.get("model.safetensors").map_err(NewEmbedderError::api_get)? + } + }; + (config, tokenizer, weights) + }; + + let config = std::fs::read_to_string(&config_filename) + .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; + let config: Config = serde_json::from_str(&config).map_err(|inner| { + NewEmbedderError::deserialize_config(config, config_filename, inner) + })?; + let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) + .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; + + let vb = match options.weight_source { + WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) + .map_err(NewEmbedderError::pytorch_weight)?, + WeightSource::Safetensors => unsafe { + VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) + .map_err(NewEmbedderError::safetensor_weight)? + }, + }; + + let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; + + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + + Ok(Self { model, tokenizer, options }) + } + + pub fn embed(&self, texts: Vec) -> std::result::Result>, EmbedError> { + let tokens = self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) + }) + .collect::, EmbedError>>()?; + + let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + let embeddings = + self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + + let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + .map_err(EmbedError::tensor_shape)?; + let embeddings: Tensor = if self.options.normalize_embeddings { + normalize_l2(&embeddings).map_err(EmbedError::tensor_value)? + } else { + embeddings + }; + + let embeddings = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; + Ok(embeddings) + } +} + +fn normalize_l2(v: &Tensor) -> Result { + v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) +}