3948: Fix hnsw internal panic by using another library r=ManyTheFish a=Kerollmops

This pull request fixes #3923. The issue concerns the `hnsw` crate panicking due to a wrong call to the `[T]::copy_from_slice` function.

I decided to switch the library to `instant-distance`, which is maintained [by someone of trust](https://lib.rs/~djc), who maintains a lot of very important crates.

- [x] Make Clippy happy with the first commit.
- [x] Reproduce the #3923 bug without this patch
- [x] Check if the bug disappeared with this PR.
- [x] Test with [the Algolia e-commerce dataset](https://www.notion.so/meilisearch/Algolia-Ecommerce-c5fa3b5f23a7485295df7e87306d5859).

Co-authored-by: Kerollmops <clement@meilisearch.com>
This commit is contained in:
meili-bors[bot] 2023-07-25 13:28:25 +00:00 committed by GitHub
commit 9e3e69373e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 222 additions and 215 deletions

71
Cargo.lock generated
View File

@ -1197,12 +1197,6 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "doc-comment"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
[[package]] [[package]]
name = "dump" name = "dump"
version = "1.3.0" version = "1.3.0"
@ -1707,15 +1701,6 @@ dependencies = [
"byteorder", "byteorder",
] ]
[[package]]
name = "hashbrown"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
dependencies = [
"ahash 0.7.6",
]
[[package]] [[package]]
name = "hashbrown" name = "hashbrown"
version = "0.12.3" version = "0.12.3"
@ -1814,22 +1799,6 @@ dependencies = [
"digest", "digest",
] ]
[[package]]
name = "hnsw"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752"
dependencies = [
"ahash 0.7.6",
"hashbrown 0.11.2",
"libm",
"num-traits",
"rand_core",
"serde",
"smallvec",
"space",
]
[[package]] [[package]]
name = "http" name = "http"
version = "0.2.9" version = "0.2.9"
@ -2008,6 +1977,21 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "instant-distance"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a"
dependencies = [
"num_cpus",
"ordered-float",
"parking_lot",
"rand",
"rayon",
"serde",
"serde-big-array",
]
[[package]] [[package]]
name = "io-lifetimes" name = "io-lifetimes"
version = "1.0.11" version = "1.0.11"
@ -2701,9 +2685,9 @@ dependencies = [
"geoutils", "geoutils",
"grenad", "grenad",
"heed", "heed",
"hnsw",
"indexmap 1.9.3", "indexmap 1.9.3",
"insta", "insta",
"instant-distance",
"itertools", "itertools",
"json-depth-checker", "json-depth-checker",
"levenshtein_automata", "levenshtein_automata",
@ -2727,7 +2711,6 @@ dependencies = [
"smallstr", "smallstr",
"smallvec", "smallvec",
"smartstring", "smartstring",
"space",
"tempfile", "tempfile",
"thiserror", "thiserror",
"time", "time",
@ -3607,6 +3590,15 @@ dependencies = [
"serde_derive", "serde_derive",
] ]
[[package]]
name = "serde-big-array"
version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "serde-cs" name = "serde-cs"
version = "0.2.4" version = "0.2.4"
@ -3756,9 +3748,6 @@ name = "smallvec"
version = "1.10.0" version = "1.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
dependencies = [
"serde",
]
[[package]] [[package]]
name = "smartstring" name = "smartstring"
@ -3781,16 +3770,6 @@ dependencies = [
"winapi", "winapi",
] ]
[[package]]
name = "space"
version = "0.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c5ab9701ae895386d13db622abf411989deff7109b13b46b6173bb4ce5c1d123"
dependencies = [
"doc-comment",
"num-traits",
]
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.5.2" version = "0.5.2"

View File

