diff --git a/src/backend.rs b/src/backend.rs index 8d58d8c..a5ed13f 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,6 +1,8 @@ use crate::{backend_opt::UpstreamOption, log::*}; +use h3::server; use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use rustls::server::ResolvesServerCert; use std::{ borrow::Cow, fs::File, @@ -11,7 +13,11 @@ use std::{ Arc, }, }; -use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; +use tokio_rustls::rustls::{ + server::ResolvesServerCertUsingSni, + sign::{any_supported_type, CertifiedKey, SigningKey}, + Certificate, PrivateKey, ServerConfig, +}; // server name (hostname or ip address) in ascii lower case pub type ServerNameLC = Vec; @@ -230,4 +236,142 @@ impl Backend { // server_config; Ok(server_config) } + + pub fn read_certs_and_key(&self) -> io::Result { + debug!("Read TLS server certificates and private key"); + let (certs_path, certs_keys_path) = + if let (Some(c), Some(k)) = (self.tls_cert_path.as_ref(), self.tls_cert_key_path.as_ref()) { + (c, k) + } else { + return Err(io::Error::new( + io::ErrorKind::Other, + "Invalid certs and keys paths", + )); + }; + let certs: Vec<_> = { + let certs_path_str = certs_path.display().to_string(); + let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { + io::Error::new( + e.kind(), + format!( + "Unable to load the certificates [{}]: {}", + certs_path_str, e + ), + ) + })?); + rustls_pemfile::certs(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates", + ) + })? + } + .drain(..) + .map(Certificate) + .collect(); + let certs_keys: Vec<_> = { + let certs_keys_path_str = certs_keys_path.display().to_string(); + let encoded_keys = { + let mut encoded_keys = vec![]; + File::open(certs_keys_path) + .map_err(|e| { + io::Error::new( + e.kind(), + format!( + "Unable to load the certificate keys [{}]: {}", + certs_keys_path_str, e + ), + ) + })? + .read_to_end(&mut encoded_keys)?; + encoded_keys + }; + let mut reader = Cursor::new(encoded_keys); + let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates private keys (PKCS8)", + ) + })?; + reader.set_position(0); + let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader)?; + let mut keys = pkcs8_keys; + keys.append(&mut rsa_keys); + if keys.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "No private keys found - Make sure that they are in PKCS#8/PEM format", + )); + } + keys.drain(..).map(PrivateKey).collect() + }; + let signing_key = certs_keys + .iter() + .find_map(|k| { + if let Ok(sk) = any_supported_type(k) { + Some(sk) + } else { + None + } + }) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to find a valid certificate and key", + ) + })?; + Ok(CertifiedKey::new(certs, signing_key)) + } +} + +impl Backends { + pub async fn generate_server_crypto_with_cert_resolver( + &self, + ) -> Result { + let mut resolver = ResolvesServerCertUsingSni::new(); + + for (_, backend) in self.apps.iter() { + if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { + match backend.read_certs_and_key() { + Ok(certified_key) => { + if let Err(e) = resolver.add(backend.server_name.as_str(), certified_key) { + error!( + "{}: Failed to read some certificates and keys {}", + backend.server_name.as_str(), + e + ) + } + } + Err(e) => { + warn!( + "Failed to add certificate for {}: {}", + backend.server_name.as_str(), + e + ); + } + } + } + } + + let mut server_config = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + + #[cfg(feature = "h3")] + { + server_config.alpn_protocols = vec![ + b"h3".to_vec(), + b"hq-29".to_vec(), // quinn draft example TODO: remove later + b"h2".to_vec(), + b"http/1.1".to_vec(), + ]; + } + #[cfg(not(feature = "h3"))] + { + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + } + + Ok(server_config) + } } diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index bc38b3c..24d2f32 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -101,7 +101,7 @@ where } #[cfg(feature = "h3")] - async fn parse_sni_and_get_crypto_h3<'a>( + async fn parse_sni_and_get_crypto_h3<'a, 'b>( &self, peeked_conn: &mut quinn::Connecting, server_crypto_map: &'a ServerCryptoMap, @@ -132,30 +132,36 @@ where ) -> Result<()> { // TODO: Work around to initially serve incoming connection // かなり適当。エラーが出たり出なかったり。原因がわからない… - let next = self - .globals - .backends - .apps - .iter() - .filter(|&(_, backend)| { - backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() - }) - .map(|(name, _)| name) - .next(); - ensure!(next.is_some(), "No TLS supported app"); - let initial_app_name = next.ok_or_else(|| anyhow!(""))?; - debug!( - "HTTP/3 SNI multiplexer initial app_name: {:?}", - std::str::from_utf8(initial_app_name) - ); - let backend_serve = self - .globals - .backends - .apps - .get(initial_app_name) - .ok_or_else(|| anyhow!(""))?; + // let next = self + // .globals + // .backends + // .apps + // .iter() + // .filter(|&(_, backend)| { + // backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() + // }) + // .map(|(name, _)| name) + // .next(); + // ensure!(next.is_some(), "No TLS supported app"); + // let initial_app_name = next.ok_or_else(|| anyhow!(""))?; + // debug!( + // "HTTP/3 SNI multiplexer initial app_name: {:?}", + // std::str::from_utf8(initial_app_name) + // ); + // let backend_serve = self + // .globals + // .backends + // .apps + // .get(initial_app_name) + // .ok_or_else(|| anyhow!(""))?; - let initial_server_crypto = backend_serve.update_server_config().await?; + // let initial_server_crypto = backend_serve.update_server_config().await?; + let initial_server_crypto = self + .globals + .backends + .generate_server_crypto_with_cert_resolver() + .await + .unwrap(); let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(initial_server_crypto)); let (endpoint, incoming) = quinn::Endpoint::server(server_config_h3, self.listening_on)?; @@ -174,9 +180,10 @@ where let peeked_conn = peeked_conn.unwrap(); let new_server_name = match self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { - Some((new_server_name, new_server_crypto)) => { + Some((new_server_name, _new_server_crypto)) => { + debug!("omg"); // Set ServerConfig::set_server_config for given SNI - endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone()))); + // endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone()))); Some(new_server_name) }, None => None