From c47efbfc93fd7b5f4605879d38a298ae25ea2431 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Thu, 23 Jun 2022 23:10:19 -0400 Subject: [PATCH] fix redirection --- src/backend.rs | 2 +- src/config.rs | 2 +- src/constants.rs | 4 +- src/globals.rs | 4 +- src/main.rs | 2 +- src/proxy/proxy_handler.rs | 107 ++++++++++++++++++++++--------------- 6 files changed, 72 insertions(+), 49 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index bcf2de4..76e5aa6 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -21,7 +21,7 @@ pub struct Backend { #[derive(Debug, Clone)] pub struct ReverseProxy { pub default_destination_uri: hyper::Uri, - pub destination_uris: Option>, // TODO: url pathで引っ掛ける。 + pub destination_uris: HashMap, // TODO: url pathで引っ掛ける。 } impl Backend { diff --git a/src/config.rs b/src/config.rs index cb2d860..715cc6e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -33,7 +33,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap hostname: "localhost".to_string(), reverse_proxy: ReverseProxy { default_destination_uri: "https://google.com/".parse::().unwrap(), - destination_uris: Some(map_example), + destination_uris: map_example, }, redirect_to_https: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず diff --git a/src/constants.rs b/src/constants.rs index 1f5c5be..927ffc1 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,6 +1,6 @@ pub const LISTEN_ADDRESSES: &[&str] = &["0.0.0.0", "[::]"]; -pub const HTTP_LISTEN_PORT: u32 = 8080; -pub const HTTPS_LISTEN_PORT: u32 = 8443; +pub const HTTP_LISTEN_PORT: u16 = 8080; +pub const HTTPS_LISTEN_PORT: u16 = 8443; pub const TIMEOUT_SEC: u64 = 10; pub const MAX_CLIENTS: usize = 512; pub const MAX_CONCURRENT_STREAMS: u32 = 16; diff --git a/src/globals.rs b/src/globals.rs index 5508c7d..30227c7 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -8,8 +8,8 @@ use tokio::time::Duration; #[derive(Debug, Clone)] pub struct Globals { pub listen_sockets: Vec, - pub http_port: Option, - pub https_port: Option, + pub http_port: Option, + pub https_port: Option, pub timeout: Duration, pub max_clients: usize, diff --git a/src/main.rs b/src/main.rs index 3acad8f..b0878c0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -79,7 +79,7 @@ async fn entrypoint(globals: Arc, backends: Arc, backends: Arc>, ) -> Result> { - let headers = req.headers(); - + debug!("req: {:?}", req); // Here we start to handle with hostname // Find backend application for given hostname - let (hostname, port) = parse_hostname_port(headers, tls_enabled)?; + let (hostname, _port) = parse_hostname_port(&req, tls_enabled)?; let path = req.uri().path(); + let path_and_query = req.uri().path_and_query().unwrap().as_str(); + println!("{:?}", path_and_query); let backend = if let Some(be) = backends.get(hostname.as_str()) { be } else { @@ -41,31 +42,18 @@ pub async fn handle_request( // Redirect to https if tls_enabled is false and redirect_to_https is true if !tls_enabled && backend.redirect_to_https.unwrap_or(false) { - if let Some(https_port) = globals.https_port { - let dest = if https_port == 443 { - format!("https://{}{}", hostname, path) - } else { - format!( - "https://{}:{}{}", - hostname, - globals.https_port.unwrap(), - path - ) - }; - return https_redirection(dest); - } else { - return http_error(StatusCode::SERVICE_UNAVAILABLE); - } + debug!("Redirect to https: {}", hostname); + return https_redirection(hostname, globals.https_port, path_and_query); } // Find reverse proxy for given path - // let destination_uri = if backend.reverse_proxy.destination_uris.is_some() { - // if let (b) = backend.re - // } else { - // backend.reverse_proxy.default_destination_uri.clone(); - // }; + let destination_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) { + uri.to_owned() + } else { + backend.reverse_proxy.default_destination_uri.clone() + }; - debug!("path: {}", req.uri().path()); + debug!("destination_uri: {}", destination_uri); // if req.version() == hyper::Version::HTTP_11 { // Ok(Response::new(Body::from("Hello World"))) // } else { @@ -73,7 +61,7 @@ pub async fn handle_request( // with an appropriate StatusCode instead of an Err. // Err("not HTTP/1.1, abort connection") // http_error(StatusCode::NOT_FOUND) - https_redirection("https://www.google.com/".to_string()) + https_redirection("www.google.com".to_string(), Some(443_u16), "/") // } // }); } @@ -86,32 +74,67 @@ fn http_error(status_code: StatusCode) -> Result> { Ok(response) } -fn https_redirection(redirect_to: String) -> Result> { +fn https_redirection( + hostname: String, + https_port: Option, + path_and_query: &str, +) -> Result> { + let dest_uri: String = if let Some(https_port) = https_port { + if https_port == 443 { + format!("https://{}{}", hostname, path_and_query) + } else { + format!("https://{}:{}{}", hostname, https_port, path_and_query) + } + } else { + return http_error(StatusCode::SERVICE_UNAVAILABLE); + }; let response = Response::builder() .status(StatusCode::MOVED_PERMANENTLY) - .header("Location", redirect_to) + .header("Location", dest_uri) .body(Body::empty()) .unwrap(); Ok(response) } -fn parse_hostname_port(headers: &HeaderMap, tls_enabled: bool) -> Result<(String, u16)> { - let hostname_port = headers - .get("host") - .ok_or_else(|| anyhow!("No host in request header"))?; - let hp_as_uri = hostname_port.to_str().unwrap().parse::().unwrap(); +fn parse_hostname_port(req: &Request, tls_enabled: bool) -> Result<(String, u16)> { + let hostname_port_headers = req.headers().get("host"); + let hostname_uri = req.uri().host(); + let port_uri = req.uri().port_u16(); - let hostname = hp_as_uri - .host() - .ok_or_else(|| anyhow!("Failed to parse hostname"))?; + if hostname_port_headers.is_none() && hostname_uri.is_none() { + bail!("No host in request header"); + } - let port = if let Some(p) = hp_as_uri.port() { - p.as_u16() - } else if tls_enabled { - 443 - } else { - 80 + let (hostname, port) = match (hostname_uri, hostname_port_headers) { + (Some(x), _) => { + let hostname = hostname_uri.unwrap(); + let port = if let Some(p) = port_uri { + p + } else if tls_enabled { + 443 + } else { + 80 + }; + (hostname.to_string(), port) + } + (None, Some(x)) => { + let hp_as_uri = x.to_str().unwrap().parse::().unwrap(); + let hostname = hp_as_uri + .host() + .ok_or_else(|| anyhow!("Failed to parse hostname"))?; + let port = if let Some(p) = hp_as_uri.port() { + p.as_u16() + } else if tls_enabled { + 443 + } else { + 80 + }; + (hostname.to_string(), port) + } + (None, None) => { + bail!("Host unspecified in request") + } }; - Ok((hostname.to_string(), port)) + Ok((hostname, port)) }