use crate::{ crypto_source::CryptoSource, error::*, log::*, server_crypto::{ServerCryptoBase, ServerNameBytes}, }; use ahash::HashMap; use async_trait::async_trait; use hot_reload::{Reload, ReloaderError}; use std::sync::Arc; /* ------------------------------------------------ */ /// Boxed CryptoSource trait object with Send and Sync /// TODO: support for not only `CryptoFileSource` but also other type of sources pub(super) type DynCryptoSource = dyn CryptoSource + Send + Sync + 'static; #[derive(Clone)] /// Reloader service for certificates and keys for TLS pub struct CryptoReloader { inner: HashMap>>, } impl Extend<(ServerNameBytes, T)> for CryptoReloader where T: CryptoSource + Send + Sync + 'static, { fn extend>(&mut self, iter: I) { let iter = iter .into_iter() .map(|(k, v)| (k, Arc::new(Box::new(v) as Box))); self.inner.extend(iter); } } #[async_trait] impl Reload for CryptoReloader { type Source = HashMap>>; async fn new(source: &Self::Source) -> Result> { let mut inner = HashMap::default(); inner.extend(source.clone()); Ok(Self { inner }) } async fn reload(&self) -> Result, ReloaderError> { let mut server_crypto_base = ServerCryptoBase::default(); for (server_name_bytes, crypto_source) in self.inner.iter() { let certs_keys = match crypto_source.read().await { Ok(certs_keys) => certs_keys, Err(e) => { error!("Failed to read certs and keys, skip at this time: {}", e); continue; } }; server_crypto_base.inner.insert(server_name_bytes.clone(), certs_keys); } Ok(Some(server_crypto_base)) } } /* ------------------------------------------------ */ #[cfg(test)] mod tests { use super::*; use crate::crypto_source::CryptoFileSourceBuilder; #[tokio::test] async fn test_crypto_reloader() { let tls_cert_path = "../example-certs/server.crt"; let tls_cert_key_path = "../example-certs/server.key"; let client_ca_cert_path = Some("../example-certs/client.ca.crt"); let mut crypto_reloader = CryptoReloader::new(&HashMap::default()).await.unwrap(); let crypto_source = CryptoFileSourceBuilder::default() .tls_cert_path(tls_cert_path) .tls_cert_key_path(tls_cert_key_path) .client_ca_cert_path(client_ca_cert_path) .build() .unwrap(); crypto_reloader.extend(vec![(b"localhost".to_vec(), crypto_source)]); let server_crypto_base = crypto_reloader.reload().await.unwrap().unwrap(); assert_eq!(server_crypto_base.inner.len(), 1); } }