diff --git a/Cargo.toml b/Cargo.toml index ed3f832..b18377b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,7 +12,8 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = [] +default = ["h3"] +h3 = ["quinn"] [dependencies] anyhow = "1.0.58" @@ -26,7 +27,7 @@ hyper = { version = "0.14.19", default-features = false, features = [ "stream", ] } log = "0.4.17" -tokio = { version = "1.19.2", features = [ +tokio = { version = "1.19.2", default-features = false, features = [ "net", "rt-multi-thread", "parking_lot", @@ -48,6 +49,10 @@ hyper-rustls = { version = "0.23.0", default-features = false, features = [ "http2", ] } parking_lot = "0.12.1" +quinn = { version = "0.8.3", optional = true } +h3 = { git = "https://github.com/hyperium/h3.git" } +h3-quinn = { git = "https://github.com/hyperium/h3.git" } +bytes = "1.1.0" [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/bench/bench.sh b/bench/bench.sh index e024055..428ab62 100644 --- a/bench/bench.sh +++ b/bench/bench.sh @@ -2,12 +2,12 @@ echo "----------------------------" echo "Benchmark on rpxy" -ab -c 100 -n 10000 http://127.0.0.1:8080/ # TODO: localhost = 127.0.0.1を解決できるように決めておかんとだめそう +ab -c 100 -n 10000 http://127.0.0.1:8080/index.html # TODO: localhost = 127.0.0.1を解決できるように決めておかんとだめそう echo "----------------------------" echo "Benchmark on nginx" -ab -c 100 -n 10000 http://127.0.0.1:8090/ +ab -c 100 -n 10000 http://127.0.0.1:8090/index.html echo "----------------------------" echo "Benchmark on caddy" -ab -c 100 -n 10000 http://127.0.0.1:8100/ +ab -c 100 -n 10000 http://127.0.0.1:8100/index.html diff --git a/bench/rpxy.toml b/bench/rpxy.toml index d6667de..d34150a 100644 --- a/bench/rpxy.toml +++ b/bench/rpxy.toml @@ -1,6 +1,7 @@ listen_port = 8080 # listen_port_tls = 8443 listen_ipv6 = false +listen_only_ipv6 = false max_concurrent_streams = 128 max_clients = 512 @@ -17,3 +18,7 @@ reverse_proxy = [ { upstream = [{ location = 'backend-nginx', tls = false }] }, # { upstream = [{ location = '192.168.100.100', tls = false }] }, ] + + +[experimental] +h3 = true diff --git a/config-example.toml b/config-example.toml index 5befe8b..f9cb30b 100644 --- a/config-example.toml +++ b/config-example.toml @@ -16,6 +16,7 @@ max_clients = 512 # Optional: Listen [::] listen_ipv6 = false +listen_only_ipv6 = false # Optional: App that serves all plaintext http request by referring to HOSTS or request header # execpt for configured application. @@ -54,3 +55,9 @@ tls = { https_redirection = true, tls_cert_path = 'localhost.pem', tls_cert_key_ [apps.another_localhost] server_name = 'localhost.localdomain' reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] + +################################### +# Experimantal settings # +################################### +[experimental] +h3 = true diff --git a/src/backend.rs b/src/backend.rs index 554eccd..d5f56ee 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -180,7 +180,20 @@ impl Backend { "Unable to find a valid certificate and key", ) })?; - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + #[cfg(feature = "h3")] + { + server_config.alpn_protocols = vec![ + b"h3".to_vec(), + b"hq-29".to_vec(), // quinn draft example TODO: remove later + b"h2".to_vec(), + b"http/1.1".to_vec(), + ]; + } + #[cfg(not(feature = "h3"))] + { + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + } let mut config_store = self.server_config.lock(); *config_store = Some(server_config); diff --git a/src/config/parse.rs b/src/config/parse.rs index 7008486..bb18674 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -39,11 +39,18 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> }, anyhow!("Wrong port spec.") ); - let mut listen_addresses: Vec<&str> = LISTEN_ADDRESSES_V4.to_vec(); - if let Some(v) = config.listen_ipv6 { + let mut listen_addresses: Vec<&str> = Vec::new(); + if let Some(v) = config.listen_only_ipv6 { if v { listen_addresses.extend(LISTEN_ADDRESSES_V6.iter()); } + } else if let Some(v) = config.listen_ipv6 { + listen_addresses.extend(LISTEN_ADDRESSES_V4.iter()); + if v { + listen_addresses.extend(LISTEN_ADDRESSES_V6.iter()); + } + } else { + listen_addresses.extend(LISTEN_ADDRESSES_V4.iter()); } globals.listen_sockets = listen_addresses .iter() @@ -144,6 +151,16 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> } } + // experimental + if let Some(exp) = config.experimental { + if let Some(b) = exp.h3 { + globals.http3 = b; + if b { + info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.") + } + } + } + Ok(()) } diff --git a/src/config/toml.rs b/src/config/toml.rs index b04e048..87fc8bb 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -8,10 +8,17 @@ pub struct ConfigToml { pub listen_port: Option, pub listen_port_tls: Option, pub listen_ipv6: Option, + pub listen_only_ipv6: Option, pub max_concurrent_streams: Option, pub max_clients: Option, pub apps: Option, pub default_app: Option, + pub experimental: Option, +} + +#[derive(Deserialize, Debug, Default)] +pub struct Experimental { + pub h3: Option, } #[derive(Deserialize, Debug, Default)] diff --git a/src/globals.rs b/src/globals.rs index 30227c7..19dd539 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -16,6 +16,7 @@ pub struct Globals { pub clients_count: ClientsCount, pub max_concurrent_streams: u32, pub keepalive: bool, + pub http3: bool, pub runtime_handle: tokio::runtime::Handle, } diff --git a/src/main.rs b/src/main.rs index 5010d73..be59bfc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -55,6 +55,7 @@ fn main() { listen_sockets: Vec::new(), http_port: None, https_port: None, + http3: false, timeout: Duration::from_secs(TIMEOUT_SEC), max_clients: MAX_CLIENTS, clients_count: Default::default(), diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index a094bdc..20b5810 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "h3")] +mod proxy_h3; mod proxy_handler; mod proxy_main; mod proxy_tls; diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs new file mode 100644 index 0000000..a9eecc0 --- /dev/null +++ b/src/proxy/proxy_h3.rs @@ -0,0 +1,110 @@ +use super::Proxy; +use crate::{error::*, log::*}; +use bytes::{Buf, Bytes}; +use h3::{quic::BidiStream, server::RequestStream}; +use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response}; +use std::net::SocketAddr; + +impl Proxy +where + T: Connect + Clone + Sync + Send + 'static, +{ + pub async fn client_serve_h3(self, conn: quinn::Connecting) -> Result<()> { + let client_addr = conn.remote_address(); + + match conn.await { + Ok(new_conn) => { + info!("QUIC connection established from {:?} {:?}", client_addr, { + let hsd = new_conn + .connection + .handshake_data() + .ok_or_else(|| anyhow!(""))? + .downcast::() + .map_err(|_| anyhow!(""))?; + ( + hsd.protocol.map_or_else( + || "".into(), + |x| String::from_utf8_lossy(&x).into_owned(), + ), + hsd.server_name.map_or_else(|| "".into(), |x| x), + ) + }); + + let mut h3_conn = + h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)) + .await?; + info!("HTTP/3 connection established"); + + while let Some((req, stream)) = h3_conn + .accept() + .await + .map_err(|e| anyhow!("HTTP/3 accept failed: {}", e))? + { + info!("HTTP/3 new request received"); + + let self_inner = self.clone(); + self.globals.runtime_handle.spawn(async move { + if let Err(e) = self_inner.handle_request_h3(req, stream, client_addr).await { + error!("HTTP/3 request failed: {}", e); + } + }); + } + } + Err(err) => { + warn!("QUIC accepting connection failed: {:?}", err); + } + } + + Ok(()) + } + + async fn handle_request_h3( + self, + req: Request<()>, + mut stream: RequestStream, + client_addr: SocketAddr, + ) -> Result<()> + where + S: BidiStream, + { + let (req_parts, _) = req.into_parts(); + + // TODO: h3 -> h2/http1.1などのプロトコル変換がなければ、bodyはBytes単位で直でsend_dataして転送した方がいい。やむなし。 + let mut body_chunk: Vec = Vec::new(); + while let Some(request_body) = stream.recv_data().await? { + body_chunk.extend_from_slice(request_body.chunk()); + } + let body = if body_chunk.is_empty() { + Body::default() + } else { + debug!("HTTP/3 request with non-empty body"); + Body::from(body_chunk) + }; + // trailers + let trailers = if let Some(trailers) = stream.recv_trailers().await? { + debug!("HTTP/3 request with trailers"); + trailers + } else { + HeaderMap::new() + }; + + let new_req: Request = Request::from_parts(req_parts, body); + let res = self.handle_request(new_req, client_addr).await?; + + let (new_res_parts, new_body) = res.into_parts(); + let new_res = Response::from_parts(new_res_parts, ()); + + match stream.send_response(new_res).await { + Ok(_) => { + debug!("HTTP/3 response to connection successful"); + let data = hyper::body::to_bytes(new_body).await?; + stream.send_data(data).await?; + stream.send_trailers(trailers).await?; + } + Err(err) => { + error!("Unable to send response to connection peer: {:?}", err); + } + } + Ok(stream.finish().await?) + } +} diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index c73e003..3b947fa 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -98,6 +98,14 @@ where return http_error(StatusCode::BAD_REQUEST); } }; + #[cfg(feature = "h3")] + { + if let Some(port) = self.globals.https_port { + res_backend + .headers_mut() + .insert("alt-svc", format!("h3=\":{}\"", port).parse().unwrap()); + } + } debug!("Response from backend: {:?}", res_backend.status()); if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS { @@ -156,12 +164,12 @@ fn generate_request_forwarded( debug!("Generate request to be forwarded"); // Add te: trailer if contained in original request - let te_trailer = { + let te_trailers = { if let Some(te) = req.headers().get("te") { te.to_str() .unwrap() .split(',') - .any(|x| x.trim() == "trailer") + .any(|x| x.trim() == "trailers") } else { false } @@ -175,7 +183,7 @@ fn generate_request_forwarded( // X-Forwarded-For add_forwarding_header(headers, client_addr)?; // Add te: trailer if te_trailer - if te_trailer { + if te_trailers { headers.insert("te", "trailer".parse().unwrap()); } @@ -200,6 +208,9 @@ fn generate_request_forwarded( // Change version to http/1.1 when destination scheme is http if req.version() != Version::HTTP_11 && upstream_scheme_host.scheme() == Some(&Scheme::HTTP) { *req.version_mut() = Version::HTTP_11; + } else if req.version() == Version::HTTP_3 { + debug!("HTTP/3 is currently unsupported for request to upstream. Use HTTP/2."); + *req.version_mut() = Version::HTTP_2; } Ok(req) @@ -290,7 +301,10 @@ fn secure_redirection( Ok(response) } -fn parse_host_port(req: &Request, tls_enabled: bool) -> Result<(String, u16)> { +fn parse_host_port( + req: &Request, + tls_enabled: bool, +) -> Result<(String, u16)> { let host_port_headers = req.headers().get("host"); let host_uri = req.uri().host(); let port_uri = req.uri().port_u16(); diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index b46e88c..1a8c707 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -74,13 +74,14 @@ where }); } - async fn start_without_tls( - self, - listener: TcpListener, - server: Http, - ) -> Result<()> { + async fn start_without_tls(self, server: Http) -> Result<()> { let listener_service = async { - while let Ok((stream, _client_addr)) = listener.accept().await { + let tcp_listener = TcpListener::bind(&self.listening_on).await?; + info!( + "Start TCP proxy serving with HTTP request for configured host names: {:?}", + tcp_listener.local_addr()? + ); + while let Ok((stream, _client_addr)) = tcp_listener.accept().await { self .clone() .client_serve(stream, server.clone(), _client_addr) @@ -93,8 +94,6 @@ where } pub async fn start(self) -> Result<()> { - let tcp_listener = TcpListener::bind(&self.listening_on).await?; - let mut server = Http::new(); server.http1_keep_alive(self.globals.keepalive); server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); @@ -103,18 +102,10 @@ where let server = server.with_executor(executor); if self.tls_enabled { - info!( - "Start TCP proxy serving with HTTPS request for configured host names: {:?}", - tcp_listener.local_addr()? - ); // #[cfg(feature = "tls")] - self.start_with_tls(tcp_listener, server).await?; + self.start_with_tls(server).await?; } else { - info!( - "Start TCP proxy serving with HTTP request for configured host names: {:?}", - tcp_listener.local_addr()? - ); - self.start_without_tls(tcp_listener, server).await?; + self.start_without_tls(server).await?; } Ok(()) diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index b8f939f..eb5a909 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -1,7 +1,10 @@ use super::proxy_main::{LocalExecutor, Proxy}; use crate::{constants::CERTS_WATCH_DELAY_SECS, error::*, log::*}; +#[cfg(feature = "h3")] +use futures::StreamExt; use futures::{future::FutureExt, join, select}; use hyper::{client::connect::Connect, server::conn::Http}; +use rustls::ServerConfig; use std::{sync::Arc, time::Duration}; use tokio::net::TcpListener; @@ -9,11 +12,7 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn start_with_tls( - self, - listener: TcpListener, - server: Http, - ) -> Result<()> { + pub async fn start_with_tls(self, server: Http) -> Result<()> { let cert_service = async { info!("Start cert watch service for {}", self.listening_on); loop { @@ -28,10 +27,17 @@ where } }; + // TCP Listener Service, i.e., http/2 and http/1.1 let listener_service = async { + let tcp_listener = TcpListener::bind(&self.listening_on).await?; + info!( + "Start TCP proxy serving with HTTPS request for configured host names: {:?}", + tcp_listener.local_addr()? + ); + loop { select! { - tcp_cnx = listener.accept().fuse() => { + tcp_cnx = tcp_listener.accept().fuse() => { if tcp_cnx.is_err() { continue; } @@ -53,25 +59,13 @@ where info!("No SNI in ClientHello"); continue; }; - let backend_serve = if let Some(backend_serve) = self.backends.apps.get(svn){ - backend_serve - } else { - info!("No configuration for the server name {} given in client_hello", svn); - continue; - }; - - if backend_serve.tls_cert_path.is_none() { // at least cert does exit - debug!("SNI indicates a site that doesn't support TLS."); - continue; - } - let server_config = if let Some(p) = backend_serve.get_tls_server_config(){ + let server_crypto = if let Some(p) = self.fetch_server_crypto(svn) { p } else { - error!("Failed to load server config"); continue; }; // Finally serve the TLS connection - if let Ok(stream) = start.into_stream(Arc::new(server_config)).await { + if let Ok(stream) = start.into_stream(Arc::new(server_crypto)).await { self.clone().client_serve(stream, server.clone(), _client_addr).await } } @@ -81,6 +75,112 @@ where Ok(()) as Result<()> }; - join!(listener_service, cert_service).0 + /////////////////////// + #[cfg(feature = "h3")] + let listener_service_h3 = async { + // TODO: Work around to initially serve incoming connection + // かなり適当。エラーが出たり出なかったり。原因がわからない… + let tls_app_names: Vec = self + .backends + .apps + .iter() + .filter(|&(_, backend)| { + backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() + }) + .map(|(name, _)| name.to_string()) + .collect(); + ensure!(!tls_app_names.is_empty(), "No TLS supported app"); + let initial_app_name = tls_app_names.get(0).unwrap().as_str(); + info!("Initial app_name: {}", initial_app_name); + let backend_serve = self.backends.apps.get(initial_app_name).unwrap(); + let server_crypto = backend_serve.get_tls_server_config().unwrap(); + let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + + let (endpoint, incoming) = + quinn::Endpoint::server(server_config_h3, self.listening_on).unwrap(); + debug!("HTTP/3 UDP listening on {}", endpoint.local_addr().unwrap()); + + let mut p = incoming.peekable(); + loop { + // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts. + // peek() should work for that. + if let Some(peeked_conn) = std::pin::Pin::new(&mut p).peek_mut().await { + let hsd = peeked_conn.handshake_data().await; + let hsd_downcast = hsd? + .downcast::() + .unwrap(); + let svn = if let Some(sni) = hsd_downcast.server_name { + sni + } else { + debug!("HTTP/3 no SNI is given"); + continue; + }; + let new_server_crypto = if let Some(p) = self.fetch_server_crypto(&svn) { + p + } else { + continue; + }; + // Set ServerConfig::set_server_config for given SNI + let mut new_server_config_h3 = + quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)); + if svn == "localhost" { + new_server_config_h3.concurrent_connections(512); + } + info!( + "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", + svn + ); + endpoint.set_server_config(Some(new_server_config_h3)); + } + + // Then acquire actual connection + let peekable_incoming = std::pin::Pin::new(&mut p); + if let Some(conn) = peekable_incoming.get_mut().next().await { + let fut = self.clone().client_serve_h3(conn); + self.globals.runtime_handle.spawn(async { + if let Err(e) = fut.await { + warn!("QUIC or HTTP/3 connection failed: {}", e) + } + }); + } else { + break; + } + } + endpoint.wait_idle().await; + Ok(()) as Result<()> + }; + + #[cfg(not(feature = "h3"))] + { + join!(listener_service, cert_service).0 + } + #[cfg(feature = "h3")] + { + join!(listener_service, cert_service, listener_service_h3).0 + } + } + + fn fetch_server_crypto(&self, server_name: &str) -> Option { + let backend_serve = if let Some(backend_serve) = self.backends.apps.get(server_name) { + backend_serve + } else { + warn!( + "No configuration for the server name {} given in client_hello", + server_name + ); + return None; + }; + + if backend_serve.tls_cert_path.is_none() { + // at least cert does exit + warn!("SNI indicates a site that doesn't support TLS."); + return None; + } + if let Some(p) = backend_serve.get_tls_server_config() { + Some(p) + } else { + error!("Failed to load server config"); + None + } } }