From 29ab54b259f94d2403689052e1ba68ffcedceafa Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 25 Jul 2023 12:36:01 +0200 Subject: [PATCH] Replace the hnsw crate by the instant-distance one --- Cargo.lock | 71 +++++++------------ milli/Cargo.toml | 3 +- milli/src/distance.rs | 40 +++++++---- milli/src/index.rs | 5 +- milli/src/search/new/mod.rs | 49 ++++++------- milli/src/update/delete_documents.rs | 25 ++++--- .../src/update/index_documents/typed_chunk.rs | 54 ++++++++------ 7 files changed, 127 insertions(+), 120 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 370841384..bfc85dda8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1197,12 +1197,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "doc-comment" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" - [[package]] name = "dump" version = "1.3.0" @@ -1707,15 +1701,6 @@ dependencies = [ "byteorder", ] -[[package]] -name = "hashbrown" -version = "0.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" -dependencies = [ - "ahash 0.7.6", -] - [[package]] name = "hashbrown" version = "0.12.3" @@ -1814,22 +1799,6 @@ dependencies = [ "digest", ] -[[package]] -name = "hnsw" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752" -dependencies = [ - "ahash 0.7.6", - "hashbrown 0.11.2", - "libm", - "num-traits", - "rand_core", - "serde", - "smallvec", - "space", -] - [[package]] name = "http" version = "0.2.9" @@ -2008,6 +1977,21 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "instant-distance" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a" +dependencies = [ + "num_cpus", + "ordered-float", + "parking_lot", + "rand", + "rayon", + "serde", + "serde-big-array", +] + [[package]] name = "io-lifetimes" version = "1.0.11" @@ -2701,9 +2685,9 @@ dependencies = [ "geoutils", "grenad", "heed", - "hnsw", "indexmap 1.9.3", "insta", + "instant-distance", "itertools", "json-depth-checker", "levenshtein_automata", @@ -2727,7 +2711,6 @@ dependencies = [ "smallstr", "smallvec", "smartstring", - "space", "tempfile", "thiserror", "time", @@ -3607,6 +3590,15 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-big-array" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f" +dependencies = [ + "serde", +] + [[package]] name = "serde-cs" version = "0.2.4" @@ -3756,9 +3748,6 @@ name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" -dependencies = [ - "serde", -] [[package]] name = "smartstring" @@ -3781,16 +3770,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "space" -version = "0.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5ab9701ae895386d13db622abf411989deff7109b13b46b6173bb4ce5c1d123" -dependencies = [ - "doc-comment", - "num-traits", -] - [[package]] name = "spin" version = "0.5.2" diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 2689975cd..cbe0794fe 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -33,8 +33,8 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f "lmdb", "sync-read-txn", ] } -hnsw = { version = "0.11.0", features = ["serde1"] } indexmap = { version = "1.9.3", features = ["serde"] } +instant-distance = { version = "0.6.1", features = ["with-serde"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.5.10" @@ -48,7 +48,6 @@ rstar = { version = "0.10.0", features = ["serde"] } serde = { version = "1.0.160", features = ["derive"] } serde_json = { version = "1.0.95", features = ["preserve_order"] } slice-group-by = "0.3.0" -space = "0.17.0" smallstr = { version = "0.3.0", features = ["serde"] } smallvec = "1.10.0" smartstring = "1.0.1" diff --git a/milli/src/distance.rs b/milli/src/distance.rs index c838e4bd4..e9e17e647 100644 --- a/milli/src/distance.rs +++ b/milli/src/distance.rs @@ -1,20 +1,36 @@ +use std::ops; + +use instant_distance::Point; use serde::{Deserialize, Serialize}; -use space::Metric; -#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] -pub struct DotProduct; +use crate::normalize_vector; -impl Metric> for DotProduct { - type Unit = u32; +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +pub struct NDotProductPoint(Vec); - // Following . - // - // Here is a playground that validate the ordering of the bit representation of floats in range 0.0..=1.0: - // - fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { - let dist = 1.0 - dot_product_similarity(a, b); +impl NDotProductPoint { + pub fn new(point: Vec) -> Self { + NDotProductPoint(normalize_vector(point)) + } + + pub fn into_inner(self) -> Vec { + self.0 + } +} + +impl ops::Deref for NDotProductPoint { + type Target = [f32]; + + fn deref(&self) -> &Self::Target { + self.0.as_slice() + } +} + +impl Point for NDotProductPoint { + fn distance(&self, other: &Self) -> f32 { + let dist = 1.0 - dot_product_similarity(&self.0, &other.0); debug_assert!(!dist.is_nan()); - dist.to_bits() + dist } } diff --git a/milli/src/index.rs b/milli/src/index.rs index 392ed1705..847ab0088 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -8,12 +8,11 @@ use charabia::{Language, Script}; use heed::flags::Flags; use heed::types::*; use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn}; -use rand_pcg::Pcg32; use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::distance::DotProduct; +use crate::distance::NDotProductPoint; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -31,7 +30,7 @@ use crate::{ }; /// The HNSW data-structure that we serialize, fill and search in. -pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; +pub type Hnsw = instant_distance::Hnsw; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index c27e02514..ad15d8e91 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,7 +28,7 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; -use hnsw::Searcher; +use instant_distance::Search; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -40,19 +40,18 @@ use ranking_rules::{ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use roaring::RoaringBitmap; use sort::Sort; -use space::Neighbor; use self::distinct::facet_string_values; use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; use self::graph_based_ranking_rule::Words; use self::interner::Interned; +use crate::distance::NDotProductPoint; use crate::error::FieldIdMapMissingEntry; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; use crate::{ - normalize_vector, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, - UserError, BEU32, + AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, }; /// A structure used throughout the execution of a search query. @@ -445,29 +444,31 @@ pub fn execute_search( check_sort_criteria(ctx, sort_criteria.as_ref())?; if let Some(vector) = vector { - let mut searcher = Searcher::new(); - let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); - let ef = hnsw.len().min(100); - let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; - let vector = normalize_vector(vector.clone()); - let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); + let mut search = Search::default(); + let docids = match ctx.index.vector_hnsw(ctx.txn)? { + Some(hnsw) => { + let vector = NDotProductPoint::new(vector.clone()); + let neighbors = hnsw.search(&vector, &mut search); - let mut docids = Vec::new(); - let mut uniq_docids = RoaringBitmap::new(); - for Neighbor { index, distance: _ } in neighbors.iter() { - let index = BEU32::new(*index as u32); - let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); - if universe.contains(docid) && uniq_docids.insert(docid) { - docids.push(docid); - if docids.len() == (from + length) { - break; + let mut docids = Vec::new(); + let mut uniq_docids = RoaringBitmap::new(); + for instant_distance::Item { distance: _, pid, point: _ } in neighbors { + let index = BEU32::new(pid.into_inner()); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); + if universe.contains(docid) && uniq_docids.insert(docid) { + docids.push(docid); + if docids.len() == (from + length) { + break; + } + } } - } - } - // return the nearest documents that are also part of the candidates - // along with a dummy list of scores that are useless in this context. - let docids: Vec<_> = docids.into_iter().skip(from).take(length).collect(); + // return the nearest documents that are also part of the candidates + // along with a dummy list of scores that are useless in this context. + docids.into_iter().skip(from).take(length).collect() + } + None => Vec::new(), + }; return Ok(PartialSearchResult { candidates: universe, diff --git a/milli/src/update/delete_documents.rs b/milli/src/update/delete_documents.rs index c9124e591..906d6922f 100644 --- a/milli/src/update/delete_documents.rs +++ b/milli/src/update/delete_documents.rs @@ -4,10 +4,9 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use fst::IntoStreamer; use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice}; use heed::{BytesDecode, BytesEncode, Database, RwIter}; -use hnsw::Searcher; +use instant_distance::PointId; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; -use space::KnnPoints; use time::OffsetDateTime; use super::facet::delete::FacetsDelete; @@ -436,24 +435,24 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { // An ugly and slow way to remove the vectors from the HNSW // It basically reconstructs the HNSW from scratch without editing the current one. - let current_hnsw = self.index.vector_hnsw(self.wtxn)?.unwrap_or_default(); - if !current_hnsw.is_empty() { - let mut new_hnsw = Hnsw::default(); - let mut searcher = Searcher::new(); - let mut new_vector_id_docids = Vec::new(); - + if let Some(current_hnsw) = self.index.vector_hnsw(self.wtxn)? { + let mut points = Vec::new(); + let mut docids = Vec::new(); for result in vector_id_docid.iter(self.wtxn)? { let (vector_id, docid) = result?; if !self.to_delete_docids.contains(docid.get()) { - let vector = current_hnsw.get_point(vector_id.get() as usize).clone(); - let vector_id = new_hnsw.insert(vector, &mut searcher); - new_vector_id_docids.push((vector_id as u32, docid)); + let pid = PointId::from(vector_id.get()); + let vector = current_hnsw[pid].clone(); + points.push(vector); + docids.push(docid); } } + let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points); + vector_id_docid.clear(self.wtxn)?; - for (vector_id, docid) in new_vector_id_docids { - vector_id_docid.put(self.wtxn, &BEU32::new(vector_id), &docid)?; + for (pid, docid) in pids.into_iter().zip(docids) { + vector_id_docid.put(self.wtxn, &BEU32::new(pid.into_inner()), &docid)?; } self.index.put_vector_hnsw(self.wtxn, &new_hnsw)?; } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 3f197fbd1..921ce4ecd 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -9,22 +9,19 @@ use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::ByteSlice; use heed::RwTxn; -use hnsw::Searcher; use roaring::RoaringBitmap; -use space::KnnPoints; use super::helpers::{ self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap, }; use super::{ClonableMmap, MergeFn}; +use crate::distance::NDotProductPoint; use crate::error::UserError; use crate::facet::FacetType; +use crate::index::Hnsw; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; -use crate::{ - lat_lng_to_xyz, normalize_vector, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, - BEU32, -}; +use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -230,17 +227,20 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } TypedChunk::VectorPoints(vector_points) => { - let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); - let mut searcher = Searcher::new(); - - let mut expected_dimensions = match index.vector_id_docid.iter(wtxn)?.next() { - Some(result) => { - let (vector_id, _) = result?; - Some(hnsw.get_point(vector_id.get() as usize).len()) - } - None => None, + let (pids, mut points): (Vec<_>, Vec<_>) = match index.vector_hnsw(wtxn)? { + Some(hnsw) => hnsw.iter().map(|(pid, point)| (pid, point.clone())).unzip(), + None => Default::default(), }; + // Convert the PointIds into DocumentIds + let mut docids = Vec::new(); + for pid in pids { + let docid = + index.vector_id_docid.get(wtxn, &BEU32::new(pid.into_inner()))?.unwrap(); + docids.push(docid.get()); + } + + let mut expected_dimensions = points.get(0).map(|p| p.len()); let mut cursor = vector_points.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) @@ -256,12 +256,26 @@ pub(crate) fn write_typed_chunk_into_index( return Err(UserError::InvalidVectorDimensions { expected, found })?; } - let vector = normalize_vector(vector); - let vector_id = hnsw.insert(vector, &mut searcher) as u32; - index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; + points.push(NDotProductPoint::new(vector)); + docids.push(docid); } - log::debug!("There are {} entries in the HNSW so far", hnsw.len()); - index.put_vector_hnsw(wtxn, &hnsw)?; + + assert_eq!(docids.len(), points.len()); + + let hnsw_length = points.len(); + let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points); + + index.vector_id_docid.clear(wtxn)?; + for (docid, pid) in docids.into_iter().zip(pids) { + index.vector_id_docid.put( + wtxn, + &BEU32::new(pid.into_inner()), + &BEU32::new(docid), + )?; + } + + log::debug!("There are {} entries in the HNSW so far", hnsw_length); + index.put_vector_hnsw(wtxn, &new_hnsw)?; } TypedChunk::ScriptLanguageDocids(hash_pair) => { let mut buffer = Vec::new();