Streaming supports tool calling

This commit is contained in:
Clément Renault 2025-05-14 14:58:01 +02:00
parent da7d651f4b
commit 6e8b371111
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -1,9 +1,10 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::mem; use std::mem;
use std::time::Duration;
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
use actix_web::{Either, HttpResponse, Responder}; use actix_web::{Either, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event}; use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig; use async_openai::config::OpenAIConfig;
use async_openai::types::{ use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
@ -15,7 +16,6 @@ use async_openai::types::{
}; };
use async_openai::Client; use async_openai::Client;
use futures::StreamExt; use futures::StreamExt;
use futures_util::stream;
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions; use meilisearch_types::keys::actions;
@ -59,7 +59,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
async fn chat( async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT_GET }>, Data<IndexScheduler>>,
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
web::Json(mut chat_completion): web::Json<CreateChatCompletionRequest>, web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
) -> impl Responder { ) -> impl Responder {
// To enable later on, when the feature will be experimental // To enable later on, when the feature will be experimental
// index_scheduler.features().check_chat("Using the /chat route")?; // index_scheduler.features().check_chat("Using the /chat route")?;
@ -298,27 +298,47 @@ async fn streamed_chat(
.unwrap(), .unwrap(),
); );
let (tx, rx) = tokio::sync::mpsc::channel(10);
let _join_handle = Handle::current().spawn(async move {
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
let client = Client::with_config(config); let client = Client::with_config(config);
let response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
let mut global_tool_calls = HashMap::<u32, Call>::new(); let mut global_tool_calls = HashMap::<u32, Call>::new();
actix_web_lab::sse::Sse::from_stream(response.flat_map(move |response| match response {
'main: loop {
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
while let Some(result) = response.next().await {
match result {
Ok(resp) => { Ok(resp) => {
let delta = &resp.choices[0].delta; let delta = &resp.choices[0].delta;
let ChatCompletionStreamResponseDelta { let ChatCompletionStreamResponseDelta {
content: _, content,
function_call: _, function_call: _,
ref tool_calls, ref tool_calls,
role: _, role: _,
refusal: _, refusal: _,
} = delta; } = delta;
if content.is_none() && tool_calls.is_none() && global_tool_calls.is_empty()
{
break 'main;
}
if let Some(text) = content {
tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await.unwrap()
}
match tool_calls { match tool_calls {
Some(tool_calls) => { Some(tool_calls) => {
for chunk in tool_calls { for chunk in tool_calls {
let ChatCompletionMessageToolCallChunk { index, id, r#type: _, function } = let ChatCompletionMessageToolCallChunk {
chunk; index,
let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); id,
r#type: _,
function,
} = chunk;
let FunctionCallStream { name, arguments } =
function.as_ref().unwrap();
global_tool_calls global_tool_calls
.entry(*index) .entry(*index)
.or_insert_with(|| Call { .or_insert_with(|| Call {
@ -328,10 +348,12 @@ async fn streamed_chat(
}) })
.append(arguments.as_ref().unwrap()); .append(arguments.as_ref().unwrap());
} }
stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]) tx.send(Event::Data(sse::Data::new_json(&resp).unwrap()))
.await
.unwrap()
} }
None if !global_tool_calls.is_empty() => { None if !global_tool_calls.is_empty() => {
dbg!(&global_tool_calls); // dbg!(&global_tool_calls);
let (meili_calls, other_calls): (Vec<_>, Vec<_>) = let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
mem::take(&mut global_tool_calls) mem::take(&mut global_tool_calls)
@ -387,23 +409,24 @@ async fn streamed_chat(
) )
.unwrap(); .unwrap();
// let permit = search_queue.try_get_search_permit().await?; let permit =
search_queue.try_get_search_permit().await.unwrap();
let features = index_scheduler.features(); let features = index_scheduler.features();
let index_cloned = index.clone(); let index_cloned = index.clone();
// let search_result = tokio::task::spawn_blocking(move || { let search_result = tokio::task::spawn_blocking(move || {
let search_result = perform_search( perform_search(
index_uid.to_string(), index_uid.to_string(),
&index_cloned, &index_cloned,
query, query,
search_kind, search_kind,
RetrieveVectors::new(false), RetrieveVectors::new(false),
features, features,
); )
// }) })
// .await; .await;
// permit.drop().await; permit.drop().await;
// let search_result = search_result.unwrap(); let search_result = search_result.unwrap();
if let Ok(ref search_result) = search_result { if let Ok(ref search_result) = search_result {
// aggregate.succeed(search_result); // aggregate.succeed(search_result);
if search_result.degraded { if search_result.degraded {
@ -418,31 +441,31 @@ async fn streamed_chat(
search_result.hits.into_iter().map(|doc| doc.document), search_result.hits.into_iter().map(|doc| doc.document),
); );
let text = formatted.join("\n"); let text = formatted.join("\n");
chat_completion.messages.push(ChatCompletionRequestMessage::Tool( chat_completion.messages.push(
ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage { ChatCompletionRequestToolMessage {
tool_call_id: call.id, tool_call_id: call.id,
content: ChatCompletionRequestToolMessageContent::Text(text), content:
ChatCompletionRequestToolMessageContent::Text(
text,
),
}, },
)); ),
);
}
}
None => (),
}
}
Err(_err) => {
// writeln!(lock, "error: {err}").unwrap();
}
}
}
} }
let response = Handle::current().block_on(async {
client.chat().create_stream(chat_completion.clone()).await.unwrap()
}); });
// stream::iter(vec![ Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))
// Ok(Event::Data(sse::Data::new_json(json!({ "text": "Hello" })).unwrap())),
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " world" })).unwrap())),
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " !" })).unwrap())),
// ])
response
}
None => stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]),
}
}
Err(err) => stream::iter(vec![Err(err)]),
}))
} }
/// The structure used to aggregate the function calls to make. /// The structure used to aggregate the function calls to make.