mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-26 05:07:28 +01:00
251 lines
8.5 KiB
Rust
251 lines
8.5 KiB
Rust
use std::collections::BTreeMap;
|
|
use std::io::{Read as _, Seek as _, Write as _};
|
|
|
|
use anyhow::{bail, Context};
|
|
use futures_util::TryStreamExt as _;
|
|
use serde::Deserialize;
|
|
use sha2::Digest;
|
|
|
|
use super::client::Client;
|
|
|
|
#[derive(Deserialize, Clone)]
|
|
pub struct Asset {
|
|
pub local_location: Option<String>,
|
|
pub remote_location: Option<String>,
|
|
#[serde(default)]
|
|
pub format: AssetFormat,
|
|
pub sha256: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize, Default, Copy, Clone)]
|
|
pub enum AssetFormat {
|
|
#[default]
|
|
Auto,
|
|
Json,
|
|
NdJson,
|
|
Raw,
|
|
}
|
|
|
|
impl AssetFormat {
|
|
pub fn to_content_type(self, filename: &str) -> &'static str {
|
|
match self {
|
|
AssetFormat::Auto => Self::auto_detect(filename).to_content_type(filename),
|
|
AssetFormat::Json => "application/json",
|
|
AssetFormat::NdJson => "application/x-ndjson",
|
|
AssetFormat::Raw => "application/octet-stream",
|
|
}
|
|
}
|
|
|
|
fn auto_detect(filename: &str) -> Self {
|
|
let path = std::path::Path::new(filename);
|
|
match path.extension().and_then(|extension| extension.to_str()) {
|
|
Some(extension) if extension.eq_ignore_ascii_case("json") => Self::Json,
|
|
Some(extension) if extension.eq_ignore_ascii_case("ndjson") => Self::NdJson,
|
|
extension => {
|
|
tracing::warn!(asset = filename, ?extension, "asset has format `Auto`, but extension was not recognized. Specify `Raw` format to suppress this warning.");
|
|
AssetFormat::Raw
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn fetch_asset(
|
|
name: &str,
|
|
assets: &BTreeMap<String, Asset>,
|
|
asset_folder: &str,
|
|
) -> anyhow::Result<(std::fs::File, AssetFormat)> {
|
|
let asset =
|
|
assets.get(name).with_context(|| format!("could not find asset with name '{name}'"))?;
|
|
let filename = if let Some(local_filename) = &asset.local_location {
|
|
local_filename.clone()
|
|
} else {
|
|
format!("{asset_folder}/{name}")
|
|
};
|
|
|
|
Ok((
|
|
std::fs::File::open(&filename)
|
|
.with_context(|| format!("could not open asset '{name}' at '{filename}'"))?,
|
|
asset.format,
|
|
))
|
|
}
|
|
|
|
#[tracing::instrument(skip(client, assets), fields(asset_count = assets.len()))]
|
|
pub async fn fetch_assets(
|
|
client: &Client,
|
|
assets: &BTreeMap<String, Asset>,
|
|
asset_folder: &str,
|
|
) -> anyhow::Result<()> {
|
|
let mut download_tasks = tokio::task::JoinSet::new();
|
|
for (name, asset) in assets {
|
|
// trying local
|
|
if let Some(local) = &asset.local_location {
|
|
match std::fs::File::open(local) {
|
|
Ok(file) => {
|
|
if check_sha256(name, asset, file)? {
|
|
continue;
|
|
} else {
|
|
tracing::warn!(asset = name, file = local, "found local resource for asset but hash differed, skipping to asset store");
|
|
}
|
|
}
|
|
Err(error) => match error.kind() {
|
|
std::io::ErrorKind::NotFound => { /* file does not exist, go to remote, no need for logs */
|
|
}
|
|
_ => tracing::warn!(
|
|
error = &error as &dyn std::error::Error,
|
|
"error checking local resource, skipping to asset store"
|
|
),
|
|
},
|
|
}
|
|
}
|
|
|
|
// checking asset store
|
|
let store_filename = format!("{}/{}", asset_folder, name);
|
|
|
|
match std::fs::File::open(&store_filename) {
|
|
Ok(file) => {
|
|
if check_sha256(name, asset, file)? {
|
|
continue;
|
|
} else {
|
|
tracing::warn!(asset = name, file = store_filename, "found resource for asset in asset store, but hash differed, skipping to remote method");
|
|
}
|
|
}
|
|
Err(error) => match error.kind() {
|
|
std::io::ErrorKind::NotFound => { /* file does not exist, go to remote, no need for logs */
|
|
}
|
|
_ => tracing::warn!(
|
|
error = &error as &dyn std::error::Error,
|
|
"error checking resource in store, skipping to remote method"
|
|
),
|
|
},
|
|
}
|
|
|
|
// downloading remote
|
|
match &asset.remote_location {
|
|
Some(location) => {
|
|
std::fs::create_dir_all(asset_folder).with_context(|| format!("could not create asset folder at {asset_folder}"))?;
|
|
download_tasks.spawn({
|
|
let client = client.clone();
|
|
let name = name.to_string();
|
|
let location = location.to_string();
|
|
let store_filename = store_filename.clone();
|
|
let asset = asset.clone();
|
|
download_asset(client, name, asset, location, store_filename)});
|
|
},
|
|
None => bail!("asset {name} has no remote location, but was not found locally or in the asset store"),
|
|
}
|
|
}
|
|
|
|
while let Some(res) = download_tasks.join_next().await {
|
|
res.context("download task panicked")?.context("download task failed")?;
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn check_sha256(name: &str, asset: &Asset, mut file: std::fs::File) -> anyhow::Result<bool> {
|
|
let mut bytes = Vec::new();
|
|
file.read_to_end(&mut bytes).with_context(|| format!("hashing file for asset {name}"))?;
|
|
let mut file_hash = sha2::Sha256::new();
|
|
file_hash.update(&bytes);
|
|
let file_hash = file_hash.finalize();
|
|
let file_hash = format!("{:x}", file_hash);
|
|
tracing::debug!(hash = file_hash, "hashed local file");
|
|
|
|
Ok(match &asset.sha256 {
|
|
Some(hash) => {
|
|
tracing::debug!(hash, "hash from workload");
|
|
if hash.to_ascii_lowercase() == file_hash {
|
|
true
|
|
} else {
|
|
tracing::warn!(
|
|
file_hash,
|
|
asset_hash = hash.to_ascii_lowercase(),
|
|
"hashes don't match"
|
|
);
|
|
false
|
|
}
|
|
}
|
|
None => {
|
|
tracing::warn!(sha256 = file_hash, "Skipping hash for asset {name} that doesn't have one. Please add it to workload file");
|
|
true
|
|
}
|
|
})
|
|
}
|
|
|
|
#[tracing::instrument(skip(client, asset, name), fields(asset = name))]
|
|
async fn download_asset(
|
|
client: Client,
|
|
name: String,
|
|
asset: Asset,
|
|
src: String,
|
|
dest_filename: String,
|
|
) -> anyhow::Result<()> {
|
|
let context = || format!("failure downloading asset {name} from {src}");
|
|
|
|
let response = client.get(&src).send().await.with_context(context)?;
|
|
|
|
let file = std::fs::File::options()
|
|
.create(true)
|
|
.truncate(true)
|
|
.write(true)
|
|
.read(true)
|
|
.open(&dest_filename)
|
|
.with_context(|| format!("creating destination file {dest_filename}"))
|
|
.with_context(context)?;
|
|
|
|
let mut dest = std::io::BufWriter::new(
|
|
file.try_clone().context("cloning I/O handle").with_context(context)?,
|
|
);
|
|
|
|
let total_len: Option<u64> = response
|
|
.headers()
|
|
.get(reqwest::header::CONTENT_LENGTH)
|
|
.and_then(|value| value.to_str().ok())
|
|
.and_then(|value| value.parse().ok());
|
|
|
|
let progress = tokio::spawn({
|
|
let name = name.clone();
|
|
async move {
|
|
loop {
|
|
match file.metadata().context("could not get file metadata") {
|
|
Ok(metadata) => {
|
|
let len = metadata.len();
|
|
tracing::info!(
|
|
asset = name,
|
|
downloaded_bytes = len,
|
|
total_bytes = total_len,
|
|
"asset download in progress"
|
|
);
|
|
}
|
|
Err(error) => {
|
|
tracing::warn!(%error, "could not get file metadata");
|
|
}
|
|
}
|
|
tokio::time::sleep(std::time::Duration::from_secs(60)).await;
|
|
}
|
|
}
|
|
});
|
|
|
|
let writing_context = || format!("while writing to destination file at {dest_filename}");
|
|
|
|
let mut response = response.bytes_stream();
|
|
|
|
while let Some(bytes) =
|
|
response.try_next().await.context("while downloading file").with_context(context)?
|
|
{
|
|
dest.write_all(&bytes).with_context(writing_context).with_context(context)?;
|
|
}
|
|
|
|
progress.abort();
|
|
|
|
let mut file = dest.into_inner().with_context(writing_context).with_context(context)?;
|
|
|
|
file.rewind().context("while rewinding asset file")?;
|
|
|
|
if !check_sha256(&name, &asset, file)? {
|
|
bail!("asset '{name}': sha256 mismatch for file {dest_filename} downloaded from {src}")
|
|
}
|
|
|
|
Ok(())
|
|
}
|