4549: Hugging Face embedder improvements r=dureuill a=dureuill

Architectural changes/Internal improvements

### 1. Prefer safetensors weights over pytorch weights when available

safetensors weights are memory mapped, which reduces memory usage of supported models.

### 2. Update candle

Updates candle to `0.4.1`, now targeting crates.io and the tokenizers to `v0.15.2` (still on github).

This might fix https://github.com/meilisearch/meilisearch/issues/4399 thanks to the now included https://github.com/huggingface/candle/issues/1454

Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2024-04-04 13:47:18 +00:00 committed by GitHub
commit 339a5e3431
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 186 additions and 203 deletions

373
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -71,10 +71,10 @@ itertools = "0.11.0"
puffin = "0.16.0" puffin = "0.16.0"
csv = "1.3.0" csv = "1.3.0"
candle-core = { git = "https://github.com/huggingface/candle.git", rev="5270224f407502b82fe90bc2622894ce3871b002", version = "0.3.3" } candle-core = { version = "0.4.1" }
candle-transformers = { git = "https://github.com/huggingface/candle.git", rev="5270224f407502b82fe90bc2622894ce3871b002", version = "0.3.3" } candle-transformers = { version = "0.4.1" }
candle-nn = { git = "https://github.com/huggingface/candle.git", rev="5270224f407502b82fe90bc2622894ce3871b002", version = "0.3.3" } candle-nn = { version = "0.4.1" }
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1", default_features = false, features = [ tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.15.2", version = "0.15.2", default_features = false, features = [
"onig", "onig",
] } ] }
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [

View File

@ -89,11 +89,11 @@ impl Embedder {
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
let (weights, source) = { let (weights, source) = {
api.get("pytorch_model.bin") api.get("model.safetensors")
.map(|filename| (filename, WeightSource::Pytorch)) .map(|filename| (filename, WeightSource::Safetensors))
.or_else(|_| { .or_else(|_| {
api.get("model.safetensors") api.get("pytorch_model.bin")
.map(|filename| (filename, WeightSource::Safetensors)) .map(|filename| (filename, WeightSource::Pytorch))
}) })
.map_err(NewEmbedderError::api_get)? .map_err(NewEmbedderError::api_get)?
}; };