From 21bcf32109870b2db8cefc1803910d5771eecc8c Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 15 Nov 2023 15:43:57 +0100 Subject: [PATCH 01/28] Add candle and hg_hub, updating a lot of deps in the process --- Cargo.lock | 699 ++++++++++++++++++++++++++++++++++++++++++++--- milli/Cargo.toml | 5 + 2 files changed, 673 insertions(+), 31 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fda5f2493..f6ce4b26b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -280,9 +280,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.2" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -310,16 +310,15 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.3.2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal", "utf8parse", ] @@ -349,9 +348,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.1" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys 0.48.0", @@ -589,9 +588,9 @@ checksum = "2c676a478f63e9fa2dd5368a42f28bba0d6c560b775f38583c8bbaa7fcd67c9c" [[package]] name = "bytemuck" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" dependencies = [ "bytemuck_derive", ] @@ -659,6 +658,58 @@ dependencies = [ "displaydoc", ] +[[package]] +name = "candle-core" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528" +dependencies = [ + "byteorder", + "gemm", + "half 2.3.1", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + +[[package]] +name = "candle-nn" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528" +dependencies = [ + "candle-core", + "half 2.3.1", + "num-traits", + "rayon", + "safetensors", + "serde", + "thiserror", +] + +[[package]] +name = "candle-transformers" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528" +dependencies = [ + "byteorder", + "candle-core", + "candle-nn", + "num-traits", + "rand", + "rayon", + "serde", + "serde_json", + "serde_plain", + "tracing", + "wav", +] + [[package]] name = "cargo_toml" version = "0.15.3" @@ -765,7 +816,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" dependencies = [ "ciborium-io", - "half", + "half 1.8.2", ] [[package]] @@ -780,20 +831,19 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.21" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c27cdf28c0f604ba3f512b0c9a409f8de8513e4816705deb0498b627e7c3a3fd" +checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" dependencies = [ "clap_builder", "clap_derive", - "once_cell", ] [[package]] name = "clap_builder" -version = "4.3.21" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08a9f1ab5e9f01a9b81f202e8562eb9a10de70abf9eaeac1be465c28b75aa4aa" +checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" dependencies = [ "anstream", "anstyle", @@ -803,9 +853,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.3.12" +version = "4.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", "proc-macro2", @@ -815,9 +865,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" [[package]] name = "cobs" @@ -851,6 +901,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", + "unicode-width", "windows-sys 0.45.0", ] @@ -892,6 +943,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.4" @@ -1040,6 +1101,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -1226,6 +1293,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + [[package]] name = "dirs-next" version = "1.0.2" @@ -1236,6 +1312,18 @@ dependencies = [ "dirs-sys-next", ] +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + [[package]] name = "dirs-sys-next" version = "0.1.2" @@ -1292,6 +1380,16 @@ dependencies = [ "uuid 1.5.0", ] +[[package]] +name = "dyn-stack" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" +dependencies = [ + "bytemuck", + "reborrow", +] + [[package]] name = "either" version = "1.9.0" @@ -1455,6 +1553,15 @@ dependencies = [ "libc", ] +[[package]] +name = "esaxx-rs" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" +dependencies = [ + "cc", +] + [[package]] name = "fastrand" version = "2.0.0" @@ -1551,6 +1658,21 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.0" @@ -1677,6 +1799,123 @@ dependencies = [ "byteorder", ] +[[package]] +name = "gemm" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b3afa707040531a7527477fd63a81ea4f6f3d26037a2f96776e57fb843b258e" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cc3973a4c30c73f26a099113953d0c772bb17ee2e07976c0a06b8fe1f38a57d" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30362894b93dada374442cb2edf4512ddf19513c9bec88e06a445bcb6b22e64f" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "988499faa80566b046b4fee2c5f15af55b5a20c1fe8486b112ebb34efa045ad6" +dependencies = [ + "bytemuck", + "dyn-stack", + "half 2.3.1", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f16" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6cf2854a12371684c38d9a865063a27661812a3ff5803454c5742e8f5a388ce" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half 2.3.1", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bc84003cf6d950a7c7ca714ad6db281b6cef5c7d462f5cd9ad90ea2409c7227" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35187ef101a71eed0ecd26fb4a6255b4192a12f1c5335f3a795698f2d9b6cf33" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1778,6 +2017,20 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + [[package]] name = "hash32" version = "0.2.1" @@ -1871,6 +2124,23 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hf-hub" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" +dependencies = [ + "dirs", + "indicatif", + "log", + "native-tls", + "rand", + "serde", + "serde_json", + "thiserror", + "ureq", +] + [[package]] name = "hmac" version = "0.12.1" @@ -2507,6 +2777,19 @@ dependencies = [ "serde", ] +[[package]] +name = "indicatif" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + [[package]] name = "inout" version = "0.1.3" @@ -3057,6 +3340,22 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b8c72594ac26bfd34f2d99dfced2edfaddfe8a476e3ff2ca0eb293d925c4f83" +[[package]] +name = "macro_rules_attribute" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" + [[package]] name = "manifest-dir-macros" version = "0.1.17" @@ -3256,6 +3555,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" dependencies = [ "libc", + "stable_deref_trait", ] [[package]] @@ -3277,6 +3577,9 @@ dependencies = [ "bstr", "bytemuck", "byteorder", + "candle-core", + "candle-nn", + "candle-transformers", "charabia", "concat-arrays", "crossbeam-channel", @@ -3290,6 +3593,7 @@ dependencies = [ "geoutils", "grenad", "heed", + "hf-hub", "indexmap 2.0.0", "insta", "instant-distance", @@ -3321,6 +3625,7 @@ dependencies = [ "tempfile", "thiserror", "time", + "tokenizers", "uuid 1.5.0", ] @@ -3376,6 +3681,45 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "monostate" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f370ae88093ec6b11a710dec51321a61d420fafd1bad6e30d01bd9c920e8ee" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "371717c0a5543d6a800cac822eac735aa7d2d2fbb41002e9856a4089532dbdce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + +[[package]] +name = "native-tls" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" +dependencies = [ + "lazy_static", + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nelson" version = "0.1.0" @@ -3422,6 +3766,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "bytemuck", + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -3452,6 +3806,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.31.1" @@ -3473,12 +3833,84 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "onig" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" +dependencies = [ + "bitflags 1.3.2", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "oorandom" version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "openssl" +version = "0.10.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a257ad03cd8fb16ad4172fedf8094451e1af1c4b70097636ef2eac9a5f0cc33" +dependencies = [ + "bitflags 2.3.3", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "openssl-sys" +version = "0.9.95" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40a4130519a360279579c2053038317e40eff64d13fd3f004f9e1b72b8a6aaf9" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "ordered-float" version = "3.7.0" @@ -3755,6 +4187,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bccab0e7fd7cc19f820a1c8c91720af652d0c88dc9664dd72aef2614f04af3b" + [[package]] name = "postcard" version = "1.0.8" @@ -3864,6 +4302,18 @@ dependencies = [ "serde", ] +[[package]] +name = "pulp" +version = "0.18.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7057c1435edb390ebfc51743abad043377f1f698ce8e649a9b52a4b378be5e4d" +dependencies = [ + "bytemuck", + "libm", + "num-complex", + "reborrow", +] + [[package]] name = "quote" version = "1.0.32" @@ -3903,6 +4353,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rand_pcg" version = "0.3.1" @@ -3914,27 +4374,51 @@ dependencies = [ ] [[package]] -name = "rayon" -version = "1.7.0" +name = "raw-cpuid" +version = "10.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", ] [[package]] -name = "rayon-core" -version = "1.11.0" +name = "rayon-cond" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + [[package]] name = "redox_syscall" version = "0.2.16" @@ -4047,6 +4531,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c31b5c4033f8fdde8700e4657be2c497e7288f01515be52168c631e2e4d4086" +[[package]] +name = "riff" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1" + [[package]] name = "ring" version = "0.16.20" @@ -4193,6 +4683,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -4202,6 +4702,15 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" +dependencies = [ + "windows-sys 0.48.0", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -4218,6 +4727,29 @@ dependencies = [ "untrusted", ] +[[package]] +name = "security-framework" +version = "2.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "segment" version = "0.2.2" @@ -4238,6 +4770,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918" +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + [[package]] name = "serde" version = "1.0.190" @@ -4288,6 +4826,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.3" @@ -4445,6 +4992,18 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -4582,18 +5141,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.44" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "611040a08a0439f8248d1990b111c95baa9c704c805fa1f62104b39655fd7f90" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.44" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", @@ -4665,6 +5224,38 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" +[[package]] +name = "tokenizers" +version = "0.14.1" +source = "git+https://github.com/huggingface/tokenizers.git?tag=v0.14.1#6357206cdcce4d78ffb1e0372feb456caea09375" +dependencies = [ + "aho-corasick", + "clap", + "derive_builder", + "esaxx-rs", + "getrandom", + "indicatif", + "itertools 0.11.0", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + [[package]] name = "tokio" version = "1.29.1" @@ -4791,9 +5382,21 @@ dependencies = [ "cfg-if", "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "tracing-core" version = "0.1.31" @@ -4860,18 +5463,39 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + [[package]] name = "unicode-segmentation" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + [[package]] name = "unicode-xid" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "untrusted" version = "0.7.1" @@ -4885,10 +5509,14 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" dependencies = [ "base64 0.21.2", + "flate2", "log", + "native-tls", "once_cell", "rustls 0.21.6", "rustls-webpki 0.100.2", + "serde", + "serde_json", "url", "webpki-roots 0.23.1", ] @@ -5083,6 +5711,15 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wav" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a65e199c799848b4f997072aa4d673c034f80f40191f97fe2f0a23f410be1609" +dependencies = [ + "riff", +] + [[package]] name = "web-sys" version = "0.3.64" diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 8aa2a6f3f..0c1c5ab97 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -72,6 +72,11 @@ puffin = "0.16.0" log = "0.4.17" logging_timer = "1.1.0" csv = "1.2.1" +candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } +candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } +candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } +tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } +hf-hub = "0.3.2" [dev-dependencies] mimalloc = { version = "0.1.37", default-features = false } From 13c2c6c16beda22942029326348db0e9929df421 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 15 Nov 2023 15:46:37 +0100 Subject: [PATCH 02/28] Small commit to add hybrid search and autoembedding --- Cargo.lock | 281 +++++++++--- dump/src/lib.rs | 1 + dump/src/reader/compat/v5_to_v6.rs | 1 + index-scheduler/src/batch.rs | 9 + index-scheduler/src/features.rs | 4 +- index-scheduler/src/insta_snapshot.rs | 1 + index-scheduler/src/lib.rs | 41 ++ meilisearch-types/src/error.rs | 8 +- meilisearch-types/src/settings.rs | 23 + .../src/analytics/segment_analytics.rs | 2 +- meilisearch/src/main.rs | 6 +- .../src/routes/indexes/facet_search.rs | 3 +- meilisearch/src/routes/indexes/search.rs | 40 +- meilisearch/src/routes/multi_search.rs | 3 + meilisearch/src/search.rs | 31 +- milli/Cargo.toml | 26 +- milli/examples/search.rs | 11 +- milli/src/error.rs | 28 ++ milli/src/index.rs | 29 ++ milli/src/lib.rs | 8 +- milli/src/prompt/context.rs | 97 ++++ milli/src/prompt/document.rs | 131 ++++++ milli/src/prompt/error.rs | 56 +++ milli/src/prompt/fields.rs | 172 ++++++++ milli/src/prompt/mod.rs | 144 ++++++ milli/src/prompt/template_checker.rs | 282 ++++++++++++ milli/src/score_details.rs | 164 ++++++- milli/src/search/hybrid.rs | 336 ++++++++++++++ milli/src/search/mod.rs | 102 ++++- milli/src/search/new/matches/mod.rs | 8 +- milli/src/search/new/mod.rs | 175 +++++--- milli/src/search/new/vector_sort.rs | 150 +++++++ .../extract/extract_vector_points.rs | 330 ++++++++++++-- .../src/update/index_documents/extract/mod.rs | 63 ++- milli/src/update/index_documents/mod.rs | 35 +- .../src/update/index_documents/typed_chunk.rs | 90 +++- milli/src/update/settings.rs | 113 ++++- milli/src/vector/error.rs | 229 ++++++++++ milli/src/vector/hf.rs | 192 ++++++++ milli/src/vector/mod.rs | 142 ++++++ milli/src/vector/openai.rs | 416 ++++++++++++++++++ milli/src/vector/settings.rs | 308 +++++++++++++ 42 files changed, 4045 insertions(+), 246 deletions(-) create mode 100644 milli/src/prompt/context.rs create mode 100644 milli/src/prompt/document.rs create mode 100644 milli/src/prompt/error.rs create mode 100644 milli/src/prompt/fields.rs create mode 100644 milli/src/prompt/mod.rs create mode 100644 milli/src/prompt/template_checker.rs create mode 100644 milli/src/search/hybrid.rs create mode 100644 milli/src/search/new/vector_sort.rs create mode 100644 milli/src/vector/error.rs create mode 100644 milli/src/vector/hf.rs create mode 100644 milli/src/vector/mod.rs create mode 100644 milli/src/vector/openai.rs create mode 100644 milli/src/vector/settings.rs diff --git a/Cargo.lock b/Cargo.lock index f6ce4b26b..a407244b1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,7 +46,7 @@ dependencies = [ "actix-tls", "actix-utils", "ahash 0.8.3", - "base64 0.21.2", + "base64 0.21.5", "bitflags 1.3.2", "brotli", "bytes", @@ -120,7 +120,7 @@ dependencies = [ "futures-util", "mio", "num_cpus", - "socket2", + "socket2 0.4.9", "tokio", "tracing", ] @@ -201,7 +201,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "smallvec", - "socket2", + "socket2 0.4.9", "time", "url", ] @@ -365,6 +365,12 @@ dependencies = [ "backtrace", ] +[[package]] +name = "anymap2" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" + [[package]] name = "arbitrary" version = "1.3.0" @@ -455,9 +461,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "base64ct" @@ -508,6 +514,21 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -555,12 +576,12 @@ dependencies = [ [[package]] name = "bstr" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6798148dccfbff0fae41c7574d2fa8f1ef3492fba0face179de5d8d447d67b05" +checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" dependencies = [ "memchr", - "regex-automata 0.3.6", + "regex-automata 0.4.3", "serde", ] @@ -1346,6 +1367,12 @@ dependencies = [ "syn 2.0.28", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "doxygen-rs" version = "0.2.2" @@ -1562,6 +1589,16 @@ dependencies = [ "cc", ] +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", +] + [[package]] name = "fastrand" version = "2.0.0" @@ -1690,9 +1727,9 @@ checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -1705,9 +1742,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -1715,15 +1752,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -1732,15 +1769,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", @@ -1749,21 +1786,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -2207,7 +2244,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -2949,7 +2986,7 @@ version = "8.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "pem", "ring", "serde", @@ -2957,6 +2994,16 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "kstring" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3066350882a1cd6d950d055997f379ac37fd39f81cd4d8ed186032eb3c5747" +dependencies = [ + "serde", + "static_assertions", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -2980,9 +3027,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libgit2-sys" @@ -3251,6 +3298,63 @@ version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +[[package]] +name = "liquid" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f68ae1011499ae2ef879f631891f21c78e309755f4a5e483c4a8f12e10b609" +dependencies = [ + "doc-comment", + "liquid-core", + "liquid-derive", + "liquid-lib", + "serde", +] + +[[package]] +name = "liquid-core" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79e0724dfcaad5cfb7965ea0f178ca0870b8d7315178f4a7179f5696f7f04d5f" +dependencies = [ + "anymap2", + "itertools 0.10.5", + "kstring", + "liquid-derive", + "num-traits", + "pest", + "pest_derive", + "regex", + "serde", + "time", +] + +[[package]] +name = "liquid-derive" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + +[[package]] +name = "liquid-lib" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2a17e273a6fb1fb6268f7a5867ddfd0bd4683c7e19b51084f3d567fad4348c0" +dependencies = [ + "itertools 0.10.5", + "liquid-core", + "once_cell", + "percent-encoding", + "regex", + "time", + "unicode-segmentation", +] + [[package]] name = "litemap" version = "0.6.1" @@ -3483,7 +3587,7 @@ dependencies = [ name = "meilisearch-auth" version = "1.5.1" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "enum-iterator", "hmac", "maplit", @@ -3544,9 +3648,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "memmap2" @@ -3589,6 +3693,7 @@ dependencies = [ "filter-parser", "flatten-serde-json", "fst", + "futures", "fxhash", "geoutils", "grenad", @@ -3600,6 +3705,7 @@ dependencies = [ "itertools 0.11.0", "json-depth-checker", "levenshtein_automata", + "liquid", "log", "logging_timer", "maplit", @@ -3607,6 +3713,7 @@ dependencies = [ "meili-snap", "memmap2", "mimalloc", + "nolife", "obkv", "once_cell", "ordered-float", @@ -3614,6 +3721,7 @@ dependencies = [ "rand", "rand_pcg", "rayon", + "reqwest", "roaring", "rstar", "serde", @@ -3624,8 +3732,10 @@ dependencies = [ "smartstring", "tempfile", "thiserror", + "tiktoken-rs", "time", "tokenizers", + "tokio", "uuid 1.5.0", ] @@ -3671,9 +3781,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "log", @@ -3725,6 +3835,12 @@ name = "nelson" version = "0.1.0" source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a" +[[package]] +name = "nolife" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc" + [[package]] name = "nom" version = "7.1.3" @@ -4480,6 +4596,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" + [[package]] name = "regex-syntax" version = "0.7.4" @@ -4488,11 +4610,11 @@ checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "bytes", "encoding_rs", "futures-core", @@ -4514,6 +4636,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", "tokio-rustls 0.24.1", "tower-service", @@ -4521,7 +4644,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots 0.22.6", + "webpki-roots 0.25.3", "winreg", ] @@ -4582,6 +4705,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -4648,7 +4777,7 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", ] [[package]] @@ -4977,6 +5106,16 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + [[package]] name = "spin" version = "0.5.2" @@ -5097,6 +5236,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tar" version = "0.4.40" @@ -5159,6 +5319,21 @@ dependencies = [ "syn 2.0.28", ] +[[package]] +name = "tiktoken-rs" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4427b6b1c6b38215b92dd47a83a0ecc6735573d0a5a4c14acc0ac5b33b28adb" +dependencies = [ + "anyhow", + "base64 0.21.5", + "bstr", + "fancy-regex", + "lazy_static", + "parking_lot", + "rustc-hash", +] + [[package]] name = "time" version = "0.3.30" @@ -5258,11 +5433,10 @@ dependencies = [ [[package]] name = "tokio" -version = "1.29.1" +version = "1.34.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ - "autocfg", "backtrace", "bytes", "libc", @@ -5271,16 +5445,16 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.5", "tokio-macros", "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -5508,7 +5682,7 @@ version = "2.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "flate2", "log", "native-tls", @@ -5758,6 +5932,12 @@ dependencies = [ "rustls-webpki 0.100.2", ] +[[package]] +name = "webpki-roots" +version = "0.25.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" + [[package]] name = "whatlang" version = "0.16.2" @@ -5942,11 +6122,12 @@ dependencies = [ [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] [[package]] diff --git a/dump/src/lib.rs b/dump/src/lib.rs index 15b281c41..be0053a7c 100644 --- a/dump/src/lib.rs +++ b/dump/src/lib.rs @@ -276,6 +276,7 @@ pub(crate) mod test { ), }), pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: std::marker::PhantomData, }; settings.check() diff --git a/dump/src/reader/compat/v5_to_v6.rs b/dump/src/reader/compat/v5_to_v6.rs index 8a0d6e5e1..9351ae70d 100644 --- a/dump/src/reader/compat/v5_to_v6.rs +++ b/dump/src/reader/compat/v5_to_v6.rs @@ -378,6 +378,7 @@ impl From> for v6::Settings { v5::Setting::Reset => v6::Setting::Reset, v5::Setting::NotSet => v6::Setting::NotSet, }, + embedders: v6::Setting::NotSet, _kind: std::marker::PhantomData, } } diff --git a/index-scheduler/src/batch.rs b/index-scheduler/src/batch.rs index 94a8b3f07..cf8544ae7 100644 --- a/index-scheduler/src/batch.rs +++ b/index-scheduler/src/batch.rs @@ -1202,6 +1202,10 @@ impl IndexScheduler { let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; + let embedder_configs = index.embedding_configs(index_wtxn)?; + // TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense) + let embedders = self.embedders(embedder_configs)?; + let mut builder = milli::update::IndexDocuments::new( index_wtxn, index, @@ -1220,6 +1224,8 @@ impl IndexScheduler { let (new_builder, user_result) = builder.add_documents(reader)?; builder = new_builder; + builder = builder.with_embedders(embedders.clone()); + let received_documents = if let Some(Details::DocumentAdditionOrUpdate { received_documents, @@ -1345,6 +1351,9 @@ impl IndexScheduler { for (task, (_, settings)) in tasks.iter_mut().zip(settings) { let checked_settings = settings.clone().check(); + if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) { + self.features().check_vector("Passing `embedders` in settings")? + } if checked_settings.proximity_precision.set().is_some() { self.features.features().check_proximity_precision()?; } diff --git a/index-scheduler/src/features.rs b/index-scheduler/src/features.rs index ae2823c30..d6ce3cae4 100644 --- a/index-scheduler/src/features.rs +++ b/index-scheduler/src/features.rs @@ -56,12 +56,12 @@ impl RoFeatures { } } - pub fn check_vector(&self) -> Result<()> { + pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> { if self.runtime.vector_store { Ok(()) } else { Err(FeatureNotEnabledError { - disabled_action: "Passing `vector` as a query parameter", + disabled_action, feature: "vector store", issue_link: "https://github.com/meilisearch/product/discussions/677", } diff --git a/index-scheduler/src/insta_snapshot.rs b/index-scheduler/src/insta_snapshot.rs index bd8fa5148..ddb9e934a 100644 --- a/index-scheduler/src/insta_snapshot.rs +++ b/index-scheduler/src/insta_snapshot.rs @@ -41,6 +41,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String { planned_failures: _, run_loop_iteration: _, currently_updating_index: _, + embedders: _, } = scheduler; let rtxn = env.read_txn().unwrap(); diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index a1b6497d9..fbe38a7fb 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; use meilisearch_types::milli::documents::DocumentsBatchBuilder; use meilisearch_types::milli::update::IndexerConfig; +use meilisearch_types::milli::vector::{Embedder, EmbedderOptions}; use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; use puffin::FrameView; @@ -341,6 +342,8 @@ pub struct IndexScheduler { /// so that a handle to the index is available from other threads (search) in an optimized manner. currently_updating_index: Arc>>, + embedders: Arc>>>, + // ================= test // The next entry is dedicated to the tests. /// Provide a way to set a breakpoint in multiple part of the scheduler. @@ -386,6 +389,7 @@ impl IndexScheduler { auth_path: self.auth_path.clone(), version_file_path: self.version_file_path.clone(), currently_updating_index: self.currently_updating_index.clone(), + embedders: self.embedders.clone(), #[cfg(test)] test_breakpoint_sdr: self.test_breakpoint_sdr.clone(), #[cfg(test)] @@ -484,6 +488,7 @@ impl IndexScheduler { auth_path: options.auth_path, version_file_path: options.version_file_path, currently_updating_index: Arc::new(RwLock::new(None)), + embedders: Default::default(), #[cfg(test)] test_breakpoint_sdr, @@ -1333,6 +1338,42 @@ impl IndexScheduler { } } + // TODO: consider using a type alias or a struct embedder/template + #[allow(clippy::type_complexity)] + pub fn embedders( + &self, + embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>, + ) -> Result, Arc)>> { + let res: Result<_> = embedding_configs + .into_iter() + .map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| { + let prompt = + Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?); + // optimistically return existing embedder + { + let embedders = self.embedders.read().unwrap(); + if let Some(embedder) = embedders.get(&embedder_options) { + return Ok((name, (embedder.clone(), prompt))); + } + } + + // add missing embedder + let embedder = Arc::new( + Embedder::new(embedder_options.clone()) + .map_err(meilisearch_types::milli::vector::Error::from) + .map_err(meilisearch_types::milli::UserError::from) + .map_err(meilisearch_types::milli::Error::from)?, + ); + { + let mut embedders = self.embedders.write().unwrap(); + embedders.insert(embedder_options, embedder.clone()); + } + Ok((name, (embedder, prompt))) + }) + .collect(); + res + } + /// Blocks the thread until the test handle asks to progress to/through this breakpoint. /// /// Two messages are sent through the channel for each breakpoint. diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index b1dc6b777..b1cc7cf82 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -256,6 +256,7 @@ InvalidSettingsProximityPrecision , InvalidRequest , BAD_REQUEST ; InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ; InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ; +InvalidSettingsEmbedders , InvalidRequest , BAD_REQUEST ; InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ; InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsSortableAttributes , InvalidRequest , BAD_REQUEST ; @@ -303,7 +304,8 @@ TaskNotFound , InvalidRequest , NOT_FOUND ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; -UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE +UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; +VectorEmbeddingError , InvalidRequest , BAD_REQUEST } impl ErrorCode for JoinError { @@ -336,6 +338,9 @@ impl ErrorCode for milli::Error { UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => { Code::InvalidDocumentId } + UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, + UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, + UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::MultiplePrimaryKeyCandidatesFound { .. } => { Code::IndexPrimaryKeyMultipleCandidatesFound @@ -358,6 +363,7 @@ impl ErrorCode for milli::Error { UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance } + UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, } } } diff --git a/meilisearch-types/src/settings.rs b/meilisearch-types/src/settings.rs index 487354b8e..da06d5264 100644 --- a/meilisearch-types/src/settings.rs +++ b/meilisearch-types/src/settings.rs @@ -199,6 +199,10 @@ pub struct Settings { #[deserr(default, error = DeserrJsonError)] pub pagination: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default, error = DeserrJsonError)] + pub embedders: Setting>>, + #[serde(skip)] #[deserr(skip)] pub _kind: PhantomData, @@ -222,6 +226,7 @@ impl Settings { typo_tolerance: Setting::Reset, faceting: Setting::Reset, pagination: Setting::Reset, + embedders: Setting::Reset, _kind: PhantomData, } } @@ -243,6 +248,7 @@ impl Settings { typo_tolerance, faceting, pagination, + embedders, .. } = self; @@ -262,6 +268,7 @@ impl Settings { typo_tolerance, faceting, pagination, + embedders, _kind: PhantomData, } } @@ -307,6 +314,7 @@ impl Settings { typo_tolerance: self.typo_tolerance, faceting: self.faceting, pagination: self.pagination, + embedders: self.embedders, _kind: PhantomData, } } @@ -490,6 +498,12 @@ pub fn apply_settings_to_builder( Setting::Reset => builder.reset_pagination_max_total_hits(), Setting::NotSet => (), } + + match settings.embedders.clone() { + Setting::Set(value) => builder.set_embedder_settings(value), + Setting::Reset => builder.reset_embedder_settings(), + Setting::NotSet => (), + } } pub fn settings( @@ -571,6 +585,12 @@ pub fn settings( ), }; + let embedders = index + .embedding_configs(rtxn)? + .into_iter() + .map(|(name, config)| (name, Setting::Set(config.into()))) + .collect(); + Ok(Settings { displayed_attributes: match displayed_attributes { Some(attrs) => Setting::Set(attrs), @@ -599,6 +619,7 @@ pub fn settings( typo_tolerance: Setting::Set(typo_tolerance), faceting: Setting::Set(faceting), pagination: Setting::Set(pagination), + embedders: Setting::Set(embedders), _kind: PhantomData, }) } @@ -747,6 +768,7 @@ pub(crate) mod test { typo_tolerance: Setting::NotSet, faceting: Setting::NotSet, pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: PhantomData::, }; @@ -772,6 +794,7 @@ pub(crate) mod test { typo_tolerance: Setting::NotSet, faceting: Setting::NotSet, pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: PhantomData::, }; diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index f75516731..d5f08936d 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -686,7 +686,7 @@ impl SearchAggregator { ret.max_terms_number = q.split_whitespace().count(); } - if let Some(ref vector) = vector { + if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector { ret.max_vector_size = vector.len(); } diff --git a/meilisearch/src/main.rs b/meilisearch/src/main.rs index 246d62c3b..ddd37bbb6 100644 --- a/meilisearch/src/main.rs +++ b/meilisearch/src/main.rs @@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; /// does all the setup before meilisearch is launched fn setup(opt: &Opt) -> anyhow::Result<()> { let mut log_builder = env_logger::Builder::new(); - log_builder.parse_filters(&opt.log_level.to_string()); + let log_filters = format!( + "{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn", + opt.log_level + ); + log_builder.parse_filters(&log_filters); log_builder.init(); diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 142a424c0..72440711c 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::VectorQuery; use serde_json::Value; use crate::analytics::{Analytics, FacetSearchAggregator}; @@ -117,7 +118,7 @@ impl From for SearchQuery { highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), crop_marker: DEFAULT_CROP_MARKER(), matching_strategy, - vector, + vector: vector.map(VectorQuery::Vector), attributes_to_search_on, } } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 5a0a9e92b..e63a95e60 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -2,12 +2,13 @@ use actix_web::web::Data; use actix_web::{web, HttpRequest, HttpResponse}; use deserr::actix_web::{AwebJson, AwebQueryParameter}; use index_scheduler::IndexScheduler; -use log::debug; +use log::{debug, warn}; use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::VectorQuery; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -88,7 +89,7 @@ impl From for SearchQuery { Self { q: other.q, - vector: other.vector.map(CS::into_inner), + vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), offset: other.offset.0, limit: other.limit.0, page: other.page.as_deref().copied(), @@ -193,6 +194,9 @@ pub async fn search_with_post( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); + + embed(&mut query, index_scheduler.get_ref(), &index).await?; + let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; if let Ok(ref search_result) = search_result { @@ -206,6 +210,38 @@ pub async fn search_with_post( Ok(HttpResponse::Ok().json(search_result)) } +pub async fn embed( + query: &mut SearchQuery, + index_scheduler: &IndexScheduler, + index: &meilisearch_types::milli::Index, +) -> Result<(), ResponseError> { + if let Some(VectorQuery::String(prompt)) = query.vector.take() { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedder = index_scheduler.embedders(embedder_configs)?; + + /// FIXME: add error if no embedder, remove unwrap, support multiple embedders + let embeddings = embedder + .get("default") + .unwrap() + .0 + .embed(vec![prompt]) + .await + .map_err(meilisearch_types::milli::vector::Error::from) + .map_err(meilisearch_types::milli::UserError::from) + .map_err(meilisearch_types::milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); + + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); + } else { + query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); + } + }; + Ok(()) +} + #[cfg(test)] mod test { use super::*; diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index bcb8bb2a1..4e578572d 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator}; use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::sequential_extractor::SeqHandler; +use crate::routes::indexes::search::embed; use crate::search::{ add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, }; @@ -74,6 +75,8 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; + embed(&mut query, index_scheduler.get_ref(), &index).await.with_index(query_index)?; + let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features)) .await diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 41f073b48..235b745a9 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -16,6 +16,7 @@ use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::{ dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, + VectorQuery, }; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -46,7 +47,7 @@ pub struct SearchQuery { #[deserr(default, error = DeserrJsonError)] pub q: Option, #[deserr(default, error = DeserrJsonError)] - pub vector: Option>, + pub vector: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -105,7 +106,7 @@ pub struct SearchQueryWithIndex { #[deserr(default, error = DeserrJsonError)] pub q: Option, #[deserr(default, error = DeserrJsonError)] - pub vector: Option>, + pub vector: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -339,11 +340,18 @@ fn prepare_search<'t>( let mut search = index.search(rtxn); if query.vector.is_some() && query.q.is_some() { - warn!("Ignoring the query string `q` when used with the `vector` parameter."); + warn!("Attempting hybrid search"); } if let Some(ref vector) = query.vector { - search.vector(vector.clone()); + match vector { + VectorQuery::Vector(vector) => { + search.vector(vector.clone()); + } + VectorQuery::String(_) => { + panic!("Failed while preparing search; caller did not generate embedding for query") + } + } } if let Some(ref query) = query.q { @@ -375,7 +383,7 @@ fn prepare_search<'t>( } if query.vector.is_some() { - features.check_vector()?; + features.check_vector("Passing `vector` as a query parameter")?; } // compute the offset on the limit depending on the pagination mode. @@ -429,7 +437,11 @@ pub fn perform_search( prepare_search(index, &rtxn, &query, features)?; let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = - search.execute()?; + if query.q.is_some() && query.vector.is_some() { + search.execute_hybrid()? + } else { + search.execute()? + }; let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); @@ -538,13 +550,13 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - let semantic_score = match query.vector.as_ref() { + let semantic_score = /*match query.vector.as_ref() { Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { Some(vectors) => compute_semantic_score(vector, vectors)?, None => None, }, None => None, - }; + };*/ None; let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -629,7 +641,8 @@ pub fn perform_search( hits: documents, hits_info, query: query.q.unwrap_or_default(), - vector: query.vector, + // FIXME: display input vector + vector: None, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 0c1c5ab97..38931ca0f 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -27,10 +27,13 @@ fst = "0.4.7" fxhash = "0.2.1" geoutils = "0.5.1" grenad = { version = "0.4.5", default-features = false, features = [ - "rayon", "tempfile" + "rayon", + "tempfile", ] } heed = { version = "0.20.0-alpha.9", default-features = false, features = [ - "serde-json", "serde-bincode", "read-txn-no-tls" + "serde-json", + "serde-bincode", + "read-txn-no-tls", ] } indexmap = { version = "2.0.0", features = ["serde"] } instant-distance = { version = "0.6.1", features = ["with-serde"] } @@ -77,6 +80,15 @@ candle-transformers = { git = "https://github.com/huggingface/candle.git", versi candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } hf-hub = "0.3.2" +tokio = { version = "1.34.0", features = ["rt"] } +futures = "0.3.29" +nolife = { version = "0.3.1" } +reqwest = { version = "0.11.16", features = [ + "rustls-tls", + "json", +], default-features = false } +tiktoken-rs = "0.5.7" +liquid = "0.26.4" [dev-dependencies] mimalloc = { version = "0.1.37", default-features = false } @@ -88,7 +100,15 @@ meili-snap = { path = "../meili-snap" } rand = { version = "0.8.5", features = ["small_rng"] } [features] -all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] +all-tokenizations = [ + "charabia/chinese", + "charabia/hebrew", + "charabia/japanese", + "charabia/thai", + "charabia/korean", + "charabia/greek", + "charabia/khmer", +] # Use POSIX semaphores instead of SysV semaphores in LMDB # For more information on this feature, see heed's Cargo.toml diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 82de56434..a94677771 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -5,8 +5,8 @@ use std::time::Instant; use heed::EnvOpenOptions; use milli::{ - execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger, - TermsMatchingStrategy, + execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, + SearchLogger, TermsMatchingStrategy, }; #[global_allocator] @@ -49,14 +49,15 @@ fn main() -> Result<(), Box> { let start = Instant::now(); let mut ctx = SearchContext::new(&index, &txn); + let universe = filtered_universe(&ctx, &None)?; + let docs = execute_search( &mut ctx, - &(!query.trim().is_empty()).then(|| query.trim().to_owned()), - &None, + (!query.trim().is_empty()).then(|| query.trim()), TermsMatchingStrategy::Last, milli::score_details::ScoringStrategy::Skip, false, - &None, + universe, &None, GeoSortStrategy::default(), 0, diff --git a/milli/src/error.rs b/milli/src/error.rs index cbbd8a3e5..032fd63a7 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -180,6 +180,14 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco UnknownInternalDocumentId { document_id: DocumentId }, #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] InvalidMinTypoWordLenSetting(u8, u8), + #[error(transparent)] + VectorEmbeddingError(#[from] crate::vector::Error), + #[error(transparent)] + MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), + #[error(transparent)] + InvalidPrompt(#[from] crate::prompt::error::NewPromptError), + #[error("Invalid prompt in for embeddings with name '{0}': {1}")] + InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), } #[derive(Error, Debug)] @@ -336,6 +344,26 @@ impl From for Error { } } +#[derive(Debug, Clone, Copy)] +pub enum FaultSource { + User, + Runtime, + Bug, + Undecided, +} + +impl std::fmt::Display for FaultSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + FaultSource::User => "user error", + FaultSource::Runtime => "runtime error", + FaultSource::Bug => "coding error", + FaultSource::Undecided => "error", + }; + f.write_str(s) + } +} + #[test] fn conditionally_lookup_for_error_message() { let prefix = "Attribute `name` is not sortable."; diff --git a/milli/src/index.rs b/milli/src/index.rs index 01a01ac37..307d87906 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -23,6 +23,7 @@ use crate::heed_codec::{ }; use crate::proximity::ProximityPrecision; use crate::readable_slices::ReadableSlices; +use crate::vector::EmbeddingConfig; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, @@ -74,6 +75,7 @@ pub mod main_key { pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; pub const PROXIMITY_PRECISION: &str = "proximity-precision"; + pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; } pub mod db_name { @@ -1528,6 +1530,33 @@ impl Index { Ok(script_language) } + + pub(crate) fn put_embedding_configs( + &self, + wtxn: &mut RwTxn<'_>, + configs: Vec<(String, EmbeddingConfig)>, + ) -> heed::Result<()> { + self.main.remap_types::>>().put( + wtxn, + main_key::EMBEDDING_CONFIGS, + &configs, + ) + } + + pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result { + self.main.remap_key_type::().delete(wtxn, main_key::EMBEDDING_CONFIGS) + } + + pub fn embedding_configs( + &self, + rtxn: &RoTxn<'_>, + ) -> Result> { + Ok(self + .main + .remap_types::>>() + .get(rtxn, main_key::EMBEDDING_CONFIGS)? + .unwrap_or_default()) + } } #[cfg(test)] diff --git a/milli/src/lib.rs b/milli/src/lib.rs index acea72c41..b3c15e205 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -17,11 +17,13 @@ pub mod facet; mod fields_ids_map; pub mod heed_codec; pub mod index; +pub mod prompt; pub mod proximity; mod readable_slices; pub mod score_details; mod search; pub mod update; +pub mod vector; #[cfg(test)] #[macro_use] @@ -37,8 +39,8 @@ pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; pub use search::new::{ - execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger, - VisualSearchLogger, + execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext, + SearchLogger, VisualSearchLogger, }; use serde_json::Value; pub use {charabia as tokenizer, heed}; @@ -60,7 +62,7 @@ pub use self::index::Index; pub use self::search::{ FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy, - DEFAULT_VALUES_PER_FACET, + VectorQuery, DEFAULT_VALUES_PER_FACET, }; pub type Result = std::result::Result; diff --git a/milli/src/prompt/context.rs b/milli/src/prompt/context.rs new file mode 100644 index 000000000..a28a87caa --- /dev/null +++ b/milli/src/prompt/context.rs @@ -0,0 +1,97 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use super::document::Document; +use super::fields::Fields; +use crate::FieldsIdsMap; + +#[derive(Debug, Clone)] +pub struct Context<'a> { + document: &'a Document<'a>, + fields: Fields<'a>, +} + +impl<'a> Context<'a> { + pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { + Self { document, fields: Fields::new(document, field_id_map) } + } +} + +impl<'a> ObjectView for Context<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(self.document.as_value()) + .chain(std::iter::once(self.fields.as_value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "doc" || index == "fields" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "doc" => Some(self.document.as_value()), + "fields" => Some(self.fields.as_value()), + _ => None, + } + } +} + +impl<'a> ValueView for Context<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/prompt/document.rs b/milli/src/prompt/document.rs new file mode 100644 index 000000000..b5d43b5be --- /dev/null +++ b/milli/src/prompt/document.rs @@ -0,0 +1,131 @@ +use std::cell::OnceCell; +use std::collections::BTreeMap; + +use liquid::model::{ + DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use crate::update::del_add::{DelAdd, KvReaderDelAdd}; +use crate::FieldsIdsMap; + +#[derive(Debug, Clone)] +pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>); + +#[derive(Debug, Clone)] +struct ParsedValue(std::cell::OnceCell); + +impl ParsedValue { + fn empty() -> ParsedValue { + ParsedValue(OnceCell::new()) + } + + fn get(&self, raw: &[u8]) -> &LiquidValue { + self.0.get_or_init(|| { + let value: serde_json::Value = serde_json::from_slice(raw).unwrap(); + liquid::model::to_value(&value).unwrap() + }) + } +} + +impl<'a> Document<'a> { + pub fn new( + data: obkv::KvReaderU16<'a>, + side: DelAdd, + inverted_field_map: &'a FieldsIdsMap, + ) -> Self { + let mut out_data = BTreeMap::new(); + for (fid, raw) in data { + let obkv = KvReaderDelAdd::new(raw); + let Some(raw) = obkv.get(side) else { + continue; + }; + let Some(name) = inverted_field_map.name(fid) else { + continue; + }; + out_data.insert(name, (raw, ParsedValue::empty())); + } + Self(out_data) + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn iter(&self) -> impl Iterator + '_ { + self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned())) + } +} + +impl<'a> ObjectView for Document<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + self.len() as i64 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + let keys = BTreeMap::keys(&self.0).map(|&s| s.into()); + Box::new(keys) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView)) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView))) + } + + fn contains_key(&self, index: &str) -> bool { + self.0.contains_key(index) + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView) + } +} + +impl<'a> ValueView for Document<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.is_empty(), + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object(self.iter().collect()) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/prompt/error.rs b/milli/src/prompt/error.rs new file mode 100644 index 000000000..8a762b60a --- /dev/null +++ b/milli/src/prompt/error.rs @@ -0,0 +1,56 @@ +use crate::error::FaultSource; + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct NewPromptError { + pub kind: NewPromptErrorKind, + pub fault: FaultSource, +} + +impl From for crate::Error { + fn from(value: NewPromptError) -> Self { + crate::Error::UserError(crate::UserError::InvalidPrompt(value)) + } +} + +impl NewPromptError { + pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError { + Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User } + } + + pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError { + Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum NewPromptErrorKind { + #[error("cannot parse template: {0}")] + CannotParseTemplate(liquid::Error), + #[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")] + InvalidFieldsInTemplate(liquid::Error), +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct RenderPromptError { + pub kind: RenderPromptErrorKind, + pub fault: FaultSource, +} +impl RenderPromptError { + pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError { + Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RenderPromptErrorKind { + #[error("missing field in document: {0}")] + MissingContext(liquid::Error), +} + +impl From for crate::Error { + fn from(value: RenderPromptError) -> Self { + crate::Error::UserError(crate::UserError::MissingDocumentField(value)) + } +} diff --git a/milli/src/prompt/fields.rs b/milli/src/prompt/fields.rs new file mode 100644 index 000000000..3187485f1 --- /dev/null +++ b/milli/src/prompt/fields.rs @@ -0,0 +1,172 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use super::document::Document; +use crate::FieldsIdsMap; +#[derive(Debug, Clone)] +pub struct Fields<'a>(Vec>); + +impl<'a> Fields<'a> { + pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { + Self( + std::iter::repeat(document) + .zip(field_id_map.iter()) + .map(|(document, (_fid, name))| FieldValue { document, name }) + .collect(), + ) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct FieldValue<'a> { + name: &'a str, + document: &'a Document<'a>, +} + +impl<'a> ValueView for FieldValue<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.is_empty(), + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl<'a> FieldValue<'a> { + pub fn name(&self) -> &&'a str { + &self.name + } + + pub fn value(&self) -> &dyn ValueView { + self.document.get(self.name).unwrap_or(&LiquidValue::Nil) + } + + pub fn is_empty(&self) -> bool { + self.size() == 0 + } +} + +impl<'a> ObjectView for FieldValue<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "name" || index == "value" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "name" => Some(self.name()), + "value" => Some(self.value()), + _ => None, + } + } +} + +impl<'a> ArrayView for Fields<'a> { + fn as_value(&self) -> &dyn ValueView { + self.0.as_value() + } + + fn size(&self) -> i64 { + self.0.len() as i64 + } + + fn values<'k>(&'k self) -> Box + 'k> { + self.0.values() + } + + fn contains_key(&self, index: i64) -> bool { + self.0.contains_key(index) + } + + fn get(&self, index: i64) -> Option<&dyn ValueView> { + ArrayView::get(&self.0, index) + } +} + +impl<'a> ValueView for Fields<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + self.0.render() + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + self.0.source() + } + + fn type_name(&self) -> &'static str { + self.0.type_name() + } + + fn query_state(&self, state: liquid::model::State) -> bool { + self.0.query_state(state) + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + self.0.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + self.0.to_value() + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self) + } +} diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs new file mode 100644 index 000000000..351a51bb1 --- /dev/null +++ b/milli/src/prompt/mod.rs @@ -0,0 +1,144 @@ +mod context; +mod document; +pub(crate) mod error; +mod fields; +mod template_checker; + +use std::convert::TryFrom; + +use error::{NewPromptError, RenderPromptError}; + +use self::context::Context; +use self::document::Document; +use crate::update::del_add::DelAdd; +use crate::FieldsIdsMap; + +pub struct Prompt { + template: liquid::Template, + template_text: String, + strategy: PromptFallbackStrategy, + fallback: String, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PromptData { + pub template: String, + pub strategy: PromptFallbackStrategy, + pub fallback: String, +} + +impl From for PromptData { + fn from(value: Prompt) -> Self { + Self { template: value.template_text, strategy: value.strategy, fallback: value.fallback } + } +} + +impl TryFrom for Prompt { + type Error = NewPromptError; + + fn try_from(value: PromptData) -> Result { + Prompt::new(value.template, Some(value.strategy), Some(value.fallback)) + } +} + +impl Clone for Prompt { + fn clone(&self) -> Self { + let template_text = self.template_text.clone(); + Self { + template: new_template(&template_text).unwrap(), + template_text, + strategy: self.strategy, + fallback: self.fallback.clone(), + } + } +} + +fn new_template(text: &str) -> Result { + liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text) +} + +fn default_template() -> liquid::Template { + new_template(default_template_text()).unwrap() +} + +fn default_template_text() -> &'static str { + "{% for field in fields %} \ + {{ field.name }}: {{ field.value }}\n\ + {% endfor %}" +} + +fn default_fallback() -> &'static str { + "" +} + +impl Default for Prompt { + fn default() -> Self { + Self { + template: default_template(), + template_text: default_template_text().into(), + strategy: Default::default(), + fallback: default_fallback().into(), + } + } +} + +impl Default for PromptData { + fn default() -> Self { + Self { + template: default_template_text().into(), + strategy: Default::default(), + fallback: default_fallback().into(), + } + } +} + +impl Prompt { + pub fn new( + template: String, + strategy: Option, + fallback: Option, + ) -> Result { + let this = Self { + template: liquid::ParserBuilder::with_stdlib() + .build() + .unwrap() + .parse(&template) + .map_err(NewPromptError::cannot_parse_template)?, + template_text: template, + strategy: strategy.unwrap_or_default(), + fallback: fallback.unwrap_or_default(), + }; + + // render template with special object that's OK with `doc.*` and `fields.*` + /// FIXME: doesn't work for nested objects e.g. `doc.a.b` + this.template + .render(&template_checker::TemplateChecker) + .map_err(NewPromptError::invalid_fields_in_template)?; + + Ok(this) + } + + pub fn render( + &self, + document: obkv::KvReaderU16<'_>, + side: DelAdd, + field_id_map: &FieldsIdsMap, + ) -> Result { + let document = Document::new(document, side, field_id_map); + let context = Context::new(&document, field_id_map); + + self.template.render(&context).map_err(RenderPromptError::missing_context) + } +} + +#[derive( + Debug, Default, Clone, PartialEq, Eq, Copy, serde::Serialize, serde::Deserialize, deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum PromptFallbackStrategy { + Fallback, + Skip, + #[default] + Error, +} diff --git a/milli/src/prompt/template_checker.rs b/milli/src/prompt/template_checker.rs new file mode 100644 index 000000000..641a9ed64 --- /dev/null +++ b/milli/src/prompt/template_checker.rs @@ -0,0 +1,282 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +#[derive(Debug)] +pub struct TemplateChecker; + +#[derive(Debug)] +pub struct DummyDoc; + +#[derive(Debug)] +pub struct DummyFields; + +#[derive(Debug)] +pub struct DummyField; + +const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil; + +impl ObjectView for DummyField { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::empty()) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(std::iter::empty()) + } + + fn contains_key(&self, index: &str) -> bool { + index == "name" || index == "value" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + if self.contains_key(index) { + Some(DUMMY_VALUE.as_view()) + } else { + None + } + } +} + +impl ValueView for DummyField { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: State) -> bool { + DUMMY_VALUE.query_state(state) + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Nil + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl ValueView for DummyFields { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "array" + } + + fn query_state(&self, state: State) -> bool { + DUMMY_VALUE.query_state(state) + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Nil + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self) + } +} + +impl ArrayView for DummyFields { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + i64::MAX + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::empty()) + } + + fn contains_key(&self, _index: i64) -> bool { + true + } + + fn get(&self, _index: i64) -> Option<&dyn ValueView> { + Some(DummyField.as_value()) + } +} + +impl ObjectView for DummyDoc { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 1000 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(std::iter::empty()) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::empty()) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(std::iter::empty()) + } + + fn contains_key(&self, _index: &str) -> bool { + true + } + + fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { + Some(DUMMY_VALUE.as_view()) + } +} + +impl ValueView for DummyDoc { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: State) -> bool { + DUMMY_VALUE.query_state(state) + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Nil + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl ObjectView for TemplateChecker { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "doc" || index == "fields" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "doc" => Some(DummyDoc.as_value()), + "fields" => Some(DummyFields.as_value()), + _ => None, + } + } +} + +impl ValueView for TemplateChecker { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/score_details.rs b/milli/src/score_details.rs index 8fc998ae4..f6b9db58c 100644 --- a/milli/src/score_details.rs +++ b/milli/src/score_details.rs @@ -1,3 +1,6 @@ +use std::cmp::Ordering; + +use itertools::Itertools; use serde::Serialize; use crate::distance_between_two_points; @@ -12,9 +15,24 @@ pub enum ScoreDetails { ExactAttribute(ExactAttribute), ExactWords(ExactWords), Sort(Sort), + Vector(Vector), GeoSort(GeoSort), } +#[derive(Clone, Copy)] +pub enum ScoreValue<'a> { + Score(f64), + Sort(&'a Sort), + GeoSort(&'a GeoSort), +} + +enum RankOrValue<'a> { + Rank(Rank), + Sort(&'a Sort), + GeoSort(&'a GeoSort), + Score(f64), +} + impl ScoreDetails { pub fn local_score(&self) -> Option { self.rank().map(Rank::local_score) @@ -31,11 +49,55 @@ impl ScoreDetails { ScoreDetails::ExactWords(details) => Some(details.rank()), ScoreDetails::Sort(_) => None, ScoreDetails::GeoSort(_) => None, + ScoreDetails::Vector(_) => None, } } - pub fn global_score<'a>(details: impl Iterator) -> f64 { - Rank::global_score(details.filter_map(Self::rank)) + pub fn global_score<'a>(details: impl Iterator + 'a) -> f64 { + Self::score_values(details) + .find_map(|x| { + let ScoreValue::Score(score) = x else { + return None; + }; + Some(score) + }) + .unwrap_or(1.0f64) + } + + pub fn score_values<'a>( + details: impl Iterator + 'a, + ) -> impl Iterator> + 'a { + details + .map(ScoreDetails::rank_or_value) + .coalesce(|left, right| match (left, right) { + (RankOrValue::Rank(left), RankOrValue::Rank(right)) => { + Ok(RankOrValue::Rank(Rank::merge(left, right))) + } + (left, right) => Err((left, right)), + }) + .map(|rank_or_value| match rank_or_value { + RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()), + RankOrValue::Sort(s) => ScoreValue::Sort(s), + RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g), + RankOrValue::Score(s) => ScoreValue::Score(s), + }) + } + + fn rank_or_value(&self) -> RankOrValue<'_> { + match self { + ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()), + ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()), + ScoreDetails::Proximity(p) => RankOrValue::Rank(*p), + ScoreDetails::Fid(f) => RankOrValue::Rank(*f), + ScoreDetails::Position(p) => RankOrValue::Rank(*p), + ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()), + ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()), + ScoreDetails::Sort(sort) => RankOrValue::Sort(sort), + ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort), + ScoreDetails::Vector(vector) => RankOrValue::Score( + vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64), + ), + } } /// Panics @@ -181,6 +243,19 @@ impl ScoreDetails { details_map.insert(sort, sort_details); order += 1; } + ScoreDetails::Vector(s) => { + let vector = format!("vectorSort({:?})", s.target_vector); + let value = s.value_similarity.as_ref().map(|(v, _)| v); + let similarity = s.value_similarity.as_ref().map(|(_, s)| s); + + let details = serde_json::json!({ + "order": order, + "value": value, + "similarity": similarity, + }); + details_map.insert(vector, details); + order += 1; + } } } details_map @@ -297,15 +372,21 @@ impl Rank { pub fn global_score(details: impl Iterator) -> f64 { let mut rank = Rank { rank: 1, max_rank: 1 }; for inner_rank in details { - rank.rank -= 1; - - rank.rank *= inner_rank.max_rank; - rank.max_rank *= inner_rank.max_rank; - - rank.rank += inner_rank.rank; + rank = Rank::merge(rank, inner_rank); } rank.local_score() } + + pub fn merge(mut outer: Rank, inner: Rank) -> Rank { + outer.rank = outer.rank.saturating_sub(1); + + outer.rank *= inner.max_rank; + outer.max_rank *= inner.max_rank; + + outer.rank += inner.rank; + + outer + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] @@ -335,13 +416,78 @@ pub struct Sort { pub value: serde_json::Value, } -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +impl PartialOrd for Sort { + fn partial_cmp(&self, other: &Self) -> Option { + if self.field_name != other.field_name { + return None; + } + if self.ascending != other.ascending { + return None; + } + match (&self.value, &other.value) { + (serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal), + (serde_json::Value::Null, _) => Some(Ordering::Less), + (_, serde_json::Value::Null) => Some(Ordering::Greater), + // numbers are always before strings + (serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater), + (serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less), + (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { + // FIXME: unwrap permitted here? + let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?; + // 12 < 42, and when ascending, we want to see 12 first, so the smallest. + // Hence, when ascending, smaller is better + Some(if self.ascending { order.reverse() } else { order }) + } + (serde_json::Value::String(left), serde_json::Value::String(right)) => { + let order = left.cmp(right); + // Taking e.g. "a" and "z" + // "a" < "z", and when ascending, we want to see "a" first, so the smallest. + // Hence, when ascending, smaller is better + Some(if self.ascending { order.reverse() } else { order }) + } + _ => None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] pub struct GeoSort { pub target_point: [f64; 2], pub ascending: bool, pub value: Option<[f64; 2]>, } +impl PartialOrd for GeoSort { + fn partial_cmp(&self, other: &Self) -> Option { + if self.target_point != other.target_point { + return None; + } + if self.ascending != other.ascending { + return None; + } + Some(match (self.distance(), other.distance()) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(left), Some(right)) => { + let order = left.partial_cmp(&right)?; + if self.ascending { + // when ascending, the one with the smallest distance has the best score + order.reverse() + } else { + order + } + } + }) + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd)] +pub struct Vector { + pub target_vector: Vec, + pub value_similarity: Option<(Vec, f32)>, +} + impl GeoSort { pub fn distance(&self) -> Option { self.value.map(|value| distance_between_two_points(&self.target_point, &value)) diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs new file mode 100644 index 000000000..02c518126 --- /dev/null +++ b/milli/src/search/hybrid.rs @@ -0,0 +1,336 @@ +use std::cmp::Ordering; +use std::collections::HashMap; + +use itertools::Itertools; +use roaring::RoaringBitmap; + +use super::new::{execute_vector_search, PartialSearchResult}; +use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; +use crate::{ + execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult, +}; + +struct CombinedSearchResult { + matching_words: MatchingWords, + candidates: RoaringBitmap, + document_scores: Vec<(u32, CombinedScore)>, +} + +type CombinedScore = (Vec, Option>); + +fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { + let mut left_main_it = ScoreDetails::score_values(left.0.iter()); + let mut left_sub_it = + ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten()); + + let mut right_main_it = ScoreDetails::score_values(right.0.iter()); + let mut right_sub_it = + ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten()); + + let mut left_main = left_main_it.next(); + let mut left_sub = left_sub_it.next(); + let mut right_main = right_main_it.next(); + let mut right_sub = right_sub_it.next(); + + loop { + let left = + take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it); + + let right = + take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it); + + match (left, right) { + (None, None) => return Ordering::Equal, + (None, Some(_)) => return Ordering::Less, + (Some(_), None) => return Ordering::Greater, + (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { + if (left - right).abs() <= f64::EPSILON { + continue; + } + return left.partial_cmp(&right).unwrap(); + } + (Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => { + match left.partial_cmp(right).unwrap() { + Ordering::Equal => continue, + order => return order, + } + } + (Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => { + match left.partial_cmp(right).unwrap() { + Ordering::Equal => continue, + order => return order, + } + } + (Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater, + (Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less, + // if we have this, we're bad + (Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_))) + | (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => { + unreachable!("Unexpected geo and sort comparison") + } + } + } +} + +fn take_best_score<'a>( + main_score: &mut Option>, + sub_score: &mut Option>, + main_it: &mut impl Iterator>, + sub_it: &mut impl Iterator>, +) -> Option> { + match (*main_score, *sub_score) { + (Some(main), None) => { + *main_score = main_it.next(); + Some(main) + } + (None, Some(sub)) => { + *sub_score = sub_it.next(); + Some(sub) + } + (main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => { + // take max, both advance + *main_score = main_it.next(); + *sub_score = sub_it.next(); + if main_f >= sub_v { + main + } else { + sub + } + } + (main @ Some(ScoreValue::Score(_)), _) => { + *main_score = main_it.next(); + main + } + (_, sub @ Some(ScoreValue::Score(_))) => { + *sub_score = sub_it.next(); + sub + } + (main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => { + // take best advance both + *main_score = main_it.next(); + *sub_score = sub_it.next(); + if main_geo >= sub_geo { + main + } else { + sub + } + } + (main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => { + // take best advance both + *main_score = main_it.next(); + *sub_score = sub_it.next(); + if main_sort >= sub_sort { + main + } else { + sub + } + } + ( + Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), + Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), + ) => None, + + (None, None) => None, + } +} + +impl CombinedSearchResult { + fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self { + let mut docid_scores = HashMap::new(); + for (docid, score) in + main_results.documents_ids.iter().zip(main_results.document_scores.into_iter()) + { + docid_scores.insert(*docid, (score, None)); + } + + for (docid, score) in ancillary_results + .documents_ids + .iter() + .zip(ancillary_results.document_scores.into_iter()) + { + docid_scores + .entry(*docid) + .and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score)); + } + + let mut document_scores: Vec<_> = docid_scores.into_iter().collect(); + + document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse()); + + Self { + matching_words: main_results.matching_words, + candidates: main_results.candidates, + document_scores, + } + } + + fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult { + let mut documents_ids = + Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); + let mut document_scores = + Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); + + let mut documents_seen = RoaringBitmap::new(); + for (docid, (main_score, _sub_score)) in left + .document_scores + .into_iter() + .merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| { + // the first value is the one with the greatest score + compare_scores(left, right).is_ge() + }) + // remove documents we already saw + .filter(|(docid, _)| documents_seen.insert(*docid)) + // start skipping **after** the filter + .skip(from) + // take **after** skipping + .take(length) + { + documents_ids.push(docid); + // TODO: pass both scores to documents_score in some way? + document_scores.push(main_score); + } + + SearchResult { + matching_words: left.matching_words, + candidates: left.candidates | right.candidates, + documents_ids, + document_scores, + } + } +} + +impl<'a> Search<'a> { + pub fn execute_hybrid(&self) -> Result { + // TODO: find classier way to achieve that than to reset vector and query params + // create separate keyword and semantic searches + let mut search = Search { + query: self.query.clone(), + vector: self.vector.clone(), + filter: self.filter.clone(), + offset: 0, + limit: self.limit + self.offset, + sort_criteria: self.sort_criteria.clone(), + searchable_attributes: self.searchable_attributes, + geo_strategy: self.geo_strategy, + terms_matching_strategy: self.terms_matching_strategy, + scoring_strategy: ScoringStrategy::Detailed, + words_limit: self.words_limit, + exhaustive_number_hits: self.exhaustive_number_hits, + rtxn: self.rtxn, + index: self.index, + }; + + let vector_query = search.vector.take(); + let keyword_query = self.query.as_deref(); + + let keyword_results = search.execute()?; + + // skip semantic search if we don't have a vector query (placeholder search) + let Some(vector_query) = vector_query else { + return Ok(keyword_results); + }; + + // completely skip semantic search if the results of the keyword search are good enough + if self.results_good_enough(&keyword_results) { + return Ok(keyword_results); + } + + search.vector = Some(vector_query); + search.query = None; + + // TODO: would be better to have two distinct functions at this point + let vector_results = search.execute()?; + + // Compute keyword scores for vector_results + let keyword_results_for_vector = + self.keyword_results_for_vector(keyword_query, &vector_results)?; + + // compute vector scores for keyword_results + let vector_results_for_keyword = + // can unwrap because we returned already if there was no vector query + self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?; + + let keyword_results = + CombinedSearchResult::new(keyword_results, vector_results_for_keyword); + let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector); + + let merge_results = + CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit); + assert!(merge_results.documents_ids.len() <= self.limit); + Ok(merge_results) + } + + fn vector_results_for_keyword( + &self, + vector: &[f32], + keyword_results: &SearchResult, + ) -> Result { + let mut ctx = SearchContext::new(self.index, self.rtxn); + + if let Some(searchable_attributes) = self.searchable_attributes { + ctx.searchable_attributes(searchable_attributes)?; + } + + let universe = keyword_results.documents_ids.iter().collect(); + + execute_vector_search( + &mut ctx, + vector, + ScoringStrategy::Detailed, + universe, + &self.sort_criteria, + self.geo_strategy, + 0, + self.limit + self.offset, + ) + } + + fn keyword_results_for_vector( + &self, + query: Option<&str>, + vector_results: &SearchResult, + ) -> Result { + let mut ctx = SearchContext::new(self.index, self.rtxn); + + if let Some(searchable_attributes) = self.searchable_attributes { + ctx.searchable_attributes(searchable_attributes)?; + } + + let universe = vector_results.documents_ids.iter().collect(); + + execute_search( + &mut ctx, + query, + self.terms_matching_strategy, + ScoringStrategy::Detailed, + self.exhaustive_number_hits, + universe, + &self.sort_criteria, + self.geo_strategy, + 0, + self.limit + self.offset, + Some(self.words_limit), + &mut DefaultSearchLogger, + &mut DefaultSearchLogger, + ) + } + + fn results_good_enough(&self, keyword_results: &SearchResult) -> bool { + const GOOD_ENOUGH_SCORE: f64 = 0.9; + + // 1. we check that we got a sufficient number of results + if keyword_results.document_scores.len() < self.limit + self.offset { + return false; + } + + // 2. and that all results have a good enough score. + // we need to check all results because due to sort like rules, they're not necessarily in relevancy order + for score in &keyword_results.document_scores { + let score = ScoreDetails::global_score(score.iter()); + if score < GOOD_ENOUGH_SCORE { + return false; + } + } + true + } +} diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index ee8cd1faf..8b541ffcd 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -3,6 +3,7 @@ use std::ops::ControlFlow; use charabia::normalizer::NormalizerOption; use charabia::Normalize; +use deserr::{DeserializeError, Deserr, Sequence}; use fst::automaton::{Automaton, Str}; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; @@ -12,12 +13,13 @@ use roaring::bitmap::RoaringBitmap; pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; -use self::new::PartialSearchResult; +use self::new::{execute_vector_search, PartialSearchResult}; use crate::error::UserError; use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::{ - execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, + execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, + Result, SearchContext, }; // Building these factories is not free. @@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100; pub mod facet; mod fst_utils; +pub mod hybrid; pub mod new; pub struct Search<'a> { @@ -50,6 +53,53 @@ pub struct Search<'a> { index: &'a Index, } +#[derive(Debug, Clone, PartialEq)] +pub enum VectorQuery { + Vector(Vec), + String(String), +} + +impl Deserr for VectorQuery +where + E: DeserializeError, +{ + fn deserialize_from_value( + value: deserr::Value, + location: deserr::ValuePointerRef, + ) -> std::result::Result { + match value { + deserr::Value::String(s) => Ok(VectorQuery::String(s)), + deserr::Value::Sequence(seq) => { + let v: std::result::Result, _> = seq + .into_iter() + .enumerate() + .map(|(index, v)| match v.into_value() { + deserr::Value::Float(f) => Ok(f as f32), + deserr::Value::Integer(i) => Ok(i as f32), + v => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::IncorrectValueKind { + actual: v, + accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer], + }, + location.push_index(index), + ))), + }) + .collect(); + Ok(VectorQuery::Vector(v?)) + } + _ => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::IncorrectValueKind { + actual: value, + accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence], + }, + location, + ))), + } + } +} + impl<'a> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { Search { @@ -75,8 +125,8 @@ impl<'a> Search<'a> { self } - pub fn vector(&mut self, vector: impl Into>) -> &mut Search<'a> { - self.vector = Some(vector.into()); + pub fn vector(&mut self, vector: Vec) -> &mut Search<'a> { + self.vector = Some(vector); self } @@ -140,23 +190,35 @@ impl<'a> Search<'a> { ctx.searchable_attributes(searchable_attributes)?; } + let universe = filtered_universe(&ctx, &self.filter)?; let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } = - execute_search( - &mut ctx, - &self.query, - &self.vector, - self.terms_matching_strategy, - self.scoring_strategy, - self.exhaustive_number_hits, - &self.filter, - &self.sort_criteria, - self.geo_strategy, - self.offset, - self.limit, - Some(self.words_limit), - &mut DefaultSearchLogger, - &mut DefaultSearchLogger, - )?; + match self.vector.as_ref() { + Some(vector) => execute_vector_search( + &mut ctx, + vector, + self.scoring_strategy, + universe, + &self.sort_criteria, + self.geo_strategy, + self.offset, + self.limit, + )?, + None => execute_search( + &mut ctx, + self.query.as_deref(), + self.terms_matching_strategy, + self.scoring_strategy, + self.exhaustive_number_hits, + universe, + &self.sort_criteria, + self.geo_strategy, + self.offset, + self.limit, + Some(self.words_limit), + &mut DefaultSearchLogger, + &mut DefaultSearchLogger, + )?, + }; // consume context and located_query_terms to build MatchingWords. let matching_words = match located_query_terms { diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 5d61de0f4..067fa1efd 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -498,19 +498,19 @@ mod tests { use super::*; use crate::index::tests::TempIndex; - use crate::{execute_search, SearchContext}; + use crate::{execute_search, filtered_universe, SearchContext}; impl<'a> MatcherBuilder<'a> { fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { let mut ctx = SearchContext::new(index, rtxn); + let universe = filtered_universe(&ctx, &None).unwrap(); let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( &mut ctx, - &Some(query.to_string()), - &None, + Some(query), crate::TermsMatchingStrategy::default(), crate::score_details::ScoringStrategy::Skip, false, - &None, + universe, &None, crate::search::new::GeoSortStrategy::default(), 0, diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index a1b5da4e8..372c89601 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -16,6 +16,7 @@ mod small_bitmap; mod exact_attribute; mod sort; +mod vector_sort; #[cfg(test)] mod tests; @@ -28,7 +29,6 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; -use instant_distance::Search; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -46,7 +46,7 @@ use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; use self::graph_based_ranking_rule::Words; use self::interner::Interned; -use crate::distance::NDotProductPoint; +use self::vector_sort::VectorSort; use crate::error::FieldIdMapMissingEntry; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; @@ -258,6 +258,70 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( Ok(ranking_rules) } +fn get_ranking_rules_for_vector<'ctx>( + ctx: &SearchContext<'ctx>, + sort_criteria: &Option>, + geo_strategy: geo_sort::Strategy, + target: &[f32], +) -> Result>> { + // query graph search + + let mut sort = false; + let mut sorted_fields = HashSet::new(); + let mut geo_sorted = false; + + let mut vector = false; + let mut ranking_rules: Vec> = vec![]; + + let settings_ranking_rules = ctx.index.criteria(ctx.txn)?; + for rr in settings_ranking_rules { + match rr { + crate::Criterion::Words + | crate::Criterion::Typo + | crate::Criterion::Proximity + | crate::Criterion::Attribute + | crate::Criterion::Exactness => { + if !vector { + let vector_candidates = ctx.index.documents_ids(ctx.txn)?; + let vector_sort = VectorSort::new(ctx, target.to_vec(), vector_candidates)?; + ranking_rules.push(Box::new(vector_sort)); + vector = true; + } + } + crate::Criterion::Sort => { + if sort { + continue; + } + resolve_sort_criteria( + sort_criteria, + ctx, + &mut ranking_rules, + &mut sorted_fields, + &mut geo_sorted, + geo_strategy, + )?; + sort = true; + } + crate::Criterion::Asc(field_name) => { + if sorted_fields.contains(&field_name) { + continue; + } + sorted_fields.insert(field_name.clone()); + ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?)); + } + crate::Criterion::Desc(field_name) => { + if sorted_fields.contains(&field_name) { + continue; + } + sorted_fields.insert(field_name.clone()); + ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?)); + } + } + } + + Ok(ranking_rules) +} + /// Return the list of initialised ranking rules to be used for a query graph search. fn get_ranking_rules_for_query_graph_search<'ctx>( ctx: &SearchContext<'ctx>, @@ -422,15 +486,62 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( Ok(()) } +pub fn filtered_universe(ctx: &SearchContext, filters: &Option) -> Result { + Ok(if let Some(filters) = filters { + filters.evaluate(ctx.txn, ctx.index)? + } else { + ctx.index.documents_ids(ctx.txn)? + }) +} + +#[allow(clippy::too_many_arguments)] +pub fn execute_vector_search( + ctx: &mut SearchContext, + vector: &[f32], + scoring_strategy: ScoringStrategy, + universe: RoaringBitmap, + sort_criteria: &Option>, + geo_strategy: geo_sort::Strategy, + from: usize, + length: usize, +) -> Result { + check_sort_criteria(ctx, sort_criteria.as_ref())?; + + /// FIXME: input universe = universe & documents_with_vectors + // for now if we're computing embeddings for ALL documents, we can assume that this is just universe + let ranking_rules = get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, vector)?; + + let mut placeholder_search_logger = logger::DefaultSearchLogger; + let placeholder_search_logger: &mut dyn SearchLogger = + &mut placeholder_search_logger; + + let BucketSortOutput { docids, scores, all_candidates } = bucket_sort( + ctx, + ranking_rules, + &PlaceholderQuery, + &universe, + from, + length, + scoring_strategy, + placeholder_search_logger, + )?; + + Ok(PartialSearchResult { + candidates: all_candidates, + document_scores: scores, + documents_ids: docids, + located_query_terms: None, + }) +} + #[allow(clippy::too_many_arguments)] pub fn execute_search( ctx: &mut SearchContext, - query: &Option, - vector: &Option>, + query: Option<&str>, terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, exhaustive_number_hits: bool, - filters: &Option, + mut universe: RoaringBitmap, sort_criteria: &Option>, geo_strategy: geo_sort::Strategy, from: usize, @@ -439,60 +550,8 @@ pub fn execute_search( placeholder_search_logger: &mut dyn SearchLogger, query_graph_logger: &mut dyn SearchLogger, ) -> Result { - let mut universe = if let Some(filters) = filters { - filters.evaluate(ctx.txn, ctx.index)? - } else { - ctx.index.documents_ids(ctx.txn)? - }; - check_sort_criteria(ctx, sort_criteria.as_ref())?; - if let Some(vector) = vector { - let mut search = Search::default(); - let docids = match ctx.index.vector_hnsw(ctx.txn)? { - Some(hnsw) => { - if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { - if vector.len() != expected_size { - return Err(UserError::InvalidVectorDimensions { - expected: expected_size, - found: vector.len(), - } - .into()); - } - } - - let vector = NDotProductPoint::new(vector.clone()); - - let neighbors = hnsw.search(&vector, &mut search); - - let mut docids = Vec::new(); - let mut uniq_docids = RoaringBitmap::new(); - for instant_distance::Item { distance: _, pid, point: _ } in neighbors { - let index = pid.into_inner(); - let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); - if universe.contains(docid) && uniq_docids.insert(docid) { - docids.push(docid); - if docids.len() == (from + length) { - break; - } - } - } - - // return the nearest documents that are also part of the candidates - // along with a dummy list of scores that are useless in this context. - docids.into_iter().skip(from).take(length).collect() - } - None => Vec::new(), - }; - - return Ok(PartialSearchResult { - candidates: universe, - document_scores: vec![Vec::new(); docids.len()], - documents_ids: docids, - located_query_terms: None, - }); - } - let mut located_query_terms = None; let query_terms = if let Some(query) = query { // We make sure that the analyzer is aware of the stop words @@ -546,7 +605,7 @@ pub fn execute_search( terms_matching_strategy, )?; - universe = + universe &= resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?; bucket_sort( diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs new file mode 100644 index 000000000..831ed45cd --- /dev/null +++ b/milli/src/search/new/vector_sort.rs @@ -0,0 +1,150 @@ +use std::future::Future; +use std::iter::FromIterator; +use std::pin::Pin; + +use nolife::DynBoxScope; +use roaring::RoaringBitmap; + +use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; +use crate::distance::NDotProductPoint; +use crate::index::Hnsw; +use crate::score_details::{self, ScoreDetails}; +use crate::{Result, SearchContext, SearchLogger, UserError}; + +pub struct VectorSort { + query: Option, + target: Vec, + vector_candidates: RoaringBitmap, + scope: nolife::DynBoxScope, +} + +type Item<'a> = instant_distance::Item<'a, NDotProductPoint>; +type SearchFut = Pin>>; + +struct SearchFamily; +impl<'a> nolife::Family<'a> for SearchFamily { + type Family = Box> + 'a>; +} + +async fn search_scope( + mut time_capsule: nolife::TimeCapsule, + hnsw: Hnsw, + target: Vec, +) -> nolife::Never { + let mut search = instant_distance::Search::default(); + let it = Box::new(hnsw.search(&NDotProductPoint::new(target), &mut search)); + let mut it: Box> = it; + loop { + time_capsule.freeze(&mut it).await; + } +} + +impl VectorSort { + pub fn new( + ctx: &SearchContext, + target: Vec, + vector_candidates: RoaringBitmap, + ) -> Result { + let hnsw = + ctx.index.vector_hnsw(ctx.txn)?.unwrap_or(Hnsw::builder().build_hnsw(Vec::default()).0); + + if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { + if target.len() != expected_size { + return Err(UserError::InvalidVectorDimensions { + expected: expected_size, + found: target.len(), + } + .into()); + } + } + + let target_clone = target.clone(); + let producer = move |time_capsule| -> SearchFut { + Box::pin(search_scope(time_capsule, hnsw, target_clone)) + }; + let scope = DynBoxScope::new(producer); + + Ok(Self { query: None, target, vector_candidates, scope }) + } +} + +impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { + fn id(&self) -> String { + "vector_sort".to_owned() + } + + fn start_iteration( + &mut self, + _ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + universe: &RoaringBitmap, + query: &Q, + ) -> Result<()> { + assert!(self.query.is_none()); + + self.query = Some(query.clone()); + self.vector_candidates &= universe; + + Ok(()) + } + + #[allow(clippy::only_used_in_recursion)] + fn next_bucket( + &mut self, + ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + universe: &RoaringBitmap, + ) -> Result>> { + let query = self.query.as_ref().unwrap().clone(); + self.vector_candidates &= universe; + + if self.vector_candidates.is_empty() { + return Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: self.target.clone(), + value_similarity: None, + }), + })); + } + + let scope = &mut self.scope; + let target = &self.target; + let vector_candidates = &self.vector_candidates; + + scope.enter(|it| { + for item in it.by_ref() { + let item: Item = item; + let index = item.pid.into_inner(); + let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); + + if vector_candidates.contains(docid) { + return Ok(Some(RankingRuleOutput { + query, + candidates: RoaringBitmap::from_iter([docid]), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: target.clone(), + value_similarity: Some(( + item.point.clone().into_inner(), + 1.0 - item.distance, + )), + }), + })); + } + } + Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: target.clone(), + value_similarity: None, + }), + })) + }) + } + + fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger) { + self.query = None; + } +} diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 317a9aec3..8399c220b 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -1,9 +1,10 @@ use std::cmp::Ordering; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::fs::File; use std::io::{self, BufReader, BufWriter}; use std::mem::size_of; use std::str::from_utf8; +use std::sync::Arc; use bytemuck::cast_slice; use grenad::Writer; @@ -13,13 +14,56 @@ use serde_json::{from_slice, Value}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::error::UserError; +use crate::prompt::Prompt; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::try_split_at; -use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; +use crate::vector::Embedder; +use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors}; /// The length of the elements that are always in the buffer when inserting new values. const TRUNCATE_SIZE: usize = size_of::(); +pub struct ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + pub manual_vectors: grenad::Reader>, + // docid -> () + pub remove_vectors: grenad::Reader>, + // docid -> prompt + pub prompts: grenad::Reader>, +} + +enum VectorStateDelta { + NoChange, + // Remove all vectors, generated or manual, from this document + NowRemoved, + + // Add the manually specified vectors, passed in the other grenad + // Remove any previously generated vectors + // Note: changing the value of the manually specified vector **should not record** this delta + WasGeneratedNowManual(Vec>), + + ManualDelta(Vec>, Vec>), + + // Add the vector computed from the specified prompt + // Remove any previous vector + // Note: changing the value of the prompt **does require** recording this delta + NowGenerated(String), +} + +impl VectorStateDelta { + fn into_values(self) -> (bool, String, (Vec>, Vec>)) { + match self { + VectorStateDelta::NoChange => Default::default(), + VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()), + VectorStateDelta::WasGeneratedNowManual(add) => { + (true, Default::default(), (Default::default(), add)) + } + VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)), + VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()), + } + } +} + /// Extracts the embedding vector contained in each document under the `_vectors` field. /// /// Returns the generated grenad reader containing the docid as key associated to the Vec @@ -27,16 +71,34 @@ const TRUNCATE_SIZE: usize = size_of::(); pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, - vectors_fid: FieldId, -) -> Result>> { + field_id_map: FieldsIdsMap, + prompt: Option<&Prompt>, +) -> Result { puffin::profile_function!(); - let mut writer = create_writer( + // (docid, _index) -> KvWriterDelAdd -> Vector + let mut manual_vectors_writer = create_writer( indexer.chunk_compression_type, indexer.chunk_compression_level, tempfile::tempfile()?, ); + // (docid) -> (prompt) + let mut prompts_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + // (docid) -> () + let mut remove_vectors_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let vectors_fid = field_id_map.id("_vectors"); + let mut key_buffer = Vec::new(); let mut cursor = obkv_documents.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { @@ -53,43 +115,148 @@ pub fn extract_vector_points( // lazily get it when needed let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; - // first we retrieve the _vectors field - if let Some(value) = obkv.get(vectors_fid) { + let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) { let vectors_obkv = KvReaderDelAdd::new(value); + match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) { + (Some(old), Some(new)) => { + // no autogeneration + let del_vectors = extract_vectors(old, document_id)?; + let add_vectors = extract_vectors(new, document_id)?; - // then we extract the values - let del_vectors = vectors_obkv - .get(DelAdd::Deletion) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); - let add_vectors = vectors_obkv - .get(DelAdd::Addition) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); + VectorStateDelta::ManualDelta( + del_vectors.unwrap_or_default(), + add_vectors.unwrap_or_default(), + ) + } + (None, Some(new)) => { + // was possibly autogenerated, remove all vectors for that document + let add_vectors = extract_vectors(new, document_id)?; - // and we finally push the unique vectors into the writer - push_vectors_diff( - &mut writer, - &mut key_buffer, - del_vectors.unwrap_or_default(), - add_vectors.unwrap_or_default(), - )?; - } + VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default()) + } + (Some(_old), None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + if document_is_kept { + // becomes autogenerated + match prompt { + Some(prompt) => VectorStateDelta::NowGenerated(prompt.render( + obkv, + DelAdd::Addition, + &field_id_map, + )?), + None => VectorStateDelta::NowRemoved, + } + } else { + VectorStateDelta::NowRemoved + } + } + (None, None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + match prompt { + Some(prompt) => { + // Don't give up if the old prompt was failing + let old_prompt = prompt + .render(obkv, DelAdd::Deletion, &field_id_map) + .unwrap_or_default(); + let new_prompt = + prompt.render(obkv, DelAdd::Addition, &field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + VectorStateDelta::NoChange + } + } + // We no longer have a prompt, so we need to remove any existing vector + None => VectorStateDelta::NowRemoved, + } + } else { + VectorStateDelta::NowRemoved + } + } + } + } else { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + match prompt { + Some(prompt) => { + // Don't give up if the old prompt was failing + let old_prompt = prompt + .render(obkv, DelAdd::Deletion, &field_id_map) + .unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + VectorStateDelta::NoChange + } + } + None => VectorStateDelta::NowRemoved, + } + } else { + VectorStateDelta::NowRemoved + } + }; + + // and we finally push the unique vectors into the writer + push_vectors_diff( + &mut remove_vectors_writer, + &mut prompts_writer, + &mut manual_vectors_writer, + &mut key_buffer, + delta, + )?; } - writer_into_reader(writer) + Ok(ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + manual_vectors: writer_into_reader(manual_vectors_writer)?, + // docid -> () + remove_vectors: writer_into_reader(remove_vectors_writer)?, + // docid -> prompt + prompts: writer_into_reader(prompts_writer)?, + }) } /// Computes the diff between both Del and Add numbers and /// only inserts the parts that differ in the sorter. fn push_vectors_diff( - writer: &mut Writer>, + remove_vectors_writer: &mut Writer>, + prompts_writer: &mut Writer>, + manual_vectors_writer: &mut Writer>, key_buffer: &mut Vec, - mut del_vectors: Vec>, - mut add_vectors: Vec>, + delta: VectorStateDelta, ) -> Result<()> { + let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values(); + if must_remove { + key_buffer.truncate(TRUNCATE_SIZE); + remove_vectors_writer.insert(&key_buffer, [])?; + } + if !prompt.is_empty() { + key_buffer.truncate(TRUNCATE_SIZE); + prompts_writer.insert(&key_buffer, prompt.as_bytes())?; + } + // We sort and dedup the vectors del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); @@ -114,7 +281,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } EitherOrBoth::Right(vector) => { // We insert only the Add part of the Obkv to inform @@ -122,7 +289,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Addition, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } } } @@ -146,3 +313,102 @@ fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result( + // docid, prompt + prompt_reader: grenad::Reader, + indexer: GrenadParameters, + embedder: Arc, +) -> Result<(grenad::Reader>, Option)> { + let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; + + let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism + let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk + + // docid, state with embedding + let mut state_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let mut chunks = Vec::with_capacity(n_chunks); + let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk); + let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk); + let mut chunks_ids = Vec::with_capacity(n_chunks); + let mut cursor = prompt_reader.into_cursor()?; + + let mut expected_dimension = None; + + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + // SAFETY: precondition, the grenad value was saved from a string + let prompt = unsafe { std::str::from_utf8_unchecked(value) }; + if current_chunk.len() == current_chunk.capacity() { + chunks.push(std::mem::replace( + &mut current_chunk, + Vec::with_capacity(n_vectors_per_chunk), + )); + chunks_ids.push(std::mem::replace( + &mut current_chunk_ids, + Vec::with_capacity(n_vectors_per_chunk), + )); + }; + current_chunk.push(prompt.to_owned()); + current_chunk_ids.push(docid); + + if chunks.len() == chunks.capacity() { + let chunked_embeds = rt + .block_on( + embedder + .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), + ) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + + for (docid, embeddings) in chunks_ids + .iter() + .flat_map(|docids| docids.iter()) + .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + expected_dimension = Some(embeddings.dimension()); + } + chunks_ids.clear(); + } + } + + // send last chunk + if !chunks.is_empty() { + let chunked_embeds = rt + .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + for (docid, embeddings) in chunks_ids + .iter() + .flat_map(|docids| docids.iter()) + .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + expected_dimension = Some(embeddings.dimension()); + } + } + + if !current_chunk.is_empty() { + let embeds = rt + .block_on(embedder.embed(std::mem::take(&mut current_chunk))) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?; + + for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + expected_dimension = Some(embeddings.dimension()); + } + } + + Ok((writer_into_reader(state_writer)?, expected_dimension)) +} diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 57f349894..40b0dcd61 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -9,9 +9,10 @@ mod extract_word_docids; mod extract_word_pair_proximity_docids; mod extract_word_position_docids; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::fs::File; use std::io::BufReader; +use std::sync::Arc; use crossbeam_channel::Sender; use log::debug; @@ -23,7 +24,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_geo_points::extract_geo_points; -use self::extract_vector_points::extract_vector_points; +use self::extract_vector_points::{ + extract_embeddings, extract_vector_points, ExtractedVectorPoints, +}; use self::extract_word_docids::extract_word_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; use self::extract_word_position_docids::extract_word_position_docids; @@ -32,8 +35,10 @@ use super::helpers::{ MergeFn, MergeableReader, }; use super::{helpers, TypedChunk}; +use crate::prompt::Prompt; use crate::proximity::ProximityPrecision; -use crate::{FieldId, Result}; +use crate::vector::Embedder; +use crate::{FieldId, FieldsIdsMap, Result}; /// Extract data for each databases from obkv documents in parallel. /// Send data in grenad file over provided Sender. @@ -47,13 +52,14 @@ pub(crate) fn data_from_obkv_documents( faceted_fields: HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, - vectors_field_id: Option, + field_id_map: FieldsIdsMap, stop_words: Option>, allowed_separators: Option<&[&str]>, dictionary: Option<&[&str]>, max_positions_per_attributes: Option, exact_attributes: HashSet, proximity_precision: ProximityPrecision, + embedders: HashMap, Arc)>, ) -> Result<()> { puffin::profile_function!(); @@ -64,7 +70,8 @@ pub(crate) fn data_from_obkv_documents( original_documents_chunk, indexer, lmdb_writer_sx.clone(), - vectors_field_id, + field_id_map.clone(), + embedders.clone(), ) }) .collect::>()?; @@ -276,24 +283,42 @@ fn send_original_documents_data( original_documents_chunk: Result>>, indexer: GrenadParameters, lmdb_writer_sx: Sender>, - vectors_field_id: Option, + field_id_map: FieldsIdsMap, + embedders: HashMap, Arc)>, ) -> Result<()> { let original_documents_chunk = original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; - if let Some(vectors_field_id) = vectors_field_id { - let documents_chunk_cloned = original_documents_chunk.clone(); - let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); - rayon::spawn(move || { - let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); - let _ = match result { - Ok(vector_points) => { - lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) - } - Err(error) => lmdb_writer_sx_cloned.send(Err(error)), - }; - }); - } + let documents_chunk_cloned = original_documents_chunk.clone(); + let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + rayon::spawn(move || { + let (embedder, prompt) = embedders.get("default").cloned().unzip(); + let result = + extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref()); + let _ = match result { + Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { + /// FIXME: support multiple embedders + let results = embedder.and_then(|embedder| { + match extract_embeddings(prompts, indexer, embedder.clone()) { + Ok(results) => Some(results), + Err(error) => { + let _ = lmdb_writer_sx_cloned.send(Err(error)); + None + } + } + }); + let (embeddings, expected_dimension) = results.unzip(); + let expected_dimension = expected_dimension.flatten(); + lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { + remove_vectors, + embeddings, + expected_dimension, + manual_vectors, + })) + } + Err(error) => lmdb_writer_sx_cloned.send(Err(error)), + }; + }); // TODO: create a custom internal error lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap(); diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index f825cad1c..76848b628 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -4,11 +4,12 @@ mod helpers; mod transform; mod typed_chunk; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::io::{Cursor, Read, Seek}; use std::iter::FromIterator; use std::num::NonZeroU32; use std::result::Result as StdResult; +use std::sync::Arc; use crossbeam_channel::{Receiver, Sender}; use heed::types::Str; @@ -32,10 +33,12 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; pub use self::transform::{Transform, TransformOutput}; use crate::documents::{obkv_to_object, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; +use crate::prompt::Prompt; pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, }; +use crate::vector::Embedder; use crate::{CboRoaringBitmapCodec, Index, Result}; static MERGED_DATABASE_COUNT: usize = 7; @@ -78,6 +81,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> { should_abort: FA, added_documents: u64, deleted_documents: u64, + embedders: HashMap, Arc)>, } #[derive(Default, Debug, Clone)] @@ -121,6 +125,7 @@ where index, added_documents: 0, deleted_documents: 0, + embedders: Default::default(), }) } @@ -167,6 +172,14 @@ where Ok((self, Ok(indexed_documents))) } + pub fn with_embedders( + mut self, + embedders: HashMap, Arc)>, + ) -> Self { + self.embedders = embedders; + self + } + /// Remove a batch of documents from the current builder. /// /// Returns the number of documents deleted from the builder. @@ -322,17 +335,18 @@ where // get filterable fields for facet databases let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?; // get the fid of the `_geo.lat` and `_geo.lng` fields. - let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") { + let mut field_id_map = self.index.fields_ids_map(self.wtxn)?; + + // self.index.fields_ids_map($a)? ==>> field_id_map + let geo_fields_ids = match field_id_map.id("_geo") { Some(gfid) => { let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid); let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid); // if `_geo` is faceted then we get the `lat` and `lng` if is_sortable || is_filterable { - let field_ids = self - .index - .fields_ids_map(self.wtxn)? + let field_ids = field_id_map .insert("_geo.lat") - .zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng")) + .zip(field_id_map.insert("_geo.lng")) .ok_or(UserError::AttributeLimitReached)?; Some(field_ids) } else { @@ -341,8 +355,6 @@ where } None => None, }; - // get the fid of the `_vectors` field. - let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); let stop_words = self.index.stop_words(self.wtxn)?; let separators = self.index.allowed_separators(self.wtxn)?; @@ -364,6 +376,8 @@ where self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; + let cloned_embedder = self.embedders.clone(); + // Run extraction pipeline in parallel. pool.install(|| { puffin::profile_scope!("extract_and_send_grenad_chunks"); @@ -387,13 +401,14 @@ where faceted_fields, primary_key_id, geo_fields_ids, - vectors_field_id, + field_id_map, stop_words, separators.as_deref(), dictionary.as_deref(), max_positions_per_attributes, exact_attributes, proximity_precision, + cloned_embedder, ) }); @@ -2505,7 +2520,7 @@ mod tests { .unwrap(); let rtxn = index.read_txn().unwrap(); - let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap(); + let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap(); assert_eq!(res.documents_ids.len(), 3); } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 49e36b87e..36d230d00 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -47,7 +47,12 @@ pub(crate) enum TypedChunk { FieldIdFacetIsNullDocids(grenad::Reader>), FieldIdFacetIsEmptyDocids(grenad::Reader>), GeoPoints(grenad::Reader>), - VectorPoints(grenad::Reader>), + VectorPoints { + remove_vectors: grenad::Reader>, + embeddings: Option>>, + expected_dimension: Option, + manual_vectors: grenad::Reader>, + }, ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), } @@ -100,8 +105,8 @@ impl TypedChunk { TypedChunk::GeoPoints(grenad) => { format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) } - TypedChunk::VectorPoints(grenad) => { - format!("VectorPoints {{ number_of_entries: {} }}", grenad.len()) + TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => { + format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension.unwrap_or_default()) } TypedChunk::ScriptLanguageDocids(sl_map) => { format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) @@ -355,19 +360,64 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_rtree(wtxn, &rtree)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } - TypedChunk::VectorPoints(vector_points) => { - let mut vectors_set = HashSet::new(); + TypedChunk::VectorPoints { + remove_vectors, + manual_vectors, + embeddings, + expected_dimension, + } => { + if remove_vectors.is_empty() + && manual_vectors.is_empty() + && embeddings.as_ref().map_or(true, |e| e.is_empty()) + { + return Ok((RoaringBitmap::new(), is_merged_database)); + } + + let mut docid_vectors_map: HashMap>>> = + HashMap::new(); + // We extract and store the previous vectors if let Some(hnsw) = index.vector_hnsw(wtxn)? { for (pid, point) in hnsw.iter() { let pid_key = pid.into_inner(); let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); + docid_vectors_map.entry(docid).or_default().insert(vector); } } - let mut cursor = vector_points.into_cursor()?; + // remove vectors for docids we want them removed + let mut cursor = remove_vectors.into_cursor()?; + while let Some((key, _)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + + docid_vectors_map.remove(&docid); + } + + // add generated embeddings + if let Some((embeddings, expected_dimension)) = embeddings.zip(expected_dimension) { + let mut cursor = embeddings.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + let data: Vec> = + pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + // it is a code error to have embeddings and not expected_dimension + let embeddings = + crate::vector::Embeddings::from_inner(data, expected_dimension) + // code error if we somehow got the wrong dimension + .unwrap(); + + let mut set = HashSet::new(); + for embedding in embeddings.iter() { + set.insert(embedding.to_vec()); + } + + docid_vectors_map.insert(docid, set); + } + } + + // perform the manual diff + let mut cursor = manual_vectors.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let (left, _index) = try_split_array_at(key).unwrap(); @@ -376,23 +426,30 @@ pub(crate) fn write_typed_chunk_into_index( let vector_deladd_obkv = KvReaderDelAdd::new(value); if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { // convert the vector back to a Vec - let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - let key = (docid, vector); - if !vectors_set.remove(&key) { - error!("Unable to delete the vector: {:?}", key.1); - } + let vector: Vec> = + pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + docid_vectors_map.entry(docid).and_modify(|v| { + if !v.remove(&vector) { + error!("Unable to delete the vector: {:?}", vector); + } + }); } if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { // convert the vector back to a Vec let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); + docid_vectors_map.entry(docid).and_modify(|v| { + v.insert(vector); + }); } } // Extract the most common vector dimension let expected_dimension_size = { let mut dims = HashMap::new(); - vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); + docid_vectors_map + .values() + .flat_map(|v| v.iter()) + .for_each(|v| *dims.entry(v.len()).or_insert(0) += 1); dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) }; @@ -400,7 +457,10 @@ pub(crate) fn write_typed_chunk_into_index( // prepare the vectors before inserting them in the HNSW. let mut points = Vec::new(); let mut docids = Vec::new(); - for (docid, vector) in vectors_set { + for (docid, vector) in docid_vectors_map + .into_iter() + .flat_map(|(docid, vectors)| std::iter::repeat(docid).zip(vectors)) + { if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { return Err(UserError::InvalidVectorDimensions { expected: expected_dimension_size.unwrap_or(vector.len()), diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 712e595e9..5e3683f32 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -3,7 +3,7 @@ use std::result::Result as StdResult; use charabia::{Normalize, Tokenizer, TokenizerBuilder}; use deserr::{DeserializeError, Deserr}; -use itertools::Itertools; +use itertools::{EitherOrBoth, Itertools}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use time::OffsetDateTime; @@ -15,6 +15,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; +use crate::vector::settings::{EmbeddingSettings, PromptSettings}; +use crate::vector::EmbeddingConfig; use crate::{FieldsIdsMap, Index, OrderBy, Result}; #[derive(Debug, Clone, PartialEq, Eq, Copy)] @@ -73,6 +75,13 @@ impl Setting { otherwise => otherwise, } } + + pub fn apply(&mut self, new: Self) { + if let Setting::NotSet = new { + return; + } + *self = new; + } } impl Serialize for Setting { @@ -129,6 +138,7 @@ pub struct Settings<'a, 't, 'i> { sort_facet_values_by: Setting>, pagination_max_total_hits: Setting, proximity_precision: Setting, + embedder_settings: Setting>>, } impl<'a, 't, 'i> Settings<'a, 't, 'i> { @@ -161,6 +171,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { sort_facet_values_by: Setting::NotSet, pagination_max_total_hits: Setting::NotSet, proximity_precision: Setting::NotSet, + embedder_settings: Setting::NotSet, indexer_config, } } @@ -343,6 +354,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { self.proximity_precision = Setting::Reset; } + pub fn set_embedder_settings(&mut self, value: BTreeMap>) { + self.embedder_settings = Setting::Set(value); + } + + pub fn reset_embedder_settings(&mut self) { + self.embedder_settings = Setting::Reset; + } + fn reindex( &mut self, progress_callback: &FP, @@ -890,6 +909,60 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Ok(changed) } + fn update_embedding_configs(&mut self) -> Result { + let update = match std::mem::take(&mut self.embedder_settings) { + Setting::Set(configs) => { + let mut changed = false; + let old_configs = self.index.embedding_configs(self.wtxn)?; + let old_configs: BTreeMap> = + old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect(); + + let mut new_configs = BTreeMap::new(); + for joined in old_configs + .into_iter() + .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) + { + match joined { + EitherOrBoth::Both((name, mut old), (_, new)) => { + old.apply(new); + let new = validate_prompt(&name, old)?; + changed = true; + new_configs.insert(name, new); + } + EitherOrBoth::Left((name, setting)) => { + new_configs.insert(name, setting); + } + EitherOrBoth::Right((name, setting)) => { + let setting = validate_prompt(&name, setting)?; + changed = true; + new_configs.insert(name, setting); + } + } + } + let new_configs: Vec<(String, EmbeddingConfig)> = new_configs + .into_iter() + .filter_map(|(name, setting)| match setting { + Setting::Set(value) => Some((name, value.into())), + Setting::Reset => None, + Setting::NotSet => Some((name, EmbeddingSettings::default().into())), + }) + .collect(); + if new_configs.is_empty() { + self.index.delete_embedding_configs(self.wtxn)?; + } else { + self.index.put_embedding_configs(self.wtxn, new_configs)?; + } + changed + } + Setting::Reset => { + self.index.delete_embedding_configs(self.wtxn)?; + true + } + Setting::NotSet => false, + }; + Ok(update) + } + pub fn execute(mut self, progress_callback: FP, should_abort: FA) -> Result<()> where FP: Fn(UpdateIndexingStep) + Sync, @@ -927,6 +1000,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { let searchable_updated = self.update_searchable()?; let exact_attributes_updated = self.update_exact_attributes()?; let proximity_precision = self.update_proximity_precision()?; + // TODO: very rough approximation of the needs for reindexing where any change will result in + // a full reindexing. + // What can be done instead: + // 1. Only change the distance on a distance change + // 2. Only change the name -> embedder mapping on a name change + // 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage + let embedding_configs_updated = self.update_embedding_configs()?; if stop_words_updated || non_separator_tokens_updated @@ -937,6 +1017,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { || searchable_updated || exact_attributes_updated || proximity_precision + || embedding_configs_updated { self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?; } @@ -945,6 +1026,34 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { } } +fn validate_prompt( + name: &str, + new: Setting, +) -> Result> { + match new { + Setting::Set(EmbeddingSettings { + embedder_options, + prompt: + Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), + }) => { + // validate + let template = crate::prompt::Prompt::new(template, None, None) + .map(|prompt| crate::prompt::PromptData::from(prompt).template) + .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; + + Ok(Setting::Set(EmbeddingSettings { + embedder_options, + prompt: Setting::Set(PromptSettings { + template: Setting::Set(template), + strategy, + fallback, + }), + })) + } + new => Ok(new), + } +} + #[cfg(test)] mod tests { use big_s::S; @@ -1763,6 +1872,7 @@ mod tests { sort_facet_values_by, pagination_max_total_hits, proximity_precision, + embedder_settings, } = settings; assert!(matches!(searchable_fields, Setting::NotSet)); assert!(matches!(displayed_fields, Setting::NotSet)); @@ -1785,6 +1895,7 @@ mod tests { assert!(matches!(sort_facet_values_by, Setting::NotSet)); assert!(matches!(pagination_max_total_hits, Setting::NotSet)); assert!(matches!(proximity_precision, Setting::NotSet)); + assert!(matches!(embedder_settings, Setting::NotSet)); }) .unwrap(); } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs new file mode 100644 index 000000000..1ae7a4678 --- /dev/null +++ b/milli/src/vector/error.rs @@ -0,0 +1,229 @@ +use std::path::PathBuf; + +use hf_hub::api::sync::ApiError; + +use crate::error::FaultSource; +use crate::vector::openai::OpenAiError; + +#[derive(Debug, thiserror::Error)] +#[error("Error while generating embeddings: {inner}")] +pub struct Error { + pub inner: Box, +} + +impl> From for Error { + fn from(value: I) -> Self { + Self { inner: Box::new(value.into()) } + } +} + +impl Error { + pub fn fault(&self) -> FaultSource { + match &*self.inner { + ErrorKind::NewEmbedderError(inner) => inner.fault, + ErrorKind::EmbedError(inner) => inner.fault, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ErrorKind { + #[error(transparent)] + NewEmbedderError(#[from] NewEmbedderError), + #[error(transparent)] + EmbedError(#[from] EmbedError), +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct EmbedError { + pub kind: EmbedErrorKind, + pub fault: FaultSource, +} + +#[derive(Debug, thiserror::Error)] +pub enum EmbedErrorKind { + #[error("could not tokenize: {0}")] + Tokenize(Box), + #[error("unexpected tensor shape: {0}")] + TensorShape(candle_core::Error), + #[error("unexpected tensor value: {0}")] + TensorValue(candle_core::Error), + #[error("could not run model: {0}")] + ModelForward(candle_core::Error), + #[error("could not reach OpenAI: {0}")] + OpenAiNetwork(reqwest::Error), + #[error("unexpected response from OpenAI: {0}")] + OpenAiUnexpected(reqwest::Error), + #[error("could not authenticate against OpenAI: {0}")] + OpenAiAuth(OpenAiError), + #[error("sent too many requests to OpenAI: {0}")] + OpenAiTooManyRequests(OpenAiError), + #[error("received internal error from OpenAI: {0}")] + OpenAiInternalServerError(OpenAiError), + #[error("sent too many tokens in a request to OpenAI: {0}")] + OpenAiTooManyTokens(OpenAiError), + #[error("received unhandled HTTP status code {0} from OpenAI")] + OpenAiUnhandledStatusCode(u16), +} + +impl EmbedError { + pub fn tokenize(inner: Box) -> Self { + Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime } + } + + pub fn tensor_shape(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug } + } + + pub fn tensor_value(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug } + } + + pub fn model_forward(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } + } + + pub fn openai_network(inner: reqwest::Error) -> Self { + Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } + } + + pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } + } + + pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } + } + + pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } + } + + pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct NewEmbedderError { + pub kind: NewEmbedderErrorKind, + pub fault: FaultSource, +} + +impl NewEmbedderError { + pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError { + let open_config = OpenConfig { filename: config_filename, inner }; + + Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime } + } + + pub fn deserialize_config( + config: String, + config_filename: PathBuf, + inner: serde_json::Error, + ) -> NewEmbedderError { + let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; + Self { + kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), + fault: FaultSource::Runtime, + } + } + + pub fn open_tokenizer( + tokenizer_filename: PathBuf, + inner: Box, + ) -> NewEmbedderError { + let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner }; + Self { + kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer), + fault: FaultSource::Runtime, + } + } + + pub fn new_api_fail(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug } + } + + pub fn api_get(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided } + } + + pub fn pytorch_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn safetensor_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn load_model(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } + } + + pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { + Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } + } + + pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { + Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open config at {filename:?}: {inner}")] +pub struct OpenConfig { + pub filename: PathBuf, + pub inner: std::io::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] +pub struct DeserializeConfig { + pub config: String, + pub filename: PathBuf, + pub inner: serde_json::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open tokenizer at {filename}: {inner}")] +pub struct OpenTokenizer { + pub filename: PathBuf, + #[source] + pub inner: Box, +} + +#[derive(Debug, thiserror::Error)] +pub enum NewEmbedderErrorKind { + // hf + #[error(transparent)] + OpenConfig(OpenConfig), + #[error(transparent)] + DeserializeConfig(DeserializeConfig), + #[error(transparent)] + OpenTokenizer(OpenTokenizer), + #[error("could not build weights from Pytorch weights: {0}")] + PytorchWeight(candle_core::Error), + #[error("could not build weights from Safetensor weights: {0}")] + SafetensorWeight(candle_core::Error), + #[error("could not spawn HG_HUB API client: {0}")] + NewApiFail(ApiError), + #[error("fetching file from HG_HUB failed: {0}")] + ApiGet(ApiError), + #[error("loading model failed: {0}")] + LoadModel(candle_core::Error), + // openai + #[error("initializing web client for sending embedding requests failed: {0}")] + InitWebClient(reqwest::Error), + #[error("The API key passed to Authorization error was in an invalid format: {0}")] + InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), +} diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs new file mode 100644 index 000000000..81cdd4b34 --- /dev/null +++ b/milli/src/vector/hf.rs @@ -0,0 +1,192 @@ +use candle_core::Tensor; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself +use hf_hub::api::sync::Api; +use hf_hub::{Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +pub use super::error::{EmbedError, Error, NewEmbedderError}; +use super::{Embedding, Embeddings}; + +#[derive( + Debug, + Clone, + Copy, + Default, + Hash, + PartialEq, + Eq, + serde::Deserialize, + serde::Serialize, + deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum WeightSource { + #[default] + Safetensors, + Pytorch, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub model: String, + pub revision: Option, + pub weight_source: WeightSource, + pub normalize_embeddings: bool, +} + +impl EmbedderOptions { + pub fn new() -> Self { + Self { + //model: "sentence-transformers/all-MiniLM-L6-v2".to_string(), + model: "BAAI/bge-base-en-v1.5".to_string(), + //revision: Some("refs/pr/21".to_string()), + revision: None, + //weight_source: Default::default(), + weight_source: WeightSource::Pytorch, + normalize_embeddings: true, + } + } +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self::new() + } +} + +/// Perform embedding of documents and queries +pub struct Embedder { + model: BertModel, + tokenizer: Tokenizer, + options: EmbedderOptions, +} + +impl std::fmt::Debug for Embedder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Embedder") + .field("model", &self.options.model) + .field("tokenizer", &self.tokenizer) + .field("options", &self.options) + .finish() + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> std::result::Result { + let device = candle_core::Device::Cpu; + let repo = match options.revision.clone() { + Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), + None => Repo::model(options.model.clone()), + }; + let (config_filename, tokenizer_filename, weights_filename) = { + let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; + let api = api.repo(repo); + let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; + let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; + let weights = match options.weight_source { + WeightSource::Pytorch => { + api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? + } + WeightSource::Safetensors => { + api.get("model.safetensors").map_err(NewEmbedderError::api_get)? + } + }; + (config, tokenizer, weights) + }; + + let config = std::fs::read_to_string(&config_filename) + .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; + let config: Config = serde_json::from_str(&config).map_err(|inner| { + NewEmbedderError::deserialize_config(config, config_filename, inner) + })?; + let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) + .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; + + let vb = match options.weight_source { + WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) + .map_err(NewEmbedderError::pytorch_weight)?, + WeightSource::Safetensors => unsafe { + VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) + .map_err(NewEmbedderError::safetensor_weight)? + }, + }; + + let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; + + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + + Ok(Self { model, tokenizer, options }) + } + + pub async fn embed( + &self, + mut texts: Vec, + ) -> std::result::Result>, EmbedError> { + let tokens = match texts.len() { + 1 => vec![self + .tokenizer + .encode(texts.pop().unwrap(), true) + .map_err(EmbedError::tokenize)?], + _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?, + }; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) + }) + .collect::, EmbedError>>()?; + + let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + let embeddings = + self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + + let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + .map_err(EmbedError::tensor_shape)?; + + let embeddings: Tensor = if self.options.normalize_embeddings { + normalize_l2(&embeddings).map_err(EmbedError::tensor_value)? + } else { + embeddings + }; + + let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; + Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> std::result::Result>>, EmbedError> { + futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) + .await + } + + pub fn chunk_count_hint(&self) -> usize { + 1 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) + } +} + +fn normalize_l2(v: &Tensor) -> Result { + v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) +} diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs new file mode 100644 index 000000000..faaa7bf2a --- /dev/null +++ b/milli/src/vector/mod.rs @@ -0,0 +1,142 @@ +use self::error::{EmbedError, NewEmbedderError}; +use crate::prompt::PromptData; + +pub mod error; +pub mod hf; +pub mod openai; +pub mod settings; + +pub use self::error::Error; + +pub type Embedding = Vec; + +pub struct Embeddings { + data: Vec, + dimension: usize, +} + +impl Embeddings { + pub fn new(dimension: usize) -> Self { + Self { data: Default::default(), dimension } + } + + pub fn from_single_embedding(embedding: Vec) -> Self { + Self { dimension: embedding.len(), data: embedding } + } + + pub fn from_inner(data: Vec, dimension: usize) -> Result> { + let mut this = Self::new(dimension); + this.append(data)?; + Ok(this) + } + + pub fn dimension(&self) -> usize { + self.dimension + } + + pub fn into_inner(self) -> Vec { + self.data + } + + pub fn as_inner(&self) -> &[F] { + &self.data + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.data.as_slice().chunks_exact(self.dimension) + } + + pub fn push(&mut self, mut embedding: Vec) -> Result<(), Vec> { + if embedding.len() != self.dimension { + return Err(embedding); + } + self.data.append(&mut embedding); + Ok(()) + } + + pub fn append(&mut self, mut embeddings: Vec) -> Result<(), Vec> { + if embeddings.len() % self.dimension != 0 { + return Err(embeddings); + } + self.data.append(&mut embeddings); + Ok(()) + } +} + +#[derive(Debug)] +pub enum Embedder { + HuggingFace(hf::Embedder), + OpenAi(openai::Embedder), +} + +#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] +pub struct EmbeddingConfig { + pub embedder_options: EmbedderOptions, + pub prompt: PromptData, + // TODO: add metrics and anything needed +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub enum EmbedderOptions { + HuggingFace(hf::EmbedderOptions), + OpenAi(openai::EmbedderOptions), +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self::HuggingFace(Default::default()) + } +} + +impl EmbedderOptions { + pub fn huggingface() -> Self { + Self::HuggingFace(hf::EmbedderOptions::new()) + } + + pub fn openai(api_key: String) -> Self { + Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> std::result::Result { + Ok(match options { + EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), + EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), + }) + } + + pub async fn embed( + &self, + texts: Vec, + ) -> std::result::Result>, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed(texts).await, + Embedder::OpenAi(embedder) => embedder.embed(texts).await, + } + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> std::result::Result>>, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, + } + } + + pub fn chunk_count_hint(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), + Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), + } + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), + Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), + } + } +} diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs new file mode 100644 index 000000000..670dc8526 --- /dev/null +++ b/milli/src/vector/openai.rs @@ -0,0 +1,416 @@ +use std::fmt::Display; + +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; + +use super::error::{EmbedError, NewEmbedderError}; +use super::{Embedding, Embeddings}; + +#[derive(Debug)] +pub struct Embedder { + client: reqwest::Client, + tokenizer: tiktoken_rs::CoreBPE, + options: EmbedderOptions, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub api_key: String, + pub embedding_model: EmbeddingModel, +} + +#[derive( + Debug, + Clone, + Copy, + Default, + Hash, + PartialEq, + Eq, + serde::Serialize, + serde::Deserialize, + deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum EmbeddingModel { + #[default] + TextEmbeddingAda002, +} + +impl EmbeddingModel { + pub fn max_token(&self) -> usize { + match self { + EmbeddingModel::TextEmbeddingAda002 => 8191, + } + } + + pub fn dimensions(&self) -> usize { + match self { + EmbeddingModel::TextEmbeddingAda002 => 1536, + } + } + + pub fn name(&self) -> &'static str { + match self { + EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", + } + } + + pub fn from_name(name: &'static str) -> Option { + match name { + "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), + _ => None, + } + } +} + +pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; + +impl EmbedderOptions { + pub fn with_default_model(api_key: String) -> Self { + Self { api_key, embedding_model: Default::default() } + } + + pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self { + Self { api_key, embedding_model } + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key)) + .map_err(NewEmbedderError::openai_invalid_api_key_format)?, + ); + headers.insert( + reqwest::header::CONTENT_TYPE, + reqwest::header::HeaderValue::from_static("application/json"), + ); + let client = reqwest::ClientBuilder::new() + .default_headers(headers) + .build() + .map_err(NewEmbedderError::openai_initialize_web_client)?; + + // looking at the code it is very unclear that this can actually fail. + let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + + Ok(Self { options, client, tokenizer }) + } + + pub async fn embed(&self, texts: Vec) -> Result>, EmbedError> { + let mut tokenized = false; + + for attempt in 0..7 { + let result = if tokenized { + self.try_embed_tokenized(&texts).await + } else { + self.try_embed(&texts).await + }; + + let retry_duration = match result { + Ok(embeddings) => return Ok(embeddings), + Err(retry) => { + log::warn!("Failed: {}", retry.error); + tokenized |= retry.must_tokenize(); + retry.into_duration(attempt) + } + }?; + log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); + tokio::time::sleep(retry_duration).await; + } + + let result = if tokenized { + self.try_embed_tokenized(&texts).await + } else { + self.try_embed(&texts).await + }; + + result.map_err(Retry::into_error) + } + + async fn check_response(response: reqwest::Response) -> Result { + if !response.status().is_success() { + match response.status() { + StatusCode::UNAUTHORIZED => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + return Err(Retry::give_up(EmbedError::openai_auth_error( + error_response.error, + ))); + } + StatusCode::TOO_MANY_REQUESTS => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + return Err(Retry::rate_limited(EmbedError::openai_too_many_requests( + error_response.error, + ))); + } + StatusCode::INTERNAL_SERVER_ERROR => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + return Err(Retry::retry_later(EmbedError::openai_internal_server_error( + error_response.error, + ))); + } + StatusCode::SERVICE_UNAVAILABLE => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + return Err(Retry::retry_later(EmbedError::openai_internal_server_error( + error_response.error, + ))); + } + StatusCode::BAD_REQUEST => { + // Most probably, one text contained too many tokens + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt."); + + return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens( + error_response.error, + ))); + } + code => { + return Err(Retry::give_up(EmbedError::openai_unhandled_status_code( + code.as_u16(), + ))); + } + } + } + Ok(response) + } + + async fn try_embed + serde::Serialize>( + &self, + texts: &[S], + ) -> Result>, Retry> { + for text in texts { + log::trace!("Received prompt: {}", text.as_ref()) + } + let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; + let response = self + .client + .post(OPENAI_EMBEDDINGS_URL) + .json(&request) + .send() + .await + .map_err(EmbedError::openai_network) + .map_err(Retry::retry_later)?; + + let response = Self::check_response(response).await?; + + let response: OpenAiResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + log::trace!("response: {:?}", response.data); + + Ok(response + .data + .into_iter() + .map(|data| Embeddings::from_single_embedding(data.embedding)) + .collect()) + } + + async fn try_embed_tokenized(&self, text: &[String]) -> Result>, Retry> { + pub const OVERLAP_SIZE: usize = 200; + let mut all_embeddings = Vec::with_capacity(text.len()); + for text in text { + let max_token_count = self.options.embedding_model.max_token(); + let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let len = encoded.len(); + if len < max_token_count { + all_embeddings.append(&mut self.try_embed(&[text]).await?); + continue; + } + + let mut tokens = encoded.as_slice(); + let mut embeddings_for_prompt = + Embeddings::new(self.options.embedding_model.dimensions()); + while tokens.len() > max_token_count { + let window = &tokens[..max_token_count]; + embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); + + tokens = &tokens[max_token_count - OVERLAP_SIZE..]; + } + + // end of text + embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); + + all_embeddings.push(embeddings_for_prompt); + } + Ok(all_embeddings) + } + + async fn embed_tokens(&self, tokens: &[usize]) -> Result { + for attempt in 0..9 { + let duration = match self.try_embed_tokens(tokens).await { + Ok(embedding) => return Ok(embedding), + Err(retry) => retry.into_duration(attempt), + } + .map_err(Retry::retry_later)?; + + tokio::time::sleep(duration).await; + } + + self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) + } + + async fn try_embed_tokens(&self, tokens: &[usize]) -> Result { + let request = + OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; + let response = self + .client + .post(OPENAI_EMBEDDINGS_URL) + .json(&request) + .send() + .await + .map_err(EmbedError::openai_network) + .map_err(Retry::retry_later)?; + + let response = Self::check_response(response).await?; + + let mut response: OpenAiResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) + .await + } + + pub fn chunk_count_hint(&self) -> usize { + 10 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + 10 + } +} + +// retrying in case of failure + +struct Retry { + error: EmbedError, + strategy: RetryStrategy, +} + +enum RetryStrategy { + GiveUp, + Retry, + RetryTokenized, + RetryAfterRateLimit, +} + +impl Retry { + fn give_up(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + fn retry_later(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + fn retry_tokenized(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryTokenized } + } + + fn rate_limited(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryAfterRateLimit } + } + + fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), + RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), + RetryStrategy::RetryAfterRateLimit => { + Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) + } + } + } + + fn must_tokenize(&self) -> bool { + matches!(self.strategy, RetryStrategy::RetryTokenized) + } + + fn into_error(self) -> EmbedError { + self.error + } +} + +// openai api structs + +#[derive(Debug, Serialize)] +struct OpenAiRequest<'a, S: AsRef + serde::Serialize> { + model: &'a str, + input: &'a [S], +} + +#[derive(Debug, Serialize)] +struct OpenAiTokensRequest<'a> { + model: &'a str, + input: &'a [usize], +} + +#[derive(Debug, Deserialize)] +struct OpenAiResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAiErrorResponse { + error: OpenAiError, +} + +#[derive(Debug, Deserialize)] +pub struct OpenAiError { + message: String, + // type: String, + code: Option, +} + +impl Display for OpenAiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.code { + Some(code) => write!(f, "{} ({})", self.message, code), + None => write!(f, "{}", self.message), + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenAiEmbedding { + embedding: Embedding, + // object: String, + // index: usize, +} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs new file mode 100644 index 000000000..2c0cf7924 --- /dev/null +++ b/milli/src/vector/settings.rs @@ -0,0 +1,308 @@ +use deserr::Deserr; +use serde::{Deserialize, Serialize}; + +use crate::prompt::{PromptData, PromptFallbackStrategy}; +use crate::update::Setting; +use crate::vector::hf::WeightSource; +use crate::vector::EmbeddingConfig; + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct EmbeddingSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")] + #[deserr(default, rename = "source")] + pub embedder_options: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub prompt: Setting, +} + +impl EmbeddingSettings { + pub fn apply(&mut self, new: Self) { + let EmbeddingSettings { embedder_options, prompt } = new; + self.embedder_options.apply(embedder_options); + self.prompt.apply(prompt); + } +} + +impl From for EmbeddingSettings { + fn from(value: EmbeddingConfig) -> Self { + Self { + embedder_options: Setting::Set(value.embedder_options.into()), + prompt: Setting::Set(value.prompt.into()), + } + } +} + +impl From for EmbeddingConfig { + fn from(value: EmbeddingSettings) -> Self { + let mut this = Self::default(); + let EmbeddingSettings { embedder_options, prompt } = value; + if let Some(embedder_options) = embedder_options.set() { + this.embedder_options = embedder_options.into(); + } + if let Some(prompt) = prompt.set() { + this.prompt = prompt.into(); + } + this + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct PromptSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub template: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub strategy: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub fallback: Setting, +} + +impl PromptSettings { + pub fn apply(&mut self, new: Self) { + let PromptSettings { template, strategy, fallback } = new; + self.template.apply(template); + self.strategy.apply(strategy); + self.fallback.apply(fallback); + } +} + +impl From for PromptSettings { + fn from(value: PromptData) -> Self { + Self { + template: Setting::Set(value.template), + strategy: Setting::Set(value.strategy), + fallback: Setting::Set(value.fallback), + } + } +} + +impl From for PromptData { + fn from(value: PromptSettings) -> Self { + let mut this = PromptData::default(); + let PromptSettings { template, strategy, fallback } = value; + if let Some(template) = template.set() { + this.template = template; + } + if let Some(strategy) = strategy.set() { + this.strategy = strategy; + } + if let Some(fallback) = fallback.set() { + this.fallback = fallback; + } + this + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +pub enum EmbedderSettings { + HuggingFace(Setting), + OpenAi(Setting), +} + +impl Deserr for EmbedderSettings +where + E: deserr::DeserializeError, +{ + fn deserialize_from_value( + value: deserr::Value, + location: deserr::ValuePointerRef, + ) -> Result { + match value { + deserr::Value::Map(map) => { + if deserr::Map::len(&map) != 1 { + return Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::Unexpected { + msg: format!( + "Expected a single field, got {} fields", + deserr::Map::len(&map) + ), + }, + location, + ))); + } + let mut it = deserr::Map::into_iter(map); + let (k, v) = it.next().unwrap(); + + match k.as_str() { + "huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set( + HfEmbedderSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + ))), + "openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set( + OpenAiEmbedderSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + ))), + other => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::UnknownKey { + key: other, + accepted: &["huggingFace", "openAi"], + }, + location, + ))), + } + } + _ => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::IncorrectValueKind { + actual: value, + accepted: &[deserr::ValueKind::Map], + }, + location, + ))), + } + } +} + +impl Default for EmbedderSettings { + fn default() -> Self { + Self::HuggingFace(Default::default()) + } +} + +impl From for EmbedderSettings { + fn from(value: crate::vector::EmbedderOptions) -> Self { + match value { + crate::vector::EmbedderOptions::HuggingFace(hf) => { + Self::HuggingFace(Setting::Set(hf.into())) + } + crate::vector::EmbedderOptions::OpenAi(openai) => { + Self::OpenAi(Setting::Set(openai.into())) + } + } + } +} + +impl From for crate::vector::EmbedderOptions { + fn from(value: EmbedderSettings) -> Self { + match value { + EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), + EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), + EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), + EmbedderSettings::OpenAi(_setting) => Self::OpenAi( + crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()), + ), + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct HfEmbedderSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub model: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub revision: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub weight_source: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub normalize_embeddings: Setting, +} + +impl HfEmbedderSettings { + pub fn apply(&mut self, new: Self) { + let HfEmbedderSettings { + model, + revision, + weight_source, + normalize_embeddings: normalize_embedding, + } = new; + self.model.apply(model); + self.revision.apply(revision); + self.weight_source.apply(weight_source); + self.normalize_embeddings.apply(normalize_embedding); + } +} + +impl From for HfEmbedderSettings { + fn from(value: crate::vector::hf::EmbedderOptions) -> Self { + Self { + model: Setting::Set(value.model), + revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), + weight_source: Setting::Set(value.weight_source), + normalize_embeddings: Setting::Set(value.normalize_embeddings), + } + } +} + +impl From for crate::vector::hf::EmbedderOptions { + fn from(value: HfEmbedderSettings) -> Self { + let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value; + let mut this = Self::default(); + if let Some(model) = model.set() { + this.model = model; + } + if let Some(revision) = revision.set() { + this.revision = Some(revision); + } + if let Some(weight_source) = weight_source.set() { + this.weight_source = weight_source; + } + if let Some(normalize_embeddings) = normalize_embeddings.set() { + this.normalize_embeddings = normalize_embeddings; + } + this + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct OpenAiEmbedderSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub api_key: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub embedding_model: Setting, +} + +impl OpenAiEmbedderSettings { + pub fn apply(&mut self, new: Self) { + let Self { api_key, embedding_model: embedding_mode } = new; + self.api_key.apply(api_key); + self.embedding_model.apply(embedding_mode); + } +} + +impl From for OpenAiEmbedderSettings { + fn from(value: crate::vector::openai::EmbedderOptions) -> Self { + Self { + api_key: Setting::Set(value.api_key), + embedding_model: Setting::Set(value.embedding_model), + } + } +} + +impl From for crate::vector::openai::EmbedderOptions { + fn from(value: OpenAiEmbedderSettings) -> Self { + let OpenAiEmbedderSettings { api_key, embedding_model } = value; + Self { + api_key: api_key.set().unwrap_or_else(infer_api_key), + embedding_model: embedding_model.set().unwrap_or_default(), + } + } +} + +fn infer_api_key() -> String { + /// FIXME: get key from instance options? + std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default() +} From dde3a04679055d69838ce13cbe23a3a509e97a48 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 7 Dec 2023 13:33:15 +0100 Subject: [PATCH 03/28] WIP arroy integration --- Cargo.lock | 181 ++++++++++++++---- milli/Cargo.toml | 12 +- milli/src/index.rs | 111 +++++------ milli/src/lib.rs | 1 - milli/src/readable_slices.rs | 85 -------- milli/src/search/new/vector_sort.rs | 57 ++---- milli/src/update/clear_documents.rs | 8 +- .../src/update/index_documents/extract/mod.rs | 3 +- milli/src/update/index_documents/mod.rs | 4 + .../src/update/index_documents/typed_chunk.rs | 144 +++++++------- 10 files changed, 280 insertions(+), 326 deletions(-) delete mode 100644 milli/src/readable_slices.rs diff --git a/Cargo.lock b/Cargo.lock index a407244b1..ed6d0c291 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -380,6 +380,24 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arroy" +version = "0.1.0" +source = "git+https://github.com/meilisearch/arroy.git#4b59476f457e5443ff250ea10d40d8b66a692674" +dependencies = [ + "bytemuck", + "byteorder", + "heed", + "log", + "memmap2 0.9.0", + "ordered-float 4.2.0", + "rand", + "rayon", + "roaring", + "tempfile", + "thiserror", +] + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -537,9 +555,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" dependencies = [ "serde", ] @@ -629,9 +647,9 @@ dependencies = [ [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" @@ -687,7 +705,7 @@ dependencies = [ "byteorder", "gemm", "half 2.3.1", - "memmap2", + "memmap2 0.7.1", "num-traits", "num_cpus", "rand", @@ -1561,23 +1579,12 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.2" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b30f669a7961ef1631673d2766cc92f52d64f7ef354d4fe0ddfd30ed52f0f4f" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "errno-dragonfly", - "libc", - "windows-sys 0.48.0", -] - -[[package]] -name = "errno-dragonfly" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" -dependencies = [ - "cc", "libc", + "windows-sys 0.52.0", ] [[package]] @@ -2117,7 +2124,7 @@ version = "0.20.0-alpha.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9648a50991c86df7d00c56c268c27754fcf4c80be2ba57fc4a00dc928c6fe934" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "bytemuck", "byteorder", "heed-traits", @@ -2868,7 +2875,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a" dependencies = [ "num_cpus", - "ordered-float", + "ordered-float 3.7.0", "parking_lot", "rand", "rayon", @@ -2911,7 +2918,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", - "rustix 0.38.7", + "rustix 0.38.26", "windows-sys 0.48.0", ] @@ -3294,9 +3301,9 @@ checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" [[package]] name = "linux-raw-sys" -version = "0.4.5" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" [[package]] name = "liquid" @@ -3412,9 +3419,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "logging_timer" @@ -3543,7 +3550,7 @@ dependencies = [ "num_cpus", "obkv", "once_cell", - "ordered-float", + "ordered-float 3.7.0", "parking_lot", "permissive-json-pointer", "pin-project-lite", @@ -3618,7 +3625,7 @@ dependencies = [ "fst", "insta", "meili-snap", - "memmap2", + "memmap2 0.7.1", "milli", "roaring", "serde", @@ -3662,6 +3669,15 @@ dependencies = [ "stable_deref_trait", ] +[[package]] +name = "memmap2" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deaba38d7abf1d4cca21cc89e932e542ba2b9258664d2a9ef0e61512039c9375" +dependencies = [ + "libc", +] + [[package]] name = "memoffset" version = "0.9.0" @@ -3675,6 +3691,7 @@ dependencies = [ name = "milli" version = "1.5.1" dependencies = [ + "arroy", "big_s", "bimap", "bincode", @@ -3711,12 +3728,12 @@ dependencies = [ "maplit", "md5", "meili-snap", - "memmap2", + "memmap2 0.7.1", "mimalloc", "nolife", "obkv", "once_cell", - "ordered-float", + "ordered-float 3.7.0", "puffin", "rand", "rand_pcg", @@ -3983,7 +4000,7 @@ version = "0.10.59" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a257ad03cd8fb16ad4172fedf8094451e1af1c4b70097636ef2eac9a5f0cc33" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "cfg-if", "foreign-types", "libc", @@ -4036,6 +4053,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ordered-float" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +dependencies = [ + "num-traits", +] + [[package]] name = "page_size" version = "0.5.0" @@ -4553,6 +4579,15 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_users" version = "0.4.3" @@ -4736,15 +4771,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.7" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "172891ebdceb05aa0005f533a6cbfca599ddd7d966f6f5d4d9b2e70478e70399" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.5", - "windows-sys 0.48.0", + "linux-raw-sys 0.4.12", + "windows-sys 0.52.0", ] [[package]] @@ -5279,14 +5314,14 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.7.1" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc02fddf48964c42031a0b3fe0428320ecf3a73c401040fc0096f97794310651" +checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" dependencies = [ "cfg-if", "fastrand", - "redox_syscall 0.3.5", - "rustix 0.38.7", + "redox_syscall 0.4.1", + "rustix 0.38.26", "windows-sys 0.48.0", ] @@ -5997,6 +6032,15 @@ dependencies = [ "windows-targets 0.48.1", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", +] + [[package]] name = "windows-targets" version = "0.42.2" @@ -6027,6 +6071,21 @@ dependencies = [ "windows_x86_64_msvc 0.48.0", ] +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" @@ -6039,6 +6098,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" @@ -6051,6 +6116,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.42.2" @@ -6063,6 +6134,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.42.2" @@ -6075,6 +6152,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" @@ -6087,6 +6170,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" @@ -6099,6 +6188,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" @@ -6111,6 +6206,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winnow" version = "0.5.4" diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 38931ca0f..0aee03b2f 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -89,6 +89,8 @@ reqwest = { version = "0.11.16", features = [ ], default-features = false } tiktoken-rs = "0.5.7" liquid = "0.26.4" +arroy = { git = "https://github.com/meilisearch/arroy.git", version = "0.1.0" } +rand = "0.8.5" [dev-dependencies] mimalloc = { version = "0.1.37", default-features = false } @@ -100,15 +102,7 @@ meili-snap = { path = "../meili-snap" } rand = { version = "0.8.5", features = ["small_rng"] } [features] -all-tokenizations = [ - "charabia/chinese", - "charabia/hebrew", - "charabia/japanese", - "charabia/thai", - "charabia/korean", - "charabia/greek", - "charabia/khmer", -] +all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] # Use POSIX semaphores instead of SysV semaphores in LMDB # For more information on this feature, see heed's Cargo.toml diff --git a/milli/src/index.rs b/milli/src/index.rs index 307d87906..c494f2f2b 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -22,7 +22,6 @@ use crate::heed_codec::{ BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec, }; use crate::proximity::ProximityPrecision; -use crate::readable_slices::ReadableSlices; use crate::vector::EmbeddingConfig; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, @@ -49,10 +48,6 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; - /// The prefix of the key that is used to store the, potential big, HNSW structure. - /// It is concatenated with a big-endian encoded number (non-human readable). - /// e.g. vector-hnsw0x0032. - pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; pub const SEARCHABLE_FIELDS_KEY: &str = "searchable-fields"; pub const USER_DEFINED_SEARCHABLE_FIELDS_KEY: &str = "user-defined-searchable-fields"; @@ -75,6 +70,7 @@ pub mod main_key { pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; pub const PROXIMITY_PRECISION: &str = "proximity-precision"; + pub const VECTOR_UNAVAILABLE_VECTOR_IDS: &str = "vector-unavailable-vector-ids"; pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; } @@ -102,6 +98,9 @@ pub mod db_name { pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; + pub const VECTOR_DOCID_IDS: &str = "vector-docid-ids"; + pub const VECTOR_EMBEDDER_CATEGORY_ID: &str = "vector-embedder-category-id"; + pub const VECTOR_ARROY: &str = "vector-arroy"; pub const DOCUMENTS: &str = "documents"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } @@ -168,8 +167,16 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, - /// Maps a vector id to the document id that have it. + /// Maps a vector id to its document id. pub vector_id_docid: Database, + /// Maps a doc id to its vector ids. + pub docid_vector_ids: Database, + + /// Maps an embedder name to its id in the arroy store. + pub embedder_category_id: Database, + + /// Vector store based on arroy™. + pub vector_arroy: arroy::Database, /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, @@ -184,7 +191,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(24); + options.max_dbs(27); let env = options.open(path)?; let mut wtxn = env.write_txn()?; @@ -224,7 +231,13 @@ impl Index { env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; + // vector stuff let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; + let docid_vector_ids = env.create_database(&mut wtxn, Some(VECTOR_DOCID_IDS))?; + let embedder_category_id = + env.create_database(&mut wtxn, Some(VECTOR_EMBEDDER_CATEGORY_ID))?; + let vector_arroy = env.create_database(&mut wtxn, Some(VECTOR_ARROY))?; + let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; wtxn.commit()?; @@ -255,6 +268,9 @@ impl Index { field_id_docid_facet_f64s, field_id_docid_facet_strings, vector_id_docid, + vector_arroy, + docid_vector_ids, + embedder_category_id, documents, }) } @@ -477,63 +493,6 @@ impl Index { None => Ok(RoaringBitmap::new()), } } - - /* vector HNSW */ - - /// Writes the provided `hnsw`. - pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { - // We must delete all the chunks before we write the new HNSW chunks. - self.delete_vector_hnsw(wtxn)?; - - let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB - let bytes = bincode::serialize(hnsw).map_err(Into::into).map_err(heed::Error::Encoding)?; - for (i, chunk) in bytes.chunks(chunk_size).enumerate() { - let i = i as u32; - let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec(); - key.extend_from_slice(&i.to_be_bytes()); - self.main.remap_types::().put(wtxn, &key, chunk)?; - } - Ok(()) - } - - /// Delete the `hnsw`. - pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { - let mut iter = self - .main - .remap_types::() - .prefix_iter_mut(wtxn, main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes())?; - let mut deleted = false; - while iter.next().transpose()?.is_some() { - // We do not keep a reference to the key or the value. - unsafe { deleted |= iter.del_current()? }; - } - Ok(deleted) - } - - /// Returns the `hnsw`. - pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { - let mut slices = Vec::new(); - for result in self - .main - .remap_types::() - .prefix_iter(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)? - { - let (_, slice) = result?; - slices.push(slice); - } - - if slices.is_empty() { - Ok(None) - } else { - let readable_slices: ReadableSlices<_> = slices.into_iter().collect(); - Ok(Some( - bincode::deserialize_from(readable_slices) - .map_err(Into::into) - .map_err(heed::Error::Decoding)?, - )) - } - } - /* field distribution */ /// Writes the field distribution which associates every field name with @@ -1557,6 +1516,30 @@ impl Index { .get(rtxn, main_key::EMBEDDING_CONFIGS)? .unwrap_or_default()) } + + pub(crate) fn put_unavailable_vector_ids( + &self, + wtxn: &mut RwTxn<'_>, + unavailable_vector_ids: RoaringBitmap, + ) -> heed::Result<()> { + self.main.remap_types::().put( + wtxn, + main_key::VECTOR_UNAVAILABLE_VECTOR_IDS, + &unavailable_vector_ids, + ) + } + + pub(crate) fn delete_unavailable_vector_ids(&self, wtxn: &mut RwTxn<'_>) -> heed::Result { + self.main.remap_key_type::().delete(wtxn, main_key::VECTOR_UNAVAILABLE_VECTOR_IDS) + } + + pub fn unavailable_vector_ids(&self, rtxn: &RoTxn<'_>) -> Result { + Ok(self + .main + .remap_types::() + .get(rtxn, main_key::VECTOR_UNAVAILABLE_VECTOR_IDS)? + .unwrap_or_default()) + } } #[cfg(test)] diff --git a/milli/src/lib.rs b/milli/src/lib.rs index b3c15e205..b865747e0 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -19,7 +19,6 @@ pub mod heed_codec; pub mod index; pub mod prompt; pub mod proximity; -mod readable_slices; pub mod score_details; mod search; pub mod update; diff --git a/milli/src/readable_slices.rs b/milli/src/readable_slices.rs deleted file mode 100644 index 7f5be214f..000000000 --- a/milli/src/readable_slices.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::io::{self, Read}; -use std::iter::FromIterator; - -pub struct ReadableSlices { - inner: Vec, - pos: u64, -} - -impl FromIterator for ReadableSlices { - fn from_iter>(iter: T) -> Self { - ReadableSlices { inner: iter.into_iter().collect(), pos: 0 } - } -} - -impl> Read for ReadableSlices { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { - let original_buf_len = buf.len(); - - // We explore the list of slices to find the one where we must start reading. - let mut pos = self.pos; - let index = match self - .inner - .iter() - .map(|s| s.as_ref().len() as u64) - .position(|size| pos.checked_sub(size).map(|p| pos = p).is_none()) - { - Some(index) => index, - None => return Ok(0), - }; - - let mut inner_pos = pos as usize; - for slice in &self.inner[index..] { - let slice = &slice.as_ref()[inner_pos..]; - - if buf.len() > slice.len() { - // We must exhaust the current slice and go to the next one there is not enough here. - buf[..slice.len()].copy_from_slice(slice); - buf = &mut buf[slice.len()..]; - inner_pos = 0; - } else { - // There is enough in this slice to fill the remaining bytes of the buffer. - // Let's break just after filling it. - buf.copy_from_slice(&slice[..buf.len()]); - buf = &mut []; - break; - } - } - - let written = original_buf_len - buf.len(); - self.pos += written as u64; - Ok(written) - } -} - -#[cfg(test)] -mod test { - use std::io::Read; - - use super::ReadableSlices; - - #[test] - fn basic() { - let data: Vec<_> = (0..100).collect(); - let splits: Vec<_> = data.chunks(3).collect(); - let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); - - let mut output = Vec::new(); - let length = rdslices.read_to_end(&mut output).unwrap(); - assert_eq!(length, data.len()); - assert_eq!(output, data); - } - - #[test] - fn small_reads() { - let data: Vec<_> = (0..u8::MAX).collect(); - let splits: Vec<_> = data.chunks(27).collect(); - let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); - - let buffer = &mut [0; 45]; - let length = rdslices.read(buffer).unwrap(); - let expected: Vec<_> = (0..buffer.len() as u8).collect(); - assert_eq!(length, buffer.len()); - assert_eq!(buffer, &expected[..]); - } -} diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 831ed45cd..59b7a72c2 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -11,64 +11,31 @@ use crate::index::Hnsw; use crate::score_details::{self, ScoreDetails}; use crate::{Result, SearchContext, SearchLogger, UserError}; -pub struct VectorSort { +pub struct VectorSort<'ctx, Q: RankingRuleQueryTrait> { query: Option, target: Vec, vector_candidates: RoaringBitmap, - scope: nolife::DynBoxScope, + reader: arroy::Reader<'ctx, arroy::distances::DotProduct>, + limit: usize, } -type Item<'a> = instant_distance::Item<'a, NDotProductPoint>; -type SearchFut = Pin>>; - -struct SearchFamily; -impl<'a> nolife::Family<'a> for SearchFamily { - type Family = Box> + 'a>; -} - -async fn search_scope( - mut time_capsule: nolife::TimeCapsule, - hnsw: Hnsw, - target: Vec, -) -> nolife::Never { - let mut search = instant_distance::Search::default(); - let it = Box::new(hnsw.search(&NDotProductPoint::new(target), &mut search)); - let mut it: Box> = it; - loop { - time_capsule.freeze(&mut it).await; - } -} - -impl VectorSort { +impl<'ctx, Q: RankingRuleQueryTrait> VectorSort<'ctx, Q> { pub fn new( - ctx: &SearchContext, + ctx: &'ctx SearchContext, target: Vec, vector_candidates: RoaringBitmap, + limit: usize, ) -> Result { - let hnsw = - ctx.index.vector_hnsw(ctx.txn)?.unwrap_or(Hnsw::builder().build_hnsw(Vec::default()).0); - - if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { - if target.len() != expected_size { - return Err(UserError::InvalidVectorDimensions { - expected: expected_size, - found: target.len(), - } - .into()); - } - } + /// FIXME? what to do in case of missing metadata + let reader = arroy::Reader::open(ctx.txn, 0, ctx.index.vector_arroy)?; let target_clone = target.clone(); - let producer = move |time_capsule| -> SearchFut { - Box::pin(search_scope(time_capsule, hnsw, target_clone)) - }; - let scope = DynBoxScope::new(producer); - Ok(Self { query: None, target, vector_candidates, scope }) + Ok(Self { query: None, target, vector_candidates, reader, limit }) } } -impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { +impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<'ctx, Q> { fn id(&self) -> String { "vector_sort".to_owned() } @@ -108,11 +75,11 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { }), })); } - - let scope = &mut self.scope; let target = &self.target; let vector_candidates = &self.vector_candidates; + let result = self.reader.nns_by_vector(ctx.txn, &target, count, search_k, candidates) + scope.enter(|it| { for item in it.by_ref() { let item: Item = item; diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index 59adda3e8..3b1a6c5d8 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -43,6 +43,9 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { field_id_docid_facet_f64s, field_id_docid_facet_strings, vector_id_docid, + vector_arroy, + docid_vector_ids, + embedder_category_id: _, documents, } = self.index; @@ -58,7 +61,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?; - self.index.delete_vector_hnsw(self.wtxn)?; // Clear the other databases. external_documents_ids.clear(self.wtxn)?; @@ -82,7 +84,11 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { facet_id_string_docids.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?; + // vector + vector_arroy.clear(self.wtxn)?; vector_id_docid.clear(self.wtxn)?; + docid_vector_ids.clear(self.wtxn)?; + documents.clear(self.wtxn)?; Ok(number_of_documents) diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 40b0dcd61..06bc8b609 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -312,7 +312,8 @@ fn send_original_documents_data( lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { remove_vectors, embeddings, - expected_dimension, + /// FIXME: compute an expected dimension from the manual vectors if any + expected_dimension: expected_dimension.unwrap(), manual_vectors, })) } diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 76848b628..eaac26dd3 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -15,6 +15,7 @@ use crossbeam_channel::{Receiver, Sender}; use heed::types::Str; use heed::Database; use log::debug; +use rand::SeedableRng; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; use slice_group_by::GroupBy; @@ -489,6 +490,9 @@ where } } + let writer = arroy::Writer::prepare(self.wtxn, self.index.vector_arroy, 0, 0)?; + writer.build(self.wtxn, &mut rand::rngs::StdRng::from_entropy(), None)?; + // We write the field distribution into the main database self.index.put_field_distribution(self.wtxn, &field_distribution)?; diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 36d230d00..bc82518ca 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::convert::TryInto; use std::fs::File; use std::io::{self, BufReader}; @@ -27,6 +27,7 @@ use crate::index::Hnsw; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; +use crate::update::{available_documents_ids, AvailableDocumentsIds}; use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; pub(crate) enum TypedChunk { @@ -50,7 +51,7 @@ pub(crate) enum TypedChunk { VectorPoints { remove_vectors: grenad::Reader>, embeddings: Option>>, - expected_dimension: Option, + expected_dimension: usize, manual_vectors: grenad::Reader>, }, ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), @@ -106,7 +107,7 @@ impl TypedChunk { format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) } TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => { - format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension.unwrap_or_default()) + format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension) } TypedChunk::ScriptLanguageDocids(sl_map) => { format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) @@ -373,46 +374,53 @@ pub(crate) fn write_typed_chunk_into_index( return Ok((RoaringBitmap::new(), is_merged_database)); } - let mut docid_vectors_map: HashMap>>> = - HashMap::new(); - - // We extract and store the previous vectors - if let Some(hnsw) = index.vector_hnsw(wtxn)? { - for (pid, point) in hnsw.iter() { - let pid_key = pid.into_inner(); - let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); - let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); - docid_vectors_map.entry(docid).or_default().insert(vector); - } - } + let mut unavailable_vector_ids = index.unavailable_vector_ids(&wtxn)?; + /// FIXME: allow customizing distance + /// FIXME: allow customizing index + let writer = arroy::Writer::prepare(wtxn, index.vector_arroy, 0, expected_dimension)?; // remove vectors for docids we want them removed let mut cursor = remove_vectors.into_cursor()?; while let Some((key, _)) = cursor.move_on_next()? { let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - docid_vectors_map.remove(&docid); + let Some(to_remove_vector_ids) = index.docid_vector_ids.get(&wtxn, &docid)? else { + continue; + }; + unavailable_vector_ids -= to_remove_vector_ids; + + for item in to_remove_vector_ids { + writer.del_item(wtxn, item)?; + } } + let mut available_vector_ids = + AvailableDocumentsIds::from_documents_ids(&unavailable_vector_ids); // add generated embeddings - if let Some((embeddings, expected_dimension)) = embeddings.zip(expected_dimension) { + if let Some(embeddings) = embeddings { let mut cursor = embeddings.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - let data: Vec> = - pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); + let data = pod_collect_to_vec(value); // it is a code error to have embeddings and not expected_dimension let embeddings = crate::vector::Embeddings::from_inner(data, expected_dimension) // code error if we somehow got the wrong dimension .unwrap(); - let mut set = HashSet::new(); + let mut new_vector_ids = RoaringBitmap::new(); for embedding in embeddings.iter() { - set.insert(embedding.to_vec()); - } + /// FIXME: error when you get over 9000 + let next_vector_id = available_vector_ids.next().unwrap(); + unavailable_vector_ids.insert(next_vector_id); - docid_vectors_map.insert(docid, set); + new_vector_ids.insert(next_vector_id); + + index.vector_id_docid.put(wtxn, &next_vector_id, &docid)?; + + writer.add_item(wtxn, next_vector_id, embedding)?; + } + index.docid_vector_ids.put(wtxn, &docid, &new_vector_ids)?; } } @@ -425,68 +433,44 @@ pub(crate) fn write_typed_chunk_into_index( let vector_deladd_obkv = KvReaderDelAdd::new(value); if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { - // convert the vector back to a Vec - let vector: Vec> = - pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - docid_vectors_map.entry(docid).and_modify(|v| { - if !v.remove(&vector) { - error!("Unable to delete the vector: {:?}", vector); + let vector = pod_collect_to_vec(value); + let Some(mut docid_vector_ids) = index.docid_vector_ids.get(&wtxn, &docid)? + else { + error!("Unable to delete the vector: {:?}", vector); + continue; + }; + for item in docid_vector_ids { + /// FIXME: comparing the vectors by equality is inefficient, and dangerous by perfect equality + let candidate = writer.item_vector(&wtxn, item)?.expect("Inconsistent dbs"); + if candidate == vector { + writer.del_item(wtxn, item)?; + unavailable_vector_ids.remove(item); + index.vector_id_docid.delete(wtxn, &item)?; + docid_vector_ids.remove(item); + break; } - }); - } - if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { - // convert the vector back to a Vec - let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - docid_vectors_map.entry(docid).and_modify(|v| { - v.insert(vector); - }); - } - } - - // Extract the most common vector dimension - let expected_dimension_size = { - let mut dims = HashMap::new(); - docid_vectors_map - .values() - .flat_map(|v| v.iter()) - .for_each(|v| *dims.entry(v.len()).or_insert(0) += 1); - dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) - }; - - // Ensure that the vector lengths are correct and - // prepare the vectors before inserting them in the HNSW. - let mut points = Vec::new(); - let mut docids = Vec::new(); - for (docid, vector) in docid_vectors_map - .into_iter() - .flat_map(|(docid, vectors)| std::iter::repeat(docid).zip(vectors)) - { - if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { - return Err(UserError::InvalidVectorDimensions { - expected: expected_dimension_size.unwrap_or(vector.len()), - found: vector.len(), } - .into()); - } else { - let vector = vector.into_iter().map(OrderedFloat::into_inner).collect(); - points.push(NDotProductPoint::new(vector)); - docids.push(docid); + index.docid_vector_ids.put(wtxn, &docid, &docid_vector_ids)?; + } + let mut available_vector_ids = + AvailableDocumentsIds::from_documents_ids(&unavailable_vector_ids); + + if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { + let vector = pod_collect_to_vec(value); + let next_vector_id = available_vector_ids.next().unwrap(); + + writer.add_item(wtxn, next_vector_id, &vector)?; + unavailable_vector_ids.insert(next_vector_id); + index.vector_id_docid.put(wtxn, &next_vector_id, &docid)?; + let mut docid_vector_ids = + index.docid_vector_ids.get(&wtxn, &docid)?.unwrap_or_default(); + docid_vector_ids.insert(next_vector_id); + index.docid_vector_ids.put(wtxn, &docid, &docid_vector_ids)?; } } - let hnsw_length = points.len(); - let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points); - - assert_eq!(docids.len(), pids.len()); - - // Store the vectors in the point-docid relation database - index.vector_id_docid.clear(wtxn)?; - for (docid, pid) in docids.into_iter().zip(pids) { - index.vector_id_docid.put(wtxn, &pid.into_inner(), &docid)?; - } - - log::debug!("There are {} entries in the HNSW so far", hnsw_length); - index.put_vector_hnsw(wtxn, &new_hnsw)?; + log::debug!("There are {} entries in the arroy so far", unavailable_vector_ids.len()); + index.put_unavailable_vector_ids(wtxn, unavailable_vector_ids)?; } TypedChunk::ScriptLanguageDocids(sl_map) => { for (key, (deletion, addition)) in sl_map { From cb4ebe163e4eae5004e8e51b4154b5c54fb22a00 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 7 Dec 2023 17:03:10 +0100 Subject: [PATCH 04/28] WIP --- Cargo.lock | 2 +- milli/src/error.rs | 20 +++ milli/src/index.rs | 43 +------ milli/src/search/new/mod.rs | 11 +- milli/src/search/new/vector_sort.rs | 117 ++++++++++-------- milli/src/update/clear_documents.rs | 4 - milli/src/update/index_documents/mod.rs | 38 +++++- .../src/update/index_documents/typed_chunk.rs | 107 ++++++++-------- 8 files changed, 185 insertions(+), 157 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ed6d0c291..fba78b3b6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "arroy" version = "0.1.0" -source = "git+https://github.com/meilisearch/arroy.git#4b59476f457e5443ff250ea10d40d8b66a692674" +source = "git+https://github.com/meilisearch/arroy.git#0079af0ec960bc9c51dd66e898a6b5e980cbb083" dependencies = [ "bytemuck", "byteorder", diff --git a/milli/src/error.rs b/milli/src/error.rs index 032fd63a7..3d07590b0 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -61,6 +61,8 @@ pub enum InternalError { AbortedIndexation, #[error("The matching words list contains at least one invalid member.")] InvalidMatchingWords, + #[error(transparent)] + ArroyError(#[from] arroy::Error), } #[derive(Error, Debug)] @@ -190,6 +192,24 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), } +impl From for Error { + fn from(value: arroy::Error) -> Self { + match value { + arroy::Error::Heed(heed) => heed.into(), + arroy::Error::Io(io) => io.into(), + arroy::Error::InvalidVecDimension { expected, received } => { + Error::UserError(UserError::InvalidVectorDimensions { expected, found: received }) + } + arroy::Error::DatabaseFull + | arroy::Error::InvalidItemAppend + | arroy::Error::UnmatchingDistance { .. } + | arroy::Error::MissingMetadata => { + Error::InternalError(InternalError::ArroyError(value)) + } + } + } +} + #[derive(Error, Debug)] pub enum GeoError { #[error("The `_geo` field in the document with the id: `{document_id}` is not an object. Was expecting an object with the `_geo.lat` and `_geo.lng` fields but instead got `{value}`.")] diff --git a/milli/src/index.rs b/milli/src/index.rs index c494f2f2b..c5e190d38 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -70,7 +70,6 @@ pub mod main_key { pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; pub const PROXIMITY_PRECISION: &str = "proximity-precision"; - pub const VECTOR_UNAVAILABLE_VECTOR_IDS: &str = "vector-unavailable-vector-ids"; pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; } @@ -97,8 +96,6 @@ pub mod db_name { pub const FACET_ID_STRING_FST: &str = "facet-id-string-fst"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; - pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; - pub const VECTOR_DOCID_IDS: &str = "vector-docid-ids"; pub const VECTOR_EMBEDDER_CATEGORY_ID: &str = "vector-embedder-category-id"; pub const VECTOR_ARROY: &str = "vector-arroy"; pub const DOCUMENTS: &str = "documents"; @@ -167,16 +164,10 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, - /// Maps a vector id to its document id. - pub vector_id_docid: Database, - /// Maps a doc id to its vector ids. - pub docid_vector_ids: Database, - /// Maps an embedder name to its id in the arroy store. - pub embedder_category_id: Database, - + pub embedder_category_id: Database, /// Vector store based on arroy™. - pub vector_arroy: arroy::Database, + pub vector_arroy: arroy::Database, /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, @@ -191,7 +182,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(27); + options.max_dbs(25); let env = options.open(path)?; let mut wtxn = env.write_txn()?; @@ -232,8 +223,6 @@ impl Index { let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; // vector stuff - let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; - let docid_vector_ids = env.create_database(&mut wtxn, Some(VECTOR_DOCID_IDS))?; let embedder_category_id = env.create_database(&mut wtxn, Some(VECTOR_EMBEDDER_CATEGORY_ID))?; let vector_arroy = env.create_database(&mut wtxn, Some(VECTOR_ARROY))?; @@ -267,9 +256,7 @@ impl Index { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, - vector_id_docid, vector_arroy, - docid_vector_ids, embedder_category_id, documents, }) @@ -1516,30 +1503,6 @@ impl Index { .get(rtxn, main_key::EMBEDDING_CONFIGS)? .unwrap_or_default()) } - - pub(crate) fn put_unavailable_vector_ids( - &self, - wtxn: &mut RwTxn<'_>, - unavailable_vector_ids: RoaringBitmap, - ) -> heed::Result<()> { - self.main.remap_types::().put( - wtxn, - main_key::VECTOR_UNAVAILABLE_VECTOR_IDS, - &unavailable_vector_ids, - ) - } - - pub(crate) fn delete_unavailable_vector_ids(&self, wtxn: &mut RwTxn<'_>) -> heed::Result { - self.main.remap_key_type::().delete(wtxn, main_key::VECTOR_UNAVAILABLE_VECTOR_IDS) - } - - pub fn unavailable_vector_ids(&self, rtxn: &RoTxn<'_>) -> Result { - Ok(self - .main - .remap_types::() - .get(rtxn, main_key::VECTOR_UNAVAILABLE_VECTOR_IDS)? - .unwrap_or_default()) - } } #[cfg(test)] diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index 372c89601..ad5c59f99 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -262,6 +262,7 @@ fn get_ranking_rules_for_vector<'ctx>( ctx: &SearchContext<'ctx>, sort_criteria: &Option>, geo_strategy: geo_sort::Strategy, + limit_plus_offset: usize, target: &[f32], ) -> Result>> { // query graph search @@ -283,7 +284,12 @@ fn get_ranking_rules_for_vector<'ctx>( | crate::Criterion::Exactness => { if !vector { let vector_candidates = ctx.index.documents_ids(ctx.txn)?; - let vector_sort = VectorSort::new(ctx, target.to_vec(), vector_candidates)?; + let vector_sort = VectorSort::new( + ctx, + target.to_vec(), + vector_candidates, + limit_plus_offset, + )?; ranking_rules.push(Box::new(vector_sort)); vector = true; } @@ -509,7 +515,8 @@ pub fn execute_vector_search( /// FIXME: input universe = universe & documents_with_vectors // for now if we're computing embeddings for ALL documents, we can assume that this is just universe - let ranking_rules = get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, vector)?; + let ranking_rules = + get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?; let mut placeholder_search_logger = logger::DefaultSearchLogger; let placeholder_search_logger: &mut dyn SearchLogger = diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 59b7a72c2..9bf13c631 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -1,48 +1,83 @@ -use std::future::Future; use std::iter::FromIterator; -use std::pin::Pin; -use nolife::DynBoxScope; +use ordered_float::OrderedFloat; use roaring::RoaringBitmap; use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; -use crate::distance::NDotProductPoint; -use crate::index::Hnsw; use crate::score_details::{self, ScoreDetails}; -use crate::{Result, SearchContext, SearchLogger, UserError}; +use crate::{DocumentId, Result, SearchContext, SearchLogger}; -pub struct VectorSort<'ctx, Q: RankingRuleQueryTrait> { +pub struct VectorSort { query: Option, target: Vec, vector_candidates: RoaringBitmap, - reader: arroy::Reader<'ctx, arroy::distances::DotProduct>, + cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec)>, limit: usize, } -impl<'ctx, Q: RankingRuleQueryTrait> VectorSort<'ctx, Q> { +impl VectorSort { pub fn new( - ctx: &'ctx SearchContext, + _ctx: &SearchContext, target: Vec, vector_candidates: RoaringBitmap, limit: usize, ) -> Result { - /// FIXME? what to do in case of missing metadata - let reader = arroy::Reader::open(ctx.txn, 0, ctx.index.vector_arroy)?; + Ok(Self { + query: None, + target, + vector_candidates, + cached_sorted_docids: Default::default(), + limit, + }) + } - let target_clone = target.clone(); + fn fill_buffer(&mut self, ctx: &mut SearchContext<'_>) -> Result<()> { + let readers: std::result::Result, _> = (0..=u8::MAX) + .map_while(|k| { + arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy) + .map(Some) + .or_else(|e| match e { + arroy::Error::MissingMetadata => Ok(None), + e => Err(e), + }) + .transpose() + }) + .collect(); - Ok(Self { query: None, target, vector_candidates, reader, limit }) + let readers = readers?; + + let target = &self.target; + let mut results = Vec::new(); + + for reader in readers.iter() { + let nns_by_vector = reader.nns_by_vector( + ctx.txn, + &target, + self.limit, + None, + Some(&self.vector_candidates), + )?; + let vectors: std::result::Result, _> = nns_by_vector + .iter() + .map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap()) + .collect(); + let vectors = vectors?; + results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z))); + } + results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance)); + self.cached_sorted_docids = results.into_iter(); + Ok(()) } } -impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<'ctx, Q> { +impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { fn id(&self) -> String { "vector_sort".to_owned() } fn start_iteration( &mut self, - _ctx: &mut SearchContext<'ctx>, + ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger, universe: &RoaringBitmap, query: &Q, @@ -51,7 +86,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<'ctx, Q self.query = Some(query.clone()); self.vector_candidates &= universe; - + self.fill_buffer(ctx)?; Ok(()) } @@ -75,40 +110,24 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<'ctx, Q }), })); } - let target = &self.target; - let vector_candidates = &self.vector_candidates; - let result = self.reader.nns_by_vector(ctx.txn, &target, count, search_k, candidates) - - scope.enter(|it| { - for item in it.by_ref() { - let item: Item = item; - let index = item.pid.into_inner(); - let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); - - if vector_candidates.contains(docid) { - return Ok(Some(RankingRuleOutput { - query, - candidates: RoaringBitmap::from_iter([docid]), - score: ScoreDetails::Vector(score_details::Vector { - target_vector: target.clone(), - value_similarity: Some(( - item.point.clone().into_inner(), - 1.0 - item.distance, - )), - }), - })); - } + while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() { + if self.vector_candidates.contains(docid) { + return Ok(Some(RankingRuleOutput { + query, + candidates: RoaringBitmap::from_iter([docid]), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: self.target.clone(), + value_similarity: Some((vector, 1.0 - distance)), + }), + })); } - Ok(Some(RankingRuleOutput { - query, - candidates: universe.clone(), - score: ScoreDetails::Vector(score_details::Vector { - target_vector: target.clone(), - value_similarity: None, - }), - })) - }) + } + + // if we got out of this loop it means we've exhausted our cache. + // we need to refill it and run the function again. + self.fill_buffer(ctx)?; + self.next_bucket(ctx, _logger, universe) } fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger) { diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index 3b1a6c5d8..a6c7ff2b1 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -42,9 +42,7 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, - vector_id_docid, vector_arroy, - docid_vector_ids, embedder_category_id: _, documents, } = self.index; @@ -86,8 +84,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { field_id_docid_facet_strings.clear(self.wtxn)?; // vector vector_arroy.clear(self.wtxn)?; - vector_id_docid.clear(self.wtxn)?; - docid_vector_ids.clear(self.wtxn)?; documents.clear(self.wtxn)?; diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index eaac26dd3..472c77111 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -418,7 +418,7 @@ where } // needs to be dropped to avoid channel waiting lock. - drop(lmdb_writer_sx) + drop(lmdb_writer_sx); }); let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0; @@ -435,6 +435,8 @@ where let mut word_docids = None; let mut exact_word_docids = None; + let mut dimension = None; + for result in lmdb_writer_rx { if (self.should_abort)() { return Err(Error::InternalError(InternalError::AbortedIndexation)); @@ -464,6 +466,20 @@ where word_position_docids = Some(cloneable_chunk); TypedChunk::WordPositionDocids(chunk) } + TypedChunk::VectorPoints { + expected_dimension, + remove_vectors, + embeddings, + manual_vectors, + } => { + dimension = Some(expected_dimension); + TypedChunk::VectorPoints { + remove_vectors, + embeddings, + expected_dimension, + manual_vectors, + } + } otherwise => otherwise, }; @@ -490,9 +506,6 @@ where } } - let writer = arroy::Writer::prepare(self.wtxn, self.index.vector_arroy, 0, 0)?; - writer.build(self.wtxn, &mut rand::rngs::StdRng::from_entropy(), None)?; - // We write the field distribution into the main database self.index.put_field_distribution(self.wtxn, &field_distribution)?; @@ -500,6 +513,23 @@ where self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; + if let Some(dimension) = dimension { + let wtxn = &mut *self.wtxn; + let vector_arroy = self.index.vector_arroy; + pool.install(|| { + /// FIXME: do for each embedder + let mut rng = rand::rngs::StdRng::from_entropy(); + for k in 0..=u8::MAX { + let writer = arroy::Writer::prepare(wtxn, vector_arroy, k.into(), dimension)?; + if writer.is_empty(wtxn)? { + break; + } + writer.build(wtxn, &mut rng, None)?; + } + Result::Ok(()) + })?; + } + self.execute_prefix_databases( word_docids, exact_word_docids, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index bc82518ca..82397ed3d 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -8,9 +8,7 @@ use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::Bytes; use heed::{PutFlags, RwTxn}; -use log::error; use obkv::{KvReader, KvWriter}; -use ordered_float::OrderedFloat; use roaring::RoaringBitmap; use super::helpers::{ @@ -18,16 +16,12 @@ use super::helpers::{ valid_lmdb_key, CursorClonableMmap, }; use super::{ClonableMmap, MergeFn}; -use crate::distance::NDotProductPoint; -use crate::error::UserError; use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind}; use crate::facet::FacetType; use crate::index::db_name::DOCUMENTS; -use crate::index::Hnsw; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; -use crate::update::{available_documents_ids, AvailableDocumentsIds}; use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; pub(crate) enum TypedChunk { @@ -374,28 +368,28 @@ pub(crate) fn write_typed_chunk_into_index( return Ok((RoaringBitmap::new(), is_merged_database)); } - let mut unavailable_vector_ids = index.unavailable_vector_ids(&wtxn)?; /// FIXME: allow customizing distance - /// FIXME: allow customizing index - let writer = arroy::Writer::prepare(wtxn, index.vector_arroy, 0, expected_dimension)?; + let writers: std::result::Result, _> = (0..=u8::MAX) + .map(|k| { + /// FIXME: allow customizing index and then do index << 8 + k + arroy::Writer::prepare(wtxn, index.vector_arroy, k.into(), expected_dimension) + }) + .collect(); + let writers = writers?; // remove vectors for docids we want them removed let mut cursor = remove_vectors.into_cursor()?; while let Some((key, _)) = cursor.move_on_next()? { let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); - let Some(to_remove_vector_ids) = index.docid_vector_ids.get(&wtxn, &docid)? else { - continue; - }; - unavailable_vector_ids -= to_remove_vector_ids; - - for item in to_remove_vector_ids { - writer.del_item(wtxn, item)?; + for writer in &writers { + // Uses invariant: vectors are packed in the first writers. + if !writer.del_item(wtxn, docid)? { + break; + } } } - let mut available_vector_ids = - AvailableDocumentsIds::from_documents_ids(&unavailable_vector_ids); // add generated embeddings if let Some(embeddings) = embeddings { let mut cursor = embeddings.into_cursor()?; @@ -408,19 +402,10 @@ pub(crate) fn write_typed_chunk_into_index( // code error if we somehow got the wrong dimension .unwrap(); - let mut new_vector_ids = RoaringBitmap::new(); - for embedding in embeddings.iter() { - /// FIXME: error when you get over 9000 - let next_vector_id = available_vector_ids.next().unwrap(); - unavailable_vector_ids.insert(next_vector_id); - - new_vector_ids.insert(next_vector_id); - - index.vector_id_docid.put(wtxn, &next_vector_id, &docid)?; - - writer.add_item(wtxn, next_vector_id, embedding)?; + /// FIXME: detect overflow + for (embedding, writer) in embeddings.iter().zip(&writers) { + writer.add_item(wtxn, docid, embedding)?; } - index.docid_vector_ids.put(wtxn, &docid, &new_vector_ids)?; } } @@ -433,44 +418,52 @@ pub(crate) fn write_typed_chunk_into_index( let vector_deladd_obkv = KvReaderDelAdd::new(value); if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { - let vector = pod_collect_to_vec(value); - let Some(mut docid_vector_ids) = index.docid_vector_ids.get(&wtxn, &docid)? - else { - error!("Unable to delete the vector: {:?}", vector); - continue; - }; - for item in docid_vector_ids { - /// FIXME: comparing the vectors by equality is inefficient, and dangerous by perfect equality - let candidate = writer.item_vector(&wtxn, item)?.expect("Inconsistent dbs"); - if candidate == vector { - writer.del_item(wtxn, item)?; - unavailable_vector_ids.remove(item); - index.vector_id_docid.delete(wtxn, &item)?; - docid_vector_ids.remove(item); + let vector: Vec = pod_collect_to_vec(value); + + let mut deleted_index = None; + for (index, writer) in writers.iter().enumerate() { + let Some(candidate) = writer.item_vector(&wtxn, docid)? else { + // uses invariant: vectors are packed in the first writers. break; + }; + if candidate == vector { + writer.del_item(wtxn, docid)?; + deleted_index = Some(index); + } + } + + // 🥲 enforce invariant: vectors are packed in the first writers. + if let Some(deleted_index) = deleted_index { + let mut last_index_with_a_vector = None; + for (index, writer) in writers.iter().enumerate().skip(deleted_index) { + let Some(candidate) = writer.item_vector(&wtxn, docid)? else { + break; + }; + last_index_with_a_vector = Some((index, candidate)); + } + if let Some((last_index, vector)) = last_index_with_a_vector { + // unwrap: computed the index from the list of writers + let writer = writers.get(last_index).unwrap(); + writer.del_item(wtxn, docid)?; + writers.get(deleted_index).unwrap().add_item(wtxn, docid, &vector)?; } } - index.docid_vector_ids.put(wtxn, &docid, &docid_vector_ids)?; } - let mut available_vector_ids = - AvailableDocumentsIds::from_documents_ids(&unavailable_vector_ids); if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { let vector = pod_collect_to_vec(value); - let next_vector_id = available_vector_ids.next().unwrap(); - writer.add_item(wtxn, next_vector_id, &vector)?; - unavailable_vector_ids.insert(next_vector_id); - index.vector_id_docid.put(wtxn, &next_vector_id, &docid)?; - let mut docid_vector_ids = - index.docid_vector_ids.get(&wtxn, &docid)?.unwrap_or_default(); - docid_vector_ids.insert(next_vector_id); - index.docid_vector_ids.put(wtxn, &docid, &docid_vector_ids)?; + /// FIXME: detect overflow + for writer in &writers { + if !writer.contains_item(wtxn, docid)? { + writer.add_item(wtxn, docid, &vector)?; + break; + } + } } } - log::debug!("There are {} entries in the arroy so far", unavailable_vector_ids.len()); - index.put_unavailable_vector_ids(wtxn, unavailable_vector_ids)?; + log::debug!("There are 🤷‍♀️ entries in the arroy so far"); } TypedChunk::ScriptLanguageDocids(sl_map) => { for (key, (deletion, addition)) in sl_map { From fb539f61fe16831a008d3b948bc763ea89fc4eb6 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 7 Dec 2023 17:35:45 +0100 Subject: [PATCH 05/28] WIP --- .../src/update/index_documents/extract/mod.rs | 28 +++++++++++++------ .../src/update/index_documents/typed_chunk.rs | 7 ----- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 06bc8b609..69530a507 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -295,7 +295,7 @@ fn send_original_documents_data( let (embedder, prompt) = embedders.get("default").cloned().unzip(); let result = extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref()); - let _ = match result { + match result { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { /// FIXME: support multiple embedders let results = embedder.and_then(|embedder| { @@ -309,15 +309,25 @@ fn send_original_documents_data( }); let (embeddings, expected_dimension) = results.unzip(); let expected_dimension = expected_dimension.flatten(); - lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { - remove_vectors, - embeddings, - /// FIXME: compute an expected dimension from the manual vectors if any - expected_dimension: expected_dimension.unwrap(), - manual_vectors, - })) + if !(remove_vectors.is_empty() + && manual_vectors.is_empty() + && embeddings.as_ref().map_or(true, |e| e.is_empty())) + { + /// FIXME FIXME FIXME + if expected_dimension.is_some() { + let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { + remove_vectors, + embeddings, + /// FIXME: compute an expected dimension from the manual vectors if any + expected_dimension: expected_dimension.unwrap(), + manual_vectors, + })); + } + } + } + Err(error) => { + let _ = lmdb_writer_sx_cloned.send(Err(error)); } - Err(error) => lmdb_writer_sx_cloned.send(Err(error)), }; }); diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 82397ed3d..84b17dca9 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -361,13 +361,6 @@ pub(crate) fn write_typed_chunk_into_index( embeddings, expected_dimension, } => { - if remove_vectors.is_empty() - && manual_vectors.is_empty() - && embeddings.as_ref().map_or(true, |e| e.is_empty()) - { - return Ok((RoaringBitmap::new(), is_merged_database)); - } - /// FIXME: allow customizing distance let writers: std::result::Result, _> = (0..=u8::MAX) .map(|k| { From 687d92f217a450ad30685ccc351b00796f624bd3 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 7 Dec 2023 23:05:04 +0100 Subject: [PATCH 06/28] prompt bifluor+ --- .../update/index_documents/extract/extract_vector_points.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 8399c220b..d8d6c933c 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -172,10 +172,11 @@ pub fn extract_vector_points( prompt.render(obkv, DelAdd::Addition, &field_id_map)?; if old_prompt != new_prompt { log::trace!( - "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" + "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" ); VectorStateDelta::NowGenerated(new_prompt) } else { + log::trace!("⏭️ Prompt unmodified, skipping"); VectorStateDelta::NoChange } } @@ -204,10 +205,11 @@ pub fn extract_vector_points( let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?; if old_prompt != new_prompt { log::trace!( - "Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}" + "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" ); VectorStateDelta::NowGenerated(new_prompt) } else { + log::trace!("⏭️ Prompt unmodified, skipping"); VectorStateDelta::NoChange } } From e56f1600321ae6a0f926037798943cbb90e89f3d Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 7 Dec 2023 23:05:26 +0100 Subject: [PATCH 07/28] Actually pass embedders on reindex --- milli/src/update/settings.rs | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 5e3683f32..b8355be51 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1,5 +1,7 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::convert::TryInto; use std::result::Result as StdResult; +use std::sync::Arc; use charabia::{Normalize, Tokenizer, TokenizerBuilder}; use deserr::{DeserializeError, Deserr}; @@ -12,11 +14,12 @@ use super::IndexerConfig; use crate::criterion::Criterion; use crate::error::UserError; use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS}; +use crate::prompt::Prompt; use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; use crate::vector::settings::{EmbeddingSettings, PromptSettings}; -use crate::vector::EmbeddingConfig; +use crate::vector::{Embedder, EmbeddingConfig}; use crate::{FieldsIdsMap, Index, OrderBy, Result}; #[derive(Debug, Clone, PartialEq, Eq, Copy)] @@ -396,6 +399,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { fields_ids_map, )?; + let embedder_configs = self.index.embedding_configs(self.wtxn)?; + let embedders = self.embedders(embedder_configs)?; + // We index the generated `TransformOutput` which must contain // all the documents with fields in the newly defined searchable order. let indexing_builder = IndexDocuments::new( @@ -406,11 +412,34 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { &progress_callback, &should_abort, )?; + + let indexing_builder = indexing_builder.with_embedders(embedders); indexing_builder.execute_raw(output)?; Ok(()) } + fn embedders( + &self, + embedding_configs: Vec<(String, EmbeddingConfig)>, + ) -> Result, Arc)>> { + let res: Result<_> = embedding_configs + .into_iter() + .map(|(name, EmbeddingConfig { embedder_options, prompt })| { + let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?); + + let embedder = Arc::new( + Embedder::new(embedder_options.clone()) + .map_err(crate::vector::Error::from) + .map_err(crate::UserError::from) + .map_err(crate::Error::from)?, + ); + Ok((name, (embedder, prompt))) + }) + .collect(); + res + } + fn update_displayed(&mut self) -> Result { match self.displayed_fields { Setting::Set(ref fields) => { From 65e49b7092475d11afc97152395190cdd3e954e9 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 10:05:06 +0100 Subject: [PATCH 08/28] Remove stuff, add distribution shift (WIP) --- Cargo.lock | 219 ++++-------------- meilisearch/src/search.rs | 45 +--- milli/Cargo.toml | 16 +- milli/src/distance.rs | 41 ---- milli/src/index.rs | 4 - milli/src/lib.rs | 2 - milli/src/search/new/mod.rs | 13 +- milli/src/search/new/vector_sort.rs | 16 +- .../src/update/index_documents/typed_chunk.rs | 4 +- milli/src/vector/mod.rs | 44 ++++ 10 files changed, 126 insertions(+), 278 deletions(-) delete mode 100644 milli/src/distance.rs diff --git a/Cargo.lock b/Cargo.lock index fba78b3b6..3c2f38840 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -56,7 +56,7 @@ dependencies = [ "flate2", "futures-core", "h2", - "http", + "http 0.2.9", "httparse", "httpdate", "itoa", @@ -90,7 +90,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66ff4d247d2b160861fa2866457e85706833527840e4133f8f49aa423a38799" dependencies = [ "bytestring", - "http", + "http 0.2.9", "regex", "serde", "tracing", @@ -189,7 +189,7 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "http", + "http 0.2.9", "itoa", "language-tags", "log", @@ -1407,7 +1407,7 @@ dependencies = [ "anyhow", "big_s", "flate2", - "http", + "http 0.2.9", "log", "maplit", "meili-snap", @@ -1702,21 +1702,6 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" -[[package]] -name = "foreign-types" -version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" -dependencies = [ - "foreign-types-shared", -] - -[[package]] -name = "foreign-types-shared" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" - [[package]] name = "form_urlencoded" version = "1.2.0" @@ -2047,7 +2032,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 1.9.3", "slab", "tokio", @@ -2171,13 +2156,12 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hf-hub" version = "0.3.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2b780635574b3d92f036890d8373433d6f9fc7abb320ee42a5c25897fc8ed732" +source = "git+https://github.com/dureuill/hf-hub.git?branch=rust_tls#88d4f11cb9fa079f2912bacb96f5080b16825ce8" dependencies = [ "dirs", + "http 1.0.0", "indicatif", "log", - "native-tls", "rand", "serde", "serde_json", @@ -2205,6 +2189,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -2212,7 +2207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -2245,7 +2240,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", @@ -2265,7 +2260,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97" dependencies = [ "futures-util", - "http", + "http 0.2.9", "hyper", "rustls 0.21.6", "tokio", @@ -2868,21 +2863,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "instant-distance" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a" -dependencies = [ - "num_cpus", - "ordered-float 3.7.0", - "parking_lot", - "rand", - "rayon", - "serde", - "serde-big-array", -] - [[package]] name = "io-lifetimes" version = "1.0.11" @@ -3531,7 +3511,7 @@ dependencies = [ "futures", "futures-util", "hex", - "http", + "http 0.2.9", "index-scheduler", "indexmap 2.0.0", "insta", @@ -3718,7 +3698,6 @@ dependencies = [ "hf-hub", "indexmap 2.0.0", "insta", - "instant-distance", "itertools 0.11.0", "json-depth-checker", "levenshtein_automata", @@ -3730,7 +3709,6 @@ dependencies = [ "meili-snap", "memmap2 0.7.1", "mimalloc", - "nolife", "obkv", "once_cell", "ordered-float 3.7.0", @@ -3829,35 +3807,11 @@ dependencies = [ "syn 2.0.28", ] -[[package]] -name = "native-tls" -version = "0.2.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07226173c32f2926027b63cce4bcd8076c3552846cbe7925f3aaffeac0a3b92e" -dependencies = [ - "lazy_static", - "libc", - "log", - "openssl", - "openssl-probe", - "openssl-sys", - "schannel", - "security-framework", - "security-framework-sys", - "tempfile", -] - [[package]] name = "nelson" version = "0.1.0" source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a" -[[package]] -name = "nolife" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc" - [[package]] name = "nom" version = "7.1.3" @@ -3994,50 +3948,6 @@ version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" -[[package]] -name = "openssl" -version = "0.10.59" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a257ad03cd8fb16ad4172fedf8094451e1af1c4b70097636ef2eac9a5f0cc33" -dependencies = [ - "bitflags 2.4.1", - "cfg-if", - "foreign-types", - "libc", - "once_cell", - "openssl-macros", - "openssl-sys", -] - -[[package]] -name = "openssl-macros" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.28", -] - -[[package]] -name = "openssl-probe" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" - -[[package]] -name = "openssl-sys" -version = "0.9.95" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40a4130519a360279579c2053038317e40eff64d13fd3f004f9e1b72b8a6aaf9" -dependencies = [ - "cc", - "libc", - "pkg-config", - "vcpkg", -] - [[package]] name = "option-ext" version = "0.2.0" @@ -4655,7 +4565,7 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-rustls", @@ -4802,7 +4712,7 @@ checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" dependencies = [ "log", "ring", - "rustls-webpki 0.101.3", + "rustls-webpki", "sct", ] @@ -4815,16 +4725,6 @@ dependencies = [ "base64 0.21.5", ] -[[package]] -name = "rustls-webpki" -version = "0.100.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "rustls-webpki" version = "0.101.3" @@ -4866,15 +4766,6 @@ dependencies = [ "winapi-util", ] -[[package]] -name = "schannel" -version = "0.1.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" -dependencies = [ - "windows-sys 0.48.0", -] - [[package]] name = "scopeguard" version = "1.2.0" @@ -4891,29 +4782,6 @@ dependencies = [ "untrusted", ] -[[package]] -name = "security-framework" -version = "2.9.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" -dependencies = [ - "bitflags 1.3.2", - "core-foundation", - "core-foundation-sys", - "libc", - "security-framework-sys", -] - -[[package]] -name = "security-framework-sys" -version = "2.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" -dependencies = [ - "core-foundation-sys", - "libc", -] - [[package]] name = "segment" version = "0.2.2" @@ -4949,15 +4817,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde-big-array" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f" -dependencies = [ - "serde", -] - [[package]] name = "serde-cs" version = "0.2.4" @@ -5151,6 +5010,17 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spin" version = "0.5.2" @@ -5713,21 +5583,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "ureq" -version = "2.7.1" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97" dependencies = [ "base64 0.21.5", "flate2", "log", - "native-tls", "once_cell", "rustls 0.21.6", - "rustls-webpki 0.100.2", + "rustls-webpki", "serde", "serde_json", + "socks", "url", - "webpki-roots 0.23.1", + "webpki-roots 0.25.3", ] [[package]] @@ -5958,15 +5828,6 @@ dependencies = [ "webpki", ] -[[package]] -name = "webpki-roots" -version = "0.23.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" -dependencies = [ - "rustls-webpki 0.100.2", -] - [[package]] name = "webpki-roots" version = "0.25.3" diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 235b745a9..9136157f9 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -14,18 +14,14 @@ use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; -use meilisearch_types::milli::{ - dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, - VectorQuery, -}; +use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, - SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, + SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; -use ordered_float::OrderedFloat; use regex::Regex; use serde::Serialize; use serde_json::{json, Value}; @@ -550,13 +546,8 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - let semantic_score = /*match query.vector.as_ref() { - Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { - Some(vectors) => compute_semantic_score(vector, vectors)?, - None => None, - }, - None => None, - };*/ None; + /// FIXME: remove this or set to value from the score details + let semantic_score = None; let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -689,18 +680,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result> { - let vectors = serde_json::from_value(vectors) - .map(VectorOrArrayOfVectors::into_array_of_vectors) - .map_err(InternalError::SerdeJson)?; - Ok(vectors - .into_iter() - .flatten() - .map(|v| OrderedFloat(dot_product_similarity(query, &v))) - .max() - .map(OrderedFloat::into_inner)) -} - fn compute_formatted_options( attr_to_highlight: &HashSet, attr_to_crop: &[String], @@ -828,22 +807,6 @@ fn make_document( Ok(document) } -/// Extract the JSON value under the field name specified -/// but doesn't support nested objects. -fn extract_field( - field_name: &str, - field_ids_map: &FieldsIdsMap, - obkv: obkv::KvReaderU16, -) -> Result, MeilisearchHttpError> { - match field_ids_map.id(field_name) { - Some(fid) => match obkv.get(fid) { - Some(value) => Ok(serde_json::from_slice(value).map(Some)?), - None => Ok(None), - }, - None => Ok(None), - } -} - fn format_fields<'a>( document: &Document, field_ids_map: &FieldsIdsMap, diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 0aee03b2f..b977d64f1 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -36,7 +36,6 @@ heed = { version = "0.20.0-alpha.9", default-features = false, features = [ "read-txn-no-tls", ] } indexmap = { version = "2.0.0", features = ["serde"] } -instant-distance = { version = "0.6.1", features = ["with-serde"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.7.1" @@ -79,10 +78,11 @@ candle-core = { git = "https://github.com/huggingface/candle.git", version = "0. candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } -hf-hub = "0.3.2" +hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ + "online", +] } tokio = { version = "1.34.0", features = ["rt"] } futures = "0.3.29" -nolife = { version = "0.3.1" } reqwest = { version = "0.11.16", features = [ "rustls-tls", "json", @@ -102,7 +102,15 @@ meili-snap = { path = "../meili-snap" } rand = { version = "0.8.5", features = ["small_rng"] } [features] -all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] +all-tokenizations = [ + "charabia/chinese", + "charabia/hebrew", + "charabia/japanese", + "charabia/thai", + "charabia/korean", + "charabia/greek", + "charabia/khmer", +] # Use POSIX semaphores instead of SysV semaphores in LMDB # For more information on this feature, see heed's Cargo.toml diff --git a/milli/src/distance.rs b/milli/src/distance.rs deleted file mode 100644 index e9e17e647..000000000 --- a/milli/src/distance.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::ops; - -use instant_distance::Point; -use serde::{Deserialize, Serialize}; - -use crate::normalize_vector; - -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct NDotProductPoint(Vec); - -impl NDotProductPoint { - pub fn new(point: Vec) -> Self { - NDotProductPoint(normalize_vector(point)) - } - - pub fn into_inner(self) -> Vec { - self.0 - } -} - -impl ops::Deref for NDotProductPoint { - type Target = [f32]; - - fn deref(&self) -> &Self::Target { - self.0.as_slice() - } -} - -impl Point for NDotProductPoint { - fn distance(&self, other: &Self) -> f32 { - let dist = 1.0 - dot_product_similarity(&self.0, &other.0); - debug_assert!(!dist.is_nan()); - dist - } -} - -/// Returns the dot product similarity score that will between 0.0 and 1.0 -/// if both vectors are normalized. The higher the more similar the vectors are. -pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b).map(|(a, b)| a * b).sum() -} diff --git a/milli/src/index.rs b/milli/src/index.rs index c5e190d38..05babf410 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -10,7 +10,6 @@ use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::distance::NDotProductPoint; use crate::documents::PrimaryKey; use crate::error::{InternalError, UserError}; use crate::fields_ids_map::FieldsIdsMap; @@ -30,9 +29,6 @@ use crate::{ BEU32, BEU64, }; -/// The HNSW data-structure that we serialize, fill and search in. -pub type Hnsw = instant_distance::Hnsw; - pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; diff --git a/milli/src/lib.rs b/milli/src/lib.rs index b865747e0..ce37fe375 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,7 +10,6 @@ pub mod documents; mod asc_desc; mod criterion; -pub mod distance; mod error; mod external_documents_ids; pub mod facet; @@ -33,7 +32,6 @@ use std::convert::{TryFrom, TryInto}; use std::hash::BuildHasherDefault; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; -pub use distance::dot_product_similarity; pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index ad5c59f99..bba6cf119 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -50,6 +50,7 @@ use self::vector_sort::VectorSort; use crate::error::FieldIdMapMissingEntry; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; +use crate::vector::DistributionShift; use crate::{ AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, }; @@ -264,6 +265,7 @@ fn get_ranking_rules_for_vector<'ctx>( geo_strategy: geo_sort::Strategy, limit_plus_offset: usize, target: &[f32], + distribution_shift: Option, ) -> Result>> { // query graph search @@ -289,6 +291,7 @@ fn get_ranking_rules_for_vector<'ctx>( target.to_vec(), vector_candidates, limit_plus_offset, + distribution_shift, )?; ranking_rules.push(Box::new(vector_sort)); vector = true; @@ -515,8 +518,14 @@ pub fn execute_vector_search( /// FIXME: input universe = universe & documents_with_vectors // for now if we're computing embeddings for ALL documents, we can assume that this is just universe - let ranking_rules = - get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, from + length, vector)?; + let ranking_rules = get_ranking_rules_for_vector( + ctx, + sort_criteria, + geo_strategy, + from + length, + vector, + None, + )?; let mut placeholder_search_logger = logger::DefaultSearchLogger; let placeholder_search_logger: &mut dyn SearchLogger = diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 9bf13c631..2d7cdbe39 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -5,6 +5,7 @@ use roaring::RoaringBitmap; use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; use crate::score_details::{self, ScoreDetails}; +use crate::vector::DistributionShift; use crate::{DocumentId, Result, SearchContext, SearchLogger}; pub struct VectorSort { @@ -13,6 +14,7 @@ pub struct VectorSort { vector_candidates: RoaringBitmap, cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec)>, limit: usize, + distribution_shift: Option, } impl VectorSort { @@ -21,6 +23,7 @@ impl VectorSort { target: Vec, vector_candidates: RoaringBitmap, limit: usize, + distribution_shift: Option, ) -> Result { Ok(Self { query: None, @@ -28,6 +31,7 @@ impl VectorSort { vector_candidates, cached_sorted_docids: Default::default(), limit, + distribution_shift, }) } @@ -52,7 +56,7 @@ impl VectorSort { for reader in readers.iter() { let nns_by_vector = reader.nns_by_vector( ctx.txn, - &target, + target, self.limit, None, Some(&self.vector_candidates), @@ -66,6 +70,7 @@ impl VectorSort { } results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance)); self.cached_sorted_docids = results.into_iter(); + Ok(()) } } @@ -111,14 +116,19 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { })); } - while let Some((docid, distance, vector)) = self.cached_sorted_docids.next() { + for (docid, distance, vector) in self.cached_sorted_docids.by_ref() { if self.vector_candidates.contains(docid) { + let score = 1.0 - distance; + let score = self + .distribution_shift + .map(|distribution| distribution.shift(score)) + .unwrap_or(score); return Ok(Some(RankingRuleOutput { query, candidates: RoaringBitmap::from_iter([docid]), score: ScoreDetails::Vector(score_details::Vector { target_vector: self.target.clone(), - value_similarity: Some((vector, 1.0 - distance)), + value_similarity: Some((vector, score)), }), })); } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 84b17dca9..da99ed685 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -415,7 +415,7 @@ pub(crate) fn write_typed_chunk_into_index( let mut deleted_index = None; for (index, writer) in writers.iter().enumerate() { - let Some(candidate) = writer.item_vector(&wtxn, docid)? else { + let Some(candidate) = writer.item_vector(wtxn, docid)? else { // uses invariant: vectors are packed in the first writers. break; }; @@ -429,7 +429,7 @@ pub(crate) fn write_typed_chunk_into_index( if let Some(deleted_index) = deleted_index { let mut last_index_with_a_vector = None; for (index, writer) in writers.iter().enumerate().skip(deleted_index) { - let Some(candidate) = writer.item_vector(&wtxn, docid)? else { + let Some(candidate) = writer.item_vector(wtxn, docid)? else { break; }; last_index_with_a_vector = Some((index, candidate)); diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index faaa7bf2a..91640b8fb 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -140,3 +140,47 @@ impl Embedder { } } } + +#[derive(Debug, Clone, Copy)] +pub struct DistributionShift { + pub current_mean: f32, + pub current_sigma: f32, +} + +impl DistributionShift { + /// `None` if sigma <= 0. + pub fn new(mean: f32, sigma: f32) -> Option { + if sigma <= 0.0 { + None + } else { + Some(Self { current_mean: mean, current_sigma: sigma }) + } + } + + pub fn shift(&self, score: f32) -> f32 { + // + // We're somewhat abusively mapping the distribution of distances to a gaussian. + // The parameters we're given is the mean and sigma of the native result distribution. + // We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4. + + let target_mean = 0.5; + let target_sigma = 0.4; + + // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. + let factor = target_sigma / self.current_sigma; + // a*mu1 + b = mu2 => b = mu2 - a*mu1 + let offset = target_mean - (factor * self.current_mean); + + let mut score = factor * score + offset; + + // clamp the final score in the ]0, 1] interval. + if score <= 0.0 { + score = f32::EPSILON; + } + if score > 1.0 { + score = 1.0; + } + + score + } +} From 11e2a2c1aabbb8897f9d49f48f31071a5c7378bb Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 12:08:09 +0100 Subject: [PATCH 09/28] Fix geosort bug --- milli/src/search/new/geo_sort.rs | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/milli/src/search/new/geo_sort.rs b/milli/src/search/new/geo_sort.rs index b2e3a2f3d..5f5ceb379 100644 --- a/milli/src/search/new/geo_sort.rs +++ b/milli/src/search/new/geo_sort.rs @@ -107,12 +107,16 @@ impl GeoSort { /// Refill the internal buffer of cached docids based on the strategy. /// Drop the rtree if we don't need it anymore. - fn fill_buffer(&mut self, ctx: &mut SearchContext) -> Result<()> { + fn fill_buffer( + &mut self, + ctx: &mut SearchContext, + geo_candidates: &RoaringBitmap, + ) -> Result<()> { debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng"); debug_assert!(self.cached_sorted_docids.is_empty()); // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` - let rtree = if self.strategy.use_rtree(self.geo_candidates.len() as usize) { + let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) { if let Some(rtree) = self.rtree.as_ref() { // get rtree from cache Some(rtree) @@ -131,7 +135,7 @@ impl GeoSort { if self.ascending { let point = lat_lng_to_xyz(&self.point); for point in rtree.nearest_neighbor_iter(&point) { - if self.geo_candidates.contains(point.data.0) { + if geo_candidates.contains(point.data.0) { self.cached_sorted_docids.push_back(point.data); if self.cached_sorted_docids.len() >= cache_size { break; @@ -143,7 +147,7 @@ impl GeoSort { // and we insert the points in reverse order they get reversed when emptying the cache later on let point = lat_lng_to_xyz(&opposite_of(self.point)); for point in rtree.nearest_neighbor_iter(&point) { - if self.geo_candidates.contains(point.data.0) { + if geo_candidates.contains(point.data.0) { self.cached_sorted_docids.push_front(point.data); if self.cached_sorted_docids.len() >= cache_size { break; @@ -155,8 +159,7 @@ impl GeoSort { // the iterative version let [lat, lng] = self.field_ids.unwrap(); - let mut documents = self - .geo_candidates + let mut documents = geo_candidates .iter() .map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) }) .collect::>>()?; @@ -216,9 +219,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { assert!(self.query.is_none()); self.query = Some(query.clone()); - self.geo_candidates &= universe; - if self.geo_candidates.is_empty() { + let geo_candidates = &self.geo_candidates & universe; + + if geo_candidates.is_empty() { return Ok(()); } @@ -226,7 +230,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat"); let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng"); self.field_ids = Some([lat, lng]); - self.fill_buffer(ctx)?; + self.fill_buffer(ctx, &geo_candidates)?; Ok(()) } @@ -238,9 +242,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { universe: &RoaringBitmap, ) -> Result>> { let query = self.query.as_ref().unwrap().clone(); - self.geo_candidates &= universe; - if self.geo_candidates.is_empty() { + let geo_candidates = &self.geo_candidates & universe; + + if geo_candidates.is_empty() { return Ok(Some(RankingRuleOutput { query, candidates: universe.clone(), @@ -261,7 +266,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { } }; while let Some((id, point)) = next(&mut self.cached_sorted_docids) { - if self.geo_candidates.contains(id) { + if geo_candidates.contains(id) { return Ok(Some(RankingRuleOutput { query, candidates: RoaringBitmap::from_iter([id]), @@ -276,7 +281,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { // if we got out of this loop it means we've exhausted our cache. // we need to refill it and run the function again. - self.fill_buffer(ctx)?; + self.fill_buffer(ctx, &geo_candidates)?; self.next_bucket(ctx, logger, universe) } From d4715e0c4d1a2b7517eb03ad10c3c586ce86a12d Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 12:08:23 +0100 Subject: [PATCH 10/28] Fix same vector sort bug --- milli/src/search/new/vector_sort.rs | 42 +++++++++++++++++++---------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 2d7cdbe39..38fcfde48 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -35,7 +35,11 @@ impl VectorSort { }) } - fn fill_buffer(&mut self, ctx: &mut SearchContext<'_>) -> Result<()> { + fn fill_buffer( + &mut self, + ctx: &mut SearchContext<'_>, + vector_candidates: &RoaringBitmap, + ) -> Result<()> { let readers: std::result::Result, _> = (0..=u8::MAX) .map_while(|k| { arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy) @@ -54,13 +58,8 @@ impl VectorSort { let mut results = Vec::new(); for reader in readers.iter() { - let nns_by_vector = reader.nns_by_vector( - ctx.txn, - target, - self.limit, - None, - Some(&self.vector_candidates), - )?; + let nns_by_vector = + reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?; let vectors: std::result::Result, _> = nns_by_vector .iter() .map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap()) @@ -90,8 +89,8 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { assert!(self.query.is_none()); self.query = Some(query.clone()); - self.vector_candidates &= universe; - self.fill_buffer(ctx)?; + let vector_candidates = &self.vector_candidates & universe; + self.fill_buffer(ctx, &vector_candidates)?; Ok(()) } @@ -103,9 +102,9 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { universe: &RoaringBitmap, ) -> Result>> { let query = self.query.as_ref().unwrap().clone(); - self.vector_candidates &= universe; + let vector_candidates = &self.vector_candidates & universe; - if self.vector_candidates.is_empty() { + if vector_candidates.is_empty() { return Ok(Some(RankingRuleOutput { query, candidates: universe.clone(), @@ -117,7 +116,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { } for (docid, distance, vector) in self.cached_sorted_docids.by_ref() { - if self.vector_candidates.contains(docid) { + if vector_candidates.contains(docid) { let score = 1.0 - distance; let score = self .distribution_shift @@ -136,7 +135,22 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { // if we got out of this loop it means we've exhausted our cache. // we need to refill it and run the function again. - self.fill_buffer(ctx)?; + self.fill_buffer(ctx, &vector_candidates)?; + + // we tried filling the buffer, but it remained empty 😢 + // it means we don't actually have any document remaining in the universe with a vector. + // => exit + if self.cached_sorted_docids.len() == 0 { + return Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: self.target.clone(), + value_similarity: None, + }), + })); + } + self.next_bucket(ctx, _logger, universe) } From abbe1310848b8746725639871a9009e709f402e6 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 12:08:36 +0100 Subject: [PATCH 11/28] Cosmetic change --- index-scheduler/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index fbe38a7fb..d01b0a17d 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -342,7 +342,7 @@ pub struct IndexScheduler { /// so that a handle to the index is available from other threads (search) in an optimized manner. currently_updating_index: Arc>>, - embedders: Arc>>>, + embedders: Arc>>>, // ================= test // The next entry is dedicated to the tests. From 922a640188bd4b4930bf18b6b6ce9b8a73927d28 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 21:19:48 +0100 Subject: [PATCH 12/28] WIP multi embedders fixed template bugs --- index-scheduler/src/lib.rs | 1 - meilisearch-types/src/error.rs | 5 ++ .../src/analytics/segment_analytics.rs | 33 ++++++- .../src/routes/indexes/facet_search.rs | 10 ++- meilisearch/src/routes/indexes/search.rs | 51 ++++++++--- meilisearch/src/search.rs | 16 ++++ milli/src/error.rs | 19 +++- milli/src/prompt/mod.rs | 78 +++++++++++++++- milli/src/prompt/template_checker.rs | 45 +++++++--- milli/src/search/new/mod.rs | 2 +- .../extract/extract_vector_points.rs | 89 +++++++------------ .../src/update/index_documents/extract/mod.rs | 47 +++++----- milli/src/update/index_documents/mod.rs | 20 +++-- .../src/update/index_documents/typed_chunk.rs | 21 +++-- milli/src/update/settings.rs | 18 +++- milli/src/vector/error.rs | 15 ++++ milli/src/vector/hf.rs | 21 +++-- milli/src/vector/mod.rs | 24 ++++- milli/src/vector/openai.rs | 23 ++++- milli/src/vector/settings.rs | 58 ++++++++---- 20 files changed, 438 insertions(+), 158 deletions(-) diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index d01b0a17d..65d257ea0 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -1361,7 +1361,6 @@ impl IndexScheduler { let embedder = Arc::new( Embedder::new(embedder_options.clone()) .map_err(meilisearch_types::milli::vector::Error::from) - .map_err(meilisearch_types::milli::UserError::from) .map_err(meilisearch_types::milli::Error::from)?, ); { diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index b1cc7cf82..5df1ae106 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -222,6 +222,8 @@ InvalidVectorsType , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; +InvalidEmbedder , InvalidRequest , BAD_REQUEST ; +InvalidHybridQuery , InvalidRequest , BAD_REQUEST ; InvalidIndexLimit , InvalidRequest , BAD_REQUEST ; InvalidIndexOffset , InvalidRequest , BAD_REQUEST ; InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ; @@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; +InvalidSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; @@ -340,6 +343,7 @@ impl ErrorCode for milli::Error { } UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, + UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::MultiplePrimaryKeyCandidatesFound { .. } => { @@ -363,6 +367,7 @@ impl ErrorCode for milli::Error { UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance } + UserError::InvalidEmbedder(_) => Code::InvalidEmbedder, UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, } } diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index d5f08936d..67770d87c 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats}; use crate::search::{ FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, - DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO, }; use crate::Opt; @@ -586,6 +586,11 @@ pub struct SearchAggregator { // vector // The maximum number of floats in a vector request max_vector_size: usize, + // Whether the semantic ratio passed to a hybrid search equals the default ratio. + semantic_ratio: bool, + // Whether a non-default embedder was specified + embedder: bool, + hybrid: bool, // every time a search is done, we increment the counter linked to the used settings matching_strategy: HashMap, @@ -639,6 +644,7 @@ impl SearchAggregator { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, } = query; let mut ret = Self::default(); @@ -712,6 +718,12 @@ impl SearchAggregator { ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score_details = *show_ranking_score_details; + if let Some(hybrid) = hybrid { + ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO(); + ret.embedder = hybrid.embedder.is_some(); + ret.hybrid = true; + } + ret } @@ -765,6 +777,9 @@ impl SearchAggregator { facets_total_number_of_facets, show_ranking_score, show_ranking_score_details, + semantic_ratio, + embedder, + hybrid, } = other; if self.timestamp.is_none() { @@ -810,6 +825,9 @@ impl SearchAggregator { // vector self.max_vector_size = self.max_vector_size.max(max_vector_size); + self.semantic_ratio |= semantic_ratio; + self.hybrid |= hybrid; + self.embedder |= embedder; // pagination self.max_limit = self.max_limit.max(max_limit); @@ -878,6 +896,9 @@ impl SearchAggregator { facets_total_number_of_facets, show_ranking_score, show_ranking_score_details, + semantic_ratio, + embedder, + hybrid, } = self; if total_received == 0 { @@ -917,6 +938,11 @@ impl SearchAggregator { "vector": { "max_vector_size": max_vector_size, }, + "hybrid": { + "enabled": hybrid, + "semantic_ratio": semantic_ratio, + "embedder": embedder, + }, "pagination": { "max_limit": max_limit, "max_offset": max_offset, @@ -1012,6 +1038,7 @@ impl MultiSearchAggregator { crop_marker: _, matching_strategy: _, attributes_to_search_on: _, + hybrid: _, } = query; index_uid.as_str() @@ -1158,6 +1185,7 @@ impl FacetSearchAggregator { filter, matching_strategy, attributes_to_search_on, + hybrid, } = query; let mut ret = Self::default(); @@ -1171,7 +1199,8 @@ impl FacetSearchAggregator { || vector.is_some() || filter.is_some() || *matching_strategy != MatchingStrategy::default() - || attributes_to_search_on.is_some(); + || attributes_to_search_on.is_some() + || hybrid.is_some(); ret } diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 72440711c..59c0e7353 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -14,9 +14,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator}; use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::search::{ - add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, - DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, + DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -37,6 +37,8 @@ pub struct FacetSearchQuery { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option>, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default, error = DeserrJsonError)] pub filter: Option, #[deserr(default, error = DeserrJsonError, default)] @@ -96,6 +98,7 @@ impl From for SearchQuery { filter, matching_strategy, attributes_to_search_on, + hybrid, } = value; SearchQuery { @@ -120,6 +123,7 @@ impl From for SearchQuery { matching_strategy, vector: vector.map(VectorQuery::Vector), attributes_to_search_on, + hybrid, } } } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index e63a95e60..ec4825661 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::VectorQuery; +use meilisearch_types::milli::{self, VectorQuery}; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -17,9 +17,9 @@ use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, - DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, + DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -75,6 +75,10 @@ pub struct SearchQueryGet { matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrQueryParamError)] pub attributes_to_search_on: Option>, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_embedder: Option, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_semantic_ratio: Option, } impl From for SearchQuery { @@ -87,6 +91,18 @@ impl From for SearchQuery { None => None, }; + let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) { + (None, None) => None, + (None, Some(semantic_ratio)) => Some(HybridQuery { semantic_ratio, embedder: None }), + (Some(embedder), None) => Some(HybridQuery { + semantic_ratio: DEFAULT_SEMANTIC_RATIO(), + embedder: Some(embedder), + }), + (Some(embedder), Some(semantic_ratio)) => { + Some(HybridQuery { semantic_ratio, embedder: Some(embedder) }) + } + }; + Self { q: other.q, vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), @@ -109,6 +125,7 @@ impl From for SearchQuery { crop_marker: other.crop_marker, matching_strategy: other.matching_strategy, attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), + hybrid, } } } @@ -159,6 +176,9 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); + + embed(&mut query, index_scheduler.get_ref(), &index).await?; + let search_result = tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; if let Ok(ref search_result) = search_result { @@ -213,22 +233,31 @@ pub async fn search_with_post( pub async fn embed( query: &mut SearchQuery, index_scheduler: &IndexScheduler, - index: &meilisearch_types::milli::Index, + index: &milli::Index, ) -> Result<(), ResponseError> { if let Some(VectorQuery::String(prompt)) = query.vector.take() { let embedder_configs = index.embedding_configs(&index.read_txn()?)?; let embedder = index_scheduler.embedders(embedder_configs)?; - /// FIXME: add error if no embedder, remove unwrap, support multiple embedders + let embedder_name = if let Some(HybridQuery { + semantic_ratio: _, + embedder: Some(embedder), + }) = &query.hybrid + { + embedder + } else { + "default" + }; + let embeddings = embedder - .get("default") - .unwrap() + .get(embedder_name) + .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) + .map_err(milli::Error::from)? .0 .embed(vec![prompt]) .await - .map_err(meilisearch_types::milli::vector::Error::from) - .map_err(meilisearch_types::milli::UserError::from) - .map_err(meilisearch_types::milli::Error::from)? + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? .pop() .expect("No vector returned from embedding"); diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 9136157f9..c1e667570 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -36,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10; pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); +pub const DEFAULT_SEMANTIC_RATIO: fn() -> f32 = || 0.5; #[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] @@ -44,6 +45,8 @@ pub struct SearchQuery { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -84,6 +87,15 @@ pub struct SearchQuery { pub attributes_to_search_on: Option>, } +#[derive(Debug, Clone, Default, PartialEq, Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct HybridQuery { + #[deserr(default, error = DeserrJsonError, default = DEFAULT_SEMANTIC_RATIO())] + pub semantic_ratio: f32, + #[deserr(default, error = DeserrJsonError, default)] + pub embedder: Option, +} + impl SearchQuery { pub fn is_finite_pagination(&self) -> bool { self.page.or(self.hits_per_page).is_some() @@ -103,6 +115,8 @@ pub struct SearchQueryWithIndex { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -168,6 +182,7 @@ impl SearchQueryWithIndex { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, } = self; ( index_uid, @@ -193,6 +208,7 @@ impl SearchQueryWithIndex { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, // do not use ..Default::default() here, // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` }, diff --git a/milli/src/error.rs b/milli/src/error.rs index 3d07590b0..95a0aba6d 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -63,6 +63,8 @@ pub enum InternalError { InvalidMatchingWords, #[error(transparent)] ArroyError(#[from] arroy::Error), + #[error(transparent)] + VectorEmbeddingError(#[from] crate::vector::Error), } #[derive(Error, Debug)] @@ -188,8 +190,23 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), #[error(transparent)] InvalidPrompt(#[from] crate::prompt::error::NewPromptError), - #[error("Invalid prompt in for embeddings with name '{0}': {1}")] + #[error("Invalid prompt in for embeddings with name '{0}': {1}.")] InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), + #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] + TooManyEmbedders(usize), + #[error("Cannot find embedder with name {0}.")] + InvalidEmbedder(String), +} + +impl From for Error { + fn from(value: crate::vector::Error) -> Self { + match value.fault() { + FaultSource::User => Error::UserError(value.into()), + FaultSource::Runtime => Error::InternalError(value.into()), + FaultSource::Bug => Error::InternalError(value.into()), + FaultSource::Undecided => Error::InternalError(value.into()), + } + } } impl From for Error { diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs index 351a51bb1..67ef8b4f6 100644 --- a/milli/src/prompt/mod.rs +++ b/milli/src/prompt/mod.rs @@ -110,7 +110,6 @@ impl Prompt { }; // render template with special object that's OK with `doc.*` and `fields.*` - /// FIXME: doesn't work for nested objects e.g. `doc.a.b` this.template .render(&template_checker::TemplateChecker) .map_err(NewPromptError::invalid_fields_in_template)?; @@ -142,3 +141,80 @@ pub enum PromptFallbackStrategy { #[default] Error, } + +#[cfg(test)] +mod test { + use super::Prompt; + use crate::error::FaultSource; + use crate::prompt::error::{NewPromptError, NewPromptErrorKind}; + + #[test] + fn default_template() { + // does not panic + Prompt::default(); + } + + #[test] + fn empty_template() { + Prompt::new("".into(), None, None).unwrap(); + } + + #[test] + fn template_ok() { + Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None, None).unwrap(); + } + + #[test] + fn template_syntax() { + assert!(matches!( + Prompt::new("{{doc.title: {{doc.overview}}".into(), None, None), + Err(NewPromptError { + kind: NewPromptErrorKind::CannotParseTemplate(_), + fault: FaultSource::User + }) + )); + } + + #[test] + fn template_missing_doc() { + assert!(matches!( + Prompt::new("{{title}}: {{overview}}".into(), None, None), + Err(NewPromptError { + kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), + fault: FaultSource::User + }) + )); + } + + #[test] + fn template_nested_doc() { + Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None, None).unwrap(); + } + + #[test] + fn template_fields() { + Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None, None).unwrap(); + } + + #[test] + fn template_fields_ok() { + Prompt::new( + "{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(), + None, + None, + ) + .unwrap(); + } + + #[test] + fn template_fields_invalid() { + assert!(matches!( + // intentionally garbled field + Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None, None), + Err(NewPromptError { + kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), + fault: FaultSource::User + }) + )); + } +} diff --git a/milli/src/prompt/template_checker.rs b/milli/src/prompt/template_checker.rs index 641a9ed64..4cda4a70d 100644 --- a/milli/src/prompt/template_checker.rs +++ b/milli/src/prompt/template_checker.rs @@ -1,7 +1,7 @@ use liquid::model::{ ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, }; -use liquid::{ObjectView, ValueView}; +use liquid::{Object, ObjectView, ValueView}; #[derive(Debug)] pub struct TemplateChecker; @@ -31,11 +31,11 @@ impl ObjectView for DummyField { } fn values<'k>(&'k self) -> Box + 'k> { - Box::new(std::iter::empty()) + Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter()) } fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { - Box::new(std::iter::empty()) + Box::new(self.keys().zip(self.values())) } fn contains_key(&self, index: &str) -> bool { @@ -69,7 +69,12 @@ impl ValueView for DummyField { } fn query_state(&self, state: State) -> bool { - DUMMY_VALUE.query_state(state) + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } } fn to_kstr(&self) -> KStringCow<'_> { @@ -77,7 +82,10 @@ impl ValueView for DummyField { } fn to_value(&self) -> LiquidValue { - LiquidValue::Nil + let mut this = Object::new(); + this.insert("name".into(), LiquidValue::Nil); + this.insert("value".into(), LiquidValue::Nil); + LiquidValue::Object(this) } fn as_object(&self) -> Option<&dyn ObjectView> { @@ -103,7 +111,12 @@ impl ValueView for DummyFields { } fn query_state(&self, state: State) -> bool { - DUMMY_VALUE.query_state(state) + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } } fn to_kstr(&self) -> KStringCow<'_> { @@ -111,7 +124,7 @@ impl ValueView for DummyFields { } fn to_value(&self) -> LiquidValue { - LiquidValue::Nil + LiquidValue::Array(vec![DummyField.to_value()]) } fn as_array(&self) -> Option<&dyn ArrayView> { @@ -125,15 +138,15 @@ impl ArrayView for DummyFields { } fn size(&self) -> i64 { - i64::MAX + u16::MAX as i64 } fn values<'k>(&'k self) -> Box + 'k> { - Box::new(std::iter::empty()) + Box::new(std::iter::once(DummyField.as_value())) } - fn contains_key(&self, _index: i64) -> bool { - true + fn contains_key(&self, index: i64) -> bool { + index < self.size() } fn get(&self, _index: i64) -> Option<&dyn ValueView> { @@ -167,7 +180,8 @@ impl ObjectView for DummyDoc { } fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { - Some(DUMMY_VALUE.as_view()) + // Recursively sends itself + Some(self) } } @@ -189,7 +203,12 @@ impl ValueView for DummyDoc { } fn query_state(&self, state: State) -> bool { - DUMMY_VALUE.query_state(state) + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } } fn to_kstr(&self) -> KStringCow<'_> { diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index bba6cf119..bc7f6fb08 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -516,7 +516,7 @@ pub fn execute_vector_search( ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; - /// FIXME: input universe = universe & documents_with_vectors + // FIXME: input universe = universe & documents_with_vectors // for now if we're computing embeddings for ALL documents, we can assume that this is just universe let ranking_rules = get_ranking_rules_for_vector( ctx, diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index d8d6c933c..6edde98fb 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -71,8 +71,8 @@ impl VectorStateDelta { pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, - field_id_map: FieldsIdsMap, - prompt: Option<&Prompt>, + field_id_map: &FieldsIdsMap, + prompt: &Prompt, ) -> Result { puffin::profile_function!(); @@ -142,14 +142,11 @@ pub fn extract_vector_points( .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { // becomes autogenerated - match prompt { - Some(prompt) => VectorStateDelta::NowGenerated(prompt.render( - obkv, - DelAdd::Addition, - &field_id_map, - )?), - None => VectorStateDelta::NowRemoved, - } + VectorStateDelta::NowGenerated(prompt.render( + obkv, + DelAdd::Addition, + field_id_map, + )?) } else { VectorStateDelta::NowRemoved } @@ -162,26 +159,18 @@ pub fn extract_vector_points( .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { - match prompt { - Some(prompt) => { - // Don't give up if the old prompt was failing - let old_prompt = prompt - .render(obkv, DelAdd::Deletion, &field_id_map) - .unwrap_or_default(); - let new_prompt = - prompt.render(obkv, DelAdd::Addition, &field_id_map)?; - if old_prompt != new_prompt { - log::trace!( - "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" - ); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange - } - } - // We no longer have a prompt, so we need to remove any existing vector - None => VectorStateDelta::NowRemoved, + // Don't give up if the old prompt was failing + let old_prompt = + prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + log::trace!("⏭️ Prompt unmodified, skipping"); + VectorStateDelta::NoChange } } else { VectorStateDelta::NowRemoved @@ -196,24 +185,16 @@ pub fn extract_vector_points( .any(|deladd| deladd.get(DelAdd::Addition).is_some()); if document_is_kept { - match prompt { - Some(prompt) => { - // Don't give up if the old prompt was failing - let old_prompt = prompt - .render(obkv, DelAdd::Deletion, &field_id_map) - .unwrap_or_default(); - let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?; - if old_prompt != new_prompt { - log::trace!( - "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" - ); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange - } - } - None => VectorStateDelta::NowRemoved, + // Don't give up if the old prompt was failing + let old_prompt = + prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; + if old_prompt != new_prompt { + log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"); + VectorStateDelta::NowGenerated(new_prompt) + } else { + log::trace!("⏭️ Prompt unmodified, skipping"); + VectorStateDelta::NoChange } } else { VectorStateDelta::NowRemoved @@ -322,7 +303,7 @@ pub fn extract_embeddings( prompt_reader: grenad::Reader, indexer: GrenadParameters, embedder: Arc, -) -> Result<(grenad::Reader>, Option)> { +) -> Result>> { let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism @@ -341,8 +322,6 @@ pub fn extract_embeddings( let mut chunks_ids = Vec::with_capacity(n_chunks); let mut cursor = prompt_reader.into_cursor()?; - let mut expected_dimension = None; - while let Some((key, value)) = cursor.move_on_next()? { let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); // SAFETY: precondition, the grenad value was saved from a string @@ -367,7 +346,6 @@ pub fn extract_embeddings( .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), ) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids @@ -376,7 +354,6 @@ pub fn extract_embeddings( .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) { state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; - expected_dimension = Some(embeddings.dimension()); } chunks_ids.clear(); } @@ -387,7 +364,6 @@ pub fn extract_embeddings( let chunked_embeds = rt .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?; for (docid, embeddings) in chunks_ids .iter() @@ -395,7 +371,6 @@ pub fn extract_embeddings( .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) { state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; - expected_dimension = Some(embeddings.dimension()); } } @@ -403,14 +378,12 @@ pub fn extract_embeddings( let embeds = rt .block_on(embedder.embed(std::mem::take(&mut current_chunk))) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?; for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; - expected_dimension = Some(embeddings.dimension()); } } - Ok((writer_into_reader(state_writer)?, expected_dimension)) + writer_into_reader(state_writer) } diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 69530a507..4831cc69d 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -292,43 +292,42 @@ fn send_original_documents_data( let documents_chunk_cloned = original_documents_chunk.clone(); let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); rayon::spawn(move || { - let (embedder, prompt) = embedders.get("default").cloned().unzip(); - let result = - extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref()); - match result { - Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { - /// FIXME: support multiple embedders - let results = embedder.and_then(|embedder| { - match extract_embeddings(prompts, indexer, embedder.clone()) { + for (name, (embedder, prompt)) in embedders { + let result = extract_vector_points( + documents_chunk_cloned.clone(), + indexer, + &field_id_map, + &prompt, + ); + match result { + Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { + let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { Ok(results) => Some(results), Err(error) => { let _ = lmdb_writer_sx_cloned.send(Err(error)); None } - } - }); - let (embeddings, expected_dimension) = results.unzip(); - let expected_dimension = expected_dimension.flatten(); - if !(remove_vectors.is_empty() - && manual_vectors.is_empty() - && embeddings.as_ref().map_or(true, |e| e.is_empty())) - { - /// FIXME FIXME FIXME - if expected_dimension.is_some() { + }; + + if !(remove_vectors.is_empty() + && manual_vectors.is_empty() + && embeddings.as_ref().map_or(true, |e| e.is_empty())) + { let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { remove_vectors, embeddings, - /// FIXME: compute an expected dimension from the manual vectors if any - expected_dimension: expected_dimension.unwrap(), + expected_dimension: embedder.dimensions(), manual_vectors, + embedder_name: name, })); } } + + Err(error) => { + let _ = lmdb_writer_sx_cloned.send(Err(error)); + } } - Err(error) => { - let _ = lmdb_writer_sx_cloned.send(Err(error)); - } - }; + } }); // TODO: create a custom internal error diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 472c77111..c3c39b90f 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -435,7 +435,7 @@ where let mut word_docids = None; let mut exact_word_docids = None; - let mut dimension = None; + let mut dimension = HashMap::new(); for result in lmdb_writer_rx { if (self.should_abort)() { @@ -471,13 +471,15 @@ where remove_vectors, embeddings, manual_vectors, + embedder_name, } => { - dimension = Some(expected_dimension); + dimension.insert(embedder_name.clone(), expected_dimension); TypedChunk::VectorPoints { remove_vectors, embeddings, expected_dimension, manual_vectors, + embedder_name, } } otherwise => otherwise, @@ -513,14 +515,22 @@ where self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; - if let Some(dimension) = dimension { + for (embedder_name, dimension) in dimension { let wtxn = &mut *self.wtxn; let vector_arroy = self.index.vector_arroy; + /// FIXME: unwrap + let embedder_index = + self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); pool.install(|| { - /// FIXME: do for each embedder + let writer_index = (embedder_index as u16) << 8; let mut rng = rand::rngs::StdRng::from_entropy(); for k in 0..=u8::MAX { - let writer = arroy::Writer::prepare(wtxn, vector_arroy, k.into(), dimension)?; + let writer = arroy::Writer::prepare( + wtxn, + vector_arroy, + writer_index | (k as u16), + dimension, + )?; if writer.is_empty(wtxn)? { break; } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index da99ed685..dde2124ed 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -47,6 +47,7 @@ pub(crate) enum TypedChunk { embeddings: Option>>, expected_dimension: usize, manual_vectors: grenad::Reader>, + embedder_name: String, }, ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), } @@ -100,8 +101,8 @@ impl TypedChunk { TypedChunk::GeoPoints(grenad) => { format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) } - TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => { - format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension) + TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => { + format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name) } TypedChunk::ScriptLanguageDocids(sl_map) => { format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) @@ -360,12 +361,20 @@ pub(crate) fn write_typed_chunk_into_index( manual_vectors, embeddings, expected_dimension, + embedder_name, } => { - /// FIXME: allow customizing distance + /// FIXME: unwrap + let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); + let writer_index = (embedder_index as u16) << 8; + // FIXME: allow customizing distance let writers: std::result::Result, _> = (0..=u8::MAX) .map(|k| { - /// FIXME: allow customizing index and then do index << 8 + k - arroy::Writer::prepare(wtxn, index.vector_arroy, k.into(), expected_dimension) + arroy::Writer::prepare( + wtxn, + index.vector_arroy, + writer_index | (k as u16), + expected_dimension, + ) }) .collect(); let writers = writers?; @@ -456,7 +465,7 @@ pub(crate) fn write_typed_chunk_into_index( } } - log::debug!("There are 🤷‍♀️ entries in the arroy so far"); + log::debug!("Finished vector chunk for {}", embedder_name); } TypedChunk::ScriptLanguageDocids(sl_map) => { for (key, (deletion, addition)) in sl_map { diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index b8355be51..1149dbce5 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -431,7 +431,6 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { let embedder = Arc::new( Embedder::new(embedder_options.clone()) .map_err(crate::vector::Error::from) - .map_err(crate::UserError::from) .map_err(crate::Error::from)?, ); Ok((name, (embedder, prompt))) @@ -976,6 +975,19 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Setting::NotSet => Some((name, EmbeddingSettings::default().into())), }) .collect(); + + self.index.embedder_category_id.clear(self.wtxn)?; + for (index, (embedder_name, _)) in new_configs.iter().enumerate() { + self.index.embedder_category_id.put_with_flags( + self.wtxn, + heed::PutFlags::APPEND, + embedder_name, + &index + .try_into() + .map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?, + )?; + } + if new_configs.is_empty() { self.index.delete_embedding_configs(self.wtxn)?; } else { @@ -1062,7 +1074,7 @@ fn validate_prompt( match new { Setting::Set(EmbeddingSettings { embedder_options, - prompt: + document_template: Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), }) => { // validate @@ -1072,7 +1084,7 @@ fn validate_prompt( Ok(Setting::Set(EmbeddingSettings { embedder_options, - prompt: Setting::Set(PromptSettings { + document_template: Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback, diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs index 1ae7a4678..c5cce622d 100644 --- a/milli/src/vector/error.rs +++ b/milli/src/vector/error.rs @@ -65,6 +65,8 @@ pub enum EmbedErrorKind { OpenAiTooManyTokens(OpenAiError), #[error("received unhandled HTTP status code {0} from OpenAI")] OpenAiUnhandledStatusCode(u16), + #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] + ManualEmbed(String), } impl EmbedError { @@ -111,6 +113,10 @@ impl EmbedError { pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } } + + pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { + Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } + } } #[derive(Debug, thiserror::Error)] @@ -170,6 +176,13 @@ impl NewEmbedderError { Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } } + pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { + Self { + kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), + fault: FaultSource::Runtime, + } + } + pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } } @@ -219,6 +232,8 @@ pub enum NewEmbedderErrorKind { NewApiFail(ApiError), #[error("fetching file from HG_HUB failed: {0}")] ApiGet(ApiError), + #[error("could not determine model dimensions: test embedding failed with {0}")] + CouldNotDetermineDimension(EmbedError), #[error("loading model failed: {0}")] LoadModel(candle_core::Error), // openai diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 81cdd4b34..07185d25c 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -62,6 +62,7 @@ pub struct Embedder { model: BertModel, tokenizer: Tokenizer, options: EmbedderOptions, + dimensions: usize, } impl std::fmt::Debug for Embedder { @@ -126,10 +127,17 @@ impl Embedder { tokenizer.with_padding(Some(pp)); } - Ok(Self { model, tokenizer, options }) + let mut this = Self { model, tokenizer, options, dimensions: 0 }; + + let embeddings = this + .embed(vec!["test".into()]) + .map_err(NewEmbedderError::hf_could_not_determine_dimension)?; + this.dimensions = embeddings.first().unwrap().dimension(); + + Ok(this) } - pub async fn embed( + pub fn embed( &self, mut texts: Vec, ) -> std::result::Result>, EmbedError> { @@ -170,12 +178,11 @@ impl Embedder { Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) } - pub async fn embed_chunks( + pub fn embed_chunks( &self, text_chunks: Vec>, ) -> std::result::Result>>, EmbedError> { - futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) - .await + text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() } pub fn chunk_count_hint(&self) -> usize { @@ -185,6 +192,10 @@ impl Embedder { pub fn prompt_count_in_chunk_hint(&self) -> usize { std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) } + + pub fn dimensions(&self) -> usize { + self.dimensions + } } fn normalize_l2(v: &Tensor) -> Result { diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 91640b8fb..7185e56b1 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -3,6 +3,7 @@ use crate::prompt::PromptData; pub mod error; pub mod hf; +pub mod manual; pub mod openai; pub mod settings; @@ -67,6 +68,7 @@ impl Embeddings { pub enum Embedder { HuggingFace(hf::Embedder), OpenAi(openai::Embedder), + UserProvided(manual::Embedder), } #[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] @@ -80,6 +82,7 @@ pub struct EmbeddingConfig { pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions), OpenAi(openai::EmbedderOptions), + UserProvided(manual::EmbedderOptions), } impl Default for EmbedderOptions { @@ -93,7 +96,7 @@ impl EmbedderOptions { Self::HuggingFace(hf::EmbedderOptions::new()) } - pub fn openai(api_key: String) -> Self { + pub fn openai(api_key: Option) -> Self { Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) } } @@ -103,6 +106,9 @@ impl Embedder { Ok(match options { EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), + EmbedderOptions::UserProvided(options) => { + Self::UserProvided(manual::Embedder::new(options)) + } }) } @@ -111,8 +117,9 @@ impl Embedder { texts: Vec, ) -> std::result::Result>, EmbedError> { match self { - Embedder::HuggingFace(embedder) => embedder.embed(texts).await, + Embedder::HuggingFace(embedder) => embedder.embed(texts), Embedder::OpenAi(embedder) => embedder.embed(texts).await, + Embedder::UserProvided(embedder) => embedder.embed(texts), } } @@ -121,8 +128,9 @@ impl Embedder { text_chunks: Vec>, ) -> std::result::Result>>, EmbedError> { match self { - Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), } } @@ -130,6 +138,7 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), + Embedder::UserProvided(_) => 1, } } @@ -137,6 +146,15 @@ impl Embedder { match self { Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), + Embedder::UserProvided(_) => 1, + } + } + + pub fn dimensions(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.dimensions(), + Embedder::OpenAi(embedder) => embedder.dimensions(), + Embedder::UserProvided(embedder) => embedder.dimensions(), } } } diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 670dc8526..bab62f5e4 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -15,7 +15,7 @@ pub struct Embedder { #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub struct EmbedderOptions { - pub api_key: String, + pub api_key: Option, pub embedding_model: EmbeddingModel, } @@ -68,11 +68,11 @@ impl EmbeddingModel { pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; impl EmbedderOptions { - pub fn with_default_model(api_key: String) -> Self { + pub fn with_default_model(api_key: Option) -> Self { Self { api_key, embedding_model: Default::default() } } - pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self { + pub fn with_embedding_model(api_key: Option, embedding_model: EmbeddingModel) -> Self { Self { api_key, embedding_model } } } @@ -80,9 +80,14 @@ impl EmbedderOptions { impl Embedder { pub fn new(options: EmbedderOptions) -> Result { let mut headers = reqwest::header::HeaderMap::new(); + let mut inferred_api_key = Default::default(); + let api_key = options.api_key.as_ref().unwrap_or_else(|| { + inferred_api_key = infer_api_key(); + &inferred_api_key + }); headers.insert( reqwest::header::AUTHORIZATION, - reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key)) + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) .map_err(NewEmbedderError::openai_invalid_api_key_format)?, ); headers.insert( @@ -315,6 +320,10 @@ impl Embedder { pub fn prompt_count_in_chunk_hint(&self) -> usize { 10 } + + pub fn dimensions(&self) -> usize { + self.options.embedding_model.dimensions() + } } // retrying in case of failure @@ -414,3 +423,9 @@ struct OpenAiEmbedding { // object: String, // index: usize, } + +fn infer_api_key() -> String { + std::env::var("MEILI_OPENAI_API_KEY") + .or_else(|_| std::env::var("OPENAI_API_KEY")) + .unwrap_or_default() +} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index 2c0cf7924..f90c3cc71 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -15,14 +15,14 @@ pub struct EmbeddingSettings { pub embedder_options: Setting, #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] - pub prompt: Setting, + pub document_template: Setting, } impl EmbeddingSettings { pub fn apply(&mut self, new: Self) { - let EmbeddingSettings { embedder_options, prompt } = new; + let EmbeddingSettings { embedder_options, document_template: prompt } = new; self.embedder_options.apply(embedder_options); - self.prompt.apply(prompt); + self.document_template.apply(prompt); } } @@ -30,7 +30,7 @@ impl From for EmbeddingSettings { fn from(value: EmbeddingConfig) -> Self { Self { embedder_options: Setting::Set(value.embedder_options.into()), - prompt: Setting::Set(value.prompt.into()), + document_template: Setting::Set(value.prompt.into()), } } } @@ -38,7 +38,7 @@ impl From for EmbeddingSettings { impl From for EmbeddingConfig { fn from(value: EmbeddingSettings) -> Self { let mut this = Self::default(); - let EmbeddingSettings { embedder_options, prompt } = value; + let EmbeddingSettings { embedder_options, document_template: prompt } = value; if let Some(embedder_options) = embedder_options.set() { this.embedder_options = embedder_options.into(); } @@ -105,6 +105,7 @@ impl From for PromptData { pub enum EmbedderSettings { HuggingFace(Setting), OpenAi(Setting), + UserProvided(UserProvidedSettings), } impl Deserr for EmbedderSettings @@ -145,11 +146,17 @@ where location.push_key(&k), )?, ))), + "userProvided" => Ok(EmbedderSettings::UserProvided( + UserProvidedSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + )), other => Err(deserr::take_cf_content(E::error::( None, deserr::ErrorKind::UnknownKey { key: other, - accepted: &["huggingFace", "openAi"], + accepted: &["huggingFace", "openAi", "userProvided"], }, location, ))), @@ -182,6 +189,9 @@ impl From for EmbedderSettings { crate::vector::EmbedderOptions::OpenAi(openai) => { Self::OpenAi(Setting::Set(openai.into())) } + crate::vector::EmbedderOptions::UserProvided(user_provided) => { + Self::UserProvided(user_provided.into()) + } } } } @@ -192,9 +202,12 @@ impl From for crate::vector::EmbedderOptions { EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), - EmbedderSettings::OpenAi(_setting) => Self::OpenAi( - crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()), - ), + EmbedderSettings::OpenAi(_setting) => { + Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None)) + } + EmbedderSettings::UserProvided(user_provided) => { + Self::UserProvided(user_provided.into()) + } } } } @@ -286,7 +299,7 @@ impl OpenAiEmbedderSettings { impl From for OpenAiEmbedderSettings { fn from(value: crate::vector::openai::EmbedderOptions) -> Self { Self { - api_key: Setting::Set(value.api_key), + api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset), embedding_model: Setting::Set(value.embedding_model), } } @@ -295,14 +308,25 @@ impl From for OpenAiEmbedderSettings { impl From for crate::vector::openai::EmbedderOptions { fn from(value: OpenAiEmbedderSettings) -> Self { let OpenAiEmbedderSettings { api_key, embedding_model } = value; - Self { - api_key: api_key.set().unwrap_or_else(infer_api_key), - embedding_model: embedding_model.set().unwrap_or_default(), - } + Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() } } } -fn infer_api_key() -> String { - /// FIXME: get key from instance options? - std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default() +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct UserProvidedSettings { + pub dimensions: usize, +} + +impl From for crate::vector::manual::EmbedderOptions { + fn from(value: UserProvidedSettings) -> Self { + Self { dimensions: value.dimensions } + } +} + +impl From for UserProvidedSettings { + fn from(value: crate::vector::manual::EmbedderOptions) -> Self { + Self { dimensions: value.dimensions } + } } From 12940d79a96905f38a9016cb647f2013693e849a Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Tue, 12 Dec 2023 23:39:01 +0100 Subject: [PATCH 13/28] WIP - manual embedder - multi embedders OK - clippy + tests OK --- meilisearch-types/src/error.rs | 3 + meilisearch/src/routes/indexes/search.rs | 60 +++--- meilisearch/src/search.rs | 21 +- meilisearch/tests/dumps/mod.rs | 39 ++-- meilisearch/tests/search/mod.rs | 26 ++- meilisearch/tests/settings/get_settings.rs | 3 +- milli/src/error.rs | 8 +- .../extract/extract_vector_points.rs | 197 +++++++++++------- .../src/update/index_documents/extract/mod.rs | 1 + milli/src/update/index_documents/mod.rs | 10 +- .../src/update/index_documents/typed_chunk.rs | 26 ++- milli/src/vector/manual.rs | 34 +++ milli/src/vector/mod.rs | 4 + 13 files changed, 292 insertions(+), 140 deletions(-) create mode 100644 milli/src/vector/manual.rs diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 5df1ae106..9df41b68f 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -305,6 +305,7 @@ NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENT PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ; TaskNotFound , InvalidRequest , NOT_FOUND ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; +TooManyVectors , InvalidRequest , BAD_REQUEST ; UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; @@ -362,7 +363,9 @@ impl ErrorCode for milli::Error { UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, + UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType, UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, + UserError::TooManyVectors(_, _) => Code::TooManyVectors, UserError::SortError(_) => Code::InvalidSearchSort, UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index ec4825661..c057d4809 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -235,38 +235,42 @@ pub async fn embed( index_scheduler: &IndexScheduler, index: &milli::Index, ) -> Result<(), ResponseError> { - if let Some(VectorQuery::String(prompt)) = query.vector.take() { - let embedder_configs = index.embedding_configs(&index.read_txn()?)?; - let embedder = index_scheduler.embedders(embedder_configs)?; + match query.vector.take() { + Some(VectorQuery::String(prompt)) => { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedder = index_scheduler.embedders(embedder_configs)?; - let embedder_name = if let Some(HybridQuery { - semantic_ratio: _, - embedder: Some(embedder), - }) = &query.hybrid - { - embedder - } else { - "default" - }; + let embedder_name = + if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) = + &query.hybrid + { + embedder + } else { + "default" + }; - let embeddings = embedder - .get(embedder_name) - .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) - .map_err(milli::Error::from)? - .0 - .embed(vec![prompt]) - .await - .map_err(milli::vector::Error::from) - .map_err(milli::Error::from)? - .pop() - .expect("No vector returned from embedding"); + let embeddings = embedder + .get(embedder_name) + .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) + .map_err(milli::Error::from)? + .0 + .embed(vec![prompt]) + .await + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); - if embeddings.iter().nth(1).is_some() { - warn!("Ignoring embeddings past the first one in long search query"); - query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); - } else { - query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = + Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); + } else { + query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); + } } + Some(vector) => query.vector = Some(vector), + None => {} }; Ok(()) } diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index c1e667570..d496da1a3 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -13,7 +13,7 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; +use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -562,8 +562,17 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - /// FIXME: remove this or set to value from the score details - let semantic_score = None; + let mut semantic_score = None; + for details in &score { + if let ScoreDetails::Vector(score_details::Vector { + target_vector: _, + value_similarity: Some((_matching_vector, similarity)), + }) = details + { + semantic_score = Some(*similarity); + break; + } + } let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -648,8 +657,10 @@ pub fn perform_search( hits: documents, hits_info, query: query.q.unwrap_or_default(), - // FIXME: display input vector - vector: None, + vector: match query.vector { + Some(VectorQuery::Vector(vector)) => Some(vector), + _ => None, + }, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, diff --git a/meilisearch/tests/dumps/mod.rs b/meilisearch/tests/dumps/mod.rs index 9e949436a..07cfddd37 100644 --- a/meilisearch/tests/dumps/mod.rs +++ b/meilisearch/tests/dumps/mod.rs @@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1896,7 +1908,8 @@ async fn import_dump_v6_containing_experimental_features() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "###); diff --git a/meilisearch/tests/search/mod.rs b/meilisearch/tests/search/mod.rs index 00678f7d4..fa97beaaf 100644 --- a/meilisearch/tests/search/mod.rs +++ b/meilisearch/tests/search/mod.rs @@ -876,7 +876,31 @@ async fn experimental_feature_vector_store() { })) .await; meili_snap::snapshot!(code, @"200 OK"); - meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]"); + // vector search returns all documents that don't have vectors in the last bucket, like all sorts + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Shazam!", + "id": "287947" + }, + { + "title": "Captain Marvel", + "id": "299537" + }, + { + "title": "Escape Room", + "id": "522681" + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "id": "166428" + }, + { + "title": "Gläss", + "id": "450465" + } + ] + "###); } #[cfg(feature = "default")] diff --git a/meilisearch/tests/settings/get_settings.rs b/meilisearch/tests/settings/get_settings.rs index 0ea556b94..9ab53c51e 100644 --- a/meilisearch/tests/settings/get_settings.rs +++ b/meilisearch/tests/settings/get_settings.rs @@ -54,7 +54,7 @@ async fn get_settings() { let (response, code) = index.settings().await; assert_eq!(code, 200); let settings = response.as_object().unwrap(); - assert_eq!(settings.keys().len(), 15); + assert_eq!(settings.keys().len(), 16); assert_eq!(settings["displayedAttributes"], json!(["*"])); assert_eq!(settings["searchableAttributes"], json!(["*"])); assert_eq!(settings["filterableAttributes"], json!([])); @@ -83,6 +83,7 @@ async fn get_settings() { "maxTotalHits": 1000, }) ); + assert_eq!(settings["embedders"], json!({})); } #[actix_rt::test] diff --git a/milli/src/error.rs b/milli/src/error.rs index 95a0aba6d..9c5d8f416 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -114,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco InvalidGeoField(#[from] GeoError), #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] InvalidVectorDimensions { expected: usize, found: usize }, - #[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] - InvalidVectorsType { document_id: Value, value: Value }, + #[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] + InvalidVectorsType { document_id: Value, value: Value, subfield: String }, + #[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")] + InvalidVectorsMapType { document_id: Value, value: Value }, #[error("{0}")] InvalidFilter(String), #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] @@ -196,6 +198,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco TooManyEmbedders(usize), #[error("Cannot find embedder with name {0}.")] InvalidEmbedder(String), + #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] + TooManyVectors(String, usize), } impl From for Error { diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 6edde98fb..3a0376511 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -73,6 +73,7 @@ pub fn extract_vector_points( indexer: GrenadParameters, field_id_map: &FieldsIdsMap, prompt: &Prompt, + embedder_name: &str, ) -> Result { puffin::profile_function!(); @@ -115,89 +116,87 @@ pub fn extract_vector_points( // lazily get it when needed let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; - let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) { - let vectors_obkv = KvReaderDelAdd::new(value); - match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) { - (Some(old), Some(new)) => { - // no autogeneration - let del_vectors = extract_vectors(old, document_id)?; - let add_vectors = extract_vectors(new, document_id)?; + let vectors_field = vectors_fid + .and_then(|vectors_fid| obkv.get(vectors_fid)) + .map(KvReaderDelAdd::new) + .map(|obkv| to_vector_maps(obkv, document_id)) + .transpose()?; - VectorStateDelta::ManualDelta( - del_vectors.unwrap_or_default(), - add_vectors.unwrap_or_default(), - ) - } - (None, Some(new)) => { - // was possibly autogenerated, remove all vectors for that document - let add_vectors = extract_vectors(new, document_id)?; + let (del_map, add_map) = vectors_field.unzip(); + let del_map = del_map.flatten(); + let add_map = add_map.flatten(); - VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default()) - } - (Some(_old), None) => { - // Do we keep this document? - let document_is_kept = obkv - .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) - .any(|deladd| deladd.get(DelAdd::Addition).is_some()); - if document_is_kept { - // becomes autogenerated - VectorStateDelta::NowGenerated(prompt.render( - obkv, - DelAdd::Addition, - field_id_map, - )?) - } else { - VectorStateDelta::NowRemoved - } - } - (None, None) => { - // Do we keep this document? - let document_is_kept = obkv - .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) - .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + let del_value = del_map.and_then(|mut map| map.remove(embedder_name)); + let add_value = add_map.and_then(|mut map| map.remove(embedder_name)); - if document_is_kept { - // Don't give up if the old prompt was failing - let old_prompt = - prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); - let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; - if old_prompt != new_prompt { - log::trace!( - "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" - ); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange - } - } else { - VectorStateDelta::NowRemoved - } + let delta = match (del_value, add_value) { + (Some(old), Some(new)) => { + // no autogeneration + let del_vectors = extract_vectors(old, document_id, embedder_name)?; + let add_vectors = extract_vectors(new, document_id, embedder_name)?; + + if add_vectors.len() > u8::MAX.into() { + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + document_id().to_string(), + add_vectors.len(), + ))); + } + + VectorStateDelta::ManualDelta(del_vectors, add_vectors) + } + (Some(_old), None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + if document_is_kept { + // becomes autogenerated + VectorStateDelta::NowGenerated(prompt.render( + obkv, + DelAdd::Addition, + field_id_map, + )?) + } else { + VectorStateDelta::NowRemoved } } - } else { - // Do we keep this document? - let document_is_kept = obkv - .iter() - .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) - .any(|deladd| deladd.get(DelAdd::Addition).is_some()); - - if document_is_kept { - // Don't give up if the old prompt was failing - let old_prompt = - prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); - let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; - if old_prompt != new_prompt { - log::trace!("🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}"); - VectorStateDelta::NowGenerated(new_prompt) - } else { - log::trace!("⏭️ Prompt unmodified, skipping"); - VectorStateDelta::NoChange + (None, Some(new)) => { + // was possibly autogenerated, remove all vectors for that document + let add_vectors = extract_vectors(new, document_id, embedder_name)?; + if add_vectors.len() > u8::MAX.into() { + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + document_id().to_string(), + add_vectors.len(), + ))); + } + + VectorStateDelta::WasGeneratedNowManual(add_vectors) + } + (None, None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + // Don't give up if the old prompt was failing + let old_prompt = + prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + log::trace!("⏭️ Prompt unmodified, skipping"); + VectorStateDelta::NoChange + } + } else { + VectorStateDelta::NowRemoved } - } else { - VectorStateDelta::NowRemoved } }; @@ -221,6 +220,34 @@ pub fn extract_vector_points( }) } +fn to_vector_maps( + obkv: KvReaderDelAdd, + document_id: impl Fn() -> Value, +) -> Result<(Option>, Option>)> { + let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?; + let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?; + Ok((del, add)) +} + +fn to_vector_map( + obkv: KvReaderDelAdd, + side: DelAdd, + document_id: &impl Fn() -> Value, +) -> Result>> { + Ok(if let Some(value) = obkv.get(side) { + let Ok(value) = from_slice(value) else { + let value = from_slice(value).map_err(InternalError::SerdeJson)?; + return Err(crate::Error::UserError(UserError::InvalidVectorsMapType { + document_id: document_id(), + value, + })); + }; + Some(value) + } else { + None + }) +} + /// Computes the diff between both Del and Add numbers and /// only inserts the parts that differ in the sorter. fn push_vectors_diff( @@ -286,12 +313,20 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering { } /// Extracts the vectors from a JSON value. -fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result>>> { - match from_slice(value) { - Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)), +fn extract_vectors( + value: Value, + document_id: impl Fn() -> Value, + name: &str, +) -> Result>> { + // FIXME: ugly clone of the vectors here + match serde_json::from_value(value.clone()) { + Ok(vectors) => { + Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default()) + } Err(_) => Err(UserError::InvalidVectorsType { document_id: document_id(), - value: from_slice(value).map_err(InternalError::SerdeJson)?, + value, + subfield: name.to_owned(), } .into()), } diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 4831cc69d..a852b035b 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -298,6 +298,7 @@ fn send_original_documents_data( indexer, &field_id_map, &prompt, + &name, ); match result { Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index c3c39b90f..075dcd184 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -514,16 +514,18 @@ where // We write the primary key field id into the main database self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; + let mut rng = rand::rngs::StdRng::from_entropy(); for (embedder_name, dimension) in dimension { let wtxn = &mut *self.wtxn; let vector_arroy = self.index.vector_arroy; - /// FIXME: unwrap - let embedder_index = - self.index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); + + let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, + )?; + pool.install(|| { let writer_index = (embedder_index as u16) << 8; - let mut rng = rand::rngs::StdRng::from_entropy(); for k in 0..=u8::MAX { let writer = arroy::Writer::prepare( wtxn, diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index dde2124ed..f8fb30c7b 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -22,7 +22,9 @@ use crate::index::db_name::DOCUMENTS; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; -use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; +use crate::{ + lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError, +}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -363,8 +365,9 @@ pub(crate) fn write_typed_chunk_into_index( expected_dimension, embedder_name, } => { - /// FIXME: unwrap - let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.unwrap(); + let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, + )?; let writer_index = (embedder_index as u16) << 8; // FIXME: allow customizing distance let writers: std::result::Result, _> = (0..=u8::MAX) @@ -404,7 +407,20 @@ pub(crate) fn write_typed_chunk_into_index( // code error if we somehow got the wrong dimension .unwrap(); - /// FIXME: detect overflow + if embeddings.embedding_count() > u8::MAX.into() { + let external_docid = if let Ok(Some(Ok(index))) = index + .external_id_of(wtxn, std::iter::once(docid)) + .map(|it| it.into_iter().next()) + { + index + } else { + format!("internal docid={docid}") + }; + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + external_docid, + embeddings.embedding_count(), + ))); + } for (embedding, writer) in embeddings.iter().zip(&writers) { writer.add_item(wtxn, docid, embedding)?; } @@ -455,7 +471,7 @@ pub(crate) fn write_typed_chunk_into_index( if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { let vector = pod_collect_to_vec(value); - /// FIXME: detect overflow + // overflow was detected during vector extraction. for writer in &writers { if !writer.contains_item(wtxn, docid)? { writer.add_item(wtxn, docid, &vector)?; diff --git a/milli/src/vector/manual.rs b/milli/src/vector/manual.rs new file mode 100644 index 000000000..7ed48a251 --- /dev/null +++ b/milli/src/vector/manual.rs @@ -0,0 +1,34 @@ +use super::error::EmbedError; +use super::Embeddings; + +#[derive(Debug, Clone, Copy)] +pub struct Embedder { + dimensions: usize, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub dimensions: usize, +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Self { + Self { dimensions: options.dimensions } + } + + pub fn embed(&self, mut texts: Vec) -> Result>, EmbedError> { + let Some(text) = texts.pop() else { return Ok(Default::default()) }; + Err(EmbedError::embed_on_manual_embedder(text)) + } + + pub fn dimensions(&self) -> usize { + self.dimensions + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() + } +} diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index 7185e56b1..fa39c20a2 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -31,6 +31,10 @@ impl Embeddings { Ok(this) } + pub fn embedding_count(&self) -> usize { + self.data.len() / self.dimension + } + pub fn dimension(&self) -> usize { self.dimension } From e0cc775dc4aa7fed199f628a0328ba0e0b20f295 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 15:38:44 +0100 Subject: [PATCH 14/28] Various changes - DistributionShift in Search object (to be set from model in embed?) - Fix issue where embedder index wasn't computed at search time - Accept as default embedder either the "default" one, or the only embedder when there is only one --- index-scheduler/src/lib.rs | 7 ++- meilisearch/src/routes/indexes/search.rs | 20 ++++++--- meilisearch/src/search.rs | 4 ++ milli/src/index.rs | 8 ++++ milli/src/search/hybrid.rs | 13 ++++++ milli/src/search/mod.rs | 33 ++++++++++++++ milli/src/search/new/mod.rs | 7 ++- milli/src/search/new/vector_sort.rs | 11 ++++- .../src/update/index_documents/extract/mod.rs | 10 ++--- milli/src/update/index_documents/mod.rs | 11 ++--- milli/src/update/settings.rs | 7 ++- milli/src/vector/mod.rs | 43 ++++++++++++++++++- 12 files changed, 141 insertions(+), 33 deletions(-) diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index 65d257ea0..b9b360fa4 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -52,7 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; use meilisearch_types::milli::documents::DocumentsBatchBuilder; use meilisearch_types::milli::update::IndexerConfig; -use meilisearch_types::milli::vector::{Embedder, EmbedderOptions}; +use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs}; use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; use puffin::FrameView; @@ -1339,11 +1339,10 @@ impl IndexScheduler { } // TODO: consider using a type alias or a struct embedder/template - #[allow(clippy::type_complexity)] pub fn embedders( &self, embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>, - ) -> Result, Arc)>> { + ) -> Result { let res: Result<_> = embedding_configs .into_iter() .map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| { @@ -1370,7 +1369,7 @@ impl IndexScheduler { Ok((name, (embedder, prompt))) }) .collect(); - res + res.map(EmbeddingConfigs::new) } /// Blocks the thread until the test handle asks to progress to/through this breakpoint. diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index c057d4809..7a9a14687 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -238,22 +238,28 @@ pub async fn embed( match query.vector.take() { Some(VectorQuery::String(prompt)) => { let embedder_configs = index.embedding_configs(&index.read_txn()?)?; - let embedder = index_scheduler.embedders(embedder_configs)?; + let embedders = index_scheduler.embedders(embedder_configs)?; let embedder_name = if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) = &query.hybrid { - embedder + Some(embedder) } else { - "default" + None }; - let embeddings = embedder - .get(embedder_name) - .ok_or(milli::UserError::InvalidEmbedder(embedder_name.to_owned())) + let embedder = if let Some(embedder_name) = embedder_name { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; + + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) .map_err(milli::Error::from)? - .0 + .0; + let embeddings = embedder .embed(vec![prompt]) .await .map_err(milli::vector::Error::from) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index d496da1a3..53f6140fb 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -398,6 +398,10 @@ fn prepare_search<'t>( features.check_vector("Passing `vector` as a query parameter")?; } + if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { + search.embedder_name(embedder); + } + // compute the offset on the limit depending on the pagination mode. let (offset, limit) = if is_finite_pagination { let limit = query.hits_per_page.unwrap_or_else(DEFAULT_SEARCH_LIMIT); diff --git a/milli/src/index.rs b/milli/src/index.rs index 05babf410..6ad39dcb1 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -1499,6 +1499,14 @@ impl Index { .get(rtxn, main_key::EMBEDDING_CONFIGS)? .unwrap_or_default()) } + + pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result { + let configs = self.embedding_configs(rtxn)?; + Ok(match configs.as_slice() { + [(ref first_name, _)] => first_name.clone(), + _ => "default".to_owned(), + }) + } } #[cfg(test)] diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index 02c518126..cbec20c65 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -218,6 +218,8 @@ impl<'a> Search<'a> { exhaustive_number_hits: self.exhaustive_number_hits, rtxn: self.rtxn, index: self.index, + distribution_shift: self.distribution_shift, + embedder_name: self.embedder_name.clone(), }; let vector_query = search.vector.take(); @@ -265,6 +267,15 @@ impl<'a> Search<'a> { vector: &[f32], keyword_results: &SearchResult, ) -> Result { + let embedder_name; + let embedder_name = match &self.embedder_name { + Some(embedder_name) => embedder_name, + None => { + embedder_name = self.index.default_embedding_name(self.rtxn)?; + &embedder_name + } + }; + let mut ctx = SearchContext::new(self.index, self.rtxn); if let Some(searchable_attributes) = self.searchable_attributes { @@ -282,6 +293,8 @@ impl<'a> Search<'a> { self.geo_strategy, 0, self.limit + self.offset, + self.distribution_shift, + embedder_name, ) } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 8b541ffcd..04a6005e3 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -17,6 +17,7 @@ use self::new::{execute_vector_search, PartialSearchResult}; use crate::error::UserError; use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; use crate::score_details::{ScoreDetails, ScoringStrategy}; +use crate::vector::DistributionShift; use crate::{ execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, @@ -51,6 +52,8 @@ pub struct Search<'a> { exhaustive_number_hits: bool, rtxn: &'a heed::RoTxn<'a>, index: &'a Index, + distribution_shift: Option, + embedder_name: Option, } #[derive(Debug, Clone, PartialEq)] @@ -117,6 +120,8 @@ impl<'a> Search<'a> { words_limit: 10, rtxn, index, + distribution_shift: None, + embedder_name: None, } } @@ -183,7 +188,29 @@ impl<'a> Search<'a> { self } + pub fn distribution_shift( + &mut self, + distribution_shift: Option, + ) -> &mut Search<'a> { + self.distribution_shift = distribution_shift; + self + } + + pub fn embedder_name(&mut self, embedder_name: impl Into) -> &mut Search<'a> { + self.embedder_name = Some(embedder_name.into()); + self + } + pub fn execute(&self) -> Result { + let embedder_name; + let embedder_name = match &self.embedder_name { + Some(embedder_name) => embedder_name, + None => { + embedder_name = self.index.default_embedding_name(self.rtxn)?; + &embedder_name + } + }; + let mut ctx = SearchContext::new(self.index, self.rtxn); if let Some(searchable_attributes) = self.searchable_attributes { @@ -202,6 +229,8 @@ impl<'a> Search<'a> { self.geo_strategy, self.offset, self.limit, + self.distribution_shift, + embedder_name, )?, None => execute_search( &mut ctx, @@ -247,6 +276,8 @@ impl fmt::Debug for Search<'_> { exhaustive_number_hits, rtxn: _, index: _, + distribution_shift, + embedder_name, } = self; f.debug_struct("Search") .field("query", query) @@ -260,6 +291,8 @@ impl fmt::Debug for Search<'_> { .field("scoring_strategy", scoring_strategy) .field("exhaustive_number_hits", exhaustive_number_hits) .field("words_limit", words_limit) + .field("distribution_shift", distribution_shift) + .field("embedder_name", embedder_name) .finish() } } diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index bc7f6fb08..405b9747d 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -266,6 +266,7 @@ fn get_ranking_rules_for_vector<'ctx>( limit_plus_offset: usize, target: &[f32], distribution_shift: Option, + embedder_name: &str, ) -> Result>> { // query graph search @@ -292,6 +293,7 @@ fn get_ranking_rules_for_vector<'ctx>( vector_candidates, limit_plus_offset, distribution_shift, + embedder_name, )?; ranking_rules.push(Box::new(vector_sort)); vector = true; @@ -513,6 +515,8 @@ pub fn execute_vector_search( geo_strategy: geo_sort::Strategy, from: usize, length: usize, + distribution_shift: Option, + embedder_name: &str, ) -> Result { check_sort_criteria(ctx, sort_criteria.as_ref())?; @@ -524,7 +528,8 @@ pub fn execute_vector_search( geo_strategy, from + length, vector, - None, + distribution_shift, + embedder_name, )?; let mut placeholder_search_logger = logger::DefaultSearchLogger; diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 38fcfde48..6a37ceb7d 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -15,16 +15,21 @@ pub struct VectorSort { cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec)>, limit: usize, distribution_shift: Option, + embedder_index: u8, } impl VectorSort { pub fn new( - _ctx: &SearchContext, + ctx: &SearchContext, target: Vec, vector_candidates: RoaringBitmap, limit: usize, distribution_shift: Option, + embedder_name: &str, ) -> Result { + /// FIXME: unwrap + let embedder_index = ctx.index.embedder_category_id.get(ctx.txn, embedder_name)?.unwrap(); + Ok(Self { query: None, target, @@ -32,6 +37,7 @@ impl VectorSort { cached_sorted_docids: Default::default(), limit, distribution_shift, + embedder_index, }) } @@ -40,9 +46,10 @@ impl VectorSort { ctx: &mut SearchContext<'_>, vector_candidates: &RoaringBitmap, ) -> Result<()> { + let writer_index = (self.embedder_index as u16) << 8; let readers: std::result::Result, _> = (0..=u8::MAX) .map_while(|k| { - arroy::Reader::open(ctx.txn, k.into(), ctx.index.vector_arroy) + arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy) .map(Some) .or_else(|e| match e { arroy::Error::MissingMetadata => Ok(None), diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index a852b035b..1d06849de 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -9,10 +9,9 @@ mod extract_word_docids; mod extract_word_pair_proximity_docids; mod extract_word_position_docids; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::fs::File; use std::io::BufReader; -use std::sync::Arc; use crossbeam_channel::Sender; use log::debug; @@ -35,9 +34,8 @@ use super::helpers::{ MergeFn, MergeableReader, }; use super::{helpers, TypedChunk}; -use crate::prompt::Prompt; use crate::proximity::ProximityPrecision; -use crate::vector::Embedder; +use crate::vector::EmbeddingConfigs; use crate::{FieldId, FieldsIdsMap, Result}; /// Extract data for each databases from obkv documents in parallel. @@ -59,7 +57,7 @@ pub(crate) fn data_from_obkv_documents( max_positions_per_attributes: Option, exact_attributes: HashSet, proximity_precision: ProximityPrecision, - embedders: HashMap, Arc)>, + embedders: EmbeddingConfigs, ) -> Result<()> { puffin::profile_function!(); @@ -284,7 +282,7 @@ fn send_original_documents_data( indexer: GrenadParameters, lmdb_writer_sx: Sender>, field_id_map: FieldsIdsMap, - embedders: HashMap, Arc)>, + embedders: EmbeddingConfigs, ) -> Result<()> { let original_documents_chunk = original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 075dcd184..efc6b22ff 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -9,7 +9,6 @@ use std::io::{Cursor, Read, Seek}; use std::iter::FromIterator; use std::num::NonZeroU32; use std::result::Result as StdResult; -use std::sync::Arc; use crossbeam_channel::{Receiver, Sender}; use heed::types::Str; @@ -34,12 +33,11 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters}; pub use self::transform::{Transform, TransformOutput}; use crate::documents::{obkv_to_object, DocumentsBatchReader}; use crate::error::{Error, InternalError, UserError}; -use crate::prompt::Prompt; pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, }; -use crate::vector::Embedder; +use crate::vector::EmbeddingConfigs; use crate::{CboRoaringBitmapCodec, Index, Result}; static MERGED_DATABASE_COUNT: usize = 7; @@ -82,7 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> { should_abort: FA, added_documents: u64, deleted_documents: u64, - embedders: HashMap, Arc)>, + embedders: EmbeddingConfigs, } #[derive(Default, Debug, Clone)] @@ -173,10 +171,7 @@ where Ok((self, Ok(indexed_documents))) } - pub fn with_embedders( - mut self, - embedders: HashMap, Arc)>, - ) -> Self { + pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self { self.embedders = embedders; self } diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 1149dbce5..e9f345e42 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -14,12 +14,11 @@ use super::IndexerConfig; use crate::criterion::Criterion; use crate::error::UserError; use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS}; -use crate::prompt::Prompt; use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; use crate::vector::settings::{EmbeddingSettings, PromptSettings}; -use crate::vector::{Embedder, EmbeddingConfig}; +use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; use crate::{FieldsIdsMap, Index, OrderBy, Result}; #[derive(Debug, Clone, PartialEq, Eq, Copy)] @@ -422,7 +421,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { fn embedders( &self, embedding_configs: Vec<(String, EmbeddingConfig)>, - ) -> Result, Arc)>> { + ) -> Result { let res: Result<_> = embedding_configs .into_iter() .map(|(name, EmbeddingConfig { embedder_options, prompt })| { @@ -436,7 +435,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Ok((name, (embedder, prompt))) }) .collect(); - res + res.map(EmbeddingConfigs::new) } fn update_displayed(&mut self) -> Result { diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index fa39c20a2..df5750e77 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -1,5 +1,8 @@ +use std::collections::HashMap; +use std::sync::Arc; + use self::error::{EmbedError, NewEmbedderError}; -use crate::prompt::PromptData; +use crate::prompt::{Prompt, PromptData}; pub mod error; pub mod hf; @@ -82,6 +85,44 @@ pub struct EmbeddingConfig { // TODO: add metrics and anything needed } +#[derive(Clone, Default)] +pub struct EmbeddingConfigs(HashMap, Arc)>); + +impl EmbeddingConfigs { + pub fn new(data: HashMap, Arc)>) -> Self { + Self(data) + } + + pub fn get(&self, name: &str) -> Option<(Arc, Arc)> { + self.0.get(name).cloned() + } + + pub fn get_default(&self) -> Option<(Arc, Arc)> { + self.get_default_embedder_name().and_then(|default| self.get(&default)) + } + + pub fn get_default_embedder_name(&self) -> Option { + let mut it = self.0.keys(); + let first_name = it.next(); + let second_name = it.next(); + match (first_name, second_name) { + (None, _) => None, + (Some(first), None) => Some(first.to_owned()), + (Some(_), Some(_)) => Some("default".to_owned()), + } + } +} + +impl IntoIterator for EmbeddingConfigs { + type Item = (String, (Arc, Arc)); + + type IntoIter = std::collections::hash_map::IntoIter, Arc)>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + #[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] pub enum EmbedderOptions { HuggingFace(hf::EmbedderOptions), From 61bd2fb7a99f3d243e2c2100cbfcf2fb588b1894 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 18:03:19 +0100 Subject: [PATCH 15/28] Update arroy --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 3c2f38840..e4826b489 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -383,7 +383,7 @@ dependencies = [ [[package]] name = "arroy" version = "0.1.0" -source = "git+https://github.com/meilisearch/arroy.git#0079af0ec960bc9c51dd66e898a6b5e980cbb083" +source = "git+https://github.com/meilisearch/arroy.git#4f193fd534acd357b65bfe9eec4b3fed8ece2007" dependencies = [ "bytemuck", "byteorder", From 806e5b68997d2c0081e54be2f4f9367aefb0e006 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 21:49:13 +0100 Subject: [PATCH 16/28] Tests pass --- meilisearch/tests/search/mod.rs | 122 +++++++++++++++++++++--- meilisearch/tests/search/multi.rs | 45 ++++++++- milli/src/search/new/vector_sort.rs | 7 +- milli/src/update/index_documents/mod.rs | 30 +++++- 4 files changed, 179 insertions(+), 25 deletions(-) diff --git a/meilisearch/tests/search/mod.rs b/meilisearch/tests/search/mod.rs index fa97beaaf..ad9c2aaa2 100644 --- a/meilisearch/tests/search/mod.rs +++ b/meilisearch/tests/search/mod.rs @@ -20,22 +20,27 @@ static DOCUMENTS: Lazy = Lazy::new(|| { { "title": "Shazam!", "id": "287947", + "_vectors": { "manual": [1, 2, 3]}, }, { "title": "Captain Marvel", "id": "299537", + "_vectors": { "manual": [1, 2, 54] }, }, { "title": "Escape Room", "id": "522681", + "_vectors": { "manual": [10, -23, 32] }, }, { "title": "How to Train Your Dragon: The Hidden World", "id": "166428", + "_vectors": { "manual": [-100, 231, 32] }, }, { "title": "Gläss", "id": "450465", + "_vectors": { "manual": [-100, 340, 90] }, } ]) }); @@ -57,6 +62,7 @@ static NESTED_DOCUMENTS: Lazy = Lazy::new(|| { }, ], "cattos": "pésti", + "_vectors": { "manual": [1, 2, 3]}, }, { "id": 654, @@ -69,12 +75,14 @@ static NESTED_DOCUMENTS: Lazy = Lazy::new(|| { }, ], "cattos": ["simba", "pestiféré"], + "_vectors": { "manual": [1, 2, 54] }, }, { "id": 750, "father": "romain", "mother": "michelle", "cattos": ["enigma"], + "_vectors": { "manual": [10, 23, 32] }, }, { "id": 951, @@ -91,6 +99,7 @@ static NESTED_DOCUMENTS: Lazy = Lazy::new(|| { }, ], "cattos": ["moumoute", "gomez"], + "_vectors": { "manual": [10, 23, 32] }, }, ]) }); @@ -802,6 +811,13 @@ async fn experimental_feature_score_details() { { "title": "How to Train Your Dragon: The Hidden World", "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + }, "_rankingScoreDetails": { "words": { "order": 0, @@ -823,7 +839,7 @@ async fn experimental_feature_score_details() { "order": 3, "attributeRankingOrderScore": 1.0, "queryWordDistanceScore": 0.8095238095238095, - "score": 0.9365079365079364 + "score": 0.9727891156462584 }, "exactness": { "order": 4, @@ -870,34 +886,89 @@ async fn experimental_feature_vector_store() { meili_snap::snapshot!(code, @"200 OK"); meili_snap::snapshot!(response["vectorStore"], @"true"); + let (response, code) = index + .update_settings(json!({"embedders": { + "manual": { + "source": { + "userProvided": {"dimensions": 3} + } + } + }})) + .await; + + meili_snap::snapshot!(code, @"202 Accepted"); + let response = index.wait_task(response.uid()).await; + + meili_snap::snapshot!(meili_snap::json_string!(response["status"]), @"\"succeeded\""); + let (response, code) = index .search_post(json!({ "vector": [1.0, 2.0, 3.0], })) .await; + meili_snap::snapshot!(code, @"200 OK"); // vector search returns all documents that don't have vectors in the last bucket, like all sorts meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" [ { "title": "Shazam!", - "id": "287947" + "id": "287947", + "_vectors": { + "manual": [ + 1, + 2, + 3 + ] + }, + "_semanticScore": 1.0 }, { "title": "Captain Marvel", - "id": "299537" - }, - { - "title": "Escape Room", - "id": "522681" - }, - { - "title": "How to Train Your Dragon: The Hidden World", - "id": "166428" + "id": "299537", + "_vectors": { + "manual": [ + 1, + 2, + 54 + ] + }, + "_semanticScore": 0.9129112 }, { "title": "Gläss", - "id": "450465" + "id": "450465", + "_vectors": { + "manual": [ + -100, + 340, + 90 + ] + }, + "_semanticScore": 0.8106413 + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + }, + "_semanticScore": 0.74120104 + }, + { + "title": "Escape Room", + "id": "522681", + "_vectors": { + "manual": [ + 10, + -23, + 32 + ] + } } ] "###); @@ -1150,7 +1221,14 @@ async fn simple_search_with_strange_synonyms() { [ { "title": "How to Train Your Dragon: The Hidden World", - "id": "166428" + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + } } ] "###); @@ -1164,7 +1242,14 @@ async fn simple_search_with_strange_synonyms() { [ { "title": "How to Train Your Dragon: The Hidden World", - "id": "166428" + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + } } ] "###); @@ -1178,7 +1263,14 @@ async fn simple_search_with_strange_synonyms() { [ { "title": "How to Train Your Dragon: The Hidden World", - "id": "166428" + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + } } ] "###); diff --git a/meilisearch/tests/search/multi.rs b/meilisearch/tests/search/multi.rs index 0e2e5158d..aeec1bad4 100644 --- a/meilisearch/tests/search/multi.rs +++ b/meilisearch/tests/search/multi.rs @@ -72,7 +72,14 @@ async fn simple_search_single_index() { "hits": [ { "title": "Gläss", - "id": "450465" + "id": "450465", + "_vectors": { + "manual": [ + -100, + 340, + 90 + ] + } } ], "query": "glass", @@ -86,7 +93,14 @@ async fn simple_search_single_index() { "hits": [ { "title": "Captain Marvel", - "id": "299537" + "id": "299537", + "_vectors": { + "manual": [ + 1, + 2, + 54 + ] + } } ], "query": "captain", @@ -177,7 +191,14 @@ async fn simple_search_two_indexes() { "hits": [ { "title": "Gläss", - "id": "450465" + "id": "450465", + "_vectors": { + "manual": [ + -100, + 340, + 90 + ] + } } ], "query": "glass", @@ -203,7 +224,14 @@ async fn simple_search_two_indexes() { "age": 4 } ], - "cattos": "pésti" + "cattos": "pésti", + "_vectors": { + "manual": [ + 1, + 2, + 3 + ] + } }, { "id": 654, @@ -218,7 +246,14 @@ async fn simple_search_two_indexes() { "cattos": [ "simba", "pestiféré" - ] + ], + "_vectors": { + "manual": [ + 1, + 2, + 54 + ] + } } ], "query": "pésti", diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs index 6a37ceb7d..b29a72827 100644 --- a/milli/src/search/new/vector_sort.rs +++ b/milli/src/search/new/vector_sort.rs @@ -27,8 +27,11 @@ impl VectorSort { distribution_shift: Option, embedder_name: &str, ) -> Result { - /// FIXME: unwrap - let embedder_index = ctx.index.embedder_category_id.get(ctx.txn, embedder_name)?.unwrap(); + let embedder_index = ctx + .index + .embedder_category_id + .get(ctx.txn, embedder_name)? + .ok_or_else(|| crate::UserError::InvalidEmbedder(embedder_name.to_owned()))?; Ok(Self { query: None, diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index efc6b22ff..6906bbcd3 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -750,6 +750,8 @@ fn execute_word_prefix_docids( #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use big_s::S; use fst::IntoStreamer; use heed::RwTxn; @@ -759,6 +761,7 @@ mod tests { use crate::documents::documents_batch_reader_from_objects; use crate::index::tests::TempIndex; use crate::search::TermsMatchingStrategy; + use crate::update::Setting; use crate::{db_snap, Filter, Search}; #[test] @@ -2550,13 +2553,34 @@ mod tests { /// Vectors must be of the same length. #[test] fn test_multiple_vectors() { + use crate::vector::settings::{EmbedderSettings, EmbeddingSettings}; let index = TempIndex::new(); - index.add_documents(documents!([{"id": 0, "_vectors": [[0, 1, 2], [3, 4, 5]] }])).unwrap(); - index.add_documents(documents!([{"id": 1, "_vectors": [6, 7, 8] }])).unwrap(); + index + .update_settings(|settings| { + let mut embedders = BTreeMap::default(); + embedders.insert( + "manual".to_string(), + Setting::Set(EmbeddingSettings { + embedder_options: Setting::Set(EmbedderSettings::UserProvided( + crate::vector::settings::UserProvidedSettings { dimensions: 3 }, + )), + document_template: Setting::NotSet, + }), + ); + settings.set_embedder_settings(embedders); + }) + .unwrap(); + index .add_documents( - documents!([{"id": 2, "_vectors": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }]), + documents!([{"id": 0, "_vectors": { "manual": [[0, 1, 2], [3, 4, 5]] } }]), + ) + .unwrap(); + index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap(); + index + .add_documents( + documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]), ) .unwrap(); From b8e4709dfa6377ab4c84540e2e08069750be82e6 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 22:06:39 +0100 Subject: [PATCH 17/28] Remove prompt strategy and fallback --- milli/src/prompt/mod.rs | 74 ++++++++---------------------------- milli/src/update/settings.rs | 7 +--- milli/src/vector/settings.rs | 26 ++----------- 3 files changed, 21 insertions(+), 86 deletions(-) diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs index 67ef8b4f6..97ccbfb61 100644 --- a/milli/src/prompt/mod.rs +++ b/milli/src/prompt/mod.rs @@ -16,20 +16,16 @@ use crate::FieldsIdsMap; pub struct Prompt { template: liquid::Template, template_text: String, - strategy: PromptFallbackStrategy, - fallback: String, } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct PromptData { pub template: String, - pub strategy: PromptFallbackStrategy, - pub fallback: String, } impl From for PromptData { fn from(value: Prompt) -> Self { - Self { template: value.template_text, strategy: value.strategy, fallback: value.fallback } + Self { template: value.template_text } } } @@ -37,19 +33,14 @@ impl TryFrom for Prompt { type Error = NewPromptError; fn try_from(value: PromptData) -> Result { - Prompt::new(value.template, Some(value.strategy), Some(value.fallback)) + Prompt::new(value.template) } } impl Clone for Prompt { fn clone(&self) -> Self { let template_text = self.template_text.clone(); - Self { - template: new_template(&template_text).unwrap(), - template_text, - strategy: self.strategy, - fallback: self.fallback.clone(), - } + Self { template: new_template(&template_text).unwrap(), template_text } } } @@ -67,37 +58,20 @@ fn default_template_text() -> &'static str { {% endfor %}" } -fn default_fallback() -> &'static str { - "" -} - impl Default for Prompt { fn default() -> Self { - Self { - template: default_template(), - template_text: default_template_text().into(), - strategy: Default::default(), - fallback: default_fallback().into(), - } + Self { template: default_template(), template_text: default_template_text().into() } } } impl Default for PromptData { fn default() -> Self { - Self { - template: default_template_text().into(), - strategy: Default::default(), - fallback: default_fallback().into(), - } + Self { template: default_template_text().into() } } } impl Prompt { - pub fn new( - template: String, - strategy: Option, - fallback: Option, - ) -> Result { + pub fn new(template: String) -> Result { let this = Self { template: liquid::ParserBuilder::with_stdlib() .build() @@ -105,8 +79,6 @@ impl Prompt { .parse(&template) .map_err(NewPromptError::cannot_parse_template)?, template_text: template, - strategy: strategy.unwrap_or_default(), - fallback: fallback.unwrap_or_default(), }; // render template with special object that's OK with `doc.*` and `fields.*` @@ -130,18 +102,6 @@ impl Prompt { } } -#[derive( - Debug, Default, Clone, PartialEq, Eq, Copy, serde::Serialize, serde::Deserialize, deserr::Deserr, -)] -#[serde(deny_unknown_fields, rename_all = "camelCase")] -#[deserr(rename_all = camelCase, deny_unknown_fields)] -pub enum PromptFallbackStrategy { - Fallback, - Skip, - #[default] - Error, -} - #[cfg(test)] mod test { use super::Prompt; @@ -156,18 +116,18 @@ mod test { #[test] fn empty_template() { - Prompt::new("".into(), None, None).unwrap(); + Prompt::new("".into()).unwrap(); } #[test] fn template_ok() { - Prompt::new("{{doc.title}}: {{doc.overview}}".into(), None, None).unwrap(); + Prompt::new("{{doc.title}}: {{doc.overview}}".into()).unwrap(); } #[test] fn template_syntax() { assert!(matches!( - Prompt::new("{{doc.title: {{doc.overview}}".into(), None, None), + Prompt::new("{{doc.title: {{doc.overview}}".into()), Err(NewPromptError { kind: NewPromptErrorKind::CannotParseTemplate(_), fault: FaultSource::User @@ -178,7 +138,7 @@ mod test { #[test] fn template_missing_doc() { assert!(matches!( - Prompt::new("{{title}}: {{overview}}".into(), None, None), + Prompt::new("{{title}}: {{overview}}".into()), Err(NewPromptError { kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), fault: FaultSource::User @@ -188,29 +148,25 @@ mod test { #[test] fn template_nested_doc() { - Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into(), None, None).unwrap(); + Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into()).unwrap(); } #[test] fn template_fields() { - Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into(), None, None).unwrap(); + Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into()).unwrap(); } #[test] fn template_fields_ok() { - Prompt::new( - "{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into(), - None, - None, - ) - .unwrap(); + Prompt::new("{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into()) + .unwrap(); } #[test] fn template_fields_invalid() { assert!(matches!( // intentionally garbled field - Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into(), None, None), + Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into()), Err(NewPromptError { kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), fault: FaultSource::User diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index e9f345e42..d406c121c 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1073,11 +1073,10 @@ fn validate_prompt( match new { Setting::Set(EmbeddingSettings { embedder_options, - document_template: - Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }), + document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }), }) => { // validate - let template = crate::prompt::Prompt::new(template, None, None) + let template = crate::prompt::Prompt::new(template) .map(|prompt| crate::prompt::PromptData::from(prompt).template) .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; @@ -1085,8 +1084,6 @@ fn validate_prompt( embedder_options, document_template: Setting::Set(PromptSettings { template: Setting::Set(template), - strategy, - fallback, }), })) } diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index f90c3cc71..bd385e3f3 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -1,7 +1,7 @@ use deserr::Deserr; use serde::{Deserialize, Serialize}; -use crate::prompt::{PromptData, PromptFallbackStrategy}; +use crate::prompt::PromptData; use crate::update::Setting; use crate::vector::hf::WeightSource; use crate::vector::EmbeddingConfig; @@ -56,46 +56,28 @@ pub struct PromptSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub template: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub strategy: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub fallback: Setting, } impl PromptSettings { pub fn apply(&mut self, new: Self) { - let PromptSettings { template, strategy, fallback } = new; + let PromptSettings { template } = new; self.template.apply(template); - self.strategy.apply(strategy); - self.fallback.apply(fallback); } } impl From for PromptSettings { fn from(value: PromptData) -> Self { - Self { - template: Setting::Set(value.template), - strategy: Setting::Set(value.strategy), - fallback: Setting::Set(value.fallback), - } + Self { template: Setting::Set(value.template) } } } impl From for PromptData { fn from(value: PromptSettings) -> Self { let mut this = PromptData::default(); - let PromptSettings { template, strategy, fallback } = value; + let PromptSettings { template } = value; if let Some(template) = template.set() { this.template = template; } - if let Some(strategy) = strategy.set() { - this.strategy = strategy; - } - if let Some(fallback) = fallback.set() { - this.fallback = fallback; - } this } } From 3c1a14f1cdd46fb9a3327d216d53d7dfc3394c60 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 23:09:14 +0100 Subject: [PATCH 18/28] Add settings routes --- meilisearch/src/routes/indexes/settings.rs | 66 +++++++++++++++++++++- 1 file changed, 65 insertions(+), 1 deletion(-) diff --git a/meilisearch/src/routes/indexes/settings.rs b/meilisearch/src/routes/indexes/settings.rs index c22db24f0..024b7e7c0 100644 --- a/meilisearch/src/routes/indexes/settings.rs +++ b/meilisearch/src/routes/indexes/settings.rs @@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::ResponseError; use meilisearch_types::facet_values_sort::FacetValuesSort; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::update::Setting; use meilisearch_types::settings::{settings, RankingRuleView, Settings, Unchecked}; use meilisearch_types::tasks::KindWithContent; use serde_json::json; @@ -545,6 +546,67 @@ make_setting_route!( } ); +make_setting_route!( + "/embedders", + patch, + std::collections::BTreeMap>, + meilisearch_types::deserr::DeserrJsonError< + meilisearch_types::error::deserr_codes::InvalidSettingsEmbedders, + >, + embedders, + "embedders", + analytics, + |setting: &Option>>, req: &HttpRequest| { + + + analytics.publish( + "Embedders Updated".to_string(), + serde_json::json!({"embedders": crate::routes::indexes::settings::embedder_analytics(setting.as_ref())}), + Some(req), + ); + } +); + +fn embedder_analytics( + setting: Option< + &std::collections::BTreeMap< + String, + Setting, + >, + >, +) -> serde_json::Value { + let mut sources = std::collections::HashSet::new(); + + if let Some(s) = &setting { + for source in s + .values() + .filter_map(|config| config.clone().set()) + .filter_map(|config| config.embedder_options.set()) + { + use meilisearch_types::milli::vector::settings::EmbedderSettings; + match source { + EmbedderSettings::OpenAi(_) => sources.insert("openAi"), + EmbedderSettings::HuggingFace(_) => sources.insert("huggingFace"), + EmbedderSettings::UserProvided(_) => sources.insert("userProvided"), + }; + } + }; + + let document_template_used = setting.as_ref().map(|map| { + map.values() + .filter_map(|config| config.clone().set()) + .any(|config| config.document_template.set().is_some()) + }); + + json!( + { + "total": setting.as_ref().map(|s| s.len()), + "sources": sources, + "document_template_used": document_template_used, + } + ) +} + macro_rules! generate_configure { ($($mod:ident),*) => { pub fn configure(cfg: &mut web::ServiceConfig) { @@ -574,7 +636,8 @@ generate_configure!( ranking_rules, typo_tolerance, pagination, - faceting + faceting, + embedders ); pub async fn update_all( @@ -681,6 +744,7 @@ pub async fn update_all( "synonyms": { "total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()), }, + "embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set()) }), Some(&req), ); From 5b51cb04afd4a005f269527b7f88f47055835784 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 23:09:50 +0100 Subject: [PATCH 19/28] Remove some settings --- milli/src/vector/hf.rs | 42 +++++++++++------------------------- milli/src/vector/settings.rs | 26 ++-------------------- 2 files changed, 15 insertions(+), 53 deletions(-) diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 07185d25c..3162dadec 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -23,7 +23,7 @@ use super::{Embedding, Embeddings}; )] #[serde(deny_unknown_fields, rename_all = "camelCase")] #[deserr(rename_all = camelCase, deny_unknown_fields)] -pub enum WeightSource { +enum WeightSource { #[default] Safetensors, Pytorch, @@ -33,20 +33,13 @@ pub enum WeightSource { pub struct EmbedderOptions { pub model: String, pub revision: Option, - pub weight_source: WeightSource, - pub normalize_embeddings: bool, } impl EmbedderOptions { pub fn new() -> Self { Self { - //model: "sentence-transformers/all-MiniLM-L6-v2".to_string(), model: "BAAI/bge-base-en-v1.5".to_string(), - //revision: Some("refs/pr/21".to_string()), - revision: None, - //weight_source: Default::default(), - weight_source: WeightSource::Pytorch, - normalize_embeddings: true, + revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), } } } @@ -82,20 +75,21 @@ impl Embedder { Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), None => Repo::model(options.model.clone()), }; - let (config_filename, tokenizer_filename, weights_filename) = { + let (config_filename, tokenizer_filename, weights_filename, weight_source) = { let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; let api = api.repo(repo); let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; - let weights = match options.weight_source { - WeightSource::Pytorch => { - api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)? - } - WeightSource::Safetensors => { - api.get("model.safetensors").map_err(NewEmbedderError::api_get)? - } + let (weights, source) = { + api.get("pytorch_model.bin") + .map(|filename| (filename, WeightSource::Pytorch)) + .or_else(|_| { + api.get("model.safetensors") + .map(|filename| (filename, WeightSource::Safetensors)) + }) + .map_err(NewEmbedderError::api_get)? }; - (config, tokenizer, weights) + (config, tokenizer, weights, source) }; let config = std::fs::read_to_string(&config_filename) @@ -106,7 +100,7 @@ impl Embedder { let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; - let vb = match options.weight_source { + let vb = match weight_source { WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) .map_err(NewEmbedderError::pytorch_weight)?, WeightSource::Safetensors => unsafe { @@ -168,12 +162,6 @@ impl Embedder { let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) .map_err(EmbedError::tensor_shape)?; - let embeddings: Tensor = if self.options.normalize_embeddings { - normalize_l2(&embeddings).map_err(EmbedError::tensor_value)? - } else { - embeddings - }; - let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) } @@ -197,7 +185,3 @@ impl Embedder { self.dimensions } } - -fn normalize_l2(v: &Tensor) -> Result { - v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?) -} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index bd385e3f3..e37b0fde7 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -3,7 +3,6 @@ use serde::{Deserialize, Serialize}; use crate::prompt::PromptData; use crate::update::Setting; -use crate::vector::hf::WeightSource; use crate::vector::EmbeddingConfig; #[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] @@ -204,26 +203,13 @@ pub struct HfEmbedderSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub revision: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub weight_source: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] - pub normalize_embeddings: Setting, } impl HfEmbedderSettings { pub fn apply(&mut self, new: Self) { - let HfEmbedderSettings { - model, - revision, - weight_source, - normalize_embeddings: normalize_embedding, - } = new; + let HfEmbedderSettings { model, revision } = new; self.model.apply(model); self.revision.apply(revision); - self.weight_source.apply(weight_source); - self.normalize_embeddings.apply(normalize_embedding); } } @@ -232,15 +218,13 @@ impl From for HfEmbedderSettings { Self { model: Setting::Set(value.model), revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), - weight_source: Setting::Set(value.weight_source), - normalize_embeddings: Setting::Set(value.normalize_embeddings), } } } impl From for crate::vector::hf::EmbedderOptions { fn from(value: HfEmbedderSettings) -> Self { - let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value; + let HfEmbedderSettings { model, revision } = value; let mut this = Self::default(); if let Some(model) = model.set() { this.model = model; @@ -248,12 +232,6 @@ impl From for crate::vector::hf::EmbedderOptions { if let Some(revision) = revision.set() { this.revision = Some(revision); } - if let Some(weight_source) = weight_source.set() { - this.weight_source = weight_source; - } - if let Some(normalize_embeddings) = normalize_embeddings.set() { - this.normalize_embeddings = normalize_embeddings; - } this } } From a4536b1381d80b9fc9f4283da36cbce8434057a7 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Wed, 13 Dec 2023 23:25:38 +0100 Subject: [PATCH 20/28] Small adjustments to respect the spec --- milli/src/vector/openai.rs | 2 ++ milli/src/vector/settings.rs | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index bab62f5e4..7ae626494 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -35,6 +35,8 @@ pub struct EmbedderOptions { #[deserr(rename_all = camelCase, deny_unknown_fields)] pub enum EmbeddingModel { #[default] + #[serde(rename = "text-embedding-ada-002")] + #[deserr(rename = "text-embedding-ada-002")] TextEmbeddingAda002, } diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index e37b0fde7..a91692613 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -243,8 +243,8 @@ pub struct OpenAiEmbedderSettings { #[serde(default, skip_serializing_if = "Setting::is_not_set")] #[deserr(default)] pub api_key: Setting, - #[serde(default, skip_serializing_if = "Setting::is_not_set")] - #[deserr(default)] + #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")] + #[deserr(default, rename = "model")] pub embedding_model: Setting, } From 9991152bbe2607e584860afc845a49fa901496ca Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Wed, 13 Dec 2023 11:46:09 +0100 Subject: [PATCH 21/28] Add TODOs --- meilisearch/src/search.rs | 4 ++++ milli/src/search/hybrid.rs | 1 + milli/src/search/mod.rs | 1 + 3 files changed, 6 insertions(+) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 53f6140fb..7bf8ea160 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -90,6 +90,7 @@ pub struct SearchQuery { #[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct HybridQuery { + /// TODO validate that sementic ratio is between 0.0 and 1,0 #[deserr(default, error = DeserrJsonError, default = DEFAULT_SEMANTIC_RATIO())] pub semantic_ratio: f32, #[deserr(default, error = DeserrJsonError, default)] @@ -452,6 +453,9 @@ pub fn perform_search( let (search, is_finite_pagination, max_total_hits, offset) = prepare_search(index, &rtxn, &query, features)?; + /// TODO: Change if-cond to query.hybrid.is_some + /// + < 1.0 or remove q + /// + > 0.0 or remove vector let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = if query.q.is_some() && query.vector.is_some() { search.execute_hybrid()? diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index cbec20c65..129857fb5 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -252,6 +252,7 @@ impl<'a> Search<'a> { // can unwrap because we returned already if there was no vector query self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?; + /// TODO apply sementic ratio let keyword_results = CombinedSearchResult::new(keyword_results, vector_results_for_keyword); let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector); diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 04a6005e3..44fb3556f 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -50,6 +50,7 @@ pub struct Search<'a> { scoring_strategy: ScoringStrategy, words_limit: usize, exhaustive_number_hits: bool, + /// TODO: Add semantic ratio or pass it directly to execute_hybrid() rtxn: &'a heed::RoTxn<'a>, index: &'a Index, distribution_shift: Option, From ac68f3319449921fc7c28ed7d65f5cd6bb0ad12b Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Wed, 13 Dec 2023 11:47:12 +0100 Subject: [PATCH 22/28] Add simple test --- meilisearch/tests/search/hybrid.rs | 152 +++++++++++++++++++++++++++++ meilisearch/tests/search/mod.rs | 1 + 2 files changed, 153 insertions(+) create mode 100644 meilisearch/tests/search/hybrid.rs diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs new file mode 100644 index 000000000..e5f34bcd6 --- /dev/null +++ b/meilisearch/tests/search/hybrid.rs @@ -0,0 +1,152 @@ +use meili_snap::{json_string, snapshot}; +use once_cell::sync::Lazy; + +use crate::common::index::Index; +use crate::common::{Server, Value}; +use crate::json; + +async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Index<'a> { + let index = server.index("test"); + + let (response, code) = server.set_features(json!({"vectorStore": true})).await; + + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response), @r###" + { + "scoreDetails": false, + "vectorStore": true, + "metrics": false, + "exportPuffinReports": false, + "proximityPrecision": false + } + "###); + + let (response, code) = index + .update_settings( + json!({ "embedders": {"default": {"source": {"userProvided": {"dimensions": 2}}}} }), + ) + .await; + assert_eq!(202, code, "{:?}", response); + index.wait_task(0).await; + + let (response, code) = index.add_documents(documents.clone(), None).await; + assert_eq!(202, code, "{:?}", response); + index.wait_task(1).await; + index +} + +static SIMPLE_SEARCH_DOCUMENTS: Lazy = Lazy::new(|| { + json!([ + { + "title": "Shazam!", + "desc": "a Captain Marvel ersatz", + "id": "1", + "_vectors": {"default": [1.0, 3.0]}, + }, + { + "title": "Captain Planet", + "desc": "He's not part of the Marvel Cinematic Universe", + "id": "2", + "_vectors": {"default": [1.0, 2.0]}, + }, + { + "title": "Captain Marvel", + "desc": "a Shazam ersatz", + "id": "3", + "_vectors": {"default": [2.0, 3.0]}, + }]) +}); + +#[actix_rt::test] +async fn simple_search() { + let server = Server::new().await; + let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}}), + ) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}}), + ) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); +} + +#[actix_rt::test] +async fn invalid_semantic_ratio() { + let server = Server::new().await; + let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 1.2}}), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`.", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": -0.8}}), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`.", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); + + let (response, code) = index + .search_get( + &yaup::to_string( + &json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": 1.2}), + ) + .unwrap(), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value type for parameter `hybridSemanticRatio`: expected a string, but found a string: `1.2`", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); + + let (response, code) = index + .search_get( + &yaup::to_string( + &json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": -0.2}), + ) + .unwrap(), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value type for parameter `hybridSemanticRatio`: expected a string, but found a string: `-0.2`", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); +} diff --git a/meilisearch/tests/search/mod.rs b/meilisearch/tests/search/mod.rs index ad9c2aaa2..133a143fd 100644 --- a/meilisearch/tests/search/mod.rs +++ b/meilisearch/tests/search/mod.rs @@ -6,6 +6,7 @@ mod errors; mod facet_search; mod formatted; mod geo; +mod hybrid; mod multi; mod pagination; mod restrict_searchable; From 93dcbf598d67a37879b0e2ae7ecb7fc63145c30f Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Thu, 14 Dec 2023 10:21:10 +0100 Subject: [PATCH 23/28] Deserialize semantic ratio --- meilisearch-types/src/deserr/mod.rs | 1 + meilisearch-types/src/error.rs | 11 +++++++++- meilisearch/src/routes/indexes/search.rs | 8 +++---- meilisearch/src/search.rs | 28 ++++++++++++++++++------ 4 files changed, 36 insertions(+), 12 deletions(-) diff --git a/meilisearch-types/src/deserr/mod.rs b/meilisearch-types/src/deserr/mod.rs index df304cc2f..537b24574 100644 --- a/meilisearch-types/src/deserr/mod.rs +++ b/meilisearch-types/src/deserr/mod.rs @@ -188,3 +188,4 @@ merge_with_error_impl_take_error_message!(ParseOffsetDateTimeError); merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(IndexUidFormatError); +merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 9df41b68f..1dc33b140 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -235,7 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; -InvalidSemanticRatio , InvalidRequest , BAD_REQUEST ; +InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; @@ -459,6 +459,15 @@ impl fmt::Display for DeserrParseIntError { } } +impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`." + ) + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 7a9a14687..ad7f0dc89 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -17,7 +17,7 @@ use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, + add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; @@ -75,10 +75,10 @@ pub struct SearchQueryGet { matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrQueryParamError)] pub attributes_to_search_on: Option>, - #[deserr(default, error = DeserrQueryParamError)] + #[deserr(default, error = DeserrQueryParamError)] pub hybrid_embedder: Option, - #[deserr(default, error = DeserrQueryParamError)] - pub hybrid_semantic_ratio: Option, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_semantic_ratio: Option, } impl From for SearchQuery { diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 7bf8ea160..674b6e25e 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -36,7 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10; pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); -pub const DEFAULT_SEMANTIC_RATIO: fn() -> f32 = || 0.5; +pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5); #[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] @@ -91,12 +91,27 @@ pub struct SearchQuery { #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct HybridQuery { /// TODO validate that sementic ratio is between 0.0 and 1,0 - #[deserr(default, error = DeserrJsonError, default = DEFAULT_SEMANTIC_RATIO())] - pub semantic_ratio: f32, + #[deserr(default, error = DeserrJsonError)] + pub semantic_ratio: SemanticRatio, #[deserr(default, error = DeserrJsonError, default)] pub embedder: Option, } +#[derive(Debug, Clone, Copy, Default, PartialEq, Deserr)] +#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] +pub struct SemanticRatio(f32); +impl std::convert::TryFrom for SemanticRatio { + type Error = InvalidSearchSemanticRatio; + + fn try_from(f: f32) -> Result { + if f > 1.0 || f < 0.0 { + Err(InvalidSearchSemanticRatio) + } else { + Ok(SemanticRatio(f)) + } + } +} + impl SearchQuery { pub fn is_finite_pagination(&self) -> bool { self.page.or(self.hits_per_page).is_some() @@ -457,10 +472,9 @@ pub fn perform_search( /// + < 1.0 or remove q /// + > 0.0 or remove vector let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = - if query.q.is_some() && query.vector.is_some() { - search.execute_hybrid()? - } else { - search.execute()? + match query.hybrid { + Some(_) => search.execute_hybrid()?, + None => search.execute()?, }; let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); From f3f39444697e1f564c8ec2b24c924d943f574c2f Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Thu, 14 Dec 2023 11:21:25 +0100 Subject: [PATCH 24/28] Fix error checking --- meilisearch-types/src/error.rs | 2 +- meilisearch/src/routes/indexes/search.rs | 29 +++++++++++++++++++++--- meilisearch/src/search.rs | 3 +++ meilisearch/tests/search/hybrid.rs | 10 ++++---- 4 files changed, 35 insertions(+), 9 deletions(-) diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 1dc33b140..6f2624053 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -463,7 +463,7 @@ impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!( f, - "the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`." + "the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`." ) } } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index ad7f0dc89..b8db20da4 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -78,7 +78,28 @@ pub struct SearchQueryGet { #[deserr(default, error = DeserrQueryParamError)] pub hybrid_embedder: Option, #[deserr(default, error = DeserrQueryParamError)] - pub hybrid_semantic_ratio: Option, + pub hybrid_semantic_ratio: Option, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)] +#[deserr(try_from(String) = TryFrom::try_from -> InvalidSearchSemanticRatio)] +pub struct SemanticRatioGet(SemanticRatio); + +impl std::convert::TryFrom for SemanticRatioGet { + type Error = InvalidSearchSemanticRatio; + + fn try_from(s: String) -> Result { + let f: f32 = s.parse().map_err(|_| InvalidSearchSemanticRatio)?; + Ok(SemanticRatioGet(SemanticRatio::try_from(f)?)) + } +} + +impl std::ops::Deref for SemanticRatioGet { + type Target = SemanticRatio; + + fn deref(&self) -> &Self::Target { + &self.0 + } } impl From for SearchQuery { @@ -93,13 +114,15 @@ impl From for SearchQuery { let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) { (None, None) => None, - (None, Some(semantic_ratio)) => Some(HybridQuery { semantic_ratio, embedder: None }), + (None, Some(semantic_ratio)) => { + Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: None }) + } (Some(embedder), None) => Some(HybridQuery { semantic_ratio: DEFAULT_SEMANTIC_RATIO(), embedder: Some(embedder), }), (Some(embedder), Some(semantic_ratio)) => { - Some(HybridQuery { semantic_ratio, embedder: Some(embedder) }) + Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: Some(embedder) }) } }; diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 674b6e25e..27a6efab4 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -100,6 +100,7 @@ pub struct HybridQuery { #[derive(Debug, Clone, Copy, Default, PartialEq, Deserr)] #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] pub struct SemanticRatio(f32); + impl std::convert::TryFrom for SemanticRatio { type Error = InvalidSearchSemanticRatio; @@ -383,7 +384,9 @@ fn prepare_search<'t>( } if let Some(ref query) = query.q { + // if !matches!(query.hybrid, query.) { search.query(query); + // } } if let Some(ref searchable) = query.attributes_to_search_on { diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs index e5f34bcd6..578667244 100644 --- a/meilisearch/tests/search/hybrid.rs +++ b/meilisearch/tests/search/hybrid.rs @@ -68,7 +68,7 @@ async fn simple_search() { ) .await; snapshot!(code, @"200 OK"); - snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); + snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); let (response, code) = index .search_post( @@ -92,7 +92,7 @@ async fn invalid_semantic_ratio() { snapshot!(code, @"400 Bad Request"); snapshot!(response, @r###" { - "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`.", + "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", "code": "invalid_search_semantic_ratio", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" @@ -107,7 +107,7 @@ async fn invalid_semantic_ratio() { snapshot!(code, @"400 Bad Request"); snapshot!(response, @r###" { - "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a value between `0.0` and `1.0`.", + "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", "code": "invalid_search_semantic_ratio", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" @@ -125,7 +125,7 @@ async fn invalid_semantic_ratio() { snapshot!(code, @"400 Bad Request"); snapshot!(response, @r###" { - "message": "Invalid value type for parameter `hybridSemanticRatio`: expected a string, but found a string: `1.2`", + "message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", "code": "invalid_search_semantic_ratio", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" @@ -143,7 +143,7 @@ async fn invalid_semantic_ratio() { snapshot!(code, @"400 Bad Request"); snapshot!(response, @r###" { - "message": "Invalid value type for parameter `hybridSemanticRatio`: expected a string, but found a string: `-0.2`", + "message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", "code": "invalid_search_semantic_ratio", "type": "invalid_request", "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" From 1b7c164a559c1b8794463c58ce275761d22d5263 Mon Sep 17 00:00:00 2001 From: ManyTheFish Date: Thu, 14 Dec 2023 11:42:55 +0100 Subject: [PATCH 25/28] Pass the semantic ratio to milli --- meilisearch/src/search.rs | 50 ++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 27a6efab4..b596e2cc8 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -113,6 +113,14 @@ impl std::convert::TryFrom for SemanticRatio { } } +impl std::ops::Deref for SemanticRatio { + type Target = f32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl SearchQuery { pub fn is_finite_pagination(&self) -> bool { self.page.or(self.hits_per_page).is_some() @@ -373,20 +381,30 @@ fn prepare_search<'t>( } if let Some(ref vector) = query.vector { - match vector { - VectorQuery::Vector(vector) => { - search.vector(vector.clone()); - } - VectorQuery::String(_) => { - panic!("Failed while preparing search; caller did not generate embedding for query") - } + match &query.hybrid { + // If semantic ratio is 0.0, only the query search will impact the search results, + // skip the vector + Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), + _otherwise => match vector { + VectorQuery::Vector(vector) => { + search.vector(vector.clone()); + } + VectorQuery::String(_) => { + panic!("Failed while preparing search; caller did not generate embedding for query") + } + }, } } - if let Some(ref query) = query.q { - // if !matches!(query.hybrid, query.) { - search.query(query); - // } + if let Some(ref q) = query.q { + match &query.hybrid { + // If semantic ratio is 1.0, only the vector search will impact the search results, + // skip the query + Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (), + _otherwise => { + search.query(q); + } + } } if let Some(ref searchable) = query.attributes_to_search_on { @@ -471,12 +489,12 @@ pub fn perform_search( let (search, is_finite_pagination, max_total_hits, offset) = prepare_search(index, &rtxn, &query, features)?; - /// TODO: Change if-cond to query.hybrid.is_some - /// + < 1.0 or remove q - /// + > 0.0 or remove vector let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = - match query.hybrid { - Some(_) => search.execute_hybrid()?, + match &query.hybrid { + Some(hybrid) => match *hybrid.semantic_ratio { + 0.0 | 1.0 => search.execute()?, + ratio => search.execute_hybrid(ratio)?, + }, None => search.execute()?, }; From 217105b7da69920b53b2bc42914041c516ef2601 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 14 Dec 2023 12:42:37 +0100 Subject: [PATCH 26/28] hybrid search uses semantic ratio, error handling --- meilisearch-types/src/error.rs | 3 +- .../src/analytics/segment_analytics.rs | 2 +- meilisearch/src/error.rs | 3 + .../src/routes/indexes/facet_search.rs | 3 +- meilisearch/src/routes/indexes/search.rs | 70 +++--- meilisearch/src/search.rs | 43 ++-- meilisearch/tests/search/hybrid.rs | 2 +- milli/src/lib.rs | 2 +- milli/src/search/hybrid.rs | 229 +++--------------- milli/src/search/mod.rs | 48 ---- 10 files changed, 89 insertions(+), 316 deletions(-) diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index 6f2624053..62591e991 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -235,7 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; -InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; +InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; @@ -299,6 +299,7 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; MissingIndexUid , InvalidRequest , BAD_REQUEST ; MissingMasterKey , Auth , UNAUTHORIZED ; MissingPayload , InvalidRequest , BAD_REQUEST ; +MissingSearchHybrid , InvalidRequest , BAD_REQUEST ; MissingSwapIndexes , InvalidRequest , BAD_REQUEST ; MissingTaskFilters , InvalidRequest , BAD_REQUEST ; NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY; diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index 67770d87c..1ad277c28 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -692,7 +692,7 @@ impl SearchAggregator { ret.max_terms_number = q.split_whitespace().count(); } - if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector { + if let Some(ref vector) = vector { ret.max_vector_size = vector.len(); } diff --git a/meilisearch/src/error.rs b/meilisearch/src/error.rs index ca10c4593..3bd8f3edd 100644 --- a/meilisearch/src/error.rs +++ b/meilisearch/src/error.rs @@ -51,6 +51,8 @@ pub enum MeilisearchHttpError { DocumentFormat(#[from] DocumentFormatError), #[error(transparent)] Join(#[from] JoinError), + #[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")] + MissingSearchHybrid, } impl ErrorCode for MeilisearchHttpError { @@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError { MeilisearchHttpError::FileStore(_) => Code::Internal, MeilisearchHttpError::DocumentFormat(e) => e.error_code(), MeilisearchHttpError::Join(_) => Code::Internal, + MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid, } } } diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 59c0e7353..4b5d4d78a 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -7,7 +7,6 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::VectorQuery; use serde_json::Value; use crate::analytics::{Analytics, FacetSearchAggregator}; @@ -121,7 +120,7 @@ impl From for SearchQuery { highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(), crop_marker: DEFAULT_CROP_MARKER(), matching_strategy, - vector: vector.map(VectorQuery::Vector), + vector, attributes_to_search_on, hybrid, } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index b8db20da4..c2b6ca3fc 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -8,7 +8,7 @@ use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::{self, VectorQuery}; +use meilisearch_types::milli; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -128,7 +128,7 @@ impl From for SearchQuery { Self { q: other.q, - vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector), + vector: other.vector.map(CS::into_inner), offset: other.offset.0, limit: other.limit.0, page: other.page.as_deref().copied(), @@ -258,49 +258,37 @@ pub async fn embed( index_scheduler: &IndexScheduler, index: &milli::Index, ) -> Result<(), ResponseError> { - match query.vector.take() { - Some(VectorQuery::String(prompt)) => { - let embedder_configs = index.embedding_configs(&index.read_txn()?)?; - let embedders = index_scheduler.embedders(embedder_configs)?; + if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) = + (&query.vector, &query.q, &query.hybrid) + { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedders = index_scheduler.embedders(embedder_configs)?; - let embedder_name = - if let Some(HybridQuery { semantic_ratio: _, embedder: Some(embedder) }) = - &query.hybrid - { - Some(embedder) - } else { - None - }; + let embedder = if let Some(embedder_name) = embedder { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; - let embedder = if let Some(embedder_name) = embedder_name { - embedders.get(embedder_name) - } else { - embedders.get_default() - }; + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) + .map_err(milli::Error::from)? + .0; + let embeddings = embedder + .embed(vec![q.to_owned()]) + .await + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); - let embedder = embedder - .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) - .map_err(milli::Error::from)? - .0; - let embeddings = embedder - .embed(vec![prompt]) - .await - .map_err(milli::vector::Error::from) - .map_err(milli::Error::from)? - .pop() - .expect("No vector returned from embedding"); - - if embeddings.iter().nth(1).is_some() { - warn!("Ignoring embeddings past the first one in long search query"); - query.vector = - Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec())); - } else { - query.vector = Some(VectorQuery::Vector(embeddings.into_inner())); - } + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = Some(embeddings.iter().next().unwrap().to_vec()); + } else { + query.vector = Some(embeddings.into_inner()); } - Some(vector) => query.vector = Some(vector), - None => {} - }; + } Ok(()) } diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index b596e2cc8..267a404c0 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -7,14 +7,13 @@ use deserr::Deserr; use either::Either; use index_scheduler::RoFeatures; use indexmap::IndexMap; -use log::warn; use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; -use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, VectorQuery}; +use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; @@ -44,7 +43,7 @@ pub struct SearchQuery { #[deserr(default, error = DeserrJsonError)] pub q: Option, #[deserr(default, error = DeserrJsonError)] - pub vector: Option, + pub vector: Option>, #[deserr(default, error = DeserrJsonError)] pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] @@ -105,6 +104,8 @@ impl std::convert::TryFrom for SemanticRatio { type Error = InvalidSearchSemanticRatio; fn try_from(f: f32) -> Result { + // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable + #[allow(clippy::manual_range_contains)] if f > 1.0 || f < 0.0 { Err(InvalidSearchSemanticRatio) } else { @@ -139,7 +140,7 @@ pub struct SearchQueryWithIndex { #[deserr(default, error = DeserrJsonError)] pub q: Option, #[deserr(default, error = DeserrJsonError)] - pub vector: Option, + pub vector: Option>, #[deserr(default, error = DeserrJsonError)] pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] @@ -376,8 +377,16 @@ fn prepare_search<'t>( ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { let mut search = index.search(rtxn); - if query.vector.is_some() && query.q.is_some() { - warn!("Attempting hybrid search"); + if query.vector.is_some() { + features.check_vector("Passing `vector` as a query parameter")?; + } + + if query.hybrid.is_some() { + features.check_vector("Passing `hybrid` as a query parameter")?; + } + + if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() { + return Err(MeilisearchHttpError::MissingSearchHybrid); } if let Some(ref vector) = query.vector { @@ -385,14 +394,9 @@ fn prepare_search<'t>( // If semantic ratio is 0.0, only the query search will impact the search results, // skip the vector Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), - _otherwise => match vector { - VectorQuery::Vector(vector) => { - search.vector(vector.clone()); - } - VectorQuery::String(_) => { - panic!("Failed while preparing search; caller did not generate embedding for query") - } - }, + _otherwise => { + search.vector(vector.clone()); + } } } @@ -431,10 +435,6 @@ fn prepare_search<'t>( features.check_score_details()?; } - if query.vector.is_some() { - features.check_vector("Passing `vector` as a query parameter")?; - } - if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { search.embedder_name(embedder); } @@ -492,7 +492,7 @@ pub fn perform_search( let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = match &query.hybrid { Some(hybrid) => match *hybrid.semantic_ratio { - 0.0 | 1.0 => search.execute()?, + ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?, ratio => search.execute_hybrid(ratio)?, }, None => search.execute()?, @@ -700,10 +700,7 @@ pub fn perform_search( hits: documents, hits_info, query: query.q.unwrap_or_default(), - vector: match query.vector { - Some(VectorQuery::Vector(vector)) => Some(vector), - _ => None, - }, + vector: query.vector, processing_time_ms: before_search.elapsed().as_millis(), facet_distribution, facet_stats, diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs index 578667244..7986091b0 100644 --- a/meilisearch/tests/search/hybrid.rs +++ b/meilisearch/tests/search/hybrid.rs @@ -1,4 +1,4 @@ -use meili_snap::{json_string, snapshot}; +use meili_snap::snapshot; use once_cell::sync::Lazy; use crate::common::index::Index; diff --git a/milli/src/lib.rs b/milli/src/lib.rs index ce37fe375..f6b398304 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -59,7 +59,7 @@ pub use self::index::Index; pub use self::search::{ FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy, - VectorQuery, DEFAULT_VALUES_PER_FACET, + DEFAULT_VALUES_PER_FACET, }; pub type Result = std::result::Result; diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs index 129857fb5..67365cf52 100644 --- a/milli/src/search/hybrid.rs +++ b/milli/src/search/hybrid.rs @@ -1,49 +1,37 @@ use std::cmp::Ordering; -use std::collections::HashMap; use itertools::Itertools; use roaring::RoaringBitmap; -use super::new::{execute_vector_search, PartialSearchResult}; use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; -use crate::{ - execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult, -}; +use crate::{MatchingWords, Result, Search, SearchResult}; -struct CombinedSearchResult { +struct ScoreWithRatioResult { matching_words: MatchingWords, candidates: RoaringBitmap, - document_scores: Vec<(u32, CombinedScore)>, + document_scores: Vec<(u32, ScoreWithRatio)>, } -type CombinedScore = (Vec, Option>); +type ScoreWithRatio = (Vec, f32); -fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { - let mut left_main_it = ScoreDetails::score_values(left.0.iter()); - let mut left_sub_it = - ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten()); - - let mut right_main_it = ScoreDetails::score_values(right.0.iter()); - let mut right_sub_it = - ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten()); - - let mut left_main = left_main_it.next(); - let mut left_sub = left_sub_it.next(); - let mut right_main = right_main_it.next(); - let mut right_sub = right_sub_it.next(); +fn compare_scores( + &(ref left_scores, left_ratio): &ScoreWithRatio, + &(ref right_scores, right_ratio): &ScoreWithRatio, +) -> Ordering { + let mut left_it = ScoreDetails::score_values(left_scores.iter()); + let mut right_it = ScoreDetails::score_values(right_scores.iter()); loop { - let left = - take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it); - - let right = - take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it); + let left = left_it.next(); + let right = right_it.next(); match (left, right) { (None, None) => return Ordering::Equal, (None, Some(_)) => return Ordering::Less, (Some(_), None) => return Ordering::Greater, (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { + let left = left * left_ratio as f64; + let right = right * right_ratio as f64; if (left - right).abs() <= f64::EPSILON { continue; } @@ -72,94 +60,17 @@ fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering { } } -fn take_best_score<'a>( - main_score: &mut Option>, - sub_score: &mut Option>, - main_it: &mut impl Iterator>, - sub_it: &mut impl Iterator>, -) -> Option> { - match (*main_score, *sub_score) { - (Some(main), None) => { - *main_score = main_it.next(); - Some(main) - } - (None, Some(sub)) => { - *sub_score = sub_it.next(); - Some(sub) - } - (main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => { - // take max, both advance - *main_score = main_it.next(); - *sub_score = sub_it.next(); - if main_f >= sub_v { - main - } else { - sub - } - } - (main @ Some(ScoreValue::Score(_)), _) => { - *main_score = main_it.next(); - main - } - (_, sub @ Some(ScoreValue::Score(_))) => { - *sub_score = sub_it.next(); - sub - } - (main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => { - // take best advance both - *main_score = main_it.next(); - *sub_score = sub_it.next(); - if main_geo >= sub_geo { - main - } else { - sub - } - } - (main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => { - // take best advance both - *main_score = main_it.next(); - *sub_score = sub_it.next(); - if main_sort >= sub_sort { - main - } else { - sub - } - } - ( - Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), - Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)), - ) => None, - - (None, None) => None, - } -} - -impl CombinedSearchResult { - fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self { - let mut docid_scores = HashMap::new(); - for (docid, score) in - main_results.documents_ids.iter().zip(main_results.document_scores.into_iter()) - { - docid_scores.insert(*docid, (score, None)); - } - - for (docid, score) in ancillary_results +impl ScoreWithRatioResult { + fn new(results: SearchResult, ratio: f32) -> Self { + let document_scores = results .documents_ids - .iter() - .zip(ancillary_results.document_scores.into_iter()) - { - docid_scores - .entry(*docid) - .and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score)); - } - - let mut document_scores: Vec<_> = docid_scores.into_iter().collect(); - - document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse()); + .into_iter() + .zip(results.document_scores.into_iter().map(|scores| (scores, ratio))) + .collect(); Self { - matching_words: main_results.matching_words, - candidates: main_results.candidates, + matching_words: results.matching_words, + candidates: results.candidates, document_scores, } } @@ -200,7 +111,7 @@ impl CombinedSearchResult { } impl<'a> Search<'a> { - pub fn execute_hybrid(&self) -> Result { + pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result { // TODO: find classier way to achieve that than to reset vector and query params // create separate keyword and semantic searches let mut search = Search { @@ -223,8 +134,6 @@ impl<'a> Search<'a> { }; let vector_query = search.vector.take(); - let keyword_query = self.query.as_deref(); - let keyword_results = search.execute()?; // skip semantic search if we don't have a vector query (placeholder search) @@ -233,7 +142,7 @@ impl<'a> Search<'a> { }; // completely skip semantic search if the results of the keyword search are good enough - if self.results_good_enough(&keyword_results) { + if self.results_good_enough(&keyword_results, semantic_ratio) { return Ok(keyword_results); } @@ -243,94 +152,18 @@ impl<'a> Search<'a> { // TODO: would be better to have two distinct functions at this point let vector_results = search.execute()?; - // Compute keyword scores for vector_results - let keyword_results_for_vector = - self.keyword_results_for_vector(keyword_query, &vector_results)?; - - // compute vector scores for keyword_results - let vector_results_for_keyword = - // can unwrap because we returned already if there was no vector query - self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?; - - /// TODO apply sementic ratio - let keyword_results = - CombinedSearchResult::new(keyword_results, vector_results_for_keyword); - let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector); + let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); + let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); let merge_results = - CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit); + ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); assert!(merge_results.documents_ids.len() <= self.limit); Ok(merge_results) } - fn vector_results_for_keyword( - &self, - vector: &[f32], - keyword_results: &SearchResult, - ) -> Result { - let embedder_name; - let embedder_name = match &self.embedder_name { - Some(embedder_name) => embedder_name, - None => { - embedder_name = self.index.default_embedding_name(self.rtxn)?; - &embedder_name - } - }; - - let mut ctx = SearchContext::new(self.index, self.rtxn); - - if let Some(searchable_attributes) = self.searchable_attributes { - ctx.searchable_attributes(searchable_attributes)?; - } - - let universe = keyword_results.documents_ids.iter().collect(); - - execute_vector_search( - &mut ctx, - vector, - ScoringStrategy::Detailed, - universe, - &self.sort_criteria, - self.geo_strategy, - 0, - self.limit + self.offset, - self.distribution_shift, - embedder_name, - ) - } - - fn keyword_results_for_vector( - &self, - query: Option<&str>, - vector_results: &SearchResult, - ) -> Result { - let mut ctx = SearchContext::new(self.index, self.rtxn); - - if let Some(searchable_attributes) = self.searchable_attributes { - ctx.searchable_attributes(searchable_attributes)?; - } - - let universe = vector_results.documents_ids.iter().collect(); - - execute_search( - &mut ctx, - query, - self.terms_matching_strategy, - ScoringStrategy::Detailed, - self.exhaustive_number_hits, - universe, - &self.sort_criteria, - self.geo_strategy, - 0, - self.limit + self.offset, - Some(self.words_limit), - &mut DefaultSearchLogger, - &mut DefaultSearchLogger, - ) - } - - fn results_good_enough(&self, keyword_results: &SearchResult) -> bool { - const GOOD_ENOUGH_SCORE: f64 = 0.9; + fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool { + // A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5 + const GOOD_ENOUGH_SCORE: f64 = 0.45; // 1. we check that we got a sufficient number of results if keyword_results.document_scores.len() < self.limit + self.offset { @@ -341,7 +174,7 @@ impl<'a> Search<'a> { // we need to check all results because due to sort like rules, they're not necessarily in relevancy order for score in &keyword_results.document_scores { let score = ScoreDetails::global_score(score.iter()); - if score < GOOD_ENOUGH_SCORE { + if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE { return false; } } diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 44fb3556f..0fb24be84 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -3,7 +3,6 @@ use std::ops::ControlFlow; use charabia::normalizer::NormalizerOption; use charabia::Normalize; -use deserr::{DeserializeError, Deserr, Sequence}; use fst::automaton::{Automaton, Str}; use fst::{IntoStreamer, Streamer}; use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA}; @@ -57,53 +56,6 @@ pub struct Search<'a> { embedder_name: Option, } -#[derive(Debug, Clone, PartialEq)] -pub enum VectorQuery { - Vector(Vec), - String(String), -} - -impl Deserr for VectorQuery -where - E: DeserializeError, -{ - fn deserialize_from_value( - value: deserr::Value, - location: deserr::ValuePointerRef, - ) -> std::result::Result { - match value { - deserr::Value::String(s) => Ok(VectorQuery::String(s)), - deserr::Value::Sequence(seq) => { - let v: std::result::Result, _> = seq - .into_iter() - .enumerate() - .map(|(index, v)| match v.into_value() { - deserr::Value::Float(f) => Ok(f as f32), - deserr::Value::Integer(i) => Ok(i as f32), - v => Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::IncorrectValueKind { - actual: v, - accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer], - }, - location.push_index(index), - ))), - }) - .collect(); - Ok(VectorQuery::Vector(v?)) - } - _ => Err(deserr::take_cf_content(E::error::( - None, - deserr::ErrorKind::IncorrectValueKind { - actual: value, - accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence], - }, - location, - ))), - } - } -} - impl<'a> Search<'a> { pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> { Search { From 87bba98bd8b84f634260f9f198a0ea95a0892232 Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 14 Dec 2023 16:01:35 +0100 Subject: [PATCH 27/28] Various changes - fixed seed for arroy - check vector dimensions as soon as it is provided to search - don't embed whitespace --- meilisearch/src/routes/indexes/search.rs | 102 ++++++++++++++++------- meilisearch/src/routes/multi_search.rs | 13 +-- meilisearch/src/search.rs | 22 +++-- meilisearch/tests/search/hybrid.rs | 6 +- milli/src/search/mod.rs | 22 ++++- milli/src/update/index_documents/mod.rs | 2 +- milli/src/vector/hf.rs | 10 ++- milli/src/vector/mod.rs | 8 ++ milli/src/vector/openai.rs | 14 +++- 9 files changed, 148 insertions(+), 51 deletions(-) diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index c2b6ca3fc..c474d285e 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -9,6 +9,7 @@ use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli; +use meilisearch_types::milli::vector::DistributionShift; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -200,10 +201,11 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); - embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; + tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); } @@ -238,10 +240,11 @@ pub async fn search_with_post( let features = index_scheduler.features(); - embed(&mut query, index_scheduler.get_ref(), &index).await?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; + tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); } @@ -257,39 +260,74 @@ pub async fn embed( query: &mut SearchQuery, index_scheduler: &IndexScheduler, index: &milli::Index, -) -> Result<(), ResponseError> { - if let (None, Some(q), Some(HybridQuery { semantic_ratio: _, embedder })) = - (&query.vector, &query.q, &query.hybrid) - { - let embedder_configs = index.embedding_configs(&index.read_txn()?)?; - let embedders = index_scheduler.embedders(embedder_configs)?; +) -> Result, ResponseError> { + match (&query.hybrid, &query.vector, &query.q) { + (Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q)) + if !q.trim().is_empty() => + { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedders = index_scheduler.embedders(embedder_configs)?; - let embedder = if let Some(embedder_name) = embedder { - embedders.get(embedder_name) - } else { - embedders.get_default() - }; + let embedder = if let Some(embedder_name) = embedder { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; - let embedder = embedder - .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) - .map_err(milli::Error::from)? - .0; - let embeddings = embedder - .embed(vec![q.to_owned()]) - .await - .map_err(milli::vector::Error::from) - .map_err(milli::Error::from)? - .pop() - .expect("No vector returned from embedding"); + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) + .map_err(milli::Error::from)? + .0; - if embeddings.iter().nth(1).is_some() { - warn!("Ignoring embeddings past the first one in long search query"); - query.vector = Some(embeddings.iter().next().unwrap().to_vec()); - } else { - query.vector = Some(embeddings.into_inner()); + let distribution = embedder.distribution(); + + let embeddings = embedder + .embed(vec![q.to_owned()]) + .await + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); + + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = Some(embeddings.iter().next().unwrap().to_vec()); + } else { + query.vector = Some(embeddings.into_inner()); + } + Ok(distribution) } + (Some(hybrid), vector, _) => { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedders = index_scheduler.embedders(embedder_configs)?; + + let embedder = if let Some(embedder_name) = &hybrid.embedder { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; + + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) + .map_err(milli::Error::from)? + .0; + + if let Some(vector) = vector { + if vector.len() != embedder.dimensions() { + return Err(meilisearch_types::milli::Error::UserError( + meilisearch_types::milli::UserError::InvalidVectorDimensions { + expected: embedder.dimensions(), + found: vector.len(), + }, + ) + .into()); + } + } + + Ok(embedder.distribution()) + } + _ => Ok(None), } - Ok(()) } #[cfg(test)] diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index 4e578572d..8e81688e6 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -75,12 +75,15 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; - embed(&mut query, index_scheduler.get_ref(), &index).await.with_index(query_index)?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index) + .await + .with_index(query_index)?; - let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)) - .await - .with_index(query_index)?; + let search_result = tokio::task::spawn_blocking(move || { + perform_search(&index, query, features, distribution) + }) + .await + .with_index(query_index)?; search_results.push(SearchResultWithIndex { index_uid: index_uid.into_inner(), diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 267a404c0..b5dba8a58 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -13,6 +13,7 @@ use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; +use meilisearch_types::milli::vector::DistributionShift; use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; @@ -90,16 +91,22 @@ pub struct SearchQuery { #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] pub struct HybridQuery { /// TODO validate that sementic ratio is between 0.0 and 1,0 - #[deserr(default, error = DeserrJsonError)] + #[deserr(default, error = DeserrJsonError, default)] pub semantic_ratio: SemanticRatio, #[deserr(default, error = DeserrJsonError, default)] pub embedder: Option, } -#[derive(Debug, Clone, Copy, Default, PartialEq, Deserr)] +#[derive(Debug, Clone, Copy, PartialEq, Deserr)] #[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] pub struct SemanticRatio(f32); +impl Default for SemanticRatio { + fn default() -> Self { + DEFAULT_SEMANTIC_RATIO() + } +} + impl std::convert::TryFrom for SemanticRatio { type Error = InvalidSearchSemanticRatio; @@ -374,6 +381,7 @@ fn prepare_search<'t>( rtxn: &'t RoTxn, query: &'t SearchQuery, features: RoFeatures, + distribution: Option, ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { let mut search = index.search(rtxn); @@ -389,6 +397,8 @@ fn prepare_search<'t>( return Err(MeilisearchHttpError::MissingSearchHybrid); } + search.distribution_shift(distribution); + if let Some(ref vector) = query.vector { match &query.hybrid { // If semantic ratio is 0.0, only the query search will impact the search results, @@ -482,12 +492,13 @@ pub fn perform_search( index: &Index, query: SearchQuery, features: RoFeatures, + distribution: Option, ) -> Result { let before_search = Instant::now(); let rtxn = index.read_txn()?; let (search, is_finite_pagination, max_total_hits, offset) = - prepare_search(index, &rtxn, &query, features)?; + prepare_search(index, &rtxn, &query, features, distribution)?; let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = match &query.hybrid { @@ -718,8 +729,9 @@ pub fn perform_facet_search( let before_search = Instant::now(); let rtxn = index.read_txn()?; - let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features)?; - let mut facet_search = SearchForFacetValues::new(facet_name, search); + let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?; + let mut facet_search = + SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some()); if let Some(facet_query) = &facet_query { facet_search.query(facet_query); } diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs index 7986091b0..c3534c110 100644 --- a/meilisearch/tests/search/hybrid.rs +++ b/meilisearch/tests/search/hybrid.rs @@ -27,11 +27,11 @@ async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Inde ) .await; assert_eq!(202, code, "{:?}", response); - index.wait_task(0).await; + index.wait_task(response.uid()).await; let (response, code) = index.add_documents(documents.clone(), None).await; assert_eq!(202, code, "{:?}", response); - index.wait_task(1).await; + index.wait_task(response.uid()).await; index } @@ -68,7 +68,7 @@ async fn simple_search() { ) .await; snapshot!(code, @"200 OK"); - snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); + snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###); let (response, code) = index .search_post( diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index 0fb24be84..3e4849578 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -154,6 +154,15 @@ impl<'a> Search<'a> { self } + pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { + if has_vector_search { + let ctx = SearchContext::new(self.index, self.rtxn); + filtered_universe(&ctx, &self.filter) + } else { + Ok(self.execute()?.candidates) + } + } + pub fn execute(&self) -> Result { let embedder_name; let embedder_name = match &self.embedder_name { @@ -297,11 +306,16 @@ pub struct SearchForFacetValues<'a> { query: Option, facet: String, search_query: Search<'a>, + is_hybrid: bool, } impl<'a> SearchForFacetValues<'a> { - pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> { - SearchForFacetValues { query: None, facet, search_query } + pub fn new( + facet: String, + search_query: Search<'a>, + is_hybrid: bool, + ) -> SearchForFacetValues<'a> { + SearchForFacetValues { query: None, facet, search_query, is_hybrid } } pub fn query(&mut self, query: impl Into) -> &mut Self { @@ -351,7 +365,9 @@ impl<'a> SearchForFacetValues<'a> { None => return Ok(vec![]), }; - let search_candidates = self.search_query.execute()?.candidates; + let search_candidates = self + .search_query + .execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?; match self.query.as_ref() { Some(query) => { diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index 6906bbcd3..ffc3f6b3a 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -509,7 +509,7 @@ where // We write the primary key field id into the main database self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; - let mut rng = rand::rngs::StdRng::from_entropy(); + let mut rng = rand::rngs::StdRng::seed_from_u64(42); for (embedder_name, dimension) in dimension { let wtxn = &mut *self.wtxn; diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs index 3162dadec..0a6bcbe93 100644 --- a/milli/src/vector/hf.rs +++ b/milli/src/vector/hf.rs @@ -7,7 +7,7 @@ use hf_hub::{Repo, RepoType}; use tokenizers::{PaddingParams, Tokenizer}; pub use super::error::{EmbedError, Error, NewEmbedderError}; -use super::{Embedding, Embeddings}; +use super::{DistributionShift, Embedding, Embeddings}; #[derive( Debug, @@ -184,4 +184,12 @@ impl Embedder { pub fn dimensions(&self) -> usize { self.dimensions } + + pub fn distribution(&self) -> Option { + if self.options.model == "BAAI/bge-base-en-v1.5" { + Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) + } else { + None + } + } } diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs index df5750e77..81c4cf4a1 100644 --- a/milli/src/vector/mod.rs +++ b/milli/src/vector/mod.rs @@ -202,6 +202,14 @@ impl Embedder { Embedder::UserProvided(embedder) => embedder.dimensions(), } } + + pub fn distribution(&self) -> Option { + match self { + Embedder::HuggingFace(embedder) => embedder.distribution(), + Embedder::OpenAi(embedder) => embedder.distribution(), + Embedder::UserProvided(_embedder) => None, + } + } } #[derive(Debug, Clone, Copy)] diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs index 7ae626494..c11e6ddc6 100644 --- a/milli/src/vector/openai.rs +++ b/milli/src/vector/openai.rs @@ -4,7 +4,7 @@ use reqwest::StatusCode; use serde::{Deserialize, Serialize}; use super::error::{EmbedError, NewEmbedderError}; -use super::{Embedding, Embeddings}; +use super::{DistributionShift, Embedding, Embeddings}; #[derive(Debug)] pub struct Embedder { @@ -65,6 +65,14 @@ impl EmbeddingModel { _ => None, } } + + fn distribution(&self) -> Option { + match self { + EmbeddingModel::TextEmbeddingAda002 => { + Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) + } + } + } } pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; @@ -326,6 +334,10 @@ impl Embedder { pub fn dimensions(&self) -> usize { self.options.embedding_model.dimensions() } + + pub fn distribution(&self) -> Option { + self.options.embedding_model.distribution() + } } // retrying in case of failure From eb5cb91da22647d034b361dfb40147746329faca Mon Sep 17 00:00:00 2001 From: Louis Dureuil Date: Thu, 14 Dec 2023 16:19:46 +0100 Subject: [PATCH 28/28] Switch default from hf to openai --- milli/src/vector/settings.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs index a91692613..912cdf953 100644 --- a/milli/src/vector/settings.rs +++ b/milli/src/vector/settings.rs @@ -157,7 +157,7 @@ where impl Default for EmbedderSettings { fn default() -> Self { - Self::HuggingFace(Default::default()) + Self::OpenAi(Default::default()) } }