use std::fs::File; use std::io::{self, BufReader, Cursor, Read}; use std::path::Path; use std::sync::Arc; use std::time::Duration; use futures::{future::FutureExt, join, select}; use hyper::{client::connect::Connect, server::conn::Http}; use tokio::{ net::TcpListener, sync::mpsc::{self, Receiver}, }; use tokio_rustls::{ rustls::{Certificate, PrivateKey, ServerConfig}, TlsAcceptor, }; use crate::acceptor::{LocalExecutor, PacketAcceptor}; use crate::constants::CERTS_WATCH_DELAY_SECS; use crate::error::*; pub fn create_tls_acceptor(certs_path: P, certs_keys_path: P2) -> io::Result where P: AsRef, P2: AsRef, { let certs: Vec<_> = { let certs_path_str = certs_path.as_ref().display().to_string(); let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { io::Error::new( e.kind(), format!( "Unable to load the certificates [{}]: {}", certs_path_str, e ), ) })?); rustls_pemfile::certs(&mut reader).map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "Unable to parse the certificates", ) })? } .drain(..) .map(Certificate) .collect(); let certs_keys: Vec<_> = { let certs_keys_path_str = certs_keys_path.as_ref().display().to_string(); let encoded_keys = { let mut encoded_keys = vec![]; File::open(certs_keys_path) .map_err(|e| { io::Error::new( e.kind(), format!( "Unable to load the certificate keys [{}]: {}", certs_keys_path_str, e ), ) })? .read_to_end(&mut encoded_keys)?; encoded_keys }; let mut reader = Cursor::new(encoded_keys); let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "Unable to parse the certificates private keys (PKCS8)", ) })?; reader.set_position(0); let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| { io::Error::new( io::ErrorKind::InvalidInput, "Unable to parse the certificates private keys (RSA)", ) })?; let mut keys = pkcs8_keys; keys.append(&mut rsa_keys); if keys.is_empty() { return Err(io::Error::new( io::ErrorKind::InvalidInput, "No private keys found - Make sure that they are in PKCS#8/PEM format", )); } keys.drain(..).map(PrivateKey).collect() }; let mut server_config = certs_keys .into_iter() .find_map(|certs_key| { let server_config_builder = ServerConfig::builder() .with_safe_defaults() .with_no_client_auth(); if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) { Some(found_config) } else { None } }) .ok_or_else(|| { io::Error::new( io::ErrorKind::InvalidInput, "Unable to find a valid certificate and key", ) })?; server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; Ok(TlsAcceptor::from(Arc::new(server_config))) } impl PacketAcceptor where T: Connect + Clone + Send + Sync + 'static, { async fn start_https_service( self, mut tls_acceptor_receiver: Receiver, listener: TcpListener, server: Http, ) -> Result<()> { let mut tls_acceptor: Option = None; let listener_service = async { loop { select! { tcp_cnx = listener.accept().fuse() => { if tls_acceptor.is_none() || tcp_cnx.is_err() { continue; } let (raw_stream, _client_addr) = tcp_cnx.unwrap(); if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await { self.clone().client_serve(stream, server.clone(), _client_addr).await } } new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => { if new_tls_acceptor.is_none() { break; } tls_acceptor = new_tls_acceptor; } complete => break } } Ok(()) as Result<()> }; listener_service.await?; Ok(()) } pub async fn start_with_tls( self, listener: TcpListener, server: Http, ) -> Result<()> { let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone(); let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone(); let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1); let https_service = self.start_https_service(tls_acceptor_receiver, listener, server); let cert_service = async { loop { match create_tls_acceptor(&certs_path, &certs_keys_path) { Ok(tls_acceptor) => { if tls_acceptor_sender.send(tls_acceptor).await.is_err() { break; } } Err(e) => eprintln!("TLS certificates error: {}", e), } tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; } Ok(()) as Result<()> }; return join!(https_service, cert_service).0; } }