diff --git a/Cargo.lock b/Cargo.lock index 46218fc34..ccf79f9a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1221,6 +1221,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "dump" version = "1.2.0" @@ -1725,6 +1731,15 @@ dependencies = [ "byteorder", ] +[[package]] +name = "hashbrown" +version = "0.11.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e" +dependencies = [ + "ahash 0.7.6", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1826,6 +1841,22 @@ dependencies = [ "digest", ] +[[package]] +name = "hnsw" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752" +dependencies = [ + "ahash 0.7.6", + "hashbrown 0.11.2", + "libm", + "num-traits", + "rand_core", + "serde", + "smallvec", + "space", +] + [[package]] name = "http" version = "0.2.9" @@ -1956,7 +1987,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" dependencies = [ "autocfg", - "hashbrown", + "hashbrown 0.12.3", "serde", ] @@ -2057,7 +2088,7 @@ checksum = "37228e06c75842d1097432d94d02f37fe3ebfca9791c2e8fef6e9db17ed128c1" dependencies = [ "cedarwood", "fxhash", - "hashbrown", + "hashbrown 0.12.3", "lazy_static", "phf", "phf_codegen", @@ -2564,6 +2595,7 @@ dependencies = [ "num_cpus", "obkv", "once_cell", + "ordered-float", "parking_lot", "permissive-json-pointer", "pin-project-lite", @@ -2683,6 +2715,7 @@ dependencies = [ "bimap", "bincode", "bstr", + "bytemuck", "byteorder", "charabia", "concat-arrays", @@ -2697,6 +2730,7 @@ dependencies = [ "geoutils", "grenad", "heed", + "hnsw", "insta", "itertools", "json-depth-checker", @@ -2711,6 +2745,7 @@ dependencies = [ "once_cell", "ordered-float", "rand", + "rand_pcg", "rayon", "roaring", "rstar", @@ -2720,6 +2755,7 @@ dependencies = [ "smallstr", "smallvec", "smartstring", + "space", "tempfile", "thiserror", "time", @@ -3272,6 +3308,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core", + "serde", +] + [[package]] name = "rayon" version = "1.7.0" @@ -3731,6 +3777,9 @@ name = "smallvec" version = "1.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0" +dependencies = [ + "serde", +] [[package]] name = "smartstring" @@ -3753,6 +3802,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "space" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5ab9701ae895386d13db622abf411989deff7109b13b46b6173bb4ce5c1d123" +dependencies = [ + "doc-comment", + "num-traits", +] + [[package]] name = "spin" version = "0.5.2" @@ -4404,7 +4463,7 @@ version = "0.16.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c531a2dc4c462b833788be2c07eef4e621d0e9edbd55bf280cc164c1c1aa043" dependencies = [ - "hashbrown", + "hashbrown 0.12.3", "once_cell", ] diff --git a/index-scheduler/src/features.rs b/index-scheduler/src/features.rs index 7f4a30741..5a663fe67 100644 --- a/index-scheduler/src/features.rs +++ b/index-scheduler/src/features.rs @@ -62,7 +62,7 @@ impl RoFeatures { Err(FeatureNotEnabledError { disabled_action: "Passing `vector` as a query parameter", feature: "vector store", - issue_link: "https://github.com/meilisearch/meilisearch/discussions/TODO", + issue_link: "https://github.com/meilisearch/product/discussions/677", } .into()) } diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 6d81ff241..3880fac4b 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -217,6 +217,8 @@ InvalidDocumentFields , InvalidRequest , BAD_REQUEST ; MissingDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ; InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ; +InvalidVectorDimensions , InvalidRequest , BAD_REQUEST ; +InvalidVectorsType , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; @@ -239,6 +241,7 @@ InvalidSearchMatchingStrategy , InvalidRequest , BAD_REQUEST ; InvalidSearchOffset , InvalidRequest , BAD_REQUEST ; InvalidSearchPage , InvalidRequest , BAD_REQUEST ; InvalidSearchQ , InvalidRequest , BAD_REQUEST ; +InvalidSearchVector , InvalidRequest , BAD_REQUEST ; InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ; InvalidSearchShowRankingScore , InvalidRequest , BAD_REQUEST ; InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ; @@ -335,6 +338,8 @@ impl ErrorCode for milli::Error { UserError::InvalidSortableAttribute { .. } => Code::InvalidSearchSort, UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, + UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, + UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, UserError::SortError(_) => Code::InvalidSearchSort, UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance diff --git a/meilisearch/Cargo.toml b/meilisearch/Cargo.toml index 8fcd69591..d90dd24dd 100644 --- a/meilisearch/Cargo.toml +++ b/meilisearch/Cargo.toml @@ -48,6 +48,7 @@ mime = "0.3.17" num_cpus = "1.15.0" obkv = "0.2.0" once_cell = "1.17.1" +ordered-float = "3.7.0" parking_lot = "0.12.1" permissive-json-pointer = { path = "../permissive-json-pointer" } pin-project-lite = "0.2.9" diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index 55e4905bd..9a96c4650 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -548,6 +548,10 @@ pub struct SearchAggregator { // The maximum number of terms in a q request max_terms_number: usize, + // vector + // The maximum number of floats in a vector request + max_vector_size: usize, + // every time a search is done, we increment the counter linked to the used settings matching_strategy: HashMap, @@ -617,6 +621,10 @@ impl SearchAggregator { ret.max_terms_number = q.split_whitespace().count(); } + if let Some(ref vector) = query.vector { + ret.max_vector_size = vector.len(); + } + if query.is_finite_pagination() { let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); ret.max_limit = limit; diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 3ab093b5d..0c45f08c7 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -34,6 +34,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) { pub struct SearchQueryGet { #[deserr(default, error = DeserrQueryParamError)] q: Option, + #[deserr(default, error = DeserrQueryParamError)] + vector: Option>, #[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError)] offset: Param, #[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError)] @@ -84,6 +86,7 @@ impl From for SearchQuery { Self { q: other.q, + vector: other.vector, offset: other.offset.0, limit: other.limit.0, page: other.page.as_deref().copied(), diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 62f49c148..ec7e79692 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -6,18 +6,21 @@ use std::time::Instant; use deserr::Deserr; use either::Either; use index_scheduler::RoFeatures; +use log::warn; use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; +use meilisearch_types::milli::{dot_product_similarity, InternalError}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, - SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, + SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, }; +use ordered_float::OrderedFloat; use regex::Regex; use serde::Serialize; use serde_json::{json, Value}; @@ -33,11 +36,13 @@ pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); -#[derive(Debug, Clone, Default, PartialEq, Eq, Deserr)] +#[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct SearchQuery { #[deserr(default, error = DeserrJsonError)] pub q: Option, + #[deserr(default, error = DeserrJsonError)] + pub vector: Option>, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -86,13 +91,15 @@ impl SearchQuery { // This struct contains the fields of `SearchQuery` inline. // This is because neither deserr nor serde support `flatten` when using `deny_unknown_fields. // The `From` implementation ensures both structs remain up to date. -#[derive(Debug, Clone, PartialEq, Eq, Deserr)] +#[derive(Debug, Clone, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct SearchQueryWithIndex { #[deserr(error = DeserrJsonError, missing_field_error = DeserrJsonError::missing_index_uid)] pub index_uid: IndexUid, #[deserr(default, error = DeserrJsonError)] pub q: Option, + #[deserr(default, error = DeserrJsonError)] + pub vector: Option>, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -136,6 +143,7 @@ impl SearchQueryWithIndex { let SearchQueryWithIndex { index_uid, q, + vector, offset, limit, page, @@ -159,6 +167,7 @@ impl SearchQueryWithIndex { index_uid, SearchQuery { q, + vector, offset, limit, page, @@ -220,6 +229,8 @@ pub struct SearchHit { pub ranking_score: Option, #[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")] pub ranking_score_details: Option>, + #[serde(rename = "_semanticScore", skip_serializing_if = "Option::is_none")] + pub semantic_score: Option, } #[derive(Serialize, Debug, Clone, PartialEq)] @@ -227,6 +238,8 @@ pub struct SearchHit { pub struct SearchResult { pub hits: Vec, pub query: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub vector: Option>, pub processing_time_ms: u128, #[serde(flatten)] pub hits_info: HitsInfo, @@ -289,6 +302,14 @@ pub fn perform_search( let mut search = index.search(&rtxn); + if query.vector.is_some() && query.q.is_some() { + warn!("Ignoring the query string `q` when used with the `vector` parameter."); + } + + if let Some(ref vector) = query.vector { + search.vector(vector.clone()); + } + if let Some(ref query) = query.q { search.query(query); } @@ -312,6 +333,10 @@ pub fn perform_search( features.check_score_details()?; } + if query.vector.is_some() { + features.check_vector()?; + } + // compute the offset on the limit depending on the pagination mode. let (offset, limit) = if is_finite_pagination { let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); @@ -418,7 +443,6 @@ pub fn perform_search( formatter_builder.highlight_suffix(query.highlight_post_tag); let mut documents = Vec::new(); - let documents_iter = index.documents(&rtxn, documents_ids)?; for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) { @@ -445,6 +469,14 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } + let semantic_score = match query.vector.as_ref() { + Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { + Some(vectors) => compute_semantic_score(vector, vectors)?, + None => None, + }, + None => None, + }; + let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); let ranking_score_details = @@ -456,6 +488,7 @@ pub fn perform_search( matches_position, ranking_score_details, ranking_score, + semantic_score, }; documents.push(hit); } @@ -505,7 +538,8 @@ pub fn perform_search( let result = SearchResult { hits: documents, hits_info, - query: query.q.clone().unwrap_or_default(), + query: query.q.unwrap_or_default(), + vector: query.vector, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, @@ -529,6 +563,17 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } +fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result> { + let vectors = serde_json::from_value(vectors) + .map(VectorOrArrayOfVectors::into_array_of_vectors) + .map_err(InternalError::SerdeJson)?; + Ok(vectors + .into_iter() + .map(|v| OrderedFloat(dot_product_similarity(query, &v))) + .max() + .map(OrderedFloat::into_inner)) +} + fn compute_formatted_options( attr_to_highlight: &HashSet, attr_to_crop: &[String], @@ -656,6 +701,22 @@ fn make_document( Ok(document) } +/// Extract the JSON value under the field name specified +/// but doesn't support nested objects. +fn extract_field( + field_name: &str, + field_ids_map: &FieldsIdsMap, + obkv: obkv::KvReaderU16, +) -> Result, MeilisearchHttpError> { + match field_ids_map.id(field_name) { + Some(fid) => match obkv.get(fid) { + Some(value) => Ok(serde_json::from_slice(value).map(Some)?), + None => Ok(None), + }, + None => Ok(None), + } +} + fn format_fields>( document: &Document, field_ids_map: &FieldsIdsMap, diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 138103723..08f0c2645 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -15,6 +15,7 @@ license.workspace = true bimap = { version = "0.6.3", features = ["serde"] } bincode = "1.3.3" bstr = "1.4.0" +bytemuck = { version = "1.13.1", features = ["extern_crate_alloc"] } byteorder = "1.4.3" charabia = { version = "0.7.2", default-features = false } concat-arrays = "0.1.2" @@ -32,18 +33,21 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f "lmdb", "sync-read-txn", ] } +hnsw = { version = "0.11.0", features = ["serde1"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.5.10" obkv = "0.2.0" once_cell = "1.17.1" ordered-float = "3.6.0" +rand_pcg = { version = "0.3.1", features = ["serde1"] } rayon = "1.7.0" roaring = "0.10.1" rstar = { version = "0.10.0", features = ["serde"] } serde = { version = "1.0.160", features = ["derive"] } serde_json = { version = "1.0.95", features = ["preserve_order"] } slice-group-by = "0.3.0" +space = "0.17.0" smallstr = { version = "0.3.0", features = ["serde"] } smallvec = "1.10.0" smartstring = "1.0.1" diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 87c9a004d..82de56434 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -52,6 +52,7 @@ fn main() -> Result<(), Box> { let docs = execute_search( &mut ctx, &(!query.trim().is_empty()).then(|| query.trim().to_owned()), + &None, TermsMatchingStrategy::Last, milli::score_details::ScoringStrategy::Skip, false, diff --git a/milli/src/distance.rs b/milli/src/distance.rs new file mode 100644 index 000000000..c838e4bd4 --- /dev/null +++ b/milli/src/distance.rs @@ -0,0 +1,25 @@ +use serde::{Deserialize, Serialize}; +use space::Metric; + +#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)] +pub struct DotProduct; + +impl Metric> for DotProduct { + type Unit = u32; + + // Following . + // + // Here is a playground that validate the ordering of the bit representation of floats in range 0.0..=1.0: + // + fn distance(&self, a: &Vec, b: &Vec) -> Self::Unit { + let dist = 1.0 - dot_product_similarity(a, b); + debug_assert!(!dist.is_nan()); + dist.to_bits() + } +} + +/// Returns the dot product similarity score that will between 0.0 and 1.0 +/// if both vectors are normalized. The higher the more similar the vectors are. +pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { + a.iter().zip(b).map(|(a, b)| a * b).sum() +} diff --git a/milli/src/error.rs b/milli/src/error.rs index 8d55eabbd..3df599b61 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -110,9 +110,13 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco }, #[error(transparent)] InvalidGeoField(#[from] GeoError), + #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] + InvalidVectorDimensions { expected: usize, found: usize }, + #[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] + InvalidVectorsType { document_id: Value, value: Value }, #[error("{0}")] InvalidFilter(String), - #[error("Invalid type for filter subexpression: `expected {}, found: {1}`.", .0.join(", "))] + #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] InvalidFilterExpression(&'static [&'static str], Value), #[error("Attribute `{}` is not sortable. {}", .field, diff --git a/milli/src/index.rs b/milli/src/index.rs index fad3f665c..a22901993 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -8,10 +8,12 @@ use charabia::{Language, Script}; use heed::flags::Flags; use heed::types::*; use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn}; +use rand_pcg::Pcg32; use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; +use crate::distance::DotProduct; use crate::error::{InternalError, UserError}; use crate::facet::FacetType; use crate::fields_ids_map::FieldsIdsMap; @@ -20,12 +22,16 @@ use crate::heed_codec::facet::{ FieldIdCodec, OrderedF64Codec, }; use crate::heed_codec::{ScriptLanguageCodec, StrBEU16Codec, StrRefCodec}; +use crate::readable_slices::ReadableSlices; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, }; +/// The HNSW data-structure that we serialize, fill and search in. +pub type Hnsw = hnsw::Hnsw, Pcg32, 12, 24>; + pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; @@ -42,6 +48,10 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; + /// The prefix of the key that is used to store the, potential big, HNSW structure. + /// It is concatenated with a big-endian encoded number (non-human readable). + /// e.g. vector-hnsw0x0032. + pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw"; pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids"; pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; @@ -86,6 +96,7 @@ pub mod db_name { pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; + pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; pub const DOCUMENTS: &str = "documents"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } @@ -149,6 +160,9 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, + /// Maps a vector id to the document id that have it. + pub vector_id_docid: Database, OwnedType>, + /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, ObkvCodec>, } @@ -162,7 +176,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(23); + options.max_dbs(24); unsafe { options.flag(Flags::MdbAlwaysFreePages) }; let env = options.open(path)?; @@ -198,11 +212,11 @@ impl Index { env.create_database(&mut wtxn, Some(FACET_ID_IS_NULL_DOCIDS))?; let facet_id_is_empty_docids = env.create_database(&mut wtxn, Some(FACET_ID_IS_EMPTY_DOCIDS))?; - let field_id_docid_facet_f64s = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; + let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; wtxn.commit()?; @@ -231,6 +245,7 @@ impl Index { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, }) } @@ -502,6 +517,56 @@ impl Index { } } + /* vector HNSW */ + + /// Writes the provided `hnsw`. + pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { + // We must delete all the chunks before we write the new HNSW chunks. + self.delete_vector_hnsw(wtxn)?; + + let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB + let bytes = bincode::serialize(hnsw).map_err(|_| heed::Error::Encoding)?; + for (i, chunk) in bytes.chunks(chunk_size).enumerate() { + let i = i as u32; + let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec(); + key.extend_from_slice(&i.to_be_bytes()); + self.main.put::<_, ByteSlice, ByteSlice>(wtxn, &key, chunk)?; + } + Ok(()) + } + + /// Delete the `hnsw`. + pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { + let mut iter = self.main.prefix_iter_mut::<_, ByteSlice, DecodeIgnore>( + wtxn, + main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes(), + )?; + let mut deleted = false; + while iter.next().transpose()?.is_some() { + // We do not keep a reference to the key or the value. + unsafe { deleted |= iter.del_current()? }; + } + Ok(deleted) + } + + /// Returns the `hnsw`. + pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { + let mut slices = Vec::new(); + for result in + self.main.prefix_iter::<_, Str, ByteSlice>(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)? + { + let (_, slice) = result?; + slices.push(slice); + } + + if slices.is_empty() { + Ok(None) + } else { + let readable_slices: ReadableSlices<_> = slices.into_iter().collect(); + Ok(Some(bincode::deserialize_from(readable_slices).map_err(|_| heed::Error::Decoding)?)) + } + } + /* field distribution */ /// Writes the field distribution which associates every field name with diff --git a/milli/src/lib.rs b/milli/src/lib.rs index d3ee4f08e..99126f60e 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,6 +10,7 @@ pub mod documents; mod asc_desc; mod criterion; +pub mod distance; mod error; mod external_documents_ids; pub mod facet; @@ -17,6 +18,7 @@ mod fields_ids_map; pub mod heed_codec; pub mod index; pub mod proximity; +mod readable_slices; pub mod score_details; mod search; pub mod update; @@ -30,6 +32,7 @@ use std::convert::{TryFrom, TryInto}; use std::hash::BuildHasherDefault; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; +pub use distance::dot_product_similarity; pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; @@ -284,6 +287,35 @@ pub fn normalize_facet(original: &str) -> String { CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase() } +/// Represents either a vector or an array of multiple vectors. +#[derive(serde::Serialize, serde::Deserialize, Debug)] +#[serde(transparent)] +pub struct VectorOrArrayOfVectors { + #[serde(with = "either::serde_untagged")] + inner: either::Either, Vec>>, +} + +impl VectorOrArrayOfVectors { + pub fn into_array_of_vectors(self) -> Vec> { + match self.inner { + either::Either::Left(vector) => vec![vector], + either::Either::Right(vectors) => vectors, + } + } +} + +/// Normalize a vector by dividing the dimensions by the length of it. +pub fn normalize_vector(mut vector: Vec) -> Vec { + let squared: f32 = vector.iter().map(|x| x * x).sum(); + let length = squared.sqrt(); + if length <= f32::EPSILON { + vector + } else { + vector.iter_mut().for_each(|x| *x /= length); + vector + } +} + #[cfg(test)] mod tests { use serde_json::json; diff --git a/milli/src/readable_slices.rs b/milli/src/readable_slices.rs new file mode 100644 index 000000000..7f5be214f --- /dev/null +++ b/milli/src/readable_slices.rs @@ -0,0 +1,85 @@ +use std::io::{self, Read}; +use std::iter::FromIterator; + +pub struct ReadableSlices { + inner: Vec, + pos: u64, +} + +impl FromIterator for ReadableSlices { + fn from_iter>(iter: T) -> Self { + ReadableSlices { inner: iter.into_iter().collect(), pos: 0 } + } +} + +impl> Read for ReadableSlices { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + let original_buf_len = buf.len(); + + // We explore the list of slices to find the one where we must start reading. + let mut pos = self.pos; + let index = match self + .inner + .iter() + .map(|s| s.as_ref().len() as u64) + .position(|size| pos.checked_sub(size).map(|p| pos = p).is_none()) + { + Some(index) => index, + None => return Ok(0), + }; + + let mut inner_pos = pos as usize; + for slice in &self.inner[index..] { + let slice = &slice.as_ref()[inner_pos..]; + + if buf.len() > slice.len() { + // We must exhaust the current slice and go to the next one there is not enough here. + buf[..slice.len()].copy_from_slice(slice); + buf = &mut buf[slice.len()..]; + inner_pos = 0; + } else { + // There is enough in this slice to fill the remaining bytes of the buffer. + // Let's break just after filling it. + buf.copy_from_slice(&slice[..buf.len()]); + buf = &mut []; + break; + } + } + + let written = original_buf_len - buf.len(); + self.pos += written as u64; + Ok(written) + } +} + +#[cfg(test)] +mod test { + use std::io::Read; + + use super::ReadableSlices; + + #[test] + fn basic() { + let data: Vec<_> = (0..100).collect(); + let splits: Vec<_> = data.chunks(3).collect(); + let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); + + let mut output = Vec::new(); + let length = rdslices.read_to_end(&mut output).unwrap(); + assert_eq!(length, data.len()); + assert_eq!(output, data); + } + + #[test] + fn small_reads() { + let data: Vec<_> = (0..u8::MAX).collect(); + let splits: Vec<_> = data.chunks(27).collect(); + let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); + + let buffer = &mut [0; 45]; + let length = rdslices.read(buffer).unwrap(); + let expected: Vec<_> = (0..buffer.len() as u8).collect(); + assert_eq!(length, buffer.len()); + assert_eq!(buffer, &expected[..]); + } +} diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 3c972d9b0..970c0b7ab 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -23,6 +23,7 @@ pub mod new; pub struct Search<'a> { query: Option, + vector: Option>, // this should be linked to the String in the query filter: Option>, offset: usize, @@ -41,6 +42,7 @@ impl<'a> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { Search { query: None, + vector: None, filter: None, offset: 0, limit: 20, @@ -60,6 +62,11 @@ impl<'a> Search<'a> { self } + pub fn vector(&mut self, vector: impl Into>) -> &mut Search<'a> { + self.vector = Some(vector.into()); + self + } + pub fn offset(&mut self, offset: usize) -> &mut Search<'a> { self.offset = offset; self @@ -114,6 +121,7 @@ impl<'a> Search<'a> { execute_search( &mut ctx, &self.query, + &self.vector, self.terms_matching_strategy, self.scoring_strategy, self.exhaustive_number_hits, @@ -141,6 +149,7 @@ impl fmt::Debug for Search<'_> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let Search { query, + vector: _, filter, offset, limit, @@ -155,6 +164,7 @@ impl fmt::Debug for Search<'_> { } = self; f.debug_struct("Search") .field("query", query) + .field("vector", &"[...]") .field("filter", filter) .field("offset", offset) .field("limit", limit) diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index f33d595e5..ce28e16c1 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -509,6 +509,7 @@ mod tests { let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( &mut ctx, &Some(query.to_string()), + &None, crate::TermsMatchingStrategy::default(), crate::score_details::ScoringStrategy::Skip, false, diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 8df764f29..8bdcf077b 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -28,6 +28,7 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; +use hnsw::Searcher; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -39,6 +40,7 @@ use ranking_rules::{ use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache}; use roaring::RoaringBitmap; use sort::Sort; +use space::Neighbor; use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; @@ -46,7 +48,10 @@ use self::graph_based_ranking_rule::Words; use self::interner::Interned; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; -use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError}; +use crate::{ + normalize_vector, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, + UserError, BEU32, +}; /// A structure used throughout the execution of a search query. pub struct SearchContext<'ctx> { @@ -350,6 +355,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( pub fn execute_search( ctx: &mut SearchContext, query: &Option, + vector: &Option>, terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, exhaustive_number_hits: bool, @@ -370,8 +376,40 @@ pub fn execute_search( check_sort_criteria(ctx, sort_criteria.as_ref())?; - let mut located_query_terms = None; + if let Some(vector) = vector { + let mut searcher = Searcher::new(); + let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default(); + let ef = hnsw.len().min(100); + let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef]; + let vector = normalize_vector(vector.clone()); + let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]); + let mut docids = Vec::new(); + let mut uniq_docids = RoaringBitmap::new(); + for Neighbor { index, distance: _ } in neighbors.iter() { + let index = BEU32::new(*index as u32); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get(); + if universe.contains(docid) && uniq_docids.insert(docid) { + docids.push(docid); + if docids.len() == (from + length) { + break; + } + } + } + + // return the nearest documents that are also part of the candidates + // along with a dummy list of scores that are useless in this context. + let docids: Vec<_> = docids.into_iter().skip(from).take(length).collect(); + + return Ok(PartialSearchResult { + candidates: universe, + document_scores: vec![Vec::new(); docids.len()], + documents_ids: docids, + located_query_terms: None, + }); + } + + let mut located_query_terms = None; let query_terms = if let Some(query) = query { // We make sure that the analyzer is aware of the stop words // this ensures that the query builder is able to properly remove them. @@ -439,7 +477,6 @@ pub fn execute_search( }; let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output; - let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?; // The candidates is the universe unless the exhaustive number of hits diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index 04119c641..f4a2d43fe 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -39,6 +39,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, + vector_id_docid, documents, } = self.index; @@ -57,6 +58,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?; + self.index.delete_vector_hnsw(self.wtxn)?; // We clean all the faceted documents ids. for field_id in faceted_fields { @@ -95,6 +97,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> { facet_id_string_docids.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?; + vector_id_docid.clear(self.wtxn)?; documents.clear(self.wtxn)?; Ok(number_of_documents) diff --git a/milli/src/update/delete_documents.rs b/milli/src/update/delete_documents.rs index b971768a3..766f0e16e 100644 --- a/milli/src/update/delete_documents.rs +++ b/milli/src/update/delete_documents.rs @@ -4,8 +4,10 @@ use std::collections::{BTreeSet, HashMap, HashSet}; use fst::IntoStreamer; use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice}; use heed::{BytesDecode, BytesEncode, Database, RwIter}; +use hnsw::Searcher; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; +use space::KnnPoints; use time::OffsetDateTime; use super::facet::delete::FacetsDelete; @@ -14,6 +16,7 @@ use crate::error::InternalError; use crate::facet::FacetType; use crate::heed_codec::facet::FieldDocIdFacetCodec; use crate::heed_codec::CboRoaringBitmapCodec; +use crate::index::Hnsw; use crate::{ ExternalDocumentsIds, FieldId, FieldIdMapMissingEntry, Index, Result, RoaringBitmapCodec, BEU32, }; @@ -240,6 +243,7 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { facet_id_exists_docids, facet_id_is_null_docids, facet_id_is_empty_docids, + vector_id_docid, documents, } = self.index; // Remove from the documents database @@ -429,6 +433,30 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> { &self.to_delete_docids, )?; + // An ugly and slow way to remove the vectors from the HNSW + // It basically reconstructs the HNSW from scratch without editing the current one. + let current_hnsw = self.index.vector_hnsw(self.wtxn)?.unwrap_or_default(); + if !current_hnsw.is_empty() { + let mut new_hnsw = Hnsw::default(); + let mut searcher = Searcher::new(); + let mut new_vector_id_docids = Vec::new(); + + for result in vector_id_docid.iter(self.wtxn)? { + let (vector_id, docid) = result?; + if !self.to_delete_docids.contains(docid.get()) { + let vector = current_hnsw.get_point(vector_id.get() as usize).clone(); + let vector_id = new_hnsw.insert(vector, &mut searcher); + new_vector_id_docids.push((vector_id as u32, docid)); + } + } + + vector_id_docid.clear(self.wtxn)?; + for (vector_id, docid) in new_vector_id_docids { + vector_id_docid.put(self.wtxn, &BEU32::new(vector_id), &docid)?; + } + self.index.put_vector_hnsw(self.wtxn, &new_hnsw)?; + } + self.index.put_soft_deleted_documents_ids(self.wtxn, &RoaringBitmap::new())?; Ok(DetailedDocumentDeletionResult { diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs new file mode 100644 index 000000000..0fad3be07 --- /dev/null +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -0,0 +1,65 @@ +use std::convert::TryFrom; +use std::fs::File; +use std::io; + +use bytemuck::cast_slice; +use serde_json::{from_slice, Value}; + +use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; +use crate::error::UserError; +use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors}; + +/// Extracts the embedding vector contained in each document under the `_vectors` field. +/// +/// Returns the generated grenad reader containing the docid as key associated to the Vec +#[logging_timer::time] +pub fn extract_vector_points( + obkv_documents: grenad::Reader, + indexer: GrenadParameters, + primary_key_id: FieldId, + vectors_fid: FieldId, +) -> Result> { + let mut writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let mut cursor = obkv_documents.into_cursor()?; + while let Some((docid_bytes, value)) = cursor.move_on_next()? { + let obkv = obkv::KvReader::new(value); + + // since we only needs the primary key when we throw an error we create this getter to + // lazily get it when needed + let document_id = || -> Value { + let document_id = obkv.get(primary_key_id).unwrap(); + serde_json::from_slice(document_id).unwrap() + }; + + // first we retrieve the _vectors field + if let Some(vectors) = obkv.get(vectors_fid) { + // extract the vectors + let vectors = match from_slice(vectors) { + Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors), + Err(_) => { + return Err(UserError::InvalidVectorsType { + document_id: document_id(), + value: from_slice(vectors).map_err(InternalError::SerdeJson)?, + } + .into()) + } + }; + + for (i, vector) in vectors.into_iter().enumerate().take(u16::MAX as usize) { + let index = u16::try_from(i).unwrap(); + let mut key = docid_bytes.to_vec(); + key.extend_from_slice(&index.to_be_bytes()); + let bytes = cast_slice(&vector); + writer.insert(key, bytes)?; + } + } + // else => the `_vectors` object was `null`, there is nothing to do + } + + writer_into_reader(writer) +} diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 632f568ab..6259c7272 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -4,6 +4,7 @@ mod extract_facet_string_docids; mod extract_fid_docid_facet_values; mod extract_fid_word_count_docids; mod extract_geo_points; +mod extract_vector_points; mod extract_word_docids; mod extract_word_fid_docids; mod extract_word_pair_proximity_docids; @@ -22,6 +23,7 @@ use self::extract_facet_string_docids::extract_facet_string_docids; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_geo_points::extract_geo_points; +use self::extract_vector_points::extract_vector_points; use self::extract_word_docids::extract_word_docids; use self::extract_word_fid_docids::extract_word_fid_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; @@ -45,6 +47,7 @@ pub(crate) fn data_from_obkv_documents( faceted_fields: HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, + vectors_field_id: Option, stop_words: Option>, max_positions_per_attributes: Option, exact_attributes: HashSet, @@ -69,6 +72,7 @@ pub(crate) fn data_from_obkv_documents( &faceted_fields, primary_key_id, geo_fields_ids, + vectors_field_id, &stop_words, max_positions_per_attributes, ) @@ -279,6 +283,7 @@ fn send_and_extract_flattened_documents_data( faceted_fields: &HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, + vectors_field_id: Option, stop_words: &Option>, max_positions_per_attributes: Option, ) -> Result<( @@ -307,6 +312,25 @@ fn send_and_extract_flattened_documents_data( }); } + if let Some(vectors_field_id) = vectors_field_id { + let documents_chunk_cloned = flattened_documents_chunk.clone(); + let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + rayon::spawn(move || { + let result = extract_vector_points( + documents_chunk_cloned, + indexer, + primary_key_id, + vectors_field_id, + ); + let _ = match result { + Ok(vector_points) => { + lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) + } + Err(error) => lmdb_writer_sx_cloned.send(Err(error)), + }; + }); + } + let (docid_word_positions_chunk, docid_fid_facet_values_chunks): (Result<_>, Result<_>) = rayon::join( || { diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 993f87a1f..5b6e03637 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -304,6 +304,8 @@ where } None => None, }; + // get the fid of the `_vectors` field. + let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); let stop_words = self.index.stop_words(self.wtxn)?; let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?; @@ -340,6 +342,7 @@ where faceted_fields, primary_key_id, geo_fields_ids, + vectors_field_id, stop_words, max_positions_per_attributes, exact_attributes, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 89b10bffe..3f197fbd1 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -4,20 +4,27 @@ use std::convert::TryInto; use std::fs::File; use std::io; +use bytemuck::allocation::pod_collect_to_vec; use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::ByteSlice; use heed::RwTxn; +use hnsw::Searcher; use roaring::RoaringBitmap; +use space::KnnPoints; use super::helpers::{ self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap, }; use super::{ClonableMmap, MergeFn}; +use crate::error::UserError; use crate::facet::FacetType; use crate::update::facet::FacetsUpdate; -use crate::update::index_documents::helpers::as_cloneable_grenad; -use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result}; +use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; +use crate::{ + lat_lng_to_xyz, normalize_vector, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result, + BEU32, +}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -38,6 +45,7 @@ pub(crate) enum TypedChunk { FieldIdFacetIsNullDocids(grenad::Reader), FieldIdFacetIsEmptyDocids(grenad::Reader), GeoPoints(grenad::Reader), + VectorPoints(grenad::Reader), ScriptLanguageDocids(HashMap<(Script, Language), RoaringBitmap>), } @@ -221,6 +229,40 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_rtree(wtxn, &rtree)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } + TypedChunk::VectorPoints(vector_points) => { + let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default(); + let mut searcher = Searcher::new(); + + let mut expected_dimensions = match index.vector_id_docid.iter(wtxn)?.next() { + Some(result) => { + let (vector_id, _) = result?; + Some(hnsw.get_point(vector_id.get() as usize).len()) + } + None => None, + }; + + let mut cursor = vector_points.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + // convert the key back to a u32 (4 bytes) + let (left, _index) = try_split_array_at(key).unwrap(); + let docid = DocumentId::from_be_bytes(left); + // convert the vector back to a Vec + let vector: Vec = pod_collect_to_vec(value); + + // TODO Inform the user about the document that has a wrong `_vectors` + let found = vector.len(); + let expected = *expected_dimensions.get_or_insert(found); + if expected != found { + return Err(UserError::InvalidVectorDimensions { expected, found })?; + } + + let vector = normalize_vector(vector); + let vector_id = hnsw.insert(vector, &mut searcher) as u32; + index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?; + } + log::debug!("There are {} entries in the HNSW so far", hnsw.len()); + index.put_vector_hnsw(wtxn, &hnsw)?; + } TypedChunk::ScriptLanguageDocids(hash_pair) => { let mut buffer = Vec::new(); for (key, value) in hash_pair {