Introduce listing/getting/deleting/updating chat workspace settings

This commit is contained in:
Kerollmops 2025-05-30 12:12:47 +02:00 committed by Clément Renault
parent 50fafbbc8b
commit 0f7f5fa104
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
9 changed files with 241 additions and 51 deletions

View file

@ -44,7 +44,8 @@ use tokio::runtime::Handle;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::Sender;
use super::settings::chat::{ChatPrompts, GlobalChatSettings};
use super::settings::{ChatPrompts, GlobalChatSettings};
use super::ChatsParam;
use crate::error::MeilisearchHttpError;
use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
@ -60,13 +61,14 @@ const MEILI_REPORT_ERRORS_NAME: &str = "_meiliReportErrors";
const MEILI_SEARCH_IN_INDEX_FUNCTION_NAME: &str = "_meiliSearchInIndex";
pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("/completions").route(web::post().to(chat)));
cfg.service(web::resource("").route(web::post().to(chat)));
}
/// Get a chat completion
async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
chats_param: web::Path<ChatsParam>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
@ -74,6 +76,8 @@ async fn chat(
// To enable later on, when the feature will be experimental
// index_scheduler.features().check_chat("Using the /chat route")?;
let ChatsParam { workspace_uid } = chats_param.into_inner();
assert_eq!(
chat_completion.n.unwrap_or(1),
1,
@ -82,11 +86,27 @@ async fn chat(
if chat_completion.stream.unwrap_or(false) {
Either::Right(
streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
streamed_chat(
index_scheduler,
auth_ctrl,
search_queue,
&workspace_uid,
req,
chat_completion,
)
.await,
)
} else {
Either::Left(
non_streamed_chat(index_scheduler, auth_ctrl, search_queue, req, chat_completion).await,
non_streamed_chat(
index_scheduler,
auth_ctrl,
search_queue,
&workspace_uid,
req,
chat_completion,
)
.await,
)
}
}
@ -294,12 +314,14 @@ async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
search_queue: web::Data<SearchQueue>,
workspace_uid: &str,
req: HttpRequest,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> {
let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() {
let rtxn = index_scheduler.read_txn()?;
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid).unwrap() {
Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
};
@ -387,15 +409,18 @@ async fn streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
search_queue: web::Data<SearchQueue>,
workspace_uid: &str,
req: HttpRequest,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<impl Responder, ResponseError> {
let filters = index_scheduler.filters();
let chat_settings = match index_scheduler.chat_settings().unwrap() {
let rtxn = index_scheduler.read_txn()?;
let chat_settings = match index_scheduler.chat_settings(&rtxn, workspace_uid).unwrap() {
Some(value) => serde_json::from_value(value.clone()).unwrap(),
None => GlobalChatSettings::default(),
};
drop(rtxn);
let mut config = OpenAIConfig::default();
if let Setting::Set(api_key) = chat_settings.api_key.as_ref() {
@ -662,6 +687,7 @@ impl SseEventSender {
function_name: String,
function_arguments: String,
) -> Result<(), SendError<Event>> {
#[allow(deprecated)]
let message =
ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage {
content: None,
@ -698,6 +724,7 @@ impl SseEventSender {
resp.choices[0] = ChatChoiceStream {
index: 0,
#[allow(deprecated)]
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
@ -744,6 +771,7 @@ impl SseEventSender {
resp.choices[0] = ChatChoiceStream {
index: 0,
#[allow(deprecated)]
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,
@ -788,6 +816,7 @@ impl SseEventSender {
resp.choices[0] = ChatChoiceStream {
index: 0,
#[allow(deprecated)]
delta: ChatCompletionStreamResponseDelta {
content: None,
function_call: None,

View file

@ -0,0 +1,87 @@
use actix_web::{
web::{self, Data},
HttpResponse,
};
use deserr::{actix_web::AwebQueryParameter, Deserr};
use index_scheduler::IndexScheduler;
use meilisearch_types::{
deserr::{query_params::Param, DeserrQueryParamError},
error::{
deserr_codes::{InvalidIndexLimit, InvalidIndexOffset},
ResponseError,
},
keys::actions,
};
use serde::{Deserialize, Serialize};
use tracing::debug;
use utoipa::{IntoParams, ToSchema};
use crate::{
extractors::authentication::{policies::ActionPolicy, GuardedData},
routes::PAGINATION_DEFAULT_LIMIT,
};
use super::Pagination;
// TODO supports chats/$workspace/settings + /chats/$workspace/chat/completions
pub mod chat_completions;
pub mod settings;
#[derive(Deserialize)]
pub struct ChatsParam {
workspace_uid: String,
}
pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(web::resource("").route(web::get().to(list_workspaces))).service(
web::scope("/{workspace_uid}")
.service(web::scope("/chat/completions").configure(chat_completions::configure))
.service(web::scope("/settings").configure(settings::configure)),
);
}
#[derive(Deserr, Debug, Clone, Copy, IntoParams)]
#[deserr(error = DeserrQueryParamError, rename_all = camelCase, deny_unknown_fields)]
#[into_params(rename_all = "camelCase", parameter_in = Query)]
pub struct ListChats {
/// The number of chat workspaces to skip before starting to retrieve anything
#[param(value_type = Option<usize>, default, example = 100)]
#[deserr(default, error = DeserrQueryParamError<InvalidIndexOffset>)]
pub offset: Param<usize>,
/// The number of chat workspaces to retrieve
#[param(value_type = Option<usize>, default = 20, example = 1)]
#[deserr(default = Param(PAGINATION_DEFAULT_LIMIT), error = DeserrQueryParamError<InvalidIndexLimit>)]
pub limit: Param<usize>,
}
impl ListChats {
fn as_pagination(self) -> Pagination {
Pagination { offset: self.offset.0, limit: self.limit.0 }
}
}
#[derive(Debug, Serialize, Clone, ToSchema)]
#[serde(rename_all = "camelCase")]
pub struct ChatWorkspaceView {
/// Unique identifier for the index
pub uid: String,
}
pub async fn list_workspaces(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHATS_GET }>, Data<IndexScheduler>>,
paginate: AwebQueryParameter<ListChats, DeserrQueryParamError>,
) -> Result<HttpResponse, ResponseError> {
debug!(parameters = ?paginate, "List chat workspaces");
let filters = index_scheduler.filters();
let (total, workspaces) = index_scheduler.paginated_chat_workspace_uids(
filters,
*paginate.offset,
*paginate.limit,
)?;
let workspaces =
workspaces.into_iter().map(|uid| ChatWorkspaceView { uid }).collect::<Vec<_>>();
let ret = paginate.as_pagination().format_with(total, workspaces);
debug!(returns = ?ret, "List chat workspaces");
Ok(HttpResponse::Ok().json(ret))
}

View file

@ -10,21 +10,29 @@ use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::GuardedData;
use crate::extractors::sequential_extractor::SeqHandler;
use super::ChatsParam;
pub fn configure(cfg: &mut web::ServiceConfig) {
cfg.service(
web::resource("")
.route(web::get().to(get_settings))
.route(web::patch().to(SeqHandler(patch_settings))),
.route(web::patch().to(SeqHandler(patch_settings)))
.route(web::delete().to(SeqHandler(delete_settings))),
);
}
async fn get_settings(
index_scheduler: GuardedData<
ActionPolicy<{ actions::CHAT_SETTINGS_GET }>,
ActionPolicy<{ actions::CHATS_SETTINGS_GET }>,
Data<IndexScheduler>,
>,
chats_param: web::Path<ChatsParam>,
) -> Result<HttpResponse, ResponseError> {
let mut settings = match index_scheduler.chat_settings()? {
let ChatsParam { workspace_uid } = chats_param.into_inner();
// TODO do a spawn_blocking here ???
let rtxn = index_scheduler.read_txn()?;
let mut settings = match index_scheduler.chat_settings(&rtxn, &workspace_uid)? {
Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
};
@ -34,12 +42,17 @@ async fn get_settings(
async fn patch_settings(
index_scheduler: GuardedData<
ActionPolicy<{ actions::CHAT_SETTINGS_UPDATE }>,
ActionPolicy<{ actions::CHATS_SETTINGS_UPDATE }>,
Data<IndexScheduler>,
>,
chats_param: web::Path<ChatsParam>,
web::Json(new): web::Json<GlobalChatSettings>,
) -> Result<HttpResponse, ResponseError> {
let old = match index_scheduler.chat_settings()? {
let ChatsParam { workspace_uid } = chats_param.into_inner();
// TODO do a spawn_blocking here
let mut wtxn = index_scheduler.write_txn()?;
let old = match index_scheduler.chat_settings(&mut wtxn, &workspace_uid)? {
Some(value) => serde_json::from_value(value).unwrap(),
None => GlobalChatSettings::default(),
};
@ -64,16 +77,39 @@ async fn patch_settings(
};
let value = serde_json::to_value(settings).unwrap();
index_scheduler.put_chat_settings(&value)?;
index_scheduler.put_chat_settings(&mut wtxn, &workspace_uid, &value)?;
wtxn.commit()?;
Ok(HttpResponse::Ok().finish())
}
async fn delete_settings(
index_scheduler: GuardedData<
ActionPolicy<{ actions::CHATS_SETTINGS_DELETE }>,
Data<IndexScheduler>,
>,
chats_param: web::Path<ChatsParam>,
) -> Result<HttpResponse, ResponseError> {
let ChatsParam { workspace_uid } = chats_param.into_inner();
// TODO do a spawn_blocking here
let mut wtxn = index_scheduler.write_txn()?;
if index_scheduler.delete_chat_settings(&mut wtxn, &workspace_uid)? {
wtxn.commit()?;
Ok(HttpResponse::Ok().finish())
} else {
Ok(HttpResponse::NotFound().finish())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub enum ChatSource {
OpenAi,
}
// TODO Implement Deserr on that.
// TODO Declare DbGlobalChatSettings (alias it).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct GlobalChatSettings {
@ -114,6 +150,8 @@ impl GlobalChatSettings {
}
}
// TODO Implement Deserr on that.
// TODO Declare DbChatPrompts (alias it).
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(deny_unknown_fields, rename_all = "camelCase")]
pub struct ChatPrompts {

View file

@ -172,7 +172,7 @@ pub async fn list_indexes(
debug!(parameters = ?paginate, "List indexes");
let filters = index_scheduler.filters();
let (total, indexes) =
index_scheduler.get_paginated_indexes_stats(filters, *paginate.offset, *paginate.limit)?;
index_scheduler.paginated_indexes_stats(filters, *paginate.offset, *paginate.limit)?;
let indexes = indexes
.into_iter()
.map(|(name, stats)| IndexView {

View file

@ -52,7 +52,7 @@ const PAGINATION_DEFAULT_LIMIT_FN: fn() -> usize = || 20;
mod api_key;
pub mod batches;
pub mod chat;
pub mod chats;
mod dump;
pub mod features;
pub mod indexes;
@ -62,7 +62,6 @@ mod multi_search;
mod multi_search_analytics;
pub mod network;
mod open_api_utils;
pub mod settings;
mod snapshot;
mod swap_indexes;
pub mod tasks;
@ -116,8 +115,7 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
.service(web::scope("/metrics").configure(metrics::configure))
.service(web::scope("/experimental-features").configure(features::configure))
.service(web::scope("/network").configure(network::configure))
.service(web::scope("/chat").configure(chat::configure))
.service(web::scope("/settings/chat").configure(settings::chat::configure));
.service(web::scope("/chats").configure(chats::settings::configure));
#[cfg(feature = "swagger")]
{

View file

@ -1 +0,0 @@
pub mod chat;