Implement a first version of a streamed chat API

This commit is contained in:
Clément Renault 2025-05-14 11:18:21 +02:00
parent 3a71df7b5a
commit 7d62307739
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
3 changed files with 214 additions and 36 deletions

211
Cargo.lock generated
View File

@ -27,7 +27,7 @@ checksum = "f9e772b3bcafe335042b5db010ab7c09013dad6eac4915c91d8d50902769f331"
dependencies = [
"actix-utils",
"actix-web",
"derive_more",
"derive_more 0.99.17",
"futures-util",
"log",
"once_cell",
@ -36,24 +36,24 @@ dependencies = [
[[package]]
name = "actix-http"
version = "3.9.0"
version = "3.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d48f96fc3003717aeb9856ca3d02a8c7de502667ad76eeacd830b48d2e91fac4"
checksum = "44dfe5c9e0004c623edc65391dfd51daa201e7e30ebd9c9bedf873048ec32bc2"
dependencies = [
"actix-codec",
"actix-rt",
"actix-service",
"actix-tls",
"actix-utils",
"ahash 0.8.11",
"base64 0.22.1",
"bitflags 2.9.0",
"brotli",
"brotli 8.0.1",
"bytes",
"bytestring",
"derive_more",
"derive_more 2.0.1",
"encoding_rs",
"flate2",
"foldhash",
"futures-core",
"h2 0.3.26",
"http 0.2.11",
@ -65,7 +65,7 @@ dependencies = [
"mime",
"percent-encoding",
"pin-project-lite",
"rand",
"rand 0.9.1",
"sha1",
"smallvec",
"tokio",
@ -92,6 +92,7 @@ dependencies = [
"bytestring",
"cfg-if",
"http 0.2.11",
"regex",
"regex-lite",
"serde",
"tracing",
@ -187,7 +188,7 @@ dependencies = [
"bytestring",
"cfg-if",
"cookie",
"derive_more",
"derive_more 0.99.17",
"encoding_rs",
"futures-core",
"futures-util",
@ -220,6 +221,43 @@ dependencies = [
"syn 2.0.87",
]
[[package]]
name = "actix-web-lab"
version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33034dd88446a5deb20e42156dbfe43d07e0499345db3ae65b3f51854190531"
dependencies = [
"actix-http",
"actix-router",
"actix-service",
"actix-utils",
"actix-web",
"ahash 0.8.11",
"arc-swap",
"bytes",
"bytestring",
"csv",
"derive_more 2.0.1",
"form_urlencoded",
"futures-core",
"futures-util",
"http 0.2.11",
"impl-more",
"itertools 0.14.0",
"local-channel",
"mime",
"pin-project-lite",
"regex",
"serde",
"serde_html_form",
"serde_json",
"serde_path_to_error",
"tokio",
"tokio-stream",
"tracing",
"url",
]
[[package]]
name = "addr2line"
version = "0.20.0"
@ -385,6 +423,12 @@ dependencies = [
"derive_arbitrary",
]
[[package]]
name = "arc-swap"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]]
name = "arrayvec"
version = "0.7.4"
@ -405,7 +449,7 @@ dependencies = [
"nohash",
"ordered-float",
"page_size",
"rand",
"rand 0.8.5",
"rayon",
"roaring",
"tempfile",
@ -436,7 +480,7 @@ dependencies = [
"derive_builder 0.20.2",
"eventsource-stream",
"futures",
"rand",
"rand 0.8.5",
"reqwest",
"reqwest-eventsource",
"secrecy",
@ -493,7 +537,7 @@ dependencies = [
"getrandom 0.2.15",
"instant",
"pin-project-lite",
"rand",
"rand 0.8.5",
"tokio",
]
@ -549,8 +593,8 @@ dependencies = [
"memmap2",
"milli",
"mimalloc",
"rand",
"rand_chacha",
"rand 0.8.5",
"rand_chacha 0.3.1",
"reqwest",
"roaring",
"serde_json",
@ -702,7 +746,18 @@ checksum = "74f7971dbd9326d58187408ab83117d8ac1bb9c17b085fdacd1cf2f598719b6b"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
"brotli-decompressor",
"brotli-decompressor 4.0.1",
]
[[package]]
name = "brotli"
version = "8.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9991eea70ea4f293524138648e41ee89b0b2b12ddef3b255effa43c8056e0e0d"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
"brotli-decompressor 5.0.0",
]
[[package]]
@ -715,6 +770,16 @@ dependencies = [
"alloc-stdlib",
]
[[package]]
name = "brotli-decompressor"
version = "5.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "874bb8112abecc98cbd6d81ea4fa7e94fb9449648c93cc89aa40c81c24d7de03"
dependencies = [
"alloc-no-stdlib",
"alloc-stdlib",
]
[[package]]
name = "bstr"
version = "1.11.3"
@ -881,7 +946,7 @@ dependencies = [
"memmap2",
"num-traits",
"num_cpus",
"rand",
"rand 0.8.5",
"rand_distr",
"rayon",
"safetensors",
@ -927,7 +992,7 @@ dependencies = [
"candle-nn",
"fancy-regex",
"num-traits",
"rand",
"rand 0.8.5",
"rayon",
"serde",
"serde_json",
@ -1613,6 +1678,27 @@ dependencies = [
"syn 1.0.109",
]
[[package]]
name = "derive_more"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "093242cf7570c207c83073cf82f79706fe7b8317e98620a47d5be7c3d8497678"
dependencies = [
"derive_more-impl",
]
[[package]]
name = "derive_more-impl"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.87",
"unicode-xid",
]
[[package]]
name = "deserr"
version = "0.6.3"
@ -2412,7 +2498,7 @@ dependencies = [
"cfg-if",
"crunchy",
"num-traits",
"rand",
"rand 0.8.5",
"rand_distr",
]
@ -2542,7 +2628,7 @@ dependencies = [
"http 1.2.0",
"indicatif",
"log",
"rand",
"rand 0.8.5",
"serde",
"serde_json",
"thiserror 1.0.69",
@ -2821,9 +2907,9 @@ dependencies = [
[[package]]
name = "impl-more"
version = "0.1.6"
version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "206ca75c9c03ba3d4ace2460e57b189f39f43de612c2f85836e65c929701bb2d"
checksum = "e8a5a9a0ff0086c7a148acb942baaabeadf9504d10400b5a05645853729b9cd2"
[[package]]
name = "index-scheduler"
@ -3691,10 +3777,11 @@ dependencies = [
"actix-rt",
"actix-utils",
"actix-web",
"actix-web-lab",
"anyhow",
"async-openai",
"async-trait",
"brotli",
"brotli 6.0.0",
"bstr",
"build-info",
"byte-unit",
@ -3736,7 +3823,7 @@ dependencies = [
"pin-project-lite",
"platform-dirs",
"prometheus",
"rand",
"rand 0.8.5",
"rayon",
"regex",
"reqwest",
@ -3785,7 +3872,7 @@ dependencies = [
"hmac",
"maplit",
"meilisearch-types",
"rand",
"rand 0.8.5",
"roaring",
"serde",
"serde_json",
@ -3915,7 +4002,7 @@ dependencies = [
"obkv",
"once_cell",
"ordered-float",
"rand",
"rand 0.8.5",
"rayon",
"rayon-par-bridge",
"rhai",
@ -4461,7 +4548,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48e4cc64c2ad9ebe670cb8fd69dd50ae301650392e81c05f9bfcb2d5bdbc24b0"
dependencies = [
"phf_shared",
"rand",
"rand 0.8.5",
]
[[package]]
@ -4741,7 +4828,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fadfaed2cd7f389d0161bb73eeb07b7b78f8691047a6f3e73caaeae55310a4a6"
dependencies = [
"bytes",
"rand",
"rand 0.8.5",
"ring",
"rustc-hash 2.1.0",
"rustls",
@ -4786,8 +4873,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
dependencies = [
"libc",
"rand_chacha",
"rand_core",
"rand_chacha 0.3.1",
"rand_core 0.6.4",
]
[[package]]
name = "rand"
version = "0.9.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97"
dependencies = [
"rand_chacha 0.9.0",
"rand_core 0.9.3",
]
[[package]]
@ -4797,7 +4894,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
dependencies = [
"ppv-lite86",
"rand_core",
"rand_core 0.6.4",
]
[[package]]
name = "rand_chacha"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb"
dependencies = [
"ppv-lite86",
"rand_core 0.9.3",
]
[[package]]
@ -4809,6 +4916,15 @@ dependencies = [
"getrandom 0.2.15",
]
[[package]]
name = "rand_core"
version = "0.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38"
dependencies = [
"getrandom 0.3.1",
]
[[package]]
name = "rand_distr"
version = "0.4.3"
@ -4816,7 +4932,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31"
dependencies = [
"num-traits",
"rand",
"rand 0.8.5",
]
[[package]]
@ -5114,7 +5230,7 @@ dependencies = [
"borsh",
"bytes",
"num-traits",
"rand",
"rand 0.8.5",
"rkyv",
"serde",
"serde_json",
@ -5382,6 +5498,19 @@ dependencies = [
"syn 2.0.87",
]
[[package]]
name = "serde_html_form"
version = "0.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
dependencies = [
"form_urlencoded",
"indexmap",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_json"
version = "1.0.140"
@ -5395,6 +5524,16 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59fab13f937fa393d08645bf3a84bdfe86e296747b506ada67bb15f10f218b2a"
dependencies = [
"itoa",
"serde",
]
[[package]]
name = "serde_plain"
version = "1.0.2"
@ -5813,7 +5952,7 @@ dependencies = [
"getrandom 0.2.15",
"once_cell",
"rustix",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@ -5993,7 +6132,7 @@ dependencies = [
"monostate",
"onig",
"paste",
"rand",
"rand 0.8.5",
"rayon",
"rayon-cond",
"regex",
@ -6367,6 +6506,12 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51733f11c9c4f72aa0c160008246859e340b00807569a0da0e7a1079b27ba85"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "unicode_categories"
version = "0.1.1"

View File

@ -113,6 +113,7 @@ utoipa = { version = "5.3.1", features = [
] }
utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] }
async-openai = "0.28.1"
actix-web-lab = { version = "0.24.1", default-features = false }
[dev-dependencies]
actix-rt = "2.10.0"

View File

@ -1,7 +1,8 @@
use std::mem;
use actix_web::web::{self, Data};
use actix_web::HttpResponse;
use actix_web::{Either, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event};
use async_openai::config::OpenAIConfig;
use async_openai::types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
@ -10,6 +11,7 @@ use async_openai::types::{
FunctionObjectArgs,
};
use async_openai::Client;
use futures::StreamExt;
use index_scheduler::IndexScheduler;
use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions;
@ -53,10 +55,22 @@ async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>,
web::Json(mut chat_completion): web::Json<CreateChatCompletionRequest>,
) -> Result<HttpResponse, ResponseError> {
) -> impl Responder {
// To enable later on, when the feature will be experimental
// index_scheduler.features().check_chat("Using the /chat route")?;
if chat_completion.stream.unwrap_or(false) {
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
} else {
Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await)
}
}
async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> {
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
@ -119,7 +133,7 @@ async fn chat(
.build()
.unwrap(),
);
response = dbg!(client.chat().create(chat_completion.clone()).await.unwrap());
response = client.chat().create(chat_completion.clone()).await.unwrap();
let choice = &mut response.choices[0];
match choice.finish_reason {
@ -221,6 +235,24 @@ async fn chat(
Ok(HttpResponse::Ok().json(response))
}
async fn streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> impl Responder {
assert!(chat_completion.stream.unwrap_or(false));
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
let client = Client::with_config(config);
let response = client.chat().create_stream(chat_completion).await.unwrap();
actix_web_lab::sse::Sse::from_stream(response.map(|response| {
response
.map(|mut r| Event::Data(sse::Data::new_json(r.choices.pop().unwrap().delta).unwrap()))
}))
}
#[derive(Deserialize)]
struct SearchInIndexParameters {
/// The index uid to search in.