From 66b8cfd8c83038980abd74e851969cda08e07fed Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 21 Jun 2023 17:07:02 +0200 Subject: [PATCH] Introduce a way to store the HNSW on multiple LMDB entries --- milli/src/index.rs | 46 +++++++++++++++++--- milli/src/lib.rs | 1 + milli/src/readable_slices.rs | 84 ++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 milli/src/readable_slices.rs diff --git a/milli/src/index.rs b/milli/src/index.rs index dcfcc0730..8343515cf 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -22,6 +22,7 @@ use crate::heed_codec::facet::{ FieldIdCodec, OrderedF64Codec, }; use crate::heed_codec::{ScriptLanguageCodec, StrBEU16Codec, StrRefCodec}; +use crate::readable_slices::ReadableSlices; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, @@ -47,7 +48,10 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; - pub const VECTOR_HNSW_KEY: &str = "vector-hnsw"; + /// The prefix of the key that is used to store the, potential big, HNSW structure. + /// It is concatenated with a big-endian encoded number (non-human readable). + /// e.g. vector-hnsw0x0032. + pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; @@ -517,19 +521,49 @@ impl Index { /// Writes the provided `hnsw`. pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { - self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HNSW_KEY, hnsw) + // We must delete all the chunks before we write the new HNSW chunks. + self.delete_vector_hnsw(wtxn)?; + + let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB + let bytes = bincode::serialize(hnsw).map_err(|_| heed::Error::Encoding)?; + for (i, chunk) in bytes.chunks(chunk_size).enumerate() { + let i = i as u32; + let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec(); + key.extend_from_slice(&i.to_be_bytes()); + self.main.put::<_, ByteSlice, ByteSlice>(wtxn, &key, chunk)?; + } + Ok(()) } /// Delete the `hnsw`. pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { - self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HNSW_KEY) + let mut iter = self.main.prefix_iter_mut::<_, ByteSlice, DecodeIgnore>( + wtxn, + main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes(), + )?; + let mut deleted = false; + while let Some(_) = iter.next().transpose()? { + // We do not keep a reference to the key or the value. + unsafe { deleted |= iter.del_current()? }; + } + Ok(deleted) } /// Returns the `hnsw`. pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { - match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HNSW_KEY)? { - Some(hnsw) => Ok(Some(hnsw)), - None => Ok(None), + let mut slices = Vec::new(); + for result in + self.main.prefix_iter::<_, Str, ByteSlice>(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)? + { + let (_, slice) = result?; + slices.push(slice); + } + + if slices.is_empty() { + Ok(None) + } else { + let readable_slices: ReadableSlices<_> = slices.into_iter().collect(); + Ok(Some(bincode::deserialize_from(readable_slices).map_err(|_| heed::Error::Decoding)?)) } } diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 63cf6f397..626c30ab0 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -18,6 +18,7 @@ mod fields_ids_map; pub mod heed_codec; pub mod index; pub mod proximity; +mod readable_slices; pub mod score_details; mod search; pub mod update; diff --git a/milli/src/readable_slices.rs b/milli/src/readable_slices.rs new file mode 100644 index 000000000..9ba6c1ba1 --- /dev/null +++ b/milli/src/readable_slices.rs @@ -0,0 +1,84 @@ +use std::io::{self, Read}; +use std::iter::FromIterator; + +pub struct ReadableSlices { + inner: Vec, + pos: u64, +} + +impl FromIterator for ReadableSlices { + fn from_iter>(iter: T) -> Self { + ReadableSlices { inner: iter.into_iter().collect(), pos: 0 } + } +} + +impl> Read for ReadableSlices { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + let original_buf_len = buf.len(); + + // We explore the list of slices to find the one where we must start reading. + let mut pos = self.pos; + let index = match self + .inner + .iter() + .map(|s| s.as_ref().len() as u64) + .position(|size| pos.checked_sub(size).map(|p| pos = p).is_none()) + { + Some(index) => index, + None => return Ok(0), + }; + + let mut inner_pos = pos as usize; + for slice in &self.inner[index..] { + let slice = &slice.as_ref()[inner_pos..]; + + if buf.len() > slice.len() { + // We must exhaust the current slice and go to the next one there is not enough here. + buf[..slice.len()].copy_from_slice(slice); + buf = &mut buf[slice.len()..]; + inner_pos = 0; + } else { + // There is enough in this slice to fill the remaining bytes of the buffer. + // Let's break just after filling it. + buf.copy_from_slice(&slice[..buf.len()]); + buf = &mut []; + break; + } + } + + let written = original_buf_len - buf.len(); + self.pos += written as u64; + Ok(written) + } +} + +#[cfg(test)] +mod test { + use super::ReadableSlices; + use std::io::Read; + + #[test] + fn basic() { + let data: Vec<_> = (0..100).collect(); + let splits: Vec<_> = data.chunks(3).collect(); + let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); + + let mut output = Vec::new(); + let length = rdslices.read_to_end(&mut output).unwrap(); + assert_eq!(length, data.len()); + assert_eq!(output, data); + } + + #[test] + fn small_reads() { + let data: Vec<_> = (0..u8::MAX).collect(); + let splits: Vec<_> = data.chunks(27).collect(); + let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); + + let buffer = &mut [0; 45]; + let length = rdslices.read(buffer).unwrap(); + let expected: Vec<_> = (0..buffer.len() as u8).collect(); + assert_eq!(length, buffer.len()); + assert_eq!(buffer, &expected[..]); + } +}