From 91c6ab8392c743a5c09d442f5c5556d13c9309ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Renault?= Date: Tue, 20 May 2025 18:01:08 +0200 Subject: [PATCH] Make sure errorneous calls are handled and forwarded to the LLM --- crates/meilisearch/src/routes/chat.rs | 37 +++++++++++++++++---------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/crates/meilisearch/src/routes/chat.rs b/crates/meilisearch/src/routes/chat.rs index 0dc54b37d..b3a67ff10 100644 --- a/crates/meilisearch/src/routes/chat.rs +++ b/crates/meilisearch/src/routes/chat.rs @@ -253,23 +253,32 @@ async fn non_streamed_chat( ); for call in meili_calls { - let SearchInIndexParameters { index_uid, q } = - serde_json::from_str(&call.function.arguments).unwrap(); + let result = match serde_json::from_str(&call.function.arguments) { + Ok(SearchInIndexParameters { index_uid, q }) => process_search_request( + &index_scheduler, + auth_ctrl.clone(), + &search_queue, + &auth_token, + index_uid, + q, + ) + .await + .map_err(|e| e.to_string()), + Err(err) => Err(err.to_string()), + }; - let (_, text) = process_search_request( - &index_scheduler, - auth_ctrl.clone(), - &search_queue, - auth_token, - index_uid, - q, - ) - .await?; + let text = match result { + Ok((_, text)) => text, + Err(err) => err, + }; chat_completion.messages.push(ChatCompletionRequestMessage::Tool( ChatCompletionRequestToolMessage { - tool_call_id: call.id, - content: ChatCompletionRequestToolMessageContent::Text(text), + tool_call_id: call.id.clone(), + content: ChatCompletionRequestToolMessageContent::Text(format!( + "{}\n\n{text}", + chat_settings.prompts.pre_query + )), }, )); } @@ -413,7 +422,7 @@ async fn streamed_chat( let is_error = result.is_err(); let text = match result { Ok((_, text)) => text, - Err(err) => err.to_string(), + Err(err) => err, }; let tool = ChatCompletionRequestToolMessage {