Add some security when merging buckets together

This commit is contained in:
Clément Renault 2024-10-29 10:21:21 +01:00
parent 9fcf51dcc6
commit fdfad0c3c1
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -69,6 +69,7 @@ use crate::{CboRoaringBitmapCodec, Result};
pub struct CboCachedSorter<'extractor> {
hasher: DefaultHashBuilder,
alloc: &'extractor Bump,
max_memory: usize,
caches: InnerCaches<'extractor>,
}
@ -81,6 +82,7 @@ impl<'extractor> CboCachedSorter<'extractor> {
pub fn new_in(buckets: usize, max_memory: usize, alloc: &'extractor Bump) -> Self {
Self {
hasher: DefaultHashBuilder::default(),
max_memory,
caches: InnerCaches::Normal(NormalCaches {
caches: iter::repeat_with(|| HashMap::new_in(alloc)).take(buckets).collect(),
}),
@ -95,7 +97,11 @@ impl<'extractor> CboCachedSorter<'extractor> {
}
}
pub fn insert_del_u32(&mut self, key: &[u8], n: u32) -> grenad::Result<(), crate::Error> {
pub fn insert_del_u32(&mut self, key: &[u8], n: u32) -> Result<()> {
if self.alloc.allocated_bytes() >= self.max_memory {
self.start_spilling()?;
}
let buckets = self.buckets();
match &mut self.caches {
InnerCaches::Normal(normal) => {
@ -108,7 +114,11 @@ impl<'extractor> CboCachedSorter<'extractor> {
}
}
pub fn insert_add_u32(&mut self, key: &[u8], n: u32) -> grenad::Result<(), crate::Error> {
pub fn insert_add_u32(&mut self, key: &[u8], n: u32) -> Result<()> {
if self.alloc.allocated_bytes() >= self.max_memory {
self.start_spilling()?;
}
let buckets = self.buckets();
match &mut self.caches {
InnerCaches::Normal(normal) => {
@ -121,8 +131,10 @@ impl<'extractor> CboCachedSorter<'extractor> {
}
}
pub fn spill_only(&mut self) -> grenad::Result<()> {
let CboCachedSorter { hasher: _, alloc: _, caches } = self;
/// Make sure the cache is no longer allocating data
/// and writes every new and unknow entry to disk.
fn start_spilling(&mut self) -> Result<()> {
let CboCachedSorter { hasher: _, alloc: _, max_memory: _, caches } = self;
if let InnerCaches::Normal(normal_caches) = caches {
let dummy = NormalCaches { caches: Vec::new() };
@ -137,22 +149,24 @@ impl<'extractor> CboCachedSorter<'extractor> {
match &mut self.caches {
InnerCaches::Normal(NormalCaches { caches }) => caches
.iter_mut()
.map(|map| {
.enumerate()
.map(|(bucket, map)| {
let file = tempfile::tempfile()?;
let writer = create_writer(CompressionType::None, None, file);
let spilled = writer_into_reader(writer)?;
Ok(FrozenCache { cache: FrozenMap::new(map), spilled })
Ok(FrozenCache { bucket, cache: FrozenMap::new(map), spilled })
})
.collect(),
InnerCaches::Spilling(SpillingCaches { caches, spilled_entries, .. }) => caches
.iter_mut()
.zip(mem::take(spilled_entries))
.map(|(map, sorter)| {
.enumerate()
.map(|(bucket, (map, sorter))| {
let file = tempfile::tempfile()?;
let mut writer = create_writer(CompressionType::None, None, file);
sorter.write_into_stream_writer(&mut writer)?;
let spilled = writer_into_reader(writer)?;
Ok(FrozenCache { cache: FrozenMap::new(map), spilled })
Ok(FrozenCache { bucket, cache: FrozenMap::new(map), spilled })
})
.collect(),
}
@ -251,7 +265,7 @@ impl<'extractor> SpillingCaches<'extractor> {
buckets: usize,
key: &[u8],
n: u32,
) -> grenad::Result<(), crate::Error> {
) -> Result<()> {
let hash = compute_bytes_hash(hasher, key);
let bucket = compute_bucket_from_hash(buckets, hash);
match self.caches[bucket].raw_entry_mut().from_hash(hash, |&k| k == key) {
@ -278,7 +292,7 @@ impl<'extractor> SpillingCaches<'extractor> {
buckets: usize,
key: &[u8],
n: u32,
) -> grenad::Result<(), crate::Error> {
) -> Result<()> {
let hash = compute_bytes_hash(hasher, key);
let bucket = compute_bucket_from_hash(buckets, hash);
match self.caches[bucket].raw_entry_mut().from_hash(hash, |&k| k == key) {
@ -319,9 +333,10 @@ fn spill_entry_to_disk(
cbo_buffer: &mut Vec<u8>,
key: &[u8],
deladd: DelAddRoaringBitmap,
) -> grenad::Result<(), crate::Error> {
) -> Result<()> {
deladd_buffer.clear();
let mut value_writer = KvWriterDelAdd::new(deladd_buffer);
match deladd {
DelAddRoaringBitmap { del: Some(del), add: None } => {
cbo_buffer.clear();
@ -344,22 +359,34 @@ fn spill_entry_to_disk(
}
DelAddRoaringBitmap { del: None, add: None } => return Ok(()),
}
let bytes = value_writer.into_inner().unwrap();
spilled_entries.insert(key, bytes)
spilled_entries.insert(key, bytes).map_err(Into::into)
}
pub struct FrozenCache<'a, 'extractor> {
bucket: usize,
cache: FrozenMap<'a, 'extractor, &'extractor [u8], DelAddRoaringBitmap, DefaultHashBuilder>,
spilled: grenad::Reader<BufReader<File>>,
}
/// Merges the caches that must be all associated to the same bucket.
///
/// # Panics
///
/// - If the bucket IDs in these frozen caches are not exactly the same.
pub fn merge_caches<F>(frozen: Vec<FrozenCache>, mut iter: F) -> Result<()>
where
F: for<'a> FnMut(&'a [u8], DelAddRoaringBitmap) -> Result<()>,
{
let (mut maps, spilled): (Vec<_>, Vec<_>) =
frozen.into_iter().map(|FrozenCache { cache, spilled }| (cache, spilled)).collect();
let mut current_bucket = None;
let (mut maps, spilled): (Vec<_>, Vec<_>) = frozen
.into_iter()
.map(|FrozenCache { bucket, cache, spilled }| {
assert_eq!(*current_bucket.get_or_insert(bucket), bucket);
(cache, spilled)
})
.collect();
// First manage the spilled entries by looking into the HashMaps,
// merge them and mark them as dummy.