From 57355d960fc4af24550c70098b5808f1c93116b3 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 25 Jun 2025 11:26:31 +0200 Subject: [PATCH] Fix handling of OpenAI-compatible Gemini req/res --- crates/meilisearch/Cargo.toml | 2 +- .../src/routes/chats/chat_completions.rs | 157 ++++++++++-------- crates/meilisearch/src/routes/chats/utils.rs | 6 +- 3 files changed, 89 insertions(+), 76 deletions(-) diff --git a/crates/meilisearch/Cargo.toml b/crates/meilisearch/Cargo.toml index fe00d9fee..980e4540a 100644 --- a/crates/meilisearch/Cargo.toml +++ b/crates/meilisearch/Cargo.toml @@ -111,7 +111,7 @@ utoipa = { version = "5.4.0", features = [ "openapi_extensions", ] } utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] } -async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling" } +async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling", features = ["byot"] } secrecy = "0.10.3" actix-web-lab = { version = "0.24.1", default-features = false } diff --git a/crates/meilisearch/src/routes/chats/chat_completions.rs b/crates/meilisearch/src/routes/chats/chat_completions.rs index 8108e24dc..46444e2f9 100644 --- a/crates/meilisearch/src/routes/chats/chat_completions.rs +++ b/crates/meilisearch/src/routes/chats/chat_completions.rs @@ -3,6 +3,8 @@ use std::fmt::Write as _; use std::mem; use std::ops::ControlFlow; use std::time::Duration; +use std::pin::Pin; +use futures::stream::Stream; use actix_web::web::{self, Data}; use actix_web::{Either, HttpRequest, HttpResponse, Responder}; @@ -18,6 +20,7 @@ use async_openai::types::{ FunctionCallStream, FunctionObjectArgs, }; use async_openai::Client; +use async_openai::error::OpenAIError; use bumpalo::Bump; use futures::StreamExt; use index_scheduler::IndexScheduler; @@ -512,91 +515,101 @@ async fn run_conversation( function_support: FunctionSupport, ) -> Result, ()>, SendError> { let mut finish_reason = None; - // safety: unwrap: can only happens if `stream` was set to `false` - let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); + let mut response: Pin> + Send>>; + + match source { + DbChatCompletionSource::Gemini =>{ + response = client.chat().create_stream_byot(chat_completion.clone()).await.unwrap(); + } + _ => { + // safety: unwrap: can only happens if `stream` was set to `false` + response = client.chat().create_stream(chat_completion.clone()).await.unwrap(); + } + } + while let Some(result) = response.next().await { match result { Ok(resp) => { let choice = &resp.choices[0]; finish_reason = choice.finish_reason; - + let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta; - match tool_calls { - Some(tool_calls) => { - for chunk in tool_calls { - let ChatCompletionMessageToolCallChunk { - index, - id, - r#type: _, - function, - } = chunk; - let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); + // Accumulate tool calls if present + if let Some(tool_calls) = tool_calls { + for chunk in tool_calls { + let ChatCompletionMessageToolCallChunk { + index, + id, + r#type: _, + function, + } = chunk; + let FunctionCallStream { name, arguments } = function.as_ref().unwrap(); - global_tool_calls - .entry(*index) - .and_modify(|call| { - if call.is_internal() { - call.append(arguments.as_ref().unwrap()) + global_tool_calls + .entry(index.unwrap_or(0)) + .and_modify(|call| { + if call.is_internal() { + call.append(arguments.as_ref().unwrap()) + } + }) + .or_insert_with(|| { + if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) + { + Call::Internal { + id: id.as_ref().unwrap().clone(), + function_name: name.as_ref().unwrap().clone(), + arguments: arguments.as_ref().unwrap().clone(), } - }) - .or_insert_with(|| { - if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) - { - Call::Internal { - id: id.as_ref().unwrap().clone(), - function_name: name.as_ref().unwrap().clone(), - arguments: arguments.as_ref().unwrap().clone(), - } - } else { - Call::External - } - }); - } + } else { + Call::External + } + }); } - None => { - if !global_tool_calls.is_empty() { - let (meili_calls, _other_calls): (Vec<_>, Vec<_>) = - mem::take(global_tool_calls) - .into_values() - .flat_map(|call| match call { - Call::Internal { id, function_name: name, arguments } => { - Some(ChatCompletionMessageToolCall { - id, - r#type: Some(ChatCompletionToolType::Function), - function: FunctionCall { name, arguments }, - }) - } - Call::External => None, + } + + // If finish_reason is ToolCalls, process accumulated tool calls (for both OpenAI and Gemini) + if finish_reason == Some(FinishReason::ToolCalls) && !global_tool_calls.is_empty() { + let (meili_calls, _other_calls): (Vec<_>, Vec<_>) = + mem::take(global_tool_calls) + .into_values() + .flat_map(|call| match call { + Call::Internal { id, function_name: name, arguments } => { + Some(ChatCompletionMessageToolCall { + id, + r#type: Some(ChatCompletionToolType::Function), + function: FunctionCall { name, arguments }, }) - .partition(|call| { - call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME - }); + } + Call::External => None, + }) + .partition(|call| { + call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME + }); - chat_completion.messages.push( - ChatCompletionRequestAssistantMessageArgs::default() - .tool_calls(meili_calls.clone()) - .build() - .unwrap() - .into(), - ); + chat_completion.messages.push( + ChatCompletionRequestAssistantMessageArgs::default() + .tool_calls(meili_calls.clone()) + .build() + .unwrap() + .into(), + ); - handle_meili_tools( - index_scheduler, - auth_ctrl, - search_queue, - auth_token, - tx, - meili_calls, - chat_completion, - &resp, - function_support, - ) - .await?; - } else { - tx.forward_response(&resp).await?; - } - } + handle_meili_tools( + index_scheduler, + auth_ctrl, + search_queue, + auth_token, + tx, + meili_calls, + chat_completion, + &resp, + function_support, + ) + .await?; + } else if tool_calls.is_none() && global_tool_calls.is_empty() { + // Only forward to user if there are no tool calls to process + tx.forward_response(&resp).await?; } } Err(error) => { diff --git a/crates/meilisearch/src/routes/chats/utils.rs b/crates/meilisearch/src/routes/chats/utils.rs index 61961bd4b..9c7c48557 100644 --- a/crates/meilisearch/src/routes/chats/utils.rs +++ b/crates/meilisearch/src/routes/chats/utils.rs @@ -67,7 +67,7 @@ impl SseEventSender { ) -> Result<(), SendError> { let call_text = serde_json::to_string(message).unwrap(); let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, + index: Some(0), id: Some(uuid::Uuid::new_v4().to_string()), r#type: Some(ChatCompletionToolType::Function), function: Some(FunctionCallStream { @@ -114,7 +114,7 @@ impl SseEventSender { let progress = MeiliSearchProgress { call_id, function_name, function_arguments }; let call_text = serde_json::to_string(&progress).unwrap(); let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, + index: Some(0), id: Some(uuid::Uuid::new_v4().to_string()), r#type: Some(ChatCompletionToolType::Function), function: Some(FunctionCallStream { @@ -159,7 +159,7 @@ impl SseEventSender { let sources = MeiliSearchSources { call_id, sources: documents }; let call_text = serde_json::to_string(&sources).unwrap(); let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, + index: Some(0), id: Some(uuid::Uuid::new_v4().to_string()), r#type: Some(ChatCompletionToolType::Function), function: Some(FunctionCallStream {