mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 20:37:15 +02:00
Fix handling of OpenAI-compatible Gemini req/res
This commit is contained in:
parent
bd2bd0f33b
commit
57355d960f
3 changed files with 89 additions and 76 deletions
|
@ -111,7 +111,7 @@ utoipa = { version = "5.4.0", features = [
|
|||
"openapi_extensions",
|
||||
] }
|
||||
utoipa-scalar = { version = "0.3.0", optional = true, features = ["actix-web"] }
|
||||
async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling" }
|
||||
async-openai = { git = "https://github.com/meilisearch/async-openai", branch = "better-error-handling", features = ["byot"] }
|
||||
secrecy = "0.10.3"
|
||||
actix-web-lab = { version = "0.24.1", default-features = false }
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ use std::fmt::Write as _;
|
|||
use std::mem;
|
||||
use std::ops::ControlFlow;
|
||||
use std::time::Duration;
|
||||
use std::pin::Pin;
|
||||
use futures::stream::Stream;
|
||||
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
|
@ -18,6 +20,7 @@ use async_openai::types::{
|
|||
FunctionCallStream, FunctionObjectArgs,
|
||||
};
|
||||
use async_openai::Client;
|
||||
use async_openai::error::OpenAIError;
|
||||
use bumpalo::Bump;
|
||||
use futures::StreamExt;
|
||||
use index_scheduler::IndexScheduler;
|
||||
|
@ -512,91 +515,101 @@ async fn run_conversation<C: async_openai::config::Config>(
|
|||
function_support: FunctionSupport,
|
||||
) -> Result<ControlFlow<Option<FinishReason>, ()>, SendError<Event>> {
|
||||
let mut finish_reason = None;
|
||||
// safety: unwrap: can only happens if `stream` was set to `false`
|
||||
let mut response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||
let mut response: Pin<Box<dyn Stream<Item = Result<CreateChatCompletionStreamResponse, OpenAIError>> + Send>>;
|
||||
|
||||
match source {
|
||||
DbChatCompletionSource::Gemini =>{
|
||||
response = client.chat().create_stream_byot(chat_completion.clone()).await.unwrap();
|
||||
}
|
||||
_ => {
|
||||
// safety: unwrap: can only happens if `stream` was set to `false`
|
||||
response = client.chat().create_stream(chat_completion.clone()).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
while let Some(result) = response.next().await {
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
let choice = &resp.choices[0];
|
||||
finish_reason = choice.finish_reason;
|
||||
|
||||
|
||||
let ChatCompletionStreamResponseDelta { ref tool_calls, .. } = &choice.delta;
|
||||
|
||||
match tool_calls {
|
||||
Some(tool_calls) => {
|
||||
for chunk in tool_calls {
|
||||
let ChatCompletionMessageToolCallChunk {
|
||||
index,
|
||||
id,
|
||||
r#type: _,
|
||||
function,
|
||||
} = chunk;
|
||||
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
|
||||
// Accumulate tool calls if present
|
||||
if let Some(tool_calls) = tool_calls {
|
||||
for chunk in tool_calls {
|
||||
let ChatCompletionMessageToolCallChunk {
|
||||
index,
|
||||
id,
|
||||
r#type: _,
|
||||
function,
|
||||
} = chunk;
|
||||
let FunctionCallStream { name, arguments } = function.as_ref().unwrap();
|
||||
|
||||
global_tool_calls
|
||||
.entry(*index)
|
||||
.and_modify(|call| {
|
||||
if call.is_internal() {
|
||||
call.append(arguments.as_ref().unwrap())
|
||||
global_tool_calls
|
||||
.entry(index.unwrap_or(0))
|
||||
.and_modify(|call| {
|
||||
if call.is_internal() {
|
||||
call.append(arguments.as_ref().unwrap())
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| {
|
||||
if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
{
|
||||
Call::Internal {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| {
|
||||
if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
{
|
||||
Call::Internal {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
function_name: name.as_ref().unwrap().clone(),
|
||||
arguments: arguments.as_ref().unwrap().clone(),
|
||||
}
|
||||
} else {
|
||||
Call::External
|
||||
}
|
||||
});
|
||||
}
|
||||
} else {
|
||||
Call::External
|
||||
}
|
||||
});
|
||||
}
|
||||
None => {
|
||||
if !global_tool_calls.is_empty() {
|
||||
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(global_tool_calls)
|
||||
.into_values()
|
||||
.flat_map(|call| match call {
|
||||
Call::Internal { id, function_name: name, arguments } => {
|
||||
Some(ChatCompletionMessageToolCall {
|
||||
id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name, arguments },
|
||||
})
|
||||
}
|
||||
Call::External => None,
|
||||
}
|
||||
|
||||
// If finish_reason is ToolCalls, process accumulated tool calls (for both OpenAI and Gemini)
|
||||
if finish_reason == Some(FinishReason::ToolCalls) && !global_tool_calls.is_empty() {
|
||||
let (meili_calls, _other_calls): (Vec<_>, Vec<_>) =
|
||||
mem::take(global_tool_calls)
|
||||
.into_values()
|
||||
.flat_map(|call| match call {
|
||||
Call::Internal { id, function_name: name, arguments } => {
|
||||
Some(ChatCompletionMessageToolCall {
|
||||
id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name, arguments },
|
||||
})
|
||||
.partition(|call| {
|
||||
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
});
|
||||
}
|
||||
Call::External => None,
|
||||
})
|
||||
.partition(|call| {
|
||||
call.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME
|
||||
});
|
||||
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(meili_calls.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
.into(),
|
||||
);
|
||||
chat_completion.messages.push(
|
||||
ChatCompletionRequestAssistantMessageArgs::default()
|
||||
.tool_calls(meili_calls.clone())
|
||||
.build()
|
||||
.unwrap()
|
||||
.into(),
|
||||
);
|
||||
|
||||
handle_meili_tools(
|
||||
index_scheduler,
|
||||
auth_ctrl,
|
||||
search_queue,
|
||||
auth_token,
|
||||
tx,
|
||||
meili_calls,
|
||||
chat_completion,
|
||||
&resp,
|
||||
function_support,
|
||||
)
|
||||
.await?;
|
||||
} else {
|
||||
tx.forward_response(&resp).await?;
|
||||
}
|
||||
}
|
||||
handle_meili_tools(
|
||||
index_scheduler,
|
||||
auth_ctrl,
|
||||
search_queue,
|
||||
auth_token,
|
||||
tx,
|
||||
meili_calls,
|
||||
chat_completion,
|
||||
&resp,
|
||||
function_support,
|
||||
)
|
||||
.await?;
|
||||
} else if tool_calls.is_none() && global_tool_calls.is_empty() {
|
||||
// Only forward to user if there are no tool calls to process
|
||||
tx.forward_response(&resp).await?;
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
|
|
|
@ -67,7 +67,7 @@ impl SseEventSender {
|
|||
) -> Result<(), SendError<Event>> {
|
||||
let call_text = serde_json::to_string(message).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
index: Some(0),
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
|
@ -114,7 +114,7 @@ impl SseEventSender {
|
|||
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
||||
let call_text = serde_json::to_string(&progress).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
index: Some(0),
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
|
@ -159,7 +159,7 @@ impl SseEventSender {
|
|||
let sources = MeiliSearchSources { call_id, sources: documents };
|
||||
let call_text = serde_json::to_string(&sources).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
index: Some(0),
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue