Correctly support tenant tokens and filters

This commit is contained in:
Clément Renault 2025-05-20 16:15:49 +02:00
parent bae6c98aa3
commit 7266aed770
No known key found for this signature in database
GPG Key ID: F250A4C4E3AE5F5F
2 changed files with 72 additions and 34 deletions

View File

@ -4,6 +4,7 @@ use std::marker::PhantomData;
use std::ops::Deref; use std::ops::Deref;
use std::pin::Pin; use std::pin::Pin;
use actix_web::http::header::AUTHORIZATION;
use actix_web::web::Data; use actix_web::web::Data;
use actix_web::FromRequest; use actix_web::FromRequest;
pub use error::AuthenticationError; pub use error::AuthenticationError;
@ -94,36 +95,44 @@ impl<P: Policy + 'static, D: 'static + Clone> FromRequest for GuardedData<P, D>
_payload: &mut actix_web::dev::Payload, _payload: &mut actix_web::dev::Payload,
) -> Self::Future { ) -> Self::Future {
match req.app_data::<Data<AuthController>>().cloned() { match req.app_data::<Data<AuthController>>().cloned() {
Some(auth) => match req Some(auth) => match extract_token_from_request(req) {
.headers() Ok(Some(token)) => {
.get("Authorization") // TODO: find a less hardcoded way?
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' ')) let index = req.match_info().get("index_uid");
{ Box::pin(Self::auth_bearer(
Some(mut type_token) => match type_token.next() { auth,
Some("Bearer") => { token.to_string(),
// TODO: find a less hardcoded way? index.map(String::from),
let index = req.match_info().get("index_uid"); req.app_data::<D>().cloned(),
match type_token.next() { ))
Some(token) => Box::pin(Self::auth_bearer( }
auth, Ok(None) => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
token.to_string(), Err(e) => Box::pin(err(e.into())),
index.map(String::from),
req.app_data::<D>().cloned(),
)),
None => Box::pin(err(AuthenticationError::InvalidToken.into())),
}
}
_otherwise => {
Box::pin(err(AuthenticationError::MissingAuthorizationHeader.into()))
}
},
None => Box::pin(Self::auth_token(auth, req.app_data::<D>().cloned())),
}, },
None => Box::pin(err(AuthenticationError::IrretrievableState.into())), None => Box::pin(err(AuthenticationError::IrretrievableState.into())),
} }
} }
} }
pub fn extract_token_from_request(
req: &actix_web::HttpRequest,
) -> Result<Option<&str>, AuthenticationError> {
match req
.headers()
.get(AUTHORIZATION)
.map(|type_token| type_token.to_str().unwrap_or_default().splitn(2, ' '))
{
Some(mut type_token) => match type_token.next() {
Some("Bearer") => match type_token.next() {
Some(token) => Ok(Some(token)),
None => Err(AuthenticationError::InvalidToken),
},
_otherwise => Err(AuthenticationError::MissingAuthorizationHeader),
},
None => Ok(None),
}
}
pub trait Policy { pub trait Policy {
fn authenticate( fn authenticate(
auth: Data<AuthController>, auth: Data<AuthController>,

View File

@ -3,7 +3,7 @@ use std::mem;
use std::time::Duration; use std::time::Duration;
use actix_web::web::{self, Data}; use actix_web::web::{self, Data};
use actix_web::{Either, HttpResponse, Responder}; use actix_web::{Either, HttpRequest, HttpResponse, Responder};
use actix_web_lab::sse::{self, Event, Sse}; use actix_web_lab::sse::{self, Event, Sse};
use async_openai::config::OpenAIConfig; use async_openai::config::OpenAIConfig;
use async_openai::types::{ use async_openai::types::{
@ -18,6 +18,7 @@ use async_openai::types::{
use async_openai::Client; use async_openai::Client;
use futures::StreamExt; use futures::StreamExt;
use index_scheduler::IndexScheduler; use index_scheduler::IndexScheduler;
use meilisearch_auth::AuthController;
use meilisearch_types::error::ResponseError; use meilisearch_types::error::ResponseError;
use meilisearch_types::keys::actions; use meilisearch_types::keys::actions;
use meilisearch_types::milli::index::IndexEmbeddingConfig; use meilisearch_types::milli::index::IndexEmbeddingConfig;
@ -31,7 +32,7 @@ use tokio::sync::mpsc::error::SendError;
use super::settings::chat::{ChatPrompts, ChatSettings}; use super::settings::chat::{ChatPrompts, ChatSettings};
use crate::extractors::authentication::policies::ActionPolicy; use crate::extractors::authentication::policies::ActionPolicy;
use crate::extractors::authentication::GuardedData; use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS; use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
use crate::routes::indexes::search::search_kind; use crate::routes::indexes::search::search_kind;
use crate::search::{ use crate::search::{
@ -48,6 +49,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
/// Get a chat completion /// Get a chat completion
async fn chat( async fn chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>, web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
) -> impl Responder { ) -> impl Responder {
@ -61,9 +64,13 @@ async fn chat(
); );
if chat_completion.stream.unwrap_or(false) { if chat_completion.stream.unwrap_or(false) {
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await) Either::Right(
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
} else { } else {
Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await) Either::Left(
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
)
} }
} }
@ -115,7 +122,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
/// Process search request and return formatted results /// Process search request and return formatted results
async fn process_search_request( async fn process_search_request(
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>, index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
search_queue: &web::Data<SearchQueue>, search_queue: &web::Data<SearchQueue>,
auth_token: &str,
index_uid: String, index_uid: String,
q: Option<String>, q: Option<String>,
) -> Result<(Index, String), ResponseError> { ) -> Result<(Index, String), ResponseError> {
@ -129,8 +138,14 @@ async fn process_search_request(
..Default::default() ..Default::default()
}; };
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
auth_ctrl,
auth_token,
Some(index_uid.as_str()),
)?;
// Tenant token search_rules. // Tenant token search_rules.
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) { if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
add_search_rules(&mut query.filter, search_rules); add_search_rules(&mut query.filter, search_rules);
} }
@ -176,6 +191,8 @@ async fn process_search_request(
async fn non_streamed_chat( async fn non_streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest, mut chat_completion: CreateChatCompletionRequest,
) -> Result<HttpResponse, ResponseError> { ) -> Result<HttpResponse, ResponseError> {
@ -193,6 +210,7 @@ async fn non_streamed_chat(
} }
let client = Client::with_config(config); let client = Client::with_config(config);
let auth_token = extract_token_from_request(&req)?.unwrap();
setup_search_tool(&mut chat_completion, &chat_settings.prompts); setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let mut response; let mut response;
@ -219,9 +237,15 @@ async fn non_streamed_chat(
let SearchInIndexParameters { index_uid, q } = let SearchInIndexParameters { index_uid, q } =
serde_json::from_str(&call.function.arguments).unwrap(); serde_json::from_str(&call.function.arguments).unwrap();
let (_, text) = let (_, text) = process_search_request(
process_search_request(&index_scheduler, &search_queue, index_uid, q) &index_scheduler,
.await?; auth_ctrl.clone(),
&search_queue,
auth_token,
index_uid,
q,
)
.await?;
chat_completion.messages.push(ChatCompletionRequestMessage::Tool( chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
ChatCompletionRequestToolMessage { ChatCompletionRequestToolMessage {
@ -246,9 +270,11 @@ async fn non_streamed_chat(
async fn streamed_chat( async fn streamed_chat(
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>, index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
auth_ctrl: web::Data<AuthController>,
req: HttpRequest,
search_queue: web::Data<SearchQueue>, search_queue: web::Data<SearchQueue>,
mut chat_completion: CreateChatCompletionRequest, mut chat_completion: CreateChatCompletionRequest,
) -> impl Responder { ) -> Result<impl Responder, ResponseError> {
let chat_settings = match index_scheduler.chat_settings().unwrap() { let chat_settings = match index_scheduler.chat_settings().unwrap() {
Some(value) => serde_json::from_value(value).unwrap(), Some(value) => serde_json::from_value(value).unwrap(),
None => ChatSettings::default(), None => ChatSettings::default(),
@ -262,6 +288,7 @@ async fn streamed_chat(
config = config.with_api_base(base_api); config = config.with_api_base(base_api);
} }
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
setup_search_tool(&mut chat_completion, &chat_settings.prompts); setup_search_tool(&mut chat_completion, &chat_settings.prompts);
let (tx, rx) = tokio::sync::mpsc::channel(10); let (tx, rx) = tokio::sync::mpsc::channel(10);
@ -354,7 +381,9 @@ async fn streamed_chat(
let result = process_search_request( let result = process_search_request(
&index_scheduler, &index_scheduler,
auth_ctrl.clone(),
&search_queue, &search_queue,
&auth_token,
index_uid, index_uid,
q, q,
) )
@ -417,7 +446,7 @@ async fn streamed_chat(
let _ = tx.send(Event::Data(sse::Data::new("[DONE]"))); let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
}); });
Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)) Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
} }
/// The structure used to aggregate the function calls to make. /// The structure used to aggregate the function calls to make.