diff --git a/.gitmodules b/.gitmodules index b9069a0..59b7ea8 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "quinn"] path = quinn url = git@github.com:junkurihara/quinn.git +[submodule "s2n-quic"] + path = s2n-quic + url = git@github.com:junkurihara/s2n-quic.git diff --git a/Cargo.toml b/Cargo.toml index 64d1414..aa65657 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,7 +1,7 @@ [workspace] members = ["rpxy-bin", "rpxy-lib"] -exclude = ["quinn", "h3-quinn", "h3"] +exclude = ["quinn", "h3-quinn", "h3", "s2n-quic"] [profile.release] codegen-units = 1 diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 0fc2ae4..3d1f90c 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -13,7 +13,7 @@ publish = false [features] default = ["http3"] -http3 = [] +http3 = ["rpxy-lib/http3"] [dependencies] rpxy-lib = { path = "../rpxy-lib/", features = ["http3", "sticky-cookie"] } diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 72f8a8a..1915b68 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -44,6 +44,7 @@ where if proxy_config.https_port.is_some() { info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); } + #[cfg(feature = "http3")] if proxy_config.http3 { info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); } diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 749239c..05d63b0 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -3,6 +3,8 @@ mod proxy_client_cert; #[cfg(feature = "http3")] mod proxy_h3; mod proxy_main; +#[cfg(feature = "http3")] +mod proxy_quic; mod proxy_tls; mod socket; diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 324060f..eac6dbf 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,7 +1,7 @@ use super::Proxy; use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; use bytes::{Buf, Bytes}; -use h3::{quic::BidiStream, server::RequestStream}; +use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; use hyper::{client::connect::Connect, Body, Request, Response}; use std::net::SocketAddr; use tokio::time::{timeout, Duration}; @@ -11,67 +11,64 @@ where T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { - pub(super) async fn connection_serve_h3( + pub(super) async fn connection_serve_h3( self, - conn: quinn::Connecting, + quic_connection: C, tls_server_name: ServerNameBytesExp, - ) -> Result<()> { - let client_addr = conn.remote_address(); - - match conn.await { - Ok(new_conn) => { - let mut h3_conn = h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)).await?; - info!( - "QUIC/HTTP3 connection established from {:?} {:?}", - client_addr, tls_server_name - ); - // TODO: Is here enough to fetch server_name from NewConnection? - // to avoid deep nested call from listener_service_h3 - loop { - // this routine follows hyperium/h3 examples https://github.com/hyperium/h3/blob/master/examples/server.rs - match h3_conn.accept().await { - Ok(None) => { - break; - } - Err(e) => { - warn!("HTTP/3 error on accept incoming connection: {}", e); - match e.get_error_level() { - h3::error::ErrorLevel::ConnectionError => break, - h3::error::ErrorLevel::StreamError => continue, - } - } - Ok(Some((req, stream))) => { - // We consider the connection count separately from the stream count. - // Max clients for h1/h2 = max 'stream' for h3. - let request_count = self.globals.request_count.clone(); - if request_count.increment() > self.globals.proxy_config.max_clients { - request_count.decrement(); - h3_conn.shutdown(0).await?; - break; - } - debug!("Request incoming: current # {}", request_count.current()); - - let self_inner = self.clone(); - let tls_server_name_inner = tls_server_name.clone(); - self.globals.runtime_handle.spawn(async move { - if let Err(e) = timeout( - self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 - self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), - ) - .await - { - error!("HTTP/3 failed to process stream: {}", e); - } - request_count.decrement(); - debug!("Request processed: current # {}", request_count.current()); - }); - } + client_addr: SocketAddr, + ) -> Result<()> + where + C: ConnectionQuic, + >::BidiStream: BidiStream + Send + 'static, + <>::BidiStream as BidiStream>::RecvStream: Send, + <>::BidiStream as BidiStream>::SendStream: Send, + { + let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; + info!( + "QUIC/HTTP3 connection established from {:?} {:?}", + client_addr, tls_server_name + ); + // TODO: Is here enough to fetch server_name from NewConnection? + // to avoid deep nested call from listener_service_h3 + loop { + // this routine follows hyperium/h3 examples https://github.com/hyperium/h3/blob/master/examples/server.rs + match h3_conn.accept().await { + Ok(None) => { + break; + } + Err(e) => { + warn!("HTTP/3 error on accept incoming connection: {}", e); + match e.get_error_level() { + h3::error::ErrorLevel::ConnectionError => break, + h3::error::ErrorLevel::StreamError => continue, } } - } - Err(err) => { - warn!("QUIC accepting connection failed: {:?}", err); - return Err(RpxyError::QuicConn(err)); + Ok(Some((req, stream))) => { + // We consider the connection count separately from the stream count. + // Max clients for h1/h2 = max 'stream' for h3. + let request_count = self.globals.request_count.clone(); + if request_count.increment() > self.globals.proxy_config.max_clients { + request_count.decrement(); + h3_conn.shutdown(0).await?; + break; + } + debug!("Request incoming: current # {}", request_count.current()); + + let self_inner = self.clone(); + let tls_server_name_inner = tls_server_name.clone(); + self.globals.runtime_handle.spawn(async move { + if let Err(e) = timeout( + self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 + self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), + ) + .await + { + error!("HTTP/3 failed to process stream: {}", e); + } + request_count.decrement(); + debug!("Request processed: current # {}", request_count.current()); + }); + } } } diff --git a/rpxy-lib/src/proxy/proxy_quic.rs b/rpxy-lib/src/proxy/proxy_quic.rs new file mode 100644 index 0000000..0e660c1 --- /dev/null +++ b/rpxy-lib/src/proxy/proxy_quic.rs @@ -0,0 +1,124 @@ +use super::socket::bind_udp_socket; +use super::{ + crypto_service::{ServerCrypto, ServerCryptoBase}, + proxy_main::Proxy, +}; +use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; +use hot_reload::ReloaderReceiver; +use hyper::client::connect::Connect; +use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; +use rustls::ServerConfig; +use std::sync::Arc; + +impl Proxy +where + T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, +{ + pub(super) async fn listener_service_h3( + &self, + mut server_crypto_rx: ReloaderReceiver, + ) -> Result<()> { + info!("Start UDP proxy serving with HTTP/3 request for configured host names"); + // first set as null config server + let rustls_server_config = ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13])? + .with_no_client_auth() + .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); + + let mut transport_config_quic = TransportConfig::default(); + transport_config_quic + .max_concurrent_bidi_streams(self.globals.proxy_config.h3_max_concurrent_bidistream) + .max_concurrent_uni_streams(self.globals.proxy_config.h3_max_concurrent_unistream) + .max_idle_timeout( + self + .globals + .proxy_config + .h3_max_idle_timeout + .map(|v| quinn::IdleTimeout::try_from(v).unwrap()), + ); + + let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config)); + server_config_h3.transport = Arc::new(transport_config_quic); + server_config_h3.concurrent_connections(self.globals.proxy_config.h3_max_concurrent_connections); + + // To reuse address + let udp_socket = bind_udp_socket(&self.listening_on)?; + let runtime = quinn::default_runtime() + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?; + let endpoint = Endpoint::new( + quinn::EndpointConfig::default(), + Some(server_config_h3), + udp_socket, + runtime, + )?; + + let mut server_crypto: Option> = None; + loop { + tokio::select! { + new_conn = endpoint.accept() => { + if server_crypto.is_none() || new_conn.is_none() { + continue; + } + let mut conn: quinn::Connecting = new_conn.unwrap(); + let Ok(hsd) = conn.handshake_data().await else { + continue + }; + + let Ok(hsd_downcast) = hsd.downcast::() else { + continue + }; + let Some(new_server_name) = hsd_downcast.server_name else { + warn!("HTTP/3 no SNI is given"); + continue; + }; + debug!( + "HTTP/3 connection incoming (SNI {:?})", + new_server_name + ); + // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは? + // TODO: 通常のTLSと同じenumか何かにまとめたい + let self_clone = self.clone(); + self.globals.runtime_handle.spawn(async move { + let client_addr = conn.remote_address(); + let quic_connection = match conn.await { + Ok(new_conn) => { + info!("New connection established"); + h3_quinn::Connection::new(new_conn) + }, + Err(e) => { + warn!("QUIC accepting connection failed: {:?}", e); + return Err(RpxyError::QuicConn(e)); + } + }; + // Timeout is based on underlying quic + if let Err(e) = self_clone.connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr).await { + warn!("QUIC or HTTP/3 connection failed: {}", e); + }; + Ok(()) + }); + } + _ = server_crypto_rx.changed() => { + if server_crypto_rx.borrow().is_none() { + error!("Reloader is broken"); + break; + } + let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); + + server_crypto = (&cert_keys_map).try_into().ok(); + let Some(inner) = server_crypto.clone() else { + error!("Failed to update server crypto for h3"); + break; + }; + endpoint.set_server_config(Some(QuicServerConfig::with_crypto(inner.clone().inner_global_no_client_auth.clone()))); + + } + else => break + } + } + endpoint.wait_idle().await; + Ok(()) as Result<()> + } +} diff --git a/rpxy-lib/src/proxy/proxy_tls.rs b/rpxy-lib/src/proxy/proxy_tls.rs index 5512eff..b937b02 100644 --- a/rpxy-lib/src/proxy/proxy_tls.rs +++ b/rpxy-lib/src/proxy/proxy_tls.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "http3")] -use super::socket::bind_udp_socket; use super::{ crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, proxy_main::{LocalExecutor, Proxy}, @@ -8,10 +6,6 @@ use super::{ use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName}; use hot_reload::{ReloaderReceiver, ReloaderService}; use hyper::{client::connect::Connect, server::conn::Http}; -#[cfg(feature = "http3")] -use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; -#[cfg(feature = "http3")] -use rustls::ServerConfig; use std::sync::Arc; use tokio::time::{timeout, Duration}; @@ -105,99 +99,6 @@ where Ok(()) as Result<()> } - #[cfg(feature = "http3")] - async fn listener_service_h3(&self, mut server_crypto_rx: ReloaderReceiver) -> Result<()> { - info!("Start UDP proxy serving with HTTP/3 request for configured host names"); - // first set as null config server - let rustls_server_config = ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13])? - .with_no_client_auth() - .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); - - let mut transport_config_quic = TransportConfig::default(); - transport_config_quic - .max_concurrent_bidi_streams(self.globals.proxy_config.h3_max_concurrent_bidistream) - .max_concurrent_uni_streams(self.globals.proxy_config.h3_max_concurrent_unistream) - .max_idle_timeout( - self - .globals - .proxy_config - .h3_max_idle_timeout - .map(|v| quinn::IdleTimeout::try_from(v).unwrap()), - ); - - let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config)); - server_config_h3.transport = Arc::new(transport_config_quic); - server_config_h3.concurrent_connections(self.globals.proxy_config.h3_max_concurrent_connections); - - // To reuse address - let udp_socket = bind_udp_socket(&self.listening_on)?; - let runtime = quinn::default_runtime() - .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?; - let endpoint = Endpoint::new( - quinn::EndpointConfig::default(), - Some(server_config_h3), - udp_socket, - runtime, - )?; - - let mut server_crypto: Option> = None; - loop { - tokio::select! { - new_conn = endpoint.accept() => { - if server_crypto.is_none() || new_conn.is_none() { - continue; - } - let mut conn: quinn::Connecting = new_conn.unwrap(); - let Ok(hsd) = conn.handshake_data().await else { - continue - }; - - let Ok(hsd_downcast) = hsd.downcast::() else { - continue - }; - let Some(new_server_name) = hsd_downcast.server_name else { - warn!("HTTP/3 no SNI is given"); - continue; - }; - debug!( - "HTTP/3 connection incoming (SNI {:?})", - new_server_name - ); - // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは? - // TODO: 通常のTLSと同じenumか何かにまとめたい - let fut = self.clone().connection_serve_h3(conn, new_server_name.to_server_name_vec()); - self.globals.runtime_handle.spawn(async move { - // Timeout is based on underlying quic - if let Err(e) = fut.await { - warn!("QUIC or HTTP/3 connection failed: {}", e) - } - }); - } - _ = server_crypto_rx.changed() => { - if server_crypto_rx.borrow().is_none() { - error!("Reloader is broken"); - break; - } - let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); - - server_crypto = (&cert_keys_map).try_into().ok(); - let Some(inner) = server_crypto.clone() else { - error!("Failed to update server crypto for h3"); - break; - }; - endpoint.set_server_config(Some(QuicServerConfig::with_crypto(inner.clone().inner_global_no_client_auth.clone()))); - - } - else => break - } - } - endpoint.wait_idle().await; - Ok(()) as Result<()> - } - pub async fn start_with_tls(self, server: Http) -> Result<()> { let (cert_reloader_service, cert_reloader_rx) = ReloaderService::, ServerCryptoBase>::new( &self.globals.clone(), diff --git a/s2n-quic b/s2n-quic new file mode 160000 index 0000000..179acb8 --- /dev/null +++ b/s2n-quic @@ -0,0 +1 @@ +Subproject commit 179acb8a873eafbfc7b68de4018cd251caddfa44