Remove some settings

This commit is contained in:
Louis Dureuil 2023-12-13 23:09:50 +01:00
parent 3c1a14f1cd
commit 5b51cb04af
No known key found for this signature in database
2 changed files with 15 additions and 53 deletions

View file

@ -23,7 +23,7 @@ use super::{Embedding, Embeddings};
)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
#[deserr(rename_all = camelCase, deny_unknown_fields)]
pub enum WeightSource {
enum WeightSource {
#[default]
Safetensors,
Pytorch,
@ -33,20 +33,13 @@ pub enum WeightSource {
pub struct EmbedderOptions {
pub model: String,
pub revision: Option<String>,
pub weight_source: WeightSource,
pub normalize_embeddings: bool,
}
impl EmbedderOptions {
pub fn new() -> Self {
Self {
//model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
model: "BAAI/bge-base-en-v1.5".to_string(),
//revision: Some("refs/pr/21".to_string()),
revision: None,
//weight_source: Default::default(),
weight_source: WeightSource::Pytorch,
normalize_embeddings: true,
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
}
}
}
@ -82,20 +75,21 @@ impl Embedder {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()),
};
let (config_filename, tokenizer_filename, weights_filename) = {
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
let api = api.repo(repo);
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
let weights = match options.weight_source {
WeightSource::Pytorch => {
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)?
}
WeightSource::Safetensors => {
api.get("model.safetensors").map_err(NewEmbedderError::api_get)?
}
let (weights, source) = {
api.get("pytorch_model.bin")
.map(|filename| (filename, WeightSource::Pytorch))
.or_else(|_| {
api.get("model.safetensors")
.map(|filename| (filename, WeightSource::Safetensors))
})
.map_err(NewEmbedderError::api_get)?
};
(config, tokenizer, weights)
(config, tokenizer, weights, source)
};
let config = std::fs::read_to_string(&config_filename)
@ -106,7 +100,7 @@ impl Embedder {
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
let vb = match options.weight_source {
let vb = match weight_source {
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
.map_err(NewEmbedderError::pytorch_weight)?,
WeightSource::Safetensors => unsafe {
@ -168,12 +162,6 @@ impl Embedder {
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
.map_err(EmbedError::tensor_shape)?;
let embeddings: Tensor = if self.options.normalize_embeddings {
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
} else {
embeddings
};
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
}
@ -197,7 +185,3 @@ impl Embedder {
self.dimensions
}
}
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
}