Aggregate tool calls and display the calls to make.

This commit is contained in:
Clément Renault 2025-05-14 11:53:03 +02:00
parent 7d62307739
commit af482d8ee9
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -1,3 +1,4 @@
use std::collections::HashMap;
use std::mem; use std::mem;
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
@ -5,10 +6,11 @@ use actix_web::{Either, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event}; use actix_web_lab::sse::{self, Event};
use async_openai::config::OpenAIConfig; use async_openai::config::OpenAIConfig;
use async_openai::types::{ use async_openai::types::{
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs,
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason, ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason,
FunctionObjectArgs, FunctionCallStream, FunctionObjectArgs,
}; };
use async_openai::Client; use async_openai::Client;
use futures::StreamExt; use futures::StreamExt;
@ -59,6 +61,12 @@ async fn chat(
// To enable later on, when the feature will be experimental // To enable later on, when the feature will be experimental
// index_scheduler.features().check_chat("Using the /chat route")?; // index_scheduler.features().check_chat("Using the /chat route")?;
assert_eq!(
chat_completion.n.unwrap_or(1),
1,
"Meilisearch /chat only support one completion at a time (n = 1, n = null)"
);
if chat_completion.stream.unwrap_or(false) { if chat_completion.stream.unwrap_or(false) {
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await) Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
} else { } else {
@ -76,12 +84,6 @@ async fn non_streamed_chat(
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
let client = Client::with_config(config); let client = Client::with_config(config);
assert_eq!(
chat_completion.n.unwrap_or(1),
1,
"Meilisearch /chat only support one completion at a time (n = 1, n = null)"
);
let rtxn = index_scheduler.read_txn().unwrap(); let rtxn = index_scheduler.read_txn().unwrap();
let search_in_index_description = index_scheduler let search_in_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-description") .chat_prompts(&rtxn, "searchInIndex-description")
@ -240,19 +242,116 @@ async fn streamed_chat(
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest, mut chat_completion: CreateChatCompletionRequest,
) -> impl Responder { ) -> impl Responder {
assert!(chat_completion.stream.unwrap_or(false));
let api_key = std::env::var("MEILI_OPENAI_API_KEY") let api_key = std::env::var("MEILI_OPENAI_API_KEY")
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)"); .expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
let rtxn = index_scheduler.read_txn().unwrap();
let search_in_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION)
.to_string();
let search_in_index_q_param_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-q-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION)
.to_string();
let search_in_index_index_description = index_scheduler
.chat_prompts(&rtxn, "searchInIndex-index-param-description")
.unwrap()
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION)
.to_string();
drop(rtxn);
let tools = chat_completion.tools.get_or_insert_default();
tools.push(
ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(
FunctionObjectArgs::default()
.name("searchInIndex")
.description(&search_in_index_description)
.parameters(json!({
"type": "object",
"properties": {
"index_uid": {
"type": "string",
"enum": ["main"],
"description": search_in_index_index_description,
},
"q": {
"type": ["string", "null"],
"description": search_in_index_q_param_description,
}
},
"required": ["index_uid", "q"],
"additionalProperties": false,
}))
.strict(true)
.build()
.unwrap(),
)
.build()
.unwrap(),
);
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
let client = Client::with_config(config); let client = Client::with_config(config);
let response = client.chat().create_stream(chat_completion).await.unwrap(); let response = client.chat().create_stream(chat_completion).await.unwrap();
actix_web_lab::sse::Sse::from_stream(response.map(|response| { let mut global_tool_calls = HashMap::<u32, Call>::new();
response actix_web_lab::sse::Sse::from_stream(response.map(move |response| {
.map(|mut r| Event::Data(sse::Data::new_json(r.choices.pop().unwrap().delta).unwrap())) response.map(|mut r| {
let delta = r.choices.pop().unwrap().delta;
let ChatCompletionStreamResponseDelta {
ref content,
ref function_call,
ref tool_calls,
ref role,
ref refusal,
} = delta;
match tool_calls {
Some(tool_calls) => {
for chunk in tool_calls {
let ChatCompletionMessageToolCallChunk { index, id, r#type, function } =
chunk;
let FunctionCallStream { ref name, ref arguments } =
function.as_ref().unwrap();
global_tool_calls
.entry(*index)
.or_insert_with(|| Call {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
})
.append(arguments.as_ref().unwrap());
}
}
None if !global_tool_calls.is_empty() => {
dbg!(&global_tool_calls);
}
_ => (),
}
Event::Data(sse::Data::new_json(delta).unwrap())
})
})) }))
} }
/// The structure used to aggregate the function calls to make.
#[derive(Debug)]
struct Call {
id: String,
function_name: String,
arguments: String,
}
impl Call {
fn append(&mut self, arguments: &str) {
self.arguments.push_str(arguments);
}
}
#[derive(Deserialize)] #[derive(Deserialize)]
struct SearchInIndexParameters { struct SearchInIndexParameters {
/// The index uid to search in. /// The index uid to search in.