From 5840808021d4fd44c08905cfa88f5e7334e85060 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Thu, 7 Jul 2022 20:28:30 +0900 Subject: [PATCH] refactor --- src/constants.rs | 3 -- src/proxy/proxy_main.rs | 23 +-------- src/proxy/proxy_tls.rs | 106 +++++++++++++++++++--------------------- 3 files changed, 51 insertions(+), 81 deletions(-) diff --git a/src/constants.rs b/src/constants.rs index 9b1141c..9fef23b 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -8,9 +8,6 @@ pub const MAX_CONCURRENT_STREAMS: u32 = 16; // #[cfg(feature = "tls")] pub const CERTS_WATCH_DELAY_SECS: u32 = 10; -pub const GET_LISTENER_RETRY_MAX_CNT: u64 = 128; -pub const GET_LISTENER_RETRY_WAITING_MSEC: u64 = 10; - #[cfg(feature = "h3")] pub const H3_ALT_SVC_MAX_AGE: u32 = 60; #[cfg(feature = "h3")] diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index e0bdefa..2686c42 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -1,6 +1,6 @@ // use super::proxy_handler::handle_request; use super::Backends; -use crate::{constants::*, error::*, globals::Globals, log::*}; +use crate::{error::*, globals::Globals, log::*}; use hyper::{ client::connect::Connect, server::conn::Http, service::service_fn, Body, Client, Request, }; @@ -77,28 +77,9 @@ where }); } - // Work around to forcibly get tcp listener for "address already in use" - pub(super) async fn try_bind_tcp_listener(&self) -> Result { - let mut cnt = 0; - while cnt < GET_LISTENER_RETRY_MAX_CNT { - if let Ok(listener) = TcpListener::bind(&self.listening_on).await { - return Ok(listener); - } - cnt += 1; - tokio::time::sleep(tokio::time::Duration::from_millis( - GET_LISTENER_RETRY_WAITING_MSEC, - )) - .await; - } - - error!("Failed to get tcp listener: {}", self.listening_on); - Err(anyhow!("Failed to get tcp listener: {}", self.listening_on)) - } - async fn start_without_tls(self, server: Http) -> Result<()> { let listener_service = async { - // let tcp_listener = TcpListener::bind(&self.listening_on).await?; - let tcp_listener = self.try_bind_tcp_listener().await?; + let tcp_listener = TcpListener::bind(&self.listening_on).await?; info!("Start TCP proxy serving with HTTP request for configured host names"); while let Ok((stream, _client_addr)) = tcp_listener.accept().await { self diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 7384a03..6951561 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -6,6 +6,7 @@ use futures::{future::FutureExt, select}; use hyper::{client::connect::Connect, server::conn::Http}; use rustls::ServerConfig; use std::{sync::Arc, time::Duration}; +use tokio::net::TcpListener; impl Proxy where @@ -27,8 +28,7 @@ where // TCP Listener Service, i.e., http/2 and http/1.1 pub async fn listener_service(&self, server: Http) -> Result<()> { - // let tcp_listener = TcpListener::bind(&self.listening_on).await?; - let tcp_listener = self.try_bind_tcp_listener().await?; + let tcp_listener = TcpListener::bind(&self.listening_on).await?; info!("Start TCP proxy serving with HTTPS request for configured host names"); loop { @@ -71,6 +71,32 @@ where Ok(()) as Result<()> } + #[cfg(feature = "h3")] + async fn get_new_server_config_h3( + &self, + peeked_conn: &mut quinn::Connecting, + ) -> Option { + let hsd = if let Ok(h) = peeked_conn.handshake_data().await { + h + } else { + return None; + }; + let hsd_downcast = if let Ok(d) = hsd.downcast::() { + d + } else { + return None; + }; + let server_name = hsd_downcast.server_name?; + info!( + "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", + server_name + ); + let new_server_crypto = self.fetch_server_crypto(&server_name)?; + Some(quinn::ServerConfig::with_crypto(Arc::new( + new_server_crypto, + ))) + } + #[cfg(feature = "h3")] pub async fn listener_service_h3(&self) -> Result<()> { // TODO: Work around to initially serve incoming connection @@ -85,60 +111,48 @@ where .map(|(name, _)| name.to_string()) .collect(); ensure!(!tls_app_names.is_empty(), "No TLS supported app"); - let initial_app_name = tls_app_names.get(0).unwrap().as_str(); + let initial_app_name = tls_app_names.get(0).ok_or_else(|| anyhow!(""))?.as_str(); debug!( "HTTP/3 SNI multiplexer initial app_name: {}", initial_app_name ); - let backend_serve = self.backends.apps.get(initial_app_name).unwrap(); + let backend_serve = self + .backends + .apps + .get(initial_app_name) + .ok_or_else(|| anyhow!(""))?; while backend_serve.get_tls_server_config().is_none() { tokio::time::sleep(tokio::time::Duration::from_millis(10)).await; } - let server_crypto = backend_serve.get_tls_server_config().unwrap(); + let server_crypto = backend_serve + .get_tls_server_config() + .ok_or_else(|| anyhow!(""))?; let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); - let (endpoint, incoming) = self.try_bind_quic_listener(server_config_h3).await?; - // quinn::Endpoint::server(server_config_h3, self.listening_on).unwrap(); + let (endpoint, incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?; info!("Start UDP proxy serving with HTTP/3 request for configured host names"); let mut p = incoming.peekable(); loop { // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts. // peek() should work for that. - let success = if let Some(peeked_conn) = std::pin::Pin::new(&mut p).peek_mut().await { - let hsd = peeked_conn.handshake_data().await; - let hsd_downcast = hsd? - .downcast::() - .unwrap(); - if let Some(svn) = hsd_downcast.server_name { - if let Some(new_server_crypto) = self.fetch_server_crypto(&svn) { - // Set ServerConfig::set_server_config for given SNI - let mut new_server_config_h3 = - quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)); - if svn == "localhost" { - new_server_config_h3.concurrent_connections(512); - } - info!( - "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", - svn - ); - endpoint.set_server_config(Some(new_server_config_h3)); - true - } else { - false - } + let peeked_conn = std::pin::Pin::new(&mut p) + .peek_mut() + .await + .ok_or_else(|| anyhow!("Failed to peek"))?; + let is_acceptable = + if let Some(new_server_config) = self.get_new_server_config_h3(peeked_conn).await { + // Set ServerConfig::set_server_config for given SNI + endpoint.set_server_config(Some(new_server_config)); + true } else { - debug!("HTTP/3 no SNI is given"); false - } - } else { - false - }; + }; // Then acquire actual connection let peekable_incoming = std::pin::Pin::new(&mut p); if let Some(conn) = peekable_incoming.get_mut().next().await { - if success { + if is_acceptable { self.clone().client_serve_h3(conn).await; } } else { @@ -193,28 +207,6 @@ where } } - // Work around to forcibly get quic listener for "address already in use" - #[cfg(feature = "h3")] - async fn try_bind_quic_listener( - &self, - server_config: quinn::ServerConfig, - ) -> Result<(quinn::Endpoint, quinn::Incoming)> { - let mut cnt = 0; - while cnt < GET_LISTENER_RETRY_MAX_CNT { - if let Ok(listener) = quinn::Endpoint::server(server_config.clone(), self.listening_on) { - return Ok(listener); - } - cnt += 1; - tokio::time::sleep(tokio::time::Duration::from_millis( - GET_LISTENER_RETRY_WAITING_MSEC, - )) - .await; - } - - error!("Failed to get quic listener: {}", self.listening_on); - Err(anyhow!("Failed to get tcp listener: {}", self.listening_on)) - } - fn fetch_server_crypto(&self, server_name: &str) -> Option { let backend_serve = if let Some(backend_serve) = self.backends.apps.get(server_name) { backend_serve