mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-05-14 08:14:05 +02:00
Add milli::vector module
This commit is contained in:
parent
2e1903bd50
commit
9ae4dee202
@ -22,6 +22,7 @@ mod readable_slices;
|
||||
pub mod score_details;
|
||||
mod search;
|
||||
pub mod update;
|
||||
pub mod vector;
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
|
191
milli/src/vector/error.rs
Normal file
191
milli/src/vector/error.rs
Normal file
@ -0,0 +1,191 @@
|
||||
use std::fmt::Display;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use hf_hub::api::sync::ApiError;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum FaultSource {
|
||||
User,
|
||||
Runtime,
|
||||
Bug,
|
||||
Undecided,
|
||||
}
|
||||
|
||||
impl Display for FaultSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
FaultSource::User => "user error",
|
||||
FaultSource::Runtime => "runtime error",
|
||||
FaultSource::Bug => "coding error",
|
||||
FaultSource::Undecided => "error",
|
||||
};
|
||||
f.write_str(s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Error while generating embeddings: {inner}")]
|
||||
pub struct Error {
|
||||
pub inner: Box<ErrorKind>,
|
||||
}
|
||||
|
||||
impl<I: Into<ErrorKind>> From<I> for Error {
|
||||
fn from(value: I) -> Self {
|
||||
Self { inner: Box::new(value.into()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn fault(&self) -> FaultSource {
|
||||
match &*self.inner {
|
||||
ErrorKind::NewEmbedderError(inner) => inner.fault,
|
||||
ErrorKind::EmbedError(inner) => inner.fault,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ErrorKind {
|
||||
#[error(transparent)]
|
||||
NewEmbedderError(#[from] NewEmbedderError),
|
||||
#[error(transparent)]
|
||||
EmbedError(#[from] EmbedError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct EmbedError {
|
||||
pub kind: EmbedErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EmbedErrorKind {
|
||||
#[error("could not tokenize: {0}")]
|
||||
Tokenize(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("unexpected tensor shape: {0}")]
|
||||
TensorShape(candle_core::Error),
|
||||
#[error("unexpected tensor value: {0}")]
|
||||
TensorValue(candle_core::Error),
|
||||
#[error("could not run model: {0}")]
|
||||
ModelForward(candle_core::Error),
|
||||
}
|
||||
|
||||
impl EmbedError {
|
||||
pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self {
|
||||
Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn tensor_shape(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn tensor_value(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn model_forward(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct NewEmbedderError {
|
||||
pub kind: NewEmbedderErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
|
||||
impl NewEmbedderError {
|
||||
pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError {
|
||||
let open_config = OpenConfig { filename: config_filename, inner };
|
||||
|
||||
Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn deserialize_config(
|
||||
config: String,
|
||||
config_filename: PathBuf,
|
||||
inner: serde_json::Error,
|
||||
) -> NewEmbedderError {
|
||||
let deserialize_config = DeserializeConfig { config, filename: config_filename, inner };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open_tokenizer(
|
||||
tokenizer_filename: PathBuf,
|
||||
inner: Box<dyn std::error::Error + Send + Sync>,
|
||||
) -> NewEmbedderError {
|
||||
let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_api_fail(inner: ApiError) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn api_get(inner: ApiError) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided }
|
||||
}
|
||||
|
||||
pub fn pytorch_weight(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn load_model(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not open config at {filename:?}: {inner}")]
|
||||
pub struct OpenConfig {
|
||||
pub filename: PathBuf,
|
||||
pub inner: std::io::Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")]
|
||||
pub struct DeserializeConfig {
|
||||
pub config: String,
|
||||
pub filename: PathBuf,
|
||||
pub inner: serde_json::Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not open tokenizer at {filename}: {inner}")]
|
||||
pub struct OpenTokenizer {
|
||||
pub filename: PathBuf,
|
||||
#[source]
|
||||
pub inner: Box<dyn std::error::Error + Send + Sync>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum NewEmbedderErrorKind {
|
||||
#[error(transparent)]
|
||||
OpenConfig(OpenConfig),
|
||||
#[error(transparent)]
|
||||
DeserializeConfig(DeserializeConfig),
|
||||
#[error(transparent)]
|
||||
OpenTokenizer(OpenTokenizer),
|
||||
#[error("could not build weights from Pytorch weights: {0}")]
|
||||
PytorchWeight(candle_core::Error),
|
||||
#[error("could not build weights from Safetensor weights: {0}")]
|
||||
SafetensorWeight(candle_core::Error),
|
||||
#[error("could not spawn HG_HUB API client: {0}")]
|
||||
NewApiFail(ApiError),
|
||||
#[error("fetching file from HG_HUB failed: {0}")]
|
||||
ApiGet(ApiError),
|
||||
#[error("loading model failed: {0}")]
|
||||
LoadModel(candle_core::Error),
|
||||
}
|
154
milli/src/vector/mod.rs
Normal file
154
milli/src/vector/mod.rs
Normal file
@ -0,0 +1,154 @@
|
||||
use candle_core::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
pub use self::error::{EmbedError, Error, NewEmbedderError};
|
||||
|
||||
mod error;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub enum WeightSource {
|
||||
#[default]
|
||||
Safetensors,
|
||||
Pytorch,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EmbedderOptions {
|
||||
pub model: String,
|
||||
pub revision: Option<String>,
|
||||
pub weight_source: WeightSource,
|
||||
pub normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
//model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||
revision: Some("refs/pr/21".to_string()),
|
||||
//revision: None,
|
||||
weight_source: Default::default(),
|
||||
//weight_source: WeightSource::Pytorch,
|
||||
normalize_embeddings: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbedderOptions {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform embedding of documents and queries
|
||||
pub struct Embedder {
|
||||
model: BertModel,
|
||||
tokenizer: Tokenizer,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Embedder {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Embedder")
|
||||
.field("model", &self.options.model)
|
||||
.field("tokenizer", &self.tokenizer)
|
||||
.field("options", &self.options)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
let device = candle_core::Device::Cpu;
|
||||
let repo = match options.revision.clone() {
|
||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||
None => Repo::model(options.model.clone()),
|
||||
};
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
||||
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
|
||||
let weights = match options.weight_source {
|
||||
WeightSource::Pytorch => {
|
||||
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)?
|
||||
}
|
||||
WeightSource::Safetensors => {
|
||||
api.get("model.safetensors").map_err(NewEmbedderError::api_get)?
|
||||
}
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
|
||||
let config = std::fs::read_to_string(&config_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
||||
let config: Config = serde_json::from_str(&config).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(config, config_filename, inner)
|
||||
})?;
|
||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||
|
||||
let vb = match options.weight_source {
|
||||
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
|
||||
.map_err(NewEmbedderError::pytorch_weight)?,
|
||||
WeightSource::Safetensors => unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
|
||||
.map_err(NewEmbedderError::safetensor_weight)?
|
||||
},
|
||||
};
|
||||
|
||||
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
||||
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
|
||||
Ok(Self { model, tokenizer, options })
|
||||
}
|
||||
|
||||
pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Vec<f32>>, EmbedError> {
|
||||
let tokens = self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?;
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
||||
})
|
||||
.collect::<Result<Vec<_>, EmbedError>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings =
|
||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) =
|
||||
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||
|
||||
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||
.map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings: Tensor = if self.options.normalize_embeddings {
|
||||
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
|
||||
let embeddings = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||
Ok(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user