From 34349faeae779a4ab8687ab16408615ceccba03b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 8 Jun 2023 11:35:36 +0200 Subject: [PATCH 01/40] Create a new _vector extractor --- Cargo.lock | 1 + milli/Cargo.toml | 1 + .../extract/extract_vector_points.rs | 40 +++++++++++++++++++ .../src/update/index_documents/extract/mod.rs | 2 + 4 files changed, 44 insertions(+) create mode 100644 milli/src/update/index_documents/extract/extract_vector_points.rs diff --git a/Cargo.lock b/Cargo.lock index 46218fc34..9d09fef9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2683,6 +2683,7 @@ dependencies = [ "bimap", "bincode", "bstr", + "bytemuck", "byteorder", "charabia", "concat-arrays", diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 138103723..5ff73303a 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -15,6 +15,7 @@ license.workspace = true bimap = { version = "0.6.3", features = ["serde"] } bincode = "1.3.3" bstr = "1.4.0" +bytemuck = "1.13.1" byteorder = "1.4.3" charabia = { version = "0.7.2", default-features = false } concat-arrays = "0.1.2" diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs new file mode 100644 index 000000000..409df5dbd --- /dev/null +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -0,0 +1,40 @@ +use std::fs::File; +use std::io; + +use bytemuck::cast_slice; +use serde_json::from_slice; + +use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; +use crate::{FieldId, InternalError, Result}; + +/// Extracts the embedding vector contained in each document under the `_vector` field. +/// +/// Returns the generated grenad reader containing the docid as key associated to the Vec +#[logging_timer::time] +pub fn extract_vector_points( + obkv_documents: grenad::Reader, + indexer: GrenadParameters, + vector_fid: FieldId, +) -> Result> { + let mut writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let mut cursor = obkv_documents.into_cursor()?; + while let Some((docid_bytes, value)) = cursor.move_on_next()? { + let obkv = obkv::KvReader::new(value); + + // first we get the _vector field + if let Some(vector) = obkv.get(vector_fid) { + // try to extract the vector + let vector: Vec = from_slice(vector).map_err(InternalError::SerdeJson).unwrap(); + let bytes = cast_slice(&vector); + writer.insert(docid_bytes, bytes)?; + } + // else => the _vector object was `null`, there is nothing to do + } + + writer_into_reader(writer) +} diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 632f568ab..128fc29c0 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -4,6 +4,7 @@ mod extract_facet_string_docids; mod extract_fid_docid_facet_values; mod extract_fid_word_count_docids; mod extract_geo_points; +mod extract_vector_points; mod extract_word_docids; mod extract_word_fid_docids; mod extract_word_pair_proximity_docids; @@ -22,6 +23,7 @@ use self::extract_facet_string_docids::extract_facet_string_docids; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_geo_points::extract_geo_points; +use self::extract_vector_points::extract_vector_points; use self::extract_word_docids::extract_word_docids; use self::extract_word_fid_docids::extract_word_fid_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; From 7ac2f1489d5cb7bc2c3333cc9596d617365c2fb0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 8 Jun 2023 11:51:55 +0200 Subject: [PATCH 02/40] Extract the vectors from the documents --- .../src/update/index_documents/extract/mod.rs | 17 +++++++++++++ milli/src/update/index_documents/mod.rs | 3 +++ .../src/update/index_documents/typed_chunk.rs | 24 +++++++++++++++++++ 3 files changed, 44 insertions(+) diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 128fc29c0..fdc6f5616 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -47,6 +47,7 @@ pub(crate) fn data_from_obkv_documents( faceted_fields: HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, + vector_field_id: Option, stop_words: Option>, max_positions_per_attributes: Option, exact_attributes: HashSet, @@ -71,6 +72,7 @@ pub(crate) fn data_from_obkv_documents( &faceted_fields, primary_key_id, geo_fields_ids, + vector_field_id, &stop_words, max_positions_per_attributes, ) @@ -281,6 +283,7 @@ fn send_and_extract_flattened_documents_data( faceted_fields: &HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, + vector_field_id: Option, stop_words: &Option>, max_positions_per_attributes: Option, ) -> Result<( @@ -309,6 +312,20 @@ fn send_and_extract_flattened_documents_data( }); } + if let Some(vector_field_id) = vector_field_id { + let documents_chunk_cloned = flattened_documents_chunk.clone(); + let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + rayon::spawn(move || { + let result = extract_vector_points(documents_chunk_cloned, indexer, vector_field_id); + let _ = match result { + Ok(vector_points) => { + lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) + } + Err(error) => lmdb_writer_sx_cloned.send(Err(error)), + }; + }); + } + let (docid_word_positions_chunk, docid_fid_facet_values_chunks): (Result<_>, Result<_>) = rayon::join( || { diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 993f87a1f..adbab54db 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -304,6 +304,8 @@ where } None => None, }; + // get the fid of the `_vector` field. + let vector_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vector"); let stop_words = self.index.stop_words(self.wtxn)?; let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?; @@ -340,6 +342,7 @@ where faceted_fields, primary_key_id, geo_fields_ids, + vector_field_id, stop_words, max_positions_per_attributes, exact_attributes, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 89b10bffe..8b3477948 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -38,6 +38,7 @@ pub(crate) enum TypedChunk { FieldIdFacetIsNullDocids(grenad::Reader), FieldIdFacetIsEmptyDocids(grenad::Reader), GeoPoints(grenad::Reader), + VectorPoints(grenad::Reader), ScriptLanguageDocids(HashMap<(Script, Language), RoaringBitmap>), } @@ -221,6 +222,29 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_rtree(wtxn, &rtree)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } + TypedChunk::VectorPoints(vector_points) => { + // let mut rtree = index.geo_rtree(wtxn)?.unwrap_or_default(); + // let mut geo_faceted_docids = index.geo_faceted_documents_ids(wtxn)?; + + // let mut cursor = geo_points.into_cursor()?; + // while let Some((key, value)) = cursor.move_on_next()? { + // // convert the key back to a u32 (4 bytes) + // let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + + // // convert the latitude and longitude back to a f64 (8 bytes) + // let (lat, tail) = helpers::try_split_array_at::(value).unwrap(); + // let (lng, _) = helpers::try_split_array_at::(tail).unwrap(); + // let point = [f64::from_ne_bytes(lat), f64::from_ne_bytes(lng)]; + // let xyz_point = lat_lng_to_xyz(&point); + + // rtree.insert(GeoPoint::new(xyz_point, (docid, point))); + // geo_faceted_docids.insert(docid); + // } + // index.put_geo_rtree(wtxn, &rtree)?; + // index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; + + todo!("index vector points") + } TypedChunk::ScriptLanguageDocids(hash_pair) => { let mut buffer = Vec::new(); for (key, value) in hash_pair { From 4571e512d2b306469454f7a82467282dc36d3f41 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 8 Jun 2023 12:19:06 +0200 Subject: [PATCH 03/40] Store the vectors in an HNSW in LMDB --- Cargo.lock | 63 ++++++++++++++++++- milli/Cargo.toml | 5 +- milli/src/dot_product.rs | 16 +++++ milli/src/index.rs | 53 ++++++++++++---- milli/src/lib.rs | 1 + milli/src/update/clear_documents.rs | 3 + milli/src/update/delete_documents.rs | 1 + .../src/update/index_documents/typed_chunk.rs | 36 +++++------ 8 files changed, 142 insertions(+), 36 deletions(-) create mode 100644 milli/src/dot_product.rs diff --git a/Cargo.lock b/Cargo.lock index 9d09fef9d..904d1c225 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1221,6 +1221,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dump" version = "1.2.0" @@ -1725,6 +1731,15 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.6", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1826,6 +1841,22 @@ dependencies = [ "digest", ] +[[package]] +name = "hnsw" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752" +dependencies = [ + "ahash 0.7.6", + "hashbrown 0.11.2", + "libm", + "num-traits", + "rand_core", + "serde", + "smallvec", + "space", +] + [[package]] name = "http" version = "0.2.9" @@ -1956,7 +1987,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", "serde", ] @@ -2057,7 +2088,7 @@ checksum = "37228e06c75842d1097432d94d02f37fe3ebfca9791c2e8fef6e9db17ed128c1" dependencies = [ "cedarwood", "fxhash", - "hashbrown", + "hashbrown 0.12.3", "lazy_static", "phf", "phf_codegen", @@ -2698,6 +2729,7 @@ dependencies = [ "geoutils", "grenad", "heed", + "hnsw", "insta", "itertools", "json-depth-checker", @@ -2712,6 +2744,7 @@ dependencies = [ "once_cell", "ordered-float", "rand", + "rand_pcg", "rayon", "roaring", "rstar", @@ -2721,6 +2754,7 @@ dependencies = [ "smallstr", "smallvec", "smartstring", + "space", "tempfile", "thiserror", "time", @@ -3273,6 +3307,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core", + "serde", +] + [[package]] name = "rayon" version = "1.7.0" @@ -3732,6 +3776,9 @@ name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +dependencies = [ + "serde", +] [[package]] name = "smartstring" @@ -3754,6 +3801,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "space" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5ab9701ae895386d13db622abf411989deff7109b13b46b6173bb4ce5c1d123" +dependencies = [ + "doc-comment", + "num-traits", +] + [[package]] name = "spin" version = "0.5.2" @@ -4405,7 +4462,7 @@ version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c531a2dc4c462b833788be2c07eef4e621d0e9edbd55bf280cc164c1c1aa043" dependencies = [ - "hashbrown", + "hashbrown 0.12.3", "once_cell", ] diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 5ff73303a..08f0c2645 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -15,7 +15,7 @@ license.workspace = true bimap = { version = "0.6.3", features = ["serde"] } bincode = "1.3.3" bstr = "1.4.0" -bytemuck = "1.13.1" +bytemuck = { version = "1.13.1", features = ["extern_crate_alloc"] } byteorder = "1.4.3" charabia = { version = "0.7.2", default-features = false } concat-arrays = "0.1.2" @@ -33,18 +33,21 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f "lmdb", "sync-read-txn", ] } +hnsw = { version = "0.11.0", features = ["serde1"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.5.10" obkv = "0.2.0" once_cell = "1.17.1" ordered-float = "3.6.0" +rand_pcg = { version = "0.3.1", features = ["serde1"] } rayon = "1.7.0" roaring = "0.10.1" rstar = { version = "0.10.0", features = ["serde"] } serde = { version = "1.0.160", features = ["derive"] } serde_json = { version = "1.0.95", features = ["preserve_order"] } slice-group-by = "0.3.0" +space = "0.17.0" smallstr = { version = "0.3.0", features = ["serde"] } smallvec = "1.10.0" smartstring = "1.0.1" diff --git a/milli/src/dot_product.rs b/milli/src/dot_product.rs new file mode 100644 index 000000000..2f5f1e474 --- /dev/null +++ b/milli/src/dot_product.rs @@ -0,0 +1,16 @@ +use serde::{Deserialize, Serialize}; +use space::Metric; + +#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] +pub struct DotProduct; + +impl Metric> for DotProduct { + type Unit = u32; + + // Following . + fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum(); + debug_assert!(!dist.is_nan()); + dist.to_bits() + } +} diff --git a/milli/src/index.rs b/milli/src/index.rs index fad3f665c..4cdfb010c 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -8,10 +8,12 @@ use charabia::{Language, Script}; use heed::flags::Flags; use heed::types::*; use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn}; +use rand_pcg::Pcg32; use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; +use crate::dot_product::DotProduct; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -26,6 +28,9 @@ use crate::{ Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, }; +/// The HNSW data-structure that we serialize, fill and search in. +pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; + pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; @@ -42,6 +47,7 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; + pub const VECTOR_HNSW_KEY: &str = "vector-hnsw"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; @@ -86,6 +92,7 @@ pub mod db_name { pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; + pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; pub const DOCUMENTS: &str = "documents"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } @@ -149,6 +156,9 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, + /// Maps a vector id to the document id that have it. + pub vector_id_docid: Database, OwnedType>, + /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, ObkvCodec>, } @@ -162,7 +172,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(23); + options.max_dbs(24); unsafe { options.flag(Flags::MdbAlwaysFreePages) }; let env = options.open(path)?; @@ -198,11 +208,11 @@ impl Index { env.create_database(&mut wtxn, Some(FACET_ID_IS_NULL_DOCIDS))?; let facet_id_is_empty_docids = env.create_database(&mut wtxn, Some(FACET_ID_IS_EMPTY_DOCIDS))?; - let field_id_docid_facet_f64s = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; + let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; wtxn.commit()?; @@ -231,6 +241,7 @@ impl Index { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, }) } @@ -502,6 +513,26 @@ impl Index { } } + /* vector HNSW */ + + /// Writes the provided `hnsw`. + pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { + self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HNSW_KEY, hnsw) + } + + /// Delete the `hnsw`. + pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { + self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HNSW_KEY) + } + + /// Returns the `hnsw`. + pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { + match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HNSW_KEY)? { + Some(hnsw) => Ok(Some(hnsw)), + None => Ok(None), + } + } + /* field distribution */ /// Writes the field distribution which associates every field name with @@ -1466,9 +1497,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - age 1 | - id 2 | - name 2 | + age 1 + id 2 + name 2 "### ); @@ -1486,9 +1517,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - age 1 | - id 2 | - name 2 | + age 1 + id 2 + name 2 "### ); @@ -1502,9 +1533,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - has_dog 1 | - id 2 | - name 2 | + has_dog 1 + id 2 + name 2 "### ); } diff --git a/milli/src/lib.rs b/milli/src/lib.rs index d3ee4f08e..2e62e35ac 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,6 +10,7 @@ pub mod documents; mod asc_desc; mod criterion; +pub mod dot_product; mod error; mod external_documents_ids; pub mod facet; diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index 04119c641..f4a2d43fe 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -39,6 +39,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, } = self.index; @@ -57,6 +58,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?; + self.index.delete_vector_hnsw(self.wtxn)?; // We clean all the faceted documents ids. for field_id in faceted_fields { @@ -95,6 +97,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_string_docids.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?; + vector_id_docid.clear(self.wtxn)?; documents.clear(self.wtxn)?; Ok(number_of_documents) diff --git a/milli/src/update/delete_documents.rs b/milli/src/update/delete_documents.rs index b971768a3..73af66a95 100644 --- a/milli/src/update/delete_documents.rs +++ b/milli/src/update/delete_documents.rs @@ -240,6 +240,7 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { facet_id_exists_docids, facet_id_is_null_docids, facet_id_is_empty_docids, + vector_id_docid, documents, } = self.index; // Remove from the documents database diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 8b3477948..e2c67044c 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -4,10 +4,12 @@ use std::convert::TryInto; use std::fs::File; use std::io; +use bytemuck::allocation::pod_collect_to_vec; use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::ByteSlice; use heed::RwTxn; +use hnsw::Searcher; use roaring::RoaringBitmap; use super::helpers::{ @@ -17,7 +19,7 @@ use super::{ClonableMmap, MergeFn}; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::as_cloneable_grenad; -use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result}; +use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -223,27 +225,19 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } TypedChunk::VectorPoints(vector_points) => { - // let mut rtree = index.geo_rtree(wtxn)?.unwrap_or_default(); - // let mut geo_faceted_docids = index.geo_faceted_documents_ids(wtxn)?; + let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); + let mut searcher = Searcher::new(); - // let mut cursor = geo_points.into_cursor()?; - // while let Some((key, value)) = cursor.move_on_next()? { - // // convert the key back to a u32 (4 bytes) - // let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - - // // convert the latitude and longitude back to a f64 (8 bytes) - // let (lat, tail) = helpers::try_split_array_at::(value).unwrap(); - // let (lng, _) = helpers::try_split_array_at::(tail).unwrap(); - // let point = [f64::from_ne_bytes(lat), f64::from_ne_bytes(lng)]; - // let xyz_point = lat_lng_to_xyz(&point); - - // rtree.insert(GeoPoint::new(xyz_point, (docid, point))); - // geo_faceted_docids.insert(docid); - // } - // index.put_geo_rtree(wtxn, &rtree)?; - // index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; - - todo!("index vector points") + let mut cursor = vector_points.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + // convert the key back to a u32 (4 bytes) + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + // convert the vector back to a Vec + let vector: Vec = pod_collect_to_vec(value); + let vector_id = hnsw.insert(vector, &mut searcher) as u32; + index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; + } + index.put_vector_hnsw(wtxn, &hnsw)?; } TypedChunk::ScriptLanguageDocids(hash_pair) => { let mut buffer = Vec::new(); From cad90e8cbc9d8b9e0110c855072cef479d06d538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 8 Jun 2023 15:44:03 +0200 Subject: [PATCH 04/40] Add a vector field to the search routes --- meilisearch/src/routes/indexes/search.rs | 3 +++ meilisearch/src/search.rs | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 3ab093b5d..fae24dba2 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -34,6 +34,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) { pub struct SearchQueryGet { #[deserr(default, error = DeserrQueryParamError)] q: Option, + #[deserr(default, error = DeserrQueryParamError)] + vector: Option>, #[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError)] offset: Param, #[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError)] @@ -84,6 +86,7 @@ impl From for SearchQuery { Self { q: other.q, + vector: other.vector, offset: other.offset.0, limit: other.limit.0, page: other.page.as_deref().copied(), diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 62f49c148..81bcb6aaa 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -33,11 +33,13 @@ pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); -#[derive(Debug, Clone, Default, PartialEq, Eq, Deserr)] +#[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct SearchQuery { #[deserr(default, error = DeserrJsonError)] pub q: Option, + #[deserr(default, error = DeserrJsonError)] + pub vector: Option>, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -86,13 +88,15 @@ impl SearchQuery { // This struct contains the fields of `SearchQuery` inline. // This is because neither deserr nor serde support `flatten` when using `deny_unknown_fields. // The `From` implementation ensures both structs remain up to date. -#[derive(Debug, Clone, PartialEq, Eq, Deserr)] +#[derive(Debug, Clone, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct SearchQueryWithIndex { #[deserr(error = DeserrJsonError, missing_field_error = DeserrJsonError::missing_index_uid)] pub index_uid: IndexUid, #[deserr(default, error = DeserrJsonError)] pub q: Option, + #[deserr(default, error = DeserrJsonError)] + pub vector: Option>, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -136,6 +140,7 @@ impl SearchQueryWithIndex { let SearchQueryWithIndex { index_uid, q, + vector, offset, limit, page, @@ -159,6 +164,7 @@ impl SearchQueryWithIndex { index_uid, SearchQuery { q, + vector, offset, limit, page, From 642b0f3a1bbc6bda47e83d7d9c587d683ba494bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Thu, 8 Jun 2023 18:47:06 +0200 Subject: [PATCH 05/40] Expose a new vector field on the search route --- meilisearch/src/search.rs | 4 ++++ milli/examples/search.rs | 1 + milli/src/dot_product.rs | 4 ++++ milli/src/search/mod.rs | 10 +++++++++ milli/src/search/new/matches/mod.rs | 1 + milli/src/search/new/mod.rs | 35 ++++++++++++++++++++++++++++- 6 files changed, 54 insertions(+), 1 deletion(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 81bcb6aaa..f1fb341a2 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -295,6 +295,10 @@ pub fn perform_search( let mut search = index.search(&rtxn); + if let Some(ref vector) = query.vector { + search.vector(vector.clone()); + } + if let Some(ref query) = query.q { search.query(query); } diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 87c9a004d..82de56434 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -52,6 +52,7 @@ fn main() -> Result<(), Box> { let docs = execute_search( &mut ctx, &(!query.trim().is_empty()).then(|| query.trim().to_owned()), + &None, TermsMatchingStrategy::Last, milli::score_details::ScoringStrategy::Skip, false, diff --git a/milli/src/dot_product.rs b/milli/src/dot_product.rs index 2f5f1e474..86dd2f1d4 100644 --- a/milli/src/dot_product.rs +++ b/milli/src/dot_product.rs @@ -7,9 +7,13 @@ pub struct DotProduct; impl Metric> for DotProduct { type Unit = u32; + // TODO explain me this function, I don't understand why f32.to_bits is ordered. + // I tried to do this and it wasn't OK + // // Following . fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum(); + let dist = 1.0 - dist; debug_assert!(!dist.is_nan()); dist.to_bits() } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 3c972d9b0..970c0b7ab 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -23,6 +23,7 @@ pub mod new; pub struct Search<'a> { query: Option, + vector: Option>, // this should be linked to the String in the query filter: Option>, offset: usize, @@ -41,6 +42,7 @@ impl<'a> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { Search { query: None, + vector: None, filter: None, offset: 0, limit: 20, @@ -60,6 +62,11 @@ impl<'a> Search<'a> { self } + pub fn vector(&mut self, vector: impl Into>) -> &mut Search<'a> { + self.vector = Some(vector.into()); + self + } + pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { self.offset = offset; self @@ -114,6 +121,7 @@ impl<'a> Search<'a> { execute_search( &mut ctx, &self.query, + &self.vector, self.terms_matching_strategy, self.scoring_strategy, self.exhaustive_number_hits, @@ -141,6 +149,7 @@ impl fmt::Debug for Search<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let Search { query, + vector: _, filter, offset, limit, @@ -155,6 +164,7 @@ impl fmt::Debug for Search<'_> { } = self; f.debug_struct("Search") .field("query", query) + .field("vector", &"[...]") .field("filter", filter) .field("offset", offset) .field("limit", limit) diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index f33d595e5..ce28e16c1 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -509,6 +509,7 @@ mod tests { let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( &mut ctx, &Some(query.to_string()), + &None, crate::TermsMatchingStrategy::default(), crate::score_details::ScoringStrategy::Skip, false, diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 8df764f29..948a2fa21 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,6 +28,7 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; +use hnsw::Searcher; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -39,6 +40,7 @@ use ranking_rules::{ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use roaring::RoaringBitmap; use sort::Sort; +use space::Neighbor; use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; @@ -46,7 +48,9 @@ use self::graph_based_ranking_rule::Words; use self::interner::Interned; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; -use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError}; +use crate::{ + AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, +}; /// A structure used throughout the execution of a search query. pub struct SearchContext<'ctx> { @@ -350,6 +354,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( pub fn execute_search( ctx: &mut SearchContext, query: &Option, + vector: &Option>, terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, exhaustive_number_hits: bool, @@ -442,6 +447,34 @@ pub fn execute_search( let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?; + let docids = match vector { + Some(vector) => { + // return the nearest documents that are also part of the candidates. + let mut searcher = Searcher::new(); + let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); + let ef = hnsw.len().min(100); + let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; + let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); + + let mut docids = Vec::new(); + for Neighbor { index, distance } in neighbors.iter() { + let index = BEU32::new(*index as u32); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); + dbg!(distance, f32::from_bits(*distance)); + if universe.contains(docid) { + docids.push(docid); + if docids.len() == length { + break; + } + } + } + + docids + } + // return the search docids if the vector field is not specified + None => docids, + }; + // The candidates is the universe unless the exhaustive number of hits // is requested and a distinct attribute is set. if exhaustive_number_hits { From 268a9ef416206b61b49dbecf881987223fea9f74 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 13 Jun 2023 15:19:01 +0200 Subject: [PATCH 06/40] Move to the hgg crate --- Cargo.lock | 54 +++++++------------ milli/Cargo.toml | 3 +- milli/src/{dot_product.rs => distance.rs} | 14 +++++ milli/src/index.rs | 39 ++++++-------- milli/src/lib.rs | 2 +- milli/src/search/new/mod.rs | 36 ++++--------- milli/src/update/clear_documents.rs | 4 +- milli/src/update/delete_documents.rs | 3 +- .../src/update/index_documents/typed_chunk.rs | 17 +++--- 9 files changed, 73 insertions(+), 99 deletions(-) rename milli/src/{dot_product.rs => distance.rs} (63%) diff --git a/Cargo.lock b/Cargo.lock index 904d1c225..f2fe02366 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1736,9 +1736,6 @@ name = "hashbrown" version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" -dependencies = [ - "ahash 0.7.6", -] [[package]] name = "hashbrown" @@ -1749,6 +1746,12 @@ dependencies = [ "ahash 0.7.6", ] +[[package]] +name = "header-vec" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda7e66d32131841c4264e34a32c934df0dedb08d737f861326d616d4338f06f" + [[package]] name = "heapless" version = "0.7.16" @@ -1832,6 +1835,19 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hgg" +version = "0.4.2-alpha.0" +source = "git+https://github.com/rust-cv/hgg#6d1eacde635158163fb663d9327a2d6f612dd435" +dependencies = [ + "ahash 0.7.6", + "hashbrown 0.11.2", + "header-vec", + "num-traits", + "serde", + "space", +] + [[package]] name = "hmac" version = "0.12.1" @@ -1841,22 +1857,6 @@ dependencies = [ "digest", ] -[[package]] -name = "hnsw" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752" -dependencies = [ - "ahash 0.7.6", - "hashbrown 0.11.2", - "libm", - "num-traits", - "rand_core", - "serde", - "smallvec", - "space", -] - [[package]] name = "http" version = "0.2.9" @@ -2729,7 +2729,7 @@ dependencies = [ "geoutils", "grenad", "heed", - "hnsw", + "hgg", "insta", "itertools", "json-depth-checker", @@ -2744,7 +2744,6 @@ dependencies = [ "once_cell", "ordered-float", "rand", - "rand_pcg", "rayon", "roaring", "rstar", @@ -3307,16 +3306,6 @@ dependencies = [ "getrandom", ] -[[package]] -name = "rand_pcg" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" -dependencies = [ - "rand_core", - "serde", -] - [[package]] name = "rayon" version = "1.7.0" @@ -3776,9 +3765,6 @@ name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" -dependencies = [ - "serde", -] [[package]] name = "smartstring" diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 08f0c2645..c17d100f5 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -33,14 +33,13 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f "lmdb", "sync-read-txn", ] } -hnsw = { version = "0.11.0", features = ["serde1"] } +hgg = { git = "https://github.com/rust-cv/hgg", features = ["serde"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.5.10" obkv = "0.2.0" once_cell = "1.17.1" ordered-float = "3.6.0" -rand_pcg = { version = "0.3.1", features = ["serde1"] } rayon = "1.7.0" roaring = "0.10.1" rstar = { version = "0.10.0", features = ["serde"] } diff --git a/milli/src/dot_product.rs b/milli/src/distance.rs similarity index 63% rename from milli/src/dot_product.rs rename to milli/src/distance.rs index 86dd2f1d4..c26a745a4 100644 --- a/milli/src/dot_product.rs +++ b/milli/src/distance.rs @@ -18,3 +18,17 @@ impl Metric> for DotProduct { dist.to_bits() } } + +#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] +pub struct Euclidean; + +impl Metric> for Euclidean { + type Unit = u32; + + fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum(); + let dist = squared.sqrt(); + debug_assert!(!dist.is_nan()); + dist.to_bits() + } +} diff --git a/milli/src/index.rs b/milli/src/index.rs index 4cdfb010c..e29c6da22 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -8,12 +8,11 @@ use charabia::{Language, Script}; use heed::flags::Flags; use heed::types::*; use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn}; -use rand_pcg::Pcg32; use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::dot_product::DotProduct; +use crate::distance::Euclidean; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -28,8 +27,8 @@ use crate::{ Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, }; -/// The HNSW data-structure that we serialize, fill and search in. -pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; +/// The HGG data-structure that we serialize, fill and search in. +pub type Hgg = hgg::Hgg, DocumentId>; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; @@ -47,7 +46,7 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; - pub const VECTOR_HNSW_KEY: &str = "vector-hnsw"; + pub const VECTOR_HGG_KEY: &str = "vector-hgg"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; @@ -92,7 +91,6 @@ pub mod db_name { pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; - pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; pub const DOCUMENTS: &str = "documents"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } @@ -156,9 +154,6 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, - /// Maps a vector id to the document id that have it. - pub vector_id_docid: Database, OwnedType>, - /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, ObkvCodec>, } @@ -172,7 +167,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(24); + options.max_dbs(23); unsafe { options.flag(Flags::MdbAlwaysFreePages) }; let env = options.open(path)?; @@ -212,7 +207,6 @@ impl Index { env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; - let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; wtxn.commit()?; @@ -241,7 +235,6 @@ impl Index { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, - vector_id_docid, documents, }) } @@ -513,22 +506,22 @@ impl Index { } } - /* vector HNSW */ + /* vector HGG */ - /// Writes the provided `hnsw`. - pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { - self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HNSW_KEY, hnsw) + /// Writes the provided `hgg`. + pub(crate) fn put_vector_hgg(&self, wtxn: &mut RwTxn, hgg: &Hgg) -> heed::Result<()> { + self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HGG_KEY, hgg) } - /// Delete the `hnsw`. - pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { - self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HNSW_KEY) + /// Delete the `hgg`. + pub(crate) fn delete_vector_hgg(&self, wtxn: &mut RwTxn) -> heed::Result { + self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HGG_KEY) } - /// Returns the `hnsw`. - pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { - match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HNSW_KEY)? { - Some(hnsw) => Ok(Some(hnsw)), + /// Returns the `hgg`. + pub fn vector_hgg(&self, rtxn: &RoTxn) -> Result> { + match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HGG_KEY)? { + Some(hgg) => Ok(Some(hgg)), None => Ok(None), } } diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 2e62e35ac..4c7428fa8 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,7 +10,7 @@ pub mod documents; mod asc_desc; mod criterion; -pub mod dot_product; +mod distance; mod error; mod external_documents_ids; pub mod facet; diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 948a2fa21..f1aa21484 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,7 +28,6 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; -use hnsw::Searcher; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -40,7 +39,7 @@ use ranking_rules::{ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use roaring::RoaringBitmap; use sort::Sort; -use space::Neighbor; +use space::{KnnMap, Neighbor}; use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; @@ -48,9 +47,7 @@ use self::graph_based_ranking_rule::Words; use self::interner::Interned; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; -use crate::{ - AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, -}; +use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError}; /// A structure used throughout the execution of a search query. pub struct SearchContext<'ctx> { @@ -450,26 +447,15 @@ pub fn execute_search( let docids = match vector { Some(vector) => { // return the nearest documents that are also part of the candidates. - let mut searcher = Searcher::new(); - let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); - let ef = hnsw.len().min(100); - let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; - let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); - - let mut docids = Vec::new(); - for Neighbor { index, distance } in neighbors.iter() { - let index = BEU32::new(*index as u32); - let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); - dbg!(distance, f32::from_bits(*distance)); - if universe.contains(docid) { - docids.push(docid); - if docids.len() == length { - break; - } - } - } - - docids + let hgg = ctx.index.vector_hgg(ctx.txn)?.unwrap_or_default(); + hgg.knn_values(&vector, 100) + .into_iter() + .filter(|(Neighbor { distance, .. }, docid)| { + dbg!(distance, f32::from_bits(*distance)); + universe.contains(**docid) + }) + .map(|(_, docid)| *docid) + .collect() } // return the search docids if the vector field is not specified None => docids, diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index f4a2d43fe..e5e7f5491 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -39,7 +39,6 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, - vector_id_docid, documents, } = self.index; @@ -58,7 +57,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?; - self.index.delete_vector_hnsw(self.wtxn)?; + self.index.delete_vector_hgg(self.wtxn)?; // We clean all the faceted documents ids. for field_id in faceted_fields { @@ -97,7 +96,6 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_string_docids.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?; - vector_id_docid.clear(self.wtxn)?; documents.clear(self.wtxn)?; Ok(number_of_documents) diff --git a/milli/src/update/delete_documents.rs b/milli/src/update/delete_documents.rs index 73af66a95..890c2b329 100644 --- a/milli/src/update/delete_documents.rs +++ b/milli/src/update/delete_documents.rs @@ -240,7 +240,6 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { facet_id_exists_docids, facet_id_is_null_docids, facet_id_is_empty_docids, - vector_id_docid, documents, } = self.index; // Remove from the documents database @@ -275,6 +274,8 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { &mut words_to_delete, )?; + todo!("delete the documents from the Hgg datastructure"); + // We construct an FST set that contains the words to delete from the words FST. let words_to_delete = fst::Set::from_iter(words_to_delete.difference(&words_to_keep))?; diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index e2c67044c..82c02375c 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -9,8 +9,8 @@ use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::ByteSlice; use heed::RwTxn; -use hnsw::Searcher; use roaring::RoaringBitmap; +use space::KnnInsert; use super::helpers::{ self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap, @@ -19,7 +19,7 @@ use super::{ClonableMmap, MergeFn}; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::as_cloneable_grenad; -use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; +use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -225,19 +225,16 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } TypedChunk::VectorPoints(vector_points) => { - let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); - let mut searcher = Searcher::new(); - + let mut hgg = index.vector_hgg(wtxn)?.unwrap_or_default(); let mut cursor = vector_points.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - // convert the vector back to a Vec - let vector: Vec = pod_collect_to_vec(value); - let vector_id = hnsw.insert(vector, &mut searcher) as u32; - index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; + // convert the vector back to a Vec and insert it. + // TODO enable again when the library is fixed + hgg.insert(pod_collect_to_vec(value), docid); } - index.put_vector_hnsw(wtxn, &hnsw)?; + index.put_vector_hgg(wtxn, &hgg)?; } TypedChunk::ScriptLanguageDocids(hash_pair) => { let mut buffer = Vec::new(); From 58160081397202088ffcb2fb522deb0831e16085 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 14:17:22 +0200 Subject: [PATCH 07/40] Introduce an optimized version of the euclidean distance function --- milli/src/distance.rs | 56 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/milli/src/distance.rs b/milli/src/distance.rs index c26a745a4..bbd2f15eb 100644 --- a/milli/src/distance.rs +++ b/milli/src/distance.rs @@ -1,6 +1,13 @@ use serde::{Deserialize, Serialize}; use space::Metric; +#[cfg(any( + target_arch = "x86", + target_arch = "x86_64", + all(target_arch = "aarch64", target_feature = "neon") +))] +const MIN_DIM_SIZE_SIMD: usize = 16; + #[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] pub struct DotProduct; @@ -26,9 +33,58 @@ impl Metric> for Euclidean { type Unit = u32; fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] + { + if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD { + let squared = unsafe { squared_euclid_neon(&a, &b) }; + let dist = squared.sqrt(); + debug_assert!(!dist.is_nan()); + return dist.to_bits(); + } + } + let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum(); let dist = squared.sqrt(); debug_assert!(!dist.is_nan()); dist.to_bits() } } + +#[cfg(target_feature = "neon")] +use std::arch::aarch64::*; + +#[cfg(target_feature = "neon")] +pub(crate) unsafe fn squared_euclid_neon(v1: &[f32], v2: &[f32]) -> f32 { + let n = v1.len(); + let m = n - (n % 16); + let mut ptr1: *const f32 = v1.as_ptr(); + let mut ptr2: *const f32 = v2.as_ptr(); + let mut sum1 = vdupq_n_f32(0.); + let mut sum2 = vdupq_n_f32(0.); + let mut sum3 = vdupq_n_f32(0.); + let mut sum4 = vdupq_n_f32(0.); + + let mut i: usize = 0; + while i < m { + let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2)); + sum1 = vfmaq_f32(sum1, sub1, sub1); + + let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4))); + sum2 = vfmaq_f32(sum2, sub2, sub2); + + let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8))); + sum3 = vfmaq_f32(sum3, sub3, sub3); + + let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12))); + sum4 = vfmaq_f32(sum4, sub4, sub4); + + ptr1 = ptr1.add(16); + ptr2 = ptr2.add(16); + i += 16; + } + let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); + for i in 0..n - m { + result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); + } + result +} From aca305bb77e910f0dec3f24c8a0065c12ae5e1e3 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 14:17:55 +0200 Subject: [PATCH 08/40] Log more to make sure we insert vectors in the hgg data-structure --- milli/src/update/index_documents/typed_chunk.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 82c02375c..122484a6d 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -234,6 +234,7 @@ pub(crate) fn write_typed_chunk_into_index( // TODO enable again when the library is fixed hgg.insert(pod_collect_to_vec(value), docid); } + log::debug!("There are {} entries in the HGG so far", hgg.len()); index.put_vector_hgg(wtxn, &hgg)?; } TypedChunk::ScriptLanguageDocids(hash_pair) => { From c79e82c62a57a48ca30a3f8c9092bd8868996619 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 14:20:05 +0200 Subject: [PATCH 09/40] Move back to the hnsw crate This reverts commit 7a4b6c065482f988b01298642f4c18775503f92f. --- Cargo.lock | 54 ++++++++++++------- milli/Cargo.toml | 3 +- milli/src/{distance.rs => dot_product.rs} | 0 milli/src/index.rs | 39 ++++++++------ milli/src/lib.rs | 3 +- milli/src/search/new/mod.rs | 36 +++++++++---- milli/src/update/clear_documents.rs | 4 +- milli/src/update/delete_documents.rs | 3 +- .../src/update/index_documents/typed_chunk.rs | 19 ++++--- 9 files changed, 101 insertions(+), 60 deletions(-) rename milli/src/{distance.rs => dot_product.rs} (100%) diff --git a/Cargo.lock b/Cargo.lock index f2fe02366..904d1c225 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1736,6 +1736,9 @@ name = "hashbrown" version = "0.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.6", +] [[package]] name = "hashbrown" @@ -1746,12 +1749,6 @@ dependencies = [ "ahash 0.7.6", ] -[[package]] -name = "header-vec" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bda7e66d32131841c4264e34a32c934df0dedb08d737f861326d616d4338f06f" - [[package]] name = "heapless" version = "0.7.16" @@ -1835,19 +1832,6 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" -[[package]] -name = "hgg" -version = "0.4.2-alpha.0" -source = "git+https://github.com/rust-cv/hgg#6d1eacde635158163fb663d9327a2d6f612dd435" -dependencies = [ - "ahash 0.7.6", - "hashbrown 0.11.2", - "header-vec", - "num-traits", - "serde", - "space", -] - [[package]] name = "hmac" version = "0.12.1" @@ -1857,6 +1841,22 @@ dependencies = [ "digest", ] +[[package]] +name = "hnsw" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752" +dependencies = [ + "ahash 0.7.6", + "hashbrown 0.11.2", + "libm", + "num-traits", + "rand_core", + "serde", + "smallvec", + "space", +] + [[package]] name = "http" version = "0.2.9" @@ -2729,7 +2729,7 @@ dependencies = [ "geoutils", "grenad", "heed", - "hgg", + "hnsw", "insta", "itertools", "json-depth-checker", @@ -2744,6 +2744,7 @@ dependencies = [ "once_cell", "ordered-float", "rand", + "rand_pcg", "rayon", "roaring", "rstar", @@ -3306,6 +3307,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core", + "serde", +] + [[package]] name = "rayon" version = "1.7.0" @@ -3765,6 +3776,9 @@ name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +dependencies = [ + "serde", +] [[package]] name = "smartstring" diff --git a/milli/Cargo.toml b/milli/Cargo.toml index c17d100f5..08f0c2645 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -33,13 +33,14 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f "lmdb", "sync-read-txn", ] } -hgg = { git = "https://github.com/rust-cv/hgg", features = ["serde"] } +hnsw = { version = "0.11.0", features = ["serde1"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.5.10" obkv = "0.2.0" once_cell = "1.17.1" ordered-float = "3.6.0" +rand_pcg = { version = "0.3.1", features = ["serde1"] } rayon = "1.7.0" roaring = "0.10.1" rstar = { version = "0.10.0", features = ["serde"] } diff --git a/milli/src/distance.rs b/milli/src/dot_product.rs similarity index 100% rename from milli/src/distance.rs rename to milli/src/dot_product.rs diff --git a/milli/src/index.rs b/milli/src/index.rs index e29c6da22..4cdfb010c 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -8,11 +8,12 @@ use charabia::{Language, Script}; use heed::flags::Flags; use heed::types::*; use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn}; +use rand_pcg::Pcg32; use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::distance::Euclidean; +use crate::dot_product::DotProduct; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -27,8 +28,8 @@ use crate::{ Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, }; -/// The HGG data-structure that we serialize, fill and search in. -pub type Hgg = hgg::Hgg, DocumentId>; +/// The HNSW data-structure that we serialize, fill and search in. +pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; @@ -46,7 +47,7 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; - pub const VECTOR_HGG_KEY: &str = "vector-hgg"; + pub const VECTOR_HNSW_KEY: &str = "vector-hnsw"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; @@ -91,6 +92,7 @@ pub mod db_name { pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; + pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; pub const DOCUMENTS: &str = "documents"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } @@ -154,6 +156,9 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, + /// Maps a vector id to the document id that have it. + pub vector_id_docid: Database, OwnedType>, + /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, ObkvCodec>, } @@ -167,7 +172,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(23); + options.max_dbs(24); unsafe { options.flag(Flags::MdbAlwaysFreePages) }; let env = options.open(path)?; @@ -207,6 +212,7 @@ impl Index { env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; + let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; wtxn.commit()?; @@ -235,6 +241,7 @@ impl Index { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, }) } @@ -506,22 +513,22 @@ impl Index { } } - /* vector HGG */ + /* vector HNSW */ - /// Writes the provided `hgg`. - pub(crate) fn put_vector_hgg(&self, wtxn: &mut RwTxn, hgg: &Hgg) -> heed::Result<()> { - self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HGG_KEY, hgg) + /// Writes the provided `hnsw`. + pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { + self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HNSW_KEY, hnsw) } - /// Delete the `hgg`. - pub(crate) fn delete_vector_hgg(&self, wtxn: &mut RwTxn) -> heed::Result { - self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HGG_KEY) + /// Delete the `hnsw`. + pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { + self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HNSW_KEY) } - /// Returns the `hgg`. - pub fn vector_hgg(&self, rtxn: &RoTxn) -> Result> { - match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HGG_KEY)? { - Some(hgg) => Ok(Some(hgg)), + /// Returns the `hnsw`. + pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { + match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HNSW_KEY)? { + Some(hnsw) => Ok(Some(hnsw)), None => Ok(None), } } diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 4c7428fa8..a1dc6ca4f 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,7 +10,7 @@ pub mod documents; mod asc_desc; mod criterion; -mod distance; +pub mod dot_product; mod error; mod external_documents_ids; pub mod facet; @@ -20,6 +20,7 @@ pub mod index; pub mod proximity; pub mod score_details; mod search; +mod search; pub mod update; #[cfg(test)] diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index f1aa21484..948a2fa21 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,6 +28,7 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; +use hnsw::Searcher; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -39,7 +40,7 @@ use ranking_rules::{ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use roaring::RoaringBitmap; use sort::Sort; -use space::{KnnMap, Neighbor}; +use space::Neighbor; use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; @@ -47,7 +48,9 @@ use self::graph_based_ranking_rule::Words; use self::interner::Interned; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; -use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError}; +use crate::{ + AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, +}; /// A structure used throughout the execution of a search query. pub struct SearchContext<'ctx> { @@ -447,15 +450,26 @@ pub fn execute_search( let docids = match vector { Some(vector) => { // return the nearest documents that are also part of the candidates. - let hgg = ctx.index.vector_hgg(ctx.txn)?.unwrap_or_default(); - hgg.knn_values(&vector, 100) - .into_iter() - .filter(|(Neighbor { distance, .. }, docid)| { - dbg!(distance, f32::from_bits(*distance)); - universe.contains(**docid) - }) - .map(|(_, docid)| *docid) - .collect() + let mut searcher = Searcher::new(); + let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); + let ef = hnsw.len().min(100); + let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; + let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); + + let mut docids = Vec::new(); + for Neighbor { index, distance } in neighbors.iter() { + let index = BEU32::new(*index as u32); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); + dbg!(distance, f32::from_bits(*distance)); + if universe.contains(docid) { + docids.push(docid); + if docids.len() == length { + break; + } + } + } + + docids } // return the search docids if the vector field is not specified None => docids, diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index e5e7f5491..f4a2d43fe 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -39,6 +39,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, } = self.index; @@ -57,7 +58,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?; - self.index.delete_vector_hgg(self.wtxn)?; + self.index.delete_vector_hnsw(self.wtxn)?; // We clean all the faceted documents ids. for field_id in faceted_fields { @@ -96,6 +97,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_string_docids.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?; + vector_id_docid.clear(self.wtxn)?; documents.clear(self.wtxn)?; Ok(number_of_documents) diff --git a/milli/src/update/delete_documents.rs b/milli/src/update/delete_documents.rs index 890c2b329..73af66a95 100644 --- a/milli/src/update/delete_documents.rs +++ b/milli/src/update/delete_documents.rs @@ -240,6 +240,7 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { facet_id_exists_docids, facet_id_is_null_docids, facet_id_is_empty_docids, + vector_id_docid, documents, } = self.index; // Remove from the documents database @@ -274,8 +275,6 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { &mut words_to_delete, )?; - todo!("delete the documents from the Hgg datastructure"); - // We construct an FST set that contains the words to delete from the words FST. let words_to_delete = fst::Set::from_iter(words_to_delete.difference(&words_to_keep))?; diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 122484a6d..e136dc139 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -9,8 +9,8 @@ use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::ByteSlice; use heed::RwTxn; +use hnsw::Searcher; use roaring::RoaringBitmap; -use space::KnnInsert; use super::helpers::{ self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap, @@ -19,7 +19,7 @@ use super::{ClonableMmap, MergeFn}; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::as_cloneable_grenad; -use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result}; +use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -225,17 +225,20 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } TypedChunk::VectorPoints(vector_points) => { - let mut hgg = index.vector_hgg(wtxn)?.unwrap_or_default(); + let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); + let mut searcher = Searcher::new(); + let mut cursor = vector_points.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - // convert the vector back to a Vec and insert it. - // TODO enable again when the library is fixed - hgg.insert(pod_collect_to_vec(value), docid); + // convert the vector back to a Vec + let vector: Vec = pod_collect_to_vec(value); + let vector_id = hnsw.insert(vector, &mut searcher) as u32; + index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; } - log::debug!("There are {} entries in the HGG so far", hgg.len()); - index.put_vector_hgg(wtxn, &hgg)?; + log::debug!("There are {} entries in the HNSW so far", hnsw.len()); + index.put_vector_hnsw(wtxn, &hnsw)?; } TypedChunk::ScriptLanguageDocids(hash_pair) => { let mut buffer = Vec::new(); From 8debf6fe81e761d3d15d3f1c996d1ecfd910893e Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 14:22:36 +0200 Subject: [PATCH 10/40] Use a basic euclidean distance function --- milli/src/index.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/milli/src/index.rs b/milli/src/index.rs index 4cdfb010c..5dc9a7ad7 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -13,7 +13,7 @@ use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::dot_product::DotProduct; +use crate::dot_product::Euclidean; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -29,7 +29,7 @@ use crate::{ }; /// The HNSW data-structure that we serialize, fill and search in. -pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; +pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; From 436a10bef490a69c45ce8049898d6b99422dd124 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 14:34:58 +0200 Subject: [PATCH 11/40] Replace the euclidean with a dot product --- milli/src/index.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/milli/src/index.rs b/milli/src/index.rs index 5dc9a7ad7..4cdfb010c 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -13,7 +13,7 @@ use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::dot_product::Euclidean; +use crate::dot_product::DotProduct; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -29,7 +29,7 @@ use crate::{ }; /// The HNSW data-structure that we serialize, fill and search in. -pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; +pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; From c2a402f3ae3da297b0e13c8175ef5b5c3536d17d Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 14:54:28 +0200 Subject: [PATCH 12/40] Implement an ugly deletion of values in the HNSW --- milli/src/update/delete_documents.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/milli/src/update/delete_documents.rs b/milli/src/update/delete_documents.rs index 73af66a95..766f0e16e 100644 --- a/milli/src/update/delete_documents.rs +++ b/milli/src/update/delete_documents.rs @@ -4,8 +4,10 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use fst::IntoStreamer; use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice}; use heed::{BytesDecode, BytesEncode, Database, RwIter}; +use hnsw::Searcher; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; +use space::KnnPoints; use time::OffsetDateTime; use super::facet::delete::FacetsDelete; @@ -14,6 +16,7 @@ use crate::error::InternalError; use crate::facet::FacetType; use crate::heed_codec::facet::FieldDocIdFacetCodec; use crate::heed_codec::CboRoaringBitmapCodec; +use crate::index::Hnsw; use crate::{ ExternalDocumentsIds, FieldId, FieldIdMapMissingEntry, Index, Result, RoaringBitmapCodec, BEU32, }; @@ -430,6 +433,30 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { &self.to_delete_docids, )?; + // An ugly and slow way to remove the vectors from the HNSW + // It basically reconstructs the HNSW from scratch without editing the current one. + let current_hnsw = self.index.vector_hnsw(self.wtxn)?.unwrap_or_default(); + if !current_hnsw.is_empty() { + let mut new_hnsw = Hnsw::default(); + let mut searcher = Searcher::new(); + let mut new_vector_id_docids = Vec::new(); + + for result in vector_id_docid.iter(self.wtxn)? { + let (vector_id, docid) = result?; + if !self.to_delete_docids.contains(docid.get()) { + let vector = current_hnsw.get_point(vector_id.get() as usize).clone(); + let vector_id = new_hnsw.insert(vector, &mut searcher); + new_vector_id_docids.push((vector_id as u32, docid)); + } + } + + vector_id_docid.clear(self.wtxn)?; + for (vector_id, docid) in new_vector_id_docids { + vector_id_docid.put(self.wtxn, &BEU32::new(vector_id), &docid)?; + } + self.index.put_vector_hnsw(self.wtxn, &new_hnsw)?; + } + self.index.put_soft_deleted_documents_ids(self.wtxn, &RoaringBitmap::new())?; Ok(DetailedDocumentDeletionResult { From 23eaaf1001bbbb4620d878540058e541860341c3 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 14:58:51 +0200 Subject: [PATCH 13/40] Change the name of the distance module --- milli/src/{dot_product.rs => distance.rs} | 0 milli/src/index.rs | 2 +- milli/src/lib.rs | 3 +-- 3 files changed, 2 insertions(+), 3 deletions(-) rename milli/src/{dot_product.rs => distance.rs} (100%) diff --git a/milli/src/dot_product.rs b/milli/src/distance.rs similarity index 100% rename from milli/src/dot_product.rs rename to milli/src/distance.rs diff --git a/milli/src/index.rs b/milli/src/index.rs index 4cdfb010c..29d602330 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -13,7 +13,7 @@ use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::dot_product::DotProduct; +use crate::distance::DotProduct; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; diff --git a/milli/src/lib.rs b/milli/src/lib.rs index a1dc6ca4f..021880a50 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,7 +10,7 @@ pub mod documents; mod asc_desc; mod criterion; -pub mod dot_product; +pub mod distance; mod error; mod external_documents_ids; pub mod facet; @@ -20,7 +20,6 @@ pub mod index; pub mod proximity; pub mod score_details; mod search; -mod search; pub mod update; #[cfg(test)] From 3c31e1cdd1cce1b8e3ab5740609dbee0a1f0b166 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 15:45:00 +0200 Subject: [PATCH 14/40] Support more pages but in an ugly way --- milli/src/search/new/mod.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 948a2fa21..d56e9d1ed 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -457,19 +457,18 @@ pub fn execute_search( let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); let mut docids = Vec::new(); - for Neighbor { index, distance } in neighbors.iter() { + for Neighbor { index, distance: _ } in neighbors.iter() { let index = BEU32::new(*index as u32); let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); - dbg!(distance, f32::from_bits(*distance)); if universe.contains(docid) { docids.push(docid); - if docids.len() == length { + if docids.len() == (from + length) { break; } } } - docids + docids.into_iter().skip(from).take(length).collect() } // return the search docids if the vector field is not specified None => docids, From 2cf747cb89522a25c3b2f0ee4024d05cf28da0f4 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 15:57:31 +0200 Subject: [PATCH 15/40] Fix the tests --- milli/src/index.rs | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/milli/src/index.rs b/milli/src/index.rs index 29d602330..dcfcc0730 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -1497,9 +1497,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - age 1 - id 2 - name 2 + age 1 | + id 2 | + name 2 | "### ); @@ -1517,9 +1517,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - age 1 - id 2 - name 2 + age 1 | + id 2 | + name 2 | "### ); @@ -1533,9 +1533,9 @@ pub(crate) mod tests { db_snap!(index, field_distribution, @r###" - has_dog 1 - id 2 - name 2 + has_dog 1 | + id 2 | + name 2 | "### ); } From 3b560ef7d099ac204e74fe12639d896457f881a3 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 15:59:10 +0200 Subject: [PATCH 16/40] Make clippy happy --- milli/src/distance.rs | 2 +- milli/src/search/new/mod.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/milli/src/distance.rs b/milli/src/distance.rs index bbd2f15eb..22047eea3 100644 --- a/milli/src/distance.rs +++ b/milli/src/distance.rs @@ -36,7 +36,7 @@ impl Metric> for Euclidean { #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] { if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD { - let squared = unsafe { squared_euclid_neon(&a, &b) }; + let squared = unsafe { squared_euclid_neon(a, b) }; let dist = squared.sqrt(); debug_assert!(!dist.is_nan()); return dist.to_bits(); diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index d56e9d1ed..246a89045 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -454,7 +454,7 @@ pub fn execute_search( let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); let ef = hnsw.len().min(100); let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; - let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); + let neighbors = hnsw.nearest(vector, ef, &mut searcher, &mut dest[..]); let mut docids = Vec::new(); for Neighbor { index, distance: _ } in neighbors.iter() { From a7e0f0de89ef8b7a0e261dbaff4f0893cbd1f7d6 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 16:34:09 +0200 Subject: [PATCH 17/40] Introduce a new error message for invalid vector dimensions --- meilisearch-types/src/error.rs | 2 ++ milli/src/error.rs | 4 +++- .../src/update/index_documents/typed_chunk.rs | 18 ++++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 6d81ff241..886a0fe30 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -217,6 +217,7 @@ InvalidDocumentFields , InvalidRequest , BAD_REQUEST ; MissingDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ; +InvalidVectorDimensions , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; @@ -335,6 +336,7 @@ impl ErrorCode for milli::Error { UserError::InvalidSortableAttribute { .. } => Code::InvalidSearchSort, UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, + UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, UserError::SortError(_) => Code::InvalidSearchSort, UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance diff --git a/milli/src/error.rs b/milli/src/error.rs index 8d55eabbd..a12334f90 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -110,9 +110,11 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco }, #[error(transparent)] InvalidGeoField(#[from] GeoError), + #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] + InvalidVectorDimensions { expected: usize, found: usize }, #[error("{0}")] InvalidFilter(String), - #[error("Invalid type for filter subexpression: `expected {}, found: {1}`.", .0.join(", "))] + #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] InvalidFilterExpression(&'static [&'static str], Value), #[error("Attribute `{}` is not sortable. {}", .field, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index e136dc139..0e2e85c1c 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -11,11 +11,13 @@ use heed::types::ByteSlice; use heed::RwTxn; use hnsw::Searcher; use roaring::RoaringBitmap; +use space::KnnPoints; use super::helpers::{ self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap, }; use super::{ClonableMmap, MergeFn}; +use crate::error::UserError; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::as_cloneable_grenad; @@ -228,12 +230,28 @@ pub(crate) fn write_typed_chunk_into_index( let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); let mut searcher = Searcher::new(); + let mut expected_dimensions = match index.vector_id_docid.iter(wtxn)?.next() { + Some(result) => { + let (vector_id, _) = result?; + Some(hnsw.get_point(vector_id.get() as usize).len()) + } + None => None, + }; + let mut cursor = vector_points.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); // convert the vector back to a Vec let vector: Vec = pod_collect_to_vec(value); + + // TODO Move this error in the vector extractor + let found = vector.len(); + let expected = *expected_dimensions.get_or_insert(found); + if expected != found { + return Err(UserError::InvalidVectorDimensions { expected, found })?; + } + let vector_id = hnsw.insert(vector, &mut searcher) as u32; index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; } From 717d4fddd4ca7900da06b9baeef62e78a40fde6d Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 14 Jun 2023 16:35:29 +0200 Subject: [PATCH 18/40] Remove the unused distance --- milli/src/distance.rs | 56 ------------------------------------------- 1 file changed, 56 deletions(-) diff --git a/milli/src/distance.rs b/milli/src/distance.rs index 22047eea3..c26a745a4 100644 --- a/milli/src/distance.rs +++ b/milli/src/distance.rs @@ -1,13 +1,6 @@ use serde::{Deserialize, Serialize}; use space::Metric; -#[cfg(any( - target_arch = "x86", - target_arch = "x86_64", - all(target_arch = "aarch64", target_feature = "neon") -))] -const MIN_DIM_SIZE_SIMD: usize = 16; - #[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] pub struct DotProduct; @@ -33,58 +26,9 @@ impl Metric> for Euclidean { type Unit = u32; fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { - #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] - { - if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD { - let squared = unsafe { squared_euclid_neon(a, b) }; - let dist = squared.sqrt(); - debug_assert!(!dist.is_nan()); - return dist.to_bits(); - } - } - let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum(); let dist = squared.sqrt(); debug_assert!(!dist.is_nan()); dist.to_bits() } } - -#[cfg(target_feature = "neon")] -use std::arch::aarch64::*; - -#[cfg(target_feature = "neon")] -pub(crate) unsafe fn squared_euclid_neon(v1: &[f32], v2: &[f32]) -> f32 { - let n = v1.len(); - let m = n - (n % 16); - let mut ptr1: *const f32 = v1.as_ptr(); - let mut ptr2: *const f32 = v2.as_ptr(); - let mut sum1 = vdupq_n_f32(0.); - let mut sum2 = vdupq_n_f32(0.); - let mut sum3 = vdupq_n_f32(0.); - let mut sum4 = vdupq_n_f32(0.); - - let mut i: usize = 0; - while i < m { - let sub1 = vsubq_f32(vld1q_f32(ptr1), vld1q_f32(ptr2)); - sum1 = vfmaq_f32(sum1, sub1, sub1); - - let sub2 = vsubq_f32(vld1q_f32(ptr1.add(4)), vld1q_f32(ptr2.add(4))); - sum2 = vfmaq_f32(sum2, sub2, sub2); - - let sub3 = vsubq_f32(vld1q_f32(ptr1.add(8)), vld1q_f32(ptr2.add(8))); - sum3 = vfmaq_f32(sum3, sub3, sub3); - - let sub4 = vsubq_f32(vld1q_f32(ptr1.add(12)), vld1q_f32(ptr2.add(12))); - sum4 = vfmaq_f32(sum4, sub4, sub4); - - ptr1 = ptr1.add(16); - ptr2 = ptr2.add(16); - i += 16; - } - let mut result = vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3) + vaddvq_f32(sum4); - for i in 0..n - m { - result += (*ptr1.add(i) - *ptr2.add(i)).powi(2); - } - result -} From 1b2923f7c0c01094834475be174d8d5d57b0adc2 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 10:14:25 +0200 Subject: [PATCH 19/40] Return the vector in the output of the search routes --- meilisearch/src/search.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index f1fb341a2..a85c0a437 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -233,6 +233,8 @@ pub struct SearchHit { pub struct SearchResult { pub hits: Vec, pub query: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub vector: Option>, pub processing_time_ms: u128, #[serde(flatten)] pub hits_info: HitsInfo, @@ -515,7 +517,8 @@ pub fn perform_search( let result = SearchResult { hits: documents, hits_info, - query: query.q.clone().unwrap_or_default(), + query: query.q.unwrap_or_default(), + vector: query.vector, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, From 321ec5f3fa01107829b49e11ccefbb2ac2490bc0 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 11:17:20 +0200 Subject: [PATCH 20/40] Accept multiple vectors by documents using the _vectors field --- .../extract/extract_vector_points.rs | 32 +++++++++++++------ .../src/update/index_documents/extract/mod.rs | 4 +-- milli/src/update/index_documents/mod.rs | 6 ++-- .../src/update/index_documents/typed_chunk.rs | 5 +-- 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 409df5dbd..7e2bd25c5 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -1,20 +1,22 @@ +use std::convert::TryFrom; use std::fs::File; use std::io; use bytemuck::cast_slice; +use either::Either; use serde_json::from_slice; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::{FieldId, InternalError, Result}; -/// Extracts the embedding vector contained in each document under the `_vector` field. +/// Extracts the embedding vector contained in each document under the `_vectors` field. /// /// Returns the generated grenad reader containing the docid as key associated to the Vec #[logging_timer::time] pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, - vector_fid: FieldId, + vectors_fid: FieldId, ) -> Result> { let mut writer = create_writer( indexer.chunk_compression_type, @@ -26,14 +28,26 @@ pub fn extract_vector_points( while let Some((docid_bytes, value)) = cursor.move_on_next()? { let obkv = obkv::KvReader::new(value); - // first we get the _vector field - if let Some(vector) = obkv.get(vector_fid) { - // try to extract the vector - let vector: Vec = from_slice(vector).map_err(InternalError::SerdeJson).unwrap(); - let bytes = cast_slice(&vector); - writer.insert(docid_bytes, bytes)?; + // first we retrieve the _vectors field + if let Some(vectors) = obkv.get(vectors_fid) { + // extract the vectors + let vectors: Either>, Vec> = + from_slice(vectors).map_err(InternalError::SerdeJson).unwrap(); + let vectors = vectors.map_right(|v| vec![v]).into_inner(); + + for (i, vector) in vectors.into_iter().enumerate() { + match u16::try_from(i) { + Ok(i) => { + let mut key = docid_bytes.to_vec(); + key.extend_from_slice(&i.to_ne_bytes()); + let bytes = cast_slice(&vector); + writer.insert(key, bytes)?; + } + Err(_) => continue, + } + } } - // else => the _vector object was `null`, there is nothing to do + // else => the `_vectors` object was `null`, there is nothing to do } writer_into_reader(writer) diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index fdc6f5616..325d52279 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -47,7 +47,7 @@ pub(crate) fn data_from_obkv_documents( faceted_fields: HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, - vector_field_id: Option, + vectors_field_id: Option, stop_words: Option>, max_positions_per_attributes: Option, exact_attributes: HashSet, @@ -72,7 +72,7 @@ pub(crate) fn data_from_obkv_documents( &faceted_fields, primary_key_id, geo_fields_ids, - vector_field_id, + vectors_field_id, &stop_words, max_positions_per_attributes, ) diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index adbab54db..5b6e03637 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -304,8 +304,8 @@ where } None => None, }; - // get the fid of the `_vector` field. - let vector_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vector"); + // get the fid of the `_vectors` field. + let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); let stop_words = self.index.stop_words(self.wtxn)?; let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?; @@ -342,7 +342,7 @@ where faceted_fields, primary_key_id, geo_fields_ids, - vector_field_id, + vectors_field_id, stop_words, max_positions_per_attributes, exact_attributes, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 0e2e85c1c..7d23ef320 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -20,7 +20,7 @@ use super::{ClonableMmap, MergeFn}; use crate::error::UserError; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; -use crate::update::index_documents::helpers::as_cloneable_grenad; +use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; pub(crate) enum TypedChunk { @@ -241,7 +241,8 @@ pub(crate) fn write_typed_chunk_into_index( let mut cursor = vector_points.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) - let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + let (left, _index) = try_split_array_at(key).unwrap(); + let docid = DocumentId::from_be_bytes(left); // convert the vector back to a Vec let vector: Vec = pod_collect_to_vec(value); From ab9f2269aa3c99373291682b58c73f0ad1970c7f Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 11:45:29 +0200 Subject: [PATCH 21/40] Normalize the vectors during indexation and search --- milli/src/lib.rs | 12 ++++++++++++ milli/src/search/new/mod.rs | 6 ++++-- milli/src/update/index_documents/extract/mod.rs | 6 +++--- milli/src/update/index_documents/typed_chunk.rs | 2 ++ 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 021880a50..5bebdbda5 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -285,6 +285,18 @@ pub fn normalize_facet(original: &str) -> String { CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase() } +/// Normalize a vector by dividing the dimensions by the lenght of it. +pub fn normalize_vector(mut vector: Vec) -> Vec { + let squared: f32 = vector.iter().map(|x| x * x).sum(); + let length = squared.sqrt(); + if length <= f32::EPSILON { + vector + } else { + vector.iter_mut().for_each(|x| *x = *x / length); + vector + } +} + #[cfg(test)] mod tests { use serde_json::json; diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 246a89045..51917b772 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -49,7 +49,8 @@ use self::interner::Interned; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; use crate::{ - AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, BEU32, + normalize_vector, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, + UserError, BEU32, }; /// A structure used throughout the execution of a search query. @@ -454,7 +455,8 @@ pub fn execute_search( let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); let ef = hnsw.len().min(100); let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; - let neighbors = hnsw.nearest(vector, ef, &mut searcher, &mut dest[..]); + let vector = normalize_vector(vector.clone()); + let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); let mut docids = Vec::new(); for Neighbor { index, distance: _ } in neighbors.iter() { diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 325d52279..14f08b106 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -283,7 +283,7 @@ fn send_and_extract_flattened_documents_data( faceted_fields: &HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, - vector_field_id: Option, + vectors_field_id: Option, stop_words: &Option>, max_positions_per_attributes: Option, ) -> Result<( @@ -312,11 +312,11 @@ fn send_and_extract_flattened_documents_data( }); } - if let Some(vector_field_id) = vector_field_id { + if let Some(vectors_field_id) = vectors_field_id { let documents_chunk_cloned = flattened_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); rayon::spawn(move || { - let result = extract_vector_points(documents_chunk_cloned, indexer, vector_field_id); + let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); let _ = match result { Ok(vector_points) => { lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 7d23ef320..a63aacf83 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -19,6 +19,7 @@ use super::helpers::{ use super::{ClonableMmap, MergeFn}; use crate::error::UserError; use crate::facet::FacetType; +use crate::normalize_vector; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; @@ -253,6 +254,7 @@ pub(crate) fn write_typed_chunk_into_index( return Err(UserError::InvalidVectorDimensions { expected, found })?; } + let vector = normalize_vector(vector); let vector_id = hnsw.insert(vector, &mut searcher) as u32; index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; } From 5c5a4e075d7ae43bcd2b143063dce6bec206b424 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 13:48:39 +0200 Subject: [PATCH 22/40] Make clippy happy --- milli/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 5bebdbda5..04c81039a 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -292,7 +292,7 @@ pub fn normalize_vector(mut vector: Vec) -> Vec { if length <= f32::EPSILON { vector } else { - vector.iter_mut().for_each(|x| *x = *x / length); + vector.iter_mut().for_each(|x| *x /= length); vector } } From 3e3c74339231646bb72d93e9a02db10ec7c345f7 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 13:48:57 +0200 Subject: [PATCH 23/40] Make Rustfmt happy --- milli/src/update/index_documents/typed_chunk.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index a63aacf83..915bb2299 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -19,10 +19,12 @@ use super::helpers::{ use super::{ClonableMmap, MergeFn}; use crate::error::UserError; use crate::facet::FacetType; -use crate::normalize_vector; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; -use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, BEU32}; +use crate::{ + lat_lng_to_xyz, normalize_vector, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, + BEU32, +}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), From 737aec17056ebe328d0edceab1cb1f6eceb4f447 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 14:38:58 +0200 Subject: [PATCH 24/40] Expose an _semanticSimilarity as a dot product in the documents --- Cargo.lock | 1 + meilisearch/Cargo.toml | 1 + meilisearch/src/search.rs | 22 ++++++++++++++++++++++ milli/src/distance.rs | 18 ++++++++++++++---- milli/src/lib.rs | 1 + 5 files changed, 39 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 904d1c225..ccf79f9a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2595,6 +2595,7 @@ dependencies = [ "num_cpus", "obkv", "once_cell", + "ordered-float", "parking_lot", "permissive-json-pointer", "pin-project-lite", diff --git a/meilisearch/Cargo.toml b/meilisearch/Cargo.toml index 8fcd69591..d90dd24dd 100644 --- a/meilisearch/Cargo.toml +++ b/meilisearch/Cargo.toml @@ -48,6 +48,7 @@ mime = "0.3.17" num_cpus = "1.15.0" obkv = "0.2.0" once_cell = "1.17.1" +ordered-float = "3.7.0" parking_lot = "0.12.1" permissive-json-pointer = { path = "../permissive-json-pointer" } pin-project-lite = "0.2.9" diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index a85c0a437..c0d707657 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -10,6 +10,7 @@ use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::dot_product_similarity; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -18,6 +19,7 @@ use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; +use ordered_float::OrderedFloat; use regex::Regex; use serde::Serialize; use serde_json::{json, Value}; @@ -457,6 +459,10 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } + if let Some(vector) = query.vector.as_ref() { + insert_semantic_similarity(&vector, &mut document); + } + let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); let ranking_score_details = @@ -542,6 +548,22 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } +fn insert_semantic_similarity(query: &[f32], document: &mut Document) { + if let Some(value) = document.get("_vectors") { + let vectors: Vec> = match serde_json::from_value(value.clone()) { + Ok(Either::Left(vector)) => vec![vector], + Ok(Either::Right(vectors)) => vectors, + Err(_) => return, + }; + let similarity = vectors + .into_iter() + .map(|v| OrderedFloat(dot_product_similarity(query, &v))) + .max() + .map(OrderedFloat::into_inner); + document.insert("_semanticSimilarity".to_string(), json!(similarity)); + } +} + fn compute_formatted_options( attr_to_highlight: &HashSet, attr_to_crop: &[String], diff --git a/milli/src/distance.rs b/milli/src/distance.rs index c26a745a4..1b91b4654 100644 --- a/milli/src/distance.rs +++ b/milli/src/distance.rs @@ -12,13 +12,18 @@ impl Metric> for DotProduct { // // Following . fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { - let dist: f32 = a.iter().zip(b).map(|(a, b)| a * b).sum(); - let dist = 1.0 - dist; + let dist = 1.0 - dot_product_similarity(a, b); debug_assert!(!dist.is_nan()); dist.to_bits() } } +/// Returns the dot product similarity score that will between 0.0 and 1.0 +/// if both vectors are normalized. The higher the more similar the vectors are. +pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(a, b)| a * b).sum() +} + #[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] pub struct Euclidean; @@ -26,9 +31,14 @@ impl Metric> for Euclidean { type Unit = u32; fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { - let squared: f32 = a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum(); - let dist = squared.sqrt(); + let dist = euclidean_squared_distance(a, b).sqrt(); debug_assert!(!dist.is_nan()); dist.to_bits() } } + +/// Return the squared euclidean distance between both vectors that will +/// between 0.0 and +inf. The smaller the nearer the vectors are. +pub fn euclidean_squared_distance(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum() +} diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 04c81039a..c93bf88ff 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -31,6 +31,7 @@ use std::convert::{TryFrom, TryInto}; use std::hash::BuildHasherDefault; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; +pub use distance::{dot_product_similarity, euclidean_squared_distance}; pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; From 7aa12753370bbd11b3ee06459d6a1516dffd8564 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 15:54:28 +0200 Subject: [PATCH 25/40] Display the _semanticSimilarity even if the `_vectors` field is not displayed --- meilisearch/src/search.rs | 45 ++++++++++++------- milli/src/lib.rs | 17 +++++++ .../extract/extract_vector_points.rs | 11 ++--- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index c0d707657..f7cfe99f9 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -17,7 +17,7 @@ use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, - SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, + SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, }; use ordered_float::OrderedFloat; use regex::Regex; @@ -432,7 +432,6 @@ pub fn perform_search( formatter_builder.highlight_suffix(query.highlight_post_tag); let mut documents = Vec::new(); - let documents_iter = index.documents(&rtxn, documents_ids)?; for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { @@ -460,7 +459,9 @@ pub fn perform_search( } if let Some(vector) = query.vector.as_ref() { - insert_semantic_similarity(&vector, &mut document); + if let Some(vectors) = extract_field("_vectors", &fields_ids_map, obkv)? { + insert_semantic_similarity(vector, vectors, &mut document); + } } let ranking_score = @@ -548,20 +549,18 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn insert_semantic_similarity(query: &[f32], document: &mut Document) { - if let Some(value) = document.get("_vectors") { - let vectors: Vec> = match serde_json::from_value(value.clone()) { - Ok(Either::Left(vector)) => vec![vector], - Ok(Either::Right(vectors)) => vectors, +fn insert_semantic_similarity(query: &[f32], vectors: Value, document: &mut Document) { + let vectors = + match serde_json::from_value(vectors).map(VectorOrArrayOfVectors::into_array_of_vectors) { + Ok(vectors) => vectors, Err(_) => return, }; - let similarity = vectors - .into_iter() - .map(|v| OrderedFloat(dot_product_similarity(query, &v))) - .max() - .map(OrderedFloat::into_inner); - document.insert("_semanticSimilarity".to_string(), json!(similarity)); - } + let similarity = vectors + .into_iter() + .map(|v| OrderedFloat(dot_product_similarity(query, &v))) + .max() + .map(OrderedFloat::into_inner); + document.insert("_semanticSimilarity".to_string(), json!(similarity)); } fn compute_formatted_options( @@ -691,6 +690,22 @@ fn make_document( Ok(document) } +/// Extract the JSON value under the field name specified +/// but doesn't support nested objects. +fn extract_field( + field_name: &str, + field_ids_map: &FieldsIdsMap, + obkv: obkv::KvReaderU16, +) -> Result, MeilisearchHttpError> { + match field_ids_map.id(field_name) { + Some(fid) => match obkv.get(fid) { + Some(value) => Ok(serde_json::from_slice(value).map(Some)?), + None => Ok(None), + }, + None => Ok(None), + } +} + fn format_fields>( document: &Document, field_ids_map: &FieldsIdsMap, diff --git a/milli/src/lib.rs b/milli/src/lib.rs index c93bf88ff..63cf6f397 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -286,6 +286,23 @@ pub fn normalize_facet(original: &str) -> String { CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase() } +/// Represents either a vector or an array of multiple vectors. +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(transparent)] +pub struct VectorOrArrayOfVectors { + #[serde(with = "either::serde_untagged")] + inner: either::Either, Vec>>, +} + +impl VectorOrArrayOfVectors { + pub fn into_array_of_vectors(self) -> Vec> { + match self.inner { + either::Either::Left(vector) => vec![vector], + either::Either::Right(vectors) => vectors, + } + } +} + /// Normalize a vector by dividing the dimensions by the lenght of it. pub fn normalize_vector(mut vector: Vec) -> Vec { let squared: f32 = vector.iter().map(|x| x * x).sum(); diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 7e2bd25c5..c2a08b320 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -3,11 +3,10 @@ use std::fs::File; use std::io; use bytemuck::cast_slice; -use either::Either; use serde_json::from_slice; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; -use crate::{FieldId, InternalError, Result}; +use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors}; /// Extracts the embedding vector contained in each document under the `_vectors` field. /// @@ -31,9 +30,11 @@ pub fn extract_vector_points( // first we retrieve the _vectors field if let Some(vectors) = obkv.get(vectors_fid) { // extract the vectors - let vectors: Either>, Vec> = - from_slice(vectors).map_err(InternalError::SerdeJson).unwrap(); - let vectors = vectors.map_right(|v| vec![v]).into_inner(); + // TODO return a user error before unwrapping + let vectors = from_slice(vectors) + .map_err(InternalError::SerdeJson) + .map(VectorOrArrayOfVectors::into_array_of_vectors) + .unwrap(); for (i, vector) in vectors.into_iter().enumerate() { match u16::try_from(i) { From 531748c536f9f1fc2c1239f7bae62a720c14cda0 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 16:18:24 +0200 Subject: [PATCH 26/40] Return a user error when the _vectors type is invalid --- meilisearch-types/src/error.rs | 2 ++ milli/src/error.rs | 2 ++ .../extract/extract_vector_points.rs | 23 ++++++++++++++----- .../src/update/index_documents/extract/mod.rs | 7 +++++- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 886a0fe30..ff9b06d85 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -218,6 +218,7 @@ MissingDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ; InvalidVectorDimensions , InvalidRequest , BAD_REQUEST ; +InvalidVectorsType , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; @@ -337,6 +338,7 @@ impl ErrorCode for milli::Error { UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, + UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, UserError::SortError(_) => Code::InvalidSearchSort, UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance diff --git a/milli/src/error.rs b/milli/src/error.rs index a12334f90..3df599b61 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -112,6 +112,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco InvalidGeoField(#[from] GeoError), #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] InvalidVectorDimensions { expected: usize, found: usize }, + #[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] + InvalidVectorsType { document_id: Value, value: Value }, #[error("{0}")] InvalidFilter(String), #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index c2a08b320..e78dbc080 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -3,9 +3,10 @@ use std::fs::File; use std::io; use bytemuck::cast_slice; -use serde_json::from_slice; +use serde_json::{from_slice, Value}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; +use crate::error::UserError; use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors}; /// Extracts the embedding vector contained in each document under the `_vectors` field. @@ -15,6 +16,7 @@ use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors}; pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, + primary_key_id: FieldId, vectors_fid: FieldId, ) -> Result> { let mut writer = create_writer( @@ -27,14 +29,23 @@ pub fn extract_vector_points( while let Some((docid_bytes, value)) = cursor.move_on_next()? { let obkv = obkv::KvReader::new(value); + // since we only needs the primary key when we throw an error we create this getter to + // lazily get it when needed + let document_id = || -> Value { + let document_id = obkv.get(primary_key_id).unwrap(); + serde_json::from_slice(document_id).unwrap() + }; + // first we retrieve the _vectors field if let Some(vectors) = obkv.get(vectors_fid) { // extract the vectors - // TODO return a user error before unwrapping - let vectors = from_slice(vectors) - .map_err(InternalError::SerdeJson) - .map(VectorOrArrayOfVectors::into_array_of_vectors) - .unwrap(); + let vectors = match from_slice(vectors) { + Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors), + Err(_) => return Err(UserError::InvalidVectorsType { + document_id: document_id(), + value: from_slice(vectors).map_err(InternalError::SerdeJson)?, + }.into()), + }; for (i, vector) in vectors.into_iter().enumerate() { match u16::try_from(i) { diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 14f08b106..6259c7272 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -316,7 +316,12 @@ fn send_and_extract_flattened_documents_data( let documents_chunk_cloned = flattened_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); rayon::spawn(move || { - let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); + let result = extract_vector_points( + documents_chunk_cloned, + indexer, + primary_key_id, + vectors_field_id, + ); let _ = match result { Ok(vector_points) => { lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) From ff3664431f8d16310614f946c1ce5b63a2ab56dc Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Tue, 20 Jun 2023 16:26:00 +0200 Subject: [PATCH 27/40] Make rustfmt happy --- .../index_documents/extract/extract_vector_points.rs | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index e78dbc080..ddf25917c 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -41,10 +41,13 @@ pub fn extract_vector_points( // extract the vectors let vectors = match from_slice(vectors) { Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors), - Err(_) => return Err(UserError::InvalidVectorsType { - document_id: document_id(), - value: from_slice(vectors).map_err(InternalError::SerdeJson)?, - }.into()), + Err(_) => { + return Err(UserError::InvalidVectorsType { + document_id: document_id(), + value: from_slice(vectors).map_err(InternalError::SerdeJson)?, + } + .into()) + } }; for (i, vector) in vectors.into_iter().enumerate() { From 66b8cfd8c83038980abd74e851969cda08e07fed Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 21 Jun 2023 17:07:02 +0200 Subject: [PATCH 28/40] Introduce a way to store the HNSW on multiple LMDB entries --- milli/src/index.rs | 46 +++++++++++++++++--- milli/src/lib.rs | 1 + milli/src/readable_slices.rs | 84 ++++++++++++++++++++++++++++++++++++ 3 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 milli/src/readable_slices.rs diff --git a/milli/src/index.rs b/milli/src/index.rs index dcfcc0730..8343515cf 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -22,6 +22,7 @@ use crate::heed_codec::facet::{ FieldIdCodec, OrderedF64Codec, }; use crate::heed_codec::{ScriptLanguageCodec, StrBEU16Codec, StrRefCodec}; +use crate::readable_slices::ReadableSlices; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, @@ -47,7 +48,10 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; - pub const VECTOR_HNSW_KEY: &str = "vector-hnsw"; + /// The prefix of the key that is used to store the, potential big, HNSW structure. + /// It is concatenated with a big-endian encoded number (non-human readable). + /// e.g. vector-hnsw0x0032. + pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; @@ -517,19 +521,49 @@ impl Index { /// Writes the provided `hnsw`. pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { - self.main.put::<_, Str, SerdeBincode>(wtxn, main_key::VECTOR_HNSW_KEY, hnsw) + // We must delete all the chunks before we write the new HNSW chunks. + self.delete_vector_hnsw(wtxn)?; + + let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB + let bytes = bincode::serialize(hnsw).map_err(|_| heed::Error::Encoding)?; + for (i, chunk) in bytes.chunks(chunk_size).enumerate() { + let i = i as u32; + let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec(); + key.extend_from_slice(&i.to_be_bytes()); + self.main.put::<_, ByteSlice, ByteSlice>(wtxn, &key, chunk)?; + } + Ok(()) } /// Delete the `hnsw`. pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { - self.main.delete::<_, Str>(wtxn, main_key::VECTOR_HNSW_KEY) + let mut iter = self.main.prefix_iter_mut::<_, ByteSlice, DecodeIgnore>( + wtxn, + main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes(), + )?; + let mut deleted = false; + while let Some(_) = iter.next().transpose()? { + // We do not keep a reference to the key or the value. + unsafe { deleted |= iter.del_current()? }; + } + Ok(deleted) } /// Returns the `hnsw`. pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { - match self.main.get::<_, Str, SerdeBincode>(rtxn, main_key::VECTOR_HNSW_KEY)? { - Some(hnsw) => Ok(Some(hnsw)), - None => Ok(None), + let mut slices = Vec::new(); + for result in + self.main.prefix_iter::<_, Str, ByteSlice>(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)? + { + let (_, slice) = result?; + slices.push(slice); + } + + if slices.is_empty() { + Ok(None) + } else { + let readable_slices: ReadableSlices<_> = slices.into_iter().collect(); + Ok(Some(bincode::deserialize_from(readable_slices).map_err(|_| heed::Error::Decoding)?)) } } diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 63cf6f397..626c30ab0 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -18,6 +18,7 @@ mod fields_ids_map; pub mod heed_codec; pub mod index; pub mod proximity; +mod readable_slices; pub mod score_details; mod search; pub mod update; diff --git a/milli/src/readable_slices.rs b/milli/src/readable_slices.rs new file mode 100644 index 000000000..9ba6c1ba1 --- /dev/null +++ b/milli/src/readable_slices.rs @@ -0,0 +1,84 @@ +use std::io::{self, Read}; +use std::iter::FromIterator; + +pub struct ReadableSlices { + inner: Vec, + pos: u64, +} + +impl FromIterator for ReadableSlices { + fn from_iter>(iter: T) -> Self { + ReadableSlices { inner: iter.into_iter().collect(), pos: 0 } + } +} + +impl> Read for ReadableSlices { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + let original_buf_len = buf.len(); + + // We explore the list of slices to find the one where we must start reading. + let mut pos = self.pos; + let index = match self + .inner + .iter() + .map(|s| s.as_ref().len() as u64) + .position(|size| pos.checked_sub(size).map(|p| pos = p).is_none()) + { + Some(index) => index, + None => return Ok(0), + }; + + let mut inner_pos = pos as usize; + for slice in &self.inner[index..] { + let slice = &slice.as_ref()[inner_pos..]; + + if buf.len() > slice.len() { + // We must exhaust the current slice and go to the next one there is not enough here. + buf[..slice.len()].copy_from_slice(slice); + buf = &mut buf[slice.len()..]; + inner_pos = 0; + } else { + // There is enough in this slice to fill the remaining bytes of the buffer. + // Let's break just after filling it. + buf.copy_from_slice(&slice[..buf.len()]); + buf = &mut []; + break; + } + } + + let written = original_buf_len - buf.len(); + self.pos += written as u64; + Ok(written) + } +} + +#[cfg(test)] +mod test { + use super::ReadableSlices; + use std::io::Read; + + #[test] + fn basic() { + let data: Vec<_> = (0..100).collect(); + let splits: Vec<_> = data.chunks(3).collect(); + let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); + + let mut output = Vec::new(); + let length = rdslices.read_to_end(&mut output).unwrap(); + assert_eq!(length, data.len()); + assert_eq!(output, data); + } + + #[test] + fn small_reads() { + let data: Vec<_> = (0..u8::MAX).collect(); + let splits: Vec<_> = data.chunks(27).collect(); + let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); + + let buffer = &mut [0; 45]; + let length = rdslices.read(buffer).unwrap(); + let expected: Vec<_> = (0..buffer.len() as u8).collect(); + assert_eq!(length, buffer.len()); + assert_eq!(buffer, &expected[..]); + } +} From 7c2f5f77b8d4ea07bd97087bb6d842321019aa38 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Wed, 21 Jun 2023 17:10:19 +0200 Subject: [PATCH 29/40] Make clippy and fmt happy --- milli/src/index.rs | 2 +- milli/src/readable_slices.rs | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/milli/src/index.rs b/milli/src/index.rs index 8343515cf..a22901993 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -542,7 +542,7 @@ impl Index { main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes(), )?; let mut deleted = false; - while let Some(_) = iter.next().transpose()? { + while iter.next().transpose()?.is_some() { // We do not keep a reference to the key or the value. unsafe { deleted |= iter.del_current()? }; } diff --git a/milli/src/readable_slices.rs b/milli/src/readable_slices.rs index 9ba6c1ba1..7f5be214f 100644 --- a/milli/src/readable_slices.rs +++ b/milli/src/readable_slices.rs @@ -54,9 +54,10 @@ impl> Read for ReadableSlices { #[cfg(test)] mod test { - use super::ReadableSlices; use std::io::Read; + use super::ReadableSlices; + #[test] fn basic() { let data: Vec<_> = (0..100).collect(); From 66fb5c150cbc10addfd6e3f1623487cf86522989 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Mon, 26 Jun 2023 10:59:03 +0200 Subject: [PATCH 30/40] Rename _semanticSimilarity into _semanticScore --- meilisearch/src/search.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index f7cfe99f9..fe4b10e59 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -460,7 +460,7 @@ pub fn perform_search( if let Some(vector) = query.vector.as_ref() { if let Some(vectors) = extract_field("_vectors", &fields_ids_map, obkv)? { - insert_semantic_similarity(vector, vectors, &mut document); + insert_semantic_score(vector, vectors, &mut document); } } @@ -549,7 +549,7 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn insert_semantic_similarity(query: &[f32], vectors: Value, document: &mut Document) { +fn insert_semantic_score(query: &[f32], vectors: Value, document: &mut Document) { let vectors = match serde_json::from_value(vectors).map(VectorOrArrayOfVectors::into_array_of_vectors) { Ok(vectors) => vectors, @@ -560,7 +560,7 @@ fn insert_semantic_similarity(query: &[f32], vectors: Value, document: &mut Docu .map(|v| OrderedFloat(dot_product_similarity(query, &v))) .max() .map(OrderedFloat::into_inner); - document.insert("_semanticSimilarity".to_string(), json!(similarity)); + document.insert("_semanticScore".to_string(), json!(similarity)); } fn compute_formatted_options( From 864ad2a23c891496a512fed195ee4f4d10dd7761 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Mon, 26 Jun 2023 16:36:01 +0200 Subject: [PATCH 31/40] Check that vector store feature is enabled --- meilisearch/src/search.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index fe4b10e59..8f2d48650 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -326,6 +326,10 @@ pub fn perform_search( features.check_score_details()?; } + if query.vector.is_some() { + features.check_vector()?; + } + // compute the offset on the limit depending on the pagination mode. let (offset, limit) = if is_finite_pagination { let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); From 816d7ed174a6b1743a6f7c0ce0f18b898ce240f8 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Mon, 26 Jun 2023 17:28:07 +0200 Subject: [PATCH 32/40] Update the Vector Store product feature link --- index-scheduler/src/features.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/index-scheduler/src/features.rs b/index-scheduler/src/features.rs index 7f4a30741..5a663fe67 100644 --- a/index-scheduler/src/features.rs +++ b/index-scheduler/src/features.rs @@ -62,7 +62,7 @@ impl RoFeatures { Err(FeatureNotEnabledError { disabled_action: "Passing `vector` as a query parameter", feature: "vector store", - issue_link: "https://github.com/meilisearch/meilisearch/discussions/TODO", + issue_link: "https://github.com/meilisearch/product/discussions/677", } .into()) } From eecf20f1095c039679cecfd89bf7023d97845a39 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Mon, 26 Jun 2023 18:07:56 +0200 Subject: [PATCH 33/40] Introduce a new invalid_vector_store --- meilisearch-types/src/error.rs | 1 + meilisearch/src/routes/indexes/search.rs | 2 +- meilisearch/src/search.rs | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index ff9b06d85..3880fac4b 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -241,6 +241,7 @@ InvalidSearchMatchingStrategy , InvalidRequest , BAD_REQUEST ; InvalidSearchOffset , InvalidRequest , BAD_REQUEST ; InvalidSearchPage , InvalidRequest , BAD_REQUEST ; InvalidSearchQ , InvalidRequest , BAD_REQUEST ; +InvalidSearchVector , InvalidRequest , BAD_REQUEST ; InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ; InvalidSearchShowRankingScore , InvalidRequest , BAD_REQUEST ; InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ; diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index fae24dba2..0c45f08c7 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -34,7 +34,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) { pub struct SearchQueryGet { #[deserr(default, error = DeserrQueryParamError)] q: Option, - #[deserr(default, error = DeserrQueryParamError)] + #[deserr(default, error = DeserrQueryParamError)] vector: Option>, #[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError)] offset: Param, diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 8f2d48650..a8c6765bc 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -40,7 +40,7 @@ pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); pub struct SearchQuery { #[deserr(default, error = DeserrJsonError)] pub q: Option, - #[deserr(default, error = DeserrJsonError)] + #[deserr(default, error = DeserrJsonError)] pub vector: Option>, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, From f3e4d706388f7c5a78cca19f9f5acf06785ae496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 10:19:30 +0200 Subject: [PATCH 34/40] Send analytics about the query vector length --- meilisearch/src/analytics/segment_analytics.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index 55e4905bd..9a96c4650 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -548,6 +548,10 @@ pub struct SearchAggregator { // The maximum number of terms in a q request max_terms_number: usize, + // vector + // The maximum number of floats in a vector request + max_vector_size: usize, + // every time a search is done, we increment the counter linked to the used settings matching_strategy: HashMap, @@ -617,6 +621,10 @@ impl SearchAggregator { ret.max_terms_number = q.split_whitespace().count(); } + if let Some(ref vector) = query.vector { + ret.max_vector_size = vector.len(); + } + if query.is_finite_pagination() { let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); ret.max_limit = limit; From 63bfe1cee22fcd5c620e01edad3b75a8e0e97511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 12:28:32 +0200 Subject: [PATCH 35/40] Ignore when there are too many vectors --- .../extract/extract_vector_points.rs | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index ddf25917c..0fad3be07 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -50,16 +50,12 @@ pub fn extract_vector_points( } }; - for (i, vector) in vectors.into_iter().enumerate() { - match u16::try_from(i) { - Ok(i) => { - let mut key = docid_bytes.to_vec(); - key.extend_from_slice(&i.to_ne_bytes()); - let bytes = cast_slice(&vector); - writer.insert(key, bytes)?; - } - Err(_) => continue, - } + for (i, vector) in vectors.into_iter().enumerate().take(u16::MAX as usize) { + let index = u16::try_from(i).unwrap(); + let mut key = docid_bytes.to_vec(); + key.extend_from_slice(&index.to_be_bytes()); + let bytes = cast_slice(&vector); + writer.insert(key, bytes)?; } } // else => the `_vectors` object was `null`, there is nothing to do From 29d8268c94e05ca19820911175265e4298150d59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 12:29:11 +0200 Subject: [PATCH 36/40] Fix the vector query part by using the correct universe --- milli/src/search/new/mod.rs | 63 +++++++++++++++++++------------------ 1 file changed, 33 insertions(+), 30 deletions(-) diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 51917b772..8bdcf077b 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -376,8 +376,40 @@ pub fn execute_search( check_sort_criteria(ctx, sort_criteria.as_ref())?; - let mut located_query_terms = None; + if let Some(vector) = vector { + let mut searcher = Searcher::new(); + let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); + let ef = hnsw.len().min(100); + let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; + let vector = normalize_vector(vector.clone()); + let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); + let mut docids = Vec::new(); + let mut uniq_docids = RoaringBitmap::new(); + for Neighbor { index, distance: _ } in neighbors.iter() { + let index = BEU32::new(*index as u32); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); + if universe.contains(docid) && uniq_docids.insert(docid) { + docids.push(docid); + if docids.len() == (from + length) { + break; + } + } + } + + // return the nearest documents that are also part of the candidates + // along with a dummy list of scores that are useless in this context. + let docids: Vec<_> = docids.into_iter().skip(from).take(length).collect(); + + return Ok(PartialSearchResult { + candidates: universe, + document_scores: vec![Vec::new(); docids.len()], + documents_ids: docids, + located_query_terms: None, + }); + } + + let mut located_query_terms = None; let query_terms = if let Some(query) = query { // We make sure that the analyzer is aware of the stop words // this ensures that the query builder is able to properly remove them. @@ -445,37 +477,8 @@ pub fn execute_search( }; let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output; - let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?; - let docids = match vector { - Some(vector) => { - // return the nearest documents that are also part of the candidates. - let mut searcher = Searcher::new(); - let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); - let ef = hnsw.len().min(100); - let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; - let vector = normalize_vector(vector.clone()); - let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); - - let mut docids = Vec::new(); - for Neighbor { index, distance: _ } in neighbors.iter() { - let index = BEU32::new(*index as u32); - let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); - if universe.contains(docid) { - docids.push(docid); - if docids.len() == (from + length) { - break; - } - } - } - - docids.into_iter().skip(from).take(length).collect() - } - // return the search docids if the vector field is not specified - None => docids, - }; - // The candidates is the universe unless the exhaustive number of hits // is requested and a distinct attribute is set. if exhaustive_number_hits { From ebad1f396f898ab245b16a96f83c404e604b5d09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 12:29:40 +0200 Subject: [PATCH 37/40] Remove the useless euclidean distance implementation --- milli/src/distance.rs | 25 +++---------------------- milli/src/lib.rs | 4 ++-- 2 files changed, 5 insertions(+), 24 deletions(-) diff --git a/milli/src/distance.rs b/milli/src/distance.rs index 1b91b4654..c838e4bd4 100644 --- a/milli/src/distance.rs +++ b/milli/src/distance.rs @@ -7,10 +7,10 @@ pub struct DotProduct; impl Metric> for DotProduct { type Unit = u32; - // TODO explain me this function, I don't understand why f32.to_bits is ordered. - // I tried to do this and it wasn't OK - // // Following . + // + // Here is a playground that validate the ordering of the bit representation of floats in range 0.0..=1.0: + // fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { let dist = 1.0 - dot_product_similarity(a, b); debug_assert!(!dist.is_nan()); @@ -23,22 +23,3 @@ impl Metric> for DotProduct { pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { a.iter().zip(b).map(|(a, b)| a * b).sum() } - -#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] -pub struct Euclidean; - -impl Metric> for Euclidean { - type Unit = u32; - - fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { - let dist = euclidean_squared_distance(a, b).sqrt(); - debug_assert!(!dist.is_nan()); - dist.to_bits() - } -} - -/// Return the squared euclidean distance between both vectors that will -/// between 0.0 and +inf. The smaller the nearer the vectors are. -pub fn euclidean_squared_distance(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b).map(|(a, b)| (a - b).powi(2)).sum() -} diff --git a/milli/src/lib.rs b/milli/src/lib.rs index 626c30ab0..99126f60e 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -32,7 +32,7 @@ use std::convert::{TryFrom, TryInto}; use std::hash::BuildHasherDefault; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; -pub use distance::{dot_product_similarity, euclidean_squared_distance}; +pub use distance::dot_product_similarity; pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; @@ -304,7 +304,7 @@ impl VectorOrArrayOfVectors { } } -/// Normalize a vector by dividing the dimensions by the lenght of it. +/// Normalize a vector by dividing the dimensions by the length of it. pub fn normalize_vector(mut vector: Vec) -> Vec { let squared: f32 = vector.iter().map(|x| x * x).sum(); let length = squared.sqrt(); From 30741d17fa997806cc1bd61c9d44c754e9becd7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 12:30:44 +0200 Subject: [PATCH 38/40] Change the TODO message --- milli/src/update/index_documents/typed_chunk.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 915bb2299..3f197fbd1 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -249,7 +249,7 @@ pub(crate) fn write_typed_chunk_into_index( // convert the vector back to a Vec let vector: Vec = pod_collect_to_vec(value); - // TODO Move this error in the vector extractor + // TODO Inform the user about the document that has a wrong `_vectors` let found = vector.len(); let expected = *expected_dimensions.get_or_insert(found); if expected != found { From b2b413db12d2db1bb57c704c16dc9d7d9ae5f325 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 12:31:23 +0200 Subject: [PATCH 39/40] Return all the _semanticScore values in the documents --- meilisearch/src/search.rs | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index a8c6765bc..346c9b1ec 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -228,6 +228,8 @@ pub struct SearchHit { pub ranking_score: Option, #[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")] pub ranking_score_details: Option>, + #[serde(rename = "_semanticScore", skip_serializing_if = "Option::is_none")] + pub semantic_score: Option, } #[derive(Serialize, Debug, Clone, PartialEq)] @@ -462,11 +464,13 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - if let Some(vector) = query.vector.as_ref() { - if let Some(vectors) = extract_field("_vectors", &fields_ids_map, obkv)? { - insert_semantic_score(vector, vectors, &mut document); - } - } + let semantic_score = match query.vector.as_ref() { + Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { + Some(vectors) => compute_semantic_score(vector, vectors)?, + None => None, + }, + None => None, + }; let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -479,6 +483,7 @@ pub fn perform_search( matches_position, ranking_score_details, ranking_score, + semantic_score, }; documents.push(hit); } @@ -553,18 +558,15 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn insert_semantic_score(query: &[f32], vectors: Value, document: &mut Document) { - let vectors = - match serde_json::from_value(vectors).map(VectorOrArrayOfVectors::into_array_of_vectors) { - Ok(vectors) => vectors, - Err(_) => return, - }; - let similarity = vectors +fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result> { + let vectors = serde_json::from_value(vectors) + .map(VectorOrArrayOfVectors::into_array_of_vectors) + .map_err(InternalError::SerdeJson)?; + Ok(vectors .into_iter() .map(|v| OrderedFloat(dot_product_similarity(query, &v))) .max() - .map(OrderedFloat::into_inner); - document.insert("_semanticScore".to_string(), json!(similarity)); + .map(OrderedFloat::into_inner)) } fn compute_formatted_options( From e69be93e425ed77ae5c621b0bb347763fc162d48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 27 Jun 2023 12:31:52 +0200 Subject: [PATCH 40/40] Log warn about using both q and vector field parameters --- meilisearch/src/search.rs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 346c9b1ec..ec7e79692 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -6,12 +6,13 @@ use std::time::Instant; use deserr::Deserr; use either::Either; use index_scheduler::RoFeatures; +use log::warn; use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::dot_product_similarity; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; +use meilisearch_types::milli::{dot_product_similarity, InternalError}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; @@ -301,6 +302,10 @@ pub fn perform_search( let mut search = index.search(&rtxn); + if query.vector.is_some() && query.q.is_some() { + warn!("Ignoring the query string `q` when used with the `vector` parameter."); + } + if let Some(ref vector) = query.vector { search.vector(vector.clone()); }