// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy use super::Proxy; use crate::{constants::*, error::*, log::*}; use hyper::{ client::connect::Connect, header::{HeaderMap, HeaderValue}, http::uri::Scheme, Body, Request, Response, StatusCode, Uri, Version, }; use std::net::SocketAddr; use tokio::io::copy_bidirectional; const HOP_HEADERS: &[&str] = &[ "connection", "te", "trailer", "keep-alive", "proxy-connection", "proxy-authenticate", "proxy-authorization", "transfer-encoding", "upgrade", ]; impl Proxy where T: Connect + Clone + Sync + Send + 'static, { pub async fn handle_request( self, mut req: Request, client_addr: SocketAddr, // アクセス制御用 ) -> Result> { info!( "Handling {:?} request from {}: {} {} {:?}", req.version(), client_addr, req.method(), req.uri(), req .headers() .get("user-agent") .map_or_else(|| "", |ua| ua.to_str().unwrap()) ); // Here we start to handle with server_name // Find backend application for given server_name let (server_name, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) { v } else { return http_error(StatusCode::SERVICE_UNAVAILABLE); }; let backend = if let Some(be) = self.backends.apps.get(server_name.as_str()) { be } else if let Some(default_be) = &self.backends.default_app { debug!("Serving by default app: {}", default_be); self.backends.apps.get(default_be).unwrap() } else { return http_error(StatusCode::SERVICE_UNAVAILABLE); }; // Redirect to https if tls_enabled is false and redirect_to_https is true let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned(); if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { debug!("Redirect to secure connection: {}", server_name); return secure_redirection(&server_name, self.globals.https_port, &path_and_query); } // Find reverse proxy for given path and choose one of upstream host // TODO: More flexible path matcher let path = req.uri().path(); let upstream_uri = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) { upstream.get() } else if let Some(default_upstream) = &backend.reverse_proxy.default_upstream { default_upstream.get() } else { return http_error(StatusCode::NOT_FOUND); }; let upstream_scheme_host = if let Some(u) = upstream_uri { u } else { return http_error(StatusCode::INTERNAL_SERVER_ERROR); }; // Upgrade in request header let upgrade_in_request = extract_upgrade(req.headers()); let request_upgraded = req.extensions_mut().remove::(); // Build request from destination information let req_forwarded = if let Ok(req) = generate_request_forwarded( client_addr, req, upstream_scheme_host, path_and_query, &upgrade_in_request, ) { req } else { error!("Failed to generate destination uri for reverse proxy"); return http_error(StatusCode::SERVICE_UNAVAILABLE); }; debug!("Request to be forwarded: {:?}", req_forwarded); // Forward request to let mut res_backend = match self.forwarder.request(req_forwarded).await { Ok(res) => res, Err(e) => { error!("Failed to get response from backend: {}", e); return http_error(StatusCode::BAD_REQUEST); } }; #[cfg(feature = "h3")] { if self.globals.http3 { if let Some(port) = self.globals.https_port { res_backend.headers_mut().insert( hyper::header::ALT_SVC, format!( "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE ) .parse() .unwrap(), ); } } } debug!("Response from backend: {:?}", res_backend.status()); if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS { // Handle StatusCode::SWITCHING_PROTOCOLS in response let upgrade_in_response = extract_upgrade(res_backend.headers()); if upgrade_in_request == upgrade_in_response { if let Some(request_upgraded) = request_upgraded { let mut response_upgraded = res_backend .extensions_mut() .remove::() .ok_or_else(|| anyhow!("Response does not have an upgrade extension"))? // TODO: any response code? .await?; // TODO: H3で死ぬことがある // thread 'rpxy' panicked at 'Failed to upgrade request: hyper::Error(User(ManualUpgrade))', src/proxy/proxy_handler.rs:124:63 tokio::spawn(async move { let mut request_upgraded = request_upgraded.await.map_err(|e| { error!("Failed to upgrade request: {}", e); anyhow!("Failed to upgrade request: {}", e) })?; // TODO: any response code? copy_bidirectional(&mut response_upgraded, &mut request_upgraded) .await .map_err(|e| anyhow!("Coping between upgraded connections failed: {}", e))?; // TODO: any response code? Ok(()) as Result<()> }); Ok(res_backend) } else { error!("Request does not have an upgrade extension"); http_error(StatusCode::BAD_GATEWAY) } } else { error!( "Backend tried to switch to protocol {:?} when {:?} was requested", upgrade_in_response, upgrade_in_request ); http_error(StatusCode::BAD_GATEWAY) } } else { // Generate response to client if generate_response_forwarded(&mut res_backend).is_ok() { Ok(res_backend) } else { http_error(StatusCode::BAD_GATEWAY) } } } } fn generate_response_forwarded(response: &mut Response) -> Result<()> { let headers = response.headers_mut(); remove_hop_header(headers); remove_connection_header(headers); Ok(()) } fn generate_request_forwarded( client_addr: SocketAddr, mut req: Request, upstream_scheme_host: &Uri, path_and_query: String, upgrade: &Option, ) -> Result> { debug!("Generate request to be forwarded"); // Add te: trailer if contained in original request let te_trailers = { if let Some(te) = req.headers().get("te") { te.to_str() .unwrap() .split(',') .any(|x| x.trim() == "trailers") } else { false } }; let headers = req.headers_mut(); // delete headers specified in header.connection remove_connection_header(headers); // delete hop headers including header.connection remove_hop_header(headers); // X-Forwarded-For add_forwarding_header(headers, client_addr)?; // Add te: trailer if te_trailer if te_trailers { headers.insert("te", "trailer".parse().unwrap()); } // update uri in request *req.uri_mut() = Uri::builder() .scheme(upstream_scheme_host.scheme().unwrap().as_str()) .authority(upstream_scheme_host.authority().unwrap().as_str()) .path_and_query(&path_and_query) .build()?; // upgrade if let Some(v) = upgrade { req.headers_mut().insert("upgrade", v.parse().unwrap()); req .headers_mut() .insert("connection", HeaderValue::from_str("upgrade")?); } // Change version to http/1.1 when destination scheme is http if req.version() != Version::HTTP_11 && upstream_scheme_host.scheme() == Some(&Scheme::HTTP) { *req.version_mut() = Version::HTTP_11; } else if req.version() == Version::HTTP_3 { debug!("HTTP/3 is currently unsupported for request to upstream. Use HTTP/2."); *req.version_mut() = Version::HTTP_2; } Ok(req) } fn add_forwarding_header(headers: &mut HeaderMap, client_addr: SocketAddr) -> Result<()> { // TODO: Option対応? let client_ip = client_addr.ip(); match headers.entry("x-forwarded-for") { hyper::header::Entry::Vacant(entry) => { entry.insert(client_ip.to_string().parse()?); } hyper::header::Entry::Occupied(entry) => { let client_ip_str = client_ip.to_string(); let mut addr = String::with_capacity(entry.get().as_bytes().len() + 2 + client_ip_str.len()); addr.push_str(std::str::from_utf8(entry.get().as_bytes()).unwrap()); addr.push(','); addr.push(' '); addr.push_str(&client_ip_str); } } Ok(()) } fn remove_connection_header(headers: &mut HeaderMap) { if headers.get("connection").is_some() { let v = headers.get("connection").cloned().unwrap(); for m in v.to_str().unwrap().split(',') { if !m.is_empty() { headers.remove(m.trim()); } } } } fn remove_hop_header(headers: &mut HeaderMap) { HOP_HEADERS.iter().for_each(|key| { headers.remove(*key); }); } fn http_error(status_code: StatusCode) -> Result> { let response = Response::builder() .status(status_code) .body(Body::empty()) .unwrap(); Ok(response) } fn extract_upgrade(headers: &HeaderMap) -> Option { if let Some(c) = headers.get("connection") { if c .to_str() .unwrap_or("") .split(',') .into_iter() .any(|w| w.trim().to_ascii_lowercase() == "upgrade") { if let Some(u) = headers.get("upgrade") { let m = u.to_str().unwrap().to_string(); debug!("Upgrade in request header: {}", m); return Some(m); } } } None } fn secure_redirection( server_name: &str, tls_port: Option, path_and_query: &str, ) -> Result> { let dest_uri: String = if let Some(tls_port) = tls_port { if tls_port == 443 { format!("https://{}{}", server_name, path_and_query) } else { format!("https://{}:{}{}", server_name, tls_port, path_and_query) } } else { bail!("Internal error! TLS port is not set internally."); }; let response = Response::builder() .status(StatusCode::MOVED_PERMANENTLY) .header("Location", dest_uri) .body(Body::empty()) .unwrap(); Ok(response) } fn parse_host_port( req: &Request, tls_enabled: bool, ) -> Result<(String, u16)> { let host_port_headers = req.headers().get("host"); let host_uri = req.uri().host(); let port_uri = req.uri().port_u16(); if host_port_headers.is_none() && host_uri.is_none() { bail!("No host in request header"); } let (host, port) = match (host_uri, host_port_headers) { (Some(x), _) => { let port = if let Some(p) = port_uri { p } else if tls_enabled { 443 } else { 80 }; (x.to_string(), port) } (None, Some(x)) => { let hp_as_uri = x.to_str().unwrap().parse::().unwrap(); let host = hp_as_uri .host() .ok_or_else(|| anyhow!("Failed to parse host"))?; let port = if let Some(p) = hp_as_uri.port() { p.as_u16() } else if tls_enabled { 443 } else { 80 }; (host.to_string(), port) } (None, None) => { bail!("Host unspecified in request") } }; Ok((host, port)) }