@ -472,6 +472,77 @@ pub fn parse_filter(input: Span) -> IResult<FilterCondition> {
terminated(|input| parse_expression(input, 0), eof)(input) terminated(|input| parse_expression(input, 0), eof)(input)
} }
impl<'a> std::fmt::Display for FilterCondition<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FilterCondition::Not(filter) => {
write!(f, "NOT ({filter})")
}
FilterCondition::Condition { fid, op } => {
write!(f, "{fid} {op}")
}
FilterCondition::In { fid, els } => {
write!(f, "{fid} IN[")?;
for el in els {
write!(f, "{el}, ")?;
}
write!(f, "]")
}
FilterCondition::Or(els) => {
write!(f, "OR[")?;
for el in els {
write!(f, "{el}, ")?;
}
write!(f, "]")
}
FilterCondition::And(els) => {
write!(f, "AND[")?;
for el in els {
write!(f, "{el}, ")?;
}
write!(f, "]")
}
FilterCondition::GeoLowerThan { point, radius } => {
write!(f, "_geoRadius({}, {}, {})", point[0], point[1], radius)
}
FilterCondition::GeoBoundingBox {
top_right_point: top_left_point,
bottom_left_point: bottom_right_point,
} => {
write!(
f,
"_geoBoundingBox([{}, {}], [{}, {}])",
top_left_point[0],
top_left_point[1],
bottom_right_point[0],
bottom_right_point[1]
)
}
}
}
}
impl<'a> std::fmt::Display for Condition<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Condition::GreaterThan(token) => write!(f, "> {token}"),
Condition::GreaterThanOrEqual(token) => write!(f, ">= {token}"),
Condition::Equal(token) => write!(f, "= {token}"),
Condition::NotEqual(token) => write!(f, "!= {token}"),
Condition::Null => write!(f, "IS NULL"),
Condition::Empty => write!(f, "IS EMPTY"),
Condition::Exists => write!(f, "EXISTS"),
Condition::LowerThan(token) => write!(f, "< {token}"),
Condition::LowerThanOrEqual(token) => write!(f, "<= {token}"),
Condition::Between { from, to } => write!(f, "{from} TO {to}"),
}
}
}
impl<'a> std::fmt::Display for Token<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{{{}}}", self.value())
}
}
#[cfg(test)] #[cfg(test)]
pub mod tests { pub mod tests {
use super::*; use super::*;
@ -852,74 +923,3 @@ pub mod tests {
assert_eq!(token.value(), s); assert_eq!(token.value(), s);
} }
} }
impl<'a> std::fmt::Display for FilterCondition<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FilterCondition::Not(filter) => {
write!(f, "NOT ({filter})")
}
FilterCondition::Condition { fid, op } => {
write!(f, "{fid} {op}")
}
FilterCondition::In { fid, els } => {
write!(f, "{fid} IN[")?;
for el in els {
write!(f, "{el}, ")?;
}
write!(f, "]")
}
FilterCondition::Or(els) => {
write!(f, "OR[")?;
for el in els {
write!(f, "{el}, ")?;
}
write!(f, "]")
}
FilterCondition::And(els) => {
write!(f, "AND[")?;
for el in els {
write!(f, "{el}, ")?;
}
write!(f, "]")
}
FilterCondition::GeoLowerThan { point, radius } => {
write!(f, "_geoRadius({}, {}, {})", point[0], point[1], radius)
}
FilterCondition::GeoBoundingBox {
top_right_point: top_left_point,
bottom_left_point: bottom_right_point,
} => {
write!(
f,
"_geoBoundingBox([{}, {}], [{}, {}])",
top_left_point[0],
top_left_point[1],
bottom_right_point[0],
bottom_right_point[1]
)
}
}
}
}
impl<'a> std::fmt::Display for Condition<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Condition::GreaterThan(token) => write!(f, "> {token}"),
Condition::GreaterThanOrEqual(token) => write!(f, ">= {token}"),
Condition::Equal(token) => write!(f, "= {token}"),
Condition::NotEqual(token) => write!(f, "!= {token}"),
Condition::Null => write!(f, "IS NULL"),
Condition::Empty => write!(f, "IS EMPTY"),
Condition::Exists => write!(f, "EXISTS"),
Condition::LowerThan(token) => write!(f, "< {token}"),
Condition::LowerThanOrEqual(token) => write!(f, "<= {token}"),
Condition::Between { from, to } => write!(f, "{from} TO {to}"),
}
}
}
impl<'a> std::fmt::Display for Token<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{{{}}}", self.value())
}
}

View File

@ -199,6 +199,30 @@ macro_rules! snapshot {
}; };
} }
/// Create a string from the value by serializing it as Json, optionally
/// redacting some parts of it.
///
/// The second argument to the macro can be an object expression for redaction.
/// It's in the form { selector => replacement }. For more information about redactions
/// refer to the redactions feature in the `insta` guide.
#[macro_export]
macro_rules! json_string {
($value:expr, {$($k:expr => $v:expr),*$(,)?}) => {
{
let (_, snap) = meili_snap::insta::_prepare_snapshot_for_redaction!($value, {$($k => $v),*}, Json, File);
snap
}
};
($value:expr) => {{
let value = meili_snap::insta::_macro_support::serialize_value(
&$value,
meili_snap::insta::_macro_support::SerializationFormat::Json,
meili_snap::insta::_macro_support::SnapshotLocation::File
);
value
}};
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate as meili_snap; use crate as meili_snap;
@ -250,27 +274,3 @@ mod tests {
} }
} }
} }
/// Create a string from the value by serializing it as Json, optionally
/// redacting some parts of it.
///
/// The second argument to the macro can be an object expression for redaction.
/// It's in the form { selector => replacement }. For more information about redactions
/// refer to the redactions feature in the `insta` guide.
#[macro_export]
macro_rules! json_string {
($value:expr, {$($k:expr => $v:expr),*$(,)?}) => {
{
let (_, snap) = meili_snap::insta::_prepare_snapshot_for_redaction!($value, {$($k => $v),*}, Json, File);
snap
}
};
($value:expr) => {{
let value = meili_snap::insta::_macro_support::serialize_value(
&$value,
meili_snap::insta::_macro_support::SerializationFormat::Json,
meili_snap::insta::_macro_support::SnapshotLocation::File
);
value
}};
}

View File

@ -33,8 +33,8 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f
"lmdb", "lmdb",
"sync-read-txn", "sync-read-txn",
] } ] }
hnsw = { version = "0.11.0", features = ["serde1"] }
indexmap = { version = "1.9.3", features = ["serde"] } indexmap = { version = "1.9.3", features = ["serde"] }
instant-distance = { version = "0.6.1", features = ["with-serde"] }
json-depth-checker = { path = "../json-depth-checker" } json-depth-checker = { path = "../json-depth-checker" }
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
memmap2 = "0.5.10" memmap2 = "0.5.10"
@ -48,7 +48,6 @@ rstar = { version = "0.10.0", features = ["serde"] }
serde = { version = "1.0.160", features = ["derive"] } serde = { version = "1.0.160", features = ["derive"] }
serde_json = { version = "1.0.95", features = ["preserve_order"] } serde_json = { version = "1.0.95", features = ["preserve_order"] }
slice-group-by = "0.3.0" slice-group-by = "0.3.0"
space = "0.17.0"
smallstr = { version = "0.3.0", features = ["serde"] } smallstr = { version = "0.3.0", features = ["serde"] }
smallvec = "1.10.0" smallvec = "1.10.0"
smartstring = "1.0.1" smartstring = "1.0.1"

View File

