mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-06-10 18:11:41 +02:00
Clean up chat completions modules a bit
This commit is contained in:
parent
201a808fe2
commit
7d574433b6
@ -1,26 +1,21 @@
|
|||||||
use std::cell::RefCell;
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::fmt::Write as _;
|
use std::fmt::Write as _;
|
||||||
use std::mem;
|
use std::mem;
|
||||||
use std::ops::ControlFlow;
|
use std::ops::ControlFlow;
|
||||||
use std::sync::RwLock;
|
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use actix_web::web::{self, Data};
|
use actix_web::web::{self, Data};
|
||||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||||
use actix_web_lab::sse::{self, Event, Sse};
|
use actix_web_lab::sse::{Event, Sse};
|
||||||
use async_openai::config::{Config, OpenAIConfig};
|
use async_openai::config::{Config, OpenAIConfig};
|
||||||
use async_openai::error::{ApiError, OpenAIError};
|
|
||||||
use async_openai::reqwest_eventsource::Error as EventSourceError;
|
|
||||||
use async_openai::types::{
|
use async_openai::types::{
|
||||||
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||||
ChatCompletionRequestAssistantMessage, ChatCompletionRequestAssistantMessageArgs,
|
ChatCompletionRequestAssistantMessageArgs, ChatCompletionRequestMessage,
|
||||||
ChatCompletionRequestMessage, ChatCompletionRequestSystemMessage,
|
ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
|
||||||
ChatCompletionRequestSystemMessageContent, ChatCompletionRequestToolMessage,
|
ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
|
||||||
ChatCompletionRequestToolMessageContent, ChatCompletionStreamResponseDelta,
|
ChatCompletionStreamResponseDelta, ChatCompletionToolArgs, ChatCompletionToolType,
|
||||||
ChatCompletionToolArgs, ChatCompletionToolType, CreateChatCompletionRequest,
|
CreateChatCompletionRequest, CreateChatCompletionStreamResponse, FinishReason, FunctionCall,
|
||||||
CreateChatCompletionStreamResponse, FinishReason, FunctionCall, FunctionCallStream,
|
FunctionCallStream, FunctionObjectArgs,
|
||||||
FunctionObjectArgs, Role,
|
|
||||||
};
|
};
|
||||||
use async_openai::Client;
|
use async_openai::Client;
|
||||||
use bumpalo::Bump;
|
use bumpalo::Bump;
|
||||||
@ -31,38 +26,30 @@ use meilisearch_types::error::{Code, ResponseError};
|
|||||||
use meilisearch_types::features::{
|
use meilisearch_types::features::{
|
||||||
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
|
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
|
||||||
};
|
};
|
||||||
use meilisearch_types::heed::RoTxn;
|
|
||||||
use meilisearch_types::keys::actions;
|
use meilisearch_types::keys::actions;
|
||||||
use meilisearch_types::milli::index::ChatConfig;
|
use meilisearch_types::milli::index::ChatConfig;
|
||||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
use meilisearch_types::milli::{all_obkv_to_json, obkv_to_json, TimeBudget};
|
||||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
|
||||||
use meilisearch_types::milli::{
|
|
||||||
all_obkv_to_json, obkv_to_json, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap,
|
|
||||||
MetadataBuilder, TimeBudget,
|
|
||||||
};
|
|
||||||
use meilisearch_types::{Document, Index};
|
use meilisearch_types::{Document, Index};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use tokio::runtime::Handle;
|
use tokio::runtime::Handle;
|
||||||
use tokio::sync::mpsc::error::SendError;
|
use tokio::sync::mpsc::error::SendError;
|
||||||
use tokio::sync::mpsc::Sender;
|
|
||||||
use uuid::Uuid;
|
|
||||||
|
|
||||||
use super::ChatsParam;
|
use super::errors::StreamErrorEvent;
|
||||||
|
use super::utils::format_documents;
|
||||||
|
use super::{
|
||||||
|
ChatsParam, MEILI_APPEND_CONVERSATION_MESSAGE_NAME, MEILI_SEARCH_IN_INDEX_FUNCTION_NAME,
|
||||||
|
MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME,
|
||||||
|
};
|
||||||
use crate::error::MeilisearchHttpError;
|
use crate::error::MeilisearchHttpError;
|
||||||
use crate::extractors::authentication::policies::ActionPolicy;
|
use crate::extractors::authentication::policies::ActionPolicy;
|
||||||
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
|
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
|
||||||
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
||||||
|
use crate::routes::chats::utils::SseEventSender;
|
||||||
use crate::routes::indexes::search::search_kind;
|
use crate::routes::indexes::search::search_kind;
|
||||||
use crate::search::{add_search_rules, prepare_search, search_from_kind, SearchQuery};
|
use crate::search::{add_search_rules, prepare_search, search_from_kind, SearchQuery};
|
||||||
use crate::search_queue::SearchQueue;
|
use crate::search_queue::SearchQueue;
|
||||||
|
|
||||||
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
|
|
||||||
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
|
|
||||||
const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources";
|
|
||||||
const MEILI_REPORT_ERRORS_NAME: &str = "_meiliReportErrors";
|
|
||||||
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
|
||||||
|
|
||||||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||||
cfg.service(web::resource("").route(web::post().to(chat)));
|
cfg.service(web::resource("").route(web::post().to(chat)));
|
||||||
}
|
}
|
||||||
@ -140,7 +127,6 @@ fn setup_search_tool(
|
|||||||
let mut report_progress = false;
|
let mut report_progress = false;
|
||||||
let mut report_sources = false;
|
let mut report_sources = false;
|
||||||
let mut append_to_conversation = false;
|
let mut append_to_conversation = false;
|
||||||
let mut report_errors = false;
|
|
||||||
tools.retain(|tool| {
|
tools.retain(|tool| {
|
||||||
match tool.function.name.as_str() {
|
match tool.function.name.as_str() {
|
||||||
MEILI_SEARCH_PROGRESS_NAME => {
|
MEILI_SEARCH_PROGRESS_NAME => {
|
||||||
@ -155,10 +141,6 @@ fn setup_search_tool(
|
|||||||
append_to_conversation = true;
|
append_to_conversation = true;
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
MEILI_REPORT_ERRORS_NAME => {
|
|
||||||
report_errors = true;
|
|
||||||
false
|
|
||||||
}
|
|
||||||
_ => true, // keep other tools
|
_ => true, // keep other tools
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
@ -443,7 +425,7 @@ async fn streamed_chat(
|
|||||||
tracing::debug!("Conversation function support: {function_support:?}");
|
tracing::debug!("Conversation function support: {function_support:?}");
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||||
let tx = SseEventSender(tx);
|
let tx = SseEventSender::new(tx);
|
||||||
let _join_handle = Handle::current().spawn(async move {
|
let _join_handle = Handle::current().spawn(async move {
|
||||||
let client = Client::with_config(config.clone());
|
let client = Client::with_config(config.clone());
|
||||||
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
let mut global_tool_calls = HashMap::<u32, Call>::new();
|
||||||
@ -521,9 +503,7 @@ async fn run_conversation<C: Config>(
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.or_insert_with(|| {
|
.or_insert_with(|| {
|
||||||
if name
|
if name.as_deref() == Some(MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
||||||
.as_ref()
|
|
||||||
.map_or(false, |n| n == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME)
|
|
||||||
{
|
{
|
||||||
Call::Internal {
|
Call::Internal {
|
||||||
id: id.as_ref().unwrap().clone(),
|
id: id.as_ref().unwrap().clone(),
|
||||||
@ -680,181 +660,6 @@ async fn handle_meili_tools(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SseEventSender(Sender<Event>);
|
|
||||||
|
|
||||||
impl SseEventSender {
|
|
||||||
/// Ask the front-end user to append this tool *call* to the conversation
|
|
||||||
pub async fn append_tool_call_conversation_message(
|
|
||||||
&self,
|
|
||||||
resp: CreateChatCompletionStreamResponse,
|
|
||||||
call_id: String,
|
|
||||||
function_name: String,
|
|
||||||
function_arguments: String,
|
|
||||||
) -> Result<(), SendError<Event>> {
|
|
||||||
#[allow(deprecated)]
|
|
||||||
let message =
|
|
||||||
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
|
||||||
content: None,
|
|
||||||
refusal: None,
|
|
||||||
name: None,
|
|
||||||
audio: None,
|
|
||||||
tool_calls: Some(vec![ChatCompletionMessageToolCall {
|
|
||||||
id: call_id,
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: FunctionCall { name: function_name, arguments: function_arguments },
|
|
||||||
}]),
|
|
||||||
function_call: None,
|
|
||||||
});
|
|
||||||
|
|
||||||
self.append_conversation_message(resp, &message).await
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Ask the front-end user to append this tool to the conversation
|
|
||||||
pub async fn append_conversation_message(
|
|
||||||
&self,
|
|
||||||
mut resp: CreateChatCompletionStreamResponse,
|
|
||||||
message: &ChatCompletionRequestMessage,
|
|
||||||
) -> Result<(), SendError<Event>> {
|
|
||||||
let call_text = serde_json::to_string(message).unwrap();
|
|
||||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
|
||||||
index: 0,
|
|
||||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: Some(FunctionCallStream {
|
|
||||||
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
|
|
||||||
arguments: Some(call_text),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
resp.choices[0] = ChatChoiceStream {
|
|
||||||
index: 0,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
delta: ChatCompletionStreamResponseDelta {
|
|
||||||
content: None,
|
|
||||||
function_call: None,
|
|
||||||
tool_calls: Some(vec![tool_call]),
|
|
||||||
role: Some(Role::Assistant),
|
|
||||||
refusal: None,
|
|
||||||
},
|
|
||||||
finish_reason: None,
|
|
||||||
logprobs: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
self.send_json(&resp).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn report_search_progress(
|
|
||||||
&self,
|
|
||||||
mut resp: CreateChatCompletionStreamResponse,
|
|
||||||
call_id: &str,
|
|
||||||
function_name: &str,
|
|
||||||
function_arguments: &str,
|
|
||||||
) -> Result<(), SendError<Event>> {
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
/// Provides information about the current Meilisearch search operation.
|
|
||||||
struct MeiliSearchProgress<'a> {
|
|
||||||
/// The call ID to track the sources of the search.
|
|
||||||
call_id: &'a str,
|
|
||||||
/// The name of the function we are executing.
|
|
||||||
function_name: &'a str,
|
|
||||||
/// The arguments of the function we are executing, encoded in JSON.
|
|
||||||
function_arguments: &'a str,
|
|
||||||
}
|
|
||||||
|
|
||||||
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
|
||||||
let call_text = serde_json::to_string(&progress).unwrap();
|
|
||||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
|
||||||
index: 0,
|
|
||||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: Some(FunctionCallStream {
|
|
||||||
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
|
||||||
arguments: Some(call_text),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
resp.choices[0] = ChatChoiceStream {
|
|
||||||
index: 0,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
delta: ChatCompletionStreamResponseDelta {
|
|
||||||
content: None,
|
|
||||||
function_call: None,
|
|
||||||
tool_calls: Some(vec![tool_call]),
|
|
||||||
role: Some(Role::Assistant),
|
|
||||||
refusal: None,
|
|
||||||
},
|
|
||||||
finish_reason: None,
|
|
||||||
logprobs: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
self.send_json(&resp).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn report_sources(
|
|
||||||
&self,
|
|
||||||
mut resp: CreateChatCompletionStreamResponse,
|
|
||||||
call_id: &str,
|
|
||||||
documents: &[Document],
|
|
||||||
) -> Result<(), SendError<Event>> {
|
|
||||||
#[derive(Debug, Clone, Serialize)]
|
|
||||||
/// Provides sources of the search.
|
|
||||||
struct MeiliSearchSources<'a> {
|
|
||||||
/// The call ID to track the original search associated to those sources.
|
|
||||||
call_id: &'a str,
|
|
||||||
/// The documents associated with the search (call_id).
|
|
||||||
/// Only the displayed attributes of the documents are returned.
|
|
||||||
sources: &'a [Document],
|
|
||||||
}
|
|
||||||
|
|
||||||
let sources = MeiliSearchSources { call_id, sources: documents };
|
|
||||||
let call_text = serde_json::to_string(&sources).unwrap();
|
|
||||||
let tool_call = ChatCompletionMessageToolCallChunk {
|
|
||||||
index: 0,
|
|
||||||
id: Some(uuid::Uuid::new_v4().to_string()),
|
|
||||||
r#type: Some(ChatCompletionToolType::Function),
|
|
||||||
function: Some(FunctionCallStream {
|
|
||||||
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
|
||||||
arguments: Some(call_text),
|
|
||||||
}),
|
|
||||||
};
|
|
||||||
|
|
||||||
resp.choices[0] = ChatChoiceStream {
|
|
||||||
index: 0,
|
|
||||||
#[allow(deprecated)]
|
|
||||||
delta: ChatCompletionStreamResponseDelta {
|
|
||||||
content: None,
|
|
||||||
function_call: None,
|
|
||||||
tool_calls: Some(vec![tool_call]),
|
|
||||||
role: Some(Role::Assistant),
|
|
||||||
refusal: None,
|
|
||||||
},
|
|
||||||
finish_reason: None,
|
|
||||||
logprobs: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
self.send_json(&resp).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn forward_response(
|
|
||||||
&self,
|
|
||||||
resp: &CreateChatCompletionStreamResponse,
|
|
||||||
) -> Result<(), SendError<Event>> {
|
|
||||||
self.send_json(resp).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn send_error(&self, error: &StreamErrorEvent) -> Result<(), SendError<Event>> {
|
|
||||||
self.send_json(error).await
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn stop(self) -> Result<(), SendError<Event>> {
|
|
||||||
self.0.send(Event::Data(sse::Data::new("[DONE]"))).await
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
|
|
||||||
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// The structure used to aggregate the function calls to make.
|
/// The structure used to aggregate the function calls to make.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum Call {
|
enum Call {
|
||||||
@ -892,220 +697,3 @@ struct SearchInIndexParameters {
|
|||||||
/// The query parameter to use.
|
/// The query parameter to use.
|
||||||
q: Option<String>,
|
q: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn format_documents<'t, 'doc>(
|
|
||||||
rtxn: &RoTxn<'t>,
|
|
||||||
index: &Index,
|
|
||||||
doc_alloc: &'doc Bump,
|
|
||||||
internal_docids: Vec<DocumentId>,
|
|
||||||
) -> Result<Vec<&'doc str>, ResponseError> {
|
|
||||||
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
|
|
||||||
|
|
||||||
let prompt = Prompt::new(template, max_bytes).unwrap();
|
|
||||||
let fid_map = index.fields_ids_map(rtxn)?;
|
|
||||||
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
|
|
||||||
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
|
|
||||||
let global = RwLock::new(fid_map_with_meta);
|
|
||||||
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
|
|
||||||
|
|
||||||
let external_ids: Vec<String> = index
|
|
||||||
.external_id_of(rtxn, internal_docids.iter().copied())?
|
|
||||||
.into_iter()
|
|
||||||
.collect::<Result<_, _>>()?;
|
|
||||||
|
|
||||||
let mut renders = Vec::new();
|
|
||||||
for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
|
|
||||||
let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
|
|
||||||
Some(doc) => doc,
|
|
||||||
None => continue,
|
|
||||||
};
|
|
||||||
|
|
||||||
let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap();
|
|
||||||
renders.push(text);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(renders)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// An error that occurs during the streaming process.
|
|
||||||
///
|
|
||||||
/// It directly comes from the OpenAI API and you can
|
|
||||||
/// read more about error events on their website:
|
|
||||||
/// <https://platform.openai.com/docs/api-reference/realtime-server-events/error>
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct StreamErrorEvent {
|
|
||||||
/// The unique ID of the server event.
|
|
||||||
event_id: String,
|
|
||||||
/// The event type, must be error.
|
|
||||||
r#type: String,
|
|
||||||
/// Details of the error.
|
|
||||||
error: StreamError,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Details of the error.
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
|
||||||
pub struct StreamError {
|
|
||||||
/// The type of error (e.g., "invalid_request_error", "server_error").
|
|
||||||
r#type: String,
|
|
||||||
/// Error code, if any.
|
|
||||||
code: Option<String>,
|
|
||||||
/// A human-readable error message.
|
|
||||||
message: String,
|
|
||||||
/// Parameter related to the error, if any.
|
|
||||||
param: Option<String>,
|
|
||||||
/// The event_id of the client event that caused the error, if applicable.
|
|
||||||
event_id: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl StreamErrorEvent {
|
|
||||||
pub async fn from_openai_error(error: OpenAIError) -> Result<Self, reqwest::Error> {
|
|
||||||
let error_type = "error".to_string();
|
|
||||||
match error {
|
|
||||||
OpenAIError::Reqwest(e) => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "internal_reqwest_error".to_string(),
|
|
||||||
code: Some("internal".to_string()),
|
|
||||||
message: e.to_string(),
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
OpenAIError::ApiError(ApiError { message, r#type, param, code }) => {
|
|
||||||
Ok(StreamErrorEvent {
|
|
||||||
r#type: error_type,
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
error: StreamError {
|
|
||||||
r#type: r#type.unwrap_or_else(|| "unknown".to_string()),
|
|
||||||
code,
|
|
||||||
message,
|
|
||||||
param,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "json_deserialize_error".to_string(),
|
|
||||||
code: Some("internal".to_string()),
|
|
||||||
message: error.to_string(),
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
|
|
||||||
OpenAIError::StreamError(error) => match error {
|
|
||||||
EventSourceError::InvalidStatusCode(_status_code, response) => {
|
|
||||||
let OpenAiOutsideError {
|
|
||||||
error: OpenAiInnerError { code, message, param, r#type },
|
|
||||||
} = response.json().await?;
|
|
||||||
|
|
||||||
Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError { r#type, code, message, param, event_id: None },
|
|
||||||
})
|
|
||||||
}
|
|
||||||
EventSourceError::InvalidContentType(_header_value, response) => {
|
|
||||||
let OpenAiOutsideError {
|
|
||||||
error: OpenAiInnerError { code, message, param, r#type },
|
|
||||||
} = response.json().await?;
|
|
||||||
|
|
||||||
Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError { r#type, code, message, param, event_id: None },
|
|
||||||
})
|
|
||||||
}
|
|
||||||
EventSourceError::Utf8(error) => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "invalid_utf8_error".to_string(),
|
|
||||||
code: None,
|
|
||||||
message: error.to_string(),
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
EventSourceError::Parser(error) => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "parser_error".to_string(),
|
|
||||||
code: None,
|
|
||||||
message: error.to_string(),
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
EventSourceError::Transport(error) => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "transport_error".to_string(),
|
|
||||||
code: None,
|
|
||||||
message: error.to_string(),
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "invalid_last_event_id".to_string(),
|
|
||||||
code: None,
|
|
||||||
message,
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
EventSourceError::StreamEnded => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "stream_ended".to_string(),
|
|
||||||
code: None,
|
|
||||||
message: "Stream ended".to_string(),
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
},
|
|
||||||
OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent {
|
|
||||||
event_id: Uuid::new_v4().to_string(),
|
|
||||||
r#type: error_type,
|
|
||||||
error: StreamError {
|
|
||||||
r#type: "invalid_argument".to_string(),
|
|
||||||
code: None,
|
|
||||||
message,
|
|
||||||
param: None,
|
|
||||||
event_id: None,
|
|
||||||
},
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct OpenAiOutsideError {
|
|
||||||
/// Emitted when an error occurs.
|
|
||||||
error: OpenAiInnerError,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Emitted when an error occurs.
|
|
||||||
#[derive(Debug, Clone, Deserialize)]
|
|
||||||
pub struct OpenAiInnerError {
|
|
||||||
/// The error code.
|
|
||||||
code: Option<String>,
|
|
||||||
/// The error message.
|
|
||||||
message: String,
|
|
||||||
/// The error parameter.
|
|
||||||
param: Option<String>,
|
|
||||||
/// The type of the event. Always `error`.
|
|
||||||
r#type: String,
|
|
||||||
}
|
|
||||||
|
187
crates/meilisearch/src/routes/chats/errors.rs
Normal file
187
crates/meilisearch/src/routes/chats/errors.rs
Normal file
@ -0,0 +1,187 @@
|
|||||||
|
use async_openai::error::{ApiError, OpenAIError};
|
||||||
|
use async_openai::reqwest_eventsource::Error as EventSourceError;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct OpenAiOutsideError {
|
||||||
|
/// Emitted when an error occurs.
|
||||||
|
error: OpenAiInnerError,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emitted when an error occurs.
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct OpenAiInnerError {
|
||||||
|
/// The error code.
|
||||||
|
code: Option<String>,
|
||||||
|
/// The error message.
|
||||||
|
message: String,
|
||||||
|
/// The error parameter.
|
||||||
|
param: Option<String>,
|
||||||
|
/// The type of the event. Always `error`.
|
||||||
|
r#type: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An error that occurs during the streaming process.
|
||||||
|
///
|
||||||
|
/// It directly comes from the OpenAI API and you can
|
||||||
|
/// read more about error events on their website:
|
||||||
|
/// <https://platform.openai.com/docs/api-reference/realtime-server-events/error>
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct StreamErrorEvent {
|
||||||
|
/// The unique ID of the server event.
|
||||||
|
pub event_id: String,
|
||||||
|
/// The event type, must be error.
|
||||||
|
pub r#type: String,
|
||||||
|
/// Details of the error.
|
||||||
|
pub error: StreamError,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Details of the error.
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct StreamError {
|
||||||
|
/// The type of error (e.g., "invalid_request_error", "server_error").
|
||||||
|
pub r#type: String,
|
||||||
|
/// Error code, if any.
|
||||||
|
pub code: Option<String>,
|
||||||
|
/// A human-readable error message.
|
||||||
|
pub message: String,
|
||||||
|
/// Parameter related to the error, if any.
|
||||||
|
pub param: Option<String>,
|
||||||
|
/// The event_id of the client event that caused the error, if applicable.
|
||||||
|
pub event_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamErrorEvent {
|
||||||
|
pub async fn from_openai_error(error: OpenAIError) -> Result<Self, reqwest::Error> {
|
||||||
|
let error_type = "error".to_string();
|
||||||
|
match error {
|
||||||
|
OpenAIError::Reqwest(e) => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "internal_reqwest_error".to_string(),
|
||||||
|
code: Some("internal".to_string()),
|
||||||
|
message: e.to_string(),
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
OpenAIError::ApiError(ApiError { message, r#type, param, code }) => {
|
||||||
|
Ok(StreamErrorEvent {
|
||||||
|
r#type: error_type,
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
error: StreamError {
|
||||||
|
r#type: r#type.unwrap_or_else(|| "unknown".to_string()),
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
param,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
OpenAIError::JSONDeserialize(error) => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "json_deserialize_error".to_string(),
|
||||||
|
code: Some("internal".to_string()),
|
||||||
|
message: error.to_string(),
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
OpenAIError::FileSaveError(_) | OpenAIError::FileReadError(_) => unreachable!(),
|
||||||
|
OpenAIError::StreamError(error) => match error {
|
||||||
|
EventSourceError::InvalidStatusCode(_status_code, response) => {
|
||||||
|
let OpenAiOutsideError {
|
||||||
|
error: OpenAiInnerError { code, message, param, r#type },
|
||||||
|
} = response.json().await?;
|
||||||
|
|
||||||
|
Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError { r#type, code, message, param, event_id: None },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
EventSourceError::InvalidContentType(_header_value, response) => {
|
||||||
|
let OpenAiOutsideError {
|
||||||
|
error: OpenAiInnerError { code, message, param, r#type },
|
||||||
|
} = response.json().await?;
|
||||||
|
|
||||||
|
Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError { r#type, code, message, param, event_id: None },
|
||||||
|
})
|
||||||
|
}
|
||||||
|
EventSourceError::Utf8(error) => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "invalid_utf8_error".to_string(),
|
||||||
|
code: None,
|
||||||
|
message: error.to_string(),
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
EventSourceError::Parser(error) => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "parser_error".to_string(),
|
||||||
|
code: None,
|
||||||
|
message: error.to_string(),
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
EventSourceError::Transport(error) => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "transport_error".to_string(),
|
||||||
|
code: None,
|
||||||
|
message: error.to_string(),
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
EventSourceError::InvalidLastEventId(message) => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "invalid_last_event_id".to_string(),
|
||||||
|
code: None,
|
||||||
|
message,
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
EventSourceError::StreamEnded => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "stream_ended".to_string(),
|
||||||
|
code: None,
|
||||||
|
message: "Stream ended".to_string(),
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
OpenAIError::InvalidArgument(message) => Ok(StreamErrorEvent {
|
||||||
|
event_id: Uuid::new_v4().to_string(),
|
||||||
|
r#type: error_type,
|
||||||
|
error: StreamError {
|
||||||
|
r#type: "invalid_argument".to_string(),
|
||||||
|
code: None,
|
||||||
|
message,
|
||||||
|
param: None,
|
||||||
|
event_id: None,
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -1,30 +1,35 @@
|
|||||||
use actix_web::{
|
use actix_web::web::{self, Data};
|
||||||
web::{self, Data},
|
use actix_web::HttpResponse;
|
||||||
HttpResponse,
|
use deserr::actix_web::AwebQueryParameter;
|
||||||
};
|
use deserr::Deserr;
|
||||||
use deserr::{actix_web::AwebQueryParameter, Deserr};
|
|
||||||
use index_scheduler::IndexScheduler;
|
use index_scheduler::IndexScheduler;
|
||||||
use meilisearch_types::{
|
use meilisearch_types::deserr::query_params::Param;
|
||||||
deserr::{query_params::Param, DeserrQueryParamError},
|
use meilisearch_types::deserr::DeserrQueryParamError;
|
||||||
error::{
|
use meilisearch_types::error::deserr_codes::{InvalidIndexLimit, InvalidIndexOffset};
|
||||||
deserr_codes::{InvalidIndexLimit, InvalidIndexOffset},
|
use meilisearch_types::error::ResponseError;
|
||||||
ResponseError,
|
use meilisearch_types::keys::actions;
|
||||||
},
|
|
||||||
keys::actions,
|
|
||||||
};
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
use utoipa::{IntoParams, ToSchema};
|
use utoipa::{IntoParams, ToSchema};
|
||||||
|
|
||||||
use crate::{
|
|
||||||
extractors::authentication::{policies::ActionPolicy, GuardedData},
|
|
||||||
routes::PAGINATION_DEFAULT_LIMIT,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::Pagination;
|
use super::Pagination;
|
||||||
|
use crate::extractors::authentication::policies::ActionPolicy;
|
||||||
|
use crate::extractors::authentication::GuardedData;
|
||||||
|
use crate::routes::PAGINATION_DEFAULT_LIMIT;
|
||||||
|
|
||||||
pub mod chat_completions;
|
pub mod chat_completions;
|
||||||
|
mod errors;
|
||||||
pub mod settings;
|
pub mod settings;
|
||||||
|
mod utils;
|
||||||
|
|
||||||
|
/// The function name to report search progress.
|
||||||
|
const MEILI_SEARCH_PROGRESS_NAME: &str = "_meiliSearchProgress";
|
||||||
|
/// The function name to append a conversation message in the user conversation.
|
||||||
|
const MEILI_APPEND_CONVERSATION_MESSAGE_NAME: &str = "_meiliAppendConversationMessage";
|
||||||
|
/// The function name to report sources to the frontend.
|
||||||
|
const MEILI_SEARCH_SOURCES_NAME: &str = "_meiliSearchSources";
|
||||||
|
/// The *internal* function name to provide to the LLM to search in indexes.
|
||||||
|
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
pub struct ChatsParam {
|
pub struct ChatsParam {
|
||||||
|
243
crates/meilisearch/src/routes/chats/utils.rs
Normal file
243
crates/meilisearch/src/routes/chats/utils.rs
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
use std::cell::RefCell;
|
||||||
|
use std::sync::RwLock;
|
||||||
|
|
||||||
|
use actix_web_lab::sse::{self, Event};
|
||||||
|
use async_openai::types::{
|
||||||
|
ChatChoiceStream, ChatCompletionMessageToolCall, ChatCompletionMessageToolCallChunk,
|
||||||
|
ChatCompletionRequestAssistantMessage, ChatCompletionRequestMessage,
|
||||||
|
ChatCompletionStreamResponseDelta, ChatCompletionToolType, CreateChatCompletionStreamResponse,
|
||||||
|
FunctionCall, FunctionCallStream, Role,
|
||||||
|
};
|
||||||
|
use bumpalo::Bump;
|
||||||
|
use meilisearch_types::error::ResponseError;
|
||||||
|
use meilisearch_types::heed::RoTxn;
|
||||||
|
use meilisearch_types::milli::index::ChatConfig;
|
||||||
|
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||||
|
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||||
|
use meilisearch_types::milli::{
|
||||||
|
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder,
|
||||||
|
};
|
||||||
|
use meilisearch_types::{Document, Index};
|
||||||
|
use serde::Serialize;
|
||||||
|
use tokio::sync::mpsc::error::SendError;
|
||||||
|
use tokio::sync::mpsc::Sender;
|
||||||
|
|
||||||
|
use super::errors::StreamErrorEvent;
|
||||||
|
use super::MEILI_APPEND_CONVERSATION_MESSAGE_NAME;
|
||||||
|
use crate::routes::chats::{MEILI_SEARCH_PROGRESS_NAME, MEILI_SEARCH_SOURCES_NAME};
|
||||||
|
|
||||||
|
pub struct SseEventSender(Sender<Event>);
|
||||||
|
|
||||||
|
impl SseEventSender {
|
||||||
|
pub fn new(sender: Sender<Event>) -> Self {
|
||||||
|
Self(sender)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ask the front-end user to append this tool *call* to the conversation
|
||||||
|
pub async fn append_tool_call_conversation_message(
|
||||||
|
&self,
|
||||||
|
resp: CreateChatCompletionStreamResponse,
|
||||||
|
call_id: String,
|
||||||
|
function_name: String,
|
||||||
|
function_arguments: String,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
#[allow(deprecated)]
|
||||||
|
let message =
|
||||||
|
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
|
||||||
|
content: None,
|
||||||
|
refusal: None,
|
||||||
|
name: None,
|
||||||
|
audio: None,
|
||||||
|
tool_calls: Some(vec![ChatCompletionMessageToolCall {
|
||||||
|
id: call_id,
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: FunctionCall { name: function_name, arguments: function_arguments },
|
||||||
|
}]),
|
||||||
|
function_call: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
self.append_conversation_message(resp, &message).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ask the front-end user to append this tool to the conversation
|
||||||
|
pub async fn append_conversation_message(
|
||||||
|
&self,
|
||||||
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
|
message: &ChatCompletionRequestMessage,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
let call_text = serde_json::to_string(message).unwrap();
|
||||||
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
|
index: 0,
|
||||||
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: Some(FunctionCallStream {
|
||||||
|
name: Some(MEILI_APPEND_CONVERSATION_MESSAGE_NAME.to_string()),
|
||||||
|
arguments: Some(call_text),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
resp.choices[0] = ChatChoiceStream {
|
||||||
|
index: 0,
|
||||||
|
#[allow(deprecated)]
|
||||||
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
|
content: None,
|
||||||
|
function_call: None,
|
||||||
|
tool_calls: Some(vec![tool_call]),
|
||||||
|
role: Some(Role::Assistant),
|
||||||
|
refusal: None,
|
||||||
|
},
|
||||||
|
finish_reason: None,
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.send_json(&resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn report_search_progress(
|
||||||
|
&self,
|
||||||
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
|
call_id: &str,
|
||||||
|
function_name: &str,
|
||||||
|
function_arguments: &str,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
/// Provides information about the current Meilisearch search operation.
|
||||||
|
struct MeiliSearchProgress<'a> {
|
||||||
|
/// The call ID to track the sources of the search.
|
||||||
|
call_id: &'a str,
|
||||||
|
/// The name of the function we are executing.
|
||||||
|
function_name: &'a str,
|
||||||
|
/// The arguments of the function we are executing, encoded in JSON.
|
||||||
|
function_arguments: &'a str,
|
||||||
|
}
|
||||||
|
|
||||||
|
let progress = MeiliSearchProgress { call_id, function_name, function_arguments };
|
||||||
|
let call_text = serde_json::to_string(&progress).unwrap();
|
||||||
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
|
index: 0,
|
||||||
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: Some(FunctionCallStream {
|
||||||
|
name: Some(MEILI_SEARCH_PROGRESS_NAME.to_string()),
|
||||||
|
arguments: Some(call_text),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
resp.choices[0] = ChatChoiceStream {
|
||||||
|
index: 0,
|
||||||
|
#[allow(deprecated)]
|
||||||
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
|
content: None,
|
||||||
|
function_call: None,
|
||||||
|
tool_calls: Some(vec![tool_call]),
|
||||||
|
role: Some(Role::Assistant),
|
||||||
|
refusal: None,
|
||||||
|
},
|
||||||
|
finish_reason: None,
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.send_json(&resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn report_sources(
|
||||||
|
&self,
|
||||||
|
mut resp: CreateChatCompletionStreamResponse,
|
||||||
|
call_id: &str,
|
||||||
|
documents: &[Document],
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
/// Provides sources of the search.
|
||||||
|
struct MeiliSearchSources<'a> {
|
||||||
|
/// The call ID to track the original search associated to those sources.
|
||||||
|
call_id: &'a str,
|
||||||
|
/// The documents associated with the search (call_id).
|
||||||
|
/// Only the displayed attributes of the documents are returned.
|
||||||
|
sources: &'a [Document],
|
||||||
|
}
|
||||||
|
|
||||||
|
let sources = MeiliSearchSources { call_id, sources: documents };
|
||||||
|
let call_text = serde_json::to_string(&sources).unwrap();
|
||||||
|
let tool_call = ChatCompletionMessageToolCallChunk {
|
||||||
|
index: 0,
|
||||||
|
id: Some(uuid::Uuid::new_v4().to_string()),
|
||||||
|
r#type: Some(ChatCompletionToolType::Function),
|
||||||
|
function: Some(FunctionCallStream {
|
||||||
|
name: Some(MEILI_SEARCH_SOURCES_NAME.to_string()),
|
||||||
|
arguments: Some(call_text),
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
|
||||||
|
resp.choices[0] = ChatChoiceStream {
|
||||||
|
index: 0,
|
||||||
|
#[allow(deprecated)]
|
||||||
|
delta: ChatCompletionStreamResponseDelta {
|
||||||
|
content: None,
|
||||||
|
function_call: None,
|
||||||
|
tool_calls: Some(vec![tool_call]),
|
||||||
|
role: Some(Role::Assistant),
|
||||||
|
refusal: None,
|
||||||
|
},
|
||||||
|
finish_reason: None,
|
||||||
|
logprobs: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
self.send_json(&resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn forward_response(
|
||||||
|
&self,
|
||||||
|
resp: &CreateChatCompletionStreamResponse,
|
||||||
|
) -> Result<(), SendError<Event>> {
|
||||||
|
self.send_json(resp).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send_error(&self, error: &StreamErrorEvent) -> Result<(), SendError<Event>> {
|
||||||
|
self.send_json(error).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn stop(self) -> Result<(), SendError<Event>> {
|
||||||
|
self.0.send(Event::Data(sse::Data::new("[DONE]"))).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_json<S: Serialize>(&self, data: &S) -> Result<(), SendError<Event>> {
|
||||||
|
self.0.send(Event::Data(sse::Data::new_json(data).unwrap())).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Format documents based on the provided template and maximum bytes.
|
||||||
|
///
|
||||||
|
/// This formatting function is usually used to generate a summary of the documents for LLMs.
|
||||||
|
pub fn format_documents<'t, 'doc>(
|
||||||
|
rtxn: &RoTxn<'t>,
|
||||||
|
index: &Index,
|
||||||
|
doc_alloc: &'doc Bump,
|
||||||
|
internal_docids: Vec<DocumentId>,
|
||||||
|
) -> Result<Vec<&'doc str>, ResponseError> {
|
||||||
|
let ChatConfig { prompt: PromptData { template, max_bytes }, .. } = index.chat_config(rtxn)?;
|
||||||
|
|
||||||
|
let prompt = Prompt::new(template, max_bytes).unwrap();
|
||||||
|
let fid_map = index.fields_ids_map(rtxn)?;
|
||||||
|
let metadata_builder = MetadataBuilder::from_index(index, rtxn)?;
|
||||||
|
let fid_map_with_meta = FieldIdMapWithMetadata::new(fid_map.clone(), metadata_builder);
|
||||||
|
let global = RwLock::new(fid_map_with_meta);
|
||||||
|
let gfid_map = RefCell::new(GlobalFieldsIdsMap::new(&global));
|
||||||
|
|
||||||
|
let external_ids: Vec<String> = index
|
||||||
|
.external_id_of(rtxn, internal_docids.iter().copied())?
|
||||||
|
.into_iter()
|
||||||
|
.collect::<Result<_, _>>()?;
|
||||||
|
|
||||||
|
let mut renders = Vec::new();
|
||||||
|
for (docid, external_docid) in internal_docids.into_iter().zip(external_ids) {
|
||||||
|
let document = match DocumentFromDb::new(docid, rtxn, index, &fid_map)? {
|
||||||
|
Some(doc) => doc,
|
||||||
|
None => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let text = prompt.render_document(&external_docid, document, &gfid_map, doc_alloc).unwrap();
|
||||||
|
renders.push(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(renders)
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user