Better chat completions settings management

This commit is contained in:
Kerollmops 2025-05-30 15:02:24 +02:00 committed by Clément Renault
parent 0f7f5fa104
commit 02cbcea3db
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
6 changed files with 219 additions and 107 deletions

View file

@ -26,13 +26,15 @@ use bumpalo::Bump;
use futures::StreamExt;
use index_scheduler::IndexScheduler;
use meilisearch_auth::AuthController;
use meilisearch_types::error::ResponseError;
use meilisearch_types::error::{Code, ResponseError};
use meilisearch_types::features::{
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
};
use meilisearch_types::heed::RoTxn;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::ChatConfig;
use meilisearch_types::milli::prompt::{Prompt, PromptData};
use meilisearch_types::milli::update::new::document::DocumentFromDb;
use meilisearch_types::milli::update::Setting;
use meilisearch_types::milli::{
all_obkv_to_json, obkv_to_json, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap,
MetadataBuilder, TimeBudget,
@ -44,7 +46,6 @@ use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::Sender;
use super::settings::{ChatPrompts, GlobalChatSettings};
use super::ChatsParam;
use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::policies::ActionPolicy;
@ -132,7 +133,7 @@ fn setup_search_tool(
index_scheduler: &Data<IndexScheduler>,
filters: &meilisearch_auth::AuthFilter,
chat_completion: &mut CreateChatCompletionRequest,
prompts: &ChatPrompts,
prompts: &DbChatCompletionPrompts,
) -> Result<FunctionSupport, ResponseError> {
let tools = chat_completion.tools.get_or_insert_default();
if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
@ -167,7 +168,7 @@ fn setup_search_tool(
});
let mut index_uids = Vec::new();
let mut function_description = prompts.search_description.clone().unwrap();
let mut function_description = prompts.search_description.clone();
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
// Make sure to skip unauthorized indexes
if !filters.is_index_authorized(&name) {
@ -195,13 +196,13 @@ fn setup_search_tool(
"index_uid": {
"type": "string",
"enum": index_uids,
"description": prompts.search_index_uid_param.clone().unwrap(),
"description": prompts.search_index_uid_param,
},
"q": {
// Unfortunately, Mistral does not support an array of types, here.
// "type": ["string", "null"],
"type": "string",
"description": prompts.search_q_param.clone().unwrap(),
"description": prompts.search_q_param,
}
},
"required": ["index_uid", "q"],
@ -219,9 +220,7 @@ fn setup_search_tool(
chat_completion.messages.insert(
0,
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(
prompts.system.as_ref().unwrap().clone(),
),
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
name: None,
}),
);
@ -322,23 +321,27 @@ async fn non_streamed_chat(
let rtxn = index_scheduler.read_txn()?;
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid).unwrap() {
Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
Some(settings) => settings,
None => {
return Err(ResponseError::from_msg(
format!("Chat `{workspace_uid}` not found"),
Code::ChatWorkspaceNotFound,
))
}
};
let mut config = OpenAIConfig::default();
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
if let Some(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
if let Some(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
let mut response;
loop {
@ -381,13 +384,11 @@ async fn non_streamed_chat(
Err(err) => err,
};
let answer = format!("{}\n\n{text}", chat_settings.prompts.pre_query);
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(format!(
"{}\n\n{text}",
chat_settings.prompts.clone().unwrap().pre_query.unwrap()
)),
content: ChatCompletionRequestToolMessageContent::Text(answer),
},
));
}
@ -416,24 +417,28 @@ async fn streamed_chat(
let filters = index_scheduler.filters();
let rtxn = index_scheduler.read_txn()?;
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid).unwrap() {
Some(value) => serde_json::from_value(value.clone()).unwrap(),
None => GlobalChatSettings::default(),
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid)? {
Some(settings) => settings,
None => {
return Err(ResponseError::from_msg(
format!("Chat `{workspace_uid}` not found"),
Code::ChatWorkspaceNotFound,
))
}
};
drop(rtxn);
let mut config = OpenAIConfig::default();
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
if let Some(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
if let Some(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
let function_support =
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10);
let tx = SseEventSender(tx);
@ -478,7 +483,7 @@ async fn run_conversation<C: Config>(
search_queue: &web::Data<SearchQueue>,
auth_token: &str,
client: &Client<C>,
chat_settings: &GlobalChatSettings,
chat_settings: &DbChatSettings,
chat_completion: &mut CreateChatCompletionRequest,
tx: &SseEventSender,
global_tool_calls: &mut HashMap<u32, Call>,
@ -605,7 +610,7 @@ async fn handle_meili_tools(
auth_ctrl: &web::Data<AuthController>,
search_queue: &web::Data<SearchQueue>,
auth_token: &str,
chat_settings: &GlobalChatSettings,
chat_settings: &DbChatSettings,
tx: &SseEventSender,
meili_calls: Vec<ChatCompletionMessageToolCall>,
chat_completion: &mut CreateChatCompletionRequest,
@ -658,12 +663,10 @@ async fn handle_meili_tools(
Err(err) => err,
};
let answer = format!("{}\n\n{text}", chat_settings.prompts.pre_query);
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(format!(
"{}\n\n{text}",
chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()
)),
content: ChatCompletionRequestToolMessageContent::Text(answer),
});
if append_to_conversation {