@ -1,20 +1,36 @@
use std::ops;
use instant_distance::Point;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use space::Metric;
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] use crate::normalize_vector;
pub struct DotProduct;
impl Metric<Vec<f32>> for DotProduct { #[derive(Debug, Default, Clone, Serialize, Deserialize)]
type Unit = u32; pub struct NDotProductPoint(Vec<f32>);
// Following <https://docs.rs/space/0.17.0/space/trait.Metric.html>. impl NDotProductPoint {
// pub fn new(point: Vec<f32>) -> Self {
// Here is a playground that validate the ordering of the bit representation of floats in range 0.0..=1.0: NDotProductPoint(normalize_vector(point))
// <https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=6c59e31a3cc5036b32edf51e8937b56e> }
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
let dist = 1.0 - dot_product_similarity(a, b); pub fn into_inner(self) -> Vec<f32> {
self.0
}
}
impl ops::Deref for NDotProductPoint {
type Target = [f32];
fn deref(&self) -> &Self::Target {
self.0.as_slice()
}
}
impl Point for NDotProductPoint {
fn distance(&self, other: &Self) -> f32 {
let dist = 1.0 - dot_product_similarity(&self.0, &other.0);
debug_assert!(!dist.is_nan()); debug_assert!(!dist.is_nan());
dist.to_bits() dist
} }
} }

View File

@ -8,12 +8,11 @@ use charabia::{Language, Script};
use heed::flags::Flags; use heed::flags::Flags;
use heed::types::*; use heed::types::*;
use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn}; use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn};
use rand_pcg::Pcg32;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use rstar::RTree; use rstar::RTree;
use time::OffsetDateTime; use time::OffsetDateTime;
use crate::distance::DotProduct; use crate::distance::NDotProductPoint;
use crate::error::{InternalError, UserError}; use crate::error::{InternalError, UserError};
use crate::facet::FacetType; use crate::facet::FacetType;
use crate::fields_ids_map::FieldsIdsMap; use crate::fields_ids_map::FieldsIdsMap;
@ -31,7 +30,7 @@ use crate::{
}; };
/// The HNSW data-structure that we serialize, fill and search in. /// The HNSW data-structure that we serialize, fill and search in.
pub type Hnsw = hnsw::Hnsw<DotProduct, Vec<f32>, Pcg32, 12, 24>; pub type Hnsw = instant_distance::Hnsw<NDotProductPoint>;
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;

View File

