From 6e8b371111c2f115fc86e0b5b384961660ef6ec8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 14 May 2025 14:58:01 +0200 Subject: [PATCH] Streaming supports tool calling --- crates/meilisearch/src/routes/chat.rs | 293 ++++++++++++++------------ 1 file changed, 158 insertions(+), 135 deletions(-) diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 207feb256..d2def9488 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; use std::mem; +use std::time::Duration; use actix_web::web::{self, Data}; use actix_web::{Either, HttpResponse, Responder}; -use actix_web_lab::sse::{self, Event}; +use actix_web_lab::sse::{self, Event, Sse}; use async_openai::config::OpenAIConfig; use async_openai::types::{ ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, @@ -15,7 +16,6 @@ use async_openai::types::{ }; use async_openai::Client; use futures::StreamExt; -use futures_util::stream; use index_scheduler::IndexScheduler; use meilisearch_types::error::ResponseError; use meilisearch_types::keys::actions; @@ -59,7 +59,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) { async fn chat( index_scheduler: GuardedData, Data>, search_queue: web::Data, - web::Json(mut chat_completion): web::Json, + web::Json(chat_completion): web::Json, ) -> impl Responder { // To enable later on, when the feature will be experimental // index_scheduler.features().check_chat("Using the /chat route")?; @@ -298,151 +298,174 @@ async fn streamed_chat( .unwrap(), ); - let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base - let client = Client::with_config(config); - let response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); - let mut global_tool_calls = HashMap::::new(); - actix_web_lab::sse::Sse::from_stream(response.flat_map(move |response| match response { - Ok(resp) => { - let delta = &resp.choices[0].delta; - let ChatCompletionStreamResponseDelta { - content: _, - function_call: _, - ref tool_calls, - role: _, - refusal: _, - } = delta; + let (tx, rx) = tokio::sync::mpsc::channel(10); + let _join_handle = Handle::current().spawn(async move { + let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base + let client = Client::with_config(config); + let mut global_tool_calls = HashMap::::new(); - match tool_calls { - Some(tool_calls) => { - for chunk in tool_calls { - let ChatCompletionMessageToolCallChunk { index, id, r#type: _, function } = - chunk; - let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); - global_tool_calls - .entry(*index) - .or_insert_with(|| Call { - id: id.as_ref().unwrap().clone(), - function_name: name.as_ref().unwrap().clone(), - arguments: arguments.as_ref().unwrap().clone(), - }) - .append(arguments.as_ref().unwrap()); - } - stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]) - } - None if !global_tool_calls.is_empty() => { - dbg!(&global_tool_calls); + 'main: loop { + let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); - let (meili_calls, other_calls): (Vec<_>, Vec<_>) = - mem::take(&mut global_tool_calls) - .into_iter() - .map(|(_, call)| ChatCompletionMessageToolCall { - id: call.id, - r#type: ChatCompletionToolType::Function, - function: FunctionCall { - name: call.function_name, - arguments: call.arguments, - }, - }) - .partition(|call| call.function.name == "searchInIndex"); + while let Some(result) = response.next().await { + match result { + Ok(resp) => { + let delta = &resp.choices[0].delta; + let ChatCompletionStreamResponseDelta { + content, + function_call: _, + ref tool_calls, + role: _, + refusal: _, + } = delta; - chat_completion.messages.push( - ChatCompletionRequestAssistantMessageArgs::default() - .tool_calls(meili_calls.clone()) - .build() - .unwrap() - .into(), - ); - - for call in meili_calls { - let SearchInIndexParameters { index_uid, q } = - serde_json::from_str(&call.function.arguments).unwrap(); - - let mut query = SearchQuery { - q, - hybrid: Some(HybridQuery { - semantic_ratio: SemanticRatio::default(), - embedder: EMBEDDER_NAME.to_string(), - }), - limit: 20, - ..Default::default() - }; - - // Tenant token search_rules. - if let Some(search_rules) = - index_scheduler.filters().get_index_search_rules(&index_uid) + if content.is_none() && tool_calls.is_none() && global_tool_calls.is_empty() { - add_search_rules(&mut query.filter, search_rules); + break 'main; } - // TBD - // let mut aggregate = SearchAggregator::::from_query(&query); + if let Some(text) = content { + tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())).await.unwrap() + } - let index = index_scheduler.index(&index_uid).unwrap(); - let search_kind = search_kind( - &query, - index_scheduler.get_ref(), - index_uid.to_string(), - &index, - ) - .unwrap(); - - // let permit = search_queue.try_get_search_permit().await?; - let features = index_scheduler.features(); - let index_cloned = index.clone(); - // let search_result = tokio::task::spawn_blocking(move || { - let search_result = perform_search( - index_uid.to_string(), - &index_cloned, - query, - search_kind, - RetrieveVectors::new(false), - features, - ); - // }) - // .await; - // permit.drop().await; - - // let search_result = search_result.unwrap(); - if let Ok(ref search_result) = search_result { - // aggregate.succeed(search_result); - if search_result.degraded { - MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); + match tool_calls { + Some(tool_calls) => { + for chunk in tool_calls { + let ChatCompletionMessageToolCallChunk { + index, + id, + r#type: _, + function, + } = chunk; + let FunctionCallStream { name, arguments } = + function.as_ref().unwrap(); + global_tool_calls + .entry(*index) + .or_insert_with(|| Call { + id: id.as_ref().unwrap().clone(), + function_name: name.as_ref().unwrap().clone(), + arguments: arguments.as_ref().unwrap().clone(), + }) + .append(arguments.as_ref().unwrap()); + } + tx.send(Event::Data(sse::Data::new_json(&resp).unwrap())) + .await + .unwrap() } + None if !global_tool_calls.is_empty() => { + // dbg!(&global_tool_calls); + + let (meili_calls, other_calls): (Vec<_>, Vec<_>) = + mem::take(&mut global_tool_calls) + .into_iter() + .map(|(_, call)| ChatCompletionMessageToolCall { + id: call.id, + r#type: ChatCompletionToolType::Function, + function: FunctionCall { + name: call.function_name, + arguments: call.arguments, + }, + }) + .partition(|call| call.function.name == "searchInIndex"); + + chat_completion.messages.push( + ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(meili_calls.clone()) + .build() + .unwrap() + .into(), + ); + + for call in meili_calls { + let SearchInIndexParameters { index_uid, q } = + serde_json::from_str(&call.function.arguments).unwrap(); + + let mut query = SearchQuery { + q, + hybrid: Some(HybridQuery { + semantic_ratio: SemanticRatio::default(), + embedder: EMBEDDER_NAME.to_string(), + }), + limit: 20, + ..Default::default() + }; + + // Tenant token search_rules. + if let Some(search_rules) = + index_scheduler.filters().get_index_search_rules(&index_uid) + { + add_search_rules(&mut query.filter, search_rules); + } + + // TBD + // let mut aggregate = SearchAggregator::::from_query(&query); + + let index = index_scheduler.index(&index_uid).unwrap(); + let search_kind = search_kind( + &query, + index_scheduler.get_ref(), + index_uid.to_string(), + &index, + ) + .unwrap(); + + let permit = + search_queue.try_get_search_permit().await.unwrap(); + let features = index_scheduler.features(); + let index_cloned = index.clone(); + let search_result = tokio::task::spawn_blocking(move || { + perform_search( + index_uid.to_string(), + &index_cloned, + query, + search_kind, + RetrieveVectors::new(false), + features, + ) + }) + .await; + permit.drop().await; + + let search_result = search_result.unwrap(); + if let Ok(ref search_result) = search_result { + // aggregate.succeed(search_result); + if search_result.degraded { + MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc(); + } + } + // analytics.publish(aggregate, &req); + + let search_result = search_result.unwrap(); + let formatted = format_documents( + &index, + search_result.hits.into_iter().map(|doc| doc.document), + ); + let text = formatted.join("\n"); + chat_completion.messages.push( + ChatCompletionRequestMessage::Tool( + ChatCompletionRequestToolMessage { + tool_call_id: call.id, + content: + ChatCompletionRequestToolMessageContent::Text( + text, + ), + }, + ), + ); + } + } + None => (), } - // analytics.publish(aggregate, &req); - - let search_result = search_result.unwrap(); - let formatted = format_documents( - &index, - search_result.hits.into_iter().map(|doc| doc.document), - ); - let text = formatted.join("\n"); - chat_completion.messages.push(ChatCompletionRequestMessage::Tool( - ChatCompletionRequestToolMessage { - tool_call_id: call.id, - content: ChatCompletionRequestToolMessageContent::Text(text), - }, - )); } - - let response = Handle::current().block_on(async { - client.chat().create_stream(chat_completion.clone()).await.unwrap() - }); - - // stream::iter(vec![ - // Ok(Event::Data(sse::Data::new_json(json!({ "text": "Hello" })).unwrap())), - // Ok(Event::Data(sse::Data::new_json(json!({ "text": " world" })).unwrap())), - // Ok(Event::Data(sse::Data::new_json(json!({ "text": " !" })).unwrap())), - // ]) - - response + Err(_err) => { + // writeln!(lock, "error: {err}").unwrap(); + } } - None => stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]), } } - Err(err) => stream::iter(vec![Err(err)]), - })) + }); + + Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)) } /// The structure used to aggregate the function calls to make.