Server: resolve wildcard domains
This commit is contained in:
parent
86cf2f9665
commit
46e85e6ee8
1 changed files with 66 additions and 2 deletions
|
|
@ -7,11 +7,72 @@ use tokio_rustls::rustls::{
|
|||
CertificateDer, PrivateKeyDer,
|
||||
pem::{self, PemObject, SectionKind},
|
||||
},
|
||||
server::ResolvesServerCertUsingSni,
|
||||
server::{ClientHello, ResolvesServerCert},
|
||||
sign::CertifiedKey,
|
||||
};
|
||||
use x509_parser::prelude::GeneralName;
|
||||
|
||||
/// Something that resolves do different cert chains/keys based
|
||||
/// on client-supplied server name (via SNI).
|
||||
/// Forked from rustls to add wildcard subdomains.
|
||||
#[derive(Debug)]
|
||||
pub struct ResolvesServerCertUsingSni {
|
||||
by_name: HashMap<String, Arc<CertifiedKey>>,
|
||||
}
|
||||
|
||||
impl ResolvesServerCertUsingSni {
|
||||
/// Create a new and empty (i.e., knows no certificates) resolver.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
by_name: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new `sign::CertifiedKey` to be used for the given SNI `name`.
|
||||
///
|
||||
/// This function fails if `name` is not a valid DNS name, or if
|
||||
/// it's not valid for the supplied certificate, or if the certificate
|
||||
/// chain is syntactically faulty.
|
||||
pub fn add(&mut self, name: &str, ck: CertifiedKey) -> Result<(), tokio_rustls::rustls::Error> {
|
||||
self.by_name.insert(name.to_string(), Arc::new(ck));
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl ResolvesServerCert for ResolvesServerCertUsingSni {
|
||||
fn resolve(&self, client_hello: ClientHello<'_>) -> Option<Arc<CertifiedKey>> {
|
||||
if let Some(name) = client_hello.server_name() {
|
||||
if let Some(key) = self.by_name.get(name) {
|
||||
// Exact match
|
||||
return Some(key.clone());
|
||||
}
|
||||
// Iteratively remove subdomains, replacing them with wildcards.
|
||||
let mut name = name.to_string();
|
||||
loop {
|
||||
if !name.contains('.') {
|
||||
// We made it to the top level!
|
||||
return None;
|
||||
}
|
||||
if name.as_bytes()[0] == b'*' {
|
||||
if !name[2..].contains('.') {
|
||||
// We made it to the top level!
|
||||
return None;
|
||||
}
|
||||
name = format!("*{}", &name[name[2..].find('.').unwrap() + 2..]);
|
||||
} else {
|
||||
name = format!("*{}", &name[name.find('.').unwrap()..]);
|
||||
}
|
||||
if let Some(key) = self.by_name.get(&name) {
|
||||
return Some(key.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// This kind of resolver requires SNI
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn play(
|
||||
records: &'static Records,
|
||||
use_tls: bool,
|
||||
|
|
@ -102,8 +163,11 @@ pub async fn play(
|
|||
.general_names
|
||||
.iter()
|
||||
{
|
||||
if debug {
|
||||
println!("Add name to resolver: {:?}", name);
|
||||
}
|
||||
if let GeneralName::DNSName(name) = name {
|
||||
resolver.add(name, cert_key.clone()).ok();
|
||||
resolver.add(name, cert_key.clone()).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue