Correctly support tenant tokens and filters

This commit is contained in:
Clément Renault 2025-05-20 16:15:49 +02:00
parent 46680585ae
commit 7636365a65
No known key found for this signature in database
GPG key ID: F250A4C4E3AE5F5F
2 changed files with 72 additions and 34 deletions

View file

@ -3,7 +3,7 @@ use std::mem;
use std::time::Duration;
use actix_web::web::{self, Data};
use actix_web::{Either, HttpResponse, Responder};
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig;
use async_openai::types::{
@ -18,6 +18,7 @@ use async_openai::types::{
use async_openai::Client;
use futures::StreamExt;
use index_scheduler::IndexScheduler;
use meilisearch_auth::AuthController;
use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::IndexEmbeddingConfig;
@ -31,7 +32,7 @@ use tokio::sync::mpsc::error::SendError;
use super::settings::chat::{ChatPrompts, ChatSettings};
use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::GuardedData;
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::routes::indexes::search::search_kind;
use crate::search::{
@ -48,6 +49,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
/// Get a chat completion
async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
) -> impl Responder {
@ -61,9 +64,13 @@ async fn chat(
);
if chat_completion.stream.unwrap_or(false) {
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
Either::Right(
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
} else {
Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await)
Either::Left(
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
}
}
@ -115,7 +122,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
/// Process search request and return formatted results
async fn process_search_request(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
search_queue: &web::Data<SearchQueue>,
auth_token: &str,
index_uid: String,
q: Option<String>,
) -> Result<(Index, String), ResponseError> {
@ -129,8 +138,14 @@ async fn process_search_request(
..Default::default()
};
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
auth_ctrl,
auth_token,
Some(index_uid.as_str()),
)?;
// Tenant token search_rules.
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
add_search_rules(&mut query.filter, search_rules);
}
@ -176,6 +191,8 @@ async fn process_search_request(
async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> {
@ -193,6 +210,7 @@ async fn non_streamed_chat(
}
let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap();
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let mut response;
@ -219,9 +237,15 @@ async fn non_streamed_chat(
let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap();
let (_, text) =
process_search_request(&index_scheduler, &search_queue, index_uid, q)
.await?;
let (_, text) = process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
auth_token,
index_uid,
q,
)
.await?;
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage {
@ -246,9 +270,11 @@ async fn non_streamed_chat(
async fn streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest,
) -> impl Responder {
) -> Result<impl Responder, ResponseError> {
let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(),
@ -262,6 +288,7 @@ async fn streamed_chat(
config = config.with_api_base(base_api);
}
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let (tx, rx) = tokio::sync::mpsc::channel(10);
@ -354,7 +381,9 @@ async fn streamed_chat(
let result = process_search_request(
&index_scheduler,
auth_ctrl.clone(),
&search_queue,
&auth_token,
index_uid,
q,
)
@ -417,7 +446,7 @@ async fn streamed_chat(
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
});
Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
}
/// The structure used to aggregate the function calls to make.