From b57c59baa4f0683d521d91d7d9e23dd1f98dd08b Mon Sep 17 00:00:00 2001 From: ad hoc Date: Fri, 4 Mar 2022 20:12:44 +0100 Subject: [PATCH] sequential extractor --- Cargo.lock | 2 +- meilisearch-http/Cargo.toml | 2 +- .../src/extractors/authentication/mod.rs | 5 +- meilisearch-http/src/extractors/mod.rs | 1 + .../src/extractors/sequential_extractor.rs | 148 ++++++++++++++++++ meilisearch-http/src/routes/api_key.rs | 15 +- meilisearch-http/src/routes/dump.rs | 7 +- .../src/routes/indexes/documents.rs | 15 +- meilisearch-http/src/routes/indexes/mod.rs | 11 +- meilisearch-http/src/routes/indexes/search.rs | 5 +- .../src/routes/indexes/settings.rs | 14 +- meilisearch-http/src/routes/indexes/tasks.rs | 5 +- meilisearch-http/src/routes/tasks.rs | 5 +- meilisearch-http/tests/auth/authorization.rs | 1 + 14 files changed, 198 insertions(+), 38 deletions(-) create mode 100644 meilisearch-http/src/extractors/sequential_extractor.rs diff --git a/Cargo.lock b/Cargo.lock index 1ba427a1e..815fdf3b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1780,7 +1780,7 @@ dependencies = [ "once_cell", "parking_lot 0.11.2", "paste", - "pin-project", + "pin-project-lite", "platform-dirs", "rand", "rayon", diff --git a/meilisearch-http/Cargo.toml b/meilisearch-http/Cargo.toml index c3dd1026f..da7d9e61a 100644 --- a/meilisearch-http/Cargo.toml +++ b/meilisearch-http/Cargo.toml @@ -54,7 +54,6 @@ num_cpus = "1.13.0" obkv = "0.2.0" once_cell = "1.8.0" parking_lot = "0.11.2" -pin-project = "1.0.8" platform-dirs = "0.3.0" rand = "0.8.4" rayon = "1.5.1" @@ -78,6 +77,7 @@ tokio = { version = "1.11.0", features = ["full"] } tokio-stream = "0.1.7" uuid = { version = "0.8.2", features = ["serde"] } walkdir = "2.3.2" +pin-project-lite = "0.2.8" [dev-dependencies] actix-rt = "2.2.0" diff --git a/meilisearch-http/src/extractors/authentication/mod.rs b/meilisearch-http/src/extractors/authentication/mod.rs index ab0030fc1..873f7cbcd 100644 --- a/meilisearch-http/src/extractors/authentication/mod.rs +++ b/meilisearch-http/src/extractors/authentication/mod.rs @@ -41,10 +41,7 @@ impl GuardedData { }), None => Err(AuthenticationError::IrretrievableState.into()), }, - (token, None) => { - let token = token.to_string(); - Err(AuthenticationError::InvalidToken(token).into()) - } + (token, None) => Err(AuthenticationError::InvalidToken(token).into()), } } diff --git a/meilisearch-http/src/extractors/mod.rs b/meilisearch-http/src/extractors/mod.rs index 09a56e4a0..98a22f8c9 100644 --- a/meilisearch-http/src/extractors/mod.rs +++ b/meilisearch-http/src/extractors/mod.rs @@ -1,3 +1,4 @@ pub mod payload; #[macro_use] pub mod authentication; +pub mod sequential_extractor; diff --git a/meilisearch-http/src/extractors/sequential_extractor.rs b/meilisearch-http/src/extractors/sequential_extractor.rs new file mode 100644 index 000000000..1176334ad --- /dev/null +++ b/meilisearch-http/src/extractors/sequential_extractor.rs @@ -0,0 +1,148 @@ +#![allow(non_snake_case)] +use std::{future::Future, pin::Pin, task::Poll}; + +use actix_web::{dev::Payload, FromRequest, Handler, HttpRequest}; +use pin_project_lite::pin_project; + +/// `SeqHandler` is an actix `Handler` that enforces that extractors errors are returned in the +/// same order as they are defined in the wrapped handler. This is needed because, by default, actix +/// to resolves the extractors concurrently, whereas we always need the authentication extractor to +/// throw first. +#[derive(Clone)] +pub struct SeqHandler(pub H); + +pub struct SeqFromRequest(T); + +/// This macro implements `FromRequest` for arbitrary arity handler, except for one, which is +/// useless anyway. +macro_rules! gen_seq { + ($ty:ident; $($T:ident)+) => { + pin_project! { + pub struct $ty<$($T: FromRequest), +> { + $( + #[pin] + $T: ExtractFuture<$T::Future, $T, $T::Error>, + )+ + } + } + + impl<$($T: FromRequest), +> Future for $ty<$($T),+> { + type Output = Result, actix_web::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + let mut this = self.project(); + + let mut count_fut = 0; + let mut count_finished = 0; + + $( + count_fut += 1; + match this.$T.as_mut().project() { + ExtractProj::Future { fut } => match fut.poll(cx) { + Poll::Ready(Ok(output)) => { + count_finished += 1; + let _ = this + .$T + .as_mut() + .project_replace(ExtractFuture::Done { output }); + } + Poll::Ready(Err(error)) => { + count_finished += 1; + let _ = this + .$T + .as_mut() + .project_replace(ExtractFuture::Error { error }); + } + Poll::Pending => (), + }, + ExtractProj::Done { .. } => count_finished += 1, + ExtractProj::Error { .. } => { + // short circuit if all previous are finished and we had an error. + if count_finished == count_fut { + match this.$T.project_replace(ExtractFuture::Empty) { + ExtractReplaceProj::Error { error } => { + return Poll::Ready(Err(error.into())) + } + _ => unreachable!("Invalid future state"), + } + } else { + count_finished += 1; + } + } + ExtractProj::Empty => unreachable!("From request polled after being finished. {}", stringify!($T)), + } + )+ + + if count_fut == count_finished { + let result = ( + $( + match this.$T.project_replace(ExtractFuture::Empty) { + ExtractReplaceProj::Done { output } => output, + ExtractReplaceProj::Error { error } => return Poll::Ready(Err(error.into())), + _ => unreachable!("Invalid future state"), + }, + )+ + ); + + Poll::Ready(Ok(SeqFromRequest(result))) + } else { + Poll::Pending + } + } + } + + impl<$($T: FromRequest,)+> FromRequest for SeqFromRequest<($($T,)+)> { + type Error = actix_web::Error; + + type Future = $ty<$($T),+>; + + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + $ty { + $( + $T: ExtractFuture::Future { + fut: $T::from_request(req, payload), + }, + )+ + } + } + } + + impl Handler> for SeqHandler + where + Han: Handler<($($T),+)>, + { + type Output = Han::Output; + type Future = Han::Future; + + fn call(&self, args: SeqFromRequest<($($T),+)>) -> Self::Future { + self.0.call(args.0) + } + } + }; +} + +// Not working for a single argument, but then, it is not really necessary. +// gen_seq! { SeqFromRequestFut1; A } +gen_seq! { SeqFromRequestFut2; A B } +gen_seq! { SeqFromRequestFut3; A B C } +gen_seq! { SeqFromRequestFut4; A B C D } +gen_seq! { SeqFromRequestFut5; A B C D E } +gen_seq! { SeqFromRequestFut6; A B C D E F } + +pin_project! { + #[project = ExtractProj] + #[project_replace = ExtractReplaceProj] + enum ExtractFuture { + Future { + #[pin] + fut: Fut, + }, + Done { + output: Res, + }, + Error { + error: Err, + }, + Empty, + } +} diff --git a/meilisearch-http/src/routes/api_key.rs b/meilisearch-http/src/routes/api_key.rs index aea35e68e..310b09c4d 100644 --- a/meilisearch-http/src/routes/api_key.rs +++ b/meilisearch-http/src/routes/api_key.rs @@ -7,20 +7,23 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use time::OffsetDateTime; -use crate::extractors::authentication::{policies::*, GuardedData}; +use crate::extractors::{ + authentication::{policies::*, GuardedData}, + sequential_extractor::SeqHandler, +}; use meilisearch_error::{Code, ResponseError}; pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service( web::resource("") - .route(web::post().to(create_api_key)) - .route(web::get().to(list_api_keys)), + .route(web::post().to(SeqHandler(create_api_key))) + .route(web::get().to(SeqHandler(list_api_keys))), ) .service( web::resource("/{api_key}") - .route(web::get().to(get_api_key)) - .route(web::patch().to(patch_api_key)) - .route(web::delete().to(delete_api_key)), + .route(web::get().to(SeqHandler(get_api_key))) + .route(web::patch().to(SeqHandler(patch_api_key))) + .route(web::delete().to(SeqHandler(delete_api_key))), ); } diff --git a/meilisearch-http/src/routes/dump.rs b/meilisearch-http/src/routes/dump.rs index 0627ea378..65cd7521f 100644 --- a/meilisearch-http/src/routes/dump.rs +++ b/meilisearch-http/src/routes/dump.rs @@ -7,10 +7,13 @@ use serde_json::json; use crate::analytics::Analytics; use crate::extractors::authentication::{policies::*, GuardedData}; +use crate::extractors::sequential_extractor::SeqHandler; pub fn configure(cfg: &mut web::ServiceConfig) { - cfg.service(web::resource("").route(web::post().to(create_dump))) - .service(web::resource("/{dump_uid}/status").route(web::get().to(get_dump_status))); + cfg.service(web::resource("").route(web::post().to(SeqHandler(create_dump)))) + .service( + web::resource("/{dump_uid}/status").route(web::get().to(SeqHandler(get_dump_status))), + ); } pub async fn create_dump( diff --git a/meilisearch-http/src/routes/indexes/documents.rs b/meilisearch-http/src/routes/indexes/documents.rs index d18c600af..66551ec77 100644 --- a/meilisearch-http/src/routes/indexes/documents.rs +++ b/meilisearch-http/src/routes/indexes/documents.rs @@ -20,6 +20,7 @@ use crate::analytics::Analytics; use crate::error::MeilisearchHttpError; use crate::extractors::authentication::{policies::*, GuardedData}; use crate::extractors::payload::Payload; +use crate::extractors::sequential_extractor::SeqHandler; use crate::task::SummarizedTaskView; const DEFAULT_RETRIEVE_DOCUMENTS_OFFSET: usize = 0; @@ -71,17 +72,17 @@ pub struct DocumentParam { pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service( web::resource("") - .route(web::get().to(get_all_documents)) - .route(web::post().to(add_documents)) - .route(web::put().to(update_documents)) - .route(web::delete().to(clear_all_documents)), + .route(web::get().to(SeqHandler(get_all_documents))) + .route(web::post().to(SeqHandler(add_documents))) + .route(web::put().to(SeqHandler(update_documents))) + .route(web::delete().to(SeqHandler(clear_all_documents))), ) // this route needs to be before the /documents/{document_id} to match properly - .service(web::resource("/delete-batch").route(web::post().to(delete_documents))) + .service(web::resource("/delete-batch").route(web::post().to(SeqHandler(delete_documents)))) .service( web::resource("/{document_id}") - .route(web::get().to(get_document)) - .route(web::delete().to(delete_document)), + .route(web::get().to(SeqHandler(get_document))) + .route(web::delete().to(SeqHandler(delete_document))), ); } diff --git a/meilisearch-http/src/routes/indexes/mod.rs b/meilisearch-http/src/routes/indexes/mod.rs index 50e54e6b4..bd74fd724 100644 --- a/meilisearch-http/src/routes/indexes/mod.rs +++ b/meilisearch-http/src/routes/indexes/mod.rs @@ -9,6 +9,7 @@ use time::OffsetDateTime; use crate::analytics::Analytics; use crate::extractors::authentication::{policies::*, GuardedData}; +use crate::extractors::sequential_extractor::SeqHandler; use crate::task::SummarizedTaskView; pub mod documents; @@ -20,17 +21,17 @@ pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service( web::resource("") .route(web::get().to(list_indexes)) - .route(web::post().to(create_index)), + .route(web::post().to(SeqHandler(create_index))), ) .service( web::scope("/{index_uid}") .service( web::resource("") - .route(web::get().to(get_index)) - .route(web::put().to(update_index)) - .route(web::delete().to(delete_index)), + .route(web::get().to(SeqHandler(get_index))) + .route(web::put().to(SeqHandler(update_index))) + .route(web::delete().to(SeqHandler(delete_index))), ) - .service(web::resource("/stats").route(web::get().to(get_index_stats))) + .service(web::resource("/stats").route(web::get().to(SeqHandler(get_index_stats)))) .service(web::scope("/documents").configure(documents::configure)) .service(web::scope("/search").configure(search::configure)) .service(web::scope("/tasks").configure(tasks::configure)) diff --git a/meilisearch-http/src/routes/indexes/search.rs b/meilisearch-http/src/routes/indexes/search.rs index a1695633e..14b3f74f5 100644 --- a/meilisearch-http/src/routes/indexes/search.rs +++ b/meilisearch-http/src/routes/indexes/search.rs @@ -9,12 +9,13 @@ use serde_json::Value; use crate::analytics::{Analytics, SearchAggregator}; use crate::extractors::authentication::{policies::*, GuardedData}; +use crate::extractors::sequential_extractor::SeqHandler; pub fn configure(cfg: &mut web::ServiceConfig) { cfg.service( web::resource("") - .route(web::get().to(search_with_url_query)) - .route(web::post().to(search_with_post)), + .route(web::get().to(SeqHandler(search_with_url_query))) + .route(web::post().to(SeqHandler(search_with_post))), ); } diff --git a/meilisearch-http/src/routes/indexes/settings.rs b/meilisearch-http/src/routes/indexes/settings.rs index 8b38072c4..eeb3e71b3 100644 --- a/meilisearch-http/src/routes/indexes/settings.rs +++ b/meilisearch-http/src/routes/indexes/settings.rs @@ -23,6 +23,7 @@ macro_rules! make_setting_route { use crate::analytics::Analytics; use crate::extractors::authentication::{policies::*, GuardedData}; + use crate::extractors::sequential_extractor::SeqHandler; use crate::task::SummarizedTaskView; use meilisearch_error::ResponseError; @@ -98,9 +99,9 @@ macro_rules! make_setting_route { pub fn resources() -> Resource { Resource::new($route) - .route(web::get().to(get)) - .route(web::post().to(update)) - .route(web::delete().to(delete)) + .route(web::get().to(SeqHandler(get))) + .route(web::post().to(SeqHandler(update))) + .route(web::delete().to(SeqHandler(delete))) } } }; @@ -226,11 +227,12 @@ make_setting_route!( macro_rules! generate_configure { ($($mod:ident),*) => { pub fn configure(cfg: &mut web::ServiceConfig) { + use crate::extractors::sequential_extractor::SeqHandler; cfg.service( web::resource("") - .route(web::post().to(update_all)) - .route(web::get().to(get_all)) - .route(web::delete().to(delete_all))) + .route(web::post().to(SeqHandler(update_all))) + .route(web::get().to(SeqHandler(get_all))) + .route(web::delete().to(SeqHandler(delete_all)))) $(.service($mod::resources()))*; } }; diff --git a/meilisearch-http/src/routes/indexes/tasks.rs b/meilisearch-http/src/routes/indexes/tasks.rs index 8545831a0..01ed85db8 100644 --- a/meilisearch-http/src/routes/indexes/tasks.rs +++ b/meilisearch-http/src/routes/indexes/tasks.rs @@ -8,11 +8,12 @@ use time::OffsetDateTime; use crate::analytics::Analytics; use crate::extractors::authentication::{policies::*, GuardedData}; +use crate::extractors::sequential_extractor::SeqHandler; use crate::task::{TaskListView, TaskView}; pub fn configure(cfg: &mut web::ServiceConfig) { - cfg.service(web::resource("").route(web::get().to(get_all_tasks_status))) - .service(web::resource("{task_id}").route(web::get().to(get_task_status))); + cfg.service(web::resource("").route(web::get().to(SeqHandler(get_all_tasks_status)))) + .service(web::resource("{task_id}").route(web::get().to(SeqHandler(get_task_status)))); } #[derive(Debug, Serialize)] diff --git a/meilisearch-http/src/routes/tasks.rs b/meilisearch-http/src/routes/tasks.rs index 350cef3dc..ae932253a 100644 --- a/meilisearch-http/src/routes/tasks.rs +++ b/meilisearch-http/src/routes/tasks.rs @@ -7,11 +7,12 @@ use serde_json::json; use crate::analytics::Analytics; use crate::extractors::authentication::{policies::*, GuardedData}; +use crate::extractors::sequential_extractor::SeqHandler; use crate::task::{TaskListView, TaskView}; pub fn configure(cfg: &mut web::ServiceConfig) { - cfg.service(web::resource("").route(web::get().to(get_tasks))) - .service(web::resource("/{task_id}").route(web::get().to(get_task))); + cfg.service(web::resource("").route(web::get().to(SeqHandler(get_tasks)))) + .service(web::resource("/{task_id}").route(web::get().to(SeqHandler(get_task)))); } async fn get_tasks( diff --git a/meilisearch-http/tests/auth/authorization.rs b/meilisearch-http/tests/auth/authorization.rs index 30df2dd2d..dcb4504c6 100644 --- a/meilisearch-http/tests/auth/authorization.rs +++ b/meilisearch-http/tests/auth/authorization.rs @@ -91,6 +91,7 @@ async fn error_access_expired_key() { thread::sleep(time::Duration::new(1, 0)); for (method, route) in AUTHORIZATIONS.keys() { + dbg!(route); let (response, code) = server.dummy_request(method, route).await; assert_eq!(response, INVALID_RESPONSE.clone());