From 84f49d76cd2869cc4af08d835d6182bfc2c9e042 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Sun, 7 Jan 2024 21:27:29 +0100 Subject: [PATCH] Add cuda feature --- Cargo.lock | 21 +++++++++++++++++++++ milli/Cargo.toml | 3 +++ milli/src/vector/hf.rs | 2 +- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index bf3851db5..f1fc93b1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -706,6 +706,8 @@ version = "0.3.3" source = "git+https://github.com/huggingface/candle.git#5270224f407502b82fe90bc2622894ce3871b002" dependencies = [ "byteorder", + "candle-kernels", + "cudarc", "gemm", "half 2.3.1", "memmap2 0.9.3", @@ -720,6 +722,16 @@ dependencies = [ "zip", ] +[[package]] +name = "candle-kernels" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528" +dependencies = [ + "anyhow", + "glob", + "rayon", +] + [[package]] name = "candle-nn" version = "0.3.3" @@ -1163,6 +1175,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "cudarc" +version = "0.9.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1871a911a2b9a3f66a285896a719159985683bf9903aa2cf89e0c9f53e14552" +dependencies = [ + "half 2.3.1", +] + [[package]] name = "darling" version = "0.14.4" diff --git a/milli/Cargo.toml b/milli/Cargo.toml index ec27b5f39..047a30e35 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -137,3 +137,6 @@ greek = ["charabia/greek"] # allow khmer specialized tokenization khmer = ["charabia/khmer"] + +# allow CUDA support +cuda = ["candle-core/cuda"] diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 0a6bcbe93..3a3949e77 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -70,7 +70,7 @@ impl std::fmt::Debug for Embedder { impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { - let device = candle_core::Device::Cpu; + let device = candle_core::Device::cuda_if_available(0).unwrap(); let repo = match options.revision.clone() { Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), None => Repo::model(options.model.clone()),