Nearly support tools on the streaming route

This commit is contained in:
Clément Renault 2025-05-14 14:29:41 +02:00
parent 24050f06e4
commit da7d651f4b
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F

View File

@ -6,14 +6,16 @@ use actix_web::{Either, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event}; use actix_web_lab::sse::{self, Event};
use async_openai::config::OpenAIConfig; use async_openai::config::OpenAIConfig;
use async_openai::types::{ use async_openai::types::{
ChatCompletionMessageToolCallChunk, ChatCompletionRequestAssistantMessageArgs, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
ChatCompletionRequestMessage, ChatCompletionRequestToolMessage, ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta, ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, FinishReason, ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
FunctionCallStream, FunctionObjectArgs, CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionCallStream,
FunctionObjectArgs,
}; };
use async_openai::Client; use async_openai::Client;
use futures::StreamExt; use futures::StreamExt;
use futures_util::stream;
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions; use meilisearch_types::keys::actions;
@ -23,6 +25,7 @@ use meilisearch_types::milli::vector::EmbeddingConfig;
use meilisearch_types::{Document, Index}; use meilisearch_types::{Document, Index};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::json; use serde_json::json;
use tokio::runtime::Handle;
use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::GuardedData;
@ -297,26 +300,25 @@ async fn streamed_chat(
let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base let config = OpenAIConfig::default().with_api_key(&api_key); // we can also change the API base
let client = Client::with_config(config); let client = Client::with_config(config);
let response = client.chat().create_stream(chat_completion).await.unwrap(); let response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
let mut global_tool_calls = HashMap::<u32, Call>::new(); let mut global_tool_calls = HashMap::<u32, Call>::new();
actix_web_lab::sse::Sse::from_stream(response.map(move |response| { actix_web_lab::sse::Sse::from_stream(response.flat_map(move |response| match response {
response.map(|mut r| { Ok(resp) => {
let delta = &r.choices[0].delta; let delta = &resp.choices[0].delta;
let ChatCompletionStreamResponseDelta { let ChatCompletionStreamResponseDelta {
ref content, content: _,
ref function_call, function_call: _,
ref tool_calls, ref tool_calls,
ref role, role: _,
ref refusal, refusal: _,
} = delta; } = delta;
match tool_calls { match tool_calls {
Some(tool_calls) => { Some(tool_calls) => {
for chunk in tool_calls { for chunk in tool_calls {
let ChatCompletionMessageToolCallChunk { index, id, r#type, function } = let ChatCompletionMessageToolCallChunk { index, id, r#type: _, function } =
chunk; chunk;
let FunctionCallStream { ref name, ref arguments } = let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
function.as_ref().unwrap();
global_tool_calls global_tool_calls
.entry(*index) .entry(*index)
.or_insert_with(|| Call { .or_insert_with(|| Call {
@ -326,15 +328,120 @@ async fn streamed_chat(
}) })
.append(arguments.as_ref().unwrap()); .append(arguments.as_ref().unwrap());
} }
stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))])
} }
None if !global_tool_calls.is_empty() => { None if !global_tool_calls.is_empty() => {
dbg!(&global_tool_calls); dbg!(&global_tool_calls);
}
None => (),
}
Event::Data(sse::Data::new_json(r).unwrap()) let (meili_calls, other_calls): (Vec<_>, Vec<_>) =
}) mem::take(&mut global_tool_calls)
.into_iter()
.map(|(_, call)| ChatCompletionMessageToolCall {
id: call.id,
r#type: ChatCompletionToolType::Function,
function: FunctionCall {
name: call.function_name,
arguments: call.arguments,
},
})
.partition(|call| call.function.name == "searchInIndex");
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
for call in meili_calls {
let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap();
let mut query = SearchQuery {
q,
hybrid: Some(HybridQuery {
semantic_ratio: SemanticRatio::default(),
embedder: EMBEDDER_NAME.to_string(),
}),
limit: 20,
..Default::default()
};
// Tenant token search_rules.
if let Some(search_rules) =
index_scheduler.filters().get_index_search_rules(&index_uid)
{
add_search_rules(&mut query.filter, search_rules);
}
// TBD
// let mut aggregate = SearchAggregator::<SearchPOST>::from_query(&query);
let index = index_scheduler.index(&index_uid).unwrap();
let search_kind = search_kind(
&query,
index_scheduler.get_ref(),
index_uid.to_string(),
&index,
)
.unwrap();
// let permit = search_queue.try_get_search_permit().await?;
let features = index_scheduler.features();
let index_cloned = index.clone();
// let search_result = tokio::task::spawn_blocking(move || {
let search_result = perform_search(
index_uid.to_string(),
&index_cloned,
query,
search_kind,
RetrieveVectors::new(false),
features,
);
// })
// .await;
// permit.drop().await;
// let search_result = search_result.unwrap();
if let Ok(ref search_result) = search_result {
// aggregate.succeed(search_result);
if search_result.degraded {
MEILISEARCH_DEGRADED_SEARCH_REQUESTS.inc();
}
}
// analytics.publish(aggregate, &req);
let search_result = search_result.unwrap();
let formatted = format_documents(
&index,
search_result.hits.into_iter().map(|doc| doc.document),
);
let text = formatted.join("\n");
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
tool_call_id: call.id,
content: ChatCompletionRequestToolMessageContent::Text(text),
},
));
}
let response = Handle::current().block_on(async {
client.chat().create_stream(chat_completion.clone()).await.unwrap()
});
// stream::iter(vec![
// Ok(Event::Data(sse::Data::new_json(json!({ "text": "Hello" })).unwrap())),
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " world" })).unwrap())),
// Ok(Event::Data(sse::Data::new_json(json!({ "text": " !" })).unwrap())),
// ])
response
}
None => stream::iter(vec![Ok(Event::Data(sse::Data::new_json(resp).unwrap()))]),
}
}
Err(err) => stream::iter(vec![Err(err)]),
})) }))
} }