diff --git a/Cargo.lock b/Cargo.lock index e4789da4a..e2069db87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -489,6 +489,11 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bbqueue" +version = "0.5.1" +source = "git+https://github.com/kerollmops/bbqueue#cbb87cc707b5af415ef203bdaf2443e06ba0d6d4" + [[package]] name = "benchmarks" version = "1.12.0" @@ -3611,6 +3616,7 @@ version = "1.12.0" dependencies = [ "allocator-api2", "arroy 0.5.0 (registry+https://github.com/rust-lang/crates.io-index)", + "bbqueue", "big_s", "bimap", "bincode", @@ -3623,6 +3629,7 @@ dependencies = [ "candle-transformers", "charabia", "concat-arrays", + "crossbeam", "crossbeam-channel", "csv", "deserr", diff --git a/crates/milli/Cargo.toml b/crates/milli/Cargo.toml index a0bd86a42..798a4ea19 100644 --- a/crates/milli/Cargo.toml +++ b/crates/milli/Cargo.toml @@ -98,6 +98,8 @@ allocator-api2 = "0.2.18" rustc-hash = "2.0.0" uell = "0.1.0" enum-iterator = "2.1.0" +bbqueue = { git = "https://github.com/kerollmops/bbqueue" } +crossbeam = "0.8.4" [dev-dependencies] mimalloc = { version = "0.1.43", default-features = false } diff --git a/crates/milli/src/update/new/channel.rs b/crates/milli/src/update/new/channel.rs index 00b471b52..21cd6b87d 100644 --- a/crates/milli/src/update/new/channel.rs +++ b/crates/milli/src/update/new/channel.rs @@ -1,6 +1,7 @@ use std::marker::PhantomData; use std::sync::atomic::{AtomicUsize, Ordering}; +use crossbeam::sync::{Parker, Unparker}; use crossbeam_channel::{IntoIter, Receiver, SendError, Sender}; use heed::types::Bytes; use heed::BytesDecode; @@ -8,6 +9,7 @@ use memmap2::Mmap; use roaring::RoaringBitmap; use super::extract::FacetKind; +use super::thread_local::{FullySend, ThreadLocal}; use super::StdResult; use crate::heed_codec::facet::{FieldDocIdFacetF64Codec, FieldDocIdFacetStringCodec}; use crate::index::main_key::{GEO_FACETED_DOCUMENTS_IDS_KEY, GEO_RTREE_KEY}; @@ -16,6 +18,50 @@ use crate::update::new::KvReaderFieldId; use crate::vector::Embedding; use crate::{DocumentId, Index}; +/// Creates a tuple of producer/receivers to be used by +/// the extractors and the writer loop. +/// +/// # Safety +/// +/// Panics if the number of provided bbqueue is not exactly equal +/// to the number of available threads in the rayon threadpool. +pub fn extractor_writer_bbqueue( + bbqueue: &[bbqueue::BBBuffer], +) -> (ExtractorBbqueueSender, WriterBbqueueReceiver) { + assert_eq!( + bbqueue.len(), + rayon::current_num_threads(), + "You must provide as many BBBuffer as the available number of threads to extract" + ); + + let parker = Parker::new(); + let extractors = ThreadLocal::with_capacity(bbqueue.len()); + let producers = rayon::broadcast(|bi| { + let bbqueue = &bbqueue[bi.index()]; + let (producer, consumer) = bbqueue.try_split_framed().unwrap(); + extractors.get_or(|| FullySend(producer)); + consumer + }); + + ( + ExtractorBbqueueSender { inner: extractors, unparker: parker.unparker().clone() }, + WriterBbqueueReceiver { inner: producers, parker }, + ) +} + +pub struct ExtractorBbqueueSender<'a> { + inner: ThreadLocal>>, + /// Used to wake up the receiver thread, + /// Used everytime we write something in the producer. + unparker: Unparker, +} + +pub struct WriterBbqueueReceiver<'a> { + inner: Vec>, + /// Used to park when no more work is required + parker: Parker, +} + /// The capacity of the channel is currently in number of messages. pub fn extractor_writer_channel(cap: usize) -> (ExtractorSender, WriterReceiver) { let (sender, receiver) = crossbeam_channel::bounded(cap);