mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-23 05:00:06 +01:00
Remove some settings
This commit is contained in:
parent
3c1a14f1cd
commit
5b51cb04af
@ -23,7 +23,7 @@ use super::{Embedding, Embeddings};
|
||||
)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum WeightSource {
|
||||
enum WeightSource {
|
||||
#[default]
|
||||
Safetensors,
|
||||
Pytorch,
|
||||
@ -33,20 +33,13 @@ pub enum WeightSource {
|
||||
pub struct EmbedderOptions {
|
||||
pub model: String,
|
||||
pub revision: Option<String>,
|
||||
pub weight_source: WeightSource,
|
||||
pub normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
//model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||
//revision: Some("refs/pr/21".to_string()),
|
||||
revision: None,
|
||||
//weight_source: Default::default(),
|
||||
weight_source: WeightSource::Pytorch,
|
||||
normalize_embeddings: true,
|
||||
revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -82,20 +75,21 @@ impl Embedder {
|
||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||
None => Repo::model(options.model.clone()),
|
||||
};
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let (config_filename, tokenizer_filename, weights_filename, weight_source) = {
|
||||
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
||||
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
|
||||
let weights = match options.weight_source {
|
||||
WeightSource::Pytorch => {
|
||||
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)?
|
||||
}
|
||||
WeightSource::Safetensors => {
|
||||
api.get("model.safetensors").map_err(NewEmbedderError::api_get)?
|
||||
}
|
||||
let (weights, source) = {
|
||||
api.get("pytorch_model.bin")
|
||||
.map(|filename| (filename, WeightSource::Pytorch))
|
||||
.or_else(|_| {
|
||||
api.get("model.safetensors")
|
||||
.map(|filename| (filename, WeightSource::Safetensors))
|
||||
})
|
||||
.map_err(NewEmbedderError::api_get)?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
(config, tokenizer, weights, source)
|
||||
};
|
||||
|
||||
let config = std::fs::read_to_string(&config_filename)
|
||||
@ -106,7 +100,7 @@ impl Embedder {
|
||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||
|
||||
let vb = match options.weight_source {
|
||||
let vb = match weight_source {
|
||||
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
|
||||
.map_err(NewEmbedderError::pytorch_weight)?,
|
||||
WeightSource::Safetensors => unsafe {
|
||||
@ -168,12 +162,6 @@ impl Embedder {
|
||||
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||
.map_err(EmbedError::tensor_shape)?;
|
||||
|
||||
let embeddings: Tensor = if self.options.normalize_embeddings {
|
||||
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
|
||||
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
||||
}
|
||||
@ -197,7 +185,3 @@ impl Embedder {
|
||||
self.dimensions
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::prompt::PromptData;
|
||||
use crate::update::Setting;
|
||||
use crate::vector::hf::WeightSource;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
@ -204,26 +203,13 @@ pub struct HfEmbedderSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub revision: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub weight_source: Setting<WeightSource>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub normalize_embeddings: Setting<bool>,
|
||||
}
|
||||
|
||||
impl HfEmbedderSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let HfEmbedderSettings {
|
||||
model,
|
||||
revision,
|
||||
weight_source,
|
||||
normalize_embeddings: normalize_embedding,
|
||||
} = new;
|
||||
let HfEmbedderSettings { model, revision } = new;
|
||||
self.model.apply(model);
|
||||
self.revision.apply(revision);
|
||||
self.weight_source.apply(weight_source);
|
||||
self.normalize_embeddings.apply(normalize_embedding);
|
||||
}
|
||||
}
|
||||
|
||||
@ -232,15 +218,13 @@ impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
|
||||
Self {
|
||||
model: Setting::Set(value.model),
|
||||
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
|
||||
weight_source: Setting::Set(value.weight_source),
|
||||
normalize_embeddings: Setting::Set(value.normalize_embeddings),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
|
||||
fn from(value: HfEmbedderSettings) -> Self {
|
||||
let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value;
|
||||
let HfEmbedderSettings { model, revision } = value;
|
||||
let mut this = Self::default();
|
||||
if let Some(model) = model.set() {
|
||||
this.model = model;
|
||||
@ -248,12 +232,6 @@ impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
|
||||
if let Some(revision) = revision.set() {
|
||||
this.revision = Some(revision);
|
||||
}
|
||||
if let Some(weight_source) = weight_source.set() {
|
||||
this.weight_source = weight_source;
|
||||
}
|
||||
if let Some(normalize_embeddings) = normalize_embeddings.set() {
|
||||
this.normalize_embeddings = normalize_embeddings;
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user