diff --git a/crates/meilitool/src/main.rs b/crates/meilitool/src/main.rs index 8a8b774b8..b4787c479 100644 --- a/crates/meilitool/src/main.rs +++ b/crates/meilitool/src/main.rs @@ -1,6 +1,10 @@ use std::fs::{read_dir, read_to_string, remove_file, File}; +use std::hint::black_box; use std::io::{BufWriter, Write as _}; +use std::ops::Bound; use std::path::PathBuf; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::thread; use std::time::Instant; use anyhow::{bail, Context}; @@ -640,28 +644,70 @@ fn hair_dryer( for part in index_parts { match part { IndexPart::Arroy => { - let mut count = 0; - let total = index.vector_arroy.len(&rtxn)?; - eprintln!("Hair drying arroy for {uid}..."); - for (i, result) in index - .vector_arroy - .remap_types::() - .iter(&rtxn)? - .enumerate() - { - let (key, value) = result?; + // It would be better if it is a command parameter + let total_threads = thread::available_parallelism().unwrap().get() * 10; + eprintln!("Hair drying arroy for {uid} using {total_threads} threads..."); - // All of this just to avoid compiler optimizations 🤞 - // We must read all the bytes to make the pages hot in cache. - // - count += std::hint::black_box(key.iter().fold(0, |acc, _| acc + 1)); - count += std::hint::black_box(value.iter().fold(0, |acc, _| acc + 1)); + let database = index.vector_arroy.remap_types::(); + let num_keys = database.len(&rtxn)? as usize; + let first_entry = database.iter(&rtxn)?.next().transpose()?; + let last_entry = database.rev_iter(&rtxn)?.next().transpose()?; + let keys_by_thread = num_keys / total_threads; - if i % 10_000 == 0 { - let perc = (i as f64) / (total as f64) * 100.0; - eprintln!("Visited {i}/{total} ({perc:.2}%) keys") + let Some(((first_key, _), (last_key, _))) = first_entry.zip(last_entry) + else { + continue; + }; + + let first_key_num = first_key.try_into().map(u64::from_be_bytes).unwrap(); + let last_key_num = last_key.try_into().map(u64::from_be_bytes).unwrap(); + + eprintln!("between {first_key_num:x} and {last_key_num:x}"); + eprintln!("Iterating over {keys_by_thread} entries by thread..."); + + let progress = AtomicUsize::new(0); + let count = thread::scope(|s| -> anyhow::Result { + let mut handles = Vec::new(); + + for tid in 0..total_threads { + let index = &index; + let progress = &progress; + let handle = s.spawn(move || -> anyhow::Result { + let rtxn = index.read_txn()?; + let start = first_key_num + (keys_by_thread * tid) as u64; + let start_bytes = start.to_be_bytes(); + let range = (Bound::Included(&start_bytes[..]), Bound::Unbounded); + + let mut count: usize = 0; + for result in database.range(&rtxn, &range)?.take(keys_by_thread) { + let (key, value) = result?; + + // All of this just to avoid compiler optimizations 🤞 + // We must read all the bytes to make the pages hot in cache. + // + count += black_box(key.iter().fold(0, |acc, _| acc + 1)); + count += black_box(value.iter().fold(0, |acc, _| acc + 1)); + + let current_progress = progress.fetch_add(1, Ordering::Relaxed); + if current_progress % 10_000 == 0 { + let perc = (current_progress as f64) / (num_keys as f64) * 100.0; + eprintln!("Visited {current_progress}/{num_keys} ({perc:.2}%) keys"); + } + } + + Ok(count) + }); + + handles.push(handle); } - } + + let mut count = 0usize; + for handle in handles { + count += handle.join().unwrap()?; + } + Ok(count) + })?; + eprintln!("Done hair drying a total of at least {count} bytes."); } }