diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 61ee792..8d7e55c 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -2,7 +2,7 @@ use super::proxy_main::{LocalExecutor, Proxy}; use crate::{constants::*, error::*, log::*}; #[cfg(feature = "h3")] use futures::StreamExt; -use futures::{future::FutureExt, join, select}; +use futures::{future::FutureExt, select}; use hyper::{client::connect::Connect, server::conn::Http}; use rustls::ServerConfig; use std::{sync::Arc, time::Duration}; @@ -11,174 +11,190 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn start_with_tls(self, server: Http) -> Result<()> { - let cert_service = async { - info!("Start cert watch service for {}", self.listening_on); - loop { - for (server_name, backend) in self.backends.apps.iter() { - if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { - if let Err(_e) = backend.update_server_config().await { - warn!("Failed to update certs for {}: {}", server_name, _e); - } + pub async fn cert_service(&self) { + info!("Start cert watch service for {}", self.listening_on); + loop { + for (server_name, backend) in self.backends.apps.iter() { + if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { + if let Err(_e) = backend.update_server_config().await { + warn!("Failed to update certs for {}: {}", server_name, _e); } } - tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; } - }; + tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; + } + } - // TCP Listener Service, i.e., http/2 and http/1.1 - let listener_service = async { - // let tcp_listener = TcpListener::bind(&self.listening_on).await?; - let tcp_listener = self.try_bind_tcp_listener().await?; - info!( - "Start TCP proxy serving with HTTPS request for configured host names: {:?}", - tcp_listener.local_addr()? - ); + // TCP Listener Service, i.e., http/2 and http/1.1 + pub async fn listener_service(&self, server: Http) -> Result<()> { + // let tcp_listener = TcpListener::bind(&self.listening_on).await?; + let tcp_listener = self.try_bind_tcp_listener().await?; + info!( + "Start TCP proxy serving with HTTPS request for configured host names: {:?}", + tcp_listener.local_addr()? + ); - loop { - select! { - tcp_cnx = tcp_listener.accept().fuse() => { - if tcp_cnx.is_err() { - continue; - } - let (raw_stream, _client_addr) = tcp_cnx.unwrap(); - - // First check SNI - let rustls_acceptor = rustls::server::Acceptor::new().unwrap(); - let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls_acceptor, raw_stream).await; - if acceptor.is_err() { - continue; - } - let start = acceptor.unwrap(); - let client_hello = start.client_hello(); - debug!("SNI in ClientHello: {:?}", client_hello.server_name()); - // Find server config for given SNI - let svn = if let Some(svn) = client_hello.server_name() { - svn - } else { - info!("No SNI in ClientHello"); - continue; - }; - let server_crypto = if let Some(p) = self.fetch_server_crypto(svn) { - p - } else { - continue; - }; - // Finally serve the TLS connection - if let Ok(stream) = start.into_stream(Arc::new(server_crypto)).await { - self.clone().client_serve(stream, server.clone(), _client_addr).await - } + loop { + select! { + tcp_cnx = tcp_listener.accept().fuse() => { + if tcp_cnx.is_err() { + continue; } - complete => break - } - } - Ok(()) as Result<()> - }; + let (raw_stream, _client_addr) = tcp_cnx.unwrap(); - /////////////////////// - #[cfg(feature = "h3")] - let listener_service_h3 = async { - // TODO: Work around to initially serve incoming connection - // かなり適当。エラーが出たり出なかったり。原因がわからない… - let tls_app_names: Vec = self - .backends - .apps - .iter() - .filter(|&(_, backend)| { - backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() - }) - .map(|(name, _)| name.to_string()) - .collect(); - ensure!(!tls_app_names.is_empty(), "No TLS supported app"); - let initial_app_name = tls_app_names.get(0).unwrap().as_str(); - debug!( - "HTTP/3 SNI multiplexer initial app_name: {}", - initial_app_name - ); - let backend_serve = self.backends.apps.get(initial_app_name).unwrap(); - while backend_serve.get_tls_server_config().is_none() { - tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; - } - let server_crypto = backend_serve.get_tls_server_config().unwrap(); - let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); - - let (endpoint, incoming) = self.try_bind_quic_listener(server_config_h3).await?; - // quinn::Endpoint::server(server_config_h3, self.listening_on).unwrap(); - info!( - "Start UDP proxy serving with HTTP/3 request for configured host names: {:?}", - endpoint.local_addr()? - ); - - let mut p = incoming.peekable(); - loop { - // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts. - // peek() should work for that. - let success = if let Some(peeked_conn) = std::pin::Pin::new(&mut p).peek_mut().await { - let hsd = peeked_conn.handshake_data().await; - let hsd_downcast = hsd? - .downcast::() - .unwrap(); - if let Some(svn) = hsd_downcast.server_name { - if let Some(new_server_crypto) = self.fetch_server_crypto(&svn) { - // Set ServerConfig::set_server_config for given SNI - let mut new_server_config_h3 = - quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)); - if svn == "localhost" { - new_server_config_h3.concurrent_connections(512); - } - info!( - "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", - svn - ); - endpoint.set_server_config(Some(new_server_config_h3)); - true - } else { - false - } + // First check SNI + let rustls_acceptor = rustls::server::Acceptor::new().unwrap(); + let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls_acceptor, raw_stream).await; + if acceptor.is_err() { + continue; + } + let start = acceptor.unwrap(); + let client_hello = start.client_hello(); + debug!("SNI in ClientHello: {:?}", client_hello.server_name()); + // Find server config for given SNI + let svn = if let Some(svn) = client_hello.server_name() { + svn + } else { + info!("No SNI in ClientHello"); + continue; + }; + let server_crypto = if let Some(p) = self.fetch_server_crypto(svn) { + p + } else { + continue; + }; + // Finally serve the TLS connection + if let Ok(stream) = start.into_stream(Arc::new(server_crypto)).await { + self.clone().client_serve(stream, server.clone(), _client_addr).await + } + } + complete => break + } + } + Ok(()) as Result<()> + } + + #[cfg(feature = "h3")] + pub async fn listener_service_h3(&self) -> Result<()> { + // TODO: Work around to initially serve incoming connection + // かなり適当。エラーが出たり出なかったり。原因がわからない… + let tls_app_names: Vec = self + .backends + .apps + .iter() + .filter(|&(_, backend)| { + backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() + }) + .map(|(name, _)| name.to_string()) + .collect(); + ensure!(!tls_app_names.is_empty(), "No TLS supported app"); + let initial_app_name = tls_app_names.get(0).unwrap().as_str(); + debug!( + "HTTP/3 SNI multiplexer initial app_name: {}", + initial_app_name + ); + let backend_serve = self.backends.apps.get(initial_app_name).unwrap(); + while backend_serve.get_tls_server_config().is_none() { + tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; + } + let server_crypto = backend_serve.get_tls_server_config().unwrap(); + let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + + let (endpoint, incoming) = self.try_bind_quic_listener(server_config_h3).await?; + // quinn::Endpoint::server(server_config_h3, self.listening_on).unwrap(); + info!( + "Start UDP proxy serving with HTTP/3 request for configured host names: {:?}", + endpoint.local_addr()? + ); + + let mut p = incoming.peekable(); + loop { + // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts. + // peek() should work for that. + let success = if let Some(peeked_conn) = std::pin::Pin::new(&mut p).peek_mut().await { + let hsd = peeked_conn.handshake_data().await; + let hsd_downcast = hsd? + .downcast::() + .unwrap(); + if let Some(svn) = hsd_downcast.server_name { + if let Some(new_server_crypto) = self.fetch_server_crypto(&svn) { + // Set ServerConfig::set_server_config for given SNI + let mut new_server_config_h3 = + quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)); + if svn == "localhost" { + new_server_config_h3.concurrent_connections(512); + } + info!( + "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", + svn + ); + endpoint.set_server_config(Some(new_server_config_h3)); + true } else { - debug!("HTTP/3 no SNI is given"); false } } else { + debug!("HTTP/3 no SNI is given"); false - }; - - // Then acquire actual connection - let peekable_incoming = std::pin::Pin::new(&mut p); - if let Some(conn) = peekable_incoming.get_mut().next().await { - if success { - self.clone().client_serve_h3(conn).await; - } - } else { - break; } - } - endpoint.wait_idle().await; - Ok(()) as Result<()> - }; + } else { + false + }; + // Then acquire actual connection + let peekable_incoming = std::pin::Pin::new(&mut p); + if let Some(conn) = peekable_incoming.get_mut().next().await { + if success { + self.clone().client_serve_h3(conn).await; + } + } else { + break; + } + } + endpoint.wait_idle().await; + Ok(()) as Result<()> + } + + pub async fn start_with_tls(self, server: Http) -> Result<()> { #[cfg(not(feature = "h3"))] { - join!(listener_service, cert_service).0 + select! { + _= cert_service => { + error!("Cert service for TLS exited"); + }, + _ = listener_service => { + error!("TCP proxy service for TLS exited"); + }, + + }; + Ok(()) } #[cfg(feature = "h3")] { if self.globals.http3 { tokio::select! { - _= cert_service => { + _= self.cert_service() => { error!("Cert service for TLS exited"); }, - _ = listener_service => { + _ = self.listener_service(server) => { error!("TCP proxy service for TLS exited"); }, - _= listener_service_h3 => { - error!("UDP proxy service for TLS exited"); + _= self.listener_service_h3() => { + error!("UDP proxy service for QUIC exited"); }, }; - // join!(listener_service, cert_service, listener_service_h3).0 Ok(()) } else { - join!(listener_service, cert_service).0 + tokio::select! { + _= self.cert_service() => { + error!("Cert service for TLS exited"); + }, + _ = self.listener_service(server) => { + error!("TCP proxy service for TLS exited"); + }, + + }; + Ok(()) } } }