From 36c8ebcb5467988ac536a92fe4752f8d7b830055 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 22 Jul 2022 22:26:18 +0900 Subject: [PATCH] change handling of maximum capable request number, max streams --- Cargo.toml | 11 +++++---- config-example.toml | 16 +++++++++---- src/config/parse.rs | 23 +++++++++++++++---- src/config/toml.rs | 11 ++++++++- src/constants.rs | 14 ++++++++---- src/globals.rs | 18 ++++++++++----- src/log.rs | 6 ++--- src/main.rs | 41 +++++++++++++++++++-------------- src/msg_handler/handler.rs | 3 +-- src/proxy/proxy_h3.rs | 47 +++++++++++++++----------------------- src/proxy/proxy_main.rs | 11 +++++---- src/proxy/proxy_tls.rs | 26 +++++++++++++++------ 12 files changed, 138 insertions(+), 89 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 18485a7..9743e0c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "rust-rpxy" +name = "rpxy" version = "0.1.0" authors = ["Jun Kurihara"] homepage = "https://github.com/junkurihara/rust-rpxy" @@ -17,7 +17,7 @@ h3 = ["quinn"] [dependencies] anyhow = "1.0.58" -clap = { version = "3.2.10", features = ["std", "cargo", "wrap_help"] } +clap = { version = "3.2.13", features = ["std", "cargo", "wrap_help"] } env_logger = "0.9.0" futures = "0.3.21" hyper = { version = "0.14.20", default-features = false, features = [ @@ -27,7 +27,7 @@ hyper = { version = "0.14.20", default-features = false, features = [ "stream", ] } log = "0.4.17" -tokio = { version = "1.19.2", default-features = false, features = [ +tokio = { version = "1.20.0", default-features = false, features = [ "net", "rt-multi-thread", "parking_lot", @@ -51,9 +51,10 @@ hyper-rustls = { version = "0.23.0", default-features = false, features = [ quinn = { version = "0.8.3", optional = true } h3 = { git = "https://github.com/hyperium/h3.git" } h3-quinn = { git = "https://github.com/hyperium/h3.git" } -bytes = "1.1.0" -mimalloc-rust = "0.2.0" +bytes = "1.2.0" +[target.'cfg(not(target_env = "msvc"))'.dependencies] +tikv-jemallocator = "0.5.0" [dev-dependencies] diff --git a/config-example.toml b/config-example.toml index 68c0e8a..ef8f5d9 100644 --- a/config-example.toml +++ b/config-example.toml @@ -10,8 +10,10 @@ listen_port = 8080 listen_port_tls = 8443 -# Optional -max_concurrent_streams = 128 +# Optional for h2 and http1.1 +max_concurrent_streams = 100 + +# Optional. Counted in total for http1.1, 2, 3 max_clients = 512 # Optional: Listen [::] @@ -65,5 +67,11 @@ reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] # Experimantal settings # ################################### [experimental] -h3 = true -ignore_sni_consistency = false # Higly recommend not to be true. If true, you ignore RFC. +ignore_sni_consistency = false # Higly recommend not to be true. If true, you ignore RFC. if not specified, it is always false. + +[experimenta.h3] # If this specified, h3 is enabled +alt_svc_max_age = 3600 # sec +request_max_body_size = 65536 # bytes +max_concurrent_connections = 10000 +max_concurrent_bidistream = 100 +max_concurrent_unistream = 100 diff --git a/src/config/parse.rs b/src/config/parse.rs index 40319c1..a8290a6 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -157,16 +157,29 @@ pub fn parse_opts(globals: &mut Globals) -> Result<()> { // experimental if let Some(exp) = config.experimental { - if let Some(b) = exp.h3 { - globals.http3 = b; - if b { - info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.") + if let Some(h3option) = exp.h3 { + globals.http3 = true; + info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); + if let Some(x) = h3option.alt_svc_max_age { + globals.h3_alt_svc_max_age = x; + } + if let Some(x) = h3option.request_max_body_size { + globals.h3_request_max_body_size = x; + } + if let Some(x) = h3option.max_concurrent_connections { + globals.h3_max_concurrent_connections = x; + } + if let Some(x) = h3option.max_concurrent_bidistream { + globals.h3_max_concurrent_bidistream = x.into(); + } + if let Some(x) = h3option.max_concurrent_unistream { + globals.h3_max_concurrent_unistream = x.into(); } } if let Some(b) = exp.ignore_sni_consistency { globals.sni_consistency = !b; if b { - info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC.") + info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC."); } } } diff --git a/src/config/toml.rs b/src/config/toml.rs index 19813ab..11e2db9 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -15,9 +15,18 @@ pub struct ConfigToml { pub experimental: Option, } +#[derive(Deserialize, Debug, Default)] +pub struct Http3Option { + pub alt_svc_max_age: Option, + pub request_max_body_size: Option, + pub max_concurrent_connections: Option, + pub max_concurrent_bidistream: Option, + pub max_concurrent_unistream: Option, +} + #[derive(Deserialize, Debug, Default)] pub struct Experimental { - pub h3: Option, + pub h3: Option, pub ignore_sni_consistency: Option, } diff --git a/src/constants.rs b/src/constants.rs index 76bcb7b..3816ab6 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -5,17 +5,21 @@ pub const LISTEN_ADDRESSES_V6: &[&str] = &["[::]"]; pub const PROXY_TIMEOUT_SEC: u64 = 60; pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; pub const MAX_CLIENTS: usize = 512; -pub const MAX_CONCURRENT_STREAMS: u32 = 32; +pub const MAX_CONCURRENT_STREAMS: u32 = 64; // #[cfg(feature = "tls")] pub const CERTS_WATCH_DELAY_SECS: u32 = 30; -#[cfg(feature = "h3")] -pub const H3_ALT_SVC_MAX_AGE: u32 = 3600; - // #[cfg(feature = "h3")] // pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB // #[cfg(feature = "h3")] // pub const H3_REQUEST_BUF_SIZE: usize = 65_536; // 64KB // handled by quinn +#[allow(non_snake_case)] #[cfg(feature = "h3")] -pub const H3_REQUEST_MAX_BODY_SIZE: usize = 268_435_456; // 256MB +pub mod H3 { + pub const ALT_SVC_MAX_AGE: u32 = 3600; + pub const REQUEST_MAX_BODY_SIZE: usize = 268_435_456; // 256MB + pub const MAX_CONCURRENT_CONNECTIONS: u32 = 4096; + pub const MAX_CONCURRENT_BIDISTREAM: u32 = 64; + pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; +} diff --git a/src/globals.rs b/src/globals.rs index abec03f..1880fe8 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -15,21 +15,27 @@ pub struct Globals { pub upstream_timeout: Duration, pub max_clients: usize, - pub clients_count: ClientsCount, + pub request_count: RequestCount, pub max_concurrent_streams: u32, pub keepalive: bool, - pub http3: bool, - pub sni_consistency: bool, pub runtime_handle: tokio::runtime::Handle, - pub backends: Backends, + + // experimentals + pub sni_consistency: bool, + pub http3: bool, + pub h3_alt_svc_max_age: u32, + pub h3_request_max_body_size: usize, + pub h3_max_concurrent_bidistream: quinn::VarInt, + pub h3_max_concurrent_unistream: quinn::VarInt, + pub h3_max_concurrent_connections: u32, } #[derive(Debug, Clone, Default)] -pub struct ClientsCount(Arc); +pub struct RequestCount(Arc); -impl ClientsCount { +impl RequestCount { pub fn current(&self) -> usize { self.0.load(Ordering::Relaxed) } diff --git a/src/log.rs b/src/log.rs index 7a20d42..10a0267 100644 --- a/src/log.rs +++ b/src/log.rs @@ -1,8 +1,6 @@ -use std::net::SocketAddr; - -pub use log::{debug, error, info, warn}; - use crate::utils::ToCanonical; +pub use log::{debug, error, info, warn, Level}; +use std::net::SocketAddr; #[derive(Debug, Clone)] pub struct MessageLog { diff --git a/src/main.rs b/src/main.rs index 23ce609..e87ff75 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,9 @@ -use mimalloc_rust::*; +#[cfg(not(target_env = "msvc"))] +use tikv_jemallocator::Jemalloc; + +#[cfg(not(target_env = "msvc"))] #[global_allocator] -static GLOBAL_MIMALLOC: GlobalMiMalloc = GlobalMiMalloc; +static GLOBAL: Jemalloc = Jemalloc; mod backend; mod backend_opt; @@ -33,18 +36,16 @@ use tokio::time::Duration; fn main() { // env::set_var("RUST_LOG", "info"); env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) - .format(|buf, record| { + .format(|buf, rec| { let ts = buf.timestamp(); - writeln!( - buf, - "{} [{}] {}", - ts, - record.level(), - // record.target(), - record.args(), - // record.file().unwrap_or("unknown"), - // record.line().unwrap_or(0), - ) + match rec.level() { + log::Level::Debug => { + writeln!(buf, "{} [{}] {} ({})", ts, rec.level(), rec.args(), rec.target(),) + } + _ => { + writeln!(buf, "{} [{}] {}", ts, rec.level(), rec.args(),) + } + } }) .init(); info!("Start http (reverse) proxy"); @@ -59,23 +60,29 @@ fn main() { listen_sockets: Vec::new(), http_port: None, https_port: None, - http3: false, - sni_consistency: true, // TODO: Reconsider each timeout values proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), max_clients: MAX_CLIENTS, - clients_count: Default::default(), + request_count: Default::default(), max_concurrent_streams: MAX_CONCURRENT_STREAMS, keepalive: true, - runtime_handle: runtime.handle().clone(), + runtime_handle: runtime.handle().clone(), backends: Backends { default_server_name: None, apps: HashMap::::default(), }, + + sni_consistency: true, + http3: false, + h3_alt_svc_max_age: H3::ALT_SVC_MAX_AGE, + h3_request_max_body_size: H3::REQUEST_MAX_BODY_SIZE, + h3_max_concurrent_connections: H3::MAX_CONCURRENT_CONNECTIONS, + h3_max_concurrent_bidistream: H3::MAX_CONCURRENT_BIDISTREAM.into(), + h3_max_concurrent_unistream: H3::MAX_CONCURRENT_UNISTREAM.into(), }; if let Err(e) = parse_opts(&mut globals) { diff --git a/src/msg_handler/handler.rs b/src/msg_handler/handler.rs index 76dd8fa..ea2b621 100644 --- a/src/msg_handler/handler.rs +++ b/src/msg_handler/handler.rs @@ -2,7 +2,6 @@ use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*}; use crate::{ backend::{ServerNameLC, Upstream}, - constants::*, error::*, globals::Globals, log::*, @@ -229,7 +228,7 @@ where header::ALT_SVC.as_str(), format!( "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", - port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE + port, self.globals.h3_alt_svc_max_age, port, self.globals.h3_alt_svc_max_age ), )?; } diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs index 7749ba2..7eb5868 100644 --- a/src/proxy/proxy_h3.rs +++ b/src/proxy/proxy_h3.rs @@ -1,5 +1,5 @@ use super::Proxy; -use crate::{backend::ServerNameLC, constants::*, error::*, log::*}; +use crate::{backend::ServerNameLC, error::*, log::*}; use bytes::{Buf, Bytes}; use h3::{quic::BidiStream, server::RequestStream}; use hyper::{client::connect::Connect, Body, Request, Response}; @@ -10,12 +10,7 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub(super) fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) { - let clients_count = self.globals.clients_count.clone(); - if clients_count.increment() > self.globals.max_clients { - clients_count.decrement(); - return; - } + pub(super) fn connection_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) { let fut = self .clone() .handle_connection_h3(conn, tls_server_name.to_vec()); @@ -24,8 +19,6 @@ where if let Err(e) = fut.await { warn!("QUIC or HTTP/3 connection failed: {}", e) } - clients_count.decrement(); - debug!("Client #: {}", clients_count.current()); }); } @@ -45,40 +38,37 @@ where "QUIC/HTTP3 connection established from {:?} {:?}", client_addr, tls_server_name ); - - // Does this work enough? - // while let Some((req, stream)) = h3_conn - // .accept() - // .await - // .map_err(|e| anyhow!("HTTP/3 accept failed: {}", e))? + // TODO: Is here enough to fetch server_name from NewConnection? + // to avoid deep nested call from listener_service_h3 while let Some((req, stream)) = match h3_conn.accept().await { Ok(opt_req) => opt_req, Err(e) => { - warn!( - "HTTP/3 failed to accept incoming connection (likely timeout): {}", - e - ); + warn!("HTTP/3 failed to accept incoming connection: {}", e); return Ok(h3_conn.shutdown(0).await?); } } { - debug!( - "HTTP/3 new request from {}: {} {}", - client_addr, - req.method(), - req.uri() - ); + // We consider the connection count separately from the stream count. + // Max clients for h1/h2 = max 'stream' for h3. + let request_count = self.globals.request_count.clone(); + if request_count.increment() > self.globals.max_clients { + request_count.decrement(); + return Ok(h3_conn.shutdown(0).await?); + } + debug!("Request incoming: current # {}", request_count.current()); let self_inner = self.clone(); let tls_server_name_inner = tls_server_name.clone(); self.globals.runtime_handle.spawn(async move { if let Err(e) = timeout( self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 - self_inner.handle_stream_h3(req, stream, client_addr, tls_server_name_inner), + self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), ) .await { error!("HTTP/3 failed to process stream: {}", e); } + request_count.decrement(); + debug!("Request processed: current # {}", request_count.current()); }); } } @@ -91,7 +81,7 @@ where Ok(()) } - async fn handle_stream_h3( + async fn stream_serve_h3( self, req: Request<()>, stream: RequestStream, @@ -111,13 +101,14 @@ where // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. + let max_body_size = self.globals.h3_request_max_body_size; self.globals.runtime_handle.spawn(async move { let mut sender = body_sender; let mut size = 0usize; while let Some(mut body) = recv_stream.recv_data().await? { debug!("HTTP/3 incoming request body"); size += body.remaining(); - if size > H3_REQUEST_MAX_BODY_SIZE { + if size > max_body_size { error!("Exceeds max request body size for HTTP/3"); return Err(anyhow!("Exceeds max request body size for HTTP/3")); } diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index c56e7b2..3fa8df5 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -56,11 +56,12 @@ where ) where I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { - let clients_count = self.globals.clients_count.clone(); - if clients_count.increment() > self.globals.max_clients { - clients_count.decrement(); + let request_count = self.globals.request_count.clone(); + if request_count.increment() > self.globals.max_clients { + request_count.decrement(); return; } + debug!("Request incoming: current # {}", request_count.current()); // let inner = tls_server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase())); self.globals.runtime_handle.clone().spawn(async move { @@ -84,8 +85,8 @@ where .await .ok(); - clients_count.decrement(); - debug!("Client #: {}", clients_count.current()); + request_count.decrement(); + debug!("Request processed: current # {}", request_count.current()); }); } diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 02f1d20..4503965 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -6,7 +6,11 @@ use futures::{future::FutureExt, select}; use hyper::{client::connect::Connect, server::conn::Http}; use rustls::ServerConfig; use std::sync::Arc; -use tokio::{net::TcpListener, sync::watch, time::Duration}; +use tokio::{ + net::TcpListener, + sync::watch, + time::{sleep, Duration}, +}; use tokio_rustls::TlsAcceptor; impl Proxy @@ -29,7 +33,7 @@ where } else { error!("Failed to update certs"); } - tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; + sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; } } @@ -50,8 +54,7 @@ where if tls_acceptor.is_none() || tcp_cnx.is_err() { continue; } - - let (raw_stream, _client_addr) = tcp_cnx.unwrap(); + let (raw_stream, client_addr) = tcp_cnx.unwrap(); if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await { // Retrieve SNI @@ -62,7 +65,7 @@ where if server_name.is_none(){ continue; } - self.clone().client_serve(stream, server.clone(), _client_addr, server_name); // TODO: don't want to pass copied value... + self.clone().client_serve(stream, server.clone(), client_addr, server_name); // TODO: don't want to pass copied value... } } _ = server_crypto_rx.changed().fuse() => { @@ -83,13 +86,20 @@ where &self, mut server_crypto_rx: watch::Receiver>>, ) -> Result<()> { + let mut transport_config_quic = quinn::TransportConfig::default(); + transport_config_quic + .max_concurrent_bidi_streams(self.globals.h3_max_concurrent_bidistream) + .max_concurrent_uni_streams(self.globals.h3_max_concurrent_unistream); + let server_crypto = self .globals .backends .generate_server_crypto_with_cert_resolver() .await?; - let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + let mut server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + server_config_h3.transport = Arc::new(transport_config_quic); + server_config_h3.concurrent_connections(self.globals.h3_max_concurrent_connections); let (endpoint, mut incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?; info!("Start UDP proxy serving with HTTP/3 request for configured host names"); @@ -121,7 +131,9 @@ where "HTTP/3 connection incoming (SNI {:?})", new_server_name ); - self.clone().client_serve_h3(conn, new_server_name.as_ref()); + // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは? + // TODO: 通常のTLSと同じenumか何かにまとめたい + self.clone().connection_serve_h3(conn, new_server_name.as_ref()); } _ = server_crypto_rx.changed().fuse() => { if server_crypto_rx.borrow().is_none() {