From f692021bfcbd98eec2b8833d164d287e83a2322f Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 22 Jan 2024 10:13:27 +0100 Subject: [PATCH] Implement PR comments --- milli/Cargo.toml | 2 +- milli/src/vector/hf.rs | 8 +++++++- xtask/Cargo.toml | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 047a30e35..dc2b992e0 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -138,5 +138,5 @@ greek = ["charabia/greek"] # allow khmer specialized tokenization khmer = ["charabia/khmer"] -# allow CUDA support +# allow CUDA support, see cuda = ["candle-core/cuda"] diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 3a3949e77..7acb09aa8 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -70,7 +70,13 @@ impl std::fmt::Debug for Embedder { impl Embedder { pub fn new(options: EmbedderOptions) -> std::result::Result { - let device = candle_core::Device::cuda_if_available(0).unwrap(); + let device = match candle_core::Device::cuda_if_available(0) { + Ok(device) => device, + Err(error) => { + log::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error); + candle_core::Device::Cpu + } + }; let repo = match options.revision.clone() { Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), None => Repo::model(options.model.clone()), diff --git a/xtask/Cargo.toml b/xtask/Cargo.toml index af9ecc7b5..07271ea09 100644 --- a/xtask/Cargo.toml +++ b/xtask/Cargo.toml @@ -2,7 +2,7 @@ name = "xtask" version.workspace = true authors.workspace = true -description.workspace = true +description = "Workspace automation tool following the xtask pattern " homepage.workspace = true readme.workspace = true edition.workspace = true