spawn handshake async task to avoid blocking

This commit is contained in:
Jun Kurihara 2022-07-23 14:59:31 +09:00
commit 01a13d0168
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
2 changed files with 40 additions and 31 deletions

View file

@ -4,9 +4,9 @@ pub const LISTEN_ADDRESSES_V6: &[&str] = &["[::]"];
// pub const HTTPS_LISTEN_PORT: u16 = 8443; // pub const HTTPS_LISTEN_PORT: u16 = 8443;
pub const PROXY_TIMEOUT_SEC: u64 = 60; pub const PROXY_TIMEOUT_SEC: u64 = 60;
pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; pub const UPSTREAM_TIMEOUT_SEC: u64 = 60;
pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 3;
pub const MAX_CLIENTS: usize = 512; pub const MAX_CLIENTS: usize = 512;
pub const MAX_CONCURRENT_STREAMS: u32 = 64; pub const MAX_CONCURRENT_STREAMS: u32 = 64;
// #[cfg(feature = "tls")]
pub const CERTS_WATCH_DELAY_SECS: u32 = 60; pub const CERTS_WATCH_DELAY_SECS: u32 = 60;
// #[cfg(feature = "http3")] // #[cfg(feature = "http3")]

View file

@ -6,7 +6,7 @@ use std::sync::Arc;
use tokio::{ use tokio::{
net::TcpListener, net::TcpListener,
sync::watch, sync::watch,
time::{sleep, Duration}, time::{sleep, timeout, Duration},
}; };
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
@ -56,28 +56,36 @@ where
let server_clone = server.clone(); let server_clone = server.clone();
let self_inner = self.clone(); let self_inner = self.clone();
let fut = async move { // spawns async handshake to avoid blocking thread by sequential handshake.
match acceptor.accept(raw_stream).await { let handshake_fut = async move {
Ok(stream) => { // timeout is introduced to avoid get stuck here.
// Retrieve SNI match timeout(Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), acceptor.accept(raw_stream)).await {
let (_, conn) = stream.get_ref(); Ok(x) => match x {
let server_name = conn.sni_hostname(); Ok(stream) => {
debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name); // Retrieve SNI
let server_name = server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase())); let (_, conn) = stream.get_ref();
if server_name.is_none(){ let server_name = conn.sni_hostname();
Err(anyhow!("No SNI is given")) debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name);
} else { let server_name = server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase()));
self_inner.client_serve(stream, server_clone, client_addr, server_name); // TODO: don't want to pass copied value... if server_name.is_none(){
Ok(()) Err(anyhow!("No SNI is given"))
} else {
// this immediately spawns another future to actually handle stream. so it is okay to introduce timeout for handshake.
self_inner.client_serve(stream, server_clone, client_addr, server_name); // TODO: don't want to pass copied value...
Ok(())
}
},
Err(e) => {
Err(anyhow!("Failed to handshake TLS: {}", e))
} }
}, },
Err(e) => { Err(e) => {
Err(anyhow!("Failed to accept TLS stream {}", e)) Err(anyhow!("Timeout to handshake TLS: {}", e))
} }
} }
}; };
self.globals.runtime_handle.spawn( async move { self.globals.runtime_handle.spawn( async move {
if let Err(e) = fut.await { if let Err(e) = handshake_fut.await {
error!("{}", e); error!("{}", e);
} }
}); });
@ -131,11 +139,12 @@ where
Ok(d) => d, Ok(d) => d,
Err(_) => continue Err(_) => continue
}; };
let new_server_name = if let Some(sn) = hsd_downcast.server_name { let new_server_name = match hsd_downcast.server_name {
sn.as_bytes().to_ascii_lowercase() Some(sn) => sn.as_bytes().to_ascii_lowercase(),
} else { None => {
warn!("HTTP/3 no SNI is given"); warn!("HTTP/3 no SNI is given");
continue; continue;
}
}; };
debug!( debug!(
"HTTP/3 connection incoming (SNI {:?})", "HTTP/3 connection incoming (SNI {:?})",
@ -161,7 +170,7 @@ where
endpoint.set_server_config(Some(QuicServerConfig::with_crypto(server_crypto.clone().unwrap()))); endpoint.set_server_config(Some(QuicServerConfig::with_crypto(server_crypto.clone().unwrap())));
} }
} }
// complete => break else => break
} }
} }
endpoint.wait_idle().await; endpoint.wait_idle().await;
@ -199,10 +208,10 @@ where
_= self.listener_service_h3(rx) => { _= self.listener_service_h3(rx) => {
error!("UDP proxy service for QUIC exited"); error!("UDP proxy service for QUIC exited");
}, },
// complete => { else => {
// error!("Something went wrong"); error!("Something went wrong");
// return Ok(()) return Ok(())
// } }
}; };
Ok(()) Ok(())
} else { } else {
@ -213,10 +222,10 @@ where
_ = self.listener_service(server, rx) => { _ = self.listener_service(server, rx) => {
error!("TCP proxy service for TLS exited"); error!("TCP proxy service for TLS exited");
}, },
// complete => { else => {
// error!("Something went wrong"); error!("Something went wrong");
// return Ok(()) return Ok(())
// } }
}; };
Ok(()) Ok(())
} }