mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-05-14 16:23:57 +02:00
Aggregate tool calls and display the calls to make.
This commit is contained in:
parent
7d62307739
commit
af482d8ee9
@ -1,3 +1,4 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
use actix_web::web::{self, Data};
|
use actix_web::web::{self, Data};
|
||||||
@ -5,10 +6,11 @@ use actix_web::{Either, HttpResponse, Responder};
|
|||||||
use actix_web_lab::sse::{self, Event};
|
use actix_web_lab::sse::{self, Event};
|
||||||
use async_openai::config::OpenAIConfig;
|
use async_openai::config::OpenAIConfig;
|
||||||
use async_openai::types::{
|
use async_openai::types::{
|
||||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs,
|
||||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage,
|
||||||
|
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
|
||||||
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason,
|
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason,
|
||||||
FunctionObjectArgs,
|
FunctionCallStream, FunctionObjectArgs,
|
||||||
};
|
};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use futures::StreamExt;
|
use futures::StreamExt;
|
||||||
@ -59,6 +61,12 @@ async fn chat(
|
|||||||
// To enable later on, when the feature will be experimental
|
// To enable later on, when the feature will be experimental
|
||||||
// index_scheduler.features().check_chat("Using the /chat route")?;
|
// index_scheduler.features().check_chat("Using the /chat route")?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
chat_completion.n.unwrap_or(1),
|
||||||
|
1,
|
||||||
|
"Meilisearch /chat only support one completion at a time (n = 1, n = null)"
|
||||||
|
);
|
||||||
|
|
||||||
if chat_completion.stream.unwrap_or(false) {
|
if chat_completion.stream.unwrap_or(false) {
|
||||||
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
|
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
|
||||||
} else {
|
} else {
|
||||||
@ -76,12 +84,6 @@ async fn non_streamed_chat(
|
|||||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
|
|
||||||
assert_eq!(
|
|
||||||
chat_completion.n.unwrap_or(1),
|
|
||||||
1,
|
|
||||||
"Meilisearch /chat only support one completion at a time (n = 1, n = null)"
|
|
||||||
);
|
|
||||||
|
|
||||||
let rtxn = index_scheduler.read_txn().unwrap();
|
let rtxn = index_scheduler.read_txn().unwrap();
|
||||||
let search_in_index_description = index_scheduler
|
let search_in_index_description = index_scheduler
|
||||||
.chat_prompts(&rtxn, "searchInIndex-description")
|
.chat_prompts(&rtxn, "searchInIndex-description")
|
||||||
@ -240,19 +242,116 @@ async fn streamed_chat(
|
|||||||
search_queue: web::Data<SearchQueue>,
|
search_queue: web::Data<SearchQueue>,
|
||||||
mut chat_completion: CreateChatCompletionRequest,
|
mut chat_completion: CreateChatCompletionRequest,
|
||||||
) -> impl Responder {
|
) -> impl Responder {
|
||||||
assert!(chat_completion.stream.unwrap_or(false));
|
|
||||||
|
|
||||||
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
|
let api_key = std::env::var("MEILI_OPENAI_API_KEY")
|
||||||
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
|
.expect("cannot find OpenAI API Key (MEILI_OPENAI_API_KEY)");
|
||||||
|
|
||||||
|
let rtxn = index_scheduler.read_txn().unwrap();
|
||||||
|
let search_in_index_description = index_scheduler
|
||||||
|
.chat_prompts(&rtxn, "searchInIndex-description")
|
||||||
|
.unwrap()
|
||||||
|
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_TOOL_DESCRIPTION)
|
||||||
|
.to_string();
|
||||||
|
let search_in_index_q_param_description = index_scheduler
|
||||||
|
.chat_prompts(&rtxn, "searchInIndex-q-param-description")
|
||||||
|
.unwrap()
|
||||||
|
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_Q_PARAMETER_TOOL_DESCRIPTION)
|
||||||
|
.to_string();
|
||||||
|
let search_in_index_index_description = index_scheduler
|
||||||
|
.chat_prompts(&rtxn, "searchInIndex-index-param-description")
|
||||||
|
.unwrap()
|
||||||
|
.unwrap_or(DEFAULT_SEARCH_IN_INDEX_INDEX_PARAMETER_TOOL_DESCRIPTION)
|
||||||
|
.to_string();
|
||||||
|
drop(rtxn);
|
||||||
|
|
||||||
|
let tools = chat_completion.tools.get_or_insert_default();
|
||||||
|
tools.push(
|
||||||
|
ChatCompletionToolArgs::default()
|
||||||
|
.r#type(ChatCompletionToolType::Function)
|
||||||
|
.function(
|
||||||
|
FunctionObjectArgs::default()
|
||||||
|
.name("searchInIndex")
|
||||||
|
.description(&search_in_index_description)
|
||||||
|
.parameters(json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"index_uid": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["main"],
|
||||||
|
"description": search_in_index_index_description,
|
||||||
|
},
|
||||||
|
"q": {
|
||||||
|
"type": ["string", "null"],
|
||||||
|
"description": search_in_index_q_param_description,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["index_uid", "q"],
|
||||||
|
"additionalProperties": false,
|
||||||
|
}))
|
||||||
|
.strict(true)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.build()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
|
||||||
let client = Client::with_config(config);
|
let client = Client::with_config(config);
|
||||||
let response = client.chat().create_stream(chat_completion).await.unwrap();
|
let response = client.chat().create_stream(chat_completion).await.unwrap();
|
||||||
actix_web_lab::sse::Sse::from_stream(response.map(|response| {
|
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||||
response
|
actix_web_lab::sse::Sse::from_stream(response.map(move |response| {
|
||||||
.map(|mut r| Event::Data(sse::Data::new_json(r.choices.pop().unwrap().delta).unwrap()))
|
response.map(|mut r| {
|
||||||
|
let delta = r.choices.pop().unwrap().delta;
|
||||||
|
let ChatCompletionStreamResponseDelta {
|
||||||
|
ref content,
|
||||||
|
ref function_call,
|
||||||
|
ref tool_calls,
|
||||||
|
ref role,
|
||||||
|
ref refusal,
|
||||||
|
} = delta;
|
||||||
|
|
||||||
|
match tool_calls {
|
||||||
|
Some(tool_calls) => {
|
||||||
|
for chunk in tool_calls {
|
||||||
|
let ChatCompletionMessageToolCallChunk { index, id, r#type, function } =
|
||||||
|
chunk;
|
||||||
|
let FunctionCallStream { ref name, ref arguments } =
|
||||||
|
function.as_ref().unwrap();
|
||||||
|
global_tool_calls
|
||||||
|
.entry(*index)
|
||||||
|
.or_insert_with(|| Call {
|
||||||
|
id: id.as_ref().unwrap().clone(),
|
||||||
|
function_name: name.as_ref().unwrap().clone(),
|
||||||
|
arguments: arguments.as_ref().unwrap().clone(),
|
||||||
|
})
|
||||||
|
.append(arguments.as_ref().unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None if !global_tool_calls.is_empty() => {
|
||||||
|
dbg!(&global_tool_calls);
|
||||||
|
}
|
||||||
|
_ => (),
|
||||||
|
}
|
||||||
|
|
||||||
|
Event::Data(sse::Data::new_json(delta).unwrap())
|
||||||
|
})
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The structure used to aggregate the function calls to make.
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Call {
|
||||||
|
id: String,
|
||||||
|
function_name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Call {
|
||||||
|
fn append(&mut self, arguments: &str) {
|
||||||
|
self.arguments.push_str(arguments);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct SearchInIndexParameters {
|
struct SearchInIndexParameters {
|
||||||
/// The index uid to search in.
|
/// The index uid to search in.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user