diff --git a/src/client.rs b/src/client.rs index 03353ad..b2dd267 100644 --- a/src/client.rs +++ b/src/client.rs @@ -24,6 +24,10 @@ use tokio_rustls::{ danger::{HandshakeSignatureValid, ServerCertVerifier}, }, pki_types::ServerName, + pki_types::{ + CertificateDer, + pem::{self, PemObject, SectionKind}, + }, }, }; @@ -90,6 +94,7 @@ pub async fn play( use_tls: bool, connect_to: (String, u16), repeat: u32, + cert_path: Option<&str>, debug: bool, ) { // Semaphore used to limit the number of concurrent clients. @@ -124,12 +129,46 @@ pub async fn play( }); if use_tls { - let mut config = tokio_rustls::rustls::ClientConfig::builder() + let config_builder = tokio_rustls::rustls::ClientConfig::builder(); + let mut config = if let Some(cert_path) = cert_path { + let mut certs = tokio_rustls::rustls::RootCertStore::empty(); + for file in std::fs::read_dir(cert_path).unwrap_or_else(|e| { + panic!("Cannot read certificate directory `{cert_path}`: {e:?}") + }) { + match file { + Ok(file) => { + if file.file_name().as_encoded_bytes().ends_with(b".crt") { + for section in + <(pem::SectionKind, Vec) as PemObject>::pem_file_iter( + file.path(), + ) + .unwrap() + { + let (kind, data) = section.unwrap(); + if kind == SectionKind::Certificate { + let (_rem, cert) = + x509_parser::parse_x509_certificate(&data).unwrap(); + if cert.is_ca() { + certs + .add(CertificateDer::from_slice(Box::leak( + data.to_vec().into_boxed_slice(), + ))) + .unwrap(); + } + } + } + } + } + Err(e) => eprintln!("Error listing cert directory: {e:?}"), + } + } + config_builder.with_root_certificates(certs) + } else { + config_builder.with_platform_verifier().unwrap() //.dangerous() //.with_custom_certificate_verifier(Arc::new(DummyCertVerifier)) - .with_platform_verifier() - .unwrap() - .with_no_client_auth(); + } + .with_no_client_auth(); let mut enable_early_data = false; for (var, val) in std::env::vars() { match var.as_str() { diff --git a/src/main.rs b/src/main.rs index 08d39aa..392e2a8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -61,6 +61,9 @@ struct OptClient { /// Only play this record #[argp(option)] record: Option, + /// Path to PEM certificates (if not provided, use system's certificates) + #[argp(option, short = 'c')] + certs: Option, /// Print debug info #[argp(switch, short = 'd')] debug: bool, @@ -142,6 +145,7 @@ async fn main() { subopt.tls, (subopt.connect_addr, subopt.connect_port), subopt.repeat, + subopt.certs.as_deref(), subopt.debug, ) .await;