@ -28,7 +28,7 @@ use db_cache::DatabaseCache;
use exact_attribute::ExactAttribute; use exact_attribute::ExactAttribute;
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
use heed::RoTxn; use heed::RoTxn;
use hnsw::Searcher; use instant_distance::Search;
use interner::{DedupInterner, Interner}; use interner::{DedupInterner, Interner};
pub use logger::visual::VisualSearchLogger; pub use logger::visual::VisualSearchLogger;
pub use logger::{DefaultSearchLogger, SearchLogger}; pub use logger::{DefaultSearchLogger, SearchLogger};
@ -40,19 +40,18 @@ use ranking_rules::{
use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache};
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use sort::Sort; use sort::Sort;
use space::Neighbor;
use self::distinct::facet_string_values; use self::distinct::facet_string_values;
use self::geo_sort::GeoSort; use self::geo_sort::GeoSort;
pub use self::geo_sort::Strategy as GeoSortStrategy; pub use self::geo_sort::Strategy as GeoSortStrategy;
use self::graph_based_ranking_rule::Words; use self::graph_based_ranking_rule::Words;
use self::interner::Interned; use self::interner::Interned;
use crate::distance::NDotProductPoint;
use crate::error::FieldIdMapMissingEntry; use crate::error::FieldIdMapMissingEntry;
use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::score_details::{ScoreDetails, ScoringStrategy};
use crate::search::new::distinct::apply_distinct_rule; use crate::search::new::distinct::apply_distinct_rule;
use crate::{ use crate::{
normalize_vector, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32,
UserError, BEU32,
}; };
/// A structure used throughout the execution of a search query. /// A structure used throughout the execution of a search query.
@ -445,17 +444,16 @@ pub fn execute_search(
check_sort_criteria(ctx, sort_criteria.as_ref())?; check_sort_criteria(ctx, sort_criteria.as_ref())?;
if let Some(vector) = vector { if let Some(vector) = vector {
let mut searcher = Searcher::new(); let mut search = Search::default();
let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); let docids = match ctx.index.vector_hnsw(ctx.txn)? {
let ef = hnsw.len().min(100); Some(hnsw) => {
let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; let vector = NDotProductPoint::new(vector.clone());
let vector = normalize_vector(vector.clone()); let neighbors = hnsw.search(&vector, &mut search);
let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]);
let mut docids = Vec::new(); let mut docids = Vec::new();
let mut uniq_docids = RoaringBitmap::new(); let mut uniq_docids = RoaringBitmap::new();
for Neighbor { index, distance: _ } in neighbors.iter() { for instant_distance::Item { distance: _, pid, point: _ } in neighbors {
let index = BEU32::new(*index as u32); let index = BEU32::new(pid.into_inner());
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get();
if universe.contains(docid) && uniq_docids.insert(docid) { if universe.contains(docid) && uniq_docids.insert(docid) {
docids.push(docid); docids.push(docid);
@ -467,7 +465,10 @@ pub fn execute_search(
// return the nearest documents that are also part of the candidates // return the nearest documents that are also part of the candidates
// along with a dummy list of scores that are useless in this context. // along with a dummy list of scores that are useless in this context.
let docids: Vec<_> = docids.into_iter().skip(from).take(length).collect(); docids.into_iter().skip(from).take(length).collect()
}
None => Vec::new(),
};
return Ok(PartialSearchResult { return Ok(PartialSearchResult {
candidates: universe, candidates: universe,

View File

@ -4,10 +4,9 @@ use std::collections::{BTreeSet, HashMap, HashSet};
use fst::IntoStreamer; use fst::IntoStreamer;
use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice}; use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice};
use heed::{BytesDecode, BytesEncode, Database, RwIter}; use heed::{BytesDecode, BytesEncode, Database, RwIter};
use hnsw::Searcher; use instant_distance::PointId;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use space::KnnPoints;
use time::OffsetDateTime; use time::OffsetDateTime;
use super::facet::delete::FacetsDelete; use super::facet::delete::FacetsDelete;
@ -436,24 +435,24 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> {
// An ugly and slow way to remove the vectors from the HNSW // An ugly and slow way to remove the vectors from the HNSW
// It basically reconstructs the HNSW from scratch without editing the current one. // It basically reconstructs the HNSW from scratch without editing the current one.
let current_hnsw = self.index.vector_hnsw(self.wtxn)?.unwrap_or_default(); if let Some(current_hnsw) = self.index.vector_hnsw(self.wtxn)? {
if !current_hnsw.is_empty() { let mut points = Vec::new();
let mut new_hnsw = Hnsw::default(); let mut docids = Vec::new();
let mut searcher = Searcher::new();
let mut new_vector_id_docids = Vec::new();
for result in vector_id_docid.iter(self.wtxn)? { for result in vector_id_docid.iter(self.wtxn)? {
let (vector_id, docid) = result?; let (vector_id, docid) = result?;
if !self.to_delete_docids.contains(docid.get()) { if !self.to_delete_docids.contains(docid.get()) {
let vector = current_hnsw.get_point(vector_id.get() as usize).clone(); let pid = PointId::from(vector_id.get());
let vector_id = new_hnsw.insert(vector, &mut searcher); let vector = current_hnsw[pid].clone();
new_vector_id_docids.push((vector_id as u32, docid)); points.push(vector);
docids.push(docid);
} }
} }
let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points);
vector_id_docid.clear(self.wtxn)?; vector_id_docid.clear(self.wtxn)?;
for (vector_id, docid) in new_vector_id_docids { for (pid, docid) in pids.into_iter().zip(docids) {
vector_id_docid.put(self.wtxn, &BEU32::new(vector_id), &docid)?; vector_id_docid.put(self.wtxn, &BEU32::new(pid.into_inner()), &docid)?;
} }
self.index.put_vector_hnsw(self.wtxn, &new_hnsw)?; self.index.put_vector_hnsw(self.wtxn, &new_hnsw)?;
} }

View File

@ -9,22 +9,19 @@ use charabia::{Language, Script};
use grenad::MergerBuilder; use grenad::MergerBuilder;
use heed::types::ByteSlice; use heed::types::ByteSlice;
use heed::RwTxn; use heed::RwTxn;
use hnsw::Searcher;
use roaring::RoaringBitmap; use roaring::RoaringBitmap;
use space::KnnPoints;
use super::helpers::{ use super::helpers::{
self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap, self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap,
}; };
use super::{ClonableMmap, MergeFn}; use super::{ClonableMmap, MergeFn};
use crate::distance::NDotProductPoint;
use crate::error::UserError; use crate::error::UserError;
use crate::facet::FacetType; use crate::facet::FacetType;
use crate::index::Hnsw;
use crate::update::facet::FacetsUpdate; use crate::update::facet::FacetsUpdate;
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
use crate::{ use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32};
lat_lng_to_xyz, normalize_vector, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result,
BEU32,
};
pub(crate) enum TypedChunk { pub(crate) enum TypedChunk {
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>), FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
@ -230,17 +227,20 @@ pub(crate) fn write_typed_chunk_into_index(
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
} }
TypedChunk::VectorPoints(vector_points) => { TypedChunk::VectorPoints(vector_points) => {
let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); let (pids, mut points): (Vec<_>, Vec<_>) = match index.vector_hnsw(wtxn)? {
let mut searcher = Searcher::new(); Some(hnsw) => hnsw.iter().map(|(pid, point)| (pid, point.clone())).unzip(),
None => Default::default(),
let mut expected_dimensions = match index.vector_id_docid.iter(wtxn)?.next() {
Some(result) => {
let (vector_id, _) = result?;
Some(hnsw.get_point(vector_id.get() as usize).len())
}
None => None,
}; };
// Convert the PointIds into DocumentIds
let mut docids = Vec::new();
for pid in pids {
let docid =
index.vector_id_docid.get(wtxn, &BEU32::new(pid.into_inner()))?.unwrap();
docids.push(docid.get());
}
let mut expected_dimensions = points.get(0).map(|p| p.len());
let mut cursor = vector_points.into_cursor()?; let mut cursor = vector_points.into_cursor()?;
while let Some((key, value)) = cursor.move_on_next()? { while let Some((key, value)) = cursor.move_on_next()? {
// convert the key back to a u32 (4 bytes) // convert the key back to a u32 (4 bytes)
@ -256,12 +256,26 @@ pub(crate) fn write_typed_chunk_into_index(
return Err(UserError::InvalidVectorDimensions { expected, found })?; return Err(UserError::InvalidVectorDimensions { expected, found })?;
} }
let vector = normalize_vector(vector); points.push(NDotProductPoint::new(vector));
let vector_id = hnsw.insert(vector, &mut searcher) as u32; docids.push(docid);
index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?;
} }
log::debug!("There are {} entries in the HNSW so far", hnsw.len());
index.put_vector_hnsw(wtxn, &hnsw)?; assert_eq!(docids.len(), points.len());
let hnsw_length = points.len();
let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points);
index.vector_id_docid.clear(wtxn)?;
for (docid, pid) in docids.into_iter().zip(pids) {
index.vector_id_docid.put(
wtxn,
&BEU32::new(pid.into_inner()),
&BEU32::new(docid),
)?;
}
log::debug!("There are {} entries in the HNSW so far", hnsw_length);
index.put_vector_hnsw(wtxn, &new_hnsw)?;
} }
TypedChunk::ScriptLanguageDocids(hash_pair) => { TypedChunk::ScriptLanguageDocids(hash_pair) => {
let mut buffer = Vec::new(); let mut buffer = Vec::new();