mirror of
https://github.com/meilisearch/MeiliSearch
synced 2025-01-11 05:54:30 +01:00
add cors
This commit is contained in:
parent
6a32432b01
commit
bbf9f41a04
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -1026,6 +1026,7 @@ dependencies = [
|
||||
"chrono 0.4.9 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"crossbeam-channel 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"env_logger 0.7.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"futures 0.3.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"heed 0.6.1 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"http 0.1.19 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"http-service 0.4.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
|
@ -39,11 +39,11 @@ tide = "0.5.1"
|
||||
ureq = { version = "0.11.2", features = ["tls"], default-features = false }
|
||||
walkdir = "2.2.9"
|
||||
whoami = "0.6"
|
||||
|
||||
http-service = "0.4.0"
|
||||
futures = "0.3.1"
|
||||
|
||||
[dev-dependencies]
|
||||
http-service-mock = "0.4.0"
|
||||
http-service = "0.4.0"
|
||||
tempdir = "0.3.7"
|
||||
|
||||
[dev-dependencies.assert-json-diff]
|
||||
|
424
meilisearch-http/src/cors.rs
Normal file
424
meilisearch-http/src/cors.rs
Normal file
@ -0,0 +1,424 @@
|
||||
//! Cors middleware
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
use http::header::HeaderValue;
|
||||
use http::{header, Method, StatusCode};
|
||||
use http_service::Body;
|
||||
|
||||
use tide::middleware::{Middleware, Next};
|
||||
use tide::{Request, Response};
|
||||
|
||||
/// Middleware for CORS
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// use http::header::HeaderValue;
|
||||
/// use tide::middleware::{Cors, Origin};
|
||||
///
|
||||
/// Cors::new()
|
||||
/// .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS"))
|
||||
/// .allow_origin(Origin::from("*"))
|
||||
/// .allow_credentials(false);
|
||||
/// ```
|
||||
#[derive(Clone, Debug, Hash)]
|
||||
pub struct Cors {
|
||||
allow_credentials: Option<HeaderValue>,
|
||||
allow_headers: HeaderValue,
|
||||
allow_methods: HeaderValue,
|
||||
allow_origin: Origin,
|
||||
expose_headers: Option<HeaderValue>,
|
||||
max_age: HeaderValue,
|
||||
}
|
||||
|
||||
pub const DEFAULT_MAX_AGE: &str = "86400";
|
||||
pub const DEFAULT_METHODS: &str = "GET, POST, OPTIONS";
|
||||
pub const WILDCARD: &str = "*";
|
||||
|
||||
impl Cors {
|
||||
/// Creates a new Cors middleware.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
allow_credentials: None,
|
||||
allow_headers: HeaderValue::from_static(WILDCARD),
|
||||
allow_methods: HeaderValue::from_static(DEFAULT_METHODS),
|
||||
allow_origin: Origin::Any,
|
||||
expose_headers: None,
|
||||
max_age: HeaderValue::from_static(DEFAULT_MAX_AGE),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set allow_credentials and return new Cors
|
||||
pub fn allow_credentials(mut self, allow_credentials: bool) -> Self {
|
||||
self.allow_credentials = match HeaderValue::from_str(&allow_credentials.to_string()) {
|
||||
Ok(header) => Some(header),
|
||||
Err(_) => None,
|
||||
};
|
||||
self
|
||||
}
|
||||
|
||||
/// Set allow_headers and return new Cors
|
||||
pub fn allow_headers<T: Into<HeaderValue>>(mut self, headers: T) -> Self {
|
||||
self.allow_headers = headers.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set max_age and return new Cors
|
||||
pub fn max_age<T: Into<HeaderValue>>(mut self, max_age: T) -> Self {
|
||||
self.max_age = max_age.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set allow_methods and return new Cors
|
||||
pub fn allow_methods<T: Into<HeaderValue>>(mut self, methods: T) -> Self {
|
||||
self.allow_methods = methods.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set allow_origin and return new Cors
|
||||
pub fn allow_origin<T: Into<Origin>>(mut self, origin: T) -> Self {
|
||||
self.allow_origin = origin.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set expose_headers and return new Cors
|
||||
pub fn expose_headers<T: Into<HeaderValue>>(mut self, headers: T) -> Self {
|
||||
self.expose_headers = Some(headers.into());
|
||||
self
|
||||
}
|
||||
|
||||
fn build_preflight_response(&self, origin: &HeaderValue) -> http::response::Response<Body> {
|
||||
let mut response = http::Response::builder()
|
||||
.status(StatusCode::OK)
|
||||
.header::<_, HeaderValue>(header::ACCESS_CONTROL_ALLOW_ORIGIN, origin.clone())
|
||||
.header(
|
||||
header::ACCESS_CONTROL_ALLOW_METHODS,
|
||||
self.allow_methods.clone(),
|
||||
)
|
||||
.header(
|
||||
header::ACCESS_CONTROL_ALLOW_HEADERS,
|
||||
self.allow_headers.clone(),
|
||||
)
|
||||
.header(header::ACCESS_CONTROL_MAX_AGE, self.max_age.clone())
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
if let Some(allow_credentials) = self.allow_credentials.clone() {
|
||||
response
|
||||
.headers_mut()
|
||||
.append(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, allow_credentials);
|
||||
}
|
||||
|
||||
if let Some(expose_headers) = self.expose_headers.clone() {
|
||||
response
|
||||
.headers_mut()
|
||||
.append(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers);
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
||||
/// Look at origin of request and determine allow_origin
|
||||
fn response_origin<T: Into<HeaderValue>>(&self, origin: T) -> Option<HeaderValue> {
|
||||
let origin = origin.into();
|
||||
if !self.is_valid_origin(origin.clone()) {
|
||||
return None;
|
||||
}
|
||||
|
||||
match self.allow_origin {
|
||||
Origin::Any => Some(HeaderValue::from_static(WILDCARD)),
|
||||
_ => Some(origin),
|
||||
}
|
||||
}
|
||||
|
||||
/// Determine if origin is appropriate
|
||||
fn is_valid_origin<T: Into<HeaderValue>>(&self, origin: T) -> bool {
|
||||
let origin = match origin.into().to_str() {
|
||||
Ok(s) => s.to_string(),
|
||||
Err(_) => return false,
|
||||
};
|
||||
|
||||
match &self.allow_origin {
|
||||
Origin::Any => true,
|
||||
Origin::Exact(s) => s == &origin,
|
||||
Origin::List(list) => list.contains(&origin),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<State: Send + Sync + 'static> Middleware<State> for Cors {
|
||||
fn handle<'a>(&'a self, req: Request<State>, next: Next<'a, State>) -> BoxFuture<'a, Response> {
|
||||
Box::pin(async move {
|
||||
let origin = req
|
||||
.headers()
|
||||
.get(header::ORIGIN)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| HeaderValue::from_static(""));
|
||||
|
||||
if !self.is_valid_origin(&origin) {
|
||||
return http::Response::builder()
|
||||
.status(StatusCode::UNAUTHORIZED)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
.into();
|
||||
}
|
||||
|
||||
// Return results immediately upon preflight request
|
||||
if req.method() == Method::OPTIONS {
|
||||
return self.build_preflight_response(&origin).into();
|
||||
}
|
||||
|
||||
let mut response: http_service::Response = next.run(req).await.into();
|
||||
let headers = response.headers_mut();
|
||||
|
||||
headers.append(
|
||||
header::ACCESS_CONTROL_ALLOW_ORIGIN,
|
||||
self.response_origin(origin).unwrap(),
|
||||
);
|
||||
|
||||
if let Some(allow_credentials) = self.allow_credentials.clone() {
|
||||
headers.append(header::ACCESS_CONTROL_ALLOW_CREDENTIALS, allow_credentials);
|
||||
}
|
||||
|
||||
if let Some(expose_headers) = self.expose_headers.clone() {
|
||||
headers.append(header::ACCESS_CONTROL_EXPOSE_HEADERS, expose_headers);
|
||||
}
|
||||
response.into()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Cors {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// allow_origin enum
|
||||
#[derive(Clone, Debug, Hash, PartialEq)]
|
||||
pub enum Origin {
|
||||
/// Wildcard. Accept all origin requests
|
||||
Any,
|
||||
/// Set a single allow_origin target
|
||||
Exact(String),
|
||||
/// Set multiple allow_origin targets
|
||||
List(Vec<String>),
|
||||
}
|
||||
|
||||
impl From<String> for Origin {
|
||||
fn from(s: String) -> Self {
|
||||
if s == "*" {
|
||||
return Origin::Any;
|
||||
}
|
||||
Origin::Exact(s)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Origin {
|
||||
fn from(s: &str) -> Self {
|
||||
Origin::from(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<String>> for Origin {
|
||||
fn from(list: Vec<String>) -> Self {
|
||||
if list.len() == 1 {
|
||||
return Self::from(list[0].clone());
|
||||
}
|
||||
|
||||
Origin::List(list)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Vec<&str>> for Origin {
|
||||
fn from(list: Vec<&str>) -> Self {
|
||||
Origin::from(list.iter().map(|s| s.to_string()).collect::<Vec<String>>())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
use http::header::HeaderValue;
|
||||
use http_service::Body;
|
||||
use http_service_mock::make_server;
|
||||
|
||||
const ALLOW_ORIGIN: &str = "example.com";
|
||||
const ALLOW_METHODS: &str = "GET, POST, OPTIONS, DELETE";
|
||||
const EXPOSE_HEADER: &str = "X-My-Custom-Header";
|
||||
|
||||
const ENDPOINT: &str = "/cors";
|
||||
|
||||
fn app() -> crate::Server<()> {
|
||||
let mut app = crate::Server::new();
|
||||
app.at(ENDPOINT).get(|_| async move { "Hello World" });
|
||||
|
||||
app
|
||||
}
|
||||
|
||||
fn request() -> http::Request<http_service::Body> {
|
||||
http::Request::get(ENDPOINT)
|
||||
.header(http::header::ORIGIN, ALLOW_ORIGIN)
|
||||
.method(http::method::Method::GET)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preflight_request() {
|
||||
let mut app = app();
|
||||
app.middleware(
|
||||
Cors::new()
|
||||
.allow_origin(Origin::from(ALLOW_ORIGIN))
|
||||
.allow_methods(HeaderValue::from_static(ALLOW_METHODS))
|
||||
.expose_headers(HeaderValue::from_static(EXPOSE_HEADER))
|
||||
.allow_credentials(true),
|
||||
);
|
||||
|
||||
let mut server = make_server(app.into_http_service()).unwrap();
|
||||
|
||||
let req = http::Request::get(ENDPOINT)
|
||||
.header(http::header::ORIGIN, ALLOW_ORIGIN)
|
||||
.method(http::method::Method::OPTIONS)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let res = server.simulate(req).unwrap();
|
||||
|
||||
assert_eq!(res.status(), 200);
|
||||
|
||||
assert_eq!(
|
||||
res.headers().get("access-control-allow-origin").unwrap(),
|
||||
ALLOW_ORIGIN
|
||||
);
|
||||
assert_eq!(
|
||||
res.headers().get("access-control-allow-methods").unwrap(),
|
||||
ALLOW_METHODS
|
||||
);
|
||||
assert_eq!(
|
||||
res.headers().get("access-control-allow-headers").unwrap(),
|
||||
WILDCARD
|
||||
);
|
||||
assert_eq!(
|
||||
res.headers().get("access-control-max-age").unwrap(),
|
||||
DEFAULT_MAX_AGE
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
res.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.unwrap(),
|
||||
"true"
|
||||
);
|
||||
}
|
||||
#[test]
|
||||
fn default_cors_middleware() {
|
||||
let mut app = app();
|
||||
app.middleware(Cors::new());
|
||||
|
||||
let mut server = make_server(app.into_http_service()).unwrap();
|
||||
let res = server.simulate(request()).unwrap();
|
||||
|
||||
assert_eq!(res.status(), 200);
|
||||
|
||||
assert_eq!(
|
||||
res.headers().get("access-control-allow-origin").unwrap(),
|
||||
"*"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_cors_middleware() {
|
||||
let mut app = app();
|
||||
app.middleware(
|
||||
Cors::new()
|
||||
.allow_origin(Origin::from(ALLOW_ORIGIN))
|
||||
.allow_credentials(false)
|
||||
.allow_methods(HeaderValue::from_static(ALLOW_METHODS))
|
||||
.expose_headers(HeaderValue::from_static(EXPOSE_HEADER)),
|
||||
);
|
||||
|
||||
let mut server = make_server(app.into_http_service()).unwrap();
|
||||
let res = server.simulate(request()).unwrap();
|
||||
|
||||
assert_eq!(res.status(), 200);
|
||||
assert_eq!(
|
||||
res.headers().get("access-control-allow-origin").unwrap(),
|
||||
ALLOW_ORIGIN
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_true() {
|
||||
let mut app = app();
|
||||
app.middleware(Cors::new().allow_credentials(true));
|
||||
|
||||
let mut server = make_server(app.into_http_service()).unwrap();
|
||||
let res = server.simulate(request()).unwrap();
|
||||
|
||||
assert_eq!(res.status(), 200);
|
||||
assert_eq!(
|
||||
res.headers()
|
||||
.get("access-control-allow-credentials")
|
||||
.unwrap(),
|
||||
"true"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn set_allow_origin_list() {
|
||||
let mut app = app();
|
||||
let origins = vec![ALLOW_ORIGIN, "foo.com", "bar.com"];
|
||||
app.middleware(Cors::new().allow_origin(origins.clone()));
|
||||
let mut server = make_server(app.into_http_service()).unwrap();
|
||||
|
||||
for origin in origins {
|
||||
let request = http::Request::get(ENDPOINT)
|
||||
.header(http::header::ORIGIN, origin)
|
||||
.method(http::method::Method::GET)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let res = server.simulate(request).unwrap();
|
||||
|
||||
assert_eq!(res.status(), 200);
|
||||
assert_eq!(
|
||||
res.headers().get("access-control-allow-origin").unwrap(),
|
||||
origin
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn not_set_origin_header() {
|
||||
let mut app = app();
|
||||
app.middleware(Cors::new());
|
||||
|
||||
let request = http::Request::get(ENDPOINT)
|
||||
.method(http::method::Method::GET)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let mut server = make_server(app.into_http_service()).unwrap();
|
||||
let res = server.simulate(request).unwrap();
|
||||
|
||||
assert_eq!(res.status(), 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unauthorized_origin() {
|
||||
let mut app = app();
|
||||
app.middleware(Cors::new().allow_origin(ALLOW_ORIGIN));
|
||||
|
||||
let request = http::Request::get(ENDPOINT)
|
||||
.header(http::header::ORIGIN, "unauthorize-origin.net")
|
||||
.method(http::method::Method::GET)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
|
||||
let mut server = make_server(app.into_http_service()).unwrap();
|
||||
let res = server.simulate(request).unwrap();
|
||||
|
||||
assert_eq!(res.status(), 401);
|
||||
}
|
||||
}
|
@ -5,7 +5,6 @@ use async_std::task;
|
||||
use log::info;
|
||||
use main_error::MainError;
|
||||
use structopt::StructOpt;
|
||||
// use tide::middleware::{CorsMiddleware, CorsOrigin};
|
||||
use tide::middleware::RequestLogger;
|
||||
|
||||
use meilisearch_http::data::Data;
|
||||
@ -13,7 +12,10 @@ use meilisearch_http::option::Opt;
|
||||
use meilisearch_http::routes;
|
||||
use meilisearch_http::routes::index::index_update_callback;
|
||||
|
||||
use cors::Cors;
|
||||
|
||||
mod analytics;
|
||||
mod cors;
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[global_allocator]
|
||||
@ -36,11 +38,7 @@ pub fn main() -> Result<(), MainError> {
|
||||
|
||||
let mut app = tide::with_state(data);
|
||||
|
||||
// app.middleware(
|
||||
// CorsMiddleware::new()
|
||||
// .allow_origin(CorsOrigin::from("*"))
|
||||
// .allow_methods(HeaderValue::from_static("GET, POST, OPTIONS")),
|
||||
// );
|
||||
app.middleware(Cors::new());
|
||||
app.middleware(RequestLogger::new());
|
||||
// app.middleware(tide_compression::Compression::new());
|
||||
// app.middleware(tide_compression::Decompression::new());
|
||||
|
@ -136,6 +136,7 @@ pub fn load_routes(app: &mut tide::Server<Data>) {
|
||||
.post(|ctx| into_response(setting::update_displayed(ctx)))
|
||||
.delete(|ctx| into_response(setting::delete_displayed(ctx)));
|
||||
});
|
||||
|
||||
router.at("/index-new-fields")
|
||||
.get(|ctx| into_response(setting::get_index_new_fields(ctx)))
|
||||
.post(|ctx| into_response(setting::update_index_new_fields(ctx)));
|
||||
|
Loading…
x
Reference in New Issue
Block a user