mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-05-14 16:23:57 +02:00
Aggregate tool calls and display the calls to make.
This commit is contained in:
parent
7d62307739
commit
af482d8ee9
@ -1,3 +1,4 @@
|
||||
use std::collections::HashMap;
|
||||
use std::mem;
|
||||
|
||||
use actix_web::web::{self, Data};
|
||||
@ -5,10 +6,11 @@ use actix_web::{Either, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{self, Event};
|
||||
use async_openai::config::OpenAIConfig;
|
||||
use async_openai::types::{
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||
ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs,
|
||||
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
|
||||
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
|
||||
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason,
|
||||
FunctionObjectArgs,
|
||||
FunctionCallStream, FunctionObjectArgs,
|
||||
};
|
||||
use async_openai::Client;
|
||||
use futures::StreamExt;
|
||||
@ -59,6 +61,12 @@ async fn chat(
|
||||
// To enable later on, when the feature will be experimental
|
||||
// index_scheduler.features().check_chat("Using the /chat route")?;
|
||||
|
||||
assert_eq!(
|
||||
chat_completion.n.unwrap_or(1),
|
||||
1,
|
||||
"Meilisearch /chat only support one completion at a time (n = 1, n = null)"
|
||||
);
|
||||
|
||||
if chat_completion.stream.unwrap_or(false) {
|
||||
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
|
||||
} else {
|
||||
@ -76,12 +84,6 @@ async fn non_streamed_chat(
|
||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
||||
let client = Client::with_config(config);
|
||||
|
||||
assert_eq!(
|
||||
chat_completion.n.unwrap_or(1),
|
||||
1,
|
||||
"Meilisearch /chat only support one completion at a time (n = 1, n = null)"
|
||||
);
|
||||
|
||||
let rtxn = index_scheduler.read_txn().unwrap();
|
||||
let search_in_index_description = index_scheduler
|
||||
.chat_prompts(&rtxn, "searchInIndex-description")
|
||||
@ -240,19 +242,116 @@ async fn streamed_chat(
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> impl Responder {
|
||||
assert!(chat_completion.stream.unwrap_or(false));
|
||||
|
||||
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
|
||||
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
|
||||
|
||||
let rtxn = index_scheduler.read_txn().unwrap();
|
||||
let search_in_index_description = index_scheduler
|
||||
.chat_prompts(&rtxn, "searchInIndex-description")
|
||||
.unwrap()
|
||||
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION)
|
||||
.to_string();
|
||||
let search_in_index_q_param_description = index_scheduler
|
||||
.chat_prompts(&rtxn, "searchInIndex-q-param-description")
|
||||
.unwrap()
|
||||
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION)
|
||||
.to_string();
|
||||
let search_in_index_index_description = index_scheduler
|
||||
.chat_prompts(&rtxn, "searchInIndex-index-param-description")
|
||||
.unwrap()
|
||||
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION)
|
||||
.to_string();
|
||||
drop(rtxn);
|
||||
|
||||
let tools = chat_completion.tools.get_or_insert_default();
|
||||
tools.push(
|
||||
ChatCompletionToolArgs::default()
|
||||
.r#type(ChatCompletionToolType::Function)
|
||||
.function(
|
||||
FunctionObjectArgs::default()
|
||||
.name("searchInIndex")
|
||||
.description(&search_in_index_description)
|
||||
.parameters(json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"index_uid": {
|
||||
"type": "string",
|
||||
"enum": ["main"],
|
||||
"description": search_in_index_index_description,
|
||||
},
|
||||
"q": {
|
||||
"type": ["string", "null"],
|
||||
"description": search_in_index_q_param_description,
|
||||
}
|
||||
},
|
||||
"required": ["index_uid", "q"],
|
||||
"additionalProperties": false,
|
||||
}))
|
||||
.strict(true)
|
||||
.build()
|
||||
.unwrap(),
|
||||
)
|
||||
.build()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
||||
let client = Client::with_config(config);
|
||||
let response = client.chat().create_stream(chat_completion).await.unwrap();
|
||||
actix_web_lab::sse::Sse::from_stream(response.map(|response| {
|
||||
response
|
||||
.map(|mut r| Event::Data(sse::Data::new_json(r.choices.pop().unwrap().delta).unwrap()))
|
||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||
actix_web_lab::sse::Sse::from_stream(response.map(move |response| {
|
||||
response.map(|mut r| {
|
||||
let delta = r.choices.pop().unwrap().delta;
|
||||
let ChatCompletionStreamResponseDelta {
|
||||
ref content,
|
||||
ref function_call,
|
||||
ref tool_calls,
|
||||
ref role,
|
||||
ref refusal,
|
||||
} = delta;
|
||||
|
||||
match tool_calls {
|
||||
Some(tool_calls) => {
|
||||
for chunk in tool_calls {
|
||||
let ChatCompletionMessageToolCallChunk { index, id, r#type, function } =
|
||||
chunk;
|
||||
let FunctionCallStream { ref name, ref arguments } =
|
||||
function.as_ref().unwrap();
|
||||
global_tool_calls
|
||||
.entry(*index)
|
||||
.or_insert_with(|| Call {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
})
|
||||
.append(arguments.as_ref().unwrap());
|
||||
}
|
||||
}
|
||||
None if !global_tool_calls.is_empty() => {
|
||||
dbg!(&global_tool_calls);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
Event::Data(sse::Data::new_json(delta).unwrap())
|
||||
})
|
||||
}))
|
||||
}
|
||||
|
||||
/// The structure used to aggregate the function calls to make.
|
||||
#[derive(Debug)]
|
||||
struct Call {
|
||||
id: String,
|
||||
function_name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
impl Call {
|
||||
fn append(&mut self, arguments: &str) {
|
||||
self.arguments.push_str(arguments);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct SearchInIndexParameters {
|
||||
/// The index uid to search in.
|
||||
|
Loading…
x
Reference in New Issue
Block a user