diff --git a/crates/meilisearch/src/routes/chats/chat_completions.rs b/crates/meilisearch/src/routes/chats/chat_completions.rs index e14ce3c2c..ed8df3c8b 100644 --- a/crates/meilisearch/src/routes/chats/chat_completions.rs +++ b/crates/meilisearch/src/routes/chats/chat_completions.rs @@ -1,26 +1,21 @@ -use std::cell::RefCell; use std::collections::HashMap; use std::fmt::Write as _; use std::mem; use std::ops::ControlFlow; -use std::sync::RwLock; use std::time::Duration; use actix_web::web::{self, Data}; use actix_web::{Either, HttpRequest, HttpResponse, Responder}; -use actix_web_lab::sse::{self, Event, Sse}; +use actix_web_lab::sse::{Event, Sse}; use async_openai::config::{Config, OpenAIConfig}; -use async_openai::error::{ApiError, OpenAIError}; -use async_openai::reqwest_eventsource::Error as EventSourceError; use async_openai::types::{ - ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, - ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs, - ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage, - ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage, - ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta, - ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest, - CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream, - FunctionObjectArgs, Role, + ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, + ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage, + ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, + ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, + ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType, + CreateChatCompletionRequest, CreateChatCompletionStreamResponse, FinishReason, FunctionCall, + FunctionCallStream, FunctionObjectArgs, }; use async_openai::Client; use bumpalo::Bump; @@ -31,38 +26,30 @@ use meilisearch_types::error::{Code, ResponseError}; use meilisearch_types::features::{ ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings, }; -use meilisearch_types::heed::RoTxn; use meilisearch_types::keys::actions; use meilisearch_types::milli::index::ChatConfig; -use meilisearch_types::milli::prompt::{Prompt, PromptData}; -use meilisearch_types::milli::update::new::document::DocumentFromDb; -use meilisearch_types::milli::{ - all_obkv_to_json, obkv_to_json, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, - MetadataBuilder, TimeBudget, -}; +use meilisearch_types::milli::{all_obkv_to_json, obkv_to_json, TimeBudget}; use meilisearch_types::{Document, Index}; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use serde_json::json; use tokio::runtime::Handle; use tokio::sync::mpsc::error::SendError; -use tokio::sync::mpsc::Sender; -use uuid::Uuid; -use super::ChatsParam; +use super::errors::StreamErrorEvent; +use super::utils::format_documents; +use super::{ + ChatsParam, MEILI_APPEND_CONVERSATION_MESSAGE_NAME, MEILI_SEARCH_IN_INDEX_FUNCTION_NAME, + MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME, +}; use crate::error::MeilisearchHttpError; use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _}; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; +use crate::routes::chats::utils::SseEventSender; use crate::routes::indexes::search::search_kind; use crate::search::{add_search_rules, prepare_search, search_from_kind, SearchQuery}; use crate::search_queue::SearchQueue; -const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress"; -const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage"; -const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources"; -const MEILI_REPORT_ERRORS_NAME: &str = "_meiliReportErrors"; -const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; - pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service(web::resource("").route(web::post().to(chat))); } @@ -140,7 +127,6 @@ fn setup_search_tool( let mut report_progress = false; let mut report_sources = false; let mut append_to_conversation = false; - let mut report_errors = false; tools.retain(|tool| { match tool.function.name.as_str() { MEILI_SEARCH_PROGRESS_NAME => { @@ -155,10 +141,6 @@ fn setup_search_tool( append_to_conversation = true; false } - MEILI_REPORT_ERRORS_NAME => { - report_errors = true; - false - } _ => true, // keep other tools } }); @@ -443,7 +425,7 @@ async fn streamed_chat( tracing::debug!("Conversation function support: {function_support:?}"); let (tx, rx) = tokio::sync::mpsc::channel(10); - let tx = SseEventSender(tx); + let tx = SseEventSender::new(tx); let _join_handle = Handle::current().spawn(async move { let client = Client::with_config(config.clone()); let mut global_tool_calls = HashMap::::new(); @@ -521,9 +503,7 @@ async fn run_conversation( } }) .or_insert_with(|| { - if name - .as_ref() - .map_or(false, |n| n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) + if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME) { Call::Internal { id: id.as_ref().unwrap().clone(), @@ -680,181 +660,6 @@ async fn handle_meili_tools( Ok(()) } -pub struct SseEventSender(Sender); - -impl SseEventSender { - /// Ask the front-end user to append this tool *call* to the conversation - pub async fn append_tool_call_conversation_message( - &self, - resp: CreateChatCompletionStreamResponse, - call_id: String, - function_name: String, - function_arguments: String, - ) -> Result<(), SendError> { - #[allow(deprecated)] - let message = - ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { - content: None, - refusal: None, - name: None, - audio: None, - tool_calls: Some(vec![ChatCompletionMessageToolCall { - id: call_id, - r#type: Some(ChatCompletionToolType::Function), - function: FunctionCall { name: function_name, arguments: function_arguments }, - }]), - function_call: None, - }); - - self.append_conversation_message(resp, &message).await - } - - /// Ask the front-end user to append this tool to the conversation - pub async fn append_conversation_message( - &self, - mut resp: CreateChatCompletionStreamResponse, - message: &ChatCompletionRequestMessage, - ) -> Result<(), SendError> { - let call_text = serde_json::to_string(message).unwrap(); - let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, - id: Some(uuid::Uuid::new_v4().to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()), - arguments: Some(call_text), - }), - }; - - resp.choices[0] = ChatChoiceStream { - index: 0, - #[allow(deprecated)] - delta: ChatCompletionStreamResponseDelta { - content: None, - function_call: None, - tool_calls: Some(vec![tool_call]), - role: Some(Role::Assistant), - refusal: None, - }, - finish_reason: None, - logprobs: None, - }; - - self.send_json(&resp).await - } - - pub async fn report_search_progress( - &self, - mut resp: CreateChatCompletionStreamResponse, - call_id: &str, - function_name: &str, - function_arguments: &str, - ) -> Result<(), SendError> { - #[derive(Debug, Clone, Serialize)] - /// Provides information about the current Meilisearch search operation. - struct MeiliSearchProgress<'a> { - /// The call ID to track the sources of the search. - call_id: &'a str, - /// The name of the function we are executing. - function_name: &'a str, - /// The arguments of the function we are executing, encoded in JSON. - function_arguments: &'a str, - } - - let progress = MeiliSearchProgress { call_id, function_name, function_arguments }; - let call_text = serde_json::to_string(&progress).unwrap(); - let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, - id: Some(uuid::Uuid::new_v4().to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()), - arguments: Some(call_text), - }), - }; - - resp.choices[0] = ChatChoiceStream { - index: 0, - #[allow(deprecated)] - delta: ChatCompletionStreamResponseDelta { - content: None, - function_call: None, - tool_calls: Some(vec![tool_call]), - role: Some(Role::Assistant), - refusal: None, - }, - finish_reason: None, - logprobs: None, - }; - - self.send_json(&resp).await - } - - pub async fn report_sources( - &self, - mut resp: CreateChatCompletionStreamResponse, - call_id: &str, - documents: &[Document], - ) -> Result<(), SendError> { - #[derive(Debug, Clone, Serialize)] - /// Provides sources of the search. - struct MeiliSearchSources<'a> { - /// The call ID to track the original search associated to those sources. - call_id: &'a str, - /// The documents associated with the search (call_id). - /// Only the displayed attributes of the documents are returned. - sources: &'a [Document], - } - - let sources = MeiliSearchSources { call_id, sources: documents }; - let call_text = serde_json::to_string(&sources).unwrap(); - let tool_call = ChatCompletionMessageToolCallChunk { - index: 0, - id: Some(uuid::Uuid::new_v4().to_string()), - r#type: Some(ChatCompletionToolType::Function), - function: Some(FunctionCallStream { - name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()), - arguments: Some(call_text), - }), - }; - - resp.choices[0] = ChatChoiceStream { - index: 0, - #[allow(deprecated)] - delta: ChatCompletionStreamResponseDelta { - content: None, - function_call: None, - tool_calls: Some(vec![tool_call]), - role: Some(Role::Assistant), - refusal: None, - }, - finish_reason: None, - logprobs: None, - }; - - self.send_json(&resp).await - } - - pub async fn forward_response( - &self, - resp: &CreateChatCompletionStreamResponse, - ) -> Result<(), SendError> { - self.send_json(resp).await - } - - pub async fn send_error(&self, error: &StreamErrorEvent) -> Result<(), SendError> { - self.send_json(error).await - } - - pub async fn stop(self) -> Result<(), SendError> { - self.0.send(Event::Data(sse::Data::new("[DONE]"))).await - } - - async fn send_json(&self, data: &S) -> Result<(), SendError> { - self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await - } -} - /// The structure used to aggregate the function calls to make. #[derive(Debug)] enum Call { @@ -892,220 +697,3 @@ struct SearchInIndexParameters { /// The query parameter to use. q: Option, } - -fn format_documents<'t, 'doc>( - rtxn: &RoTxn<'t>, - index: &Index, - doc_alloc: &'doc Bump, - internal_docids: Vec, -) -> Result, ResponseError> { - let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?; - - let prompt = Prompt::new(template, max_bytes).unwrap(); - let fid_map = index.fields_ids_map(rtxn)?; - let metadata_builder = MetadataBuilder::from_index(index, rtxn)?; - let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder); - let global = RwLock::new(fid_map_with_meta); - let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global)); - - let external_ids: Vec = index - .external_id_of(rtxn, internal_docids.iter().copied())? - .into_iter() - .collect::>()?; - - let mut renders = Vec::new(); - for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) { - let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? { - Some(doc) => doc, - None => continue, - }; - - let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap(); - renders.push(text); - } - - Ok(renders) -} - -/// An error that occurs during the streaming process. -/// -/// It directly comes from the OpenAI API and you can -/// read more about error events on their website: -/// -#[derive(Debug, Serialize, Deserialize)] -pub struct StreamErrorEvent { - /// The unique ID of the server event. - event_id: String, - /// The event type, must be error. - r#type: String, - /// Details of the error. - error: StreamError, -} - -/// Details of the error. -#[derive(Debug, Serialize, Deserialize)] -pub struct StreamError { - /// The type of error (e.g., "invalid_request_error", "server_error"). - r#type: String, - /// Error code, if any. - code: Option, - /// A human-readable error message. - message: String, - /// Parameter related to the error, if any. - param: Option, - /// The event_id of the client event that caused the error, if applicable. - event_id: Option, -} - -impl StreamErrorEvent { - pub async fn from_openai_error(error: OpenAIError) -> Result { - let error_type = "error".to_string(); - match error { - OpenAIError::Reqwest(e) => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "internal_reqwest_error".to_string(), - code: Some("internal".to_string()), - message: e.to_string(), - param: None, - event_id: None, - }, - }), - OpenAIError::ApiError(ApiError { message, r#type, param, code }) => { - Ok(StreamErrorEvent { - r#type: error_type, - event_id: Uuid::new_v4().to_string(), - error: StreamError { - r#type: r#type.unwrap_or_else(|| "unknown".to_string()), - code, - message, - param, - event_id: None, - }, - }) - } - OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "json_deserialize_error".to_string(), - code: Some("internal".to_string()), - message: error.to_string(), - param: None, - event_id: None, - }, - }), - OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(), - OpenAIError::StreamError(error) => match error { - EventSourceError::InvalidStatusCode(_status_code, response) => { - let OpenAiOutsideError { - error: OpenAiInnerError { code, message, param, r#type }, - } = response.json().await?; - - Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { r#type, code, message, param, event_id: None }, - }) - } - EventSourceError::InvalidContentType(_header_value, response) => { - let OpenAiOutsideError { - error: OpenAiInnerError { code, message, param, r#type }, - } = response.json().await?; - - Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { r#type, code, message, param, event_id: None }, - }) - } - EventSourceError::Utf8(error) => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "invalid_utf8_error".to_string(), - code: None, - message: error.to_string(), - param: None, - event_id: None, - }, - }), - EventSourceError::Parser(error) => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "parser_error".to_string(), - code: None, - message: error.to_string(), - param: None, - event_id: None, - }, - }), - EventSourceError::Transport(error) => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "transport_error".to_string(), - code: None, - message: error.to_string(), - param: None, - event_id: None, - }, - }), - EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "invalid_last_event_id".to_string(), - code: None, - message, - param: None, - event_id: None, - }, - }), - EventSourceError::StreamEnded => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "stream_ended".to_string(), - code: None, - message: "Stream ended".to_string(), - param: None, - event_id: None, - }, - }), - }, - OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: error_type, - error: StreamError { - r#type: "invalid_argument".to_string(), - code: None, - message, - param: None, - event_id: None, - }, - }), - } - } -} - -#[derive(Debug, Clone, Deserialize)] -pub struct OpenAiOutsideError { - /// Emitted when an error occurs. - error: OpenAiInnerError, -} - -/// Emitted when an error occurs. -#[derive(Debug, Clone, Deserialize)] -pub struct OpenAiInnerError { - /// The error code. - code: Option, - /// The error message. - message: String, - /// The error parameter. - param: Option, - /// The type of the event. Always `error`. - r#type: String, -} diff --git a/crates/meilisearch/src/routes/chats/errors.rs b/crates/meilisearch/src/routes/chats/errors.rs new file mode 100644 index 000000000..f1aa9722b --- /dev/null +++ b/crates/meilisearch/src/routes/chats/errors.rs @@ -0,0 +1,187 @@ +use async_openai::error::{ApiError, OpenAIError}; +use async_openai::reqwest_eventsource::Error as EventSourceError; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +#[derive(Debug, Clone, Deserialize)] +pub struct OpenAiOutsideError { + /// Emitted when an error occurs. + error: OpenAiInnerError, +} + +/// Emitted when an error occurs. +#[derive(Debug, Clone, Deserialize)] +pub struct OpenAiInnerError { + /// The error code. + code: Option, + /// The error message. + message: String, + /// The error parameter. + param: Option, + /// The type of the event. Always `error`. + r#type: String, +} + +/// An error that occurs during the streaming process. +/// +/// It directly comes from the OpenAI API and you can +/// read more about error events on their website: +/// +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamErrorEvent { + /// The unique ID of the server event. + pub event_id: String, + /// The event type, must be error. + pub r#type: String, + /// Details of the error. + pub error: StreamError, +} + +/// Details of the error. +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamError { + /// The type of error (e.g., "invalid_request_error", "server_error"). + pub r#type: String, + /// Error code, if any. + pub code: Option, + /// A human-readable error message. + pub message: String, + /// Parameter related to the error, if any. + pub param: Option, + /// The event_id of the client event that caused the error, if applicable. + pub event_id: Option, +} + +impl StreamErrorEvent { + pub async fn from_openai_error(error: OpenAIError) -> Result { + let error_type = "error".to_string(); + match error { + OpenAIError::Reqwest(e) => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "internal_reqwest_error".to_string(), + code: Some("internal".to_string()), + message: e.to_string(), + param: None, + event_id: None, + }, + }), + OpenAIError::ApiError(ApiError { message, r#type, param, code }) => { + Ok(StreamErrorEvent { + r#type: error_type, + event_id: Uuid::new_v4().to_string(), + error: StreamError { + r#type: r#type.unwrap_or_else(|| "unknown".to_string()), + code, + message, + param, + event_id: None, + }, + }) + } + OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "json_deserialize_error".to_string(), + code: Some("internal".to_string()), + message: error.to_string(), + param: None, + event_id: None, + }, + }), + OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(), + OpenAIError::StreamError(error) => match error { + EventSourceError::InvalidStatusCode(_status_code, response) => { + let OpenAiOutsideError { + error: OpenAiInnerError { code, message, param, r#type }, + } = response.json().await?; + + Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { r#type, code, message, param, event_id: None }, + }) + } + EventSourceError::InvalidContentType(_header_value, response) => { + let OpenAiOutsideError { + error: OpenAiInnerError { code, message, param, r#type }, + } = response.json().await?; + + Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { r#type, code, message, param, event_id: None }, + }) + } + EventSourceError::Utf8(error) => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "invalid_utf8_error".to_string(), + code: None, + message: error.to_string(), + param: None, + event_id: None, + }, + }), + EventSourceError::Parser(error) => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "parser_error".to_string(), + code: None, + message: error.to_string(), + param: None, + event_id: None, + }, + }), + EventSourceError::Transport(error) => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "transport_error".to_string(), + code: None, + message: error.to_string(), + param: None, + event_id: None, + }, + }), + EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "invalid_last_event_id".to_string(), + code: None, + message, + param: None, + event_id: None, + }, + }), + EventSourceError::StreamEnded => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "stream_ended".to_string(), + code: None, + message: "Stream ended".to_string(), + param: None, + event_id: None, + }, + }), + }, + OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: error_type, + error: StreamError { + r#type: "invalid_argument".to_string(), + code: None, + message, + param: None, + event_id: None, + }, + }), + } + } +} diff --git a/crates/meilisearch/src/routes/chats/mod.rs b/crates/meilisearch/src/routes/chats/mod.rs index 0fa0d54b4..bb0476ab8 100644 --- a/crates/meilisearch/src/routes/chats/mod.rs +++ b/crates/meilisearch/src/routes/chats/mod.rs @@ -1,30 +1,35 @@ -use actix_web::{ - web::{self, Data}, - HttpResponse, -}; -use deserr::{actix_web::AwebQueryParameter, Deserr}; +use actix_web::web::{self, Data}; +use actix_web::HttpResponse; +use deserr::actix_web::AwebQueryParameter; +use deserr::Deserr; use index_scheduler::IndexScheduler; -use meilisearch_types::{ - deserr::{query_params::Param, DeserrQueryParamError}, - error::{ - deserr_codes::{InvalidIndexLimit, InvalidIndexOffset}, - ResponseError, - }, - keys::actions, -}; +use meilisearch_types::deserr::query_params::Param; +use meilisearch_types::deserr::DeserrQueryParamError; +use meilisearch_types::error::deserr_codes::{InvalidIndexLimit, InvalidIndexOffset}; +use meilisearch_types::error::ResponseError; +use meilisearch_types::keys::actions; use serde::{Deserialize, Serialize}; use tracing::debug; use utoipa::{IntoParams, ToSchema}; -use crate::{ - extractors::authentication::{policies::ActionPolicy, GuardedData}, - routes::PAGINATION_DEFAULT_LIMIT, -}; - use super::Pagination; +use crate::extractors::authentication::policies::ActionPolicy; +use crate::extractors::authentication::GuardedData; +use crate::routes::PAGINATION_DEFAULT_LIMIT; pub mod chat_completions; +mod errors; pub mod settings; +mod utils; + +/// The function name to report search progress. +const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress"; +/// The function name to append a conversation message in the user conversation. +const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage"; +/// The function name to report sources to the frontend. +const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources"; +/// The *internal* function name to provide to the LLM to search in indexes. +const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex"; #[derive(Deserialize)] pub struct ChatsParam { diff --git a/crates/meilisearch/src/routes/chats/utils.rs b/crates/meilisearch/src/routes/chats/utils.rs new file mode 100644 index 000000000..424b4ea64 --- /dev/null +++ b/crates/meilisearch/src/routes/chats/utils.rs @@ -0,0 +1,243 @@ +use std::cell::RefCell; +use std::sync::RwLock; + +use actix_web_lab::sse::{self, Event}; +use async_openai::types::{ + ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk, + ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage, + ChatCompletionStreamResponseDelta, ChatCompletionToolType, CreateChatCompletionStreamResponse, + FunctionCall, FunctionCallStream, Role, +}; +use bumpalo::Bump; +use meilisearch_types::error::ResponseError; +use meilisearch_types::heed::RoTxn; +use meilisearch_types::milli::index::ChatConfig; +use meilisearch_types::milli::prompt::{Prompt, PromptData}; +use meilisearch_types::milli::update::new::document::DocumentFromDb; +use meilisearch_types::milli::{ + DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, +}; +use meilisearch_types::{Document, Index}; +use serde::Serialize; +use tokio::sync::mpsc::error::SendError; +use tokio::sync::mpsc::Sender; + +use super::errors::StreamErrorEvent; +use super::MEILI_APPEND_CONVERSATION_MESSAGE_NAME; +use crate::routes::chats::{MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME}; + +pub struct SseEventSender(Sender); + +impl SseEventSender { + pub fn new(sender: Sender) -> Self { + Self(sender) + } + + /// Ask the front-end user to append this tool *call* to the conversation + pub async fn append_tool_call_conversation_message( + &self, + resp: CreateChatCompletionStreamResponse, + call_id: String, + function_name: String, + function_arguments: String, + ) -> Result<(), SendError> { + #[allow(deprecated)] + let message = + ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { + content: None, + refusal: None, + name: None, + audio: None, + tool_calls: Some(vec![ChatCompletionMessageToolCall { + id: call_id, + r#type: Some(ChatCompletionToolType::Function), + function: FunctionCall { name: function_name, arguments: function_arguments }, + }]), + function_call: None, + }); + + self.append_conversation_message(resp, &message).await + } + + /// Ask the front-end user to append this tool to the conversation + pub async fn append_conversation_message( + &self, + mut resp: CreateChatCompletionStreamResponse, + message: &ChatCompletionRequestMessage, + ) -> Result<(), SendError> { + let call_text = serde_json::to_string(message).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()), + arguments: Some(call_text), + }), + }; + + resp.choices[0] = ChatChoiceStream { + index: 0, + #[allow(deprecated)] + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + + self.send_json(&resp).await + } + + pub async fn report_search_progress( + &self, + mut resp: CreateChatCompletionStreamResponse, + call_id: &str, + function_name: &str, + function_arguments: &str, + ) -> Result<(), SendError> { + #[derive(Debug, Clone, Serialize)] + /// Provides information about the current Meilisearch search operation. + struct MeiliSearchProgress<'a> { + /// The call ID to track the sources of the search. + call_id: &'a str, + /// The name of the function we are executing. + function_name: &'a str, + /// The arguments of the function we are executing, encoded in JSON. + function_arguments: &'a str, + } + + let progress = MeiliSearchProgress { call_id, function_name, function_arguments }; + let call_text = serde_json::to_string(&progress).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()), + arguments: Some(call_text), + }), + }; + + resp.choices[0] = ChatChoiceStream { + index: 0, + #[allow(deprecated)] + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + + self.send_json(&resp).await + } + + pub async fn report_sources( + &self, + mut resp: CreateChatCompletionStreamResponse, + call_id: &str, + documents: &[Document], + ) -> Result<(), SendError> { + #[derive(Debug, Clone, Serialize)] + /// Provides sources of the search. + struct MeiliSearchSources<'a> { + /// The call ID to track the original search associated to those sources. + call_id: &'a str, + /// The documents associated with the search (call_id). + /// Only the displayed attributes of the documents are returned. + sources: &'a [Document], + } + + let sources = MeiliSearchSources { call_id, sources: documents }; + let call_text = serde_json::to_string(&sources).unwrap(); + let tool_call = ChatCompletionMessageToolCallChunk { + index: 0, + id: Some(uuid::Uuid::new_v4().to_string()), + r#type: Some(ChatCompletionToolType::Function), + function: Some(FunctionCallStream { + name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()), + arguments: Some(call_text), + }), + }; + + resp.choices[0] = ChatChoiceStream { + index: 0, + #[allow(deprecated)] + delta: ChatCompletionStreamResponseDelta { + content: None, + function_call: None, + tool_calls: Some(vec![tool_call]), + role: Some(Role::Assistant), + refusal: None, + }, + finish_reason: None, + logprobs: None, + }; + + self.send_json(&resp).await + } + + pub async fn forward_response( + &self, + resp: &CreateChatCompletionStreamResponse, + ) -> Result<(), SendError> { + self.send_json(resp).await + } + + pub async fn send_error(&self, error: &StreamErrorEvent) -> Result<(), SendError> { + self.send_json(error).await + } + + pub async fn stop(self) -> Result<(), SendError> { + self.0.send(Event::Data(sse::Data::new("[DONE]"))).await + } + + async fn send_json(&self, data: &S) -> Result<(), SendError> { + self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await + } +} + +/// Format documents based on the provided template and maximum bytes. +/// +/// This formatting function is usually used to generate a summary of the documents for LLMs. +pub fn format_documents<'t, 'doc>( + rtxn: &RoTxn<'t>, + index: &Index, + doc_alloc: &'doc Bump, + internal_docids: Vec, +) -> Result, ResponseError> { + let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?; + + let prompt = Prompt::new(template, max_bytes).unwrap(); + let fid_map = index.fields_ids_map(rtxn)?; + let metadata_builder = MetadataBuilder::from_index(index, rtxn)?; + let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder); + let global = RwLock::new(fid_map_with_meta); + let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global)); + + let external_ids: Vec = index + .external_id_of(rtxn, internal_docids.iter().copied())? + .into_iter() + .collect::>()?; + + let mut renders = Vec::new(); + for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) { + let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? { + Some(doc) => doc, + None => continue, + }; + + let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap(); + renders.push(text); + } + + Ok(renders) +}