From 7ee20b0895f902275c727784b4b098e238e16abf Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 5 Mar 2024 14:42:06 +0100 Subject: [PATCH] Refactor xtask bench --- xtask/src/bench/assets.rs | 250 ++++++++ xtask/src/bench/client.rs | 80 +++ xtask/src/bench/command.rs | 194 ++++++ xtask/src/bench/dashboard.rs | 167 +++++ xtask/src/bench/meili_process.rs | 112 ++++ xtask/src/bench/mod.rs | 1014 +----------------------------- xtask/src/bench/workload.rs | 262 ++++++++ 7 files changed, 1094 insertions(+), 985 deletions(-) create mode 100644 xtask/src/bench/assets.rs create mode 100644 xtask/src/bench/client.rs create mode 100644 xtask/src/bench/command.rs create mode 100644 xtask/src/bench/dashboard.rs create mode 100644 xtask/src/bench/meili_process.rs create mode 100644 xtask/src/bench/workload.rs diff --git a/xtask/src/bench/assets.rs b/xtask/src/bench/assets.rs new file mode 100644 index 000000000..241928dbf --- /dev/null +++ b/xtask/src/bench/assets.rs @@ -0,0 +1,250 @@ +use std::collections::BTreeMap; +use std::io::{Read as _, Seek as _, Write as _}; + +use anyhow::{bail, Context}; +use futures_util::TryStreamExt as _; +use serde::Deserialize; +use sha2::Digest; + +use super::client::Client; + +#[derive(Deserialize, Clone)] +pub struct Asset { + pub local_location: Option, + pub remote_location: Option, + #[serde(default)] + pub format: AssetFormat, + pub sha256: Option, +} + +#[derive(Deserialize, Default, Copy, Clone)] +pub enum AssetFormat { + #[default] + Auto, + Json, + NdJson, + Raw, +} + +impl AssetFormat { + pub fn to_content_type(self, filename: &str) -> &'static str { + match self { + AssetFormat::Auto => Self::auto_detect(filename).to_content_type(filename), + AssetFormat::Json => "application/json", + AssetFormat::NdJson => "application/x-ndjson", + AssetFormat::Raw => "application/octet-stream", + } + } + + fn auto_detect(filename: &str) -> Self { + let path = std::path::Path::new(filename); + match path.extension().and_then(|extension| extension.to_str()) { + Some(extension) if extension.eq_ignore_ascii_case("json") => Self::Json, + Some(extension) if extension.eq_ignore_ascii_case("ndjson") => Self::NdJson, + extension => { + tracing::warn!(asset = filename, ?extension, "asset has format `Auto`, but extension was not recognized. Specify `Raw` format to suppress this warning."); + AssetFormat::Raw + } + } + } +} + +pub fn fetch_asset( + name: &str, + assets: &BTreeMap, + asset_folder: &str, +) -> anyhow::Result<(std::fs::File, AssetFormat)> { + let asset = + assets.get(name).with_context(|| format!("could not find asset with name '{name}'"))?; + let filename = if let Some(local_filename) = &asset.local_location { + local_filename.clone() + } else { + format!("{asset_folder}/{name}") + }; + + Ok(( + std::fs::File::open(&filename) + .with_context(|| format!("could not open asset '{name}' at '{filename}'"))?, + asset.format, + )) +} + +#[tracing::instrument(skip(client, assets), fields(asset_count = assets.len()))] +pub async fn fetch_assets( + client: &Client, + assets: &BTreeMap, + asset_folder: &str, +) -> anyhow::Result<()> { + let mut download_tasks = tokio::task::JoinSet::new(); + for (name, asset) in assets { + // trying local + if let Some(local) = &asset.local_location { + match std::fs::File::open(local) { + Ok(file) => { + if check_sha256(name, asset, file)? { + continue; + } else { + tracing::warn!(asset = name, file = local, "found local resource for asset but hash differed, skipping to asset store"); + } + } + Err(error) => match error.kind() { + std::io::ErrorKind::NotFound => { /* file does not exist, go to remote, no need for logs */ + } + _ => tracing::warn!( + error = &error as &dyn std::error::Error, + "error checking local resource, skipping to asset store" + ), + }, + } + } + + // checking asset store + let store_filename = format!("{}/{}", asset_folder, name); + + match std::fs::File::open(&store_filename) { + Ok(file) => { + if check_sha256(name, asset, file)? { + continue; + } else { + tracing::warn!(asset = name, file = store_filename, "found resource for asset in asset store, but hash differed, skipping to remote method"); + } + } + Err(error) => match error.kind() { + std::io::ErrorKind::NotFound => { /* file does not exist, go to remote, no need for logs */ + } + _ => tracing::warn!( + error = &error as &dyn std::error::Error, + "error checking resource in store, skipping to remote method" + ), + }, + } + + // downloading remote + match &asset.remote_location { + Some(location) => { + std::fs::create_dir_all(asset_folder).with_context(|| format!("could not create asset folder at {asset_folder}"))?; + download_tasks.spawn({ + let client = client.clone(); + let name = name.to_string(); + let location = location.to_string(); + let store_filename = store_filename.clone(); + let asset = asset.clone(); + download_asset(client, name, asset, location, store_filename)}); + }, + None => bail!("asset {name} has no remote location, but was not found locally or in the asset store"), + } + } + + while let Some(res) = download_tasks.join_next().await { + res.context("download task panicked")?.context("download task failed")?; + } + + Ok(()) +} + +fn check_sha256(name: &str, asset: &Asset, mut file: std::fs::File) -> anyhow::Result { + let mut bytes = Vec::new(); + file.read_to_end(&mut bytes).with_context(|| format!("hashing file for asset {name}"))?; + let mut file_hash = sha2::Sha256::new(); + file_hash.update(&bytes); + let file_hash = file_hash.finalize(); + let file_hash = format!("{:x}", file_hash); + tracing::debug!(hash = file_hash, "hashed local file"); + + Ok(match &asset.sha256 { + Some(hash) => { + tracing::debug!(hash, "hash from workload"); + if hash.to_ascii_lowercase() == file_hash { + true + } else { + tracing::warn!( + file_hash, + asset_hash = hash.to_ascii_lowercase(), + "hashes don't match" + ); + false + } + } + None => { + tracing::warn!(sha256 = file_hash, "Skipping hash for asset {name} that doesn't have one. Please add it to workload file"); + true + } + }) +} + +#[tracing::instrument(skip(client, asset, name), fields(asset = name))] +async fn download_asset( + client: Client, + name: String, + asset: Asset, + src: String, + dest_filename: String, +) -> anyhow::Result<()> { + let context = || format!("failure downloading asset {name} from {src}"); + + let response = client.get(&src).send().await.with_context(context)?; + + let file = std::fs::File::options() + .create(true) + .truncate(true) + .write(true) + .read(true) + .open(&dest_filename) + .with_context(|| format!("creating destination file {dest_filename}")) + .with_context(context)?; + + let mut dest = std::io::BufWriter::new( + file.try_clone().context("cloning I/O handle").with_context(context)?, + ); + + let total_len: Option = response + .headers() + .get(reqwest::header::CONTENT_LENGTH) + .and_then(|value| value.to_str().ok()) + .and_then(|value| value.parse().ok()); + + let progress = tokio::spawn({ + let name = name.clone(); + async move { + loop { + match file.metadata().context("could not get file metadata") { + Ok(metadata) => { + let len = metadata.len(); + tracing::info!( + asset = name, + downloaded_bytes = len, + total_bytes = total_len, + "asset download in progress" + ); + } + Err(error) => { + tracing::warn!(%error, "could not get file metadata"); + } + } + tokio::time::sleep(std::time::Duration::from_secs(60)).await; + } + } + }); + + let writing_context = || format!("while writing to destination file at {dest_filename}"); + + let mut response = response.bytes_stream(); + + while let Some(bytes) = + response.try_next().await.context("while downloading file").with_context(context)? + { + dest.write_all(&bytes).with_context(writing_context).with_context(context)?; + } + + progress.abort(); + + let mut file = dest.into_inner().with_context(writing_context).with_context(context)?; + + file.rewind().context("while rewinding asset file")?; + + if !check_sha256(&name, &asset, file)? { + bail!("asset '{name}': sha256 mismatch for file {dest_filename} downloaded from {src}") + } + + Ok(()) +} diff --git a/xtask/src/bench/client.rs b/xtask/src/bench/client.rs new file mode 100644 index 000000000..3e46615cc --- /dev/null +++ b/xtask/src/bench/client.rs @@ -0,0 +1,80 @@ +use anyhow::Context; +use serde::Deserialize; + +#[derive(Debug, Clone)] +pub struct Client { + base_url: Option, + client: reqwest::Client, +} + +impl Client { + pub fn new( + base_url: Option, + api_key: Option<&str>, + timeout: Option, + ) -> anyhow::Result { + let mut headers = reqwest::header::HeaderMap::new(); + if let Some(api_key) = api_key { + headers.append( + reqwest::header::AUTHORIZATION, + reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}")) + .context("Invalid authorization header")?, + ); + } + + let client = reqwest::ClientBuilder::new().default_headers(headers); + let client = if let Some(timeout) = timeout { client.timeout(timeout) } else { client }; + let client = client.build()?; + Ok(Self { base_url, client }) + } + + pub fn request(&self, method: reqwest::Method, route: &str) -> reqwest::RequestBuilder { + if let Some(base_url) = &self.base_url { + if route.is_empty() { + self.client.request(method, base_url) + } else { + self.client.request(method, format!("{}/{}", base_url, route)) + } + } else { + self.client.request(method, route) + } + } + + pub fn get(&self, route: &str) -> reqwest::RequestBuilder { + self.request(reqwest::Method::GET, route) + } + + pub fn put(&self, route: &str) -> reqwest::RequestBuilder { + self.request(reqwest::Method::PUT, route) + } + + pub fn post(&self, route: &str) -> reqwest::RequestBuilder { + self.request(reqwest::Method::POST, route) + } + + pub fn delete(&self, route: &str) -> reqwest::RequestBuilder { + self.request(reqwest::Method::DELETE, route) + } +} + +#[derive(Debug, Clone, Copy, Deserialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum Method { + Get, + Post, + Patch, + Delete, + Put, +} + +impl From for reqwest::Method { + fn from(value: Method) -> Self { + match value { + Method::Get => Self::GET, + Method::Post => Self::POST, + Method::Patch => Self::PATCH, + Method::Delete => Self::DELETE, + Method::Put => Self::PUT, + } + } +} diff --git a/xtask/src/bench/command.rs b/xtask/src/bench/command.rs new file mode 100644 index 000000000..0f0b5d213 --- /dev/null +++ b/xtask/src/bench/command.rs @@ -0,0 +1,194 @@ +use std::collections::BTreeMap; +use std::fmt::Display; +use std::io::Read as _; + +use anyhow::{bail, Context as _}; +use serde::Deserialize; + +use super::assets::{fetch_asset, Asset}; +use super::client::{Client, Method}; + +#[derive(Clone, Deserialize)] +pub struct Command { + pub route: String, + pub method: Method, + #[serde(default)] + pub body: Body, + #[serde(default)] + pub synchronous: SyncMode, +} + +#[derive(Default, Clone, Deserialize)] +#[serde(untagged)] +pub enum Body { + Inline { + inline: serde_json::Value, + }, + Asset { + asset: String, + }, + #[default] + Empty, +} + +impl Body { + pub fn get( + self, + assets: &BTreeMap, + asset_folder: &str, + ) -> anyhow::Result, &'static str)>> { + Ok(match self { + Body::Inline { inline: body } => Some(( + serde_json::to_vec(&body) + .context("serializing to bytes") + .context("while getting inline body")?, + "application/json", + )), + Body::Asset { asset: name } => Some({ + let context = || format!("while getting body from asset '{name}'"); + let (mut file, format) = + fetch_asset(&name, assets, asset_folder).with_context(context)?; + let mut buf = Vec::new(); + file.read_to_end(&mut buf).with_context(context)?; + (buf, format.to_content_type(&name)) + }), + Body::Empty => None, + }) + } +} + +impl Display for Command { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?} {} ({:?})", self.method, self.route, self.synchronous) + } +} + +#[derive(Default, Debug, Clone, Copy, Deserialize)] +pub enum SyncMode { + DontWait, + #[default] + WaitForResponse, + WaitForTask, +} + +pub async fn run_batch( + client: &Client, + batch: &[Command], + assets: &BTreeMap, + asset_folder: &str, +) -> anyhow::Result<()> { + let [.., last] = batch else { return Ok(()) }; + let sync = last.synchronous; + + let mut tasks = tokio::task::JoinSet::new(); + + for command in batch { + // FIXME: you probably don't want to copy assets everytime here + tasks.spawn({ + let client = client.clone(); + let command = command.clone(); + let assets = assets.clone(); + let asset_folder = asset_folder.to_owned(); + + async move { run(client, command, &assets, &asset_folder).await } + }); + } + + while let Some(result) = tasks.join_next().await { + result + .context("panicked while executing command")? + .context("error while executing command")?; + } + + match sync { + SyncMode::DontWait => {} + SyncMode::WaitForResponse => {} + SyncMode::WaitForTask => wait_for_tasks(client).await?, + } + + Ok(()) +} + +async fn wait_for_tasks(client: &Client) -> anyhow::Result<()> { + loop { + let response = client + .get("tasks?statuses=enqueued,processing") + .send() + .await + .context("could not wait for tasks")?; + let response: serde_json::Value = response + .json() + .await + .context("could not deserialize response to JSON") + .context("could not wait for tasks")?; + match response.get("total") { + Some(serde_json::Value::Number(number)) => { + let number = number.as_u64().with_context(|| { + format!("waiting for tasks: could not parse 'total' as integer, got {}", number) + })?; + if number == 0 { + break; + } else { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + continue; + } + } + Some(thing_else) => { + bail!(format!( + "waiting for tasks: could not parse 'total' as a number, got '{thing_else}'" + )) + } + None => { + bail!(format!( + "waiting for tasks: expected response to contain 'total', got '{response}'" + )) + } + } + } + Ok(()) +} + +#[tracing::instrument(skip(client, command, assets, asset_folder), fields(command = %command))] +pub async fn run( + client: Client, + mut command: Command, + assets: &BTreeMap, + asset_folder: &str, +) -> anyhow::Result<()> { + // memtake the body here to leave an empty body in its place, so that command is not partially moved-out + let body = std::mem::take(&mut command.body) + .get(assets, asset_folder) + .with_context(|| format!("while getting body for command {command}"))?; + + let request = client.request(command.method.into(), &command.route); + + let request = if let Some((body, content_type)) = body { + request.body(body).header(reqwest::header::CONTENT_TYPE, content_type) + } else { + request + }; + + let response = + request.send().await.with_context(|| format!("error sending command: {}", command))?; + + let code = response.status(); + if code.is_client_error() { + tracing::error!(%command, %code, "error in workload file"); + let response: serde_json::Value = response + .json() + .await + .context("could not deserialize response as JSON") + .context("parsing error in workload file when sending command")?; + bail!("error in workload file: server responded with error code {code} and '{response}'") + } else if code.is_server_error() { + tracing::error!(%command, %code, "server error"); + let response: serde_json::Value = response + .json() + .await + .context("could not deserialize response as JSON") + .context("parsing server error when sending command")?; + bail!("server error: server responded with error code {code} and '{response}'") + } + + Ok(()) +} diff --git a/xtask/src/bench/dashboard.rs b/xtask/src/bench/dashboard.rs new file mode 100644 index 000000000..833426207 --- /dev/null +++ b/xtask/src/bench/dashboard.rs @@ -0,0 +1,167 @@ +use std::collections::BTreeMap; + +use anyhow::{bail, Context}; +use serde_json::json; +use tokio::signal::ctrl_c; +use tokio::task::AbortHandle; +use tracing_trace::processor::span_stats::CallStats; +use uuid::Uuid; + +use super::client::Client; +use super::env_info; +use super::workload::Workload; + +pub async fn cancel_on_ctrl_c( + invocation_uuid: Uuid, + dashboard_client: Client, + abort_handle: AbortHandle, +) { + tracing::info!("press Ctrl-C to cancel the invocation"); + match ctrl_c().await { + Ok(()) => { + tracing::info!(%invocation_uuid, "received Ctrl-C, cancelling invocation"); + mark_as_failed(dashboard_client, invocation_uuid, None).await; + abort_handle.abort(); + } + Err(error) => tracing::warn!( + error = &error as &dyn std::error::Error, + "failed to listen to Ctrl-C signal, invocation won't be canceled on Ctrl-C" + ), + } +} + +pub async fn mark_as_failed( + dashboard_client: Client, + invocation_uuid: Uuid, + failure_reason: Option, +) { + let response = dashboard_client + .post("cancel-invocation") + .json(&json!({ + "invocation_uuid": invocation_uuid, + "failure_reason": failure_reason, + })) + .send() + .await; + let response = match response { + Ok(response) => response, + Err(response_error) => { + tracing::error!(error = &response_error as &dyn std::error::Error, %invocation_uuid, "could not mark invocation as failed"); + return; + } + }; + + if !response.status().is_success() { + tracing::error!( + %invocation_uuid, + "could not mark invocation as failed: {}", + response.text().await.unwrap() + ); + return; + } + tracing::warn!(%invocation_uuid, "marked invocation as failed or canceled"); +} + +pub async fn send_machine_info( + dashboard_client: &Client, + env: &env_info::Environment, +) -> anyhow::Result<()> { + let response = dashboard_client + .put("machine") + .json(&json!({"hostname": env.hostname})) + .send() + .await + .context("sending machine information")?; + if !response.status().is_success() { + bail!( + "could not send machine information: {} {}", + response.status(), + response.text().await.unwrap_or_else(|_| "unknown".into()) + ); + } + Ok(()) +} + +pub async fn create_invocation( + dashboard_client: &Client, + build_info: build_info::BuildInfo, + commit_message: &str, + env: env_info::Environment, + max_workloads: usize, + reason: Option<&str>, +) -> anyhow::Result { + let response = dashboard_client + .put("invocation") + .json(&json!({ + "commit": { + "sha1": build_info.commit_sha1, + "message": commit_message, + "commit_date": build_info.commit_timestamp, + "branch": build_info.branch, + "tag": build_info.describe.and_then(|describe| describe.as_tag()), + }, + "machine_hostname": env.hostname, + "max_workloads": max_workloads, + "reason": reason + })) + .send() + .await + .context("sending invocation")?; + if !response.status().is_success() { + bail!( + "could not send new invocation: {}", + response.text().await.unwrap_or_else(|_| "unknown".into()) + ); + } + let invocation_uuid: Uuid = + response.json().await.context("could not deserialize invocation response as JSON")?; + Ok(invocation_uuid) +} + +pub async fn create_workload( + dashboard_client: &Client, + invocation_uuid: Uuid, + workload: &Workload, +) -> anyhow::Result { + let response = dashboard_client + .put("workload") + .json(&json!({ + "invocation_uuid": invocation_uuid, + "name": &workload.name, + "max_runs": workload.run_count, + })) + .send() + .await + .context("could not create new workload")?; + + if !response.status().is_success() { + bail!("creating new workload failed: {}", response.text().await.unwrap()) + } + + let workload_uuid: Uuid = + response.json().await.context("could not deserialize JSON as UUID")?; + Ok(workload_uuid) +} + +pub async fn create_run( + dashboard_client: Client, + workload_uuid: Uuid, + report: &BTreeMap, +) -> anyhow::Result<()> { + let response = dashboard_client + .put("run") + .json(&json!({ + "workload_uuid": workload_uuid, + "data": report + })) + .send() + .await + .context("sending new run")?; + if !response.status().is_success() { + bail!( + "sending new run failed: {}", + response.text().await.unwrap_or_else(|_| "unknown".into()) + ) + } + Ok(()) +} diff --git a/xtask/src/bench/meili_process.rs b/xtask/src/bench/meili_process.rs new file mode 100644 index 000000000..99f6f4ea6 --- /dev/null +++ b/xtask/src/bench/meili_process.rs @@ -0,0 +1,112 @@ +use std::collections::BTreeMap; + +use anyhow::{bail, Context as _}; + +use super::assets::Asset; +use super::client::Client; +use super::workload::Workload; + +pub async fn kill(mut meilisearch: tokio::process::Child) { + if let Err(error) = meilisearch.kill().await { + tracing::warn!( + error = &error as &dyn std::error::Error, + "while terminating Meilisearch server" + ) + } +} + +#[tracing::instrument] +pub async fn build() -> anyhow::Result<()> { + let mut command = tokio::process::Command::new("cargo"); + command.arg("build").arg("--release").arg("-p").arg("meilisearch"); + + command.kill_on_drop(true); + + let mut builder = command.spawn().context("error building Meilisearch")?; + + if !builder.wait().await.context("could not build Meilisearch")?.success() { + bail!("failed building Meilisearch") + } + + Ok(()) +} + +#[tracing::instrument(skip(client, master_key, workload), fields(workload = workload.name))] +pub async fn start( + client: &Client, + master_key: Option<&str>, + workload: &Workload, + asset_folder: &str, +) -> anyhow::Result { + let mut command = tokio::process::Command::new("cargo"); + command + .arg("run") + .arg("--release") + .arg("-p") + .arg("meilisearch") + .arg("--bin") + .arg("meilisearch") + .arg("--"); + + command.arg("--db-path").arg("./_xtask_benchmark.ms"); + if let Some(master_key) = master_key { + command.arg("--master-key").arg(master_key); + } + command.arg("--experimental-enable-logs-route"); + + for extra_arg in workload.extra_cli_args.iter() { + command.arg(extra_arg); + } + + command.kill_on_drop(true); + + let mut meilisearch = command.spawn().context("Error starting Meilisearch")?; + + wait_for_health(client, &mut meilisearch, &workload.assets, asset_folder).await?; + + Ok(meilisearch) +} + +async fn wait_for_health( + client: &Client, + meilisearch: &mut tokio::process::Child, + assets: &BTreeMap, + asset_folder: &str, +) -> anyhow::Result<()> { + for i in 0..100 { + let res = super::command::run(client.clone(), health_command(), assets, asset_folder).await; + if res.is_ok() { + // check that this is actually the current Meilisearch instance that answered us + if let Some(exit_code) = + meilisearch.try_wait().context("cannot check Meilisearch server process status")? + { + tracing::error!("Got an health response from a different process"); + bail!("Meilisearch server exited early with code {exit_code}"); + } + + return Ok(()); + } + tokio::time::sleep(std::time::Duration::from_millis(500)).await; + // check whether the Meilisearch instance exited early (cut the wait) + if let Some(exit_code) = + meilisearch.try_wait().context("cannot check Meilisearch server process status")? + { + bail!("Meilisearch server exited early with code {exit_code}"); + } + tracing::debug!(attempt = i, "Waiting for Meilisearch to go up"); + } + bail!("meilisearch is not responding") +} + +fn health_command() -> super::command::Command { + super::command::Command { + route: "/health".into(), + method: super::client::Method::Get, + body: Default::default(), + synchronous: super::command::SyncMode::WaitForResponse, + } +} + +pub fn delete_db() { + let _ = std::fs::remove_dir_all("./_xtask_benchmark.ms"); +} diff --git a/xtask/src/bench/mod.rs b/xtask/src/bench/mod.rs index cfc7c124f..62c11b604 100644 --- a/xtask/src/bench/mod.rs +++ b/xtask/src/bench/mod.rs @@ -1,20 +1,21 @@ +mod assets; +mod client; +mod command; +mod dashboard; mod env_info; +mod meili_process; +mod workload; -use std::collections::BTreeMap; -use std::fmt::Display; -use std::io::{Read, Seek, Write}; use std::path::PathBuf; -use anyhow::{bail, Context}; +use anyhow::Context; use clap::Parser; -use futures_util::TryStreamExt; -use serde::Deserialize; -use serde_json::json; -use sha2::Digest; use tracing_subscriber::fmt::format::FmtSpan; use tracing_subscriber::layer::SubscriberExt; use tracing_subscriber::Layer; -use uuid::Uuid; + +use self::client::Client; +use self::workload::Workload; pub fn default_http_addr() -> String { "127.0.0.1:7700".to_string() @@ -35,62 +36,6 @@ pub fn default_dashboard_url() -> String { "http://localhost:9001".into() } -#[derive(Debug, Clone)] -pub struct Client { - base_url: Option, - client: reqwest::Client, -} - -impl Client { - pub fn new( - base_url: Option, - api_key: Option<&str>, - timeout: Option, - ) -> anyhow::Result { - let mut headers = reqwest::header::HeaderMap::new(); - if let Some(api_key) = api_key { - headers.append( - reqwest::header::AUTHORIZATION, - reqwest::header::HeaderValue::from_str(&format!("Bearer {api_key}")) - .context("Invalid authorization header")?, - ); - } - - let client = reqwest::ClientBuilder::new().default_headers(headers); - let client = if let Some(timeout) = timeout { client.timeout(timeout) } else { client }; - let client = client.build()?; - Ok(Self { base_url, client }) - } - - pub fn request(&self, method: reqwest::Method, route: &str) -> reqwest::RequestBuilder { - if let Some(base_url) = &self.base_url { - if route.is_empty() { - self.client.request(method, base_url) - } else { - self.client.request(method, format!("{}/{}", base_url, route)) - } - } else { - self.client.request(method, route) - } - } - - pub fn get(&self, route: &str) -> reqwest::RequestBuilder { - self.request(reqwest::Method::GET, route) - } - - pub fn put(&self, route: &str) -> reqwest::RequestBuilder { - self.request(reqwest::Method::PUT, route) - } - - pub fn post(&self, route: &str) -> reqwest::RequestBuilder { - self.request(reqwest::Method::POST, route) - } - - pub fn delete(&self, route: &str) -> reqwest::RequestBuilder { - self.request(reqwest::Method::DELETE, route) - } -} - /// Run benchmarks from a workload #[derive(Parser, Debug)] pub struct BenchDeriveArgs { @@ -134,166 +79,11 @@ pub struct BenchDeriveArgs { reason: Option, } -#[derive(Deserialize)] -pub struct Workload { - pub name: String, - pub run_count: u16, - pub extra_cli_args: Vec, - pub assets: BTreeMap, - pub commands: Vec, -} - -#[derive(Deserialize, Clone)] -pub struct Asset { - pub local_location: Option, - pub remote_location: Option, - #[serde(default)] - pub format: AssetFormat, - pub sha256: Option, -} - -#[derive(Deserialize, Default, Copy, Clone)] -pub enum AssetFormat { - #[default] - Auto, - Json, - NdJson, - Raw, -} -impl AssetFormat { - fn to_content_type(self, filename: &str) -> &'static str { - match self { - AssetFormat::Auto => Self::auto_detect(filename).to_content_type(filename), - AssetFormat::Json => "application/json", - AssetFormat::NdJson => "application/x-ndjson", - AssetFormat::Raw => "application/octet-stream", - } - } - - fn auto_detect(filename: &str) -> Self { - let path = std::path::Path::new(filename); - match path.extension().and_then(|extension| extension.to_str()) { - Some(extension) if extension.eq_ignore_ascii_case("json") => Self::Json, - Some(extension) if extension.eq_ignore_ascii_case("ndjson") => Self::NdJson, - extension => { - tracing::warn!(asset = filename, ?extension, "asset has format `Auto`, but extension was not recognized. Specify `Raw` format to suppress this warning."); - AssetFormat::Raw - } - } - } -} - -#[derive(Clone, Deserialize)] -pub struct Command { - pub route: String, - pub method: Method, - #[serde(default)] - pub body: Body, - #[serde(default)] - pub synchronous: SyncMode, -} - -#[derive(Default, Clone, Deserialize)] -#[serde(untagged)] -pub enum Body { - Inline { - inline: serde_json::Value, - }, - Asset { - asset: String, - }, - #[default] - Empty, -} - -impl Body { - pub fn get( - self, - assets: &BTreeMap, - asset_folder: &str, - ) -> anyhow::Result, &'static str)>> { - Ok(match self { - Body::Inline { inline: body } => Some(( - serde_json::to_vec(&body) - .context("serializing to bytes") - .context("while getting inline body")?, - "application/json", - )), - Body::Asset { asset: name } => Some({ - let context = || format!("while getting body from asset '{name}'"); - let (mut file, format) = - fetch_asset(&name, assets, asset_folder).with_context(context)?; - let mut buf = Vec::new(); - file.read_to_end(&mut buf).with_context(context)?; - (buf, format.to_content_type(&name)) - }), - Body::Empty => None, - }) - } -} - -fn fetch_asset( - name: &str, - assets: &BTreeMap, - asset_folder: &str, -) -> anyhow::Result<(std::fs::File, AssetFormat)> { - let asset = - assets.get(name).with_context(|| format!("could not find asset with name '{name}'"))?; - let filename = if let Some(local_filename) = &asset.local_location { - local_filename.clone() - } else { - format!("{asset_folder}/{name}") - }; - - Ok(( - std::fs::File::open(&filename) - .with_context(|| format!("could not open asset '{name}' at '{filename}'"))?, - asset.format, - )) -} - -impl Display for Command { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?} {} ({:?})", self.method, self.route, self.synchronous) - } -} - -#[derive(Debug, Clone, Copy, Deserialize)] -pub enum Method { - GET, - POST, - PATCH, - DELETE, - PUT, -} - -impl From for reqwest::Method { - fn from(value: Method) -> Self { - match value { - Method::GET => Self::GET, - Method::POST => Self::POST, - Method::PATCH => Self::PATCH, - Method::DELETE => Self::DELETE, - Method::PUT => Self::PUT, - } - } -} - -#[derive(Default, Debug, Clone, Copy, Deserialize)] -pub enum SyncMode { - DontWait, - #[default] - WaitForResponse, - WaitForTask, -} - pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { + // setup logs let filter: tracing_subscriber::filter::Targets = args.log_filter.parse().context("invalid --log-filter")?; - let env = env_info::Environment::generate_from_current_config(); - let build_info = build_info::BuildInfo::from_build(); - let subscriber = tracing_subscriber::registry().with( tracing_subscriber::fmt::layer() .with_span_events(FmtSpan::NEW | FmtSpan::CLOSE) @@ -301,9 +91,15 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { ); tracing::subscriber::set_global_default(subscriber).context("could not setup logging")?; + // fetch environment and build info + let env = env_info::Environment::generate_from_current_config(); + let build_info = build_info::BuildInfo::from_build(); + + // tokio runtime let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; let _scope = rt.enter(); + // setup clients let assets_client = Client::new(None, args.assets_key.as_deref(), Some(std::time::Duration::from_secs(3600)))?; // 1h @@ -328,55 +124,19 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { Some(std::time::Duration::from_secs(60)), )?; + // enter runtime + rt.block_on(async { - let response = dashboard_client - .put("machine") - .json(&json!({"hostname": env.hostname})) - .send() - .await - .context("sending machine information")?; - if !response.status().is_success() { - bail!( - "could not send machine information: {} {}", - response.status(), - response.text().await.unwrap_or_else(|_| "unknown".into()) - ); - } + dashboard::send_machine_info(&dashboard_client, &env).await?; let commit_message = build_info.commit_msg.context("missing commit message")?.split('\n').next().unwrap(); let max_workloads = args.workload_file.len(); let reason: Option<&str> = args.reason.as_deref(); - let response = dashboard_client - .put("invocation") - .json(&json!({ - "commit": { - "sha1": build_info.commit_sha1, - "message": commit_message, - "commit_date": build_info.commit_timestamp, - "branch": build_info.branch, - "tag": build_info.describe.and_then(|describe| describe.as_tag()), - }, - "machine_hostname": env.hostname, - "max_workloads": max_workloads, - "reason": reason - })) - .send() - .await - .context("sending invocation")?; - - if !response.status().is_success() { - bail!( - "could not send new invocation: {}", - response.text().await.unwrap_or_else(|_| "unknown".into()) - ); - } - - let invocation_uuid: Uuid = - response.json().await.context("could not deserialize invocation response as JSON")?; - - + let invocation_uuid = dashboard::create_invocation(&dashboard_client, build_info, commit_message, env, max_workloads, reason).await?; tracing::info!(workload_count = args.workload_file.len(), "handling workload files"); + + // main task let workload_runs = tokio::spawn( { let dashboard_client = dashboard_client.clone(); @@ -388,7 +148,7 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { ) .with_context(|| format!("error parsing {} as JSON", workload_file.display()))?; - run_workload( + workload::execute( &assets_client, &dashboard_client, &logs_client, @@ -403,26 +163,14 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { Ok::<(), anyhow::Error>(()) }}); + // handle ctrl-c let abort_handle = workload_runs.abort_handle(); - tokio::spawn({ let dashboard_client = dashboard_client.clone(); - async move { - tracing::info!("press Ctrl-C to cancel the invocation"); - match tokio::signal::ctrl_c().await { - Ok(()) => { - tracing::info!(%invocation_uuid, "received Ctrl-C, cancelling invocation"); - mark_as_failed(dashboard_client, invocation_uuid, None).await; - abort_handle.abort(); - } - Err(error) => tracing::warn!( - error = &error as &dyn std::error::Error, - "failed to listen to Ctrl-C signal, invocation won't be canceled on Ctrl-C" - ), - } - } + dashboard::cancel_on_ctrl_c(invocation_uuid, dashboard_client, abort_handle) }); + // wait for the end of the main task, handle result match workload_runs.await { Ok(Ok(_)) => { tracing::info!("Success"); @@ -430,7 +178,7 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { } Ok(Err(error)) => { tracing::error!(%invocation_uuid, error = %error, "invocation failed, attempting to report the failure to dashboard"); - mark_as_failed(dashboard_client, invocation_uuid, Some(error.to_string())).await; + dashboard::mark_as_failed(dashboard_client, invocation_uuid, Some(error.to_string())).await; tracing::warn!(%invocation_uuid, "invocation marked as failed following error"); Err(error) }, @@ -438,7 +186,7 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { match join_error.try_into_panic() { Ok(panic) => { tracing::error!("invocation panicked, attempting to report the failure to dashboard"); - mark_as_failed(dashboard_client, invocation_uuid, Some("Panicked".into())).await; + dashboard::mark_as_failed(dashboard_client, invocation_uuid, Some("Panicked".into())).await; std::panic::resume_unwind(panic) } Err(_) => { @@ -453,707 +201,3 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> { Ok(()) } - -async fn mark_as_failed( - dashboard_client: Client, - invocation_uuid: Uuid, - failure_reason: Option, -) { - let response = dashboard_client - .post("cancel-invocation") - .json(&json!({ - "invocation_uuid": invocation_uuid, - "failure_reason": failure_reason, - })) - .send() - .await; - let response = match response { - Ok(response) => response, - Err(response_error) => { - tracing::error!(error = &response_error as &dyn std::error::Error, %invocation_uuid, "could not mark invocation as failed"); - return; - } - }; - - if !response.status().is_success() { - tracing::error!( - %invocation_uuid, - "could not mark invocation as failed: {}", - response.text().await.unwrap() - ); - return; - } - tracing::warn!(%invocation_uuid, "marked invocation as failed or canceled"); -} - -#[allow(clippy::too_many_arguments)] // not best code quality, but this is a benchmark runner -#[tracing::instrument(skip(assets_client, dashboard_client, logs_client, meili_client, workload, master_key, args), fields(workload = workload.name))] -async fn run_workload( - assets_client: &Client, - dashboard_client: &Client, - logs_client: &Client, - meili_client: &Client, - invocation_uuid: Uuid, - master_key: Option<&str>, - workload: Workload, - args: &BenchDeriveArgs, -) -> anyhow::Result<()> { - fetch_assets(assets_client, &workload.assets, &args.asset_folder).await?; - - let response = dashboard_client - .put("workload") - .json(&json!({ - "invocation_uuid": invocation_uuid, - "name": &workload.name, - "max_runs": workload.run_count, - })) - .send() - .await - .context("could not create new workload")?; - - if !response.status().is_success() { - bail!("creating new workload failed: {}", response.text().await.unwrap()) - } - - let workload_uuid: Uuid = - response.json().await.context("could not deserialize JSON as UUID")?; - - let mut tasks = Vec::new(); - - for i in 0..workload.run_count { - tasks.push( - run_workload_run( - dashboard_client, - logs_client, - meili_client, - workload_uuid, - master_key, - &workload, - args, - i, - ) - .await?, - ); - } - - let mut reports = Vec::with_capacity(workload.run_count as usize); - - for task in tasks { - reports.push( - task.await - .context("task panicked while processing report")? - .context("task failed while processing report")?, - ); - } - - tracing::info!(workload = workload.name, "Successful workload"); - - Ok(()) -} - -#[tracing::instrument(skip(client, assets), fields(asset_count = assets.len()))] -async fn fetch_assets( - client: &Client, - assets: &BTreeMap, - asset_folder: &str, -) -> anyhow::Result<()> { - let mut download_tasks = tokio::task::JoinSet::new(); - for (name, asset) in assets { - // trying local - if let Some(local) = &asset.local_location { - match std::fs::File::open(local) { - Ok(file) => { - if check_sha256(name, asset, file)? { - continue; - } else { - tracing::warn!(asset = name, file = local, "found local resource for asset but hash differed, skipping to asset store"); - } - } - Err(error) => match error.kind() { - std::io::ErrorKind::NotFound => { /* file does not exist, go to remote, no need for logs */ - } - _ => tracing::warn!( - error = &error as &dyn std::error::Error, - "error checking local resource, skipping to asset store" - ), - }, - } - } - - // checking asset store - let store_filename = format!("{}/{}", asset_folder, name); - - match std::fs::File::open(&store_filename) { - Ok(file) => { - if check_sha256(name, asset, file)? { - continue; - } else { - tracing::warn!(asset = name, file = store_filename, "found resource for asset in asset store, but hash differed, skipping to remote method"); - } - } - Err(error) => match error.kind() { - std::io::ErrorKind::NotFound => { /* file does not exist, go to remote, no need for logs */ - } - _ => tracing::warn!( - error = &error as &dyn std::error::Error, - "error checking resource in store, skipping to remote method" - ), - }, - } - - // downloading remote - match &asset.remote_location { - Some(location) => { - std::fs::create_dir_all(asset_folder).with_context(|| format!("could not create asset folder at {asset_folder}"))?; - download_tasks.spawn({ - let client = client.clone(); - let name = name.to_string(); - let location = location.to_string(); - let store_filename = store_filename.clone(); - let asset = asset.clone(); - download_asset(client, name, asset, location, store_filename)}); - }, - None => bail!("asset {name} has no remote location, but was not found locally or in the asset store"), - } - } - - while let Some(res) = download_tasks.join_next().await { - res.context("download task panicked")?.context("download task failed")?; - } - - Ok(()) -} - -fn check_sha256(name: &str, asset: &Asset, mut file: std::fs::File) -> anyhow::Result { - let mut bytes = Vec::new(); - file.read_to_end(&mut bytes).with_context(|| format!("hashing file for asset {name}"))?; - let mut file_hash = sha2::Sha256::new(); - file_hash.update(&bytes); - let file_hash = file_hash.finalize(); - let file_hash = format!("{:x}", file_hash); - tracing::debug!(hash = file_hash, "hashed local file"); - - Ok(match &asset.sha256 { - Some(hash) => { - tracing::debug!(hash, "hash from workload"); - if hash.to_ascii_lowercase() == file_hash { - true - } else { - tracing::warn!( - file_hash, - asset_hash = hash.to_ascii_lowercase(), - "hashes don't match" - ); - false - } - } - None => { - tracing::warn!(sha256 = file_hash, "Skipping hash for asset {name} that doesn't have one. Please add it to workload file"); - true - } - }) -} - -#[tracing::instrument(skip(client, asset, name), fields(asset = name))] -async fn download_asset( - client: Client, - name: String, - asset: Asset, - src: String, - dest_filename: String, -) -> anyhow::Result<()> { - let context = || format!("failure downloading asset {name} from {src}"); - - let response = client.get(&src).send().await.with_context(context)?; - - let file = std::fs::File::options() - .create(true) - .truncate(true) - .write(true) - .read(true) - .open(&dest_filename) - .with_context(|| format!("creating destination file {dest_filename}")) - .with_context(context)?; - - let mut dest = std::io::BufWriter::new( - file.try_clone().context("cloning I/O handle").with_context(context)?, - ); - - let total_len: Option = response - .headers() - .get(reqwest::header::CONTENT_LENGTH) - .and_then(|value| value.to_str().ok()) - .and_then(|value| value.parse().ok()); - - let progress = tokio::spawn({ - let name = name.clone(); - async move { - loop { - match file.metadata().context("could not get file metadata") { - Ok(metadata) => { - let len = metadata.len(); - tracing::info!( - asset = name, - downloaded_bytes = len, - total_bytes = total_len, - "asset download in progress" - ); - } - Err(error) => { - tracing::warn!(%error, "could not get file metadata"); - } - } - tokio::time::sleep(std::time::Duration::from_secs(60)).await; - } - } - }); - - let writing_context = || format!("while writing to destination file at {dest_filename}"); - - let mut response = response.bytes_stream(); - - while let Some(bytes) = - response.try_next().await.context("while downloading file").with_context(context)? - { - dest.write_all(&bytes).with_context(writing_context).with_context(context)?; - } - - progress.abort(); - - let mut file = dest.into_inner().with_context(writing_context).with_context(context)?; - - file.rewind().context("while rewinding asset file")?; - - if !check_sha256(&name, &asset, file)? { - bail!("asset '{name}': sha256 mismatch for file {dest_filename} downloaded from {src}") - } - - Ok(()) -} - -#[allow(clippy::too_many_arguments)] // not best code quality, but this is a benchmark runner -#[tracing::instrument(skip(dashboard_client, logs_client, meili_client, workload, master_key, args), fields(workload = %workload.name))] -async fn run_workload_run( - dashboard_client: &Client, - logs_client: &Client, - meili_client: &Client, - workload_uuid: Uuid, - master_key: Option<&str>, - workload: &Workload, - args: &BenchDeriveArgs, - run_number: u16, -) -> anyhow::Result>> { - delete_db(); - build_meilisearch().await?; - let meilisearch = - start_meilisearch(meili_client, master_key, workload, &args.asset_folder).await?; - let processor = run_commands( - dashboard_client, - logs_client, - meili_client, - workload_uuid, - workload, - args, - run_number, - ) - .await?; - - kill_meilisearch(meilisearch).await; - - tracing::info!(run_number, "Successful run"); - - Ok(processor) -} - -async fn kill_meilisearch(mut meilisearch: tokio::process::Child) { - if let Err(error) = meilisearch.kill().await { - tracing::warn!( - error = &error as &dyn std::error::Error, - "while terminating Meilisearch server" - ) - } -} - -#[tracing::instrument] -async fn build_meilisearch() -> anyhow::Result<()> { - let mut command = tokio::process::Command::new("cargo"); - command.arg("build").arg("--release").arg("-p").arg("meilisearch"); - - command.kill_on_drop(true); - - let mut builder = command.spawn().context("error building Meilisearch")?; - - if !builder.wait().await.context("could not build Meilisearch")?.success() { - bail!("failed building Meilisearch") - } - - Ok(()) -} - -#[tracing::instrument(skip(client, master_key, workload), fields(workload = workload.name))] -async fn start_meilisearch( - client: &Client, - master_key: Option<&str>, - workload: &Workload, - asset_folder: &str, -) -> anyhow::Result { - let mut command = tokio::process::Command::new("cargo"); - command - .arg("run") - .arg("--release") - .arg("-p") - .arg("meilisearch") - .arg("--bin") - .arg("meilisearch") - .arg("--"); - - command.arg("--db-path").arg("./_xtask_benchmark.ms"); - if let Some(master_key) = master_key { - command.arg("--master-key").arg(master_key); - } - command.arg("--experimental-enable-logs-route"); - - for extra_arg in workload.extra_cli_args.iter() { - command.arg(extra_arg); - } - - command.kill_on_drop(true); - - let mut meilisearch = command.spawn().context("Error starting Meilisearch")?; - - wait_for_health(client, &mut meilisearch, &workload.assets, asset_folder).await?; - - Ok(meilisearch) -} - -async fn wait_for_health( - client: &Client, - meilisearch: &mut tokio::process::Child, - assets: &BTreeMap, - asset_folder: &str, -) -> anyhow::Result<()> { - for i in 0..100 { - let res = run_command(client.clone(), health_command(), assets, asset_folder).await; - if res.is_ok() { - // check that this is actually the current Meilisearch instance that answered us - if let Some(exit_code) = - meilisearch.try_wait().context("cannot check Meilisearch server process status")? - { - tracing::error!("Got an health response from a different process"); - bail!("Meilisearch server exited early with code {exit_code}"); - } - - return Ok(()); - } - tokio::time::sleep(std::time::Duration::from_millis(500)).await; - // check whether the Meilisearch instance exited early (cut the wait) - if let Some(exit_code) = - meilisearch.try_wait().context("cannot check Meilisearch server process status")? - { - bail!("Meilisearch server exited early with code {exit_code}"); - } - tracing::debug!(attempt = i, "Waiting for Meilisearch to go up"); - } - bail!("meilisearch is not responding") -} - -fn health_command() -> Command { - Command { - route: "/health".into(), - method: Method::GET, - body: Default::default(), - synchronous: SyncMode::WaitForResponse, - } -} - -fn delete_db() { - let _ = std::fs::remove_dir_all("./_xtask_benchmark.ms"); -} - -async fn run_commands( - dashboard_client: &Client, - logs_client: &Client, - meili_client: &Client, - workload_uuid: Uuid, - workload: &Workload, - args: &BenchDeriveArgs, - run_number: u16, -) -> anyhow::Result>> { - let report_folder = &args.report_folder; - let workload_name = &workload.name; - - std::fs::create_dir_all(report_folder) - .with_context(|| format!("could not create report directory at {report_folder}"))?; - - let trace_filename = format!("{report_folder}/{workload_name}-{run_number}-trace.json"); - let report_filename = format!("{report_folder}/{workload_name}-{run_number}-report.json"); - - let report_handle = start_report(logs_client, trace_filename).await?; - - for batch in workload - .commands - .as_slice() - .split_inclusive(|command| !matches!(command.synchronous, SyncMode::DontWait)) - { - run_batch(meili_client, batch, &workload.assets, &args.asset_folder).await?; - } - - let processor = - stop_report(dashboard_client, logs_client, workload_uuid, report_filename, report_handle) - .await?; - - Ok(processor) -} - -async fn stop_report( - dashboard_client: &Client, - logs_client: &Client, - workload_uuid: Uuid, - filename: String, - report_handle: tokio::task::JoinHandle>, -) -> anyhow::Result>> { - let response = logs_client.delete("").send().await.context("while stopping report")?; - if !response.status().is_success() { - bail!("received HTTP {} while stopping report", response.status()) - } - - let mut file = tokio::time::timeout(std::time::Duration::from_secs(1000), report_handle) - .await - .context("while waiting for the end of the report")? - .context("report writing task panicked")? - .context("while writing report")?; - - file.rewind().context("while rewinding report file")?; - - let process_handle = tokio::task::spawn({ - let dashboard_client = dashboard_client.clone(); - async move { - let span = tracing::info_span!("processing trace to report", filename); - let _guard = span.enter(); - let report = tracing_trace::processor::span_stats::to_call_stats( - tracing_trace::TraceReader::new(std::io::BufReader::new(file)), - ) - .context("could not convert trace to report")?; - let context = || format!("writing report to {filename}"); - - let response = dashboard_client - .put("run") - .json(&json!({ - "workload_uuid": workload_uuid, - "data": report - })) - .send() - .await - .context("sending new run")?; - - if !response.status().is_success() { - bail!( - "sending new run failed: {}", - response.text().await.unwrap_or_else(|_| "unknown".into()) - ) - } - - let mut output_file = std::io::BufWriter::new( - std::fs::File::options() - .create(true) - .truncate(true) - .write(true) - .read(true) - .open(&filename) - .with_context(context)?, - ); - - for (key, value) in report { - serde_json::to_writer(&mut output_file, &json!({key: value})) - .context("serializing span stat")?; - writeln!(&mut output_file).with_context(context)?; - } - output_file.flush().with_context(context)?; - let mut output_file = output_file.into_inner().with_context(context)?; - - output_file.rewind().context("could not rewind output_file").with_context(context)?; - - tracing::info!("success"); - Ok(output_file) - } - }); - - Ok(process_handle) -} - -async fn start_report( - logs_client: &Client, - filename: String, -) -> anyhow::Result>> { - let report_file = std::fs::File::options() - .create(true) - .truncate(true) - .write(true) - .read(true) - .open(&filename) - .with_context(|| format!("could not create file at {filename}"))?; - let mut report_file = std::io::BufWriter::new(report_file); - - let response = logs_client - .post("") - .json(&json!({ - "mode": "profile", - "target": "indexing::=trace" - })) - .send() - .await - .context("failed to start report")?; - - let code = response.status(); - if code.is_client_error() { - tracing::error!(%code, "request error when trying to start report"); - let response: serde_json::Value = response - .json() - .await - .context("could not deserialize response as JSON") - .context("response error when trying to start report")?; - bail!( - "request error when trying to start report: server responded with error code {code} and '{response}'" - ) - } else if code.is_server_error() { - tracing::error!(%code, "server error when trying to start report"); - let response: serde_json::Value = response - .json() - .await - .context("could not deserialize response as JSON") - .context("response error trying to start report")?; - bail!("server error when trying to start report: server responded with error code {code} and '{response}'") - } - - Ok(tokio::task::spawn(async move { - let mut stream = response.bytes_stream(); - while let Some(bytes) = stream.try_next().await.context("while waiting for report")? { - report_file - .write_all(&bytes) - .with_context(|| format!("while writing report to {filename}"))?; - } - report_file.into_inner().with_context(|| format!("while writing report to {filename}")) - })) -} - -async fn run_batch( - client: &Client, - batch: &[Command], - assets: &BTreeMap, - asset_folder: &str, -) -> anyhow::Result<()> { - let [.., last] = batch else { return Ok(()) }; - let sync = last.synchronous; - - let mut tasks = tokio::task::JoinSet::new(); - - for command in batch { - // FIXME: you probably don't want to copy assets everytime here - tasks.spawn({ - let client = client.clone(); - let command = command.clone(); - let assets = assets.clone(); - let asset_folder = asset_folder.to_owned(); - - async move { run_command(client, command, &assets, &asset_folder).await } - }); - } - - while let Some(result) = tasks.join_next().await { - result - .context("panicked while executing command")? - .context("error while executing command")?; - } - - match sync { - SyncMode::DontWait => {} - SyncMode::WaitForResponse => {} - SyncMode::WaitForTask => wait_for_tasks(client).await?, - } - - Ok(()) -} - -async fn wait_for_tasks(client: &Client) -> anyhow::Result<()> { - loop { - let response = client - .get("tasks?statuses=enqueued,processing") - .send() - .await - .context("could not wait for tasks")?; - let response: serde_json::Value = response - .json() - .await - .context("could not deserialize response to JSON") - .context("could not wait for tasks")?; - match response.get("total") { - Some(serde_json::Value::Number(number)) => { - let number = number.as_u64().with_context(|| { - format!("waiting for tasks: could not parse 'total' as integer, got {}", number) - })?; - if number == 0 { - break; - } else { - tokio::time::sleep(std::time::Duration::from_secs(1)).await; - continue; - } - } - Some(thing_else) => { - bail!(format!( - "waiting for tasks: could not parse 'total' as a number, got '{thing_else}'" - )) - } - None => { - bail!(format!( - "waiting for tasks: expected response to contain 'total', got '{response}'" - )) - } - } - } - Ok(()) -} - -#[tracing::instrument(skip(client, command, assets, asset_folder), fields(command = %command))] -async fn run_command( - client: Client, - mut command: Command, - assets: &BTreeMap, - asset_folder: &str, -) -> anyhow::Result<()> { - // memtake the body here to leave an empty body in its place, so that command is not partially moved-out - let body = std::mem::take(&mut command.body) - .get(assets, asset_folder) - .with_context(|| format!("while getting body for command {command}"))?; - - let request = client.request(command.method.into(), &command.route); - - let request = if let Some((body, content_type)) = body { - request.body(body).header(reqwest::header::CONTENT_TYPE, content_type) - } else { - request - }; - - let response = - request.send().await.with_context(|| format!("error sending command: {}", command))?; - - let code = response.status(); - if code.is_client_error() { - tracing::error!(%command, %code, "error in workload file"); - let response: serde_json::Value = response - .json() - .await - .context("could not deserialize response as JSON") - .context("parsing error in workload file when sending command")?; - bail!("error in workload file: server responded with error code {code} and '{response}'") - } else if code.is_server_error() { - tracing::error!(%command, %code, "server error"); - let response: serde_json::Value = response - .json() - .await - .context("could not deserialize response as JSON") - .context("parsing server error when sending command")?; - bail!("server error: server responded with error code {code} and '{response}'") - } - - Ok(()) -} diff --git a/xtask/src/bench/workload.rs b/xtask/src/bench/workload.rs new file mode 100644 index 000000000..b3e952f29 --- /dev/null +++ b/xtask/src/bench/workload.rs @@ -0,0 +1,262 @@ +use std::collections::BTreeMap; +use std::fs::File; +use std::io::{Seek as _, Write as _}; + +use anyhow::{bail, Context as _}; +use futures_util::TryStreamExt as _; +use serde::Deserialize; +use serde_json::json; +use tokio::task::JoinHandle; +use uuid::Uuid; + +use super::assets::Asset; +use super::client::Client; +use super::command::SyncMode; +use super::BenchDeriveArgs; +use crate::bench::{assets, dashboard, meili_process}; + +#[derive(Deserialize)] +pub struct Workload { + pub name: String, + pub run_count: u16, + pub extra_cli_args: Vec, + pub assets: BTreeMap, + pub commands: Vec, +} + +async fn run_commands( + dashboard_client: &Client, + logs_client: &Client, + meili_client: &Client, + workload_uuid: Uuid, + workload: &Workload, + args: &BenchDeriveArgs, + run_number: u16, +) -> anyhow::Result>> { + let report_folder = &args.report_folder; + let workload_name = &workload.name; + + std::fs::create_dir_all(report_folder) + .with_context(|| format!("could not create report directory at {report_folder}"))?; + + let trace_filename = format!("{report_folder}/{workload_name}-{run_number}-trace.json"); + let report_filename = format!("{report_folder}/{workload_name}-{run_number}-report.json"); + + let report_handle = start_report(logs_client, trace_filename).await?; + + for batch in workload + .commands + .as_slice() + .split_inclusive(|command| !matches!(command.synchronous, SyncMode::DontWait)) + { + super::command::run_batch(meili_client, batch, &workload.assets, &args.asset_folder) + .await?; + } + + let processor = + stop_report(dashboard_client, logs_client, workload_uuid, report_filename, report_handle) + .await?; + + Ok(processor) +} + +#[allow(clippy::too_many_arguments)] // not best code quality, but this is a benchmark runner +#[tracing::instrument(skip(assets_client, dashboard_client, logs_client, meili_client, workload, master_key, args), fields(workload = workload.name))] +pub async fn execute( + assets_client: &Client, + dashboard_client: &Client, + logs_client: &Client, + meili_client: &Client, + invocation_uuid: Uuid, + master_key: Option<&str>, + workload: Workload, + args: &BenchDeriveArgs, +) -> anyhow::Result<()> { + assets::fetch_assets(assets_client, &workload.assets, &args.asset_folder).await?; + + let workload_uuid = + dashboard::create_workload(dashboard_client, invocation_uuid, &workload).await?; + + let mut tasks = Vec::new(); + + for i in 0..workload.run_count { + tasks.push( + execute_run( + dashboard_client, + logs_client, + meili_client, + workload_uuid, + master_key, + &workload, + args, + i, + ) + .await?, + ); + } + + let mut reports = Vec::with_capacity(workload.run_count as usize); + + for task in tasks { + reports.push( + task.await + .context("task panicked while processing report")? + .context("task failed while processing report")?, + ); + } + + tracing::info!(workload = workload.name, "Successful workload"); + + Ok(()) +} + +#[allow(clippy::too_many_arguments)] // not best code quality, but this is a benchmark runner +#[tracing::instrument(skip(dashboard_client, logs_client, meili_client, workload, master_key, args), fields(workload = %workload.name))] +async fn execute_run( + dashboard_client: &Client, + logs_client: &Client, + meili_client: &Client, + workload_uuid: Uuid, + master_key: Option<&str>, + workload: &Workload, + args: &BenchDeriveArgs, + run_number: u16, +) -> anyhow::Result>> { + meili_process::delete_db(); + + meili_process::build().await?; + let meilisearch = + meili_process::start(meili_client, master_key, workload, &args.asset_folder).await?; + + let processor = run_commands( + dashboard_client, + logs_client, + meili_client, + workload_uuid, + workload, + args, + run_number, + ) + .await?; + + meili_process::kill(meilisearch).await; + + tracing::info!(run_number, "Successful run"); + + Ok(processor) +} + +async fn start_report( + logs_client: &Client, + filename: String, +) -> anyhow::Result>> { + let report_file = std::fs::File::options() + .create(true) + .truncate(true) + .write(true) + .read(true) + .open(&filename) + .with_context(|| format!("could not create file at {filename}"))?; + let mut report_file = std::io::BufWriter::new(report_file); + + let response = logs_client + .post("") + .json(&json!({ + "mode": "profile", + "target": "indexing::=trace" + })) + .send() + .await + .context("failed to start report")?; + + let code = response.status(); + if code.is_client_error() { + tracing::error!(%code, "request error when trying to start report"); + let response: serde_json::Value = response + .json() + .await + .context("could not deserialize response as JSON") + .context("response error when trying to start report")?; + bail!( + "request error when trying to start report: server responded with error code {code} and '{response}'" + ) + } else if code.is_server_error() { + tracing::error!(%code, "server error when trying to start report"); + let response: serde_json::Value = response + .json() + .await + .context("could not deserialize response as JSON") + .context("response error trying to start report")?; + bail!("server error when trying to start report: server responded with error code {code} and '{response}'") + } + + Ok(tokio::task::spawn(async move { + let mut stream = response.bytes_stream(); + while let Some(bytes) = stream.try_next().await.context("while waiting for report")? { + report_file + .write_all(&bytes) + .with_context(|| format!("while writing report to {filename}"))?; + } + report_file.into_inner().with_context(|| format!("while writing report to {filename}")) + })) +} + +async fn stop_report( + dashboard_client: &Client, + logs_client: &Client, + workload_uuid: Uuid, + filename: String, + report_handle: tokio::task::JoinHandle>, +) -> anyhow::Result>> { + let response = logs_client.delete("").send().await.context("while stopping report")?; + if !response.status().is_success() { + bail!("received HTTP {} while stopping report", response.status()) + } + + let mut file = tokio::time::timeout(std::time::Duration::from_secs(1000), report_handle) + .await + .context("while waiting for the end of the report")? + .context("report writing task panicked")? + .context("while writing report")?; + + file.rewind().context("while rewinding report file")?; + + let process_handle = tokio::task::spawn({ + let dashboard_client = dashboard_client.clone(); + async move { + let span = tracing::info_span!("processing trace to report", filename); + let _guard = span.enter(); + let report = tracing_trace::processor::span_stats::to_call_stats( + tracing_trace::TraceReader::new(std::io::BufReader::new(file)), + ) + .context("could not convert trace to report")?; + let context = || format!("writing report to {filename}"); + + dashboard::create_run(dashboard_client, workload_uuid, &report).await?; + + let mut output_file = std::io::BufWriter::new( + std::fs::File::options() + .create(true) + .truncate(true) + .write(true) + .read(true) + .open(&filename) + .with_context(context)?, + ); + + for (key, value) in report { + serde_json::to_writer(&mut output_file, &json!({key: value})) + .context("serializing span stat")?; + writeln!(&mut output_file).with_context(context)?; + } + output_file.flush().with_context(context)?; + let mut output_file = output_file.into_inner().with_context(context)?; + + output_file.rewind().context("could not rewind output_file").with_context(context)?; + + Ok(output_file) + } + }); + + Ok(process_handle) +}