Better chat settings management

This commit is contained in:
Clément Renault 2025-05-21 21:06:11 +02:00
parent afb43d266e
commit 7929872091
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
3 changed files with 95 additions and 45 deletions

View file

@ -28,6 +28,7 @@ use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::ChatConfig;
use meilisearch_types::milli::prompt::{Prompt, PromptData};
use meilisearch_types::milli::update::new::document::DocumentFromDb;
use meilisearch_types::milli::update::Setting;
use meilisearch_types::milli::{
DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap, MetadataBuilder, TimeBudget,
};
@ -107,20 +108,20 @@ fn setup_search_tool(
.function(
FunctionObjectArgs::default()
.name(SEARCH_IN_INDEX_FUNCTION_NAME)
.description(&prompts.search_description)
.description(&prompts.search_description.clone().unwrap())
.parameters(json!({
"type": "object",
"properties": {
"index_uid": {
"type": "string",
"enum": index_uids,
"description": prompts.search_index_uid_param,
"description": prompts.search_index_uid_param.clone().unwrap(),
},
"q": {
// Unfortunately, Mistral does not support an array of types, here.
// "type": ["string", "null"],
"type": "string",
"description": prompts.search_q_param,
"description": prompts.search_q_param.clone().unwrap(),
}
},
"required": ["index_uid", "q"],
@ -136,7 +137,9 @@ fn setup_search_tool(
chat_completion.messages.insert(
0,
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
content: ChatCompletionRequestSystemMessageContent::Text(
prompts.system.as_ref().unwrap().clone(),
),
name: None,
}),
);
@ -239,16 +242,17 @@ async fn non_streamed_chat(
};
let mut config = OpenAIConfig::default();
if let Some(api_key) = chat_settings.api_key.as_ref() {
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let mut response;
loop {
@ -296,7 +300,7 @@ async fn non_streamed_chat(
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(format!(
"{}\n\n{text}",
chat_settings.prompts.pre_query
chat_settings.prompts.clone().unwrap().pre_query.unwrap()
)),
},
));
@ -325,20 +329,21 @@ async fn streamed_chat(
let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(),
Some(value) => serde_json::from_value(value.clone()).unwrap(),
None => GlobalChatSettings::default(),
};
let mut config = OpenAIConfig::default();
if let Some(api_key) = chat_settings.api_key.as_ref() {
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
config = config.with_api_key(api_key);
}
if let Some(base_api) = chat_settings.base_api.as_ref() {
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
config = config.with_api_base(base_api);
}
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
let (tx, rx) = tokio::sync::mpsc::channel(10);
let _join_handle = Handle::current().spawn(async move {
@ -447,7 +452,7 @@ async fn streamed_chat(
let tool = ChatCompletionRequestToolMessage {
tool_call_id: call.id.clone(),
content: ChatCompletionRequestToolMessageContent::Text(
format!("{}\n\n{text}", chat_settings.prompts.pre_query),
format!("{}\n\n{text}", chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()),
),
};