diff --git a/Cargo.toml b/Cargo.toml index b18377b..c42df7f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,7 +48,6 @@ hyper-rustls = { version = "0.23.0", default-features = false, features = [ "http1", "http2", ] } -parking_lot = "0.12.1" quinn = { version = "0.8.3", optional = true } h3 = { git = "https://github.com/hyperium/h3.git" } h3-quinn = { git = "https://github.com/hyperium/h3.git" } diff --git a/src/config/parse.rs b/src/config/parse.rs index 650a8e7..347164b 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -7,7 +7,6 @@ use crate::{ proxy::{Backend, Backends, ReverseProxy, Upstream, UpstreamOption}, }; use clap::Arg; -use parking_lot::Mutex; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::net::SocketAddr; @@ -132,7 +131,6 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> tls_cert_path, tls_cert_key_path, https_redirection, - server_config: Mutex::new(None), }, ); info!("Registering application: {} ({})", app_name, server_name); diff --git a/src/proxy/backend.rs b/src/proxy/backend.rs index 18794e0..2e9ae53 100644 --- a/src/proxy/backend.rs +++ b/src/proxy/backend.rs @@ -1,6 +1,5 @@ use super::UpstreamOption; use crate::log::*; -use parking_lot::Mutex; use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::{ @@ -28,7 +27,6 @@ pub struct Backend { pub tls_cert_path: Option, pub tls_cert_key_path: Option, pub https_redirection: Option, - pub server_config: Mutex>, } #[derive(Debug, Clone)] @@ -89,15 +87,7 @@ impl Upstream { } impl Backend { - pub fn get_tls_server_config(&self) -> Option { - let lock = self.server_config.lock(); - let opt_clone = lock.clone(); - if let Some(sc) = opt_clone { - return Some(sc); - } - None - } - pub async fn update_server_config(&self) -> io::Result<()> { + pub async fn update_server_config(&self) -> io::Result { debug!("Update TLS server config"); let (certs_path, certs_keys_path) = if let (Some(c), Some(k)) = (self.tls_cert_path.as_ref(), self.tls_cert_key_path.as_ref()) { @@ -204,10 +194,7 @@ impl Backend { server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; } - let mut config_store = self.server_config.lock(); - *config_store = Some(server_config); - // server_config; - Ok(()) + Ok(server_config) } } diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index 7c2e6af..d77c6a5 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -218,7 +218,7 @@ fn generate_request_forwarded( remove_hop_header(headers); // X-Forwarded-For add_forwarding_header(headers, client_addr)?; - println!("{:?}", headers); + // println!("{:?}", headers); // Add te: trailer if te_trailer if te_trailers { diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index c9d8914..148c119 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -4,39 +4,58 @@ use crate::{constants::*, error::*, log::*}; use futures::StreamExt; use futures::{future::FutureExt, select}; use hyper::{client::connect::Connect, server::conn::Http}; +use rustc_hash::FxHashMap as HashMap; use rustls::ServerConfig; use std::sync::Arc; -use tokio::{net::TcpListener, time::Duration}; +use tokio::{net::TcpListener, sync::watch, time::Duration}; + +type ServerCryptoMap = HashMap; impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn cert_service(&self) { + async fn cert_service(&self, server_crypto_tx: watch::Sender>) { info!("Start cert watch service"); loop { + let mut hm_server_config = HashMap::::default(); for (server_name, backend) in self.backends.apps.iter() { if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { - if let Err(_e) = backend.update_server_config().await { - warn!("Failed to update certs for {}: {}", server_name, _e); + match backend.update_server_config().await { + Err(_e) => { + error!("Failed to update certs for {}: {}", server_name, _e); + break; + } + Ok(server_config) => { + hm_server_config.insert(server_name.to_owned(), server_config); + } } } } + if let Err(_e) = server_crypto_tx.send(Some(hm_server_config)) { + error!("Failed to populate server crypto"); + break; + } tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; } } // TCP Listener Service, i.e., http/2 and http/1.1 - pub async fn listener_service(&self, server: Http) -> Result<()> { + async fn listener_service( + &self, + server: Http, + mut server_crypto_rx: watch::Receiver>, + ) -> Result<()> { let tcp_listener = TcpListener::bind(&self.listening_on).await?; info!("Start TCP proxy serving with HTTPS request for configured host names"); + let mut server_crypto_map: Option = None; loop { select! { tcp_cnx = tcp_listener.accept().fuse() => { // First check SNI let rustls_acceptor = rustls::server::Acceptor::new(); - if tcp_cnx.is_err() || rustls_acceptor.is_err() { + if server_crypto_map.is_none() || tcp_cnx.is_err() || rustls_acceptor.is_err() { continue; } let (raw_stream, _client_addr) = tcp_cnx.unwrap(); @@ -55,8 +74,8 @@ where info!("No SNI in ClientHello"); continue; }; - let server_crypto = if let Some(p) = self.fetch_server_crypto(svn) { - p + let server_crypto = if let Some(p) = server_crypto_map.as_ref().unwrap().get(svn) { + p.to_owned() } else { continue; }; @@ -65,6 +84,12 @@ where self.clone().client_serve(stream, server.clone(), _client_addr).await } } + _ = server_crypto_rx.changed().fuse() => { + if server_crypto_rx.borrow().is_none() { + break; + } + server_crypto_map = server_crypto_rx.borrow().clone(); + } complete => break } } @@ -72,10 +97,11 @@ where } #[cfg(feature = "h3")] - async fn parse_sni_and_get_config_h3( + async fn parse_sni_and_get_crypto_h3( &self, peeked_conn: &mut quinn::Connecting, - ) -> Option { + server_crypto_map: &ServerCryptoMap, + ) -> Option { let hsd = if let Ok(h) = peeked_conn.handshake_data().await { h } else { @@ -91,14 +117,14 @@ where "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", server_name ); - let new_server_crypto = self.fetch_server_crypto(&server_name)?; - Some(quinn::ServerConfig::with_crypto(Arc::new( - new_server_crypto, - ))) + server_crypto_map.get(&server_name).cloned() } #[cfg(feature = "h3")] - pub async fn listener_service_h3(&self) -> Result<()> { + async fn listener_service_h3( + &self, + mut server_crypto_rx: watch::Receiver>, + ) -> Result<()> { // TODO: Work around to initially serve incoming connection // かなり適当。エラーが出たり出なかったり。原因がわからない… let tls_app_names: Vec = self @@ -121,42 +147,49 @@ where .apps .get(initial_app_name) .ok_or_else(|| anyhow!(""))?; - while backend_serve.get_tls_server_config().is_none() { - tokio::time::sleep(Duration::from_millis(10)).await; - } - let server_crypto = backend_serve - .get_tls_server_config() - .ok_or_else(|| anyhow!(""))?; - let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + let initial_server_crypto = backend_serve.update_server_config().await?; + + let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(initial_server_crypto)); let (endpoint, incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?; info!("Start UDP proxy serving with HTTP/3 request for configured host names"); + let mut server_crypto_map: Option = None; let mut p = incoming.peekable(); loop { - // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts. - // peek() should work for that. - let peeked_conn = std::pin::Pin::new(&mut p) - .peek_mut() - .await - .ok_or_else(|| anyhow!("Failed to peek"))?; - let is_acceptable = - if let Some(new_server_config) = self.parse_sni_and_get_config_h3(peeked_conn).await { - // Set ServerConfig::set_server_config for given SNI - endpoint.set_server_config(Some(new_server_config)); - true - } else { - false - }; - - // Then acquire actual connection - let peekable_incoming = std::pin::Pin::new(&mut p); - if let Some(conn) = peekable_incoming.get_mut().next().await { - if is_acceptable { - self.clone().client_serve_h3(conn).await; + select! { + // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts. + // peek() should work for that. + peeked_conn = std::pin::Pin::new(&mut p).peek_mut().fuse() => { + if server_crypto_map.is_none() || peeked_conn.is_none() { + continue; + } + let peeked_conn = peeked_conn.unwrap(); + let is_acceptable = + if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { + // Set ServerConfig::set_server_config for given SNI + endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)))); + true + } else { + false + }; + // Then acquire actual connection + let peekable_incoming = std::pin::Pin::new(&mut p); + if let Some(conn) = peekable_incoming.get_mut().next().await { + if is_acceptable { + self.clone().client_serve_h3(conn).await; + } + } else { + continue; + } } - } else { - break; + _ = server_crypto_rx.changed().fuse() => { + if server_crypto_rx.borrow().is_none() { + break; + } + server_crypto_map = server_crypto_rx.borrow().clone(); + } + complete => break } } endpoint.wait_idle().await; @@ -164,16 +197,16 @@ where } pub async fn start_with_tls(self, server: Http) -> Result<()> { + let (tx, rx) = watch::channel::>(None); #[cfg(not(feature = "h3"))] { select! { - _= cert_service => { + _= self.cert_service(tx) => { error!("Cert service for TLS exited"); }, - _ = listener_service => { + _ = self.listener_service(server, rx) => { error!("TCP proxy service for TLS exited"); }, - }; Ok(()) } @@ -181,23 +214,23 @@ where { if self.globals.http3 { tokio::select! { - _= self.cert_service() => { + _= self.cert_service(tx) => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server) => { + _ = self.listener_service(server, rx.clone()) => { error!("TCP proxy service for TLS exited"); }, - _= self.listener_service_h3() => { + _= self.listener_service_h3(rx) => { error!("UDP proxy service for QUIC exited"); }, }; Ok(()) } else { tokio::select! { - _= self.cert_service() => { + _= self.cert_service(tx) => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server) => { + _ = self.listener_service(server, rx) => { error!("TCP proxy service for TLS exited"); }, @@ -206,28 +239,4 @@ where } } } - - fn fetch_server_crypto(&self, server_name: &str) -> Option { - let backend_serve = if let Some(backend_serve) = self.backends.apps.get(server_name) { - backend_serve - } else { - warn!( - "No configuration for the server name {} given in client_hello", - server_name - ); - return None; - }; - - if backend_serve.tls_cert_path.is_none() { - // at least cert does exit - warn!("SNI indicates a site that doesn't support TLS."); - return None; - } - if let Some(p) = backend_serve.get_tls_server_config() { - Some(p) - } else { - error!("Failed to load server config"); - None - } - } }