mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-05-14 16:23:57 +02:00
Add milli::vector module
This commit is contained in:
parent
2e1903bd50
commit
9ae4dee202
@ -22,6 +22,7 @@ mod readable_slices;
|
|||||||
pub mod score_details;
|
pub mod score_details;
|
||||||
mod search;
|
mod search;
|
||||||
pub mod update;
|
pub mod update;
|
||||||
|
pub mod vector;
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[macro_use]
|
#[macro_use]
|
||||||
|
191
milli/src/vector/error.rs
Normal file
191
milli/src/vector/error.rs
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
use std::fmt::Display;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
use hf_hub::api::sync::ApiError;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub enum FaultSource {
|
||||||
|
User,
|
||||||
|
Runtime,
|
||||||
|
Bug,
|
||||||
|
Undecided,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for FaultSource {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let s = match self {
|
||||||
|
FaultSource::User => "user error",
|
||||||
|
FaultSource::Runtime => "runtime error",
|
||||||
|
FaultSource::Bug => "coding error",
|
||||||
|
FaultSource::Undecided => "error",
|
||||||
|
};
|
||||||
|
f.write_str(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("Error while generating embeddings: {inner}")]
|
||||||
|
pub struct Error {
|
||||||
|
pub inner: Box<ErrorKind>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: Into<ErrorKind>> From<I> for Error {
|
||||||
|
fn from(value: I) -> Self {
|
||||||
|
Self { inner: Box::new(value.into()) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
pub fn fault(&self) -> FaultSource {
|
||||||
|
match &*self.inner {
|
||||||
|
ErrorKind::NewEmbedderError(inner) => inner.fault,
|
||||||
|
ErrorKind::EmbedError(inner) => inner.fault,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum ErrorKind {
|
||||||
|
#[error(transparent)]
|
||||||
|
NewEmbedderError(#[from] NewEmbedderError),
|
||||||
|
#[error(transparent)]
|
||||||
|
EmbedError(#[from] EmbedError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("{fault}: {kind}")]
|
||||||
|
pub struct EmbedError {
|
||||||
|
pub kind: EmbedErrorKind,
|
||||||
|
pub fault: FaultSource,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum EmbedErrorKind {
|
||||||
|
#[error("could not tokenize: {0}")]
|
||||||
|
Tokenize(Box<dyn std::error::Error + Send + Sync>),
|
||||||
|
#[error("unexpected tensor shape: {0}")]
|
||||||
|
TensorShape(candle_core::Error),
|
||||||
|
#[error("unexpected tensor value: {0}")]
|
||||||
|
TensorValue(candle_core::Error),
|
||||||
|
#[error("could not run model: {0}")]
|
||||||
|
ModelForward(candle_core::Error),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbedError {
|
||||||
|
pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self {
|
||||||
|
Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensor_shape(inner: candle_core::Error) -> Self {
|
||||||
|
Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensor_value(inner: candle_core::Error) -> Self {
|
||||||
|
Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn model_forward(inner: candle_core::Error) -> Self {
|
||||||
|
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("{fault}: {kind}")]
|
||||||
|
pub struct NewEmbedderError {
|
||||||
|
pub kind: NewEmbedderErrorKind,
|
||||||
|
pub fault: FaultSource,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NewEmbedderError {
|
||||||
|
pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError {
|
||||||
|
let open_config = OpenConfig { filename: config_filename, inner };
|
||||||
|
|
||||||
|
Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn deserialize_config(
|
||||||
|
config: String,
|
||||||
|
config_filename: PathBuf,
|
||||||
|
inner: serde_json::Error,
|
||||||
|
) -> NewEmbedderError {
|
||||||
|
let deserialize_config = DeserializeConfig { config, filename: config_filename, inner };
|
||||||
|
Self {
|
||||||
|
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn open_tokenizer(
|
||||||
|
tokenizer_filename: PathBuf,
|
||||||
|
inner: Box<dyn std::error::Error + Send + Sync>,
|
||||||
|
) -> NewEmbedderError {
|
||||||
|
let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner };
|
||||||
|
Self {
|
||||||
|
kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer),
|
||||||
|
fault: FaultSource::Runtime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_api_fail(inner: ApiError) -> Self {
|
||||||
|
Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn api_get(inner: ApiError) -> Self {
|
||||||
|
Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn pytorch_weight(inner: candle_core::Error) -> Self {
|
||||||
|
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
|
||||||
|
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_model(inner: candle_core::Error) -> Self {
|
||||||
|
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("could not open config at {filename:?}: {inner}")]
|
||||||
|
pub struct OpenConfig {
|
||||||
|
pub filename: PathBuf,
|
||||||
|
pub inner: std::io::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")]
|
||||||
|
pub struct DeserializeConfig {
|
||||||
|
pub config: String,
|
||||||
|
pub filename: PathBuf,
|
||||||
|
pub inner: serde_json::Error,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
#[error("could not open tokenizer at {filename}: {inner}")]
|
||||||
|
pub struct OpenTokenizer {
|
||||||
|
pub filename: PathBuf,
|
||||||
|
#[source]
|
||||||
|
pub inner: Box<dyn std::error::Error + Send + Sync>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum NewEmbedderErrorKind {
|
||||||
|
#[error(transparent)]
|
||||||
|
OpenConfig(OpenConfig),
|
||||||
|
#[error(transparent)]
|
||||||
|
DeserializeConfig(DeserializeConfig),
|
||||||
|
#[error(transparent)]
|
||||||
|
OpenTokenizer(OpenTokenizer),
|
||||||
|
#[error("could not build weights from Pytorch weights: {0}")]
|
||||||
|
PytorchWeight(candle_core::Error),
|
||||||
|
#[error("could not build weights from Safetensor weights: {0}")]
|
||||||
|
SafetensorWeight(candle_core::Error),
|
||||||
|
#[error("could not spawn HG_HUB API client: {0}")]
|
||||||
|
NewApiFail(ApiError),
|
||||||
|
#[error("fetching file from HG_HUB failed: {0}")]
|
||||||
|
ApiGet(ApiError),
|
||||||
|
#[error("loading model failed: {0}")]
|
||||||
|
LoadModel(candle_core::Error),
|
||||||
|
}
|
154
milli/src/vector/mod.rs
Normal file
154
milli/src/vector/mod.rs
Normal file
@ -0,0 +1,154 @@
|
|||||||
|
use candle_core::Tensor;
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||||
|
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use hf_hub::{Repo, RepoType};
|
||||||
|
use tokenizers::{PaddingParams, Tokenizer};
|
||||||
|
|
||||||
|
pub use self::error::{EmbedError, Error, NewEmbedderError};
|
||||||
|
|
||||||
|
mod error;
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub enum WeightSource {
|
||||||
|
#[default]
|
||||||
|
Safetensors,
|
||||||
|
Pytorch,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct EmbedderOptions {
|
||||||
|
pub model: String,
|
||||||
|
pub revision: Option<String>,
|
||||||
|
pub weight_source: WeightSource,
|
||||||
|
pub normalize_embeddings: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EmbedderOptions {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||||
|
//model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||||
|
revision: Some("refs/pr/21".to_string()),
|
||||||
|
//revision: None,
|
||||||
|
weight_source: Default::default(),
|
||||||
|
//weight_source: WeightSource::Pytorch,
|
||||||
|
normalize_embeddings: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for EmbedderOptions {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Perform embedding of documents and queries
|
||||||
|
pub struct Embedder {
|
||||||
|
model: BertModel,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
options: EmbedderOptions,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for Embedder {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("Embedder")
|
||||||
|
.field("model", &self.options.model)
|
||||||
|
.field("tokenizer", &self.tokenizer)
|
||||||
|
.field("options", &self.options)
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Embedder {
|
||||||
|
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||||
|
let device = candle_core::Device::Cpu;
|
||||||
|
let repo = match options.revision.clone() {
|
||||||
|
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||||
|
None => Repo::model(options.model.clone()),
|
||||||
|
};
|
||||||
|
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||||
|
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
||||||
|
let api = api.repo(repo);
|
||||||
|
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
||||||
|
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
|
||||||
|
let weights = match options.weight_source {
|
||||||
|
WeightSource::Pytorch => {
|
||||||
|
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)?
|
||||||
|
}
|
||||||
|
WeightSource::Safetensors => {
|
||||||
|
api.get("model.safetensors").map_err(NewEmbedderError::api_get)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(config, tokenizer, weights)
|
||||||
|
};
|
||||||
|
|
||||||
|
let config = std::fs::read_to_string(&config_filename)
|
||||||
|
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
||||||
|
let config: Config = serde_json::from_str(&config).map_err(|inner| {
|
||||||
|
NewEmbedderError::deserialize_config(config, config_filename, inner)
|
||||||
|
})?;
|
||||||
|
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||||
|
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||||
|
|
||||||
|
let vb = match options.weight_source {
|
||||||
|
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
|
||||||
|
.map_err(NewEmbedderError::pytorch_weight)?,
|
||||||
|
WeightSource::Safetensors => unsafe {
|
||||||
|
VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
|
||||||
|
.map_err(NewEmbedderError::safetensor_weight)?
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
||||||
|
|
||||||
|
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||||
|
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||||
|
} else {
|
||||||
|
let pp = PaddingParams {
|
||||||
|
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
tokenizer.with_padding(Some(pp));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self { model, tokenizer, options })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn embed(&self, texts: Vec<String>) -> std::result::Result<Vec<Vec<f32>>, EmbedError> {
|
||||||
|
let tokens = self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?;
|
||||||
|
let token_ids = tokens
|
||||||
|
.iter()
|
||||||
|
.map(|tokens| {
|
||||||
|
let tokens = tokens.get_ids().to_vec();
|
||||||
|
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, EmbedError>>()?;
|
||||||
|
|
||||||
|
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
||||||
|
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||||
|
let embeddings =
|
||||||
|
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||||
|
|
||||||
|
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||||
|
let (_n_sentence, n_tokens, _hidden_size) =
|
||||||
|
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||||
|
|
||||||
|
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||||
|
.map_err(EmbedError::tensor_shape)?;
|
||||||
|
let embeddings: Tensor = if self.options.normalize_embeddings {
|
||||||
|
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
|
||||||
|
} else {
|
||||||
|
embeddings
|
||||||
|
};
|
||||||
|
|
||||||
|
let embeddings = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
|
||||||
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user