mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-03 11:57:07 +02:00
Better chat completions settings management
This commit is contained in:
parent
0f7f5fa104
commit
02cbcea3db
6 changed files with 219 additions and 107 deletions
|
@ -26,13 +26,15 @@ use bumpalo::Bump;
|
|||
use futures::StreamExt;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_auth::AuthController;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::error::{Code, ResponseError};
|
||||
use meilisearch_types::features::{
|
||||
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings as DbChatSettings,
|
||||
};
|
||||
use meilisearch_types::heed::RoTxn;
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::index::ChatConfig;
|
||||
use meilisearch_types::milli::prompt::{Prompt, PromptData};
|
||||
use meilisearch_types::milli::update::new::document::DocumentFromDb;
|
||||
use meilisearch_types::milli::update::Setting;
|
||||
use meilisearch_types::milli::{
|
||||
all_obkv_to_json, obkv_to_json, DocumentId, FieldIdMapWithMetadata, GlobalFieldsIdsMap,
|
||||
MetadataBuilder, TimeBudget,
|
||||
|
@ -44,7 +46,6 @@ use tokio::runtime::Handle;
|
|||
use tokio::sync::mpsc::error::SendError;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
|
||||
use super::settings::{ChatPrompts, GlobalChatSettings};
|
||||
use super::ChatsParam;
|
||||
use crate::error::MeilisearchHttpError;
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
|
@ -132,7 +133,7 @@ fn setup_search_tool(
|
|||
index_scheduler: &Data<IndexScheduler>,
|
||||
filters: &meilisearch_auth::AuthFilter,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
prompts: &ChatPrompts,
|
||||
prompts: &DbChatCompletionPrompts,
|
||||
) -> Result<FunctionSupport, ResponseError> {
|
||||
let tools = chat_completion.tools.get_or_insert_default();
|
||||
if tools.iter().find(|t| t.function.name == MEILI_SEARCH_IN_INDEX_FUNCTION_NAME).is_some() {
|
||||
|
@ -167,7 +168,7 @@ fn setup_search_tool(
|
|||
});
|
||||
|
||||
let mut index_uids = Vec::new();
|
||||
let mut function_description = prompts.search_description.clone().unwrap();
|
||||
let mut function_description = prompts.search_description.clone();
|
||||
index_scheduler.try_for_each_index::<_, ()>(|name, index| {
|
||||
// Make sure to skip unauthorized indexes
|
||||
if !filters.is_index_authorized(&name) {
|
||||
|
@ -195,13 +196,13 @@ fn setup_search_tool(
|
|||
"index_uid": {
|
||||
"type": "string",
|
||||
"enum": index_uids,
|
||||
"description": prompts.search_index_uid_param.clone().unwrap(),
|
||||
"description": prompts.search_index_uid_param,
|
||||
},
|
||||
"q": {
|
||||
// Unfortunately, Mistral does not support an array of types, here.
|
||||
// "type": ["string", "null"],
|
||||
"type": "string",
|
||||
"description": prompts.search_q_param.clone().unwrap(),
|
||||
"description": prompts.search_q_param,
|
||||
}
|
||||
},
|
||||
"required": ["index_uid", "q"],
|
||||
|
@ -219,9 +220,7 @@ fn setup_search_tool(
|
|||
chat_completion.messages.insert(
|
||||
0,
|
||||
ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage {
|
||||
content: ChatCompletionRequestSystemMessageContent::Text(
|
||||
prompts.system.as_ref().unwrap().clone(),
|
||||
),
|
||||
content: ChatCompletionRequestSystemMessageContent::Text(prompts.system.clone()),
|
||||
name: None,
|
||||
}),
|
||||
);
|
||||
|
@ -322,23 +321,27 @@ async fn non_streamed_chat(
|
|||
|
||||
let rtxn = index_scheduler.read_txn()?;
|
||||
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid).unwrap() {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
Some(settings) => settings,
|
||||
None => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatWorkspaceNotFound,
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
let client = Client::with_config(config);
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||
let FunctionSupport { report_progress, report_sources, append_to_conversation, report_errors } =
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
||||
|
||||
let mut response;
|
||||
loop {
|
||||
|
@ -381,13 +384,11 @@ async fn non_streamed_chat(
|
|||
Err(err) => err,
|
||||
};
|
||||
|
||||
let answer = format!("{}\n\n{text}", chat_settings.prompts.pre_query);
|
||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
||||
ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(format!(
|
||||
"{}\n\n{text}",
|
||||
chat_settings.prompts.clone().unwrap().pre_query.unwrap()
|
||||
)),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(answer),
|
||||
},
|
||||
));
|
||||
}
|
||||
|
@ -416,24 +417,28 @@ async fn streamed_chat(
|
|||
let filters = index_scheduler.filters();
|
||||
|
||||
let rtxn = index_scheduler.read_txn()?;
|
||||
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid).unwrap() {
|
||||
Some(value) => serde_json::from_value(value.clone()).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid)? {
|
||||
Some(settings) => settings,
|
||||
None => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatWorkspaceNotFound,
|
||||
))
|
||||
}
|
||||
};
|
||||
drop(rtxn);
|
||||
|
||||
let mut config = OpenAIConfig::default();
|
||||
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
|
||||
if let Some(api_key) = chat_settings.api_key.as_ref() {
|
||||
config = config.with_api_key(api_key);
|
||||
}
|
||||
if let Setting::Set(base_api) = chat_settings.base_api.as_ref() {
|
||||
if let Some(base_api) = chat_settings.base_api.as_ref() {
|
||||
config = config.with_api_base(base_api);
|
||||
}
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
let prompts = chat_settings.prompts.clone().or(Setting::Set(ChatPrompts::default())).unwrap();
|
||||
let function_support =
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &prompts)?;
|
||||
setup_search_tool(&index_scheduler, filters, &mut chat_completion, &chat_settings.prompts)?;
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
let tx = SseEventSender(tx);
|
||||
|
@ -478,7 +483,7 @@ async fn run_conversation<C: Config>(
|
|||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
client: &Client<C>,
|
||||
chat_settings: &GlobalChatSettings,
|
||||
chat_settings: &DbChatSettings,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
tx: &SseEventSender,
|
||||
global_tool_calls: &mut HashMap<u32, Call>,
|
||||
|
@ -605,7 +610,7 @@ async fn handle_meili_tools(
|
|||
auth_ctrl: &web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
chat_settings: &GlobalChatSettings,
|
||||
chat_settings: &DbChatSettings,
|
||||
tx: &SseEventSender,
|
||||
meili_calls: Vec<ChatCompletionMessageToolCall>,
|
||||
chat_completion: &mut CreateChatCompletionRequest,
|
||||
|
@ -658,12 +663,10 @@ async fn handle_meili_tools(
|
|||
Err(err) => err,
|
||||
};
|
||||
|
||||
let answer = format!("{}\n\n{text}", chat_settings.prompts.pre_query);
|
||||
let tool = ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage {
|
||||
tool_call_id: call.id.clone(),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(format!(
|
||||
"{}\n\n{text}",
|
||||
chat_settings.prompts.as_ref().unwrap().pre_query.as_ref().unwrap()
|
||||
)),
|
||||
content: ChatCompletionRequestToolMessageContent::Text(answer),
|
||||
});
|
||||
|
||||
if append_to_conversation {
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
use actix_web::web::{self, Data};
|
||||
use actix_web::HttpResponse;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::error::{Code, ResponseError};
|
||||
use meilisearch_types::features::{
|
||||
ChatCompletionPrompts as DbChatCompletionPrompts, ChatCompletionSettings,
|
||||
ChatCompletionSource as DbChatCompletionSource, DEFAULT_CHAT_PRE_QUERY_PROMPT,
|
||||
DEFAULT_CHAT_SEARCH_DESCRIPTION_PROMPT, DEFAULT_CHAT_SEARCH_INDEX_UID_PARAM_PROMPT,
|
||||
DEFAULT_CHAT_SEARCH_Q_PARAM_PROMPT, DEFAULT_CHAT_SYSTEM_PROMPT,
|
||||
};
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::update::Setting;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
@ -15,7 +21,7 @@ use super::ChatsParam;
|
|||
pub fn configure(cfg: &mut web::ServiceConfig) {
|
||||
cfg.service(
|
||||
web::resource("")
|
||||
.route(web::get().to(get_settings))
|
||||
.route(web::get().to(SeqHandler(get_settings)))
|
||||
.route(web::patch().to(SeqHandler(patch_settings)))
|
||||
.route(web::delete().to(SeqHandler(delete_settings))),
|
||||
);
|
||||
|
@ -33,8 +39,13 @@ async fn get_settings(
|
|||
// TODO do a spawn_blocking here ???
|
||||
let rtxn = index_scheduler.read_txn()?;
|
||||
let mut settings = match index_scheduler.chat_settings(&rtxn, &workspace_uid)? {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
Some(settings) => settings,
|
||||
None => {
|
||||
return Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatWorkspaceNotFound,
|
||||
))
|
||||
}
|
||||
};
|
||||
settings.hide_secrets();
|
||||
Ok(HttpResponse::Ok().json(settings))
|
||||
|
@ -52,35 +63,73 @@ async fn patch_settings(
|
|||
|
||||
// TODO do a spawn_blocking here
|
||||
let mut wtxn = index_scheduler.write_txn()?;
|
||||
let old = match index_scheduler.chat_settings(&mut wtxn, &workspace_uid)? {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => GlobalChatSettings::default(),
|
||||
};
|
||||
let old_settings =
|
||||
index_scheduler.chat_settings(&mut wtxn, &workspace_uid)?.unwrap_or_default();
|
||||
|
||||
let settings = GlobalChatSettings {
|
||||
source: new.source.or(old.source),
|
||||
base_api: new.base_api.clone().or(old.base_api),
|
||||
api_key: new.api_key.clone().or(old.api_key),
|
||||
prompts: match (new.prompts, old.prompts) {
|
||||
(Setting::NotSet, set) | (set, Setting::NotSet) => set,
|
||||
(Setting::Set(_) | Setting::Reset, Setting::Reset) => Setting::Reset,
|
||||
(Setting::Reset, Setting::Set(set)) => Setting::Set(set),
|
||||
// If both are set we must merge the prompts settings
|
||||
(Setting::Set(new), Setting::Set(old)) => Setting::Set(ChatPrompts {
|
||||
system: new.system.or(old.system),
|
||||
search_description: new.search_description.or(old.search_description),
|
||||
search_q_param: new.search_q_param.or(old.search_q_param),
|
||||
search_index_uid_param: new.search_index_uid_param.or(old.search_index_uid_param),
|
||||
pre_query: new.pre_query.or(old.pre_query),
|
||||
}),
|
||||
let prompts = match new.prompts {
|
||||
Setting::Set(new_prompts) => DbChatCompletionPrompts {
|
||||
system: match new_prompts.system {
|
||||
Setting::Set(new_system) => new_system,
|
||||
Setting::Reset => DEFAULT_CHAT_SYSTEM_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.system,
|
||||
},
|
||||
search_description: match new_prompts.search_description {
|
||||
Setting::Set(new_description) => new_description,
|
||||
Setting::Reset => DEFAULT_CHAT_SEARCH_DESCRIPTION_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.search_description,
|
||||
},
|
||||
search_q_param: match new_prompts.search_q_param {
|
||||
Setting::Set(new_description) => new_description,
|
||||
Setting::Reset => DEFAULT_CHAT_SEARCH_Q_PARAM_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.search_q_param,
|
||||
},
|
||||
search_index_uid_param: match new_prompts.search_index_uid_param {
|
||||
Setting::Set(new_description) => new_description,
|
||||
Setting::Reset => DEFAULT_CHAT_SEARCH_INDEX_UID_PARAM_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.search_index_uid_param,
|
||||
},
|
||||
pre_query: match new_prompts.pre_query {
|
||||
Setting::Set(new_description) => new_description,
|
||||
Setting::Reset => DEFAULT_CHAT_PRE_QUERY_PROMPT.to_string(),
|
||||
Setting::NotSet => old_settings.prompts.pre_query,
|
||||
},
|
||||
},
|
||||
Setting::Reset => DbChatCompletionPrompts::default(),
|
||||
Setting::NotSet => old_settings.prompts,
|
||||
};
|
||||
|
||||
let value = serde_json::to_value(settings).unwrap();
|
||||
index_scheduler.put_chat_settings(&mut wtxn, &workspace_uid, &value)?;
|
||||
let settings = ChatCompletionSettings {
|
||||
source: match new.source {
|
||||
Setting::Set(new_source) => new_source.into(),
|
||||
Setting::Reset => DbChatCompletionSource::default(),
|
||||
Setting::NotSet => old_settings.source,
|
||||
},
|
||||
base_api: match new.base_api {
|
||||
Setting::Set(new_base_api) => Some(new_base_api),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.base_api,
|
||||
},
|
||||
api_key: match new.api_key {
|
||||
Setting::Set(new_api_key) => Some(new_api_key),
|
||||
Setting::Reset => None,
|
||||
Setting::NotSet => old_settings.api_key,
|
||||
},
|
||||
prompts,
|
||||
};
|
||||
|
||||
// TODO send analytics
|
||||
// analytics.publish(
|
||||
// PatchNetworkAnalytics {
|
||||
// network_size: merged_remotes.len(),
|
||||
// network_has_self: merged_self.is_some(),
|
||||
// },
|
||||
// &req,
|
||||
// );
|
||||
|
||||
index_scheduler.put_chat_settings(&mut wtxn, &workspace_uid, &settings)?;
|
||||
wtxn.commit()?;
|
||||
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
Ok(HttpResponse::Ok().json(settings))
|
||||
}
|
||||
|
||||
async fn delete_settings(
|
||||
|
@ -96,58 +145,42 @@ async fn delete_settings(
|
|||
let mut wtxn = index_scheduler.write_txn()?;
|
||||
if index_scheduler.delete_chat_settings(&mut wtxn, &workspace_uid)? {
|
||||
wtxn.commit()?;
|
||||
Ok(HttpResponse::Ok().finish())
|
||||
Ok(HttpResponse::NoContent().finish())
|
||||
} else {
|
||||
Ok(HttpResponse::NotFound().finish())
|
||||
Err(ResponseError::from_msg(
|
||||
format!("Chat `{workspace_uid}` not found"),
|
||||
Code::ChatWorkspaceNotFound,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Default, Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub enum ChatSource {
|
||||
pub enum ChatCompletionSource {
|
||||
#[default]
|
||||
OpenAi,
|
||||
}
|
||||
|
||||
// TODO Implement Deserr on that.
|
||||
// TODO Declare DbGlobalChatSettings (alias it).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct GlobalChatSettings {
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub source: Setting<ChatSource>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub base_api: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default, skip_serializing_if = "Setting::is_not_set")]
|
||||
pub prompts: Setting<ChatPrompts>,
|
||||
impl From<ChatCompletionSource> for DbChatCompletionSource {
|
||||
fn from(source: ChatCompletionSource) -> Self {
|
||||
match source {
|
||||
ChatCompletionSource::OpenAi => DbChatCompletionSource::OpenAi,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GlobalChatSettings {
|
||||
pub fn hide_secrets(&mut self) {
|
||||
match &mut self.api_key {
|
||||
Setting::Set(key) => Self::hide_secret(key),
|
||||
Setting::Reset => (),
|
||||
Setting::NotSet => (),
|
||||
}
|
||||
}
|
||||
|
||||
fn hide_secret(secret: &mut String) {
|
||||
match secret.len() {
|
||||
x if x < 10 => {
|
||||
secret.replace_range(.., "XXX...");
|
||||
}
|
||||
x if x < 20 => {
|
||||
secret.replace_range(2.., "XXXX...");
|
||||
}
|
||||
x if x < 30 => {
|
||||
secret.replace_range(3.., "XXXXX...");
|
||||
}
|
||||
_x => {
|
||||
secret.replace_range(5.., "XXXXXX...");
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO Implement Deserr on that.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(deny_unknown_fields, rename_all = "camelCase")]
|
||||
pub struct GlobalChatSettings {
|
||||
#[serde(default)]
|
||||
pub source: Setting<ChatCompletionSource>,
|
||||
#[serde(default)]
|
||||
pub base_api: Setting<String>,
|
||||
#[serde(default)]
|
||||
pub api_key: Setting<String>,
|
||||
#[serde(default)]
|
||||
pub prompts: Setting<ChatPrompts>,
|
||||
}
|
||||
|
||||
// TODO Implement Deserr on that.
|
||||
|
|
|
@ -115,7 +115,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||
.service(web::scope("/metrics").configure(metrics::configure))
|
||||
.service(web::scope("/experimental-features").configure(features::configure))
|
||||
.service(web::scope("/network").configure(network::configure))
|
||||
.service(web::scope("/chats").configure(chats::settings::configure));
|
||||
.service(web::scope("/chats").configure(chats::configure));
|
||||
|
||||
#[cfg(feature = "swagger")]
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue