diff --git a/src/server.rs b/src/server.rs index f0d513a..5536fd3 100644 --- a/src/server.rs +++ b/src/server.rs @@ -7,11 +7,72 @@ use tokio_rustls::rustls::{ CertificateDer, PrivateKeyDer, pem::{self, PemObject, SectionKind}, }, - server::ResolvesServerCertUsingSni, + server::{ClientHello, ResolvesServerCert}, sign::CertifiedKey, }; use x509_parser::prelude::GeneralName; +/// Something that resolves do different cert chains/keys based +/// on client-supplied server name (via SNI). +/// Forked from rustls to add wildcard subdomains. +#[derive(Debug)] +pub struct ResolvesServerCertUsingSni { + by_name: HashMap>, +} + +impl ResolvesServerCertUsingSni { + /// Create a new and empty (i.e., knows no certificates) resolver. + pub fn new() -> Self { + Self { + by_name: HashMap::new(), + } + } + + /// Add a new `sign::CertifiedKey` to be used for the given SNI `name`. + /// + /// This function fails if `name` is not a valid DNS name, or if + /// it's not valid for the supplied certificate, or if the certificate + /// chain is syntactically faulty. + pub fn add(&mut self, name: &str, ck: CertifiedKey) -> Result<(), tokio_rustls::rustls::Error> { + self.by_name.insert(name.to_string(), Arc::new(ck)); + Ok(()) + } +} + +impl ResolvesServerCert for ResolvesServerCertUsingSni { + fn resolve(&self, client_hello: ClientHello<'_>) -> Option> { + if let Some(name) = client_hello.server_name() { + if let Some(key) = self.by_name.get(name) { + // Exact match + return Some(key.clone()); + } + // Iteratively remove subdomains, replacing them with wildcards. + let mut name = name.to_string(); + loop { + if !name.contains('.') { + // We made it to the top level! + return None; + } + if name.as_bytes()[0] == b'*' { + if !name[2..].contains('.') { + // We made it to the top level! + return None; + } + name = format!("*{}", &name[name[2..].find('.').unwrap() + 2..]); + } else { + name = format!("*{}", &name[name.find('.').unwrap()..]); + } + if let Some(key) = self.by_name.get(&name) { + return Some(key.clone()); + } + } + } else { + // This kind of resolver requires SNI + None + } + } +} + pub async fn play( records: &'static Records, use_tls: bool, @@ -102,8 +163,11 @@ pub async fn play( .general_names .iter() { + if debug { + println!("Add name to resolver: {:?}", name); + } if let GeneralName::DNSName(name) = name { - resolver.add(name, cert_key.clone()).ok(); + resolver.add(name, cert_key.clone()).unwrap(); } } }