From 9335fc702efd1ba79a00f1850bd93f080ee6307d Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 9 Jul 2022 23:57:08 +0900 Subject: [PATCH] some refactor --- src/config/parse.rs | 6 +-- src/main.rs | 6 +-- src/proxy/backend.rs | 7 ++- src/proxy/proxy_handler.rs | 69 ++++++++++++++++-------------- src/proxy/proxy_tls.rs | 59 ++++++++++++++------------ src/proxy/utils_request.rs | 87 ++++++++++++++++++++++++++++---------- 6 files changed, 146 insertions(+), 88 deletions(-) diff --git a/src/config/parse.rs b/src/config/parse.rs index 347164b..6015082 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -91,7 +91,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> // each app for (app_name, app) in apps.0.iter() { ensure!(app.server_name.is_some(), "Missing server_name"); - let server_name = app.server_name.as_ref().unwrap(); + let server_name = app.server_name.as_ref().unwrap().to_ascii_lowercase(); // TLS settings let (tls_cert_path, tls_cert_key_path, https_redirection) = if app.tls.is_none() { @@ -122,7 +122,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; backends.apps.insert( - server_name.to_owned(), + server_name.as_bytes().to_vec(), Backend { app_name: app_name.to_owned(), server_name: server_name.to_owned(), @@ -149,7 +149,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", d, d_sn[0] ); - backends.default_app = Some(d_sn[0].to_owned()); + backends.default_server_name = Some(d_sn[0].as_bytes().to_vec()); } } diff --git a/src/main.rs b/src/main.rs index babea2b..6a56710 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,7 +16,7 @@ use crate::{ error::*, globals::*, log::*, - proxy::{Backend, Backends, Proxy}, + proxy::{Backend, Backends, Proxy, ServerNameLC}, }; use futures::future::select_all; use hyper::Client; @@ -64,8 +64,8 @@ fn main() { }; let mut backends = Backends { - default_app: None, - apps: HashMap::::default(), + default_server_name: None, + apps: HashMap::::default(), }; parse_opts(&mut globals, &mut backends).expect("Invalid configuration"); diff --git a/src/proxy/backend.rs b/src/proxy/backend.rs index 2e9ae53..24d4265 100644 --- a/src/proxy/backend.rs +++ b/src/proxy/backend.rs @@ -13,9 +13,12 @@ use std::{ }; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; +// server name (hostname or ip address) in ascii lower case +pub type ServerNameLC = Vec; + pub struct Backends { - pub apps: HashMap, // TODO: hyper::uriで抜いたhostで引っ掛ける。Stringでいいのか? - pub default_app: Option, // for plaintext http + pub apps: HashMap, // TODO: hyper::uriで抜いたhostで引っ掛ける。Stringでいいのか? + pub default_server_name: Option, // for plaintext http } pub struct Backend { diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index 765672d..9893525 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -23,27 +23,27 @@ where // Here we start to handle with server_name // Find backend application for given server_name, and drop if incoming request is invalid as request. - let (server_name, _port) = parse_host_port(&req)?; + // let (server_name, _port) = parse_host_port(&req)?; + let server_name_bytes = req.parse_host()?.to_ascii_lowercase(); - if !self.backends.apps.contains_key(&server_name) && self.backends.default_app.is_none() { + let backend = if let Some(be) = self.backends.apps.get(&server_name_bytes) { + be + } else if let Some(default_server_name) = &self.backends.default_server_name { + debug!("Serving by default app"); + self.backends.apps.get(default_server_name).unwrap() + } else { // info!("{} => {}", request_log, StatusCode::SERVICE_UNAVAILABLE); return http_error(StatusCode::SERVICE_UNAVAILABLE); - } - let backend = if let Some(be) = self.backends.apps.get(&server_name) { - be - } else { - let default_be = self.backends.default_app.as_ref().unwrap(); - debug!("Serving by default app: {}", default_be); - self.backends.apps.get(default_be).unwrap() }; // Redirect to https if !tls_enabled and redirect_to_https is true if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { - debug!("Redirect to secure connection: {}", server_name); + debug!("Redirect to secure connection: {}", &backend.server_name); // info!("{} => {}", request_log, StatusCode::PERMANENT_REDIRECT); - return secure_redirection(&server_name, self.globals.https_port, &req); + return secure_redirection(&backend.server_name, self.globals.https_port, &req); } + /////////////////////// // Find reverse proxy for given path and choose one of upstream host // TODO: More flexible path matcher let path = req.uri().path(); @@ -59,6 +59,7 @@ where } else { return http_error(StatusCode::INTERNAL_SERVER_ERROR); }; + /////////////////////// // Upgrade in request header let upgrade_in_request = extract_upgrade(req.headers()); @@ -77,7 +78,8 @@ where error!("Failed to generate destination uri for reverse proxy"); return http_error(StatusCode::SERVICE_UNAVAILABLE); }; - debug!("Request to be forwarded: {:?}", req_forwarded); + // debug!("Request to be forwarded: {:?}", req_forwarded); + req_forwarded.log(&client_addr, Some("Forwarding")); // Forward request to let mut res_backend = match self.forwarder.request(req_forwarded).await { @@ -87,21 +89,6 @@ where return http_error(StatusCode::BAD_REQUEST); } }; - #[cfg(feature = "h3")] - { - if self.globals.http3 { - if let Some(port) = self.globals.https_port { - let alt_svc_value = HeaderValue::from_str(&format!( - "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", - port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE - )) - .unwrap(); - res_backend - .headers_mut() - .insert(header::ALT_SVC, alt_svc_value); - } - } - } debug!("Response from backend: {:?}", res_backend.status()); // let response_log = res_backend.status().to_string(); @@ -175,6 +162,21 @@ where "server", &format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")), )?; + #[cfg(feature = "h3")] + { + if self.globals.http3 { + if let Some(port) = self.globals.https_port { + append_header_entry( + headers, + header::ALT_SVC.as_str(), + &format!( + "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", + port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE + ), + )?; + } + } + } Ok(()) } @@ -192,7 +194,9 @@ where // Add te: trailer if contained in original request let te_trailers = { if let Some(te) = req.headers().get(header::TE) { - te.to_str()?.split(',').any(|x| x.trim() == "trailers") + te.as_bytes() + .split(|v| v == &b',' || v == &b' ') + .any(|x| x == "trailers".as_bytes()) } else { false } @@ -205,7 +209,6 @@ where remove_hop_header(headers); // X-Forwarded-For add_forwarding_header(headers, client_addr, self.tls_enabled)?; - // println!("{:?}", headers); // Add te: trailer if te_trailer if te_trailers { @@ -214,10 +217,14 @@ where // add "host" header of original server_name if not exist (default) if req.headers().get(header::HOST).is_none() { - let org_host = req.uri().host().unwrap_or("none").to_owned(); + let org_host = req + .uri() + .host() + .ok_or_else(|| anyhow!("Invalid request"))? + .to_owned(); req .headers_mut() - .insert(header::HOST, HeaderValue::from_str(org_host.as_str())?); + .insert(header::HOST, HeaderValue::from_str(&org_host)?); }; // apply upstream-specific headers given in upstream_option diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 148c119..684be0d 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -1,4 +1,7 @@ -use super::proxy_main::{LocalExecutor, Proxy}; +use super::{ + proxy_main::{LocalExecutor, Proxy}, + ServerNameLC, +}; use crate::{constants::*, error::*, log::*}; #[cfg(feature = "h3")] use futures::StreamExt; @@ -9,7 +12,7 @@ use rustls::ServerConfig; use std::sync::Arc; use tokio::{net::TcpListener, sync::watch, time::Duration}; -type ServerCryptoMap = HashMap; +type ServerCryptoMap = HashMap>; impl Proxy where @@ -18,16 +21,19 @@ where async fn cert_service(&self, server_crypto_tx: watch::Sender>) { info!("Start cert watch service"); loop { - let mut hm_server_config = HashMap::::default(); - for (server_name, backend) in self.backends.apps.iter() { + let mut hm_server_config = HashMap::>::default(); + for (server_name_bytes, backend) in self.backends.apps.iter() { if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { match backend.update_server_config().await { Err(_e) => { - error!("Failed to update certs for {}: {}", server_name, _e); + error!( + "Failed to update certs for {}: {}", + &backend.server_name, _e + ); break; } Ok(server_config) => { - hm_server_config.insert(server_name.to_owned(), server_config); + hm_server_config.insert(server_name_bytes.to_vec(), Arc::new(server_config)); } } } @@ -66,21 +72,20 @@ where let start = acceptor.unwrap(); let client_hello = start.client_hello(); - debug!("SNI in ClientHello: {:?}", client_hello.server_name()); // Find server config for given SNI - let svn = if let Some(svn) = client_hello.server_name() { - svn - } else { + if client_hello.server_name().is_none(){ info!("No SNI in ClientHello"); continue; - }; - let server_crypto = if let Some(p) = server_crypto_map.as_ref().unwrap().get(svn) { - p.to_owned() - } else { + } + let server_name = client_hello.server_name().unwrap().to_ascii_lowercase(); + debug!("SNI in ClientHello: {:?}", server_name); + let server_crypto = server_crypto_map.as_ref().unwrap().get(server_name.as_bytes()); + if server_crypto.is_none() { + debug!("No TLS serving app for {}", server_name); continue; }; // Finally serve the TLS connection - if let Ok(stream) = start.into_stream(Arc::new(server_crypto)).await { + if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await { self.clone().client_serve(stream, server.clone(), _client_addr).await } } @@ -101,7 +106,7 @@ where &self, peeked_conn: &mut quinn::Connecting, server_crypto_map: &ServerCryptoMap, - ) -> Option { + ) -> Option> { let hsd = if let Ok(h) = peeked_conn.handshake_data().await { h } else { @@ -112,12 +117,14 @@ where } else { return None; }; - let server_name = hsd_downcast.server_name?; + let server_name = hsd_downcast.server_name?.to_ascii_lowercase(); info!( "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", server_name ); - server_crypto_map.get(&server_name).cloned() + server_crypto_map + .get(&server_name.as_bytes().to_vec()) + .cloned() } #[cfg(feature = "h3")] @@ -127,20 +134,20 @@ where ) -> Result<()> { // TODO: Work around to initially serve incoming connection // かなり適当。エラーが出たり出なかったり。原因がわからない… - let tls_app_names: Vec = self + let next = self .backends .apps .iter() .filter(|&(_, backend)| { backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() }) - .map(|(name, _)| name.to_string()) - .collect(); - ensure!(!tls_app_names.is_empty(), "No TLS supported app"); - let initial_app_name = tls_app_names.get(0).ok_or_else(|| anyhow!(""))?.as_str(); + .map(|(name, _)| name) + .next(); + ensure!(next.is_some(), "No TLS supported app"); + let initial_app_name = next.ok_or_else(|| anyhow!(""))?; debug!( - "HTTP/3 SNI multiplexer initial app_name: {}", - initial_app_name + "HTTP/3 SNI multiplexer initial app_name: {:?}", + String::from_utf8(initial_app_name.to_vec()) ); let backend_serve = self .backends @@ -168,7 +175,7 @@ where let is_acceptable = if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { // Set ServerConfig::set_server_config for given SNI - endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)))); + endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto))); true } else { false diff --git a/src/proxy/utils_request.rs b/src/proxy/utils_request.rs index 16c2705..163f31c 100644 --- a/src/proxy/utils_request.rs +++ b/src/proxy/utils_request.rs @@ -1,5 +1,5 @@ use crate::{error::*, log::*, utils::*}; -use hyper::{header, Request, Uri}; +use hyper::{header, Request}; use std::fmt::Display; //////////////////////////////////////////////////// @@ -36,29 +36,70 @@ impl MsgLog for &Request { } } -pub(super) fn parse_host_port( - req: &Request, -) -> Result<(String, Option)> { - let headers_host = req.headers().get("host"); - let uri_host = req.uri().host(); - let uri_port = req.uri().port_u16(); +pub trait ParseHost { + fn parse_host(&self) -> Result<&[u8]>; +} +impl ParseHost for Request { + fn parse_host(&self) -> Result<&[u8]> { + let headers_host = self.headers().get(header::HOST); + let uri_host = self.uri().host(); + // let uri_port = self.uri().port_u16(); - ensure!( - !(headers_host.is_none() && uri_host.is_none()), - "No host in request header" - ); + ensure!( + !(headers_host.is_none() && uri_host.is_none()), + "No host in request header" + ); - // prioritize server_name in uri - if let Some(v) = uri_host { - Ok((v.to_string(), uri_port)) - } else { - let uri_from_host = headers_host.unwrap().to_str()?.parse::()?; - Ok(( - uri_from_host - .host() - .ok_or_else(|| anyhow!("Failed to parse host"))? - .to_string(), - uri_from_host.port_u16(), - )) + // prioritize server_name in uri + uri_host.map_or_else( + || { + let m = headers_host.unwrap().as_bytes(); + if m.starts_with(&[b'[']) { + println!("v6 bracket"); + // v6 address with bracket case. if port is specified, always it is in this case. + let mut iter = m.split(|ptr| ptr == &b'[' || ptr == &b']'); + iter.next().ok_or_else(|| anyhow!("Invalid Host"))?; // first item is always blank + iter.next().ok_or_else(|| anyhow!("Invalid Host")) + } else if m.len() - m.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { + println!("v6 non-bracket"); + // v6 address case, if 2 or more ':' is contained + Ok(m) + } else { + // v4 address or hostname + m.split(|colon| colon == &b':') + .into_iter() + .next() + .ok_or_else(|| anyhow!("Invalid Host")) + } + }, + |v| Ok(v.as_bytes()), + ) } } + +// pub(super) fn parse_host_port( +// req: &Request, +// ) -> Result<(String, Option)> { +// let headers_host = req.headers().get("host"); +// let uri_host = req.uri().host(); +// let uri_port = req.uri().port_u16(); + +// ensure!( +// !(headers_host.is_none() && uri_host.is_none()), +// "No host in request header" +// ); + +// // prioritize server_name in uri +// if let Some(v) = uri_host { +// Ok((v.to_string(), uri_port)) +// } else { +// let uri_from_host = headers_host.unwrap().to_str()?.parse::()?; +// Ok(( +// uri_from_host +// .host() +// .ok_or_else(|| anyhow!("Failed to parse host"))? +// .to_string(), +// uri_from_host.port_u16(), +// )) +// } +// }