mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-07-04 12:27:13 +02:00
Correctly support tenant tokens and filters
This commit is contained in:
parent
46680585ae
commit
7636365a65
2 changed files with 72 additions and 34 deletions
|
@ -3,7 +3,7 @@ use std::mem;
|
|||
use std::time::Duration;
|
||||
|
||||
use actix_web::web::{self, Data};
|
||||
use actix_web::{Either, HttpResponse, Responder};
|
||||
use actix_web::{Either, HttpRequest, HttpResponse, Responder};
|
||||
use actix_web_lab::sse::{self, Event, Sse};
|
||||
use async_openai::config::OpenAIConfig;
|
||||
use async_openai::types::{
|
||||
|
@ -18,6 +18,7 @@ use async_openai::types::{
|
|||
use async_openai::Client;
|
||||
use futures::StreamExt;
|
||||
use index_scheduler::IndexScheduler;
|
||||
use meilisearch_auth::AuthController;
|
||||
use meilisearch_types::error::ResponseError;
|
||||
use meilisearch_types::keys::actions;
|
||||
use meilisearch_types::milli::index::IndexEmbeddingConfig;
|
||||
|
@ -31,7 +32,7 @@ use tokio::sync::mpsc::error::SendError;
|
|||
|
||||
use super::settings::chat::{ChatPrompts, ChatSettings};
|
||||
use crate::extractors::authentication::policies::ActionPolicy;
|
||||
use crate::extractors::authentication::GuardedData;
|
||||
use crate::extractors::authentication::{extract_token_from_request, GuardedData, Policy as _};
|
||||
use crate::metrics::MEILISEARCH_DEGRADED_SEARCH_REQUESTS;
|
||||
use crate::routes::indexes::search::search_kind;
|
||||
use crate::search::{
|
||||
|
@ -48,6 +49,8 @@ pub fn configure(cfg: &mut web::ServiceConfig) {
|
|||
/// Get a chat completion
|
||||
async fn chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
web::Json(chat_completion): web::Json<CreateChatCompletionRequest>,
|
||||
) -> impl Responder {
|
||||
|
@ -61,9 +64,13 @@ async fn chat(
|
|||
);
|
||||
|
||||
if chat_completion.stream.unwrap_or(false) {
|
||||
Either::Right(streamed_chat(index_scheduler, search_queue, chat_completion).await)
|
||||
Either::Right(
|
||||
streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
||||
)
|
||||
} else {
|
||||
Either::Left(non_streamed_chat(index_scheduler, search_queue, chat_completion).await)
|
||||
Either::Left(
|
||||
non_streamed_chat(index_scheduler, auth_ctrl, req, search_queue, chat_completion).await,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -115,7 +122,9 @@ fn setup_search_tool(chat_completion: &mut CreateChatCompletionRequest, prompts:
|
|||
/// Process search request and return formatted results
|
||||
async fn process_search_request(
|
||||
index_scheduler: &GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
search_queue: &web::Data<SearchQueue>,
|
||||
auth_token: &str,
|
||||
index_uid: String,
|
||||
q: Option<String>,
|
||||
) -> Result<(Index, String), ResponseError> {
|
||||
|
@ -129,8 +138,14 @@ async fn process_search_request(
|
|||
..Default::default()
|
||||
};
|
||||
|
||||
let auth_filter = ActionPolicy::<{ actions::SEARCH }>::authenticate(
|
||||
auth_ctrl,
|
||||
auth_token,
|
||||
Some(index_uid.as_str()),
|
||||
)?;
|
||||
|
||||
// Tenant token search_rules.
|
||||
if let Some(search_rules) = index_scheduler.filters().get_index_search_rules(&index_uid) {
|
||||
if let Some(search_rules) = auth_filter.get_index_search_rules(&index_uid) {
|
||||
add_search_rules(&mut query.filter, search_rules);
|
||||
}
|
||||
|
||||
|
@ -176,6 +191,8 @@ async fn process_search_request(
|
|||
|
||||
async fn non_streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> Result<HttpResponse, ResponseError> {
|
||||
|
@ -193,6 +210,7 @@ async fn non_streamed_chat(
|
|||
}
|
||||
let client = Client::with_config(config);
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap();
|
||||
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
|
||||
|
||||
let mut response;
|
||||
|
@ -219,9 +237,15 @@ async fn non_streamed_chat(
|
|||
let SearchInIndexParameters { index_uid, q } =
|
||||
serde_json::from_str(&call.function.arguments).unwrap();
|
||||
|
||||
let (_, text) =
|
||||
process_search_request(&index_scheduler, &search_queue, index_uid, q)
|
||||
.await?;
|
||||
let (_, text) = process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
)
|
||||
.await?;
|
||||
|
||||
chat_completion.messages.push(ChatCompletionRequestMessage::Tool(
|
||||
ChatCompletionRequestToolMessage {
|
||||
|
@ -246,9 +270,11 @@ async fn non_streamed_chat(
|
|||
|
||||
async fn streamed_chat(
|
||||
index_scheduler: GuardedData<ActionPolicy<{ actions::CHAT }>, Data<IndexScheduler>>,
|
||||
auth_ctrl: web::Data<AuthController>,
|
||||
req: HttpRequest,
|
||||
search_queue: web::Data<SearchQueue>,
|
||||
mut chat_completion: CreateChatCompletionRequest,
|
||||
) -> impl Responder {
|
||||
) -> Result<impl Responder, ResponseError> {
|
||||
let chat_settings = match index_scheduler.chat_settings().unwrap() {
|
||||
Some(value) => serde_json::from_value(value).unwrap(),
|
||||
None => ChatSettings::default(),
|
||||
|
@ -262,6 +288,7 @@ async fn streamed_chat(
|
|||
config = config.with_api_base(base_api);
|
||||
}
|
||||
|
||||
let auth_token = extract_token_from_request(&req)?.unwrap().to_string();
|
||||
setup_search_tool(&mut chat_completion, &chat_settings.prompts);
|
||||
|
||||
let (tx, rx) = tokio::sync::mpsc::channel(10);
|
||||
|
@ -354,7 +381,9 @@ async fn streamed_chat(
|
|||
|
||||
let result = process_search_request(
|
||||
&index_scheduler,
|
||||
auth_ctrl.clone(),
|
||||
&search_queue,
|
||||
&auth_token,
|
||||
index_uid,
|
||||
q,
|
||||
)
|
||||
|
@ -417,7 +446,7 @@ async fn streamed_chat(
|
|||
let _ = tx.send(Event::Data(sse::Data::new("[DONE]")));
|
||||
});
|
||||
|
||||
Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10))
|
||||
Ok(Sse::from_infallible_receiver(rx).with_retry_duration(Duration::from_secs(10)))
|
||||
}
|
||||
|
||||
/// The structure used to aggregate the function calls to make.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue