From b82c86c8f5fb49e3f52ec30182779ba1c6a219a3 Mon Sep 17 00:00:00 2001 From: Kerollmops Date: Mon, 30 May 2022 13:59:27 +0200 Subject: [PATCH] Allow users to filter indexUid with a * --- meilisearch-http/src/routes/tasks.rs | 43 ++++++++++++++++++++++++++-- meilisearch-http/tests/tasks/mod.rs | 25 ++++++++++++++++ 2 files changed, 65 insertions(+), 3 deletions(-) diff --git a/meilisearch-http/src/routes/tasks.rs b/meilisearch-http/src/routes/tasks.rs index 34132db0d..93af5af26 100644 --- a/meilisearch-http/src/routes/tasks.rs +++ b/meilisearch-http/src/routes/tasks.rs @@ -6,6 +6,7 @@ use meilisearch_lib::{IndexUid, MeiliSearch}; use serde::Deserialize; use serde_cs::vec::CS; use serde_json::json; +use std::str::FromStr; use crate::analytics::Analytics; use crate::extractors::authentication::{policies::*, GuardedData}; @@ -23,7 +24,26 @@ pub struct TaskFilterQuery { #[serde(rename = "type")] type_: Option>, status: Option>, - index_uid: Option>, + index_uid: Option>, +} + +/// A type that tries to match either a star (*) or an IndexUid. +#[derive(Debug)] +enum StarOrIndexUid { + Star, + IndexUid(IndexUid), +} + +impl FromStr for StarOrIndexUid { + type Err = ::Err; + + fn from_str(s: &str) -> Result { + if s.trim() == "*" { + Ok(StarOrIndexUid::Star) + } else { + IndexUid::from_str(s).map(StarOrIndexUid::IndexUid) + } + } } #[rustfmt::skip] @@ -70,12 +90,29 @@ async fn get_tasks( let search_rules = &meilisearch.filters().search_rules; - // We first filter on potential indexes and make sure + // We first tranform a potential indexUid=* into a + // "not specified indexUid filter". + let index_uid = + match index_uid { + Some(indexes) => indexes + .into_inner() + .into_iter() + .fold(Some(Vec::new()), |acc, val| match (acc, val) { + (None, _) | (_, StarOrIndexUid::Star) => None, + (Some(mut acc), StarOrIndexUid::IndexUid(uid)) => { + acc.push(uid); + Some(acc) + } + }), + None => None, + }; + + // Then we filter on potential indexes and make sure // that the search filter restrictions are also applied. let indexes_filters = match index_uid { Some(indexes) => { let mut filters = TaskFilter::default(); - for name in indexes.into_inner() { + for name in indexes { if search_rules.is_index_authorized(&name) { filters.filter_index(name.to_string()); } diff --git a/meilisearch-http/tests/tasks/mod.rs b/meilisearch-http/tests/tasks/mod.rs index 80bf6cb3d..b14491fd2 100644 --- a/meilisearch-http/tests/tasks/mod.rs +++ b/meilisearch-http/tests/tasks/mod.rs @@ -59,6 +59,31 @@ async fn list_tasks() { assert_eq!(response["results"].as_array().unwrap().len(), 2); } +#[actix_rt::test] +async fn list_tasks_with_index_filter() { + let server = Server::new().await; + let index = server.index("test"); + index.create(None).await; + index.wait_task(0).await; + index + .add_documents( + serde_json::from_str(include_str!("../assets/test_set.json")).unwrap(), + None, + ) + .await; + let (response, code) = index.service.get("/tasks?indexUid=test").await; + assert_eq!(code, 200); + assert_eq!(response["results"].as_array().unwrap().len(), 2); + + let (response, code) = index.service.get("/tasks?indexUid=*").await; + assert_eq!(code, 200); + assert_eq!(response["results"].as_array().unwrap().len(), 2); + + let (response, code) = index.service.get("/tasks?indexUid=*,pasteque").await; + assert_eq!(code, 200); + assert_eq!(response["results"].as_array().unwrap().len(), 2); +} + #[actix_rt::test] async fn list_tasks_status_filtered() { let server = Server::new().await;