Fix handling of OpenAI-compatible Gemini req/res

This commit is contained in:
Dijana Pavlovic 2025-06-25 11:26:31 +02:00
parent bd2bd0f33b
commit 57355d960f
3 changed files with 89 additions and 76 deletions

View file

@ -111,7 +111,7 @@ utoipa = { version = "5.4.0", features = [
"openapi_extensions",
] }
utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] }
async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling" }
async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling", features = ["byot"] }
secrecy = "0.10.3"
actix-web-lab = { version = "0.24.1", default-features = false }

View file

@ -3,6 +3,8 @@ use std::fmt::Write as _;
use std::mem;
use std::ops::ControlFlow;
use std::time::Duration;
use std::pin::Pin;
use futures::stream::Stream;
use actix_web::web::{self, Data};
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
@ -18,6 +20,7 @@ use async_openai::types::{
FunctionCallStream, FunctionObjectArgs,
};
use async_openai::Client;
use async_openai::error::OpenAIError;
use bumpalo::Bump;
use futures::StreamExt;
use index_scheduler::IndexScheduler;
@ -512,91 +515,101 @@ async fn run_conversation<C: async_openai::config::Config>(
function_support: FunctionSupport,
) -> Result<ControlFlow<Option<FinishReason>, ()>, SendError<Event>> {
let mut finish_reason = None;
// safety: unwrap: can only happens if `stream` was set to `false`
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
let mut response: Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>;
match source {
DbChatCompletionSource::Gemini =>{
response = client.chat().create_stream_byot(chat_completion.clone()).await.unwrap();
}
_ => {
// safety: unwrap: can only happens if `stream` was set to `false`
response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
}
}
while let Some(result) = response.next().await {
match result {
Ok(resp) => {
let choice = &resp.choices[0];
finish_reason = choice.finish_reason;
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta;
match tool_calls {
Some(tool_calls) => {
for chunk in tool_calls {
let ChatCompletionMessageToolCallChunk {
index,
id,
r#type: _,
function,
} = chunk;
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
// Accumulate tool calls if present
if let Some(tool_calls) = tool_calls {
for chunk in tool_calls {
let ChatCompletionMessageToolCallChunk {
index,
id,
r#type: _,
function,
} = chunk;
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
global_tool_calls
.entry(*index)
.and_modify(|call| {
if call.is_internal() {
call.append(arguments.as_ref().unwrap())
global_tool_calls
.entry(index.unwrap_or(0))
.and_modify(|call| {
if call.is_internal() {
call.append(arguments.as_ref().unwrap())
}
})
.or_insert_with(|| {
if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
{
Call::Internal {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
}
})
.or_insert_with(|| {
if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
{
Call::Internal {
id: id.as_ref().unwrap().clone(),
function_name: name.as_ref().unwrap().clone(),
arguments: arguments.as_ref().unwrap().clone(),
}
} else {
Call::External
}
});
}
} else {
Call::External
}
});
}
None => {
if !global_tool_calls.is_empty() {
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
mem::take(global_tool_calls)
.into_values()
.flat_map(|call| match call {
Call::Internal { id, function_name: name, arguments } => {
Some(ChatCompletionMessageToolCall {
id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { name, arguments },
})
}
Call::External => None,
}
// If finish_reason is ToolCalls, process accumulated tool calls (for both OpenAI and Gemini)
if finish_reason == Some(FinishReason::ToolCalls) && !global_tool_calls.is_empty() {
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
mem::take(global_tool_calls)
.into_values()
.flat_map(|call| match call {
Call::Internal { id, function_name: name, arguments } => {
Some(ChatCompletionMessageToolCall {
id,
r#type: Some(ChatCompletionToolType::Function),
function: FunctionCall { name, arguments },
})
.partition(|call| {
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
});
}
Call::External => None,
})
.partition(|call| {
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
});
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
chat_completion.messages.push(
ChatCompletionRequestAssistantMessageArgs::default()
.tool_calls(meili_calls.clone())
.build()
.unwrap()
.into(),
);
handle_meili_tools(
index_scheduler,
auth_ctrl,
search_queue,
auth_token,
tx,
meili_calls,
chat_completion,
&resp,
function_support,
)
.await?;
} else {
tx.forward_response(&resp).await?;
}
}
handle_meili_tools(
index_scheduler,
auth_ctrl,
search_queue,
auth_token,
tx,
meili_calls,
chat_completion,
&resp,
function_support,
)
.await?;
} else if tool_calls.is_none() && global_tool_calls.is_empty() {
// Only forward to user if there are no tool calls to process
tx.forward_response(&resp).await?;
}
}
Err(error) => {

View file

@ -67,7 +67,7 @@ impl SseEventSender {
) -> Result<(), SendError<Event>> {
let call_text = serde_json::to_string(message).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
index: Some(0),
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
@ -114,7 +114,7 @@ impl SseEventSender {
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
let call_text = serde_json::to_string(&progress).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
index: Some(0),
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {
@ -159,7 +159,7 @@ impl SseEventSender {
let sources = MeiliSearchSources { call_id, sources: documents };
let call_text = serde_json::to_string(&sources).unwrap();
let tool_call = ChatCompletionMessageToolCallChunk {
index: 0,
index: Some(0),
id: Some(uuid::Uuid::new_v4().to_string()),
r#type: Some(ChatCompletionToolType::Function),
function: Some(FunctionCallStream {