From 01a13d0168ffd9167b1e716272502d1409c35e07 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 23 Jul 2022 14:59:31 +0900 Subject: [PATCH] spawn handshake async task to avoid blocking --- src/constants.rs | 2 +- src/proxy/proxy_tls.rs | 69 ++++++++++++++++++++++++------------------ 2 files changed, 40 insertions(+), 31 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 4abc782..aaa06eb 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -4,9 +4,9 @@ pub const LISTEN_ADDRESSES_V6: &[&str] = &["[::]"]; // pub const HTTPS_LISTEN_PORT: u16 = 8443; pub const PROXY_TIMEOUT_SEC: u64 = 60; pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; +pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 3; pub const MAX_CLIENTS: usize = 512; pub const MAX_CONCURRENT_STREAMS: u32 = 64; -// #[cfg(feature = "tls")] pub const CERTS_WATCH_DELAY_SECS: u32 = 60; // #[cfg(feature = "http3")] diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index ac743d9..a9c12a1 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use tokio::{ net::TcpListener, sync::watch, - time::{sleep, Duration}, + time::{sleep, timeout, Duration}, }; use tokio_rustls::TlsAcceptor; @@ -56,28 +56,36 @@ where let server_clone = server.clone(); let self_inner = self.clone(); - let fut = async move { - match acceptor.accept(raw_stream).await { - Ok(stream) => { - // Retrieve SNI - let (_, conn) = stream.get_ref(); - let server_name = conn.sni_hostname(); - debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name); - let server_name = server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase())); - if server_name.is_none(){ - Err(anyhow!("No SNI is given")) - } else { - self_inner.client_serve(stream, server_clone, client_addr, server_name); // TODO: don't want to pass copied value... - Ok(()) + // spawns async handshake to avoid blocking thread by sequential handshake. + let handshake_fut = async move { + // timeout is introduced to avoid get stuck here. + match timeout(Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), acceptor.accept(raw_stream)).await { + Ok(x) => match x { + Ok(stream) => { + // Retrieve SNI + let (_, conn) = stream.get_ref(); + let server_name = conn.sni_hostname(); + debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name); + let server_name = server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase())); + if server_name.is_none(){ + Err(anyhow!("No SNI is given")) + } else { + // this immediately spawns another future to actually handle stream. so it is okay to introduce timeout for handshake. + self_inner.client_serve(stream, server_clone, client_addr, server_name); // TODO: don't want to pass copied value... + Ok(()) + } + }, + Err(e) => { + Err(anyhow!("Failed to handshake TLS: {}", e)) } }, Err(e) => { - Err(anyhow!("Failed to accept TLS stream {}", e)) + Err(anyhow!("Timeout to handshake TLS: {}", e)) } } }; self.globals.runtime_handle.spawn( async move { - if let Err(e) = fut.await { + if let Err(e) = handshake_fut.await { error!("{}", e); } }); @@ -131,11 +139,12 @@ where Ok(d) => d, Err(_) => continue }; - let new_server_name = if let Some(sn) = hsd_downcast.server_name { - sn.as_bytes().to_ascii_lowercase() - } else { - warn!("HTTP/3 no SNI is given"); - continue; + let new_server_name = match hsd_downcast.server_name { + Some(sn) => sn.as_bytes().to_ascii_lowercase(), + None => { + warn!("HTTP/3 no SNI is given"); + continue; + } }; debug!( "HTTP/3 connection incoming (SNI {:?})", @@ -161,7 +170,7 @@ where endpoint.set_server_config(Some(QuicServerConfig::with_crypto(server_crypto.clone().unwrap()))); } } - // complete => break + else => break } } endpoint.wait_idle().await; @@ -199,10 +208,10 @@ where _= self.listener_service_h3(rx) => { error!("UDP proxy service for QUIC exited"); }, - // complete => { - // error!("Something went wrong"); - // return Ok(()) - // } + else => { + error!("Something went wrong"); + return Ok(()) + } }; Ok(()) } else { @@ -213,10 +222,10 @@ where _ = self.listener_service(server, rx) => { error!("TCP proxy service for TLS exited"); }, - // complete => { - // error!("Something went wrong"); - // return Ok(()) - // } + else => { + error!("Something went wrong"); + return Ok(()) + } }; Ok(()) }