diff --git a/Cargo.lock b/Cargo.lock index f802ac4ab..62cd4b300 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1771,6 +1771,7 @@ dependencies = [ "oxidized-json-checker", "parking_lot", "paste", + "pin-project", "rand 0.7.3", "rayon", "regex", diff --git a/meilisearch-http/Cargo.toml b/meilisearch-http/Cargo.toml index e6912e428..e32ebf587 100644 --- a/meilisearch-http/Cargo.toml +++ b/meilisearch-http/Cargo.toml @@ -73,6 +73,7 @@ tokio = { version = "1", features = ["full"] } uuid = { version = "0.8.2", features = ["serde"] } walkdir = "2.3.2" obkv = "0.1.1" +pin-project = "1.0.7" [dependencies.sentry] default-features = false diff --git a/meilisearch-http/src/helpers/authentication.rs b/meilisearch-http/src/helpers/authentication.rs index 9944c0bd4..a1a0c431e 100644 --- a/meilisearch-http/src/helpers/authentication.rs +++ b/meilisearch-http/src/helpers/authentication.rs @@ -1,14 +1,16 @@ -use std::cell::RefCell; use std::pin::Pin; -use std::rc::Rc; use std::task::{Context, Poll}; use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform}; use actix_web::web; -use futures::future::{err, ok, Future, Ready}; +use actix_web::body::Body; +use futures::ready; +use futures::future::{ok, Future, Ready}; +use actix_web::ResponseError as _; +use pin_project::pin_project; -use crate::error::{Error, ResponseError}; use crate::Data; +use crate::error::{Error, ResponseError}; #[derive(Clone, Copy)] pub enum Authentication { @@ -17,13 +19,11 @@ pub enum Authentication { Admin, } -impl Transform for Authentication +impl Transform for Authentication where - S: Service, Error = actix_web::Error>, - S::Future: 'static, - B: 'static, + S: Service, Error = actix_web::Error>, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = actix_web::Error; type InitError = (); type Transform = LoggingMiddleware; @@ -32,54 +32,45 @@ where fn new_transform(&self, service: S) -> Self::Future { ok(LoggingMiddleware { acl: *self, - service: Rc::new(RefCell::new(service)), + service, }) } } pub struct LoggingMiddleware { acl: Authentication, - service: Rc>, + service: S, } #[allow(clippy::type_complexity)] -impl Service for LoggingMiddleware +impl Service for LoggingMiddleware where - S: Service, Error = actix_web::Error> + 'static, - S::Future: 'static, - B: 'static, + S: Service, Error = actix_web::Error>, { - type Response = ServiceResponse; + type Response = ServiceResponse; type Error = actix_web::Error; - type Future = Pin>>>; + type Future = AuthenticationFuture; fn poll_ready(&self, cx: &mut Context) -> Poll> { self.service.poll_ready(cx) } fn call(&self, req: ServiceRequest) -> Self::Future { - let svc = self.service.clone(); - // This unwrap is left because this error should never appear. If that's the case, then - // it means that actix-web has an issue or someone changes the type `Data`. let data = req.app_data::>().unwrap(); if data.api_keys().master.is_none() { - return Box::pin(svc.call(req)); + return AuthenticationFuture::Authenticated(self.service.call(req)) } let auth_header = match req.headers().get("X-Meili-API-Key") { Some(auth) => match auth.to_str() { Ok(auth) => auth, Err(_) => { - return Box::pin(err( - ResponseError::from(Error::MissingAuthorizationHeader).into() - )) + return AuthenticationFuture::NoHeader(Some(req)) } }, None => { - return Box::pin(err( - ResponseError::from(Error::MissingAuthorizationHeader).into() - )); + return AuthenticationFuture::NoHeader(Some(req)) } }; @@ -97,12 +88,66 @@ where }; if authenticated { - Box::pin(svc.call(req)) + AuthenticationFuture::Authenticated(self.service.call(req)) } else { - Box::pin(err(ResponseError::from(Error::InvalidToken( - auth_header.to_string(), - )) - .into())) + AuthenticationFuture::Refused(Some(req)) + } + } +} + +#[pin_project(project = AuthProj)] +pub enum AuthenticationFuture +where + S: Service, +{ + Authenticated(#[pin] S::Future), + NoHeader(Option), + Refused(Option), +} + +impl Future for AuthenticationFuture +where + S: Service, Error = actix_web::Error>, +{ + type Output = Result, actix_web::Error>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) ->Poll { + let this = self.project(); + match this { + AuthProj::Authenticated(fut) => { + match ready!(fut.poll(cx)) { + Ok(resp) => Poll::Ready(Ok(resp)), + Err(e) => Poll::Ready(Err(e)), + } + } + AuthProj::NoHeader(req) => { + match req.take() { + Some(req) => { + let response = ResponseError::from(Error::MissingAuthorizationHeader); + let response = response.error_response(); + let response = req.into_response(response); + Poll::Ready(Ok(response)) + } + // https://doc.rust-lang.org/nightly/std/future/trait.Future.html#panics + None => unreachable!("poll called again on ready future"), + } + } + AuthProj::Refused(req) => { + match req.take() { + Some(req) => { + let bad_token = req.headers() + .get("X-Meili-API-Key") + .map(|h| h.to_str().map(String::from).unwrap_or_default()) + .unwrap_or_default(); + let response = ResponseError::from(Error::InvalidToken(bad_token)); + let response = response.error_response(); + let response = req.into_response(response); + Poll::Ready(Ok(response)) + } + // https://doc.rust-lang.org/nightly/std/future/trait.Future.html#panics + None => unreachable!("poll called again on ready future"), + } + } } } }