mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-22 12:40:04 +01:00
Small commit to add hybrid search and autoembedding
This commit is contained in:
parent
21bcf32109
commit
13c2c6c16b
281
Cargo.lock
generated
281
Cargo.lock
generated
@ -46,7 +46,7 @@ dependencies = [
|
||||
"actix-tls",
|
||||
"actix-utils",
|
||||
"ahash 0.8.3",
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.5",
|
||||
"bitflags 1.3.2",
|
||||
"brotli",
|
||||
"bytes",
|
||||
@ -120,7 +120,7 @@ dependencies = [
|
||||
"futures-util",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"socket2",
|
||||
"socket2 0.4.9",
|
||||
"tokio",
|
||||
"tracing",
|
||||
]
|
||||
@ -201,7 +201,7 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"smallvec",
|
||||
"socket2",
|
||||
"socket2 0.4.9",
|
||||
"time",
|
||||
"url",
|
||||
]
|
||||
@ -365,6 +365,12 @@ dependencies = [
|
||||
"backtrace",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anymap2"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d301b3b94cb4b2f23d7917810addbbaff90738e0ca2be692bd027e70d7e0330c"
|
||||
|
||||
[[package]]
|
||||
name = "arbitrary"
|
||||
version = "1.3.0"
|
||||
@ -455,9 +461,9 @@ checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
|
||||
|
||||
[[package]]
|
||||
name = "base64"
|
||||
version = "0.21.2"
|
||||
version = "0.21.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "604178f6c5c21f02dc555784810edfb88d34ac2c73b2eae109655649ee73ce3d"
|
||||
checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9"
|
||||
|
||||
[[package]]
|
||||
name = "base64ct"
|
||||
@ -508,6 +514,21 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-set"
|
||||
version = "0.5.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0700ddab506f33b20a03b13996eccd309a48e5ff77d0d95926aa0210fb4e95f1"
|
||||
dependencies = [
|
||||
"bit-vec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bit-vec"
|
||||
version = "0.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "1.3.2"
|
||||
@ -555,12 +576,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "bstr"
|
||||
version = "1.6.0"
|
||||
version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6798148dccfbff0fae41c7574d2fa8f1ef3492fba0face179de5d8d447d67b05"
|
||||
checksum = "542f33a8835a0884b006a0c3df3dadd99c0c3f296ed26c2fdc8028e01ad6230c"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
"regex-automata 0.3.6",
|
||||
"regex-automata 0.4.3",
|
||||
"serde",
|
||||
]
|
||||
|
||||
@ -1346,6 +1367,12 @@ dependencies = [
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "doc-comment"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10"
|
||||
|
||||
[[package]]
|
||||
name = "doxygen-rs"
|
||||
version = "0.2.2"
|
||||
@ -1562,6 +1589,16 @@ dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fancy-regex"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b95f7c0680e4142284cf8b22c14a476e87d61b004a3a0861872b32ef7ead40a2"
|
||||
dependencies = [
|
||||
"bit-set",
|
||||
"regex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastrand"
|
||||
version = "2.0.0"
|
||||
@ -1690,9 +1727,9 @@ checksum = "7ab85b9b05e3978cc9a9cf8fea7f01b494e1a09ed3037e16ba39edc7a29eb61a"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23342abe12aba583913b2e62f22225ff9c950774065e4bfb61a19cd9770fec40"
|
||||
checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@ -1705,9 +1742,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "955518d47e09b25bbebc7a18df10b81f0c766eaf4c4f1cccef2fca5f2a4fb5f2"
|
||||
checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
@ -1715,15 +1752,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4bca583b7e26f571124fe5b7561d49cb2868d79116cfa0eefce955557c6fee8c"
|
||||
checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ccecee823288125bd88b4d7f565c9e58e41858e47ab72e8ea2d64e93624386e0"
|
||||
checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
@ -1732,15 +1769,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4fff74096e71ed47f8e023204cfd0aa1289cd54ae5430a9523be060cdb849964"
|
||||
checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "89ca545a94061b6365f2c7355b4b32bd20df3ff95f02da9329b34ccc3bd6ee72"
|
||||
checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -1749,21 +1786,21 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f43be4fe21a13b9781a69afa4985b0f6ee0e1afab2c6f454a8cf30e2b2237b6e"
|
||||
checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "76d3d132be6c0e6aa1534069c705a74a5997a356c0dc2f86a47765e5617c5b65"
|
||||
checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.28"
|
||||
version = "0.3.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "26b01e40b772d54cf6c6d721c1d1abd0647a0106a12ecaa1c186273392a69533"
|
||||
checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
@ -2207,7 +2244,7 @@ dependencies = [
|
||||
"httpdate",
|
||||
"itoa",
|
||||
"pin-project-lite",
|
||||
"socket2",
|
||||
"socket2 0.4.9",
|
||||
"tokio",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
@ -2949,7 +2986,7 @@ version = "8.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.5",
|
||||
"pem",
|
||||
"ring",
|
||||
"serde",
|
||||
@ -2957,6 +2994,16 @@ dependencies = [
|
||||
"simple_asn1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "kstring"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec3066350882a1cd6d950d055997f379ac37fd39f81cd4d8ed186032eb3c5747"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "language-tags"
|
||||
version = "0.3.2"
|
||||
@ -2980,9 +3027,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.147"
|
||||
version = "0.2.150"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
|
||||
checksum = "89d92a4743f9a61002fae18374ed11e7973f530cb3a3255fb354818118b2203c"
|
||||
|
||||
[[package]]
|
||||
name = "libgit2-sys"
|
||||
@ -3251,6 +3298,63 @@ version = "0.4.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57bcfdad1b858c2db7c38303a6d2ad4dfaf5eb53dfeb0910128b2c26d6158503"
|
||||
|
||||
[[package]]
|
||||
name = "liquid"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "69f68ae1011499ae2ef879f631891f21c78e309755f4a5e483c4a8f12e10b609"
|
||||
dependencies = [
|
||||
"doc-comment",
|
||||
"liquid-core",
|
||||
"liquid-derive",
|
||||
"liquid-lib",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "liquid-core"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "79e0724dfcaad5cfb7965ea0f178ca0870b8d7315178f4a7179f5696f7f04d5f"
|
||||
dependencies = [
|
||||
"anymap2",
|
||||
"itertools 0.10.5",
|
||||
"kstring",
|
||||
"liquid-derive",
|
||||
"num-traits",
|
||||
"pest",
|
||||
"pest_derive",
|
||||
"regex",
|
||||
"serde",
|
||||
"time",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "liquid-derive"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fc2fb41a9bb4257a3803154bdf7e2df7d45197d1941c9b1a90ad815231630721"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "liquid-lib"
|
||||
version = "0.26.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e2a17e273a6fb1fb6268f7a5867ddfd0bd4683c7e19b51084f3d567fad4348c0"
|
||||
dependencies = [
|
||||
"itertools 0.10.5",
|
||||
"liquid-core",
|
||||
"once_cell",
|
||||
"percent-encoding",
|
||||
"regex",
|
||||
"time",
|
||||
"unicode-segmentation",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.6.1"
|
||||
@ -3483,7 +3587,7 @@ dependencies = [
|
||||
name = "meilisearch-auth"
|
||||
version = "1.5.1"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.5",
|
||||
"enum-iterator",
|
||||
"hmac",
|
||||
"maplit",
|
||||
@ -3544,9 +3648,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.5.0"
|
||||
version = "2.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d"
|
||||
checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167"
|
||||
|
||||
[[package]]
|
||||
name = "memmap2"
|
||||
@ -3589,6 +3693,7 @@ dependencies = [
|
||||
"filter-parser",
|
||||
"flatten-serde-json",
|
||||
"fst",
|
||||
"futures",
|
||||
"fxhash",
|
||||
"geoutils",
|
||||
"grenad",
|
||||
@ -3600,6 +3705,7 @@ dependencies = [
|
||||
"itertools 0.11.0",
|
||||
"json-depth-checker",
|
||||
"levenshtein_automata",
|
||||
"liquid",
|
||||
"log",
|
||||
"logging_timer",
|
||||
"maplit",
|
||||
@ -3607,6 +3713,7 @@ dependencies = [
|
||||
"meili-snap",
|
||||
"memmap2",
|
||||
"mimalloc",
|
||||
"nolife",
|
||||
"obkv",
|
||||
"once_cell",
|
||||
"ordered-float",
|
||||
@ -3614,6 +3721,7 @@ dependencies = [
|
||||
"rand",
|
||||
"rand_pcg",
|
||||
"rayon",
|
||||
"reqwest",
|
||||
"roaring",
|
||||
"rstar",
|
||||
"serde",
|
||||
@ -3624,8 +3732,10 @@ dependencies = [
|
||||
"smartstring",
|
||||
"tempfile",
|
||||
"thiserror",
|
||||
"tiktoken-rs",
|
||||
"time",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"uuid 1.5.0",
|
||||
]
|
||||
|
||||
@ -3671,9 +3781,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.8"
|
||||
version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "927a765cd3fc26206e66b296465fa9d3e5ab003e651c1b3c060e7956d96b19d2"
|
||||
checksum = "3dce281c5e46beae905d4de1870d8b1509a9142b62eedf18b443b011ca8343d0"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
@ -3725,6 +3835,12 @@ name = "nelson"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/meilisearch/nelson.git?rev=675f13885548fb415ead8fbb447e9e6d9314000a#675f13885548fb415ead8fbb447e9e6d9314000a"
|
||||
|
||||
[[package]]
|
||||
name = "nolife"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bc52aaf087e8a52e7a2692f83f2dac6ac7ff9d0136bf9c6ac496635cfe3e50dc"
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
@ -4480,6 +4596,12 @@ dependencies = [
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f"
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.7.4"
|
||||
@ -4488,11 +4610,11 @@ checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2"
|
||||
|
||||
[[package]]
|
||||
name = "reqwest"
|
||||
version = "0.11.18"
|
||||
version = "0.11.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cde824a14b7c14f85caff81225f411faacc04a2013f41670f41443742b1c1c55"
|
||||
checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.5",
|
||||
"bytes",
|
||||
"encoding_rs",
|
||||
"futures-core",
|
||||
@ -4514,6 +4636,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_urlencoded",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.1",
|
||||
"tower-service",
|
||||
@ -4521,7 +4644,7 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-futures",
|
||||
"web-sys",
|
||||
"webpki-roots 0.22.6",
|
||||
"webpki-roots 0.25.3",
|
||||
"winreg",
|
||||
]
|
||||
|
||||
@ -4582,6 +4705,12 @@ version = "0.1.23"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.4.0"
|
||||
@ -4648,7 +4777,7 @@ version = "1.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2d3987094b1d07b653b7dfdc3f70ce9a1da9c51ac18c1b06b662e4f9a0e9f4b2"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4977,6 +5106,16 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.5.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.2"
|
||||
@ -5097,6 +5236,27 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"core-foundation",
|
||||
"system-configuration-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "system-configuration-sys"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a75fb188eb626b924683e3b95e3a48e63551fcfb51949de2f06a9d91dbee93c9"
|
||||
dependencies = [
|
||||
"core-foundation-sys",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tar"
|
||||
version = "0.4.40"
|
||||
@ -5159,6 +5319,21 @@ dependencies = [
|
||||
"syn 2.0.28",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tiktoken-rs"
|
||||
version = "0.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a4427b6b1c6b38215b92dd47a83a0ecc6735573d0a5a4c14acc0ac5b33b28adb"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"base64 0.21.5",
|
||||
"bstr",
|
||||
"fancy-regex",
|
||||
"lazy_static",
|
||||
"parking_lot",
|
||||
"rustc-hash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.30"
|
||||
@ -5258,11 +5433,10 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.29.1"
|
||||
version = "1.34.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "532826ff75199d5833b9d2c5fe410f29235e25704ee5f0ef599fb51c21f4a4da"
|
||||
checksum = "d0c014766411e834f7af5b8f4cf46257aab4036ca95e9d2c144a10f59ad6f5b9"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"backtrace",
|
||||
"bytes",
|
||||
"libc",
|
||||
@ -5271,16 +5445,16 @@ dependencies = [
|
||||
"parking_lot",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
"socket2 0.5.5",
|
||||
"tokio-macros",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-macros"
|
||||
version = "2.1.0"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "630bdcf245f78637c13ec01ffae6187cca34625e8c63150d424b59e55af2675e"
|
||||
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -5508,7 +5682,7 @@ version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0b11c96ac7ee530603dcdf68ed1557050f374ce55a5a07193ebf8cbc9f8927e9"
|
||||
dependencies = [
|
||||
"base64 0.21.2",
|
||||
"base64 0.21.5",
|
||||
"flate2",
|
||||
"log",
|
||||
"native-tls",
|
||||
@ -5758,6 +5932,12 @@ dependencies = [
|
||||
"rustls-webpki 0.100.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.25.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10"
|
||||
|
||||
[[package]]
|
||||
name = "whatlang"
|
||||
version = "0.16.2"
|
||||
@ -5942,11 +6122,12 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "winreg"
|
||||
version = "0.10.1"
|
||||
version = "0.50.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80d0f4e272c85def139476380b12f9ac60926689dd2e01d4923222f40580869d"
|
||||
checksum = "524e57b2c537c0f9b1e69f1965311ec12182b4122e45035b1508cd24d2adadb1"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
"cfg-if",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -276,6 +276,7 @@ pub(crate) mod test {
|
||||
),
|
||||
}),
|
||||
pagination: Setting::NotSet,
|
||||
embedders: Setting::NotSet,
|
||||
_kind: std::marker::PhantomData,
|
||||
};
|
||||
settings.check()
|
||||
|
@ -378,6 +378,7 @@ impl<T> From<v5::Settings<T>> for v6::Settings<v6::Unchecked> {
|
||||
v5::Setting::Reset => v6::Setting::Reset,
|
||||
v5::Setting::NotSet => v6::Setting::NotSet,
|
||||
},
|
||||
embedders: v6::Setting::NotSet,
|
||||
_kind: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -1202,6 +1202,10 @@ impl IndexScheduler {
|
||||
|
||||
let config = IndexDocumentsConfig { update_method: method, ..Default::default() };
|
||||
|
||||
let embedder_configs = index.embedding_configs(index_wtxn)?;
|
||||
// TODO: consider Arc'ing the map too (we only need read access + we'll be cloning it multiple times, so really makes sense)
|
||||
let embedders = self.embedders(embedder_configs)?;
|
||||
|
||||
let mut builder = milli::update::IndexDocuments::new(
|
||||
index_wtxn,
|
||||
index,
|
||||
@ -1220,6 +1224,8 @@ impl IndexScheduler {
|
||||
let (new_builder, user_result) = builder.add_documents(reader)?;
|
||||
builder = new_builder;
|
||||
|
||||
builder = builder.with_embedders(embedders.clone());
|
||||
|
||||
let received_documents =
|
||||
if let Some(Details::DocumentAdditionOrUpdate {
|
||||
received_documents,
|
||||
@ -1345,6 +1351,9 @@ impl IndexScheduler {
|
||||
|
||||
for (task, (_, settings)) in tasks.iter_mut().zip(settings) {
|
||||
let checked_settings = settings.clone().check();
|
||||
if matches!(checked_settings.embedders, milli::update::Setting::Set(_)) {
|
||||
self.features().check_vector("Passing `embedders` in settings")?
|
||||
}
|
||||
if checked_settings.proximity_precision.set().is_some() {
|
||||
self.features.features().check_proximity_precision()?;
|
||||
}
|
||||
|
@ -56,12 +56,12 @@ impl RoFeatures {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn check_vector(&self) -> Result<()> {
|
||||
pub fn check_vector(&self, disabled_action: &'static str) -> Result<()> {
|
||||
if self.runtime.vector_store {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(FeatureNotEnabledError {
|
||||
disabled_action: "Passing `vector` as a query parameter",
|
||||
disabled_action,
|
||||
feature: "vector store",
|
||||
issue_link: "https://github.com/meilisearch/product/discussions/677",
|
||||
}
|
||||
|
@ -41,6 +41,7 @@ pub fn snapshot_index_scheduler(scheduler: &IndexScheduler) -> String {
|
||||
planned_failures: _,
|
||||
run_loop_iteration: _,
|
||||
currently_updating_index: _,
|
||||
embedders: _,
|
||||
} = scheduler;
|
||||
|
||||
let rtxn = env.read_txn().unwrap();
|
||||
|
@ -52,6 +52,7 @@ use meilisearch_types::heed::types::{SerdeBincode, SerdeJson, Str, I128};
|
||||
use meilisearch_types::heed::{self, Database, Env, PutFlags, RoTxn, RwTxn};
|
||||
use meilisearch_types::milli::documents::DocumentsBatchBuilder;
|
||||
use meilisearch_types::milli::update::IndexerConfig;
|
||||
use meilisearch_types::milli::vector::{Embedder, EmbedderOptions};
|
||||
use meilisearch_types::milli::{self, CboRoaringBitmapCodec, Index, RoaringBitmapCodec, BEU32};
|
||||
use meilisearch_types::tasks::{Kind, KindWithContent, Status, Task};
|
||||
use puffin::FrameView;
|
||||
@ -341,6 +342,8 @@ pub struct IndexScheduler {
|
||||
/// so that a handle to the index is available from other threads (search) in an optimized manner.
|
||||
currently_updating_index: Arc<RwLock<Option<(String, Index)>>>,
|
||||
|
||||
embedders: Arc<RwLock<HashMap<EmbedderOptions, std::sync::Arc<Embedder>>>>,
|
||||
|
||||
// ================= test
|
||||
// The next entry is dedicated to the tests.
|
||||
/// Provide a way to set a breakpoint in multiple part of the scheduler.
|
||||
@ -386,6 +389,7 @@ impl IndexScheduler {
|
||||
auth_path: self.auth_path.clone(),
|
||||
version_file_path: self.version_file_path.clone(),
|
||||
currently_updating_index: self.currently_updating_index.clone(),
|
||||
embedders: self.embedders.clone(),
|
||||
#[cfg(test)]
|
||||
test_breakpoint_sdr: self.test_breakpoint_sdr.clone(),
|
||||
#[cfg(test)]
|
||||
@ -484,6 +488,7 @@ impl IndexScheduler {
|
||||
auth_path: options.auth_path,
|
||||
version_file_path: options.version_file_path,
|
||||
currently_updating_index: Arc::new(RwLock::new(None)),
|
||||
embedders: Default::default(),
|
||||
|
||||
#[cfg(test)]
|
||||
test_breakpoint_sdr,
|
||||
@ -1333,6 +1338,42 @@ impl IndexScheduler {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: consider using a type alias or a struct embedder/template
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub fn embedders(
|
||||
&self,
|
||||
embedding_configs: Vec<(String, milli::vector::EmbeddingConfig)>,
|
||||
) -> Result<HashMap<String, (Arc<milli::vector::Embedder>, Arc<milli::prompt::Prompt>)>> {
|
||||
let res: Result<_> = embedding_configs
|
||||
.into_iter()
|
||||
.map(|(name, milli::vector::EmbeddingConfig { embedder_options, prompt })| {
|
||||
let prompt =
|
||||
Arc::new(prompt.try_into().map_err(meilisearch_types::milli::Error::from)?);
|
||||
// optimistically return existing embedder
|
||||
{
|
||||
let embedders = self.embedders.read().unwrap();
|
||||
if let Some(embedder) = embedders.get(&embedder_options) {
|
||||
return Ok((name, (embedder.clone(), prompt)));
|
||||
}
|
||||
}
|
||||
|
||||
// add missing embedder
|
||||
let embedder = Arc::new(
|
||||
Embedder::new(embedder_options.clone())
|
||||
.map_err(meilisearch_types::milli::vector::Error::from)
|
||||
.map_err(meilisearch_types::milli::UserError::from)
|
||||
.map_err(meilisearch_types::milli::Error::from)?,
|
||||
);
|
||||
{
|
||||
let mut embedders = self.embedders.write().unwrap();
|
||||
embedders.insert(embedder_options, embedder.clone());
|
||||
}
|
||||
Ok((name, (embedder, prompt)))
|
||||
})
|
||||
.collect();
|
||||
res
|
||||
}
|
||||
|
||||
/// Blocks the thread until the test handle asks to progress to/through this breakpoint.
|
||||
///
|
||||
/// Two messages are sent through the channel for each breakpoint.
|
||||
|
@ -256,6 +256,7 @@ InvalidSettingsProximityPrecision , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsEmbedders , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsSortableAttributes , InvalidRequest , BAD_REQUEST ;
|
||||
@ -303,7 +304,8 @@ TaskNotFound , InvalidRequest , NOT_FOUND ;
|
||||
TooManyOpenFiles , System , UNPROCESSABLE_ENTITY ;
|
||||
UnretrievableDocument , Internal , BAD_REQUEST ;
|
||||
UnretrievableErrorCode , InvalidRequest , BAD_REQUEST ;
|
||||
UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE
|
||||
UnsupportedMediaType , InvalidRequest , UNSUPPORTED_MEDIA_TYPE ;
|
||||
VectorEmbeddingError , InvalidRequest , BAD_REQUEST
|
||||
}
|
||||
|
||||
impl ErrorCode for JoinError {
|
||||
@ -336,6 +338,9 @@ impl ErrorCode for milli::Error {
|
||||
UserError::InvalidDocumentId { .. } | UserError::TooManyDocumentIds { .. } => {
|
||||
Code::InvalidDocumentId
|
||||
}
|
||||
UserError::MissingDocumentField(_) => Code::InvalidDocumentFields,
|
||||
UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
|
||||
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
|
||||
UserError::NoPrimaryKeyCandidateFound => Code::IndexPrimaryKeyNoCandidateFound,
|
||||
UserError::MultiplePrimaryKeyCandidatesFound { .. } => {
|
||||
Code::IndexPrimaryKeyMultipleCandidatesFound
|
||||
@ -358,6 +363,7 @@ impl ErrorCode for milli::Error {
|
||||
UserError::InvalidMinTypoWordLenSetting(_, _) => {
|
||||
Code::InvalidSettingsTypoTolerance
|
||||
}
|
||||
UserError::VectorEmbeddingError(_) => Code::VectorEmbeddingError,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -199,6 +199,10 @@ pub struct Settings<T> {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSettingsPagination>)]
|
||||
pub pagination: Setting<PaginationSettings>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSettingsEmbedders>)]
|
||||
pub embedders: Setting<BTreeMap<String, Setting<milli::vector::settings::EmbeddingSettings>>>,
|
||||
|
||||
#[serde(skip)]
|
||||
#[deserr(skip)]
|
||||
pub _kind: PhantomData<T>,
|
||||
@ -222,6 +226,7 @@ impl Settings<Checked> {
|
||||
typo_tolerance: Setting::Reset,
|
||||
faceting: Setting::Reset,
|
||||
pagination: Setting::Reset,
|
||||
embedders: Setting::Reset,
|
||||
_kind: PhantomData,
|
||||
}
|
||||
}
|
||||
@ -243,6 +248,7 @@ impl Settings<Checked> {
|
||||
typo_tolerance,
|
||||
faceting,
|
||||
pagination,
|
||||
embedders,
|
||||
..
|
||||
} = self;
|
||||
|
||||
@ -262,6 +268,7 @@ impl Settings<Checked> {
|
||||
typo_tolerance,
|
||||
faceting,
|
||||
pagination,
|
||||
embedders,
|
||||
_kind: PhantomData,
|
||||
}
|
||||
}
|
||||
@ -307,6 +314,7 @@ impl Settings<Unchecked> {
|
||||
typo_tolerance: self.typo_tolerance,
|
||||
faceting: self.faceting,
|
||||
pagination: self.pagination,
|
||||
embedders: self.embedders,
|
||||
_kind: PhantomData,
|
||||
}
|
||||
}
|
||||
@ -490,6 +498,12 @@ pub fn apply_settings_to_builder(
|
||||
Setting::Reset => builder.reset_pagination_max_total_hits(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.embedders.clone() {
|
||||
Setting::Set(value) => builder.set_embedder_settings(value),
|
||||
Setting::Reset => builder.reset_embedder_settings(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn settings(
|
||||
@ -571,6 +585,12 @@ pub fn settings(
|
||||
),
|
||||
};
|
||||
|
||||
let embedders = index
|
||||
.embedding_configs(rtxn)?
|
||||
.into_iter()
|
||||
.map(|(name, config)| (name, Setting::Set(config.into())))
|
||||
.collect();
|
||||
|
||||
Ok(Settings {
|
||||
displayed_attributes: match displayed_attributes {
|
||||
Some(attrs) => Setting::Set(attrs),
|
||||
@ -599,6 +619,7 @@ pub fn settings(
|
||||
typo_tolerance: Setting::Set(typo_tolerance),
|
||||
faceting: Setting::Set(faceting),
|
||||
pagination: Setting::Set(pagination),
|
||||
embedders: Setting::Set(embedders),
|
||||
_kind: PhantomData,
|
||||
})
|
||||
}
|
||||
@ -747,6 +768,7 @@ pub(crate) mod test {
|
||||
typo_tolerance: Setting::NotSet,
|
||||
faceting: Setting::NotSet,
|
||||
pagination: Setting::NotSet,
|
||||
embedders: Setting::NotSet,
|
||||
_kind: PhantomData::<Unchecked>,
|
||||
};
|
||||
|
||||
@ -772,6 +794,7 @@ pub(crate) mod test {
|
||||
typo_tolerance: Setting::NotSet,
|
||||
faceting: Setting::NotSet,
|
||||
pagination: Setting::NotSet,
|
||||
embedders: Setting::NotSet,
|
||||
_kind: PhantomData::<Unchecked>,
|
||||
};
|
||||
|
||||
|
@ -686,7 +686,7 @@ impl SearchAggregator {
|
||||
ret.max_terms_number = q.split_whitespace().count();
|
||||
}
|
||||
|
||||
if let Some(ref vector) = vector {
|
||||
if let Some(meilisearch_types::milli::VectorQuery::Vector(ref vector)) = vector {
|
||||
ret.max_vector_size = vector.len();
|
||||
}
|
||||
|
||||
|
@ -19,7 +19,11 @@ static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
/// does all the setup before meilisearch is launched
|
||||
fn setup(opt: &Opt) -> anyhow::Result<()> {
|
||||
let mut log_builder = env_logger::Builder::new();
|
||||
log_builder.parse_filters(&opt.log_level.to_string());
|
||||
let log_filters = format!(
|
||||
"{},h2=warn,hyper=warn,tokio_util=warn,tracing=warn,rustls=warn,mio=warn,reqwest=warn",
|
||||
opt.log_level
|
||||
);
|
||||
log_builder.parse_filters(&log_filters);
|
||||
|
||||
log_builder.init();
|
||||
|
||||
|
@ -7,6 +7,7 @@ use meilisearch_types::deserr::DeserrJsonError;
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::VectorQuery;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::analytics::{Analytics, FacetSearchAggregator};
|
||||
@ -117,7 +118,7 @@ impl From<FacetSearchQuery> for SearchQuery {
|
||||
highlight_post_tag: DEFAULT_HIGHLIGHT_POST_TAG(),
|
||||
crop_marker: DEFAULT_CROP_MARKER(),
|
||||
matching_strategy,
|
||||
vector,
|
||||
vector: vector.map(VectorQuery::Vector),
|
||||
attributes_to_search_on,
|
||||
}
|
||||
}
|
||||
|
@ -2,12 +2,13 @@ use actix_web::web::Data;
|
||||
use actix_web::{web, HttpRequest, HttpResponse};
|
||||
use deserr::actix_web::{AwebJson, AwebQueryParameter};
|
||||
use index_scheduler::IndexScheduler;
|
||||
use log::debug;
|
||||
use log::{debug, warn};
|
||||
use meilisearch_types::deserr::query_params::Param;
|
||||
use meilisearch_types::deserr::{DeserrJsonError, DeserrQueryParamError};
|
||||
use meilisearch_types::error::deserr_codes::*;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::VectorQuery;
|
||||
use meilisearch_types::serde_cs::vec::CS;
|
||||
use serde_json::Value;
|
||||
|
||||
@ -88,7 +89,7 @@ impl From<SearchQueryGet> for SearchQuery {
|
||||
|
||||
Self {
|
||||
q: other.q,
|
||||
vector: other.vector.map(CS::into_inner),
|
||||
vector: other.vector.map(CS::into_inner).map(VectorQuery::Vector),
|
||||
offset: other.offset.0,
|
||||
limit: other.limit.0,
|
||||
page: other.page.as_deref().copied(),
|
||||
@ -193,6 +194,9 @@ pub async fn search_with_post(
|
||||
let index = index_scheduler.index(&index_uid)?;
|
||||
|
||||
let features = index_scheduler.features();
|
||||
|
||||
embed(&mut query, index_scheduler.get_ref(), &index).await?;
|
||||
|
||||
let search_result =
|
||||
tokio::task::spawn_blocking(move || perform_search(&index, query, features)).await?;
|
||||
if let Ok(ref search_result) = search_result {
|
||||
@ -206,6 +210,38 @@ pub async fn search_with_post(
|
||||
Ok(HttpResponse::Ok().json(search_result))
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
query: &mut SearchQuery,
|
||||
index_scheduler: &IndexScheduler,
|
||||
index: &meilisearch_types::milli::Index,
|
||||
) -> Result<(), ResponseError> {
|
||||
if let Some(VectorQuery::String(prompt)) = query.vector.take() {
|
||||
let embedder_configs = index.embedding_configs(&index.read_txn()?)?;
|
||||
let embedder = index_scheduler.embedders(embedder_configs)?;
|
||||
|
||||
/// FIXME: add error if no embedder, remove unwrap, support multiple embedders
|
||||
let embeddings = embedder
|
||||
.get("default")
|
||||
.unwrap()
|
||||
.0
|
||||
.embed(vec![prompt])
|
||||
.await
|
||||
.map_err(meilisearch_types::milli::vector::Error::from)
|
||||
.map_err(meilisearch_types::milli::UserError::from)
|
||||
.map_err(meilisearch_types::milli::Error::from)?
|
||||
.pop()
|
||||
.expect("No vector returned from embedding");
|
||||
|
||||
if embeddings.iter().nth(1).is_some() {
|
||||
warn!("Ignoring embeddings past the first one in long search query");
|
||||
query.vector = Some(VectorQuery::Vector(embeddings.iter().next().unwrap().to_vec()));
|
||||
} else {
|
||||
query.vector = Some(VectorQuery::Vector(embeddings.into_inner()));
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
@ -13,6 +13,7 @@ use crate::analytics::{Analytics, MultiSearchAggregator};
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::{AuthenticationError, GuardedData};
|
||||
use crate::extractors::sequential_extractor::SeqHandler;
|
||||
use crate::routes::indexes::search::embed;
|
||||
use crate::search::{
|
||||
add_search_rules, perform_search, SearchQueryWithIndex, SearchResultWithIndex,
|
||||
};
|
||||
@ -74,6 +75,8 @@ pub async fn multi_search_with_post(
|
||||
})
|
||||
.with_index(query_index)?;
|
||||
|
||||
embed(&mut query, index_scheduler.get_ref(), &index).await.with_index(query_index)?;
|
||||
|
||||
let search_result =
|
||||
tokio::task::spawn_blocking(move || perform_search(&index, query, features))
|
||||
.await
|
||||
|
@ -16,6 +16,7 @@ use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use meilisearch_types::milli::{
|
||||
dot_product_similarity, FacetValueHit, InternalError, OrderBy, SearchForFacetValues,
|
||||
VectorQuery,
|
||||
};
|
||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||
use meilisearch_types::{milli, Document};
|
||||
@ -46,7 +47,7 @@ pub struct SearchQuery {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub q: Option<String>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchVector>)]
|
||||
pub vector: Option<Vec<f32>>,
|
||||
pub vector: Option<milli::VectorQuery>,
|
||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||
pub offset: usize,
|
||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||
@ -105,7 +106,7 @@ pub struct SearchQueryWithIndex {
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub q: Option<String>,
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSearchQ>)]
|
||||
pub vector: Option<Vec<f32>>,
|
||||
pub vector: Option<VectorQuery>,
|
||||
#[deserr(default = DEFAULT_SEARCH_OFFSET(), error = DeserrJsonError<InvalidSearchOffset>)]
|
||||
pub offset: usize,
|
||||
#[deserr(default = DEFAULT_SEARCH_LIMIT(), error = DeserrJsonError<InvalidSearchLimit>)]
|
||||
@ -339,11 +340,18 @@ fn prepare_search<'t>(
|
||||
let mut search = index.search(rtxn);
|
||||
|
||||
if query.vector.is_some() && query.q.is_some() {
|
||||
warn!("Ignoring the query string `q` when used with the `vector` parameter.");
|
||||
warn!("Attempting hybrid search");
|
||||
}
|
||||
|
||||
if let Some(ref vector) = query.vector {
|
||||
search.vector(vector.clone());
|
||||
match vector {
|
||||
VectorQuery::Vector(vector) => {
|
||||
search.vector(vector.clone());
|
||||
}
|
||||
VectorQuery::String(_) => {
|
||||
panic!("Failed while preparing search; caller did not generate embedding for query")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(ref query) = query.q {
|
||||
@ -375,7 +383,7 @@ fn prepare_search<'t>(
|
||||
}
|
||||
|
||||
if query.vector.is_some() {
|
||||
features.check_vector()?;
|
||||
features.check_vector("Passing `vector` as a query parameter")?;
|
||||
}
|
||||
|
||||
// compute the offset on the limit depending on the pagination mode.
|
||||
@ -429,7 +437,11 @@ pub fn perform_search(
|
||||
prepare_search(index, &rtxn, &query, features)?;
|
||||
|
||||
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
|
||||
search.execute()?;
|
||||
if query.q.is_some() && query.vector.is_some() {
|
||||
search.execute_hybrid()?
|
||||
} else {
|
||||
search.execute()?
|
||||
};
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
|
||||
|
||||
@ -538,13 +550,13 @@ pub fn perform_search(
|
||||
insert_geo_distance(sort, &mut document);
|
||||
}
|
||||
|
||||
let semantic_score = match query.vector.as_ref() {
|
||||
let semantic_score = /*match query.vector.as_ref() {
|
||||
Some(vector) => match extract_field("_vectors", &fields_ids_map, obkv)? {
|
||||
Some(vectors) => compute_semantic_score(vector, vectors)?,
|
||||
None => None,
|
||||
},
|
||||
None => None,
|
||||
};
|
||||
};*/ None;
|
||||
|
||||
let ranking_score =
|
||||
query.show_ranking_score.then(|| ScoreDetails::global_score(score.iter()));
|
||||
@ -629,7 +641,8 @@ pub fn perform_search(
|
||||
hits: documents,
|
||||
hits_info,
|
||||
query: query.q.unwrap_or_default(),
|
||||
vector: query.vector,
|
||||
// FIXME: display input vector
|
||||
vector: None,
|
||||
processing_time_ms: before_search.elapsed().as_millis(),
|
||||
facet_distribution,
|
||||
facet_stats,
|
||||
|
@ -27,10 +27,13 @@ fst = "0.4.7"
|
||||
fxhash = "0.2.1"
|
||||
geoutils = "0.5.1"
|
||||
grenad = { version = "0.4.5", default-features = false, features = [
|
||||
"rayon", "tempfile"
|
||||
"rayon",
|
||||
"tempfile",
|
||||
] }
|
||||
heed = { version = "0.20.0-alpha.9", default-features = false, features = [
|
||||
"serde-json", "serde-bincode", "read-txn-no-tls"
|
||||
"serde-json",
|
||||
"serde-bincode",
|
||||
"read-txn-no-tls",
|
||||
] }
|
||||
indexmap = { version = "2.0.0", features = ["serde"] }
|
||||
instant-distance = { version = "0.6.1", features = ["with-serde"] }
|
||||
@ -77,6 +80,15 @@ candle-transformers = { git = "https://github.com/huggingface/candle.git", versi
|
||||
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.3.1" }
|
||||
tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.14.1", version = "0.14.1" }
|
||||
hf-hub = "0.3.2"
|
||||
tokio = { version = "1.34.0", features = ["rt"] }
|
||||
futures = "0.3.29"
|
||||
nolife = { version = "0.3.1" }
|
||||
reqwest = { version = "0.11.16", features = [
|
||||
"rustls-tls",
|
||||
"json",
|
||||
], default-features = false }
|
||||
tiktoken-rs = "0.5.7"
|
||||
liquid = "0.26.4"
|
||||
|
||||
[dev-dependencies]
|
||||
mimalloc = { version = "0.1.37", default-features = false }
|
||||
@ -88,7 +100,15 @@ meili-snap = { path = "../meili-snap" }
|
||||
rand = { version = "0.8.5", features = ["small_rng"] }
|
||||
|
||||
[features]
|
||||
all-tokenizations = ["charabia/chinese", "charabia/hebrew", "charabia/japanese", "charabia/thai", "charabia/korean", "charabia/greek", "charabia/khmer"]
|
||||
all-tokenizations = [
|
||||
"charabia/chinese",
|
||||
"charabia/hebrew",
|
||||
"charabia/japanese",
|
||||
"charabia/thai",
|
||||
"charabia/korean",
|
||||
"charabia/greek",
|
||||
"charabia/khmer",
|
||||
]
|
||||
|
||||
# Use POSIX semaphores instead of SysV semaphores in LMDB
|
||||
# For more information on this feature, see heed's Cargo.toml
|
||||
|
@ -5,8 +5,8 @@ use std::time::Instant;
|
||||
|
||||
use heed::EnvOpenOptions;
|
||||
use milli::{
|
||||
execute_search, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext, SearchLogger,
|
||||
TermsMatchingStrategy,
|
||||
execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext,
|
||||
SearchLogger, TermsMatchingStrategy,
|
||||
};
|
||||
|
||||
#[global_allocator]
|
||||
@ -49,14 +49,15 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
let start = Instant::now();
|
||||
|
||||
let mut ctx = SearchContext::new(&index, &txn);
|
||||
let universe = filtered_universe(&ctx, &None)?;
|
||||
|
||||
let docs = execute_search(
|
||||
&mut ctx,
|
||||
&(!query.trim().is_empty()).then(|| query.trim().to_owned()),
|
||||
&None,
|
||||
(!query.trim().is_empty()).then(|| query.trim()),
|
||||
TermsMatchingStrategy::Last,
|
||||
milli::score_details::ScoringStrategy::Skip,
|
||||
false,
|
||||
&None,
|
||||
universe,
|
||||
&None,
|
||||
GeoSortStrategy::default(),
|
||||
0,
|
||||
|
@ -180,6 +180,14 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
||||
UnknownInternalDocumentId { document_id: DocumentId },
|
||||
#[error("`minWordSizeForTypos` setting is invalid. `oneTypo` and `twoTypos` fields should be between `0` and `255`, and `twoTypos` should be greater or equals to `oneTypo` but found `oneTypo: {0}` and twoTypos: {1}`.")]
|
||||
InvalidMinTypoWordLenSetting(u8, u8),
|
||||
#[error(transparent)]
|
||||
VectorEmbeddingError(#[from] crate::vector::Error),
|
||||
#[error(transparent)]
|
||||
MissingDocumentField(#[from] crate::prompt::error::RenderPromptError),
|
||||
#[error(transparent)]
|
||||
InvalidPrompt(#[from] crate::prompt::error::NewPromptError),
|
||||
#[error("Invalid prompt in for embeddings with name '{0}': {1}")]
|
||||
InvalidPromptForEmbeddings(String, crate::prompt::error::NewPromptError),
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
@ -336,6 +344,26 @@ impl From<HeedError> for Error {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum FaultSource {
|
||||
User,
|
||||
Runtime,
|
||||
Bug,
|
||||
Undecided,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for FaultSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
FaultSource::User => "user error",
|
||||
FaultSource::Runtime => "runtime error",
|
||||
FaultSource::Bug => "coding error",
|
||||
FaultSource::Undecided => "error",
|
||||
};
|
||||
f.write_str(s)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conditionally_lookup_for_error_message() {
|
||||
let prefix = "Attribute `name` is not sortable.";
|
||||
|
@ -23,6 +23,7 @@ use crate::heed_codec::{
|
||||
};
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::readable_slices::ReadableSlices;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
use crate::{
|
||||
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
|
||||
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
|
||||
@ -74,6 +75,7 @@ pub mod main_key {
|
||||
pub const SORT_FACET_VALUES_BY: &str = "sort-facet-values-by";
|
||||
pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits";
|
||||
pub const PROXIMITY_PRECISION: &str = "proximity-precision";
|
||||
pub const EMBEDDING_CONFIGS: &str = "embedding_configs";
|
||||
}
|
||||
|
||||
pub mod db_name {
|
||||
@ -1528,6 +1530,33 @@ impl Index {
|
||||
|
||||
Ok(script_language)
|
||||
}
|
||||
|
||||
pub(crate) fn put_embedding_configs(
|
||||
&self,
|
||||
wtxn: &mut RwTxn<'_>,
|
||||
configs: Vec<(String, EmbeddingConfig)>,
|
||||
) -> heed::Result<()> {
|
||||
self.main.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>().put(
|
||||
wtxn,
|
||||
main_key::EMBEDDING_CONFIGS,
|
||||
&configs,
|
||||
)
|
||||
}
|
||||
|
||||
pub(crate) fn delete_embedding_configs(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> {
|
||||
self.main.remap_key_type::<Str>().delete(wtxn, main_key::EMBEDDING_CONFIGS)
|
||||
}
|
||||
|
||||
pub fn embedding_configs(
|
||||
&self,
|
||||
rtxn: &RoTxn<'_>,
|
||||
) -> Result<Vec<(String, crate::vector::EmbeddingConfig)>> {
|
||||
Ok(self
|
||||
.main
|
||||
.remap_types::<Str, SerdeJson<Vec<(String, EmbeddingConfig)>>>()
|
||||
.get(rtxn, main_key::EMBEDDING_CONFIGS)?
|
||||
.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -17,11 +17,13 @@ pub mod facet;
|
||||
mod fields_ids_map;
|
||||
pub mod heed_codec;
|
||||
pub mod index;
|
||||
pub mod prompt;
|
||||
pub mod proximity;
|
||||
mod readable_slices;
|
||||
pub mod score_details;
|
||||
mod search;
|
||||
pub mod update;
|
||||
pub mod vector;
|
||||
|
||||
#[cfg(test)]
|
||||
#[macro_use]
|
||||
@ -37,8 +39,8 @@ pub use filter_parser::{Condition, FilterCondition, Span, Token};
|
||||
use fxhash::{FxHasher32, FxHasher64};
|
||||
pub use grenad::CompressionType;
|
||||
pub use search::new::{
|
||||
execute_search, DefaultSearchLogger, GeoSortStrategy, SearchContext, SearchLogger,
|
||||
VisualSearchLogger,
|
||||
execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, SearchContext,
|
||||
SearchLogger, VisualSearchLogger,
|
||||
};
|
||||
use serde_json::Value;
|
||||
pub use {charabia as tokenizer, heed};
|
||||
@ -60,7 +62,7 @@ pub use self::index::Index;
|
||||
pub use self::search::{
|
||||
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
|
||||
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
|
||||
DEFAULT_VALUES_PER_FACET,
|
||||
VectorQuery, DEFAULT_VALUES_PER_FACET,
|
||||
};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, error::Error>;
|
||||
|
97
milli/src/prompt/context.rs
Normal file
97
milli/src/prompt/context.rs
Normal file
@ -0,0 +1,97 @@
|
||||
use liquid::model::{
|
||||
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
|
||||
};
|
||||
use liquid::{ObjectView, ValueView};
|
||||
|
||||
use super::document::Document;
|
||||
use super::fields::Fields;
|
||||
use crate::FieldsIdsMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Context<'a> {
|
||||
document: &'a Document<'a>,
|
||||
fields: Fields<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Context<'a> {
|
||||
pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self {
|
||||
Self { document, fields: Fields::new(document, field_id_map) }
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ObjectView for Context<'a> {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
2
|
||||
}
|
||||
|
||||
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
|
||||
Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s)))
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
Box::new(
|
||||
std::iter::once(self.document.as_value())
|
||||
.chain(std::iter::once(self.fields.as_value())),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
||||
Box::new(self.keys().zip(self.values()))
|
||||
}
|
||||
|
||||
fn contains_key(&self, index: &str) -> bool {
|
||||
index == "doc" || index == "fields"
|
||||
}
|
||||
|
||||
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
|
||||
match index {
|
||||
"doc" => Some(self.document.as_value()),
|
||||
"fields" => Some(self.fields.as_value()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ValueView for Context<'a> {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
|
||||
}
|
||||
|
||||
fn source(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
"object"
|
||||
}
|
||||
|
||||
fn query_state(&self, state: liquid::model::State) -> bool {
|
||||
match state {
|
||||
State::Truthy => true,
|
||||
State::DefaultValue | State::Empty | State::Blank => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
|
||||
let s = ObjectRender::new(self).to_string();
|
||||
KStringCow::from_string(s)
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
LiquidValue::Object(
|
||||
self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn as_object(&self) -> Option<&dyn ObjectView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
131
milli/src/prompt/document.rs
Normal file
131
milli/src/prompt/document.rs
Normal file
@ -0,0 +1,131 @@
|
||||
use std::cell::OnceCell;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use liquid::model::{
|
||||
DisplayCow, KString, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
|
||||
};
|
||||
use liquid::{ObjectView, ValueView};
|
||||
|
||||
use crate::update::del_add::{DelAdd, KvReaderDelAdd};
|
||||
use crate::FieldsIdsMap;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Document<'a>(BTreeMap<&'a str, (&'a [u8], ParsedValue)>);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ParsedValue(std::cell::OnceCell<LiquidValue>);
|
||||
|
||||
impl ParsedValue {
|
||||
fn empty() -> ParsedValue {
|
||||
ParsedValue(OnceCell::new())
|
||||
}
|
||||
|
||||
fn get(&self, raw: &[u8]) -> &LiquidValue {
|
||||
self.0.get_or_init(|| {
|
||||
let value: serde_json::Value = serde_json::from_slice(raw).unwrap();
|
||||
liquid::model::to_value(&value).unwrap()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Document<'a> {
|
||||
pub fn new(
|
||||
data: obkv::KvReaderU16<'a>,
|
||||
side: DelAdd,
|
||||
inverted_field_map: &'a FieldsIdsMap,
|
||||
) -> Self {
|
||||
let mut out_data = BTreeMap::new();
|
||||
for (fid, raw) in data {
|
||||
let obkv = KvReaderDelAdd::new(raw);
|
||||
let Some(raw) = obkv.get(side) else {
|
||||
continue;
|
||||
};
|
||||
let Some(name) = inverted_field_map.name(fid) else {
|
||||
continue;
|
||||
};
|
||||
out_data.insert(name, (raw, ParsedValue::empty()));
|
||||
}
|
||||
Self(out_data)
|
||||
}
|
||||
|
||||
fn is_empty(&self) -> bool {
|
||||
self.0.is_empty()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.0.len()
|
||||
}
|
||||
|
||||
fn iter(&self) -> impl Iterator<Item = (KString, LiquidValue)> + '_ {
|
||||
self.0.iter().map(|(&k, (raw, data))| (k.to_owned().into(), data.get(raw).to_owned()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ObjectView for Document<'a> {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
self.len() as i64
|
||||
}
|
||||
|
||||
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
|
||||
let keys = BTreeMap::keys(&self.0).map(|&s| s.into());
|
||||
Box::new(keys)
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
Box::new(self.0.values().map(|(raw, v)| v.get(raw) as &dyn ValueView))
|
||||
}
|
||||
|
||||
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
||||
Box::new(self.0.iter().map(|(&k, (raw, data))| (k.into(), data.get(raw) as &dyn ValueView)))
|
||||
}
|
||||
|
||||
fn contains_key(&self, index: &str) -> bool {
|
||||
self.0.contains_key(index)
|
||||
}
|
||||
|
||||
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
|
||||
self.0.get(index).map(|(raw, v)| v.get(raw) as &dyn ValueView)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ValueView for Document<'a> {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
|
||||
}
|
||||
|
||||
fn source(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
"object"
|
||||
}
|
||||
|
||||
fn query_state(&self, state: liquid::model::State) -> bool {
|
||||
match state {
|
||||
State::Truthy => true,
|
||||
State::DefaultValue | State::Empty | State::Blank => self.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
|
||||
let s = ObjectRender::new(self).to_string();
|
||||
KStringCow::from_string(s)
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
LiquidValue::Object(self.iter().collect())
|
||||
}
|
||||
|
||||
fn as_object(&self) -> Option<&dyn ObjectView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
56
milli/src/prompt/error.rs
Normal file
56
milli/src/prompt/error.rs
Normal file
@ -0,0 +1,56 @@
|
||||
use crate::error::FaultSource;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct NewPromptError {
|
||||
pub kind: NewPromptErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
|
||||
impl From<NewPromptError> for crate::Error {
|
||||
fn from(value: NewPromptError) -> Self {
|
||||
crate::Error::UserError(crate::UserError::InvalidPrompt(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl NewPromptError {
|
||||
pub(crate) fn cannot_parse_template(inner: liquid::Error) -> NewPromptError {
|
||||
Self { kind: NewPromptErrorKind::CannotParseTemplate(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn invalid_fields_in_template(inner: liquid::Error) -> NewPromptError {
|
||||
Self { kind: NewPromptErrorKind::InvalidFieldsInTemplate(inner), fault: FaultSource::User }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum NewPromptErrorKind {
|
||||
#[error("cannot parse template: {0}")]
|
||||
CannotParseTemplate(liquid::Error),
|
||||
#[error("template contains invalid fields: {0}. Only `doc.*`, `fields[i].name`, `fields[i].value` are supported")]
|
||||
InvalidFieldsInTemplate(liquid::Error),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct RenderPromptError {
|
||||
pub kind: RenderPromptErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
impl RenderPromptError {
|
||||
pub(crate) fn missing_context(inner: liquid::Error) -> RenderPromptError {
|
||||
Self { kind: RenderPromptErrorKind::MissingContext(inner), fault: FaultSource::User }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum RenderPromptErrorKind {
|
||||
#[error("missing field in document: {0}")]
|
||||
MissingContext(liquid::Error),
|
||||
}
|
||||
|
||||
impl From<RenderPromptError> for crate::Error {
|
||||
fn from(value: RenderPromptError) -> Self {
|
||||
crate::Error::UserError(crate::UserError::MissingDocumentField(value))
|
||||
}
|
||||
}
|
172
milli/src/prompt/fields.rs
Normal file
172
milli/src/prompt/fields.rs
Normal file
@ -0,0 +1,172 @@
|
||||
use liquid::model::{
|
||||
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
|
||||
};
|
||||
use liquid::{ObjectView, ValueView};
|
||||
|
||||
use super::document::Document;
|
||||
use crate::FieldsIdsMap;
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Fields<'a>(Vec<FieldValue<'a>>);
|
||||
|
||||
impl<'a> Fields<'a> {
|
||||
pub fn new(document: &'a Document<'a>, field_id_map: &'a FieldsIdsMap) -> Self {
|
||||
Self(
|
||||
std::iter::repeat(document)
|
||||
.zip(field_id_map.iter())
|
||||
.map(|(document, (_fid, name))| FieldValue { document, name })
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct FieldValue<'a> {
|
||||
name: &'a str,
|
||||
document: &'a Document<'a>,
|
||||
}
|
||||
|
||||
impl<'a> ValueView for FieldValue<'a> {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
|
||||
}
|
||||
|
||||
fn source(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
"object"
|
||||
}
|
||||
|
||||
fn query_state(&self, state: liquid::model::State) -> bool {
|
||||
match state {
|
||||
State::Truthy => true,
|
||||
State::DefaultValue | State::Empty | State::Blank => self.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
|
||||
let s = ObjectRender::new(self).to_string();
|
||||
KStringCow::from_string(s)
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
LiquidValue::Object(
|
||||
self.iter().map(|(k, v)| (k.to_string().into(), v.to_value())).collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn as_object(&self) -> Option<&dyn ObjectView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> FieldValue<'a> {
|
||||
pub fn name(&self) -> &&'a str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
pub fn value(&self) -> &dyn ValueView {
|
||||
self.document.get(self.name).unwrap_or(&LiquidValue::Nil)
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.size() == 0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ObjectView for FieldValue<'a> {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
2
|
||||
}
|
||||
|
||||
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
|
||||
Box::new(["name", "value"].iter().map(|&x| KStringCow::from_static(x)))
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
Box::new(
|
||||
std::iter::once(self.name() as &dyn ValueView).chain(std::iter::once(self.value())),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
||||
Box::new(self.keys().zip(self.values()))
|
||||
}
|
||||
|
||||
fn contains_key(&self, index: &str) -> bool {
|
||||
index == "name" || index == "value"
|
||||
}
|
||||
|
||||
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
|
||||
match index {
|
||||
"name" => Some(self.name()),
|
||||
"value" => Some(self.value()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ArrayView for Fields<'a> {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self.0.as_value()
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
self.0.len() as i64
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
self.0.values()
|
||||
}
|
||||
|
||||
fn contains_key(&self, index: i64) -> bool {
|
||||
self.0.contains_key(index)
|
||||
}
|
||||
|
||||
fn get(&self, index: i64) -> Option<&dyn ValueView> {
|
||||
ArrayView::get(&self.0, index)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> ValueView for Fields<'a> {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> liquid::model::DisplayCow<'_> {
|
||||
self.0.render()
|
||||
}
|
||||
|
||||
fn source(&self) -> liquid::model::DisplayCow<'_> {
|
||||
self.0.source()
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
self.0.type_name()
|
||||
}
|
||||
|
||||
fn query_state(&self, state: liquid::model::State) -> bool {
|
||||
self.0.query_state(state)
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
|
||||
self.0.to_kstr()
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
self.0.to_value()
|
||||
}
|
||||
|
||||
fn as_array(&self) -> Option<&dyn ArrayView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
144
milli/src/prompt/mod.rs
Normal file
144
milli/src/prompt/mod.rs
Normal file
@ -0,0 +1,144 @@
|
||||
mod context;
|
||||
mod document;
|
||||
pub(crate) mod error;
|
||||
mod fields;
|
||||
mod template_checker;
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use error::{NewPromptError, RenderPromptError};
|
||||
|
||||
use self::context::Context;
|
||||
use self::document::Document;
|
||||
use crate::update::del_add::DelAdd;
|
||||
use crate::FieldsIdsMap;
|
||||
|
||||
pub struct Prompt {
|
||||
template: liquid::Template,
|
||||
template_text: String,
|
||||
strategy: PromptFallbackStrategy,
|
||||
fallback: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct PromptData {
|
||||
pub template: String,
|
||||
pub strategy: PromptFallbackStrategy,
|
||||
pub fallback: String,
|
||||
}
|
||||
|
||||
impl From<Prompt> for PromptData {
|
||||
fn from(value: Prompt) -> Self {
|
||||
Self { template: value.template_text, strategy: value.strategy, fallback: value.fallback }
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<PromptData> for Prompt {
|
||||
type Error = NewPromptError;
|
||||
|
||||
fn try_from(value: PromptData) -> Result<Self, Self::Error> {
|
||||
Prompt::new(value.template, Some(value.strategy), Some(value.fallback))
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for Prompt {
|
||||
fn clone(&self) -> Self {
|
||||
let template_text = self.template_text.clone();
|
||||
Self {
|
||||
template: new_template(&template_text).unwrap(),
|
||||
template_text,
|
||||
strategy: self.strategy,
|
||||
fallback: self.fallback.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn new_template(text: &str) -> Result<liquid::Template, liquid::Error> {
|
||||
liquid::ParserBuilder::with_stdlib().build().unwrap().parse(text)
|
||||
}
|
||||
|
||||
fn default_template() -> liquid::Template {
|
||||
new_template(default_template_text()).unwrap()
|
||||
}
|
||||
|
||||
fn default_template_text() -> &'static str {
|
||||
"{% for field in fields %} \
|
||||
{{ field.name }}: {{ field.value }}\n\
|
||||
{% endfor %}"
|
||||
}
|
||||
|
||||
fn default_fallback() -> &'static str {
|
||||
"<MISSING>"
|
||||
}
|
||||
|
||||
impl Default for Prompt {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
template: default_template(),
|
||||
template_text: default_template_text().into(),
|
||||
strategy: Default::default(),
|
||||
fallback: default_fallback().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PromptData {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
template: default_template_text().into(),
|
||||
strategy: Default::default(),
|
||||
fallback: default_fallback().into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Prompt {
|
||||
pub fn new(
|
||||
template: String,
|
||||
strategy: Option<PromptFallbackStrategy>,
|
||||
fallback: Option<String>,
|
||||
) -> Result<Self, NewPromptError> {
|
||||
let this = Self {
|
||||
template: liquid::ParserBuilder::with_stdlib()
|
||||
.build()
|
||||
.unwrap()
|
||||
.parse(&template)
|
||||
.map_err(NewPromptError::cannot_parse_template)?,
|
||||
template_text: template,
|
||||
strategy: strategy.unwrap_or_default(),
|
||||
fallback: fallback.unwrap_or_default(),
|
||||
};
|
||||
|
||||
// render template with special object that's OK with `doc.*` and `fields.*`
|
||||
/// FIXME: doesn't work for nested objects e.g. `doc.a.b`
|
||||
this.template
|
||||
.render(&template_checker::TemplateChecker)
|
||||
.map_err(NewPromptError::invalid_fields_in_template)?;
|
||||
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
pub fn render(
|
||||
&self,
|
||||
document: obkv::KvReaderU16<'_>,
|
||||
side: DelAdd,
|
||||
field_id_map: &FieldsIdsMap,
|
||||
) -> Result<String, RenderPromptError> {
|
||||
let document = Document::new(document, side, field_id_map);
|
||||
let context = Context::new(&document, field_id_map);
|
||||
|
||||
self.template.render(&context).map_err(RenderPromptError::missing_context)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug, Default, Clone, PartialEq, Eq, Copy, serde::Serialize, serde::Deserialize, deserr::Deserr,
|
||||
)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum PromptFallbackStrategy {
|
||||
Fallback,
|
||||
Skip,
|
||||
#[default]
|
||||
Error,
|
||||
}
|
282
milli/src/prompt/template_checker.rs
Normal file
282
milli/src/prompt/template_checker.rs
Normal file
@ -0,0 +1,282 @@
|
||||
use liquid::model::{
|
||||
ArrayView, DisplayCow, KStringCow, ObjectRender, ObjectSource, State, Value as LiquidValue,
|
||||
};
|
||||
use liquid::{ObjectView, ValueView};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TemplateChecker;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DummyDoc;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DummyFields;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct DummyField;
|
||||
|
||||
const DUMMY_VALUE: &LiquidValue = &LiquidValue::Nil;
|
||||
|
||||
impl ObjectView for DummyField {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
2
|
||||
}
|
||||
|
||||
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
|
||||
Box::new(["name", "value"].iter().map(|s| KStringCow::from_static(s)))
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
Box::new(std::iter::empty())
|
||||
}
|
||||
|
||||
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
||||
Box::new(std::iter::empty())
|
||||
}
|
||||
|
||||
fn contains_key(&self, index: &str) -> bool {
|
||||
index == "name" || index == "value"
|
||||
}
|
||||
|
||||
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
|
||||
if self.contains_key(index) {
|
||||
Some(DUMMY_VALUE.as_view())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueView for DummyField {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> DisplayCow<'_> {
|
||||
DUMMY_VALUE.render()
|
||||
}
|
||||
|
||||
fn source(&self) -> DisplayCow<'_> {
|
||||
DUMMY_VALUE.source()
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
"object"
|
||||
}
|
||||
|
||||
fn query_state(&self, state: State) -> bool {
|
||||
DUMMY_VALUE.query_state(state)
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> KStringCow<'_> {
|
||||
DUMMY_VALUE.to_kstr()
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
LiquidValue::Nil
|
||||
}
|
||||
|
||||
fn as_object(&self) -> Option<&dyn ObjectView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueView for DummyFields {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> DisplayCow<'_> {
|
||||
DUMMY_VALUE.render()
|
||||
}
|
||||
|
||||
fn source(&self) -> DisplayCow<'_> {
|
||||
DUMMY_VALUE.source()
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
"array"
|
||||
}
|
||||
|
||||
fn query_state(&self, state: State) -> bool {
|
||||
DUMMY_VALUE.query_state(state)
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> KStringCow<'_> {
|
||||
DUMMY_VALUE.to_kstr()
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
LiquidValue::Nil
|
||||
}
|
||||
|
||||
fn as_array(&self) -> Option<&dyn ArrayView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ArrayView for DummyFields {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
i64::MAX
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
Box::new(std::iter::empty())
|
||||
}
|
||||
|
||||
fn contains_key(&self, _index: i64) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get(&self, _index: i64) -> Option<&dyn ValueView> {
|
||||
Some(DummyField.as_value())
|
||||
}
|
||||
}
|
||||
|
||||
impl ObjectView for DummyDoc {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
1000
|
||||
}
|
||||
|
||||
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
|
||||
Box::new(std::iter::empty())
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
Box::new(std::iter::empty())
|
||||
}
|
||||
|
||||
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
||||
Box::new(std::iter::empty())
|
||||
}
|
||||
|
||||
fn contains_key(&self, _index: &str) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn get<'s>(&'s self, _index: &str) -> Option<&'s dyn ValueView> {
|
||||
Some(DUMMY_VALUE.as_view())
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueView for DummyDoc {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> DisplayCow<'_> {
|
||||
DUMMY_VALUE.render()
|
||||
}
|
||||
|
||||
fn source(&self) -> DisplayCow<'_> {
|
||||
DUMMY_VALUE.source()
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
"object"
|
||||
}
|
||||
|
||||
fn query_state(&self, state: State) -> bool {
|
||||
DUMMY_VALUE.query_state(state)
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> KStringCow<'_> {
|
||||
DUMMY_VALUE.to_kstr()
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
LiquidValue::Nil
|
||||
}
|
||||
|
||||
fn as_object(&self) -> Option<&dyn ObjectView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ObjectView for TemplateChecker {
|
||||
fn as_value(&self) -> &dyn ValueView {
|
||||
self
|
||||
}
|
||||
|
||||
fn size(&self) -> i64 {
|
||||
2
|
||||
}
|
||||
|
||||
fn keys<'k>(&'k self) -> Box<dyn Iterator<Item = KStringCow<'k>> + 'k> {
|
||||
Box::new(["doc", "fields"].iter().map(|s| KStringCow::from_static(s)))
|
||||
}
|
||||
|
||||
fn values<'k>(&'k self) -> Box<dyn Iterator<Item = &'k dyn ValueView> + 'k> {
|
||||
Box::new(
|
||||
std::iter::once(DummyDoc.as_value()).chain(std::iter::once(DummyFields.as_value())),
|
||||
)
|
||||
}
|
||||
|
||||
fn iter<'k>(&'k self) -> Box<dyn Iterator<Item = (KStringCow<'k>, &'k dyn ValueView)> + 'k> {
|
||||
Box::new(self.keys().zip(self.values()))
|
||||
}
|
||||
|
||||
fn contains_key(&self, index: &str) -> bool {
|
||||
index == "doc" || index == "fields"
|
||||
}
|
||||
|
||||
fn get<'s>(&'s self, index: &str) -> Option<&'s dyn ValueView> {
|
||||
match index {
|
||||
"doc" => Some(DummyDoc.as_value()),
|
||||
"fields" => Some(DummyFields.as_value()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ValueView for TemplateChecker {
|
||||
fn as_debug(&self) -> &dyn std::fmt::Debug {
|
||||
self
|
||||
}
|
||||
|
||||
fn render(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectRender::new(self)))
|
||||
}
|
||||
|
||||
fn source(&self) -> liquid::model::DisplayCow<'_> {
|
||||
DisplayCow::Owned(Box::new(ObjectSource::new(self)))
|
||||
}
|
||||
|
||||
fn type_name(&self) -> &'static str {
|
||||
"object"
|
||||
}
|
||||
|
||||
fn query_state(&self, state: liquid::model::State) -> bool {
|
||||
match state {
|
||||
State::Truthy => true,
|
||||
State::DefaultValue | State::Empty | State::Blank => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn to_kstr(&self) -> liquid::model::KStringCow<'_> {
|
||||
let s = ObjectRender::new(self).to_string();
|
||||
KStringCow::from_string(s)
|
||||
}
|
||||
|
||||
fn to_value(&self) -> LiquidValue {
|
||||
LiquidValue::Object(
|
||||
self.iter().map(|(k, x)| (k.to_string().into(), x.to_value())).collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn as_object(&self) -> Option<&dyn ObjectView> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
@ -1,3 +1,6 @@
|
||||
use std::cmp::Ordering;
|
||||
|
||||
use itertools::Itertools;
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::distance_between_two_points;
|
||||
@ -12,9 +15,24 @@ pub enum ScoreDetails {
|
||||
ExactAttribute(ExactAttribute),
|
||||
ExactWords(ExactWords),
|
||||
Sort(Sort),
|
||||
Vector(Vector),
|
||||
GeoSort(GeoSort),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum ScoreValue<'a> {
|
||||
Score(f64),
|
||||
Sort(&'a Sort),
|
||||
GeoSort(&'a GeoSort),
|
||||
}
|
||||
|
||||
enum RankOrValue<'a> {
|
||||
Rank(Rank),
|
||||
Sort(&'a Sort),
|
||||
GeoSort(&'a GeoSort),
|
||||
Score(f64),
|
||||
}
|
||||
|
||||
impl ScoreDetails {
|
||||
pub fn local_score(&self) -> Option<f64> {
|
||||
self.rank().map(Rank::local_score)
|
||||
@ -31,11 +49,55 @@ impl ScoreDetails {
|
||||
ScoreDetails::ExactWords(details) => Some(details.rank()),
|
||||
ScoreDetails::Sort(_) => None,
|
||||
ScoreDetails::GeoSort(_) => None,
|
||||
ScoreDetails::Vector(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self>) -> f64 {
|
||||
Rank::global_score(details.filter_map(Self::rank))
|
||||
pub fn global_score<'a>(details: impl Iterator<Item = &'a Self> + 'a) -> f64 {
|
||||
Self::score_values(details)
|
||||
.find_map(|x| {
|
||||
let ScoreValue::Score(score) = x else {
|
||||
return None;
|
||||
};
|
||||
Some(score)
|
||||
})
|
||||
.unwrap_or(1.0f64)
|
||||
}
|
||||
|
||||
pub fn score_values<'a>(
|
||||
details: impl Iterator<Item = &'a Self> + 'a,
|
||||
) -> impl Iterator<Item = ScoreValue<'a>> + 'a {
|
||||
details
|
||||
.map(ScoreDetails::rank_or_value)
|
||||
.coalesce(|left, right| match (left, right) {
|
||||
(RankOrValue::Rank(left), RankOrValue::Rank(right)) => {
|
||||
Ok(RankOrValue::Rank(Rank::merge(left, right)))
|
||||
}
|
||||
(left, right) => Err((left, right)),
|
||||
})
|
||||
.map(|rank_or_value| match rank_or_value {
|
||||
RankOrValue::Rank(r) => ScoreValue::Score(r.local_score()),
|
||||
RankOrValue::Sort(s) => ScoreValue::Sort(s),
|
||||
RankOrValue::GeoSort(g) => ScoreValue::GeoSort(g),
|
||||
RankOrValue::Score(s) => ScoreValue::Score(s),
|
||||
})
|
||||
}
|
||||
|
||||
fn rank_or_value(&self) -> RankOrValue<'_> {
|
||||
match self {
|
||||
ScoreDetails::Words(w) => RankOrValue::Rank(w.rank()),
|
||||
ScoreDetails::Typo(t) => RankOrValue::Rank(t.rank()),
|
||||
ScoreDetails::Proximity(p) => RankOrValue::Rank(*p),
|
||||
ScoreDetails::Fid(f) => RankOrValue::Rank(*f),
|
||||
ScoreDetails::Position(p) => RankOrValue::Rank(*p),
|
||||
ScoreDetails::ExactAttribute(e) => RankOrValue::Rank(e.rank()),
|
||||
ScoreDetails::ExactWords(e) => RankOrValue::Rank(e.rank()),
|
||||
ScoreDetails::Sort(sort) => RankOrValue::Sort(sort),
|
||||
ScoreDetails::GeoSort(geosort) => RankOrValue::GeoSort(geosort),
|
||||
ScoreDetails::Vector(vector) => RankOrValue::Score(
|
||||
vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// Panics
|
||||
@ -181,6 +243,19 @@ impl ScoreDetails {
|
||||
details_map.insert(sort, sort_details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::Vector(s) => {
|
||||
let vector = format!("vectorSort({:?})", s.target_vector);
|
||||
let value = s.value_similarity.as_ref().map(|(v, _)| v);
|
||||
let similarity = s.value_similarity.as_ref().map(|(_, s)| s);
|
||||
|
||||
let details = serde_json::json!({
|
||||
"order": order,
|
||||
"value": value,
|
||||
"similarity": similarity,
|
||||
});
|
||||
details_map.insert(vector, details);
|
||||
order += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
details_map
|
||||
@ -297,15 +372,21 @@ impl Rank {
|
||||
pub fn global_score(details: impl Iterator<Item = Self>) -> f64 {
|
||||
let mut rank = Rank { rank: 1, max_rank: 1 };
|
||||
for inner_rank in details {
|
||||
rank.rank -= 1;
|
||||
|
||||
rank.rank *= inner_rank.max_rank;
|
||||
rank.max_rank *= inner_rank.max_rank;
|
||||
|
||||
rank.rank += inner_rank.rank;
|
||||
rank = Rank::merge(rank, inner_rank);
|
||||
}
|
||||
rank.local_score()
|
||||
}
|
||||
|
||||
pub fn merge(mut outer: Rank, inner: Rank) -> Rank {
|
||||
outer.rank = outer.rank.saturating_sub(1);
|
||||
|
||||
outer.rank *= inner.max_rank;
|
||||
outer.max_rank *= inner.max_rank;
|
||||
|
||||
outer.rank += inner.rank;
|
||||
|
||||
outer
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize)]
|
||||
@ -335,13 +416,78 @@ pub struct Sort {
|
||||
pub value: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
|
||||
impl PartialOrd for Sort {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
if self.field_name != other.field_name {
|
||||
return None;
|
||||
}
|
||||
if self.ascending != other.ascending {
|
||||
return None;
|
||||
}
|
||||
match (&self.value, &other.value) {
|
||||
(serde_json::Value::Null, serde_json::Value::Null) => Some(Ordering::Equal),
|
||||
(serde_json::Value::Null, _) => Some(Ordering::Less),
|
||||
(_, serde_json::Value::Null) => Some(Ordering::Greater),
|
||||
// numbers are always before strings
|
||||
(serde_json::Value::Number(_), serde_json::Value::String(_)) => Some(Ordering::Greater),
|
||||
(serde_json::Value::String(_), serde_json::Value::Number(_)) => Some(Ordering::Less),
|
||||
(serde_json::Value::Number(left), serde_json::Value::Number(right)) => {
|
||||
// FIXME: unwrap permitted here?
|
||||
let order = left.as_f64().unwrap().partial_cmp(&right.as_f64().unwrap())?;
|
||||
// 12 < 42, and when ascending, we want to see 12 first, so the smallest.
|
||||
// Hence, when ascending, smaller is better
|
||||
Some(if self.ascending { order.reverse() } else { order })
|
||||
}
|
||||
(serde_json::Value::String(left), serde_json::Value::String(right)) => {
|
||||
let order = left.cmp(right);
|
||||
// Taking e.g. "a" and "z"
|
||||
// "a" < "z", and when ascending, we want to see "a" first, so the smallest.
|
||||
// Hence, when ascending, smaller is better
|
||||
Some(if self.ascending { order.reverse() } else { order })
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub struct GeoSort {
|
||||
pub target_point: [f64; 2],
|
||||
pub ascending: bool,
|
||||
pub value: Option<[f64; 2]>,
|
||||
}
|
||||
|
||||
impl PartialOrd for GeoSort {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
if self.target_point != other.target_point {
|
||||
return None;
|
||||
}
|
||||
if self.ascending != other.ascending {
|
||||
return None;
|
||||
}
|
||||
Some(match (self.distance(), other.distance()) {
|
||||
(None, None) => Ordering::Equal,
|
||||
(None, Some(_)) => Ordering::Less,
|
||||
(Some(_), None) => Ordering::Greater,
|
||||
(Some(left), Some(right)) => {
|
||||
let order = left.partial_cmp(&right)?;
|
||||
if self.ascending {
|
||||
// when ascending, the one with the smallest distance has the best score
|
||||
order.reverse()
|
||||
} else {
|
||||
order
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd)]
|
||||
pub struct Vector {
|
||||
pub target_vector: Vec<f32>,
|
||||
pub value_similarity: Option<(Vec<f32>, f32)>,
|
||||
}
|
||||
|
||||
impl GeoSort {
|
||||
pub fn distance(&self) -> Option<f64> {
|
||||
self.value.map(|value| distance_between_two_points(&self.target_point, &value))
|
||||
|
336
milli/src/search/hybrid.rs
Normal file
336
milli/src/search/hybrid.rs
Normal file
@ -0,0 +1,336 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use itertools::Itertools;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::new::{execute_vector_search, PartialSearchResult};
|
||||
use crate::score_details::{ScoreDetails, ScoreValue, ScoringStrategy};
|
||||
use crate::{
|
||||
execute_search, DefaultSearchLogger, MatchingWords, Result, Search, SearchContext, SearchResult,
|
||||
};
|
||||
|
||||
struct CombinedSearchResult {
|
||||
matching_words: MatchingWords,
|
||||
candidates: RoaringBitmap,
|
||||
document_scores: Vec<(u32, CombinedScore)>,
|
||||
}
|
||||
|
||||
type CombinedScore = (Vec<ScoreDetails>, Option<Vec<ScoreDetails>>);
|
||||
|
||||
fn compare_scores(left: &CombinedScore, right: &CombinedScore) -> Ordering {
|
||||
let mut left_main_it = ScoreDetails::score_values(left.0.iter());
|
||||
let mut left_sub_it =
|
||||
ScoreDetails::score_values(left.1.as_ref().map(|x| x.iter()).into_iter().flatten());
|
||||
|
||||
let mut right_main_it = ScoreDetails::score_values(right.0.iter());
|
||||
let mut right_sub_it =
|
||||
ScoreDetails::score_values(right.1.as_ref().map(|x| x.iter()).into_iter().flatten());
|
||||
|
||||
let mut left_main = left_main_it.next();
|
||||
let mut left_sub = left_sub_it.next();
|
||||
let mut right_main = right_main_it.next();
|
||||
let mut right_sub = right_sub_it.next();
|
||||
|
||||
loop {
|
||||
let left =
|
||||
take_best_score(&mut left_main, &mut left_sub, &mut left_main_it, &mut left_sub_it);
|
||||
|
||||
let right =
|
||||
take_best_score(&mut right_main, &mut right_sub, &mut right_main_it, &mut right_sub_it);
|
||||
|
||||
match (left, right) {
|
||||
(None, None) => return Ordering::Equal,
|
||||
(None, Some(_)) => return Ordering::Less,
|
||||
(Some(_), None) => return Ordering::Greater,
|
||||
(Some(ScoreValue::Score(left)), Some(ScoreValue::Score(right))) => {
|
||||
if (left - right).abs() <= f64::EPSILON {
|
||||
continue;
|
||||
}
|
||||
return left.partial_cmp(&right).unwrap();
|
||||
}
|
||||
(Some(ScoreValue::Sort(left)), Some(ScoreValue::Sort(right))) => {
|
||||
match left.partial_cmp(right).unwrap() {
|
||||
Ordering::Equal => continue,
|
||||
order => return order,
|
||||
}
|
||||
}
|
||||
(Some(ScoreValue::GeoSort(left)), Some(ScoreValue::GeoSort(right))) => {
|
||||
match left.partial_cmp(right).unwrap() {
|
||||
Ordering::Equal => continue,
|
||||
order => return order,
|
||||
}
|
||||
}
|
||||
(Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater,
|
||||
(Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less,
|
||||
// if we have this, we're bad
|
||||
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
|
||||
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
|
||||
unreachable!("Unexpected geo and sort comparison")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn take_best_score<'a>(
|
||||
main_score: &mut Option<ScoreValue<'a>>,
|
||||
sub_score: &mut Option<ScoreValue<'a>>,
|
||||
main_it: &mut impl Iterator<Item = ScoreValue<'a>>,
|
||||
sub_it: &mut impl Iterator<Item = ScoreValue<'a>>,
|
||||
) -> Option<ScoreValue<'a>> {
|
||||
match (*main_score, *sub_score) {
|
||||
(Some(main), None) => {
|
||||
*main_score = main_it.next();
|
||||
Some(main)
|
||||
}
|
||||
(None, Some(sub)) => {
|
||||
*sub_score = sub_it.next();
|
||||
Some(sub)
|
||||
}
|
||||
(main @ Some(ScoreValue::Score(main_f)), sub @ Some(ScoreValue::Score(sub_v))) => {
|
||||
// take max, both advance
|
||||
*main_score = main_it.next();
|
||||
*sub_score = sub_it.next();
|
||||
if main_f >= sub_v {
|
||||
main
|
||||
} else {
|
||||
sub
|
||||
}
|
||||
}
|
||||
(main @ Some(ScoreValue::Score(_)), _) => {
|
||||
*main_score = main_it.next();
|
||||
main
|
||||
}
|
||||
(_, sub @ Some(ScoreValue::Score(_))) => {
|
||||
*sub_score = sub_it.next();
|
||||
sub
|
||||
}
|
||||
(main @ Some(ScoreValue::GeoSort(main_geo)), sub @ Some(ScoreValue::GeoSort(sub_geo))) => {
|
||||
// take best advance both
|
||||
*main_score = main_it.next();
|
||||
*sub_score = sub_it.next();
|
||||
if main_geo >= sub_geo {
|
||||
main
|
||||
} else {
|
||||
sub
|
||||
}
|
||||
}
|
||||
(main @ Some(ScoreValue::Sort(main_sort)), sub @ Some(ScoreValue::Sort(sub_sort))) => {
|
||||
// take best advance both
|
||||
*main_score = main_it.next();
|
||||
*sub_score = sub_it.next();
|
||||
if main_sort >= sub_sort {
|
||||
main
|
||||
} else {
|
||||
sub
|
||||
}
|
||||
}
|
||||
(
|
||||
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
|
||||
Some(ScoreValue::GeoSort(_) | ScoreValue::Sort(_)),
|
||||
) => None,
|
||||
|
||||
(None, None) => None,
|
||||
}
|
||||
}
|
||||
|
||||
impl CombinedSearchResult {
|
||||
fn new(main_results: SearchResult, ancillary_results: PartialSearchResult) -> Self {
|
||||
let mut docid_scores = HashMap::new();
|
||||
for (docid, score) in
|
||||
main_results.documents_ids.iter().zip(main_results.document_scores.into_iter())
|
||||
{
|
||||
docid_scores.insert(*docid, (score, None));
|
||||
}
|
||||
|
||||
for (docid, score) in ancillary_results
|
||||
.documents_ids
|
||||
.iter()
|
||||
.zip(ancillary_results.document_scores.into_iter())
|
||||
{
|
||||
docid_scores
|
||||
.entry(*docid)
|
||||
.and_modify(|(_main_score, ancillary_score)| *ancillary_score = Some(score));
|
||||
}
|
||||
|
||||
let mut document_scores: Vec<_> = docid_scores.into_iter().collect();
|
||||
|
||||
document_scores.sort_by(|(_, left), (_, right)| compare_scores(left, right).reverse());
|
||||
|
||||
Self {
|
||||
matching_words: main_results.matching_words,
|
||||
candidates: main_results.candidates,
|
||||
document_scores,
|
||||
}
|
||||
}
|
||||
|
||||
fn merge(left: Self, right: Self, from: usize, length: usize) -> SearchResult {
|
||||
let mut documents_ids =
|
||||
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
|
||||
let mut document_scores =
|
||||
Vec::with_capacity(left.document_scores.len() + right.document_scores.len());
|
||||
|
||||
let mut documents_seen = RoaringBitmap::new();
|
||||
for (docid, (main_score, _sub_score)) in left
|
||||
.document_scores
|
||||
.into_iter()
|
||||
.merge_by(right.document_scores.into_iter(), |(_, left), (_, right)| {
|
||||
// the first value is the one with the greatest score
|
||||
compare_scores(left, right).is_ge()
|
||||
})
|
||||
// remove documents we already saw
|
||||
.filter(|(docid, _)| documents_seen.insert(*docid))
|
||||
// start skipping **after** the filter
|
||||
.skip(from)
|
||||
// take **after** skipping
|
||||
.take(length)
|
||||
{
|
||||
documents_ids.push(docid);
|
||||
// TODO: pass both scores to documents_score in some way?
|
||||
document_scores.push(main_score);
|
||||
}
|
||||
|
||||
SearchResult {
|
||||
matching_words: left.matching_words,
|
||||
candidates: left.candidates | right.candidates,
|
||||
documents_ids,
|
||||
document_scores,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Search<'a> {
|
||||
pub fn execute_hybrid(&self) -> Result<SearchResult> {
|
||||
// TODO: find classier way to achieve that than to reset vector and query params
|
||||
// create separate keyword and semantic searches
|
||||
let mut search = Search {
|
||||
query: self.query.clone(),
|
||||
vector: self.vector.clone(),
|
||||
filter: self.filter.clone(),
|
||||
offset: 0,
|
||||
limit: self.limit + self.offset,
|
||||
sort_criteria: self.sort_criteria.clone(),
|
||||
searchable_attributes: self.searchable_attributes,
|
||||
geo_strategy: self.geo_strategy,
|
||||
terms_matching_strategy: self.terms_matching_strategy,
|
||||
scoring_strategy: ScoringStrategy::Detailed,
|
||||
words_limit: self.words_limit,
|
||||
exhaustive_number_hits: self.exhaustive_number_hits,
|
||||
rtxn: self.rtxn,
|
||||
index: self.index,
|
||||
};
|
||||
|
||||
let vector_query = search.vector.take();
|
||||
let keyword_query = self.query.as_deref();
|
||||
|
||||
let keyword_results = search.execute()?;
|
||||
|
||||
// skip semantic search if we don't have a vector query (placeholder search)
|
||||
let Some(vector_query) = vector_query else {
|
||||
return Ok(keyword_results);
|
||||
};
|
||||
|
||||
// completely skip semantic search if the results of the keyword search are good enough
|
||||
if self.results_good_enough(&keyword_results) {
|
||||
return Ok(keyword_results);
|
||||
}
|
||||
|
||||
search.vector = Some(vector_query);
|
||||
search.query = None;
|
||||
|
||||
// TODO: would be better to have two distinct functions at this point
|
||||
let vector_results = search.execute()?;
|
||||
|
||||
// Compute keyword scores for vector_results
|
||||
let keyword_results_for_vector =
|
||||
self.keyword_results_for_vector(keyword_query, &vector_results)?;
|
||||
|
||||
// compute vector scores for keyword_results
|
||||
let vector_results_for_keyword =
|
||||
// can unwrap because we returned already if there was no vector query
|
||||
self.vector_results_for_keyword(search.vector.as_ref().unwrap(), &keyword_results)?;
|
||||
|
||||
let keyword_results =
|
||||
CombinedSearchResult::new(keyword_results, vector_results_for_keyword);
|
||||
let vector_results = CombinedSearchResult::new(vector_results, keyword_results_for_vector);
|
||||
|
||||
let merge_results =
|
||||
CombinedSearchResult::merge(vector_results, keyword_results, self.offset, self.limit);
|
||||
assert!(merge_results.documents_ids.len() <= self.limit);
|
||||
Ok(merge_results)
|
||||
}
|
||||
|
||||
fn vector_results_for_keyword(
|
||||
&self,
|
||||
vector: &[f32],
|
||||
keyword_results: &SearchResult,
|
||||
) -> Result<PartialSearchResult> {
|
||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
||||
|
||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
||||
ctx.searchable_attributes(searchable_attributes)?;
|
||||
}
|
||||
|
||||
let universe = keyword_results.documents_ids.iter().collect();
|
||||
|
||||
execute_vector_search(
|
||||
&mut ctx,
|
||||
vector,
|
||||
ScoringStrategy::Detailed,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
0,
|
||||
self.limit + self.offset,
|
||||
)
|
||||
}
|
||||
|
||||
fn keyword_results_for_vector(
|
||||
&self,
|
||||
query: Option<&str>,
|
||||
vector_results: &SearchResult,
|
||||
) -> Result<PartialSearchResult> {
|
||||
let mut ctx = SearchContext::new(self.index, self.rtxn);
|
||||
|
||||
if let Some(searchable_attributes) = self.searchable_attributes {
|
||||
ctx.searchable_attributes(searchable_attributes)?;
|
||||
}
|
||||
|
||||
let universe = vector_results.documents_ids.iter().collect();
|
||||
|
||||
execute_search(
|
||||
&mut ctx,
|
||||
query,
|
||||
self.terms_matching_strategy,
|
||||
ScoringStrategy::Detailed,
|
||||
self.exhaustive_number_hits,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
0,
|
||||
self.limit + self.offset,
|
||||
Some(self.words_limit),
|
||||
&mut DefaultSearchLogger,
|
||||
&mut DefaultSearchLogger,
|
||||
)
|
||||
}
|
||||
|
||||
fn results_good_enough(&self, keyword_results: &SearchResult) -> bool {
|
||||
const GOOD_ENOUGH_SCORE: f64 = 0.9;
|
||||
|
||||
// 1. we check that we got a sufficient number of results
|
||||
if keyword_results.document_scores.len() < self.limit + self.offset {
|
||||
return false;
|
||||
}
|
||||
|
||||
// 2. and that all results have a good enough score.
|
||||
// we need to check all results because due to sort like rules, they're not necessarily in relevancy order
|
||||
for score in &keyword_results.document_scores {
|
||||
let score = ScoreDetails::global_score(score.iter());
|
||||
if score < GOOD_ENOUGH_SCORE {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
@ -3,6 +3,7 @@ use std::ops::ControlFlow;
|
||||
|
||||
use charabia::normalizer::NormalizerOption;
|
||||
use charabia::Normalize;
|
||||
use deserr::{DeserializeError, Deserr, Sequence};
|
||||
use fst::automaton::{Automaton, Str};
|
||||
use fst::{IntoStreamer, Streamer};
|
||||
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
|
||||
@ -12,12 +13,13 @@ use roaring::bitmap::RoaringBitmap;
|
||||
|
||||
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
|
||||
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
|
||||
use self::new::PartialSearchResult;
|
||||
use self::new::{execute_vector_search, PartialSearchResult};
|
||||
use crate::error::UserError;
|
||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::{
|
||||
execute_search, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index, Result, SearchContext,
|
||||
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
|
||||
Result, SearchContext,
|
||||
};
|
||||
|
||||
// Building these factories is not free.
|
||||
@ -30,6 +32,7 @@ const MAX_NUMBER_OF_FACETS: usize = 100;
|
||||
|
||||
pub mod facet;
|
||||
mod fst_utils;
|
||||
pub mod hybrid;
|
||||
pub mod new;
|
||||
|
||||
pub struct Search<'a> {
|
||||
@ -50,6 +53,53 @@ pub struct Search<'a> {
|
||||
index: &'a Index,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum VectorQuery {
|
||||
Vector(Vec<f32>),
|
||||
String(String),
|
||||
}
|
||||
|
||||
impl<E> Deserr<E> for VectorQuery
|
||||
where
|
||||
E: DeserializeError,
|
||||
{
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef,
|
||||
) -> std::result::Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::String(s) => Ok(VectorQuery::String(s)),
|
||||
deserr::Value::Sequence(seq) => {
|
||||
let v: std::result::Result<Vec<f32>, _> = seq
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(index, v)| match v.into_value() {
|
||||
deserr::Value::Float(f) => Ok(f as f32),
|
||||
deserr::Value::Integer(i) => Ok(i as f32),
|
||||
v => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: v,
|
||||
accepted: &[deserr::ValueKind::Float, deserr::ValueKind::Integer],
|
||||
},
|
||||
location.push_index(index),
|
||||
))),
|
||||
})
|
||||
.collect();
|
||||
Ok(VectorQuery::Vector(v?))
|
||||
}
|
||||
_ => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::String, deserr::ValueKind::Sequence],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Search<'a> {
|
||||
pub fn new(rtxn: &'a heed::RoTxn, index: &'a Index) -> Search<'a> {
|
||||
Search {
|
||||
@ -75,8 +125,8 @@ impl<'a> Search<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn vector(&mut self, vector: impl Into<Vec<f32>>) -> &mut Search<'a> {
|
||||
self.vector = Some(vector.into());
|
||||
pub fn vector(&mut self, vector: Vec<f32>) -> &mut Search<'a> {
|
||||
self.vector = Some(vector);
|
||||
self
|
||||
}
|
||||
|
||||
@ -140,23 +190,35 @@ impl<'a> Search<'a> {
|
||||
ctx.searchable_attributes(searchable_attributes)?;
|
||||
}
|
||||
|
||||
let universe = filtered_universe(&ctx, &self.filter)?;
|
||||
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
|
||||
execute_search(
|
||||
&mut ctx,
|
||||
&self.query,
|
||||
&self.vector,
|
||||
self.terms_matching_strategy,
|
||||
self.scoring_strategy,
|
||||
self.exhaustive_number_hits,
|
||||
&self.filter,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
self.offset,
|
||||
self.limit,
|
||||
Some(self.words_limit),
|
||||
&mut DefaultSearchLogger,
|
||||
&mut DefaultSearchLogger,
|
||||
)?;
|
||||
match self.vector.as_ref() {
|
||||
Some(vector) => execute_vector_search(
|
||||
&mut ctx,
|
||||
vector,
|
||||
self.scoring_strategy,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
self.offset,
|
||||
self.limit,
|
||||
)?,
|
||||
None => execute_search(
|
||||
&mut ctx,
|
||||
self.query.as_deref(),
|
||||
self.terms_matching_strategy,
|
||||
self.scoring_strategy,
|
||||
self.exhaustive_number_hits,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
self.offset,
|
||||
self.limit,
|
||||
Some(self.words_limit),
|
||||
&mut DefaultSearchLogger,
|
||||
&mut DefaultSearchLogger,
|
||||
)?,
|
||||
};
|
||||
|
||||
// consume context and located_query_terms to build MatchingWords.
|
||||
let matching_words = match located_query_terms {
|
||||
|
@ -498,19 +498,19 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::index::tests::TempIndex;
|
||||
use crate::{execute_search, SearchContext};
|
||||
use crate::{execute_search, filtered_universe, SearchContext};
|
||||
|
||||
impl<'a> MatcherBuilder<'a> {
|
||||
fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self {
|
||||
let mut ctx = SearchContext::new(index, rtxn);
|
||||
let universe = filtered_universe(&ctx, &None).unwrap();
|
||||
let crate::search::PartialSearchResult { located_query_terms, .. } = execute_search(
|
||||
&mut ctx,
|
||||
&Some(query.to_string()),
|
||||
&None,
|
||||
Some(query),
|
||||
crate::TermsMatchingStrategy::default(),
|
||||
crate::score_details::ScoringStrategy::Skip,
|
||||
false,
|
||||
&None,
|
||||
universe,
|
||||
&None,
|
||||
crate::search::new::GeoSortStrategy::default(),
|
||||
0,
|
||||
|
@ -16,6 +16,7 @@ mod small_bitmap;
|
||||
|
||||
mod exact_attribute;
|
||||
mod sort;
|
||||
mod vector_sort;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
@ -28,7 +29,6 @@ use db_cache::DatabaseCache;
|
||||
use exact_attribute::ExactAttribute;
|
||||
use graph_based_ranking_rule::{Exactness, Fid, Position, Proximity, Typo};
|
||||
use heed::RoTxn;
|
||||
use instant_distance::Search;
|
||||
use interner::{DedupInterner, Interner};
|
||||
pub use logger::visual::VisualSearchLogger;
|
||||
pub use logger::{DefaultSearchLogger, SearchLogger};
|
||||
@ -46,7 +46,7 @@ use self::geo_sort::GeoSort;
|
||||
pub use self::geo_sort::Strategy as GeoSortStrategy;
|
||||
use self::graph_based_ranking_rule::Words;
|
||||
use self::interner::Interned;
|
||||
use crate::distance::NDotProductPoint;
|
||||
use self::vector_sort::VectorSort;
|
||||
use crate::error::FieldIdMapMissingEntry;
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::search::new::distinct::apply_distinct_rule;
|
||||
@ -258,6 +258,70 @@ fn get_ranking_rules_for_placeholder_search<'ctx>(
|
||||
Ok(ranking_rules)
|
||||
}
|
||||
|
||||
fn get_ranking_rules_for_vector<'ctx>(
|
||||
ctx: &SearchContext<'ctx>,
|
||||
sort_criteria: &Option<Vec<AscDesc>>,
|
||||
geo_strategy: geo_sort::Strategy,
|
||||
target: &[f32],
|
||||
) -> Result<Vec<BoxRankingRule<'ctx, PlaceholderQuery>>> {
|
||||
// query graph search
|
||||
|
||||
let mut sort = false;
|
||||
let mut sorted_fields = HashSet::new();
|
||||
let mut geo_sorted = false;
|
||||
|
||||
let mut vector = false;
|
||||
let mut ranking_rules: Vec<BoxRankingRule<PlaceholderQuery>> = vec![];
|
||||
|
||||
let settings_ranking_rules = ctx.index.criteria(ctx.txn)?;
|
||||
for rr in settings_ranking_rules {
|
||||
match rr {
|
||||
crate::Criterion::Words
|
||||
| crate::Criterion::Typo
|
||||
| crate::Criterion::Proximity
|
||||
| crate::Criterion::Attribute
|
||||
| crate::Criterion::Exactness => {
|
||||
if !vector {
|
||||
let vector_candidates = ctx.index.documents_ids(ctx.txn)?;
|
||||
let vector_sort = VectorSort::new(ctx, target.to_vec(), vector_candidates)?;
|
||||
ranking_rules.push(Box::new(vector_sort));
|
||||
vector = true;
|
||||
}
|
||||
}
|
||||
crate::Criterion::Sort => {
|
||||
if sort {
|
||||
continue;
|
||||
}
|
||||
resolve_sort_criteria(
|
||||
sort_criteria,
|
||||
ctx,
|
||||
&mut ranking_rules,
|
||||
&mut sorted_fields,
|
||||
&mut geo_sorted,
|
||||
geo_strategy,
|
||||
)?;
|
||||
sort = true;
|
||||
}
|
||||
crate::Criterion::Asc(field_name) => {
|
||||
if sorted_fields.contains(&field_name) {
|
||||
continue;
|
||||
}
|
||||
sorted_fields.insert(field_name.clone());
|
||||
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, true)?));
|
||||
}
|
||||
crate::Criterion::Desc(field_name) => {
|
||||
if sorted_fields.contains(&field_name) {
|
||||
continue;
|
||||
}
|
||||
sorted_fields.insert(field_name.clone());
|
||||
ranking_rules.push(Box::new(Sort::new(ctx.index, ctx.txn, field_name, false)?));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ranking_rules)
|
||||
}
|
||||
|
||||
/// Return the list of initialised ranking rules to be used for a query graph search.
|
||||
fn get_ranking_rules_for_query_graph_search<'ctx>(
|
||||
ctx: &SearchContext<'ctx>,
|
||||
@ -422,15 +486,62 @@ fn resolve_sort_criteria<'ctx, Query: RankingRuleQueryTrait>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn filtered_universe(ctx: &SearchContext, filters: &Option<Filter>) -> Result<RoaringBitmap> {
|
||||
Ok(if let Some(filters) = filters {
|
||||
filters.evaluate(ctx.txn, ctx.index)?
|
||||
} else {
|
||||
ctx.index.documents_ids(ctx.txn)?
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn execute_vector_search(
|
||||
ctx: &mut SearchContext,
|
||||
vector: &[f32],
|
||||
scoring_strategy: ScoringStrategy,
|
||||
universe: RoaringBitmap,
|
||||
sort_criteria: &Option<Vec<AscDesc>>,
|
||||
geo_strategy: geo_sort::Strategy,
|
||||
from: usize,
|
||||
length: usize,
|
||||
) -> Result<PartialSearchResult> {
|
||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||
|
||||
/// FIXME: input universe = universe & documents_with_vectors
|
||||
// for now if we're computing embeddings for ALL documents, we can assume that this is just universe
|
||||
let ranking_rules = get_ranking_rules_for_vector(ctx, sort_criteria, geo_strategy, vector)?;
|
||||
|
||||
let mut placeholder_search_logger = logger::DefaultSearchLogger;
|
||||
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
|
||||
&mut placeholder_search_logger;
|
||||
|
||||
let BucketSortOutput { docids, scores, all_candidates } = bucket_sort(
|
||||
ctx,
|
||||
ranking_rules,
|
||||
&PlaceholderQuery,
|
||||
&universe,
|
||||
from,
|
||||
length,
|
||||
scoring_strategy,
|
||||
placeholder_search_logger,
|
||||
)?;
|
||||
|
||||
Ok(PartialSearchResult {
|
||||
candidates: all_candidates,
|
||||
document_scores: scores,
|
||||
documents_ids: docids,
|
||||
located_query_terms: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn execute_search(
|
||||
ctx: &mut SearchContext,
|
||||
query: &Option<String>,
|
||||
vector: &Option<Vec<f32>>,
|
||||
query: Option<&str>,
|
||||
terms_matching_strategy: TermsMatchingStrategy,
|
||||
scoring_strategy: ScoringStrategy,
|
||||
exhaustive_number_hits: bool,
|
||||
filters: &Option<Filter>,
|
||||
mut universe: RoaringBitmap,
|
||||
sort_criteria: &Option<Vec<AscDesc>>,
|
||||
geo_strategy: geo_sort::Strategy,
|
||||
from: usize,
|
||||
@ -439,60 +550,8 @@ pub fn execute_search(
|
||||
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
|
||||
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
|
||||
) -> Result<PartialSearchResult> {
|
||||
let mut universe = if let Some(filters) = filters {
|
||||
filters.evaluate(ctx.txn, ctx.index)?
|
||||
} else {
|
||||
ctx.index.documents_ids(ctx.txn)?
|
||||
};
|
||||
|
||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||
|
||||
if let Some(vector) = vector {
|
||||
let mut search = Search::default();
|
||||
let docids = match ctx.index.vector_hnsw(ctx.txn)? {
|
||||
Some(hnsw) => {
|
||||
if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() {
|
||||
if vector.len() != expected_size {
|
||||
return Err(UserError::InvalidVectorDimensions {
|
||||
expected: expected_size,
|
||||
found: vector.len(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
}
|
||||
|
||||
let vector = NDotProductPoint::new(vector.clone());
|
||||
|
||||
let neighbors = hnsw.search(&vector, &mut search);
|
||||
|
||||
let mut docids = Vec::new();
|
||||
let mut uniq_docids = RoaringBitmap::new();
|
||||
for instant_distance::Item { distance: _, pid, point: _ } in neighbors {
|
||||
let index = pid.into_inner();
|
||||
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
|
||||
if universe.contains(docid) && uniq_docids.insert(docid) {
|
||||
docids.push(docid);
|
||||
if docids.len() == (from + length) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// return the nearest documents that are also part of the candidates
|
||||
// along with a dummy list of scores that are useless in this context.
|
||||
docids.into_iter().skip(from).take(length).collect()
|
||||
}
|
||||
None => Vec::new(),
|
||||
};
|
||||
|
||||
return Ok(PartialSearchResult {
|
||||
candidates: universe,
|
||||
document_scores: vec![Vec::new(); docids.len()],
|
||||
documents_ids: docids,
|
||||
located_query_terms: None,
|
||||
});
|
||||
}
|
||||
|
||||
let mut located_query_terms = None;
|
||||
let query_terms = if let Some(query) = query {
|
||||
// We make sure that the analyzer is aware of the stop words
|
||||
@ -546,7 +605,7 @@ pub fn execute_search(
|
||||
terms_matching_strategy,
|
||||
)?;
|
||||
|
||||
universe =
|
||||
universe &=
|
||||
resolve_universe(ctx, &universe, &graph, terms_matching_strategy, query_graph_logger)?;
|
||||
|
||||
bucket_sort(
|
||||
|
150
milli/src/search/new/vector_sort.rs
Normal file
150
milli/src/search/new/vector_sort.rs
Normal file
@ -0,0 +1,150 @@
|
||||
use std::future::Future;
|
||||
use std::iter::FromIterator;
|
||||
use std::pin::Pin;
|
||||
|
||||
use nolife::DynBoxScope;
|
||||
use roaring::RoaringBitmap;
|
||||
|
||||
use super::ranking_rules::{RankingRule, RankingRuleOutput, RankingRuleQueryTrait};
|
||||
use crate::distance::NDotProductPoint;
|
||||
use crate::index::Hnsw;
|
||||
use crate::score_details::{self, ScoreDetails};
|
||||
use crate::{Result, SearchContext, SearchLogger, UserError};
|
||||
|
||||
pub struct VectorSort<Q: RankingRuleQueryTrait> {
|
||||
query: Option<Q>,
|
||||
target: Vec<f32>,
|
||||
vector_candidates: RoaringBitmap,
|
||||
scope: nolife::DynBoxScope<SearchFamily>,
|
||||
}
|
||||
|
||||
type Item<'a> = instant_distance::Item<'a, NDotProductPoint>;
|
||||
type SearchFut = Pin<Box<dyn Future<Output = nolife::Never>>>;
|
||||
|
||||
struct SearchFamily;
|
||||
impl<'a> nolife::Family<'a> for SearchFamily {
|
||||
type Family = Box<dyn Iterator<Item = Item<'a>> + 'a>;
|
||||
}
|
||||
|
||||
async fn search_scope(
|
||||
mut time_capsule: nolife::TimeCapsule<SearchFamily>,
|
||||
hnsw: Hnsw,
|
||||
target: Vec<f32>,
|
||||
) -> nolife::Never {
|
||||
let mut search = instant_distance::Search::default();
|
||||
let it = Box::new(hnsw.search(&NDotProductPoint::new(target), &mut search));
|
||||
let mut it: Box<dyn Iterator<Item = Item>> = it;
|
||||
loop {
|
||||
time_capsule.freeze(&mut it).await;
|
||||
}
|
||||
}
|
||||
|
||||
impl<Q: RankingRuleQueryTrait> VectorSort<Q> {
|
||||
pub fn new(
|
||||
ctx: &SearchContext,
|
||||
target: Vec<f32>,
|
||||
vector_candidates: RoaringBitmap,
|
||||
) -> Result<Self> {
|
||||
let hnsw =
|
||||
ctx.index.vector_hnsw(ctx.txn)?.unwrap_or(Hnsw::builder().build_hnsw(Vec::default()).0);
|
||||
|
||||
if let Some(expected_size) = hnsw.iter().map(|(_, point)| point.len()).next() {
|
||||
if target.len() != expected_size {
|
||||
return Err(UserError::InvalidVectorDimensions {
|
||||
expected: expected_size,
|
||||
found: target.len(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
}
|
||||
|
||||
let target_clone = target.clone();
|
||||
let producer = move |time_capsule| -> SearchFut {
|
||||
Box::pin(search_scope(time_capsule, hnsw, target_clone))
|
||||
};
|
||||
let scope = DynBoxScope::new(producer);
|
||||
|
||||
Ok(Self { query: None, target, vector_candidates, scope })
|
||||
}
|
||||
}
|
||||
|
||||
impl<'ctx, Q: RankingRuleQueryTrait> RankingRule<'ctx, Q> for VectorSort<Q> {
|
||||
fn id(&self) -> String {
|
||||
"vector_sort".to_owned()
|
||||
}
|
||||
|
||||
fn start_iteration(
|
||||
&mut self,
|
||||
_ctx: &mut SearchContext<'ctx>,
|
||||
_logger: &mut dyn SearchLogger<Q>,
|
||||
universe: &RoaringBitmap,
|
||||
query: &Q,
|
||||
) -> Result<()> {
|
||||
assert!(self.query.is_none());
|
||||
|
||||
self.query = Some(query.clone());
|
||||
self.vector_candidates &= universe;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::only_used_in_recursion)]
|
||||
fn next_bucket(
|
||||
&mut self,
|
||||
ctx: &mut SearchContext<'ctx>,
|
||||
_logger: &mut dyn SearchLogger<Q>,
|
||||
universe: &RoaringBitmap,
|
||||
) -> Result<Option<RankingRuleOutput<Q>>> {
|
||||
let query = self.query.as_ref().unwrap().clone();
|
||||
self.vector_candidates &= universe;
|
||||
|
||||
if self.vector_candidates.is_empty() {
|
||||
return Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: universe.clone(),
|
||||
score: ScoreDetails::Vector(score_details::Vector {
|
||||
target_vector: self.target.clone(),
|
||||
value_similarity: None,
|
||||
}),
|
||||
}));
|
||||
}
|
||||
|
||||
let scope = &mut self.scope;
|
||||
let target = &self.target;
|
||||
let vector_candidates = &self.vector_candidates;
|
||||
|
||||
scope.enter(|it| {
|
||||
for item in it.by_ref() {
|
||||
let item: Item = item;
|
||||
let index = item.pid.into_inner();
|
||||
let docid = ctx.index.vector_id_docid.get(ctx.txn, &index)?.unwrap();
|
||||
|
||||
if vector_candidates.contains(docid) {
|
||||
return Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: RoaringBitmap::from_iter([docid]),
|
||||
score: ScoreDetails::Vector(score_details::Vector {
|
||||
target_vector: target.clone(),
|
||||
value_similarity: Some((
|
||||
item.point.clone().into_inner(),
|
||||
1.0 - item.distance,
|
||||
)),
|
||||
}),
|
||||
}));
|
||||
}
|
||||
}
|
||||
Ok(Some(RankingRuleOutput {
|
||||
query,
|
||||
candidates: universe.clone(),
|
||||
score: ScoreDetails::Vector(score_details::Vector {
|
||||
target_vector: target.clone(),
|
||||
value_similarity: None,
|
||||
}),
|
||||
}))
|
||||
})
|
||||
}
|
||||
|
||||
fn end_iteration(&mut self, _ctx: &mut SearchContext<'ctx>, _logger: &mut dyn SearchLogger<Q>) {
|
||||
self.query = None;
|
||||
}
|
||||
}
|
@ -1,9 +1,10 @@
|
||||
use std::cmp::Ordering;
|
||||
use std::convert::TryFrom;
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::fs::File;
|
||||
use std::io::{self, BufReader, BufWriter};
|
||||
use std::mem::size_of;
|
||||
use std::str::from_utf8;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytemuck::cast_slice;
|
||||
use grenad::Writer;
|
||||
@ -13,13 +14,56 @@ use serde_json::{from_slice, Value};
|
||||
|
||||
use super::helpers::{create_writer, writer_into_reader, GrenadParameters};
|
||||
use crate::error::UserError;
|
||||
use crate::prompt::Prompt;
|
||||
use crate::update::del_add::{DelAdd, KvReaderDelAdd, KvWriterDelAdd};
|
||||
use crate::update::index_documents::helpers::try_split_at;
|
||||
use crate::{DocumentId, FieldId, InternalError, Result, VectorOrArrayOfVectors};
|
||||
use crate::vector::Embedder;
|
||||
use crate::{DocumentId, FieldsIdsMap, InternalError, Result, VectorOrArrayOfVectors};
|
||||
|
||||
/// The length of the elements that are always in the buffer when inserting new values.
|
||||
const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
|
||||
|
||||
pub struct ExtractedVectorPoints {
|
||||
// docid, _index -> KvWriterDelAdd -> Vector
|
||||
pub manual_vectors: grenad::Reader<BufReader<File>>,
|
||||
// docid -> ()
|
||||
pub remove_vectors: grenad::Reader<BufReader<File>>,
|
||||
// docid -> prompt
|
||||
pub prompts: grenad::Reader<BufReader<File>>,
|
||||
}
|
||||
|
||||
enum VectorStateDelta {
|
||||
NoChange,
|
||||
// Remove all vectors, generated or manual, from this document
|
||||
NowRemoved,
|
||||
|
||||
// Add the manually specified vectors, passed in the other grenad
|
||||
// Remove any previously generated vectors
|
||||
// Note: changing the value of the manually specified vector **should not record** this delta
|
||||
WasGeneratedNowManual(Vec<Vec<f32>>),
|
||||
|
||||
ManualDelta(Vec<Vec<f32>>, Vec<Vec<f32>>),
|
||||
|
||||
// Add the vector computed from the specified prompt
|
||||
// Remove any previous vector
|
||||
// Note: changing the value of the prompt **does require** recording this delta
|
||||
NowGenerated(String),
|
||||
}
|
||||
|
||||
impl VectorStateDelta {
|
||||
fn into_values(self) -> (bool, String, (Vec<Vec<f32>>, Vec<Vec<f32>>)) {
|
||||
match self {
|
||||
VectorStateDelta::NoChange => Default::default(),
|
||||
VectorStateDelta::NowRemoved => (true, Default::default(), Default::default()),
|
||||
VectorStateDelta::WasGeneratedNowManual(add) => {
|
||||
(true, Default::default(), (Default::default(), add))
|
||||
}
|
||||
VectorStateDelta::ManualDelta(del, add) => (false, Default::default(), (del, add)),
|
||||
VectorStateDelta::NowGenerated(prompt) => (true, prompt, Default::default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extracts the embedding vector contained in each document under the `_vectors` field.
|
||||
///
|
||||
/// Returns the generated grenad reader containing the docid as key associated to the Vec<f32>
|
||||
@ -27,16 +71,34 @@ const TRUNCATE_SIZE: usize = size_of::<DocumentId>();
|
||||
pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
obkv_documents: grenad::Reader<R>,
|
||||
indexer: GrenadParameters,
|
||||
vectors_fid: FieldId,
|
||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||
field_id_map: FieldsIdsMap,
|
||||
prompt: Option<&Prompt>,
|
||||
) -> Result<ExtractedVectorPoints> {
|
||||
puffin::profile_function!();
|
||||
|
||||
let mut writer = create_writer(
|
||||
// (docid, _index) -> KvWriterDelAdd -> Vector
|
||||
let mut manual_vectors_writer = create_writer(
|
||||
indexer.chunk_compression_type,
|
||||
indexer.chunk_compression_level,
|
||||
tempfile::tempfile()?,
|
||||
);
|
||||
|
||||
// (docid) -> (prompt)
|
||||
let mut prompts_writer = create_writer(
|
||||
indexer.chunk_compression_type,
|
||||
indexer.chunk_compression_level,
|
||||
tempfile::tempfile()?,
|
||||
);
|
||||
|
||||
// (docid) -> ()
|
||||
let mut remove_vectors_writer = create_writer(
|
||||
indexer.chunk_compression_type,
|
||||
indexer.chunk_compression_level,
|
||||
tempfile::tempfile()?,
|
||||
);
|
||||
|
||||
let vectors_fid = field_id_map.id("_vectors");
|
||||
|
||||
let mut key_buffer = Vec::new();
|
||||
let mut cursor = obkv_documents.into_cursor()?;
|
||||
while let Some((key, value)) = cursor.move_on_next()? {
|
||||
@ -53,43 +115,148 @@ pub fn extract_vector_points<R: io::Read + io::Seek>(
|
||||
// lazily get it when needed
|
||||
let document_id = || -> Value { from_utf8(external_id_bytes).unwrap().into() };
|
||||
|
||||
// first we retrieve the _vectors field
|
||||
if let Some(value) = obkv.get(vectors_fid) {
|
||||
let delta = if let Some(value) = vectors_fid.and_then(|vectors_fid| obkv.get(vectors_fid)) {
|
||||
let vectors_obkv = KvReaderDelAdd::new(value);
|
||||
match (vectors_obkv.get(DelAdd::Deletion), vectors_obkv.get(DelAdd::Addition)) {
|
||||
(Some(old), Some(new)) => {
|
||||
// no autogeneration
|
||||
let del_vectors = extract_vectors(old, document_id)?;
|
||||
let add_vectors = extract_vectors(new, document_id)?;
|
||||
|
||||
// then we extract the values
|
||||
let del_vectors = vectors_obkv
|
||||
.get(DelAdd::Deletion)
|
||||
.map(|vectors| extract_vectors(vectors, document_id))
|
||||
.transpose()?
|
||||
.flatten();
|
||||
let add_vectors = vectors_obkv
|
||||
.get(DelAdd::Addition)
|
||||
.map(|vectors| extract_vectors(vectors, document_id))
|
||||
.transpose()?
|
||||
.flatten();
|
||||
VectorStateDelta::ManualDelta(
|
||||
del_vectors.unwrap_or_default(),
|
||||
add_vectors.unwrap_or_default(),
|
||||
)
|
||||
}
|
||||
(None, Some(new)) => {
|
||||
// was possibly autogenerated, remove all vectors for that document
|
||||
let add_vectors = extract_vectors(new, document_id)?;
|
||||
|
||||
// and we finally push the unique vectors into the writer
|
||||
push_vectors_diff(
|
||||
&mut writer,
|
||||
&mut key_buffer,
|
||||
del_vectors.unwrap_or_default(),
|
||||
add_vectors.unwrap_or_default(),
|
||||
)?;
|
||||
}
|
||||
VectorStateDelta::WasGeneratedNowManual(add_vectors.unwrap_or_default())
|
||||
}
|
||||
(Some(_old), None) => {
|
||||
// Do we keep this document?
|
||||
let document_is_kept = obkv
|
||||
.iter()
|
||||
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
|
||||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||
if document_is_kept {
|
||||
// becomes autogenerated
|
||||
match prompt {
|
||||
Some(prompt) => VectorStateDelta::NowGenerated(prompt.render(
|
||||
obkv,
|
||||
DelAdd::Addition,
|
||||
&field_id_map,
|
||||
)?),
|
||||
None => VectorStateDelta::NowRemoved,
|
||||
}
|
||||
} else {
|
||||
VectorStateDelta::NowRemoved
|
||||
}
|
||||
}
|
||||
(None, None) => {
|
||||
// Do we keep this document?
|
||||
let document_is_kept = obkv
|
||||
.iter()
|
||||
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
|
||||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||
|
||||
if document_is_kept {
|
||||
match prompt {
|
||||
Some(prompt) => {
|
||||
// Don't give up if the old prompt was failing
|
||||
let old_prompt = prompt
|
||||
.render(obkv, DelAdd::Deletion, &field_id_map)
|
||||
.unwrap_or_default();
|
||||
let new_prompt =
|
||||
prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
|
||||
if old_prompt != new_prompt {
|
||||
log::trace!(
|
||||
"Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}"
|
||||
);
|
||||
VectorStateDelta::NowGenerated(new_prompt)
|
||||
} else {
|
||||
VectorStateDelta::NoChange
|
||||
}
|
||||
}
|
||||
// We no longer have a prompt, so we need to remove any existing vector
|
||||
None => VectorStateDelta::NowRemoved,
|
||||
}
|
||||
} else {
|
||||
VectorStateDelta::NowRemoved
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Do we keep this document?
|
||||
let document_is_kept = obkv
|
||||
.iter()
|
||||
.map(|(_, deladd)| KvReaderDelAdd::new(deladd))
|
||||
.any(|deladd| deladd.get(DelAdd::Addition).is_some());
|
||||
|
||||
if document_is_kept {
|
||||
match prompt {
|
||||
Some(prompt) => {
|
||||
// Don't give up if the old prompt was failing
|
||||
let old_prompt = prompt
|
||||
.render(obkv, DelAdd::Deletion, &field_id_map)
|
||||
.unwrap_or_default();
|
||||
let new_prompt = prompt.render(obkv, DelAdd::Addition, &field_id_map)?;
|
||||
if old_prompt != new_prompt {
|
||||
log::trace!(
|
||||
"Changing prompt from\n{old_prompt}\n===\nto\n{new_prompt}"
|
||||
);
|
||||
VectorStateDelta::NowGenerated(new_prompt)
|
||||
} else {
|
||||
VectorStateDelta::NoChange
|
||||
}
|
||||
}
|
||||
None => VectorStateDelta::NowRemoved,
|
||||
}
|
||||
} else {
|
||||
VectorStateDelta::NowRemoved
|
||||
}
|
||||
};
|
||||
|
||||
// and we finally push the unique vectors into the writer
|
||||
push_vectors_diff(
|
||||
&mut remove_vectors_writer,
|
||||
&mut prompts_writer,
|
||||
&mut manual_vectors_writer,
|
||||
&mut key_buffer,
|
||||
delta,
|
||||
)?;
|
||||
}
|
||||
|
||||
writer_into_reader(writer)
|
||||
Ok(ExtractedVectorPoints {
|
||||
// docid, _index -> KvWriterDelAdd -> Vector
|
||||
manual_vectors: writer_into_reader(manual_vectors_writer)?,
|
||||
// docid -> ()
|
||||
remove_vectors: writer_into_reader(remove_vectors_writer)?,
|
||||
// docid -> prompt
|
||||
prompts: writer_into_reader(prompts_writer)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// Computes the diff between both Del and Add numbers and
|
||||
/// only inserts the parts that differ in the sorter.
|
||||
fn push_vectors_diff(
|
||||
writer: &mut Writer<BufWriter<File>>,
|
||||
remove_vectors_writer: &mut Writer<BufWriter<File>>,
|
||||
prompts_writer: &mut Writer<BufWriter<File>>,
|
||||
manual_vectors_writer: &mut Writer<BufWriter<File>>,
|
||||
key_buffer: &mut Vec<u8>,
|
||||
mut del_vectors: Vec<Vec<f32>>,
|
||||
mut add_vectors: Vec<Vec<f32>>,
|
||||
delta: VectorStateDelta,
|
||||
) -> Result<()> {
|
||||
let (must_remove, prompt, (mut del_vectors, mut add_vectors)) = delta.into_values();
|
||||
if must_remove {
|
||||
key_buffer.truncate(TRUNCATE_SIZE);
|
||||
remove_vectors_writer.insert(&key_buffer, [])?;
|
||||
}
|
||||
if !prompt.is_empty() {
|
||||
key_buffer.truncate(TRUNCATE_SIZE);
|
||||
prompts_writer.insert(&key_buffer, prompt.as_bytes())?;
|
||||
}
|
||||
|
||||
// We sort and dedup the vectors
|
||||
del_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
|
||||
add_vectors.sort_unstable_by(|a, b| compare_vectors(a, b));
|
||||
@ -114,7 +281,7 @@ fn push_vectors_diff(
|
||||
let mut obkv = KvWriterDelAdd::memory();
|
||||
obkv.insert(DelAdd::Deletion, cast_slice(&vector))?;
|
||||
let bytes = obkv.into_inner()?;
|
||||
writer.insert(&key_buffer, bytes)?;
|
||||
manual_vectors_writer.insert(&key_buffer, bytes)?;
|
||||
}
|
||||
EitherOrBoth::Right(vector) => {
|
||||
// We insert only the Add part of the Obkv to inform
|
||||
@ -122,7 +289,7 @@ fn push_vectors_diff(
|
||||
let mut obkv = KvWriterDelAdd::memory();
|
||||
obkv.insert(DelAdd::Addition, cast_slice(&vector))?;
|
||||
let bytes = obkv.into_inner()?;
|
||||
writer.insert(&key_buffer, bytes)?;
|
||||
manual_vectors_writer.insert(&key_buffer, bytes)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -146,3 +313,102 @@ fn extract_vectors(value: &[u8], document_id: impl Fn() -> Value) -> Result<Opti
|
||||
.into()),
|
||||
}
|
||||
}
|
||||
|
||||
#[logging_timer::time]
|
||||
pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
// docid, prompt
|
||||
prompt_reader: grenad::Reader<R>,
|
||||
indexer: GrenadParameters,
|
||||
embedder: Arc<Embedder>,
|
||||
) -> Result<(grenad::Reader<BufReader<File>>, Option<usize>)> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread().enable_io().enable_time().build()?;
|
||||
|
||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parellelism
|
||||
let n_vectors_per_chunk = embedder.prompt_count_in_chunk_hint(); // number of vectors in a single chunk
|
||||
|
||||
// docid, state with embedding
|
||||
let mut state_writer = create_writer(
|
||||
indexer.chunk_compression_type,
|
||||
indexer.chunk_compression_level,
|
||||
tempfile::tempfile()?,
|
||||
);
|
||||
|
||||
let mut chunks = Vec::with_capacity(n_chunks);
|
||||
let mut current_chunk = Vec::with_capacity(n_vectors_per_chunk);
|
||||
let mut current_chunk_ids = Vec::with_capacity(n_vectors_per_chunk);
|
||||
let mut chunks_ids = Vec::with_capacity(n_chunks);
|
||||
let mut cursor = prompt_reader.into_cursor()?;
|
||||
|
||||
let mut expected_dimension = None;
|
||||
|
||||
while let Some((key, value)) = cursor.move_on_next()? {
|
||||
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
|
||||
// SAFETY: precondition, the grenad value was saved from a string
|
||||
let prompt = unsafe { std::str::from_utf8_unchecked(value) };
|
||||
if current_chunk.len() == current_chunk.capacity() {
|
||||
chunks.push(std::mem::replace(
|
||||
&mut current_chunk,
|
||||
Vec::with_capacity(n_vectors_per_chunk),
|
||||
));
|
||||
chunks_ids.push(std::mem::replace(
|
||||
&mut current_chunk_ids,
|
||||
Vec::with_capacity(n_vectors_per_chunk),
|
||||
));
|
||||
};
|
||||
current_chunk.push(prompt.to_owned());
|
||||
current_chunk_ids.push(docid);
|
||||
|
||||
if chunks.len() == chunks.capacity() {
|
||||
let chunked_embeds = rt
|
||||
.block_on(
|
||||
embedder
|
||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks))),
|
||||
)
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::UserError::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
for (docid, embeddings) in chunks_ids
|
||||
.iter()
|
||||
.flat_map(|docids| docids.iter())
|
||||
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
||||
{
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
expected_dimension = Some(embeddings.dimension());
|
||||
}
|
||||
chunks_ids.clear();
|
||||
}
|
||||
}
|
||||
|
||||
// send last chunk
|
||||
if !chunks.is_empty() {
|
||||
let chunked_embeds = rt
|
||||
.block_on(embedder.embed_chunks(std::mem::take(&mut chunks)))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::UserError::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
for (docid, embeddings) in chunks_ids
|
||||
.iter()
|
||||
.flat_map(|docids| docids.iter())
|
||||
.zip(chunked_embeds.iter().flat_map(|embeds| embeds.iter()))
|
||||
{
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
expected_dimension = Some(embeddings.dimension());
|
||||
}
|
||||
}
|
||||
|
||||
if !current_chunk.is_empty() {
|
||||
let embeds = rt
|
||||
.block_on(embedder.embed(std::mem::take(&mut current_chunk)))
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::UserError::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
for (docid, embeddings) in current_chunk_ids.iter().zip(embeds.iter()) {
|
||||
state_writer.insert(docid.to_be_bytes(), cast_slice(embeddings.as_inner()))?;
|
||||
expected_dimension = Some(embeddings.dimension());
|
||||
}
|
||||
}
|
||||
|
||||
Ok((writer_into_reader(state_writer)?, expected_dimension))
|
||||
}
|
||||
|
@ -9,9 +9,10 @@ mod extract_word_docids;
|
||||
mod extract_word_pair_proximity_docids;
|
||||
mod extract_word_position_docids;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::fs::File;
|
||||
use std::io::BufReader;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crossbeam_channel::Sender;
|
||||
use log::debug;
|
||||
@ -23,7 +24,9 @@ use self::extract_facet_string_docids::extract_facet_string_docids;
|
||||
use self::extract_fid_docid_facet_values::{extract_fid_docid_facet_values, ExtractedFacetValues};
|
||||
use self::extract_fid_word_count_docids::extract_fid_word_count_docids;
|
||||
use self::extract_geo_points::extract_geo_points;
|
||||
use self::extract_vector_points::extract_vector_points;
|
||||
use self::extract_vector_points::{
|
||||
extract_embeddings, extract_vector_points, ExtractedVectorPoints,
|
||||
};
|
||||
use self::extract_word_docids::extract_word_docids;
|
||||
use self::extract_word_pair_proximity_docids::extract_word_pair_proximity_docids;
|
||||
use self::extract_word_position_docids::extract_word_position_docids;
|
||||
@ -32,8 +35,10 @@ use super::helpers::{
|
||||
MergeFn, MergeableReader,
|
||||
};
|
||||
use super::{helpers, TypedChunk};
|
||||
use crate::prompt::Prompt;
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::{FieldId, Result};
|
||||
use crate::vector::Embedder;
|
||||
use crate::{FieldId, FieldsIdsMap, Result};
|
||||
|
||||
/// Extract data for each databases from obkv documents in parallel.
|
||||
/// Send data in grenad file over provided Sender.
|
||||
@ -47,13 +52,14 @@ pub(crate) fn data_from_obkv_documents(
|
||||
faceted_fields: HashSet<FieldId>,
|
||||
primary_key_id: FieldId,
|
||||
geo_fields_ids: Option<(FieldId, FieldId)>,
|
||||
vectors_field_id: Option<FieldId>,
|
||||
field_id_map: FieldsIdsMap,
|
||||
stop_words: Option<fst::Set<&[u8]>>,
|
||||
allowed_separators: Option<&[&str]>,
|
||||
dictionary: Option<&[&str]>,
|
||||
max_positions_per_attributes: Option<u32>,
|
||||
exact_attributes: HashSet<FieldId>,
|
||||
proximity_precision: ProximityPrecision,
|
||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
||||
) -> Result<()> {
|
||||
puffin::profile_function!();
|
||||
|
||||
@ -64,7 +70,8 @@ pub(crate) fn data_from_obkv_documents(
|
||||
original_documents_chunk,
|
||||
indexer,
|
||||
lmdb_writer_sx.clone(),
|
||||
vectors_field_id,
|
||||
field_id_map.clone(),
|
||||
embedders.clone(),
|
||||
)
|
||||
})
|
||||
.collect::<Result<()>>()?;
|
||||
@ -276,24 +283,42 @@ fn send_original_documents_data(
|
||||
original_documents_chunk: Result<grenad::Reader<BufReader<File>>>,
|
||||
indexer: GrenadParameters,
|
||||
lmdb_writer_sx: Sender<Result<TypedChunk>>,
|
||||
vectors_field_id: Option<FieldId>,
|
||||
field_id_map: FieldsIdsMap,
|
||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
||||
) -> Result<()> {
|
||||
let original_documents_chunk =
|
||||
original_documents_chunk.and_then(|c| unsafe { as_cloneable_grenad(&c) })?;
|
||||
|
||||
if let Some(vectors_field_id) = vectors_field_id {
|
||||
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||
rayon::spawn(move || {
|
||||
let result = extract_vector_points(documents_chunk_cloned, indexer, vectors_field_id);
|
||||
let _ = match result {
|
||||
Ok(vector_points) => {
|
||||
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints(vector_points)))
|
||||
}
|
||||
Err(error) => lmdb_writer_sx_cloned.send(Err(error)),
|
||||
};
|
||||
});
|
||||
}
|
||||
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||
rayon::spawn(move || {
|
||||
let (embedder, prompt) = embedders.get("default").cloned().unzip();
|
||||
let result =
|
||||
extract_vector_points(documents_chunk_cloned, indexer, field_id_map, prompt.as_deref());
|
||||
let _ = match result {
|
||||
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
||||
/// FIXME: support multiple embedders
|
||||
let results = embedder.and_then(|embedder| {
|
||||
match extract_embeddings(prompts, indexer, embedder.clone()) {
|
||||
Ok(results) => Some(results),
|
||||
Err(error) => {
|
||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||
None
|
||||
}
|
||||
}
|
||||
});
|
||||
let (embeddings, expected_dimension) = results.unzip();
|
||||
let expected_dimension = expected_dimension.flatten();
|
||||
lmdb_writer_sx_cloned.send(Ok(TypedChunk::VectorPoints {
|
||||
remove_vectors,
|
||||
embeddings,
|
||||
expected_dimension,
|
||||
manual_vectors,
|
||||
}))
|
||||
}
|
||||
Err(error) => lmdb_writer_sx_cloned.send(Err(error)),
|
||||
};
|
||||
});
|
||||
|
||||
// TODO: create a custom internal error
|
||||
lmdb_writer_sx.send(Ok(TypedChunk::Documents(original_documents_chunk))).unwrap();
|
||||
|
@ -4,11 +4,12 @@ mod helpers;
|
||||
mod transform;
|
||||
mod typed_chunk;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::io::{Cursor, Read, Seek};
|
||||
use std::iter::FromIterator;
|
||||
use std::num::NonZeroU32;
|
||||
use std::result::Result as StdResult;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
use heed::types::Str;
|
||||
@ -32,10 +33,12 @@ use self::helpers::{grenad_obkv_into_chunks, GrenadParameters};
|
||||
pub use self::transform::{Transform, TransformOutput};
|
||||
use crate::documents::{obkv_to_object, DocumentsBatchReader};
|
||||
use crate::error::{Error, InternalError, UserError};
|
||||
use crate::prompt::Prompt;
|
||||
pub use crate::update::index_documents::helpers::CursorClonableMmap;
|
||||
use crate::update::{
|
||||
IndexerConfig, UpdateIndexingStep, WordPrefixDocids, WordPrefixIntegerDocids, WordsPrefixesFst,
|
||||
};
|
||||
use crate::vector::Embedder;
|
||||
use crate::{CboRoaringBitmapCodec, Index, Result};
|
||||
|
||||
static MERGED_DATABASE_COUNT: usize = 7;
|
||||
@ -78,6 +81,7 @@ pub struct IndexDocuments<'t, 'i, 'a, FP, FA> {
|
||||
should_abort: FA,
|
||||
added_documents: u64,
|
||||
deleted_documents: u64,
|
||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
||||
}
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
@ -121,6 +125,7 @@ where
|
||||
index,
|
||||
added_documents: 0,
|
||||
deleted_documents: 0,
|
||||
embedders: Default::default(),
|
||||
})
|
||||
}
|
||||
|
||||
@ -167,6 +172,14 @@ where
|
||||
Ok((self, Ok(indexed_documents)))
|
||||
}
|
||||
|
||||
pub fn with_embedders(
|
||||
mut self,
|
||||
embedders: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>,
|
||||
) -> Self {
|
||||
self.embedders = embedders;
|
||||
self
|
||||
}
|
||||
|
||||
/// Remove a batch of documents from the current builder.
|
||||
///
|
||||
/// Returns the number of documents deleted from the builder.
|
||||
@ -322,17 +335,18 @@ where
|
||||
// get filterable fields for facet databases
|
||||
let faceted_fields = self.index.faceted_fields_ids(self.wtxn)?;
|
||||
// get the fid of the `_geo.lat` and `_geo.lng` fields.
|
||||
let geo_fields_ids = match self.index.fields_ids_map(self.wtxn)?.id("_geo") {
|
||||
let mut field_id_map = self.index.fields_ids_map(self.wtxn)?;
|
||||
|
||||
// self.index.fields_ids_map($a)? ==>> field_id_map
|
||||
let geo_fields_ids = match field_id_map.id("_geo") {
|
||||
Some(gfid) => {
|
||||
let is_sortable = self.index.sortable_fields_ids(self.wtxn)?.contains(&gfid);
|
||||
let is_filterable = self.index.filterable_fields_ids(self.wtxn)?.contains(&gfid);
|
||||
// if `_geo` is faceted then we get the `lat` and `lng`
|
||||
if is_sortable || is_filterable {
|
||||
let field_ids = self
|
||||
.index
|
||||
.fields_ids_map(self.wtxn)?
|
||||
let field_ids = field_id_map
|
||||
.insert("_geo.lat")
|
||||
.zip(self.index.fields_ids_map(self.wtxn)?.insert("_geo.lng"))
|
||||
.zip(field_id_map.insert("_geo.lng"))
|
||||
.ok_or(UserError::AttributeLimitReached)?;
|
||||
Some(field_ids)
|
||||
} else {
|
||||
@ -341,8 +355,6 @@ where
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
// get the fid of the `_vectors` field.
|
||||
let vectors_field_id = self.index.fields_ids_map(self.wtxn)?.id("_vectors");
|
||||
|
||||
let stop_words = self.index.stop_words(self.wtxn)?;
|
||||
let separators = self.index.allowed_separators(self.wtxn)?;
|
||||
@ -364,6 +376,8 @@ where
|
||||
self.indexer_config.documents_chunk_size.unwrap_or(1024 * 1024 * 4); // 4MiB
|
||||
let max_positions_per_attributes = self.indexer_config.max_positions_per_attributes;
|
||||
|
||||
let cloned_embedder = self.embedders.clone();
|
||||
|
||||
// Run extraction pipeline in parallel.
|
||||
pool.install(|| {
|
||||
puffin::profile_scope!("extract_and_send_grenad_chunks");
|
||||
@ -387,13 +401,14 @@ where
|
||||
faceted_fields,
|
||||
primary_key_id,
|
||||
geo_fields_ids,
|
||||
vectors_field_id,
|
||||
field_id_map,
|
||||
stop_words,
|
||||
separators.as_deref(),
|
||||
dictionary.as_deref(),
|
||||
max_positions_per_attributes,
|
||||
exact_attributes,
|
||||
proximity_precision,
|
||||
cloned_embedder,
|
||||
)
|
||||
});
|
||||
|
||||
@ -2505,7 +2520,7 @@ mod tests {
|
||||
.unwrap();
|
||||
|
||||
let rtxn = index.read_txn().unwrap();
|
||||
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0]).execute().unwrap();
|
||||
let res = index.search(&rtxn).vector([0.0, 1.0, 2.0].to_vec()).execute().unwrap();
|
||||
assert_eq!(res.documents_ids.len(), 3);
|
||||
}
|
||||
|
||||
|
@ -47,7 +47,12 @@ pub(crate) enum TypedChunk {
|
||||
FieldIdFacetIsNullDocids(grenad::Reader<BufReader<File>>),
|
||||
FieldIdFacetIsEmptyDocids(grenad::Reader<BufReader<File>>),
|
||||
GeoPoints(grenad::Reader<BufReader<File>>),
|
||||
VectorPoints(grenad::Reader<BufReader<File>>),
|
||||
VectorPoints {
|
||||
remove_vectors: grenad::Reader<BufReader<File>>,
|
||||
embeddings: Option<grenad::Reader<BufReader<File>>>,
|
||||
expected_dimension: Option<usize>,
|
||||
manual_vectors: grenad::Reader<BufReader<File>>,
|
||||
},
|
||||
ScriptLanguageDocids(HashMap<(Script, Language), (RoaringBitmap, RoaringBitmap)>),
|
||||
}
|
||||
|
||||
@ -100,8 +105,8 @@ impl TypedChunk {
|
||||
TypedChunk::GeoPoints(grenad) => {
|
||||
format!("GeoPoints {{ number_of_entries: {} }}", grenad.len())
|
||||
}
|
||||
TypedChunk::VectorPoints(grenad) => {
|
||||
format!("VectorPoints {{ number_of_entries: {} }}", grenad.len())
|
||||
TypedChunk::VectorPoints{ remove_vectors, manual_vectors, embeddings, expected_dimension } => {
|
||||
format!("VectorPoints {{ remove_vectors: {}, manual_vectors: {}, embeddings: {}, dimension: {} }}", remove_vectors.len(), manual_vectors.len(), embeddings.as_ref().map(|e| e.len()).unwrap_or_default(), expected_dimension.unwrap_or_default())
|
||||
}
|
||||
TypedChunk::ScriptLanguageDocids(sl_map) => {
|
||||
format!("ScriptLanguageDocids {{ number_of_entries: {} }}", sl_map.len())
|
||||
@ -355,19 +360,64 @@ pub(crate) fn write_typed_chunk_into_index(
|
||||
index.put_geo_rtree(wtxn, &rtree)?;
|
||||
index.put_geo_faceted_documents_ids(wtxn, &geo_faceted_docids)?;
|
||||
}
|
||||
TypedChunk::VectorPoints(vector_points) => {
|
||||
let mut vectors_set = HashSet::new();
|
||||
TypedChunk::VectorPoints {
|
||||
remove_vectors,
|
||||
manual_vectors,
|
||||
embeddings,
|
||||
expected_dimension,
|
||||
} => {
|
||||
if remove_vectors.is_empty()
|
||||
&& manual_vectors.is_empty()
|
||||
&& embeddings.as_ref().map_or(true, |e| e.is_empty())
|
||||
{
|
||||
return Ok((RoaringBitmap::new(), is_merged_database));
|
||||
}
|
||||
|
||||
let mut docid_vectors_map: HashMap<DocumentId, HashSet<Vec<OrderedFloat<f32>>>> =
|
||||
HashMap::new();
|
||||
|
||||
// We extract and store the previous vectors
|
||||
if let Some(hnsw) = index.vector_hnsw(wtxn)? {
|
||||
for (pid, point) in hnsw.iter() {
|
||||
let pid_key = pid.into_inner();
|
||||
let docid = index.vector_id_docid.get(wtxn, &pid_key)?.unwrap();
|
||||
let vector: Vec<_> = point.iter().copied().map(OrderedFloat).collect();
|
||||
vectors_set.insert((docid, vector));
|
||||
docid_vectors_map.entry(docid).or_default().insert(vector);
|
||||
}
|
||||
}
|
||||
|
||||
let mut cursor = vector_points.into_cursor()?;
|
||||
// remove vectors for docids we want them removed
|
||||
let mut cursor = remove_vectors.into_cursor()?;
|
||||
while let Some((key, _)) = cursor.move_on_next()? {
|
||||
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
|
||||
|
||||
docid_vectors_map.remove(&docid);
|
||||
}
|
||||
|
||||
// add generated embeddings
|
||||
if let Some((embeddings, expected_dimension)) = embeddings.zip(expected_dimension) {
|
||||
let mut cursor = embeddings.into_cursor()?;
|
||||
while let Some((key, value)) = cursor.move_on_next()? {
|
||||
let docid = key.try_into().map(DocumentId::from_be_bytes).unwrap();
|
||||
let data: Vec<OrderedFloat<_>> =
|
||||
pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
|
||||
// it is a code error to have embeddings and not expected_dimension
|
||||
let embeddings =
|
||||
crate::vector::Embeddings::from_inner(data, expected_dimension)
|
||||
// code error if we somehow got the wrong dimension
|
||||
.unwrap();
|
||||
|
||||
let mut set = HashSet::new();
|
||||
for embedding in embeddings.iter() {
|
||||
set.insert(embedding.to_vec());
|
||||
}
|
||||
|
||||
docid_vectors_map.insert(docid, set);
|
||||
}
|
||||
}
|
||||
|
||||
// perform the manual diff
|
||||
let mut cursor = manual_vectors.into_cursor()?;
|
||||
while let Some((key, value)) = cursor.move_on_next()? {
|
||||
// convert the key back to a u32 (4 bytes)
|
||||
let (left, _index) = try_split_array_at(key).unwrap();
|
||||
@ -376,23 +426,30 @@ pub(crate) fn write_typed_chunk_into_index(
|
||||
let vector_deladd_obkv = KvReaderDelAdd::new(value);
|
||||
if let Some(value) = vector_deladd_obkv.get(DelAdd::Deletion) {
|
||||
// convert the vector back to a Vec<f32>
|
||||
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
|
||||
let key = (docid, vector);
|
||||
if !vectors_set.remove(&key) {
|
||||
error!("Unable to delete the vector: {:?}", key.1);
|
||||
}
|
||||
let vector: Vec<OrderedFloat<_>> =
|
||||
pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
|
||||
docid_vectors_map.entry(docid).and_modify(|v| {
|
||||
if !v.remove(&vector) {
|
||||
error!("Unable to delete the vector: {:?}", vector);
|
||||
}
|
||||
});
|
||||
}
|
||||
if let Some(value) = vector_deladd_obkv.get(DelAdd::Addition) {
|
||||
// convert the vector back to a Vec<f32>
|
||||
let vector = pod_collect_to_vec(value).into_iter().map(OrderedFloat).collect();
|
||||
vectors_set.insert((docid, vector));
|
||||
docid_vectors_map.entry(docid).and_modify(|v| {
|
||||
v.insert(vector);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the most common vector dimension
|
||||
let expected_dimension_size = {
|
||||
let mut dims = HashMap::new();
|
||||
vectors_set.iter().for_each(|(_, v)| *dims.entry(v.len()).or_insert(0) += 1);
|
||||
docid_vectors_map
|
||||
.values()
|
||||
.flat_map(|v| v.iter())
|
||||
.for_each(|v| *dims.entry(v.len()).or_insert(0) += 1);
|
||||
dims.into_iter().max_by_key(|(_, count)| *count).map(|(len, _)| len)
|
||||
};
|
||||
|
||||
@ -400,7 +457,10 @@ pub(crate) fn write_typed_chunk_into_index(
|
||||
// prepare the vectors before inserting them in the HNSW.
|
||||
let mut points = Vec::new();
|
||||
let mut docids = Vec::new();
|
||||
for (docid, vector) in vectors_set {
|
||||
for (docid, vector) in docid_vectors_map
|
||||
.into_iter()
|
||||
.flat_map(|(docid, vectors)| std::iter::repeat(docid).zip(vectors))
|
||||
{
|
||||
if expected_dimension_size.map_or(false, |expected| expected != vector.len()) {
|
||||
return Err(UserError::InvalidVectorDimensions {
|
||||
expected: expected_dimension_size.unwrap_or(vector.len()),
|
||||
|
@ -3,7 +3,7 @@ use std::result::Result as StdResult;
|
||||
|
||||
use charabia::{Normalize, Tokenizer, TokenizerBuilder};
|
||||
use deserr::{DeserializeError, Deserr};
|
||||
use itertools::Itertools;
|
||||
use itertools::{EitherOrBoth, Itertools};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
@ -15,6 +15,8 @@ use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::update::index_documents::IndexDocumentsMethod;
|
||||
use crate::update::{IndexDocuments, UpdateIndexingStep};
|
||||
use crate::vector::settings::{EmbeddingSettings, PromptSettings};
|
||||
use crate::vector::EmbeddingConfig;
|
||||
use crate::{FieldsIdsMap, Index, OrderBy, Result};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
@ -73,6 +75,13 @@ impl<T> Setting<T> {
|
||||
otherwise => otherwise,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
if let Setting::NotSet = new {
|
||||
return;
|
||||
}
|
||||
*self = new;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Serialize> Serialize for Setting<T> {
|
||||
@ -129,6 +138,7 @@ pub struct Settings<'a, 't, 'i> {
|
||||
sort_facet_values_by: Setting<HashMap<String, OrderBy>>,
|
||||
pagination_max_total_hits: Setting<usize>,
|
||||
proximity_precision: Setting<ProximityPrecision>,
|
||||
embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>,
|
||||
}
|
||||
|
||||
impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
@ -161,6 +171,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
sort_facet_values_by: Setting::NotSet,
|
||||
pagination_max_total_hits: Setting::NotSet,
|
||||
proximity_precision: Setting::NotSet,
|
||||
embedder_settings: Setting::NotSet,
|
||||
indexer_config,
|
||||
}
|
||||
}
|
||||
@ -343,6 +354,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
self.proximity_precision = Setting::Reset;
|
||||
}
|
||||
|
||||
pub fn set_embedder_settings(&mut self, value: BTreeMap<String, Setting<EmbeddingSettings>>) {
|
||||
self.embedder_settings = Setting::Set(value);
|
||||
}
|
||||
|
||||
pub fn reset_embedder_settings(&mut self) {
|
||||
self.embedder_settings = Setting::Reset;
|
||||
}
|
||||
|
||||
fn reindex<FP, FA>(
|
||||
&mut self,
|
||||
progress_callback: &FP,
|
||||
@ -890,6 +909,60 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
Ok(changed)
|
||||
}
|
||||
|
||||
fn update_embedding_configs(&mut self) -> Result<bool> {
|
||||
let update = match std::mem::take(&mut self.embedder_settings) {
|
||||
Setting::Set(configs) => {
|
||||
let mut changed = false;
|
||||
let old_configs = self.index.embedding_configs(self.wtxn)?;
|
||||
let old_configs: BTreeMap<String, Setting<EmbeddingSettings>> =
|
||||
old_configs.into_iter().map(|(k, v)| (k, Setting::Set(v.into()))).collect();
|
||||
|
||||
let mut new_configs = BTreeMap::new();
|
||||
for joined in old_configs
|
||||
.into_iter()
|
||||
.merge_join_by(configs.into_iter(), |(left, _), (right, _)| left.cmp(right))
|
||||
{
|
||||
match joined {
|
||||
EitherOrBoth::Both((name, mut old), (_, new)) => {
|
||||
old.apply(new);
|
||||
let new = validate_prompt(&name, old)?;
|
||||
changed = true;
|
||||
new_configs.insert(name, new);
|
||||
}
|
||||
EitherOrBoth::Left((name, setting)) => {
|
||||
new_configs.insert(name, setting);
|
||||
}
|
||||
EitherOrBoth::Right((name, setting)) => {
|
||||
let setting = validate_prompt(&name, setting)?;
|
||||
changed = true;
|
||||
new_configs.insert(name, setting);
|
||||
}
|
||||
}
|
||||
}
|
||||
let new_configs: Vec<(String, EmbeddingConfig)> = new_configs
|
||||
.into_iter()
|
||||
.filter_map(|(name, setting)| match setting {
|
||||
Setting::Set(value) => Some((name, value.into())),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => Some((name, EmbeddingSettings::default().into())),
|
||||
})
|
||||
.collect();
|
||||
if new_configs.is_empty() {
|
||||
self.index.delete_embedding_configs(self.wtxn)?;
|
||||
} else {
|
||||
self.index.put_embedding_configs(self.wtxn, new_configs)?;
|
||||
}
|
||||
changed
|
||||
}
|
||||
Setting::Reset => {
|
||||
self.index.delete_embedding_configs(self.wtxn)?;
|
||||
true
|
||||
}
|
||||
Setting::NotSet => false,
|
||||
};
|
||||
Ok(update)
|
||||
}
|
||||
|
||||
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()>
|
||||
where
|
||||
FP: Fn(UpdateIndexingStep) + Sync,
|
||||
@ -927,6 +1000,13 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
let searchable_updated = self.update_searchable()?;
|
||||
let exact_attributes_updated = self.update_exact_attributes()?;
|
||||
let proximity_precision = self.update_proximity_precision()?;
|
||||
// TODO: very rough approximation of the needs for reindexing where any change will result in
|
||||
// a full reindexing.
|
||||
// What can be done instead:
|
||||
// 1. Only change the distance on a distance change
|
||||
// 2. Only change the name -> embedder mapping on a name change
|
||||
// 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage
|
||||
let embedding_configs_updated = self.update_embedding_configs()?;
|
||||
|
||||
if stop_words_updated
|
||||
|| non_separator_tokens_updated
|
||||
@ -937,6 +1017,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
|| searchable_updated
|
||||
|| exact_attributes_updated
|
||||
|| proximity_precision
|
||||
|| embedding_configs_updated
|
||||
{
|
||||
self.reindex(&progress_callback, &should_abort, old_fields_ids_map)?;
|
||||
}
|
||||
@ -945,6 +1026,34 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_prompt(
|
||||
name: &str,
|
||||
new: Setting<EmbeddingSettings>,
|
||||
) -> Result<Setting<EmbeddingSettings>> {
|
||||
match new {
|
||||
Setting::Set(EmbeddingSettings {
|
||||
embedder_options,
|
||||
prompt:
|
||||
Setting::Set(PromptSettings { template: Setting::Set(template), strategy, fallback }),
|
||||
}) => {
|
||||
// validate
|
||||
let template = crate::prompt::Prompt::new(template, None, None)
|
||||
.map(|prompt| crate::prompt::PromptData::from(prompt).template)
|
||||
.map_err(|inner| UserError::InvalidPromptForEmbeddings(name.to_owned(), inner))?;
|
||||
|
||||
Ok(Setting::Set(EmbeddingSettings {
|
||||
embedder_options,
|
||||
prompt: Setting::Set(PromptSettings {
|
||||
template: Setting::Set(template),
|
||||
strategy,
|
||||
fallback,
|
||||
}),
|
||||
}))
|
||||
}
|
||||
new => Ok(new),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use big_s::S;
|
||||
@ -1763,6 +1872,7 @@ mod tests {
|
||||
sort_facet_values_by,
|
||||
pagination_max_total_hits,
|
||||
proximity_precision,
|
||||
embedder_settings,
|
||||
} = settings;
|
||||
assert!(matches!(searchable_fields, Setting::NotSet));
|
||||
assert!(matches!(displayed_fields, Setting::NotSet));
|
||||
@ -1785,6 +1895,7 @@ mod tests {
|
||||
assert!(matches!(sort_facet_values_by, Setting::NotSet));
|
||||
assert!(matches!(pagination_max_total_hits, Setting::NotSet));
|
||||
assert!(matches!(proximity_precision, Setting::NotSet));
|
||||
assert!(matches!(embedder_settings, Setting::NotSet));
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
229
milli/src/vector/error.rs
Normal file
229
milli/src/vector/error.rs
Normal file
@ -0,0 +1,229 @@
|
||||
use std::path::PathBuf;
|
||||
|
||||
use hf_hub::api::sync::ApiError;
|
||||
|
||||
use crate::error::FaultSource;
|
||||
use crate::vector::openai::OpenAiError;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Error while generating embeddings: {inner}")]
|
||||
pub struct Error {
|
||||
pub inner: Box<ErrorKind>,
|
||||
}
|
||||
|
||||
impl<I: Into<ErrorKind>> From<I> for Error {
|
||||
fn from(value: I) -> Self {
|
||||
Self { inner: Box::new(value.into()) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn fault(&self) -> FaultSource {
|
||||
match &*self.inner {
|
||||
ErrorKind::NewEmbedderError(inner) => inner.fault,
|
||||
ErrorKind::EmbedError(inner) => inner.fault,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ErrorKind {
|
||||
#[error(transparent)]
|
||||
NewEmbedderError(#[from] NewEmbedderError),
|
||||
#[error(transparent)]
|
||||
EmbedError(#[from] EmbedError),
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct EmbedError {
|
||||
pub kind: EmbedErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum EmbedErrorKind {
|
||||
#[error("could not tokenize: {0}")]
|
||||
Tokenize(Box<dyn std::error::Error + Send + Sync>),
|
||||
#[error("unexpected tensor shape: {0}")]
|
||||
TensorShape(candle_core::Error),
|
||||
#[error("unexpected tensor value: {0}")]
|
||||
TensorValue(candle_core::Error),
|
||||
#[error("could not run model: {0}")]
|
||||
ModelForward(candle_core::Error),
|
||||
#[error("could not reach OpenAI: {0}")]
|
||||
OpenAiNetwork(reqwest::Error),
|
||||
#[error("unexpected response from OpenAI: {0}")]
|
||||
OpenAiUnexpected(reqwest::Error),
|
||||
#[error("could not authenticate against OpenAI: {0}")]
|
||||
OpenAiAuth(OpenAiError),
|
||||
#[error("sent too many requests to OpenAI: {0}")]
|
||||
OpenAiTooManyRequests(OpenAiError),
|
||||
#[error("received internal error from OpenAI: {0}")]
|
||||
OpenAiInternalServerError(OpenAiError),
|
||||
#[error("sent too many tokens in a request to OpenAI: {0}")]
|
||||
OpenAiTooManyTokens(OpenAiError),
|
||||
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
||||
OpenAiUnhandledStatusCode(u16),
|
||||
}
|
||||
|
||||
impl EmbedError {
|
||||
pub fn tokenize(inner: Box<dyn std::error::Error + Send + Sync>) -> Self {
|
||||
Self { kind: EmbedErrorKind::Tokenize(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn tensor_shape(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::TensorShape(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn tensor_value(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::TensorValue(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn model_forward(inner: candle_core::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_network(inner: reqwest::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_internal_server_error(inner: OpenAiError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("{fault}: {kind}")]
|
||||
pub struct NewEmbedderError {
|
||||
pub kind: NewEmbedderErrorKind,
|
||||
pub fault: FaultSource,
|
||||
}
|
||||
|
||||
impl NewEmbedderError {
|
||||
pub fn open_config(config_filename: PathBuf, inner: std::io::Error) -> NewEmbedderError {
|
||||
let open_config = OpenConfig { filename: config_filename, inner };
|
||||
|
||||
Self { kind: NewEmbedderErrorKind::OpenConfig(open_config), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn deserialize_config(
|
||||
config: String,
|
||||
config_filename: PathBuf,
|
||||
inner: serde_json::Error,
|
||||
) -> NewEmbedderError {
|
||||
let deserialize_config = DeserializeConfig { config, filename: config_filename, inner };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn open_tokenizer(
|
||||
tokenizer_filename: PathBuf,
|
||||
inner: Box<dyn std::error::Error + Send + Sync>,
|
||||
) -> NewEmbedderError {
|
||||
let open_tokenizer = OpenTokenizer { filename: tokenizer_filename, inner };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::OpenTokenizer(open_tokenizer),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_api_fail(inner: ApiError) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::NewApiFail(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub fn api_get(inner: ApiError) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::ApiGet(inner), fault: FaultSource::Undecided }
|
||||
}
|
||||
|
||||
pub fn pytorch_weight(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn load_model(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not open config at {filename:?}: {inner}")]
|
||||
pub struct OpenConfig {
|
||||
pub filename: PathBuf,
|
||||
pub inner: std::io::Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")]
|
||||
pub struct DeserializeConfig {
|
||||
pub config: String,
|
||||
pub filename: PathBuf,
|
||||
pub inner: serde_json::Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not open tokenizer at {filename}: {inner}")]
|
||||
pub struct OpenTokenizer {
|
||||
pub filename: PathBuf,
|
||||
#[source]
|
||||
pub inner: Box<dyn std::error::Error + Send + Sync>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum NewEmbedderErrorKind {
|
||||
// hf
|
||||
#[error(transparent)]
|
||||
OpenConfig(OpenConfig),
|
||||
#[error(transparent)]
|
||||
DeserializeConfig(DeserializeConfig),
|
||||
#[error(transparent)]
|
||||
OpenTokenizer(OpenTokenizer),
|
||||
#[error("could not build weights from Pytorch weights: {0}")]
|
||||
PytorchWeight(candle_core::Error),
|
||||
#[error("could not build weights from Safetensor weights: {0}")]
|
||||
SafetensorWeight(candle_core::Error),
|
||||
#[error("could not spawn HG_HUB API client: {0}")]
|
||||
NewApiFail(ApiError),
|
||||
#[error("fetching file from HG_HUB failed: {0}")]
|
||||
ApiGet(ApiError),
|
||||
#[error("loading model failed: {0}")]
|
||||
LoadModel(candle_core::Error),
|
||||
// openai
|
||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||
InitWebClient(reqwest::Error),
|
||||
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
|
||||
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
|
||||
}
|
192
milli/src/vector/hf.rs
Normal file
192
milli/src/vector/hf.rs
Normal file
@ -0,0 +1,192 @@
|
||||
use candle_core::Tensor;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::bert::{BertModel, Config, DTYPE};
|
||||
// FIXME: currently we'll be using the hub to retrieve model, in the future we might want to embed it into Meilisearch itself
|
||||
use hf_hub::api::sync::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use tokenizers::{PaddingParams, Tokenizer};
|
||||
|
||||
pub use super::error::{EmbedError, Error, NewEmbedderError};
|
||||
use super::{Embedding, Embeddings};
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Clone,
|
||||
Copy,
|
||||
Default,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
serde::Deserialize,
|
||||
serde::Serialize,
|
||||
deserr::Deserr,
|
||||
)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum WeightSource {
|
||||
#[default]
|
||||
Safetensors,
|
||||
Pytorch,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub model: String,
|
||||
pub revision: Option<String>,
|
||||
pub weight_source: WeightSource,
|
||||
pub normalize_embeddings: bool,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
//model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
model: "BAAI/bge-base-en-v1.5".to_string(),
|
||||
//revision: Some("refs/pr/21".to_string()),
|
||||
revision: None,
|
||||
//weight_source: Default::default(),
|
||||
weight_source: WeightSource::Pytorch,
|
||||
normalize_embeddings: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbedderOptions {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform embedding of documents and queries
|
||||
pub struct Embedder {
|
||||
model: BertModel,
|
||||
tokenizer: Tokenizer,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for Embedder {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("Embedder")
|
||||
.field("model", &self.options.model)
|
||||
.field("tokenizer", &self.tokenizer)
|
||||
.field("options", &self.options)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
let device = candle_core::Device::Cpu;
|
||||
let repo = match options.revision.clone() {
|
||||
Some(revision) => Repo::with_revision(options.model.clone(), RepoType::Model, revision),
|
||||
None => Repo::model(options.model.clone()),
|
||||
};
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new().map_err(NewEmbedderError::new_api_fail)?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json").map_err(NewEmbedderError::api_get)?;
|
||||
let tokenizer = api.get("tokenizer.json").map_err(NewEmbedderError::api_get)?;
|
||||
let weights = match options.weight_source {
|
||||
WeightSource::Pytorch => {
|
||||
api.get("pytorch_model.bin").map_err(NewEmbedderError::api_get)?
|
||||
}
|
||||
WeightSource::Safetensors => {
|
||||
api.get("model.safetensors").map_err(NewEmbedderError::api_get)?
|
||||
}
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
|
||||
let config = std::fs::read_to_string(&config_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
||||
let config: Config = serde_json::from_str(&config).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(config, config_filename, inner)
|
||||
})?;
|
||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||
|
||||
let vb = match options.weight_source {
|
||||
WeightSource::Pytorch => VarBuilder::from_pth(&weights_filename, DTYPE, &device)
|
||||
.map_err(NewEmbedderError::pytorch_weight)?,
|
||||
WeightSource::Safetensors => unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)
|
||||
.map_err(NewEmbedderError::safetensor_weight)?
|
||||
},
|
||||
};
|
||||
|
||||
let model = BertModel::load(vb, &config).map_err(NewEmbedderError::load_model)?;
|
||||
|
||||
if let Some(pp) = tokenizer.get_padding_mut() {
|
||||
pp.strategy = tokenizers::PaddingStrategy::BatchLongest
|
||||
} else {
|
||||
let pp = PaddingParams {
|
||||
strategy: tokenizers::PaddingStrategy::BatchLongest,
|
||||
..Default::default()
|
||||
};
|
||||
tokenizer.with_padding(Some(pp));
|
||||
}
|
||||
|
||||
Ok(Self { model, tokenizer, options })
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
&self,
|
||||
mut texts: Vec<String>,
|
||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let tokens = match texts.len() {
|
||||
1 => vec![self
|
||||
.tokenizer
|
||||
.encode(texts.pop().unwrap(), true)
|
||||
.map_err(EmbedError::tokenize)?],
|
||||
_ => self.tokenizer.encode_batch(texts, true).map_err(EmbedError::tokenize)?,
|
||||
};
|
||||
let token_ids = tokens
|
||||
.iter()
|
||||
.map(|tokens| {
|
||||
let tokens = tokens.get_ids().to_vec();
|
||||
Tensor::new(tokens.as_slice(), &self.model.device).map_err(EmbedError::tensor_shape)
|
||||
})
|
||||
.collect::<Result<Vec<_>, EmbedError>>()?;
|
||||
|
||||
let token_ids = Tensor::stack(&token_ids, 0).map_err(EmbedError::tensor_shape)?;
|
||||
let token_type_ids = token_ids.zeros_like().map_err(EmbedError::tensor_shape)?;
|
||||
let embeddings =
|
||||
self.model.forward(&token_ids, &token_type_ids).map_err(EmbedError::model_forward)?;
|
||||
|
||||
// Apply some avg-pooling by taking the mean embedding value for all tokens (including padding)
|
||||
let (_n_sentence, n_tokens, _hidden_size) =
|
||||
embeddings.dims3().map_err(EmbedError::tensor_shape)?;
|
||||
|
||||
let embeddings = (embeddings.sum(1).map_err(EmbedError::tensor_value)? / (n_tokens as f64))
|
||||
.map_err(EmbedError::tensor_shape)?;
|
||||
|
||||
let embeddings: Tensor = if self.options.normalize_embeddings {
|
||||
normalize_l2(&embeddings).map_err(EmbedError::tensor_value)?
|
||||
} else {
|
||||
embeddings
|
||||
};
|
||||
|
||||
let embeddings: Vec<Embedding> = embeddings.to_vec2().map_err(EmbedError::tensor_shape)?;
|
||||
Ok(embeddings.into_iter().map(Embeddings::from_single_embedding).collect())
|
||||
}
|
||||
|
||||
pub async fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
std::thread::available_parallelism().map(|x| x.get()).unwrap_or(8)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_l2(v: &Tensor) -> Result<Tensor, candle_core::Error> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
142
milli/src/vector/mod.rs
Normal file
142
milli/src/vector/mod.rs
Normal file
@ -0,0 +1,142 @@
|
||||
use self::error::{EmbedError, NewEmbedderError};
|
||||
use crate::prompt::PromptData;
|
||||
|
||||
pub mod error;
|
||||
pub mod hf;
|
||||
pub mod openai;
|
||||
pub mod settings;
|
||||
|
||||
pub use self::error::Error;
|
||||
|
||||
pub type Embedding = Vec<f32>;
|
||||
|
||||
pub struct Embeddings<F> {
|
||||
data: Vec<F>,
|
||||
dimension: usize,
|
||||
}
|
||||
|
||||
impl<F> Embeddings<F> {
|
||||
pub fn new(dimension: usize) -> Self {
|
||||
Self { data: Default::default(), dimension }
|
||||
}
|
||||
|
||||
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
|
||||
Self { dimension: embedding.len(), data: embedding }
|
||||
}
|
||||
|
||||
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
|
||||
let mut this = Self::new(dimension);
|
||||
this.append(data)?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.dimension
|
||||
}
|
||||
|
||||
pub fn into_inner(self) -> Vec<F> {
|
||||
self.data
|
||||
}
|
||||
|
||||
pub fn as_inner(&self) -> &[F] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
|
||||
self.data.as_slice().chunks_exact(self.dimension)
|
||||
}
|
||||
|
||||
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
|
||||
if embedding.len() != self.dimension {
|
||||
return Err(embedding);
|
||||
}
|
||||
self.data.append(&mut embedding);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
|
||||
if embeddings.len() % self.dimension != 0 {
|
||||
return Err(embeddings);
|
||||
}
|
||||
self.data.append(&mut embeddings);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Embedder {
|
||||
HuggingFace(hf::Embedder),
|
||||
OpenAi(openai::Embedder),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbeddingConfig {
|
||||
pub embedder_options: EmbedderOptions,
|
||||
pub prompt: PromptData,
|
||||
// TODO: add metrics and anything needed
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub enum EmbedderOptions {
|
||||
HuggingFace(hf::EmbedderOptions),
|
||||
OpenAi(openai::EmbedderOptions),
|
||||
}
|
||||
|
||||
impl Default for EmbedderOptions {
|
||||
fn default() -> Self {
|
||||
Self::HuggingFace(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn huggingface() -> Self {
|
||||
Self::HuggingFace(hf::EmbedderOptions::new())
|
||||
}
|
||||
|
||||
pub fn openai(api_key: String) -> Self {
|
||||
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
Ok(match options {
|
||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed(texts).await,
|
||||
Embedder::OpenAi(embedder) => embedder.embed(texts).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks).await,
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks).await,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
}
|
||||
}
|
||||
}
|
416
milli/src/vector/openai.rs
Normal file
416
milli/src/vector/openai.rs
Normal file
@ -0,0 +1,416 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::error::{EmbedError, NewEmbedderError};
|
||||
use super::{Embedding, Embeddings};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
client: reqwest::Client,
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub api_key: String,
|
||||
pub embedding_model: EmbeddingModel,
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Clone,
|
||||
Copy,
|
||||
Default,
|
||||
Hash,
|
||||
PartialEq,
|
||||
Eq,
|
||||
serde::Serialize,
|
||||
serde::Deserialize,
|
||||
deserr::Deserr,
|
||||
)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum EmbeddingModel {
|
||||
#[default]
|
||||
TextEmbeddingAda002,
|
||||
}
|
||||
|
||||
impl EmbeddingModel {
|
||||
pub fn max_token(&self) -> usize {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => 8191,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => 1536,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => "text-embedding-ada-002",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_name(name: &'static str) -> Option<Self> {
|
||||
match name {
|
||||
"text-embedding-ada-002" => Some(EmbeddingModel::TextEmbeddingAda002),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub const OPENAI_EMBEDDINGS_URL: &str = "https://api.openai.com/v1/embeddings";
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn with_default_model(api_key: String) -> Self {
|
||||
Self { api_key, embedding_model: Default::default() }
|
||||
}
|
||||
|
||||
pub fn with_embedding_model(api_key: String, embedding_model: EmbeddingModel) -> Self {
|
||||
Self { api_key, embedding_model }
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", &options.api_key))
|
||||
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
|
||||
);
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.default_headers(headers)
|
||||
.build()
|
||||
.map_err(NewEmbedderError::openai_initialize_web_client)?;
|
||||
|
||||
// looking at the code it is very unclear that this can actually fail.
|
||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||
|
||||
Ok(Self { options, client, tokenizer })
|
||||
}
|
||||
|
||||
pub async fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let mut tokenized = false;
|
||||
|
||||
for attempt in 0..7 {
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts).await
|
||||
} else {
|
||||
self.try_embed(&texts).await
|
||||
};
|
||||
|
||||
let retry_duration = match result {
|
||||
Ok(embeddings) => return Ok(embeddings),
|
||||
Err(retry) => {
|
||||
log::warn!("Failed: {}", retry.error);
|
||||
tokenized |= retry.must_tokenize();
|
||||
retry.into_duration(attempt)
|
||||
}
|
||||
}?;
|
||||
log::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
|
||||
tokio::time::sleep(retry_duration).await;
|
||||
}
|
||||
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts).await
|
||||
} else {
|
||||
self.try_embed(&texts).await
|
||||
};
|
||||
|
||||
result.map_err(Retry::into_error)
|
||||
}
|
||||
|
||||
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
|
||||
if !response.status().is_success() {
|
||||
match response.status() {
|
||||
StatusCode::UNAUTHORIZED => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
return Err(Retry::give_up(EmbedError::openai_auth_error(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::SERVICE_UNAVAILABLE => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Most probably, one text contained too many tokens
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
log::warn!("OpenAI: input was too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
|
||||
|
||||
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
code => {
|
||||
return Err(Retry::give_up(EmbedError::openai_unhandled_status_code(
|
||||
code.as_u16(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
for text in texts {
|
||||
log::trace!("Received prompt: {}", text.as_ref())
|
||||
}
|
||||
let request = OpenAiRequest { model: self.options.embedding_model.name(), input: texts };
|
||||
let response = self
|
||||
.client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::openai_network)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
let response = Self::check_response(response).await?;
|
||||
|
||||
let response: OpenAiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
log::trace!("response: {:?}", response.data);
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|data| Embeddings::from_single_embedding(data.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
pub const OVERLAP_SIZE: usize = 200;
|
||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||
for text in text {
|
||||
let max_token_count = self.options.embedding_model.max_token();
|
||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.try_embed(&[text]).await?);
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut tokens = encoded.as_slice();
|
||||
let mut embeddings_for_prompt =
|
||||
Embeddings::new(self.options.embedding_model.dimensions());
|
||||
while tokens.len() > max_token_count {
|
||||
let window = &tokens[..max_token_count];
|
||||
embeddings_for_prompt.push(self.embed_tokens(window).await?).unwrap();
|
||||
|
||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||
}
|
||||
|
||||
// end of text
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens).await?).unwrap();
|
||||
|
||||
all_embeddings.push(embeddings_for_prompt);
|
||||
}
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
|
||||
async fn embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
||||
for attempt in 0..9 {
|
||||
let duration = match self.try_embed_tokens(tokens).await {
|
||||
Ok(embedding) => return Ok(embedding),
|
||||
Err(retry) => retry.into_duration(attempt),
|
||||
}
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
|
||||
self.try_embed_tokens(tokens).await.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
}
|
||||
|
||||
async fn try_embed_tokens(&self, tokens: &[usize]) -> Result<Embedding, Retry> {
|
||||
let request =
|
||||
OpenAiTokensRequest { model: self.options.embedding_model.name(), input: tokens };
|
||||
let response = self
|
||||
.client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::openai_network)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
let response = Self::check_response(response).await?;
|
||||
|
||||
let mut response: OpenAiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||
}
|
||||
|
||||
pub async fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
futures::future::try_join_all(text_chunks.into_iter().map(|prompts| self.embed(prompts)))
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
10
|
||||
}
|
||||
}
|
||||
|
||||
// retrying in case of failure
|
||||
|
||||
struct Retry {
|
||||
error: EmbedError,
|
||||
strategy: RetryStrategy,
|
||||
}
|
||||
|
||||
enum RetryStrategy {
|
||||
GiveUp,
|
||||
Retry,
|
||||
RetryTokenized,
|
||||
RetryAfterRateLimit,
|
||||
}
|
||||
|
||||
impl Retry {
|
||||
fn give_up(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::GiveUp }
|
||||
}
|
||||
|
||||
fn retry_later(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::Retry }
|
||||
}
|
||||
|
||||
fn retry_tokenized(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryTokenized }
|
||||
}
|
||||
|
||||
fn rate_limited(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
||||
}
|
||||
|
||||
fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
|
||||
match self.strategy {
|
||||
RetryStrategy::GiveUp => Err(self.error),
|
||||
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
|
||||
RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)),
|
||||
RetryStrategy::RetryAfterRateLimit => {
|
||||
Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn must_tokenize(&self) -> bool {
|
||||
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
||||
}
|
||||
|
||||
fn into_error(self) -> EmbedError {
|
||||
self.error
|
||||
}
|
||||
}
|
||||
|
||||
// openai api structs
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
|
||||
model: &'a str,
|
||||
input: &'a [S],
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OpenAiTokensRequest<'a> {
|
||||
model: &'a str,
|
||||
input: &'a [usize],
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
data: Vec<OpenAiEmbedding>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiErrorResponse {
|
||||
error: OpenAiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OpenAiError {
|
||||
message: String,
|
||||
// type: String,
|
||||
code: Option<String>,
|
||||
}
|
||||
|
||||
impl Display for OpenAiError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.code {
|
||||
Some(code) => write!(f, "{} ({})", self.message, code),
|
||||
None => write!(f, "{}", self.message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbedding {
|
||||
embedding: Embedding,
|
||||
// object: String,
|
||||
// index: usize,
|
||||
}
|
308
milli/src/vector/settings.rs
Normal file
308
milli/src/vector/settings.rs
Normal file
@ -0,0 +1,308 @@
|
||||
use deserr::Deserr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::prompt::{PromptData, PromptFallbackStrategy};
|
||||
use crate::update::Setting;
|
||||
use crate::vector::hf::WeightSource;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct EmbeddingSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set", rename = "source")]
|
||||
#[deserr(default, rename = "source")]
|
||||
pub embedder_options: Setting<EmbedderSettings>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub prompt: Setting<PromptSettings>,
|
||||
}
|
||||
|
||||
impl EmbeddingSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let EmbeddingSettings { embedder_options, prompt } = new;
|
||||
self.embedder_options.apply(embedder_options);
|
||||
self.prompt.apply(prompt);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
fn from(value: EmbeddingConfig) -> Self {
|
||||
Self {
|
||||
embedder_options: Setting::Set(value.embedder_options.into()),
|
||||
prompt: Setting::Set(value.prompt.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
fn from(value: EmbeddingSettings) -> Self {
|
||||
let mut this = Self::default();
|
||||
let EmbeddingSettings { embedder_options, prompt } = value;
|
||||
if let Some(embedder_options) = embedder_options.set() {
|
||||
this.embedder_options = embedder_options.into();
|
||||
}
|
||||
if let Some(prompt) = prompt.set() {
|
||||
this.prompt = prompt.into();
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct PromptSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub template: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub strategy: Setting<PromptFallbackStrategy>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub fallback: Setting<String>,
|
||||
}
|
||||
|
||||
impl PromptSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let PromptSettings { template, strategy, fallback } = new;
|
||||
self.template.apply(template);
|
||||
self.strategy.apply(strategy);
|
||||
self.fallback.apply(fallback);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PromptData> for PromptSettings {
|
||||
fn from(value: PromptData) -> Self {
|
||||
Self {
|
||||
template: Setting::Set(value.template),
|
||||
strategy: Setting::Set(value.strategy),
|
||||
fallback: Setting::Set(value.fallback),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<PromptSettings> for PromptData {
|
||||
fn from(value: PromptSettings) -> Self {
|
||||
let mut this = PromptData::default();
|
||||
let PromptSettings { template, strategy, fallback } = value;
|
||||
if let Some(template) = template.set() {
|
||||
this.template = template;
|
||||
}
|
||||
if let Some(strategy) = strategy.set() {
|
||||
this.strategy = strategy;
|
||||
}
|
||||
if let Some(fallback) = fallback.set() {
|
||||
this.fallback = fallback;
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub enum EmbedderSettings {
|
||||
HuggingFace(Setting<HfEmbedderSettings>),
|
||||
OpenAi(Setting<OpenAiEmbedderSettings>),
|
||||
}
|
||||
|
||||
impl<E> Deserr<E> for EmbedderSettings
|
||||
where
|
||||
E: deserr::DeserializeError,
|
||||
{
|
||||
fn deserialize_from_value<V: deserr::IntoValue>(
|
||||
value: deserr::Value<V>,
|
||||
location: deserr::ValuePointerRef,
|
||||
) -> Result<Self, E> {
|
||||
match value {
|
||||
deserr::Value::Map(map) => {
|
||||
if deserr::Map::len(&map) != 1 {
|
||||
return Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::Unexpected {
|
||||
msg: format!(
|
||||
"Expected a single field, got {} fields",
|
||||
deserr::Map::len(&map)
|
||||
),
|
||||
},
|
||||
location,
|
||||
)));
|
||||
}
|
||||
let mut it = deserr::Map::into_iter(map);
|
||||
let (k, v) = it.next().unwrap();
|
||||
|
||||
match k.as_str() {
|
||||
"huggingFace" => Ok(EmbedderSettings::HuggingFace(Setting::Set(
|
||||
HfEmbedderSettings::deserialize_from_value(
|
||||
v.into_value(),
|
||||
location.push_key(&k),
|
||||
)?,
|
||||
))),
|
||||
"openAi" => Ok(EmbedderSettings::OpenAi(Setting::Set(
|
||||
OpenAiEmbedderSettings::deserialize_from_value(
|
||||
v.into_value(),
|
||||
location.push_key(&k),
|
||||
)?,
|
||||
))),
|
||||
other => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::UnknownKey {
|
||||
key: other,
|
||||
accepted: &["huggingFace", "openAi"],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
_ => Err(deserr::take_cf_content(E::error::<V>(
|
||||
None,
|
||||
deserr::ErrorKind::IncorrectValueKind {
|
||||
actual: value,
|
||||
accepted: &[deserr::ValueKind::Map],
|
||||
},
|
||||
location,
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for EmbedderSettings {
|
||||
fn default() -> Self {
|
||||
Self::HuggingFace(Default::default())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::vector::EmbedderOptions> for EmbedderSettings {
|
||||
fn from(value: crate::vector::EmbedderOptions) -> Self {
|
||||
match value {
|
||||
crate::vector::EmbedderOptions::HuggingFace(hf) => {
|
||||
Self::HuggingFace(Setting::Set(hf.into()))
|
||||
}
|
||||
crate::vector::EmbedderOptions::OpenAi(openai) => {
|
||||
Self::OpenAi(Setting::Set(openai.into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<EmbedderSettings> for crate::vector::EmbedderOptions {
|
||||
fn from(value: EmbedderSettings) -> Self {
|
||||
match value {
|
||||
EmbedderSettings::HuggingFace(Setting::Set(hf)) => Self::HuggingFace(hf.into()),
|
||||
EmbedderSettings::HuggingFace(_setting) => Self::HuggingFace(Default::default()),
|
||||
EmbedderSettings::OpenAi(Setting::Set(ai)) => Self::OpenAi(ai.into()),
|
||||
EmbedderSettings::OpenAi(_setting) => Self::OpenAi(
|
||||
crate::vector::openai::EmbedderOptions::with_default_model(infer_api_key()),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct HfEmbedderSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub model: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub revision: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub weight_source: Setting<WeightSource>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub normalize_embeddings: Setting<bool>,
|
||||
}
|
||||
|
||||
impl HfEmbedderSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let HfEmbedderSettings {
|
||||
model,
|
||||
revision,
|
||||
weight_source,
|
||||
normalize_embeddings: normalize_embedding,
|
||||
} = new;
|
||||
self.model.apply(model);
|
||||
self.revision.apply(revision);
|
||||
self.weight_source.apply(weight_source);
|
||||
self.normalize_embeddings.apply(normalize_embedding);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::vector::hf::EmbedderOptions> for HfEmbedderSettings {
|
||||
fn from(value: crate::vector::hf::EmbedderOptions) -> Self {
|
||||
Self {
|
||||
model: Setting::Set(value.model),
|
||||
revision: value.revision.map(Setting::Set).unwrap_or(Setting::NotSet),
|
||||
weight_source: Setting::Set(value.weight_source),
|
||||
normalize_embeddings: Setting::Set(value.normalize_embeddings),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<HfEmbedderSettings> for crate::vector::hf::EmbedderOptions {
|
||||
fn from(value: HfEmbedderSettings) -> Self {
|
||||
let HfEmbedderSettings { model, revision, weight_source, normalize_embeddings } = value;
|
||||
let mut this = Self::default();
|
||||
if let Some(model) = model.set() {
|
||||
this.model = model;
|
||||
}
|
||||
if let Some(revision) = revision.set() {
|
||||
this.revision = Some(revision);
|
||||
}
|
||||
if let Some(weight_source) = weight_source.set() {
|
||||
this.weight_source = weight_source;
|
||||
}
|
||||
if let Some(normalize_embeddings) = normalize_embeddings.set() {
|
||||
this.normalize_embeddings = normalize_embeddings;
|
||||
}
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq, Deserr)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub struct OpenAiEmbedderSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub embedding_model: Setting<crate::vector::openai::EmbeddingModel>,
|
||||
}
|
||||
|
||||
impl OpenAiEmbedderSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let Self { api_key, embedding_model: embedding_mode } = new;
|
||||
self.api_key.apply(api_key);
|
||||
self.embedding_model.apply(embedding_mode);
|
||||
}
|
||||
}
|
||||
|
||||
impl From<crate::vector::openai::EmbedderOptions> for OpenAiEmbedderSettings {
|
||||
fn from(value: crate::vector::openai::EmbedderOptions) -> Self {
|
||||
Self {
|
||||
api_key: Setting::Set(value.api_key),
|
||||
embedding_model: Setting::Set(value.embedding_model),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<OpenAiEmbedderSettings> for crate::vector::openai::EmbedderOptions {
|
||||
fn from(value: OpenAiEmbedderSettings) -> Self {
|
||||
let OpenAiEmbedderSettings { api_key, embedding_model } = value;
|
||||
Self {
|
||||
api_key: api_key.set().unwrap_or_else(infer_api_key),
|
||||
embedding_model: embedding_model.set().unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_api_key() -> String {
|
||||
/// FIXME: get key from instance options?
|
||||
std::env::var("MEILI_OPENAI_API_KEY").unwrap_or_default()
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user