Server: resolve wildcard domains

This commit is contained in:
Pascal Engélibert 2026-03-18 16:00:03 +01:00
commit 46e85e6ee8

View file

@ -7,11 +7,72 @@ use tokio_rustls::rustls::{
CertificateDer, PrivateKeyDer,
pem::{self, PemObject, SectionKind},
},
server::ResolvesServerCertUsingSni,
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
};
use x509_parser::prelude::GeneralName;
/// Something that resolves do different cert chains/keys based
/// on client-supplied server name (via SNI).
/// Forked from rustls to add wildcard subdomains.
#[derive(Debug)]
pub struct ResolvesServerCertUsingSni {
by_name: HashMap<String, Arc<CertifiedKey>>,
}
impl ResolvesServerCertUsingSni {
/// Create a new and empty (i.e., knows no certificates) resolver.
pub fn new() -> Self {
Self {
by_name: HashMap::new(),
}
}
/// Add a new `sign::CertifiedKey` to be used for the given SNI `name`.
///
/// This function fails if `name` is not a valid DNS name, or if
/// it's not valid for the supplied certificate, or if the certificate
/// chain is syntactically faulty.
pub fn add(&mut self, name: &str, ck: CertifiedKey) -> Result<(), tokio_rustls::rustls::Error> {
self.by_name.insert(name.to_string(), Arc::new(ck));
Ok(())
}
}
impl ResolvesServerCert for ResolvesServerCertUsingSni {
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
if let Some(name) = client_hello.server_name() {
if let Some(key) = self.by_name.get(name) {
// Exact match
return Some(key.clone());
}
// Iteratively remove subdomains, replacing them with wildcards.
let mut name = name.to_string();
loop {
if !name.contains('.') {
// We made it to the top level!
return None;
}
if name.as_bytes()[0] == b'*' {
if !name[2..].contains('.') {
// We made it to the top level!
return None;
}
name = format!("*{}", &name[name[2..].find('.').unwrap() + 2..]);
} else {
name = format!("*{}", &name[name.find('.').unwrap()..]);
}
if let Some(key) = self.by_name.get(&name) {
return Some(key.clone());
}
}
} else {
// This kind of resolver requires SNI
None
}
}
}
pub async fn play(
records: &'static Records,
use_tls: bool,
@ -102,8 +163,11 @@ pub async fn play(
.general_names
.iter()
{
if debug {
println!("Add name to resolver: {:?}", name);
}
if let GeneralName::DNSName(name) = name {
resolver.add(name, cert_key.clone()).ok();
resolver.add(name, cert_key.clone()).unwrap();
}
}
}