diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 77d8ad8..ac743d9 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -1,6 +1,5 @@ use super::proxy_main::{LocalExecutor, Proxy}; use crate::{constants::*, error::*, log::*}; -use futures::{future::FutureExt, select}; use hyper::{client::connect::Connect, server::conn::Http}; use rustls::ServerConfig; use std::sync::Arc; @@ -47,39 +46,50 @@ where // let mut server_crypto: Option> = None; let mut tls_acceptor: Option = None; loop { - select! { - tcp_cnx = tcp_listener.accept().fuse() => { + tokio::select! { + tcp_cnx = tcp_listener.accept() => { if tls_acceptor.is_none() || tcp_cnx.is_err() { continue; } let (raw_stream, client_addr) = tcp_cnx.unwrap(); + let acceptor = tls_acceptor.clone().unwrap(); + let server_clone = server.clone(); + let self_inner = self.clone(); - match tls_acceptor.as_ref().unwrap().accept(raw_stream).await { - Ok(stream) => { - // Retrieve SNI - let (_, conn) = stream.get_ref(); - let server_name = conn.sni_hostname(); - debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name); - let server_name = server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase())); - if server_name.is_none(){ - continue; + let fut = async move { + match acceptor.accept(raw_stream).await { + Ok(stream) => { + // Retrieve SNI + let (_, conn) = stream.get_ref(); + let server_name = conn.sni_hostname(); + debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name); + let server_name = server_name.map_or_else(|| None, |v| Some(v.as_bytes().to_ascii_lowercase())); + if server_name.is_none(){ + Err(anyhow!("No SNI is given")) + } else { + self_inner.client_serve(stream, server_clone, client_addr, server_name); // TODO: don't want to pass copied value... + Ok(()) + } + }, + Err(e) => { + Err(anyhow!("Failed to accept TLS stream {}", e)) } - self.clone().client_serve(stream, server.clone(), client_addr, server_name); // TODO: don't want to pass copied value... - }, - Err(e) => { - error!("Failed to accept TLS stream {}", e); - continue; } - } + }; + self.globals.runtime_handle.spawn( async move { + if let Err(e) = fut.await { + error!("{}", e); + } + }); } - _ = server_crypto_rx.changed().fuse() => { + _ = server_crypto_rx.changed() => { if server_crypto_rx.borrow().is_none() { break; } let server_crypto = server_crypto_rx.borrow().clone().unwrap(); tls_acceptor = Some(TlsAcceptor::from(server_crypto)); } - complete => break + else => break } } Ok(()) as Result<()> @@ -106,8 +116,8 @@ where let mut server_crypto: Option> = None; loop { - select! { - new_conn = incoming.next().fuse() => { + tokio::select! { + new_conn = incoming.next() => { if server_crypto.is_none() || new_conn.is_none() { continue; } @@ -141,7 +151,7 @@ where } }); } - _ = server_crypto_rx.changed().fuse() => { + _ = server_crypto_rx.changed() => { if server_crypto_rx.borrow().is_none() { break; } @@ -151,7 +161,7 @@ where endpoint.set_server_config(Some(QuicServerConfig::with_crypto(server_crypto.clone().unwrap()))); } } - complete => break + // complete => break } } endpoint.wait_idle().await; @@ -179,34 +189,34 @@ where #[cfg(feature = "http3")] { if self.globals.http3 { - select! { - _= self.cert_service(tx).fuse() => { + tokio::select! { + _= self.cert_service(tx) => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, rx.clone()).fuse() => { + _ = self.listener_service(server, rx.clone()) => { error!("TCP proxy service for TLS exited"); }, - _= self.listener_service_h3(rx).fuse() => { + _= self.listener_service_h3(rx) => { error!("UDP proxy service for QUIC exited"); }, - complete => { - error!("Something went wrong"); - return Ok(()) - } + // complete => { + // error!("Something went wrong"); + // return Ok(()) + // } }; Ok(()) } else { - select! { - _= self.cert_service(tx).fuse() => { + tokio::select! { + _= self.cert_service(tx) => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, rx).fuse() => { + _ = self.listener_service(server, rx) => { error!("TCP proxy service for TLS exited"); }, - complete => { - error!("Something went wrong"); - return Ok(()) - } + // complete => { + // error!("Something went wrong"); + // return Ok(()) + // } }; Ok(()) }