mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-22 20:50:04 +01:00
Merge #4304
4304: Add CUDA GPU support for Hugging Face embedders r=Kerollmops a=dureuill Adds a "cuda" feature to `milli`. Compiling with this feature requires that the CUDA support library be installed (see "with CUDA support" paragraph in https://huggingface.github.io/candle/guide/installation.html), and adds CUDA support to the `huggingFace` embedder. To enable GPU support, users will need to: 1. Have a compatible NVidia GPU under Linux 2. Follow [the guide](https://huggingface.github.io/candle/guide/installation.html) to install the CUDA dependencies 3. Compile Meilisearch with the `cuda` feature: `cargo build --release --features cuda` # Impact Enabling the CUDA feature allows to use an available GPU to compute embeddings with a `huggingFace` embedder. On an AWS Graviton 2, this yields a x3 - x5 improvement on indexing time. # Technical details - I had to change the CI so that the cuda feature is not included in the `Tests all features` workflow - To achieve that, I had to add a binary following the `cargo xtask` design pattern, to list all features excepted the cuda one. - I then changed the workflow accordingly (renamed to "Tests almost all features" 😉) - A test run of the new feature was done on a temporary version of this PR that had it enabled for PRs: [See the results here](https://github.com/meilisearch/meilisearch/actions/runs/7461331929/job/20301216732) Co-authored-by: Louis Dureuil <louis@meilisearch.com>
This commit is contained in:
commit
b6fc181993
2
.cargo/config.toml
Normal file
2
.cargo/config.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[alias]
|
||||
xtask = "run --package xtask --"
|
18
.github/workflows/test-suite.yml
vendored
18
.github/workflows/test-suite.yml
vendored
@ -82,7 +82,7 @@ jobs:
|
||||
args: --locked --release --all
|
||||
|
||||
test-all-features:
|
||||
name: Tests all features
|
||||
name: Tests almost all features
|
||||
runs-on: ubuntu-latest
|
||||
container:
|
||||
# Use ubuntu-18.04 to compile with glibc 2.27, which are the production expectations
|
||||
@ -98,16 +98,12 @@ jobs:
|
||||
with:
|
||||
toolchain: stable
|
||||
override: true
|
||||
- name: Run cargo build with all features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: build
|
||||
args: --workspace --locked --release --all-features
|
||||
- name: Run cargo test with all features
|
||||
uses: actions-rs/cargo@v1
|
||||
with:
|
||||
command: test
|
||||
args: --workspace --locked --release --all-features
|
||||
- name: Run cargo build with almost all features
|
||||
run: |
|
||||
cargo build --workspace --locked --release --features "$(cargo xtask list-features --exclude-feature cuda)"
|
||||
- name: Run cargo test with almost all features
|
||||
run: |
|
||||
cargo test --workspace --locked --release --features "$(cargo xtask list-features --exclude-feature cuda)"
|
||||
|
||||
test-disabled-tokenization:
|
||||
name: Test disabled tokenization
|
||||
|
@ -75,6 +75,12 @@ If you get a "Too many open files" error you might want to increase the open fil
|
||||
ulimit -Sn 3000
|
||||
```
|
||||
|
||||
#### Build tools
|
||||
|
||||
Meilisearch follows the [cargo xtask](https://github.com/matklad/cargo-xtask) workflow to provide some build tools.
|
||||
|
||||
Run `cargo xtask --help` from the root of the repository to find out what is available.
|
||||
|
||||
## Git Guidelines
|
||||
|
||||
### Git Branches
|
||||
|
64
Cargo.lock
generated
64
Cargo.lock
generated
@ -700,12 +700,23 @@ dependencies = [
|
||||
"displaydoc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "camino"
|
||||
version = "1.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c59e92b5a388f549b863a7bea62612c09f24c8393560709a54558a9abdfb3b9c"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-core"
|
||||
version = "0.3.3"
|
||||
source = "git+https://github.com/huggingface/candle.git#5270224f407502b82fe90bc2622894ce3871b002"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"candle-kernels",
|
||||
"cudarc",
|
||||
"gemm",
|
||||
"half 2.3.1",
|
||||
"memmap2 0.9.3",
|
||||
@ -720,6 +731,16 @@ dependencies = [
|
||||
"zip",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-kernels"
|
||||
version = "0.3.1"
|
||||
source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"glob",
|
||||
"rayon",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "candle-nn"
|
||||
version = "0.3.3"
|
||||
@ -752,6 +773,29 @@ dependencies = [
|
||||
"wav",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo-platform"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ceed8ef69d8518a5dda55c07425450b58a4e1946f4951eab6d7191ee86c2443d"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo_metadata"
|
||||
version = "0.18.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d886547e41f740c616ae73108f6eb70afe6d940c7bc697cb30f13daec073037"
|
||||
dependencies = [
|
||||
"camino",
|
||||
"cargo-platform",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cargo_toml"
|
||||
version = "0.18.0"
|
||||
@ -1163,6 +1207,15 @@ dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cudarc"
|
||||
version = "0.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9395df0cab995685664e79cc35ad6302bf08fb9c5d82301875a183affe1278b1"
|
||||
dependencies = [
|
||||
"half 2.3.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "darling"
|
||||
version = "0.14.4"
|
||||
@ -4827,6 +4880,9 @@ name = "semver"
|
||||
version = "1.0.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "seq-macro"
|
||||
@ -6174,6 +6230,14 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "xtask"
|
||||
version = "1.6.0"
|
||||
dependencies = [
|
||||
"cargo_metadata",
|
||||
"clap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yada"
|
||||
version = "0.5.0"
|
||||
|
@ -16,6 +16,7 @@ members = [
|
||||
"json-depth-checker",
|
||||
"benchmarks",
|
||||
"fuzzers",
|
||||
"xtask",
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
|
@ -137,3 +137,6 @@ greek = ["charabia/greek"]
|
||||
|
||||
# allow khmer specialized tokenization
|
||||
khmer = ["charabia/khmer"]
|
||||
|
||||
# allow CUDA support, see <https://github.com/meilisearch/meilisearch/issues/4306>
|
||||
cuda = ["candle-core/cuda"]
|
||||
|
@ -70,7 +70,13 @@ impl std::fmt::Debug for Embedder {
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
let device = candle_core::Device::Cpu;
|
||||
let device = match candle_core::Device::cuda_if_available(0) {
|
||||
Ok(device) => device,
|
||||
Err(error) => {
|
||||
log::warn!("could not initialize CUDA device for Hugging Face embedder, defaulting to CPU: {}", error);
|
||||
candle_core::Device::Cpu
|
||||
}
|
||||
};
|
||||
let repo = match options.revision.clone() {
|
||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||
None => Repo::model(options.model.clone()),
|
||||
|
15
xtask/Cargo.toml
Normal file
15
xtask/Cargo.toml
Normal file
@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "xtask"
|
||||
version.workspace = true
|
||||
authors.workspace = true
|
||||
description = "Workspace automation tool following the xtask pattern <https://github.com/matklad/cargo-xtask>"
|
||||
homepage.workspace = true
|
||||
readme.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
cargo_metadata = "0.18.1"
|
||||
clap = { version = "4.4.14", features = ["derive"] }
|
41
xtask/src/main.rs
Normal file
41
xtask/src/main.rs
Normal file
@ -0,0 +1,41 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
/// List features available in the workspace
|
||||
#[derive(Parser, Debug)]
|
||||
struct ListFeaturesDeriveArgs {
|
||||
/// Feature to exclude from the list. Repeat the argument to exclude multiple features
|
||||
#[arg(short, long)]
|
||||
exclude_feature: Vec<String>,
|
||||
}
|
||||
|
||||
/// Utilitary commands
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about)]
|
||||
#[command(name = "cargo xtask")]
|
||||
#[command(bin_name = "cargo xtask")]
|
||||
enum Command {
|
||||
ListFeatures(ListFeaturesDeriveArgs),
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let args = Command::parse();
|
||||
match args {
|
||||
Command::ListFeatures(args) => list_features(args),
|
||||
}
|
||||
}
|
||||
|
||||
fn list_features(args: ListFeaturesDeriveArgs) {
|
||||
let exclude_features: HashSet<_> = args.exclude_feature.into_iter().collect();
|
||||
let metadata = cargo_metadata::MetadataCommand::new().no_deps().exec().unwrap();
|
||||
let features: Vec<String> = metadata
|
||||
.packages
|
||||
.iter()
|
||||
.flat_map(|package| package.features.keys())
|
||||
.filter(|feature| !exclude_features.contains(feature.as_str()))
|
||||
.map(|s| s.to_owned())
|
||||
.collect();
|
||||
let features = features.join(" ");
|
||||
println!("{features}")
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user