change handling of maximum capable request number, max streams

This commit is contained in:
Jun Kurihara 2022-07-22 22:26:18 +09:00
commit 36c8ebcb54
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
12 changed files with 138 additions and 89 deletions

View file

@ -1,5 +1,5 @@
[package] [package]
name = "rust-rpxy" name = "rpxy"
version = "0.1.0" version = "0.1.0"
authors = ["Jun Kurihara"] authors = ["Jun Kurihara"]
homepage = "https://github.com/junkurihara/rust-rpxy" homepage = "https://github.com/junkurihara/rust-rpxy"
@ -17,7 +17,7 @@ h3 = ["quinn"]
[dependencies] [dependencies]
anyhow = "1.0.58" anyhow = "1.0.58"
clap = { version = "3.2.10", features = ["std", "cargo", "wrap_help"] } clap = { version = "3.2.13", features = ["std", "cargo", "wrap_help"] }
env_logger = "0.9.0" env_logger = "0.9.0"
futures = "0.3.21" futures = "0.3.21"
hyper = { version = "0.14.20", default-features = false, features = [ hyper = { version = "0.14.20", default-features = false, features = [
@ -27,7 +27,7 @@ hyper = { version = "0.14.20", default-features = false, features = [
"stream", "stream",
] } ] }
log = "0.4.17" log = "0.4.17"
tokio = { version = "1.19.2", default-features = false, features = [ tokio = { version = "1.20.0", default-features = false, features = [
"net", "net",
"rt-multi-thread", "rt-multi-thread",
"parking_lot", "parking_lot",
@ -51,9 +51,10 @@ hyper-rustls = { version = "0.23.0", default-features = false, features = [
quinn = { version = "0.8.3", optional = true } quinn = { version = "0.8.3", optional = true }
h3 = { git = "https://github.com/hyperium/h3.git" } h3 = { git = "https://github.com/hyperium/h3.git" }
h3-quinn = { git = "https://github.com/hyperium/h3.git" } h3-quinn = { git = "https://github.com/hyperium/h3.git" }
bytes = "1.1.0" bytes = "1.2.0"
mimalloc-rust = "0.2.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = "0.5.0"
[dev-dependencies] [dev-dependencies]

View file

@ -10,8 +10,10 @@
listen_port = 8080 listen_port = 8080
listen_port_tls = 8443 listen_port_tls = 8443
# Optional # Optional for h2 and http1.1
max_concurrent_streams = 128 max_concurrent_streams = 100
# Optional. Counted in total for http1.1, 2, 3
max_clients = 512 max_clients = 512
# Optional: Listen [::] # Optional: Listen [::]
@ -65,5 +67,11 @@ reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }]
# Experimantal settings # # Experimantal settings #
################################### ###################################
[experimental] [experimental]
h3 = true ignore_sni_consistency = false # Higly recommend not to be true. If true, you ignore RFC. if not specified, it is always false.
ignore_sni_consistency = false # Higly recommend not to be true. If true, you ignore RFC.
[experimenta.h3] # If this specified, h3 is enabled
alt_svc_max_age = 3600 # sec
request_max_body_size = 65536 # bytes
max_concurrent_connections = 10000
max_concurrent_bidistream = 100
max_concurrent_unistream = 100

View file

@ -157,16 +157,29 @@ pub fn parse_opts(globals: &mut Globals) -> Result<()> {
// experimental // experimental
if let Some(exp) = config.experimental { if let Some(exp) = config.experimental {
if let Some(b) = exp.h3 { if let Some(h3option) = exp.h3 {
globals.http3 = b; globals.http3 = true;
if b { info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.");
info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.") if let Some(x) = h3option.alt_svc_max_age {
globals.h3_alt_svc_max_age = x;
}
if let Some(x) = h3option.request_max_body_size {
globals.h3_request_max_body_size = x;
}
if let Some(x) = h3option.max_concurrent_connections {
globals.h3_max_concurrent_connections = x;
}
if let Some(x) = h3option.max_concurrent_bidistream {
globals.h3_max_concurrent_bidistream = x.into();
}
if let Some(x) = h3option.max_concurrent_unistream {
globals.h3_max_concurrent_unistream = x.into();
} }
} }
if let Some(b) = exp.ignore_sni_consistency { if let Some(b) = exp.ignore_sni_consistency {
globals.sni_consistency = !b; globals.sni_consistency = !b;
if b { if b {
info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC.") info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC.");
} }
} }
} }

View file

@ -15,9 +15,18 @@ pub struct ConfigToml {
pub experimental: Option<Experimental>, pub experimental: Option<Experimental>,
} }
#[derive(Deserialize, Debug, Default)]
pub struct Http3Option {
pub alt_svc_max_age: Option<u32>,
pub request_max_body_size: Option<usize>,
pub max_concurrent_connections: Option<u32>,
pub max_concurrent_bidistream: Option<u32>,
pub max_concurrent_unistream: Option<u32>,
}
#[derive(Deserialize, Debug, Default)] #[derive(Deserialize, Debug, Default)]
pub struct Experimental { pub struct Experimental {
pub h3: Option<bool>, pub h3: Option<Http3Option>,
pub ignore_sni_consistency: Option<bool>, pub ignore_sni_consistency: Option<bool>,
} }

View file

@ -5,17 +5,21 @@ pub const LISTEN_ADDRESSES_V6: &[&str] = &["[::]"];
pub const PROXY_TIMEOUT_SEC: u64 = 60; pub const PROXY_TIMEOUT_SEC: u64 = 60;
pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; pub const UPSTREAM_TIMEOUT_SEC: u64 = 60;
pub const MAX_CLIENTS: usize = 512; pub const MAX_CLIENTS: usize = 512;
pub const MAX_CONCURRENT_STREAMS: u32 = 32; pub const MAX_CONCURRENT_STREAMS: u32 = 64;
// #[cfg(feature = "tls")] // #[cfg(feature = "tls")]
pub const CERTS_WATCH_DELAY_SECS: u32 = 30; pub const CERTS_WATCH_DELAY_SECS: u32 = 30;
#[cfg(feature = "h3")]
pub const H3_ALT_SVC_MAX_AGE: u32 = 3600;
// #[cfg(feature = "h3")] // #[cfg(feature = "h3")]
// pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB // pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB
// #[cfg(feature = "h3")] // #[cfg(feature = "h3")]
// pub const H3_REQUEST_BUF_SIZE: usize = 65_536; // 64KB // handled by quinn // pub const H3_REQUEST_BUF_SIZE: usize = 65_536; // 64KB // handled by quinn
#[allow(non_snake_case)]
#[cfg(feature = "h3")] #[cfg(feature = "h3")]
pub const H3_REQUEST_MAX_BODY_SIZE: usize = 268_435_456; // 256MB pub mod H3 {
pub const ALT_SVC_MAX_AGE: u32 = 3600;
pub const REQUEST_MAX_BODY_SIZE: usize = 268_435_456; // 256MB
pub const MAX_CONCURRENT_CONNECTIONS: u32 = 4096;
pub const MAX_CONCURRENT_BIDISTREAM: u32 = 64;
pub const MAX_CONCURRENT_UNISTREAM: u32 = 64;
}

View file

@ -15,21 +15,27 @@ pub struct Globals {
pub upstream_timeout: Duration, pub upstream_timeout: Duration,
pub max_clients: usize, pub max_clients: usize,
pub clients_count: ClientsCount, pub request_count: RequestCount,
pub max_concurrent_streams: u32, pub max_concurrent_streams: u32,
pub keepalive: bool, pub keepalive: bool,
pub http3: bool,
pub sni_consistency: bool,
pub runtime_handle: tokio::runtime::Handle, pub runtime_handle: tokio::runtime::Handle,
pub backends: Backends, pub backends: Backends,
// experimentals
pub sni_consistency: bool,
pub http3: bool,
pub h3_alt_svc_max_age: u32,
pub h3_request_max_body_size: usize,
pub h3_max_concurrent_bidistream: quinn::VarInt,
pub h3_max_concurrent_unistream: quinn::VarInt,
pub h3_max_concurrent_connections: u32,
} }
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct ClientsCount(Arc<AtomicUsize>); pub struct RequestCount(Arc<AtomicUsize>);
impl ClientsCount { impl RequestCount {
pub fn current(&self) -> usize { pub fn current(&self) -> usize {
self.0.load(Ordering::Relaxed) self.0.load(Ordering::Relaxed)
} }

View file

@ -1,8 +1,6 @@
use std::net::SocketAddr;
pub use log::{debug, error, info, warn};
use crate::utils::ToCanonical; use crate::utils::ToCanonical;
pub use log::{debug, error, info, warn, Level};
use std::net::SocketAddr;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct MessageLog { pub struct MessageLog {

View file

@ -1,6 +1,9 @@
use mimalloc_rust::*; #[cfg(not(target_env = "msvc"))]
use tikv_jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator] #[global_allocator]
static GLOBAL_MIMALLOC: GlobalMiMalloc = GlobalMiMalloc; static GLOBAL: Jemalloc = Jemalloc;
mod backend; mod backend;
mod backend_opt; mod backend_opt;
@ -33,18 +36,16 @@ use tokio::time::Duration;
fn main() { fn main() {
// env::set_var("RUST_LOG", "info"); // env::set_var("RUST_LOG", "info");
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
.format(|buf, record| { .format(|buf, rec| {
let ts = buf.timestamp(); let ts = buf.timestamp();
writeln!( match rec.level() {
buf, log::Level::Debug => {
"{} [{}] {}", writeln!(buf, "{} [{}] {} ({})", ts, rec.level(), rec.args(), rec.target(),)
ts, }
record.level(), _ => {
// record.target(), writeln!(buf, "{} [{}] {}", ts, rec.level(), rec.args(),)
record.args(), }
// record.file().unwrap_or("unknown"), }
// record.line().unwrap_or(0),
)
}) })
.init(); .init();
info!("Start http (reverse) proxy"); info!("Start http (reverse) proxy");
@ -59,23 +60,29 @@ fn main() {
listen_sockets: Vec::new(), listen_sockets: Vec::new(),
http_port: None, http_port: None,
https_port: None, https_port: None,
http3: false,
sni_consistency: true,
// TODO: Reconsider each timeout values // TODO: Reconsider each timeout values
proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC),
upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC),
max_clients: MAX_CLIENTS, max_clients: MAX_CLIENTS,
clients_count: Default::default(), request_count: Default::default(),
max_concurrent_streams: MAX_CONCURRENT_STREAMS, max_concurrent_streams: MAX_CONCURRENT_STREAMS,
keepalive: true, keepalive: true,
runtime_handle: runtime.handle().clone(),
runtime_handle: runtime.handle().clone(),
backends: Backends { backends: Backends {
default_server_name: None, default_server_name: None,
apps: HashMap::<ServerNameLC, Backend>::default(), apps: HashMap::<ServerNameLC, Backend>::default(),
}, },
sni_consistency: true,
http3: false,
h3_alt_svc_max_age: H3::ALT_SVC_MAX_AGE,
h3_request_max_body_size: H3::REQUEST_MAX_BODY_SIZE,
h3_max_concurrent_connections: H3::MAX_CONCURRENT_CONNECTIONS,
h3_max_concurrent_bidistream: H3::MAX_CONCURRENT_BIDISTREAM.into(),
h3_max_concurrent_unistream: H3::MAX_CONCURRENT_UNISTREAM.into(),
}; };
if let Err(e) = parse_opts(&mut globals) { if let Err(e) = parse_opts(&mut globals) {

View file

@ -2,7 +2,6 @@
use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*}; use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*};
use crate::{ use crate::{
backend::{ServerNameLC, Upstream}, backend::{ServerNameLC, Upstream},
constants::*,
error::*, error::*,
globals::Globals, globals::Globals,
log::*, log::*,
@ -229,7 +228,7 @@ where
header::ALT_SVC.as_str(), header::ALT_SVC.as_str(),
format!( format!(
"h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE port, self.globals.h3_alt_svc_max_age, port, self.globals.h3_alt_svc_max_age
), ),
)?; )?;
} }

View file

@ -1,5 +1,5 @@
use super::Proxy; use super::Proxy;
use crate::{backend::ServerNameLC, constants::*, error::*, log::*}; use crate::{backend::ServerNameLC, error::*, log::*};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use h3::{quic::BidiStream, server::RequestStream}; use h3::{quic::BidiStream, server::RequestStream};
use hyper::{client::connect::Connect, Body, Request, Response}; use hyper::{client::connect::Connect, Body, Request, Response};
@ -10,12 +10,7 @@ impl<T> Proxy<T>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
{ {
pub(super) fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) { pub(super) fn connection_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) {
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
return;
}
let fut = self let fut = self
.clone() .clone()
.handle_connection_h3(conn, tls_server_name.to_vec()); .handle_connection_h3(conn, tls_server_name.to_vec());
@ -24,8 +19,6 @@ where
if let Err(e) = fut.await { if let Err(e) = fut.await {
warn!("QUIC or HTTP/3 connection failed: {}", e) warn!("QUIC or HTTP/3 connection failed: {}", e)
} }
clients_count.decrement();
debug!("Client #: {}", clients_count.current());
}); });
} }
@ -45,40 +38,37 @@ where
"QUIC/HTTP3 connection established from {:?} {:?}", "QUIC/HTTP3 connection established from {:?} {:?}",
client_addr, tls_server_name client_addr, tls_server_name
); );
// TODO: Is here enough to fetch server_name from NewConnection?
// Does this work enough? // to avoid deep nested call from listener_service_h3
// while let Some((req, stream)) = h3_conn
// .accept()
// .await
// .map_err(|e| anyhow!("HTTP/3 accept failed: {}", e))?
while let Some((req, stream)) = match h3_conn.accept().await { while let Some((req, stream)) = match h3_conn.accept().await {
Ok(opt_req) => opt_req, Ok(opt_req) => opt_req,
Err(e) => { Err(e) => {
warn!( warn!("HTTP/3 failed to accept incoming connection: {}", e);
"HTTP/3 failed to accept incoming connection (likely timeout): {}",
e
);
return Ok(h3_conn.shutdown(0).await?); return Ok(h3_conn.shutdown(0).await?);
} }
} { } {
debug!( // We consider the connection count separately from the stream count.
"HTTP/3 new request from {}: {} {}", // Max clients for h1/h2 = max 'stream' for h3.
client_addr, let request_count = self.globals.request_count.clone();
req.method(), if request_count.increment() > self.globals.max_clients {
req.uri() request_count.decrement();
); return Ok(h3_conn.shutdown(0).await?);
}
debug!("Request incoming: current # {}", request_count.current());
let self_inner = self.clone(); let self_inner = self.clone();
let tls_server_name_inner = tls_server_name.clone(); let tls_server_name_inner = tls_server_name.clone();
self.globals.runtime_handle.spawn(async move { self.globals.runtime_handle.spawn(async move {
if let Err(e) = timeout( if let Err(e) = timeout(
self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
self_inner.handle_stream_h3(req, stream, client_addr, tls_server_name_inner), self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner),
) )
.await .await
{ {
error!("HTTP/3 failed to process stream: {}", e); error!("HTTP/3 failed to process stream: {}", e);
} }
request_count.decrement();
debug!("Request processed: current # {}", request_count.current());
}); });
} }
} }
@ -91,7 +81,7 @@ where
Ok(()) Ok(())
} }
async fn handle_stream_h3<S>( async fn stream_serve_h3<S>(
self, self,
req: Request<()>, req: Request<()>,
stream: RequestStream<S, Bytes>, stream: RequestStream<S, Bytes>,
@ -111,13 +101,14 @@ where
// Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1
// The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn.
let max_body_size = self.globals.h3_request_max_body_size;
self.globals.runtime_handle.spawn(async move { self.globals.runtime_handle.spawn(async move {
let mut sender = body_sender; let mut sender = body_sender;
let mut size = 0usize; let mut size = 0usize;
while let Some(mut body) = recv_stream.recv_data().await? { while let Some(mut body) = recv_stream.recv_data().await? {
debug!("HTTP/3 incoming request body"); debug!("HTTP/3 incoming request body");
size += body.remaining(); size += body.remaining();
if size > H3_REQUEST_MAX_BODY_SIZE { if size > max_body_size {
error!("Exceeds max request body size for HTTP/3"); error!("Exceeds max request body size for HTTP/3");
return Err(anyhow!("Exceeds max request body size for HTTP/3")); return Err(anyhow!("Exceeds max request body size for HTTP/3"));
} }

View file

@ -56,11 +56,12 @@ where
) where ) where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static, I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{ {
let clients_count = self.globals.clients_count.clone(); let request_count = self.globals.request_count.clone();
if clients_count.increment() > self.globals.max_clients { if request_count.increment() > self.globals.max_clients {
clients_count.decrement(); request_count.decrement();
return; return;
} }
debug!("Request incoming: current # {}", request_count.current());
// let inner = tls_server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase())); // let inner = tls_server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase()));
self.globals.runtime_handle.clone().spawn(async move { self.globals.runtime_handle.clone().spawn(async move {
@ -84,8 +85,8 @@ where
.await .await
.ok(); .ok();
clients_count.decrement(); request_count.decrement();
debug!("Client #: {}", clients_count.current()); debug!("Request processed: current # {}", request_count.current());
}); });
} }

View file

@ -6,7 +6,11 @@ use futures::{future::FutureExt, select};
use hyper::{client::connect::Connect, server::conn::Http}; use hyper::{client::connect::Connect, server::conn::Http};
use rustls::ServerConfig; use rustls::ServerConfig;
use std::sync::Arc; use std::sync::Arc;
use tokio::{net::TcpListener, sync::watch, time::Duration}; use tokio::{
net::TcpListener,
sync::watch,
time::{sleep, Duration},
};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
impl<T> Proxy<T> impl<T> Proxy<T>
@ -29,7 +33,7 @@ where
} else { } else {
error!("Failed to update certs"); error!("Failed to update certs");
} }
tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await;
} }
} }
@ -50,8 +54,7 @@ where
if tls_acceptor.is_none() || tcp_cnx.is_err() { if tls_acceptor.is_none() || tcp_cnx.is_err() {
continue; continue;
} }
let (raw_stream, client_addr) = tcp_cnx.unwrap();
let (raw_stream, _client_addr) = tcp_cnx.unwrap();
if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await { if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
// Retrieve SNI // Retrieve SNI
@ -62,7 +65,7 @@ where
if server_name.is_none(){ if server_name.is_none(){
continue; continue;
} }
self.clone().client_serve(stream, server.clone(), _client_addr, server_name); // TODO: don't want to pass copied value... self.clone().client_serve(stream, server.clone(), client_addr, server_name); // TODO: don't want to pass copied value...
} }
} }
_ = server_crypto_rx.changed().fuse() => { _ = server_crypto_rx.changed().fuse() => {
@ -83,13 +86,20 @@ where
&self, &self,
mut server_crypto_rx: watch::Receiver<Option<Arc<ServerConfig>>>, mut server_crypto_rx: watch::Receiver<Option<Arc<ServerConfig>>>,
) -> Result<()> { ) -> Result<()> {
let mut transport_config_quic = quinn::TransportConfig::default();
transport_config_quic
.max_concurrent_bidi_streams(self.globals.h3_max_concurrent_bidistream)
.max_concurrent_uni_streams(self.globals.h3_max_concurrent_unistream);
let server_crypto = self let server_crypto = self
.globals .globals
.backends .backends
.generate_server_crypto_with_cert_resolver() .generate_server_crypto_with_cert_resolver()
.await?; .await?;
let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); let mut server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto));
server_config_h3.transport = Arc::new(transport_config_quic);
server_config_h3.concurrent_connections(self.globals.h3_max_concurrent_connections);
let (endpoint, mut incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?; let (endpoint, mut incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?;
info!("Start UDP proxy serving with HTTP/3 request for configured host names"); info!("Start UDP proxy serving with HTTP/3 request for configured host names");
@ -121,7 +131,9 @@ where
"HTTP/3 connection incoming (SNI {:?})", "HTTP/3 connection incoming (SNI {:?})",
new_server_name new_server_name
); );
self.clone().client_serve_h3(conn, new_server_name.as_ref()); // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは
// TODO: 通常のTLSと同じenumか何かにまとめたい
self.clone().connection_serve_h3(conn, new_server_name.as_ref());
} }
_ = server_crypto_rx.changed().fuse() => { _ = server_crypto_rx.changed().fuse() => {
if server_crypto_rx.borrow().is_none() { if server_crypto_rx.borrow().is_none() {