From d7d782499aa391d436834736def3db12464e3723 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 19 Jul 2022 17:39:06 +0900 Subject: [PATCH] totally refine sni inspection mechanism using rustls --- src/backend.rs | 118 +++------------------------ src/proxy/proxy_h3.rs | 4 +- src/proxy/proxy_main.rs | 5 +- src/proxy/proxy_tls.rs | 172 ++++++++++++---------------------------- 4 files changed, 65 insertions(+), 234 deletions(-) diff --git a/src/backend.rs b/src/backend.rs index a5ed13f..d283cd9 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,8 +1,6 @@ use crate::{backend_opt::UpstreamOption, log::*}; -use h3::server; use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use rustls::server::ResolvesServerCert; use std::{ borrow::Cow, fs::File, @@ -15,7 +13,7 @@ use std::{ }; use tokio_rustls::rustls::{ server::ResolvesServerCertUsingSni, - sign::{any_supported_type, CertifiedKey, SigningKey}, + sign::{any_supported_type, CertifiedKey}, Certificate, PrivateKey, ServerConfig, }; @@ -131,112 +129,6 @@ impl Upstream { } impl Backend { - pub async fn update_server_config(&self) -> io::Result { - debug!("Update TLS server config"); - let (certs_path, certs_keys_path) = - if let (Some(c), Some(k)) = (self.tls_cert_path.as_ref(), self.tls_cert_key_path.as_ref()) { - (c, k) - } else { - return Err(io::Error::new( - io::ErrorKind::Other, - "Invalid certs and keys paths", - )); - }; - let certs: Vec<_> = { - let certs_path_str = certs_path.display().to_string(); - let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { - io::Error::new( - e.kind(), - format!( - "Unable to load the certificates [{}]: {}", - certs_path_str, e - ), - ) - })?); - rustls_pemfile::certs(&mut reader).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to parse the certificates", - ) - })? - } - .drain(..) - .map(Certificate) - .collect(); - let certs_keys: Vec<_> = { - let certs_keys_path_str = certs_keys_path.display().to_string(); - let encoded_keys = { - let mut encoded_keys = vec![]; - File::open(certs_keys_path) - .map_err(|e| { - io::Error::new( - e.kind(), - format!( - "Unable to load the certificate keys [{}]: {}", - certs_keys_path_str, e - ), - ) - })? - .read_to_end(&mut encoded_keys)?; - encoded_keys - }; - let mut reader = Cursor::new(encoded_keys); - let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to parse the certificates private keys (PKCS8)", - ) - })?; - reader.set_position(0); - let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader)?; - let mut keys = pkcs8_keys; - keys.append(&mut rsa_keys); - if keys.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "No private keys found - Make sure that they are in PKCS#8/PEM format", - )); - } - keys.drain(..).map(PrivateKey).collect() - }; - - let mut server_config = certs_keys - .into_iter() - .find_map(|certs_key| { - let server_config_builder = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth(); - if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) { - Some(found_config) - } else { - None - } - }) - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to find a valid certificate and key", - ) - })?; - - #[cfg(feature = "h3")] - { - server_config.alpn_protocols = vec![ - b"h3".to_vec(), - b"hq-29".to_vec(), // quinn draft example TODO: remove later - b"h2".to_vec(), - b"http/1.1".to_vec(), - ]; - } - #[cfg(not(feature = "h3"))] - { - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - } - - // server_config; - Ok(server_config) - } - pub fn read_certs_and_key(&self) -> io::Result { debug!("Read TLS server certificates and private key"); let (certs_path, certs_keys_path) = @@ -330,6 +222,7 @@ impl Backends { ) -> Result { let mut resolver = ResolvesServerCertUsingSni::new(); + let mut cnt = 0; for (_, backend) in self.apps.iter() { if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { match backend.read_certs_and_key() { @@ -340,6 +233,12 @@ impl Backends { backend.server_name.as_str(), e ) + } else { + debug!( + "Add certificate for server_name: {}", + backend.server_name.as_str() + ); + cnt += 1; } } Err(e) => { @@ -352,6 +251,7 @@ impl Backends { } } } + debug!("Load certificate chain for {} server_name's", cnt); let mut server_config = ServerConfig::builder() .with_safe_defaults() diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs index 955e5e8..087a69c 100644 --- a/src/proxy/proxy_h3.rs +++ b/src/proxy/proxy_h3.rs @@ -10,7 +10,7 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) { + pub(super) fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) { let clients_count = self.globals.clients_count.clone(); if clients_count.increment() > self.globals.max_clients { clients_count.decrement(); @@ -29,7 +29,7 @@ where }); } - pub async fn handle_connection_h3( + async fn handle_connection_h3( self, conn: quinn::Connecting, tls_server_name: ServerNameLC, diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index 70054a3..da99021 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -45,7 +45,7 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn client_serve( + pub(super) fn client_serve( self, stream: I, server: Http, @@ -94,8 +94,7 @@ where while let Ok((stream, _client_addr)) = tcp_listener.accept().await { self .clone() - .client_serve(stream, server.clone(), _client_addr, None) - .await; + .client_serve(stream, server.clone(), _client_addr, None); } Ok(()) as Result<()> }; diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 24d2f32..750ea2d 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -1,45 +1,32 @@ use super::proxy_main::{LocalExecutor, Proxy}; -use crate::{backend::ServerNameLC, constants::*, error::*, log::*}; +use crate::{constants::*, error::*, log::*}; #[cfg(feature = "h3")] use futures::StreamExt; use futures::{future::FutureExt, select}; use hyper::{client::connect::Connect, server::conn::Http}; -use rustc_hash::FxHashMap as HashMap; use rustls::ServerConfig; -#[cfg(feature = "h3")] -use std::pin::Pin; use std::sync::Arc; use tokio::{net::TcpListener, sync::watch, time::Duration}; -type ServerCryptoMap = HashMap>; - impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - async fn cert_service(&self, server_crypto_tx: watch::Sender>) { + async fn cert_service(&self, server_crypto_tx: watch::Sender>>) { info!("Start cert watch service"); loop { - let mut hm_server_config = HashMap::>::default(); - for (server_name_bytes, backend) in self.globals.backends.apps.iter() { - if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { - match backend.update_server_config().await { - Err(_e) => { - error!( - "Failed to update certs for {}: {}", - &backend.server_name, _e - ); - break; - } - Ok(server_config) => { - hm_server_config.insert(server_name_bytes.to_vec(), Arc::new(server_config)); - } - } + if let Ok(server_crypto) = self + .globals + .backends + .generate_server_crypto_with_cert_resolver() + .await + { + if let Err(_e) = server_crypto_tx.send(Some(Arc::new(server_crypto))) { + error!("Failed to populate server crypto"); + break; } - } - if let Err(_e) = server_crypto_tx.send(Some(hm_server_config)) { - error!("Failed to populate server crypto"); - break; + } else { + error!("Failed to update certs"); } tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; } @@ -49,18 +36,18 @@ where async fn listener_service( &self, server: Http, - mut server_crypto_rx: watch::Receiver>, + mut server_crypto_rx: watch::Receiver>>, ) -> Result<()> { let tcp_listener = TcpListener::bind(&self.listening_on).await?; info!("Start TCP proxy serving with HTTPS request for configured host names"); - let mut server_crypto_map: Option = None; + let mut server_crypto: Option> = None; loop { select! { tcp_cnx = tcp_listener.accept().fuse() => { // First check SNI let rustls_acceptor = rustls::server::Acceptor::new(); - if server_crypto_map.is_none() || tcp_cnx.is_err() || rustls_acceptor.is_err() { + if server_crypto.is_none() || tcp_cnx.is_err() || rustls_acceptor.is_err() { continue; } let (raw_stream, _client_addr) = tcp_cnx.unwrap(); @@ -78,21 +65,16 @@ where } let server_name = client_hello.server_name().unwrap().to_ascii_lowercase(); debug!("SNI in ClientHello: {:?}", server_name); - let server_crypto = server_crypto_map.as_ref().unwrap().get(server_name.as_bytes()); - if server_crypto.is_none() { - debug!("No TLS serving app for {}", server_name); - continue; - }; // Finally serve the TLS connection - if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await { - self.clone().client_serve(stream, server.clone(), _client_addr, Some(server_name.as_bytes())).await + if let Ok(stream) = start.into_stream(server_crypto.clone().unwrap()).await { + self.clone().client_serve(stream, server.clone(), _client_addr, Some(server_name.as_bytes())) } } _ = server_crypto_rx.changed().fuse() => { if server_crypto_rx.borrow().is_none() { break; } - server_crypto_map = server_crypto_rx.borrow().clone(); + server_crypto = server_crypto_rx.borrow().clone(); } complete => break } @@ -100,110 +82,60 @@ where Ok(()) as Result<()> } - #[cfg(feature = "h3")] - async fn parse_sni_and_get_crypto_h3<'a, 'b>( - &self, - peeked_conn: &mut quinn::Connecting, - server_crypto_map: &'a ServerCryptoMap, - ) -> Option<(&'a ServerNameLC, &'a Arc)> { - let hsd = if let Ok(h) = peeked_conn.handshake_data().await { - h - } else { - return None; - }; - let hsd_downcast = if let Ok(d) = hsd.downcast::() { - d - } else { - return None; - }; - let server_name = hsd_downcast.server_name?.to_ascii_lowercase(); - info!( - "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", - server_name - ); - server_crypto_map.get_key_value(&server_name.into_bytes()) - // .map_or_else(|| None, |(k, v)| Some((k.clone(), v.clone()))); - } - #[cfg(feature = "h3")] async fn listener_service_h3( &self, - mut server_crypto_rx: watch::Receiver>, + mut server_crypto_rx: watch::Receiver>>, ) -> Result<()> { - // TODO: Work around to initially serve incoming connection - // かなり適当。エラーが出たり出なかったり。原因がわからない… - // let next = self - // .globals - // .backends - // .apps - // .iter() - // .filter(|&(_, backend)| { - // backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() - // }) - // .map(|(name, _)| name) - // .next(); - // ensure!(next.is_some(), "No TLS supported app"); - // let initial_app_name = next.ok_or_else(|| anyhow!(""))?; - // debug!( - // "HTTP/3 SNI multiplexer initial app_name: {:?}", - // std::str::from_utf8(initial_app_name) - // ); - // let backend_serve = self - // .globals - // .backends - // .apps - // .get(initial_app_name) - // .ok_or_else(|| anyhow!(""))?; - - // let initial_server_crypto = backend_serve.update_server_config().await?; - let initial_server_crypto = self + let server_crypto = self .globals .backends .generate_server_crypto_with_cert_resolver() - .await - .unwrap(); + .await?; - let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(initial_server_crypto)); - let (endpoint, incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?; + let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + let (endpoint, mut incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?; info!("Start UDP proxy serving with HTTP/3 request for configured host names"); - let mut server_crypto_map: Option = None; - let mut p = incoming.peekable(); + let mut server_crypto: Option> = None; loop { select! { - // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts. - // peek() should work for that. - peeked_conn = Pin::new(&mut p).peek_mut().fuse() => { - if server_crypto_map.is_none() || peeked_conn.is_none() { + new_conn = incoming.next().fuse() => { + if server_crypto.is_none() || new_conn.is_none() { continue; } - let peeked_conn = peeked_conn.unwrap(); - - let new_server_name = match self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { - Some((new_server_name, _new_server_crypto)) => { - debug!("omg"); - // Set ServerConfig::set_server_config for given SNI - // endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone()))); - Some(new_server_name) - }, - None => None + let mut conn = new_conn.unwrap(); + let hsd = if let Ok(h) = conn.handshake_data().await { + h + } else { + continue }; - - // Then acquire actual connection - let peekable_incoming = Pin::new(&mut p); - if let Some(conn) = peekable_incoming.get_mut().next().await { - if let Some(new_server_name) = new_server_name { - self.clone().client_serve_h3(conn, new_server_name).await; - } + let hsd_downcast = if let Ok(d) = hsd.downcast::() { + d } else { continue; - } + }; + let new_server_name = if let Some(sn) = hsd_downcast.server_name { + sn.as_bytes().to_ascii_lowercase() + } else { + warn!("HTTP/3 no SNI is given"); + continue; + }; + debug!( + "HTTP/3 connection incoming (SNI {:?})", + new_server_name + ); + self.clone().client_serve_h3(conn, new_server_name.as_ref()); } _ = server_crypto_rx.changed().fuse() => { if server_crypto_rx.borrow().is_none() { break; } - server_crypto_map = server_crypto_rx.borrow().clone(); + server_crypto = server_crypto_rx.borrow().clone(); + if server_crypto.is_some(){ + debug!("Reload server crypto"); + endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(server_crypto.clone().unwrap()))); + } } complete => break } @@ -213,7 +145,7 @@ where } pub async fn start_with_tls(self, server: Http) -> Result<()> { - let (tx, rx) = watch::channel::>(None); + let (tx, rx) = watch::channel::>>(None); #[cfg(not(feature = "h3"))] { select! {