Add cuda feature

This commit is contained in:
Louis Dureuil 2024-01-07 21:27:29 +01:00
parent 262b20fdba
commit 84f49d76cd
No known key found for this signature in database
3 changed files with 25 additions and 1 deletions

View file

@ -70,7 +70,7 @@ impl std::fmt::Debug for Embedder {
impl Embedder {
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
let device = candle_core::Device::Cpu;
let device = candle_core::Device::cuda_if_available(0).unwrap();
let repo = match options.revision.clone() {
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
None => Repo::model(options.model.clone()),