From af482d8ee93d24ab831bbdc5529e5e1b2be62fd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Wed, 14 May 2025 11:53:03 +0200 Subject: [PATCH] Aggregate tool calls and display the calls to make. --- crates/meilisearch/src/routes/chat.rs | 127 +++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 14 deletions(-) diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index ad46d91c8..4c9c9934b 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::mem; use actix_web::web::{self, Data}; @@ -5,10 +6,11 @@ use actix_web::{Either, HttpResponse, Responder}; use actix_web_lab::sse::{self, Event}; use async_openai::config::OpenAIConfig; use async_openai::types::{ - ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, - ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, + ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs, + ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, + ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason, - FunctionObjectArgs, + FunctionCallStream, FunctionObjectArgs, }; use async_openai::Client; use futures::StreamExt; @@ -59,6 +61,12 @@ async fn chat( // To enable later on, when the feature will be experimental // index_scheduler.features().check_chat("Using the /chat route")?; + assert_eq!( + chat_completion.n.unwrap_or(1), + 1, + "Meilisearch /chat only support one completion at a time (n = 1, n = null)" + ); + if chat_completion.stream.unwrap_or(false) { Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await) } else { @@ -76,12 +84,6 @@ async fn non_streamed_chat( let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let client = Client::with_config(config); - assert_eq!( - chat_completion.n.unwrap_or(1), - 1, - "Meilisearch /chat only support one completion at a time (n = 1, n = null)" - ); - let rtxn = index_scheduler.read_txn().unwrap(); let search_in_index_description = index_scheduler .chat_prompts(&rtxn, "searchInIndex-description") @@ -240,19 +242,116 @@ async fn streamed_chat( search_queue: web::Data, mut chat_completion: CreateChatCompletionRequest, ) -> impl Responder { - assert!(chat_completion.stream.unwrap_or(false)); - let api_key = std::env::var("MEILI_OPENAI_API_KEY") .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); + + let rtxn = index_scheduler.read_txn().unwrap(); + let search_in_index_description = index_scheduler + .chat_prompts(&rtxn, "searchInIndex-description") + .unwrap() + .unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION) + .to_string(); + let search_in_index_q_param_description = index_scheduler + .chat_prompts(&rtxn, "searchInIndex-q-param-description") + .unwrap() + .unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION) + .to_string(); + let search_in_index_index_description = index_scheduler + .chat_prompts(&rtxn, "searchInIndex-index-param-description") + .unwrap() + .unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION) + .to_string(); + drop(rtxn); + + let tools = chat_completion.tools.get_or_insert_default(); + tools.push( + ChatCompletionToolArgs::default() + .r#type(ChatCompletionToolType::Function) + .function( + FunctionObjectArgs::default() + .name("searchInIndex") + .description(&search_in_index_description) + .parameters(json!({ + "type": "object", + "properties": { + "index_uid": { + "type": "string", + "enum": ["main"], + "description": search_in_index_index_description, + }, + "q": { + "type": ["string", "null"], + "description": search_in_index_q_param_description, + } + }, + "required": ["index_uid", "q"], + "additionalProperties": false, + })) + .strict(true) + .build() + .unwrap(), + ) + .build() + .unwrap(), + ); + let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let client = Client::with_config(config); let response = client.chat().create_stream(chat_completion).await.unwrap(); - actix_web_lab::sse::Sse::from_stream(response.map(|response| { - response - .map(|mut r| Event::Data(sse::Data::new_json(r.choices.pop().unwrap().delta).unwrap())) + let mut global_tool_calls = HashMap::::new(); + actix_web_lab::sse::Sse::from_stream(response.map(move |response| { + response.map(|mut r| { + let delta = r.choices.pop().unwrap().delta; + let ChatCompletionStreamResponseDelta { + ref content, + ref function_call, + ref tool_calls, + ref role, + ref refusal, + } = delta; + + match tool_calls { + Some(tool_calls) => { + for chunk in tool_calls { + let ChatCompletionMessageToolCallChunk { index, id, r#type, function } = + chunk; + let FunctionCallStream { ref name, ref arguments } = + function.as_ref().unwrap(); + global_tool_calls + .entry(*index) + .or_insert_with(|| Call { + id: id.as_ref().unwrap().clone(), + function_name: name.as_ref().unwrap().clone(), + arguments: arguments.as_ref().unwrap().clone(), + }) + .append(arguments.as_ref().unwrap()); + } + } + None if !global_tool_calls.is_empty() => { + dbg!(&global_tool_calls); + } + _ => (), + } + + Event::Data(sse::Data::new_json(delta).unwrap()) + }) })) } +/// The structure used to aggregate the function calls to make. +#[derive(Debug)] +struct Call { + id: String, + function_name: String, + arguments: String, +} + +impl Call { + fn append(&mut self, arguments: &str) { + self.arguments.push_str(arguments); + } +} + #[derive(Deserialize)] struct SearchInIndexParameters { /// The index uid to search in.