diff --git a/Cargo.lock b/Cargo.lock index fda5f2493..e4826b489 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,7 +46,7 @@ dependencies = [ "actix-tls", "actix-utils", "ahash 0.8.3", - "base64 0.21.2", + "base64 0.21.5", "bitflags 1.3.2", "brotli", "bytes", @@ -56,7 +56,7 @@ dependencies = [ "flate2", "futures-core", "h2", - "http", + "http 0.2.9", "httparse", "httpdate", "itoa", @@ -90,7 +90,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d66ff4d247d2b160861fa2866457e85706833527840e4133f8f49aa423a38799" dependencies = [ "bytestring", - "http", + "http 0.2.9", "regex", "serde", "tracing", @@ -120,7 +120,7 @@ dependencies = [ "futures-util", "mio", "num_cpus", - "socket2", + "socket2 0.4.9", "tokio", "tracing", ] @@ -189,7 +189,7 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "http", + "http 0.2.9", "itoa", "language-tags", "log", @@ -201,7 +201,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "smallvec", - "socket2", + "socket2 0.4.9", "time", "url", ] @@ -280,9 +280,9 @@ dependencies = [ [[package]] name = "aho-corasick" -version = "1.0.2" +version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43f6cb1bf222025340178f382c426f13757b2960e89779dfcb319c32542a5a41" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] @@ -310,16 +310,15 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" -version = "0.3.2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal", "utf8parse", ] @@ -349,9 +348,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.1" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys 0.48.0", @@ -366,6 +365,12 @@ dependencies = [ "backtrace", ] +[[package]] +name = "anymap2" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c" + [[package]] name = "arbitrary" version = "1.3.0" @@ -375,6 +380,24 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "arroy" +version = "0.1.0" +source = "git+https://github.com/meilisearch/arroy.git#4f193fd534acd357b65bfe9eec4b3fed8ece2007" +dependencies = [ + "bytemuck", + "byteorder", + "heed", + "log", + "memmap2 0.9.0", + "ordered-float 4.2.0", + "rand", + "rayon", + "roaring", + "tempfile", + "thiserror", +] + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -456,9 +479,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" -version = "0.21.2" +version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d" +checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "base64ct" @@ -509,6 +532,21 @@ dependencies = [ "serde", ] +[[package]] +name = "bit-set" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1" +dependencies = [ + "bit-vec", +] + +[[package]] +name = "bit-vec" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb" + [[package]] name = "bitflags" version = "1.3.2" @@ -517,9 +555,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" dependencies = [ "serde", ] @@ -556,12 +594,12 @@ dependencies = [ [[package]] name = "bstr" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6798148dccfbff0fae41c7574d2fa8f1ef3492fba0face179de5d8d447d67b05" +checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c" dependencies = [ "memchr", - "regex-automata 0.3.6", + "regex-automata 0.4.3", "serde", ] @@ -589,9 +627,9 @@ checksum = "2c676a478f63e9fa2dd5368a42f28bba0d6c560b775f38583c8bbaa7fcd67c9c" [[package]] name = "bytemuck" -version = "1.13.1" +version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" dependencies = [ "bytemuck_derive", ] @@ -609,9 +647,9 @@ dependencies = [ [[package]] name = "byteorder" -version = "1.4.3" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" @@ -659,6 +697,58 @@ dependencies = [ "displaydoc", ] +[[package]] +name = "candle-core" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528" +dependencies = [ + "byteorder", + "gemm", + "half 2.3.1", + "memmap2 0.7.1", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + +[[package]] +name = "candle-nn" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528" +dependencies = [ + "candle-core", + "half 2.3.1", + "num-traits", + "rayon", + "safetensors", + "serde", + "thiserror", +] + +[[package]] +name = "candle-transformers" +version = "0.3.1" +source = "git+https://github.com/huggingface/candle.git#f4fcf6090045ac44122fd5f0a7e46db6e3e16528" +dependencies = [ + "byteorder", + "candle-core", + "candle-nn", + "num-traits", + "rand", + "rayon", + "serde", + "serde_json", + "serde_plain", + "tracing", + "wav", +] + [[package]] name = "cargo_toml" version = "0.15.3" @@ -765,7 +855,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" dependencies = [ "ciborium-io", - "half", + "half 1.8.2", ] [[package]] @@ -780,20 +870,19 @@ dependencies = [ [[package]] name = "clap" -version = "4.3.21" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c27cdf28c0f604ba3f512b0c9a409f8de8513e4816705deb0498b627e7c3a3fd" +checksum = "2275f18819641850fa26c89acc84d465c1bf91ce57bc2748b28c420473352f64" dependencies = [ "clap_builder", "clap_derive", - "once_cell", ] [[package]] name = "clap_builder" -version = "4.3.21" +version = "4.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08a9f1ab5e9f01a9b81f202e8562eb9a10de70abf9eaeac1be465c28b75aa4aa" +checksum = "07cdf1b148b25c1e1f7a42225e30a0d99a615cd4637eae7365548dd4529b95bc" dependencies = [ "anstream", "anstyle", @@ -803,9 +892,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.3.12" +version = "4.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050" +checksum = "cf9804afaaf59a91e75b022a30fb7229a7901f60c755489cc61c9b423b836442" dependencies = [ "heck", "proc-macro2", @@ -815,9 +904,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" [[package]] name = "cobs" @@ -851,6 +940,7 @@ dependencies = [ "encode_unicode", "lazy_static", "libc", + "unicode-width", "windows-sys 0.45.0", ] @@ -892,6 +982,16 @@ dependencies = [ "version_check", ] +[[package]] +name = "core-foundation" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.4" @@ -1040,6 +1140,12 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + [[package]] name = "crypto-common" version = "0.1.6" @@ -1226,6 +1332,15 @@ dependencies = [ "subtle", ] +[[package]] +name = "dirs" +version = "5.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "44c45a9d03d6676652bcb5e724c7e988de1acad23a711b5217ab9cbecbec2225" +dependencies = [ + "dirs-sys", +] + [[package]] name = "dirs-next" version = "1.0.2" @@ -1236,6 +1351,18 @@ dependencies = [ "dirs-sys-next", ] +[[package]] +name = "dirs-sys" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" +dependencies = [ + "libc", + "option-ext", + "redox_users", + "windows-sys 0.48.0", +] + [[package]] name = "dirs-sys-next" version = "0.1.2" @@ -1258,6 +1385,12 @@ dependencies = [ "syn 2.0.28", ] +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "doxygen-rs" version = "0.2.2" @@ -1274,7 +1407,7 @@ dependencies = [ "anyhow", "big_s", "flate2", - "http", + "http 0.2.9", "log", "maplit", "meili-snap", @@ -1292,6 +1425,16 @@ dependencies = [ "uuid 1.5.0", ] +[[package]] +name = "dyn-stack" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" +dependencies = [ + "bytemuck", + "reborrow", +] + [[package]] name = "either" version = "1.9.0" @@ -1436,23 +1579,31 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.2" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b30f669a7961ef1631673d2766cc92f52d64f7ef354d4fe0ddfd30ed52f0f4f" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ - "errno-dragonfly", "libc", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] -name = "errno-dragonfly" -version = "0.1.2" +name = "esaxx-rs" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf" +checksum = "d817e038c30374a4bcb22f94d0a8a0e216958d4c3dcde369b1439fec4bdda6e6" dependencies = [ "cc", - "libc", +] + +[[package]] +name = "fancy-regex" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2" +dependencies = [ + "bit-set", + "regex", ] [[package]] @@ -1568,9 +1719,9 @@ checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a" [[package]] name = "futures" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40" +checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335" dependencies = [ "futures-channel", "futures-core", @@ -1583,9 +1734,9 @@ dependencies = [ [[package]] name = "futures-channel" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2" +checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb" dependencies = [ "futures-core", "futures-sink", @@ -1593,15 +1744,15 @@ dependencies = [ [[package]] name = "futures-core" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c" +checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c" [[package]] name = "futures-executor" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0" +checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc" dependencies = [ "futures-core", "futures-task", @@ -1610,15 +1761,15 @@ dependencies = [ [[package]] name = "futures-io" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964" +checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa" [[package]] name = "futures-macro" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72" +checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", @@ -1627,21 +1778,21 @@ dependencies = [ [[package]] name = "futures-sink" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e" +checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817" [[package]] name = "futures-task" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65" +checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2" [[package]] name = "futures-util" -version = "0.3.28" +version = "0.3.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533" +checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104" dependencies = [ "futures-channel", "futures-core", @@ -1677,6 +1828,123 @@ dependencies = [ "byteorder", ] +[[package]] +name = "gemm" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b3afa707040531a7527477fd63a81ea4f6f3d26037a2f96776e57fb843b258e" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cc3973a4c30c73f26a099113953d0c772bb17ee2e07976c0a06b8fe1f38a57d" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30362894b93dada374442cb2edf4512ddf19513c9bec88e06a445bcb6b22e64f" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "988499faa80566b046b4fee2c5f15af55b5a20c1fe8486b112ebb34efa045ad6" +dependencies = [ + "bytemuck", + "dyn-stack", + "half 2.3.1", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f16" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6cf2854a12371684c38d9a865063a27661812a3ff5803454c5742e8f5a388ce" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half 2.3.1", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bc84003cf6d950a7c7ca714ad6db281b6cef5c7d462f5cd9ad90ea2409c7227" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.16.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35187ef101a71eed0ecd26fb4a6255b4192a12f1c5335f3a795698f2d9b6cf33" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -1764,7 +2032,7 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.9", "indexmap 1.9.3", "slab", "tokio", @@ -1778,6 +2046,20 @@ version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + [[package]] name = "hash32" version = "0.2.1" @@ -1827,7 +2109,7 @@ version = "0.20.0-alpha.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9648a50991c86df7d00c56c268c27754fcf4c80be2ba57fc4a00dc928c6fe934" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "bytemuck", "byteorder", "heed-traits", @@ -1871,6 +2153,22 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "hf-hub" +version = "0.3.2" +source = "git+https://github.com/dureuill/hf-hub.git?branch=rust_tls#88d4f11cb9fa079f2912bacb96f5080b16825ce8" +dependencies = [ + "dirs", + "http 1.0.0", + "indicatif", + "log", + "rand", + "serde", + "serde_json", + "thiserror", + "ureq", +] + [[package]] name = "hmac" version = "0.12.1" @@ -1891,6 +2189,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.5" @@ -1898,7 +2207,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d5f38f16d184e36f2408a55281cd658ecbd3ca05cce6d6510a176eca393e26d1" dependencies = [ "bytes", - "http", + "http 0.2.9", "pin-project-lite", ] @@ -1931,13 +2240,13 @@ dependencies = [ "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "httparse", "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.4.9", "tokio", "tower-service", "tracing", @@ -1951,7 +2260,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d78e1e73ec14cf7375674f74d7dde185c8206fd9dea6fb6295e8a98098aaa97" dependencies = [ "futures-util", - "http", + "http 0.2.9", "hyper", "rustls 0.21.6", "tokio", @@ -2507,6 +2816,19 @@ dependencies = [ "serde", ] +[[package]] +name = "indicatif" +version = "0.17.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb28741c9db9a713d93deb3bb9515c20788cef5815265bee4980e87bde7e0f25" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + [[package]] name = "inout" version = "0.1.3" @@ -2541,21 +2863,6 @@ dependencies = [ "cfg-if", ] -[[package]] -name = "instant-distance" -version = "0.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c619cdaa30bb84088963968bee12a45ea5fbbf355f2c021bcd15589f5ca494a" -dependencies = [ - "num_cpus", - "ordered-float", - "parking_lot", - "rand", - "rayon", - "serde", - "serde-big-array", -] - [[package]] name = "io-lifetimes" version = "1.0.11" @@ -2591,7 +2898,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", - "rustix 0.38.7", + "rustix 0.38.26", "windows-sys 0.48.0", ] @@ -2666,7 +2973,7 @@ version = "8.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "pem", "ring", "serde", @@ -2674,6 +2981,16 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "kstring" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3066350882a1cd6d950d055997f379ac37fd39f81cd4d8ed186032eb3c5747" +dependencies = [ + "serde", + "static_assertions", +] + [[package]] name = "language-tags" version = "0.3.2" @@ -2697,9 +3014,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.147" +version = "0.2.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" +checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c" [[package]] name = "libgit2-sys" @@ -2964,9 +3281,66 @@ checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4" [[package]] name = "linux-raw-sys" -version = "0.4.5" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" + +[[package]] +name = "liquid" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69f68ae1011499ae2ef879f631891f21c78e309755f4a5e483c4a8f12e10b609" +dependencies = [ + "doc-comment", + "liquid-core", + "liquid-derive", + "liquid-lib", + "serde", +] + +[[package]] +name = "liquid-core" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79e0724dfcaad5cfb7965ea0f178ca0870b8d7315178f4a7179f5696f7f04d5f" +dependencies = [ + "anymap2", + "itertools 0.10.5", + "kstring", + "liquid-derive", + "num-traits", + "pest", + "pest_derive", + "regex", + "serde", + "time", +] + +[[package]] +name = "liquid-derive" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + +[[package]] +name = "liquid-lib" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2a17e273a6fb1fb6268f7a5867ddfd0bd4683c7e19b51084f3d567fad4348c0" +dependencies = [ + "itertools 0.10.5", + "liquid-core", + "once_cell", + "percent-encoding", + "regex", + "time", + "unicode-segmentation", +] [[package]] name = "litemap" @@ -3025,9 +3399,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.19" +version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "logging_timer" @@ -3057,6 +3431,22 @@ version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b8c72594ac26bfd34f2d99dfced2edfaddfe8a476e3ff2ca0eb293d925c4f83" +[[package]] +name = "macro_rules_attribute" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a82271f7bc033d84bbca59a3ce3e4159938cb08a9c3aebbe54d215131518a13" +dependencies = [ + "macro_rules_attribute-proc_macro", + "paste", +] + +[[package]] +name = "macro_rules_attribute-proc_macro" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8dd856d451cc0da70e2ef2ce95a18e39a93b7558bedf10201ad28503f918568" + [[package]] name = "manifest-dir-macros" version = "0.1.17" @@ -3121,7 +3511,7 @@ dependencies = [ "futures", "futures-util", "hex", - "http", + "http 0.2.9", "index-scheduler", "indexmap 2.0.0", "insta", @@ -3140,7 +3530,7 @@ dependencies = [ "num_cpus", "obkv", "once_cell", - "ordered-float", + "ordered-float 3.7.0", "parking_lot", "permissive-json-pointer", "pin-project-lite", @@ -3184,7 +3574,7 @@ dependencies = [ name = "meilisearch-auth" version = "1.5.1" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "enum-iterator", "hmac", "maplit", @@ -3215,7 +3605,7 @@ dependencies = [ "fst", "insta", "meili-snap", - "memmap2", + "memmap2 0.7.1", "milli", "roaring", "serde", @@ -3245,15 +3635,25 @@ dependencies = [ [[package]] name = "memchr" -version = "2.5.0" +version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" +checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "memmap2" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" +dependencies = [ + "libc", + "stable_deref_trait", +] + +[[package]] +name = "memmap2" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "deaba38d7abf1d4cca21cc89e932e542ba2b9258664d2a9ef0e61512039c9375" dependencies = [ "libc", ] @@ -3271,12 +3671,16 @@ dependencies = [ name = "milli" version = "1.5.1" dependencies = [ + "arroy", "big_s", "bimap", "bincode", "bstr", "bytemuck", "byteorder", + "candle-core", + "candle-nn", + "candle-transformers", "charabia", "concat-arrays", "crossbeam-channel", @@ -3286,30 +3690,33 @@ dependencies = [ "filter-parser", "flatten-serde-json", "fst", + "futures", "fxhash", "geoutils", "grenad", "heed", + "hf-hub", "indexmap 2.0.0", "insta", - "instant-distance", "itertools 0.11.0", "json-depth-checker", "levenshtein_automata", + "liquid", "log", "logging_timer", "maplit", "md5", "meili-snap", - "memmap2", + "memmap2 0.7.1", "mimalloc", "obkv", "once_cell", - "ordered-float", + "ordered-float 3.7.0", "puffin", "rand", "rand_pcg", "rayon", + "reqwest", "roaring", "rstar", "serde", @@ -3320,7 +3727,10 @@ dependencies = [ "smartstring", "tempfile", "thiserror", + "tiktoken-rs", "time", + "tokenizers", + "tokio", "uuid 1.5.0", ] @@ -3366,9 +3776,9 @@ dependencies = [ [[package]] name = "mio" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2" +checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0" dependencies = [ "libc", "log", @@ -3376,6 +3786,27 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "monostate" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15f370ae88093ec6b11a710dec51321a61d420fafd1bad6e30d01bd9c920e8ee" +dependencies = [ + "monostate-impl", + "serde", +] + +[[package]] +name = "monostate-impl" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "371717c0a5543d6a800cac822eac735aa7d2d2fbb41002e9856a4089532dbdce" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "nelson" version = "0.1.0" @@ -3422,6 +3853,16 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "bytemuck", + "num-traits", +] + [[package]] name = "num-integer" version = "0.1.45" @@ -3452,6 +3893,12 @@ dependencies = [ "libc", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "object" version = "0.31.1" @@ -3473,12 +3920,40 @@ version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" +[[package]] +name = "onig" +version = "6.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c4b31c8722ad9171c6d77d3557db078cab2bd50afcc9d09c8b315c59df8ca4f" +dependencies = [ + "bitflags 1.3.2", + "libc", + "once_cell", + "onig_sys", +] + +[[package]] +name = "onig_sys" +version = "69.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b829e3d7e9cc74c7e315ee8edb185bf4190da5acde74afd7fc59c35b1f086e7" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "oorandom" version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +[[package]] +name = "option-ext" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" + [[package]] name = "ordered-float" version = "3.7.0" @@ -3488,6 +3963,15 @@ dependencies = [ "num-traits", ] +[[package]] +name = "ordered-float" +version = "4.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a76df7075c7d4d01fdcb46c912dd17fba5b60c78ea480b475f2b6ab6f666584e" +dependencies = [ + "num-traits", +] + [[package]] name = "page_size" version = "0.5.0" @@ -3755,6 +4239,12 @@ dependencies = [ "plotters-backend", ] +[[package]] +name = "portable-atomic" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3bccab0e7fd7cc19f820a1c8c91720af652d0c88dc9664dd72aef2614f04af3b" + [[package]] name = "postcard" version = "1.0.8" @@ -3864,6 +4354,18 @@ dependencies = [ "serde", ] +[[package]] +name = "pulp" +version = "0.18.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7057c1435edb390ebfc51743abad043377f1f698ce8e649a9b52a4b378be5e4d" +dependencies = [ + "bytemuck", + "libm", + "num-complex", + "reborrow", +] + [[package]] name = "quote" version = "1.0.32" @@ -3903,6 +4405,16 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rand_pcg" version = "0.3.1" @@ -3914,27 +4426,51 @@ dependencies = [ ] [[package]] -name = "rayon" -version = "1.7.0" +name = "raw-cpuid" +version = "10.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", ] [[package]] -name = "rayon-core" -version = "1.11.0" +name = "rayon-cond" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" +checksum = "059f538b55efd2309c9794130bc149c6a553db90e9d99c2030785c82f0bd7df9" +dependencies = [ + "either", + "itertools 0.11.0", + "rayon", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ - "crossbeam-channel", "crossbeam-deque", "crossbeam-utils", - "num_cpus", ] +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + [[package]] name = "redox_syscall" version = "0.2.16" @@ -3953,6 +4489,15 @@ dependencies = [ "bitflags 1.3.2", ] +[[package]] +name = "redox_syscall" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + [[package]] name = "redox_users" version = "0.4.3" @@ -3996,6 +4541,12 @@ dependencies = [ "regex-syntax", ] +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" + [[package]] name = "regex-syntax" version = "0.7.4" @@ -4004,17 +4555,17 @@ checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" [[package]] name = "reqwest" -version = "0.11.18" +version = "0.11.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55" +checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", "bytes", "encoding_rs", "futures-core", "futures-util", "h2", - "http", + "http 0.2.9", "http-body", "hyper", "hyper-rustls", @@ -4030,6 +4581,7 @@ dependencies = [ "serde", "serde_json", "serde_urlencoded", + "system-configuration", "tokio", "tokio-rustls 0.24.1", "tower-service", @@ -4037,7 +4589,7 @@ dependencies = [ "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots 0.22.6", + "webpki-roots 0.25.3", "winreg", ] @@ -4047,6 +4599,12 @@ version = "0.1.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c31b5c4033f8fdde8700e4657be2c497e7288f01515be52168c631e2e4d4086" +[[package]] +name = "riff" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1" + [[package]] name = "ring" version = "0.16.20" @@ -4092,6 +4650,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc_version" version = "0.4.0" @@ -4117,15 +4681,15 @@ dependencies = [ [[package]] name = "rustix" -version = "0.38.7" +version = "0.38.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "172891ebdceb05aa0005f533a6cbfca599ddd7d966f6f5d4d9b2e70478e70399" +checksum = "9470c4bf8246c8daf25f9598dca807fb6510347b1e1cfa55749113850c79d88a" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.1", "errno", "libc", - "linux-raw-sys 0.4.5", - "windows-sys 0.48.0", + "linux-raw-sys 0.4.12", + "windows-sys 0.52.0", ] [[package]] @@ -4148,7 +4712,7 @@ checksum = "1d1feddffcfcc0b33f5c6ce9a29e341e4cd59c3f78e7ee45f4a40c038b1d6cbb" dependencies = [ "log", "ring", - "rustls-webpki 0.101.3", + "rustls-webpki", "sct", ] @@ -4158,17 +4722,7 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2" dependencies = [ - "base64 0.21.2", -] - -[[package]] -name = "rustls-webpki" -version = "0.100.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e98ff011474fa39949b7e5c0428f9b4937eda7da7848bbb947786b7be0b27dab" -dependencies = [ - "ring", - "untrusted", + "base64 0.21.5", ] [[package]] @@ -4193,6 +4747,16 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" +[[package]] +name = "safetensors" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "same-file" version = "1.0.6" @@ -4238,6 +4802,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0293b4b29daaf487284529cc2f5675b8e57c61f70167ba415a463651fd6a918" +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + [[package]] name = "serde" version = "1.0.190" @@ -4247,15 +4817,6 @@ dependencies = [ "serde_derive", ] -[[package]] -name = "serde-big-array" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f" -dependencies = [ - "serde", -] - [[package]] name = "serde-cs" version = "0.2.4" @@ -4288,6 +4849,15 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_plain" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce1fc6db65a611022b23a0dec6975d63fb80a302cb3388835ff02c097258d50" +dependencies = [ + "serde", +] + [[package]] name = "serde_spanned" version = "0.6.3" @@ -4430,6 +5000,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys 0.48.0", +] + +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spin" version = "0.5.2" @@ -4445,6 +5036,18 @@ dependencies = [ "lock_api", ] +[[package]] +name = "spm_precompiled" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5851699c4033c63636f7ea4cf7b7c1f1bf06d0cc03cfb42e711de5a5c46cf326" +dependencies = [ + "base64 0.13.1", + "nom", + "serde", + "unicode-segmentation", +] + [[package]] name = "stable_deref_trait" version = "1.2.0" @@ -4538,6 +5141,27 @@ dependencies = [ "winapi", ] +[[package]] +name = "system-configuration" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" +dependencies = [ + "bitflags 1.3.2", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tar" version = "0.4.40" @@ -4560,14 +5184,14 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.7.1" +version = "3.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc02fddf48964c42031a0b3fe0428320ecf3a73c401040fc0096f97794310651" +checksum = "7ef1adac450ad7f4b3c28589471ade84f25f731a7a0fe30d71dfa9f60fd808e5" dependencies = [ "cfg-if", "fastrand", - "redox_syscall 0.3.5", - "rustix 0.38.7", + "redox_syscall 0.4.1", + "rustix 0.38.26", "windows-sys 0.48.0", ] @@ -4582,24 +5206,39 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.44" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "611040a08a0439f8248d1990b111c95baa9c704c805fa1f62104b39655fd7f90" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.44" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "090198534930841fab3a5d1bb637cde49e339654e606195f8d9c76eeb081dc96" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", "syn 2.0.28", ] +[[package]] +name = "tiktoken-rs" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4427b6b1c6b38215b92dd47a83a0ecc6735573d0a5a4c14acc0ac5b33b28adb" +dependencies = [ + "anyhow", + "base64 0.21.5", + "bstr", + "fancy-regex", + "lazy_static", + "parking_lot", + "rustc-hash", +] + [[package]] name = "time" version = "0.3.30" @@ -4666,12 +5305,43 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] -name = "tokio" -version = "1.29.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da" +name = "tokenizers" +version = "0.14.1" +source = "git+https://github.com/huggingface/tokenizers.git?tag=v0.14.1#6357206cdcce4d78ffb1e0372feb456caea09375" +dependencies = [ + "aho-corasick", + "clap", + "derive_builder", + "esaxx-rs", + "getrandom", + "indicatif", + "itertools 0.11.0", + "lazy_static", + "log", + "macro_rules_attribute", + "monostate", + "onig", + "paste", + "rand", + "rayon", + "rayon-cond", + "regex", + "regex-syntax", + "serde", + "serde_json", + "spm_precompiled", + "thiserror", + "unicode-normalization-alignments", + "unicode-segmentation", + "unicode_categories", +] + +[[package]] +name = "tokio" +version = "1.34.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9" dependencies = [ - "autocfg", "backtrace", "bytes", "libc", @@ -4680,16 +5350,16 @@ dependencies = [ "parking_lot", "pin-project-lite", "signal-hook-registry", - "socket2", + "socket2 0.5.5", "tokio-macros", "windows-sys 0.48.0", ] [[package]] name = "tokio-macros" -version = "2.1.0" +version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", @@ -4791,9 +5461,21 @@ dependencies = [ "cfg-if", "log", "pin-project-lite", + "tracing-attributes", "tracing-core", ] +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.28", +] + [[package]] name = "tracing-core" version = "0.1.31" @@ -4860,18 +5542,39 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-normalization-alignments" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43f613e4fa046e69818dd287fdc4bc78175ff20331479dab6e1b0f98d57062de" +dependencies = [ + "smallvec", +] + [[package]] name = "unicode-segmentation" version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1dd624098567895118886609431a7c3b8f516e41d30e0643f03d94592a147e36" +[[package]] +name = "unicode-width" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85" + [[package]] name = "unicode-xid" version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +[[package]] +name = "unicode_categories" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" + [[package]] name = "untrusted" version = "0.7.1" @@ -4880,17 +5583,21 @@ checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" [[package]] name = "ureq" -version = "2.7.1" +version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9" +checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97" dependencies = [ - "base64 0.21.2", + "base64 0.21.5", + "flate2", "log", "once_cell", "rustls 0.21.6", - "rustls-webpki 0.100.2", + "rustls-webpki", + "serde", + "serde_json", + "socks", "url", - "webpki-roots 0.23.1", + "webpki-roots 0.25.3", ] [[package]] @@ -5083,6 +5790,15 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wav" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a65e199c799848b4f997072aa4d673c034f80f40191f97fe2f0a23f410be1609" +dependencies = [ + "riff", +] + [[package]] name = "web-sys" version = "0.3.64" @@ -5114,12 +5830,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.23.1" +version = "0.25.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b03058f88386e5ff5310d9111d53f48b17d732b401aeb83a8d5190f2ac459338" -dependencies = [ - "rustls-webpki 0.100.2", -] +checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10" [[package]] name = "whatlang" @@ -5180,6 +5893,15 @@ dependencies = [ "windows-targets 0.48.1", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.0", +] + [[package]] name = "windows-targets" version = "0.42.2" @@ -5210,6 +5932,21 @@ dependencies = [ "windows_x86_64_msvc 0.48.0", ] +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm 0.52.0", + "windows_aarch64_msvc 0.52.0", + "windows_i686_gnu 0.52.0", + "windows_i686_msvc 0.52.0", + "windows_x86_64_gnu 0.52.0", + "windows_x86_64_gnullvm 0.52.0", + "windows_x86_64_msvc 0.52.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" @@ -5222,6 +5959,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91ae572e1b79dba883e0d315474df7305d12f569b400fcf90581b06062f7e1bc" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + [[package]] name = "windows_aarch64_msvc" version = "0.42.2" @@ -5234,6 +5977,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2ef27e0d7bdfcfc7b868b317c1d32c641a6fe4629c171b8928c7b08d98d7cf3" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + [[package]] name = "windows_i686_gnu" version = "0.42.2" @@ -5246,6 +5995,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "622a1962a7db830d6fd0a69683c80a18fda201879f0f447f065a3b7467daa241" +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + [[package]] name = "windows_i686_msvc" version = "0.42.2" @@ -5258,6 +6013,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4542c6e364ce21bf45d69fdd2a8e455fa38d316158cfd43b3ac1c5b1b19f8e00" +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + [[package]] name = "windows_x86_64_gnu" version = "0.42.2" @@ -5270,6 +6031,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca2b8a661f7628cbd23440e50b05d705db3686f894fc9580820623656af974b1" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + [[package]] name = "windows_x86_64_gnullvm" version = "0.42.2" @@ -5282,6 +6049,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7896dbc1f41e08872e9d5e8f8baa8fdd2677f29468c4e156210174edc7f7b953" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + [[package]] name = "windows_x86_64_msvc" version = "0.42.2" @@ -5294,6 +6067,12 @@ version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a515f5799fe4961cb532f983ce2b23082366b898e52ffbce459c86f67c8378a" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" + [[package]] name = "winnow" version = "0.5.4" @@ -5305,11 +6084,12 @@ dependencies = [ [[package]] name = "winreg" -version = "0.10.1" +version = "0.50.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d" +checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1" dependencies = [ - "winapi", + "cfg-if", + "windows-sys 0.48.0", ] [[package]] diff --git a/dump/src/lib.rs b/dump/src/lib.rs index 15b281c41..be0053a7c 100644 --- a/dump/src/lib.rs +++ b/dump/src/lib.rs @@ -276,6 +276,7 @@ pub(crate) mod test { ), }), pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: std::marker::PhantomData, }; settings.check() diff --git a/dump/src/reader/compat/v5_to_v6.rs b/dump/src/reader/compat/v5_to_v6.rs index 8a0d6e5e1..9351ae70d 100644 --- a/dump/src/reader/compat/v5_to_v6.rs +++ b/dump/src/reader/compat/v5_to_v6.rs @@ -378,6 +378,7 @@ impl From> for v6::Settings { v5::Setting::Reset => v6::Setting::Reset, v5::Setting::NotSet => v6::Setting::NotSet, }, + embedders: v6::Setting::NotSet, _kind: std::marker::PhantomData, } } diff --git a/index-scheduler/src/batch.rs b/index-scheduler/src/batch.rs index 94a8b3f07..cf8544ae7 100644 --- a/index-scheduler/src/batch.rs +++ b/index-scheduler/src/batch.rs @@ -1202,6 +1202,10 @@ impl IndexScheduler { let config = IndexDocumentsConfig { update_method: method, ..Default::default() }; + let embedder_configs = index.embedding_configs(index_wtxn)?; + // TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense) + let embedders = self.embedders(embedder_configs)?; + let mut builder = milli::update::IndexDocuments::new( index_wtxn, index, @@ -1220,6 +1224,8 @@ impl IndexScheduler { let (new_builder, user_result) = builder.add_documents(reader)?; builder = new_builder; + builder = builder.with_embedders(embedders.clone()); + let received_documents = if let Some(Details::DocumentAdditionOrUpdate { received_documents, @@ -1345,6 +1351,9 @@ impl IndexScheduler { for (task, (_, settings)) in tasks.iter_mut().zip(settings) { let checked_settings = settings.clone().check(); + if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) { + self.features().check_vector("Passing `embedders` in settings")? + } if checked_settings.proximity_precision.set().is_some() { self.features.features().check_proximity_precision()?; } diff --git a/index-scheduler/src/features.rs b/index-scheduler/src/features.rs index ae2823c30..d6ce3cae4 100644 --- a/index-scheduler/src/features.rs +++ b/index-scheduler/src/features.rs @@ -56,12 +56,12 @@ impl RoFeatures { } } - pub fn check_vector(&self) -> Result<()> { + pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> { if self.runtime.vector_store { Ok(()) } else { Err(FeatureNotEnabledError { - disabled_action: "Passing `vector` as a query parameter", + disabled_action, feature: "vector store", issue_link: "https://github.com/meilisearch/product/discussions/677", } diff --git a/index-scheduler/src/insta_snapshot.rs b/index-scheduler/src/insta_snapshot.rs index bd8fa5148..ddb9e934a 100644 --- a/index-scheduler/src/insta_snapshot.rs +++ b/index-scheduler/src/insta_snapshot.rs @@ -41,6 +41,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String { planned_failures: _, run_loop_iteration: _, currently_updating_index: _, + embedders: _, } = scheduler; let rtxn = env.read_txn().unwrap(); diff --git a/index-scheduler/src/lib.rs b/index-scheduler/src/lib.rs index a1b6497d9..b9b360fa4 100644 --- a/index-scheduler/src/lib.rs +++ b/index-scheduler/src/lib.rs @@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128}; use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn}; use meilisearch_types::milli::documents::DocumentsBatchBuilder; use meilisearch_types::milli::update::IndexerConfig; +use meilisearch_types::milli::vector::{Embedder, EmbedderOptions, EmbeddingConfigs}; use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32}; use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task}; use puffin::FrameView; @@ -341,6 +342,8 @@ pub struct IndexScheduler { /// so that a handle to the index is available from other threads (search) in an optimized manner. currently_updating_index: Arc>>, + embedders: Arc>>>, + // ================= test // The next entry is dedicated to the tests. /// Provide a way to set a breakpoint in multiple part of the scheduler. @@ -386,6 +389,7 @@ impl IndexScheduler { auth_path: self.auth_path.clone(), version_file_path: self.version_file_path.clone(), currently_updating_index: self.currently_updating_index.clone(), + embedders: self.embedders.clone(), #[cfg(test)] test_breakpoint_sdr: self.test_breakpoint_sdr.clone(), #[cfg(test)] @@ -484,6 +488,7 @@ impl IndexScheduler { auth_path: options.auth_path, version_file_path: options.version_file_path, currently_updating_index: Arc::new(RwLock::new(None)), + embedders: Default::default(), #[cfg(test)] test_breakpoint_sdr, @@ -1333,6 +1338,40 @@ impl IndexScheduler { } } + // TODO: consider using a type alias or a struct embedder/template + pub fn embedders( + &self, + embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>, + ) -> Result { + let res: Result<_> = embedding_configs + .into_iter() + .map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| { + let prompt = + Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?); + // optimistically return existing embedder + { + let embedders = self.embedders.read().unwrap(); + if let Some(embedder) = embedders.get(&embedder_options) { + return Ok((name, (embedder.clone(), prompt))); + } + } + + // add missing embedder + let embedder = Arc::new( + Embedder::new(embedder_options.clone()) + .map_err(meilisearch_types::milli::vector::Error::from) + .map_err(meilisearch_types::milli::Error::from)?, + ); + { + let mut embedders = self.embedders.write().unwrap(); + embedders.insert(embedder_options, embedder.clone()); + } + Ok((name, (embedder, prompt))) + }) + .collect(); + res.map(EmbeddingConfigs::new) + } + /// Blocks the thread until the test handle asks to progress to/through this breakpoint. /// /// Two messages are sent through the channel for each breakpoint. diff --git a/meilisearch-types/src/deserr/mod.rs b/meilisearch-types/src/deserr/mod.rs index df304cc2f..537b24574 100644 --- a/meilisearch-types/src/deserr/mod.rs +++ b/meilisearch-types/src/deserr/mod.rs @@ -188,3 +188,4 @@ merge_with_error_impl_take_error_message!(ParseOffsetDateTimeError); merge_with_error_impl_take_error_message!(ParseTaskKindError); merge_with_error_impl_take_error_message!(ParseTaskStatusError); merge_with_error_impl_take_error_message!(IndexUidFormatError); +merge_with_error_impl_take_error_message!(InvalidSearchSemanticRatio); diff --git a/meilisearch-types/src/error.rs b/meilisearch-types/src/error.rs index b1dc6b777..62591e991 100644 --- a/meilisearch-types/src/error.rs +++ b/meilisearch-types/src/error.rs @@ -222,6 +222,8 @@ InvalidVectorsType , InvalidRequest , BAD_REQUEST ; InvalidDocumentId , InvalidRequest , BAD_REQUEST ; InvalidDocumentLimit , InvalidRequest , BAD_REQUEST ; InvalidDocumentOffset , InvalidRequest , BAD_REQUEST ; +InvalidEmbedder , InvalidRequest , BAD_REQUEST ; +InvalidHybridQuery , InvalidRequest , BAD_REQUEST ; InvalidIndexLimit , InvalidRequest , BAD_REQUEST ; InvalidIndexOffset , InvalidRequest , BAD_REQUEST ; InvalidIndexPrimaryKey , InvalidRequest , BAD_REQUEST ; @@ -233,6 +235,7 @@ InvalidSearchAttributesToRetrieve , InvalidRequest , BAD_REQUEST ; InvalidSearchCropLength , InvalidRequest , BAD_REQUEST ; InvalidSearchCropMarker , InvalidRequest , BAD_REQUEST ; InvalidSearchFacets , InvalidRequest , BAD_REQUEST ; +InvalidSearchSemanticRatio , InvalidRequest , BAD_REQUEST ; InvalidFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; InvalidSearchFilter , InvalidRequest , BAD_REQUEST ; InvalidSearchHighlightPostTag , InvalidRequest , BAD_REQUEST ; @@ -256,6 +259,7 @@ InvalidSettingsProximityPrecision , InvalidRequest , BAD_REQUEST ; InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ; InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ; +InvalidSettingsEmbedders , InvalidRequest , BAD_REQUEST ; InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ; InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ; InvalidSettingsSortableAttributes , InvalidRequest , BAD_REQUEST ; @@ -295,15 +299,18 @@ MissingFacetSearchFacetName , InvalidRequest , BAD_REQUEST ; MissingIndexUid , InvalidRequest , BAD_REQUEST ; MissingMasterKey , Auth , UNAUTHORIZED ; MissingPayload , InvalidRequest , BAD_REQUEST ; +MissingSearchHybrid , InvalidRequest , BAD_REQUEST ; MissingSwapIndexes , InvalidRequest , BAD_REQUEST ; MissingTaskFilters , InvalidRequest , BAD_REQUEST ; NoSpaceLeftOnDevice , System , UNPROCESSABLE_ENTITY; PayloadTooLarge , InvalidRequest , PAYLOAD_TOO_LARGE ; TaskNotFound , InvalidRequest , NOT_FOUND ; TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ; +TooManyVectors , InvalidRequest , BAD_REQUEST ; UnretrievableDocument , Internal , BAD_REQUEST ; UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ; -UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE +UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ; +VectorEmbeddingError , InvalidRequest , BAD_REQUEST } impl ErrorCode for JoinError { @@ -336,6 +343,10 @@ impl ErrorCode for milli::Error { UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => { Code::InvalidDocumentId } + UserError::MissingDocumentField(_) => Code::InvalidDocumentFields, + UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders, + UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders, + UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders, UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound, UserError::MultiplePrimaryKeyCandidatesFound { .. } => { Code::IndexPrimaryKeyMultipleCandidatesFound @@ -353,11 +364,15 @@ impl ErrorCode for milli::Error { UserError::CriterionError(_) => Code::InvalidSettingsRankingRules, UserError::InvalidGeoField { .. } => Code::InvalidDocumentGeoField, UserError::InvalidVectorDimensions { .. } => Code::InvalidVectorDimensions, + UserError::InvalidVectorsMapType { .. } => Code::InvalidVectorsType, UserError::InvalidVectorsType { .. } => Code::InvalidVectorsType, + UserError::TooManyVectors(_, _) => Code::TooManyVectors, UserError::SortError(_) => Code::InvalidSearchSort, UserError::InvalidMinTypoWordLenSetting(_, _) => { Code::InvalidSettingsTypoTolerance } + UserError::InvalidEmbedder(_) => Code::InvalidEmbedder, + UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError, } } } @@ -445,6 +460,15 @@ impl fmt::Display for DeserrParseIntError { } } +impl fmt::Display for deserr_codes::InvalidSearchSemanticRatio { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`." + ) + } +} + #[macro_export] macro_rules! internal_error { ($target:ty : $($other:path), *) => { diff --git a/meilisearch-types/src/settings.rs b/meilisearch-types/src/settings.rs index 487354b8e..da06d5264 100644 --- a/meilisearch-types/src/settings.rs +++ b/meilisearch-types/src/settings.rs @@ -199,6 +199,10 @@ pub struct Settings { #[deserr(default, error = DeserrJsonError)] pub pagination: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default, error = DeserrJsonError)] + pub embedders: Setting>>, + #[serde(skip)] #[deserr(skip)] pub _kind: PhantomData, @@ -222,6 +226,7 @@ impl Settings { typo_tolerance: Setting::Reset, faceting: Setting::Reset, pagination: Setting::Reset, + embedders: Setting::Reset, _kind: PhantomData, } } @@ -243,6 +248,7 @@ impl Settings { typo_tolerance, faceting, pagination, + embedders, .. } = self; @@ -262,6 +268,7 @@ impl Settings { typo_tolerance, faceting, pagination, + embedders, _kind: PhantomData, } } @@ -307,6 +314,7 @@ impl Settings { typo_tolerance: self.typo_tolerance, faceting: self.faceting, pagination: self.pagination, + embedders: self.embedders, _kind: PhantomData, } } @@ -490,6 +498,12 @@ pub fn apply_settings_to_builder( Setting::Reset => builder.reset_pagination_max_total_hits(), Setting::NotSet => (), } + + match settings.embedders.clone() { + Setting::Set(value) => builder.set_embedder_settings(value), + Setting::Reset => builder.reset_embedder_settings(), + Setting::NotSet => (), + } } pub fn settings( @@ -571,6 +585,12 @@ pub fn settings( ), }; + let embedders = index + .embedding_configs(rtxn)? + .into_iter() + .map(|(name, config)| (name, Setting::Set(config.into()))) + .collect(); + Ok(Settings { displayed_attributes: match displayed_attributes { Some(attrs) => Setting::Set(attrs), @@ -599,6 +619,7 @@ pub fn settings( typo_tolerance: Setting::Set(typo_tolerance), faceting: Setting::Set(faceting), pagination: Setting::Set(pagination), + embedders: Setting::Set(embedders), _kind: PhantomData, }) } @@ -747,6 +768,7 @@ pub(crate) mod test { typo_tolerance: Setting::NotSet, faceting: Setting::NotSet, pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: PhantomData::, }; @@ -772,6 +794,7 @@ pub(crate) mod test { typo_tolerance: Setting::NotSet, faceting: Setting::NotSet, pagination: Setting::NotSet, + embedders: Setting::NotSet, _kind: PhantomData::, }; diff --git a/meilisearch/src/analytics/segment_analytics.rs b/meilisearch/src/analytics/segment_analytics.rs index f75516731..1ad277c28 100644 --- a/meilisearch/src/analytics/segment_analytics.rs +++ b/meilisearch/src/analytics/segment_analytics.rs @@ -36,7 +36,7 @@ use crate::routes::{create_all_stats, Stats}; use crate::search::{ FacetSearchResult, MatchingStrategy, SearchQuery, SearchQueryWithIndex, SearchResult, DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, - DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEMANTIC_RATIO, }; use crate::Opt; @@ -586,6 +586,11 @@ pub struct SearchAggregator { // vector // The maximum number of floats in a vector request max_vector_size: usize, + // Whether the semantic ratio passed to a hybrid search equals the default ratio. + semantic_ratio: bool, + // Whether a non-default embedder was specified + embedder: bool, + hybrid: bool, // every time a search is done, we increment the counter linked to the used settings matching_strategy: HashMap, @@ -639,6 +644,7 @@ impl SearchAggregator { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, } = query; let mut ret = Self::default(); @@ -712,6 +718,12 @@ impl SearchAggregator { ret.show_ranking_score = *show_ranking_score; ret.show_ranking_score_details = *show_ranking_score_details; + if let Some(hybrid) = hybrid { + ret.semantic_ratio = hybrid.semantic_ratio != DEFAULT_SEMANTIC_RATIO(); + ret.embedder = hybrid.embedder.is_some(); + ret.hybrid = true; + } + ret } @@ -765,6 +777,9 @@ impl SearchAggregator { facets_total_number_of_facets, show_ranking_score, show_ranking_score_details, + semantic_ratio, + embedder, + hybrid, } = other; if self.timestamp.is_none() { @@ -810,6 +825,9 @@ impl SearchAggregator { // vector self.max_vector_size = self.max_vector_size.max(max_vector_size); + self.semantic_ratio |= semantic_ratio; + self.hybrid |= hybrid; + self.embedder |= embedder; // pagination self.max_limit = self.max_limit.max(max_limit); @@ -878,6 +896,9 @@ impl SearchAggregator { facets_total_number_of_facets, show_ranking_score, show_ranking_score_details, + semantic_ratio, + embedder, + hybrid, } = self; if total_received == 0 { @@ -917,6 +938,11 @@ impl SearchAggregator { "vector": { "max_vector_size": max_vector_size, }, + "hybrid": { + "enabled": hybrid, + "semantic_ratio": semantic_ratio, + "embedder": embedder, + }, "pagination": { "max_limit": max_limit, "max_offset": max_offset, @@ -1012,6 +1038,7 @@ impl MultiSearchAggregator { crop_marker: _, matching_strategy: _, attributes_to_search_on: _, + hybrid: _, } = query; index_uid.as_str() @@ -1158,6 +1185,7 @@ impl FacetSearchAggregator { filter, matching_strategy, attributes_to_search_on, + hybrid, } = query; let mut ret = Self::default(); @@ -1171,7 +1199,8 @@ impl FacetSearchAggregator { || vector.is_some() || filter.is_some() || *matching_strategy != MatchingStrategy::default() - || attributes_to_search_on.is_some(); + || attributes_to_search_on.is_some() + || hybrid.is_some(); ret } diff --git a/meilisearch/src/error.rs b/meilisearch/src/error.rs index ca10c4593..3bd8f3edd 100644 --- a/meilisearch/src/error.rs +++ b/meilisearch/src/error.rs @@ -51,6 +51,8 @@ pub enum MeilisearchHttpError { DocumentFormat(#[from] DocumentFormatError), #[error(transparent)] Join(#[from] JoinError), + #[error("Invalid request: missing `hybrid` parameter when both `q` and `vector` are present.")] + MissingSearchHybrid, } impl ErrorCode for MeilisearchHttpError { @@ -74,6 +76,7 @@ impl ErrorCode for MeilisearchHttpError { MeilisearchHttpError::FileStore(_) => Code::Internal, MeilisearchHttpError::DocumentFormat(e) => e.error_code(), MeilisearchHttpError::Join(_) => Code::Internal, + MeilisearchHttpError::MissingSearchHybrid => Code::MissingSearchHybrid, } } } diff --git a/meilisearch/src/main.rs b/meilisearch/src/main.rs index 246d62c3b..ddd37bbb6 100644 --- a/meilisearch/src/main.rs +++ b/meilisearch/src/main.rs @@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; /// does all the setup before meilisearch is launched fn setup(opt: &Opt) -> anyhow::Result<()> { let mut log_builder = env_logger::Builder::new(); - log_builder.parse_filters(&opt.log_level.to_string()); + let log_filters = format!( + "{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn", + opt.log_level + ); + log_builder.parse_filters(&log_filters); log_builder.init(); diff --git a/meilisearch/src/routes/indexes/facet_search.rs b/meilisearch/src/routes/indexes/facet_search.rs index 142a424c0..4b5d4d78a 100644 --- a/meilisearch/src/routes/indexes/facet_search.rs +++ b/meilisearch/src/routes/indexes/facet_search.rs @@ -13,9 +13,9 @@ use crate::analytics::{Analytics, FacetSearchAggregator}; use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::search::{ - add_search_rules, perform_facet_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, - DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_facet_search, HybridQuery, MatchingStrategy, SearchQuery, + DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -36,6 +36,8 @@ pub struct FacetSearchQuery { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option>, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default, error = DeserrJsonError)] pub filter: Option, #[deserr(default, error = DeserrJsonError, default)] @@ -95,6 +97,7 @@ impl From for SearchQuery { filter, matching_strategy, attributes_to_search_on, + hybrid, } = value; SearchQuery { @@ -119,6 +122,7 @@ impl From for SearchQuery { matching_strategy, vector, attributes_to_search_on, + hybrid, } } } diff --git a/meilisearch/src/routes/indexes/search.rs b/meilisearch/src/routes/indexes/search.rs index 5a0a9e92b..c474d285e 100644 --- a/meilisearch/src/routes/indexes/search.rs +++ b/meilisearch/src/routes/indexes/search.rs @@ -2,12 +2,14 @@ use actix_web::web::Data; use actix_web::{web, HttpRequest, HttpResponse}; use deserr::actix_web::{AwebJson, AwebQueryParameter}; use index_scheduler::IndexScheduler; -use log::debug; +use log::{debug, warn}; use meilisearch_types::deserr::query_params::Param; use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError}; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::error::ResponseError; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli; +use meilisearch_types::milli::vector::DistributionShift; use meilisearch_types::serde_cs::vec::CS; use serde_json::Value; @@ -16,9 +18,9 @@ use crate::extractors::authentication::policies::*; use crate::extractors::authentication::GuardedData; use crate::extractors::sequential_extractor::SeqHandler; use crate::search::{ - add_search_rules, perform_search, MatchingStrategy, SearchQuery, DEFAULT_CROP_LENGTH, - DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, DEFAULT_HIGHLIGHT_PRE_TAG, - DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, + add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio, + DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG, + DEFAULT_HIGHLIGHT_PRE_TAG, DEFAULT_SEARCH_LIMIT, DEFAULT_SEARCH_OFFSET, DEFAULT_SEMANTIC_RATIO, }; pub fn configure(cfg: &mut web::ServiceConfig) { @@ -74,6 +76,31 @@ pub struct SearchQueryGet { matching_strategy: MatchingStrategy, #[deserr(default, error = DeserrQueryParamError)] pub attributes_to_search_on: Option>, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_embedder: Option, + #[deserr(default, error = DeserrQueryParamError)] + pub hybrid_semantic_ratio: Option, +} + +#[derive(Debug, Clone, Copy, Default, PartialEq, deserr::Deserr)] +#[deserr(try_from(String) = TryFrom::try_from -> InvalidSearchSemanticRatio)] +pub struct SemanticRatioGet(SemanticRatio); + +impl std::convert::TryFrom for SemanticRatioGet { + type Error = InvalidSearchSemanticRatio; + + fn try_from(s: String) -> Result { + let f: f32 = s.parse().map_err(|_| InvalidSearchSemanticRatio)?; + Ok(SemanticRatioGet(SemanticRatio::try_from(f)?)) + } +} + +impl std::ops::Deref for SemanticRatioGet { + type Target = SemanticRatio; + + fn deref(&self) -> &Self::Target { + &self.0 + } } impl From for SearchQuery { @@ -86,6 +113,20 @@ impl From for SearchQuery { None => None, }; + let hybrid = match (other.hybrid_embedder, other.hybrid_semantic_ratio) { + (None, None) => None, + (None, Some(semantic_ratio)) => { + Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: None }) + } + (Some(embedder), None) => Some(HybridQuery { + semantic_ratio: DEFAULT_SEMANTIC_RATIO(), + embedder: Some(embedder), + }), + (Some(embedder), Some(semantic_ratio)) => { + Some(HybridQuery { semantic_ratio: *semantic_ratio, embedder: Some(embedder) }) + } + }; + Self { q: other.q, vector: other.vector.map(CS::into_inner), @@ -108,6 +149,7 @@ impl From for SearchQuery { crop_marker: other.crop_marker, matching_strategy: other.matching_strategy, attributes_to_search_on: other.attributes_to_search_on.map(|o| o.into_iter().collect()), + hybrid, } } } @@ -158,8 +200,12 @@ pub async fn search_with_url_query( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); + + let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; + let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; + tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); } @@ -193,8 +239,12 @@ pub async fn search_with_post( let index = index_scheduler.index(&index_uid)?; let features = index_scheduler.features(); + + let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?; + let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?; + tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution)) + .await?; if let Ok(ref search_result) = search_result { aggregate.succeed(search_result); } @@ -206,6 +256,80 @@ pub async fn search_with_post( Ok(HttpResponse::Ok().json(search_result)) } +pub async fn embed( + query: &mut SearchQuery, + index_scheduler: &IndexScheduler, + index: &milli::Index, +) -> Result, ResponseError> { + match (&query.hybrid, &query.vector, &query.q) { + (Some(HybridQuery { semantic_ratio: _, embedder }), None, Some(q)) + if !q.trim().is_empty() => + { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedders = index_scheduler.embedders(embedder_configs)?; + + let embedder = if let Some(embedder_name) = embedder { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; + + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) + .map_err(milli::Error::from)? + .0; + + let distribution = embedder.distribution(); + + let embeddings = embedder + .embed(vec![q.to_owned()]) + .await + .map_err(milli::vector::Error::from) + .map_err(milli::Error::from)? + .pop() + .expect("No vector returned from embedding"); + + if embeddings.iter().nth(1).is_some() { + warn!("Ignoring embeddings past the first one in long search query"); + query.vector = Some(embeddings.iter().next().unwrap().to_vec()); + } else { + query.vector = Some(embeddings.into_inner()); + } + Ok(distribution) + } + (Some(hybrid), vector, _) => { + let embedder_configs = index.embedding_configs(&index.read_txn()?)?; + let embedders = index_scheduler.embedders(embedder_configs)?; + + let embedder = if let Some(embedder_name) = &hybrid.embedder { + embedders.get(embedder_name) + } else { + embedders.get_default() + }; + + let embedder = embedder + .ok_or(milli::UserError::InvalidEmbedder("default".to_owned())) + .map_err(milli::Error::from)? + .0; + + if let Some(vector) = vector { + if vector.len() != embedder.dimensions() { + return Err(meilisearch_types::milli::Error::UserError( + meilisearch_types::milli::UserError::InvalidVectorDimensions { + expected: embedder.dimensions(), + found: vector.len(), + }, + ) + .into()); + } + } + + Ok(embedder.distribution()) + } + _ => Ok(None), + } +} + #[cfg(test)] mod test { use super::*; diff --git a/meilisearch/src/routes/indexes/settings.rs b/meilisearch/src/routes/indexes/settings.rs index c22db24f0..024b7e7c0 100644 --- a/meilisearch/src/routes/indexes/settings.rs +++ b/meilisearch/src/routes/indexes/settings.rs @@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::ResponseError; use meilisearch_types::facet_values_sort::FacetValuesSort; use meilisearch_types::index_uid::IndexUid; +use meilisearch_types::milli::update::Setting; use meilisearch_types::settings::{settings, RankingRuleView, Settings, Unchecked}; use meilisearch_types::tasks::KindWithContent; use serde_json::json; @@ -545,6 +546,67 @@ make_setting_route!( } ); +make_setting_route!( + "/embedders", + patch, + std::collections::BTreeMap>, + meilisearch_types::deserr::DeserrJsonError< + meilisearch_types::error::deserr_codes::InvalidSettingsEmbedders, + >, + embedders, + "embedders", + analytics, + |setting: &Option>>, req: &HttpRequest| { + + + analytics.publish( + "Embedders Updated".to_string(), + serde_json::json!({"embedders": crate::routes::indexes::settings::embedder_analytics(setting.as_ref())}), + Some(req), + ); + } +); + +fn embedder_analytics( + setting: Option< + &std::collections::BTreeMap< + String, + Setting, + >, + >, +) -> serde_json::Value { + let mut sources = std::collections::HashSet::new(); + + if let Some(s) = &setting { + for source in s + .values() + .filter_map(|config| config.clone().set()) + .filter_map(|config| config.embedder_options.set()) + { + use meilisearch_types::milli::vector::settings::EmbedderSettings; + match source { + EmbedderSettings::OpenAi(_) => sources.insert("openAi"), + EmbedderSettings::HuggingFace(_) => sources.insert("huggingFace"), + EmbedderSettings::UserProvided(_) => sources.insert("userProvided"), + }; + } + }; + + let document_template_used = setting.as_ref().map(|map| { + map.values() + .filter_map(|config| config.clone().set()) + .any(|config| config.document_template.set().is_some()) + }); + + json!( + { + "total": setting.as_ref().map(|s| s.len()), + "sources": sources, + "document_template_used": document_template_used, + } + ) +} + macro_rules! generate_configure { ($($mod:ident),*) => { pub fn configure(cfg: &mut web::ServiceConfig) { @@ -574,7 +636,8 @@ generate_configure!( ranking_rules, typo_tolerance, pagination, - faceting + faceting, + embedders ); pub async fn update_all( @@ -681,6 +744,7 @@ pub async fn update_all( "synonyms": { "total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()), }, + "embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set()) }), Some(&req), ); diff --git a/meilisearch/src/routes/multi_search.rs b/meilisearch/src/routes/multi_search.rs index bcb8bb2a1..8e81688e6 100644 --- a/meilisearch/src/routes/multi_search.rs +++ b/meilisearch/src/routes/multi_search.rs @@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator}; use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::{AuthenticationError, GuardedData}; use crate::extractors::sequential_extractor::SeqHandler; +use crate::routes::indexes::search::embed; use crate::search::{ add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex, }; @@ -74,10 +75,15 @@ pub async fn multi_search_with_post( }) .with_index(query_index)?; - let search_result = - tokio::task::spawn_blocking(move || perform_search(&index, query, features)) - .await - .with_index(query_index)?; + let distribution = embed(&mut query, index_scheduler.get_ref(), &index) + .await + .with_index(query_index)?; + + let search_result = tokio::task::spawn_blocking(move || { + perform_search(&index, query, features, distribution) + }) + .await + .with_index(query_index)?; search_results.push(SearchResultWithIndex { index_uid: index_uid.into_inner(), diff --git a/meilisearch/src/search.rs b/meilisearch/src/search.rs index 41f073b48..b5dba8a58 100644 --- a/meilisearch/src/search.rs +++ b/meilisearch/src/search.rs @@ -7,24 +7,21 @@ use deserr::Deserr; use either::Either; use index_scheduler::RoFeatures; use indexmap::IndexMap; -use log::warn; use meilisearch_auth::IndexSearchRules; use meilisearch_types::deserr::DeserrJsonError; use meilisearch_types::error::deserr_codes::*; use meilisearch_types::heed::RoTxn; use meilisearch_types::index_uid::IndexUid; -use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy}; -use meilisearch_types::milli::{ - dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues, -}; +use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy}; +use meilisearch_types::milli::vector::DistributionShift; +use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues}; use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS; use meilisearch_types::{milli, Document}; use milli::tokenizer::TokenizerBuilder; use milli::{ AscDesc, FieldId, FieldsIdsMap, Filter, FormatOptions, Index, MatchBounds, MatcherBuilder, - SortError, TermsMatchingStrategy, VectorOrArrayOfVectors, DEFAULT_VALUES_PER_FACET, + SortError, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET, }; -use ordered_float::OrderedFloat; use regex::Regex; use serde::Serialize; use serde_json::{json, Value}; @@ -39,6 +36,7 @@ pub const DEFAULT_CROP_LENGTH: fn() -> usize = || 10; pub const DEFAULT_CROP_MARKER: fn() -> String = || "…".to_string(); pub const DEFAULT_HIGHLIGHT_PRE_TAG: fn() -> String = || "".to_string(); pub const DEFAULT_HIGHLIGHT_POST_TAG: fn() -> String = || "".to_string(); +pub const DEFAULT_SEMANTIC_RATIO: fn() -> SemanticRatio = || SemanticRatio(0.5); #[derive(Debug, Clone, Default, PartialEq, Deserr)] #[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] @@ -47,6 +45,8 @@ pub struct SearchQuery { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option>, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -87,6 +87,48 @@ pub struct SearchQuery { pub attributes_to_search_on: Option>, } +#[derive(Debug, Clone, Default, PartialEq, Deserr)] +#[deserr(error = DeserrJsonError, rename_all = camelCase, deny_unknown_fields)] +pub struct HybridQuery { + /// TODO validate that sementic ratio is between 0.0 and 1,0 + #[deserr(default, error = DeserrJsonError, default)] + pub semantic_ratio: SemanticRatio, + #[deserr(default, error = DeserrJsonError, default)] + pub embedder: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Deserr)] +#[deserr(try_from(f32) = TryFrom::try_from -> InvalidSearchSemanticRatio)] +pub struct SemanticRatio(f32); + +impl Default for SemanticRatio { + fn default() -> Self { + DEFAULT_SEMANTIC_RATIO() + } +} + +impl std::convert::TryFrom for SemanticRatio { + type Error = InvalidSearchSemanticRatio; + + fn try_from(f: f32) -> Result { + // the suggested "fix" is: `!(0.0..=1.0).contains(&f)`` which is allegedly less readable + #[allow(clippy::manual_range_contains)] + if f > 1.0 || f < 0.0 { + Err(InvalidSearchSemanticRatio) + } else { + Ok(SemanticRatio(f)) + } + } +} + +impl std::ops::Deref for SemanticRatio { + type Target = f32; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl SearchQuery { pub fn is_finite_pagination(&self) -> bool { self.page.or(self.hits_per_page).is_some() @@ -106,6 +148,8 @@ pub struct SearchQueryWithIndex { pub q: Option, #[deserr(default, error = DeserrJsonError)] pub vector: Option>, + #[deserr(default, error = DeserrJsonError)] + pub hybrid: Option, #[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError)] pub offset: usize, #[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError)] @@ -171,6 +215,7 @@ impl SearchQueryWithIndex { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, } = self; ( index_uid, @@ -196,6 +241,7 @@ impl SearchQueryWithIndex { crop_marker, matching_strategy, attributes_to_search_on, + hybrid, // do not use ..Default::default() here, // rather add any missing field from `SearchQuery` to `SearchQueryWithIndex` }, @@ -335,19 +381,44 @@ fn prepare_search<'t>( rtxn: &'t RoTxn, query: &'t SearchQuery, features: RoFeatures, + distribution: Option, ) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> { let mut search = index.search(rtxn); - if query.vector.is_some() && query.q.is_some() { - warn!("Ignoring the query string `q` when used with the `vector` parameter."); + if query.vector.is_some() { + features.check_vector("Passing `vector` as a query parameter")?; } + if query.hybrid.is_some() { + features.check_vector("Passing `hybrid` as a query parameter")?; + } + + if query.hybrid.is_none() && query.q.is_some() && query.vector.is_some() { + return Err(MeilisearchHttpError::MissingSearchHybrid); + } + + search.distribution_shift(distribution); + if let Some(ref vector) = query.vector { - search.vector(vector.clone()); + match &query.hybrid { + // If semantic ratio is 0.0, only the query search will impact the search results, + // skip the vector + Some(hybrid) if *hybrid.semantic_ratio == 0.0 => (), + _otherwise => { + search.vector(vector.clone()); + } + } } - if let Some(ref query) = query.q { - search.query(query); + if let Some(ref q) = query.q { + match &query.hybrid { + // If semantic ratio is 1.0, only the vector search will impact the search results, + // skip the query + Some(hybrid) if *hybrid.semantic_ratio == 1.0 => (), + _otherwise => { + search.query(q); + } + } } if let Some(ref searchable) = query.attributes_to_search_on { @@ -374,8 +445,8 @@ fn prepare_search<'t>( features.check_score_details()?; } - if query.vector.is_some() { - features.check_vector()?; + if let Some(HybridQuery { embedder: Some(embedder), .. }) = &query.hybrid { + search.embedder_name(embedder); } // compute the offset on the limit depending on the pagination mode. @@ -421,15 +492,22 @@ pub fn perform_search( index: &Index, query: SearchQuery, features: RoFeatures, + distribution: Option, ) -> Result { let before_search = Instant::now(); let rtxn = index.read_txn()?; let (search, is_finite_pagination, max_total_hits, offset) = - prepare_search(index, &rtxn, &query, features)?; + prepare_search(index, &rtxn, &query, features, distribution)?; let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } = - search.execute()?; + match &query.hybrid { + Some(hybrid) => match *hybrid.semantic_ratio { + ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?, + ratio => search.execute_hybrid(ratio)?, + }, + None => search.execute()?, + }; let fields_ids_map = index.fields_ids_map(&rtxn).unwrap(); @@ -538,13 +616,17 @@ pub fn perform_search( insert_geo_distance(sort, &mut document); } - let semantic_score = match query.vector.as_ref() { - Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? { - Some(vectors) => compute_semantic_score(vector, vectors)?, - None => None, - }, - None => None, - }; + let mut semantic_score = None; + for details in &score { + if let ScoreDetails::Vector(score_details::Vector { + target_vector: _, + value_similarity: Some((_matching_vector, similarity)), + }) = details + { + semantic_score = Some(*similarity); + break; + } + } let ranking_score = query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter())); @@ -647,8 +729,9 @@ pub fn perform_facet_search( let before_search = Instant::now(); let rtxn = index.read_txn()?; - let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features)?; - let mut facet_search = SearchForFacetValues::new(facet_name, search); + let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?; + let mut facet_search = + SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some()); if let Some(facet_query) = &facet_query { facet_search.query(facet_query); } @@ -676,18 +759,6 @@ fn insert_geo_distance(sorts: &[String], document: &mut Document) { } } -fn compute_semantic_score(query: &[f32], vectors: Value) -> milli::Result> { - let vectors = serde_json::from_value(vectors) - .map(VectorOrArrayOfVectors::into_array_of_vectors) - .map_err(InternalError::SerdeJson)?; - Ok(vectors - .into_iter() - .flatten() - .map(|v| OrderedFloat(dot_product_similarity(query, &v))) - .max() - .map(OrderedFloat::into_inner)) -} - fn compute_formatted_options( attr_to_highlight: &HashSet, attr_to_crop: &[String], @@ -815,22 +886,6 @@ fn make_document( Ok(document) } -/// Extract the JSON value under the field name specified -/// but doesn't support nested objects. -fn extract_field( - field_name: &str, - field_ids_map: &FieldsIdsMap, - obkv: obkv::KvReaderU16, -) -> Result, MeilisearchHttpError> { - match field_ids_map.id(field_name) { - Some(fid) => match obkv.get(fid) { - Some(value) => Ok(serde_json::from_slice(value).map(Some)?), - None => Ok(None), - }, - None => Ok(None), - } -} - fn format_fields<'a>( document: &Document, field_ids_map: &FieldsIdsMap, diff --git a/meilisearch/tests/dumps/mod.rs b/meilisearch/tests/dumps/mod.rs index 9e949436a..07cfddd37 100644 --- a/meilisearch/tests/dumps/mod.rs +++ b/meilisearch/tests/dumps/mod.rs @@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "### ); @@ -1896,7 +1908,8 @@ async fn import_dump_v6_containing_experimental_features() { }, "pagination": { "maxTotalHits": 1000 - } + }, + "embedders": {} } "###); diff --git a/meilisearch/tests/search/hybrid.rs b/meilisearch/tests/search/hybrid.rs new file mode 100644 index 000000000..c3534c110 --- /dev/null +++ b/meilisearch/tests/search/hybrid.rs @@ -0,0 +1,152 @@ +use meili_snap::snapshot; +use once_cell::sync::Lazy; + +use crate::common::index::Index; +use crate::common::{Server, Value}; +use crate::json; + +async fn index_with_documents<'a>(server: &'a Server, documents: &Value) -> Index<'a> { + let index = server.index("test"); + + let (response, code) = server.set_features(json!({"vectorStore": true})).await; + + meili_snap::snapshot!(code, @"200 OK"); + meili_snap::snapshot!(meili_snap::json_string!(response), @r###" + { + "scoreDetails": false, + "vectorStore": true, + "metrics": false, + "exportPuffinReports": false, + "proximityPrecision": false + } + "###); + + let (response, code) = index + .update_settings( + json!({ "embedders": {"default": {"source": {"userProvided": {"dimensions": 2}}}} }), + ) + .await; + assert_eq!(202, code, "{:?}", response); + index.wait_task(response.uid()).await; + + let (response, code) = index.add_documents(documents.clone(), None).await; + assert_eq!(202, code, "{:?}", response); + index.wait_task(response.uid()).await; + index +} + +static SIMPLE_SEARCH_DOCUMENTS: Lazy = Lazy::new(|| { + json!([ + { + "title": "Shazam!", + "desc": "a Captain Marvel ersatz", + "id": "1", + "_vectors": {"default": [1.0, 3.0]}, + }, + { + "title": "Captain Planet", + "desc": "He's not part of the Marvel Cinematic Universe", + "id": "2", + "_vectors": {"default": [1.0, 2.0]}, + }, + { + "title": "Captain Marvel", + "desc": "a Shazam ersatz", + "id": "3", + "_vectors": {"default": [2.0, 3.0]}, + }]) +}); + +#[actix_rt::test] +async fn simple_search() { + let server = Server::new().await; + let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.2}}), + ) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]}},{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]}},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]}}]"###); + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 0.8}}), + ) + .await; + snapshot!(code, @"200 OK"); + snapshot!(response["hits"], @r###"[{"title":"Captain Marvel","desc":"a Shazam ersatz","id":"3","_vectors":{"default":[2.0,3.0]},"_semanticScore":0.99029034},{"title":"Captain Planet","desc":"He's not part of the Marvel Cinematic Universe","id":"2","_vectors":{"default":[1.0,2.0]},"_semanticScore":0.97434163},{"title":"Shazam!","desc":"a Captain Marvel ersatz","id":"1","_vectors":{"default":[1.0,3.0]},"_semanticScore":0.9472136}]"###); +} + +#[actix_rt::test] +async fn invalid_semantic_ratio() { + let server = Server::new().await; + let index = index_with_documents(&server, &SIMPLE_SEARCH_DOCUMENTS).await; + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": 1.2}}), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); + + let (response, code) = index + .search_post( + json!({"q": "Captain", "vector": [1.0, 1.0], "hybrid": {"semanticRatio": -0.8}}), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value at `.hybrid.semanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); + + let (response, code) = index + .search_get( + &yaup::to_string( + &json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": 1.2}), + ) + .unwrap(), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); + + let (response, code) = index + .search_get( + &yaup::to_string( + &json!({"q": "Captain", "vector": [1.0, 1.0], "hybridSemanticRatio": -0.2}), + ) + .unwrap(), + ) + .await; + snapshot!(code, @"400 Bad Request"); + snapshot!(response, @r###" + { + "message": "Invalid value in parameter `hybridSemanticRatio`: the value of `semanticRatio` is invalid, expected a float between `0.0` and `1.0`.", + "code": "invalid_search_semantic_ratio", + "type": "invalid_request", + "link": "https://docs.meilisearch.com/errors#invalid_search_semantic_ratio" + } + "###); +} diff --git a/meilisearch/tests/search/mod.rs b/meilisearch/tests/search/mod.rs index 00678f7d4..133a143fd 100644 --- a/meilisearch/tests/search/mod.rs +++ b/meilisearch/tests/search/mod.rs @@ -6,6 +6,7 @@ mod errors; mod facet_search; mod formatted; mod geo; +mod hybrid; mod multi; mod pagination; mod restrict_searchable; @@ -20,22 +21,27 @@ static DOCUMENTS: Lazy = Lazy::new(|| { { "title": "Shazam!", "id": "287947", + "_vectors": { "manual": [1, 2, 3]}, }, { "title": "Captain Marvel", "id": "299537", + "_vectors": { "manual": [1, 2, 54] }, }, { "title": "Escape Room", "id": "522681", + "_vectors": { "manual": [10, -23, 32] }, }, { "title": "How to Train Your Dragon: The Hidden World", "id": "166428", + "_vectors": { "manual": [-100, 231, 32] }, }, { "title": "Gläss", "id": "450465", + "_vectors": { "manual": [-100, 340, 90] }, } ]) }); @@ -57,6 +63,7 @@ static NESTED_DOCUMENTS: Lazy = Lazy::new(|| { }, ], "cattos": "pésti", + "_vectors": { "manual": [1, 2, 3]}, }, { "id": 654, @@ -69,12 +76,14 @@ static NESTED_DOCUMENTS: Lazy = Lazy::new(|| { }, ], "cattos": ["simba", "pestiféré"], + "_vectors": { "manual": [1, 2, 54] }, }, { "id": 750, "father": "romain", "mother": "michelle", "cattos": ["enigma"], + "_vectors": { "manual": [10, 23, 32] }, }, { "id": 951, @@ -91,6 +100,7 @@ static NESTED_DOCUMENTS: Lazy = Lazy::new(|| { }, ], "cattos": ["moumoute", "gomez"], + "_vectors": { "manual": [10, 23, 32] }, }, ]) }); @@ -802,6 +812,13 @@ async fn experimental_feature_score_details() { { "title": "How to Train Your Dragon: The Hidden World", "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + }, "_rankingScoreDetails": { "words": { "order": 0, @@ -823,7 +840,7 @@ async fn experimental_feature_score_details() { "order": 3, "attributeRankingOrderScore": 1.0, "queryWordDistanceScore": 0.8095238095238095, - "score": 0.9365079365079364 + "score": 0.9727891156462584 }, "exactness": { "order": 4, @@ -870,13 +887,92 @@ async fn experimental_feature_vector_store() { meili_snap::snapshot!(code, @"200 OK"); meili_snap::snapshot!(response["vectorStore"], @"true"); + let (response, code) = index + .update_settings(json!({"embedders": { + "manual": { + "source": { + "userProvided": {"dimensions": 3} + } + } + }})) + .await; + + meili_snap::snapshot!(code, @"202 Accepted"); + let response = index.wait_task(response.uid()).await; + + meili_snap::snapshot!(meili_snap::json_string!(response["status"]), @"\"succeeded\""); + let (response, code) = index .search_post(json!({ "vector": [1.0, 2.0, 3.0], })) .await; + meili_snap::snapshot!(code, @"200 OK"); - meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @"[]"); + // vector search returns all documents that don't have vectors in the last bucket, like all sorts + meili_snap::snapshot!(meili_snap::json_string!(response["hits"]), @r###" + [ + { + "title": "Shazam!", + "id": "287947", + "_vectors": { + "manual": [ + 1, + 2, + 3 + ] + }, + "_semanticScore": 1.0 + }, + { + "title": "Captain Marvel", + "id": "299537", + "_vectors": { + "manual": [ + 1, + 2, + 54 + ] + }, + "_semanticScore": 0.9129112 + }, + { + "title": "Gläss", + "id": "450465", + "_vectors": { + "manual": [ + -100, + 340, + 90 + ] + }, + "_semanticScore": 0.8106413 + }, + { + "title": "How to Train Your Dragon: The Hidden World", + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + }, + "_semanticScore": 0.74120104 + }, + { + "title": "Escape Room", + "id": "522681", + "_vectors": { + "manual": [ + 10, + -23, + 32 + ] + } + } + ] + "###); } #[cfg(feature = "default")] @@ -1126,7 +1222,14 @@ async fn simple_search_with_strange_synonyms() { [ { "title": "How to Train Your Dragon: The Hidden World", - "id": "166428" + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + } } ] "###); @@ -1140,7 +1243,14 @@ async fn simple_search_with_strange_synonyms() { [ { "title": "How to Train Your Dragon: The Hidden World", - "id": "166428" + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + } } ] "###); @@ -1154,7 +1264,14 @@ async fn simple_search_with_strange_synonyms() { [ { "title": "How to Train Your Dragon: The Hidden World", - "id": "166428" + "id": "166428", + "_vectors": { + "manual": [ + -100, + 231, + 32 + ] + } } ] "###); diff --git a/meilisearch/tests/search/multi.rs b/meilisearch/tests/search/multi.rs index 0e2e5158d..aeec1bad4 100644 --- a/meilisearch/tests/search/multi.rs +++ b/meilisearch/tests/search/multi.rs @@ -72,7 +72,14 @@ async fn simple_search_single_index() { "hits": [ { "title": "Gläss", - "id": "450465" + "id": "450465", + "_vectors": { + "manual": [ + -100, + 340, + 90 + ] + } } ], "query": "glass", @@ -86,7 +93,14 @@ async fn simple_search_single_index() { "hits": [ { "title": "Captain Marvel", - "id": "299537" + "id": "299537", + "_vectors": { + "manual": [ + 1, + 2, + 54 + ] + } } ], "query": "captain", @@ -177,7 +191,14 @@ async fn simple_search_two_indexes() { "hits": [ { "title": "Gläss", - "id": "450465" + "id": "450465", + "_vectors": { + "manual": [ + -100, + 340, + 90 + ] + } } ], "query": "glass", @@ -203,7 +224,14 @@ async fn simple_search_two_indexes() { "age": 4 } ], - "cattos": "pésti" + "cattos": "pésti", + "_vectors": { + "manual": [ + 1, + 2, + 3 + ] + } }, { "id": 654, @@ -218,7 +246,14 @@ async fn simple_search_two_indexes() { "cattos": [ "simba", "pestiféré" - ] + ], + "_vectors": { + "manual": [ + 1, + 2, + 54 + ] + } } ], "query": "pésti", diff --git a/meilisearch/tests/settings/get_settings.rs b/meilisearch/tests/settings/get_settings.rs index 0ea556b94..9ab53c51e 100644 --- a/meilisearch/tests/settings/get_settings.rs +++ b/meilisearch/tests/settings/get_settings.rs @@ -54,7 +54,7 @@ async fn get_settings() { let (response, code) = index.settings().await; assert_eq!(code, 200); let settings = response.as_object().unwrap(); - assert_eq!(settings.keys().len(), 15); + assert_eq!(settings.keys().len(), 16); assert_eq!(settings["displayedAttributes"], json!(["*"])); assert_eq!(settings["searchableAttributes"], json!(["*"])); assert_eq!(settings["filterableAttributes"], json!([])); @@ -83,6 +83,7 @@ async fn get_settings() { "maxTotalHits": 1000, }) ); + assert_eq!(settings["embedders"], json!({})); } #[actix_rt::test] diff --git a/milli/Cargo.toml b/milli/Cargo.toml index 8aa2a6f3f..b977d64f1 100644 --- a/milli/Cargo.toml +++ b/milli/Cargo.toml @@ -27,13 +27,15 @@ fst = "0.4.7" fxhash = "0.2.1" geoutils = "0.5.1" grenad = { version = "0.4.5", default-features = false, features = [ - "rayon", "tempfile" + "rayon", + "tempfile", ] } heed = { version = "0.20.0-alpha.9", default-features = false, features = [ - "serde-json", "serde-bincode", "read-txn-no-tls" + "serde-json", + "serde-bincode", + "read-txn-no-tls", ] } indexmap = { version = "2.0.0", features = ["serde"] } -instant-distance = { version = "0.6.1", features = ["with-serde"] } json-depth-checker = { path = "../json-depth-checker" } levenshtein_automata = { version = "0.2.1", features = ["fst_automaton"] } memmap2 = "0.7.1" @@ -72,6 +74,23 @@ puffin = "0.16.0" log = "0.4.17" logging_timer = "1.1.0" csv = "1.2.1" +candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } +candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } +candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" } +tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" } +hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [ + "online", +] } +tokio = { version = "1.34.0", features = ["rt"] } +futures = "0.3.29" +reqwest = { version = "0.11.16", features = [ + "rustls-tls", + "json", +], default-features = false } +tiktoken-rs = "0.5.7" +liquid = "0.26.4" +arroy = { git = "https://github.com/meilisearch/arroy.git", version = "0.1.0" } +rand = "0.8.5" [dev-dependencies] mimalloc = { version = "0.1.37", default-features = false } @@ -83,7 +102,15 @@ meili-snap = { path = "../meili-snap" } rand = { version = "0.8.5", features = ["small_rng"] } [features] -all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"] +all-tokenizations = [ + "charabia/chinese", + "charabia/hebrew", + "charabia/japanese", + "charabia/thai", + "charabia/korean", + "charabia/greek", + "charabia/khmer", +] # Use POSIX semaphores instead of SysV semaphores in LMDB # For more information on this feature, see heed's Cargo.toml diff --git a/milli/examples/search.rs b/milli/examples/search.rs index 82de56434..a94677771 100644 --- a/milli/examples/search.rs +++ b/milli/examples/search.rs @@ -5,8 +5,8 @@ use std::time::Instant; use heed::EnvOpenOptions; use milli::{ - execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger, - TermsMatchingStrategy, + execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, + SearchLogger, TermsMatchingStrategy, }; #[global_allocator] @@ -49,14 +49,15 @@ fn main() -> Result<(), Box> { let start = Instant::now(); let mut ctx = SearchContext::new(&index, &txn); + let universe = filtered_universe(&ctx, &None)?; + let docs = execute_search( &mut ctx, - &(!query.trim().is_empty()).then(|| query.trim().to_owned()), - &None, + (!query.trim().is_empty()).then(|| query.trim()), TermsMatchingStrategy::Last, milli::score_details::ScoringStrategy::Skip, false, - &None, + universe, &None, GeoSortStrategy::default(), 0, diff --git a/milli/src/distance.rs b/milli/src/distance.rs deleted file mode 100644 index e9e17e647..000000000 --- a/milli/src/distance.rs +++ /dev/null @@ -1,41 +0,0 @@ -use std::ops; - -use instant_distance::Point; -use serde::{Deserialize, Serialize}; - -use crate::normalize_vector; - -#[derive(Debug, Default, Clone, Serialize, Deserialize)] -pub struct NDotProductPoint(Vec); - -impl NDotProductPoint { - pub fn new(point: Vec) -> Self { - NDotProductPoint(normalize_vector(point)) - } - - pub fn into_inner(self) -> Vec { - self.0 - } -} - -impl ops::Deref for NDotProductPoint { - type Target = [f32]; - - fn deref(&self) -> &Self::Target { - self.0.as_slice() - } -} - -impl Point for NDotProductPoint { - fn distance(&self, other: &Self) -> f32 { - let dist = 1.0 - dot_product_similarity(&self.0, &other.0); - debug_assert!(!dist.is_nan()); - dist - } -} - -/// Returns the dot product similarity score that will between 0.0 and 1.0 -/// if both vectors are normalized. The higher the more similar the vectors are. -pub fn dot_product_similarity(a: &[f32], b: &[f32]) -> f32 { - a.iter().zip(b).map(|(a, b)| a * b).sum() -} diff --git a/milli/src/error.rs b/milli/src/error.rs index cbbd8a3e5..9c5d8f416 100644 --- a/milli/src/error.rs +++ b/milli/src/error.rs @@ -61,6 +61,10 @@ pub enum InternalError { AbortedIndexation, #[error("The matching words list contains at least one invalid member.")] InvalidMatchingWords, + #[error(transparent)] + ArroyError(#[from] arroy::Error), + #[error(transparent)] + VectorEmbeddingError(#[from] crate::vector::Error), } #[derive(Error, Debug)] @@ -110,8 +114,10 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco InvalidGeoField(#[from] GeoError), #[error("Invalid vector dimensions: expected: `{}`, found: `{}`.", .expected, .found)] InvalidVectorDimensions { expected: usize, found: usize }, - #[error("The `_vectors` field in the document with the id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] - InvalidVectorsType { document_id: Value, value: Value }, + #[error("The `_vectors.{subfield}` field in the document with id: `{document_id}` is not an array. Was expecting an array of floats or an array of arrays of floats but instead got `{value}`.")] + InvalidVectorsType { document_id: Value, value: Value, subfield: String }, + #[error("The `_vectors` field in the document with id: `{document_id}` is not an object. Was expecting an object with a key for each embedder with manually provided vectors, but instead got `{value}`")] + InvalidVectorsMapType { document_id: Value, value: Value }, #[error("{0}")] InvalidFilter(String), #[error("Invalid type for filter subexpression: expected: {}, found: {1}.", .0.join(", "))] @@ -180,6 +186,49 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco UnknownInternalDocumentId { document_id: DocumentId }, #[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")] InvalidMinTypoWordLenSetting(u8, u8), + #[error(transparent)] + VectorEmbeddingError(#[from] crate::vector::Error), + #[error(transparent)] + MissingDocumentField(#[from] crate::prompt::error::RenderPromptError), + #[error(transparent)] + InvalidPrompt(#[from] crate::prompt::error::NewPromptError), + #[error("Invalid prompt in for embeddings with name '{0}': {1}.")] + InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError), + #[error("Too many embedders in the configuration. Found {0}, but limited to 256.")] + TooManyEmbedders(usize), + #[error("Cannot find embedder with name {0}.")] + InvalidEmbedder(String), + #[error("Too many vectors for document with id {0}: found {1}, but limited to 256.")] + TooManyVectors(String, usize), +} + +impl From for Error { + fn from(value: crate::vector::Error) -> Self { + match value.fault() { + FaultSource::User => Error::UserError(value.into()), + FaultSource::Runtime => Error::InternalError(value.into()), + FaultSource::Bug => Error::InternalError(value.into()), + FaultSource::Undecided => Error::InternalError(value.into()), + } + } +} + +impl From for Error { + fn from(value: arroy::Error) -> Self { + match value { + arroy::Error::Heed(heed) => heed.into(), + arroy::Error::Io(io) => io.into(), + arroy::Error::InvalidVecDimension { expected, received } => { + Error::UserError(UserError::InvalidVectorDimensions { expected, found: received }) + } + arroy::Error::DatabaseFull + | arroy::Error::InvalidItemAppend + | arroy::Error::UnmatchingDistance { .. } + | arroy::Error::MissingMetadata => { + Error::InternalError(InternalError::ArroyError(value)) + } + } + } } #[derive(Error, Debug)] @@ -336,6 +385,26 @@ impl From for Error { } } +#[derive(Debug, Clone, Copy)] +pub enum FaultSource { + User, + Runtime, + Bug, + Undecided, +} + +impl std::fmt::Display for FaultSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + FaultSource::User => "user error", + FaultSource::Runtime => "runtime error", + FaultSource::Bug => "coding error", + FaultSource::Undecided => "error", + }; + f.write_str(s) + } +} + #[test] fn conditionally_lookup_for_error_message() { let prefix = "Attribute `name` is not sortable."; diff --git a/milli/src/index.rs b/milli/src/index.rs index 01a01ac37..6ad39dcb1 100644 --- a/milli/src/index.rs +++ b/milli/src/index.rs @@ -10,7 +10,6 @@ use roaring::RoaringBitmap; use rstar::RTree; use time::OffsetDateTime; -use crate::distance::NDotProductPoint; use crate::documents::PrimaryKey; use crate::error::{InternalError, UserError}; use crate::fields_ids_map::FieldsIdsMap; @@ -22,7 +21,7 @@ use crate::heed_codec::{ BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec, }; use crate::proximity::ProximityPrecision; -use crate::readable_slices::ReadableSlices; +use crate::vector::EmbeddingConfig; use crate::{ default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds, FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec, @@ -30,9 +29,6 @@ use crate::{ BEU32, BEU64, }; -/// The HNSW data-structure that we serialize, fill and search in. -pub type Hnsw = instant_distance::Hnsw; - pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5; pub const DEFAULT_MIN_WORD_LEN_TWO_TYPOS: u8 = 9; @@ -48,10 +44,6 @@ pub mod main_key { pub const FIELDS_IDS_MAP_KEY: &str = "fields-ids-map"; pub const GEO_FACETED_DOCUMENTS_IDS_KEY: &str = "geo-faceted-documents-ids"; pub const GEO_RTREE_KEY: &str = "geo-rtree"; - /// The prefix of the key that is used to store the, potential big, HNSW structure. - /// It is concatenated with a big-endian encoded number (non-human readable). - /// e.g. vector-hnsw0x0032. - pub const VECTOR_HNSW_KEY_PREFIX: &str = "vector-hnsw"; pub const PRIMARY_KEY_KEY: &str = "primary-key"; pub const SEARCHABLE_FIELDS_KEY: &str = "searchable-fields"; pub const USER_DEFINED_SEARCHABLE_FIELDS_KEY: &str = "user-defined-searchable-fields"; @@ -74,6 +66,7 @@ pub mod main_key { pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by"; pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits"; pub const PROXIMITY_PRECISION: &str = "proximity-precision"; + pub const EMBEDDING_CONFIGS: &str = "embedding_configs"; } pub mod db_name { @@ -99,7 +92,8 @@ pub mod db_name { pub const FACET_ID_STRING_FST: &str = "facet-id-string-fst"; pub const FIELD_ID_DOCID_FACET_F64S: &str = "field-id-docid-facet-f64s"; pub const FIELD_ID_DOCID_FACET_STRINGS: &str = "field-id-docid-facet-strings"; - pub const VECTOR_ID_DOCID: &str = "vector-id-docids"; + pub const VECTOR_EMBEDDER_CATEGORY_ID: &str = "vector-embedder-category-id"; + pub const VECTOR_ARROY: &str = "vector-arroy"; pub const DOCUMENTS: &str = "documents"; pub const SCRIPT_LANGUAGE_DOCIDS: &str = "script_language_docids"; } @@ -166,8 +160,10 @@ pub struct Index { /// Maps the document id, the facet field id and the strings. pub field_id_docid_facet_strings: Database, - /// Maps a vector id to the document id that have it. - pub vector_id_docid: Database, + /// Maps an embedder name to its id in the arroy store. + pub embedder_category_id: Database, + /// Vector store based on arroy™. + pub vector_arroy: arroy::Database, /// Maps the document id to the document as an obkv store. pub(crate) documents: Database, @@ -182,7 +178,7 @@ impl Index { ) -> Result { use db_name::*; - options.max_dbs(24); + options.max_dbs(25); let env = options.open(path)?; let mut wtxn = env.write_txn()?; @@ -222,7 +218,11 @@ impl Index { env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_F64S))?; let field_id_docid_facet_strings = env.create_database(&mut wtxn, Some(FIELD_ID_DOCID_FACET_STRINGS))?; - let vector_id_docid = env.create_database(&mut wtxn, Some(VECTOR_ID_DOCID))?; + // vector stuff + let embedder_category_id = + env.create_database(&mut wtxn, Some(VECTOR_EMBEDDER_CATEGORY_ID))?; + let vector_arroy = env.create_database(&mut wtxn, Some(VECTOR_ARROY))?; + let documents = env.create_database(&mut wtxn, Some(DOCUMENTS))?; wtxn.commit()?; @@ -252,7 +252,8 @@ impl Index { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, - vector_id_docid, + vector_arroy, + embedder_category_id, documents, }) } @@ -475,63 +476,6 @@ impl Index { None => Ok(RoaringBitmap::new()), } } - - /* vector HNSW */ - - /// Writes the provided `hnsw`. - pub(crate) fn put_vector_hnsw(&self, wtxn: &mut RwTxn, hnsw: &Hnsw) -> heed::Result<()> { - // We must delete all the chunks before we write the new HNSW chunks. - self.delete_vector_hnsw(wtxn)?; - - let chunk_size = 1024 * 1024 * (1024 + 512); // 1.5 GiB - let bytes = bincode::serialize(hnsw).map_err(Into::into).map_err(heed::Error::Encoding)?; - for (i, chunk) in bytes.chunks(chunk_size).enumerate() { - let i = i as u32; - let mut key = main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes().to_vec(); - key.extend_from_slice(&i.to_be_bytes()); - self.main.remap_types::().put(wtxn, &key, chunk)?; - } - Ok(()) - } - - /// Delete the `hnsw`. - pub(crate) fn delete_vector_hnsw(&self, wtxn: &mut RwTxn) -> heed::Result { - let mut iter = self - .main - .remap_types::() - .prefix_iter_mut(wtxn, main_key::VECTOR_HNSW_KEY_PREFIX.as_bytes())?; - let mut deleted = false; - while iter.next().transpose()?.is_some() { - // We do not keep a reference to the key or the value. - unsafe { deleted |= iter.del_current()? }; - } - Ok(deleted) - } - - /// Returns the `hnsw`. - pub fn vector_hnsw(&self, rtxn: &RoTxn) -> Result> { - let mut slices = Vec::new(); - for result in self - .main - .remap_types::() - .prefix_iter(rtxn, main_key::VECTOR_HNSW_KEY_PREFIX)? - { - let (_, slice) = result?; - slices.push(slice); - } - - if slices.is_empty() { - Ok(None) - } else { - let readable_slices: ReadableSlices<_> = slices.into_iter().collect(); - Ok(Some( - bincode::deserialize_from(readable_slices) - .map_err(Into::into) - .map_err(heed::Error::Decoding)?, - )) - } - } - /* field distribution */ /// Writes the field distribution which associates every field name with @@ -1528,6 +1472,41 @@ impl Index { Ok(script_language) } + + pub(crate) fn put_embedding_configs( + &self, + wtxn: &mut RwTxn<'_>, + configs: Vec<(String, EmbeddingConfig)>, + ) -> heed::Result<()> { + self.main.remap_types::>>().put( + wtxn, + main_key::EMBEDDING_CONFIGS, + &configs, + ) + } + + pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result { + self.main.remap_key_type::().delete(wtxn, main_key::EMBEDDING_CONFIGS) + } + + pub fn embedding_configs( + &self, + rtxn: &RoTxn<'_>, + ) -> Result> { + Ok(self + .main + .remap_types::>>() + .get(rtxn, main_key::EMBEDDING_CONFIGS)? + .unwrap_or_default()) + } + + pub fn default_embedding_name(&self, rtxn: &RoTxn<'_>) -> Result { + let configs = self.embedding_configs(rtxn)?; + Ok(match configs.as_slice() { + [(ref first_name, _)] => first_name.clone(), + _ => "default".to_owned(), + }) + } } #[cfg(test)] diff --git a/milli/src/lib.rs b/milli/src/lib.rs index acea72c41..f6b398304 100644 --- a/milli/src/lib.rs +++ b/milli/src/lib.rs @@ -10,18 +10,18 @@ pub mod documents; mod asc_desc; mod criterion; -pub mod distance; mod error; mod external_documents_ids; pub mod facet; mod fields_ids_map; pub mod heed_codec; pub mod index; +pub mod prompt; pub mod proximity; -mod readable_slices; pub mod score_details; mod search; pub mod update; +pub mod vector; #[cfg(test)] #[macro_use] @@ -32,13 +32,12 @@ use std::convert::{TryFrom, TryInto}; use std::hash::BuildHasherDefault; use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer}; -pub use distance::dot_product_similarity; pub use filter_parser::{Condition, FilterCondition, Span, Token}; use fxhash::{FxHasher32, FxHasher64}; pub use grenad::CompressionType; pub use search::new::{ - execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger, - VisualSearchLogger, + execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext, + SearchLogger, VisualSearchLogger, }; use serde_json::Value; pub use {charabia as tokenizer, heed}; diff --git a/milli/src/prompt/context.rs b/milli/src/prompt/context.rs new file mode 100644 index 000000000..a28a87caa --- /dev/null +++ b/milli/src/prompt/context.rs @@ -0,0 +1,97 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use super::document::Document; +use super::fields::Fields; +use crate::FieldsIdsMap; + +#[derive(Debug, Clone)] +pub struct Context<'a> { + document: &'a Document<'a>, + fields: Fields<'a>, +} + +impl<'a> Context<'a> { + pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { + Self { document, fields: Fields::new(document, field_id_map) } + } +} + +impl<'a> ObjectView for Context<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(self.document.as_value()) + .chain(std::iter::once(self.fields.as_value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "doc" || index == "fields" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "doc" => Some(self.document.as_value()), + "fields" => Some(self.fields.as_value()), + _ => None, + } + } +} + +impl<'a> ValueView for Context<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/prompt/document.rs b/milli/src/prompt/document.rs new file mode 100644 index 000000000..b5d43b5be --- /dev/null +++ b/milli/src/prompt/document.rs @@ -0,0 +1,131 @@ +use std::cell::OnceCell; +use std::collections::BTreeMap; + +use liquid::model::{ + DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use crate::update::del_add::{DelAdd, KvReaderDelAdd}; +use crate::FieldsIdsMap; + +#[derive(Debug, Clone)] +pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>); + +#[derive(Debug, Clone)] +struct ParsedValue(std::cell::OnceCell); + +impl ParsedValue { + fn empty() -> ParsedValue { + ParsedValue(OnceCell::new()) + } + + fn get(&self, raw: &[u8]) -> &LiquidValue { + self.0.get_or_init(|| { + let value: serde_json::Value = serde_json::from_slice(raw).unwrap(); + liquid::model::to_value(&value).unwrap() + }) + } +} + +impl<'a> Document<'a> { + pub fn new( + data: obkv::KvReaderU16<'a>, + side: DelAdd, + inverted_field_map: &'a FieldsIdsMap, + ) -> Self { + let mut out_data = BTreeMap::new(); + for (fid, raw) in data { + let obkv = KvReaderDelAdd::new(raw); + let Some(raw) = obkv.get(side) else { + continue; + }; + let Some(name) = inverted_field_map.name(fid) else { + continue; + }; + out_data.insert(name, (raw, ParsedValue::empty())); + } + Self(out_data) + } + + fn is_empty(&self) -> bool { + self.0.is_empty() + } + + fn len(&self) -> usize { + self.0.len() + } + + fn iter(&self) -> impl Iterator + '_ { + self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned())) + } +} + +impl<'a> ObjectView for Document<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + self.len() as i64 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + let keys = BTreeMap::keys(&self.0).map(|&s| s.into()); + Box::new(keys) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView)) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView))) + } + + fn contains_key(&self, index: &str) -> bool { + self.0.contains_key(index) + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView) + } +} + +impl<'a> ValueView for Document<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.is_empty(), + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object(self.iter().collect()) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/prompt/error.rs b/milli/src/prompt/error.rs new file mode 100644 index 000000000..8a762b60a --- /dev/null +++ b/milli/src/prompt/error.rs @@ -0,0 +1,56 @@ +use crate::error::FaultSource; + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct NewPromptError { + pub kind: NewPromptErrorKind, + pub fault: FaultSource, +} + +impl From for crate::Error { + fn from(value: NewPromptError) -> Self { + crate::Error::UserError(crate::UserError::InvalidPrompt(value)) + } +} + +impl NewPromptError { + pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError { + Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User } + } + + pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError { + Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum NewPromptErrorKind { + #[error("cannot parse template: {0}")] + CannotParseTemplate(liquid::Error), + #[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")] + InvalidFieldsInTemplate(liquid::Error), +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct RenderPromptError { + pub kind: RenderPromptErrorKind, + pub fault: FaultSource, +} +impl RenderPromptError { + pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError { + Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum RenderPromptErrorKind { + #[error("missing field in document: {0}")] + MissingContext(liquid::Error), +} + +impl From for crate::Error { + fn from(value: RenderPromptError) -> Self { + crate::Error::UserError(crate::UserError::MissingDocumentField(value)) + } +} diff --git a/milli/src/prompt/fields.rs b/milli/src/prompt/fields.rs new file mode 100644 index 000000000..3187485f1 --- /dev/null +++ b/milli/src/prompt/fields.rs @@ -0,0 +1,172 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{ObjectView, ValueView}; + +use super::document::Document; +use crate::FieldsIdsMap; +#[derive(Debug, Clone)] +pub struct Fields<'a>(Vec>); + +impl<'a> Fields<'a> { + pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self { + Self( + std::iter::repeat(document) + .zip(field_id_map.iter()) + .map(|(document, (_fid, name))| FieldValue { document, name }) + .collect(), + ) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct FieldValue<'a> { + name: &'a str, + document: &'a Document<'a>, +} + +impl<'a> ValueView for FieldValue<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => self.is_empty(), + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl<'a> FieldValue<'a> { + pub fn name(&self) -> &&'a str { + &self.name + } + + pub fn value(&self) -> &dyn ValueView { + self.document.get(self.name).unwrap_or(&LiquidValue::Nil) + } + + pub fn is_empty(&self) -> bool { + self.size() == 0 + } +} + +impl<'a> ObjectView for FieldValue<'a> { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "name" || index == "value" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "name" => Some(self.name()), + "value" => Some(self.value()), + _ => None, + } + } +} + +impl<'a> ArrayView for Fields<'a> { + fn as_value(&self) -> &dyn ValueView { + self.0.as_value() + } + + fn size(&self) -> i64 { + self.0.len() as i64 + } + + fn values<'k>(&'k self) -> Box + 'k> { + self.0.values() + } + + fn contains_key(&self, index: i64) -> bool { + self.0.contains_key(index) + } + + fn get(&self, index: i64) -> Option<&dyn ValueView> { + ArrayView::get(&self.0, index) + } +} + +impl<'a> ValueView for Fields<'a> { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + self.0.render() + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + self.0.source() + } + + fn type_name(&self) -> &'static str { + self.0.type_name() + } + + fn query_state(&self, state: liquid::model::State) -> bool { + self.0.query_state(state) + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + self.0.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + self.0.to_value() + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self) + } +} diff --git a/milli/src/prompt/mod.rs b/milli/src/prompt/mod.rs new file mode 100644 index 000000000..97ccbfb61 --- /dev/null +++ b/milli/src/prompt/mod.rs @@ -0,0 +1,176 @@ +mod context; +mod document; +pub(crate) mod error; +mod fields; +mod template_checker; + +use std::convert::TryFrom; + +use error::{NewPromptError, RenderPromptError}; + +use self::context::Context; +use self::document::Document; +use crate::update::del_add::DelAdd; +use crate::FieldsIdsMap; + +pub struct Prompt { + template: liquid::Template, + template_text: String, +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct PromptData { + pub template: String, +} + +impl From for PromptData { + fn from(value: Prompt) -> Self { + Self { template: value.template_text } + } +} + +impl TryFrom for Prompt { + type Error = NewPromptError; + + fn try_from(value: PromptData) -> Result { + Prompt::new(value.template) + } +} + +impl Clone for Prompt { + fn clone(&self) -> Self { + let template_text = self.template_text.clone(); + Self { template: new_template(&template_text).unwrap(), template_text } + } +} + +fn new_template(text: &str) -> Result { + liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text) +} + +fn default_template() -> liquid::Template { + new_template(default_template_text()).unwrap() +} + +fn default_template_text() -> &'static str { + "{% for field in fields %} \ + {{ field.name }}: {{ field.value }}\n\ + {% endfor %}" +} + +impl Default for Prompt { + fn default() -> Self { + Self { template: default_template(), template_text: default_template_text().into() } + } +} + +impl Default for PromptData { + fn default() -> Self { + Self { template: default_template_text().into() } + } +} + +impl Prompt { + pub fn new(template: String) -> Result { + let this = Self { + template: liquid::ParserBuilder::with_stdlib() + .build() + .unwrap() + .parse(&template) + .map_err(NewPromptError::cannot_parse_template)?, + template_text: template, + }; + + // render template with special object that's OK with `doc.*` and `fields.*` + this.template + .render(&template_checker::TemplateChecker) + .map_err(NewPromptError::invalid_fields_in_template)?; + + Ok(this) + } + + pub fn render( + &self, + document: obkv::KvReaderU16<'_>, + side: DelAdd, + field_id_map: &FieldsIdsMap, + ) -> Result { + let document = Document::new(document, side, field_id_map); + let context = Context::new(&document, field_id_map); + + self.template.render(&context).map_err(RenderPromptError::missing_context) + } +} + +#[cfg(test)] +mod test { + use super::Prompt; + use crate::error::FaultSource; + use crate::prompt::error::{NewPromptError, NewPromptErrorKind}; + + #[test] + fn default_template() { + // does not panic + Prompt::default(); + } + + #[test] + fn empty_template() { + Prompt::new("".into()).unwrap(); + } + + #[test] + fn template_ok() { + Prompt::new("{{doc.title}}: {{doc.overview}}".into()).unwrap(); + } + + #[test] + fn template_syntax() { + assert!(matches!( + Prompt::new("{{doc.title: {{doc.overview}}".into()), + Err(NewPromptError { + kind: NewPromptErrorKind::CannotParseTemplate(_), + fault: FaultSource::User + }) + )); + } + + #[test] + fn template_missing_doc() { + assert!(matches!( + Prompt::new("{{title}}: {{overview}}".into()), + Err(NewPromptError { + kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), + fault: FaultSource::User + }) + )); + } + + #[test] + fn template_nested_doc() { + Prompt::new("{{doc.actor.firstName}}: {{doc.actor.lastName}}".into()).unwrap(); + } + + #[test] + fn template_fields() { + Prompt::new("{% for field in fields %}{{field}}{% endfor %}".into()).unwrap(); + } + + #[test] + fn template_fields_ok() { + Prompt::new("{% for field in fields %}{{field.name}}: {{field.value}}{% endfor %}".into()) + .unwrap(); + } + + #[test] + fn template_fields_invalid() { + assert!(matches!( + // intentionally garbled field + Prompt::new("{% for field in fields %}{{field.vaelu}} {% endfor %}".into()), + Err(NewPromptError { + kind: NewPromptErrorKind::InvalidFieldsInTemplate(_), + fault: FaultSource::User + }) + )); + } +} diff --git a/milli/src/prompt/template_checker.rs b/milli/src/prompt/template_checker.rs new file mode 100644 index 000000000..4cda4a70d --- /dev/null +++ b/milli/src/prompt/template_checker.rs @@ -0,0 +1,301 @@ +use liquid::model::{ + ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue, +}; +use liquid::{Object, ObjectView, ValueView}; + +#[derive(Debug)] +pub struct TemplateChecker; + +#[derive(Debug)] +pub struct DummyDoc; + +#[derive(Debug)] +pub struct DummyFields; + +#[derive(Debug)] +pub struct DummyField; + +const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil; + +impl ObjectView for DummyField { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(vec![DUMMY_VALUE.as_view(), DUMMY_VALUE.as_view()].into_iter()) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "name" || index == "value" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + if self.contains_key(index) { + Some(DUMMY_VALUE.as_view()) + } else { + None + } + } +} + +impl ValueView for DummyField { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + let mut this = Object::new(); + this.insert("name".into(), LiquidValue::Nil); + this.insert("value".into(), LiquidValue::Nil); + LiquidValue::Object(this) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl ValueView for DummyFields { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "array" + } + + fn query_state(&self, state: State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Array(vec![DummyField.to_value()]) + } + + fn as_array(&self) -> Option<&dyn ArrayView> { + Some(self) + } +} + +impl ArrayView for DummyFields { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + u16::MAX as i64 + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::once(DummyField.as_value())) + } + + fn contains_key(&self, index: i64) -> bool { + index < self.size() + } + + fn get(&self, _index: i64) -> Option<&dyn ValueView> { + Some(DummyField.as_value()) + } +} + +impl ObjectView for DummyDoc { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 1000 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(std::iter::empty()) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new(std::iter::empty()) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(std::iter::empty()) + } + + fn contains_key(&self, _index: &str) -> bool { + true + } + + fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> { + // Recursively sends itself + Some(self) + } +} + +impl ValueView for DummyDoc { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> DisplayCow<'_> { + DUMMY_VALUE.render() + } + + fn source(&self) -> DisplayCow<'_> { + DUMMY_VALUE.source() + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue => false, + State::Empty => false, + State::Blank => false, + } + } + + fn to_kstr(&self) -> KStringCow<'_> { + DUMMY_VALUE.to_kstr() + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Nil + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} + +impl ObjectView for TemplateChecker { + fn as_value(&self) -> &dyn ValueView { + self + } + + fn size(&self) -> i64 { + 2 + } + + fn keys<'k>(&'k self) -> Box> + 'k> { + Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s))) + } + + fn values<'k>(&'k self) -> Box + 'k> { + Box::new( + std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())), + ) + } + + fn iter<'k>(&'k self) -> Box, &'k dyn ValueView)> + 'k> { + Box::new(self.keys().zip(self.values())) + } + + fn contains_key(&self, index: &str) -> bool { + index == "doc" || index == "fields" + } + + fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> { + match index { + "doc" => Some(DummyDoc.as_value()), + "fields" => Some(DummyFields.as_value()), + _ => None, + } + } +} + +impl ValueView for TemplateChecker { + fn as_debug(&self) -> &dyn std::fmt::Debug { + self + } + + fn render(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectRender::new(self))) + } + + fn source(&self) -> liquid::model::DisplayCow<'_> { + DisplayCow::Owned(Box::new(ObjectSource::new(self))) + } + + fn type_name(&self) -> &'static str { + "object" + } + + fn query_state(&self, state: liquid::model::State) -> bool { + match state { + State::Truthy => true, + State::DefaultValue | State::Empty | State::Blank => false, + } + } + + fn to_kstr(&self) -> liquid::model::KStringCow<'_> { + let s = ObjectRender::new(self).to_string(); + KStringCow::from_string(s) + } + + fn to_value(&self) -> LiquidValue { + LiquidValue::Object( + self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(), + ) + } + + fn as_object(&self) -> Option<&dyn ObjectView> { + Some(self) + } +} diff --git a/milli/src/readable_slices.rs b/milli/src/readable_slices.rs deleted file mode 100644 index 7f5be214f..000000000 --- a/milli/src/readable_slices.rs +++ /dev/null @@ -1,85 +0,0 @@ -use std::io::{self, Read}; -use std::iter::FromIterator; - -pub struct ReadableSlices { - inner: Vec, - pos: u64, -} - -impl FromIterator for ReadableSlices { - fn from_iter>(iter: T) -> Self { - ReadableSlices { inner: iter.into_iter().collect(), pos: 0 } - } -} - -impl> Read for ReadableSlices { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { - let original_buf_len = buf.len(); - - // We explore the list of slices to find the one where we must start reading. - let mut pos = self.pos; - let index = match self - .inner - .iter() - .map(|s| s.as_ref().len() as u64) - .position(|size| pos.checked_sub(size).map(|p| pos = p).is_none()) - { - Some(index) => index, - None => return Ok(0), - }; - - let mut inner_pos = pos as usize; - for slice in &self.inner[index..] { - let slice = &slice.as_ref()[inner_pos..]; - - if buf.len() > slice.len() { - // We must exhaust the current slice and go to the next one there is not enough here. - buf[..slice.len()].copy_from_slice(slice); - buf = &mut buf[slice.len()..]; - inner_pos = 0; - } else { - // There is enough in this slice to fill the remaining bytes of the buffer. - // Let's break just after filling it. - buf.copy_from_slice(&slice[..buf.len()]); - buf = &mut []; - break; - } - } - - let written = original_buf_len - buf.len(); - self.pos += written as u64; - Ok(written) - } -} - -#[cfg(test)] -mod test { - use std::io::Read; - - use super::ReadableSlices; - - #[test] - fn basic() { - let data: Vec<_> = (0..100).collect(); - let splits: Vec<_> = data.chunks(3).collect(); - let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); - - let mut output = Vec::new(); - let length = rdslices.read_to_end(&mut output).unwrap(); - assert_eq!(length, data.len()); - assert_eq!(output, data); - } - - #[test] - fn small_reads() { - let data: Vec<_> = (0..u8::MAX).collect(); - let splits: Vec<_> = data.chunks(27).collect(); - let mut rdslices: ReadableSlices<_> = splits.into_iter().collect(); - - let buffer = &mut [0; 45]; - let length = rdslices.read(buffer).unwrap(); - let expected: Vec<_> = (0..buffer.len() as u8).collect(); - assert_eq!(length, buffer.len()); - assert_eq!(buffer, &expected[..]); - } -} diff --git a/milli/src/score_details.rs b/milli/src/score_details.rs index 8fc998ae4..f6b9db58c 100644 --- a/milli/src/score_details.rs +++ b/milli/src/score_details.rs @@ -1,3 +1,6 @@ +use std::cmp::Ordering; + +use itertools::Itertools; use serde::Serialize; use crate::distance_between_two_points; @@ -12,9 +15,24 @@ pub enum ScoreDetails { ExactAttribute(ExactAttribute), ExactWords(ExactWords), Sort(Sort), + Vector(Vector), GeoSort(GeoSort), } +#[derive(Clone, Copy)] +pub enum ScoreValue<'a> { + Score(f64), + Sort(&'a Sort), + GeoSort(&'a GeoSort), +} + +enum RankOrValue<'a> { + Rank(Rank), + Sort(&'a Sort), + GeoSort(&'a GeoSort), + Score(f64), +} + impl ScoreDetails { pub fn local_score(&self) -> Option { self.rank().map(Rank::local_score) @@ -31,11 +49,55 @@ impl ScoreDetails { ScoreDetails::ExactWords(details) => Some(details.rank()), ScoreDetails::Sort(_) => None, ScoreDetails::GeoSort(_) => None, + ScoreDetails::Vector(_) => None, } } - pub fn global_score<'a>(details: impl Iterator) -> f64 { - Rank::global_score(details.filter_map(Self::rank)) + pub fn global_score<'a>(details: impl Iterator + 'a) -> f64 { + Self::score_values(details) + .find_map(|x| { + let ScoreValue::Score(score) = x else { + return None; + }; + Some(score) + }) + .unwrap_or(1.0f64) + } + + pub fn score_values<'a>( + details: impl Iterator + 'a, + ) -> impl Iterator> + 'a { + details + .map(ScoreDetails::rank_or_value) + .coalesce(|left, right| match (left, right) { + (RankOrValue::Rank(left), RankOrValue::Rank(right)) => { + Ok(RankOrValue::Rank(Rank::merge(left, right))) + } + (left, right) => Err((left, right)), + }) + .map(|rank_or_value| match rank_or_value { + RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()), + RankOrValue::Sort(s) => ScoreValue::Sort(s), + RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g), + RankOrValue::Score(s) => ScoreValue::Score(s), + }) + } + + fn rank_or_value(&self) -> RankOrValue<'_> { + match self { + ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()), + ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()), + ScoreDetails::Proximity(p) => RankOrValue::Rank(*p), + ScoreDetails::Fid(f) => RankOrValue::Rank(*f), + ScoreDetails::Position(p) => RankOrValue::Rank(*p), + ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()), + ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()), + ScoreDetails::Sort(sort) => RankOrValue::Sort(sort), + ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort), + ScoreDetails::Vector(vector) => RankOrValue::Score( + vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64), + ), + } } /// Panics @@ -181,6 +243,19 @@ impl ScoreDetails { details_map.insert(sort, sort_details); order += 1; } + ScoreDetails::Vector(s) => { + let vector = format!("vectorSort({:?})", s.target_vector); + let value = s.value_similarity.as_ref().map(|(v, _)| v); + let similarity = s.value_similarity.as_ref().map(|(_, s)| s); + + let details = serde_json::json!({ + "order": order, + "value": value, + "similarity": similarity, + }); + details_map.insert(vector, details); + order += 1; + } } } details_map @@ -297,15 +372,21 @@ impl Rank { pub fn global_score(details: impl Iterator) -> f64 { let mut rank = Rank { rank: 1, max_rank: 1 }; for inner_rank in details { - rank.rank -= 1; - - rank.rank *= inner_rank.max_rank; - rank.max_rank *= inner_rank.max_rank; - - rank.rank += inner_rank.rank; + rank = Rank::merge(rank, inner_rank); } rank.local_score() } + + pub fn merge(mut outer: Rank, inner: Rank) -> Rank { + outer.rank = outer.rank.saturating_sub(1); + + outer.rank *= inner.max_rank; + outer.max_rank *= inner.max_rank; + + outer.rank += inner.rank; + + outer + } } #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)] @@ -335,13 +416,78 @@ pub struct Sort { pub value: serde_json::Value, } -#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)] +impl PartialOrd for Sort { + fn partial_cmp(&self, other: &Self) -> Option { + if self.field_name != other.field_name { + return None; + } + if self.ascending != other.ascending { + return None; + } + match (&self.value, &other.value) { + (serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal), + (serde_json::Value::Null, _) => Some(Ordering::Less), + (_, serde_json::Value::Null) => Some(Ordering::Greater), + // numbers are always before strings + (serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater), + (serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less), + (serde_json::Value::Number(left), serde_json::Value::Number(right)) => { + // FIXME: unwrap permitted here? + let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?; + // 12 < 42, and when ascending, we want to see 12 first, so the smallest. + // Hence, when ascending, smaller is better + Some(if self.ascending { order.reverse() } else { order }) + } + (serde_json::Value::String(left), serde_json::Value::String(right)) => { + let order = left.cmp(right); + // Taking e.g. "a" and "z" + // "a" < "z", and when ascending, we want to see "a" first, so the smallest. + // Hence, when ascending, smaller is better + Some(if self.ascending { order.reverse() } else { order }) + } + _ => None, + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq)] pub struct GeoSort { pub target_point: [f64; 2], pub ascending: bool, pub value: Option<[f64; 2]>, } +impl PartialOrd for GeoSort { + fn partial_cmp(&self, other: &Self) -> Option { + if self.target_point != other.target_point { + return None; + } + if self.ascending != other.ascending { + return None; + } + Some(match (self.distance(), other.distance()) { + (None, None) => Ordering::Equal, + (None, Some(_)) => Ordering::Less, + (Some(_), None) => Ordering::Greater, + (Some(left), Some(right)) => { + let order = left.partial_cmp(&right)?; + if self.ascending { + // when ascending, the one with the smallest distance has the best score + order.reverse() + } else { + order + } + } + }) + } +} + +#[derive(Debug, Clone, PartialEq, PartialOrd)] +pub struct Vector { + pub target_vector: Vec, + pub value_similarity: Option<(Vec, f32)>, +} + impl GeoSort { pub fn distance(&self) -> Option { self.value.map(|value| distance_between_two_points(&self.target_point, &value)) diff --git a/milli/src/search/hybrid.rs b/milli/src/search/hybrid.rs new file mode 100644 index 000000000..67365cf52 --- /dev/null +++ b/milli/src/search/hybrid.rs @@ -0,0 +1,183 @@ +use std::cmp::Ordering; + +use itertools::Itertools; +use roaring::RoaringBitmap; + +use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy}; +use crate::{MatchingWords, Result, Search, SearchResult}; + +struct ScoreWithRatioResult { + matching_words: MatchingWords, + candidates: RoaringBitmap, + document_scores: Vec<(u32, ScoreWithRatio)>, +} + +type ScoreWithRatio = (Vec, f32); + +fn compare_scores( + &(ref left_scores, left_ratio): &ScoreWithRatio, + &(ref right_scores, right_ratio): &ScoreWithRatio, +) -> Ordering { + let mut left_it = ScoreDetails::score_values(left_scores.iter()); + let mut right_it = ScoreDetails::score_values(right_scores.iter()); + + loop { + let left = left_it.next(); + let right = right_it.next(); + + match (left, right) { + (None, None) => return Ordering::Equal, + (None, Some(_)) => return Ordering::Less, + (Some(_), None) => return Ordering::Greater, + (Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => { + let left = left * left_ratio as f64; + let right = right * right_ratio as f64; + if (left - right).abs() <= f64::EPSILON { + continue; + } + return left.partial_cmp(&right).unwrap(); + } + (Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => { + match left.partial_cmp(right).unwrap() { + Ordering::Equal => continue, + order => return order, + } + } + (Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => { + match left.partial_cmp(right).unwrap() { + Ordering::Equal => continue, + order => return order, + } + } + (Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater, + (Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less, + // if we have this, we're bad + (Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_))) + | (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => { + unreachable!("Unexpected geo and sort comparison") + } + } + } +} + +impl ScoreWithRatioResult { + fn new(results: SearchResult, ratio: f32) -> Self { + let document_scores = results + .documents_ids + .into_iter() + .zip(results.document_scores.into_iter().map(|scores| (scores, ratio))) + .collect(); + + Self { + matching_words: results.matching_words, + candidates: results.candidates, + document_scores, + } + } + + fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult { + let mut documents_ids = + Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); + let mut document_scores = + Vec::with_capacity(left.document_scores.len() + right.document_scores.len()); + + let mut documents_seen = RoaringBitmap::new(); + for (docid, (main_score, _sub_score)) in left + .document_scores + .into_iter() + .merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| { + // the first value is the one with the greatest score + compare_scores(left, right).is_ge() + }) + // remove documents we already saw + .filter(|(docid, _)| documents_seen.insert(*docid)) + // start skipping **after** the filter + .skip(from) + // take **after** skipping + .take(length) + { + documents_ids.push(docid); + // TODO: pass both scores to documents_score in some way? + document_scores.push(main_score); + } + + SearchResult { + matching_words: left.matching_words, + candidates: left.candidates | right.candidates, + documents_ids, + document_scores, + } + } +} + +impl<'a> Search<'a> { + pub fn execute_hybrid(&self, semantic_ratio: f32) -> Result { + // TODO: find classier way to achieve that than to reset vector and query params + // create separate keyword and semantic searches + let mut search = Search { + query: self.query.clone(), + vector: self.vector.clone(), + filter: self.filter.clone(), + offset: 0, + limit: self.limit + self.offset, + sort_criteria: self.sort_criteria.clone(), + searchable_attributes: self.searchable_attributes, + geo_strategy: self.geo_strategy, + terms_matching_strategy: self.terms_matching_strategy, + scoring_strategy: ScoringStrategy::Detailed, + words_limit: self.words_limit, + exhaustive_number_hits: self.exhaustive_number_hits, + rtxn: self.rtxn, + index: self.index, + distribution_shift: self.distribution_shift, + embedder_name: self.embedder_name.clone(), + }; + + let vector_query = search.vector.take(); + let keyword_results = search.execute()?; + + // skip semantic search if we don't have a vector query (placeholder search) + let Some(vector_query) = vector_query else { + return Ok(keyword_results); + }; + + // completely skip semantic search if the results of the keyword search are good enough + if self.results_good_enough(&keyword_results, semantic_ratio) { + return Ok(keyword_results); + } + + search.vector = Some(vector_query); + search.query = None; + + // TODO: would be better to have two distinct functions at this point + let vector_results = search.execute()?; + + let keyword_results = ScoreWithRatioResult::new(keyword_results, 1.0 - semantic_ratio); + let vector_results = ScoreWithRatioResult::new(vector_results, semantic_ratio); + + let merge_results = + ScoreWithRatioResult::merge(vector_results, keyword_results, self.offset, self.limit); + assert!(merge_results.documents_ids.len() <= self.limit); + Ok(merge_results) + } + + fn results_good_enough(&self, keyword_results: &SearchResult, semantic_ratio: f32) -> bool { + // A result is good enough if its keyword score is > 0.9 with a semantic ratio of 0.5 => 0.9 * 0.5 + const GOOD_ENOUGH_SCORE: f64 = 0.45; + + // 1. we check that we got a sufficient number of results + if keyword_results.document_scores.len() < self.limit + self.offset { + return false; + } + + // 2. and that all results have a good enough score. + // we need to check all results because due to sort like rules, they're not necessarily in relevancy order + for score in &keyword_results.document_scores { + let score = ScoreDetails::global_score(score.iter()); + if score * ((1.0 - semantic_ratio) as f64) < GOOD_ENOUGH_SCORE { + return false; + } + } + true + } +} diff --git a/milli/src/search/mod.rs b/milli/src/search/mod.rs index ee8cd1faf..3e4849578 100644 --- a/milli/src/search/mod.rs +++ b/milli/src/search/mod.rs @@ -12,12 +12,14 @@ use roaring::bitmap::RoaringBitmap; pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET}; pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords}; -use self::new::PartialSearchResult; +use self::new::{execute_vector_search, PartialSearchResult}; use crate::error::UserError; use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue}; use crate::score_details::{ScoreDetails, ScoringStrategy}; +use crate::vector::DistributionShift; use crate::{ - execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext, + execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, + Result, SearchContext, }; // Building these factories is not free. @@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100; pub mod facet; mod fst_utils; +pub mod hybrid; pub mod new; pub struct Search<'a> { @@ -46,8 +49,11 @@ pub struct Search<'a> { scoring_strategy: ScoringStrategy, words_limit: usize, exhaustive_number_hits: bool, + /// TODO: Add semantic ratio or pass it directly to execute_hybrid() rtxn: &'a heed::RoTxn<'a>, index: &'a Index, + distribution_shift: Option, + embedder_name: Option, } impl<'a> Search<'a> { @@ -67,6 +73,8 @@ impl<'a> Search<'a> { words_limit: 10, rtxn, index, + distribution_shift: None, + embedder_name: None, } } @@ -75,8 +83,8 @@ impl<'a> Search<'a> { self } - pub fn vector(&mut self, vector: impl Into>) -> &mut Search<'a> { - self.vector = Some(vector.into()); + pub fn vector(&mut self, vector: Vec) -> &mut Search<'a> { + self.vector = Some(vector); self } @@ -133,30 +141,75 @@ impl<'a> Search<'a> { self } + pub fn distribution_shift( + &mut self, + distribution_shift: Option, + ) -> &mut Search<'a> { + self.distribution_shift = distribution_shift; + self + } + + pub fn embedder_name(&mut self, embedder_name: impl Into) -> &mut Search<'a> { + self.embedder_name = Some(embedder_name.into()); + self + } + + pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result { + if has_vector_search { + let ctx = SearchContext::new(self.index, self.rtxn); + filtered_universe(&ctx, &self.filter) + } else { + Ok(self.execute()?.candidates) + } + } + pub fn execute(&self) -> Result { + let embedder_name; + let embedder_name = match &self.embedder_name { + Some(embedder_name) => embedder_name, + None => { + embedder_name = self.index.default_embedding_name(self.rtxn)?; + &embedder_name + } + }; + let mut ctx = SearchContext::new(self.index, self.rtxn); if let Some(searchable_attributes) = self.searchable_attributes { ctx.searchable_attributes(searchable_attributes)?; } + let universe = filtered_universe(&ctx, &self.filter)?; let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } = - execute_search( - &mut ctx, - &self.query, - &self.vector, - self.terms_matching_strategy, - self.scoring_strategy, - self.exhaustive_number_hits, - &self.filter, - &self.sort_criteria, - self.geo_strategy, - self.offset, - self.limit, - Some(self.words_limit), - &mut DefaultSearchLogger, - &mut DefaultSearchLogger, - )?; + match self.vector.as_ref() { + Some(vector) => execute_vector_search( + &mut ctx, + vector, + self.scoring_strategy, + universe, + &self.sort_criteria, + self.geo_strategy, + self.offset, + self.limit, + self.distribution_shift, + embedder_name, + )?, + None => execute_search( + &mut ctx, + self.query.as_deref(), + self.terms_matching_strategy, + self.scoring_strategy, + self.exhaustive_number_hits, + universe, + &self.sort_criteria, + self.geo_strategy, + self.offset, + self.limit, + Some(self.words_limit), + &mut DefaultSearchLogger, + &mut DefaultSearchLogger, + )?, + }; // consume context and located_query_terms to build MatchingWords. let matching_words = match located_query_terms { @@ -185,6 +238,8 @@ impl fmt::Debug for Search<'_> { exhaustive_number_hits, rtxn: _, index: _, + distribution_shift, + embedder_name, } = self; f.debug_struct("Search") .field("query", query) @@ -198,6 +253,8 @@ impl fmt::Debug for Search<'_> { .field("scoring_strategy", scoring_strategy) .field("exhaustive_number_hits", exhaustive_number_hits) .field("words_limit", words_limit) + .field("distribution_shift", distribution_shift) + .field("embedder_name", embedder_name) .finish() } } @@ -249,11 +306,16 @@ pub struct SearchForFacetValues<'a> { query: Option, facet: String, search_query: Search<'a>, + is_hybrid: bool, } impl<'a> SearchForFacetValues<'a> { - pub fn new(facet: String, search_query: Search<'a>) -> SearchForFacetValues<'a> { - SearchForFacetValues { query: None, facet, search_query } + pub fn new( + facet: String, + search_query: Search<'a>, + is_hybrid: bool, + ) -> SearchForFacetValues<'a> { + SearchForFacetValues { query: None, facet, search_query, is_hybrid } } pub fn query(&mut self, query: impl Into) -> &mut Self { @@ -303,7 +365,9 @@ impl<'a> SearchForFacetValues<'a> { None => return Ok(vec![]), }; - let search_candidates = self.search_query.execute()?.candidates; + let search_candidates = self + .search_query + .execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?; match self.query.as_ref() { Some(query) => { diff --git a/milli/src/search/new/geo_sort.rs b/milli/src/search/new/geo_sort.rs index b2e3a2f3d..5f5ceb379 100644 --- a/milli/src/search/new/geo_sort.rs +++ b/milli/src/search/new/geo_sort.rs @@ -107,12 +107,16 @@ impl GeoSort { /// Refill the internal buffer of cached docids based on the strategy. /// Drop the rtree if we don't need it anymore. - fn fill_buffer(&mut self, ctx: &mut SearchContext) -> Result<()> { + fn fill_buffer( + &mut self, + ctx: &mut SearchContext, + geo_candidates: &RoaringBitmap, + ) -> Result<()> { debug_assert!(self.field_ids.is_some(), "fill_buffer can't be called without the lat&lng"); debug_assert!(self.cached_sorted_docids.is_empty()); // lazily initialize the rtree if needed by the strategy, and cache it in `self.rtree` - let rtree = if self.strategy.use_rtree(self.geo_candidates.len() as usize) { + let rtree = if self.strategy.use_rtree(geo_candidates.len() as usize) { if let Some(rtree) = self.rtree.as_ref() { // get rtree from cache Some(rtree) @@ -131,7 +135,7 @@ impl GeoSort { if self.ascending { let point = lat_lng_to_xyz(&self.point); for point in rtree.nearest_neighbor_iter(&point) { - if self.geo_candidates.contains(point.data.0) { + if geo_candidates.contains(point.data.0) { self.cached_sorted_docids.push_back(point.data); if self.cached_sorted_docids.len() >= cache_size { break; @@ -143,7 +147,7 @@ impl GeoSort { // and we insert the points in reverse order they get reversed when emptying the cache later on let point = lat_lng_to_xyz(&opposite_of(self.point)); for point in rtree.nearest_neighbor_iter(&point) { - if self.geo_candidates.contains(point.data.0) { + if geo_candidates.contains(point.data.0) { self.cached_sorted_docids.push_front(point.data); if self.cached_sorted_docids.len() >= cache_size { break; @@ -155,8 +159,7 @@ impl GeoSort { // the iterative version let [lat, lng] = self.field_ids.unwrap(); - let mut documents = self - .geo_candidates + let mut documents = geo_candidates .iter() .map(|id| -> Result<_> { Ok((id, geo_value(id, lat, lng, ctx.index, ctx.txn)?)) }) .collect::>>()?; @@ -216,9 +219,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { assert!(self.query.is_none()); self.query = Some(query.clone()); - self.geo_candidates &= universe; - if self.geo_candidates.is_empty() { + let geo_candidates = &self.geo_candidates & universe; + + if geo_candidates.is_empty() { return Ok(()); } @@ -226,7 +230,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { let lat = fid_map.id("_geo.lat").expect("geo candidates but no fid for lat"); let lng = fid_map.id("_geo.lng").expect("geo candidates but no fid for lng"); self.field_ids = Some([lat, lng]); - self.fill_buffer(ctx)?; + self.fill_buffer(ctx, &geo_candidates)?; Ok(()) } @@ -238,9 +242,10 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { universe: &RoaringBitmap, ) -> Result>> { let query = self.query.as_ref().unwrap().clone(); - self.geo_candidates &= universe; - if self.geo_candidates.is_empty() { + let geo_candidates = &self.geo_candidates & universe; + + if geo_candidates.is_empty() { return Ok(Some(RankingRuleOutput { query, candidates: universe.clone(), @@ -261,7 +266,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { } }; while let Some((id, point)) = next(&mut self.cached_sorted_docids) { - if self.geo_candidates.contains(id) { + if geo_candidates.contains(id) { return Ok(Some(RankingRuleOutput { query, candidates: RoaringBitmap::from_iter([id]), @@ -276,7 +281,7 @@ impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for GeoSort { // if we got out of this loop it means we've exhausted our cache. // we need to refill it and run the function again. - self.fill_buffer(ctx)?; + self.fill_buffer(ctx, &geo_candidates)?; self.next_bucket(ctx, logger, universe) } diff --git a/milli/src/search/new/matches/mod.rs b/milli/src/search/new/matches/mod.rs index 5d61de0f4..067fa1efd 100644 --- a/milli/src/search/new/matches/mod.rs +++ b/milli/src/search/new/matches/mod.rs @@ -498,19 +498,19 @@ mod tests { use super::*; use crate::index::tests::TempIndex; - use crate::{execute_search, SearchContext}; + use crate::{execute_search, filtered_universe, SearchContext}; impl<'a> MatcherBuilder<'a> { fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self { let mut ctx = SearchContext::new(index, rtxn); + let universe = filtered_universe(&ctx, &None).unwrap(); let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search( &mut ctx, - &Some(query.to_string()), - &None, + Some(query), crate::TermsMatchingStrategy::default(), crate::score_details::ScoringStrategy::Skip, false, - &None, + universe, &None, crate::search::new::GeoSortStrategy::default(), 0, diff --git a/milli/src/search/new/mod.rs b/milli/src/search/new/mod.rs index a1b5da4e8..405b9747d 100644 --- a/milli/src/search/new/mod.rs +++ b/milli/src/search/new/mod.rs @@ -16,6 +16,7 @@ mod small_bitmap; mod exact_attribute; mod sort; +mod vector_sort; #[cfg(test)] mod tests; @@ -28,7 +29,6 @@ use db_cache::DatabaseCache; use exact_attribute::ExactAttribute; use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo}; use heed::RoTxn; -use instant_distance::Search; use interner::{DedupInterner, Interner}; pub use logger::visual::VisualSearchLogger; pub use logger::{DefaultSearchLogger, SearchLogger}; @@ -46,10 +46,11 @@ use self::geo_sort::GeoSort; pub use self::geo_sort::Strategy as GeoSortStrategy; use self::graph_based_ranking_rule::Words; use self::interner::Interned; -use crate::distance::NDotProductPoint; +use self::vector_sort::VectorSort; use crate::error::FieldIdMapMissingEntry; use crate::score_details::{ScoreDetails, ScoringStrategy}; use crate::search::new::distinct::apply_distinct_rule; +use crate::vector::DistributionShift; use crate::{ AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError, }; @@ -258,6 +259,80 @@ fn get_ranking_rules_for_placeholder_search<'ctx>( Ok(ranking_rules) } +fn get_ranking_rules_for_vector<'ctx>( + ctx: &SearchContext<'ctx>, + sort_criteria: &Option>, + geo_strategy: geo_sort::Strategy, + limit_plus_offset: usize, + target: &[f32], + distribution_shift: Option, + embedder_name: &str, +) -> Result>> { + // query graph search + + let mut sort = false; + let mut sorted_fields = HashSet::new(); + let mut geo_sorted = false; + + let mut vector = false; + let mut ranking_rules: Vec> = vec![]; + + let settings_ranking_rules = ctx.index.criteria(ctx.txn)?; + for rr in settings_ranking_rules { + match rr { + crate::Criterion::Words + | crate::Criterion::Typo + | crate::Criterion::Proximity + | crate::Criterion::Attribute + | crate::Criterion::Exactness => { + if !vector { + let vector_candidates = ctx.index.documents_ids(ctx.txn)?; + let vector_sort = VectorSort::new( + ctx, + target.to_vec(), + vector_candidates, + limit_plus_offset, + distribution_shift, + embedder_name, + )?; + ranking_rules.push(Box::new(vector_sort)); + vector = true; + } + } + crate::Criterion::Sort => { + if sort { + continue; + } + resolve_sort_criteria( + sort_criteria, + ctx, + &mut ranking_rules, + &mut sorted_fields, + &mut geo_sorted, + geo_strategy, + )?; + sort = true; + } + crate::Criterion::Asc(field_name) => { + if sorted_fields.contains(&field_name) { + continue; + } + sorted_fields.insert(field_name.clone()); + ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?)); + } + crate::Criterion::Desc(field_name) => { + if sorted_fields.contains(&field_name) { + continue; + } + sorted_fields.insert(field_name.clone()); + ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?)); + } + } + } + + Ok(ranking_rules) +} + /// Return the list of initialised ranking rules to be used for a query graph search. fn get_ranking_rules_for_query_graph_search<'ctx>( ctx: &SearchContext<'ctx>, @@ -422,15 +497,72 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>( Ok(()) } +pub fn filtered_universe(ctx: &SearchContext, filters: &Option) -> Result { + Ok(if let Some(filters) = filters { + filters.evaluate(ctx.txn, ctx.index)? + } else { + ctx.index.documents_ids(ctx.txn)? + }) +} + +#[allow(clippy::too_many_arguments)] +pub fn execute_vector_search( + ctx: &mut SearchContext, + vector: &[f32], + scoring_strategy: ScoringStrategy, + universe: RoaringBitmap, + sort_criteria: &Option>, + geo_strategy: geo_sort::Strategy, + from: usize, + length: usize, + distribution_shift: Option, + embedder_name: &str, +) -> Result { + check_sort_criteria(ctx, sort_criteria.as_ref())?; + + // FIXME: input universe = universe & documents_with_vectors + // for now if we're computing embeddings for ALL documents, we can assume that this is just universe + let ranking_rules = get_ranking_rules_for_vector( + ctx, + sort_criteria, + geo_strategy, + from + length, + vector, + distribution_shift, + embedder_name, + )?; + + let mut placeholder_search_logger = logger::DefaultSearchLogger; + let placeholder_search_logger: &mut dyn SearchLogger = + &mut placeholder_search_logger; + + let BucketSortOutput { docids, scores, all_candidates } = bucket_sort( + ctx, + ranking_rules, + &PlaceholderQuery, + &universe, + from, + length, + scoring_strategy, + placeholder_search_logger, + )?; + + Ok(PartialSearchResult { + candidates: all_candidates, + document_scores: scores, + documents_ids: docids, + located_query_terms: None, + }) +} + #[allow(clippy::too_many_arguments)] pub fn execute_search( ctx: &mut SearchContext, - query: &Option, - vector: &Option>, + query: Option<&str>, terms_matching_strategy: TermsMatchingStrategy, scoring_strategy: ScoringStrategy, exhaustive_number_hits: bool, - filters: &Option, + mut universe: RoaringBitmap, sort_criteria: &Option>, geo_strategy: geo_sort::Strategy, from: usize, @@ -439,60 +571,8 @@ pub fn execute_search( placeholder_search_logger: &mut dyn SearchLogger, query_graph_logger: &mut dyn SearchLogger, ) -> Result { - let mut universe = if let Some(filters) = filters { - filters.evaluate(ctx.txn, ctx.index)? - } else { - ctx.index.documents_ids(ctx.txn)? - }; - check_sort_criteria(ctx, sort_criteria.as_ref())?; - if let Some(vector) = vector { - let mut search = Search::default(); - let docids = match ctx.index.vector_hnsw(ctx.txn)? { - Some(hnsw) => { - if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() { - if vector.len() != expected_size { - return Err(UserError::InvalidVectorDimensions { - expected: expected_size, - found: vector.len(), - } - .into()); - } - } - - let vector = NDotProductPoint::new(vector.clone()); - - let neighbors = hnsw.search(&vector, &mut search); - - let mut docids = Vec::new(); - let mut uniq_docids = RoaringBitmap::new(); - for instant_distance::Item { distance: _, pid, point: _ } in neighbors { - let index = pid.into_inner(); - let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap(); - if universe.contains(docid) && uniq_docids.insert(docid) { - docids.push(docid); - if docids.len() == (from + length) { - break; - } - } - } - - // return the nearest documents that are also part of the candidates - // along with a dummy list of scores that are useless in this context. - docids.into_iter().skip(from).take(length).collect() - } - None => Vec::new(), - }; - - return Ok(PartialSearchResult { - candidates: universe, - document_scores: vec![Vec::new(); docids.len()], - documents_ids: docids, - located_query_terms: None, - }); - } - let mut located_query_terms = None; let query_terms = if let Some(query) = query { // We make sure that the analyzer is aware of the stop words @@ -546,7 +626,7 @@ pub fn execute_search( terms_matching_strategy, )?; - universe = + universe &= resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?; bucket_sort( diff --git a/milli/src/search/new/vector_sort.rs b/milli/src/search/new/vector_sort.rs new file mode 100644 index 000000000..b29a72827 --- /dev/null +++ b/milli/src/search/new/vector_sort.rs @@ -0,0 +1,170 @@ +use std::iter::FromIterator; + +use ordered_float::OrderedFloat; +use roaring::RoaringBitmap; + +use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait}; +use crate::score_details::{self, ScoreDetails}; +use crate::vector::DistributionShift; +use crate::{DocumentId, Result, SearchContext, SearchLogger}; + +pub struct VectorSort { + query: Option, + target: Vec, + vector_candidates: RoaringBitmap, + cached_sorted_docids: std::vec::IntoIter<(DocumentId, f32, Vec)>, + limit: usize, + distribution_shift: Option, + embedder_index: u8, +} + +impl VectorSort { + pub fn new( + ctx: &SearchContext, + target: Vec, + vector_candidates: RoaringBitmap, + limit: usize, + distribution_shift: Option, + embedder_name: &str, + ) -> Result { + let embedder_index = ctx + .index + .embedder_category_id + .get(ctx.txn, embedder_name)? + .ok_or_else(|| crate::UserError::InvalidEmbedder(embedder_name.to_owned()))?; + + Ok(Self { + query: None, + target, + vector_candidates, + cached_sorted_docids: Default::default(), + limit, + distribution_shift, + embedder_index, + }) + } + + fn fill_buffer( + &mut self, + ctx: &mut SearchContext<'_>, + vector_candidates: &RoaringBitmap, + ) -> Result<()> { + let writer_index = (self.embedder_index as u16) << 8; + let readers: std::result::Result, _> = (0..=u8::MAX) + .map_while(|k| { + arroy::Reader::open(ctx.txn, writer_index | (k as u16), ctx.index.vector_arroy) + .map(Some) + .or_else(|e| match e { + arroy::Error::MissingMetadata => Ok(None), + e => Err(e), + }) + .transpose() + }) + .collect(); + + let readers = readers?; + + let target = &self.target; + let mut results = Vec::new(); + + for reader in readers.iter() { + let nns_by_vector = + reader.nns_by_vector(ctx.txn, target, self.limit, None, Some(vector_candidates))?; + let vectors: std::result::Result, _> = nns_by_vector + .iter() + .map(|(docid, _)| reader.item_vector(ctx.txn, *docid).transpose().unwrap()) + .collect(); + let vectors = vectors?; + results.extend(nns_by_vector.into_iter().zip(vectors).map(|((x, y), z)| (x, y, z))); + } + results.sort_unstable_by_key(|(_, distance, _)| OrderedFloat(*distance)); + self.cached_sorted_docids = results.into_iter(); + + Ok(()) + } +} + +impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort { + fn id(&self) -> String { + "vector_sort".to_owned() + } + + fn start_iteration( + &mut self, + ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + universe: &RoaringBitmap, + query: &Q, + ) -> Result<()> { + assert!(self.query.is_none()); + + self.query = Some(query.clone()); + let vector_candidates = &self.vector_candidates & universe; + self.fill_buffer(ctx, &vector_candidates)?; + Ok(()) + } + + #[allow(clippy::only_used_in_recursion)] + fn next_bucket( + &mut self, + ctx: &mut SearchContext<'ctx>, + _logger: &mut dyn SearchLogger, + universe: &RoaringBitmap, + ) -> Result>> { + let query = self.query.as_ref().unwrap().clone(); + let vector_candidates = &self.vector_candidates & universe; + + if vector_candidates.is_empty() { + return Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: self.target.clone(), + value_similarity: None, + }), + })); + } + + for (docid, distance, vector) in self.cached_sorted_docids.by_ref() { + if vector_candidates.contains(docid) { + let score = 1.0 - distance; + let score = self + .distribution_shift + .map(|distribution| distribution.shift(score)) + .unwrap_or(score); + return Ok(Some(RankingRuleOutput { + query, + candidates: RoaringBitmap::from_iter([docid]), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: self.target.clone(), + value_similarity: Some((vector, score)), + }), + })); + } + } + + // if we got out of this loop it means we've exhausted our cache. + // we need to refill it and run the function again. + self.fill_buffer(ctx, &vector_candidates)?; + + // we tried filling the buffer, but it remained empty 😢 + // it means we don't actually have any document remaining in the universe with a vector. + // => exit + if self.cached_sorted_docids.len() == 0 { + return Ok(Some(RankingRuleOutput { + query, + candidates: universe.clone(), + score: ScoreDetails::Vector(score_details::Vector { + target_vector: self.target.clone(), + value_similarity: None, + }), + })); + } + + self.next_bucket(ctx, _logger, universe) + } + + fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger) { + self.query = None; + } +} diff --git a/milli/src/update/clear_documents.rs b/milli/src/update/clear_documents.rs index 59adda3e8..a6c7ff2b1 100644 --- a/milli/src/update/clear_documents.rs +++ b/milli/src/update/clear_documents.rs @@ -42,7 +42,8 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { facet_id_is_empty_docids, field_id_docid_facet_f64s, field_id_docid_facet_strings, - vector_id_docid, + vector_arroy, + embedder_category_id: _, documents, } = self.index; @@ -58,7 +59,6 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { self.index.put_field_distribution(self.wtxn, &FieldDistribution::default())?; self.index.delete_geo_rtree(self.wtxn)?; self.index.delete_geo_faceted_documents_ids(self.wtxn)?; - self.index.delete_vector_hnsw(self.wtxn)?; // Clear the other databases. external_documents_ids.clear(self.wtxn)?; @@ -82,7 +82,9 @@ impl<'t, 'i> ClearDocuments<'t, 'i> { facet_id_string_docids.clear(self.wtxn)?; field_id_docid_facet_f64s.clear(self.wtxn)?; field_id_docid_facet_strings.clear(self.wtxn)?; - vector_id_docid.clear(self.wtxn)?; + // vector + vector_arroy.clear(self.wtxn)?; + documents.clear(self.wtxn)?; Ok(number_of_documents) diff --git a/milli/src/update/index_documents/extract/extract_vector_points.rs b/milli/src/update/index_documents/extract/extract_vector_points.rs index 317a9aec3..3a0376511 100644 --- a/milli/src/update/index_documents/extract/extract_vector_points.rs +++ b/milli/src/update/index_documents/extract/extract_vector_points.rs @@ -1,9 +1,10 @@ use std::cmp::Ordering; -use std::convert::TryFrom; +use std::convert::{TryFrom, TryInto}; use std::fs::File; use std::io::{self, BufReader, BufWriter}; use std::mem::size_of; use std::str::from_utf8; +use std::sync::Arc; use bytemuck::cast_slice; use grenad::Writer; @@ -13,13 +14,56 @@ use serde_json::{from_slice, Value}; use super::helpers::{create_writer, writer_into_reader, GrenadParameters}; use crate::error::UserError; +use crate::prompt::Prompt; use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd}; use crate::update::index_documents::helpers::try_split_at; -use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors}; +use crate::vector::Embedder; +use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors}; /// The length of the elements that are always in the buffer when inserting new values. const TRUNCATE_SIZE: usize = size_of::(); +pub struct ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + pub manual_vectors: grenad::Reader>, + // docid -> () + pub remove_vectors: grenad::Reader>, + // docid -> prompt + pub prompts: grenad::Reader>, +} + +enum VectorStateDelta { + NoChange, + // Remove all vectors, generated or manual, from this document + NowRemoved, + + // Add the manually specified vectors, passed in the other grenad + // Remove any previously generated vectors + // Note: changing the value of the manually specified vector **should not record** this delta + WasGeneratedNowManual(Vec>), + + ManualDelta(Vec>, Vec>), + + // Add the vector computed from the specified prompt + // Remove any previous vector + // Note: changing the value of the prompt **does require** recording this delta + NowGenerated(String), +} + +impl VectorStateDelta { + fn into_values(self) -> (bool, String, (Vec>, Vec>)) { + match self { + VectorStateDelta::NoChange => Default::default(), + VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()), + VectorStateDelta::WasGeneratedNowManual(add) => { + (true, Default::default(), (Default::default(), add)) + } + VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)), + VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()), + } + } +} + /// Extracts the embedding vector contained in each document under the `_vectors` field. /// /// Returns the generated grenad reader containing the docid as key associated to the Vec @@ -27,16 +71,35 @@ const TRUNCATE_SIZE: usize = size_of::(); pub fn extract_vector_points( obkv_documents: grenad::Reader, indexer: GrenadParameters, - vectors_fid: FieldId, -) -> Result>> { + field_id_map: &FieldsIdsMap, + prompt: &Prompt, + embedder_name: &str, +) -> Result { puffin::profile_function!(); - let mut writer = create_writer( + // (docid, _index) -> KvWriterDelAdd -> Vector + let mut manual_vectors_writer = create_writer( indexer.chunk_compression_type, indexer.chunk_compression_level, tempfile::tempfile()?, ); + // (docid) -> (prompt) + let mut prompts_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + // (docid) -> () + let mut remove_vectors_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let vectors_fid = field_id_map.id("_vectors"); + let mut key_buffer = Vec::new(); let mut cursor = obkv_documents.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { @@ -53,43 +116,157 @@ pub fn extract_vector_points( // lazily get it when needed let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() }; - // first we retrieve the _vectors field - if let Some(value) = obkv.get(vectors_fid) { - let vectors_obkv = KvReaderDelAdd::new(value); + let vectors_field = vectors_fid + .and_then(|vectors_fid| obkv.get(vectors_fid)) + .map(KvReaderDelAdd::new) + .map(|obkv| to_vector_maps(obkv, document_id)) + .transpose()?; - // then we extract the values - let del_vectors = vectors_obkv - .get(DelAdd::Deletion) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); - let add_vectors = vectors_obkv - .get(DelAdd::Addition) - .map(|vectors| extract_vectors(vectors, document_id)) - .transpose()? - .flatten(); + let (del_map, add_map) = vectors_field.unzip(); + let del_map = del_map.flatten(); + let add_map = add_map.flatten(); - // and we finally push the unique vectors into the writer - push_vectors_diff( - &mut writer, - &mut key_buffer, - del_vectors.unwrap_or_default(), - add_vectors.unwrap_or_default(), - )?; - } + let del_value = del_map.and_then(|mut map| map.remove(embedder_name)); + let add_value = add_map.and_then(|mut map| map.remove(embedder_name)); + + let delta = match (del_value, add_value) { + (Some(old), Some(new)) => { + // no autogeneration + let del_vectors = extract_vectors(old, document_id, embedder_name)?; + let add_vectors = extract_vectors(new, document_id, embedder_name)?; + + if add_vectors.len() > u8::MAX.into() { + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + document_id().to_string(), + add_vectors.len(), + ))); + } + + VectorStateDelta::ManualDelta(del_vectors, add_vectors) + } + (Some(_old), None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + if document_is_kept { + // becomes autogenerated + VectorStateDelta::NowGenerated(prompt.render( + obkv, + DelAdd::Addition, + field_id_map, + )?) + } else { + VectorStateDelta::NowRemoved + } + } + (None, Some(new)) => { + // was possibly autogenerated, remove all vectors for that document + let add_vectors = extract_vectors(new, document_id, embedder_name)?; + if add_vectors.len() > u8::MAX.into() { + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + document_id().to_string(), + add_vectors.len(), + ))); + } + + VectorStateDelta::WasGeneratedNowManual(add_vectors) + } + (None, None) => { + // Do we keep this document? + let document_is_kept = obkv + .iter() + .map(|(_, deladd)| KvReaderDelAdd::new(deladd)) + .any(|deladd| deladd.get(DelAdd::Addition).is_some()); + + if document_is_kept { + // Don't give up if the old prompt was failing + let old_prompt = + prompt.render(obkv, DelAdd::Deletion, field_id_map).unwrap_or_default(); + let new_prompt = prompt.render(obkv, DelAdd::Addition, field_id_map)?; + if old_prompt != new_prompt { + log::trace!( + "🚀 Changing prompt from\n{old_prompt}\n===to===\n{new_prompt}" + ); + VectorStateDelta::NowGenerated(new_prompt) + } else { + log::trace!("⏭️ Prompt unmodified, skipping"); + VectorStateDelta::NoChange + } + } else { + VectorStateDelta::NowRemoved + } + } + }; + + // and we finally push the unique vectors into the writer + push_vectors_diff( + &mut remove_vectors_writer, + &mut prompts_writer, + &mut manual_vectors_writer, + &mut key_buffer, + delta, + )?; } - writer_into_reader(writer) + Ok(ExtractedVectorPoints { + // docid, _index -> KvWriterDelAdd -> Vector + manual_vectors: writer_into_reader(manual_vectors_writer)?, + // docid -> () + remove_vectors: writer_into_reader(remove_vectors_writer)?, + // docid -> prompt + prompts: writer_into_reader(prompts_writer)?, + }) +} + +fn to_vector_maps( + obkv: KvReaderDelAdd, + document_id: impl Fn() -> Value, +) -> Result<(Option>, Option>)> { + let del = to_vector_map(obkv, DelAdd::Deletion, &document_id)?; + let add = to_vector_map(obkv, DelAdd::Addition, &document_id)?; + Ok((del, add)) +} + +fn to_vector_map( + obkv: KvReaderDelAdd, + side: DelAdd, + document_id: &impl Fn() -> Value, +) -> Result>> { + Ok(if let Some(value) = obkv.get(side) { + let Ok(value) = from_slice(value) else { + let value = from_slice(value).map_err(InternalError::SerdeJson)?; + return Err(crate::Error::UserError(UserError::InvalidVectorsMapType { + document_id: document_id(), + value, + })); + }; + Some(value) + } else { + None + }) } /// Computes the diff between both Del and Add numbers and /// only inserts the parts that differ in the sorter. fn push_vectors_diff( - writer: &mut Writer>, + remove_vectors_writer: &mut Writer>, + prompts_writer: &mut Writer>, + manual_vectors_writer: &mut Writer>, key_buffer: &mut Vec, - mut del_vectors: Vec>, - mut add_vectors: Vec>, + delta: VectorStateDelta, ) -> Result<()> { + let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values(); + if must_remove { + key_buffer.truncate(TRUNCATE_SIZE); + remove_vectors_writer.insert(&key_buffer, [])?; + } + if !prompt.is_empty() { + key_buffer.truncate(TRUNCATE_SIZE); + prompts_writer.insert(&key_buffer, prompt.as_bytes())?; + } + // We sort and dedup the vectors del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b)); @@ -114,7 +291,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Deletion, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } EitherOrBoth::Right(vector) => { // We insert only the Add part of the Obkv to inform @@ -122,7 +299,7 @@ fn push_vectors_diff( let mut obkv = KvWriterDelAdd::memory(); obkv.insert(DelAdd::Addition, cast_slice(&vector))?; let bytes = obkv.into_inner()?; - writer.insert(&key_buffer, bytes)?; + manual_vectors_writer.insert(&key_buffer, bytes)?; } } } @@ -136,13 +313,112 @@ fn compare_vectors(a: &[f32], b: &[f32]) -> Ordering { } /// Extracts the vectors from a JSON value. -fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result>>> { - match from_slice(value) { - Ok(vectors) => Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors)), +fn extract_vectors( + value: Value, + document_id: impl Fn() -> Value, + name: &str, +) -> Result>> { + // FIXME: ugly clone of the vectors here + match serde_json::from_value(value.clone()) { + Ok(vectors) => { + Ok(VectorOrArrayOfVectors::into_array_of_vectors(vectors).unwrap_or_default()) + } Err(_) => Err(UserError::InvalidVectorsType { document_id: document_id(), - value: from_slice(value).map_err(InternalError::SerdeJson)?, + value, + subfield: name.to_owned(), } .into()), } } + +#[logging_timer::time] +pub fn extract_embeddings( + // docid, prompt + prompt_reader: grenad::Reader, + indexer: GrenadParameters, + embedder: Arc, +) -> Result>> { + let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?; + + let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism + let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk + + // docid, state with embedding + let mut state_writer = create_writer( + indexer.chunk_compression_type, + indexer.chunk_compression_level, + tempfile::tempfile()?, + ); + + let mut chunks = Vec::with_capacity(n_chunks); + let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk); + let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk); + let mut chunks_ids = Vec::with_capacity(n_chunks); + let mut cursor = prompt_reader.into_cursor()?; + + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + // SAFETY: precondition, the grenad value was saved from a string + let prompt = unsafe { std::str::from_utf8_unchecked(value) }; + if current_chunk.len() == current_chunk.capacity() { + chunks.push(std::mem::replace( + &mut current_chunk, + Vec::with_capacity(n_vectors_per_chunk), + )); + chunks_ids.push(std::mem::replace( + &mut current_chunk_ids, + Vec::with_capacity(n_vectors_per_chunk), + )); + }; + current_chunk.push(prompt.to_owned()); + current_chunk_ids.push(docid); + + if chunks.len() == chunks.capacity() { + let chunked_embeds = rt + .block_on( + embedder + .embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))), + ) + .map_err(crate::vector::Error::from) + .map_err(crate::Error::from)?; + + for (docid, embeddings) in chunks_ids + .iter() + .flat_map(|docids| docids.iter()) + .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + } + chunks_ids.clear(); + } + } + + // send last chunk + if !chunks.is_empty() { + let chunked_embeds = rt + .block_on(embedder.embed_chunks(std::mem::take(&mut chunks))) + .map_err(crate::vector::Error::from) + .map_err(crate::Error::from)?; + for (docid, embeddings) in chunks_ids + .iter() + .flat_map(|docids| docids.iter()) + .zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter())) + { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + } + } + + if !current_chunk.is_empty() { + let embeds = rt + .block_on(embedder.embed(std::mem::take(&mut current_chunk))) + .map_err(crate::vector::Error::from) + .map_err(crate::Error::from)?; + + for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) { + state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?; + } + } + + writer_into_reader(state_writer) +} diff --git a/milli/src/update/index_documents/extract/mod.rs b/milli/src/update/index_documents/extract/mod.rs index 57f349894..1d06849de 100644 --- a/milli/src/update/index_documents/extract/mod.rs +++ b/milli/src/update/index_documents/extract/mod.rs @@ -23,7 +23,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids; use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues}; use self::extract_fid_word_count_docids::extract_fid_word_count_docids; use self::extract_geo_points::extract_geo_points; -use self::extract_vector_points::extract_vector_points; +use self::extract_vector_points::{ + extract_embeddings, extract_vector_points, ExtractedVectorPoints, +}; use self::extract_word_docids::extract_word_docids; use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids; use self::extract_word_position_docids::extract_word_position_docids; @@ -33,7 +35,8 @@ use super::helpers::{ }; use super::{helpers, TypedChunk}; use crate::proximity::ProximityPrecision; -use crate::{FieldId, Result}; +use crate::vector::EmbeddingConfigs; +use crate::{FieldId, FieldsIdsMap, Result}; /// Extract data for each databases from obkv documents in parallel. /// Send data in grenad file over provided Sender. @@ -47,13 +50,14 @@ pub(crate) fn data_from_obkv_documents( faceted_fields: HashSet, primary_key_id: FieldId, geo_fields_ids: Option<(FieldId, FieldId)>, - vectors_field_id: Option, + field_id_map: FieldsIdsMap, stop_words: Option>, allowed_separators: Option<&[&str]>, dictionary: Option<&[&str]>, max_positions_per_attributes: Option, exact_attributes: HashSet, proximity_precision: ProximityPrecision, + embedders: EmbeddingConfigs, ) -> Result<()> { puffin::profile_function!(); @@ -64,7 +68,8 @@ pub(crate) fn data_from_obkv_documents( original_documents_chunk, indexer, lmdb_writer_sx.clone(), - vectors_field_id, + field_id_map.clone(), + embedders.clone(), ) }) .collect::>()?; @@ -276,24 +281,53 @@ fn send_original_documents_data( original_documents_chunk: Result>>, indexer: GrenadParameters, lmdb_writer_sx: Sender>, - vectors_field_id: Option, + field_id_map: FieldsIdsMap, + embedders: EmbeddingConfigs, ) -> Result<()> { let original_documents_chunk = original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?; - if let Some(vectors_field_id) = vectors_field_id { - let documents_chunk_cloned = original_documents_chunk.clone(); - let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); - rayon::spawn(move || { - let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id); - let _ = match result { - Ok(vector_points) => { - lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points))) + let documents_chunk_cloned = original_documents_chunk.clone(); + let lmdb_writer_sx_cloned = lmdb_writer_sx.clone(); + rayon::spawn(move || { + for (name, (embedder, prompt)) in embedders { + let result = extract_vector_points( + documents_chunk_cloned.clone(), + indexer, + &field_id_map, + &prompt, + &name, + ); + match result { + Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => { + let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) { + Ok(results) => Some(results), + Err(error) => { + let _ = lmdb_writer_sx_cloned.send(Err(error)); + None + } + }; + + if !(remove_vectors.is_empty() + && manual_vectors.is_empty() + && embeddings.as_ref().map_or(true, |e| e.is_empty())) + { + let _ = lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints { + remove_vectors, + embeddings, + expected_dimension: embedder.dimensions(), + manual_vectors, + embedder_name: name, + })); + } } - Err(error) => lmdb_writer_sx_cloned.send(Err(error)), - }; - }); - } + + Err(error) => { + let _ = lmdb_writer_sx_cloned.send(Err(error)); + } + } + } + }); // TODO: create a custom internal error lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap(); diff --git a/milli/src/update/index_documents/mod.rs b/milli/src/update/index_documents/mod.rs index f825cad1c..ffc3f6b3a 100644 --- a/milli/src/update/index_documents/mod.rs +++ b/milli/src/update/index_documents/mod.rs @@ -4,7 +4,7 @@ mod helpers; mod transform; mod typed_chunk; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use std::io::{Cursor, Read, Seek}; use std::iter::FromIterator; use std::num::NonZeroU32; @@ -14,6 +14,7 @@ use crossbeam_channel::{Receiver, Sender}; use heed::types::Str; use heed::Database; use log::debug; +use rand::SeedableRng; use roaring::RoaringBitmap; use serde::{Deserialize, Serialize}; use slice_group_by::GroupBy; @@ -36,6 +37,7 @@ pub use crate::update::index_documents::helpers::CursorClonableMmap; use crate::update::{ IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst, }; +use crate::vector::EmbeddingConfigs; use crate::{CboRoaringBitmapCodec, Index, Result}; static MERGED_DATABASE_COUNT: usize = 7; @@ -78,6 +80,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> { should_abort: FA, added_documents: u64, deleted_documents: u64, + embedders: EmbeddingConfigs, } #[derive(Default, Debug, Clone)] @@ -121,6 +124,7 @@ where index, added_documents: 0, deleted_documents: 0, + embedders: Default::default(), }) } @@ -167,6 +171,11 @@ where Ok((self, Ok(indexed_documents))) } + pub fn with_embedders(mut self, embedders: EmbeddingConfigs) -> Self { + self.embedders = embedders; + self + } + /// Remove a batch of documents from the current builder. /// /// Returns the number of documents deleted from the builder. @@ -322,17 +331,18 @@ where // get filterable fields for facet databases let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?; // get the fid of the `_geo.lat` and `_geo.lng` fields. - let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") { + let mut field_id_map = self.index.fields_ids_map(self.wtxn)?; + + // self.index.fields_ids_map($a)? ==>> field_id_map + let geo_fields_ids = match field_id_map.id("_geo") { Some(gfid) => { let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid); let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid); // if `_geo` is faceted then we get the `lat` and `lng` if is_sortable || is_filterable { - let field_ids = self - .index - .fields_ids_map(self.wtxn)? + let field_ids = field_id_map .insert("_geo.lat") - .zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng")) + .zip(field_id_map.insert("_geo.lng")) .ok_or(UserError::AttributeLimitReached)?; Some(field_ids) } else { @@ -341,8 +351,6 @@ where } None => None, }; - // get the fid of the `_vectors` field. - let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors"); let stop_words = self.index.stop_words(self.wtxn)?; let separators = self.index.allowed_separators(self.wtxn)?; @@ -364,6 +372,8 @@ where self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes; + let cloned_embedder = self.embedders.clone(); + // Run extraction pipeline in parallel. pool.install(|| { puffin::profile_scope!("extract_and_send_grenad_chunks"); @@ -387,13 +397,14 @@ where faceted_fields, primary_key_id, geo_fields_ids, - vectors_field_id, + field_id_map, stop_words, separators.as_deref(), dictionary.as_deref(), max_positions_per_attributes, exact_attributes, proximity_precision, + cloned_embedder, ) }); @@ -402,7 +413,7 @@ where } // needs to be dropped to avoid channel waiting lock. - drop(lmdb_writer_sx) + drop(lmdb_writer_sx); }); let index_is_empty = self.index.number_of_documents(self.wtxn)? == 0; @@ -419,6 +430,8 @@ where let mut word_docids = None; let mut exact_word_docids = None; + let mut dimension = HashMap::new(); + for result in lmdb_writer_rx { if (self.should_abort)() { return Err(Error::InternalError(InternalError::AbortedIndexation)); @@ -448,6 +461,22 @@ where word_position_docids = Some(cloneable_chunk); TypedChunk::WordPositionDocids(chunk) } + TypedChunk::VectorPoints { + expected_dimension, + remove_vectors, + embeddings, + manual_vectors, + embedder_name, + } => { + dimension.insert(embedder_name.clone(), expected_dimension); + TypedChunk::VectorPoints { + remove_vectors, + embeddings, + expected_dimension, + manual_vectors, + embedder_name, + } + } otherwise => otherwise, }; @@ -480,6 +509,33 @@ where // We write the primary key field id into the main database self.index.put_primary_key(self.wtxn, &primary_key)?; let number_of_documents = self.index.number_of_documents(self.wtxn)?; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); + + for (embedder_name, dimension) in dimension { + let wtxn = &mut *self.wtxn; + let vector_arroy = self.index.vector_arroy; + + let embedder_index = self.index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, + )?; + + pool.install(|| { + let writer_index = (embedder_index as u16) << 8; + for k in 0..=u8::MAX { + let writer = arroy::Writer::prepare( + wtxn, + vector_arroy, + writer_index | (k as u16), + dimension, + )?; + if writer.is_empty(wtxn)? { + break; + } + writer.build(wtxn, &mut rng, None)?; + } + Result::Ok(()) + })?; + } self.execute_prefix_databases( word_docids, @@ -694,6 +750,8 @@ fn execute_word_prefix_docids( #[cfg(test)] mod tests { + use std::collections::BTreeMap; + use big_s::S; use fst::IntoStreamer; use heed::RwTxn; @@ -703,6 +761,7 @@ mod tests { use crate::documents::documents_batch_reader_from_objects; use crate::index::tests::TempIndex; use crate::search::TermsMatchingStrategy; + use crate::update::Setting; use crate::{db_snap, Filter, Search}; #[test] @@ -2494,18 +2553,39 @@ mod tests { /// Vectors must be of the same length. #[test] fn test_multiple_vectors() { + use crate::vector::settings::{EmbedderSettings, EmbeddingSettings}; let index = TempIndex::new(); - index.add_documents(documents!([{"id": 0, "_vectors": [[0, 1, 2], [3, 4, 5]] }])).unwrap(); - index.add_documents(documents!([{"id": 1, "_vectors": [6, 7, 8] }])).unwrap(); + index + .update_settings(|settings| { + let mut embedders = BTreeMap::default(); + embedders.insert( + "manual".to_string(), + Setting::Set(EmbeddingSettings { + embedder_options: Setting::Set(EmbedderSettings::UserProvided( + crate::vector::settings::UserProvidedSettings { dimensions: 3 }, + )), + document_template: Setting::NotSet, + }), + ); + settings.set_embedder_settings(embedders); + }) + .unwrap(); + index .add_documents( - documents!([{"id": 2, "_vectors": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }]), + documents!([{"id": 0, "_vectors": { "manual": [[0, 1, 2], [3, 4, 5]] } }]), + ) + .unwrap(); + index.add_documents(documents!([{"id": 1, "_vectors": { "manual": [6, 7, 8] }}])).unwrap(); + index + .add_documents( + documents!([{"id": 2, "_vectors": { "manual": [[9, 10, 11], [12, 13, 14], [15, 16, 17]] }}]), ) .unwrap(); let rtxn = index.read_txn().unwrap(); - let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap(); + let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap(); assert_eq!(res.documents_ids.len(), 3); } diff --git a/milli/src/update/index_documents/typed_chunk.rs b/milli/src/update/index_documents/typed_chunk.rs index 49e36b87e..f8fb30c7b 100644 --- a/milli/src/update/index_documents/typed_chunk.rs +++ b/milli/src/update/index_documents/typed_chunk.rs @@ -1,4 +1,4 @@ -use std::collections::{HashMap, HashSet}; +use std::collections::HashMap; use std::convert::TryInto; use std::fs::File; use std::io::{self, BufReader}; @@ -8,9 +8,7 @@ use charabia::{Language, Script}; use grenad::MergerBuilder; use heed::types::Bytes; use heed::{PutFlags, RwTxn}; -use log::error; use obkv::{KvReader, KvWriter}; -use ordered_float::OrderedFloat; use roaring::RoaringBitmap; use super::helpers::{ @@ -18,16 +16,15 @@ use super::helpers::{ valid_lmdb_key, CursorClonableMmap, }; use super::{ClonableMmap, MergeFn}; -use crate::distance::NDotProductPoint; -use crate::error::UserError; use crate::external_documents_ids::{DocumentOperation, DocumentOperationKind}; use crate::facet::FacetType; use crate::index::db_name::DOCUMENTS; -use crate::index::Hnsw; use crate::update::del_add::{deladd_serialize_add_side, DelAdd, KvReaderDelAdd}; use crate::update::facet::FacetsUpdate; use crate::update::index_documents::helpers::{as_cloneable_grenad, try_split_array_at}; -use crate::{lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, Result, SerializationError}; +use crate::{ + lat_lng_to_xyz, DocumentId, FieldId, GeoPoint, Index, InternalError, Result, SerializationError, +}; pub(crate) enum TypedChunk { FieldIdDocidFacetStrings(grenad::Reader), @@ -47,7 +44,13 @@ pub(crate) enum TypedChunk { FieldIdFacetIsNullDocids(grenad::Reader>), FieldIdFacetIsEmptyDocids(grenad::Reader>), GeoPoints(grenad::Reader>), - VectorPoints(grenad::Reader>), + VectorPoints { + remove_vectors: grenad::Reader>, + embeddings: Option>>, + expected_dimension: usize, + manual_vectors: grenad::Reader>, + embedder_name: String, + }, ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>), } @@ -100,8 +103,8 @@ impl TypedChunk { TypedChunk::GeoPoints(grenad) => { format!("GeoPoints {{ number_of_entries: {} }}", grenad.len()) } - TypedChunk::VectorPoints(grenad) => { - format!("VectorPoints {{ number_of_entries: {} }}", grenad.len()) + TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension, embedder_name } => { + format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {}, embedder_name: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension, embedder_name) } TypedChunk::ScriptLanguageDocids(sl_map) => { format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len()) @@ -355,19 +358,77 @@ pub(crate) fn write_typed_chunk_into_index( index.put_geo_rtree(wtxn, &rtree)?; index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?; } - TypedChunk::VectorPoints(vector_points) => { - let mut vectors_set = HashSet::new(); - // We extract and store the previous vectors - if let Some(hnsw) = index.vector_hnsw(wtxn)? { - for (pid, point) in hnsw.iter() { - let pid_key = pid.into_inner(); - let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap(); - let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); + TypedChunk::VectorPoints { + remove_vectors, + manual_vectors, + embeddings, + expected_dimension, + embedder_name, + } => { + let embedder_index = index.embedder_category_id.get(wtxn, &embedder_name)?.ok_or( + InternalError::DatabaseMissingEntry { db_name: "embedder_category_id", key: None }, + )?; + let writer_index = (embedder_index as u16) << 8; + // FIXME: allow customizing distance + let writers: std::result::Result, _> = (0..=u8::MAX) + .map(|k| { + arroy::Writer::prepare( + wtxn, + index.vector_arroy, + writer_index | (k as u16), + expected_dimension, + ) + }) + .collect(); + let writers = writers?; + + // remove vectors for docids we want them removed + let mut cursor = remove_vectors.into_cursor()?; + while let Some((key, _)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + + for writer in &writers { + // Uses invariant: vectors are packed in the first writers. + if !writer.del_item(wtxn, docid)? { + break; + } } } - let mut cursor = vector_points.into_cursor()?; + // add generated embeddings + if let Some(embeddings) = embeddings { + let mut cursor = embeddings.into_cursor()?; + while let Some((key, value)) = cursor.move_on_next()? { + let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap(); + let data = pod_collect_to_vec(value); + // it is a code error to have embeddings and not expected_dimension + let embeddings = + crate::vector::Embeddings::from_inner(data, expected_dimension) + // code error if we somehow got the wrong dimension + .unwrap(); + + if embeddings.embedding_count() > u8::MAX.into() { + let external_docid = if let Ok(Some(Ok(index))) = index + .external_id_of(wtxn, std::iter::once(docid)) + .map(|it| it.into_iter().next()) + { + index + } else { + format!("internal docid={docid}") + }; + return Err(crate::Error::UserError(crate::UserError::TooManyVectors( + external_docid, + embeddings.embedding_count(), + ))); + } + for (embedding, writer) in embeddings.iter().zip(&writers) { + writer.add_item(wtxn, docid, embedding)?; + } + } + } + + // perform the manual diff + let mut cursor = manual_vectors.into_cursor()?; while let Some((key, value)) = cursor.move_on_next()? { // convert the key back to a u32 (4 bytes) let (left, _index) = try_split_array_at(key).unwrap(); @@ -375,58 +436,52 @@ pub(crate) fn write_typed_chunk_into_index( let vector_deladd_obkv = KvReaderDelAdd::new(value); if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) { - // convert the vector back to a Vec - let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - let key = (docid, vector); - if !vectors_set.remove(&key) { - error!("Unable to delete the vector: {:?}", key.1); + let vector: Vec = pod_collect_to_vec(value); + + let mut deleted_index = None; + for (index, writer) in writers.iter().enumerate() { + let Some(candidate) = writer.item_vector(wtxn, docid)? else { + // uses invariant: vectors are packed in the first writers. + break; + }; + if candidate == vector { + writer.del_item(wtxn, docid)?; + deleted_index = Some(index); + } + } + + // 🥲 enforce invariant: vectors are packed in the first writers. + if let Some(deleted_index) = deleted_index { + let mut last_index_with_a_vector = None; + for (index, writer) in writers.iter().enumerate().skip(deleted_index) { + let Some(candidate) = writer.item_vector(wtxn, docid)? else { + break; + }; + last_index_with_a_vector = Some((index, candidate)); + } + if let Some((last_index, vector)) = last_index_with_a_vector { + // unwrap: computed the index from the list of writers + let writer = writers.get(last_index).unwrap(); + writer.del_item(wtxn, docid)?; + writers.get(deleted_index).unwrap().add_item(wtxn, docid, &vector)?; + } } } + if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) { - // convert the vector back to a Vec - let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect(); - vectors_set.insert((docid, vector)); - } - } + let vector = pod_collect_to_vec(value); - // Extract the most common vector dimension - let expected_dimension_size = { - let mut dims = HashMap::new(); - vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1); - dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len) - }; - - // Ensure that the vector lengths are correct and - // prepare the vectors before inserting them in the HNSW. - let mut points = Vec::new(); - let mut docids = Vec::new(); - for (docid, vector) in vectors_set { - if expected_dimension_size.map_or(false, |expected| expected != vector.len()) { - return Err(UserError::InvalidVectorDimensions { - expected: expected_dimension_size.unwrap_or(vector.len()), - found: vector.len(), + // overflow was detected during vector extraction. + for writer in &writers { + if !writer.contains_item(wtxn, docid)? { + writer.add_item(wtxn, docid, &vector)?; + break; + } } - .into()); - } else { - let vector = vector.into_iter().map(OrderedFloat::into_inner).collect(); - points.push(NDotProductPoint::new(vector)); - docids.push(docid); } } - let hnsw_length = points.len(); - let (new_hnsw, pids) = Hnsw::builder().build_hnsw(points); - - assert_eq!(docids.len(), pids.len()); - - // Store the vectors in the point-docid relation database - index.vector_id_docid.clear(wtxn)?; - for (docid, pid) in docids.into_iter().zip(pids) { - index.vector_id_docid.put(wtxn, &pid.into_inner(), &docid)?; - } - - log::debug!("There are {} entries in the HNSW so far", hnsw_length); - index.put_vector_hnsw(wtxn, &new_hnsw)?; + log::debug!("Finished vector chunk for {}", embedder_name); } TypedChunk::ScriptLanguageDocids(sl_map) => { for (key, (deletion, addition)) in sl_map { diff --git a/milli/src/update/settings.rs b/milli/src/update/settings.rs index 712e595e9..d406c121c 100644 --- a/milli/src/update/settings.rs +++ b/milli/src/update/settings.rs @@ -1,9 +1,11 @@ use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use std::convert::TryInto; use std::result::Result as StdResult; +use std::sync::Arc; use charabia::{Normalize, Tokenizer, TokenizerBuilder}; use deserr::{DeserializeError, Deserr}; -use itertools::Itertools; +use itertools::{EitherOrBoth, Itertools}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use time::OffsetDateTime; @@ -15,6 +17,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS use crate::proximity::ProximityPrecision; use crate::update::index_documents::IndexDocumentsMethod; use crate::update::{IndexDocuments, UpdateIndexingStep}; +use crate::vector::settings::{EmbeddingSettings, PromptSettings}; +use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs}; use crate::{FieldsIdsMap, Index, OrderBy, Result}; #[derive(Debug, Clone, PartialEq, Eq, Copy)] @@ -73,6 +77,13 @@ impl Setting { otherwise => otherwise, } } + + pub fn apply(&mut self, new: Self) { + if let Setting::NotSet = new { + return; + } + *self = new; + } } impl Serialize for Setting { @@ -129,6 +140,7 @@ pub struct Settings<'a, 't, 'i> { sort_facet_values_by: Setting>, pagination_max_total_hits: Setting, proximity_precision: Setting, + embedder_settings: Setting>>, } impl<'a, 't, 'i> Settings<'a, 't, 'i> { @@ -161,6 +173,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { sort_facet_values_by: Setting::NotSet, pagination_max_total_hits: Setting::NotSet, proximity_precision: Setting::NotSet, + embedder_settings: Setting::NotSet, indexer_config, } } @@ -343,6 +356,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { self.proximity_precision = Setting::Reset; } + pub fn set_embedder_settings(&mut self, value: BTreeMap>) { + self.embedder_settings = Setting::Set(value); + } + + pub fn reset_embedder_settings(&mut self) { + self.embedder_settings = Setting::Reset; + } + fn reindex( &mut self, progress_callback: &FP, @@ -377,6 +398,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { fields_ids_map, )?; + let embedder_configs = self.index.embedding_configs(self.wtxn)?; + let embedders = self.embedders(embedder_configs)?; + // We index the generated `TransformOutput` which must contain // all the documents with fields in the newly defined searchable order. let indexing_builder = IndexDocuments::new( @@ -387,11 +411,33 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { &progress_callback, &should_abort, )?; + + let indexing_builder = indexing_builder.with_embedders(embedders); indexing_builder.execute_raw(output)?; Ok(()) } + fn embedders( + &self, + embedding_configs: Vec<(String, EmbeddingConfig)>, + ) -> Result { + let res: Result<_> = embedding_configs + .into_iter() + .map(|(name, EmbeddingConfig { embedder_options, prompt })| { + let prompt = Arc::new(prompt.try_into().map_err(crate::Error::from)?); + + let embedder = Arc::new( + Embedder::new(embedder_options.clone()) + .map_err(crate::vector::Error::from) + .map_err(crate::Error::from)?, + ); + Ok((name, (embedder, prompt))) + }) + .collect(); + res.map(EmbeddingConfigs::new) + } + fn update_displayed(&mut self) -> Result { match self.displayed_fields { Setting::Set(ref fields) => { @@ -890,6 +936,73 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { Ok(changed) } + fn update_embedding_configs(&mut self) -> Result { + let update = match std::mem::take(&mut self.embedder_settings) { + Setting::Set(configs) => { + let mut changed = false; + let old_configs = self.index.embedding_configs(self.wtxn)?; + let old_configs: BTreeMap> = + old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect(); + + let mut new_configs = BTreeMap::new(); + for joined in old_configs + .into_iter() + .merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right)) + { + match joined { + EitherOrBoth::Both((name, mut old), (_, new)) => { + old.apply(new); + let new = validate_prompt(&name, old)?; + changed = true; + new_configs.insert(name, new); + } + EitherOrBoth::Left((name, setting)) => { + new_configs.insert(name, setting); + } + EitherOrBoth::Right((name, setting)) => { + let setting = validate_prompt(&name, setting)?; + changed = true; + new_configs.insert(name, setting); + } + } + } + let new_configs: Vec<(String, EmbeddingConfig)> = new_configs + .into_iter() + .filter_map(|(name, setting)| match setting { + Setting::Set(value) => Some((name, value.into())), + Setting::Reset => None, + Setting::NotSet => Some((name, EmbeddingSettings::default().into())), + }) + .collect(); + + self.index.embedder_category_id.clear(self.wtxn)?; + for (index, (embedder_name, _)) in new_configs.iter().enumerate() { + self.index.embedder_category_id.put_with_flags( + self.wtxn, + heed::PutFlags::APPEND, + embedder_name, + &index + .try_into() + .map_err(|_| UserError::TooManyEmbedders(new_configs.len()))?, + )?; + } + + if new_configs.is_empty() { + self.index.delete_embedding_configs(self.wtxn)?; + } else { + self.index.put_embedding_configs(self.wtxn, new_configs)?; + } + changed + } + Setting::Reset => { + self.index.delete_embedding_configs(self.wtxn)?; + true + } + Setting::NotSet => false, + }; + Ok(update) + } + pub fn execute(mut self, progress_callback: FP, should_abort: FA) -> Result<()> where FP: Fn(UpdateIndexingStep) + Sync, @@ -927,6 +1040,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { let searchable_updated = self.update_searchable()?; let exact_attributes_updated = self.update_exact_attributes()?; let proximity_precision = self.update_proximity_precision()?; + // TODO: very rough approximation of the needs for reindexing where any change will result in + // a full reindexing. + // What can be done instead: + // 1. Only change the distance on a distance change + // 2. Only change the name -> embedder mapping on a name change + // 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage + let embedding_configs_updated = self.update_embedding_configs()?; if stop_words_updated || non_separator_tokens_updated @@ -937,6 +1057,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { || searchable_updated || exact_attributes_updated || proximity_precision + || embedding_configs_updated { self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?; } @@ -945,6 +1066,31 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> { } } +fn validate_prompt( + name: &str, + new: Setting, +) -> Result> { + match new { + Setting::Set(EmbeddingSettings { + embedder_options, + document_template: Setting::Set(PromptSettings { template: Setting::Set(template) }), + }) => { + // validate + let template = crate::prompt::Prompt::new(template) + .map(|prompt| crate::prompt::PromptData::from(prompt).template) + .map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?; + + Ok(Setting::Set(EmbeddingSettings { + embedder_options, + document_template: Setting::Set(PromptSettings { + template: Setting::Set(template), + }), + })) + } + new => Ok(new), + } +} + #[cfg(test)] mod tests { use big_s::S; @@ -1763,6 +1909,7 @@ mod tests { sort_facet_values_by, pagination_max_total_hits, proximity_precision, + embedder_settings, } = settings; assert!(matches!(searchable_fields, Setting::NotSet)); assert!(matches!(displayed_fields, Setting::NotSet)); @@ -1785,6 +1932,7 @@ mod tests { assert!(matches!(sort_facet_values_by, Setting::NotSet)); assert!(matches!(pagination_max_total_hits, Setting::NotSet)); assert!(matches!(proximity_precision, Setting::NotSet)); + assert!(matches!(embedder_settings, Setting::NotSet)); }) .unwrap(); } diff --git a/milli/src/vector/error.rs b/milli/src/vector/error.rs new file mode 100644 index 000000000..c5cce622d --- /dev/null +++ b/milli/src/vector/error.rs @@ -0,0 +1,244 @@ +use std::path::PathBuf; + +use hf_hub::api::sync::ApiError; + +use crate::error::FaultSource; +use crate::vector::openai::OpenAiError; + +#[derive(Debug, thiserror::Error)] +#[error("Error while generating embeddings: {inner}")] +pub struct Error { + pub inner: Box, +} + +impl> From for Error { + fn from(value: I) -> Self { + Self { inner: Box::new(value.into()) } + } +} + +impl Error { + pub fn fault(&self) -> FaultSource { + match &*self.inner { + ErrorKind::NewEmbedderError(inner) => inner.fault, + ErrorKind::EmbedError(inner) => inner.fault, + } + } +} + +#[derive(Debug, thiserror::Error)] +pub enum ErrorKind { + #[error(transparent)] + NewEmbedderError(#[from] NewEmbedderError), + #[error(transparent)] + EmbedError(#[from] EmbedError), +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct EmbedError { + pub kind: EmbedErrorKind, + pub fault: FaultSource, +} + +#[derive(Debug, thiserror::Error)] +pub enum EmbedErrorKind { + #[error("could not tokenize: {0}")] + Tokenize(Box), + #[error("unexpected tensor shape: {0}")] + TensorShape(candle_core::Error), + #[error("unexpected tensor value: {0}")] + TensorValue(candle_core::Error), + #[error("could not run model: {0}")] + ModelForward(candle_core::Error), + #[error("could not reach OpenAI: {0}")] + OpenAiNetwork(reqwest::Error), + #[error("unexpected response from OpenAI: {0}")] + OpenAiUnexpected(reqwest::Error), + #[error("could not authenticate against OpenAI: {0}")] + OpenAiAuth(OpenAiError), + #[error("sent too many requests to OpenAI: {0}")] + OpenAiTooManyRequests(OpenAiError), + #[error("received internal error from OpenAI: {0}")] + OpenAiInternalServerError(OpenAiError), + #[error("sent too many tokens in a request to OpenAI: {0}")] + OpenAiTooManyTokens(OpenAiError), + #[error("received unhandled HTTP status code {0} from OpenAI")] + OpenAiUnhandledStatusCode(u16), + #[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")] + ManualEmbed(String), +} + +impl EmbedError { + pub fn tokenize(inner: Box) -> Self { + Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime } + } + + pub fn tensor_shape(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug } + } + + pub fn tensor_value(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug } + } + + pub fn model_forward(inner: candle_core::Error) -> Self { + Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime } + } + + pub fn openai_network(inner: reqwest::Error) -> Self { + Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime } + } + + pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug } + } + + pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User } + } + + pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime } + } + + pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug } + } + + pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError { + Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug } + } + + pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError { + Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("{fault}: {kind}")] +pub struct NewEmbedderError { + pub kind: NewEmbedderErrorKind, + pub fault: FaultSource, +} + +impl NewEmbedderError { + pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError { + let open_config = OpenConfig { filename: config_filename, inner }; + + Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime } + } + + pub fn deserialize_config( + config: String, + config_filename: PathBuf, + inner: serde_json::Error, + ) -> NewEmbedderError { + let deserialize_config = DeserializeConfig { config, filename: config_filename, inner }; + Self { + kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config), + fault: FaultSource::Runtime, + } + } + + pub fn open_tokenizer( + tokenizer_filename: PathBuf, + inner: Box, + ) -> NewEmbedderError { + let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner }; + Self { + kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer), + fault: FaultSource::Runtime, + } + } + + pub fn new_api_fail(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug } + } + + pub fn api_get(inner: ApiError) -> Self { + Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided } + } + + pub fn pytorch_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn safetensor_weight(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime } + } + + pub fn load_model(inner: candle_core::Error) -> Self { + Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime } + } + + pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError { + Self { + kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner), + fault: FaultSource::Runtime, + } + } + + pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self { + Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime } + } + + pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self { + Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open config at {filename:?}: {inner}")] +pub struct OpenConfig { + pub filename: PathBuf, + pub inner: std::io::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")] +pub struct DeserializeConfig { + pub config: String, + pub filename: PathBuf, + pub inner: serde_json::Error, +} + +#[derive(Debug, thiserror::Error)] +#[error("could not open tokenizer at {filename}: {inner}")] +pub struct OpenTokenizer { + pub filename: PathBuf, + #[source] + pub inner: Box, +} + +#[derive(Debug, thiserror::Error)] +pub enum NewEmbedderErrorKind { + // hf + #[error(transparent)] + OpenConfig(OpenConfig), + #[error(transparent)] + DeserializeConfig(DeserializeConfig), + #[error(transparent)] + OpenTokenizer(OpenTokenizer), + #[error("could not build weights from Pytorch weights: {0}")] + PytorchWeight(candle_core::Error), + #[error("could not build weights from Safetensor weights: {0}")] + SafetensorWeight(candle_core::Error), + #[error("could not spawn HG_HUB API client: {0}")] + NewApiFail(ApiError), + #[error("fetching file from HG_HUB failed: {0}")] + ApiGet(ApiError), + #[error("could not determine model dimensions: test embedding failed with {0}")] + CouldNotDetermineDimension(EmbedError), + #[error("loading model failed: {0}")] + LoadModel(candle_core::Error), + // openai + #[error("initializing web client for sending embedding requests failed: {0}")] + InitWebClient(reqwest::Error), + #[error("The API key passed to Authorization error was in an invalid format: {0}")] + InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue), +} diff --git a/milli/src/vector/hf.rs b/milli/src/vector/hf.rs new file mode 100644 index 000000000..0a6bcbe93 --- /dev/null +++ b/milli/src/vector/hf.rs @@ -0,0 +1,195 @@ +use candle_core::Tensor; +use candle_nn::VarBuilder; +use candle_transformers::models::bert::{BertModel, Config, DTYPE}; +// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself +use hf_hub::api::sync::Api; +use hf_hub::{Repo, RepoType}; +use tokenizers::{PaddingParams, Tokenizer}; + +pub use super::error::{EmbedError, Error, NewEmbedderError}; +use super::{DistributionShift, Embedding, Embeddings}; + +#[derive( + Debug, + Clone, + Copy, + Default, + Hash, + PartialEq, + Eq, + serde::Deserialize, + serde::Serialize, + deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +enum WeightSource { + #[default] + Safetensors, + Pytorch, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub model: String, + pub revision: Option, +} + +impl EmbedderOptions { + pub fn new() -> Self { + Self { + model: "BAAI/bge-base-en-v1.5".to_string(), + revision: Some("617ca489d9e86b49b8167676d8220688b99db36e".into()), + } + } +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self::new() + } +} + +/// Perform embedding of documents and queries +pub struct Embedder { + model: BertModel, + tokenizer: Tokenizer, + options: EmbedderOptions, + dimensions: usize, +} + +impl std::fmt::Debug for Embedder { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("Embedder") + .field("model", &self.options.model) + .field("tokenizer", &self.tokenizer) + .field("options", &self.options) + .finish() + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> std::result::Result { + let device = candle_core::Device::Cpu; + let repo = match options.revision.clone() { + Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision), + None => Repo::model(options.model.clone()), + }; + let (config_filename, tokenizer_filename, weights_filename, weight_source) = { + let api = Api::new().map_err(NewEmbedderError::new_api_fail)?; + let api = api.repo(repo); + let config = api.get("config.json").map_err(NewEmbedderError::api_get)?; + let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?; + let (weights, source) = { + api.get("pytorch_model.bin") + .map(|filename| (filename, WeightSource::Pytorch)) + .or_else(|_| { + api.get("model.safetensors") + .map(|filename| (filename, WeightSource::Safetensors)) + }) + .map_err(NewEmbedderError::api_get)? + }; + (config, tokenizer, weights, source) + }; + + let config = std::fs::read_to_string(&config_filename) + .map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?; + let config: Config = serde_json::from_str(&config).map_err(|inner| { + NewEmbedderError::deserialize_config(config, config_filename, inner) + })?; + let mut tokenizer = Tokenizer::from_file(&tokenizer_filename) + .map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?; + + let vb = match weight_source { + WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device) + .map_err(NewEmbedderError::pytorch_weight)?, + WeightSource::Safetensors => unsafe { + VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device) + .map_err(NewEmbedderError::safetensor_weight)? + }, + }; + + let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?; + + if let Some(pp) = tokenizer.get_padding_mut() { + pp.strategy = tokenizers::PaddingStrategy::BatchLongest + } else { + let pp = PaddingParams { + strategy: tokenizers::PaddingStrategy::BatchLongest, + ..Default::default() + }; + tokenizer.with_padding(Some(pp)); + } + + let mut this = Self { model, tokenizer, options, dimensions: 0 }; + + let embeddings = this + .embed(vec!["test".into()]) + .map_err(NewEmbedderError::hf_could_not_determine_dimension)?; + this.dimensions = embeddings.first().unwrap().dimension(); + + Ok(this) + } + + pub fn embed( + &self, + mut texts: Vec, + ) -> std::result::Result>, EmbedError> { + let tokens = match texts.len() { + 1 => vec![self + .tokenizer + .encode(texts.pop().unwrap(), true) + .map_err(EmbedError::tokenize)?], + _ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?, + }; + let token_ids = tokens + .iter() + .map(|tokens| { + let tokens = tokens.get_ids().to_vec(); + Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape) + }) + .collect::, EmbedError>>()?; + + let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?; + let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?; + let embeddings = + self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?; + + // Apply some avg-pooling by taking the mean embedding value for all tokens (including padding) + let (_n_sentence, n_tokens, _hidden_size) = + embeddings.dims3().map_err(EmbedError::tensor_shape)?; + + let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64)) + .map_err(EmbedError::tensor_shape)?; + + let embeddings: Vec = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?; + Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect()) + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> std::result::Result>>, EmbedError> { + text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() + } + + pub fn chunk_count_hint(&self) -> usize { + 1 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8) + } + + pub fn dimensions(&self) -> usize { + self.dimensions + } + + pub fn distribution(&self) -> Option { + if self.options.model == "BAAI/bge-base-en-v1.5" { + Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 }) + } else { + None + } + } +} diff --git a/milli/src/vector/manual.rs b/milli/src/vector/manual.rs new file mode 100644 index 000000000..7ed48a251 --- /dev/null +++ b/milli/src/vector/manual.rs @@ -0,0 +1,34 @@ +use super::error::EmbedError; +use super::Embeddings; + +#[derive(Debug, Clone, Copy)] +pub struct Embedder { + dimensions: usize, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub dimensions: usize, +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Self { + Self { dimensions: options.dimensions } + } + + pub fn embed(&self, mut texts: Vec) -> Result>, EmbedError> { + let Some(text) = texts.pop() else { return Ok(Default::default()) }; + Err(EmbedError::embed_on_manual_embedder(text)) + } + + pub fn dimensions(&self) -> usize { + self.dimensions + } + + pub fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + text_chunks.into_iter().map(|prompts| self.embed(prompts)).collect() + } +} diff --git a/milli/src/vector/mod.rs b/milli/src/vector/mod.rs new file mode 100644 index 000000000..81c4cf4a1 --- /dev/null +++ b/milli/src/vector/mod.rs @@ -0,0 +1,257 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use self::error::{EmbedError, NewEmbedderError}; +use crate::prompt::{Prompt, PromptData}; + +pub mod error; +pub mod hf; +pub mod manual; +pub mod openai; +pub mod settings; + +pub use self::error::Error; + +pub type Embedding = Vec; + +pub struct Embeddings { + data: Vec, + dimension: usize, +} + +impl Embeddings { + pub fn new(dimension: usize) -> Self { + Self { data: Default::default(), dimension } + } + + pub fn from_single_embedding(embedding: Vec) -> Self { + Self { dimension: embedding.len(), data: embedding } + } + + pub fn from_inner(data: Vec, dimension: usize) -> Result> { + let mut this = Self::new(dimension); + this.append(data)?; + Ok(this) + } + + pub fn embedding_count(&self) -> usize { + self.data.len() / self.dimension + } + + pub fn dimension(&self) -> usize { + self.dimension + } + + pub fn into_inner(self) -> Vec { + self.data + } + + pub fn as_inner(&self) -> &[F] { + &self.data + } + + pub fn iter(&self) -> impl Iterator + '_ { + self.data.as_slice().chunks_exact(self.dimension) + } + + pub fn push(&mut self, mut embedding: Vec) -> Result<(), Vec> { + if embedding.len() != self.dimension { + return Err(embedding); + } + self.data.append(&mut embedding); + Ok(()) + } + + pub fn append(&mut self, mut embeddings: Vec) -> Result<(), Vec> { + if embeddings.len() % self.dimension != 0 { + return Err(embeddings); + } + self.data.append(&mut embeddings); + Ok(()) + } +} + +#[derive(Debug)] +pub enum Embedder { + HuggingFace(hf::Embedder), + OpenAi(openai::Embedder), + UserProvided(manual::Embedder), +} + +#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)] +pub struct EmbeddingConfig { + pub embedder_options: EmbedderOptions, + pub prompt: PromptData, + // TODO: add metrics and anything needed +} + +#[derive(Clone, Default)] +pub struct EmbeddingConfigs(HashMap, Arc)>); + +impl EmbeddingConfigs { + pub fn new(data: HashMap, Arc)>) -> Self { + Self(data) + } + + pub fn get(&self, name: &str) -> Option<(Arc, Arc)> { + self.0.get(name).cloned() + } + + pub fn get_default(&self) -> Option<(Arc, Arc)> { + self.get_default_embedder_name().and_then(|default| self.get(&default)) + } + + pub fn get_default_embedder_name(&self) -> Option { + let mut it = self.0.keys(); + let first_name = it.next(); + let second_name = it.next(); + match (first_name, second_name) { + (None, _) => None, + (Some(first), None) => Some(first.to_owned()), + (Some(_), Some(_)) => Some("default".to_owned()), + } + } +} + +impl IntoIterator for EmbeddingConfigs { + type Item = (String, (Arc, Arc)); + + type IntoIter = std::collections::hash_map::IntoIter, Arc)>; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub enum EmbedderOptions { + HuggingFace(hf::EmbedderOptions), + OpenAi(openai::EmbedderOptions), + UserProvided(manual::EmbedderOptions), +} + +impl Default for EmbedderOptions { + fn default() -> Self { + Self::HuggingFace(Default::default()) + } +} + +impl EmbedderOptions { + pub fn huggingface() -> Self { + Self::HuggingFace(hf::EmbedderOptions::new()) + } + + pub fn openai(api_key: Option) -> Self { + Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key)) + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> std::result::Result { + Ok(match options { + EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?), + EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?), + EmbedderOptions::UserProvided(options) => { + Self::UserProvided(manual::Embedder::new(options)) + } + }) + } + + pub async fn embed( + &self, + texts: Vec, + ) -> std::result::Result>, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed(texts), + Embedder::OpenAi(embedder) => embedder.embed(texts).await, + Embedder::UserProvided(embedder) => embedder.embed(texts), + } + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> std::result::Result>>, EmbedError> { + match self { + Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks), + Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await, + Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks), + } + } + + pub fn chunk_count_hint(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(), + Embedder::OpenAi(embedder) => embedder.chunk_count_hint(), + Embedder::UserProvided(_) => 1, + } + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(), + Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(), + Embedder::UserProvided(_) => 1, + } + } + + pub fn dimensions(&self) -> usize { + match self { + Embedder::HuggingFace(embedder) => embedder.dimensions(), + Embedder::OpenAi(embedder) => embedder.dimensions(), + Embedder::UserProvided(embedder) => embedder.dimensions(), + } + } + + pub fn distribution(&self) -> Option { + match self { + Embedder::HuggingFace(embedder) => embedder.distribution(), + Embedder::OpenAi(embedder) => embedder.distribution(), + Embedder::UserProvided(_embedder) => None, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct DistributionShift { + pub current_mean: f32, + pub current_sigma: f32, +} + +impl DistributionShift { + /// `None` if sigma <= 0. + pub fn new(mean: f32, sigma: f32) -> Option { + if sigma <= 0.0 { + None + } else { + Some(Self { current_mean: mean, current_sigma: sigma }) + } + } + + pub fn shift(&self, score: f32) -> f32 { + // + // We're somewhat abusively mapping the distribution of distances to a gaussian. + // The parameters we're given is the mean and sigma of the native result distribution. + // We're using them to retarget the distribution to a gaussian centered on 0.5 with a sigma of 0.4. + + let target_mean = 0.5; + let target_sigma = 0.4; + + // a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive. + let factor = target_sigma / self.current_sigma; + // a*mu1 + b = mu2 => b = mu2 - a*mu1 + let offset = target_mean - (factor * self.current_mean); + + let mut score = factor * score + offset; + + // clamp the final score in the ]0, 1] interval. + if score <= 0.0 { + score = f32::EPSILON; + } + if score > 1.0 { + score = 1.0; + } + + score + } +} diff --git a/milli/src/vector/openai.rs b/milli/src/vector/openai.rs new file mode 100644 index 000000000..c11e6ddc6 --- /dev/null +++ b/milli/src/vector/openai.rs @@ -0,0 +1,445 @@ +use std::fmt::Display; + +use reqwest::StatusCode; +use serde::{Deserialize, Serialize}; + +use super::error::{EmbedError, NewEmbedderError}; +use super::{DistributionShift, Embedding, Embeddings}; + +#[derive(Debug)] +pub struct Embedder { + client: reqwest::Client, + tokenizer: tiktoken_rs::CoreBPE, + options: EmbedderOptions, +} + +#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)] +pub struct EmbedderOptions { + pub api_key: Option, + pub embedding_model: EmbeddingModel, +} + +#[derive( + Debug, + Clone, + Copy, + Default, + Hash, + PartialEq, + Eq, + serde::Serialize, + serde::Deserialize, + deserr::Deserr, +)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub enum EmbeddingModel { + #[default] + #[serde(rename = "text-embedding-ada-002")] + #[deserr(rename = "text-embedding-ada-002")] + TextEmbeddingAda002, +} + +impl EmbeddingModel { + pub fn max_token(&self) -> usize { + match self { + EmbeddingModel::TextEmbeddingAda002 => 8191, + } + } + + pub fn dimensions(&self) -> usize { + match self { + EmbeddingModel::TextEmbeddingAda002 => 1536, + } + } + + pub fn name(&self) -> &'static str { + match self { + EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002", + } + } + + pub fn from_name(name: &'static str) -> Option { + match name { + "text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002), + _ => None, + } + } + + fn distribution(&self) -> Option { + match self { + EmbeddingModel::TextEmbeddingAda002 => { + Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 }) + } + } + } +} + +pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings"; + +impl EmbedderOptions { + pub fn with_default_model(api_key: Option) -> Self { + Self { api_key, embedding_model: Default::default() } + } + + pub fn with_embedding_model(api_key: Option, embedding_model: EmbeddingModel) -> Self { + Self { api_key, embedding_model } + } +} + +impl Embedder { + pub fn new(options: EmbedderOptions) -> Result { + let mut headers = reqwest::header::HeaderMap::new(); + let mut inferred_api_key = Default::default(); + let api_key = options.api_key.as_ref().unwrap_or_else(|| { + inferred_api_key = infer_api_key(); + &inferred_api_key + }); + headers.insert( + reqwest::header::AUTHORIZATION, + reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key)) + .map_err(NewEmbedderError::openai_invalid_api_key_format)?, + ); + headers.insert( + reqwest::header::CONTENT_TYPE, + reqwest::header::HeaderValue::from_static("application/json"), + ); + let client = reqwest::ClientBuilder::new() + .default_headers(headers) + .build() + .map_err(NewEmbedderError::openai_initialize_web_client)?; + + // looking at the code it is very unclear that this can actually fail. + let tokenizer = tiktoken_rs::cl100k_base().unwrap(); + + Ok(Self { options, client, tokenizer }) + } + + pub async fn embed(&self, texts: Vec) -> Result>, EmbedError> { + let mut tokenized = false; + + for attempt in 0..7 { + let result = if tokenized { + self.try_embed_tokenized(&texts).await + } else { + self.try_embed(&texts).await + }; + + let retry_duration = match result { + Ok(embeddings) => return Ok(embeddings), + Err(retry) => { + log::warn!("Failed: {}", retry.error); + tokenized |= retry.must_tokenize(); + retry.into_duration(attempt) + } + }?; + log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis()); + tokio::time::sleep(retry_duration).await; + } + + let result = if tokenized { + self.try_embed_tokenized(&texts).await + } else { + self.try_embed(&texts).await + }; + + result.map_err(Retry::into_error) + } + + async fn check_response(response: reqwest::Response) -> Result { + if !response.status().is_success() { + match response.status() { + StatusCode::UNAUTHORIZED => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + return Err(Retry::give_up(EmbedError::openai_auth_error( + error_response.error, + ))); + } + StatusCode::TOO_MANY_REQUESTS => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + return Err(Retry::rate_limited(EmbedError::openai_too_many_requests( + error_response.error, + ))); + } + StatusCode::INTERNAL_SERVER_ERROR => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + return Err(Retry::retry_later(EmbedError::openai_internal_server_error( + error_response.error, + ))); + } + StatusCode::SERVICE_UNAVAILABLE => { + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + return Err(Retry::retry_later(EmbedError::openai_internal_server_error( + error_response.error, + ))); + } + StatusCode::BAD_REQUEST => { + // Most probably, one text contained too many tokens + let error_response: OpenAiErrorResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt."); + + return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens( + error_response.error, + ))); + } + code => { + return Err(Retry::give_up(EmbedError::openai_unhandled_status_code( + code.as_u16(), + ))); + } + } + } + Ok(response) + } + + async fn try_embed + serde::Serialize>( + &self, + texts: &[S], + ) -> Result>, Retry> { + for text in texts { + log::trace!("Received prompt: {}", text.as_ref()) + } + let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts }; + let response = self + .client + .post(OPENAI_EMBEDDINGS_URL) + .json(&request) + .send() + .await + .map_err(EmbedError::openai_network) + .map_err(Retry::retry_later)?; + + let response = Self::check_response(response).await?; + + let response: OpenAiResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + + log::trace!("response: {:?}", response.data); + + Ok(response + .data + .into_iter() + .map(|data| Embeddings::from_single_embedding(data.embedding)) + .collect()) + } + + async fn try_embed_tokenized(&self, text: &[String]) -> Result>, Retry> { + pub const OVERLAP_SIZE: usize = 200; + let mut all_embeddings = Vec::with_capacity(text.len()); + for text in text { + let max_token_count = self.options.embedding_model.max_token(); + let encoded = self.tokenizer.encode_ordinary(text.as_str()); + let len = encoded.len(); + if len < max_token_count { + all_embeddings.append(&mut self.try_embed(&[text]).await?); + continue; + } + + let mut tokens = encoded.as_slice(); + let mut embeddings_for_prompt = + Embeddings::new(self.options.embedding_model.dimensions()); + while tokens.len() > max_token_count { + let window = &tokens[..max_token_count]; + embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap(); + + tokens = &tokens[max_token_count - OVERLAP_SIZE..]; + } + + // end of text + embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap(); + + all_embeddings.push(embeddings_for_prompt); + } + Ok(all_embeddings) + } + + async fn embed_tokens(&self, tokens: &[usize]) -> Result { + for attempt in 0..9 { + let duration = match self.try_embed_tokens(tokens).await { + Ok(embedding) => return Ok(embedding), + Err(retry) => retry.into_duration(attempt), + } + .map_err(Retry::retry_later)?; + + tokio::time::sleep(duration).await; + } + + self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error())) + } + + async fn try_embed_tokens(&self, tokens: &[usize]) -> Result { + let request = + OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens }; + let response = self + .client + .post(OPENAI_EMBEDDINGS_URL) + .json(&request) + .send() + .await + .map_err(EmbedError::openai_network) + .map_err(Retry::retry_later)?; + + let response = Self::check_response(response).await?; + + let mut response: OpenAiResponse = response + .json() + .await + .map_err(EmbedError::openai_unexpected) + .map_err(Retry::retry_later)?; + Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default()) + } + + pub async fn embed_chunks( + &self, + text_chunks: Vec>, + ) -> Result>>, EmbedError> { + futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts))) + .await + } + + pub fn chunk_count_hint(&self) -> usize { + 10 + } + + pub fn prompt_count_in_chunk_hint(&self) -> usize { + 10 + } + + pub fn dimensions(&self) -> usize { + self.options.embedding_model.dimensions() + } + + pub fn distribution(&self) -> Option { + self.options.embedding_model.distribution() + } +} + +// retrying in case of failure + +struct Retry { + error: EmbedError, + strategy: RetryStrategy, +} + +enum RetryStrategy { + GiveUp, + Retry, + RetryTokenized, + RetryAfterRateLimit, +} + +impl Retry { + fn give_up(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::GiveUp } + } + + fn retry_later(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::Retry } + } + + fn retry_tokenized(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryTokenized } + } + + fn rate_limited(error: EmbedError) -> Self { + Self { error, strategy: RetryStrategy::RetryAfterRateLimit } + } + + fn into_duration(self, attempt: u32) -> Result { + match self.strategy { + RetryStrategy::GiveUp => Err(self.error), + RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))), + RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)), + RetryStrategy::RetryAfterRateLimit => { + Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt))) + } + } + } + + fn must_tokenize(&self) -> bool { + matches!(self.strategy, RetryStrategy::RetryTokenized) + } + + fn into_error(self) -> EmbedError { + self.error + } +} + +// openai api structs + +#[derive(Debug, Serialize)] +struct OpenAiRequest<'a, S: AsRef + serde::Serialize> { + model: &'a str, + input: &'a [S], +} + +#[derive(Debug, Serialize)] +struct OpenAiTokensRequest<'a> { + model: &'a str, + input: &'a [usize], +} + +#[derive(Debug, Deserialize)] +struct OpenAiResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAiErrorResponse { + error: OpenAiError, +} + +#[derive(Debug, Deserialize)] +pub struct OpenAiError { + message: String, + // type: String, + code: Option, +} + +impl Display for OpenAiError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match &self.code { + Some(code) => write!(f, "{} ({})", self.message, code), + None => write!(f, "{}", self.message), + } + } +} + +#[derive(Debug, Deserialize)] +struct OpenAiEmbedding { + embedding: Embedding, + // object: String, + // index: usize, +} + +fn infer_api_key() -> String { + std::env::var("MEILI_OPENAI_API_KEY") + .or_else(|_| std::env::var("OPENAI_API_KEY")) + .unwrap_or_default() +} diff --git a/milli/src/vector/settings.rs b/milli/src/vector/settings.rs new file mode 100644 index 000000000..912cdf953 --- /dev/null +++ b/milli/src/vector/settings.rs @@ -0,0 +1,292 @@ +use deserr::Deserr; +use serde::{Deserialize, Serialize}; + +use crate::prompt::PromptData; +use crate::update::Setting; +use crate::vector::EmbeddingConfig; + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct EmbeddingSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")] + #[deserr(default, rename = "source")] + pub embedder_options: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub document_template: Setting, +} + +impl EmbeddingSettings { + pub fn apply(&mut self, new: Self) { + let EmbeddingSettings { embedder_options, document_template: prompt } = new; + self.embedder_options.apply(embedder_options); + self.document_template.apply(prompt); + } +} + +impl From for EmbeddingSettings { + fn from(value: EmbeddingConfig) -> Self { + Self { + embedder_options: Setting::Set(value.embedder_options.into()), + document_template: Setting::Set(value.prompt.into()), + } + } +} + +impl From for EmbeddingConfig { + fn from(value: EmbeddingSettings) -> Self { + let mut this = Self::default(); + let EmbeddingSettings { embedder_options, document_template: prompt } = value; + if let Some(embedder_options) = embedder_options.set() { + this.embedder_options = embedder_options.into(); + } + if let Some(prompt) = prompt.set() { + this.prompt = prompt.into(); + } + this + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct PromptSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub template: Setting, +} + +impl PromptSettings { + pub fn apply(&mut self, new: Self) { + let PromptSettings { template } = new; + self.template.apply(template); + } +} + +impl From for PromptSettings { + fn from(value: PromptData) -> Self { + Self { template: Setting::Set(value.template) } + } +} + +impl From for PromptData { + fn from(value: PromptSettings) -> Self { + let mut this = PromptData::default(); + let PromptSettings { template } = value; + if let Some(template) = template.set() { + this.template = template; + } + this + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +pub enum EmbedderSettings { + HuggingFace(Setting), + OpenAi(Setting), + UserProvided(UserProvidedSettings), +} + +impl Deserr for EmbedderSettings +where + E: deserr::DeserializeError, +{ + fn deserialize_from_value( + value: deserr::Value, + location: deserr::ValuePointerRef, + ) -> Result { + match value { + deserr::Value::Map(map) => { + if deserr::Map::len(&map) != 1 { + return Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::Unexpected { + msg: format!( + "Expected a single field, got {} fields", + deserr::Map::len(&map) + ), + }, + location, + ))); + } + let mut it = deserr::Map::into_iter(map); + let (k, v) = it.next().unwrap(); + + match k.as_str() { + "huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set( + HfEmbedderSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + ))), + "openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set( + OpenAiEmbedderSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + ))), + "userProvided" => Ok(EmbedderSettings::UserProvided( + UserProvidedSettings::deserialize_from_value( + v.into_value(), + location.push_key(&k), + )?, + )), + other => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::UnknownKey { + key: other, + accepted: &["huggingFace", "openAi", "userProvided"], + }, + location, + ))), + } + } + _ => Err(deserr::take_cf_content(E::error::( + None, + deserr::ErrorKind::IncorrectValueKind { + actual: value, + accepted: &[deserr::ValueKind::Map], + }, + location, + ))), + } + } +} + +impl Default for EmbedderSettings { + fn default() -> Self { + Self::OpenAi(Default::default()) + } +} + +impl From for EmbedderSettings { + fn from(value: crate::vector::EmbedderOptions) -> Self { + match value { + crate::vector::EmbedderOptions::HuggingFace(hf) => { + Self::HuggingFace(Setting::Set(hf.into())) + } + crate::vector::EmbedderOptions::OpenAi(openai) => { + Self::OpenAi(Setting::Set(openai.into())) + } + crate::vector::EmbedderOptions::UserProvided(user_provided) => { + Self::UserProvided(user_provided.into()) + } + } + } +} + +impl From for crate::vector::EmbedderOptions { + fn from(value: EmbedderSettings) -> Self { + match value { + EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()), + EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()), + EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()), + EmbedderSettings::OpenAi(_setting) => { + Self::OpenAi(crate::vector::openai::EmbedderOptions::with_default_model(None)) + } + EmbedderSettings::UserProvided(user_provided) => { + Self::UserProvided(user_provided.into()) + } + } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct HfEmbedderSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub model: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub revision: Setting, +} + +impl HfEmbedderSettings { + pub fn apply(&mut self, new: Self) { + let HfEmbedderSettings { model, revision } = new; + self.model.apply(model); + self.revision.apply(revision); + } +} + +impl From for HfEmbedderSettings { + fn from(value: crate::vector::hf::EmbedderOptions) -> Self { + Self { + model: Setting::Set(value.model), + revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet), + } + } +} + +impl From for crate::vector::hf::EmbedderOptions { + fn from(value: HfEmbedderSettings) -> Self { + let HfEmbedderSettings { model, revision } = value; + let mut this = Self::default(); + if let Some(model) = model.set() { + this.model = model; + } + if let Some(revision) = revision.set() { + this.revision = Some(revision); + } + this + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct OpenAiEmbedderSettings { + #[serde(default, skip_serializing_if = "Setting::is_not_set")] + #[deserr(default)] + pub api_key: Setting, + #[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "model")] + #[deserr(default, rename = "model")] + pub embedding_model: Setting, +} + +impl OpenAiEmbedderSettings { + pub fn apply(&mut self, new: Self) { + let Self { api_key, embedding_model: embedding_mode } = new; + self.api_key.apply(api_key); + self.embedding_model.apply(embedding_mode); + } +} + +impl From for OpenAiEmbedderSettings { + fn from(value: crate::vector::openai::EmbedderOptions) -> Self { + Self { + api_key: value.api_key.map(Setting::Set).unwrap_or(Setting::Reset), + embedding_model: Setting::Set(value.embedding_model), + } + } +} + +impl From for crate::vector::openai::EmbedderOptions { + fn from(value: OpenAiEmbedderSettings) -> Self { + let OpenAiEmbedderSettings { api_key, embedding_model } = value; + Self { api_key: api_key.set(), embedding_model: embedding_model.set().unwrap_or_default() } + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)] +#[serde(deny_unknown_fields, rename_all = "camelCase")] +#[deserr(rename_all = camelCase, deny_unknown_fields)] +pub struct UserProvidedSettings { + pub dimensions: usize, +} + +impl From for crate::vector::manual::EmbedderOptions { + fn from(value: UserProvidedSettings) -> Self { + Self { dimensions: value.dimensions } + } +} + +impl From for UserProvidedSettings { + fn from(value: crate::vector::manual::EmbedderOptions) -> Self { + Self { dimensions: value.dimensions } + } +}