Add cuda feature

This commit is contained in:
Louis Dureuil 2024-01-07 21:27:29 +01:00
parent 262b20fdba
commit 84f49d76cd
No known key found for this signature in database
3 changed files with 25 additions and 1 deletions

21
Cargo.lock generated
View File

@ -706,6 +706,8 @@ version = "0.3.3"
source = "git+https://github.com/huggingface/candle.git#5270224f407502b82fe90bc2622894ce3871b002" source = "git+https://github.com/huggingface/candle.git#5270224f407502b82fe90bc2622894ce3871b002"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels",
"cudarc",
"gemm", "gemm",
"half 2.3.1", "half 2.3.1",
"memmap2 0.9.3", "memmap2 0.9.3",
@ -720,6 +722,16 @@ dependencies = [
"zip", "zip",
] ]
[[package]]
name = "candle-kernels"
version = "0.3.1"
source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528"
dependencies = [
"anyhow",
"glob",
"rayon",
]
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.3.3" version = "0.3.3"
@ -1163,6 +1175,15 @@ dependencies = [
"memchr", "memchr",
] ]
[[package]]
name = "cudarc"
version = "0.9.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1871a911a2b9a3f66a285896a719159985683bf9903aa2cf89e0c9f53e14552"
dependencies = [
"half 2.3.1",
]
[[package]] [[package]]
name = "darling" name = "darling"
version = "0.14.4" version = "0.14.4"

View File

@ -137,3 +137,6 @@ greek = ["charabia/greek"]
# allow khmer specialized tokenization # allow khmer specialized tokenization
khmer = ["charabia/khmer"] khmer = ["charabia/khmer"]
# allow CUDA support
cuda = ["candle-core/cuda"]

View File

@ -70,7 +70,7 @@ impl std::fmt::Debug for Embedder {
impl Embedder { impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> { pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
let device = candle_core::Device::Cpu; let device = candle_core::Device::cuda_if_available(0).unwrap();
let repo = match options.revision.clone() { let repo = match options.revision.clone() {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()), None => Repo::model(options.model.clone()),