diff --git a/crates/meilisearch-types/src/features.rs b/crates/meilisearch-types/src/features.rs index 651077484..83054e784 100644 --- a/crates/meilisearch-types/src/features.rs +++ b/crates/meilisearch-types/src/features.rs @@ -107,7 +107,7 @@ impl ChatCompletionSettings { } } -#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Default)] #[serde(rename_all = "camelCase")] pub enum ChatCompletionSource { #[default] diff --git a/crates/meilisearch/src/routes/chats/chat_completions.rs b/crates/meilisearch/src/routes/chats/chat_completions.rs index 7431609e6..8108e24dc 100644 --- a/crates/meilisearch/src/routes/chats/chat_completions.rs +++ b/crates/meilisearch/src/routes/chats/chat_completions.rs @@ -23,7 +23,10 @@ use futures::StreamExt; use index_scheduler::IndexScheduler; use meilisearch_auth::AuthController; use meilisearch_types::error::{Code, ResponseError}; -use meilisearch_types::features::{ChatCompletionPrompts as DbChatCompletionPrompts, SystemRole}; +use meilisearch_types::features::{ + ChatCompletionPrompts as DbChatCompletionPrompts, + ChatCompletionSource as DbChatCompletionSource, SystemRole, +}; use meilisearch_types::keys::actions; use meilisearch_types::milli::index::ChatConfig; use meilisearch_types::milli::{all_obkv_to_json, obkv_to_json, TimeBudget}; @@ -34,7 +37,7 @@ use tokio::runtime::Handle; use tokio::sync::mpsc::error::SendError; use super::config::Config; -use super::errors::StreamErrorEvent; +use super::errors::{MistralError, OpenAiOutsideError, StreamErrorEvent}; use super::utils::format_documents; use super::{ ChatsParam, MEILI_APPEND_CONVERSATION_MESSAGE_NAME, MEILI_SEARCH_IN_INDEX_FUNCTION_NAME, @@ -469,6 +472,7 @@ async fn streamed_chat( &search_queue, &auth_token, &client, + chat_settings.source, &mut chat_completion, &tx, &mut global_tool_calls, @@ -501,6 +505,7 @@ async fn run_conversation( search_queue: &web::Data, auth_token: &str, client: &Client, + source: DbChatCompletionSource, chat_completion: &mut CreateChatCompletionRequest, tx: &SseEventSender, global_tool_calls: &mut HashMap, @@ -595,7 +600,13 @@ async fn run_conversation( } } Err(error) => { - let error = StreamErrorEvent::from_openai_error(error).await.unwrap(); + let result = match source { + DbChatCompletionSource::Mistral => { + StreamErrorEvent::from_openai_error::(error).await + } + _ => StreamErrorEvent::from_openai_error::(error).await, + }; + let error = result.unwrap_or_else(StreamErrorEvent::from_reqwest_error); tx.send_error(&error).await?; return Ok(ControlFlow::Break(None)); } diff --git a/crates/meilisearch/src/routes/chats/errors.rs b/crates/meilisearch/src/routes/chats/errors.rs index efa60ba50..e7fb661ed 100644 --- a/crates/meilisearch/src/routes/chats/errors.rs +++ b/crates/meilisearch/src/routes/chats/errors.rs @@ -4,6 +4,39 @@ use meilisearch_types::error::ResponseError; use serde::{Deserialize, Serialize}; use uuid::Uuid; +/// The error type which is always `error`. +const ERROR_TYPE: &str = "error"; + +/// The error struct returned by the Mistral API. +/// +/// ```json +/// { +/// "object": "error", +/// "message": "Service tier capacity exceeded for this model.", +/// "type": "invalid_request_error", +/// "param": null, +/// "code": null +/// } +/// ``` +#[derive(Debug, Clone, Deserialize)] +pub struct MistralError { + message: String, + r#type: String, + param: Option, + code: Option, +} + +impl From for StreamErrorEvent { + fn from(error: MistralError) -> Self { + let MistralError { message, r#type, param, code } = error; + StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: ERROR_TYPE.to_owned(), + error: StreamError { r#type, code, message, param, event_id: None }, + } + } +} + #[derive(Debug, Clone, Deserialize)] pub struct OpenAiOutsideError { /// Emitted when an error occurs. @@ -23,6 +56,17 @@ pub struct OpenAiInnerError { r#type: String, } +impl From for StreamErrorEvent { + fn from(error: OpenAiOutsideError) -> Self { + let OpenAiOutsideError { error: OpenAiInnerError { code, message, param, r#type } } = error; + StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: ERROR_TYPE.to_string(), + error: StreamError { r#type, code, message, param, event_id: None }, + } + } +} + /// An error that occurs during the streaming process. /// /// It directly comes from the OpenAI API and you can @@ -54,13 +98,15 @@ pub struct StreamError { } impl StreamErrorEvent { - const ERROR_TYPE: &str = "error"; - - pub async fn from_openai_error(error: OpenAIError) -> Result { + pub async fn from_openai_error(error: OpenAIError) -> Result + where + E: serde::de::DeserializeOwned, + Self: From, + { match error { OpenAIError::Reqwest(e) => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "internal_reqwest_error".to_string(), code: Some("internal".to_string()), @@ -71,7 +117,7 @@ impl StreamErrorEvent { }), OpenAIError::ApiError(ApiError { message, r#type, param, code }) => { Ok(StreamErrorEvent { - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), event_id: Uuid::new_v4().to_string(), error: StreamError { r#type: r#type.unwrap_or_else(|| "unknown".to_string()), @@ -84,7 +130,7 @@ impl StreamErrorEvent { } OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "json_deserialize_error".to_string(), code: Some("internal".to_string()), @@ -96,30 +142,16 @@ impl StreamErrorEvent { OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(), OpenAIError::StreamError(error) => match error { EventSourceError::InvalidStatusCode(_status_code, response) => { - let OpenAiOutsideError { - error: OpenAiInnerError { code, message, param, r#type }, - } = response.json().await?; - - Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), - error: StreamError { r#type, code, message, param, event_id: None }, - }) + let error = response.json::().await?; + Ok(StreamErrorEvent::from(error)) } EventSourceError::InvalidContentType(_header_value, response) => { - let OpenAiOutsideError { - error: OpenAiInnerError { code, message, param, r#type }, - } = response.json().await?; - - Ok(StreamErrorEvent { - event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), - error: StreamError { r#type, code, message, param, event_id: None }, - }) + let error = response.json::().await?; + Ok(StreamErrorEvent::from(error)) } EventSourceError::Utf8(error) => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "invalid_utf8_error".to_string(), code: None, @@ -130,7 +162,7 @@ impl StreamErrorEvent { }), EventSourceError::Parser(error) => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "parser_error".to_string(), code: None, @@ -141,7 +173,7 @@ impl StreamErrorEvent { }), EventSourceError::Transport(error) => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "transport_error".to_string(), code: None, @@ -152,7 +184,7 @@ impl StreamErrorEvent { }), EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "invalid_last_event_id".to_string(), code: None, @@ -163,7 +195,7 @@ impl StreamErrorEvent { }), EventSourceError::StreamEnded => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "stream_ended".to_string(), code: None, @@ -175,7 +207,7 @@ impl StreamErrorEvent { }, OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "invalid_argument".to_string(), code: None, @@ -191,7 +223,7 @@ impl StreamErrorEvent { let ResponseError { code, message, .. } = error; StreamErrorEvent { event_id: Uuid::new_v4().to_string(), - r#type: Self::ERROR_TYPE.to_string(), + r#type: ERROR_TYPE.to_string(), error: StreamError { r#type: "response_error".to_string(), code: Some(code.as_str().to_string()), @@ -201,4 +233,18 @@ impl StreamErrorEvent { }, } } + + pub fn from_reqwest_error(error: reqwest::Error) -> Self { + StreamErrorEvent { + event_id: Uuid::new_v4().to_string(), + r#type: ERROR_TYPE.to_string(), + error: StreamError { + r#type: "reqwest_error".to_string(), + code: None, + message: error.to_string(), + param: None, + event_id: None, + }, + } + } }