mirror of
https://github.com/meilisearch/MeiliSearch
synced 2024-12-26 14:40:05 +01:00
Merge branch 'main' into tmp-release-v1.7.4
This commit is contained in:
commit
796213af9a
8
.github/ISSUE_TEMPLATE/sprint_issue.md
vendored
8
.github/ISSUE_TEMPLATE/sprint_issue.md
vendored
@ -2,7 +2,7 @@
|
||||
name: New sprint issue
|
||||
about: ⚠️ Should only be used by the engine team ⚠️
|
||||
title: ''
|
||||
labels: ''
|
||||
labels: 'missing usage in PRD, impacts docs'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
@ -21,11 +21,7 @@ Related spec: WIP
|
||||
|
||||
## TODO
|
||||
|
||||
<!---Feel free to adapt this list with more technical/product steps-->
|
||||
|
||||
- [ ] Release a prototype
|
||||
- [ ] If prototype validated, merge changes into `main`
|
||||
- [ ] Update the spec
|
||||
<!---If necessary, create a list with technical/product steps-->
|
||||
|
||||
### Reminders when modifying the Setting API
|
||||
|
||||
|
2
.github/workflows/bench-pr.yml
vendored
2
.github/workflows/bench-pr.yml
vendored
@ -43,4 +43,4 @@ jobs:
|
||||
|
||||
- name: Run benchmarks on PR ${{ github.event.issue.id }}
|
||||
run: |
|
||||
cargo xtask bench --api-key "${{ secrets.BENCHMARK_API_KEY }}" --dashboard-url "${{ vars.BENCHMARK_DASHBOARD_URL }}" --reason "[Comment](${{ github.event.comment.url }}) on [#${{github.event.issue.id}}](${{ github.event.issue.url }})" -- ${{ steps.command.outputs.command-arguments }}
|
||||
cargo xtask bench --api-key "${{ secrets.BENCHMARK_API_KEY }}" --dashboard-url "${{ vars.BENCHMARK_DASHBOARD_URL }}" --reason "[Comment](${{ github.event.comment.html_url }}) on [#${{ github.event.issue.number }}](${{ github.event.issue.html_url }})" -- ${{ steps.command.outputs.command-arguments }}
|
38
.github/workflows/milestone-workflow.yml
vendored
38
.github/workflows/milestone-workflow.yml
vendored
@ -110,6 +110,44 @@ jobs:
|
||||
--milestone $MILESTONE_VERSION \
|
||||
--assignee curquiza
|
||||
|
||||
create-update-version-issue:
|
||||
needs: get-release-version
|
||||
# Create the update-version issue even if the release is a patch release
|
||||
if: github.event.action == 'created'
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
ISSUE_TEMPLATE: issue-template.md
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Download the issue template
|
||||
run: curl -s https://raw.githubusercontent.com/meilisearch/engine-team/main/issue-templates/update-version-issue.md > $ISSUE_TEMPLATE
|
||||
- name: Create the issue
|
||||
run: |
|
||||
gh issue create \
|
||||
--title "Update version in Cargo.toml for $MILESTONE_VERSION" \
|
||||
--label 'maintenance' \
|
||||
--body-file $ISSUE_TEMPLATE \
|
||||
--milestone $MILESTONE_VERSION
|
||||
|
||||
create-update-openapi-issue:
|
||||
needs: get-release-version
|
||||
# Create the openAPI issue if the release is not only a patch release
|
||||
if: github.event.action == 'created' && needs.get-release-version.outputs.is-patch == 'false'
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
ISSUE_TEMPLATE: issue-template.md
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Download the issue template
|
||||
run: curl -s https://raw.githubusercontent.com/meilisearch/engine-team/main/issue-templates/update-openapi-issue.md > $ISSUE_TEMPLATE
|
||||
- name: Create the issue
|
||||
run: |
|
||||
gh issue create \
|
||||
--title "Update Open API file for $MILESTONE_VERSION" \
|
||||
--label 'maintenance' \
|
||||
--body-file $ISSUE_TEMPLATE \
|
||||
--milestone $MILESTONE_VERSION
|
||||
|
||||
# ----------------
|
||||
# MILESTONE CLOSED
|
||||
# ----------------
|
||||
|
362
BENCHMARKS.md
Normal file
362
BENCHMARKS.md
Normal file
@ -0,0 +1,362 @@
|
||||
# Benchmarks
|
||||
|
||||
Currently this repository hosts two kinds of benchmarks:
|
||||
|
||||
1. The older "milli benchmarks", that use [criterion](https://github.com/bheisler/criterion.rs) and live in the "benchmarks" directory.
|
||||
2. The newer "bench" that are workload-based and so split between the [`workloads`](./workloads/) directory and the [`xtask::bench`](./xtask/src/bench/) module.
|
||||
|
||||
This document describes the newer "bench" benchmarks. For more details on the "milli benchmarks", see [benchmarks/README.md](./benchmarks/README.md).
|
||||
|
||||
## Design philosophy for the benchmarks
|
||||
|
||||
The newer "bench" benchmarks are **integration** benchmarks, in the sense that they spawn an actual Meilisearch server and measure its performance end-to-end, including HTTP request overhead.
|
||||
|
||||
Since this is prone to fluctuating, the benchmarks regain a bit of precision by measuring the runtime of the individual spans using the [logging machinery](./CONTRIBUTING.md#logging) of Meilisearch.
|
||||
|
||||
A span roughly translates to a function call. The benchmark runner collects all the spans by name using the [logs route](https://github.com/orgs/meilisearch/discussions/721) and sums their runtime. The processed results are then sent to the [benchmark dashboard](https://bench.meilisearch.dev), which is in charge of storing and presenting the data.
|
||||
|
||||
## Running the benchmarks
|
||||
|
||||
Benchmarks can run locally or in CI.
|
||||
|
||||
### Locally
|
||||
|
||||
#### With a local benchmark dashboard
|
||||
|
||||
The benchmarks dashboard lives in its [own repository](https://github.com/meilisearch/benchboard). We provide binaries for Ubuntu/Debian, but you can build from source for other platforms (MacOS should work as it was developed under that platform).
|
||||
|
||||
Run the `benchboard` binary to create a fresh database of results. By default it will serve the results and the API to gather results on `http://localhost:9001`.
|
||||
|
||||
From the Meilisearch repository, you can then run benchmarks with:
|
||||
|
||||
```sh
|
||||
cargo xtask bench -- workloads/my_workload_1.json ..
|
||||
```
|
||||
|
||||
This command will build and run Meilisearch locally on port 7700, so make sure that this port is available.
|
||||
To run benchmarks on a different commit, just use the usual git command to get back to the desired commit.
|
||||
|
||||
#### Without a local benchmark dashboard
|
||||
|
||||
To work with the raw results, you can also skip using a local benchmark dashboard.
|
||||
|
||||
Run:
|
||||
|
||||
```sh
|
||||
cargo xtask bench --no-dashboard -- workloads/my_workload_1.json workloads/my_workload_2.json ..
|
||||
```
|
||||
|
||||
For processing the results, look at [Looking at benchmark results/Without dashboard](#without-dashboard).
|
||||
|
||||
### In CI
|
||||
|
||||
We have dedicated runners to run workloads on CI. Currently, there are three ways of running the CI:
|
||||
|
||||
1. Automatically, on every push to `main`.
|
||||
2. Manually, by clicking the [`Run workflow`](https://github.com/meilisearch/meilisearch/actions/workflows/bench-manual.yml) button and specifying the target reference (tag, commit or branch) as well as one or multiple workloads to run. The workloads must exist in the Meilisearch repository (conventionally, in the [`workloads`](./workloads/) directory) on the target reference. Globbing (e.g., `workloads/*.json`) works.
|
||||
3. Manually on a PR, by posting a comment containing a `/bench` command, followed by one or multiple workloads to run. Globbing works. The workloads must exist in the Meilisearch repository in the branch of the PR.
|
||||
```
|
||||
/bench workloads/movies*.json /hackernews_1M.json
|
||||
```
|
||||
|
||||
## Looking at benchmark results
|
||||
|
||||
### On the dashboard
|
||||
|
||||
Results are available on the global dashboard used by CI at <https://bench.meilisearch.dev> or on your [local dashboard](#with-a-local-benchmark-dashboard).
|
||||
|
||||
The dashboard homepage presents three sections:
|
||||
|
||||
1. The latest invocations (a call to `cargo xtask bench`, either local or by CI) with their reason (generally set to some helpful link in CI) and their status.
|
||||
2. The latest workloads ran on `main`.
|
||||
3. The latest workloads ran on other references.
|
||||
|
||||
By default, the workload shows the total runtime delta with the latest applicable commit on `main`. The latest applicable commit is the latest commit for workload invocations that do not originate on `main`, and the latest previous commit for workload invocations that originate on `main`.
|
||||
|
||||
You can explicitly request a detailed comparison by span with the `main` branch, the branch or origin, or any previous commit, by clicking the links at the bottom of the workload invocation.
|
||||
|
||||
In the detailed comparison view, the spans are sorted by improvements, regressions, stable (no statistically significant change) and unstable (the span runtime is comparable to its standard deviation).
|
||||
|
||||
You can click on the name of any span to get a box plot comparing the target commit with multiple commits of the selected branch.
|
||||
|
||||
### Without dashboard
|
||||
|
||||
After the workloads are done running, the reports will live in the Meilisearch repository, in the `bench/reports` directory (by default).
|
||||
|
||||
You can then convert these reports into other formats.
|
||||
|
||||
- To [Firefox profiler](https://profiler.firefox.com) format. Run:
|
||||
```sh
|
||||
cd bench/reports
|
||||
cargo run --release --bin trace-to-firefox -- my_workload_1-0-trace.json
|
||||
```
|
||||
You can then upload the resulting `firefox-my_workload_1-0-trace.json` file to the online profiler.
|
||||
|
||||
|
||||
## Designing benchmark workloads
|
||||
|
||||
Benchmark workloads conventionally live in the `workloads` directory of the Meilisearch repository.
|
||||
|
||||
They are JSON files with the following structure (comments are not actually supported, to make your own, remove them or copy some existing workload file):
|
||||
|
||||
```jsonc
|
||||
{
|
||||
// Name of the workload. Must be unique to the workload, as it will be used to group results on the dashboard.
|
||||
"name": "hackernews.ndjson_1M,no-threads",
|
||||
// Number of consecutive runs of the commands that should be performed.
|
||||
// Each run uses a fresh instance of Meilisearch and a fresh database.
|
||||
// Each run produces its own report file.
|
||||
"run_count": 3,
|
||||
// List of arguments to add to the Meilisearch command line.
|
||||
"extra_cli_args": ["--max-indexing-threads=1"],
|
||||
// List of named assets that can be used in the commands.
|
||||
"assets": {
|
||||
// name of the asset.
|
||||
// Must be unique at the workload level.
|
||||
// For better results, the same asset (same sha256) should have the same name accross workloads.
|
||||
// Having multiple assets with the same name and distinct hashes is supported accross workloads,
|
||||
// but will lead to superfluous downloads.
|
||||
//
|
||||
// Assets are stored in the `bench/assets/` directory by default.
|
||||
"hackernews-100_000.ndjson": {
|
||||
// If the assets exists in the local filesystem (Meilisearch repository or for your local workloads)
|
||||
// Its file path can be specified here.
|
||||
// `null` if the asset should be downloaded from a remote location.
|
||||
"local_location": null,
|
||||
// URL of the remote location where the asset can be downloaded.
|
||||
// Use the `--assets-key` of the runner to pass an API key in the `Authorization: Bearer` header of the download requests.
|
||||
// `null` if the asset should be imported from a local location.
|
||||
// if both local and remote locations are specified, then the local one is tried first, then the remote one
|
||||
// if the file is locally missing or its hash differs.
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-100_000.ndjson",
|
||||
// SHA256 of the asset.
|
||||
// Optional, the `sha256` of the asset will be displayed during a run of the workload if it is missing.
|
||||
// If present, the hash of the asset in the `bench/assets/` directory will be compared against this hash before
|
||||
// running the workload. If the hashes differ, the asset will be downloaded anew.
|
||||
"sha256": "60ecd23485d560edbd90d9ca31f0e6dba1455422f2a44e402600fbb5f7f1b213",
|
||||
// Optional, one of "Auto", "Json", "NdJson" or "Raw".
|
||||
// If missing, assumed to be "Auto".
|
||||
// If "Auto", the format will be determined from the extension in the asset name.
|
||||
"format": "NdJson"
|
||||
},
|
||||
"hackernews-200_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-200_000.ndjson",
|
||||
"sha256": "785b0271fdb47cba574fab617d5d332276b835c05dd86e4a95251cf7892a1685"
|
||||
},
|
||||
"hackernews-300_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-300_000.ndjson",
|
||||
"sha256": "de73c7154652eddfaf69cdc3b2f824d5c452f095f40a20a1c97bb1b5c4d80ab2"
|
||||
},
|
||||
"hackernews-400_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-400_000.ndjson",
|
||||
"sha256": "c1b00a24689110f366447e434c201c086d6f456d54ed1c4995894102794d8fe7"
|
||||
},
|
||||
"hackernews-500_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-500_000.ndjson",
|
||||
"sha256": "ae98f9dbef8193d750e3e2dbb6a91648941a1edca5f6e82c143e7996f4840083"
|
||||
},
|
||||
"hackernews-600_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-600_000.ndjson",
|
||||
"sha256": "b495fdc72c4a944801f786400f22076ab99186bee9699f67cbab2f21f5b74dbe"
|
||||
},
|
||||
"hackernews-700_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-700_000.ndjson",
|
||||
"sha256": "4b2c63974f3dabaa4954e3d4598b48324d03c522321ac05b0d583f36cb78a28b"
|
||||
},
|
||||
"hackernews-800_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-800_000.ndjson",
|
||||
"sha256": "cb7b6afe0e6caa1be111be256821bc63b0771b2a0e1fad95af7aaeeffd7ba546"
|
||||
},
|
||||
"hackernews-900_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-900_000.ndjson",
|
||||
"sha256": "e1154ddcd398f1c867758a93db5bcb21a07b9e55530c188a2917fdef332d3ba9"
|
||||
},
|
||||
"hackernews-1_000_000.ndjson": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/hackernews/hackernews-1_000_000.ndjson",
|
||||
"sha256": "27e25efd0b68b159b8b21350d9af76938710cb29ce0393fa71b41c4f3c630ffe"
|
||||
}
|
||||
},
|
||||
// Core of the workload.
|
||||
// A list of commands to run sequentially.
|
||||
// A command is a request to the Meilisearch instance that is executed while the profiling runs.
|
||||
"commands": [
|
||||
{
|
||||
// Meilisearch route to call. `http://localhost:7700/` will be prepended.
|
||||
"route": "indexes/movies/settings",
|
||||
// HTTP method to call.
|
||||
"method": "PATCH",
|
||||
// If applicable, body of the request.
|
||||
// Optional, if missing, the body will be empty.
|
||||
"body": {
|
||||
// One of "empty", "inline" or "asset".
|
||||
// If using "empty", you can skip the entire "body" key.
|
||||
"inline": {
|
||||
// when "inline" is used, the body is the JSON object that is the value of the `"inline"` key.
|
||||
"displayedAttributes": [
|
||||
"title",
|
||||
"by",
|
||||
"score",
|
||||
"time"
|
||||
],
|
||||
"searchableAttributes": [
|
||||
"title"
|
||||
],
|
||||
"filterableAttributes": [
|
||||
"by"
|
||||
],
|
||||
"sortableAttributes": [
|
||||
"score",
|
||||
"time"
|
||||
]
|
||||
}
|
||||
},
|
||||
// Whether to wait before running the next request.
|
||||
// One of:
|
||||
// - DontWait: run the next command without waiting the response to this one.
|
||||
// - WaitForResponse: run the next command as soon as the response from the server is received.
|
||||
// - WaitForTask: run the next command once **all** the Meilisearch tasks created up to now have finished processing.
|
||||
"synchronous": "DontWait"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
// When using "asset", use the name of an asset as value to use the content of that asset as body.
|
||||
// the content type is derived of the format of the asset:
|
||||
// "NdJson" => "application/x-ndjson"
|
||||
// "Json" => "application/json"
|
||||
// "Raw" => "application/octet-stream"
|
||||
// See [AssetFormat::to_content_type](https://github.com/meilisearch/meilisearch/blob/7b670a4afadb132ac4a01b6403108700501a391d/xtask/src/bench/assets.rs#L30)
|
||||
// for details and up-to-date list.
|
||||
"asset": "hackernews-100_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-200_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-300_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-400_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-500_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-600_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-700_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-800_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-900_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForResponse"
|
||||
},
|
||||
{
|
||||
"route": "indexes/movies/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "hackernews-1_000_000.ndjson"
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Adding new assets
|
||||
|
||||
Assets reside in our DigitalOcean S3 space. Assuming you have team access to the DigitalOcean S3 space:
|
||||
|
||||
1. go to <https://cloud.digitalocean.com/spaces/milli-benchmarks?i=d1c552&path=bench%2Fdatasets%2F>
|
||||
2. upload your dataset:
|
||||
1. if your dataset is a single file, upload that single file using the "upload" button,
|
||||
2. otherwise, create a folder using the "create folder" button, then inside that folder upload your individual files.
|
||||
|
||||
## Upgrading `https://bench.meilisearch.dev`
|
||||
|
||||
The URL of the server is in our password manager (look for "benchboard").
|
||||
|
||||
1. Make the needed modifications on the [benchboard repository](https://github.com/meilisearch/benchboard) and merge them to main.
|
||||
2. Publish a new release to produce the Ubuntu/Debian binary.
|
||||
3. Download the binary locally, send it to the server:
|
||||
```
|
||||
scp -6 ~/Downloads/benchboard root@\[<ipv6-address>\]:/bench/new-benchboard
|
||||
```
|
||||
Note that the ipv6 must be between escaped square brackets for SCP.
|
||||
4. SSH to the server:
|
||||
```
|
||||
ssh root@<ipv6-address>
|
||||
```
|
||||
Note the ipv6 must **NOT** be between escaped square brackets for SSH 🥲
|
||||
5. On the server, set the correct permissions for the new binary:
|
||||
```
|
||||
chown bench:bench /bench/new-benchboard
|
||||
chmod 700 /bench/new-benchboard
|
||||
```
|
||||
6. On the server, move the new binary to the location of the running binary (if unsure, start by making a backup of the running binary):
|
||||
```
|
||||
mv /bench/{new-,}benchboard
|
||||
```
|
||||
7. Restart the benchboard service.
|
||||
```
|
||||
systemctl restart benchboard
|
||||
```
|
||||
8. Check that the service runs correctly.
|
||||
```
|
||||
systemctl status benchboard
|
||||
```
|
||||
9. Check the availability of the service by going to <https://bench.meilisearch.dev> on your browser.
|
@ -4,7 +4,7 @@ First, thank you for contributing to Meilisearch! The goal of this document is t
|
||||
|
||||
Remember that there are many ways to contribute other than writing code: writing [tutorials or blog posts](https://github.com/meilisearch/awesome-meilisearch), improving [the documentation](https://github.com/meilisearch/documentation), submitting [bug reports](https://github.com/meilisearch/meilisearch/issues/new?assignees=&labels=&template=bug_report.md&title=) and [feature requests](https://github.com/meilisearch/product/discussions/categories/feedback-feature-proposal)...
|
||||
|
||||
The code in this repository is only concerned with managing multiple indexes, handling the update store, and exposing an HTTP API. Search and indexation are the domain of our core engine, [`milli`](https://github.com/meilisearch/milli), while tokenization is handled by [our `charabia` library](https://github.com/meilisearch/charabia/).
|
||||
Meilisearch can manage multiple indexes, handle the update store, and expose an HTTP API. Search and indexation are the domain of our core engine, [`milli`](https://github.com/meilisearch/meilisearch/tree/main/milli), while tokenization is handled by [our `charabia` library](https://github.com/meilisearch/charabia/).
|
||||
|
||||
If Meilisearch does not offer optimized support for your language, please consider contributing to `charabia` by following the [CONTRIBUTING.md file](https://github.com/meilisearch/charabia/blob/main/CONTRIBUTING.md) and integrating your intended normalizer/segmenter.
|
||||
|
||||
@ -81,6 +81,30 @@ Meilisearch follows the [cargo xtask](https://github.com/matklad/cargo-xtask) wo
|
||||
|
||||
Run `cargo xtask --help` from the root of the repository to find out what is available.
|
||||
|
||||
### Logging
|
||||
|
||||
Meilisearch uses [`tracing`](https://lib.rs/crates/tracing) for logging purposes. Tracing logs are structured and can be displayed as JSON to the end user, so prefer passing arguments as fields rather than interpolating them in the message.
|
||||
|
||||
Refer to the [documentation](https://docs.rs/tracing/0.1.40/tracing/index.html#using-the-macros) for the syntax of the spans and events.
|
||||
|
||||
Logging spans are used for 3 distinct purposes:
|
||||
|
||||
1. Regular logging
|
||||
2. Profiling
|
||||
3. Benchmarking
|
||||
|
||||
As a result, the spans should follow some rules:
|
||||
|
||||
- They should not be put on functions that are called too often. That is because opening and closing a span causes some overhead. For regular logging, avoid putting spans on functions that are taking less than a few hundred nanoseconds. For profiling or benchmarking, avoid putting spans on functions that are taking less than a few microseconds.
|
||||
- For profiling and benchmarking, use the `TRACE` level.
|
||||
- For profiling and benchmarking, use the following `target` prefixes:
|
||||
- `indexing::` for spans meant when profiling the indexing operations.
|
||||
- `search::` for spans meant when profiling the search operations.
|
||||
|
||||
### Benchmarking
|
||||
|
||||
See [BENCHMARKS.md](./BENCHMARKS.md)
|
||||
|
||||
## Git Guidelines
|
||||
|
||||
### Git Branches
|
||||
|
327
Cargo.lock
generated
327
Cargo.lock
generated
@ -149,10 +149,10 @@ dependencies = [
|
||||
"impl-more",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"tokio-rustls 0.23.4",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tracing",
|
||||
"webpki-roots 0.22.6",
|
||||
"webpki-roots 0.25.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -306,9 +306,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "anstream"
|
||||
version = "0.6.7"
|
||||
version = "0.6.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4cd2405b3ac1faab2990b74d728624cd9fd115651fcecc7c2d8daf01376275ba"
|
||||
checksum = "d96bd03f33fe50a863e394ee9718a706f988b9079b20c3784fb726e7678b62fb"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"anstyle-parse",
|
||||
@ -320,9 +320,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.1"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd"
|
||||
checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle-parse"
|
||||
@ -494,7 +494,7 @@ checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b"
|
||||
|
||||
[[package]]
|
||||
name = "benchmarks"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bytes",
|
||||
@ -628,7 +628,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "build-info"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"time",
|
||||
@ -877,9 +877,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "charabia"
|
||||
version = "0.8.7"
|
||||
version = "0.8.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3a9071b1586dd067b5fdfd2069fab932c047ca5bbce4bd2bdee8af0f4b155053"
|
||||
checksum = "60dc1a562fc8cb53d552d371758a4ecd76d15cc7489d2b968529cd9cadcbd854"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"cow-utils",
|
||||
@ -1529,7 +1529,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "dump"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"big_s",
|
||||
@ -1643,9 +1643,9 @@ checksum = "a246d82be1c9d791c5dfde9a2bd045fc3cbba3fa2b11ad558f27d01712f00569"
|
||||
|
||||
[[package]]
|
||||
name = "encoding_rs"
|
||||
version = "0.8.32"
|
||||
version = "0.8.33"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394"
|
||||
checksum = "7268b386296a025e474d5140678f75d6de9493ae55a5d709eeb9dd08149945e1"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
@ -1692,16 +1692,26 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.10.1"
|
||||
name = "env_filter"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "95b3f3e67048839cb0d0781f445682a35113da7121f7c949db0e2be96a4fbece"
|
||||
checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea"
|
||||
dependencies = [
|
||||
"humantime",
|
||||
"is-terminal",
|
||||
"log",
|
||||
"regex",
|
||||
"termcolor",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "env_logger"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c012a26a7f605efc424dd53697843a72be7dc86ad2d01f7814337794a12231d"
|
||||
dependencies = [
|
||||
"anstream",
|
||||
"anstyle",
|
||||
"env_filter",
|
||||
"humantime",
|
||||
"log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -1767,7 +1777,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "file-store"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"faux",
|
||||
"tempfile",
|
||||
@ -1790,7 +1800,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "filter-parser"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"insta",
|
||||
"nom",
|
||||
@ -1810,7 +1820,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "flatten-serde-json"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"serde_json",
|
||||
@ -1928,7 +1938,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "fuzzers"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"arbitrary",
|
||||
"clap",
|
||||
@ -2102,8 +2112,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2222,7 +2234,7 @@ dependencies = [
|
||||
"atomic-polyfill",
|
||||
"hash32",
|
||||
"rustc_version",
|
||||
"spin 0.9.8",
|
||||
"spin",
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
@ -2393,7 +2405,7 @@ dependencies = [
|
||||
"hyper",
|
||||
"rustls 0.21.10",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.1",
|
||||
"tokio-rustls",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -2420,7 +2432,7 @@ checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d"
|
||||
|
||||
[[package]]
|
||||
name = "index-scheduler"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"big_s",
|
||||
@ -2453,9 +2465,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.1.0"
|
||||
version = "2.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d530e1a18b1cb4c484e6e34556a0d948706958449fca0cab753d649f2bce3d1f"
|
||||
checksum = "7b0b929d511467233429c45a44ac1dcaa21ba0f5ba11e4879e6ed28ddb4f9df4"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown",
|
||||
@ -2607,7 +2619,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "json-depth-checker"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"criterion",
|
||||
"serde_json",
|
||||
@ -2615,13 +2627,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "jsonwebtoken"
|
||||
version = "8.3.0"
|
||||
version = "9.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6971da4d9c3aa03c3d8f3ff0f4155b534aad021292003895a469716b2a230378"
|
||||
checksum = "5c7ea04a7c5c055c175f189b6dc6ba036fd62306b58c66c9f6389036c503a3f4"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"js-sys",
|
||||
"pem",
|
||||
"ring 0.16.20",
|
||||
"ring",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"simple_asn1",
|
||||
@ -2733,9 +2746,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-cc-cedict-builder"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a90d23f7cef31c6ab7ac0d4f3b23940754207f7b5a80b080c39193caffe99ac2"
|
||||
checksum = "ca21f2ee3ca40e7f3ebbd568d041be1531c2c28dbf540e737aeba934ab53f330"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -2752,9 +2765,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-compress"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1927b7d2bd4ffc19e07691bf8609722663c341f80260a1c636cee8f1ec420dce"
|
||||
checksum = "34da125091f3b3a49351f418484a16cb2a23f6888cd53fe219edad19d263da5d"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"flate2",
|
||||
@ -2763,9 +2776,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-core"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3299caa2b81c9a076535a4651a83bf7d624c15f2349f243187fffc64b5a78251"
|
||||
checksum = "09d4b717a8a31b73a3cbd3552e0abda14e0c85d97dc8b911035342533defdbad"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -2780,9 +2793,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-decompress"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b82b8d2323a67dc8ff0c40751d199b7ba94cd5e3c13a5b31622d318acc79e5b"
|
||||
checksum = "98f4476c99cb4ffa54fbfc42953adf69ada7276cfbb594bce9829547de012058"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"flate2",
|
||||
@ -2791,9 +2804,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-dictionary"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cddf783b459d54b130d956889bec052c25fcb478a304e03fa9b2289387572bc5"
|
||||
checksum = "a45b92f0ce331c2202c6cec3135e4bfce29525ab3bb97a613c27c8e0a29fa967"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -2811,9 +2824,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-ipadic-builder"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "27c708f08f14b0806f6c4cce5324b4bcba27209463026b78c31f399f8be9d30d"
|
||||
checksum = "642dee52201852df209cb43423ff1ca4d161a329f5cdba049a7b5820118345f2"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -2832,9 +2845,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-ipadic-neologd-builder"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e5e67eb91652203d202f7d27ead220d1d8c9099552709b8429eae9c70f2312fb"
|
||||
checksum = "325144b154e68159373e944d1cd7f67c6ff9965a2af41240a8e41732b3fdb3af"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -2853,9 +2866,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-ko-dic"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d45da8d9a5888f4d4e78bb29fc82ff9ae519962efb0d2d92343b6cf8e373952f"
|
||||
checksum = "b484a2f9964e7424264fda304beb6ff6ad883c347accfe1115e777dedef3661d"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"byteorder",
|
||||
@ -2870,9 +2883,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-ko-dic-builder"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41c0933295dc945178bbc08f34111dc3ef22bfee38820f78453c8f8d4f3463d1"
|
||||
checksum = "b9413d4d9bf7af921f5ac64414a290c7ba81695e8ba08dd2f6c950b57c281a69"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -2890,12 +2903,11 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-tokenizer"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "348ce9bb3f2e5edc577420b98cca05b2177f3af50ef5ae278a1d8a1351d56197"
|
||||
checksum = "9987c818462d51ca67e131e40f0386e25e8c557e195059b1257f95731561185d"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"byteorder",
|
||||
"lindera-core",
|
||||
"lindera-dictionary",
|
||||
"once_cell",
|
||||
@ -2905,9 +2917,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-unidic"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74022a57c395ed7e213a9cd5833207e3c583145078ee9a164aeaec68b30c9d8e"
|
||||
checksum = "0c379cf436b2627cd7d3498642e491eadbff9b3e01231c516ce9f9b1893ab7c3"
|
||||
dependencies = [
|
||||
"bincode",
|
||||
"byteorder",
|
||||
@ -2922,9 +2934,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "lindera-unidic-builder"
|
||||
version = "0.27.2"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a34e5564ee81af82603cd6a03c3abe6e17cc0ae598bfa5078809f06e59e96e08"
|
||||
checksum = "601ec33b5174141396a7a4ca066278863840221fec32d0be19091e7fae91ed94"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bincode",
|
||||
@ -3115,7 +3127,7 @@ checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771"
|
||||
|
||||
[[package]]
|
||||
name = "meili-snap"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"insta",
|
||||
"md5",
|
||||
@ -3124,7 +3136,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"actix-cors",
|
||||
"actix-http",
|
||||
@ -3182,7 +3194,7 @@ dependencies = [
|
||||
"rayon",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"rustls 0.20.9",
|
||||
"rustls 0.21.10",
|
||||
"rustls-pemfile",
|
||||
"segment",
|
||||
"serde",
|
||||
@ -3217,7 +3229,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch-auth"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"enum-iterator",
|
||||
@ -3236,7 +3248,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilisearch-types"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"actix-web",
|
||||
"anyhow",
|
||||
@ -3266,7 +3278,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "meilitool"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
@ -3305,7 +3317,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "milli"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"arroy",
|
||||
"big_s",
|
||||
@ -3326,7 +3338,6 @@ dependencies = [
|
||||
"filter-parser",
|
||||
"flatten-serde-json",
|
||||
"fst",
|
||||
"futures",
|
||||
"fxhash",
|
||||
"geoutils",
|
||||
"grenad",
|
||||
@ -3350,7 +3361,6 @@ dependencies = [
|
||||
"rand",
|
||||
"rand_pcg",
|
||||
"rayon",
|
||||
"reqwest",
|
||||
"roaring",
|
||||
"rstar",
|
||||
"serde",
|
||||
@ -3364,8 +3374,9 @@ dependencies = [
|
||||
"tiktoken-rs",
|
||||
"time",
|
||||
"tokenizers",
|
||||
"tokio",
|
||||
"tracing",
|
||||
"ureq",
|
||||
"url",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
@ -3731,11 +3742,12 @@ checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
||||
|
||||
[[package]]
|
||||
name = "pem"
|
||||
version = "1.1.1"
|
||||
version = "3.0.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8835c273a76a90455d7344889b0964598e3316e2a79ede8e36f16bdcf2228b8"
|
||||
checksum = "1b8fcc794035347fb64beda2d3b462595dd2753e3f268d89c5aae77e8cf2c310"
|
||||
dependencies = [
|
||||
"base64 0.13.1",
|
||||
"base64 0.21.7",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -3746,7 +3758,7 @@ checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||
|
||||
[[package]]
|
||||
name = "permissive-json-pointer"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"big_s",
|
||||
"serde_json",
|
||||
@ -4244,7 +4256,7 @@ dependencies = [
|
||||
"serde_urlencoded",
|
||||
"system-configuration",
|
||||
"tokio",
|
||||
"tokio-rustls 0.24.1",
|
||||
"tokio-rustls",
|
||||
"tokio-util",
|
||||
"tower-service",
|
||||
"url",
|
||||
@ -4270,31 +4282,17 @@ checksum = "b9b1a3d5f46d53f4a3478e2be4a5a5ce5108ea58b100dcd139830eae7f79a3a1"
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.16.20"
|
||||
version = "0.17.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3053cf52e236a3ed746dfc745aa9cacf1b791d846bdaf412f60a8d7d6e17c8fc"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"libc",
|
||||
"once_cell",
|
||||
"spin 0.5.2",
|
||||
"untrusted 0.7.1",
|
||||
"web-sys",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ring"
|
||||
version = "0.17.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "688c63d65483050968b2a8937f7995f443e27041a0f7700aa59b0822aedebb74"
|
||||
checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"cfg-if",
|
||||
"getrandom",
|
||||
"libc",
|
||||
"spin 0.9.8",
|
||||
"untrusted 0.9.0",
|
||||
"windows-sys 0.48.0",
|
||||
"spin",
|
||||
"untrusted",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4369,18 +4367,6 @@ dependencies = [
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.20.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b80e3dec595989ea8510028f30c408a4630db12c9cbb8de34203b89d6577e99"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring 0.16.20",
|
||||
"sct",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.21.10"
|
||||
@ -4388,11 +4374,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring 0.17.7",
|
||||
"rustls-webpki",
|
||||
"ring",
|
||||
"rustls-webpki 0.101.7",
|
||||
"sct",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls"
|
||||
version = "0.22.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
|
||||
dependencies = [
|
||||
"log",
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.2",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pemfile"
|
||||
version = "1.0.4"
|
||||
@ -4402,14 +4402,31 @@ dependencies = [
|
||||
"base64 0.21.7",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-pki-types"
|
||||
version = "1.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ede67b28608b4c60685c7d54122d4400d90f62b40caee7700e700380a390fa8"
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.101.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765"
|
||||
dependencies = [
|
||||
"ring 0.17.7",
|
||||
"untrusted 0.9.0",
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustls-webpki"
|
||||
version = "0.102.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
|
||||
dependencies = [
|
||||
"ring",
|
||||
"rustls-pki-types",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4455,8 +4472,8 @@ version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414"
|
||||
dependencies = [
|
||||
"ring 0.17.7",
|
||||
"untrusted 0.9.0",
|
||||
"ring",
|
||||
"untrusted",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -4490,9 +4507,9 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.195"
|
||||
version = "1.0.197"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02"
|
||||
checksum = "3fb1c873e1b9b056a4dc4c0c198b24c3ffa059243875552b2bd0933b1aee4ce2"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
@ -4508,9 +4525,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.195"
|
||||
version = "1.0.197"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c"
|
||||
checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -4519,9 +4536,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.111"
|
||||
version = "1.0.114"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "176e46fa42316f18edd598015a5166857fc835ec732f5215eac6b7bdbf0a84f4"
|
||||
checksum = "c5f09b1bd632ef549eaa9f60a1f8de742bdbc698e6cee2095fc84dde5f549ae0"
|
||||
dependencies = [
|
||||
"indexmap",
|
||||
"itoa",
|
||||
@ -4719,12 +4736,6 @@ dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6e63cff320ae2c57904679ba7cb63280a3dc4613885beafb148ee7bf9aa9042d"
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
@ -4917,18 +4928,18 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.56"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad"
|
||||
checksum = "03468839009160513471e86a034bb2c5c0e4baae3b43f79ffc55c4a5427b3297"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.56"
|
||||
version = "1.0.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471"
|
||||
checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@ -5078,17 +5089,6 @@ dependencies = [
|
||||
"syn 2.0.48",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.23.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59"
|
||||
dependencies = [
|
||||
"rustls 0.20.9",
|
||||
"tokio",
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-rustls"
|
||||
version = "0.24.1"
|
||||
@ -5364,12 +5364,6 @@ version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a"
|
||||
|
||||
[[package]]
|
||||
name = "untrusted"
|
||||
version = "0.9.0"
|
||||
@ -5378,21 +5372,22 @@ checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1"
|
||||
|
||||
[[package]]
|
||||
name = "ureq"
|
||||
version = "2.9.1"
|
||||
version = "2.9.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f8cdd25c339e200129fe4de81451814e5228c9b771d57378817d6117cc2b3f97"
|
||||
checksum = "11f214ce18d8b2cbe84ed3aa6486ed3f5b285cf8d8fbdbce9f3f767a724adc35"
|
||||
dependencies = [
|
||||
"base64 0.21.7",
|
||||
"flate2",
|
||||
"log",
|
||||
"once_cell",
|
||||
"rustls 0.21.10",
|
||||
"rustls-webpki",
|
||||
"rustls 0.22.2",
|
||||
"rustls-pki-types",
|
||||
"rustls-webpki 0.102.2",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"socks",
|
||||
"url",
|
||||
"webpki-roots 0.25.3",
|
||||
"webpki-roots 0.26.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -5628,31 +5623,21 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki"
|
||||
version = "0.22.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53"
|
||||
dependencies = [
|
||||
"ring 0.17.7",
|
||||
"untrusted 0.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.22.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6c71e40d7d2c34a5106301fb632274ca37242cd0c9d3e64dbece371a40a2d87"
|
||||
dependencies = [
|
||||
"webpki",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.25.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1778a42e8b3b90bff8d0f5032bf22250792889a5cdc752aa0020c84abe3aaf10"
|
||||
|
||||
[[package]]
|
||||
name = "webpki-roots"
|
||||
version = "0.26.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
|
||||
dependencies = [
|
||||
"rustls-pki-types",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "whatlang"
|
||||
version = "0.16.4"
|
||||
@ -5941,7 +5926,7 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "xtask"
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"build-info",
|
||||
@ -6052,6 +6037,12 @@ dependencies = [
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeroize"
|
||||
version = "1.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
|
||||
|
||||
[[package]]
|
||||
name = "zerovec"
|
||||
version = "0.10.1"
|
||||
|
@ -21,7 +21,7 @@ members = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "1.7.4"
|
||||
version = "1.8.0"
|
||||
authors = [
|
||||
"Quentin de Quelen <quentin@dequelen.me>",
|
||||
"Clément Renault <clement@meilisearch.com>",
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -277,6 +277,7 @@ pub(crate) mod test {
|
||||
}),
|
||||
pagination: Setting::NotSet,
|
||||
embedders: Setting::NotSet,
|
||||
search_cutoff_ms: Setting::NotSet,
|
||||
_kind: std::marker::PhantomData,
|
||||
};
|
||||
settings.check()
|
||||
|
@ -379,6 +379,7 @@ impl<T> From<v5::Settings<T>> for v6::Settings<v6::Unchecked> {
|
||||
v5::Setting::NotSet => v6::Setting::NotSet,
|
||||
},
|
||||
embedders: v6::Setting::NotSet,
|
||||
search_cutoff_ms: v6::Setting::NotSet,
|
||||
_kind: std::marker::PhantomData,
|
||||
}
|
||||
}
|
||||
|
@ -1301,8 +1301,8 @@ impl IndexScheduler {
|
||||
|
||||
wtxn.commit().map_err(Error::HeedTransaction)?;
|
||||
|
||||
// Once the tasks are commited, we should delete all the update files associated ASAP to avoid leaking files in case of a restart
|
||||
tracing::debug!("Deleting the upadate files");
|
||||
// Once the tasks are committed, we should delete all the update files associated ASAP to avoid leaking files in case of a restart
|
||||
tracing::debug!("Deleting the update files");
|
||||
|
||||
//We take one read transaction **per thread**. Then, every thread is going to pull out new IDs from the roaring bitmap with the help of an atomic shared index into the bitmap
|
||||
let idx = AtomicU32::new(0);
|
||||
@ -1332,7 +1332,7 @@ impl IndexScheduler {
|
||||
Ok(TickOutcome::TickAgain(processed_tasks))
|
||||
}
|
||||
|
||||
/// Once the tasks changes have been commited we must send all the tasks that were updated to our webhook if there is one.
|
||||
/// Once the tasks changes have been committed we must send all the tasks that were updated to our webhook if there is one.
|
||||
fn notify_webhook(&self, updated: &RoaringBitmap) -> Result<()> {
|
||||
if let Some(ref url) = self.webhook_url {
|
||||
struct TaskReader<'a, 'b> {
|
||||
|
@ -11,7 +11,7 @@ edition.workspace = true
|
||||
license.workspace = true
|
||||
|
||||
[dependencies]
|
||||
actix-web = { version = "4.4.1", default-features = false }
|
||||
actix-web = { version = "4.5.1", default-features = false }
|
||||
anyhow = "1.0.79"
|
||||
convert_case = "0.6.0"
|
||||
csv = "1.3.0"
|
||||
|
@ -259,6 +259,7 @@ InvalidSettingsProximityPrecision , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsFaceting , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsFilterableAttributes , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsPagination , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsSearchCutoffMs , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsEmbedders , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsRankingRules , InvalidRequest , BAD_REQUEST ;
|
||||
InvalidSettingsSearchableAttributes , InvalidRequest , BAD_REQUEST ;
|
||||
@ -352,6 +353,7 @@ impl ErrorCode for milli::Error {
|
||||
| UserError::InvalidOpenAiModelDimensions { .. }
|
||||
| UserError::InvalidOpenAiModelDimensionsMax { .. }
|
||||
| UserError::InvalidSettingsDimensions { .. }
|
||||
| UserError::InvalidUrl { .. }
|
||||
| UserError::InvalidPrompt(_) => Code::InvalidSettingsEmbedders,
|
||||
UserError::TooManyEmbedders(_) => Code::InvalidSettingsEmbedders,
|
||||
UserError::InvalidPromptForEmbeddings(..) => Code::InvalidSettingsEmbedders,
|
||||
|
@ -202,6 +202,9 @@ pub struct Settings<T> {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSettingsEmbedders>)]
|
||||
pub embedders: Setting<BTreeMap<String, Setting<milli::vector::settings::EmbeddingSettings>>>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default, error = DeserrJsonError<InvalidSettingsSearchCutoffMs>)]
|
||||
pub search_cutoff_ms: Setting<u64>,
|
||||
|
||||
#[serde(skip)]
|
||||
#[deserr(skip)]
|
||||
@ -227,6 +230,7 @@ impl Settings<Checked> {
|
||||
faceting: Setting::Reset,
|
||||
pagination: Setting::Reset,
|
||||
embedders: Setting::Reset,
|
||||
search_cutoff_ms: Setting::Reset,
|
||||
_kind: PhantomData,
|
||||
}
|
||||
}
|
||||
@ -249,6 +253,7 @@ impl Settings<Checked> {
|
||||
faceting,
|
||||
pagination,
|
||||
embedders,
|
||||
search_cutoff_ms,
|
||||
..
|
||||
} = self;
|
||||
|
||||
@ -269,6 +274,7 @@ impl Settings<Checked> {
|
||||
faceting,
|
||||
pagination,
|
||||
embedders,
|
||||
search_cutoff_ms,
|
||||
_kind: PhantomData,
|
||||
}
|
||||
}
|
||||
@ -315,6 +321,7 @@ impl Settings<Unchecked> {
|
||||
faceting: self.faceting,
|
||||
pagination: self.pagination,
|
||||
embedders: self.embedders,
|
||||
search_cutoff_ms: self.search_cutoff_ms,
|
||||
_kind: PhantomData,
|
||||
}
|
||||
}
|
||||
@ -347,19 +354,40 @@ pub fn apply_settings_to_builder(
|
||||
settings: &Settings<Checked>,
|
||||
builder: &mut milli::update::Settings,
|
||||
) {
|
||||
match settings.searchable_attributes {
|
||||
let Settings {
|
||||
displayed_attributes,
|
||||
searchable_attributes,
|
||||
filterable_attributes,
|
||||
sortable_attributes,
|
||||
ranking_rules,
|
||||
stop_words,
|
||||
non_separator_tokens,
|
||||
separator_tokens,
|
||||
dictionary,
|
||||
synonyms,
|
||||
distinct_attribute,
|
||||
proximity_precision,
|
||||
typo_tolerance,
|
||||
faceting,
|
||||
pagination,
|
||||
embedders,
|
||||
search_cutoff_ms,
|
||||
_kind,
|
||||
} = settings;
|
||||
|
||||
match searchable_attributes {
|
||||
Setting::Set(ref names) => builder.set_searchable_fields(names.clone()),
|
||||
Setting::Reset => builder.reset_searchable_fields(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.displayed_attributes {
|
||||
match displayed_attributes {
|
||||
Setting::Set(ref names) => builder.set_displayed_fields(names.clone()),
|
||||
Setting::Reset => builder.reset_displayed_fields(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.filterable_attributes {
|
||||
match filterable_attributes {
|
||||
Setting::Set(ref facets) => {
|
||||
builder.set_filterable_fields(facets.clone().into_iter().collect())
|
||||
}
|
||||
@ -367,13 +395,13 @@ pub fn apply_settings_to_builder(
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.sortable_attributes {
|
||||
match sortable_attributes {
|
||||
Setting::Set(ref fields) => builder.set_sortable_fields(fields.iter().cloned().collect()),
|
||||
Setting::Reset => builder.reset_sortable_fields(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.ranking_rules {
|
||||
match ranking_rules {
|
||||
Setting::Set(ref criteria) => {
|
||||
builder.set_criteria(criteria.iter().map(|c| c.clone().into()).collect())
|
||||
}
|
||||
@ -381,13 +409,13 @@ pub fn apply_settings_to_builder(
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.stop_words {
|
||||
match stop_words {
|
||||
Setting::Set(ref stop_words) => builder.set_stop_words(stop_words.clone()),
|
||||
Setting::Reset => builder.reset_stop_words(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.non_separator_tokens {
|
||||
match non_separator_tokens {
|
||||
Setting::Set(ref non_separator_tokens) => {
|
||||
builder.set_non_separator_tokens(non_separator_tokens.clone())
|
||||
}
|
||||
@ -395,7 +423,7 @@ pub fn apply_settings_to_builder(
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.separator_tokens {
|
||||
match separator_tokens {
|
||||
Setting::Set(ref separator_tokens) => {
|
||||
builder.set_separator_tokens(separator_tokens.clone())
|
||||
}
|
||||
@ -403,31 +431,31 @@ pub fn apply_settings_to_builder(
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.dictionary {
|
||||
match dictionary {
|
||||
Setting::Set(ref dictionary) => builder.set_dictionary(dictionary.clone()),
|
||||
Setting::Reset => builder.reset_dictionary(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.synonyms {
|
||||
match synonyms {
|
||||
Setting::Set(ref synonyms) => builder.set_synonyms(synonyms.clone().into_iter().collect()),
|
||||
Setting::Reset => builder.reset_synonyms(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.distinct_attribute {
|
||||
match distinct_attribute {
|
||||
Setting::Set(ref attr) => builder.set_distinct_field(attr.clone()),
|
||||
Setting::Reset => builder.reset_distinct_field(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.proximity_precision {
|
||||
match proximity_precision {
|
||||
Setting::Set(ref precision) => builder.set_proximity_precision((*precision).into()),
|
||||
Setting::Reset => builder.reset_proximity_precision(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.typo_tolerance {
|
||||
match typo_tolerance {
|
||||
Setting::Set(ref value) => {
|
||||
match value.enabled {
|
||||
Setting::Set(val) => builder.set_autorize_typos(val),
|
||||
@ -482,7 +510,7 @@ pub fn apply_settings_to_builder(
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match &settings.faceting {
|
||||
match faceting {
|
||||
Setting::Set(FacetingSettings { max_values_per_facet, sort_facet_values_by }) => {
|
||||
match max_values_per_facet {
|
||||
Setting::Set(val) => builder.set_max_values_per_facet(*val),
|
||||
@ -504,7 +532,7 @@ pub fn apply_settings_to_builder(
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.pagination {
|
||||
match pagination {
|
||||
Setting::Set(ref value) => match value.max_total_hits {
|
||||
Setting::Set(val) => builder.set_pagination_max_total_hits(val),
|
||||
Setting::Reset => builder.reset_pagination_max_total_hits(),
|
||||
@ -514,11 +542,17 @@ pub fn apply_settings_to_builder(
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match settings.embedders.clone() {
|
||||
Setting::Set(value) => builder.set_embedder_settings(value),
|
||||
match embedders {
|
||||
Setting::Set(value) => builder.set_embedder_settings(value.clone()),
|
||||
Setting::Reset => builder.reset_embedder_settings(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
|
||||
match search_cutoff_ms {
|
||||
Setting::Set(cutoff) => builder.set_search_cutoff(*cutoff),
|
||||
Setting::Reset => builder.reset_search_cutoff(),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn settings(
|
||||
@ -607,6 +641,8 @@ pub fn settings(
|
||||
.collect();
|
||||
let embedders = if embedders.is_empty() { Setting::NotSet } else { Setting::Set(embedders) };
|
||||
|
||||
let search_cutoff_ms = index.search_cutoff(rtxn)?;
|
||||
|
||||
Ok(Settings {
|
||||
displayed_attributes: match displayed_attributes {
|
||||
Some(attrs) => Setting::Set(attrs),
|
||||
@ -633,6 +669,10 @@ pub fn settings(
|
||||
faceting: Setting::Set(faceting),
|
||||
pagination: Setting::Set(pagination),
|
||||
embedders,
|
||||
search_cutoff_ms: match search_cutoff_ms {
|
||||
Some(cutoff) => Setting::Set(cutoff),
|
||||
None => Setting::Reset,
|
||||
},
|
||||
_kind: PhantomData,
|
||||
})
|
||||
}
|
||||
@ -783,6 +823,7 @@ pub(crate) mod test {
|
||||
faceting: Setting::NotSet,
|
||||
pagination: Setting::NotSet,
|
||||
embedders: Setting::NotSet,
|
||||
search_cutoff_ms: Setting::NotSet,
|
||||
_kind: PhantomData::<Unchecked>,
|
||||
};
|
||||
|
||||
@ -809,6 +850,7 @@ pub(crate) mod test {
|
||||
faceting: Setting::NotSet,
|
||||
pagination: Setting::NotSet,
|
||||
embedders: Setting::NotSet,
|
||||
search_cutoff_ms: Setting::NotSet,
|
||||
_kind: PhantomData::<Unchecked>,
|
||||
};
|
||||
|
||||
|
@ -14,18 +14,18 @@ default-run = "meilisearch"
|
||||
|
||||
[dependencies]
|
||||
actix-cors = "0.7.0"
|
||||
actix-http = { version = "3.5.1", default-features = false, features = [
|
||||
actix-http = { version = "3.6.0", default-features = false, features = [
|
||||
"compress-brotli",
|
||||
"compress-gzip",
|
||||
"rustls",
|
||||
"rustls-0_21",
|
||||
] }
|
||||
actix-utils = "3.0.1"
|
||||
actix-web = { version = "4.4.1", default-features = false, features = [
|
||||
actix-web = { version = "4.5.1", default-features = false, features = [
|
||||
"macros",
|
||||
"compress-brotli",
|
||||
"compress-gzip",
|
||||
"cookies",
|
||||
"rustls",
|
||||
"rustls-0_21",
|
||||
] }
|
||||
actix-web-static-files = { git = "https://github.com/kilork/actix-web-static-files.git", rev = "2d3b6160", optional = true }
|
||||
anyhow = { version = "1.0.79", features = ["backtrace"] }
|
||||
@ -52,7 +52,7 @@ index-scheduler = { path = "../index-scheduler" }
|
||||
indexmap = { version = "2.1.0", features = ["serde"] }
|
||||
is-terminal = "0.4.10"
|
||||
itertools = "0.11.0"
|
||||
jsonwebtoken = "8.3.0"
|
||||
jsonwebtoken = "9.2.0"
|
||||
lazy_static = "1.4.0"
|
||||
meilisearch-auth = { path = "../meilisearch-auth" }
|
||||
meilisearch-types = { path = "../meilisearch-types" }
|
||||
@ -75,7 +75,7 @@ reqwest = { version = "0.11.23", features = [
|
||||
"rustls-tls",
|
||||
"json",
|
||||
], default-features = false }
|
||||
rustls = "0.20.8"
|
||||
rustls = "0.21.6"
|
||||
rustls-pemfile = "1.0.2"
|
||||
segment = { version = "0.2.3", optional = true }
|
||||
serde = { version = "1.0.195", features = ["derive"] }
|
||||
|
@ -579,6 +579,7 @@ pub struct SearchAggregator {
|
||||
// requests
|
||||
total_received: usize,
|
||||
total_succeeded: usize,
|
||||
total_degraded: usize,
|
||||
time_spent: BinaryHeap<usize>,
|
||||
|
||||
// sort
|
||||
@ -758,9 +759,13 @@ impl SearchAggregator {
|
||||
hits_info: _,
|
||||
facet_distribution: _,
|
||||
facet_stats: _,
|
||||
degraded,
|
||||
} = result;
|
||||
|
||||
self.total_succeeded = self.total_succeeded.saturating_add(1);
|
||||
if *degraded {
|
||||
self.total_degraded = self.total_degraded.saturating_add(1);
|
||||
}
|
||||
self.time_spent.push(*processing_time_ms as usize);
|
||||
}
|
||||
|
||||
@ -802,6 +807,7 @@ impl SearchAggregator {
|
||||
semantic_ratio,
|
||||
embedder,
|
||||
hybrid,
|
||||
total_degraded,
|
||||
} = other;
|
||||
|
||||
if self.timestamp.is_none() {
|
||||
@ -816,6 +822,7 @@ impl SearchAggregator {
|
||||
// request
|
||||
self.total_received = self.total_received.saturating_add(total_received);
|
||||
self.total_succeeded = self.total_succeeded.saturating_add(total_succeeded);
|
||||
self.total_degraded = self.total_degraded.saturating_add(total_degraded);
|
||||
self.time_spent.append(time_spent);
|
||||
|
||||
// sort
|
||||
@ -921,6 +928,7 @@ impl SearchAggregator {
|
||||
semantic_ratio,
|
||||
embedder,
|
||||
hybrid,
|
||||
total_degraded,
|
||||
} = self;
|
||||
|
||||
if total_received == 0 {
|
||||
@ -940,6 +948,7 @@ impl SearchAggregator {
|
||||
"total_succeeded": total_succeeded,
|
||||
"total_failed": total_received.saturating_sub(total_succeeded), // just to be sure we never panics
|
||||
"total_received": total_received,
|
||||
"total_degraded": total_degraded,
|
||||
},
|
||||
"sort": {
|
||||
"with_geoPoint": sort_with_geo_point,
|
||||
|
@ -151,7 +151,7 @@ async fn run_http(
|
||||
.keep_alive(KeepAlive::Os);
|
||||
|
||||
if let Some(config) = opt_clone.get_ssl_config()? {
|
||||
http_server.bind_rustls(opt_clone.http_addr, config)?.run().await?;
|
||||
http_server.bind_rustls_021(opt_clone.http_addr, config)?.run().await?;
|
||||
} else {
|
||||
http_server.bind(&opt_clone.http_addr)?.run().await?;
|
||||
}
|
||||
|
@ -4,24 +4,17 @@ use prometheus::{
|
||||
register_int_gauge_vec, HistogramVec, IntCounterVec, IntGauge, IntGaugeVec,
|
||||
};
|
||||
|
||||
/// Create evenly distributed buckets
|
||||
fn create_buckets() -> [f64; 29] {
|
||||
(0..10)
|
||||
.chain((10..100).step_by(10))
|
||||
.chain((100..=1000).step_by(100))
|
||||
.map(|i| i as f64 / 1000.)
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
lazy_static! {
|
||||
pub static ref MEILISEARCH_HTTP_RESPONSE_TIME_CUSTOM_BUCKETS: [f64; 29] = create_buckets();
|
||||
pub static ref MEILISEARCH_HTTP_REQUESTS_TOTAL: IntCounterVec = register_int_counter_vec!(
|
||||
opts!("meilisearch_http_requests_total", "Meilisearch HTTP requests total"),
|
||||
&["method", "path"]
|
||||
&["method", "path", "status"]
|
||||
)
|
||||
.expect("Can't create a metric");
|
||||
pub static ref MEILISEARCH_DEGRADED_SEARCH_REQUESTS: IntGauge = register_int_gauge!(opts!(
|
||||
"meilisearch_degraded_search_requests",
|
||||
"Meilisearch number of degraded search requests"
|
||||
))
|
||||
.expect("Can't create a metric");
|
||||
pub static ref MEILISEARCH_DB_SIZE_BYTES: IntGauge =
|
||||
register_int_gauge!(opts!("meilisearch_db_size_bytes", "Meilisearch DB Size In Bytes"))
|
||||
.expect("Can't create a metric");
|
||||
@ -42,7 +35,7 @@ lazy_static! {
|
||||
"meilisearch_http_response_time_seconds",
|
||||
"Meilisearch HTTP response times",
|
||||
&["method", "path"],
|
||||
MEILISEARCH_HTTP_RESPONSE_TIME_CUSTOM_BUCKETS.to_vec()
|
||||
vec![0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0]
|
||||
)
|
||||
.expect("Can't create a metric");
|
||||
pub static ref MEILISEARCH_NB_TASKS: IntGaugeVec = register_int_gauge_vec!(
|
||||
|
@ -65,9 +65,6 @@ where
|
||||
.with_label_values(&[&request_method, request_path])
|
||||
.start_timer(),
|
||||
);
|
||||
crate::metrics::MEILISEARCH_HTTP_REQUESTS_TOTAL
|
||||
.with_label_values(&[&request_method, request_path])
|
||||
.inc();
|
||||
}
|
||||
};
|
||||
|
||||
@ -76,6 +73,14 @@ where
|
||||
Box::pin(async move {
|
||||
let res = fut.await?;
|
||||
|
||||
crate::metrics::MEILISEARCH_HTTP_REQUESTS_TOTAL
|
||||
.with_label_values(&[
|
||||
res.request().method().as_str(),
|
||||
res.request().path(),
|
||||
res.status().as_str(),
|
||||
])
|
||||
.inc();
|
||||
|
||||
if let Some(histogram_timer) = histogram_timer {
|
||||
histogram_timer.observe_duration();
|
||||
};
|
||||
|
@ -564,11 +564,11 @@ impl Opt {
|
||||
}
|
||||
if self.ssl_require_auth {
|
||||
let verifier = AllowAnyAuthenticatedClient::new(client_auth_roots);
|
||||
config.with_client_cert_verifier(verifier)
|
||||
config.with_client_cert_verifier(Arc::from(verifier))
|
||||
} else {
|
||||
let verifier =
|
||||
AllowAnyAnonymousOrAuthenticatedClient::new(client_auth_roots);
|
||||
config.with_client_cert_verifier(verifier)
|
||||
config.with_client_cert_verifier(Arc::from(verifier))
|
||||
}
|
||||
}
|
||||
None => config.with_no_client_auth(),
|
||||
|
@ -17,6 +17,7 @@ use crate::analytics::{Analytics, SearchAggregator};
|
||||
use crate::extractors::authentication::policies::*;
|
||||
use crate::extractors::authentication::GuardedData;
|
||||
use crate::extractors::sequential_extractor::SeqHandler;
|
||||
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
||||
use crate::search::{
|
||||
add_search_rules, perform_search, HybridQuery, MatchingStrategy, SearchQuery, SemanticRatio,
|
||||
DEFAULT_CROP_LENGTH, DEFAULT_CROP_MARKER, DEFAULT_HIGHLIGHT_POST_TAG,
|
||||
@ -201,7 +202,7 @@ pub async fn search_with_url_query(
|
||||
let index = index_scheduler.index(&index_uid)?;
|
||||
let features = index_scheduler.features();
|
||||
|
||||
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
|
||||
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;
|
||||
|
||||
let search_result =
|
||||
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
|
||||
@ -240,13 +241,16 @@ pub async fn search_with_post(
|
||||
|
||||
let features = index_scheduler.features();
|
||||
|
||||
let distribution = embed(&mut query, index_scheduler.get_ref(), &index).await?;
|
||||
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)?;
|
||||
|
||||
let search_result =
|
||||
tokio::task::spawn_blocking(move || perform_search(&index, query, features, distribution))
|
||||
.await?;
|
||||
if let Ok(ref search_result) = search_result {
|
||||
aggregate.succeed(search_result);
|
||||
if search_result.degraded {
|
||||
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
|
||||
}
|
||||
}
|
||||
analytics.post_search(aggregate);
|
||||
|
||||
@ -256,7 +260,7 @@ pub async fn search_with_post(
|
||||
Ok(HttpResponse::Ok().json(search_result))
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
pub fn embed(
|
||||
query: &mut SearchQuery,
|
||||
index_scheduler: &IndexScheduler,
|
||||
index: &milli::Index,
|
||||
@ -283,7 +287,6 @@ pub async fn embed(
|
||||
|
||||
let embeddings = embedder
|
||||
.embed(vec![q.to_owned()])
|
||||
.await
|
||||
.map_err(milli::vector::Error::from)
|
||||
.map_err(milli::Error::from)?
|
||||
.pop()
|
||||
|
@ -604,6 +604,8 @@ fn embedder_analytics(
|
||||
EmbedderSource::OpenAi => sources.insert("openAi"),
|
||||
EmbedderSource::HuggingFace => sources.insert("huggingFace"),
|
||||
EmbedderSource::UserProvided => sources.insert("userProvided"),
|
||||
EmbedderSource::Ollama => sources.insert("ollama"),
|
||||
EmbedderSource::Rest => sources.insert("rest"),
|
||||
};
|
||||
}
|
||||
};
|
||||
@ -623,6 +625,25 @@ fn embedder_analytics(
|
||||
)
|
||||
}
|
||||
|
||||
make_setting_route!(
|
||||
"/search-cutoff-ms",
|
||||
put,
|
||||
u64,
|
||||
meilisearch_types::deserr::DeserrJsonError<
|
||||
meilisearch_types::error::deserr_codes::InvalidSettingsSearchCutoffMs,
|
||||
>,
|
||||
search_cutoff_ms,
|
||||
"searchCutoffMs",
|
||||
analytics,
|
||||
|setting: &Option<u64>, req: &HttpRequest| {
|
||||
analytics.publish(
|
||||
"Search Cutoff Updated".to_string(),
|
||||
serde_json::json!({"search_cutoff_ms": setting }),
|
||||
Some(req),
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
macro_rules! generate_configure {
|
||||
($($mod:ident),*) => {
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
@ -653,7 +674,8 @@ generate_configure!(
|
||||
typo_tolerance,
|
||||
pagination,
|
||||
faceting,
|
||||
embedders
|
||||
embedders,
|
||||
search_cutoff_ms
|
||||
);
|
||||
|
||||
pub async fn update_all(
|
||||
@ -764,7 +786,8 @@ pub async fn update_all(
|
||||
"synonyms": {
|
||||
"total": new_settings.synonyms.as_ref().set().map(|synonyms| synonyms.len()),
|
||||
},
|
||||
"embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set())
|
||||
"embedders": crate::routes::indexes::settings::embedder_analytics(new_settings.embedders.as_ref().set()),
|
||||
"search_cutoff_ms": new_settings.search_cutoff_ms.as_ref().set(),
|
||||
}),
|
||||
Some(&req),
|
||||
);
|
||||
|
@ -75,9 +75,8 @@ pub async fn multi_search_with_post(
|
||||
})
|
||||
.with_index(query_index)?;
|
||||
|
||||
let distribution = embed(&mut query, index_scheduler.get_ref(), &index)
|
||||
.await
|
||||
.with_index(query_index)?;
|
||||
let distribution =
|
||||
embed(&mut query, index_scheduler.get_ref(), &index).with_index(query_index)?;
|
||||
|
||||
let search_result = tokio::task::spawn_blocking(move || {
|
||||
perform_search(&index, query, features, distribution)
|
||||
|
@ -1,7 +1,7 @@
|
||||
use std::cmp::min;
|
||||
use std::collections::{BTreeMap, BTreeSet, HashSet};
|
||||
use std::str::FromStr;
|
||||
use std::time::Instant;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use deserr::Deserr;
|
||||
use either::Either;
|
||||
@ -14,7 +14,7 @@ use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::index_uid::IndexUid;
|
||||
use meilisearch_types::milli::score_details::{self, ScoreDetails, ScoringStrategy};
|
||||
use meilisearch_types::milli::vector::DistributionShift;
|
||||
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues};
|
||||
use meilisearch_types::milli::{FacetValueHit, OrderBy, SearchForFacetValues, TimeBudget};
|
||||
use meilisearch_types::settings::DEFAULT_PAGINATION_MAX_TOTAL_HITS;
|
||||
use meilisearch_types::{milli, Document};
|
||||
use milli::tokenizer::TokenizerBuilder;
|
||||
@ -323,6 +323,10 @@ pub struct SearchResult {
|
||||
pub facet_distribution: Option<BTreeMap<String, IndexMap<String, u64>>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub facet_stats: Option<BTreeMap<String, FacetStats>>,
|
||||
|
||||
// This information is only used for analytics purposes
|
||||
#[serde(skip)]
|
||||
pub degraded: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Debug, Clone, PartialEq)]
|
||||
@ -382,8 +386,10 @@ fn prepare_search<'t>(
|
||||
query: &'t SearchQuery,
|
||||
features: RoFeatures,
|
||||
distribution: Option<DistributionShift>,
|
||||
time_budget: TimeBudget,
|
||||
) -> Result<(milli::Search<'t>, bool, usize, usize), MeilisearchHttpError> {
|
||||
let mut search = index.search(rtxn);
|
||||
search.time_budget(time_budget);
|
||||
|
||||
if query.vector.is_some() {
|
||||
features.check_vector("Passing `vector` as a query parameter")?;
|
||||
@ -492,18 +498,28 @@ pub fn perform_search(
|
||||
) -> Result<SearchResult, MeilisearchHttpError> {
|
||||
let before_search = Instant::now();
|
||||
let rtxn = index.read_txn()?;
|
||||
let time_budget = match index.search_cutoff(&rtxn)? {
|
||||
Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)),
|
||||
None => TimeBudget::default(),
|
||||
};
|
||||
|
||||
let (search, is_finite_pagination, max_total_hits, offset) =
|
||||
prepare_search(index, &rtxn, &query, features, distribution)?;
|
||||
prepare_search(index, &rtxn, &query, features, distribution, time_budget)?;
|
||||
|
||||
let milli::SearchResult { documents_ids, matching_words, candidates, document_scores, .. } =
|
||||
match &query.hybrid {
|
||||
Some(hybrid) => match *hybrid.semantic_ratio {
|
||||
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
|
||||
ratio => search.execute_hybrid(ratio)?,
|
||||
},
|
||||
None => search.execute()?,
|
||||
};
|
||||
let milli::SearchResult {
|
||||
documents_ids,
|
||||
matching_words,
|
||||
candidates,
|
||||
document_scores,
|
||||
degraded,
|
||||
..
|
||||
} = match &query.hybrid {
|
||||
Some(hybrid) => match *hybrid.semantic_ratio {
|
||||
ratio if ratio == 0.0 || ratio == 1.0 => search.execute()?,
|
||||
ratio => search.execute_hybrid(ratio)?,
|
||||
},
|
||||
None => search.execute()?,
|
||||
};
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(&rtxn).unwrap();
|
||||
|
||||
@ -530,7 +546,7 @@ pub fn perform_search(
|
||||
// The attributes to retrieve are the ones explicitly marked as to retrieve (all by default),
|
||||
// but these attributes must be also be present
|
||||
// - in the fields_ids_map
|
||||
// - in the the displayed attributes
|
||||
// - in the displayed attributes
|
||||
let to_retrieve_ids: BTreeSet<_> = query
|
||||
.attributes_to_retrieve
|
||||
.as_ref()
|
||||
@ -671,27 +687,16 @@ pub fn perform_search(
|
||||
|
||||
let sort_facet_values_by =
|
||||
index.sort_facet_values_by(&rtxn).map_err(milli::Error::from)?;
|
||||
let default_sort_facet_values_by =
|
||||
sort_facet_values_by.get("*").copied().unwrap_or_default();
|
||||
|
||||
if fields.iter().all(|f| f != "*") {
|
||||
let fields: Vec<_> = fields
|
||||
.iter()
|
||||
.map(|n| {
|
||||
(
|
||||
n,
|
||||
sort_facet_values_by
|
||||
.get(n)
|
||||
.copied()
|
||||
.unwrap_or(default_sort_facet_values_by),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let fields: Vec<_> =
|
||||
fields.iter().map(|n| (n, sort_facet_values_by.get(n))).collect();
|
||||
facet_distribution.facets(fields);
|
||||
}
|
||||
|
||||
let distribution = facet_distribution
|
||||
.candidates(candidates)
|
||||
.default_order_by(default_sort_facet_values_by)
|
||||
.default_order_by(sort_facet_values_by.get("*"))
|
||||
.execute()?;
|
||||
let stats = facet_distribution.compute_stats()?;
|
||||
(Some(distribution), Some(stats))
|
||||
@ -711,6 +716,7 @@ pub fn perform_search(
|
||||
processing_time_ms: before_search.elapsed().as_millis(),
|
||||
facet_distribution,
|
||||
facet_stats,
|
||||
degraded,
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
@ -724,8 +730,13 @@ pub fn perform_facet_search(
|
||||
) -> Result<FacetSearchResult, MeilisearchHttpError> {
|
||||
let before_search = Instant::now();
|
||||
let rtxn = index.read_txn()?;
|
||||
let time_budget = match index.search_cutoff(&rtxn)? {
|
||||
Some(cutoff) => TimeBudget::new(Duration::from_millis(cutoff)),
|
||||
None => TimeBudget::default(),
|
||||
};
|
||||
|
||||
let (search, _, _, _) = prepare_search(index, &rtxn, &search_query, features, None)?;
|
||||
let (search, _, _, _) =
|
||||
prepare_search(index, &rtxn, &search_query, features, None, time_budget)?;
|
||||
let mut facet_search =
|
||||
SearchForFacetValues::new(facet_name, search, search_query.hybrid.is_some());
|
||||
if let Some(facet_query) = &facet_query {
|
||||
|
@ -328,6 +328,11 @@ impl Index<'_> {
|
||||
self.service.patch_encoded(url, settings, self.encoder).await
|
||||
}
|
||||
|
||||
pub async fn update_settings_search_cutoff_ms(&self, settings: Value) -> (Value, StatusCode) {
|
||||
let url = format!("/indexes/{}/settings/search-cutoff-ms", urlencode(self.uid.as_ref()));
|
||||
self.service.put_encoded(url, settings, self.encoder).await
|
||||
}
|
||||
|
||||
pub async fn delete_settings(&self) -> (Value, StatusCode) {
|
||||
let url = format!("/indexes/{}/settings", urlencode(self.uid.as_ref()));
|
||||
self.service.delete(url).await
|
||||
|
@ -16,6 +16,7 @@ pub use server::{default_settings, Server};
|
||||
pub struct Value(pub serde_json::Value);
|
||||
|
||||
impl Value {
|
||||
#[track_caller]
|
||||
pub fn uid(&self) -> u64 {
|
||||
if let Some(uid) = self["uid"].as_u64() {
|
||||
uid
|
||||
|
@ -1237,8 +1237,8 @@ async fn error_add_documents_missing_document_id() {
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
#[ignore] // // TODO: Fix in an other PR: this does not provoke any error.
|
||||
async fn error_document_field_limit_reached() {
|
||||
#[should_panic]
|
||||
async fn error_document_field_limit_reached_in_one_document() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
@ -1246,22 +1246,241 @@ async fn error_document_field_limit_reached() {
|
||||
|
||||
let mut big_object = std::collections::HashMap::new();
|
||||
big_object.insert("id".to_owned(), "wow");
|
||||
for i in 0..65535 {
|
||||
for i in 0..(u16::MAX as usize + 1) {
|
||||
let key = i.to_string();
|
||||
big_object.insert(key, "I am a text!");
|
||||
}
|
||||
|
||||
let documents = json!([big_object]);
|
||||
|
||||
let (_response, code) = index.update_documents(documents, Some("id")).await;
|
||||
snapshot!(code, @"202");
|
||||
let (response, code) = index.update_documents(documents, Some("id")).await;
|
||||
snapshot!(code, @"500 Internal Server Error");
|
||||
|
||||
index.wait_task(0).await;
|
||||
let (response, code) = index.get_task(0).await;
|
||||
snapshot!(code, @"200");
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
// Documents without a primary key are not accepted.
|
||||
snapshot!(json_string!(response, { ".duration" => "[duration]", ".enqueuedAt" => "[date]", ".startedAt" => "[date]", ".finishedAt" => "[date]" }),
|
||||
@"");
|
||||
snapshot!(response,
|
||||
@r###"
|
||||
{
|
||||
"uid": 1,
|
||||
"indexUid": "test",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn error_document_field_limit_reached_over_multiple_documents() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
index.create(Some("id")).await;
|
||||
|
||||
let mut big_object = std::collections::HashMap::new();
|
||||
big_object.insert("id".to_owned(), "wow");
|
||||
for i in 0..(u16::MAX / 2) {
|
||||
let key = i.to_string();
|
||||
big_object.insert(key, "I am a text!");
|
||||
}
|
||||
|
||||
let documents = json!([big_object]);
|
||||
|
||||
let (response, code) = index.update_documents(documents, Some("id")).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
snapshot!(response,
|
||||
@r###"
|
||||
{
|
||||
"uid": 1,
|
||||
"indexUid": "test",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
let mut big_object = std::collections::HashMap::new();
|
||||
big_object.insert("id".to_owned(), "waw");
|
||||
for i in (u16::MAX as usize / 2)..(u16::MAX as usize + 1) {
|
||||
let key = i.to_string();
|
||||
big_object.insert(key, "I am a text!");
|
||||
}
|
||||
|
||||
let documents = json!([big_object]);
|
||||
|
||||
let (response, code) = index.update_documents(documents, Some("id")).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
snapshot!(response,
|
||||
@r###"
|
||||
{
|
||||
"uid": 2,
|
||||
"indexUid": "test",
|
||||
"status": "failed",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 0
|
||||
},
|
||||
"error": {
|
||||
"message": "A document cannot contain more than 65,535 fields.",
|
||||
"code": "max_fields_limit_exceeded",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#max_fields_limit_exceeded"
|
||||
},
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn error_document_field_limit_reached_in_one_nested_document() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
index.create(Some("id")).await;
|
||||
|
||||
let mut nested = std::collections::HashMap::new();
|
||||
for i in 0..(u16::MAX as usize + 1) {
|
||||
let key = i.to_string();
|
||||
nested.insert(key, "I am a text!");
|
||||
}
|
||||
let mut big_object = std::collections::HashMap::new();
|
||||
big_object.insert("id".to_owned(), "wow");
|
||||
|
||||
let documents = json!([big_object]);
|
||||
|
||||
let (response, code) = index.update_documents(documents, Some("id")).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
// Documents without a primary key are not accepted.
|
||||
snapshot!(response,
|
||||
@r###"
|
||||
{
|
||||
"uid": 1,
|
||||
"indexUid": "test",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn error_document_field_limit_reached_over_multiple_documents_with_nested_fields() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
index.create(Some("id")).await;
|
||||
|
||||
let mut nested = std::collections::HashMap::new();
|
||||
for i in 0..(u16::MAX / 2) {
|
||||
let key = i.to_string();
|
||||
nested.insert(key, "I am a text!");
|
||||
}
|
||||
let mut big_object = std::collections::HashMap::new();
|
||||
big_object.insert("id".to_owned(), "wow");
|
||||
|
||||
let documents = json!([big_object]);
|
||||
|
||||
let (response, code) = index.update_documents(documents, Some("id")).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
snapshot!(response,
|
||||
@r###"
|
||||
{
|
||||
"uid": 1,
|
||||
"indexUid": "test",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
|
||||
let mut nested = std::collections::HashMap::new();
|
||||
for i in 0..(u16::MAX / 2) {
|
||||
let key = i.to_string();
|
||||
nested.insert(key, "I am a text!");
|
||||
}
|
||||
let mut big_object = std::collections::HashMap::new();
|
||||
big_object.insert("id".to_owned(), "wow");
|
||||
|
||||
let documents = json!([big_object]);
|
||||
|
||||
let (response, code) = index.update_documents(documents, Some("id")).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
|
||||
let response = index.wait_task(response.uid()).await;
|
||||
snapshot!(code, @"202 Accepted");
|
||||
snapshot!(response,
|
||||
@r###"
|
||||
{
|
||||
"uid": 2,
|
||||
"indexUid": "test",
|
||||
"status": "succeeded",
|
||||
"type": "documentAdditionOrUpdate",
|
||||
"canceledBy": null,
|
||||
"details": {
|
||||
"receivedDocuments": 1,
|
||||
"indexedDocuments": 1
|
||||
},
|
||||
"error": null,
|
||||
"duration": "[duration]",
|
||||
"enqueuedAt": "[date]",
|
||||
"startedAt": "[date]",
|
||||
"finishedAt": "[date]"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
|
@ -77,7 +77,8 @@ async fn import_dump_v1_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -238,7 +239,8 @@ async fn import_dump_v1_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -385,7 +387,8 @@ async fn import_dump_v1_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -518,7 +521,8 @@ async fn import_dump_v2_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -663,7 +667,8 @@ async fn import_dump_v2_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -807,7 +812,8 @@ async fn import_dump_v2_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -940,7 +946,8 @@ async fn import_dump_v3_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1085,7 +1092,8 @@ async fn import_dump_v3_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1229,7 +1237,8 @@ async fn import_dump_v3_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1362,7 +1371,8 @@ async fn import_dump_v4_movie_raw() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1507,7 +1517,8 @@ async fn import_dump_v4_movie_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1651,7 +1662,8 @@ async fn import_dump_v4_rubygems_with_settings() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###
|
||||
);
|
||||
@ -1895,7 +1907,8 @@ async fn import_dump_v6_containing_experimental_features() {
|
||||
},
|
||||
"pagination": {
|
||||
"maxTotalHits": 1000
|
||||
}
|
||||
},
|
||||
"searchCutoffMs": null
|
||||
}
|
||||
"###);
|
||||
|
||||
|
@ -123,6 +123,28 @@ async fn simple_facet_search_with_max_values() {
|
||||
assert_eq!(dbg!(response)["facetHits"].as_array().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn simple_facet_search_by_count_with_max_values() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
let documents = DOCUMENTS.clone();
|
||||
index
|
||||
.update_settings_faceting(
|
||||
json!({ "maxValuesPerFacet": 1, "sortFacetValuesBy": { "*": "count" } }),
|
||||
)
|
||||
.await;
|
||||
index.update_settings_filterable_attributes(json!(["genres"])).await;
|
||||
index.add_documents(documents, None).await;
|
||||
index.wait_task(2).await;
|
||||
|
||||
let (response, code) =
|
||||
index.facet_search(json!({"facetName": "genres", "facetQuery": "a"})).await;
|
||||
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
assert_eq!(dbg!(response)["facetHits"].as_array().unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn non_filterable_facet_search_error() {
|
||||
let server = Server::new().await;
|
||||
@ -157,3 +179,24 @@ async fn facet_search_dont_support_words() {
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
assert_eq!(response["facetHits"].as_array().unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn simple_facet_search_with_sort_by_count() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
let documents = DOCUMENTS.clone();
|
||||
index.update_settings_faceting(json!({ "sortFacetValuesBy": { "*": "count" } })).await;
|
||||
index.update_settings_filterable_attributes(json!(["genres"])).await;
|
||||
index.add_documents(documents, None).await;
|
||||
index.wait_task(2).await;
|
||||
|
||||
let (response, code) =
|
||||
index.facet_search(json!({"facetName": "genres", "facetQuery": "a"})).await;
|
||||
|
||||
assert_eq!(code, 200, "{}", response);
|
||||
let hits = response["facetHits"].as_array().unwrap();
|
||||
assert_eq!(hits.len(), 2);
|
||||
assert_eq!(hits[0], json!({ "value": "Action", "count": 3 }));
|
||||
assert_eq!(hits[1], json!({ "value": "Adventure", "count": 2 }));
|
||||
}
|
||||
|
@ -834,6 +834,94 @@ async fn test_score_details() {
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn test_degraded_score_details() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
let documents = NESTED_DOCUMENTS.clone();
|
||||
|
||||
index.add_documents(json!(documents), None).await;
|
||||
// We can't really use anything else than 0ms here; otherwise, the test will get flaky.
|
||||
let (res, _code) = index.update_settings(json!({ "searchCutoffMs": 0 })).await;
|
||||
index.wait_task(res.uid()).await;
|
||||
|
||||
index
|
||||
.search(
|
||||
json!({
|
||||
"q": "b",
|
||||
"attributesToRetrieve": ["doggos.name", "cattos"],
|
||||
"showRankingScoreDetails": true,
|
||||
}),
|
||||
|response, code| {
|
||||
meili_snap::snapshot!(code, @"200 OK");
|
||||
meili_snap::snapshot!(meili_snap::json_string!(response, { ".processingTimeMs" => "[duration]" }), @r###"
|
||||
{
|
||||
"hits": [
|
||||
{
|
||||
"doggos": [
|
||||
{
|
||||
"name": "bobby"
|
||||
},
|
||||
{
|
||||
"name": "buddy"
|
||||
}
|
||||
],
|
||||
"cattos": "pésti",
|
||||
"_rankingScoreDetails": {
|
||||
"skipped": {
|
||||
"order": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"doggos": [
|
||||
{
|
||||
"name": "gros bill"
|
||||
}
|
||||
],
|
||||
"cattos": [
|
||||
"simba",
|
||||
"pestiféré"
|
||||
],
|
||||
"_rankingScoreDetails": {
|
||||
"skipped": {
|
||||
"order": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"doggos": [
|
||||
{
|
||||
"name": "turbo"
|
||||
},
|
||||
{
|
||||
"name": "fast"
|
||||
}
|
||||
],
|
||||
"cattos": [
|
||||
"moumoute",
|
||||
"gomez"
|
||||
],
|
||||
"_rankingScoreDetails": {
|
||||
"skipped": {
|
||||
"order": 0
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"query": "b",
|
||||
"processingTimeMs": "[duration]",
|
||||
"limit": 20,
|
||||
"offset": 0,
|
||||
"estimatedTotalHits": 3
|
||||
}
|
||||
"###);
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn experimental_feature_vector_store() {
|
||||
let server = Server::new().await;
|
||||
|
@ -337,3 +337,31 @@ async fn settings_bad_pagination() {
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
async fn settings_bad_search_cutoff_ms() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
|
||||
let (response, code) = index.update_settings(json!({ "searchCutoffMs": "doggo" })).await;
|
||||
snapshot!(code, @"400 Bad Request");
|
||||
snapshot!(json_string!(response), @r###"
|
||||
{
|
||||
"message": "Invalid value type at `.searchCutoffMs`: expected a positive integer, but found a string: `\"doggo\"`",
|
||||
"code": "invalid_settings_search_cutoff_ms",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_settings_search_cutoff_ms"
|
||||
}
|
||||
"###);
|
||||
|
||||
let (response, code) = index.update_settings_search_cutoff_ms(json!("doggo")).await;
|
||||
snapshot!(code, @"400 Bad Request");
|
||||
snapshot!(json_string!(response), @r###"
|
||||
{
|
||||
"message": "Invalid value type: expected a positive integer, but found a string: `\"doggo\"`",
|
||||
"code": "invalid_settings_search_cutoff_ms",
|
||||
"type": "invalid_request",
|
||||
"link": "https://docs.meilisearch.com/errors#invalid_settings_search_cutoff_ms"
|
||||
}
|
||||
"###);
|
||||
}
|
||||
|
@ -35,6 +35,7 @@ static DEFAULT_SETTINGS_VALUES: Lazy<HashMap<&'static str, Value>> = Lazy::new(|
|
||||
"maxTotalHits": json!(1000),
|
||||
}),
|
||||
);
|
||||
map.insert("search_cutoff_ms", json!(null));
|
||||
map
|
||||
});
|
||||
|
||||
@ -49,12 +50,12 @@ async fn get_settings_unexisting_index() {
|
||||
async fn get_settings() {
|
||||
let server = Server::new().await;
|
||||
let index = server.index("test");
|
||||
index.create(None).await;
|
||||
index.wait_task(0).await;
|
||||
let (response, _code) = index.create(None).await;
|
||||
index.wait_task(response.uid()).await;
|
||||
let (response, code) = index.settings().await;
|
||||
assert_eq!(code, 200);
|
||||
let settings = response.as_object().unwrap();
|
||||
assert_eq!(settings.keys().len(), 15);
|
||||
assert_eq!(settings.keys().len(), 16);
|
||||
assert_eq!(settings["displayedAttributes"], json!(["*"]));
|
||||
assert_eq!(settings["searchableAttributes"], json!(["*"]));
|
||||
assert_eq!(settings["filterableAttributes"], json!([]));
|
||||
@ -84,6 +85,7 @@ async fn get_settings() {
|
||||
})
|
||||
);
|
||||
assert_eq!(settings["proximityPrecision"], json!("byWord"));
|
||||
assert_eq!(settings["searchCutoffMs"], json!(null));
|
||||
}
|
||||
|
||||
#[actix_rt::test]
|
||||
@ -285,7 +287,8 @@ test_setting_routes!(
|
||||
ranking_rules put,
|
||||
synonyms put,
|
||||
pagination patch,
|
||||
faceting patch
|
||||
faceting patch,
|
||||
search_cutoff_ms put
|
||||
);
|
||||
|
||||
#[actix_rt::test]
|
||||
|
@ -17,7 +17,7 @@ bincode = "1.3.3"
|
||||
bstr = "1.9.0"
|
||||
bytemuck = { version = "1.14.0", features = ["extern_crate_alloc"] }
|
||||
byteorder = "1.5.0"
|
||||
charabia = { version = "0.8.7", default-features = false }
|
||||
charabia = { version = "0.8.8", default-features = false }
|
||||
concat-arrays = "0.1.2"
|
||||
crossbeam-channel = "0.5.11"
|
||||
deserr = "0.6.1"
|
||||
@ -80,17 +80,13 @@ tokenizers = { git = "https://github.com/huggingface/tokenizers.git", tag = "v0.
|
||||
hf-hub = { git = "https://github.com/dureuill/hf-hub.git", branch = "rust_tls", default_features = false, features = [
|
||||
"online",
|
||||
] }
|
||||
tokio = { version = "1.35.1", features = ["rt"] }
|
||||
futures = "0.3.30"
|
||||
reqwest = { version = "0.11.23", features = [
|
||||
"rustls-tls",
|
||||
"json",
|
||||
], default-features = false }
|
||||
tiktoken-rs = "0.5.8"
|
||||
liquid = "0.26.4"
|
||||
arroy = "0.2.0"
|
||||
rand = "0.8.5"
|
||||
tracing = "0.1.40"
|
||||
ureq = { version = "2.9.6", features = ["json"] }
|
||||
url = "2.5.0"
|
||||
|
||||
[dev-dependencies]
|
||||
mimalloc = { version = "0.1.39", default-features = false }
|
||||
|
@ -6,7 +6,7 @@ use std::time::Instant;
|
||||
use heed::EnvOpenOptions;
|
||||
use milli::{
|
||||
execute_search, filtered_universe, DefaultSearchLogger, GeoSortStrategy, Index, SearchContext,
|
||||
SearchLogger, TermsMatchingStrategy,
|
||||
SearchLogger, TermsMatchingStrategy, TimeBudget,
|
||||
};
|
||||
|
||||
#[global_allocator]
|
||||
@ -65,6 +65,7 @@ fn main() -> Result<(), Box<dyn Error>> {
|
||||
None,
|
||||
&mut DefaultSearchLogger,
|
||||
logger,
|
||||
TimeBudget::max(),
|
||||
)?;
|
||||
if let Some((logger, dir)) = detailed_logger {
|
||||
logger.finish(&mut ctx, Path::new(dir))?;
|
||||
|
@ -243,6 +243,8 @@ only composed of alphanumeric characters (a-z A-Z 0-9), hyphens (-) and undersco
|
||||
},
|
||||
#[error("`.embedders.{embedder_name}.dimensions`: `dimensions` cannot be zero")]
|
||||
InvalidSettingsDimensions { embedder_name: String },
|
||||
#[error("`.embedders.{embedder_name}.url`: could not parse `{url}`: {inner_error}")]
|
||||
InvalidUrl { embedder_name: String, inner_error: url::ParseError, url: String },
|
||||
}
|
||||
|
||||
impl From<crate::vector::Error> for Error {
|
||||
|
@ -20,13 +20,13 @@ use crate::heed_codec::facet::{
|
||||
use crate::heed_codec::{
|
||||
BEU16StrCodec, FstSetCodec, ScriptLanguageCodec, StrBEU16Codec, StrRefCodec,
|
||||
};
|
||||
use crate::order_by_map::OrderByMap;
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
use crate::{
|
||||
default_criteria, CboRoaringBitmapCodec, Criterion, DocumentId, ExternalDocumentsIds,
|
||||
FacetDistribution, FieldDistribution, FieldId, FieldIdWordCountCodec, GeoPoint, ObkvCodec,
|
||||
OrderBy, Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16,
|
||||
BEU32, BEU64,
|
||||
Result, RoaringBitmapCodec, RoaringBitmapLenCodec, Search, U8StrStrCodec, BEU16, BEU32, BEU64,
|
||||
};
|
||||
|
||||
pub const DEFAULT_MIN_WORD_LEN_ONE_TYPO: u8 = 5;
|
||||
@ -67,6 +67,7 @@ pub mod main_key {
|
||||
pub const PAGINATION_MAX_TOTAL_HITS: &str = "pagination-max-total-hits";
|
||||
pub const PROXIMITY_PRECISION: &str = "proximity-precision";
|
||||
pub const EMBEDDING_CONFIGS: &str = "embedding_configs";
|
||||
pub const SEARCH_CUTOFF: &str = "search_cutoff";
|
||||
}
|
||||
|
||||
pub mod db_name {
|
||||
@ -1373,21 +1374,19 @@ impl Index {
|
||||
self.main.remap_key_type::<Str>().delete(txn, main_key::MAX_VALUES_PER_FACET)
|
||||
}
|
||||
|
||||
pub fn sort_facet_values_by(&self, txn: &RoTxn) -> heed::Result<HashMap<String, OrderBy>> {
|
||||
let mut orders = self
|
||||
pub fn sort_facet_values_by(&self, txn: &RoTxn) -> heed::Result<OrderByMap> {
|
||||
let orders = self
|
||||
.main
|
||||
.remap_types::<Str, SerdeJson<HashMap<String, OrderBy>>>()
|
||||
.remap_types::<Str, SerdeJson<OrderByMap>>()
|
||||
.get(txn, main_key::SORT_FACET_VALUES_BY)?
|
||||
.unwrap_or_default();
|
||||
// Insert the default ordering if it is not already overwritten by the user.
|
||||
orders.entry("*".to_string()).or_insert(OrderBy::Lexicographic);
|
||||
Ok(orders)
|
||||
}
|
||||
|
||||
pub(crate) fn put_sort_facet_values_by(
|
||||
&self,
|
||||
txn: &mut RwTxn,
|
||||
val: &HashMap<String, OrderBy>,
|
||||
val: &OrderByMap,
|
||||
) -> heed::Result<()> {
|
||||
self.main.remap_types::<Str, SerdeJson<_>>().put(txn, main_key::SORT_FACET_VALUES_BY, &val)
|
||||
}
|
||||
@ -1507,6 +1506,18 @@ impl Index {
|
||||
_ => "default".to_owned(),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) fn put_search_cutoff(&self, wtxn: &mut RwTxn<'_>, cutoff: u64) -> heed::Result<()> {
|
||||
self.main.remap_types::<Str, BEU64>().put(wtxn, main_key::SEARCH_CUTOFF, &cutoff)
|
||||
}
|
||||
|
||||
pub fn search_cutoff(&self, rtxn: &RoTxn<'_>) -> Result<Option<u64>> {
|
||||
Ok(self.main.remap_types::<Str, BEU64>().get(rtxn, main_key::SEARCH_CUTOFF)?)
|
||||
}
|
||||
|
||||
pub(crate) fn delete_search_cutoff(&self, wtxn: &mut RwTxn<'_>) -> heed::Result<bool> {
|
||||
self.main.remap_key_type::<Str>().delete(wtxn, main_key::SEARCH_CUTOFF)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@ -2423,6 +2434,7 @@ pub(crate) mod tests {
|
||||
candidates: _,
|
||||
document_scores: _,
|
||||
mut documents_ids,
|
||||
degraded: _,
|
||||
} = search.execute().unwrap();
|
||||
let primary_key_id = index.fields_ids_map(&rtxn).unwrap().id("primary_key").unwrap();
|
||||
documents_ids.sort_unstable();
|
||||
|
@ -16,6 +16,7 @@ pub mod facet;
|
||||
mod fields_ids_map;
|
||||
pub mod heed_codec;
|
||||
pub mod index;
|
||||
pub mod order_by_map;
|
||||
pub mod prompt;
|
||||
pub mod proximity;
|
||||
pub mod score_details;
|
||||
@ -29,6 +30,7 @@ pub mod snapshot_tests;
|
||||
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::convert::{TryFrom, TryInto};
|
||||
use std::fmt;
|
||||
use std::hash::BuildHasherDefault;
|
||||
|
||||
use charabia::normalizer::{CharNormalizer, CompatibilityDecompositionNormalizer};
|
||||
@ -56,10 +58,10 @@ pub use self::heed_codec::{
|
||||
UncheckedU8StrStrCodec,
|
||||
};
|
||||
pub use self::index::Index;
|
||||
pub use self::search::facet::{FacetValueHit, SearchForFacetValues};
|
||||
pub use self::search::{
|
||||
FacetDistribution, FacetValueHit, Filter, FormatOptions, MatchBounds, MatcherBuilder,
|
||||
MatchingWords, OrderBy, Search, SearchForFacetValues, SearchResult, TermsMatchingStrategy,
|
||||
DEFAULT_VALUES_PER_FACET,
|
||||
FacetDistribution, Filter, FormatOptions, MatchBounds, MatcherBuilder, MatchingWords, OrderBy,
|
||||
Search, SearchResult, TermsMatchingStrategy, DEFAULT_VALUES_PER_FACET,
|
||||
};
|
||||
|
||||
pub type Result<T> = std::result::Result<T, error::Error>;
|
||||
@ -103,6 +105,73 @@ pub const MAX_WORD_LENGTH: usize = MAX_LMDB_KEY_LENGTH / 2;
|
||||
|
||||
pub const MAX_POSITION_PER_ATTRIBUTE: u32 = u16::MAX as u32 + 1;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct TimeBudget {
|
||||
started_at: std::time::Instant,
|
||||
budget: std::time::Duration,
|
||||
|
||||
/// When testing the time budget, ensuring we did more than iteration of the bucket sort can be useful.
|
||||
/// But to avoid being flaky, the only option is to add the ability to stop after a specific number of calls instead of a `Duration`.
|
||||
#[cfg(test)]
|
||||
stop_after: Option<(std::sync::Arc<std::sync::atomic::AtomicUsize>, usize)>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for TimeBudget {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_struct("TimeBudget")
|
||||
.field("started_at", &self.started_at)
|
||||
.field("budget", &self.budget)
|
||||
.field("left", &(self.budget - self.started_at.elapsed()))
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TimeBudget {
|
||||
fn default() -> Self {
|
||||
Self::new(std::time::Duration::from_millis(150))
|
||||
}
|
||||
}
|
||||
|
||||
impl TimeBudget {
|
||||
pub fn new(budget: std::time::Duration) -> Self {
|
||||
Self {
|
||||
started_at: std::time::Instant::now(),
|
||||
budget,
|
||||
|
||||
#[cfg(test)]
|
||||
stop_after: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max() -> Self {
|
||||
Self::new(std::time::Duration::from_secs(u64::MAX))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn with_stop_after(mut self, stop_after: usize) -> Self {
|
||||
use std::sync::atomic::AtomicUsize;
|
||||
use std::sync::Arc;
|
||||
|
||||
self.stop_after = Some((Arc::new(AtomicUsize::new(0)), stop_after));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn exceeded(&self) -> bool {
|
||||
#[cfg(test)]
|
||||
if let Some((current, stop_after)) = &self.stop_after {
|
||||
let current = current.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if current >= *stop_after {
|
||||
return true;
|
||||
} else {
|
||||
// if a number has been specified then we ignore entirely the time budget
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
self.started_at.elapsed() > self.budget
|
||||
}
|
||||
}
|
||||
|
||||
// Convert an absolute word position into a relative position.
|
||||
// Return the field id of the attribute related to the absolute position
|
||||
// and the relative position in the attribute.
|
||||
|
57
milli/src/order_by_map.rs
Normal file
57
milli/src/order_by_map.rs
Normal file
@ -0,0 +1,57 @@
|
||||
use std::collections::{hash_map, HashMap};
|
||||
use std::iter::FromIterator;
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
|
||||
use crate::OrderBy;
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct OrderByMap(HashMap<String, OrderBy>);
|
||||
|
||||
impl OrderByMap {
|
||||
pub fn get(&self, key: impl AsRef<str>) -> OrderBy {
|
||||
self.0
|
||||
.get(key.as_ref())
|
||||
.copied()
|
||||
.unwrap_or_else(|| self.0.get("*").copied().unwrap_or_default())
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: String, value: OrderBy) -> Option<OrderBy> {
|
||||
self.0.insert(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OrderByMap {
|
||||
fn default() -> Self {
|
||||
let mut map = HashMap::new();
|
||||
map.insert("*".to_string(), OrderBy::Lexicographic);
|
||||
OrderByMap(map)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromIterator<(String, OrderBy)> for OrderByMap {
|
||||
fn from_iter<T: IntoIterator<Item = (String, OrderBy)>>(iter: T) -> Self {
|
||||
OrderByMap(iter.into_iter().collect())
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoIterator for OrderByMap {
|
||||
type Item = (String, OrderBy);
|
||||
type IntoIter = hash_map::IntoIter<String, OrderBy>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
self.0.into_iter()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for OrderByMap {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let mut map = Deserialize::deserialize(deserializer).map(OrderByMap)?;
|
||||
// Insert the default ordering if it is not already overwritten by the user.
|
||||
map.0.entry("*".to_string()).or_insert(OrderBy::default());
|
||||
Ok(map)
|
||||
}
|
||||
}
|
@ -17,6 +17,9 @@ pub enum ScoreDetails {
|
||||
Sort(Sort),
|
||||
Vector(Vector),
|
||||
GeoSort(GeoSort),
|
||||
|
||||
/// Returned when we don't have the time to finish applying all the subsequent ranking-rules
|
||||
Skipped,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
@ -50,6 +53,7 @@ impl ScoreDetails {
|
||||
ScoreDetails::Sort(_) => None,
|
||||
ScoreDetails::GeoSort(_) => None,
|
||||
ScoreDetails::Vector(_) => None,
|
||||
ScoreDetails::Skipped => Some(Rank { rank: 0, max_rank: 1 }),
|
||||
}
|
||||
}
|
||||
|
||||
@ -97,6 +101,7 @@ impl ScoreDetails {
|
||||
ScoreDetails::Vector(vector) => RankOrValue::Score(
|
||||
vector.value_similarity.as_ref().map(|(_, s)| *s as f64).unwrap_or(0.0f64),
|
||||
),
|
||||
ScoreDetails::Skipped => RankOrValue::Rank(Rank { rank: 0, max_rank: 1 }),
|
||||
}
|
||||
}
|
||||
|
||||
@ -256,6 +261,11 @@ impl ScoreDetails {
|
||||
details_map.insert(vector, details);
|
||||
order += 1;
|
||||
}
|
||||
ScoreDetails::Skipped => {
|
||||
details_map
|
||||
.insert("skipped".to_string(), serde_json::json!({ "order": order }));
|
||||
order += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
details_map
|
||||
|
@ -168,7 +168,7 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
|
||||
}
|
||||
|
||||
// should we stop?
|
||||
// We should if the the search range doesn't include any
|
||||
// We should if the search range doesn't include any
|
||||
// element from the previous key or its successors
|
||||
let should_stop = {
|
||||
match self.right {
|
||||
@ -232,7 +232,7 @@ impl<'t, 'b, 'bitmap> FacetRangeSearch<'t, 'b, 'bitmap> {
|
||||
}
|
||||
|
||||
// should we stop?
|
||||
// We should if the the search range doesn't include any
|
||||
// We should if the search range doesn't include any
|
||||
// element from the previous key or its successors
|
||||
let should_stop = {
|
||||
match self.right {
|
||||
|
@ -6,15 +6,18 @@ use roaring::RoaringBitmap;
|
||||
|
||||
pub use self::facet_distribution::{FacetDistribution, OrderBy, DEFAULT_VALUES_PER_FACET};
|
||||
pub use self::filter::{BadGeoError, Filter};
|
||||
pub use self::search::{FacetValueHit, SearchForFacetValues};
|
||||
use crate::heed_codec::facet::{FacetGroupKeyCodec, FacetGroupValueCodec, OrderedF64Codec};
|
||||
use crate::heed_codec::BytesRefCodec;
|
||||
use crate::{Index, Result};
|
||||
|
||||
mod facet_distribution;
|
||||
mod facet_distribution_iter;
|
||||
mod facet_range_search;
|
||||
mod facet_sort_ascending;
|
||||
mod facet_sort_descending;
|
||||
mod filter;
|
||||
mod search;
|
||||
|
||||
fn facet_extreme_value<'t>(
|
||||
mut extreme_it: impl Iterator<Item = heed::Result<(RoaringBitmap, &'t [u8])>> + 't,
|
||||
|
326
milli/src/search/facet/search.rs
Normal file
326
milli/src/search/facet/search.rs
Normal file
@ -0,0 +1,326 @@
|
||||
use std::cmp::{Ordering, Reverse};
|
||||
use std::collections::BinaryHeap;
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
use charabia::normalizer::NormalizerOption;
|
||||
use charabia::Normalize;
|
||||
use fst::automaton::{Automaton, Str};
|
||||
use fst::{IntoStreamer, Streamer};
|
||||
use roaring::RoaringBitmap;
|
||||
use tracing::error;
|
||||
|
||||
use crate::error::UserError;
|
||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
|
||||
use crate::search::build_dfa;
|
||||
use crate::{DocumentId, FieldId, OrderBy, Result, Search};
|
||||
|
||||
/// The maximum number of values per facet returned by the facet search route.
|
||||
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
|
||||
|
||||
pub struct SearchForFacetValues<'a> {
|
||||
query: Option<String>,
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
max_values: usize,
|
||||
is_hybrid: bool,
|
||||
}
|
||||
|
||||
impl<'a> SearchForFacetValues<'a> {
|
||||
pub fn new(
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
is_hybrid: bool,
|
||||
) -> SearchForFacetValues<'a> {
|
||||
SearchForFacetValues {
|
||||
query: None,
|
||||
facet,
|
||||
search_query,
|
||||
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
|
||||
is_hybrid,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
|
||||
self.query = Some(query.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_values(&mut self, max: usize) -> &mut Self {
|
||||
self.max_values = max;
|
||||
self
|
||||
}
|
||||
|
||||
fn one_original_value_of(
|
||||
&self,
|
||||
field_id: FieldId,
|
||||
facet_str: &str,
|
||||
any_docid: DocumentId,
|
||||
) -> Result<Option<String>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
let key: (FieldId, _, &str) = (field_id, any_docid, facet_str);
|
||||
Ok(index.field_id_docid_facet_strings.get(rtxn, &key)?.map(|v| v.to_owned()))
|
||||
}
|
||||
|
||||
pub fn execute(&self) -> Result<Vec<FacetValueHit>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let filterable_fields = index.filterable_fields(rtxn)?;
|
||||
if !filterable_fields.contains(&self.facet) {
|
||||
let (valid_fields, hidden_fields) =
|
||||
index.remove_hidden_fields(rtxn, filterable_fields)?;
|
||||
|
||||
return Err(UserError::InvalidFacetSearchFacetName {
|
||||
field: self.facet.clone(),
|
||||
valid_fields,
|
||||
hidden_fields,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(rtxn)?;
|
||||
let fid = match fields_ids_map.id(&self.facet) {
|
||||
Some(fid) => fid,
|
||||
// we return an empty list of results when the attribute has been
|
||||
// set as filterable but no document contains this field (yet).
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let fst = match self.search_query.index.facet_id_string_fst.get(rtxn, &fid)? {
|
||||
Some(fst) => fst,
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let search_candidates = self
|
||||
.search_query
|
||||
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
|
||||
|
||||
let mut results = match index.sort_facet_values_by(rtxn)?.get(&self.facet) {
|
||||
OrderBy::Lexicographic => ValuesCollection::by_lexicographic(self.max_values),
|
||||
OrderBy::Count => ValuesCollection::by_count(self.max_values),
|
||||
};
|
||||
|
||||
match self.query.as_ref() {
|
||||
Some(query) => {
|
||||
let options = NormalizerOption { lossy: true, ..Default::default() };
|
||||
let query = query.normalize(&options);
|
||||
let query = query.as_ref();
|
||||
|
||||
let authorize_typos = self.search_query.index.authorize_typos(rtxn)?;
|
||||
let field_authorizes_typos =
|
||||
!self.search_query.index.exact_attributes_ids(rtxn)?.contains(&fid);
|
||||
|
||||
if authorize_typos && field_authorizes_typos {
|
||||
let exact_words_fst = self.search_query.index.exact_words(rtxn)?;
|
||||
if exact_words_fst.map_or(false, |fst| fst.contains(query)) {
|
||||
if fst.contains(query) {
|
||||
self.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
query,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?;
|
||||
}
|
||||
} else {
|
||||
let one_typo = self.search_query.index.min_word_len_one_typo(rtxn)?;
|
||||
let two_typos = self.search_query.index.min_word_len_two_typos(rtxn)?;
|
||||
|
||||
let is_prefix = true;
|
||||
let automaton = if query.len() < one_typo as usize {
|
||||
build_dfa(query, 0, is_prefix)
|
||||
} else if query.len() < two_typos as usize {
|
||||
build_dfa(query, 1, is_prefix)
|
||||
} else {
|
||||
build_dfa(query, 2, is_prefix)
|
||||
};
|
||||
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let automaton = Str::new(query).starts_with();
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let prefix = FacetGroupKey { field_id: fid, level: 0, left_bound: "" };
|
||||
for result in index.facet_id_string_docids.prefix_iter(rtxn, &prefix)? {
|
||||
let (FacetGroupKey { left_bound, .. }, FacetGroupValue { bitmap, .. }) =
|
||||
result?;
|
||||
let count = search_candidates.intersection_len(&bitmap);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, left_bound, bitmap.min().unwrap())?
|
||||
.unwrap_or_else(|| left_bound.to_string());
|
||||
if results.insert(FacetValueHit { value, count }).is_break() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results.into_sorted_vec())
|
||||
}
|
||||
|
||||
fn fetch_original_facets_using_normalized(
|
||||
&self,
|
||||
fid: FieldId,
|
||||
value: &str,
|
||||
query: &str,
|
||||
search_candidates: &RoaringBitmap,
|
||||
results: &mut ValuesCollection,
|
||||
) -> Result<ControlFlow<()>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let database = index.facet_id_normalized_string_strings;
|
||||
let key = (fid, value);
|
||||
let original_strings = match database.get(rtxn, &key)? {
|
||||
Some(original_strings) => original_strings,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
for original in original_strings {
|
||||
let key = FacetGroupKey { field_id: fid, level: 0, left_bound: original.as_str() };
|
||||
let docids = match index.facet_id_string_docids.get(rtxn, &key)? {
|
||||
Some(FacetGroupValue { bitmap, .. }) => bitmap,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
let count = search_candidates.intersection_len(&docids);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, &original, docids.min().unwrap())?
|
||||
.unwrap_or_else(|| query.to_string());
|
||||
if results.insert(FacetValueHit { value, count }).is_break() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, PartialEq)]
|
||||
pub struct FacetValueHit {
|
||||
/// The original facet value
|
||||
pub value: String,
|
||||
/// The number of documents associated to this facet
|
||||
pub count: u64,
|
||||
}
|
||||
|
||||
impl PartialOrd for FacetValueHit {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for FacetValueHit {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.count.cmp(&other.count).then_with(|| self.value.cmp(&other.value))
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for FacetValueHit {}
|
||||
|
||||
/// A wrapper type that collects the best facet values by
|
||||
/// lexicographic or number of associated values.
|
||||
enum ValuesCollection {
|
||||
/// Keeps the top values according to the lexicographic order.
|
||||
Lexicographic { max: usize, content: Vec<FacetValueHit> },
|
||||
/// Keeps the top values according to the number of values associated to them.
|
||||
///
|
||||
/// Note that it is a max heap and we need to move the smallest counts
|
||||
/// at the top to be able to pop them when we reach the max_values limit.
|
||||
Count { max: usize, content: BinaryHeap<Reverse<FacetValueHit>> },
|
||||
}
|
||||
|
||||
impl ValuesCollection {
|
||||
pub fn by_lexicographic(max: usize) -> Self {
|
||||
ValuesCollection::Lexicographic { max, content: Vec::new() }
|
||||
}
|
||||
|
||||
pub fn by_count(max: usize) -> Self {
|
||||
ValuesCollection::Count { max, content: BinaryHeap::new() }
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, value: FacetValueHit) -> ControlFlow<()> {
|
||||
match self {
|
||||
ValuesCollection::Lexicographic { max, content } => {
|
||||
if content.len() < *max {
|
||||
content.push(value);
|
||||
if content.len() < *max {
|
||||
return ControlFlow::Continue(());
|
||||
}
|
||||
}
|
||||
ControlFlow::Break(())
|
||||
}
|
||||
ValuesCollection::Count { max, content } => {
|
||||
if content.len() == *max {
|
||||
// Peeking gives us the worst value in the list as
|
||||
// this is a max-heap and we reversed it.
|
||||
let Some(mut peek) = content.peek_mut() else { return ControlFlow::Break(()) };
|
||||
if peek.0.count <= value.count {
|
||||
// Replace the current worst value in the heap
|
||||
// with the new one we received that is better.
|
||||
*peek = Reverse(value);
|
||||
}
|
||||
} else {
|
||||
content.push(Reverse(value));
|
||||
}
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the list of facet values in descending order of, either,
|
||||
/// count or lexicographic order of the value depending on the type.
|
||||
pub fn into_sorted_vec(self) -> Vec<FacetValueHit> {
|
||||
match self {
|
||||
ValuesCollection::Lexicographic { content, .. } => content.into_iter().collect(),
|
||||
ValuesCollection::Count { content, .. } => {
|
||||
// Convert the heap into a vec of hits by removing the Reverse wrapper.
|
||||
// Hits are already in the right order as they were reversed and there
|
||||
// are output in ascending order.
|
||||
content.into_sorted_vec().into_iter().map(|Reverse(hit)| hit).collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -10,6 +10,7 @@ struct ScoreWithRatioResult {
|
||||
matching_words: MatchingWords,
|
||||
candidates: RoaringBitmap,
|
||||
document_scores: Vec<(u32, ScoreWithRatio)>,
|
||||
degraded: bool,
|
||||
}
|
||||
|
||||
type ScoreWithRatio = (Vec<ScoreDetails>, f32);
|
||||
@ -49,8 +50,12 @@ fn compare_scores(
|
||||
order => return order,
|
||||
}
|
||||
}
|
||||
(Some(ScoreValue::Score(_)), Some(_)) => return Ordering::Greater,
|
||||
(Some(_), Some(ScoreValue::Score(_))) => return Ordering::Less,
|
||||
(Some(ScoreValue::Score(x)), Some(_)) => {
|
||||
return if x == 0. { Ordering::Less } else { Ordering::Greater }
|
||||
}
|
||||
(Some(_), Some(ScoreValue::Score(x))) => {
|
||||
return if x == 0. { Ordering::Greater } else { Ordering::Less }
|
||||
}
|
||||
// if we have this, we're bad
|
||||
(Some(ScoreValue::GeoSort(_)), Some(ScoreValue::Sort(_)))
|
||||
| (Some(ScoreValue::Sort(_)), Some(ScoreValue::GeoSort(_))) => {
|
||||
@ -72,6 +77,7 @@ impl ScoreWithRatioResult {
|
||||
matching_words: results.matching_words,
|
||||
candidates: results.candidates,
|
||||
document_scores,
|
||||
degraded: results.degraded,
|
||||
}
|
||||
}
|
||||
|
||||
@ -106,6 +112,7 @@ impl ScoreWithRatioResult {
|
||||
candidates: left.candidates | right.candidates,
|
||||
documents_ids,
|
||||
document_scores,
|
||||
degraded: left.degraded | right.degraded,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -131,6 +138,7 @@ impl<'a> Search<'a> {
|
||||
index: self.index,
|
||||
distribution_shift: self.distribution_shift,
|
||||
embedder_name: self.embedder_name.clone(),
|
||||
time_budget: self.time_budget.clone(),
|
||||
};
|
||||
|
||||
let vector_query = search.vector.take();
|
||||
|
@ -1,25 +1,17 @@
|
||||
use std::fmt;
|
||||
use std::ops::ControlFlow;
|
||||
|
||||
use charabia::normalizer::NormalizerOption;
|
||||
use charabia::Normalize;
|
||||
use fst::automaton::{Automaton, Str};
|
||||
use fst::{IntoStreamer, Streamer};
|
||||
use levenshtein_automata::{LevenshteinAutomatonBuilder as LevBuilder, DFA};
|
||||
use once_cell::sync::Lazy;
|
||||
use roaring::bitmap::RoaringBitmap;
|
||||
use tracing::error;
|
||||
|
||||
pub use self::facet::{FacetDistribution, Filter, OrderBy, DEFAULT_VALUES_PER_FACET};
|
||||
pub use self::new::matches::{FormatOptions, MatchBounds, MatcherBuilder, MatchingWords};
|
||||
use self::new::{execute_vector_search, PartialSearchResult};
|
||||
use crate::error::UserError;
|
||||
use crate::heed_codec::facet::{FacetGroupKey, FacetGroupValue};
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::vector::DistributionShift;
|
||||
use crate::{
|
||||
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, FieldId, Index,
|
||||
Result, SearchContext,
|
||||
execute_search, filtered_universe, AscDesc, DefaultSearchLogger, DocumentId, Index, Result,
|
||||
SearchContext, TimeBudget,
|
||||
};
|
||||
|
||||
// Building these factories is not free.
|
||||
@ -27,9 +19,6 @@ static LEVDIST0: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(0, true));
|
||||
static LEVDIST1: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(1, true));
|
||||
static LEVDIST2: Lazy<LevBuilder> = Lazy::new(|| LevBuilder::new(2, true));
|
||||
|
||||
/// The maximum number of values per facet returned by the facet search route.
|
||||
const DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET: usize = 100;
|
||||
|
||||
pub mod facet;
|
||||
mod fst_utils;
|
||||
pub mod hybrid;
|
||||
@ -54,6 +43,8 @@ pub struct Search<'a> {
|
||||
index: &'a Index,
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
embedder_name: Option<String>,
|
||||
|
||||
time_budget: TimeBudget,
|
||||
}
|
||||
|
||||
impl<'a> Search<'a> {
|
||||
@ -75,6 +66,7 @@ impl<'a> Search<'a> {
|
||||
index,
|
||||
distribution_shift: None,
|
||||
embedder_name: None,
|
||||
time_budget: TimeBudget::max(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -154,6 +146,11 @@ impl<'a> Search<'a> {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn time_budget(&mut self, time_budget: TimeBudget) -> &mut Search<'a> {
|
||||
self.time_budget = time_budget;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn execute_for_candidates(&self, has_vector_search: bool) -> Result<RoaringBitmap> {
|
||||
if has_vector_search {
|
||||
let ctx = SearchContext::new(self.index, self.rtxn);
|
||||
@ -180,36 +177,43 @@ impl<'a> Search<'a> {
|
||||
}
|
||||
|
||||
let universe = filtered_universe(&ctx, &self.filter)?;
|
||||
let PartialSearchResult { located_query_terms, candidates, documents_ids, document_scores } =
|
||||
match self.vector.as_ref() {
|
||||
Some(vector) => execute_vector_search(
|
||||
&mut ctx,
|
||||
vector,
|
||||
self.scoring_strategy,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
self.offset,
|
||||
self.limit,
|
||||
self.distribution_shift,
|
||||
embedder_name,
|
||||
)?,
|
||||
None => execute_search(
|
||||
&mut ctx,
|
||||
self.query.as_deref(),
|
||||
self.terms_matching_strategy,
|
||||
self.scoring_strategy,
|
||||
self.exhaustive_number_hits,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
self.offset,
|
||||
self.limit,
|
||||
Some(self.words_limit),
|
||||
&mut DefaultSearchLogger,
|
||||
&mut DefaultSearchLogger,
|
||||
)?,
|
||||
};
|
||||
let PartialSearchResult {
|
||||
located_query_terms,
|
||||
candidates,
|
||||
documents_ids,
|
||||
document_scores,
|
||||
degraded,
|
||||
} = match self.vector.as_ref() {
|
||||
Some(vector) => execute_vector_search(
|
||||
&mut ctx,
|
||||
vector,
|
||||
self.scoring_strategy,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
self.offset,
|
||||
self.limit,
|
||||
self.distribution_shift,
|
||||
embedder_name,
|
||||
self.time_budget.clone(),
|
||||
)?,
|
||||
None => execute_search(
|
||||
&mut ctx,
|
||||
self.query.as_deref(),
|
||||
self.terms_matching_strategy,
|
||||
self.scoring_strategy,
|
||||
self.exhaustive_number_hits,
|
||||
universe,
|
||||
&self.sort_criteria,
|
||||
self.geo_strategy,
|
||||
self.offset,
|
||||
self.limit,
|
||||
Some(self.words_limit),
|
||||
&mut DefaultSearchLogger,
|
||||
&mut DefaultSearchLogger,
|
||||
self.time_budget.clone(),
|
||||
)?,
|
||||
};
|
||||
|
||||
// consume context and located_query_terms to build MatchingWords.
|
||||
let matching_words = match located_query_terms {
|
||||
@ -217,7 +221,7 @@ impl<'a> Search<'a> {
|
||||
None => MatchingWords::default(),
|
||||
};
|
||||
|
||||
Ok(SearchResult { matching_words, candidates, document_scores, documents_ids })
|
||||
Ok(SearchResult { matching_words, candidates, document_scores, documents_ids, degraded })
|
||||
}
|
||||
}
|
||||
|
||||
@ -240,6 +244,7 @@ impl fmt::Debug for Search<'_> {
|
||||
index: _,
|
||||
distribution_shift,
|
||||
embedder_name,
|
||||
time_budget,
|
||||
} = self;
|
||||
f.debug_struct("Search")
|
||||
.field("query", query)
|
||||
@ -255,6 +260,7 @@ impl fmt::Debug for Search<'_> {
|
||||
.field("words_limit", words_limit)
|
||||
.field("distribution_shift", distribution_shift)
|
||||
.field("embedder_name", embedder_name)
|
||||
.field("time_budget", time_budget)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@ -265,6 +271,7 @@ pub struct SearchResult {
|
||||
pub candidates: RoaringBitmap,
|
||||
pub documents_ids: Vec<DocumentId>,
|
||||
pub document_scores: Vec<Vec<ScoreDetails>>,
|
||||
pub degraded: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
@ -302,240 +309,6 @@ pub fn build_dfa(word: &str, typos: u8, is_prefix: bool) -> DFA {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SearchForFacetValues<'a> {
|
||||
query: Option<String>,
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
max_values: usize,
|
||||
is_hybrid: bool,
|
||||
}
|
||||
|
||||
impl<'a> SearchForFacetValues<'a> {
|
||||
pub fn new(
|
||||
facet: String,
|
||||
search_query: Search<'a>,
|
||||
is_hybrid: bool,
|
||||
) -> SearchForFacetValues<'a> {
|
||||
SearchForFacetValues {
|
||||
query: None,
|
||||
facet,
|
||||
search_query,
|
||||
max_values: DEFAULT_MAX_NUMBER_OF_VALUES_PER_FACET,
|
||||
is_hybrid,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query(&mut self, query: impl Into<String>) -> &mut Self {
|
||||
self.query = Some(query.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn max_values(&mut self, max: usize) -> &mut Self {
|
||||
self.max_values = max;
|
||||
self
|
||||
}
|
||||
|
||||
fn one_original_value_of(
|
||||
&self,
|
||||
field_id: FieldId,
|
||||
facet_str: &str,
|
||||
any_docid: DocumentId,
|
||||
) -> Result<Option<String>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
let key: (FieldId, _, &str) = (field_id, any_docid, facet_str);
|
||||
Ok(index.field_id_docid_facet_strings.get(rtxn, &key)?.map(|v| v.to_owned()))
|
||||
}
|
||||
|
||||
pub fn execute(&self) -> Result<Vec<FacetValueHit>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let filterable_fields = index.filterable_fields(rtxn)?;
|
||||
if !filterable_fields.contains(&self.facet) {
|
||||
let (valid_fields, hidden_fields) =
|
||||
index.remove_hidden_fields(rtxn, filterable_fields)?;
|
||||
|
||||
return Err(UserError::InvalidFacetSearchFacetName {
|
||||
field: self.facet.clone(),
|
||||
valid_fields,
|
||||
hidden_fields,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let fields_ids_map = index.fields_ids_map(rtxn)?;
|
||||
let fid = match fields_ids_map.id(&self.facet) {
|
||||
Some(fid) => fid,
|
||||
// we return an empty list of results when the attribute has been
|
||||
// set as filterable but no document contains this field (yet).
|
||||
None => return Ok(Vec::new()),
|
||||
};
|
||||
|
||||
let fst = match self.search_query.index.facet_id_string_fst.get(rtxn, &fid)? {
|
||||
Some(fst) => fst,
|
||||
None => return Ok(vec![]),
|
||||
};
|
||||
|
||||
let search_candidates = self
|
||||
.search_query
|
||||
.execute_for_candidates(self.is_hybrid || self.search_query.vector.is_some())?;
|
||||
|
||||
match self.query.as_ref() {
|
||||
Some(query) => {
|
||||
let options = NormalizerOption { lossy: true, ..Default::default() };
|
||||
let query = query.normalize(&options);
|
||||
let query = query.as_ref();
|
||||
|
||||
let authorize_typos = self.search_query.index.authorize_typos(rtxn)?;
|
||||
let field_authorizes_typos =
|
||||
!self.search_query.index.exact_attributes_ids(rtxn)?.contains(&fid);
|
||||
|
||||
if authorize_typos && field_authorizes_typos {
|
||||
let exact_words_fst = self.search_query.index.exact_words(rtxn)?;
|
||||
if exact_words_fst.map_or(false, |fst| fst.contains(query)) {
|
||||
let mut results = vec![];
|
||||
if fst.contains(query) {
|
||||
self.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
query,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?;
|
||||
}
|
||||
Ok(results)
|
||||
} else {
|
||||
let one_typo = self.search_query.index.min_word_len_one_typo(rtxn)?;
|
||||
let two_typos = self.search_query.index.min_word_len_two_typos(rtxn)?;
|
||||
|
||||
let is_prefix = true;
|
||||
let automaton = if query.len() < one_typo as usize {
|
||||
build_dfa(query, 0, is_prefix)
|
||||
} else if query.len() < two_typos as usize {
|
||||
build_dfa(query, 1, is_prefix)
|
||||
} else {
|
||||
build_dfa(query, 2, is_prefix)
|
||||
};
|
||||
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
let mut results = vec![];
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
} else {
|
||||
let automaton = Str::new(query).starts_with();
|
||||
let mut stream = fst.search(automaton).into_stream();
|
||||
let mut results = vec![];
|
||||
while let Some(facet_value) = stream.next() {
|
||||
let value = std::str::from_utf8(facet_value)?;
|
||||
if self
|
||||
.fetch_original_facets_using_normalized(
|
||||
fid,
|
||||
value,
|
||||
query,
|
||||
&search_candidates,
|
||||
&mut results,
|
||||
)?
|
||||
.is_break()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
None => {
|
||||
let mut results = vec![];
|
||||
let prefix = FacetGroupKey { field_id: fid, level: 0, left_bound: "" };
|
||||
for result in index.facet_id_string_docids.prefix_iter(rtxn, &prefix)? {
|
||||
let (FacetGroupKey { left_bound, .. }, FacetGroupValue { bitmap, .. }) =
|
||||
result?;
|
||||
let count = search_candidates.intersection_len(&bitmap);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, left_bound, bitmap.min().unwrap())?
|
||||
.unwrap_or_else(|| left_bound.to_string());
|
||||
results.push(FacetValueHit { value, count });
|
||||
}
|
||||
if results.len() >= self.max_values {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fetch_original_facets_using_normalized(
|
||||
&self,
|
||||
fid: FieldId,
|
||||
value: &str,
|
||||
query: &str,
|
||||
search_candidates: &RoaringBitmap,
|
||||
results: &mut Vec<FacetValueHit>,
|
||||
) -> Result<ControlFlow<()>> {
|
||||
let index = self.search_query.index;
|
||||
let rtxn = self.search_query.rtxn;
|
||||
|
||||
let database = index.facet_id_normalized_string_strings;
|
||||
let key = (fid, value);
|
||||
let original_strings = match database.get(rtxn, &key)? {
|
||||
Some(original_strings) => original_strings,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
for original in original_strings {
|
||||
let key = FacetGroupKey { field_id: fid, level: 0, left_bound: original.as_str() };
|
||||
let docids = match index.facet_id_string_docids.get(rtxn, &key)? {
|
||||
Some(FacetGroupValue { bitmap, .. }) => bitmap,
|
||||
None => {
|
||||
error!("the facet value is missing from the facet database: {key:?}");
|
||||
return Ok(ControlFlow::Continue(()));
|
||||
}
|
||||
};
|
||||
let count = search_candidates.intersection_len(&docids);
|
||||
if count != 0 {
|
||||
let value = self
|
||||
.one_original_value_of(fid, &original, docids.min().unwrap())?
|
||||
.unwrap_or_else(|| query.to_string());
|
||||
results.push(FacetValueHit { value, count });
|
||||
}
|
||||
if results.len() >= self.max_values {
|
||||
return Ok(ControlFlow::Break(()));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ControlFlow::Continue(()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, serde::Serialize, PartialEq)]
|
||||
pub struct FacetValueHit {
|
||||
/// The original facet value
|
||||
pub value: String,
|
||||
/// The number of documents associated to this facet
|
||||
pub count: u64,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
#[allow(unused_imports)]
|
||||
|
@ -5,12 +5,14 @@ use super::ranking_rules::{BoxRankingRule, RankingRuleQueryTrait};
|
||||
use super::SearchContext;
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::search::new::distinct::{apply_distinct_rule, distinct_single_docid, DistinctOutput};
|
||||
use crate::Result;
|
||||
use crate::{Result, TimeBudget};
|
||||
|
||||
pub struct BucketSortOutput {
|
||||
pub docids: Vec<u32>,
|
||||
pub scores: Vec<Vec<ScoreDetails>>,
|
||||
pub all_candidates: RoaringBitmap,
|
||||
|
||||
pub degraded: bool,
|
||||
}
|
||||
|
||||
// TODO: would probably be good to regroup some of these inside of a struct?
|
||||
@ -25,6 +27,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
length: usize,
|
||||
scoring_strategy: ScoringStrategy,
|
||||
logger: &mut dyn SearchLogger<Q>,
|
||||
time_budget: TimeBudget,
|
||||
) -> Result<BucketSortOutput> {
|
||||
logger.initial_query(query);
|
||||
logger.ranking_rules(&ranking_rules);
|
||||
@ -41,6 +44,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
docids: vec![],
|
||||
scores: vec![],
|
||||
all_candidates: universe.clone(),
|
||||
degraded: false,
|
||||
});
|
||||
}
|
||||
if ranking_rules.is_empty() {
|
||||
@ -74,6 +78,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
scores: vec![Default::default(); results.len()],
|
||||
docids: results,
|
||||
all_candidates,
|
||||
degraded: false,
|
||||
});
|
||||
} else {
|
||||
let docids: Vec<u32> = universe.iter().skip(from).take(length).collect();
|
||||
@ -81,6 +86,7 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
scores: vec![Default::default(); docids.len()],
|
||||
docids,
|
||||
all_candidates: universe.clone(),
|
||||
degraded: false,
|
||||
});
|
||||
};
|
||||
}
|
||||
@ -154,6 +160,28 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
}
|
||||
|
||||
while valid_docids.len() < length {
|
||||
if time_budget.exceeded() {
|
||||
loop {
|
||||
let bucket = std::mem::take(&mut ranking_rule_universes[cur_ranking_rule_index]);
|
||||
ranking_rule_scores.push(ScoreDetails::Skipped);
|
||||
maybe_add_to_results!(bucket);
|
||||
ranking_rule_scores.pop();
|
||||
|
||||
if cur_ranking_rule_index == 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
back!();
|
||||
}
|
||||
|
||||
return Ok(BucketSortOutput {
|
||||
scores: valid_scores,
|
||||
docids: valid_docids,
|
||||
all_candidates,
|
||||
degraded: true,
|
||||
});
|
||||
}
|
||||
|
||||
// The universe for this bucket is zero, so we don't need to sort
|
||||
// anything, just go back to the parent ranking rule.
|
||||
if ranking_rule_universes[cur_ranking_rule_index].is_empty()
|
||||
@ -219,7 +247,12 @@ pub fn bucket_sort<'ctx, Q: RankingRuleQueryTrait>(
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(BucketSortOutput { docids: valid_docids, scores: valid_scores, all_candidates })
|
||||
Ok(BucketSortOutput {
|
||||
docids: valid_docids,
|
||||
scores: valid_scores,
|
||||
all_candidates,
|
||||
degraded: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Add the candidates to the results. Take `distinct`, `from`, `length`, and `cur_offset`
|
||||
|
@ -502,7 +502,7 @@ mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::index::tests::TempIndex;
|
||||
use crate::{execute_search, filtered_universe, SearchContext};
|
||||
use crate::{execute_search, filtered_universe, SearchContext, TimeBudget};
|
||||
|
||||
impl<'a> MatcherBuilder<'a> {
|
||||
fn new_test(rtxn: &'a heed::RoTxn, index: &'a TempIndex, query: &str) -> Self {
|
||||
@ -522,6 +522,7 @@ mod tests {
|
||||
Some(10),
|
||||
&mut crate::DefaultSearchLogger,
|
||||
&mut crate::DefaultSearchLogger,
|
||||
TimeBudget::max(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
|
@ -52,7 +52,8 @@ use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::search::new::distinct::apply_distinct_rule;
|
||||
use crate::vector::DistributionShift;
|
||||
use crate::{
|
||||
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, UserError,
|
||||
AscDesc, DocumentId, FieldId, Filter, Index, Member, Result, TermsMatchingStrategy, TimeBudget,
|
||||
UserError,
|
||||
};
|
||||
|
||||
/// A structure used throughout the execution of a search query.
|
||||
@ -518,6 +519,7 @@ pub fn execute_vector_search(
|
||||
length: usize,
|
||||
distribution_shift: Option<DistributionShift>,
|
||||
embedder_name: &str,
|
||||
time_budget: TimeBudget,
|
||||
) -> Result<PartialSearchResult> {
|
||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||
|
||||
@ -537,7 +539,7 @@ pub fn execute_vector_search(
|
||||
let placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery> =
|
||||
&mut placeholder_search_logger;
|
||||
|
||||
let BucketSortOutput { docids, scores, all_candidates } = bucket_sort(
|
||||
let BucketSortOutput { docids, scores, all_candidates, degraded } = bucket_sort(
|
||||
ctx,
|
||||
ranking_rules,
|
||||
&PlaceholderQuery,
|
||||
@ -546,6 +548,7 @@ pub fn execute_vector_search(
|
||||
length,
|
||||
scoring_strategy,
|
||||
placeholder_search_logger,
|
||||
time_budget,
|
||||
)?;
|
||||
|
||||
Ok(PartialSearchResult {
|
||||
@ -553,6 +556,7 @@ pub fn execute_vector_search(
|
||||
document_scores: scores,
|
||||
documents_ids: docids,
|
||||
located_query_terms: None,
|
||||
degraded,
|
||||
})
|
||||
}
|
||||
|
||||
@ -572,6 +576,7 @@ pub fn execute_search(
|
||||
words_limit: Option<usize>,
|
||||
placeholder_search_logger: &mut dyn SearchLogger<PlaceholderQuery>,
|
||||
query_graph_logger: &mut dyn SearchLogger<QueryGraph>,
|
||||
time_budget: TimeBudget,
|
||||
) -> Result<PartialSearchResult> {
|
||||
check_sort_criteria(ctx, sort_criteria.as_ref())?;
|
||||
|
||||
@ -648,6 +653,7 @@ pub fn execute_search(
|
||||
length,
|
||||
scoring_strategy,
|
||||
query_graph_logger,
|
||||
time_budget,
|
||||
)?
|
||||
} else {
|
||||
let ranking_rules =
|
||||
@ -661,10 +667,11 @@ pub fn execute_search(
|
||||
length,
|
||||
scoring_strategy,
|
||||
placeholder_search_logger,
|
||||
time_budget,
|
||||
)?
|
||||
};
|
||||
|
||||
let BucketSortOutput { docids, scores, mut all_candidates } = bucket_sort_output;
|
||||
let BucketSortOutput { docids, scores, mut all_candidates, degraded } = bucket_sort_output;
|
||||
let fields_ids_map = ctx.index.fields_ids_map(ctx.txn)?;
|
||||
|
||||
// The candidates is the universe unless the exhaustive number of hits
|
||||
@ -682,6 +689,7 @@ pub fn execute_search(
|
||||
document_scores: scores,
|
||||
documents_ids: docids,
|
||||
located_query_terms,
|
||||
degraded,
|
||||
})
|
||||
}
|
||||
|
||||
@ -742,4 +750,6 @@ pub struct PartialSearchResult {
|
||||
pub candidates: RoaringBitmap,
|
||||
pub documents_ids: Vec<DocumentId>,
|
||||
pub document_scores: Vec<Vec<ScoreDetails>>,
|
||||
|
||||
pub degraded: bool,
|
||||
}
|
||||
|
429
milli/src/search/new/tests/cutoff.rs
Normal file
429
milli/src/search/new/tests/cutoff.rs
Normal file
@ -0,0 +1,429 @@
|
||||
//! This module test the search cutoff and ensure a few things:
|
||||
//! 1. A basic test works and mark the search as degraded
|
||||
//! 2. A test that ensure the filters are affectively applied even with a cutoff of 0
|
||||
//! 3. A test that ensure the cutoff works well with the ranking scores
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use big_s::S;
|
||||
use maplit::hashset;
|
||||
use meili_snap::snapshot;
|
||||
|
||||
use crate::index::tests::TempIndex;
|
||||
use crate::score_details::{ScoreDetails, ScoringStrategy};
|
||||
use crate::{Criterion, Filter, Search, TimeBudget};
|
||||
|
||||
fn create_index() -> TempIndex {
|
||||
let index = TempIndex::new();
|
||||
|
||||
index
|
||||
.update_settings(|s| {
|
||||
s.set_primary_key("id".to_owned());
|
||||
s.set_searchable_fields(vec!["text".to_owned()]);
|
||||
s.set_filterable_fields(hashset! { S("id") });
|
||||
s.set_criteria(vec![Criterion::Words, Criterion::Typo]);
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// reverse the ID / insertion order so we see better what was sorted from what got the insertion order ordering
|
||||
index
|
||||
.add_documents(documents!([
|
||||
{
|
||||
"id": 4,
|
||||
"text": "hella puppo kefir",
|
||||
},
|
||||
{
|
||||
"id": 3,
|
||||
"text": "hella puppy kefir",
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"text": "hello",
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"text": "hello puppy",
|
||||
},
|
||||
{
|
||||
"id": 0,
|
||||
"text": "hello puppy kefir",
|
||||
},
|
||||
]))
|
||||
.unwrap();
|
||||
index
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn basic_degraded_search() {
|
||||
let index = create_index();
|
||||
let rtxn = index.read_txn().unwrap();
|
||||
|
||||
let mut search = Search::new(&rtxn, &index);
|
||||
search.query("hello puppy kefir");
|
||||
search.limit(3);
|
||||
search.time_budget(TimeBudget::new(Duration::from_millis(0)));
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
assert!(result.degraded);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn degraded_search_cannot_skip_filter() {
|
||||
let index = create_index();
|
||||
let rtxn = index.read_txn().unwrap();
|
||||
|
||||
let mut search = Search::new(&rtxn, &index);
|
||||
search.query("hello puppy kefir");
|
||||
search.limit(100);
|
||||
search.time_budget(TimeBudget::new(Duration::from_millis(0)));
|
||||
let filter_condition = Filter::from_str("id > 2").unwrap().unwrap();
|
||||
search.filter(filter_condition);
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
assert!(result.degraded);
|
||||
snapshot!(format!("{:?}\n{:?}", result.candidates, result.documents_ids), @r###"
|
||||
RoaringBitmap<[0, 1]>
|
||||
[0, 1]
|
||||
"###);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(clippy::format_collect)] // the test is already quite big
|
||||
fn degraded_search_and_score_details() {
|
||||
let index = create_index();
|
||||
let rtxn = index.read_txn().unwrap();
|
||||
|
||||
let mut search = Search::new(&rtxn, &index);
|
||||
search.query("hello puppy kefir");
|
||||
search.limit(4);
|
||||
search.scoring_strategy(ScoringStrategy::Detailed);
|
||||
search.time_budget(TimeBudget::max());
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
|
||||
IDs: [4, 1, 0, 3]
|
||||
Scores: 1.0000 0.9167 0.8333 0.6667
|
||||
Score Details:
|
||||
[
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 0,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 1,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 2,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 2,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 0,
|
||||
max_typo_count: 2,
|
||||
},
|
||||
),
|
||||
],
|
||||
]
|
||||
"###);
|
||||
|
||||
// Do ONE loop iteration. Not much can be deduced, almost everyone matched the words first bucket.
|
||||
search.time_budget(TimeBudget::max().with_stop_after(1));
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
|
||||
IDs: [0, 1, 4, 2]
|
||||
Scores: 0.6667 0.6667 0.6667 0.0000
|
||||
Score Details:
|
||||
[
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Skipped,
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Skipped,
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Skipped,
|
||||
],
|
||||
[
|
||||
Skipped,
|
||||
],
|
||||
]
|
||||
"###);
|
||||
|
||||
// Do TWO loop iterations. The first document should be entirely sorted
|
||||
search.time_budget(TimeBudget::max().with_stop_after(2));
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
|
||||
IDs: [4, 0, 1, 2]
|
||||
Scores: 1.0000 0.6667 0.6667 0.0000
|
||||
Score Details:
|
||||
[
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 0,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Skipped,
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Skipped,
|
||||
],
|
||||
[
|
||||
Skipped,
|
||||
],
|
||||
]
|
||||
"###);
|
||||
|
||||
// Do THREE loop iterations. The second document should be entirely sorted as well
|
||||
search.time_budget(TimeBudget::max().with_stop_after(3));
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
|
||||
IDs: [4, 1, 0, 2]
|
||||
Scores: 1.0000 0.9167 0.6667 0.0000
|
||||
Score Details:
|
||||
[
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 0,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 1,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Skipped,
|
||||
],
|
||||
[
|
||||
Skipped,
|
||||
],
|
||||
]
|
||||
"###);
|
||||
|
||||
// Do FOUR loop iterations. The third document should be entirely sorted as well
|
||||
// The words bucket have still not progressed thus the last document doesn't have any info yet.
|
||||
search.time_budget(TimeBudget::max().with_stop_after(4));
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
|
||||
IDs: [4, 1, 0, 2]
|
||||
Scores: 1.0000 0.9167 0.8333 0.0000
|
||||
Score Details:
|
||||
[
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 0,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 1,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 2,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Skipped,
|
||||
],
|
||||
]
|
||||
"###);
|
||||
|
||||
// After SIX loop iteration. The words ranking rule gave us a new bucket.
|
||||
// Since we reached the limit we were able to early exit without checking the typo ranking rule.
|
||||
search.time_budget(TimeBudget::max().with_stop_after(6));
|
||||
|
||||
let result = search.execute().unwrap();
|
||||
snapshot!(format!("IDs: {:?}\nScores: {}\nScore Details:\n{:#?}", result.documents_ids, result.document_scores.iter().map(|scores| format!("{:.4} ", ScoreDetails::global_score(scores.iter()))).collect::<String>(), result.document_scores), @r###"
|
||||
IDs: [4, 1, 0, 3]
|
||||
Scores: 1.0000 0.9167 0.8333 0.3333
|
||||
Score Details:
|
||||
[
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 0,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 1,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 3,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Typo(
|
||||
Typo {
|
||||
typo_count: 2,
|
||||
max_typo_count: 3,
|
||||
},
|
||||
),
|
||||
],
|
||||
[
|
||||
Words(
|
||||
Words {
|
||||
matching_words: 2,
|
||||
max_matching_words: 3,
|
||||
},
|
||||
),
|
||||
Skipped,
|
||||
],
|
||||
]
|
||||
"###);
|
||||
}
|
@ -1,5 +1,6 @@
|
||||
pub mod attribute_fid;
|
||||
pub mod attribute_position;
|
||||
pub mod cutoff;
|
||||
pub mod distinct;
|
||||
pub mod exactness;
|
||||
pub mod geo_sort;
|
||||
|
@ -5,7 +5,7 @@ The typo ranking rule should transform the query graph such that it only contain
|
||||
the combinations of word derivations that it used to compute its bucket.
|
||||
|
||||
The proximity ranking rule should then look for proximities only between those specific derivations.
|
||||
For example, given the the search query `beautiful summer` and the dataset:
|
||||
For example, given the search query `beautiful summer` and the dataset:
|
||||
```text
|
||||
{ "id": 0, "text": "beautigul summer...... beautiful day in the summer" }
|
||||
{ "id": 1, "text": "beautiful summer" }
|
||||
|
@ -339,6 +339,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
prompt_reader: grenad::Reader<R>,
|
||||
indexer: GrenadParameters,
|
||||
embedder: Arc<Embedder>,
|
||||
request_threads: &rayon::ThreadPool,
|
||||
) -> Result<grenad::Reader<BufReader<File>>> {
|
||||
puffin::profile_function!();
|
||||
let n_chunks = embedder.chunk_count_hint(); // chunk level parallelism
|
||||
@ -376,7 +377,10 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
|
||||
if chunks.len() == chunks.capacity() {
|
||||
let chunked_embeds = embedder
|
||||
.embed_chunks(std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)))
|
||||
.embed_chunks(
|
||||
std::mem::replace(&mut chunks, Vec::with_capacity(n_chunks)),
|
||||
request_threads,
|
||||
)
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
@ -394,7 +398,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
// send last chunk
|
||||
if !chunks.is_empty() {
|
||||
let chunked_embeds = embedder
|
||||
.embed_chunks(std::mem::take(&mut chunks))
|
||||
.embed_chunks(std::mem::take(&mut chunks), request_threads)
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
for (docid, embeddings) in chunks_ids
|
||||
@ -408,7 +412,7 @@ pub fn extract_embeddings<R: io::Read + io::Seek>(
|
||||
|
||||
if !current_chunk.is_empty() {
|
||||
let embeds = embedder
|
||||
.embed_chunks(vec![std::mem::take(&mut current_chunk)])
|
||||
.embed_chunks(vec![std::mem::take(&mut current_chunk)], request_threads)
|
||||
.map_err(crate::vector::Error::from)
|
||||
.map_err(crate::Error::from)?;
|
||||
|
||||
|
@ -238,6 +238,12 @@ fn send_original_documents_data(
|
||||
|
||||
let documents_chunk_cloned = original_documents_chunk.clone();
|
||||
let lmdb_writer_sx_cloned = lmdb_writer_sx.clone();
|
||||
|
||||
let request_threads = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(crate::vector::REQUEST_PARALLELISM)
|
||||
.thread_name(|index| format!("embedding-request-{index}"))
|
||||
.build()?;
|
||||
|
||||
rayon::spawn(move || {
|
||||
for (name, (embedder, prompt)) in embedders {
|
||||
let result = extract_vector_points(
|
||||
@ -249,7 +255,12 @@ fn send_original_documents_data(
|
||||
);
|
||||
match result {
|
||||
Ok(ExtractedVectorPoints { manual_vectors, remove_vectors, prompts }) => {
|
||||
let embeddings = match extract_embeddings(prompts, indexer, embedder.clone()) {
|
||||
let embeddings = match extract_embeddings(
|
||||
prompts,
|
||||
indexer,
|
||||
embedder.clone(),
|
||||
&request_threads,
|
||||
) {
|
||||
Ok(results) => Some(results),
|
||||
Err(error) => {
|
||||
let _ = lmdb_writer_sx_cloned.send(Err(error));
|
||||
|
@ -2646,6 +2646,12 @@ mod tests {
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::Set(3),
|
||||
document_template: Setting::NotSet,
|
||||
url: Setting::NotSet,
|
||||
query: Setting::NotSet,
|
||||
input_field: Setting::NotSet,
|
||||
path_to_embeddings: Setting::NotSet,
|
||||
embedding_object: Setting::NotSet,
|
||||
input_type: Setting::NotSet,
|
||||
}),
|
||||
);
|
||||
settings.set_embedder_settings(embedders);
|
||||
|
@ -14,12 +14,13 @@ use super::IndexerConfig;
|
||||
use crate::criterion::Criterion;
|
||||
use crate::error::UserError;
|
||||
use crate::index::{DEFAULT_MIN_WORD_LEN_ONE_TYPO, DEFAULT_MIN_WORD_LEN_TWO_TYPOS};
|
||||
use crate::order_by_map::OrderByMap;
|
||||
use crate::proximity::ProximityPrecision;
|
||||
use crate::update::index_documents::IndexDocumentsMethod;
|
||||
use crate::update::{IndexDocuments, UpdateIndexingStep};
|
||||
use crate::vector::settings::{check_set, check_unset, EmbedderSource, EmbeddingSettings};
|
||||
use crate::vector::{Embedder, EmbeddingConfig, EmbeddingConfigs};
|
||||
use crate::{FieldsIdsMap, Index, OrderBy, Result};
|
||||
use crate::{FieldsIdsMap, Index, Result};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum Setting<T> {
|
||||
@ -145,10 +146,11 @@ pub struct Settings<'a, 't, 'i> {
|
||||
/// Attributes on which typo tolerance is disabled.
|
||||
exact_attributes: Setting<HashSet<String>>,
|
||||
max_values_per_facet: Setting<usize>,
|
||||
sort_facet_values_by: Setting<HashMap<String, OrderBy>>,
|
||||
sort_facet_values_by: Setting<OrderByMap>,
|
||||
pagination_max_total_hits: Setting<usize>,
|
||||
proximity_precision: Setting<ProximityPrecision>,
|
||||
embedder_settings: Setting<BTreeMap<String, Setting<EmbeddingSettings>>>,
|
||||
search_cutoff: Setting<u64>,
|
||||
}
|
||||
|
||||
impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
@ -182,6 +184,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
pagination_max_total_hits: Setting::NotSet,
|
||||
proximity_precision: Setting::NotSet,
|
||||
embedder_settings: Setting::NotSet,
|
||||
search_cutoff: Setting::NotSet,
|
||||
indexer_config,
|
||||
}
|
||||
}
|
||||
@ -340,7 +343,7 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
self.max_values_per_facet = Setting::Reset;
|
||||
}
|
||||
|
||||
pub fn set_sort_facet_values_by(&mut self, value: HashMap<String, OrderBy>) {
|
||||
pub fn set_sort_facet_values_by(&mut self, value: OrderByMap) {
|
||||
self.sort_facet_values_by = Setting::Set(value);
|
||||
}
|
||||
|
||||
@ -372,6 +375,14 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
self.embedder_settings = Setting::Reset;
|
||||
}
|
||||
|
||||
pub fn set_search_cutoff(&mut self, value: u64) {
|
||||
self.search_cutoff = Setting::Set(value);
|
||||
}
|
||||
|
||||
pub fn reset_search_cutoff(&mut self) {
|
||||
self.search_cutoff = Setting::Reset;
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
level = "trace"
|
||||
skip(self, progress_callback, should_abort, old_fields_ids_map),
|
||||
@ -1025,6 +1036,24 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
Ok(update)
|
||||
}
|
||||
|
||||
fn update_search_cutoff(&mut self) -> Result<bool> {
|
||||
let changed = match self.search_cutoff {
|
||||
Setting::Set(new) => {
|
||||
let old = self.index.search_cutoff(self.wtxn)?;
|
||||
if old == Some(new) {
|
||||
false
|
||||
} else {
|
||||
self.index.put_search_cutoff(self.wtxn, new)?;
|
||||
true
|
||||
}
|
||||
}
|
||||
Setting::Reset => self.index.delete_search_cutoff(self.wtxn)?,
|
||||
Setting::NotSet => false,
|
||||
};
|
||||
|
||||
Ok(changed)
|
||||
}
|
||||
|
||||
pub fn execute<FP, FA>(mut self, progress_callback: FP, should_abort: FA) -> Result<()>
|
||||
where
|
||||
FP: Fn(UpdateIndexingStep) + Sync,
|
||||
@ -1073,6 +1102,9 @@ impl<'a, 't, 'i> Settings<'a, 't, 'i> {
|
||||
// 3. Keep the old vectors but reattempt indexing on a prompt change: only actually changed prompt will need embedding + storage
|
||||
let embedding_configs_updated = self.update_embedding_configs()?;
|
||||
|
||||
// never trigger re-indexing
|
||||
self.update_search_cutoff()?;
|
||||
|
||||
if stop_words_updated
|
||||
|| non_separator_tokens_updated
|
||||
|| separator_tokens_updated
|
||||
@ -1131,6 +1163,12 @@ fn validate_prompt(
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template: Setting::Set(template),
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
}) => {
|
||||
// validate
|
||||
let template = crate::prompt::Prompt::new(template)
|
||||
@ -1144,6 +1182,12 @@ fn validate_prompt(
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template: Setting::Set(template),
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
}))
|
||||
}
|
||||
new => Ok(new),
|
||||
@ -1156,8 +1200,20 @@ pub fn validate_embedding_settings(
|
||||
) -> Result<Setting<EmbeddingSettings>> {
|
||||
let settings = validate_prompt(name, settings)?;
|
||||
let Setting::Set(settings) = settings else { return Ok(settings) };
|
||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
||||
settings;
|
||||
let EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
} = settings;
|
||||
|
||||
if let Some(0) = dimensions.set() {
|
||||
return Err(crate::error::UserError::InvalidSettingsDimensions {
|
||||
@ -1166,6 +1222,14 @@ pub fn validate_embedding_settings(
|
||||
.into());
|
||||
}
|
||||
|
||||
if let Some(url) = url.as_ref().set() {
|
||||
url::Url::parse(url).map_err(|error| crate::error::UserError::InvalidUrl {
|
||||
embedder_name: name.to_owned(),
|
||||
inner_error: error,
|
||||
url: url.to_owned(),
|
||||
})?;
|
||||
}
|
||||
|
||||
let Some(inferred_source) = source.set() else {
|
||||
return Ok(Setting::Set(EmbeddingSettings {
|
||||
source,
|
||||
@ -1174,11 +1238,25 @@ pub fn validate_embedding_settings(
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
}));
|
||||
};
|
||||
match inferred_source {
|
||||
EmbedderSource::OpenAi => {
|
||||
check_unset(&revision, "revision", inferred_source, name)?;
|
||||
|
||||
check_unset(&url, "url", inferred_source, name)?;
|
||||
check_unset(&query, "query", inferred_source, name)?;
|
||||
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||
|
||||
if let Setting::Set(model) = &model {
|
||||
let model = crate::vector::openai::EmbeddingModel::from_name(model.as_str())
|
||||
.ok_or(crate::error::UserError::InvalidOpenAiModel {
|
||||
@ -1209,9 +1287,30 @@ pub fn validate_embedding_settings(
|
||||
}
|
||||
}
|
||||
}
|
||||
EmbedderSource::Ollama => {
|
||||
// Dimensions get inferred, only model name is required
|
||||
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
||||
check_set(&model, "model", inferred_source, name)?;
|
||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||
check_unset(&revision, "revision", inferred_source, name)?;
|
||||
|
||||
check_unset(&url, "url", inferred_source, name)?;
|
||||
check_unset(&query, "query", inferred_source, name)?;
|
||||
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||
}
|
||||
EmbedderSource::HuggingFace => {
|
||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||
check_unset(&dimensions, "dimensions", inferred_source, name)?;
|
||||
|
||||
check_unset(&url, "url", inferred_source, name)?;
|
||||
check_unset(&query, "query", inferred_source, name)?;
|
||||
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||
}
|
||||
EmbedderSource::UserProvided => {
|
||||
check_unset(&model, "model", inferred_source, name)?;
|
||||
@ -1219,6 +1318,18 @@ pub fn validate_embedding_settings(
|
||||
check_unset(&api_key, "apiKey", inferred_source, name)?;
|
||||
check_unset(&document_template, "documentTemplate", inferred_source, name)?;
|
||||
check_set(&dimensions, "dimensions", inferred_source, name)?;
|
||||
|
||||
check_unset(&url, "url", inferred_source, name)?;
|
||||
check_unset(&query, "query", inferred_source, name)?;
|
||||
check_unset(&input_field, "inputField", inferred_source, name)?;
|
||||
check_unset(&path_to_embeddings, "pathToEmbeddings", inferred_source, name)?;
|
||||
check_unset(&embedding_object, "embeddingObject", inferred_source, name)?;
|
||||
check_unset(&input_type, "inputType", inferred_source, name)?;
|
||||
}
|
||||
EmbedderSource::Rest => {
|
||||
check_unset(&model, "model", inferred_source, name)?;
|
||||
check_unset(&revision, "revision", inferred_source, name)?;
|
||||
check_set(&url, "url", inferred_source, name)?;
|
||||
}
|
||||
}
|
||||
Ok(Setting::Set(EmbeddingSettings {
|
||||
@ -1228,6 +1339,12 @@ pub fn validate_embedding_settings(
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
}))
|
||||
}
|
||||
|
||||
@ -2050,6 +2167,7 @@ mod tests {
|
||||
pagination_max_total_hits,
|
||||
proximity_precision,
|
||||
embedder_settings,
|
||||
search_cutoff,
|
||||
} = settings;
|
||||
assert!(matches!(searchable_fields, Setting::NotSet));
|
||||
assert!(matches!(displayed_fields, Setting::NotSet));
|
||||
@ -2073,6 +2191,7 @@ mod tests {
|
||||
assert!(matches!(pagination_max_total_hits, Setting::NotSet));
|
||||
assert!(matches!(proximity_precision, Setting::NotSet));
|
||||
assert!(matches!(embedder_settings, Setting::NotSet));
|
||||
assert!(matches!(search_cutoff, Setting::NotSet));
|
||||
})
|
||||
.unwrap();
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ use std::path::PathBuf;
|
||||
use hf_hub::api::sync::ApiError;
|
||||
|
||||
use crate::error::FaultSource;
|
||||
use crate::vector::openai::OpenAiError;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("Error while generating embeddings: {inner}")]
|
||||
@ -51,26 +50,34 @@ pub enum EmbedErrorKind {
|
||||
TensorValue(candle_core::Error),
|
||||
#[error("could not run model: {0}")]
|
||||
ModelForward(candle_core::Error),
|
||||
#[error("could not reach OpenAI: {0}")]
|
||||
OpenAiNetwork(reqwest::Error),
|
||||
#[error("unexpected response from OpenAI: {0}")]
|
||||
OpenAiUnexpected(reqwest::Error),
|
||||
#[error("could not authenticate against OpenAI: {0}")]
|
||||
OpenAiAuth(OpenAiError),
|
||||
#[error("sent too many requests to OpenAI: {0}")]
|
||||
OpenAiTooManyRequests(OpenAiError),
|
||||
#[error("received internal error from OpenAI: {0:?}")]
|
||||
OpenAiInternalServerError(Option<OpenAiError>),
|
||||
#[error("sent too many tokens in a request to OpenAI: {0}")]
|
||||
OpenAiTooManyTokens(OpenAiError),
|
||||
#[error("received unhandled HTTP status code {0} from OpenAI")]
|
||||
OpenAiUnhandledStatusCode(u16),
|
||||
#[error("attempt to embed the following text in a configuration where embeddings must be user provided: {0:?}")]
|
||||
ManualEmbed(String),
|
||||
#[error("could not initialize asynchronous runtime: {0}")]
|
||||
OpenAiRuntimeInit(std::io::Error),
|
||||
#[error("initializing web client for sending embedding requests failed: {0}")]
|
||||
InitWebClient(reqwest::Error),
|
||||
#[error("model not found. Meilisearch will not automatically download models from the Ollama library, please pull the model manually: {0:?}")]
|
||||
OllamaModelNotFoundError(Option<String>),
|
||||
#[error("error deserialization the response body as JSON: {0}")]
|
||||
RestResponseDeserialization(std::io::Error),
|
||||
#[error("component `{0}` not found in path `{1}` in response: `{2}`")]
|
||||
RestResponseMissingEmbeddings(String, String, String),
|
||||
#[error("expected a response parseable as a vector or an array of vectors: {0}")]
|
||||
RestResponseFormat(serde_json::Error),
|
||||
#[error("expected a response containing {0} embeddings, got only {1}")]
|
||||
RestResponseEmbeddingCount(usize, usize),
|
||||
#[error("could not authenticate against embedding server: {0:?}")]
|
||||
RestUnauthorized(Option<String>),
|
||||
#[error("sent too many requests to embedding server: {0:?}")]
|
||||
RestTooManyRequests(Option<String>),
|
||||
#[error("sent a bad request to embedding server: {0:?}")]
|
||||
RestBadRequest(Option<String>),
|
||||
#[error("received internal error from embedding server: {0:?}")]
|
||||
RestInternalServerError(u16, Option<String>),
|
||||
#[error("received HTTP {0} from embedding server: {0:?}")]
|
||||
RestOtherStatusCode(u16, Option<String>),
|
||||
#[error("could not reach embedding server: {0}")]
|
||||
RestNetwork(ureq::Transport),
|
||||
#[error("was expected '{}' to be an object in query '{0}'", .1.join("."))]
|
||||
RestNotAnObject(serde_json::Value, Vec<String>),
|
||||
#[error("while embedding tokenized, was expecting embeddings of dimension `{0}`, got embeddings of dimensions `{1}`")]
|
||||
OpenAiUnexpectedDimension(usize, usize),
|
||||
}
|
||||
|
||||
impl EmbedError {
|
||||
@ -90,44 +97,98 @@ impl EmbedError {
|
||||
Self { kind: EmbedErrorKind::ModelForward(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_network(inner: reqwest::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::OpenAiNetwork(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn openai_unexpected(inner: reqwest::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiUnexpected(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_auth_error(inner: OpenAiError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiAuth(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_too_many_requests(inner: OpenAiError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiTooManyRequests(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_internal_server_error(inner: Option<OpenAiError>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiInternalServerError(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_too_many_tokens(inner: OpenAiError) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiTooManyTokens(inner), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_unhandled_status_code(code: u16) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiUnhandledStatusCode(code), fault: FaultSource::Bug }
|
||||
}
|
||||
|
||||
pub(crate) fn embed_on_manual_embedder(texts: String) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::ManualEmbed(texts), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_runtime_init(inner: std::io::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OpenAiRuntimeInit(inner), fault: FaultSource::Runtime }
|
||||
pub(crate) fn ollama_model_not_found(inner: Option<String>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::OllamaModelNotFoundError(inner), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub fn openai_initialize_web_client(inner: reqwest::Error) -> Self {
|
||||
Self { kind: EmbedErrorKind::InitWebClient(inner), fault: FaultSource::Runtime }
|
||||
pub(crate) fn rest_response_deserialization(error: std::io::Error) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestResponseDeserialization(error),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_response_missing_embeddings<S: AsRef<str>>(
|
||||
response: serde_json::Value,
|
||||
component: &str,
|
||||
response_field: &[S],
|
||||
) -> EmbedError {
|
||||
let response_field: Vec<&str> = response_field.iter().map(AsRef::as_ref).collect();
|
||||
let response_field = response_field.join(".");
|
||||
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestResponseMissingEmbeddings(
|
||||
component.to_owned(),
|
||||
response_field,
|
||||
serde_json::to_string_pretty(&response).unwrap_or_default(),
|
||||
),
|
||||
fault: FaultSource::Undecided,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_response_format(error: serde_json::Error) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::RestResponseFormat(error), fault: FaultSource::Undecided }
|
||||
}
|
||||
|
||||
pub(crate) fn rest_response_embedding_count(expected: usize, got: usize) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestResponseEmbeddingCount(expected, got),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_unauthorized(error_response: Option<String>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::RestUnauthorized(error_response), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn rest_too_many_requests(error_response: Option<String>) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestTooManyRequests(error_response),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_bad_request(error_response: Option<String>) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::RestBadRequest(error_response), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn rest_internal_server_error(
|
||||
code: u16,
|
||||
error_response: Option<String>,
|
||||
) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestInternalServerError(code, error_response),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_other_status_code(code: u16, error_response: Option<String>) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::RestOtherStatusCode(code, error_response),
|
||||
fault: FaultSource::Undecided,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn rest_network(transport: ureq::Transport) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::RestNetwork(transport), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub(crate) fn rest_not_an_object(
|
||||
query: serde_json::Value,
|
||||
input_path: Vec<String>,
|
||||
) -> EmbedError {
|
||||
Self { kind: EmbedErrorKind::RestNotAnObject(query, input_path), fault: FaultSource::User }
|
||||
}
|
||||
|
||||
pub(crate) fn openai_unexpected_dimension(expected: usize, got: usize) -> EmbedError {
|
||||
Self {
|
||||
kind: EmbedErrorKind::OpenAiUnexpectedDimension(expected, got),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -188,16 +249,12 @@ impl NewEmbedderError {
|
||||
Self { kind: NewEmbedderErrorKind::LoadModel(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn hf_could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||
pub fn could_not_determine_dimension(inner: EmbedError) -> NewEmbedderError {
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::CouldNotDetermineDimension(inner),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn openai_invalid_api_key_format(inner: reqwest::header::InvalidHeaderValue) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::InvalidApiKeyFormat(inner), fault: FaultSource::User }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
@ -244,7 +301,4 @@ pub enum NewEmbedderErrorKind {
|
||||
CouldNotDetermineDimension(EmbedError),
|
||||
#[error("loading model failed: {0}")]
|
||||
LoadModel(candle_core::Error),
|
||||
// openai
|
||||
#[error("The API key passed to Authorization error was in an invalid format: {0}")]
|
||||
InvalidApiKeyFormat(reqwest::header::InvalidHeaderValue),
|
||||
}
|
||||
|
@ -131,7 +131,7 @@ impl Embedder {
|
||||
|
||||
let embeddings = this
|
||||
.embed(vec!["test".into()])
|
||||
.map_err(NewEmbedderError::hf_could_not_determine_dimension)?;
|
||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||
this.dimensions = embeddings.first().unwrap().dimension();
|
||||
|
||||
Ok(this)
|
||||
@ -194,7 +194,10 @@ impl Embedder {
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
if self.options.model == "BAAI/bge-base-en-v1.5" {
|
||||
Some(DistributionShift { current_mean: 0.85, current_sigma: 0.1 })
|
||||
Some(DistributionShift {
|
||||
current_mean: ordered_float::OrderedFloat(0.85),
|
||||
current_sigma: ordered_float::OrderedFloat(0.1),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -1,6 +1,9 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use ordered_float::OrderedFloat;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use self::error::{EmbedError, NewEmbedderError};
|
||||
use crate::prompt::{Prompt, PromptData};
|
||||
|
||||
@ -10,50 +13,71 @@ pub mod manual;
|
||||
pub mod openai;
|
||||
pub mod settings;
|
||||
|
||||
pub mod ollama;
|
||||
pub mod rest;
|
||||
|
||||
pub use self::error::Error;
|
||||
|
||||
pub type Embedding = Vec<f32>;
|
||||
|
||||
pub const REQUEST_PARALLELISM: usize = 40;
|
||||
|
||||
/// One or multiple embeddings stored consecutively in a flat vector.
|
||||
pub struct Embeddings<F> {
|
||||
data: Vec<F>,
|
||||
dimension: usize,
|
||||
}
|
||||
|
||||
impl<F> Embeddings<F> {
|
||||
/// Declares an empty vector of embeddings of the specified dimensions.
|
||||
pub fn new(dimension: usize) -> Self {
|
||||
Self { data: Default::default(), dimension }
|
||||
}
|
||||
|
||||
/// Declares a vector of embeddings containing a single element.
|
||||
///
|
||||
/// The dimension is inferred from the length of the passed embedding.
|
||||
pub fn from_single_embedding(embedding: Vec<F>) -> Self {
|
||||
Self { dimension: embedding.len(), data: embedding }
|
||||
}
|
||||
|
||||
/// Declares a vector of embeddings from its components.
|
||||
///
|
||||
/// `data.len()` must be a multiple of `dimension`, otherwise an error is returned.
|
||||
pub fn from_inner(data: Vec<F>, dimension: usize) -> Result<Self, Vec<F>> {
|
||||
let mut this = Self::new(dimension);
|
||||
this.append(data)?;
|
||||
Ok(this)
|
||||
}
|
||||
|
||||
/// Returns the number of embeddings in this vector of embeddings.
|
||||
pub fn embedding_count(&self) -> usize {
|
||||
self.data.len() / self.dimension
|
||||
}
|
||||
|
||||
/// Dimension of a single embedding.
|
||||
pub fn dimension(&self) -> usize {
|
||||
self.dimension
|
||||
}
|
||||
|
||||
/// Deconstructs self into the inner flat vector.
|
||||
pub fn into_inner(self) -> Vec<F> {
|
||||
self.data
|
||||
}
|
||||
|
||||
/// A reference to the inner flat vector.
|
||||
pub fn as_inner(&self) -> &[F] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Iterates over the embeddings contained in the flat vector.
|
||||
pub fn iter(&self) -> impl Iterator<Item = &'_ [F]> + '_ {
|
||||
self.data.as_slice().chunks_exact(self.dimension)
|
||||
}
|
||||
|
||||
/// Push an embedding at the end of the embeddings.
|
||||
///
|
||||
/// If `embedding.len() != self.dimension`, then the push operation fails.
|
||||
pub fn push(&mut self, mut embedding: Vec<F>) -> Result<(), Vec<F>> {
|
||||
if embedding.len() != self.dimension {
|
||||
return Err(embedding);
|
||||
@ -62,6 +86,9 @@ impl<F> Embeddings<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Append a flat vector of embeddings a the end of the embeddings.
|
||||
///
|
||||
/// If `embeddings.len() % self.dimension != 0`, then the append operation fails.
|
||||
pub fn append(&mut self, mut embeddings: Vec<F>) -> Result<(), Vec<F>> {
|
||||
if embeddings.len() % self.dimension != 0 {
|
||||
return Err(embeddings);
|
||||
@ -71,36 +98,60 @@ impl<F> Embeddings<F> {
|
||||
}
|
||||
}
|
||||
|
||||
/// An embedder can be used to transform text into embeddings.
|
||||
#[derive(Debug)]
|
||||
pub enum Embedder {
|
||||
/// An embedder based on running local models, fetched from the Hugging Face Hub.
|
||||
HuggingFace(hf::Embedder),
|
||||
/// An embedder based on making embedding queries against the OpenAI API.
|
||||
OpenAi(openai::Embedder),
|
||||
/// An embedder based on the user providing the embeddings in the documents and queries.
|
||||
UserProvided(manual::Embedder),
|
||||
/// An embedder based on making embedding queries against an <https://ollama.com> embedding server.
|
||||
Ollama(ollama::Embedder),
|
||||
/// An embedder based on making embedding queries against a generic JSON/REST embedding server.
|
||||
Rest(rest::Embedder),
|
||||
}
|
||||
|
||||
/// Configuration for an embedder.
|
||||
#[derive(Debug, Clone, Default, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbeddingConfig {
|
||||
/// Options of the embedder, specific to each kind of embedder
|
||||
pub embedder_options: EmbedderOptions,
|
||||
/// Document template
|
||||
pub prompt: PromptData,
|
||||
// TODO: add metrics and anything needed
|
||||
}
|
||||
|
||||
/// Map of embedder configurations.
|
||||
///
|
||||
/// Each configuration is mapped to a name.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct EmbeddingConfigs(HashMap<String, (Arc<Embedder>, Arc<Prompt>)>);
|
||||
|
||||
impl EmbeddingConfigs {
|
||||
/// Create the map from its internal component.s
|
||||
pub fn new(data: HashMap<String, (Arc<Embedder>, Arc<Prompt>)>) -> Self {
|
||||
Self(data)
|
||||
}
|
||||
|
||||
/// Get an embedder configuration and template from its name.
|
||||
pub fn get(&self, name: &str) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
|
||||
self.0.get(name).cloned()
|
||||
}
|
||||
|
||||
/// Get the default embedder configuration, if any.
|
||||
pub fn get_default(&self) -> Option<(Arc<Embedder>, Arc<Prompt>)> {
|
||||
self.get_default_embedder_name().and_then(|default| self.get(&default))
|
||||
}
|
||||
|
||||
/// Get the name of the default embedder configuration.
|
||||
///
|
||||
/// The default embedder is determined as follows:
|
||||
///
|
||||
/// - If there is only one embedder, it is always the default.
|
||||
/// - If there are multiple embedders and one of them is called `default`, then that one is the default embedder.
|
||||
/// - In all other cases, there is no default embedder.
|
||||
pub fn get_default_embedder_name(&self) -> Option<String> {
|
||||
let mut it = self.0.keys();
|
||||
let first_name = it.next();
|
||||
@ -123,11 +174,14 @@ impl IntoIterator for EmbeddingConfigs {
|
||||
}
|
||||
}
|
||||
|
||||
/// Options of an embedder, specific to each kind of embedder.
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub enum EmbedderOptions {
|
||||
HuggingFace(hf::EmbedderOptions),
|
||||
OpenAi(openai::EmbedderOptions),
|
||||
Ollama(ollama::EmbedderOptions),
|
||||
UserProvided(manual::EmbedderOptions),
|
||||
Rest(rest::EmbedderOptions),
|
||||
}
|
||||
|
||||
impl Default for EmbedderOptions {
|
||||
@ -137,91 +191,158 @@ impl Default for EmbedderOptions {
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
/// Default options for the Hugging Face embedder
|
||||
pub fn huggingface() -> Self {
|
||||
Self::HuggingFace(hf::EmbedderOptions::new())
|
||||
}
|
||||
|
||||
/// Default options for the OpenAI embedder
|
||||
pub fn openai(api_key: Option<String>) -> Self {
|
||||
Self::OpenAi(openai::EmbedderOptions::with_default_model(api_key))
|
||||
}
|
||||
|
||||
pub fn ollama() -> Self {
|
||||
Self::Ollama(ollama::EmbedderOptions::with_default_model())
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
/// Spawns a new embedder built from its options.
|
||||
pub fn new(options: EmbedderOptions) -> std::result::Result<Self, NewEmbedderError> {
|
||||
Ok(match options {
|
||||
EmbedderOptions::HuggingFace(options) => Self::HuggingFace(hf::Embedder::new(options)?),
|
||||
EmbedderOptions::OpenAi(options) => Self::OpenAi(openai::Embedder::new(options)?),
|
||||
EmbedderOptions::Ollama(options) => Self::Ollama(ollama::Embedder::new(options)?),
|
||||
EmbedderOptions::UserProvided(options) => {
|
||||
Self::UserProvided(manual::Embedder::new(options))
|
||||
}
|
||||
EmbedderOptions::Rest(options) => Self::Rest(rest::Embedder::new(options)?),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
/// Embed one or multiple texts.
|
||||
///
|
||||
/// Each text can be embedded as one or multiple embeddings.
|
||||
pub fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
) -> std::result::Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed(texts),
|
||||
Embedder::OpenAi(embedder) => {
|
||||
let client = embedder.new_client()?;
|
||||
embedder.embed(texts, &client).await
|
||||
}
|
||||
Embedder::OpenAi(embedder) => embedder.embed(texts),
|
||||
Embedder::Ollama(embedder) => embedder.embed(texts),
|
||||
Embedder::UserProvided(embedder) => embedder.embed(texts),
|
||||
Embedder::Rest(embedder) => embedder.embed(texts),
|
||||
}
|
||||
}
|
||||
|
||||
/// # Panics
|
||||
/// Embed multiple chunks of texts.
|
||||
///
|
||||
/// - if called from an asynchronous context
|
||||
/// Each chunk is composed of one or multiple texts.
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &rayon::ThreadPool,
|
||||
) -> std::result::Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::OpenAi(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||
Embedder::Ollama(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||
Embedder::UserProvided(embedder) => embedder.embed_chunks(text_chunks),
|
||||
Embedder::Rest(embedder) => embedder.embed_chunks(text_chunks, threads),
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the preferred number of chunks to pass to [`Self::embed_chunks`]
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::Ollama(embedder) => embedder.chunk_count_hint(),
|
||||
Embedder::UserProvided(_) => 1,
|
||||
Embedder::Rest(embedder) => embedder.chunk_count_hint(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the preferred number of texts in a single chunk passed to [`Self::embed`]
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::OpenAi(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::Ollama(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
Embedder::UserProvided(_) => 1,
|
||||
Embedder::Rest(embedder) => embedder.prompt_count_in_chunk_hint(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Indicates the dimensions of a single embedding produced by the embedder.
|
||||
pub fn dimensions(&self) -> usize {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.dimensions(),
|
||||
Embedder::OpenAi(embedder) => embedder.dimensions(),
|
||||
Embedder::Ollama(embedder) => embedder.dimensions(),
|
||||
Embedder::UserProvided(embedder) => embedder.dimensions(),
|
||||
Embedder::Rest(embedder) => embedder.dimensions(),
|
||||
}
|
||||
}
|
||||
|
||||
/// An optional distribution used to apply an affine transformation to the similarity score of a document.
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
match self {
|
||||
Embedder::HuggingFace(embedder) => embedder.distribution(),
|
||||
Embedder::OpenAi(embedder) => embedder.distribution(),
|
||||
Embedder::Ollama(embedder) => embedder.distribution(),
|
||||
Embedder::UserProvided(_embedder) => None,
|
||||
Embedder::Rest(embedder) => embedder.distribution(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
/// Describes the mean and sigma of distribution of embedding similarity in the embedding space.
|
||||
///
|
||||
/// The intended use is to make the similarity score more comparable to the regular ranking score.
|
||||
/// This allows to correct effects where results are too "packed" around a certain value.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Deserialize, Serialize)]
|
||||
#[serde(from = "DistributionShiftSerializable")]
|
||||
#[serde(into = "DistributionShiftSerializable")]
|
||||
pub struct DistributionShift {
|
||||
pub current_mean: f32,
|
||||
pub current_sigma: f32,
|
||||
/// Value where the results are "packed".
|
||||
///
|
||||
/// Similarity scores are translated so that they are packed around 0.5 instead
|
||||
pub current_mean: OrderedFloat<f32>,
|
||||
|
||||
/// standard deviation of a similarity score.
|
||||
///
|
||||
/// Set below 0.4 to make the results less packed around the mean, and above 0.4 to make them more packed.
|
||||
pub current_sigma: OrderedFloat<f32>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct DistributionShiftSerializable {
|
||||
current_mean: f32,
|
||||
current_sigma: f32,
|
||||
}
|
||||
|
||||
impl From<DistributionShift> for DistributionShiftSerializable {
|
||||
fn from(
|
||||
DistributionShift {
|
||||
current_mean: OrderedFloat(current_mean),
|
||||
current_sigma: OrderedFloat(current_sigma),
|
||||
}: DistributionShift,
|
||||
) -> Self {
|
||||
Self { current_mean, current_sigma }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<DistributionShiftSerializable> for DistributionShift {
|
||||
fn from(
|
||||
DistributionShiftSerializable { current_mean, current_sigma }: DistributionShiftSerializable,
|
||||
) -> Self {
|
||||
Self {
|
||||
current_mean: OrderedFloat(current_mean),
|
||||
current_sigma: OrderedFloat(current_sigma),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DistributionShift {
|
||||
@ -230,11 +351,13 @@ impl DistributionShift {
|
||||
if sigma <= 0.0 {
|
||||
None
|
||||
} else {
|
||||
Some(Self { current_mean: mean, current_sigma: sigma })
|
||||
Some(Self { current_mean: OrderedFloat(mean), current_sigma: OrderedFloat(sigma) })
|
||||
}
|
||||
}
|
||||
|
||||
pub fn shift(&self, score: f32) -> f32 {
|
||||
let current_mean = self.current_mean.0;
|
||||
let current_sigma = self.current_sigma.0;
|
||||
// <https://math.stackexchange.com/a/2894689>
|
||||
// We're somewhat abusively mapping the distribution of distances to a gaussian.
|
||||
// The parameters we're given is the mean and sigma of the native result distribution.
|
||||
@ -244,9 +367,9 @@ impl DistributionShift {
|
||||
let target_sigma = 0.4;
|
||||
|
||||
// a^2 sig1^2 = sig2^2 => a^2 = sig2^2 / sig1^2 => a = sig2 / sig1, assuming a, sig1, and sig2 positive.
|
||||
let factor = target_sigma / self.current_sigma;
|
||||
let factor = target_sigma / current_sigma;
|
||||
// a*mu1 + b = mu2 => b = mu2 - a*mu1
|
||||
let offset = target_mean - (factor * self.current_mean);
|
||||
let offset = target_mean - (factor * current_mean);
|
||||
|
||||
let mut score = factor * score + offset;
|
||||
|
||||
@ -262,6 +385,7 @@ impl DistributionShift {
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether CUDA is supported in this version of Meilisearch.
|
||||
pub const fn is_cuda_enabled() -> bool {
|
||||
cfg!(feature = "cuda")
|
||||
}
|
||||
|
102
milli/src/vector/ollama.rs
Normal file
102
milli/src/vector/ollama.rs
Normal file
@ -0,0 +1,102 @@
|
||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||
|
||||
use super::error::{EmbedError, EmbedErrorKind, NewEmbedderError, NewEmbedderErrorKind};
|
||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||
use super::{DistributionShift, Embeddings};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
rest_embedder: RestEmbedder,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub embedding_model: String,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn with_default_model() -> Self {
|
||||
Self { embedding_model: "nomic-embed-text".into() }
|
||||
}
|
||||
|
||||
pub fn with_embedding_model(embedding_model: String) -> Self {
|
||||
Self { embedding_model }
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let model = options.embedding_model.as_str();
|
||||
let rest_embedder = match RestEmbedder::new(RestEmbedderOptions {
|
||||
api_key: None,
|
||||
distribution: None,
|
||||
dimensions: None,
|
||||
url: get_ollama_path(),
|
||||
query: serde_json::json!({
|
||||
"model": model,
|
||||
}),
|
||||
input_field: vec!["prompt".to_owned()],
|
||||
path_to_embeddings: Default::default(),
|
||||
embedding_object: vec!["embedding".to_owned()],
|
||||
input_type: super::rest::InputType::Text,
|
||||
}) {
|
||||
Ok(embedder) => embedder,
|
||||
Err(NewEmbedderError {
|
||||
kind:
|
||||
NewEmbedderErrorKind::CouldNotDetermineDimension(EmbedError {
|
||||
kind: super::error::EmbedErrorKind::RestOtherStatusCode(404, error),
|
||||
fault: _,
|
||||
}),
|
||||
fault: _,
|
||||
}) => {
|
||||
return Err(NewEmbedderError::could_not_determine_dimension(
|
||||
EmbedError::ollama_model_not_found(error),
|
||||
))
|
||||
}
|
||||
Err(error) => return Err(error),
|
||||
};
|
||||
|
||||
Ok(Self { rest_embedder })
|
||||
}
|
||||
|
||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self.rest_embedder.embed(texts) {
|
||||
Ok(embeddings) => Ok(embeddings),
|
||||
Err(EmbedError { kind: EmbedErrorKind::RestOtherStatusCode(404, error), fault: _ }) => {
|
||||
Err(EmbedError::ollama_model_not_found(error))
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &rayon::ThreadPool,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
threads.install(move || {
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
self.rest_embedder.chunk_count_hint()
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
self.rest_embedder.prompt_count_in_chunk_hint()
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.rest_embedder.dimensions()
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn get_ollama_path() -> String {
|
||||
// Important: Hostname not enough, has to be entire path to embeddings endpoint
|
||||
std::env::var("MEILI_OLLAMA_URL").unwrap_or("http://localhost:11434/api/embeddings".to_string())
|
||||
}
|
@ -1,17 +1,10 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use reqwest::StatusCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use ordered_float::OrderedFloat;
|
||||
use rayon::iter::{IntoParallelIterator, ParallelIterator as _};
|
||||
|
||||
use super::error::{EmbedError, NewEmbedderError};
|
||||
use super::{DistributionShift, Embedding, Embeddings};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
headers: reqwest::header::HeaderMap,
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
use super::rest::{Embedder as RestEmbedder, EmbedderOptions as RestEmbedderOptions};
|
||||
use super::{DistributionShift, Embeddings};
|
||||
use crate::vector::error::EmbedErrorKind;
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
@ -20,6 +13,32 @@ pub struct EmbedderOptions {
|
||||
pub dimensions: Option<usize>,
|
||||
}
|
||||
|
||||
impl EmbedderOptions {
|
||||
pub fn dimensions(&self) -> usize {
|
||||
if self.embedding_model.supports_overriding_dimensions() {
|
||||
self.dimensions.unwrap_or(self.embedding_model.default_dimensions())
|
||||
} else {
|
||||
self.embedding_model.default_dimensions()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn query(&self) -> serde_json::Value {
|
||||
let model = self.embedding_model.name();
|
||||
|
||||
let mut query = serde_json::json!({
|
||||
"model": model,
|
||||
});
|
||||
|
||||
if self.embedding_model.supports_overriding_dimensions() {
|
||||
if let Some(dimensions) = self.dimensions {
|
||||
query["dimensions"] = dimensions.into();
|
||||
}
|
||||
}
|
||||
|
||||
query
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(
|
||||
Debug,
|
||||
Clone,
|
||||
@ -92,15 +111,18 @@ impl EmbeddingModel {
|
||||
|
||||
fn distribution(&self) -> Option<DistributionShift> {
|
||||
match self {
|
||||
EmbeddingModel::TextEmbeddingAda002 => {
|
||||
Some(DistributionShift { current_mean: 0.90, current_sigma: 0.08 })
|
||||
}
|
||||
EmbeddingModel::TextEmbedding3Large => {
|
||||
Some(DistributionShift { current_mean: 0.70, current_sigma: 0.1 })
|
||||
}
|
||||
EmbeddingModel::TextEmbedding3Small => {
|
||||
Some(DistributionShift { current_mean: 0.75, current_sigma: 0.1 })
|
||||
}
|
||||
EmbeddingModel::TextEmbeddingAda002 => Some(DistributionShift {
|
||||
current_mean: OrderedFloat(0.90),
|
||||
current_sigma: OrderedFloat(0.08),
|
||||
}),
|
||||
EmbeddingModel::TextEmbedding3Large => Some(DistributionShift {
|
||||
current_mean: OrderedFloat(0.70),
|
||||
current_sigma: OrderedFloat(0.1),
|
||||
}),
|
||||
EmbeddingModel::TextEmbedding3Small => Some(DistributionShift {
|
||||
current_mean: OrderedFloat(0.75),
|
||||
current_sigma: OrderedFloat(0.1),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
@ -125,178 +147,57 @@ impl EmbedderOptions {
|
||||
}
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new_client(&self) -> Result<reqwest::Client, EmbedError> {
|
||||
reqwest::ClientBuilder::new()
|
||||
.default_headers(self.headers.clone())
|
||||
.build()
|
||||
.map_err(EmbedError::openai_initialize_web_client)
|
||||
}
|
||||
fn infer_api_key() -> String {
|
||||
std::env::var("MEILI_OPENAI_API_KEY")
|
||||
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
tokenizer: tiktoken_rs::CoreBPE,
|
||||
rest_embedder: RestEmbedder,
|
||||
options: EmbedderOptions,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let mut headers = reqwest::header::HeaderMap::new();
|
||||
let mut inferred_api_key = Default::default();
|
||||
let api_key = options.api_key.as_ref().unwrap_or_else(|| {
|
||||
inferred_api_key = infer_api_key();
|
||||
&inferred_api_key
|
||||
});
|
||||
headers.insert(
|
||||
reqwest::header::AUTHORIZATION,
|
||||
reqwest::header::HeaderValue::from_str(&format!("Bearer {}", api_key))
|
||||
.map_err(NewEmbedderError::openai_invalid_api_key_format)?,
|
||||
);
|
||||
headers.insert(
|
||||
reqwest::header::CONTENT_TYPE,
|
||||
reqwest::header::HeaderValue::from_static("application/json"),
|
||||
);
|
||||
|
||||
let rest_embedder = RestEmbedder::new(RestEmbedderOptions {
|
||||
api_key: Some(api_key.clone()),
|
||||
distribution: options.embedding_model.distribution(),
|
||||
dimensions: Some(options.dimensions()),
|
||||
url: OPENAI_EMBEDDINGS_URL.to_owned(),
|
||||
query: options.query(),
|
||||
input_field: vec!["input".to_owned()],
|
||||
input_type: crate::vector::rest::InputType::TextArray,
|
||||
path_to_embeddings: vec!["data".to_owned()],
|
||||
embedding_object: vec!["embedding".to_owned()],
|
||||
})?;
|
||||
|
||||
// looking at the code it is very unclear that this can actually fail.
|
||||
let tokenizer = tiktoken_rs::cl100k_base().unwrap();
|
||||
|
||||
Ok(Self { options, headers, tokenizer })
|
||||
Ok(Self { options, rest_embedder, tokenizer })
|
||||
}
|
||||
|
||||
pub async fn embed(
|
||||
&self,
|
||||
texts: Vec<String>,
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let mut tokenized = false;
|
||||
|
||||
for attempt in 0..7 {
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
let retry_duration = match result {
|
||||
Ok(embeddings) => return Ok(embeddings),
|
||||
Err(retry) => {
|
||||
tracing::warn!("Failed: {}", retry.error);
|
||||
tokenized |= retry.must_tokenize();
|
||||
retry.into_duration(attempt)
|
||||
}
|
||||
}?;
|
||||
|
||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||
tracing::warn!(
|
||||
"Attempt #{}, retrying after {}ms.",
|
||||
attempt,
|
||||
retry_duration.as_millis()
|
||||
);
|
||||
tokio::time::sleep(retry_duration).await;
|
||||
}
|
||||
|
||||
let result = if tokenized {
|
||||
self.try_embed_tokenized(&texts, client).await
|
||||
} else {
|
||||
self.try_embed(&texts, client).await
|
||||
};
|
||||
|
||||
result.map_err(Retry::into_error)
|
||||
}
|
||||
|
||||
async fn check_response(response: reqwest::Response) -> Result<reqwest::Response, Retry> {
|
||||
if !response.status().is_success() {
|
||||
match response.status() {
|
||||
StatusCode::UNAUTHORIZED => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
return Err(Retry::give_up(EmbedError::openai_auth_error(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
return Err(Retry::rate_limited(EmbedError::openai_too_many_requests(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
| StatusCode::BAD_GATEWAY
|
||||
| StatusCode::SERVICE_UNAVAILABLE => {
|
||||
let error_response: Result<OpenAiErrorResponse, _> = response.json().await;
|
||||
return Err(Retry::retry_later(EmbedError::openai_internal_server_error(
|
||||
error_response.ok().map(|error_response| error_response.error),
|
||||
)));
|
||||
}
|
||||
StatusCode::BAD_REQUEST => {
|
||||
// Most probably, one text contained too many tokens
|
||||
let error_response: OpenAiErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tracing::warn!("OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your prompt.");
|
||||
|
||||
return Err(Retry::retry_tokenized(EmbedError::openai_too_many_tokens(
|
||||
error_response.error,
|
||||
)));
|
||||
}
|
||||
code => {
|
||||
return Err(Retry::retry_later(EmbedError::openai_unhandled_status_code(
|
||||
code.as_u16(),
|
||||
)));
|
||||
}
|
||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
match self.rest_embedder.embed_ref(&texts) {
|
||||
Ok(embeddings) => Ok(embeddings),
|
||||
Err(EmbedError { kind: EmbedErrorKind::RestBadRequest(error), fault: _ }) => {
|
||||
tracing::warn!(error=?error, "OpenAI: received `BAD_REQUEST`. Input was maybe too long, retrying on tokenized version. For best performance, limit the size of your document template.");
|
||||
self.try_embed_tokenized(&texts)
|
||||
}
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn try_embed<S: AsRef<str> + serde::Serialize>(
|
||||
&self,
|
||||
texts: &[S],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
for text in texts {
|
||||
tracing::trace!("Received prompt: {}", text.as_ref())
|
||||
}
|
||||
let request = OpenAiRequest {
|
||||
model: self.options.embedding_model.name(),
|
||||
input: texts,
|
||||
dimensions: self.overriden_dimensions(),
|
||||
};
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::openai_network)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
let response = Self::check_response(response).await?;
|
||||
|
||||
let response: OpenAiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tracing::trace!("response: {:?}", response.data);
|
||||
|
||||
Ok(response
|
||||
.data
|
||||
.into_iter()
|
||||
.map(|data| Embeddings::from_single_embedding(data.embedding))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn try_embed_tokenized(
|
||||
&self,
|
||||
text: &[String],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Vec<Embeddings<f32>>, Retry> {
|
||||
fn try_embed_tokenized(&self, text: &[String]) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
pub const OVERLAP_SIZE: usize = 200;
|
||||
let mut all_embeddings = Vec::with_capacity(text.len());
|
||||
for text in text {
|
||||
@ -304,7 +205,7 @@ impl Embedder {
|
||||
let encoded = self.tokenizer.encode_ordinary(text.as_str());
|
||||
let len = encoded.len();
|
||||
if len < max_token_count {
|
||||
all_embeddings.append(&mut self.try_embed(&[text], client).await?);
|
||||
all_embeddings.append(&mut self.rest_embedder.embed_ref(&[text])?);
|
||||
continue;
|
||||
}
|
||||
|
||||
@ -312,215 +213,49 @@ impl Embedder {
|
||||
let mut embeddings_for_prompt = Embeddings::new(self.dimensions());
|
||||
while tokens.len() > max_token_count {
|
||||
let window = &tokens[..max_token_count];
|
||||
embeddings_for_prompt.push(self.embed_tokens(window, client).await?).unwrap();
|
||||
let embedding = self.rest_embedder.embed_tokens(window)?;
|
||||
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
|
||||
EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
|
||||
})?;
|
||||
|
||||
tokens = &tokens[max_token_count - OVERLAP_SIZE..];
|
||||
}
|
||||
|
||||
// end of text
|
||||
embeddings_for_prompt.push(self.embed_tokens(tokens, client).await?).unwrap();
|
||||
let embedding = self.rest_embedder.embed_tokens(tokens)?;
|
||||
|
||||
embeddings_for_prompt.append(embedding.into_inner()).map_err(|got| {
|
||||
EmbedError::openai_unexpected_dimension(self.dimensions(), got.len())
|
||||
})?;
|
||||
|
||||
all_embeddings.push(embeddings_for_prompt);
|
||||
}
|
||||
Ok(all_embeddings)
|
||||
}
|
||||
|
||||
async fn embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
for attempt in 0..9 {
|
||||
let duration = match self.try_embed_tokens(tokens, client).await {
|
||||
Ok(embedding) => return Ok(embedding),
|
||||
Err(retry) => retry.into_duration(attempt),
|
||||
}
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
tokio::time::sleep(duration).await;
|
||||
}
|
||||
|
||||
self.try_embed_tokens(tokens, client)
|
||||
.await
|
||||
.map_err(|retry| Retry::give_up(retry.into_error()))
|
||||
}
|
||||
|
||||
async fn try_embed_tokens(
|
||||
&self,
|
||||
tokens: &[usize],
|
||||
client: &reqwest::Client,
|
||||
) -> Result<Embedding, Retry> {
|
||||
let request = OpenAiTokensRequest {
|
||||
model: self.options.embedding_model.name(),
|
||||
input: tokens,
|
||||
dimensions: self.overriden_dimensions(),
|
||||
};
|
||||
let response = client
|
||||
.post(OPENAI_EMBEDDINGS_URL)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(EmbedError::openai_network)
|
||||
.map_err(Retry::retry_later)?;
|
||||
|
||||
let response = Self::check_response(response).await?;
|
||||
|
||||
let mut response: OpenAiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(EmbedError::openai_unexpected)
|
||||
.map_err(Retry::retry_later)?;
|
||||
Ok(response.data.pop().map(|data| data.embedding).unwrap_or_default())
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &rayon::ThreadPool,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_io()
|
||||
.enable_time()
|
||||
.build()
|
||||
.map_err(EmbedError::openai_runtime_init)?;
|
||||
let client = self.new_client()?;
|
||||
rt.block_on(futures::future::try_join_all(
|
||||
text_chunks.into_iter().map(|prompts| self.embed(prompts, &client)),
|
||||
))
|
||||
threads.install(move || {
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
10
|
||||
self.rest_embedder.chunk_count_hint()
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
10
|
||||
self.rest_embedder.prompt_count_in_chunk_hint()
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||
self.options.dimensions.unwrap_or(self.options.embedding_model.default_dimensions())
|
||||
} else {
|
||||
self.options.embedding_model.default_dimensions()
|
||||
}
|
||||
self.options.dimensions()
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.options.embedding_model.distribution()
|
||||
}
|
||||
|
||||
fn overriden_dimensions(&self) -> Option<usize> {
|
||||
if self.options.embedding_model.supports_overriding_dimensions() {
|
||||
self.options.dimensions
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// retrying in case of failure
|
||||
|
||||
struct Retry {
|
||||
error: EmbedError,
|
||||
strategy: RetryStrategy,
|
||||
}
|
||||
|
||||
enum RetryStrategy {
|
||||
GiveUp,
|
||||
Retry,
|
||||
RetryTokenized,
|
||||
RetryAfterRateLimit,
|
||||
}
|
||||
|
||||
impl Retry {
|
||||
fn give_up(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::GiveUp }
|
||||
}
|
||||
|
||||
fn retry_later(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::Retry }
|
||||
}
|
||||
|
||||
fn retry_tokenized(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryTokenized }
|
||||
}
|
||||
|
||||
fn rate_limited(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
||||
}
|
||||
|
||||
fn into_duration(self, attempt: u32) -> Result<tokio::time::Duration, EmbedError> {
|
||||
match self.strategy {
|
||||
RetryStrategy::GiveUp => Err(self.error),
|
||||
RetryStrategy::Retry => Ok(tokio::time::Duration::from_millis((10u64).pow(attempt))),
|
||||
RetryStrategy::RetryTokenized => Ok(tokio::time::Duration::from_millis(1)),
|
||||
RetryStrategy::RetryAfterRateLimit => {
|
||||
Ok(tokio::time::Duration::from_millis(100 + 10u64.pow(attempt)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn must_tokenize(&self) -> bool {
|
||||
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
||||
}
|
||||
|
||||
fn into_error(self) -> EmbedError {
|
||||
self.error
|
||||
}
|
||||
}
|
||||
|
||||
// openai api structs
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OpenAiRequest<'a, S: AsRef<str> + serde::Serialize> {
|
||||
model: &'a str,
|
||||
input: &'a [S],
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
dimensions: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct OpenAiTokensRequest<'a> {
|
||||
model: &'a str,
|
||||
input: &'a [usize],
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
dimensions: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiResponse {
|
||||
data: Vec<OpenAiEmbedding>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiErrorResponse {
|
||||
error: OpenAiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct OpenAiError {
|
||||
message: String,
|
||||
// type: String,
|
||||
code: Option<String>,
|
||||
}
|
||||
|
||||
impl Display for OpenAiError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match &self.code {
|
||||
Some(code) => write!(f, "{} ({})", self.message, code),
|
||||
None => write!(f, "{}", self.message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbedding {
|
||||
embedding: Embedding,
|
||||
// object: String,
|
||||
// index: usize,
|
||||
}
|
||||
|
||||
fn infer_api_key() -> String {
|
||||
std::env::var("MEILI_OPENAI_API_KEY")
|
||||
.or_else(|_| std::env::var("OPENAI_API_KEY"))
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
373
milli/src/vector/rest.rs
Normal file
373
milli/src/vector/rest.rs
Normal file
@ -0,0 +1,373 @@
|
||||
use deserr::Deserr;
|
||||
use rayon::iter::{IntoParallelIterator as _, ParallelIterator as _};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::{
|
||||
DistributionShift, EmbedError, Embedding, Embeddings, NewEmbedderError, REQUEST_PARALLELISM,
|
||||
};
|
||||
|
||||
// retrying in case of failure
|
||||
|
||||
pub struct Retry {
|
||||
pub error: EmbedError,
|
||||
strategy: RetryStrategy,
|
||||
}
|
||||
|
||||
pub enum RetryStrategy {
|
||||
GiveUp,
|
||||
Retry,
|
||||
RetryTokenized,
|
||||
RetryAfterRateLimit,
|
||||
}
|
||||
|
||||
impl Retry {
|
||||
pub fn give_up(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::GiveUp }
|
||||
}
|
||||
|
||||
pub fn retry_later(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::Retry }
|
||||
}
|
||||
|
||||
pub fn retry_tokenized(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryTokenized }
|
||||
}
|
||||
|
||||
pub fn rate_limited(error: EmbedError) -> Self {
|
||||
Self { error, strategy: RetryStrategy::RetryAfterRateLimit }
|
||||
}
|
||||
|
||||
pub fn into_duration(self, attempt: u32) -> Result<std::time::Duration, EmbedError> {
|
||||
match self.strategy {
|
||||
RetryStrategy::GiveUp => Err(self.error),
|
||||
RetryStrategy::Retry => Ok(std::time::Duration::from_millis((10u64).pow(attempt))),
|
||||
RetryStrategy::RetryTokenized => Ok(std::time::Duration::from_millis(1)),
|
||||
RetryStrategy::RetryAfterRateLimit => {
|
||||
Ok(std::time::Duration::from_millis(100 + 10u64.pow(attempt)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn must_tokenize(&self) -> bool {
|
||||
matches!(self.strategy, RetryStrategy::RetryTokenized)
|
||||
}
|
||||
|
||||
pub fn into_error(self) -> EmbedError {
|
||||
self.error
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Embedder {
|
||||
client: ureq::Agent,
|
||||
options: EmbedderOptions,
|
||||
bearer: Option<String>,
|
||||
dimensions: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
|
||||
pub struct EmbedderOptions {
|
||||
pub api_key: Option<String>,
|
||||
pub distribution: Option<DistributionShift>,
|
||||
pub dimensions: Option<usize>,
|
||||
pub url: String,
|
||||
pub query: serde_json::Value,
|
||||
pub input_field: Vec<String>,
|
||||
// path to the array of embeddings
|
||||
pub path_to_embeddings: Vec<String>,
|
||||
// shape of a single embedding
|
||||
pub embedding_object: Vec<String>,
|
||||
pub input_type: InputType,
|
||||
}
|
||||
|
||||
impl Default for EmbedderOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
url: Default::default(),
|
||||
query: Default::default(),
|
||||
input_field: vec!["input".into()],
|
||||
path_to_embeddings: vec!["data".into()],
|
||||
embedding_object: vec!["embedding".into()],
|
||||
input_type: InputType::Text,
|
||||
api_key: None,
|
||||
distribution: None,
|
||||
dimensions: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::hash::Hash for EmbedderOptions {
|
||||
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
|
||||
self.api_key.hash(state);
|
||||
self.distribution.hash(state);
|
||||
self.dimensions.hash(state);
|
||||
self.url.hash(state);
|
||||
// skip hashing the query
|
||||
// collisions in regular usage should be minimal,
|
||||
// and the list is limited to 256 values anyway
|
||||
self.input_field.hash(state);
|
||||
self.path_to_embeddings.hash(state);
|
||||
self.embedding_object.hash(state);
|
||||
self.input_type.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Deserialize, Serialize, PartialEq, Eq, Hash, Deserr)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[deserr(rename_all = camelCase, deny_unknown_fields)]
|
||||
pub enum InputType {
|
||||
Text,
|
||||
TextArray,
|
||||
}
|
||||
|
||||
impl Embedder {
|
||||
pub fn new(options: EmbedderOptions) -> Result<Self, NewEmbedderError> {
|
||||
let bearer = options.api_key.as_deref().map(|api_key| format!("Bearer {api_key}"));
|
||||
|
||||
let client = ureq::AgentBuilder::new()
|
||||
.max_idle_connections(REQUEST_PARALLELISM * 2)
|
||||
.max_idle_connections_per_host(REQUEST_PARALLELISM * 2)
|
||||
.build();
|
||||
|
||||
let dimensions = if let Some(dimensions) = options.dimensions {
|
||||
dimensions
|
||||
} else {
|
||||
infer_dimensions(&client, &options, bearer.as_deref())?
|
||||
};
|
||||
|
||||
Ok(Self { client, dimensions, options, bearer })
|
||||
}
|
||||
|
||||
pub fn embed(&self, texts: Vec<String>) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
embed(&self.client, &self.options, self.bearer.as_deref(), texts.as_slice(), texts.len())
|
||||
}
|
||||
|
||||
pub fn embed_ref<S>(&self, texts: &[S]) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||
where
|
||||
S: AsRef<str> + Serialize,
|
||||
{
|
||||
embed(&self.client, &self.options, self.bearer.as_deref(), texts, texts.len())
|
||||
}
|
||||
|
||||
pub fn embed_tokens(&self, tokens: &[usize]) -> Result<Embeddings<f32>, EmbedError> {
|
||||
let mut embeddings = embed(&self.client, &self.options, self.bearer.as_deref(), tokens, 1)?;
|
||||
// unwrap: guaranteed that embeddings.len() == 1, otherwise the previous line terminated in error
|
||||
Ok(embeddings.pop().unwrap())
|
||||
}
|
||||
|
||||
pub fn embed_chunks(
|
||||
&self,
|
||||
text_chunks: Vec<Vec<String>>,
|
||||
threads: &rayon::ThreadPool,
|
||||
) -> Result<Vec<Vec<Embeddings<f32>>>, EmbedError> {
|
||||
threads.install(move || {
|
||||
text_chunks.into_par_iter().map(move |chunk| self.embed(chunk)).collect()
|
||||
})
|
||||
}
|
||||
|
||||
pub fn chunk_count_hint(&self) -> usize {
|
||||
super::REQUEST_PARALLELISM
|
||||
}
|
||||
|
||||
pub fn prompt_count_in_chunk_hint(&self) -> usize {
|
||||
match self.options.input_type {
|
||||
InputType::Text => 1,
|
||||
InputType::TextArray => 10,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dimensions(&self) -> usize {
|
||||
self.dimensions
|
||||
}
|
||||
|
||||
pub fn distribution(&self) -> Option<DistributionShift> {
|
||||
self.options.distribution
|
||||
}
|
||||
}
|
||||
|
||||
fn infer_dimensions(
|
||||
client: &ureq::Agent,
|
||||
options: &EmbedderOptions,
|
||||
bearer: Option<&str>,
|
||||
) -> Result<usize, NewEmbedderError> {
|
||||
let v = embed(client, options, bearer, ["test"].as_slice(), 1)
|
||||
.map_err(NewEmbedderError::could_not_determine_dimension)?;
|
||||
// unwrap: guaranteed that v.len() == 1, otherwise the previous line terminated in error
|
||||
Ok(v.first().unwrap().dimension())
|
||||
}
|
||||
|
||||
fn embed<S>(
|
||||
client: &ureq::Agent,
|
||||
options: &EmbedderOptions,
|
||||
bearer: Option<&str>,
|
||||
inputs: &[S],
|
||||
expected_count: usize,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError>
|
||||
where
|
||||
S: Serialize,
|
||||
{
|
||||
let request = client.post(&options.url);
|
||||
let request =
|
||||
if let Some(bearer) = bearer { request.set("Authorization", bearer) } else { request };
|
||||
let request = request.set("Content-Type", "application/json");
|
||||
|
||||
let input_value = match options.input_type {
|
||||
InputType::Text => serde_json::json!(inputs.first()),
|
||||
InputType::TextArray => serde_json::json!(inputs),
|
||||
};
|
||||
|
||||
let body = match options.input_field.as_slice() {
|
||||
[] => {
|
||||
// inject input in body
|
||||
input_value
|
||||
}
|
||||
[input] => {
|
||||
let mut body = options.query.clone();
|
||||
|
||||
body.as_object_mut()
|
||||
.ok_or_else(|| {
|
||||
EmbedError::rest_not_an_object(
|
||||
options.query.clone(),
|
||||
options.input_field.clone(),
|
||||
)
|
||||
})?
|
||||
.insert(input.clone(), input_value);
|
||||
body
|
||||
}
|
||||
[path @ .., input] => {
|
||||
let mut body = options.query.clone();
|
||||
|
||||
let mut current_value = &mut body;
|
||||
for component in path {
|
||||
current_value = current_value
|
||||
.as_object_mut()
|
||||
.ok_or_else(|| {
|
||||
EmbedError::rest_not_an_object(
|
||||
options.query.clone(),
|
||||
options.input_field.clone(),
|
||||
)
|
||||
})?
|
||||
.entry(component.clone())
|
||||
.or_insert(serde_json::json!({}));
|
||||
}
|
||||
|
||||
current_value.as_object_mut().unwrap().insert(input.clone(), input_value);
|
||||
body
|
||||
}
|
||||
};
|
||||
|
||||
for attempt in 0..7 {
|
||||
let response = request.clone().send_json(&body);
|
||||
let result = check_response(response);
|
||||
|
||||
let retry_duration = match result {
|
||||
Ok(response) => return response_to_embedding(response, options, expected_count),
|
||||
Err(retry) => {
|
||||
tracing::warn!("Failed: {}", retry.error);
|
||||
retry.into_duration(attempt)
|
||||
}
|
||||
}?;
|
||||
|
||||
let retry_duration = retry_duration.min(std::time::Duration::from_secs(60)); // don't wait more than a minute
|
||||
tracing::warn!("Attempt #{}, retrying after {}ms.", attempt, retry_duration.as_millis());
|
||||
std::thread::sleep(retry_duration);
|
||||
}
|
||||
|
||||
let response = request.send_json(&body);
|
||||
let result = check_response(response);
|
||||
result
|
||||
.map_err(Retry::into_error)
|
||||
.and_then(|response| response_to_embedding(response, options, expected_count))
|
||||
}
|
||||
|
||||
fn check_response(response: Result<ureq::Response, ureq::Error>) -> Result<ureq::Response, Retry> {
|
||||
match response {
|
||||
Ok(response) => Ok(response),
|
||||
Err(ureq::Error::Status(code, response)) => {
|
||||
let error_response: Option<String> = response.into_string().ok();
|
||||
Err(match code {
|
||||
401 => Retry::give_up(EmbedError::rest_unauthorized(error_response)),
|
||||
429 => Retry::rate_limited(EmbedError::rest_too_many_requests(error_response)),
|
||||
400 => Retry::give_up(EmbedError::rest_bad_request(error_response)),
|
||||
500..=599 => {
|
||||
Retry::retry_later(EmbedError::rest_internal_server_error(code, error_response))
|
||||
}
|
||||
402..=499 => {
|
||||
Retry::give_up(EmbedError::rest_other_status_code(code, error_response))
|
||||
}
|
||||
_ => Retry::retry_later(EmbedError::rest_other_status_code(code, error_response)),
|
||||
})
|
||||
}
|
||||
Err(ureq::Error::Transport(transport)) => {
|
||||
Err(Retry::retry_later(EmbedError::rest_network(transport)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn response_to_embedding(
|
||||
response: ureq::Response,
|
||||
options: &EmbedderOptions,
|
||||
expected_count: usize,
|
||||
) -> Result<Vec<Embeddings<f32>>, EmbedError> {
|
||||
let response: serde_json::Value =
|
||||
response.into_json().map_err(EmbedError::rest_response_deserialization)?;
|
||||
|
||||
let mut current_value = &response;
|
||||
for component in &options.path_to_embeddings {
|
||||
let component = component.as_ref();
|
||||
current_value = current_value.get(component).ok_or_else(|| {
|
||||
EmbedError::rest_response_missing_embeddings(
|
||||
response.clone(),
|
||||
component,
|
||||
&options.path_to_embeddings,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
let embeddings = match options.input_type {
|
||||
InputType::Text => {
|
||||
for component in &options.embedding_object {
|
||||
current_value = current_value.get(component).ok_or_else(|| {
|
||||
EmbedError::rest_response_missing_embeddings(
|
||||
response.clone(),
|
||||
component,
|
||||
&options.embedding_object,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
let embeddings = current_value.to_owned();
|
||||
let embeddings: Embedding =
|
||||
serde_json::from_value(embeddings).map_err(EmbedError::rest_response_format)?;
|
||||
|
||||
vec![Embeddings::from_single_embedding(embeddings)]
|
||||
}
|
||||
InputType::TextArray => {
|
||||
let empty = vec![];
|
||||
let values = current_value.as_array().unwrap_or(&empty);
|
||||
let mut embeddings: Vec<Embeddings<f32>> = Vec::with_capacity(expected_count);
|
||||
for value in values {
|
||||
let mut current_value = value;
|
||||
for component in &options.embedding_object {
|
||||
current_value = current_value.get(component).ok_or_else(|| {
|
||||
EmbedError::rest_response_missing_embeddings(
|
||||
response.clone(),
|
||||
component,
|
||||
&options.embedding_object,
|
||||
)
|
||||
})?;
|
||||
}
|
||||
let embedding = current_value.to_owned();
|
||||
let embedding: Embedding =
|
||||
serde_json::from_value(embedding).map_err(EmbedError::rest_response_format)?;
|
||||
embeddings.push(Embeddings::from_single_embedding(embedding));
|
||||
}
|
||||
embeddings
|
||||
}
|
||||
};
|
||||
|
||||
if embeddings.len() != expected_count {
|
||||
return Err(EmbedError::rest_response_embedding_count(expected_count, embeddings.len()));
|
||||
}
|
||||
|
||||
Ok(embeddings)
|
||||
}
|
@ -1,7 +1,8 @@
|
||||
use deserr::Deserr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::openai;
|
||||
use super::rest::InputType;
|
||||
use super::{ollama, openai};
|
||||
use crate::prompt::PromptData;
|
||||
use crate::update::Setting;
|
||||
use crate::vector::EmbeddingConfig;
|
||||
@ -29,6 +30,24 @@ pub struct EmbeddingSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub document_template: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub url: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub query: Setting<serde_json::Value>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub input_field: Setting<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub path_to_embeddings: Setting<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub embedding_object: Setting<Vec<String>>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
#[deserr(default)]
|
||||
pub input_type: Setting<InputType>,
|
||||
}
|
||||
|
||||
pub fn check_unset<T>(
|
||||
@ -75,16 +94,42 @@ impl EmbeddingSettings {
|
||||
pub const DIMENSIONS: &'static str = "dimensions";
|
||||
pub const DOCUMENT_TEMPLATE: &'static str = "documentTemplate";
|
||||
|
||||
pub const URL: &'static str = "url";
|
||||
pub const QUERY: &'static str = "query";
|
||||
pub const INPUT_FIELD: &'static str = "inputField";
|
||||
pub const PATH_TO_EMBEDDINGS: &'static str = "pathToEmbeddings";
|
||||
pub const EMBEDDING_OBJECT: &'static str = "embeddingObject";
|
||||
pub const INPUT_TYPE: &'static str = "inputType";
|
||||
|
||||
pub fn allowed_sources_for_field(field: &'static str) -> &'static [EmbedderSource] {
|
||||
match field {
|
||||
Self::SOURCE => {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::UserProvided]
|
||||
Self::SOURCE => &[
|
||||
EmbedderSource::HuggingFace,
|
||||
EmbedderSource::OpenAi,
|
||||
EmbedderSource::UserProvided,
|
||||
EmbedderSource::Rest,
|
||||
EmbedderSource::Ollama,
|
||||
],
|
||||
Self::MODEL => {
|
||||
&[EmbedderSource::HuggingFace, EmbedderSource::OpenAi, EmbedderSource::Ollama]
|
||||
}
|
||||
Self::MODEL => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
||||
Self::REVISION => &[EmbedderSource::HuggingFace],
|
||||
Self::API_KEY => &[EmbedderSource::OpenAi],
|
||||
Self::DIMENSIONS => &[EmbedderSource::OpenAi, EmbedderSource::UserProvided],
|
||||
Self::DOCUMENT_TEMPLATE => &[EmbedderSource::HuggingFace, EmbedderSource::OpenAi],
|
||||
Self::API_KEY => &[EmbedderSource::OpenAi, EmbedderSource::Rest],
|
||||
Self::DIMENSIONS => {
|
||||
&[EmbedderSource::OpenAi, EmbedderSource::UserProvided, EmbedderSource::Rest]
|
||||
}
|
||||
Self::DOCUMENT_TEMPLATE => &[
|
||||
EmbedderSource::HuggingFace,
|
||||
EmbedderSource::OpenAi,
|
||||
EmbedderSource::Ollama,
|
||||
EmbedderSource::Rest,
|
||||
],
|
||||
Self::URL => &[EmbedderSource::Rest],
|
||||
Self::QUERY => &[EmbedderSource::Rest],
|
||||
Self::INPUT_FIELD => &[EmbedderSource::Rest],
|
||||
Self::PATH_TO_EMBEDDINGS => &[EmbedderSource::Rest],
|
||||
Self::EMBEDDING_OBJECT => &[EmbedderSource::Rest],
|
||||
Self::INPUT_TYPE => &[EmbedderSource::Rest],
|
||||
_other => unreachable!("unknown field"),
|
||||
}
|
||||
}
|
||||
@ -101,7 +146,20 @@ impl EmbeddingSettings {
|
||||
EmbedderSource::HuggingFace => {
|
||||
&[Self::SOURCE, Self::MODEL, Self::REVISION, Self::DOCUMENT_TEMPLATE]
|
||||
}
|
||||
EmbedderSource::Ollama => &[Self::SOURCE, Self::MODEL, Self::DOCUMENT_TEMPLATE],
|
||||
EmbedderSource::UserProvided => &[Self::SOURCE, Self::DIMENSIONS],
|
||||
EmbedderSource::Rest => &[
|
||||
Self::SOURCE,
|
||||
Self::API_KEY,
|
||||
Self::DIMENSIONS,
|
||||
Self::DOCUMENT_TEMPLATE,
|
||||
Self::URL,
|
||||
Self::QUERY,
|
||||
Self::INPUT_FIELD,
|
||||
Self::PATH_TO_EMBEDDINGS,
|
||||
Self::EMBEDDING_OBJECT,
|
||||
Self::INPUT_TYPE,
|
||||
],
|
||||
}
|
||||
}
|
||||
|
||||
@ -134,7 +192,9 @@ pub enum EmbedderSource {
|
||||
#[default]
|
||||
OpenAi,
|
||||
HuggingFace,
|
||||
Ollama,
|
||||
UserProvided,
|
||||
Rest,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EmbedderSource {
|
||||
@ -143,6 +203,8 @@ impl std::fmt::Display for EmbedderSource {
|
||||
EmbedderSource::OpenAi => "openAi",
|
||||
EmbedderSource::HuggingFace => "huggingFace",
|
||||
EmbedderSource::UserProvided => "userProvided",
|
||||
EmbedderSource::Ollama => "ollama",
|
||||
EmbedderSource::Rest => "rest",
|
||||
};
|
||||
f.write_str(s)
|
||||
}
|
||||
@ -150,8 +212,20 @@ impl std::fmt::Display for EmbedderSource {
|
||||
|
||||
impl EmbeddingSettings {
|
||||
pub fn apply(&mut self, new: Self) {
|
||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
||||
new;
|
||||
let EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
} = new;
|
||||
let old_source = self.source;
|
||||
self.source.apply(source);
|
||||
// Reinitialize the whole setting object on a source change
|
||||
@ -163,6 +237,12 @@ impl EmbeddingSettings {
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
};
|
||||
return;
|
||||
}
|
||||
@ -172,6 +252,13 @@ impl EmbeddingSettings {
|
||||
self.api_key.apply(api_key);
|
||||
self.dimensions.apply(dimensions);
|
||||
self.document_template.apply(document_template);
|
||||
|
||||
self.url.apply(url);
|
||||
self.query.apply(query);
|
||||
self.input_field.apply(input_field);
|
||||
self.path_to_embeddings.apply(path_to_embeddings);
|
||||
self.embedding_object.apply(embedding_object);
|
||||
self.input_type.apply(input_type);
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,6 +273,12 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::NotSet,
|
||||
document_template: Setting::Set(prompt.template),
|
||||
url: Setting::NotSet,
|
||||
query: Setting::NotSet,
|
||||
input_field: Setting::NotSet,
|
||||
path_to_embeddings: Setting::NotSet,
|
||||
embedding_object: Setting::NotSet,
|
||||
input_type: Setting::NotSet,
|
||||
},
|
||||
super::EmbedderOptions::OpenAi(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::OpenAi),
|
||||
@ -194,6 +287,26 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
api_key: options.api_key.map(Setting::Set).unwrap_or_default(),
|
||||
dimensions: options.dimensions.map(Setting::Set).unwrap_or_default(),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
url: Setting::NotSet,
|
||||
query: Setting::NotSet,
|
||||
input_field: Setting::NotSet,
|
||||
path_to_embeddings: Setting::NotSet,
|
||||
embedding_object: Setting::NotSet,
|
||||
input_type: Setting::NotSet,
|
||||
},
|
||||
super::EmbedderOptions::Ollama(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::Ollama),
|
||||
model: Setting::Set(options.embedding_model.to_owned()),
|
||||
revision: Setting::NotSet,
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::NotSet,
|
||||
document_template: Setting::Set(prompt.template),
|
||||
url: Setting::NotSet,
|
||||
query: Setting::NotSet,
|
||||
input_field: Setting::NotSet,
|
||||
path_to_embeddings: Setting::NotSet,
|
||||
embedding_object: Setting::NotSet,
|
||||
input_type: Setting::NotSet,
|
||||
},
|
||||
super::EmbedderOptions::UserProvided(options) => Self {
|
||||
source: Setting::Set(EmbedderSource::UserProvided),
|
||||
@ -202,6 +315,37 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
api_key: Setting::NotSet,
|
||||
dimensions: Setting::Set(options.dimensions),
|
||||
document_template: Setting::NotSet,
|
||||
url: Setting::NotSet,
|
||||
query: Setting::NotSet,
|
||||
input_field: Setting::NotSet,
|
||||
path_to_embeddings: Setting::NotSet,
|
||||
embedding_object: Setting::NotSet,
|
||||
input_type: Setting::NotSet,
|
||||
},
|
||||
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
||||
api_key,
|
||||
// TODO: support distribution
|
||||
distribution: _,
|
||||
dimensions,
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
}) => Self {
|
||||
source: Setting::Set(EmbedderSource::Rest),
|
||||
model: Setting::NotSet,
|
||||
revision: Setting::NotSet,
|
||||
api_key: api_key.map(Setting::Set).unwrap_or_default(),
|
||||
dimensions: dimensions.map(Setting::Set).unwrap_or_default(),
|
||||
document_template: Setting::Set(prompt.template),
|
||||
url: Setting::Set(url),
|
||||
query: Setting::Set(query),
|
||||
input_field: Setting::Set(input_field),
|
||||
path_to_embeddings: Setting::Set(path_to_embeddings),
|
||||
embedding_object: Setting::Set(embedding_object),
|
||||
input_type: Setting::Set(input_type),
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -210,8 +354,20 @@ impl From<EmbeddingConfig> for EmbeddingSettings {
|
||||
impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
fn from(value: EmbeddingSettings) -> Self {
|
||||
let mut this = Self::default();
|
||||
let EmbeddingSettings { source, model, revision, api_key, dimensions, document_template } =
|
||||
value;
|
||||
let EmbeddingSettings {
|
||||
source,
|
||||
model,
|
||||
revision,
|
||||
api_key,
|
||||
dimensions,
|
||||
document_template,
|
||||
url,
|
||||
query,
|
||||
input_field,
|
||||
path_to_embeddings,
|
||||
embedding_object,
|
||||
input_type,
|
||||
} = value;
|
||||
if let Some(source) = source.set() {
|
||||
match source {
|
||||
EmbedderSource::OpenAi => {
|
||||
@ -229,6 +385,14 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
}
|
||||
this.embedder_options = super::EmbedderOptions::OpenAi(options);
|
||||
}
|
||||
EmbedderSource::Ollama => {
|
||||
let mut options: ollama::EmbedderOptions =
|
||||
super::ollama::EmbedderOptions::with_default_model();
|
||||
if let Some(model) = model.set() {
|
||||
options.embedding_model = model;
|
||||
}
|
||||
this.embedder_options = super::EmbedderOptions::Ollama(options);
|
||||
}
|
||||
EmbedderSource::HuggingFace => {
|
||||
let mut options = super::hf::EmbedderOptions::default();
|
||||
if let Some(model) = model.set() {
|
||||
@ -251,6 +415,26 @@ impl From<EmbeddingSettings> for EmbeddingConfig {
|
||||
dimensions: dimensions.set().unwrap(),
|
||||
});
|
||||
}
|
||||
EmbedderSource::Rest => {
|
||||
let embedder_options = super::rest::EmbedderOptions::default();
|
||||
|
||||
this.embedder_options =
|
||||
super::EmbedderOptions::Rest(super::rest::EmbedderOptions {
|
||||
api_key: api_key.set(),
|
||||
distribution: None,
|
||||
dimensions: dimensions.set(),
|
||||
url: url.set().unwrap(),
|
||||
query: query.set().unwrap_or(embedder_options.query),
|
||||
input_field: input_field.set().unwrap_or(embedder_options.input_field),
|
||||
path_to_embeddings: path_to_embeddings
|
||||
.set()
|
||||
.unwrap_or(embedder_options.path_to_embeddings),
|
||||
embedding_object: embedding_object
|
||||
.set()
|
||||
.unwrap_or(embedder_options.embedding_object),
|
||||
input_type: input_type.set().unwrap_or(embedder_options.input_type),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
94
workloads/settings-add-remove-filters.json
Normal file
94
workloads/settings-add-remove-filters.json
Normal file
@ -0,0 +1,94 @@
|
||||
{
|
||||
"name": "settings-add-remove-filters.json",
|
||||
"run_count": 2,
|
||||
"extra_cli_args": [
|
||||
"--max-indexing-threads=4"
|
||||
],
|
||||
"assets": {
|
||||
"150k-people.json": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/150k-people.json",
|
||||
"sha256": "28c359a0956958af0ba204ec11bad3045a0864a10b4838914fea25a01724f84b"
|
||||
}
|
||||
},
|
||||
"commands": [
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"searchableAttributes": [
|
||||
"last_name",
|
||||
"first_name",
|
||||
"featured_job_organization_name",
|
||||
"facebook_url",
|
||||
"twitter_url",
|
||||
"linkedin_url"
|
||||
],
|
||||
"filterableAttributes": [
|
||||
"city",
|
||||
"region",
|
||||
"country_code"
|
||||
],
|
||||
"dictionary": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
],
|
||||
"stopWords": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "DontWait"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "150k-people.json"
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"filterableAttributes": [
|
||||
"city",
|
||||
"region",
|
||||
"country_code",
|
||||
"featured_job_title",
|
||||
"featured_job_organization_name"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"filterableAttributes": [
|
||||
"city",
|
||||
"region",
|
||||
"country_code"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
}
|
||||
]
|
||||
}
|
86
workloads/settings-proximity-precision.json
Normal file
86
workloads/settings-proximity-precision.json
Normal file
@ -0,0 +1,86 @@
|
||||
{
|
||||
"name": "settings-proximity-precision.json",
|
||||
"run_count": 2,
|
||||
"extra_cli_args": [
|
||||
"--max-indexing-threads=4"
|
||||
],
|
||||
"assets": {
|
||||
"150k-people.json": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/150k-people.json",
|
||||
"sha256": "28c359a0956958af0ba204ec11bad3045a0864a10b4838914fea25a01724f84b"
|
||||
}
|
||||
},
|
||||
"commands": [
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"searchableAttributes": [
|
||||
"last_name",
|
||||
"first_name",
|
||||
"featured_job_organization_name",
|
||||
"facebook_url",
|
||||
"twitter_url",
|
||||
"linkedin_url"
|
||||
],
|
||||
"filterableAttributes": [
|
||||
"city",
|
||||
"region",
|
||||
"country_code",
|
||||
"featured_job_title",
|
||||
"featured_job_organization_name"
|
||||
],
|
||||
"dictionary": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
],
|
||||
"stopWords": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "DontWait"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "150k-people.json"
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"proximityPrecision": "byAttribute"
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"proximityPrecision": "byWord"
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
}
|
||||
]
|
||||
}
|
114
workloads/settings-remove-add-swap-searchable.json
Normal file
114
workloads/settings-remove-add-swap-searchable.json
Normal file
@ -0,0 +1,114 @@
|
||||
{
|
||||
"name": "settings-remove-add-swap-searchable.json",
|
||||
"run_count": 2,
|
||||
"extra_cli_args": [
|
||||
"--max-indexing-threads=4"
|
||||
],
|
||||
"assets": {
|
||||
"150k-people.json": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/150k-people.json",
|
||||
"sha256": "28c359a0956958af0ba204ec11bad3045a0864a10b4838914fea25a01724f84b"
|
||||
}
|
||||
},
|
||||
"commands": [
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"searchableAttributes": [
|
||||
"last_name",
|
||||
"first_name",
|
||||
"featured_job_organization_name",
|
||||
"facebook_url",
|
||||
"twitter_url",
|
||||
"linkedin_url"
|
||||
],
|
||||
"filterableAttributes": [
|
||||
"city",
|
||||
"region",
|
||||
"country_code",
|
||||
"featured_job_title",
|
||||
"featured_job_organization_name"
|
||||
],
|
||||
"dictionary": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
],
|
||||
"stopWords": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "DontWait"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "150k-people.json"
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"searchableAttributes": [
|
||||
"last_name",
|
||||
"first_name",
|
||||
"featured_job_organization_name"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"searchableAttributes": [
|
||||
"last_name",
|
||||
"first_name",
|
||||
"featured_job_organization_name",
|
||||
"facebook_url",
|
||||
"twitter_url",
|
||||
"linkedin_url"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"searchableAttributes": [
|
||||
"first_name",
|
||||
"last_name",
|
||||
"featured_job_organization_name",
|
||||
"facebook_url",
|
||||
"twitter_url",
|
||||
"linkedin_url"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
}
|
||||
]
|
||||
}
|
115
workloads/settings-typo.json
Normal file
115
workloads/settings-typo.json
Normal file
@ -0,0 +1,115 @@
|
||||
{
|
||||
"name": "settings-typo.json",
|
||||
"run_count": 2,
|
||||
"extra_cli_args": [
|
||||
"--max-indexing-threads=4"
|
||||
],
|
||||
"assets": {
|
||||
"150k-people.json": {
|
||||
"local_location": null,
|
||||
"remote_location": "https://milli-benchmarks.fra1.digitaloceanspaces.com/bench/datasets/150k-people.json",
|
||||
"sha256": "28c359a0956958af0ba204ec11bad3045a0864a10b4838914fea25a01724f84b"
|
||||
}
|
||||
},
|
||||
"commands": [
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"searchableAttributes": [
|
||||
"last_name",
|
||||
"first_name",
|
||||
"featured_job_title",
|
||||
"featured_job_organization_name",
|
||||
"facebook_url",
|
||||
"twitter_url",
|
||||
"linkedin_url"
|
||||
],
|
||||
"filterableAttributes": [
|
||||
"city",
|
||||
"region",
|
||||
"country_code",
|
||||
"featured_job_title",
|
||||
"featured_job_organization_name"
|
||||
],
|
||||
"dictionary": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
],
|
||||
"stopWords": [
|
||||
"https://",
|
||||
"http://",
|
||||
"www.",
|
||||
"crunchbase.com",
|
||||
"facebook.com",
|
||||
"twitter.com",
|
||||
"linkedin.com"
|
||||
]
|
||||
}
|
||||
},
|
||||
"synchronous": "DontWait"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/documents",
|
||||
"method": "POST",
|
||||
"body": {
|
||||
"asset": "150k-people.json"
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"typoTolerance": {
|
||||
"disableOnAttributes": ["featured_job_organization_name"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"typoTolerance": {
|
||||
"disableOnAttributes": []
|
||||
}
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"typoTolerance": {
|
||||
"disableOnWords": ["Ben","Elowitz","Kevin","Flaherty", "Ron", "Dustin", "Owen", "Chris", "Mark", "Matt", "Peter", "Van", "Head", "of"]
|
||||
}
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
},
|
||||
{
|
||||
"route": "indexes/peoples/settings",
|
||||
"method": "PATCH",
|
||||
"body": {
|
||||
"inline": {
|
||||
"typoTolerance": {
|
||||
"disableOnWords": []
|
||||
}
|
||||
}
|
||||
},
|
||||
"synchronous": "WaitForTask"
|
||||
}
|
||||
]
|
||||
}
|
@ -11,157 +11,179 @@ use super::client::Client;
|
||||
use super::env_info;
|
||||
use super::workload::Workload;
|
||||
|
||||
pub async fn cancel_on_ctrl_c(
|
||||
invocation_uuid: Uuid,
|
||||
dashboard_client: Client,
|
||||
abort_handle: AbortHandle,
|
||||
) {
|
||||
tracing::info!("press Ctrl-C to cancel the invocation");
|
||||
match ctrl_c().await {
|
||||
Ok(()) => {
|
||||
tracing::info!(%invocation_uuid, "received Ctrl-C, cancelling invocation");
|
||||
mark_as_failed(dashboard_client, invocation_uuid, None).await;
|
||||
abort_handle.abort();
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DashboardClient {
|
||||
Client(Client),
|
||||
Dry,
|
||||
}
|
||||
|
||||
impl DashboardClient {
|
||||
pub fn new(dashboard_url: &str, api_key: Option<&str>) -> anyhow::Result<Self> {
|
||||
let dashboard_client = Client::new(
|
||||
Some(format!("{}/api/v1", dashboard_url)),
|
||||
api_key,
|
||||
Some(std::time::Duration::from_secs(60)),
|
||||
)?;
|
||||
|
||||
Ok(Self::Client(dashboard_client))
|
||||
}
|
||||
|
||||
pub fn new_dry() -> Self {
|
||||
Self::Dry
|
||||
}
|
||||
|
||||
pub async fn send_machine_info(&self, env: &env_info::Environment) -> anyhow::Result<()> {
|
||||
let Self::Client(dashboard_client) = self else { return Ok(()) };
|
||||
|
||||
let response = dashboard_client
|
||||
.put("machine")
|
||||
.json(&json!({"hostname": env.hostname}))
|
||||
.send()
|
||||
.await
|
||||
.context("sending machine information")?;
|
||||
if !response.status().is_success() {
|
||||
bail!(
|
||||
"could not send machine information: {} {}",
|
||||
response.status(),
|
||||
response.text().await.unwrap_or_else(|_| "unknown".into())
|
||||
);
|
||||
}
|
||||
Err(error) => tracing::warn!(
|
||||
error = &error as &dyn std::error::Error,
|
||||
"failed to listen to Ctrl-C signal, invocation won't be canceled on Ctrl-C"
|
||||
),
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mark_as_failed(
|
||||
dashboard_client: Client,
|
||||
invocation_uuid: Uuid,
|
||||
failure_reason: Option<String>,
|
||||
) {
|
||||
let response = dashboard_client
|
||||
.post("cancel-invocation")
|
||||
.json(&json!({
|
||||
"invocation_uuid": invocation_uuid,
|
||||
"failure_reason": failure_reason,
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(response_error) => {
|
||||
tracing::error!(error = &response_error as &dyn std::error::Error, %invocation_uuid, "could not mark invocation as failed");
|
||||
return;
|
||||
pub async fn create_invocation(
|
||||
&self,
|
||||
build_info: build_info::BuildInfo,
|
||||
commit_message: &str,
|
||||
env: env_info::Environment,
|
||||
max_workloads: usize,
|
||||
reason: Option<&str>,
|
||||
) -> anyhow::Result<Uuid> {
|
||||
let Self::Client(dashboard_client) = self else { return Ok(Uuid::now_v7()) };
|
||||
|
||||
let response = dashboard_client
|
||||
.put("invocation")
|
||||
.json(&json!({
|
||||
"commit": {
|
||||
"sha1": build_info.commit_sha1,
|
||||
"message": commit_message,
|
||||
"commit_date": build_info.commit_timestamp,
|
||||
"branch": build_info.branch,
|
||||
"tag": build_info.describe.and_then(|describe| describe.as_tag()),
|
||||
},
|
||||
"machine_hostname": env.hostname,
|
||||
"max_workloads": max_workloads,
|
||||
"reason": reason
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.context("sending invocation")?;
|
||||
if !response.status().is_success() {
|
||||
bail!(
|
||||
"could not send new invocation: {}",
|
||||
response.text().await.unwrap_or_else(|_| "unknown".into())
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
tracing::error!(
|
||||
%invocation_uuid,
|
||||
"could not mark invocation as failed: {}",
|
||||
response.text().await.unwrap()
|
||||
);
|
||||
return;
|
||||
}
|
||||
tracing::warn!(%invocation_uuid, "marked invocation as failed or canceled");
|
||||
}
|
||||
|
||||
pub async fn send_machine_info(
|
||||
dashboard_client: &Client,
|
||||
env: &env_info::Environment,
|
||||
) -> anyhow::Result<()> {
|
||||
let response = dashboard_client
|
||||
.put("machine")
|
||||
.json(&json!({"hostname": env.hostname}))
|
||||
.send()
|
||||
.await
|
||||
.context("sending machine information")?;
|
||||
if !response.status().is_success() {
|
||||
bail!(
|
||||
"could not send machine information: {} {}",
|
||||
response.status(),
|
||||
response.text().await.unwrap_or_else(|_| "unknown".into())
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_invocation(
|
||||
dashboard_client: &Client,
|
||||
build_info: build_info::BuildInfo,
|
||||
commit_message: &str,
|
||||
env: env_info::Environment,
|
||||
max_workloads: usize,
|
||||
reason: Option<&str>,
|
||||
) -> anyhow::Result<Uuid> {
|
||||
let response = dashboard_client
|
||||
.put("invocation")
|
||||
.json(&json!({
|
||||
"commit": {
|
||||
"sha1": build_info.commit_sha1,
|
||||
"message": commit_message,
|
||||
"commit_date": build_info.commit_timestamp,
|
||||
"branch": build_info.branch,
|
||||
"tag": build_info.describe.and_then(|describe| describe.as_tag()),
|
||||
},
|
||||
"machine_hostname": env.hostname,
|
||||
"max_workloads": max_workloads,
|
||||
"reason": reason
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.context("sending invocation")?;
|
||||
if !response.status().is_success() {
|
||||
bail!(
|
||||
"could not send new invocation: {}",
|
||||
response.text().await.unwrap_or_else(|_| "unknown".into())
|
||||
);
|
||||
}
|
||||
let invocation_uuid: Uuid =
|
||||
response.json().await.context("could not deserialize invocation response as JSON")?;
|
||||
Ok(invocation_uuid)
|
||||
}
|
||||
|
||||
pub async fn create_workload(
|
||||
dashboard_client: &Client,
|
||||
invocation_uuid: Uuid,
|
||||
workload: &Workload,
|
||||
) -> anyhow::Result<Uuid> {
|
||||
let response = dashboard_client
|
||||
.put("workload")
|
||||
.json(&json!({
|
||||
"invocation_uuid": invocation_uuid,
|
||||
"name": &workload.name,
|
||||
"max_runs": workload.run_count,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.context("could not create new workload")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
bail!("creating new workload failed: {}", response.text().await.unwrap())
|
||||
let invocation_uuid: Uuid =
|
||||
response.json().await.context("could not deserialize invocation response as JSON")?;
|
||||
Ok(invocation_uuid)
|
||||
}
|
||||
|
||||
let workload_uuid: Uuid =
|
||||
response.json().await.context("could not deserialize JSON as UUID")?;
|
||||
Ok(workload_uuid)
|
||||
}
|
||||
pub async fn create_workload(
|
||||
&self,
|
||||
invocation_uuid: Uuid,
|
||||
workload: &Workload,
|
||||
) -> anyhow::Result<Uuid> {
|
||||
let Self::Client(dashboard_client) = self else { return Ok(Uuid::now_v7()) };
|
||||
|
||||
pub async fn create_run(
|
||||
dashboard_client: Client,
|
||||
workload_uuid: Uuid,
|
||||
report: &BTreeMap<String, CallStats>,
|
||||
) -> anyhow::Result<()> {
|
||||
let response = dashboard_client
|
||||
.put("run")
|
||||
.json(&json!({
|
||||
"workload_uuid": workload_uuid,
|
||||
"data": report
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.context("sending new run")?;
|
||||
if !response.status().is_success() {
|
||||
bail!(
|
||||
"sending new run failed: {}",
|
||||
response.text().await.unwrap_or_else(|_| "unknown".into())
|
||||
)
|
||||
let response = dashboard_client
|
||||
.put("workload")
|
||||
.json(&json!({
|
||||
"invocation_uuid": invocation_uuid,
|
||||
"name": &workload.name,
|
||||
"max_runs": workload.run_count,
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.context("could not create new workload")?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
bail!("creating new workload failed: {}", response.text().await.unwrap())
|
||||
}
|
||||
|
||||
let workload_uuid: Uuid =
|
||||
response.json().await.context("could not deserialize JSON as UUID")?;
|
||||
Ok(workload_uuid)
|
||||
}
|
||||
|
||||
pub async fn create_run(
|
||||
&self,
|
||||
workload_uuid: Uuid,
|
||||
report: &BTreeMap<String, CallStats>,
|
||||
) -> anyhow::Result<()> {
|
||||
let Self::Client(dashboard_client) = self else { return Ok(()) };
|
||||
|
||||
let response = dashboard_client
|
||||
.put("run")
|
||||
.json(&json!({
|
||||
"workload_uuid": workload_uuid,
|
||||
"data": report
|
||||
}))
|
||||
.send()
|
||||
.await
|
||||
.context("sending new run")?;
|
||||
if !response.status().is_success() {
|
||||
bail!(
|
||||
"sending new run failed: {}",
|
||||
response.text().await.unwrap_or_else(|_| "unknown".into())
|
||||
)
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn cancel_on_ctrl_c(self, invocation_uuid: Uuid, abort_handle: AbortHandle) {
|
||||
tracing::info!("press Ctrl-C to cancel the invocation");
|
||||
match ctrl_c().await {
|
||||
Ok(()) => {
|
||||
tracing::info!(%invocation_uuid, "received Ctrl-C, cancelling invocation");
|
||||
self.mark_as_failed(invocation_uuid, None).await;
|
||||
abort_handle.abort();
|
||||
}
|
||||
Err(error) => tracing::warn!(
|
||||
error = &error as &dyn std::error::Error,
|
||||
"failed to listen to Ctrl-C signal, invocation won't be canceled on Ctrl-C"
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn mark_as_failed(&self, invocation_uuid: Uuid, failure_reason: Option<String>) {
|
||||
if let DashboardClient::Client(client) = self {
|
||||
let response = client
|
||||
.post("cancel-invocation")
|
||||
.json(&json!({
|
||||
"invocation_uuid": invocation_uuid,
|
||||
"failure_reason": failure_reason,
|
||||
}))
|
||||
.send()
|
||||
.await;
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(response_error) => {
|
||||
tracing::error!(error = &response_error as &dyn std::error::Error, %invocation_uuid, "could not mark invocation as failed");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if !response.status().is_success() {
|
||||
tracing::error!(
|
||||
%invocation_uuid,
|
||||
"could not mark invocation as failed: {}",
|
||||
response.text().await.unwrap()
|
||||
);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!(%invocation_uuid, "marked invocation as failed or canceled");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -50,6 +50,10 @@ pub struct BenchDeriveArgs {
|
||||
#[arg(long, default_value_t = default_dashboard_url())]
|
||||
dashboard_url: String,
|
||||
|
||||
/// Don't actually send results to the dashboard
|
||||
#[arg(long)]
|
||||
no_dashboard: bool,
|
||||
|
||||
/// Directory to output reports.
|
||||
#[arg(long, default_value_t = default_report_folder())]
|
||||
report_folder: String,
|
||||
@ -103,11 +107,11 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> {
|
||||
let assets_client =
|
||||
Client::new(None, args.assets_key.as_deref(), Some(std::time::Duration::from_secs(3600)))?; // 1h
|
||||
|
||||
let dashboard_client = Client::new(
|
||||
Some(format!("{}/api/v1", args.dashboard_url)),
|
||||
args.api_key.as_deref(),
|
||||
Some(std::time::Duration::from_secs(60)),
|
||||
)?;
|
||||
let dashboard_client = if args.no_dashboard {
|
||||
dashboard::DashboardClient::new_dry()
|
||||
} else {
|
||||
dashboard::DashboardClient::new(&args.dashboard_url, args.api_key.as_deref())?
|
||||
};
|
||||
|
||||
// reporting uses its own client because keeping the stream open to wait for entries
|
||||
// blocks any other requests
|
||||
@ -127,12 +131,12 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> {
|
||||
// enter runtime
|
||||
|
||||
rt.block_on(async {
|
||||
dashboard::send_machine_info(&dashboard_client, &env).await?;
|
||||
dashboard_client.send_machine_info(&env).await?;
|
||||
|
||||
let commit_message = build_info.commit_msg.context("missing commit message")?.split('\n').next().unwrap();
|
||||
let max_workloads = args.workload_file.len();
|
||||
let reason: Option<&str> = args.reason.as_deref();
|
||||
let invocation_uuid = dashboard::create_invocation(&dashboard_client, build_info, commit_message, env, max_workloads, reason).await?;
|
||||
let invocation_uuid = dashboard_client.create_invocation( build_info, commit_message, env, max_workloads, reason).await?;
|
||||
|
||||
tracing::info!(workload_count = args.workload_file.len(), "handling workload files");
|
||||
|
||||
@ -167,7 +171,7 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> {
|
||||
let abort_handle = workload_runs.abort_handle();
|
||||
tokio::spawn({
|
||||
let dashboard_client = dashboard_client.clone();
|
||||
dashboard::cancel_on_ctrl_c(invocation_uuid, dashboard_client, abort_handle)
|
||||
dashboard_client.cancel_on_ctrl_c(invocation_uuid, abort_handle)
|
||||
});
|
||||
|
||||
// wait for the end of the main task, handle result
|
||||
@ -178,7 +182,7 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> {
|
||||
}
|
||||
Ok(Err(error)) => {
|
||||
tracing::error!(%invocation_uuid, error = %error, "invocation failed, attempting to report the failure to dashboard");
|
||||
dashboard::mark_as_failed(dashboard_client, invocation_uuid, Some(error.to_string())).await;
|
||||
dashboard_client.mark_as_failed(invocation_uuid, Some(error.to_string())).await;
|
||||
tracing::warn!(%invocation_uuid, "invocation marked as failed following error");
|
||||
Err(error)
|
||||
},
|
||||
@ -186,7 +190,7 @@ pub fn run(args: BenchDeriveArgs) -> anyhow::Result<()> {
|
||||
match join_error.try_into_panic() {
|
||||
Ok(panic) => {
|
||||
tracing::error!("invocation panicked, attempting to report the failure to dashboard");
|
||||
dashboard::mark_as_failed(dashboard_client, invocation_uuid, Some("Panicked".into())).await;
|
||||
dashboard_client.mark_as_failed( invocation_uuid, Some("Panicked".into())).await;
|
||||
std::panic::resume_unwind(panic)
|
||||
}
|
||||
Err(_) => {
|
||||
|
@ -12,8 +12,9 @@ use uuid::Uuid;
|
||||
use super::assets::Asset;
|
||||
use super::client::Client;
|
||||
use super::command::SyncMode;
|
||||
use super::dashboard::DashboardClient;
|
||||
use super::BenchDeriveArgs;
|
||||
use crate::bench::{assets, dashboard, meili_process};
|
||||
use crate::bench::{assets, meili_process};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct Workload {
|
||||
@ -25,7 +26,7 @@ pub struct Workload {
|
||||
}
|
||||
|
||||
async fn run_commands(
|
||||
dashboard_client: &Client,
|
||||
dashboard_client: &DashboardClient,
|
||||
logs_client: &Client,
|
||||
meili_client: &Client,
|
||||
workload_uuid: Uuid,
|
||||
@ -64,7 +65,7 @@ async fn run_commands(
|
||||
#[tracing::instrument(skip(assets_client, dashboard_client, logs_client, meili_client, workload, master_key, args), fields(workload = workload.name))]
|
||||
pub async fn execute(
|
||||
assets_client: &Client,
|
||||
dashboard_client: &Client,
|
||||
dashboard_client: &DashboardClient,
|
||||
logs_client: &Client,
|
||||
meili_client: &Client,
|
||||
invocation_uuid: Uuid,
|
||||
@ -74,8 +75,7 @@ pub async fn execute(
|
||||
) -> anyhow::Result<()> {
|
||||
assets::fetch_assets(assets_client, &workload.assets, &args.asset_folder).await?;
|
||||
|
||||
let workload_uuid =
|
||||
dashboard::create_workload(dashboard_client, invocation_uuid, &workload).await?;
|
||||
let workload_uuid = dashboard_client.create_workload(invocation_uuid, &workload).await?;
|
||||
|
||||
let mut tasks = Vec::new();
|
||||
|
||||
@ -113,7 +113,7 @@ pub async fn execute(
|
||||
#[allow(clippy::too_many_arguments)] // not best code quality, but this is a benchmark runner
|
||||
#[tracing::instrument(skip(dashboard_client, logs_client, meili_client, workload, master_key, args), fields(workload = %workload.name))]
|
||||
async fn execute_run(
|
||||
dashboard_client: &Client,
|
||||
dashboard_client: &DashboardClient,
|
||||
logs_client: &Client,
|
||||
meili_client: &Client,
|
||||
workload_uuid: Uuid,
|
||||
@ -202,7 +202,7 @@ async fn start_report(
|
||||
}
|
||||
|
||||
async fn stop_report(
|
||||
dashboard_client: &Client,
|
||||
dashboard_client: &DashboardClient,
|
||||
logs_client: &Client,
|
||||
workload_uuid: Uuid,
|
||||
filename: String,
|
||||
@ -232,7 +232,7 @@ async fn stop_report(
|
||||
.context("could not convert trace to report")?;
|
||||
let context = || format!("writing report to {filename}");
|
||||
|
||||
dashboard::create_run(dashboard_client, workload_uuid, &report).await?;
|
||||
dashboard_client.create_run(workload_uuid, &report).await?;
|
||||
|
||||
let mut output_file = std::io::BufWriter::new(
|
||||
std::fs::File::options()
|
||||
|
Loading…
x
Reference in New Issue
Block a user