From acdd5aa6ea143b2b92079e50cc0e22afeebee570 Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Thu, 12 Dec 2024 17:54:28 +0100 Subject: [PATCH] Use the thread source id instead of the destination id when filtering on the cache to merge --- crates/milli/src/update/new/extract/cache.rs | 44 ++++++++++++-------- 1 file changed, 27 insertions(+), 17 deletions(-) diff --git a/crates/milli/src/update/new/extract/cache.rs b/crates/milli/src/update/new/extract/cache.rs index 62c00d2b1..e2c8bb5fe 100644 --- a/crates/milli/src/update/new/extract/cache.rs +++ b/crates/milli/src/update/new/extract/cache.rs @@ -177,12 +177,12 @@ impl<'extractor> BalancedCaches<'extractor> { Ok(()) } - pub fn freeze(&mut self) -> Result>> { + pub fn freeze(&mut self, source_id: usize) -> Result>> { match &mut self.caches { InnerCaches::Normal(NormalCaches { caches }) => caches .iter_mut() .enumerate() - .map(|(bucket, map)| { + .map(|(bucket_id, map)| { // safety: we are transmuting the Bbbul into a FrozenBbbul // that are the same size. let map = unsafe { @@ -201,14 +201,19 @@ impl<'extractor> BalancedCaches<'extractor> { >, >(map) }; - Ok(FrozenCache { bucket, cache: FrozenMap::new(map), spilled: Vec::new() }) + Ok(FrozenCache { + source_id, + bucket_id, + cache: FrozenMap::new(map), + spilled: Vec::new(), + }) }) .collect(), InnerCaches::Spilling(SpillingCaches { caches, spilled_entries, .. }) => caches .iter_mut() .zip(mem::take(spilled_entries)) .enumerate() - .map(|(bucket, (map, sorter))| { + .map(|(bucket_id, (map, sorter))| { let spilled = sorter .into_reader_cursors()? .into_iter() @@ -234,7 +239,7 @@ impl<'extractor> BalancedCaches<'extractor> { >, >(map) }; - Ok(FrozenCache { bucket, cache: FrozenMap::new(map), spilled }) + Ok(FrozenCache { source_id, bucket_id, cache: FrozenMap::new(map), spilled }) }) .collect(), } @@ -440,7 +445,8 @@ fn spill_entry_to_sorter( } pub struct FrozenCache<'a, 'extractor> { - bucket: usize, + bucket_id: usize, + source_id: usize, cache: FrozenMap< 'a, 'extractor, @@ -457,9 +463,9 @@ pub fn transpose_and_freeze_caches<'a, 'extractor>( let width = caches.first().map(BalancedCaches::buckets).unwrap_or(0); let mut bucket_caches: Vec<_> = iter::repeat_with(Vec::new).take(width).collect(); - for thread_cache in caches { - for frozen in thread_cache.freeze()? { - bucket_caches[frozen.bucket].push(frozen); + for (thread_index, thread_cache) in caches.iter_mut().enumerate() { + for frozen in thread_cache.freeze(thread_index)? { + bucket_caches[frozen.bucket_id].push(frozen); } } @@ -479,13 +485,13 @@ where let mut maps = Vec::new(); let mut heap = BinaryHeap::new(); let mut current_bucket = None; - for FrozenCache { bucket, cache, spilled } in frozen { - assert_eq!(*current_bucket.get_or_insert(bucket), bucket); - maps.push((bucket, cache)); + for FrozenCache { source_id, bucket_id, cache, spilled } in frozen { + assert_eq!(*current_bucket.get_or_insert(bucket_id), bucket_id); + maps.push((source_id, cache)); for reader in spilled { let mut cursor = reader.into_cursor()?; if cursor.move_on_next()?.is_some() { - heap.push(Entry { cursor, bucket }); + heap.push(Entry { cursor, source_id }); } } } @@ -520,8 +526,12 @@ where // Once we merged all of the spilled bitmaps we must also // fetch the entries from the non-spilled entries (the HashMaps). - for (map_bucket, map) in maps.iter_mut() { - if first_entry.bucket != *map_bucket { + for (source_id, map) in maps.iter_mut() { + debug_assert!( + !(map.get(first_key).is_some() && first_entry.source_id == *source_id), + "A thread should not have spiled a key that has been inserted in the cache" + ); + if first_entry.source_id != *source_id { if let Some(new) = map.get_mut(first_key) { output.union_and_clear_bbbul(new); } @@ -564,14 +574,14 @@ where struct Entry { cursor: ReaderCursor, - bucket: usize, + source_id: usize, } impl Ord for Entry { fn cmp(&self, other: &Entry) -> Ordering { let skey = self.cursor.current().map(|(k, _)| k); let okey = other.cursor.current().map(|(k, _)| k); - skey.cmp(&okey).then(self.bucket.cmp(&other.bucket)).reverse() + skey.cmp(&okey).then(self.source_id.cmp(&other.source_id)).reverse() } }