mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-05 19:23:31 +01:00
Merge #3825
3825: Accept semantic vectors and allow users to query nearest neighbors r=Kerollmops a=Kerollmops This Pull Request brings a new feature to the current API. The engine accepts a new `_vectors` field akin to the `_geo` one. This vector is stored in Meilisearch and can be retrieved via search. This work is the first step toward hybrid search, bringing the best of both worlds: keyword and semantic search ❤️🔥 ## ToDo - [x] Make it possible to get the `limit` nearest neighbors from a user-generated vector by using the `vector` field of search route. - [x] Delete the documents and vectors from the HNSW-related data structures. - [x] Do it the slow and ugly way (we need to be able to iterate over all the values). - [ ] Do it the efficient way (Wait for a new method or implement it myself). - [ ] ~~Move from the `hnsw` crate to the hgg one~~ The hgg crate is too slow. Meilisearch takes approximately 88s to answer a query. It is related to the time it takes to deserialize the `Hgg` data structure or search in it. I didn't take the time to measure precisely. We moved back to the hnsw crate which takes approximately 40ms to answer. - [ ] ~~Wait for a fix for https://github.com/rust-cv/hgg/issues/4.~~ - [x] Fix the current dot product function. - [x] Fill in the other `SearchResult` fields. - [x] Remove the `hnsw` dependency of the meilisearch crate. - [x] Fix the pages by taking the offset into account. - [x] Release a first prototype https://github.com/meilisearch/product/discussions/621#discussioncomment-6183647 - [x] Make the pagination and filtering faster and more correct. - [x] Return the original vector in the output search results (like `query`). - [x] Return an `_semanticSimilarity` field in the documents (it's a dot product) - [x] Return this score even if the `_vectors` field is not displayed - [x] Rename the field `_semanticScore`. - [ ] Return the `_geoDistance` value even if the `_geo` field is not displayed - [x] Store the HNSW on possibly multiple LMDB values. - [ ] Measure it and make it faster if needed - [ ] Export the `ReadableSlices` type into a small external crate - [x] Accept an `_vectors` field instead of the `_vector` one. - [x] Normalize all vectors. - [ ] Remove the `_vectors` field from the default searchable attributes (as we do with `_geo`?). - [ ] Correctly compute the candidates by remembering the documents having a valid `_vectors` field. - [ ] Return the right errors: - [ ] Return an error when the query vector is not the same length as the vectors in the HNSW. - [ ] We must return the user document id that triggered the vector dimension issue. - [x] If an indexation error occurs. - [ ] Fix the error codes when using the search route. - [ ] ~~Introduce some settings:~~ We currently ensure that the vector length is consistent over the whole set of documents and return an error for when a vector dimension doesn't follow the current number of dimensions. - [ ] The length of the vector the user will provide. - [ ] The distance function (we only support dot as of now). - [ ] Introduce other distance functions - [ ] Euclidean - [ ] Dot Product - [ ] Cosine - [ ] Make them SIMD optimized - [ ] Give credit to qdrant - [ ] Add tests. - [ ] Write a mini spec. - [ ] Release it in v1.3 as an experimental feature. Co-authored-by: Clément Renault <clement@meilisearch.com> Co-authored-by: Kerollmops <clement@meilisearch.com>
This commit is contained in:
commit
d8ea688481
65
Cargo.lock
generated
65
Cargo.lock
generated
@ -1221,6 +1221,12 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "doc-comment"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "dump"
|
name = "dump"
|
||||||
version = "1.2.0"
|
version = "1.2.0"
|
||||||
@ -1725,6 +1731,15 @@ dependencies = [
|
|||||||
"byteorder",
|
"byteorder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hashbrown"
|
||||||
|
version = "0.11.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ab5ef0d4909ef3724cc8cce6ccc8572c5c817592e9285f5464f8e86f8bd3726e"
|
||||||
|
dependencies = [
|
||||||
|
"ahash 0.7.6",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hashbrown"
|
name = "hashbrown"
|
||||||
version = "0.12.3"
|
version = "0.12.3"
|
||||||
@ -1826,6 +1841,22 @@ dependencies = [
|
|||||||
"digest",
|
"digest",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hnsw"
|
||||||
|
version = "0.11.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "2b9740ebf8769ec4ad6762cc951ba18f39bba6dfbc2fbbe46285f7539af79752"
|
||||||
|
dependencies = [
|
||||||
|
"ahash 0.7.6",
|
||||||
|
"hashbrown 0.11.2",
|
||||||
|
"libm",
|
||||||
|
"num-traits",
|
||||||
|
"rand_core",
|
||||||
|
"serde",
|
||||||
|
"smallvec",
|
||||||
|
"space",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http"
|
name = "http"
|
||||||
version = "0.2.9"
|
version = "0.2.9"
|
||||||
@ -1956,7 +1987,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"autocfg",
|
"autocfg",
|
||||||
"hashbrown",
|
"hashbrown 0.12.3",
|
||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -2057,7 +2088,7 @@ checksum = "37228e06c75842d1097432d94d02f37fe3ebfca9791c2e8fef6e9db17ed128c1"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"cedarwood",
|
"cedarwood",
|
||||||
"fxhash",
|
"fxhash",
|
||||||
"hashbrown",
|
"hashbrown 0.12.3",
|
||||||
"lazy_static",
|
"lazy_static",
|
||||||
"phf",
|
"phf",
|
||||||
"phf_codegen",
|
"phf_codegen",
|
||||||
@ -2564,6 +2595,7 @@ dependencies = [
|
|||||||
"num_cpus",
|
"num_cpus",
|
||||||
"obkv",
|
"obkv",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
|
"ordered-float",
|
||||||
"parking_lot",
|
"parking_lot",
|
||||||
"permissive-json-pointer",
|
"permissive-json-pointer",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
@ -2683,6 +2715,7 @@ dependencies = [
|
|||||||
"bimap",
|
"bimap",
|
||||||
"bincode",
|
"bincode",
|
||||||
"bstr",
|
"bstr",
|
||||||
|
"bytemuck",
|
||||||
"byteorder",
|
"byteorder",
|
||||||
"charabia",
|
"charabia",
|
||||||
"concat-arrays",
|
"concat-arrays",
|
||||||
@ -2697,6 +2730,7 @@ dependencies = [
|
|||||||
"geoutils",
|
"geoutils",
|
||||||
"grenad",
|
"grenad",
|
||||||
"heed",
|
"heed",
|
||||||
|
"hnsw",
|
||||||
"insta",
|
"insta",
|
||||||
"itertools",
|
"itertools",
|
||||||
"json-depth-checker",
|
"json-depth-checker",
|
||||||
@ -2711,6 +2745,7 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"ordered-float",
|
"ordered-float",
|
||||||
"rand",
|
"rand",
|
||||||
|
"rand_pcg",
|
||||||
"rayon",
|
"rayon",
|
||||||
"roaring",
|
"roaring",
|
||||||
"rstar",
|
"rstar",
|
||||||
@ -2720,6 +2755,7 @@ dependencies = [
|
|||||||
"smallstr",
|
"smallstr",
|
||||||
"smallvec",
|
"smallvec",
|
||||||
"smartstring",
|
"smartstring",
|
||||||
|
"space",
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror",
|
"thiserror",
|
||||||
"time",
|
"time",
|
||||||
@ -3272,6 +3308,16 @@ dependencies = [
|
|||||||
"getrandom",
|
"getrandom",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "rand_pcg"
|
||||||
|
version = "0.3.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e"
|
||||||
|
dependencies = [
|
||||||
|
"rand_core",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rayon"
|
name = "rayon"
|
||||||
version = "1.7.0"
|
version = "1.7.0"
|
||||||
@ -3731,6 +3777,9 @@ name = "smallvec"
|
|||||||
version = "1.10.0"
|
version = "1.10.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
|
checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
|
||||||
|
dependencies = [
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "smartstring"
|
name = "smartstring"
|
||||||
@ -3753,6 +3802,16 @@ dependencies = [
|
|||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "space"
|
||||||
|
version = "0.17.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c5ab9701ae895386d13db622abf411989deff7109b13b46b6173bb4ce5c1d123"
|
||||||
|
dependencies = [
|
||||||
|
"doc-comment",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spin"
|
name = "spin"
|
||||||
version = "0.5.2"
|
version = "0.5.2"
|
||||||
@ -4404,7 +4463,7 @@ version = "0.16.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9c531a2dc4c462b833788be2c07eef4e621d0e9edbd55bf280cc164c1c1aa043"
|
checksum = "9c531a2dc4c462b833788be2c07eef4e621d0e9edbd55bf280cc164c1c1aa043"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"hashbrown",
|
"hashbrown 0.12.3",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ impl RoFeatures {
|
|||||||
Err(FeatureNotEnabledError {
|
Err(FeatureNotEnabledError {
|
||||||
disabled_action: "Passing `vector` as a query parameter",
|
disabled_action: "Passing `vector` as a query parameter",
|
||||||
feature: "vector store",
|
feature: "vector store",
|
||||||
issue_link: "https://github.com/meilisearch/meilisearch/discussions/TODO",
|
issue_link: "https://github.com/meilisearch/product/discussions/677",
|
||||||
}
|
}
|
||||||
.into())
|
.into())
|
||||||
}
|
}
|
||||||
|
@ -217,6 +217,8 @@ InvalidDocumentFields , InvalidRequest , BAD_REQUEST ;
|
|||||||
MissingDocumentFilter , InvalidRequest , BAD_REQUEST ;
|
MissingDocumentFilter , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentFilter , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentGeoField , InvalidRequest , BAD_REQUEST ;
|
||||||
|
InvalidVectorDimensions , InvalidRequest , BAD_REQUEST ;
|
||||||
|
InvalidVectorsType , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentId , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
|
InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ;
|
||||||
@ -239,6 +241,7 @@ InvalidSearchMatchingStrategy , InvalidRequest , BAD_REQUEST ;
|
|||||||
InvalidSearchOffset , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchOffset , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchPage , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchPage , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchQ , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchQ , InvalidRequest , BAD_REQUEST ;
|
||||||
|
InvalidSearchVector , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchShowMatchesPosition , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchShowRankingScore , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchShowRankingScore , InvalidRequest , BAD_REQUEST ;
|
||||||
InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ;
|
InvalidSearchShowRankingScoreDetails , InvalidRequest , BAD_REQUEST ;
|
||||||
@ -335,6 +338,8 @@ impl ErrorCode for milli::Error {
|
|||||||
UserError::InvalidSortableAttribute { .. } => Code::InvalidSearchSort,
|
UserError::InvalidSortableAttribute { .. } => Code::InvalidSearchSort,
|
||||||
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
|
UserError::CriterionError(_) => Code::InvalidSettingsRankingRules,
|
||||||
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
|
UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField,
|
||||||
|
UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions,
|
||||||
|
UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType,
|
||||||
UserError::SortError(_) => Code::InvalidSearchSort,
|
UserError::SortError(_) => Code::InvalidSearchSort,
|
||||||
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
||||||
Code::InvalidSettingsTypoTolerance
|
Code::InvalidSettingsTypoTolerance
|
||||||
|
@ -48,6 +48,7 @@ mime = "0.3.17"
|
|||||||
num_cpus = "1.15.0"
|
num_cpus = "1.15.0"
|
||||||
obkv = "0.2.0"
|
obkv = "0.2.0"
|
||||||
once_cell = "1.17.1"
|
once_cell = "1.17.1"
|
||||||
|
ordered-float = "3.7.0"
|
||||||
parking_lot = "0.12.1"
|
parking_lot = "0.12.1"
|
||||||
permissive-json-pointer = { path = "../permissive-json-pointer" }
|
permissive-json-pointer = { path = "../permissive-json-pointer" }
|
||||||
pin-project-lite = "0.2.9"
|
pin-project-lite = "0.2.9"
|
||||||
|
@ -548,6 +548,10 @@ pub struct SearchAggregator {
|
|||||||
// The maximum number of terms in a q request
|
// The maximum number of terms in a q request
|
||||||
max_terms_number: usize,
|
max_terms_number: usize,
|
||||||
|
|
||||||
|
// vector
|
||||||
|
// The maximum number of floats in a vector request
|
||||||
|
max_vector_size: usize,
|
||||||
|
|
||||||
// every time a search is done, we increment the counter linked to the used settings
|
// every time a search is done, we increment the counter linked to the used settings
|
||||||
matching_strategy: HashMap<String, usize>,
|
matching_strategy: HashMap<String, usize>,
|
||||||
|
|
||||||
@ -617,6 +621,10 @@ impl SearchAggregator {
|
|||||||
ret.max_terms_number = q.split_whitespace().count();
|
ret.max_terms_number = q.split_whitespace().count();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref vector) = query.vector {
|
||||||
|
ret.max_vector_size = vector.len();
|
||||||
|
}
|
||||||
|
|
||||||
if query.is_finite_pagination() {
|
if query.is_finite_pagination() {
|
||||||
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
||||||
ret.max_limit = limit;
|
ret.max_limit = limit;
|
||||||
|
@ -34,6 +34,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||||||
pub struct SearchQueryGet {
|
pub struct SearchQueryGet {
|
||||||
#[deserr(default, error = DeserrQueryParamError<InvalidSearchQ>)]
|
#[deserr(default, error = DeserrQueryParamError<InvalidSearchQ>)]
|
||||||
q: Option<String>,
|
q: Option<String>,
|
||||||
|
#[deserr(default, error = DeserrQueryParamError<InvalidSearchVector>)]
|
||||||
|
vector: Option<Vec<f32>>,
|
||||||
#[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError<InvalidSearchOffset>)]
|
#[deserr(default = Param(DEFAULT_SEARCH_OFFSET()), error = DeserrQueryParamError<InvalidSearchOffset>)]
|
||||||
offset: Param<usize>,
|
offset: Param<usize>,
|
||||||
#[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError<InvalidSearchLimit>)]
|
#[deserr(default = Param(DEFAULT_SEARCH_LIMIT()), error = DeserrQueryParamError<InvalidSearchLimit>)]
|
||||||
@ -84,6 +86,7 @@ impl From<SearchQueryGet> for SearchQuery {
|
|||||||
|
|
||||||
Self {
|
Self {
|
||||||
q: other.q,
|
q: other.q,
|
||||||
|
vector: other.vector,
|
||||||
offset: other.offset.0,
|
offset: other.offset.0,
|
||||||
limit: other.limit.0,
|
limit: other.limit.0,
|
||||||
page: other.page.as_deref().copied(),
|
page: other.page.as_deref().copied(),
|
||||||
|
@ -6,18 +6,21 @@ use std::time::Instant;
|
|||||||
use deserr::Deserr;
|
use deserr::Deserr;
|
||||||
use either::Either;
|
use either::Either;
|
||||||
use index_scheduler::RoFeatures;
|
use index_scheduler::RoFeatures;
|
||||||
|
use log::warn;
|
||||||
use meilisearch_auth::IndexSearchRules;
|
use meilisearch_auth::IndexSearchRules;
|
||||||
use meilisearch_types::deserr::DeserrJsonError;
|
use meilisearch_types::deserr::DeserrJsonError;
|
||||||
use meilisearch_types::error::deserr_codes::*;
|
use meilisearch_types::error::deserr_codes::*;
|
||||||
use meilisearch_types::index_uid::IndexUid;
|
use meilisearch_types::index_uid::IndexUid;
|
||||||
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
||||||
|
use meilisearch_types::milli::{dot_product_similarity, InternalError};
|
||||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||||
use meilisearch_types::{milli, Document};
|
use meilisearch_types::{milli, Document};
|
||||||
use milli::tokenizer::TokenizerBuilder;
|
use milli::tokenizer::TokenizerBuilder;
|
||||||
use milli::{
|
use milli::{
|
||||||
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
|
AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder,
|
||||||
SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
|
SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET,
|
||||||
};
|
};
|
||||||
|
use ordered_float::OrderedFloat;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
@ -33,11 +36,13 @@ pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string();
|
|||||||
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
|
pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "<em>".to_string();
|
||||||
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
|
pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "</em>".to_string();
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Deserr)]
|
#[derive(Debug, Clone, Default, PartialEq, Deserr)]
|
||||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||||
pub struct SearchQuery {
|
pub struct SearchQuery {
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||||
pub q: Option<String>,
|
pub q: Option<String>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
||||||
|
pub vector: Option<Vec<f32>>,
|
||||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||||
pub offset: usize,
|
pub offset: usize,
|
||||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||||
@ -86,13 +91,15 @@ impl SearchQuery {
|
|||||||
// This struct contains the fields of `SearchQuery` inline.
|
// This struct contains the fields of `SearchQuery` inline.
|
||||||
// This is because neither deserr nor serde support `flatten` when using `deny_unknown_fields.
|
// This is because neither deserr nor serde support `flatten` when using `deny_unknown_fields.
|
||||||
// The `From<SearchQueryWithIndex>` implementation ensures both structs remain up to date.
|
// The `From<SearchQueryWithIndex>` implementation ensures both structs remain up to date.
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Deserr)]
|
#[derive(Debug, Clone, PartialEq, Deserr)]
|
||||||
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)]
|
||||||
pub struct SearchQueryWithIndex {
|
pub struct SearchQueryWithIndex {
|
||||||
#[deserr(error = DeserrJsonError<InvalidIndexUid>, missing_field_error = DeserrJsonError::missing_index_uid)]
|
#[deserr(error = DeserrJsonError<InvalidIndexUid>, missing_field_error = DeserrJsonError::missing_index_uid)]
|
||||||
pub index_uid: IndexUid,
|
pub index_uid: IndexUid,
|
||||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||||
pub q: Option<String>,
|
pub q: Option<String>,
|
||||||
|
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||||
|
pub vector: Option<Vec<f32>>,
|
||||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||||
pub offset: usize,
|
pub offset: usize,
|
||||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||||
@ -136,6 +143,7 @@ impl SearchQueryWithIndex {
|
|||||||
let SearchQueryWithIndex {
|
let SearchQueryWithIndex {
|
||||||
index_uid,
|
index_uid,
|
||||||
q,
|
q,
|
||||||
|
vector,
|
||||||
offset,
|
offset,
|
||||||
limit,
|
limit,
|
||||||
page,
|
page,
|
||||||
@ -159,6 +167,7 @@ impl SearchQueryWithIndex {
|
|||||||
index_uid,
|
index_uid,
|
||||||
SearchQuery {
|
SearchQuery {
|
||||||
q,
|
q,
|
||||||
|
vector,
|
||||||
offset,
|
offset,
|
||||||
limit,
|
limit,
|
||||||
page,
|
page,
|
||||||
@ -220,6 +229,8 @@ pub struct SearchHit {
|
|||||||
pub ranking_score: Option<f64>,
|
pub ranking_score: Option<f64>,
|
||||||
#[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")]
|
#[serde(rename = "_rankingScoreDetails", skip_serializing_if = "Option::is_none")]
|
||||||
pub ranking_score_details: Option<serde_json::Map<String, serde_json::Value>>,
|
pub ranking_score_details: Option<serde_json::Map<String, serde_json::Value>>,
|
||||||
|
#[serde(rename = "_semanticScore", skip_serializing_if = "Option::is_none")]
|
||||||
|
pub semantic_score: Option<f32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||||
@ -227,6 +238,8 @@ pub struct SearchHit {
|
|||||||
pub struct SearchResult {
|
pub struct SearchResult {
|
||||||
pub hits: Vec<SearchHit>,
|
pub hits: Vec<SearchHit>,
|
||||||
pub query: String,
|
pub query: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub vector: Option<Vec<f32>>,
|
||||||
pub processing_time_ms: u128,
|
pub processing_time_ms: u128,
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
pub hits_info: HitsInfo,
|
pub hits_info: HitsInfo,
|
||||||
@ -289,6 +302,14 @@ pub fn perform_search(
|
|||||||
|
|
||||||
let mut search = index.search(&rtxn);
|
let mut search = index.search(&rtxn);
|
||||||
|
|
||||||
|
if query.vector.is_some() && query.q.is_some() {
|
||||||
|
warn!("Ignoring the query string `q` when used with the `vector` parameter.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref vector) = query.vector {
|
||||||
|
search.vector(vector.clone());
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(ref query) = query.q {
|
if let Some(ref query) = query.q {
|
||||||
search.query(query);
|
search.query(query);
|
||||||
}
|
}
|
||||||
@ -312,6 +333,10 @@ pub fn perform_search(
|
|||||||
features.check_score_details()?;
|
features.check_score_details()?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if query.vector.is_some() {
|
||||||
|
features.check_vector()?;
|
||||||
|
}
|
||||||
|
|
||||||
// compute the offset on the limit depending on the pagination mode.
|
// compute the offset on the limit depending on the pagination mode.
|
||||||
let (offset, limit) = if is_finite_pagination {
|
let (offset, limit) = if is_finite_pagination {
|
||||||
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT);
|
||||||
@ -418,7 +443,6 @@ pub fn perform_search(
|
|||||||
formatter_builder.highlight_suffix(query.highlight_post_tag);
|
formatter_builder.highlight_suffix(query.highlight_post_tag);
|
||||||
|
|
||||||
let mut documents = Vec::new();
|
let mut documents = Vec::new();
|
||||||
|
|
||||||
let documents_iter = index.documents(&rtxn, documents_ids)?;
|
let documents_iter = index.documents(&rtxn, documents_ids)?;
|
||||||
|
|
||||||
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
|
for ((_id, obkv), score) in documents_iter.into_iter().zip(document_scores.into_iter()) {
|
||||||
@ -445,6 +469,14 @@ pub fn perform_search(
|
|||||||
insert_geo_distance(sort, &mut document);
|
insert_geo_distance(sort, &mut document);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let semantic_score = match query.vector.as_ref() {
|
||||||
|
Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? {
|
||||||
|
Some(vectors) => compute_semantic_score(vector, vectors)?,
|
||||||
|
None => None,
|
||||||
|
},
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
let ranking_score =
|
let ranking_score =
|
||||||
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
|
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
|
||||||
let ranking_score_details =
|
let ranking_score_details =
|
||||||
@ -456,6 +488,7 @@ pub fn perform_search(
|
|||||||
matches_position,
|
matches_position,
|
||||||
ranking_score_details,
|
ranking_score_details,
|
||||||
ranking_score,
|
ranking_score,
|
||||||
|
semantic_score,
|
||||||
};
|
};
|
||||||
documents.push(hit);
|
documents.push(hit);
|
||||||
}
|
}
|
||||||
@ -505,7 +538,8 @@ pub fn perform_search(
|
|||||||
let result = SearchResult {
|
let result = SearchResult {
|
||||||
hits: documents,
|
hits: documents,
|
||||||
hits_info,
|
hits_info,
|
||||||
query: query.q.clone().unwrap_or_default(),
|
query: query.q.unwrap_or_default(),
|
||||||
|
vector: query.vector,
|
||||||
processing_time_ms: before_search.elapsed().as_millis(),
|
processing_time_ms: before_search.elapsed().as_millis(),
|
||||||
facet_distribution,
|
facet_distribution,
|
||||||
facet_stats,
|
facet_stats,
|
||||||
@ -529,6 +563,17 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result<Option<f32>> {
|
||||||
|
let vectors = serde_json::from_value(vectors)
|
||||||
|
.map(VectorOrArrayOfVectors::into_array_of_vectors)
|
||||||
|
.map_err(InternalError::SerdeJson)?;
|
||||||
|
Ok(vectors
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| OrderedFloat(dot_product_similarity(query, &v)))
|
||||||
|
.max()
|
||||||
|
.map(OrderedFloat::into_inner))
|
||||||
|
}
|
||||||
|
|
||||||
fn compute_formatted_options(
|
fn compute_formatted_options(
|
||||||
attr_to_highlight: &HashSet<String>,
|
attr_to_highlight: &HashSet<String>,
|
||||||
attr_to_crop: &[String],
|
attr_to_crop: &[String],
|
||||||
@ -656,6 +701,22 @@ fn make_document(
|
|||||||
Ok(document)
|
Ok(document)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extract the JSON value under the field name specified
|
||||||
|
/// but doesn't support nested objects.
|
||||||
|
fn extract_field(
|
||||||
|
field_name: &str,
|
||||||
|
field_ids_map: &FieldsIdsMap,
|
||||||
|
obkv: obkv::KvReaderU16,
|
||||||
|
) -> Result<Option<serde_json::Value>, MeilisearchHttpError> {
|
||||||
|
match field_ids_map.id(field_name) {
|
||||||
|
Some(fid) => match obkv.get(fid) {
|
||||||
|
Some(value) => Ok(serde_json::from_slice(value).map(Some)?),
|
||||||
|
None => Ok(None),
|
||||||
|
},
|
||||||
|
None => Ok(None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn format_fields<A: AsRef<[u8]>>(
|
fn format_fields<A: AsRef<[u8]>>(
|
||||||
document: &Document,
|
document: &Document,
|
||||||
field_ids_map: &FieldsIdsMap,
|
field_ids_map: &FieldsIdsMap,
|
||||||
|
@ -15,6 +15,7 @@ license.workspace = true
|
|||||||
bimap = { version = "0.6.3", features = ["serde"] }
|
bimap = { version = "0.6.3", features = ["serde"] }
|
||||||
bincode = "1.3.3"
|
bincode = "1.3.3"
|
||||||
bstr = "1.4.0"
|
bstr = "1.4.0"
|
||||||
|
bytemuck = { version = "1.13.1", features = ["extern_crate_alloc"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
charabia = { version = "0.7.2", default-features = false }
|
charabia = { version = "0.7.2", default-features = false }
|
||||||
concat-arrays = "0.1.2"
|
concat-arrays = "0.1.2"
|
||||||
@ -32,18 +33,21 @@ heed = { git = "https://github.com/meilisearch/heed", tag = "v0.12.6", default-f
|
|||||||
"lmdb",
|
"lmdb",
|
||||||
"sync-read-txn",
|
"sync-read-txn",
|
||||||
] }
|
] }
|
||||||
|
hnsw = { version = "0.11.0", features = ["serde1"] }
|
||||||
json-depth-checker = { path = "../json-depth-checker" }
|
json-depth-checker = { path = "../json-depth-checker" }
|
||||||
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
|
levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] }
|
||||||
memmap2 = "0.5.10"
|
memmap2 = "0.5.10"
|
||||||
obkv = "0.2.0"
|
obkv = "0.2.0"
|
||||||
once_cell = "1.17.1"
|
once_cell = "1.17.1"
|
||||||
ordered-float = "3.6.0"
|
ordered-float = "3.6.0"
|
||||||
|
rand_pcg = { version = "0.3.1", features = ["serde1"] }
|
||||||
rayon = "1.7.0"
|
rayon = "1.7.0"
|
||||||
roaring = "0.10.1"
|
roaring = "0.10.1"
|
||||||
rstar = { version = "0.10.0", features = ["serde"] }
|
rstar = { version = "0.10.0", features = ["serde"] }
|
||||||
serde = { version = "1.0.160", features = ["derive"] }
|
serde = { version = "1.0.160", features = ["derive"] }
|
||||||
serde_json = { version = "1.0.95", features = ["preserve_order"] }
|
serde_json = { version = "1.0.95", features = ["preserve_order"] }
|
||||||
slice-group-by = "0.3.0"
|
slice-group-by = "0.3.0"
|
||||||
|
space = "0.17.0"
|
||||||
smallstr = { version = "0.3.0", features = ["serde"] }
|
smallstr = { version = "0.3.0", features = ["serde"] }
|
||||||
smallvec = "1.10.0"
|
smallvec = "1.10.0"
|
||||||
smartstring = "1.0.1"
|
smartstring = "1.0.1"
|
||||||
|
@ -52,6 +52,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
|||||||
let docs = execute_search(
|
let docs = execute_search(
|
||||||
&mut ctx,
|
&mut ctx,
|
||||||
&(!query.trim().is_empty()).then(|| query.trim().to_owned()),
|
&(!query.trim().is_empty()).then(|| query.trim().to_owned()),
|
||||||
|
&None,
|
||||||
TermsMatchingStrategy::Last,
|
TermsMatchingStrategy::Last,
|
||||||
milli::score_details::ScoringStrategy::Skip,
|
milli::score_details::ScoringStrategy::Skip,
|
||||||
false,
|
false,
|
||||||
|
25
milli/src/distance.rs
Normal file
25
milli/src/distance.rs
Normal file
@ -0,0 +1,25 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use space::Metric;
|
||||||
|
|
||||||
|
#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
|
||||||
|
pub struct DotProduct;
|
||||||
|
|
||||||
|
impl Metric<Vec<f32>> for DotProduct {
|
||||||
|
type Unit = u32;
|
||||||
|
|
||||||
|
// Following <https://docs.rs/space/0.17.0/space/trait.Metric.html>.
|
||||||
|
//
|
||||||
|
// Here is a playground that validate the ordering of the bit representation of floats in range 0.0..=1.0:
|
||||||
|
// <https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=6c59e31a3cc5036b32edf51e8937b56e>
|
||||||
|
fn distance(&self, a: &Vec<f32>, b: &Vec<f32>) -> Self::Unit {
|
||||||
|
let dist = 1.0 - dot_product_similarity(a, b);
|
||||||
|
debug_assert!(!dist.is_nan());
|
||||||
|
dist.to_bits()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the dot product similarity score that will between 0.0 and 1.0
|
||||||
|
/// if both vectors are normalized. The higher the more similar the vectors are.
|
||||||
|
pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||||
|
a.iter().zip(b).map(|(a, b)| a * b).sum()
|
||||||
|
}
|
@ -110,9 +110,13 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
|||||||
},
|
},
|
||||||
#[error(transparent)]
|
#[error(transparent)]
|
||||||
InvalidGeoField(#[from] GeoError),
|
InvalidGeoField(#[from] GeoError),
|
||||||
|
#[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)]
|
||||||
|
InvalidVectorDimensions { expected: usize, found: usize },
|
||||||
|
#[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")]
|
||||||
|
InvalidVectorsType { document_id: Value, value: Value },
|
||||||
#[error("{0}")]
|
#[error("{0}")]
|
||||||
InvalidFilter(String),
|
InvalidFilter(String),
|
||||||
#[error("Invalid type for filter subexpression: `expected {}, found: {1}`.", .0.join(", "))]
|
#[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))]
|
||||||
InvalidFilterExpression(&'static [&'static str], Value),
|
InvalidFilterExpression(&'static [&'static str], Value),
|
||||||
#[error("Attribute `{}` is not sortable. {}",
|
#[error("Attribute `{}` is not sortable. {}",
|
||||||
.field,
|
.field,
|
||||||
|
@ -8,10 +8,12 @@ use charabia::{Language, Script};
|
|||||||
use heed::flags::Flags;
|
use heed::flags::Flags;
|
||||||
use heed::types::*;
|
use heed::types::*;
|
||||||
use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn};
|
use heed::{CompactionOption, Database, PolyDatabase, RoTxn, RwTxn};
|
||||||
|
use rand_pcg::Pcg32;
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
use rstar::RTree;
|
use rstar::RTree;
|
||||||
use time::OffsetDateTime;
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
|
use crate::distance::DotProduct;
|
||||||
use crate::error::{InternalError, UserError};
|
use crate::error::{InternalError, UserError};
|
||||||
use crate::facet::FacetType;
|
use crate::facet::FacetType;
|
||||||
use crate::fields_ids_map::FieldsIdsMap;
|
use crate::fields_ids_map::FieldsIdsMap;
|
||||||
@ -20,12 +22,16 @@ use crate::heed_codec::facet::{
|
|||||||
FieldIdCodec, OrderedF64Codec,
|
FieldIdCodec, OrderedF64Codec,
|
||||||
};
|
};
|
||||||
use crate::heed_codec::{ScriptLanguageCodec, StrBEU16Codec, StrRefCodec};
|
use crate::heed_codec::{ScriptLanguageCodec, StrBEU16Codec, StrRefCodec};
|
||||||
|
use crate::readable_slices::ReadableSlices;
|
||||||
use crate::{
|
use crate::{
|
||||||
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
|
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
|
||||||
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
|
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
|
||||||
Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32,
|
Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// The HNSW data-structure that we serialize, fill and search in.
|
||||||
|
pub type Hnsw = hnsw::Hnsw<DotProduct, Vec<f32>, Pcg32, 12, 24>;
|
||||||
|
|
||||||
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
|
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
|
||||||
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;
|
pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9;
|
||||||
|
|
||||||
@ -42,6 +48,10 @@ pub mod main_key {
|
|||||||
pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map";
|
pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map";
|
||||||
pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids";
|
pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids";
|
||||||
pub const GEO_RTREE_KEY: &str = "geo-rtree";
|
pub const GEO_RTREE_KEY: &str = "geo-rtree";
|
||||||
|
/// The prefix of the key that is used to store the, potential big, HNSW structure.
|
||||||
|
/// It is concatenated with a big-endian encoded number (non-human readable).
|
||||||
|
/// e.g. vector-hnsw0x0032.
|
||||||
|
pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw";
|
||||||
pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids";
|
pub const HARD_EXTERNAL_DOCUMENTS_IDS_KEY: &str = "hard-external-documents-ids";
|
||||||
pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids";
|
pub const NUMBER_FACETED_DOCUMENTS_IDS_PREFIX: &str = "number-faceted-documents-ids";
|
||||||
pub const PRIMARY_KEY_KEY: &str = "primary-key";
|
pub const PRIMARY_KEY_KEY: &str = "primary-key";
|
||||||
@ -86,6 +96,7 @@ pub mod db_name {
|
|||||||
pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids";
|
pub const FACET_ID_STRING_DOCIDS: &str = "facet-id-string-docids";
|
||||||
pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s";
|
pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s";
|
||||||
pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings";
|
pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings";
|
||||||
|
pub const VECTOR_ID_DOCID: &str = "vector-id-docids";
|
||||||
pub const DOCUMENTS: &str = "documents";
|
pub const DOCUMENTS: &str = "documents";
|
||||||
pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids";
|
pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids";
|
||||||
}
|
}
|
||||||
@ -149,6 +160,9 @@ pub struct Index {
|
|||||||
/// Maps the document id, the facet field id and the strings.
|
/// Maps the document id, the facet field id and the strings.
|
||||||
pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>,
|
pub field_id_docid_facet_strings: Database<FieldDocIdFacetStringCodec, Str>,
|
||||||
|
|
||||||
|
/// Maps a vector id to the document id that have it.
|
||||||
|
pub vector_id_docid: Database<OwnedType<BEU32>, OwnedType<BEU32>>,
|
||||||
|
|
||||||
/// Maps the document id to the document as an obkv store.
|
/// Maps the document id to the document as an obkv store.
|
||||||
pub(crate) documents: Database<OwnedType<BEU32>, ObkvCodec>,
|
pub(crate) documents: Database<OwnedType<BEU32>, ObkvCodec>,
|
||||||
}
|
}
|
||||||
@ -162,7 +176,7 @@ impl Index {
|
|||||||
) -> Result<Index> {
|
) -> Result<Index> {
|
||||||
use db_name::*;
|
use db_name::*;
|
||||||
|
|
||||||
options.max_dbs(23);
|
options.max_dbs(24);
|
||||||
unsafe { options.flag(Flags::MdbAlwaysFreePages) };
|
unsafe { options.flag(Flags::MdbAlwaysFreePages) };
|
||||||
|
|
||||||
let env = options.open(path)?;
|
let env = options.open(path)?;
|
||||||
@ -198,11 +212,11 @@ impl Index {
|
|||||||
env.create_database(&mut wtxn, Some(FACET_ID_IS_NULL_DOCIDS))?;
|
env.create_database(&mut wtxn, Some(FACET_ID_IS_NULL_DOCIDS))?;
|
||||||
let facet_id_is_empty_docids =
|
let facet_id_is_empty_docids =
|
||||||
env.create_database(&mut wtxn, Some(FACET_ID_IS_EMPTY_DOCIDS))?;
|
env.create_database(&mut wtxn, Some(FACET_ID_IS_EMPTY_DOCIDS))?;
|
||||||
|
|
||||||
let field_id_docid_facet_f64s =
|
let field_id_docid_facet_f64s =
|
||||||
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?;
|
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?;
|
||||||
let field_id_docid_facet_strings =
|
let field_id_docid_facet_strings =
|
||||||
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?;
|
env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?;
|
||||||
|
let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?;
|
||||||
let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?;
|
let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?;
|
||||||
wtxn.commit()?;
|
wtxn.commit()?;
|
||||||
|
|
||||||
@ -231,6 +245,7 @@ impl Index {
|
|||||||
facet_id_is_empty_docids,
|
facet_id_is_empty_docids,
|
||||||
field_id_docid_facet_f64s,
|
field_id_docid_facet_f64s,
|
||||||
field_id_docid_facet_strings,
|
field_id_docid_facet_strings,
|
||||||
|
vector_id_docid,
|
||||||
documents,
|
documents,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -502,6 +517,56 @@ impl Index {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* vector HNSW */
|
||||||
|
|
||||||
|
/// Writes the provided `hnsw`.
|
||||||
|
pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> {
|
||||||
|
// We must delete all the chunks before we write the new HNSW chunks.
|
||||||
|
self.delete_vector_hnsw(wtxn)?;
|
||||||
|
|
||||||
|
let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB
|
||||||
|
let bytes = bincode::serialize(hnsw).map_err(|_| heed::Error::Encoding)?;
|
||||||
|
for (i, chunk) in bytes.chunks(chunk_size).enumerate() {
|
||||||
|
let i = i as u32;
|
||||||
|
let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec();
|
||||||
|
key.extend_from_slice(&i.to_be_bytes());
|
||||||
|
self.main.put::<_, ByteSlice, ByteSlice>(wtxn, &key, chunk)?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Delete the `hnsw`.
|
||||||
|
pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result<bool> {
|
||||||
|
let mut iter = self.main.prefix_iter_mut::<_, ByteSlice, DecodeIgnore>(
|
||||||
|
wtxn,
|
||||||
|
main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes(),
|
||||||
|
)?;
|
||||||
|
let mut deleted = false;
|
||||||
|
while iter.next().transpose()?.is_some() {
|
||||||
|
// We do not keep a reference to the key or the value.
|
||||||
|
unsafe { deleted |= iter.del_current()? };
|
||||||
|
}
|
||||||
|
Ok(deleted)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the `hnsw`.
|
||||||
|
pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result<Option<Hnsw>> {
|
||||||
|
let mut slices = Vec::new();
|
||||||
|
for result in
|
||||||
|
self.main.prefix_iter::<_, Str, ByteSlice>(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)?
|
||||||
|
{
|
||||||
|
let (_, slice) = result?;
|
||||||
|
slices.push(slice);
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.is_empty() {
|
||||||
|
Ok(None)
|
||||||
|
} else {
|
||||||
|
let readable_slices: ReadableSlices<_> = slices.into_iter().collect();
|
||||||
|
Ok(Some(bincode::deserialize_from(readable_slices).map_err(|_| heed::Error::Decoding)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/* field distribution */
|
/* field distribution */
|
||||||
|
|
||||||
/// Writes the field distribution which associates every field name with
|
/// Writes the field distribution which associates every field name with
|
||||||
|
@ -10,6 +10,7 @@ pub mod documents;
|
|||||||
|
|
||||||
mod asc_desc;
|
mod asc_desc;
|
||||||
mod criterion;
|
mod criterion;
|
||||||
|
pub mod distance;
|
||||||
mod error;
|
mod error;
|
||||||
mod external_documents_ids;
|
mod external_documents_ids;
|
||||||
pub mod facet;
|
pub mod facet;
|
||||||
@ -17,6 +18,7 @@ mod fields_ids_map;
|
|||||||
pub mod heed_codec;
|
pub mod heed_codec;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
pub mod proximity;
|
pub mod proximity;
|
||||||
|
mod readable_slices;
|
||||||
pub mod score_details;
|
pub mod score_details;
|
||||||
mod search;
|
mod search;
|
||||||
pub mod update;
|
pub mod update;
|
||||||
@ -30,6 +32,7 @@ use std::convert::{TryFrom, TryInto};
|
|||||||
use std::hash::BuildHasherDefault;
|
use std::hash::BuildHasherDefault;
|
||||||
|
|
||||||
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
|
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
|
||||||
|
pub use distance::dot_product_similarity;
|
||||||
pub use filter_parser::{Condition, FilterCondition, Span, Token};
|
pub use filter_parser::{Condition, FilterCondition, Span, Token};
|
||||||
use fxhash::{FxHasher32, FxHasher64};
|
use fxhash::{FxHasher32, FxHasher64};
|
||||||
pub use grenad::CompressionType;
|
pub use grenad::CompressionType;
|
||||||
@ -284,6 +287,35 @@ pub fn normalize_facet(original: &str) -> String {
|
|||||||
CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase()
|
CompatibilityDecompositionNormalizer.normalize_str(original.trim()).to_lowercase()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Represents either a vector or an array of multiple vectors.
|
||||||
|
#[derive(serde::Serialize, serde::Deserialize, Debug)]
|
||||||
|
#[serde(transparent)]
|
||||||
|
pub struct VectorOrArrayOfVectors {
|
||||||
|
#[serde(with = "either::serde_untagged")]
|
||||||
|
inner: either::Either<Vec<f32>, Vec<Vec<f32>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VectorOrArrayOfVectors {
|
||||||
|
pub fn into_array_of_vectors(self) -> Vec<Vec<f32>> {
|
||||||
|
match self.inner {
|
||||||
|
either::Either::Left(vector) => vec![vector],
|
||||||
|
either::Either::Right(vectors) => vectors,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Normalize a vector by dividing the dimensions by the length of it.
|
||||||
|
pub fn normalize_vector(mut vector: Vec<f32>) -> Vec<f32> {
|
||||||
|
let squared: f32 = vector.iter().map(|x| x * x).sum();
|
||||||
|
let length = squared.sqrt();
|
||||||
|
if length <= f32::EPSILON {
|
||||||
|
vector
|
||||||
|
} else {
|
||||||
|
vector.iter_mut().for_each(|x| *x /= length);
|
||||||
|
vector
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
85
milli/src/readable_slices.rs
Normal file
85
milli/src/readable_slices.rs
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
use std::io::{self, Read};
|
||||||
|
use std::iter::FromIterator;
|
||||||
|
|
||||||
|
pub struct ReadableSlices<A> {
|
||||||
|
inner: Vec<A>,
|
||||||
|
pos: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A> FromIterator<A> for ReadableSlices<A> {
|
||||||
|
fn from_iter<T: IntoIterator<Item = A>>(iter: T) -> Self {
|
||||||
|
ReadableSlices { inner: iter.into_iter().collect(), pos: 0 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<A: AsRef<[u8]>> Read for ReadableSlices<A> {
|
||||||
|
fn read(&mut self, mut buf: &mut [u8]) -> io::Result<usize> {
|
||||||
|
let original_buf_len = buf.len();
|
||||||
|
|
||||||
|
// We explore the list of slices to find the one where we must start reading.
|
||||||
|
let mut pos = self.pos;
|
||||||
|
let index = match self
|
||||||
|
.inner
|
||||||
|
.iter()
|
||||||
|
.map(|s| s.as_ref().len() as u64)
|
||||||
|
.position(|size| pos.checked_sub(size).map(|p| pos = p).is_none())
|
||||||
|
{
|
||||||
|
Some(index) => index,
|
||||||
|
None => return Ok(0),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut inner_pos = pos as usize;
|
||||||
|
for slice in &self.inner[index..] {
|
||||||
|
let slice = &slice.as_ref()[inner_pos..];
|
||||||
|
|
||||||
|
if buf.len() > slice.len() {
|
||||||
|
// We must exhaust the current slice and go to the next one there is not enough here.
|
||||||
|
buf[..slice.len()].copy_from_slice(slice);
|
||||||
|
buf = &mut buf[slice.len()..];
|
||||||
|
inner_pos = 0;
|
||||||
|
} else {
|
||||||
|
// There is enough in this slice to fill the remaining bytes of the buffer.
|
||||||
|
// Let's break just after filling it.
|
||||||
|
buf.copy_from_slice(&slice[..buf.len()]);
|
||||||
|
buf = &mut [];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let written = original_buf_len - buf.len();
|
||||||
|
self.pos += written as u64;
|
||||||
|
Ok(written)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use std::io::Read;
|
||||||
|
|
||||||
|
use super::ReadableSlices;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn basic() {
|
||||||
|
let data: Vec<_> = (0..100).collect();
|
||||||
|
let splits: Vec<_> = data.chunks(3).collect();
|
||||||
|
let mut rdslices: ReadableSlices<_> = splits.into_iter().collect();
|
||||||
|
|
||||||
|
let mut output = Vec::new();
|
||||||
|
let length = rdslices.read_to_end(&mut output).unwrap();
|
||||||
|
assert_eq!(length, data.len());
|
||||||
|
assert_eq!(output, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn small_reads() {
|
||||||
|
let data: Vec<_> = (0..u8::MAX).collect();
|
||||||
|
let splits: Vec<_> = data.chunks(27).collect();
|
||||||
|
let mut rdslices: ReadableSlices<_> = splits.into_iter().collect();
|
||||||
|
|
||||||
|
let buffer = &mut [0; 45];
|
||||||
|
let length = rdslices.read(buffer).unwrap();
|
||||||
|
let expected: Vec<_> = (0..buffer.len() as u8).collect();
|
||||||
|
assert_eq!(length, buffer.len());
|
||||||
|
assert_eq!(buffer, &expected[..]);
|
||||||
|
}
|
||||||
|
}
|
@ -23,6 +23,7 @@ pub mod new;
|
|||||||
|
|
||||||
pub struct Search<'a> {
|
pub struct Search<'a> {
|
||||||
query: Option<String>,
|
query: Option<String>,
|
||||||
|
vector: Option<Vec<f32>>,
|
||||||
// this should be linked to the String in the query
|
// this should be linked to the String in the query
|
||||||
filter: Option<Filter<'a>>,
|
filter: Option<Filter<'a>>,
|
||||||
offset: usize,
|
offset: usize,
|
||||||
@ -41,6 +42,7 @@ impl<'a> Search<'a> {
|
|||||||
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
|
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
|
||||||
Search {
|
Search {
|
||||||
query: None,
|
query: None,
|
||||||
|
vector: None,
|
||||||
filter: None,
|
filter: None,
|
||||||
offset: 0,
|
offset: 0,
|
||||||
limit: 20,
|
limit: 20,
|
||||||
@ -60,6 +62,11 @@ impl<'a> Search<'a> {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> {
|
||||||
|
self.vector = Some(vector.into());
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
pub fn offset(&mut self, offset: usize) -> &mut Search<'a> {
|
pub fn offset(&mut self, offset: usize) -> &mut Search<'a> {
|
||||||
self.offset = offset;
|
self.offset = offset;
|
||||||
self
|
self
|
||||||
@ -114,6 +121,7 @@ impl<'a> Search<'a> {
|
|||||||
execute_search(
|
execute_search(
|
||||||
&mut ctx,
|
&mut ctx,
|
||||||
&self.query,
|
&self.query,
|
||||||
|
&self.vector,
|
||||||
self.terms_matching_strategy,
|
self.terms_matching_strategy,
|
||||||
self.scoring_strategy,
|
self.scoring_strategy,
|
||||||
self.exhaustive_number_hits,
|
self.exhaustive_number_hits,
|
||||||
@ -141,6 +149,7 @@ impl fmt::Debug for Search<'_> {
|
|||||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
let Search {
|
let Search {
|
||||||
query,
|
query,
|
||||||
|
vector: _,
|
||||||
filter,
|
filter,
|
||||||
offset,
|
offset,
|
||||||
limit,
|
limit,
|
||||||
@ -155,6 +164,7 @@ impl fmt::Debug for Search<'_> {
|
|||||||
} = self;
|
} = self;
|
||||||
f.debug_struct("Search")
|
f.debug_struct("Search")
|
||||||
.field("query", query)
|
.field("query", query)
|
||||||
|
.field("vector", &"[...]")
|
||||||
.field("filter", filter)
|
.field("filter", filter)
|
||||||
.field("offset", offset)
|
.field("offset", offset)
|
||||||
.field("limit", limit)
|
.field("limit", limit)
|
||||||
|
@ -509,6 +509,7 @@ mod tests {
|
|||||||
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
|
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
|
||||||
&mut ctx,
|
&mut ctx,
|
||||||
&Some(query.to_string()),
|
&Some(query.to_string()),
|
||||||
|
&None,
|
||||||
crate::TermsMatchingStrategy::default(),
|
crate::TermsMatchingStrategy::default(),
|
||||||
crate::score_details::ScoringStrategy::Skip,
|
crate::score_details::ScoringStrategy::Skip,
|
||||||
false,
|
false,
|
||||||
|
@ -28,6 +28,7 @@ use db_cache::DatabaseCache;
|
|||||||
use exact_attribute::ExactAttribute;
|
use exact_attribute::ExactAttribute;
|
||||||
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
|
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
|
||||||
use heed::RoTxn;
|
use heed::RoTxn;
|
||||||
|
use hnsw::Searcher;
|
||||||
use interner::{DedupInterner, Interner};
|
use interner::{DedupInterner, Interner};
|
||||||
pub use logger::visual::VisualSearchLogger;
|
pub use logger::visual::VisualSearchLogger;
|
||||||
pub use logger::{DefaultSearchLogger, SearchLogger};
|
pub use logger::{DefaultSearchLogger, SearchLogger};
|
||||||
@ -39,6 +40,7 @@ use ranking_rules::{
|
|||||||
use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache};
|
use resolve_query_graph::{compute_query_graph_docids, PhraseDocIdsCache};
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
use sort::Sort;
|
use sort::Sort;
|
||||||
|
use space::Neighbor;
|
||||||
|
|
||||||
use self::geo_sort::GeoSort;
|
use self::geo_sort::GeoSort;
|
||||||
pub use self::geo_sort::Strategy as GeoSortStrategy;
|
pub use self::geo_sort::Strategy as GeoSortStrategy;
|
||||||
@ -46,7 +48,10 @@ use self::graph_based_ranking_rule::Words;
|
|||||||
use self::interner::Interned;
|
use self::interner::Interned;
|
||||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||||
use crate::search::new::distinct::apply_distinct_rule;
|
use crate::search::new::distinct::apply_distinct_rule;
|
||||||
use crate::{AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError};
|
use crate::{
|
||||||
|
normalize_vector, AscDesc, DocumentId, Filter, Index, Member, Result, TermsMatchingStrategy,
|
||||||
|
UserError, BEU32,
|
||||||
|
};
|
||||||
|
|
||||||
/// A structure used throughout the execution of a search query.
|
/// A structure used throughout the execution of a search query.
|
||||||
pub struct SearchContext<'ctx> {
|
pub struct SearchContext<'ctx> {
|
||||||
@ -350,6 +355,7 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
|
|||||||
pub fn execute_search(
|
pub fn execute_search(
|
||||||
ctx: &mut SearchContext,
|
ctx: &mut SearchContext,
|
||||||
query: &Option<String>,
|
query: &Option<String>,
|
||||||
|
vector: &Option<Vec<f32>>,
|
||||||
terms_matching_strategy: TermsMatchingStrategy,
|
terms_matching_strategy: TermsMatchingStrategy,
|
||||||
scoring_strategy: ScoringStrategy,
|
scoring_strategy: ScoringStrategy,
|
||||||
exhaustive_number_hits: bool,
|
exhaustive_number_hits: bool,
|
||||||
@ -370,8 +376,40 @@ pub fn execute_search(
|
|||||||
|
|
||||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||||
|
|
||||||
let mut located_query_terms = None;
|
if let Some(vector) = vector {
|
||||||
|
let mut searcher = Searcher::new();
|
||||||
|
let hnsw = ctx.index.vector_hnsw(ctx.txn)?.unwrap_or_default();
|
||||||
|
let ef = hnsw.len().min(100);
|
||||||
|
let mut dest = vec![Neighbor { index: 0, distance: 0 }; ef];
|
||||||
|
let vector = normalize_vector(vector.clone());
|
||||||
|
let neighbors = hnsw.nearest(&vector, ef, &mut searcher, &mut dest[..]);
|
||||||
|
|
||||||
|
let mut docids = Vec::new();
|
||||||
|
let mut uniq_docids = RoaringBitmap::new();
|
||||||
|
for Neighbor { index, distance: _ } in neighbors.iter() {
|
||||||
|
let index = BEU32::new(*index as u32);
|
||||||
|
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap().get();
|
||||||
|
if universe.contains(docid) && uniq_docids.insert(docid) {
|
||||||
|
docids.push(docid);
|
||||||
|
if docids.len() == (from + length) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// return the nearest documents that are also part of the candidates
|
||||||
|
// along with a dummy list of scores that are useless in this context.
|
||||||
|
let docids: Vec<_> = docids.into_iter().skip(from).take(length).collect();
|
||||||
|
|
||||||
|
return Ok(PartialSearchResult {
|
||||||
|
candidates: universe,
|
||||||
|
document_scores: vec![Vec::new(); docids.len()],
|
||||||
|
documents_ids: docids,
|
||||||
|
located_query_terms: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut located_query_terms = None;
|
||||||
let query_terms = if let Some(query) = query {
|
let query_terms = if let Some(query) = query {
|
||||||
// We make sure that the analyzer is aware of the stop words
|
// We make sure that the analyzer is aware of the stop words
|
||||||
// this ensures that the query builder is able to properly remove them.
|
// this ensures that the query builder is able to properly remove them.
|
||||||
@ -439,7 +477,6 @@ pub fn execute_search(
|
|||||||
};
|
};
|
||||||
|
|
||||||
let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output;
|
let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output;
|
||||||
|
|
||||||
let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?;
|
let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?;
|
||||||
|
|
||||||
// The candidates is the universe unless the exhaustive number of hits
|
// The candidates is the universe unless the exhaustive number of hits
|
||||||
|
@ -39,6 +39,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> {
|
|||||||
facet_id_is_empty_docids,
|
facet_id_is_empty_docids,
|
||||||
field_id_docid_facet_f64s,
|
field_id_docid_facet_f64s,
|
||||||
field_id_docid_facet_strings,
|
field_id_docid_facet_strings,
|
||||||
|
vector_id_docid,
|
||||||
documents,
|
documents,
|
||||||
} = self.index;
|
} = self.index;
|
||||||
|
|
||||||
@ -57,6 +58,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> {
|
|||||||
self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?;
|
self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?;
|
||||||
self.index.delete_geo_rtree(self.wtxn)?;
|
self.index.delete_geo_rtree(self.wtxn)?;
|
||||||
self.index.delete_geo_faceted_documents_ids(self.wtxn)?;
|
self.index.delete_geo_faceted_documents_ids(self.wtxn)?;
|
||||||
|
self.index.delete_vector_hnsw(self.wtxn)?;
|
||||||
|
|
||||||
// We clean all the faceted documents ids.
|
// We clean all the faceted documents ids.
|
||||||
for field_id in faceted_fields {
|
for field_id in faceted_fields {
|
||||||
@ -95,6 +97,7 @@ impl<'t, 'u, 'i> ClearDocuments<'t, 'u, 'i> {
|
|||||||
facet_id_string_docids.clear(self.wtxn)?;
|
facet_id_string_docids.clear(self.wtxn)?;
|
||||||
field_id_docid_facet_f64s.clear(self.wtxn)?;
|
field_id_docid_facet_f64s.clear(self.wtxn)?;
|
||||||
field_id_docid_facet_strings.clear(self.wtxn)?;
|
field_id_docid_facet_strings.clear(self.wtxn)?;
|
||||||
|
vector_id_docid.clear(self.wtxn)?;
|
||||||
documents.clear(self.wtxn)?;
|
documents.clear(self.wtxn)?;
|
||||||
|
|
||||||
Ok(number_of_documents)
|
Ok(number_of_documents)
|
||||||
|
@ -4,8 +4,10 @@ use std::collections::{BTreeSet, HashMap, HashSet};
|
|||||||
use fst::IntoStreamer;
|
use fst::IntoStreamer;
|
||||||
use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice};
|
use heed::types::{ByteSlice, DecodeIgnore, Str, UnalignedSlice};
|
||||||
use heed::{BytesDecode, BytesEncode, Database, RwIter};
|
use heed::{BytesDecode, BytesEncode, Database, RwIter};
|
||||||
|
use hnsw::Searcher;
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use space::KnnPoints;
|
||||||
use time::OffsetDateTime;
|
use time::OffsetDateTime;
|
||||||
|
|
||||||
use super::facet::delete::FacetsDelete;
|
use super::facet::delete::FacetsDelete;
|
||||||
@ -14,6 +16,7 @@ use crate::error::InternalError;
|
|||||||
use crate::facet::FacetType;
|
use crate::facet::FacetType;
|
||||||
use crate::heed_codec::facet::FieldDocIdFacetCodec;
|
use crate::heed_codec::facet::FieldDocIdFacetCodec;
|
||||||
use crate::heed_codec::CboRoaringBitmapCodec;
|
use crate::heed_codec::CboRoaringBitmapCodec;
|
||||||
|
use crate::index::Hnsw;
|
||||||
use crate::{
|
use crate::{
|
||||||
ExternalDocumentsIds, FieldId, FieldIdMapMissingEntry, Index, Result, RoaringBitmapCodec, BEU32,
|
ExternalDocumentsIds, FieldId, FieldIdMapMissingEntry, Index, Result, RoaringBitmapCodec, BEU32,
|
||||||
};
|
};
|
||||||
@ -240,6 +243,7 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> {
|
|||||||
facet_id_exists_docids,
|
facet_id_exists_docids,
|
||||||
facet_id_is_null_docids,
|
facet_id_is_null_docids,
|
||||||
facet_id_is_empty_docids,
|
facet_id_is_empty_docids,
|
||||||
|
vector_id_docid,
|
||||||
documents,
|
documents,
|
||||||
} = self.index;
|
} = self.index;
|
||||||
// Remove from the documents database
|
// Remove from the documents database
|
||||||
@ -429,6 +433,30 @@ impl<'t, 'u, 'i> DeleteDocuments<'t, 'u, 'i> {
|
|||||||
&self.to_delete_docids,
|
&self.to_delete_docids,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
// An ugly and slow way to remove the vectors from the HNSW
|
||||||
|
// It basically reconstructs the HNSW from scratch without editing the current one.
|
||||||
|
let current_hnsw = self.index.vector_hnsw(self.wtxn)?.unwrap_or_default();
|
||||||
|
if !current_hnsw.is_empty() {
|
||||||
|
let mut new_hnsw = Hnsw::default();
|
||||||
|
let mut searcher = Searcher::new();
|
||||||
|
let mut new_vector_id_docids = Vec::new();
|
||||||
|
|
||||||
|
for result in vector_id_docid.iter(self.wtxn)? {
|
||||||
|
let (vector_id, docid) = result?;
|
||||||
|
if !self.to_delete_docids.contains(docid.get()) {
|
||||||
|
let vector = current_hnsw.get_point(vector_id.get() as usize).clone();
|
||||||
|
let vector_id = new_hnsw.insert(vector, &mut searcher);
|
||||||
|
new_vector_id_docids.push((vector_id as u32, docid));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
vector_id_docid.clear(self.wtxn)?;
|
||||||
|
for (vector_id, docid) in new_vector_id_docids {
|
||||||
|
vector_id_docid.put(self.wtxn, &BEU32::new(vector_id), &docid)?;
|
||||||
|
}
|
||||||
|
self.index.put_vector_hnsw(self.wtxn, &new_hnsw)?;
|
||||||
|
}
|
||||||
|
|
||||||
self.index.put_soft_deleted_documents_ids(self.wtxn, &RoaringBitmap::new())?;
|
self.index.put_soft_deleted_documents_ids(self.wtxn, &RoaringBitmap::new())?;
|
||||||
|
|
||||||
Ok(DetailedDocumentDeletionResult {
|
Ok(DetailedDocumentDeletionResult {
|
||||||
|
@ -0,0 +1,65 @@
|
|||||||
|
use std::convert::TryFrom;
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
use bytemuck::cast_slice;
|
||||||
|
use serde_json::{from_slice, Value};
|
||||||
|
|
||||||
|
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
||||||
|
use crate::error::UserError;
|
||||||
|
use crate::{FieldId, InternalError, Result, VectorOrArrayOfVectors};
|
||||||
|
|
||||||
|
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
||||||
|
///
|
||||||
|
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
|
||||||
|
#[logging_timer::time]
|
||||||
|
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||||
|
obkv_documents: grenad::Reader<R>,
|
||||||
|
indexer: GrenadParameters,
|
||||||
|
primary_key_id: FieldId,
|
||||||
|
vectors_fid: FieldId,
|
||||||
|
) -> Result<grenad::Reader<File>> {
|
||||||
|
let mut writer = create_writer(
|
||||||
|
indexer.chunk_compression_type,
|
||||||
|
indexer.chunk_compression_level,
|
||||||
|
tempfile::tempfile()?,
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut cursor = obkv_documents.into_cursor()?;
|
||||||
|
while let Some((docid_bytes, value)) = cursor.move_on_next()? {
|
||||||
|
let obkv = obkv::KvReader::new(value);
|
||||||
|
|
||||||
|
// since we only needs the primary key when we throw an error we create this getter to
|
||||||
|
// lazily get it when needed
|
||||||
|
let document_id = || -> Value {
|
||||||
|
let document_id = obkv.get(primary_key_id).unwrap();
|
||||||
|
serde_json::from_slice(document_id).unwrap()
|
||||||
|
};
|
||||||
|
|
||||||
|
// first we retrieve the _vectors field
|
||||||
|
if let Some(vectors) = obkv.get(vectors_fid) {
|
||||||
|
// extract the vectors
|
||||||
|
let vectors = match from_slice(vectors) {
|
||||||
|
Ok(vectors) => VectorOrArrayOfVectors::into_array_of_vectors(vectors),
|
||||||
|
Err(_) => {
|
||||||
|
return Err(UserError::InvalidVectorsType {
|
||||||
|
document_id: document_id(),
|
||||||
|
value: from_slice(vectors).map_err(InternalError::SerdeJson)?,
|
||||||
|
}
|
||||||
|
.into())
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for (i, vector) in vectors.into_iter().enumerate().take(u16::MAX as usize) {
|
||||||
|
let index = u16::try_from(i).unwrap();
|
||||||
|
let mut key = docid_bytes.to_vec();
|
||||||
|
key.extend_from_slice(&index.to_be_bytes());
|
||||||
|
let bytes = cast_slice(&vector);
|
||||||
|
writer.insert(key, bytes)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// else => the `_vectors` object was `null`, there is nothing to do
|
||||||
|
}
|
||||||
|
|
||||||
|
writer_into_reader(writer)
|
||||||
|
}
|
@ -4,6 +4,7 @@ mod extract_facet_string_docids;
|
|||||||
mod extract_fid_docid_facet_values;
|
mod extract_fid_docid_facet_values;
|
||||||
mod extract_fid_word_count_docids;
|
mod extract_fid_word_count_docids;
|
||||||
mod extract_geo_points;
|
mod extract_geo_points;
|
||||||
|
mod extract_vector_points;
|
||||||
mod extract_word_docids;
|
mod extract_word_docids;
|
||||||
mod extract_word_fid_docids;
|
mod extract_word_fid_docids;
|
||||||
mod extract_word_pair_proximity_docids;
|
mod extract_word_pair_proximity_docids;
|
||||||
@ -22,6 +23,7 @@ use self::extract_facet_string_docids::extract_facet_string_docids;
|
|||||||
use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues};
|
use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues};
|
||||||
use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
|
use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
|
||||||
use self::extract_geo_points::extract_geo_points;
|
use self::extract_geo_points::extract_geo_points;
|
||||||
|
use self::extract_vector_points::extract_vector_points;
|
||||||
use self::extract_word_docids::extract_word_docids;
|
use self::extract_word_docids::extract_word_docids;
|
||||||
use self::extract_word_fid_docids::extract_word_fid_docids;
|
use self::extract_word_fid_docids::extract_word_fid_docids;
|
||||||
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
|
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
|
||||||
@ -45,6 +47,7 @@ pub(crate) fn data_from_obkv_documents(
|
|||||||
faceted_fields: HashSet<FieldId>,
|
faceted_fields: HashSet<FieldId>,
|
||||||
primary_key_id: FieldId,
|
primary_key_id: FieldId,
|
||||||
geo_fields_ids: Option<(FieldId, FieldId)>,
|
geo_fields_ids: Option<(FieldId, FieldId)>,
|
||||||
|
vectors_field_id: Option<FieldId>,
|
||||||
stop_words: Option<fst::Set<&[u8]>>,
|
stop_words: Option<fst::Set<&[u8]>>,
|
||||||
max_positions_per_attributes: Option<u32>,
|
max_positions_per_attributes: Option<u32>,
|
||||||
exact_attributes: HashSet<FieldId>,
|
exact_attributes: HashSet<FieldId>,
|
||||||
@ -69,6 +72,7 @@ pub(crate) fn data_from_obkv_documents(
|
|||||||
&faceted_fields,
|
&faceted_fields,
|
||||||
primary_key_id,
|
primary_key_id,
|
||||||
geo_fields_ids,
|
geo_fields_ids,
|
||||||
|
vectors_field_id,
|
||||||
&stop_words,
|
&stop_words,
|
||||||
max_positions_per_attributes,
|
max_positions_per_attributes,
|
||||||
)
|
)
|
||||||
@ -279,6 +283,7 @@ fn send_and_extract_flattened_documents_data(
|
|||||||
faceted_fields: &HashSet<FieldId>,
|
faceted_fields: &HashSet<FieldId>,
|
||||||
primary_key_id: FieldId,
|
primary_key_id: FieldId,
|
||||||
geo_fields_ids: Option<(FieldId, FieldId)>,
|
geo_fields_ids: Option<(FieldId, FieldId)>,
|
||||||
|
vectors_field_id: Option<FieldId>,
|
||||||
stop_words: &Option<fst::Set<&[u8]>>,
|
stop_words: &Option<fst::Set<&[u8]>>,
|
||||||
max_positions_per_attributes: Option<u32>,
|
max_positions_per_attributes: Option<u32>,
|
||||||
) -> Result<(
|
) -> Result<(
|
||||||
@ -307,6 +312,25 @@ fn send_and_extract_flattened_documents_data(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(vectors_field_id) = vectors_field_id {
|
||||||
|
let documents_chunk_cloned = flattened_documents_chunk.clone();
|
||||||
|
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||||
|
rayon::spawn(move || {
|
||||||
|
let result = extract_vector_points(
|
||||||
|
documents_chunk_cloned,
|
||||||
|
indexer,
|
||||||
|
primary_key_id,
|
||||||
|
vectors_field_id,
|
||||||
|
);
|
||||||
|
let _ = match result {
|
||||||
|
Ok(vector_points) => {
|
||||||
|
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))
|
||||||
|
}
|
||||||
|
Err(error) => lmdb_writer_sx_cloned.send(Err(error)),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let (docid_word_positions_chunk, docid_fid_facet_values_chunks): (Result<_>, Result<_>) =
|
let (docid_word_positions_chunk, docid_fid_facet_values_chunks): (Result<_>, Result<_>) =
|
||||||
rayon::join(
|
rayon::join(
|
||||||
|| {
|
|| {
|
||||||
|
@ -304,6 +304,8 @@ where
|
|||||||
}
|
}
|
||||||
None => None,
|
None => None,
|
||||||
};
|
};
|
||||||
|
// get the fid of the `_vectors` field.
|
||||||
|
let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors");
|
||||||
|
|
||||||
let stop_words = self.index.stop_words(self.wtxn)?;
|
let stop_words = self.index.stop_words(self.wtxn)?;
|
||||||
let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?;
|
let exact_attributes = self.index.exact_attributes_ids(self.wtxn)?;
|
||||||
@ -340,6 +342,7 @@ where
|
|||||||
faceted_fields,
|
faceted_fields,
|
||||||
primary_key_id,
|
primary_key_id,
|
||||||
geo_fields_ids,
|
geo_fields_ids,
|
||||||
|
vectors_field_id,
|
||||||
stop_words,
|
stop_words,
|
||||||
max_positions_per_attributes,
|
max_positions_per_attributes,
|
||||||
exact_attributes,
|
exact_attributes,
|
||||||
|
@ -4,20 +4,27 @@ use std::convert::TryInto;
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io;
|
use std::io;
|
||||||
|
|
||||||
|
use bytemuck::allocation::pod_collect_to_vec;
|
||||||
use charabia::{Language, Script};
|
use charabia::{Language, Script};
|
||||||
use grenad::MergerBuilder;
|
use grenad::MergerBuilder;
|
||||||
use heed::types::ByteSlice;
|
use heed::types::ByteSlice;
|
||||||
use heed::RwTxn;
|
use heed::RwTxn;
|
||||||
|
use hnsw::Searcher;
|
||||||
use roaring::RoaringBitmap;
|
use roaring::RoaringBitmap;
|
||||||
|
use space::KnnPoints;
|
||||||
|
|
||||||
use super::helpers::{
|
use super::helpers::{
|
||||||
self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap,
|
self, merge_ignore_values, serialize_roaring_bitmap, valid_lmdb_key, CursorClonableMmap,
|
||||||
};
|
};
|
||||||
use super::{ClonableMmap, MergeFn};
|
use super::{ClonableMmap, MergeFn};
|
||||||
|
use crate::error::UserError;
|
||||||
use crate::facet::FacetType;
|
use crate::facet::FacetType;
|
||||||
use crate::update::facet::FacetsUpdate;
|
use crate::update::facet::FacetsUpdate;
|
||||||
use crate::update::index_documents::helpers::as_cloneable_grenad;
|
use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at};
|
||||||
use crate::{lat_lng_to_xyz, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result};
|
use crate::{
|
||||||
|
lat_lng_to_xyz, normalize_vector, CboRoaringBitmapCodec, DocumentId, GeoPoint, Index, Result,
|
||||||
|
BEU32,
|
||||||
|
};
|
||||||
|
|
||||||
pub(crate) enum TypedChunk {
|
pub(crate) enum TypedChunk {
|
||||||
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
|
FieldIdDocidFacetStrings(grenad::Reader<CursorClonableMmap>),
|
||||||
@ -38,6 +45,7 @@ pub(crate) enum TypedChunk {
|
|||||||
FieldIdFacetIsNullDocids(grenad::Reader<File>),
|
FieldIdFacetIsNullDocids(grenad::Reader<File>),
|
||||||
FieldIdFacetIsEmptyDocids(grenad::Reader<File>),
|
FieldIdFacetIsEmptyDocids(grenad::Reader<File>),
|
||||||
GeoPoints(grenad::Reader<File>),
|
GeoPoints(grenad::Reader<File>),
|
||||||
|
VectorPoints(grenad::Reader<File>),
|
||||||
ScriptLanguageDocids(HashMap<(Script, Language), RoaringBitmap>),
|
ScriptLanguageDocids(HashMap<(Script, Language), RoaringBitmap>),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -221,6 +229,40 @@ pub(crate) fn write_typed_chunk_into_index(
|
|||||||
index.put_geo_rtree(wtxn, &rtree)?;
|
index.put_geo_rtree(wtxn, &rtree)?;
|
||||||
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
|
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
|
||||||
}
|
}
|
||||||
|
TypedChunk::VectorPoints(vector_points) => {
|
||||||
|
let mut hnsw = index.vector_hnsw(wtxn)?.unwrap_or_default();
|
||||||
|
let mut searcher = Searcher::new();
|
||||||
|
|
||||||
|
let mut expected_dimensions = match index.vector_id_docid.iter(wtxn)?.next() {
|
||||||
|
Some(result) => {
|
||||||
|
let (vector_id, _) = result?;
|
||||||
|
Some(hnsw.get_point(vector_id.get() as usize).len())
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut cursor = vector_points.into_cursor()?;
|
||||||
|
while let Some((key, value)) = cursor.move_on_next()? {
|
||||||
|
// convert the key back to a u32 (4 bytes)
|
||||||
|
let (left, _index) = try_split_array_at(key).unwrap();
|
||||||
|
let docid = DocumentId::from_be_bytes(left);
|
||||||
|
// convert the vector back to a Vec<f32>
|
||||||
|
let vector: Vec<f32> = pod_collect_to_vec(value);
|
||||||
|
|
||||||
|
// TODO Inform the user about the document that has a wrong `_vectors`
|
||||||
|
let found = vector.len();
|
||||||
|
let expected = *expected_dimensions.get_or_insert(found);
|
||||||
|
if expected != found {
|
||||||
|
return Err(UserError::InvalidVectorDimensions { expected, found })?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let vector = normalize_vector(vector);
|
||||||
|
let vector_id = hnsw.insert(vector, &mut searcher) as u32;
|
||||||
|
index.vector_id_docid.put(wtxn, &BEU32::new(vector_id), &BEU32::new(docid))?;
|
||||||
|
}
|
||||||
|
log::debug!("There are {} entries in the HNSW so far", hnsw.len());
|
||||||
|
index.put_vector_hnsw(wtxn, &hnsw)?;
|
||||||
|
}
|
||||||
TypedChunk::ScriptLanguageDocids(hash_pair) => {
|
TypedChunk::ScriptLanguageDocids(hash_pair) => {
|
||||||
let mut buffer = Vec::new();
|
let mut buffer = Vec::new();
|
||||||
for (key, value) in hash_pair {
|
for (key, value) in hash_pair {
|
||||||
|
Loading…
x
Reference in New Issue
Block a user