mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-06-10 01:51:36 +02:00
Clean up chat completions modules a bit
This commit is contained in:
parent
201a808fe2
commit
7d574433b6
@ -1,26 +1,21 @@
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::fmt::Write as _;
|
||||
use std::mem;
|
||||
use std::ops::ControlFlow;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Duration;
|
||||
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{self, Event, Sse};
|
||||
use actix_web_lab::sse::{Event, Sse};
|
||||
use async_openai::config::{Config, OpenAIConfig};
|
||||
use async_openai::error::{ApiError, OpenAIError};
|
||||
use async_openai::reqwest_eventsource::Error as EventSourceError;
|
||||
use async_openai::types::{
|
||||
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
|
||||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
||||
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage,
|
||||
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
|
||||
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
|
||||
CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream,
|
||||
FunctionObjectArgs, Role,
|
||||
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
||||
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
||||
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
||||
CreateChatCompletionRequest, CreateChatCompletionStreamResponse, FinishReason, FunctionCall,
|
||||
FunctionCallStream, FunctionObjectArgs,
|
||||
};
|
||||
use async_openai::Client;
|
||||
use bumpalo::Bump;
|
||||
@ -31,38 +26,30 @@ use meilisearch_types::error::{Code, ResponseError};
|
||||
use meilisearch_types::features::{
|
||||
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
|
||||
};
|
||||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::index::ChatConfig;
|
||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||
use meilisearch_types::milli::{
|
||||
all_obkv_to_json, obkv_to_json, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap,
|
||||
MetadataBuilder, TimeBudget,
|
||||
};
|
||||
use meilisearch_types::milli::{all_obkv_to_json, obkv_to_json, TimeBudget};
|
||||
use meilisearch_types::{Document, Index};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::ChatsParam;
|
||||
use super::errors::StreamErrorEvent;
|
||||
use super::utils::format_documents;
|
||||
use super::{
|
||||
ChatsParam, MEILI_APPEND_CONVERSATION_MESSAGE_NAME, MEILI_SEARCH_IN_INDEX_FUNCTION_NAME,
|
||||
MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME,
|
||||
};
|
||||
use crate::error::MeilisearchHttpError;
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
|
||||
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
||||
use crate::routes::chats::utils::SseEventSender;
|
||||
use crate::routes::indexes::search::search_kind;
|
||||
use crate::search::{add_search_rules, prepare_search, search_from_kind, SearchQuery};
|
||||
use crate::search_queue::SearchQueue;
|
||||
|
||||
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
|
||||
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
|
||||
const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources";
|
||||
const MEILI_REPORT_ERRORS_NAME: &str = "_meiliReportErrors";
|
||||
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||
|
||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(web::resource("").route(web::post().to(chat)));
|
||||
}
|
||||
@ -140,7 +127,6 @@ fn setup_search_tool(
|
||||
let mut report_progress = false;
|
||||
let mut report_sources = false;
|
||||
let mut append_to_conversation = false;
|
||||
let mut report_errors = false;
|
||||
tools.retain(|tool| {
|
||||
match tool.function.name.as_str() {
|
||||
MEILI_SEARCH_PROGRESS_NAME => {
|
||||
@ -155,10 +141,6 @@ fn setup_search_tool(
|
||||
append_to_conversation = true;
|
||||
false
|
||||
}
|
||||
MEILI_REPORT_ERRORS_NAME => {
|
||||
report_errors = true;
|
||||
false
|
||||
}
|
||||
_ => true, // keep other tools
|
||||
}
|
||||
});
|
||||
@ -443,7 +425,7 @@ async fn streamed_chat(
|
||||
tracing::debug!("Conversation function support: {function_support:?}");
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let tx = SseEventSender(tx);
|
||||
let tx = SseEventSender::new(tx);
|
||||
let _join_handle = Handle::current().spawn(async move {
|
||||
let client = Client::with_config(config.clone());
|
||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||
@ -521,9 +503,7 @@ async fn run_conversation<C: Config>(
|
||||
}
|
||||
})
|
||||
.or_insert_with(|| {
|
||||
if name
|
||||
.as_ref()
|
||||
.map_or(false, |n| n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||
{
|
||||
Call::Internal {
|
||||
id: id.as_ref().unwrap().clone(),
|
||||
@ -680,181 +660,6 @@ async fn handle_meili_tools(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct SseEventSender(Sender<Event>);
|
||||
|
||||
impl SseEventSender {
|
||||
/// Ask the front-end user to append this tool *call* to the conversation
|
||||
pub async fn append_tool_call_conversation_message(
|
||||
&self,
|
||||
resp: CreateChatCompletionStreamResponse,
|
||||
call_id: String,
|
||||
function_name: String,
|
||||
function_arguments: String,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[allow(deprecated)]
|
||||
let message =
|
||||
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
||||
content: None,
|
||||
refusal: None,
|
||||
name: None,
|
||||
audio: None,
|
||||
tool_calls: Some(vec![ChatCompletionMessageToolCall {
|
||||
id: call_id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name: function_name, arguments: function_arguments },
|
||||
}]),
|
||||
function_call: None,
|
||||
});
|
||||
|
||||
self.append_conversation_message(resp, &message).await
|
||||
}
|
||||
|
||||
/// Ask the front-end user to append this tool to the conversation
|
||||
pub async fn append_conversation_message(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
message: &ChatCompletionRequestMessage,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
let call_text = serde_json::to_string(message).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)]
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_search_progress(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
function_name: &str,
|
||||
function_arguments: &str,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides information about the current Meilisearch search operation.
|
||||
struct MeiliSearchProgress<'a> {
|
||||
/// The call ID to track the sources of the search.
|
||||
call_id: &'a str,
|
||||
/// The name of the function we are executing.
|
||||
function_name: &'a str,
|
||||
/// The arguments of the function we are executing, encoded in JSON.
|
||||
function_arguments: &'a str,
|
||||
}
|
||||
|
||||
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
||||
let call_text = serde_json::to_string(&progress).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)]
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_sources(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
documents: &[Document],
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides sources of the search.
|
||||
struct MeiliSearchSources<'a> {
|
||||
/// The call ID to track the original search associated to those sources.
|
||||
call_id: &'a str,
|
||||
/// The documents associated with the search (call_id).
|
||||
/// Only the displayed attributes of the documents are returned.
|
||||
sources: &'a [Document],
|
||||
}
|
||||
|
||||
let sources = MeiliSearchSources { call_id, sources: documents };
|
||||
let call_text = serde_json::to_string(&sources).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)]
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn forward_response(
|
||||
&self,
|
||||
resp: &CreateChatCompletionStreamResponse,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
self.send_json(resp).await
|
||||
}
|
||||
|
||||
pub async fn send_error(&self, error: &StreamErrorEvent) -> Result<(), SendError<Event>> {
|
||||
self.send_json(error).await
|
||||
}
|
||||
|
||||
pub async fn stop(self) -> Result<(), SendError<Event>> {
|
||||
self.0.send(Event::Data(sse::Data::new("[DONE]"))).await
|
||||
}
|
||||
|
||||
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
|
||||
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
|
||||
}
|
||||
}
|
||||
|
||||
/// The structure used to aggregate the function calls to make.
|
||||
#[derive(Debug)]
|
||||
enum Call {
|
||||
@ -892,220 +697,3 @@ struct SearchInIndexParameters {
|
||||
/// The query parameter to use.
|
||||
q: Option<String>,
|
||||
}
|
||||
|
||||
fn format_documents<'t, 'doc>(
|
||||
rtxn: &RoTxn<'t>,
|
||||
index: &Index,
|
||||
doc_alloc: &'doc Bump,
|
||||
internal_docids: Vec<DocumentId>,
|
||||
) -> Result<Vec<&'doc str>, ResponseError> {
|
||||
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
|
||||
|
||||
let prompt = Prompt::new(template, max_bytes).unwrap();
|
||||
let fid_map = index.fields_ids_map(rtxn)?;
|
||||
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
|
||||
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
|
||||
let global = RwLock::new(fid_map_with_meta);
|
||||
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
|
||||
|
||||
let external_ids: Vec<String> = index
|
||||
.external_id_of(rtxn, internal_docids.iter().copied())?
|
||||
.into_iter()
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let mut renders = Vec::new();
|
||||
for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
|
||||
let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
|
||||
Some(doc) => doc,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap();
|
||||
renders.push(text);
|
||||
}
|
||||
|
||||
Ok(renders)
|
||||
}
|
||||
|
||||
/// An error that occurs during the streaming process.
|
||||
///
|
||||
/// It directly comes from the OpenAI API and you can
|
||||
/// read more about error events on their website:
|
||||
/// <https://platform.openai.com/docs/api-reference/realtime-server-events/error>
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StreamErrorEvent {
|
||||
/// The unique ID of the server event.
|
||||
event_id: String,
|
||||
/// The event type, must be error.
|
||||
r#type: String,
|
||||
/// Details of the error.
|
||||
error: StreamError,
|
||||
}
|
||||
|
||||
/// Details of the error.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StreamError {
|
||||
/// The type of error (e.g., "invalid_request_error", "server_error").
|
||||
r#type: String,
|
||||
/// Error code, if any.
|
||||
code: Option<String>,
|
||||
/// A human-readable error message.
|
||||
message: String,
|
||||
/// Parameter related to the error, if any.
|
||||
param: Option<String>,
|
||||
/// The event_id of the client event that caused the error, if applicable.
|
||||
event_id: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamErrorEvent {
|
||||
pub async fn from_openai_error(error: OpenAIError) -> Result<Self, reqwest::Error> {
|
||||
let error_type = "error".to_string();
|
||||
match error {
|
||||
OpenAIError::Reqwest(e) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "internal_reqwest_error".to_string(),
|
||||
code: Some("internal".to_string()),
|
||||
message: e.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
OpenAIError::ApiError(ApiError { message, r#type, param, code }) => {
|
||||
Ok(StreamErrorEvent {
|
||||
r#type: error_type,
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
error: StreamError {
|
||||
r#type: r#type.unwrap_or_else(|| "unknown".to_string()),
|
||||
code,
|
||||
message,
|
||||
param,
|
||||
event_id: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "json_deserialize_error".to_string(),
|
||||
code: Some("internal".to_string()),
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
|
||||
OpenAIError::StreamError(error) => match error {
|
||||
EventSourceError::InvalidStatusCode(_status_code, response) => {
|
||||
let OpenAiOutsideError {
|
||||
error: OpenAiInnerError { code, message, param, r#type },
|
||||
} = response.json().await?;
|
||||
|
||||
Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError { r#type, code, message, param, event_id: None },
|
||||
})
|
||||
}
|
||||
EventSourceError::InvalidContentType(_header_value, response) => {
|
||||
let OpenAiOutsideError {
|
||||
error: OpenAiInnerError { code, message, param, r#type },
|
||||
} = response.json().await?;
|
||||
|
||||
Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError { r#type, code, message, param, event_id: None },
|
||||
})
|
||||
}
|
||||
EventSourceError::Utf8(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "invalid_utf8_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::Parser(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "parser_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::Transport(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "transport_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "invalid_last_event_id".to_string(),
|
||||
code: None,
|
||||
message,
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::StreamEnded => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "stream_ended".to_string(),
|
||||
code: None,
|
||||
message: "Stream ended".to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
},
|
||||
OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "invalid_argument".to_string(),
|
||||
code: None,
|
||||
message,
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiOutsideError {
|
||||
/// Emitted when an error occurs.
|
||||
error: OpenAiInnerError,
|
||||
}
|
||||
|
||||
/// Emitted when an error occurs.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiInnerError {
|
||||
/// The error code.
|
||||
code: Option<String>,
|
||||
/// The error message.
|
||||
message: String,
|
||||
/// The error parameter.
|
||||
param: Option<String>,
|
||||
/// The type of the event. Always `error`.
|
||||
r#type: String,
|
||||
}
|
||||
|
187
crates/meilisearch/src/routes/chats/errors.rs
Normal file
187
crates/meilisearch/src/routes/chats/errors.rs
Normal file
@ -0,0 +1,187 @@
|
||||
use async_openai::error::{ApiError, OpenAIError};
|
||||
use async_openai::reqwest_eventsource::Error as EventSourceError;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiOutsideError {
|
||||
/// Emitted when an error occurs.
|
||||
error: OpenAiInnerError,
|
||||
}
|
||||
|
||||
/// Emitted when an error occurs.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct OpenAiInnerError {
|
||||
/// The error code.
|
||||
code: Option<String>,
|
||||
/// The error message.
|
||||
message: String,
|
||||
/// The error parameter.
|
||||
param: Option<String>,
|
||||
/// The type of the event. Always `error`.
|
||||
r#type: String,
|
||||
}
|
||||
|
||||
/// An error that occurs during the streaming process.
|
||||
///
|
||||
/// It directly comes from the OpenAI API and you can
|
||||
/// read more about error events on their website:
|
||||
/// <https://platform.openai.com/docs/api-reference/realtime-server-events/error>
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StreamErrorEvent {
|
||||
/// The unique ID of the server event.
|
||||
pub event_id: String,
|
||||
/// The event type, must be error.
|
||||
pub r#type: String,
|
||||
/// Details of the error.
|
||||
pub error: StreamError,
|
||||
}
|
||||
|
||||
/// Details of the error.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct StreamError {
|
||||
/// The type of error (e.g., "invalid_request_error", "server_error").
|
||||
pub r#type: String,
|
||||
/// Error code, if any.
|
||||
pub code: Option<String>,
|
||||
/// A human-readable error message.
|
||||
pub message: String,
|
||||
/// Parameter related to the error, if any.
|
||||
pub param: Option<String>,
|
||||
/// The event_id of the client event that caused the error, if applicable.
|
||||
pub event_id: Option<String>,
|
||||
}
|
||||
|
||||
impl StreamErrorEvent {
|
||||
pub async fn from_openai_error(error: OpenAIError) -> Result<Self, reqwest::Error> {
|
||||
let error_type = "error".to_string();
|
||||
match error {
|
||||
OpenAIError::Reqwest(e) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "internal_reqwest_error".to_string(),
|
||||
code: Some("internal".to_string()),
|
||||
message: e.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
OpenAIError::ApiError(ApiError { message, r#type, param, code }) => {
|
||||
Ok(StreamErrorEvent {
|
||||
r#type: error_type,
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
error: StreamError {
|
||||
r#type: r#type.unwrap_or_else(|| "unknown".to_string()),
|
||||
code,
|
||||
message,
|
||||
param,
|
||||
event_id: None,
|
||||
},
|
||||
})
|
||||
}
|
||||
OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "json_deserialize_error".to_string(),
|
||||
code: Some("internal".to_string()),
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
|
||||
OpenAIError::StreamError(error) => match error {
|
||||
EventSourceError::InvalidStatusCode(_status_code, response) => {
|
||||
let OpenAiOutsideError {
|
||||
error: OpenAiInnerError { code, message, param, r#type },
|
||||
} = response.json().await?;
|
||||
|
||||
Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError { r#type, code, message, param, event_id: None },
|
||||
})
|
||||
}
|
||||
EventSourceError::InvalidContentType(_header_value, response) => {
|
||||
let OpenAiOutsideError {
|
||||
error: OpenAiInnerError { code, message, param, r#type },
|
||||
} = response.json().await?;
|
||||
|
||||
Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError { r#type, code, message, param, event_id: None },
|
||||
})
|
||||
}
|
||||
EventSourceError::Utf8(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "invalid_utf8_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::Parser(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "parser_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::Transport(error) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "transport_error".to_string(),
|
||||
code: None,
|
||||
message: error.to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "invalid_last_event_id".to_string(),
|
||||
code: None,
|
||||
message,
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
EventSourceError::StreamEnded => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "stream_ended".to_string(),
|
||||
code: None,
|
||||
message: "Stream ended".to_string(),
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
},
|
||||
OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent {
|
||||
event_id: Uuid::new_v4().to_string(),
|
||||
r#type: error_type,
|
||||
error: StreamError {
|
||||
r#type: "invalid_argument".to_string(),
|
||||
code: None,
|
||||
message,
|
||||
param: None,
|
||||
event_id: None,
|
||||
},
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
@ -1,30 +1,35 @@
|
||||
use actix_web::{
|
||||
web::{self, Data},
|
||||
HttpResponse,
|
||||
};
|
||||
use deserr::{actix_web::AwebQueryParameter, Deserr};
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::HttpResponse;
|
||||
use deserr::actix_web::AwebQueryParameter;
|
||||
use deserr::Deserr;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_types::{
|
||||
deserr::{query_params::Param, DeserrQueryParamError},
|
||||
error::{
|
||||
deserr_codes::{InvalidIndexLimit, InvalidIndexOffset},
|
||||
ResponseError,
|
||||
},
|
||||
keys::actions,
|
||||
};
|
||||
use meilisearch_types::deserr::query_params::Param;
|
||||
use meilisearch_types::deserr::DeserrQueryParamError;
|
||||
use meilisearch_types::error::deserr_codes::{InvalidIndexLimit, InvalidIndexOffset};
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::keys::actions;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::debug;
|
||||
use utoipa::{IntoParams, ToSchema};
|
||||
|
||||
use crate::{
|
||||
extractors::authentication::{policies::ActionPolicy, GuardedData},
|
||||
routes::PAGINATION_DEFAULT_LIMIT,
|
||||
};
|
||||
|
||||
use super::Pagination;
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::GuardedData;
|
||||
use crate::routes::PAGINATION_DEFAULT_LIMIT;
|
||||
|
||||
pub mod chat_completions;
|
||||
mod errors;
|
||||
pub mod settings;
|
||||
mod utils;
|
||||
|
||||
/// The function name to report search progress.
|
||||
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
|
||||
/// The function name to append a conversation message in the user conversation.
|
||||
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
|
||||
/// The function name to report sources to the frontend.
|
||||
const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources";
|
||||
/// The *internal* function name to provide to the LLM to search in indexes.
|
||||
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||
|
||||
#[derive(Deserialize)]
|
||||
pub struct ChatsParam {
|
||||
|
243
crates/meilisearch/src/routes/chats/utils.rs
Normal file
243
crates/meilisearch/src/routes/chats/utils.rs
Normal file
@ -0,0 +1,243 @@
|
||||
use std::cell::RefCell;
|
||||
use std::sync::RwLock;
|
||||
|
||||
use actix_web_lab::sse::{self, Event};
|
||||
use async_openai::types::{
|
||||
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage,
|
||||
ChatCompletionStreamResponseDelta, ChatCompletionToolType, CreateChatCompletionStreamResponse,
|
||||
FunctionCall, FunctionCallStream, Role,
|
||||
};
|
||||
use bumpalo::Bump;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::milli::index::ChatConfig;
|
||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||
use meilisearch_types::milli::{
|
||||
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder,
|
||||
};
|
||||
use meilisearch_types::{Document, Index};
|
||||
use serde::Serialize;
|
||||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use super::errors::StreamErrorEvent;
|
||||
use super::MEILI_APPEND_CONVERSATION_MESSAGE_NAME;
|
||||
use crate::routes::chats::{MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME};
|
||||
|
||||
pub struct SseEventSender(Sender<Event>);
|
||||
|
||||
impl SseEventSender {
|
||||
pub fn new(sender: Sender<Event>) -> Self {
|
||||
Self(sender)
|
||||
}
|
||||
|
||||
/// Ask the front-end user to append this tool *call* to the conversation
|
||||
pub async fn append_tool_call_conversation_message(
|
||||
&self,
|
||||
resp: CreateChatCompletionStreamResponse,
|
||||
call_id: String,
|
||||
function_name: String,
|
||||
function_arguments: String,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[allow(deprecated)]
|
||||
let message =
|
||||
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
||||
content: None,
|
||||
refusal: None,
|
||||
name: None,
|
||||
audio: None,
|
||||
tool_calls: Some(vec![ChatCompletionMessageToolCall {
|
||||
id: call_id,
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: FunctionCall { name: function_name, arguments: function_arguments },
|
||||
}]),
|
||||
function_call: None,
|
||||
});
|
||||
|
||||
self.append_conversation_message(resp, &message).await
|
||||
}
|
||||
|
||||
/// Ask the front-end user to append this tool to the conversation
|
||||
pub async fn append_conversation_message(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
message: &ChatCompletionRequestMessage,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
let call_text = serde_json::to_string(message).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)]
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_search_progress(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
function_name: &str,
|
||||
function_arguments: &str,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides information about the current Meilisearch search operation.
|
||||
struct MeiliSearchProgress<'a> {
|
||||
/// The call ID to track the sources of the search.
|
||||
call_id: &'a str,
|
||||
/// The name of the function we are executing.
|
||||
function_name: &'a str,
|
||||
/// The arguments of the function we are executing, encoded in JSON.
|
||||
function_arguments: &'a str,
|
||||
}
|
||||
|
||||
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
||||
let call_text = serde_json::to_string(&progress).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)]
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn report_sources(
|
||||
&self,
|
||||
mut resp: CreateChatCompletionStreamResponse,
|
||||
call_id: &str,
|
||||
documents: &[Document],
|
||||
) -> Result<(), SendError<Event>> {
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
/// Provides sources of the search.
|
||||
struct MeiliSearchSources<'a> {
|
||||
/// The call ID to track the original search associated to those sources.
|
||||
call_id: &'a str,
|
||||
/// The documents associated with the search (call_id).
|
||||
/// Only the displayed attributes of the documents are returned.
|
||||
sources: &'a [Document],
|
||||
}
|
||||
|
||||
let sources = MeiliSearchSources { call_id, sources: documents };
|
||||
let call_text = serde_json::to_string(&sources).unwrap();
|
||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||
index: 0,
|
||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||
r#type: Some(ChatCompletionToolType::Function),
|
||||
function: Some(FunctionCallStream {
|
||||
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
||||
arguments: Some(call_text),
|
||||
}),
|
||||
};
|
||||
|
||||
resp.choices[0] = ChatChoiceStream {
|
||||
index: 0,
|
||||
#[allow(deprecated)]
|
||||
delta: ChatCompletionStreamResponseDelta {
|
||||
content: None,
|
||||
function_call: None,
|
||||
tool_calls: Some(vec![tool_call]),
|
||||
role: Some(Role::Assistant),
|
||||
refusal: None,
|
||||
},
|
||||
finish_reason: None,
|
||||
logprobs: None,
|
||||
};
|
||||
|
||||
self.send_json(&resp).await
|
||||
}
|
||||
|
||||
pub async fn forward_response(
|
||||
&self,
|
||||
resp: &CreateChatCompletionStreamResponse,
|
||||
) -> Result<(), SendError<Event>> {
|
||||
self.send_json(resp).await
|
||||
}
|
||||
|
||||
pub async fn send_error(&self, error: &StreamErrorEvent) -> Result<(), SendError<Event>> {
|
||||
self.send_json(error).await
|
||||
}
|
||||
|
||||
pub async fn stop(self) -> Result<(), SendError<Event>> {
|
||||
self.0.send(Event::Data(sse::Data::new("[DONE]"))).await
|
||||
}
|
||||
|
||||
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
|
||||
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
|
||||
}
|
||||
}
|
||||
|
||||
/// Format documents based on the provided template and maximum bytes.
|
||||
///
|
||||
/// This formatting function is usually used to generate a summary of the documents for LLMs.
|
||||
pub fn format_documents<'t, 'doc>(
|
||||
rtxn: &RoTxn<'t>,
|
||||
index: &Index,
|
||||
doc_alloc: &'doc Bump,
|
||||
internal_docids: Vec<DocumentId>,
|
||||
) -> Result<Vec<&'doc str>, ResponseError> {
|
||||
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
|
||||
|
||||
let prompt = Prompt::new(template, max_bytes).unwrap();
|
||||
let fid_map = index.fields_ids_map(rtxn)?;
|
||||
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
|
||||
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
|
||||
let global = RwLock::new(fid_map_with_meta);
|
||||
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
|
||||
|
||||
let external_ids: Vec<String> = index
|
||||
.external_id_of(rtxn, internal_docids.iter().copied())?
|
||||
.into_iter()
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
let mut renders = Vec::new();
|
||||
for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
|
||||
let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
|
||||
Some(doc) => doc,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap();
|
||||
renders.push(text);
|
||||
}
|
||||
|
||||
Ok(renders)
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user