diff --git a/src/config.rs b/src/config.rs index 7844918..0859a6c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -24,7 +24,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap let mut map_example: HashMap = HashMap::new(); map_example.insert( "/maps".to_string(), - Uri::builder().authority("www.bing.com").build().unwrap(), + "https://www.bing.com".parse::().unwrap(), ); backends.insert( "localhost".to_string(), @@ -32,7 +32,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap app_name: "Localhost to Google except for maps".to_string(), hostname: "localhost".to_string(), reverse_proxy: ReverseProxy { - default_destination_uri: Uri::builder().authority("www.google.com").build().unwrap(), + default_destination_uri: "https://www.google.com".parse::().unwrap(), destination_uris: map_example, }, https_redirection: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index caac334..4097620 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -1,26 +1,33 @@ +// Motivated by https://github.com/felipenoris/hyper-reverse-proxy use super::Proxy; use crate::{error::*, log::*}; use hyper::{ client::connect::Connect, - header::{HeaderMap, HeaderName, HeaderValue}, + header::{HeaderMap, HeaderValue}, Body, Request, Response, StatusCode, Uri, }; use std::net::SocketAddr; -// pub static HEADERS: phf::Map<&'static str, HeaderName> = phf_map! { -// "CONNECTION" => HeaderName::from_static("connection"), -// "ws" => "wss", -// }; +const HOP_HEADERS: &[&str] = &[ + "connection", + "te", + "trailer", + "keep-alive", + "proxy-connection", + "proxy-authenticate", + "proxy-authorization", + "transfer-encoding", + "upgrade", +]; impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - // TODO: ここでbackendの名前単位でリクエストを分岐させる pub async fn handle_request( self, req: Request, - client_ip: SocketAddr, // アクセス制御用 + client_addr: SocketAddr, // アクセス制御用 ) -> Result> { debug!("Handling request: {:?}", req); // Here we start to handle with hostname @@ -45,20 +52,24 @@ where // Find reverse proxy for given path let path = req.uri().path(); - let destination_host_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) { - uri.to_owned() - } else { - backend.reverse_proxy.default_destination_uri.clone() - }; + let destination_scheme_host = + if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) { + uri.to_owned() + } else { + backend.reverse_proxy.default_destination_uri.clone() + }; - // TODO: Upgrade - // TODO: X-Forwarded-For - // TODO: Transfer Encoding + // Upgrade in request header + let upgrade_in_request = extract_upgrade(req.headers()); // Build request from destination information - let req_forwarded = if let Ok(req) = - generate_request_forwarded(client_ip, req, destination_host_uri, path_and_query) - { + let req_forwarded = if let Ok(req) = generate_request_forwarded( + client_addr, + req, + destination_scheme_host, + path_and_query, + upgrade_in_request, + ) { req } else { error!("Failed to generate destination uri for reverse proxy"); @@ -81,35 +92,89 @@ where } } -// Motivated by https://github.com/felipenoris/hyper-reverse-proxy fn generate_request_forwarded( - client_ip: SocketAddr, + client_addr: SocketAddr, mut req: Request, - destination_host_uri: Uri, + destination_scheme_host: Uri, path_and_query: String, + upgrade: Option, ) -> Result> { debug!("Generate request to be forwarded"); + // TODO: Transfer Encoding + // update "host" key in request header if req.headers().contains_key("host") { // HTTP/1.1 req.headers_mut().insert( "host", - HeaderValue::from_str(destination_host_uri.host().unwrap()) + HeaderValue::from_str(destination_scheme_host.host().unwrap()) .map_err(|_| anyhow!("Failed to insert destination host into forwarded request"))?, ); } + let headers = req.headers_mut(); + // delete headers specified in header.connection + remove_connection_header(headers); + // delete hop headers including header.connection + remove_hop_header(headers); + // X-Forwarded-For + add_forwarding_header(headers, client_addr)?; + // update uri in request *req.uri_mut() = Uri::builder() - .scheme(destination_host_uri.scheme().unwrap().as_str()) - .authority(destination_host_uri.authority().unwrap().as_str()) + .scheme(destination_scheme_host.scheme().unwrap().as_str()) + .authority(destination_scheme_host.authority().unwrap().as_str()) .path_and_query(&path_and_query) .build()?; + // upgrade + if let Some(v) = upgrade { + req.headers_mut().insert("upgrade", v.parse().unwrap()); + req + .headers_mut() + .insert("connection", HeaderValue::from_str("upgrade")?); + } + Ok(req) } +fn add_forwarding_header(headers: &mut HeaderMap, client_addr: SocketAddr) -> Result<()> { + let client_ip = client_addr.ip(); + match headers.entry("x-forwarded-for") { + hyper::header::Entry::Vacant(entry) => { + entry.insert(client_ip.to_string().parse()?); + } + hyper::header::Entry::Occupied(entry) => { + let client_ip_str = client_ip.to_string(); + let mut addr = String::with_capacity(entry.get().as_bytes().len() + 2 + client_ip_str.len()); + + addr.push_str(std::str::from_utf8(entry.get().as_bytes()).unwrap()); + addr.push(','); + addr.push(' '); + addr.push_str(&client_ip_str); + } + } + Ok(()) +} + +fn remove_connection_header(headers: &mut HeaderMap) { + if headers.get("connection").is_some() { + let v = headers.get("connection").cloned().unwrap(); + for m in v.to_str().unwrap().split(',') { + if !m.is_empty() { + headers.remove(m.trim()); + } + } + } +} + +fn remove_hop_header(headers: &mut HeaderMap) { + let _ = HOP_HEADERS.iter().for_each(|key| { + headers.remove(*key); + }); +} + fn http_error(status_code: StatusCode) -> Result> { let response = Response::builder() .status(status_code) @@ -118,6 +183,25 @@ fn http_error(status_code: StatusCode) -> Result> { Ok(response) } +fn extract_upgrade(headers: &HeaderMap) -> Option { + if let Some(c) = headers.get("connection") { + if c + .to_str() + .unwrap_or("") + .split(',') + .into_iter() + .any(|w| w.trim().to_ascii_lowercase() == "upgrade") + { + if let Some(u) = headers.get("upgrade") { + let m = u.to_str().unwrap().to_string(); + debug!("Upgrade in request header: {}", m); + return Some(m); + } + } + } + None +} + fn secure_redirection( hostname: &str, tls_port: Option,