Merge pull request #156 from junkurihara/feat/rustls-0.23

feat: rustls-0.23
This commit is contained in:
Jun Kurihara 2024-06-02 02:57:52 +09:00 committed by GitHub
commit d66863ad3f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
38 changed files with 934 additions and 1309 deletions

6
.gitmodules vendored
View file

@ -1,6 +1,6 @@
[submodule "submodules/h3"]
path = submodules/h3
url = git@github.com:junkurihara/h3.git
[submodule "submodules/rusty-http-cache-semantics"]
path = submodules/rusty-http-cache-semantics
url = git@github.com:junkurihara/rusty-http-cache-semantics.git
[submodule "submodules/s2n-quic"]
path = submodules/s2n-quic
url = git@github.com:junkurihara/s2n-quic.git

View file

@ -2,7 +2,17 @@
## 0.8.0 (Unreleased)
## 0.7.1 -- 0.7.3
### Important Changes
- Breaking: Support for `rustls`-0.23.x for http/1.1, 2 and 3. No configuration update is needed at this point.
- Breaking: Along with `rustls`, the cert manager was split from `rpxy-lib` and moved to a new inner crate `rpxy-cert`. This change is to make the cert manager reusable for other projects and to support not only static file based certificates but also other types, e.g., dynamic fetching and management via ACME, in the future.
### Improvement
- Refactor: lots of minor improvements
## 0.7.1
- deps and patches

View file

@ -1,5 +1,5 @@
[workspace.package]
version = "0.7.2"
version = "0.8.0-alpha.0"
authors = ["Jun Kurihara"]
homepage = "https://github.com/junkurihara/rust-rpxy"
repository = "https://github.com/junkurihara/rust-rpxy"
@ -9,7 +9,7 @@ edition = "2021"
publish = false
[workspace]
members = ["rpxy-bin", "rpxy-lib"]
members = ["rpxy-bin", "rpxy-lib", "rpxy-certs"]
exclude = ["submodules"]
resolver = "2"

View file

@ -14,6 +14,7 @@ publish.workspace = true
[features]
default = ["http3-quinn", "cache", "rustls-backend"]
# default = ["http3-s2n", "cache", "rustls-backend"]
http3-quinn = ["rpxy-lib/http3-quinn"]
http3-s2n = ["rpxy-lib/http3-s2n"]
native-tls-backend = ["rpxy-lib/native-tls-backend"]
@ -26,11 +27,11 @@ rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [
"sticky-cookie",
] }
mimalloc = { version = "*", default-features = false }
anyhow = "1.0.86"
rustc-hash = "1.1.0"
serde = { version = "1.0.202", default-features = false, features = ["derive"] }
derive_builder = "0.20.0"
tokio = { version = "1.37.0", default-features = false, features = [
serde = { version = "1.0.203", default-features = false, features = ["derive"] }
tokio = { version = "1.38.0", default-features = false, features = [
"net",
"rt-multi-thread",
"time",
@ -38,8 +39,7 @@ tokio = { version = "1.37.0", default-features = false, features = [
"macros",
] }
async-trait = "0.1.80"
rustls-pemfile = "1.0.4"
mimalloc = { version = "*", default-features = false }
# config
clap = { version = "4.5.4", features = ["std", "cargo", "wrap_help"] }
@ -50,5 +50,10 @@ hot_reload = "0.1.5"
tracing = { version = "0.1.40" }
tracing-subscriber = { version = "0.3.18", features = ["env-filter"] }
################################
# cert management
rpxy-certs = { path = "../rpxy-certs/", default-features = false, features = [
"http3",
] }
[dev-dependencies]

View file

@ -4,6 +4,6 @@ mod toml;
pub use {
self::toml::ConfigToml,
parse::{build_settings, parse_opts},
parse::{build_cert_manager, build_settings, parse_opts},
service::ConfigTomlReloader,
};

View file

@ -1,10 +1,10 @@
use super::toml::ConfigToml;
use crate::{
cert_file_reader::CryptoFileSource,
error::{anyhow, ensure},
};
use crate::error::{anyhow, ensure};
use clap::{Arg, ArgAction};
use hot_reload::{ReloaderReceiver, ReloaderService};
use rpxy_certs::{build_cert_reloader, CryptoFileSourceBuilder, CryptoReloader, ServerCryptoBase};
use rpxy_lib::{AppConfig, AppConfigList, ProxyConfig};
use rustc_hash::FxHashMap as HashMap;
/// Parsed options
pub struct Opts {
@ -37,20 +37,13 @@ pub fn parse_opts() -> Result<Opts, anyhow::Error> {
let config_file_path = matches.get_one::<String>("config_file").unwrap().to_owned();
let watch = matches.get_one::<bool>("watch").unwrap().to_owned();
Ok(Opts {
config_file_path,
watch,
})
Ok(Opts { config_file_path, watch })
}
pub fn build_settings(
config: &ConfigToml,
) -> std::result::Result<(ProxyConfig, AppConfigList<CryptoFileSource>), anyhow::Error> {
///////////////////////////////////
pub fn build_settings(config: &ConfigToml) -> std::result::Result<(ProxyConfig, AppConfigList), anyhow::Error> {
// build proxy config
let proxy_config: ProxyConfig = config.try_into()?;
///////////////////////////////////
// backend_apps
let apps = config.apps.clone().ok_or(anyhow!("Missing application spec"))?;
@ -78,9 +71,8 @@ pub fn build_settings(
}
// build applications
let mut app_config_list_inner = Vec::<AppConfig<CryptoFileSource>>::new();
let mut app_config_list_inner = Vec::<AppConfig>::new();
// let mut backends = Backends::new();
for (app_name, app) in apps.0.iter() {
let _server_name_string = app.server_name.as_ref().ok_or(anyhow!("No server name"))?;
let registered_app_name = app_name.to_ascii_lowercase();
@ -95,3 +87,35 @@ pub fn build_settings(
Ok((proxy_config, app_config_list))
}
/* ----------------------- */
/// Build cert map
pub async fn build_cert_manager(
config: &ConfigToml,
) -> Result<
Option<(
ReloaderService<CryptoReloader, ServerCryptoBase>,
ReloaderReceiver<ServerCryptoBase>,
)>,
anyhow::Error,
> {
let apps = config.apps.as_ref().ok_or(anyhow!("No apps"))?;
if config.listen_port_tls.is_none() {
return Ok(None);
}
let mut crypto_source_map = HashMap::default();
for app in apps.0.values() {
if let Some(tls) = app.tls.as_ref() {
ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some());
let server_name = app.server_name.as_ref().ok_or(anyhow!("No server name"))?;
let crypto_file_source = CryptoFileSourceBuilder::default()
.tls_cert_path(tls.tls_cert_path.as_ref().unwrap())
.tls_cert_key_path(tls.tls_cert_key_path.as_ref().unwrap())
.client_ca_cert_path(tls.client_ca_cert_path.as_deref())
.build()?;
crypto_source_map.insert(server_name.to_owned(), crypto_file_source);
}
}
let res = build_cert_reloader(&crypto_source_map, None).await?;
Ok(Some(res))
}

View file

@ -1,5 +1,4 @@
use crate::{
cert_file_reader::{CryptoFileSource, CryptoFileSourceBuilder},
constants::*,
error::{anyhow, ensure},
};
@ -214,7 +213,7 @@ impl ConfigToml {
}
impl Application {
pub fn build_app_config(&self, app_name: &str) -> std::result::Result<AppConfig<CryptoFileSource>, anyhow::Error> {
pub fn build_app_config(&self, app_name: &str) -> std::result::Result<AppConfig, anyhow::Error> {
let server_name_string = self.server_name.as_ref().ok_or(anyhow!("Missing server_name"))?;
// reverse proxy settings
@ -224,11 +223,6 @@ impl Application {
let tls_config = if self.tls.is_some() {
let tls = self.tls.as_ref().unwrap();
ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some());
let inner = CryptoFileSourceBuilder::default()
.tls_cert_path(tls.tls_cert_path.as_ref().unwrap())
.tls_cert_key_path(tls.tls_cert_key_path.as_ref().unwrap())
.client_ca_cert_path(tls.client_ca_cert_path.as_deref())
.build()?;
let https_redirection = if tls.https_redirection.is_none() {
true // Default true
@ -237,7 +231,7 @@ impl Application {
};
Some(TlsConfig {
inner,
mutual_tls: tls.client_ca_cert_path.is_some(),
https_redirection,
})
} else {

View file

@ -1,15 +1,15 @@
#[global_allocator]
static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
mod cert_file_reader;
mod config;
mod constants;
mod error;
mod log;
use crate::{
config::{build_settings, parse_opts, ConfigToml, ConfigTomlReloader},
config::{build_cert_manager, build_settings, parse_opts, ConfigToml, ConfigTomlReloader},
constants::CONFIG_WATCH_DELAY_SECS,
error::*,
log::*,
};
use hot_reload::{ReloaderReceiver, ReloaderService};
@ -36,13 +36,10 @@ fn main() {
std::process::exit(1);
}
} else {
let (config_service, config_rx) = ReloaderService::<ConfigTomlReloader, ConfigToml>::new(
&parsed_opts.config_file_path,
CONFIG_WATCH_DELAY_SECS,
false,
)
.await
.unwrap();
let (config_service, config_rx) =
ReloaderService::<ConfigTomlReloader, ConfigToml>::new(&parsed_opts.config_file_path, CONFIG_WATCH_DELAY_SECS, false)
.await
.unwrap();
tokio::select! {
Err(e) = config_service.start() => {
@ -53,6 +50,9 @@ fn main() {
error!("rpxy service existed: {e}");
std::process::exit(1);
}
else => {
std::process::exit(0);
}
}
}
});
@ -63,23 +63,16 @@ async fn rpxy_service_without_watcher(
runtime_handle: tokio::runtime::Handle,
) -> Result<(), anyhow::Error> {
info!("Start rpxy service");
let config_toml = match ConfigToml::new(config_file_path) {
Ok(v) => v,
Err(e) => {
error!("Invalid toml file: {e}");
std::process::exit(1);
}
};
let (proxy_conf, app_conf) = match build_settings(&config_toml) {
Ok(v) => v,
Err(e) => {
error!("Invalid configuration: {e}");
return Err(anyhow::anyhow!(e));
}
};
entrypoint(&proxy_conf, &app_conf, &runtime_handle, None)
let config_toml = ConfigToml::new(config_file_path).map_err(|e| anyhow!("Invalid toml file: {e}"))?;
let (proxy_conf, app_conf) = build_settings(&config_toml).map_err(|e| anyhow!("Invalid configuration: {e}"))?;
let cert_service_and_rx = build_cert_manager(&config_toml)
.await
.map_err(|e| anyhow::anyhow!(e))
.map_err(|e| anyhow!("Invalid cert configuration: {e}"))?;
rpxy_entrypoint(&proxy_conf, &app_conf, cert_service_and_rx.as_ref(), &runtime_handle, None)
.await
.map_err(|e| anyhow!(e))
}
async fn rpxy_service_with_watcher(
@ -89,14 +82,15 @@ async fn rpxy_service_with_watcher(
info!("Start rpxy service with dynamic config reloader");
// Initial loading
config_rx.changed().await?;
let config_toml = config_rx.borrow().clone().unwrap();
let (mut proxy_conf, mut app_conf) = match build_settings(&config_toml) {
Ok(v) => v,
Err(e) => {
error!("Invalid configuration: {e}");
return Err(anyhow::anyhow!(e));
}
};
let config_toml = config_rx
.borrow()
.clone()
.ok_or(anyhow!("Something wrong in config reloader receiver"))?;
let (mut proxy_conf, mut app_conf) = build_settings(&config_toml).map_err(|e| anyhow!("Invalid configuration: {e}"))?;
let mut cert_service_and_rx = build_cert_manager(&config_toml)
.await
.map_err(|e| anyhow!("Invalid cert configuration: {e}"))?;
// Notifier for proxy service termination
let term_notify = std::sync::Arc::new(tokio::sync::Notify::new());
@ -104,16 +98,15 @@ async fn rpxy_service_with_watcher(
// Continuous monitoring
loop {
tokio::select! {
_ = entrypoint(&proxy_conf, &app_conf, &runtime_handle, Some(term_notify.clone())) => {
error!("rpxy entrypoint exited");
break;
rpxy_res = rpxy_entrypoint(&proxy_conf, &app_conf, cert_service_and_rx.as_ref(), &runtime_handle, Some(term_notify.clone())) => {
error!("rpxy entrypoint or cert service exited");
return rpxy_res.map_err(|e| anyhow!(e));
}
_ = config_rx.changed() => {
if config_rx.borrow().is_none() {
let Some(config_toml) = config_rx.borrow().clone() else {
error!("Something wrong in config reloader receiver");
break;
}
let config_toml = config_rx.borrow().clone().unwrap();
return Err(anyhow!("Something wrong in config reloader receiver"));
};
match build_settings(&config_toml) {
Ok((p, a)) => {
(proxy_conf, app_conf) = (p, a)
@ -123,6 +116,16 @@ async fn rpxy_service_with_watcher(
continue;
}
};
match build_cert_manager(&config_toml).await {
Ok(c) => {
cert_service_and_rx = c;
},
Err(e) => {
error!("Invalid cert configuration. Configuration does not updated: {e}");
continue;
}
};
info!("Configuration updated. Terminate all spawned proxy services and force to re-bind TCP/UDP sockets");
term_notify.notify_waiters();
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
@ -131,5 +134,34 @@ async fn rpxy_service_with_watcher(
}
}
Err(anyhow::anyhow!("rpxy or continuous monitoring service exited"))
Ok(())
}
/// Wrapper of entry point for rpxy service with certificate management service
async fn rpxy_entrypoint(
proxy_config: &rpxy_lib::ProxyConfig,
app_config_list: &rpxy_lib::AppConfigList,
cert_service_and_rx: Option<&(
ReloaderService<rpxy_certs::CryptoReloader, rpxy_certs::ServerCryptoBase>,
ReloaderReceiver<rpxy_certs::ServerCryptoBase>,
)>, // TODO:
runtime_handle: &tokio::runtime::Handle,
term_notify: Option<std::sync::Arc<tokio::sync::Notify>>,
) -> Result<(), anyhow::Error> {
if let Some((cert_service, cert_rx)) = cert_service_and_rx {
tokio::select! {
rpxy_res = entrypoint(proxy_config, app_config_list, Some(cert_rx), runtime_handle, term_notify) => {
error!("rpxy entrypoint exited");
rpxy_res.map_err(|e| anyhow!(e))
}
cert_res = cert_service.start() => {
error!("cert reloader service exited");
cert_res.map_err(|e| anyhow!(e))
}
}
} else {
entrypoint(proxy_config, app_config_list, None, runtime_handle, term_notify)
.await
.map_err(|e| anyhow!(e))
}
}

39
rpxy-certs/Cargo.toml Normal file
View file

@ -0,0 +1,39 @@
[package]
name = "rpxy-certs"
description = "Cert manager library for `rpxy`"
version.workspace = true
authors.workspace = true
homepage.workspace = true
repository.workspace = true
license.workspace = true
readme.workspace = true
edition.workspace = true
publish.workspace = true
[features]
default = ["http3"]
http3 = []
[dependencies]
rustc-hash = { version = "1.1.0" }
tracing = { version = "0.1.40" }
derive_builder = { version = "0.20.0" }
thiserror = { version = "1.0.61" }
hot_reload = { version = "0.1.5" }
async-trait = { version = "0.1.80" }
rustls = { version = "0.23.8", default-features = false, features = [
"std",
"aws_lc_rs",
] }
rustls-pemfile = { version = "2.1.2" }
rustls-webpki = { version = "0.102.4", default-features = false, features = [
"std",
"aws_lc_rs",
] }
x509-parser = { version = "0.16.0" }
[dev-dependencies]
tokio = { version = "1.38.0", default-features = false, features = [
"rt-multi-thread",
"macros",
] }

157
rpxy-certs/src/certs.rs Normal file
View file

@ -0,0 +1,157 @@
use crate::error::*;
use rustc_hash::FxHashMap as HashMap;
use rustls::{crypto::aws_lc_rs::sign::any_supported_type, pki_types, sign::CertifiedKey};
use std::sync::Arc;
use x509_parser::prelude::*;
/* ------------------------------------------------ */
/// Raw certificates in rustls format
type Certificate = rustls::pki_types::CertificateDer<'static>;
/// Raw private key in rustls format
type PrivateKey = pki_types::PrivateKeyDer<'static>;
/// Subject Key ID in bytes
type SubjectKeyIdentifier = Vec<u8>;
/// Client CA trust anchors subject to the subject key identifier
type TrustAnchors = HashMap<SubjectKeyIdentifier, pki_types::TrustAnchor<'static>>;
/* ------------------------------------------------ */
/// Raw certificates and private keys loaded from files for a single server name
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SingleServerCertsKeys {
certs: Vec<Certificate>,
cert_keys: Arc<Vec<PrivateKey>>,
client_ca_certs: Option<Vec<Certificate>>,
}
impl SingleServerCertsKeys {
/// Create a new instance of SingleServerCrypto
pub fn new(certs: &[Certificate], cert_keys: &Arc<Vec<PrivateKey>>, client_ca_certs: &Option<Vec<Certificate>>) -> Self {
Self {
certs: certs.to_owned(),
cert_keys: cert_keys.clone(),
client_ca_certs: client_ca_certs.clone(),
}
}
/// Check if mutual tls is enabled
pub fn is_mutual_tls(&self) -> bool {
self.client_ca_certs.is_some()
}
/* ------------------------------------------------ */
/// Convert the certificates to bytes in der
pub fn certs_bytes(&self) -> Vec<Vec<u8>> {
self.certs.iter().map(|c| c.to_vec()).collect()
}
/// Convert the private keys to bytes in der
pub fn cert_keys_bytes(&self) -> Vec<Vec<u8>> {
self
.cert_keys
.iter()
.map(|k| match k {
pki_types::PrivateKeyDer::Pkcs1(pkcs1) => pkcs1.secret_pkcs1_der().to_owned(),
pki_types::PrivateKeyDer::Sec1(sec1) => sec1.secret_sec1_der().to_owned(),
pki_types::PrivateKeyDer::Pkcs8(pkcs8) => pkcs8.secret_pkcs8_der().to_owned(),
_ => unreachable!(),
})
.collect()
}
/// Convert the client CA certificates to bytes in der
pub fn client_ca_certs_bytes(&self) -> Option<Vec<Vec<u8>>> {
self.client_ca_certs.as_ref().map(|v| v.iter().map(|c| c.to_vec()).collect())
}
/* ------------------------------------------------ */
/// Parse the certificates and private keys for a single server and return a rustls CertifiedKey
pub fn rustls_certified_key(&self) -> Result<CertifiedKey, RpxyCertError> {
let signing_key = self
.cert_keys
.clone()
.iter()
.find_map(|k| if let Ok(sk) = any_supported_type(k) { Some(sk) } else { None })
.ok_or_else(|| RpxyCertError::InvalidCertificateAndKey)?;
let cert = self.certs.iter().map(|c| Certificate::from(c.to_vec())).collect::<Vec<_>>();
Ok(CertifiedKey::new(cert, signing_key))
}
/* ------------------------------------------------ */
/// Parse the client CA certificates and return a hashmap of pairs of a subject key identifier (key) and a trust anchor (value)
pub fn rustls_client_certs_trust_anchors(&self) -> Result<TrustAnchors, RpxyCertError> {
let Some(certs) = self.client_ca_certs.as_ref() else {
return Err(RpxyCertError::NoClientCert);
};
let certs = certs.iter().map(|c| Certificate::from(c.to_vec())).collect::<Vec<_>>();
let trust_anchors = certs
.iter()
.filter_map(|v| {
// retrieve trust anchor
let trust_anchor = webpki::anchor_from_trusted_cert(v).ok()?;
// retrieve ca key id (subject key id)
let x509_cert = parse_x509_certificate(v).map(|v| v.1).ok()?;
let mut subject_key_ids = x509_cert.iter_extensions().filter_map(|ext| match ext.parsed_extension() {
ParsedExtension::SubjectKeyIdentifier(skid) => Some(skid),
_ => None,
});
let skid = subject_key_ids.next()?;
Some((skid.0.to_owned(), trust_anchor.to_owned()))
})
.collect::<HashMap<_, _>>();
Ok(trust_anchors)
}
}
/* ------------------------------------------------ */
#[cfg(test)]
mod tests {
use super::super::*;
#[tokio::test]
async fn read_server_crt_key_files() {
let tls_cert_path = "../example-certs/server.crt";
let tls_cert_key_path = "../example-certs/server.key";
let crypto_file_source = CryptoFileSourceBuilder::default()
.tls_cert_key_path(tls_cert_key_path)
.tls_cert_path(tls_cert_path)
.build();
assert!(crypto_file_source.is_ok());
let crypto_file_source = crypto_file_source.unwrap();
let crypto_elem = crypto_file_source.read().await;
assert!(crypto_elem.is_ok());
let crypto_elem = crypto_elem.unwrap();
let certificed_key = crypto_elem.rustls_certified_key();
assert!(certificed_key.is_ok());
}
#[tokio::test]
async fn read_server_crt_key_files_with_client_ca_crt() {
let tls_cert_path = "../example-certs/server.crt";
let tls_cert_key_path = "../example-certs/server.key";
let client_ca_cert_path = Some("../example-certs/client.ca.crt");
let crypto_file_source = CryptoFileSourceBuilder::default()
.tls_cert_key_path(tls_cert_key_path)
.tls_cert_path(tls_cert_path)
.client_ca_cert_path(client_ca_cert_path)
.build();
assert!(crypto_file_source.is_ok());
let crypto_file_source = crypto_file_source.unwrap();
let crypto_elem = crypto_file_source.read().await;
assert!(crypto_elem.is_ok());
let crypto_elem = crypto_elem.unwrap();
assert!(crypto_elem.is_mutual_tls());
let certificed_key = crypto_elem.rustls_certified_key();
assert!(certificed_key.is_ok());
let trust_anchors = crypto_elem.rustls_client_certs_trust_anchors();
assert!(trust_anchors.is_ok());
let trust_anchors = trust_anchors.unwrap();
assert_eq!(trust_anchors.len(), 1);
}
}

View file

@ -1,18 +1,29 @@
use crate::log::*;
use crate::{certs::SingleServerCertsKeys, error::*, log::*};
use async_trait::async_trait;
use derive_builder::Builder;
use rpxy_lib::{
reexports::{Certificate, PrivateKey},
CertsAndKeys, CryptoSource,
};
use std::{
fs::File,
io::{self, BufReader, Cursor, Read},
path::{Path, PathBuf},
sync::Arc,
};
/* ------------------------------------------------ */
#[async_trait]
// Trait to read certs and keys anywhere from KVS, file, sqlite, etc.
pub trait CryptoSource {
type Error;
/// read crypto materials from source
async fn read(&self) -> Result<SingleServerCertsKeys, Self::Error>;
/// Returns true when mutual tls is enabled
fn is_mutual_tls(&self) -> bool;
}
/* ------------------------------------------------ */
#[derive(Builder, Debug, Clone)]
/// Crypto-related file reader implementing certs::CryptoRead trait
/// Crypto-related file reader implementing `CryptoSource` trait
pub struct CryptoFileSource {
#[builder(setter(custom))]
/// Always exist
@ -42,11 +53,12 @@ impl CryptoFileSourceBuilder {
}
}
/* ------------------------------------------------ */
#[async_trait]
impl CryptoSource for CryptoFileSource {
type Error = io::Error;
type Error = RpxyCertError;
/// read crypto materials from source
async fn read(&self) -> Result<CertsAndKeys, Self::Error> {
async fn read(&self) -> Result<SingleServerCertsKeys, Self::Error> {
read_certs_and_keys(
&self.tls_cert_path,
&self.tls_cert_key_path,
@ -59,127 +71,90 @@ impl CryptoSource for CryptoFileSource {
}
}
/* ------------------------------------------------ */
/// Read certificates and private keys from file
fn read_certs_and_keys(
cert_path: &PathBuf,
cert_key_path: &PathBuf,
client_ca_cert_path: Option<&PathBuf>,
) -> Result<CertsAndKeys, io::Error> {
) -> Result<SingleServerCertsKeys, RpxyCertError> {
debug!("Read TLS server certificates and private key");
let certs: Vec<_> = {
let certs_path_str = cert_path.display().to_string();
// certificates
let raw_certs = {
let mut reader = BufReader::new(File::open(cert_path).map_err(|e| {
io::Error::new(
e.kind(),
format!("Unable to load the certificates [{certs_path_str}]: {e}"),
format!("Unable to load the certificates [{}]: {e}", cert_path.display()),
)
})?);
rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "Unable to parse the certificates"))?
}
.drain(..)
.map(Certificate)
.collect();
};
let cert_keys: Vec<_> = {
let cert_key_path_str = cert_key_path.display().to_string();
// private keys
let raw_cert_keys = {
let encoded_keys = {
let mut encoded_keys = vec![];
File::open(cert_key_path)
.map_err(|e| {
io::Error::new(
e.kind(),
format!("Unable to load the certificate keys [{cert_key_path_str}]: {e}"),
format!("Unable to load the certificate keys [{}]: {e}", cert_key_path.display()),
)
})?
.read_to_end(&mut encoded_keys)?;
encoded_keys
};
let mut reader = Cursor::new(encoded_keys);
let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unable to parse the certificates private keys (PKCS8)",
)
})?;
let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
.map(|v| v.map(rustls::pki_types::PrivateKeyDer::Pkcs8))
.collect::<Result<Vec<_>, _>>()
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unable to parse the certificates private keys (PKCS8)",
)
})?;
reader.set_position(0);
let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader)?;
let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader)
.map(|v| v.map(rustls::pki_types::PrivateKeyDer::Pkcs1))
.collect::<Result<Vec<_>, _>>()?;
let mut keys = pkcs8_keys;
keys.append(&mut rsa_keys);
if keys.is_empty() {
return Err(io::Error::new(
return Err(RpxyCertError::IoError(io::Error::new(
io::ErrorKind::InvalidInput,
"No private keys found - Make sure that they are in PKCS#8/PEM format",
));
)));
}
keys.drain(..).map(PrivateKey).collect()
keys
};
// client ca certificates
let client_ca_certs = if let Some(path) = client_ca_cert_path {
debug!("Read CA certificates for client authentication");
// Reads client certificate and returns client
let certs: Vec<_> = {
let certs_path_str = path.display().to_string();
let certs = {
let mut reader = BufReader::new(File::open(path).map_err(|e| {
io::Error::new(
e.kind(),
format!("Unable to load the client certificates [{certs_path_str}]: {e}"),
format!("Unable to load the client certificates [{}]: {e}", path.display()),
)
})?);
rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "Unable to parse the client certificates"))?
}
.drain(..)
.map(Certificate)
.collect();
};
Some(certs)
} else {
None
};
Ok(CertsAndKeys {
certs,
cert_keys,
client_ca_certs,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn read_server_crt_key_files() {
let tls_cert_path = "../example-certs/server.crt";
let tls_cert_key_path = "../example-certs/server.key";
let crypto_file_source = CryptoFileSourceBuilder::default()
.tls_cert_key_path(tls_cert_key_path)
.tls_cert_path(tls_cert_path)
.build();
assert!(crypto_file_source.is_ok());
let crypto_file_source = crypto_file_source.unwrap();
let crypto_elem = crypto_file_source.read().await;
assert!(crypto_elem.is_ok());
}
#[tokio::test]
async fn read_server_crt_key_files_with_client_ca_crt() {
let tls_cert_path = "../example-certs/server.crt";
let tls_cert_key_path = "../example-certs/server.key";
let client_ca_cert_path = Some("../example-certs/client.ca.crt");
let crypto_file_source = CryptoFileSourceBuilder::default()
.tls_cert_key_path(tls_cert_key_path)
.tls_cert_path(tls_cert_path)
.client_ca_cert_path(client_ca_cert_path)
.build();
assert!(crypto_file_source.is_ok());
let crypto_file_source = crypto_file_source.unwrap();
let crypto_elem = crypto_file_source.read().await;
assert!(crypto_elem.is_ok());
let crypto_elem = crypto_elem.unwrap();
assert!(crypto_elem.client_ca_certs.is_some());
}
Ok(SingleServerCertsKeys::new(
&raw_certs,
&Arc::new(raw_cert_keys),
&client_ca_certs,
))
}

27
rpxy-certs/src/error.rs Normal file
View file

@ -0,0 +1,27 @@
use thiserror::Error;
/// Describes things that can go wrong in the Rpxy certificate
#[derive(Debug, Error)]
pub enum RpxyCertError {
/// Error when reading certificates and keys
#[error("Failed to read certificates from file: {0}")]
IoError(#[from] std::io::Error),
/// Error when parsing certificates and keys to generate a rustls CertifiedKey
#[error("Unable to find a valid certificate and key")]
InvalidCertificateAndKey,
/// Error when parsing client CA certificates: No client certificate found
#[error("No client certificate found")]
NoClientCert,
/// Error for hot reload certificate reloader
#[error("Certificate reload error: {0}")]
CertificateReloadError(#[from] hot_reload::ReloaderError<crate::server_crypto::ServerCryptoBase>),
/// Error when converting server name bytes to string
#[error("Failed to convert server name bytes to string: {0}")]
ServerNameBytesToString(#[from] std::string::FromUtf8Error),
/// Rustls error
#[error("Rustls error: {0}")]
RustlsError(#[from] rustls::Error),
/// Rustls CryptoProvider error
#[error("Rustls No default CryptoProvider error")]
NoDefaultCryptoProvider,
}

64
rpxy-certs/src/lib.rs Normal file
View file

@ -0,0 +1,64 @@
mod certs;
mod crypto_source;
mod error;
mod reloader_service;
mod server_crypto;
#[allow(unused_imports)]
mod log {
pub(super) use tracing::{debug, error, info, warn};
}
use crate::{error::*, log::*, reloader_service::DynCryptoSource};
use hot_reload::{ReloaderReceiver, ReloaderService};
use rustc_hash::FxHashMap as HashMap;
use rustls::crypto::{aws_lc_rs, CryptoProvider};
use std::sync::Arc;
/* ------------------------------------------------ */
pub use crate::{
certs::SingleServerCertsKeys,
crypto_source::{CryptoFileSource, CryptoFileSourceBuilder, CryptoFileSourceBuilderError, CryptoSource},
reloader_service::CryptoReloader,
server_crypto::{ServerCrypto, ServerCryptoBase},
};
/* ------------------------------------------------ */
// Constants
/// Default delay in seconds to watch certificates
const DEFAULT_CERTS_WATCH_DELAY_SECS: u32 = 60;
/// Load certificates only when updated
const LOAD_CERTS_ONLY_WHEN_UPDATED: bool = true;
/// Result type inner of certificate reloader service
type ReloaderServiceResultInner = (
ReloaderService<CryptoReloader, ServerCryptoBase>,
ReloaderReceiver<ServerCryptoBase>,
);
/// Build certificate reloader service, which accepts a map of server names to `CryptoSource` instances
pub async fn build_cert_reloader<T>(
crypto_source_map: &HashMap<String, T>,
certs_watch_period: Option<u32>,
) -> Result<ReloaderServiceResultInner, RpxyCertError>
where
T: CryptoSource<Error = RpxyCertError> + Send + Sync + Clone + 'static,
{
info!("Building certificate reloader service");
// Install aws_lc_rs as default crypto provider for rustls
let _ = CryptoProvider::install_default(aws_lc_rs::default_provider());
let source = crypto_source_map
.iter()
.map(|(k, v)| {
let server_name_bytes = k.as_bytes().to_vec().to_ascii_lowercase();
let dyn_crypto_source = Arc::new(Box::new(v.clone()) as Box<DynCryptoSource>);
(server_name_bytes, dyn_crypto_source)
})
.collect::<HashMap<_, _>>();
let certs_watch_period = certs_watch_period.unwrap_or(DEFAULT_CERTS_WATCH_DELAY_SECS);
let (cert_reloader_service, cert_reloader_rx) =
ReloaderService::<CryptoReloader, ServerCryptoBase>::new(&source, certs_watch_period, !LOAD_CERTS_ONLY_WHEN_UPDATED).await?;
Ok((cert_reloader_service, cert_reloader_rx))
}

View file

@ -0,0 +1,84 @@
use crate::{
crypto_source::CryptoSource,
error::*,
log::*,
server_crypto::{ServerCryptoBase, ServerNameBytes},
};
use async_trait::async_trait;
use hot_reload::{Reload, ReloaderError};
use rustc_hash::FxHashMap as HashMap;
use std::sync::Arc;
/* ------------------------------------------------ */
/// Boxed CryptoSource trait object with Send and Sync
/// TODO: support for not only `CryptoFileSource` but also other type of sources
pub(super) type DynCryptoSource = dyn CryptoSource<Error = RpxyCertError> + Send + Sync + 'static;
#[derive(Clone)]
/// Reloader service for certificates and keys for TLS
pub struct CryptoReloader {
inner: HashMap<ServerNameBytes, Arc<Box<DynCryptoSource>>>,
}
impl<T> Extend<(ServerNameBytes, T)> for CryptoReloader
where
T: CryptoSource<Error = RpxyCertError> + Send + Sync + 'static,
{
fn extend<I: IntoIterator<Item = (ServerNameBytes, T)>>(&mut self, iter: I) {
let iter = iter
.into_iter()
.map(|(k, v)| (k, Arc::new(Box::new(v) as Box<DynCryptoSource>)));
self.inner.extend(iter);
}
}
#[async_trait]
impl Reload<ServerCryptoBase> for CryptoReloader {
type Source = HashMap<ServerNameBytes, Arc<Box<DynCryptoSource>>>;
async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> {
let mut inner = HashMap::default();
inner.extend(source.clone());
Ok(Self { inner })
}
async fn reload(&self) -> Result<Option<ServerCryptoBase>, ReloaderError<ServerCryptoBase>> {
let mut server_crypto_base = ServerCryptoBase::default();
for (server_name_bytes, crypto_source) in self.inner.iter() {
let certs_keys = crypto_source.read().await.map_err(|e| {
error!("Failed to reload cert, key or ca cert: {e}");
ReloaderError::<ServerCryptoBase>::Reload("Failed to reload cert, key or ca cert")
})?;
server_crypto_base.inner.insert(server_name_bytes.clone(), certs_keys);
}
Ok(Some(server_crypto_base))
}
}
/* ------------------------------------------------ */
#[cfg(test)]
mod tests {
use super::*;
use crate::crypto_source::CryptoFileSourceBuilder;
#[tokio::test]
async fn test_crypto_reloader() {
let tls_cert_path = "../example-certs/server.crt";
let tls_cert_key_path = "../example-certs/server.key";
let client_ca_cert_path = Some("../example-certs/client.ca.crt");
let mut crypto_reloader = CryptoReloader::new(&HashMap::default()).await.unwrap();
let crypto_source = CryptoFileSourceBuilder::default()
.tls_cert_path(tls_cert_path)
.tls_cert_key_path(tls_cert_key_path)
.client_ca_cert_path(client_ca_cert_path)
.build()
.unwrap();
crypto_reloader.extend(vec![(b"localhost".to_vec(), crypto_source)]);
let server_crypto_base = crypto_reloader.reload().await.unwrap().unwrap();
assert_eq!(server_crypto_base.inner.len(), 1);
}
}

View file

@ -0,0 +1,206 @@
use crate::{certs::SingleServerCertsKeys, error::*, log::*};
use rustc_hash::FxHashMap as HashMap;
use rustls::{
crypto::CryptoProvider,
server::{ResolvesServerCertUsingSni, WebPkiClientVerifier},
RootCertStore, ServerConfig,
};
use std::sync::Arc;
/* ------------------------------------------------ */
/// ServerName in bytes type (TODO: this may be changed to define `common` layer defining types of names. or should be independent?)
pub type ServerNameBytes = Vec<u8>;
/// Convert ServerName in bytes to string
fn server_name_bytes_to_string(server_name_bytes: &ServerNameBytes) -> Result<String, RpxyCertError> {
let server_name = String::from_utf8(server_name_bytes.to_ascii_lowercase())?;
Ok(server_name)
}
/// ServerName (SNI) to ServerConfig map type
pub type ServerNameCryptoMap = HashMap<ServerNameBytes, Arc<ServerConfig>>;
/// ServerName (SNI) to ServerConfig map
pub struct ServerCrypto {
// For Quic/HTTP3, only servers with no client authentication, aggregated server config
pub aggregated_config_no_client_auth: Arc<ServerConfig>,
// For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers
pub individual_config_map: Arc<ServerNameCryptoMap>,
}
/* ------------------------------------------------ */
/// Reloader target for the certificate reloader service
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub struct ServerCryptoBase {
/// Map of server name to certs and keys
pub(super) inner: HashMap<ServerNameBytes, SingleServerCertsKeys>,
}
impl TryInto<Arc<ServerCrypto>> for &ServerCryptoBase {
type Error = RpxyCertError;
fn try_into(self) -> Result<Arc<ServerCrypto>, Self::Error> {
let aggregated = self.build_aggrated_server_crypto()?;
let individual = self.build_individual_server_crypto_map()?;
Ok(Arc::new(ServerCrypto {
aggregated_config_no_client_auth: Arc::new(aggregated),
individual_config_map: Arc::new(individual),
}))
}
}
impl ServerCryptoBase {
/// Build individual server crypto inner object
fn build_individual_server_crypto_map(&self) -> Result<ServerNameCryptoMap, RpxyCertError> {
let mut server_crypto_map: ServerNameCryptoMap = HashMap::default();
// AWS LC provider by default
let provider = CryptoProvider::get_default().ok_or(RpxyCertError::NoDefaultCryptoProvider)?;
for (server_name_bytes, certs_keys) in self.inner.iter() {
let server_name = server_name_bytes_to_string(server_name_bytes)?;
// Parse server certificates and private keys
let Ok(certified_key) = certs_keys.rustls_certified_key() else {
warn!("Failed to add certificate for {server_name}");
continue;
};
let mut resolver_local = ResolvesServerCertUsingSni::new();
if let Err(e) = resolver_local.add(&server_name, certified_key) {
error!("{server_name}: Failed to read some certificates and keys {e}");
};
// With no client authentication case
if !certs_keys.is_mutual_tls() {
let mut server_crypto_local = ServerConfig::builder_with_provider(provider.clone())
.with_safe_default_protocol_versions()?
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver_local));
#[cfg(feature = "http3")]
{
server_crypto_local.alpn_protocols = vec![b"h3".to_vec(), b"h2".to_vec(), b"http/1.1".to_vec()];
}
#[cfg(not(feature = "http3"))]
{
server_crypto_local.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
}
server_crypto_map.insert(server_name_bytes.clone(), Arc::new(server_crypto_local));
continue;
}
// With client authentication case, enable only http2 and http1.1
let mut client_ca_roots_local = RootCertStore::empty();
let Ok(trust_anchors) = certs_keys.rustls_client_certs_trust_anchors() else {
warn!("Failed to add client CA certificate for {server_name}");
continue;
};
let trust_anchors_without_skid = trust_anchors.values().map(|ta| ta.to_owned());
client_ca_roots_local.extend(trust_anchors_without_skid);
let Ok(client_cert_verifier) =
WebPkiClientVerifier::builder_with_provider(Arc::new(client_ca_roots_local), provider.clone()).build()
else {
warn!("Failed to build client CA certificate verifier for {server_name}");
continue;
};
let mut server_crypto_local = ServerConfig::builder_with_provider(provider.clone())
.with_safe_default_protocol_versions()?
.with_client_cert_verifier(client_cert_verifier)
.with_cert_resolver(Arc::new(resolver_local));
server_crypto_local.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
server_crypto_map.insert(server_name_bytes.clone(), Arc::new(server_crypto_local));
}
Ok(server_crypto_map)
}
/* ------------------------------------------------ */
/// Build aggregated server crypto inner object for no client auth server especially for http3
fn build_aggrated_server_crypto(&self) -> Result<ServerConfig, RpxyCertError> {
let mut resolver_global = ResolvesServerCertUsingSni::new();
// AWS LC provider by default
let provider = CryptoProvider::get_default().ok_or(RpxyCertError::NoDefaultCryptoProvider)?;
for (server_name_bytes, certs_keys) in self.inner.iter() {
let server_name = server_name_bytes_to_string(server_name_bytes)?;
// Parse server certificates and private keys
let Ok(certified_key) = certs_keys.rustls_certified_key() else {
warn!("Failed to add certificate for {server_name}");
continue;
};
// Add server certificates and private keys to resolver only if client CA certs are not present
if !certs_keys.is_mutual_tls() {
// aggregated server config for no client auth server for http3
if let Err(e) = resolver_global.add(&server_name, certified_key) {
error!("{server_name}: Failed to read some certificates and keys {e}");
};
}
}
let mut server_crypto_global = ServerConfig::builder_with_provider(provider.clone())
.with_safe_default_protocol_versions()?
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver_global));
#[cfg(feature = "http3")]
{
server_crypto_global.alpn_protocols = vec![b"h3".to_vec(), b"h2".to_vec(), b"http/1.1".to_vec()];
}
#[cfg(not(feature = "http3"))]
{
server_crypto_global.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
}
Ok(server_crypto_global)
}
}
/* ------------------------------------------------ */
#[cfg(test)]
mod tests {
use super::*;
use crate::{CryptoFileSourceBuilder, CryptoSource};
use std::convert::TryInto;
async fn read_file_source() -> SingleServerCertsKeys {
let tls_cert_path = "../example-certs/server.crt";
let tls_cert_key_path = "../example-certs/server.key";
let client_ca_cert_path = Some("../example-certs/client.ca.crt");
let crypto_file_source = CryptoFileSourceBuilder::default()
.tls_cert_key_path(tls_cert_key_path)
.tls_cert_path(tls_cert_path)
.client_ca_cert_path(client_ca_cert_path)
.build();
crypto_file_source.unwrap().read().await.unwrap()
}
#[tokio::test]
async fn test_server_crypto_base_try_into() {
let _ = CryptoProvider::install_default(rustls::crypto::aws_lc_rs::default_provider());
let mut server_crypto_base = ServerCryptoBase::default();
let single_certs_keys = read_file_source().await;
server_crypto_base.inner.insert(b"localhost".to_vec(), single_certs_keys);
let server_crypto: Arc<ServerCrypto> = (&server_crypto_base).try_into().unwrap();
assert_eq!(server_crypto.individual_config_map.len(), 1);
#[cfg(feature = "http3")]
{
assert_eq!(
server_crypto.aggregated_config_no_client_auth.alpn_protocols,
vec![b"h3".to_vec(), b"h2".to_vec(), b"http/1.1".to_vec()]
);
}
#[cfg(not(feature = "http3"))]
{
assert_eq!(
server_crypto.aggregated_config_no_client_auth.alpn_protocols,
vec![b"h2".to_vec(), b"http/1.1".to_vec()]
);
}
}
}

View file

@ -13,14 +13,16 @@ publish.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
# default = ["http3-s2n", "sticky-cookie", "cache", "rustls-backend"]
default = ["http3-quinn", "sticky-cookie", "cache", "rustls-backend"]
http3-quinn = ["socket2", "quinn", "h3", "h3-quinn"]
http3-quinn = ["socket2", "quinn", "h3", "h3-quinn", "rpxy-certs/http3"]
http3-s2n = [
"h3",
"s2n-quic",
"s2n-quic-core",
"s2n-quic-rustls",
"s2n-quic-h3",
"rpxy-certs/http3",
]
cache = ["http-cache-semantics", "lru", "sha2", "base64"]
sticky-cookie = ["base64", "sha2", "chrono"]
@ -34,7 +36,7 @@ rustc-hash = "1.1.0"
bytes = "1.6.0"
derive_builder = "0.20.0"
futures = { version = "0.3.30", features = ["alloc", "async-await"] }
tokio = { version = "1.37.0", default-features = false, features = [
tokio = { version = "1.38.0", default-features = false, features = [
"net",
"rt-multi-thread",
"time",
@ -53,7 +55,7 @@ thiserror = "1.0.61"
http = "1.1.0"
http-body-util = "0.1.1"
hyper = { version = "1.3.1", default-features = false }
hyper-util = { version = "0.1.3", features = ["full"] }
hyper-util = { version = "0.1.5", features = ["full"] }
futures-util = { version = "0.3.30", default-features = false }
futures-channel = { version = "0.3.30", default-features = false }
@ -62,7 +64,7 @@ hyper-tls = { version = "0.6.0", features = [
"alpn",
"vendored",
], optional = true }
hyper-rustls = { version = "0.27.1", default-features = false, features = [
hyper-rustls = { version = "0.27.2", default-features = false, features = [
"ring",
"native-tokio",
"http1",
@ -70,25 +72,32 @@ hyper-rustls = { version = "0.27.1", default-features = false, features = [
], optional = true }
# tls and cert management for server
rpxy-certs = { path = "../rpxy-certs/", default-features = false }
hot_reload = "0.1.5"
rustls = { version = "0.21.12", default-features = false }
tokio-rustls = { version = "0.24.1", features = ["early-data"] }
webpki = "0.22.4"
x509-parser = "0.16.0"
rustls = { version = "0.23.8", default-features = false }
tokio-rustls = { version = "0.26.0", features = ["early-data"] }
# logging
tracing = { version = "0.1.40" }
# http/3
quinn = { version = "0.10.2", optional = true }
h3 = { path = "../submodules/h3/h3/", optional = true }
h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true }
s2n-quic = { version = "1.37.0", default-features = false, features = [
quinn = { version = "0.11.1", optional = true }
h3 = { version = "0.0.5", optional = true }
h3-quinn = { version = "0.0.6", optional = true }
### TODO: workaround for s2n-quic, waiting for release of s2n-quic-0.38.0
s2n-quic = { path = "../submodules/s2n-quic/quic/s2n-quic", optional = true, default-features = false, features = [
"provider-tls-rustls",
], optional = true }
s2n-quic-core = { version = "0.37.0", default-features = false, optional = true }
s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true }
s2n-quic-rustls = { version = "0.37.0", optional = true }
] }
s2n-quic-core = { path = "../submodules/s2n-quic/quic/s2n-quic-core", optional = true, default-features = false }
s2n-quic-rustls = { path = "../submodules/s2n-quic/quic/s2n-quic-rustls", optional = true }
s2n-quic-h3 = { path = "../submodules/s2n-quic/quic/s2n-quic-h3", optional = true }
# s2n-quic = { version = "1.37.0", default-features = false, features = [
# "provider-tls-rustls",
# ], optional = true }
# s2n-quic-core = { version = "0.37.0", default-features = false, optional = true }
# s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true }
# s2n-quic-rustls = { version = "0.37.0", optional = true }
##########
# for UDP socket wit SO_REUSEADDR when h3 with quinn
socket2 = { version = "0.5.7", features = ["all"], optional = true }

View file

@ -1,5 +1,4 @@
use crate::{
crypto::CryptoSource,
error::*,
log::*,
name_exp::{ByteName, ServerName},
@ -13,10 +12,7 @@ use super::upstream::PathManager;
/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
#[derive(Builder)]
pub struct BackendApp<T>
where
T: CryptoSource,
{
pub struct BackendApp {
#[builder(setter(into))]
/// backend application name, e.g., app1
pub app_name: String,
@ -28,50 +24,30 @@ where
/// tls settings: https redirection with 30x
#[builder(default)]
pub https_redirection: Option<bool>,
/// TLS settings: source meta for server cert, key, client ca cert
/// tls settings: mutual TLS is enabled
#[builder(default)]
pub crypto_source: Option<T>,
pub mutual_tls: Option<bool>,
}
impl<'a, T> BackendAppBuilder<T>
where
T: CryptoSource,
{
impl<'a> BackendAppBuilder {
pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self {
self.server_name = Some(server_name.to_server_name());
self
}
}
#[derive(Default)]
/// HashMap and some meta information for multiple Backend structs.
pub struct BackendAppManager<T>
where
T: CryptoSource,
{
pub struct BackendAppManager {
/// HashMap of Backend structs, key is server name
pub apps: HashMap<ServerName, BackendApp<T>>,
pub apps: HashMap<ServerName, BackendApp>,
/// for plaintext http
pub default_server_name: Option<ServerName>,
}
impl<T> Default for BackendAppManager<T>
where
T: CryptoSource,
{
fn default() -> Self {
Self {
apps: HashMap::<ServerName, BackendApp<T>>::default(),
default_server_name: None,
}
}
}
impl<T> TryFrom<&AppConfig<T>> for BackendApp<T>
where
T: CryptoSource + Clone,
{
impl TryFrom<&AppConfig> for BackendApp {
type Error = RpxyError;
fn try_from(app_config: &AppConfig<T>) -> Result<Self, Self::Error> {
fn try_from(app_config: &AppConfig) -> Result<Self, Self::Error> {
let mut backend_builder = BackendAppBuilder::default();
let path_manager = PathManager::try_from(app_config)?;
backend_builder
@ -85,26 +61,21 @@ where
let tls = app_config.tls.as_ref().unwrap();
backend_builder
.https_redirection(Some(tls.https_redirection))
.crypto_source(Some(tls.inner.clone()))
.mutual_tls(Some(tls.mutual_tls))
.build()?
};
Ok(backend)
}
}
impl<T> TryFrom<&AppConfigList<T>> for BackendAppManager<T>
where
T: CryptoSource + Clone,
{
impl TryFrom<&AppConfigList> for BackendAppManager {
type Error = RpxyError;
fn try_from(config_list: &AppConfigList<T>) -> Result<Self, Self::Error> {
fn try_from(config_list: &AppConfigList) -> Result<Self, Self::Error> {
let mut manager = Self::default();
for app_config in config_list.inner.iter() {
let backend: BackendApp<T> = BackendApp::try_from(app_config)?;
manager
.apps
.insert(app_config.server_name.clone().to_server_name(), backend);
let backend: BackendApp = BackendApp::try_from(app_config)?;
manager.apps.insert(app_config.server_name.clone().to_server_name(), backend);
info!(
"Registering application {} ({})",

View file

@ -6,7 +6,6 @@ use super::load_balance::{
// use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption};
use super::upstream_opts::UpstreamOption;
use crate::{
crypto::CryptoSource,
error::RpxyError,
globals::{AppConfig, UpstreamUri},
log::*,
@ -28,12 +27,9 @@ pub struct PathManager {
inner: HashMap<PathName, UpstreamCandidates>,
}
impl<T> TryFrom<&AppConfig<T>> for PathManager
where
T: CryptoSource,
{
impl TryFrom<&AppConfig> for PathManager {
type Error = RpxyError;
fn try_from(app_config: &AppConfig<T>) -> Result<Self, Self::Error> {
fn try_from(app_config: &AppConfig) -> Result<Self, Self::Error> {
let mut inner: HashMap<PathName, UpstreamCandidates> = HashMap::default();
app_config.reverse_proxy.iter().for_each(|rpc| {

View file

@ -1,21 +1,10 @@
pub const RESPONSE_HEADER_SERVER: &str = "rpxy";
// pub const LISTEN_ADDRESSES_V4: &[&str] = &["0.0.0.0"];
// pub const LISTEN_ADDRESSES_V6: &[&str] = &["[::]"];
pub const TCP_LISTEN_BACKLOG: u32 = 1024;
// pub const HTTP_LISTEN_PORT: u16 = 8080;
// pub const HTTPS_LISTEN_PORT: u16 = 8443;
pub const PROXY_IDLE_TIMEOUT_SEC: u64 = 20;
pub const UPSTREAM_IDLE_TIMEOUT_SEC: u64 = 20;
pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser
pub const MAX_CLIENTS: usize = 512;
pub const MAX_CONCURRENT_STREAMS: u32 = 64;
pub const CERTS_WATCH_DELAY_SECS: u32 = 60;
pub const LOAD_CERTS_ONLY_WHEN_UPDATED: bool = true;
// #[cfg(feature = "http3")]
// pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB
// #[cfg(feature = "http3")]
// pub const H3_REQUEST_BUF_SIZE: usize = 65_536; // 64KB // handled by quinn
#[allow(non_snake_case)]
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]

View file

@ -1,91 +0,0 @@
use async_trait::async_trait;
use rustc_hash::FxHashSet as HashSet;
use rustls::{
sign::{any_supported_type, CertifiedKey},
Certificate, OwnedTrustAnchor, PrivateKey,
};
use std::io;
use x509_parser::prelude::*;
#[async_trait]
// Trait to read certs and keys anywhere from KVS, file, sqlite, etc.
pub trait CryptoSource {
type Error;
/// read crypto materials from source
async fn read(&self) -> Result<CertsAndKeys, Self::Error>;
/// Returns true when mutual tls is enabled
fn is_mutual_tls(&self) -> bool;
}
/// Certificates and private keys in rustls loaded from files
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct CertsAndKeys {
pub certs: Vec<Certificate>,
pub cert_keys: Vec<PrivateKey>,
pub client_ca_certs: Option<Vec<Certificate>>,
}
impl CertsAndKeys {
pub fn parse_server_certs_and_keys(&self) -> Result<CertifiedKey, anyhow::Error> {
// for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() {
let signing_key = self
.cert_keys
.iter()
.find_map(|k| {
if let Ok(sk) = any_supported_type(k) {
Some(sk)
} else {
None
}
})
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::InvalidInput,
"Unable to find a valid certificate and key",
)
})?;
Ok(CertifiedKey::new(self.certs.clone(), signing_key))
}
pub fn parse_client_ca_certs(&self) -> Result<(Vec<OwnedTrustAnchor>, HashSet<Vec<u8>>), anyhow::Error> {
let certs = self.client_ca_certs.as_ref().ok_or(anyhow::anyhow!("No client cert"))?;
let owned_trust_anchors: Vec<_> = certs
.iter()
.map(|v| {
// let trust_anchor = tokio_rustls::webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap();
let trust_anchor = webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap();
rustls::OwnedTrustAnchor::from_subject_spki_name_constraints(
trust_anchor.subject,
trust_anchor.spki,
trust_anchor.name_constraints,
)
})
.collect();
// TODO: SKID is not used currently
let subject_key_identifiers: HashSet<_> = certs
.iter()
.filter_map(|v| {
// retrieve ca key id (subject key id)
let cert = parse_x509_certificate(&v.0).unwrap().1;
let subject_key_ids = cert
.iter_extensions()
.filter_map(|ext| match ext.parsed_extension() {
ParsedExtension::SubjectKeyIdentifier(skid) => Some(skid),
_ => None,
})
.collect::<Vec<_>>();
if !subject_key_ids.is_empty() {
Some(subject_key_ids[0].0.to_owned())
} else {
None
}
})
.collect();
Ok((owned_trust_anchors, subject_key_identifiers))
}
}

View file

@ -1,36 +0,0 @@
mod certs;
mod service;
use crate::{
backend::BackendAppManager,
constants::{CERTS_WATCH_DELAY_SECS, LOAD_CERTS_ONLY_WHEN_UPDATED},
error::RpxyResult,
};
use hot_reload::{ReloaderReceiver, ReloaderService};
use service::CryptoReloader;
use std::sync::Arc;
pub use certs::{CertsAndKeys, CryptoSource};
pub use service::{ServerCrypto, ServerCryptoBase, SniServerCryptoMap};
/// Result type inner of certificate reloader service
type ReloaderServiceResultInner<T> = (
ReloaderService<CryptoReloader<T>, ServerCryptoBase>,
ReloaderReceiver<ServerCryptoBase>,
);
/// Build certificate reloader service
pub(crate) async fn build_cert_reloader<T>(
app_manager: &Arc<BackendAppManager<T>>,
) -> RpxyResult<ReloaderServiceResultInner<T>>
where
T: CryptoSource + Clone + Send + Sync + 'static,
{
let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<
service::CryptoReloader<T>,
service::ServerCryptoBase,
>::new(
app_manager, CERTS_WATCH_DELAY_SECS, !LOAD_CERTS_ONLY_WHEN_UPDATED
)
.await?;
Ok((cert_reloader_service, cert_reloader_rx))
}

View file

@ -1,251 +0,0 @@
use super::certs::{CertsAndKeys, CryptoSource};
use crate::{backend::BackendAppManager, log::*, name_exp::ServerName};
use async_trait::async_trait;
use hot_reload::*;
use rustc_hash::FxHashMap as HashMap;
use rustls::{server::ResolvesServerCertUsingSni, sign::CertifiedKey, RootCertStore, ServerConfig};
use std::sync::Arc;
#[derive(Clone)]
/// Reloader service for certificates and keys for TLS
pub struct CryptoReloader<T>
where
T: CryptoSource,
{
inner: Arc<BackendAppManager<T>>,
}
/// SNI to ServerConfig map type
pub type SniServerCryptoMap = HashMap<ServerName, Arc<ServerConfig>>;
/// SNI to ServerConfig map
pub struct ServerCrypto {
// For Quic/HTTP3, only servers with no client authentication
#[cfg(feature = "http3-quinn")]
pub inner_global_no_client_auth: Arc<ServerConfig>,
#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))]
pub inner_global_no_client_auth: s2n_quic_rustls::Server,
// For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers
pub inner_local_map: Arc<SniServerCryptoMap>,
}
/// Reloader target for the certificate reloader service
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub struct ServerCryptoBase {
inner: HashMap<ServerName, CertsAndKeys>,
}
#[async_trait]
impl<T> Reload<ServerCryptoBase> for CryptoReloader<T>
where
T: CryptoSource + Sync + Send,
{
type Source = Arc<BackendAppManager<T>>;
async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> {
Ok(Self { inner: source.clone() })
}
async fn reload(&self) -> Result<Option<ServerCryptoBase>, ReloaderError<ServerCryptoBase>> {
let mut certs_and_keys_map = ServerCryptoBase::default();
for (server_name_bytes_exp, backend) in self.inner.apps.iter() {
if let Some(crypto_source) = &backend.crypto_source {
let certs_and_keys = crypto_source
.read()
.await
.map_err(|_e| ReloaderError::<ServerCryptoBase>::Reload("Failed to reload cert, key or ca cert"))?;
certs_and_keys_map
.inner
.insert(server_name_bytes_exp.to_owned(), certs_and_keys);
}
}
Ok(Some(certs_and_keys_map))
}
}
impl TryInto<Arc<ServerCrypto>> for &ServerCryptoBase {
type Error = anyhow::Error;
fn try_into(self) -> Result<Arc<ServerCrypto>, Self::Error> {
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
let server_crypto_global = self.build_server_crypto_global()?;
let server_crypto_local_map: SniServerCryptoMap = self.build_server_crypto_local_map()?;
Ok(Arc::new(ServerCrypto {
#[cfg(feature = "http3-quinn")]
inner_global_no_client_auth: Arc::new(server_crypto_global),
#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))]
inner_global_no_client_auth: server_crypto_global,
inner_local_map: Arc::new(server_crypto_local_map),
}))
}
}
impl ServerCryptoBase {
fn build_server_crypto_local_map(&self) -> Result<SniServerCryptoMap, ReloaderError<ServerCryptoBase>> {
let mut server_crypto_local_map: SniServerCryptoMap = HashMap::default();
for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() {
let server_name: String = server_name_bytes_exp.try_into()?;
// Parse server certificates and private keys
let Ok(certified_key): Result<CertifiedKey, _> = certs_and_keys.parse_server_certs_and_keys() else {
warn!("Failed to add certificate for {}", server_name);
continue;
};
let mut resolver_local = ResolvesServerCertUsingSni::new();
let mut client_ca_roots_local = RootCertStore::empty();
// add server certificate and key
if let Err(e) = resolver_local.add(server_name.as_str(), certified_key.to_owned()) {
error!("{}: Failed to read some certificates and keys {}", server_name.as_str(), e)
}
// add client certificate if specified
if certs_and_keys.client_ca_certs.is_some() {
// add client certificate if specified
match certs_and_keys.parse_client_ca_certs() {
Ok((owned_trust_anchors, _subject_key_ids)) => {
client_ca_roots_local.add_trust_anchors(owned_trust_anchors.into_iter());
}
Err(e) => {
warn!("Failed to add client CA certificate for {}: {}", server_name.as_str(), e);
}
}
}
let mut server_config_local = if client_ca_roots_local.is_empty() {
// with no client auth, enable http1.1 -- 3
#[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))]
{
ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver_local))
}
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
{
let mut sc = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver_local));
sc.alpn_protocols = vec![b"h3".to_vec(), b"hq-29".to_vec()]; // TODO: remove hq-29 later?
sc
}
} else {
// with client auth, enable only http1.1 and 2
// let client_certs_verifier = rustls::server::AllowAnyAnonymousOrAuthenticatedClient::new(client_ca_roots);
let client_certs_verifier = rustls::server::AllowAnyAuthenticatedClient::new(client_ca_roots_local);
ServerConfig::builder()
.with_safe_defaults()
.with_client_cert_verifier(Arc::new(client_certs_verifier))
.with_cert_resolver(Arc::new(resolver_local))
};
server_config_local.alpn_protocols.push(b"h2".to_vec());
server_config_local.alpn_protocols.push(b"http/1.1".to_vec());
server_crypto_local_map.insert(server_name_bytes_exp.to_owned(), Arc::new(server_config_local));
}
Ok(server_crypto_local_map)
}
#[cfg(feature = "http3-quinn")]
fn build_server_crypto_global(&self) -> Result<ServerConfig, ReloaderError<ServerCryptoBase>> {
let mut resolver_global = ResolvesServerCertUsingSni::new();
for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() {
let server_name: String = server_name_bytes_exp.try_into()?;
// Parse server certificates and private keys
let Ok(certified_key): Result<CertifiedKey, _> = certs_and_keys.parse_server_certs_and_keys() else {
warn!("Failed to add certificate for {}", server_name);
continue;
};
if certs_and_keys.client_ca_certs.is_none() {
// aggregated server config for no client auth server for http3
if let Err(e) = resolver_global.add(server_name.as_str(), certified_key) {
error!("{}: Failed to read some certificates and keys {}", server_name.as_str(), e)
}
}
}
//////////////
let mut server_crypto_global = ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver_global));
//////////////////////////////
server_crypto_global.alpn_protocols = vec![
b"h3".to_vec(),
b"hq-29".to_vec(), // TODO: remove later?
b"h2".to_vec(),
b"http/1.1".to_vec(),
];
Ok(server_crypto_global)
}
#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))]
fn build_server_crypto_global(&self) -> Result<s2n_quic_rustls::Server, ReloaderError<ServerCryptoBase>> {
let mut resolver_global = s2n_quic_rustls::rustls::server::ResolvesServerCertUsingSni::new();
for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() {
let server_name: String = server_name_bytes_exp.try_into()?;
// Parse server certificates and private keys
let Ok(certified_key) = parse_server_certs_and_keys_s2n(certs_and_keys) else {
warn!("Failed to add certificate for {}", server_name);
continue;
};
if certs_and_keys.client_ca_certs.is_none() {
// aggregated server config for no client auth server for http3
if let Err(e) = resolver_global.add(server_name.as_str(), certified_key) {
error!("{}: Failed to read some certificates and keys {}", server_name.as_str(), e)
}
}
}
let alpn = [
b"h3".to_vec(),
b"hq-29".to_vec(), // TODO: remove later?
b"h2".to_vec(),
b"http/1.1".to_vec(),
];
let server_crypto_global = s2n_quic::provider::tls::rustls::Server::builder()
.with_cert_resolver(Arc::new(resolver_global))
.map_err(|e| anyhow::anyhow!(e))?
.with_application_protocols(alpn.iter())
.map_err(|e| anyhow::anyhow!(e))?
.build()
.map_err(|e| anyhow::anyhow!(e))?;
Ok(server_crypto_global)
}
}
#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))]
/// This is workaround for the version difference between rustls and s2n-quic-rustls
fn parse_server_certs_and_keys_s2n(
certs_and_keys: &CertsAndKeys,
) -> Result<s2n_quic_rustls::rustls::sign::CertifiedKey, anyhow::Error> {
let signing_key = certs_and_keys
.cert_keys
.iter()
.find_map(|k| {
let s2n_private_key = s2n_quic_rustls::PrivateKey(k.0.clone());
if let Ok(sk) = s2n_quic_rustls::rustls::sign::any_supported_type(&s2n_private_key) {
Some(sk)
} else {
None
}
})
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidInput, "Unable to find a valid certificate and key"))?;
let certs: Vec<_> = certs_and_keys
.certs
.iter()
.map(|c| s2n_quic_rustls::rustls::Certificate(c.0.clone()))
.collect();
Ok(s2n_quic_rustls::rustls::sign::CertifiedKey::new(certs, signing_key))
}

View file

@ -16,6 +16,10 @@ pub enum RpxyError {
NoServerNameInClientHello,
#[error("No TLS serving app: {0}")]
NoTlsServingApp(String),
#[error("No default crypto provider")]
NoDefaultCryptoProvider,
#[error("Failed to build server config: {0}")]
FailedToBuildServerConfig(String),
#[error("Failed to update server crypto: {0}")]
FailedToUpdateServerCrypto(String),
#[error("No server crypto: {0}")]
@ -60,7 +64,7 @@ pub enum RpxyError {
#[error("No certificate reloader when building a proxy for TLS")]
NoCertificateReloader,
#[error("Certificate reload error: {0}")]
CertificateReloadError(#[from] hot_reload::ReloaderError<crate::crypto::ServerCryptoBase>),
CertificateReloadError(#[from] hot_reload::ReloaderError<rpxy_certs::ServerCryptoBase>),
// backend errors
#[error("Invalid reverse proxy setting")]

View file

@ -1,9 +1,6 @@
use crate::{
constants::*,
count::RequestCount,
crypto::{CryptoSource, ServerCryptoBase},
};
use crate::{constants::*, count::RequestCount};
use hot_reload::ReloaderReceiver;
use rpxy_certs::ServerCryptoBase;
use std::{net::SocketAddr, sync::Arc, time::Duration};
/// Global object containing proxy configurations and shared object like counters.
@ -17,7 +14,7 @@ pub struct Globals {
pub runtime_handle: tokio::runtime::Handle,
/// Shared context - Notify object to stop async tasks
pub term_notify: Option<Arc<tokio::sync::Notify>>,
/// Shared context - Certificate reloader service receiver
/// Shared context - Certificate reloader service receiver // TODO: newer one
pub cert_reloader_rx: Option<ReloaderReceiver<ServerCryptoBase>>,
}
@ -127,24 +124,18 @@ impl Default for ProxyConfig {
/// Configuration parameters for backend applications
#[derive(PartialEq, Eq, Clone)]
pub struct AppConfigList<T>
where
T: CryptoSource,
{
pub inner: Vec<AppConfig<T>>,
pub struct AppConfigList {
pub inner: Vec<AppConfig>,
pub default_app: Option<String>,
}
/// Configuration parameters for single backend application
#[derive(PartialEq, Eq, Clone)]
pub struct AppConfig<T>
where
T: CryptoSource,
{
pub struct AppConfig {
pub app_name: String,
pub server_name: String,
pub reverse_proxy: Vec<ReverseProxyConfig>,
pub tls: Option<TlsConfig<T>>,
pub tls: Option<TlsConfig>,
}
/// Configuration parameters for single reverse proxy corresponding to the path
@ -165,10 +156,7 @@ pub struct UpstreamUri {
/// Configuration parameters on TLS for a single backend application
#[derive(PartialEq, Eq, Clone)]
pub struct TlsConfig<T>
where
T: CryptoSource,
{
pub inner: T,
pub struct TlsConfig {
pub mutual_tls: bool,
pub https_redirection: bool,
}

View file

@ -1,7 +1,6 @@
mod backend;
mod constants;
mod count;
mod crypto;
mod error;
mod forwarder;
mod globals;
@ -10,33 +9,35 @@ mod log;
mod message_handler;
mod name_exp;
mod proxy;
/* ------------------------------------------------ */
use crate::{
crypto::build_cert_reloader, error::*, forwarder::Forwarder, globals::Globals, log::*,
message_handler::HttpMessageHandlerBuilder, proxy::Proxy,
// crypto::build_cert_reloader,
error::*,
forwarder::Forwarder,
globals::Globals,
log::*,
message_handler::HttpMessageHandlerBuilder,
proxy::Proxy,
};
use futures::future::select_all;
use hot_reload::ReloaderReceiver;
use rpxy_certs::ServerCryptoBase;
use std::sync::Arc;
pub use crate::{
crypto::{CertsAndKeys, CryptoSource},
globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri},
};
/* ------------------------------------------------ */
pub use crate::globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri};
pub mod reexports {
pub use hyper::Uri;
pub use rustls::{Certificate, PrivateKey};
}
/// Entrypoint that creates and spawns tasks of reverse proxy services
pub async fn entrypoint<T>(
pub async fn entrypoint(
proxy_config: &ProxyConfig,
app_config_list: &AppConfigList<T>,
app_config_list: &AppConfigList,
cert_rx: Option<&ReloaderReceiver<ServerCryptoBase>>, // TODO:
runtime_handle: &tokio::runtime::Handle,
term_notify: Option<Arc<tokio::sync::Notify>>,
) -> RpxyResult<()>
where
T: CryptoSource + Clone + Send + Sync + 'static,
{
) -> RpxyResult<()> {
#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))]
warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used");
@ -78,25 +79,16 @@ where
// 1. build backends, and make it contained in Arc
let app_manager = Arc::new(backend::BackendAppManager::try_from(app_config_list)?);
// 2. build crypto reloader service
let (cert_reloader_service, cert_reloader_rx) = match proxy_config.https_port {
Some(_) => {
let (s, r) = build_cert_reloader(&app_manager).await?;
(Some(s), Some(r))
}
None => (None, None),
};
// 3. build global shared context
// 2. build global shared context
let globals = Arc::new(Globals {
proxy_config: proxy_config.clone(),
request_count: Default::default(),
runtime_handle: runtime_handle.clone(),
term_notify: term_notify.clone(),
cert_reloader_rx: cert_reloader_rx.clone(),
cert_reloader_rx: cert_rx.cloned(),
});
// 4. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well
// 3. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well
let forwarder = Arc::new(Forwarder::try_new(&globals).await?);
let message_handler = Arc::new(
HttpMessageHandlerBuilder::default()
@ -106,7 +98,7 @@ where
.build()?,
);
// 5. spawn each proxy for a given socket with copied Arc-ed message_handler.
// 4. spawn each proxy for a given socket with copied Arc-ed message_handler.
// build hyper connection builder shared with proxy instances
let connection_builder = proxy::connection_builder(&globals);
@ -127,23 +119,9 @@ where
globals.runtime_handle.spawn(async move { proxy.start().await })
});
// wait for all future
match cert_reloader_service {
Some(cert_service) => {
tokio::select! {
_ = cert_service.start() => {
error!("Certificate reloader service got down");
}
_ = select_all(futures_iter) => {
error!("Some proxy services are down");
}
}
}
None => {
if let (Ok(Err(e)), _, _) = select_all(futures_iter).await {
error!("Some proxy services are down: {}", e);
}
}
if let (Ok(Err(e)), _, _) = select_all(futures_iter).await {
error!("Some proxy services are down: {}", e);
return Err(e);
}
Ok(())

View file

@ -7,7 +7,6 @@ use super::{
};
use crate::{
backend::{BackendAppManager, LoadBalanceContext},
crypto::CryptoSource,
error::*,
forwarder::{ForwardRequest, Forwarder},
globals::Globals,
@ -34,20 +33,18 @@ pub(super) struct HandlerContext {
#[derive(Clone, Builder)]
/// HTTP message handler for requests from clients and responses from backend applications,
/// responsible to manipulate and forward messages to upstream backends and downstream clients.
pub struct HttpMessageHandler<U, C>
pub struct HttpMessageHandler<C>
where
C: Send + Sync + Connect + Clone + 'static,
U: CryptoSource + Clone,
{
forwarder: Arc<Forwarder<C>>,
pub(super) globals: Arc<Globals>,
app_manager: Arc<BackendAppManager<U>>,
app_manager: Arc<BackendAppManager>,
}
impl<U, C> HttpMessageHandler<U, C>
impl<C> HttpMessageHandler<C>
where
C: Send + Sync + Connect + Clone + 'static,
U: CryptoSource + Clone,
{
/// Handle incoming request message from a client.
/// Responsible to passthrough responses from backend applications or generate synthetic error responses.
@ -64,14 +61,7 @@ where
log_data.client_addr(&client_addr);
let http_result = self
.handle_request_inner(
&mut log_data,
req,
client_addr,
listen_addr,
tls_enabled,
tls_server_name,
)
.handle_request_inner(&mut log_data, req, client_addr, listen_addr, tls_enabled, tls_server_name)
.await;
// passthrough or synthetic response

View file

@ -3,17 +3,15 @@ use crate::{
backend::{BackendApp, UpstreamCandidates},
constants::RESPONSE_HEADER_SERVER,
log::*,
CryptoSource,
};
use anyhow::{anyhow, ensure, Result};
use http::{header, HeaderValue, Request, Response, Uri};
use hyper_util::client::legacy::connect::Connect;
use std::net::SocketAddr;
impl<U, C> HttpMessageHandler<U, C>
impl<C> HttpMessageHandler<C>
where
C: Send + Sync + Connect + Clone + 'static,
U: CryptoSource + Clone,
{
////////////////////////////////////////////////////
// Functions to generate messages
@ -21,7 +19,7 @@ where
#[allow(unused_variables)]
/// Manipulate a response message sent from a backend application to forward downstream to a client.
pub(super) fn generate_response_forwarded<B>(&self, response: &mut Response<B>, backend_app: &BackendApp<U>) -> Result<()> {
pub(super) fn generate_response_forwarded<B>(&self, response: &mut Response<B>, backend_app: &BackendApp) -> Result<()> {
let headers = response.headers_mut();
remove_connection_header(headers);
remove_hop_header(headers);
@ -31,15 +29,15 @@ where
{
// Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled
// TODO: This is a workaround for avoiding a client authentication in HTTP/3
if self.globals.proxy_config.http3 && backend_app.crypto_source.as_ref().is_some_and(|v| !v.is_mutual_tls()) {
if self.globals.proxy_config.http3
&& backend_app.https_redirection.is_some()
&& backend_app.mutual_tls.as_ref().is_some_and(|v| !v)
{
if let Some(port) = self.globals.proxy_config.https_port {
add_header_entry_overwrite_if_exist(
headers,
header::ALT_SVC.as_str(),
format!(
"h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age
),
format!("h3=\":{}\"; ma={}", port, self.globals.proxy_config.h3_alt_svc_max_age),
)?;
}
} else {

View file

@ -11,10 +11,16 @@ mod proxy_quic_s2n;
use crate::{
globals::Globals,
hyper_ext::rt::{LocalExecutor, TokioTimer},
name_exp::ServerName,
};
use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder};
use rustc_hash::FxHashMap as HashMap;
use rustls::ServerConfig;
use std::sync::Arc;
/// SNI to ServerConfig map type
pub type SniServerCryptoMap = HashMap<ServerName, Arc<ServerConfig>>;
pub(crate) use proxy_main::Proxy;
/// build connection builder shared with proxy instances

View file

@ -1,6 +1,5 @@
use super::proxy_main::Proxy;
use crate::{
crypto::CryptoSource,
error::*,
hyper_ext::body::{IncomingLike, RequestBody},
log::*,
@ -17,10 +16,9 @@ use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestSt
#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))]
use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream};
impl<U, T> Proxy<U, T>
impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
pub(super) async fn h3_serve_connection<C>(
&self,

View file

@ -1,7 +1,6 @@
use super::socket::bind_tcp_socket;
use crate::{
constants::TLS_HANDSHAKE_TIMEOUT_SEC,
crypto::{CryptoSource, ServerCrypto, SniServerCryptoMap},
error::*,
globals::Globals,
hyper_ext::{
@ -20,14 +19,15 @@ use hyper::{
service::service_fn,
};
use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder};
use rpxy_certs::ServerCrypto;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::time::timeout;
/// Wrapper function to handle request for HTTP/1.1 and HTTP/2
/// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler
async fn serve_request<U, T>(
async fn serve_request<T>(
req: Request<Incoming>,
handler: Arc<HttpMessageHandler<U, T>>,
handler: Arc<HttpMessageHandler<T>>,
client_addr: SocketAddr,
listen_addr: SocketAddr,
tls_enabled: bool,
@ -35,7 +35,6 @@ async fn serve_request<U, T>(
) -> RpxyResult<Response<ResponseBody>>
where
T: Send + Sync + Connect + Clone,
U: CryptoSource + Clone,
{
handler
.handle_request(
@ -50,10 +49,9 @@ where
#[derive(Clone)]
/// Proxy main object responsible to serve requests received from clients at the given socket address.
pub(crate) struct Proxy<U, T, E = LocalExecutor>
pub(crate) struct Proxy<T, E = LocalExecutor>
where
T: Send + Sync + Connect + Clone + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
/// global context shared among async tasks
pub globals: Arc<Globals>,
@ -64,13 +62,12 @@ where
/// hyper connection builder serving http request
pub connection_builder: Arc<ConnectionBuilder<E>>,
/// message handler serving incoming http request
pub message_handler: Arc<HttpMessageHandler<U, T>>,
pub message_handler: Arc<HttpMessageHandler<T>>,
}
impl<U, T> Proxy<U, T>
impl<T> Proxy<T>
where
T: Send + Sync + Connect + Clone + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
/// Serves requests from clients
fn serve_connection<I>(&self, stream: I, peer_addr: SocketAddr, tls_server_name: Option<ServerName>)
@ -168,7 +165,7 @@ where
let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?;
info!("Start TCP proxy serving with HTTPS request for configured host names");
let mut server_crypto_map: Option<Arc<SniServerCryptoMap>> = None;
let mut server_crypto_map: Option<Arc<super::SniServerCryptoMap>> = None;
loop {
select! {
tcp_cnx = tcp_listener.accept().fuse() => {
@ -230,12 +227,16 @@ where
error!("Reloader is broken");
break;
}
let cert_keys_map = server_crypto_rx.borrow().clone().unwrap();
let Some(server_crypto): Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok() else {
let server_crypto_base = server_crypto_rx.borrow().clone().unwrap();
let Some(server_config): Option<Arc<ServerCrypto>> = (&server_crypto_base).try_into().ok() else {
error!("Failed to update server crypto");
break;
};
server_crypto_map = Some(server_crypto.inner_local_map.clone());
let map = server_config.individual_config_map.clone().iter().map(|(k,v)| {
let server_name = ServerName::from(k.as_slice());
(server_name, v.clone())
}).collect::<rustc_hash::FxHashMap<_,_>>();
server_crypto_map = Some(Arc::new(map));
}
}
}

View file

@ -1,20 +1,17 @@
use super::proxy_main::Proxy;
use super::socket::bind_udp_socket;
use crate::{
crypto::{CryptoSource, ServerCrypto},
error::*,
log::*,
name_exp::ByteName,
};
use super::{proxy_main::Proxy, socket::bind_udp_socket};
use crate::{error::*, log::*, name_exp::ByteName};
use hyper_util::client::legacy::connect::Connect;
use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig};
use quinn::{
crypto::rustls::{HandshakeData, QuicServerConfig},
Endpoint, TransportConfig,
};
use rpxy_certs::ServerCrypto;
use rustls::ServerConfig;
use std::sync::Arc;
impl<U, T> Proxy<U, T>
impl<T> Proxy<T>
where
T: Send + Sync + Connect + Clone + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> {
let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else {
@ -22,13 +19,14 @@ where
};
info!("Start UDP proxy serving with HTTP/3 request for configured host names [quinn]");
// first set as null config server
let rustls_server_config = ServerConfig::builder()
.with_safe_default_cipher_suites()
.with_safe_default_kx_groups()
// AWS LC provider by default
let provider = rustls::crypto::CryptoProvider::get_default().ok_or(RpxyError::NoDefaultCryptoProvider)?;
let rustls_server_config = ServerConfig::builder_with_provider(provider.clone())
.with_protocol_versions(&[&rustls::version::TLS13])
.map_err(|e| RpxyError::QuinnInvalidTlsProtocolVersion(e.to_string()))?
.map_err(|e| RpxyError::FailedToBuildServerConfig(format!("TLS 1.3 server config failed: {e}")))?
.with_no_client_auth()
.with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new()));
let quinn_server_config_crypto = QuicServerConfig::try_from(Arc::new(rustls_server_config)).unwrap();
let mut transport_config_quic = TransportConfig::default();
transport_config_quic
@ -42,20 +40,15 @@ where
.map(|v| quinn::IdleTimeout::try_from(v).unwrap()),
);
let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config));
let mut server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(quinn_server_config_crypto));
server_config_h3.transport = Arc::new(transport_config_quic);
server_config_h3.concurrent_connections(self.globals.proxy_config.h3_max_concurrent_connections);
server_config_h3.max_incoming(self.globals.proxy_config.h3_max_concurrent_connections as usize);
// To reuse address
let udp_socket = bind_udp_socket(&self.listening_on)?;
let runtime = quinn::default_runtime()
.ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?;
let endpoint = Endpoint::new(
quinn::EndpointConfig::default(),
Some(server_config_h3),
udp_socket,
runtime,
)?;
let runtime =
quinn::default_runtime().ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?;
let endpoint = Endpoint::new(quinn::EndpointConfig::default(), Some(server_config_h3), udp_socket, runtime)?;
let mut server_crypto: Option<Arc<ServerCrypto>> = None;
loop {
@ -64,8 +57,10 @@ where
if server_crypto.is_none() || new_conn.is_none() {
continue;
}
let mut conn: quinn::Connecting = new_conn.unwrap();
let Ok(hsd) = conn.handshake_data().await else {
let Ok(mut incoming) = new_conn.unwrap().accept() else {
continue
};
let Ok(hsd) = incoming.handshake_data().await else {
continue
};
@ -84,8 +79,8 @@ where
// TODO: 通常のTLSと同じenumか何かにまとめたい
let self_clone = self.clone();
self.globals.runtime_handle.spawn(async move {
let client_addr = conn.remote_address();
let quic_connection = match conn.await {
let client_addr = incoming.remote_address();
let quic_connection = match incoming.await {
Ok(new_conn) => {
info!("New connection established");
h3_quinn::Connection::new(new_conn)
@ -114,8 +109,12 @@ where
error!("Failed to update server crypto for h3");
break;
};
endpoint.set_server_config(Some(QuicServerConfig::with_crypto(inner.clone().inner_global_no_client_auth.clone())));
let rustls_server_config = inner.aggregated_config_no_client_auth.clone();
let Ok(quinn_server_config_crypto) = QuicServerConfig::try_from(rustls_server_config) else {
error!("Failed to update server crypto for h3");
break;
};
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(Arc::new(quinn_server_config_crypto))));
}
else => break
}

View file

@ -1,21 +1,15 @@
use super::proxy_main::Proxy;
use crate::{
crypto::CryptoSource,
crypto::{ServerCrypto, ServerCryptoBase},
error::*,
log::*,
name_exp::ByteName,
};
use crate::{error::*, log::*, name_exp::ByteName};
use anyhow::anyhow;
use hot_reload::ReloaderReceiver;
use hyper_util::client::legacy::connect::Connect;
use rpxy_certs::{ServerCrypto, ServerCryptoBase};
use s2n_quic::provider;
use std::sync::Arc;
impl<U, T> Proxy<U, T>
impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
/// Start UDP proxy serving with HTTP/3 request for configured host names
pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> {
@ -25,7 +19,7 @@ where
info!("Start UDP proxy serving with HTTP/3 request for configured host names [s2n-quic]");
// initially wait for receipt
let mut server_crypto: Option<Arc<ServerCrypto>> = {
let mut server_crypto: Option<s2n_quic_rustls::Server> = {
let _ = server_crypto_rx.changed().await;
let sc = self.receive_server_crypto(server_crypto_rx.clone())?;
Some(sc)
@ -57,16 +51,24 @@ where
}
/// Receive server crypto from reloader
fn receive_server_crypto(
&self,
server_crypto_rx: ReloaderReceiver<ServerCryptoBase>,
) -> RpxyResult<Arc<ServerCrypto>> {
fn receive_server_crypto(&self, server_crypto_rx: ReloaderReceiver<ServerCryptoBase>) -> RpxyResult<s2n_quic_rustls::Server> {
let cert_keys_map = server_crypto_rx.borrow().clone().ok_or_else(|| {
error!("Reloader is broken");
RpxyError::CertificateReloadError(anyhow!("Reloader is broken").into())
})?;
let server_crypto: Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok();
let server_crypto: Option<s2n_quic_rustls::Server> = (&cert_keys_map).try_into().ok().and_then(|v: Arc<ServerCrypto>| {
let rustls_server_config = v.aggregated_config_no_client_auth.clone();
let resolver = rustls_server_config.cert_resolver.clone();
let alpn = rustls_server_config.alpn_protocols.clone();
#[allow(deprecated)]
let tls = provider::tls::rustls::server::Builder::default()
.with_cert_resolver(resolver)
.and_then(|t| t.with_application_protocols(alpn.iter()))
.and_then(|t| t.build())
.ok();
tls
});
server_crypto.ok_or_else(|| {
error!("Failed to update server crypto for h3 [s2n-quic]");
RpxyError::FailedToUpdateServerCrypto("Failed to update server crypto for h3 [s2n-quic]".to_string())
@ -74,7 +76,7 @@ where
}
/// Event loop for UDP proxy serving with HTTP/3 request for configured host names
async fn h3_listener_service_inner(&self, server_crypto: &Option<Arc<ServerCrypto>>) -> RpxyResult<()> {
async fn h3_listener_service_inner(&self, server_crypto: &Option<s2n_quic_rustls::Server>) -> RpxyResult<()> {
// setup UDP socket
let io = provider::io::tokio::Builder::default()
.with_receive_address(self.listening_on)?
@ -97,14 +99,11 @@ where
// setup tls
let Some(server_crypto) = server_crypto else {
warn!("No server crypto is given [s2n-quic]");
return Err(RpxyError::NoServerCrypto(
"No server crypto is given [s2n-quic]".to_string(),
));
return Err(RpxyError::NoServerCrypto("No server crypto is given [s2n-quic]".to_string()));
};
let tls = server_crypto.inner_global_no_client_auth.clone();
let mut server = s2n_quic::Server::builder()
.with_tls(tls)?
.with_tls(server_crypto.to_owned())?
.with_io(io)?
.with_limits(limits)?
.start()?;

@ -1 +0,0 @@
Subproject commit b44edeb60d234d49c45828395108f7519a048d4b

1
submodules/s2n-quic Submodule

@ -0,0 +1 @@
Subproject commit a3e8d34d74aa653acb53eff53781796db2fd6e39

View file

@ -1,17 +0,0 @@
[package]
name = "s2n-quic-h3"
# this in an unpublished internal crate so the version should not be changed
version = "0.1.0"
authors = ["AWS s2n"]
edition = "2021"
rust-version = "1.63"
license = "Apache-2.0"
# this contains an http3 implementation for testing purposes and should not be published
publish = false
[dependencies]
bytes = { version = "1", default-features = false }
futures = { version = "0.3", default-features = false }
h3 = { path = "../h3/h3/" }
s2n-quic = "1.37.0"
s2n-quic-core = "0.37.0"

View file

@ -1,10 +0,0 @@
# s2n-quic-h3
This is an internal crate used by [s2n-quic](https://github.com/aws/s2n-quic) written as a proof of concept for implementing HTTP3 on top of s2n-quic. The API is not currently stable and should not be used directly.
## License
This project is licensed under the [Apache-2.0 License][license-url].
[license-badge]: https://img.shields.io/badge/license-apache-blue.svg
[license-url]: https://aws.amazon.com/apache-2-0/

View file

@ -1,7 +0,0 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
mod s2n_quic;
pub use self::s2n_quic::*;
pub use h3;

View file

@ -1,506 +0,0 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
use bytes::{Buf, Bytes};
use futures::ready;
use h3::quic::{self, Error, StreamId, WriteBuf};
use s2n_quic::stream::{BidirectionalStream, ReceiveStream};
use s2n_quic_core::varint::VarInt;
use std::{
convert::TryInto,
fmt::{self, Display},
sync::Arc,
task::{self, Poll},
};
pub struct Connection {
conn: s2n_quic::connection::Handle,
bidi_acceptor: s2n_quic::connection::BidirectionalStreamAcceptor,
recv_acceptor: s2n_quic::connection::ReceiveStreamAcceptor,
}
impl Connection {
pub fn new(new_conn: s2n_quic::Connection) -> Self {
let (handle, acceptor) = new_conn.split();
let (bidi, recv) = acceptor.split();
Self {
conn: handle,
bidi_acceptor: bidi,
recv_acceptor: recv,
}
}
}
#[derive(Debug)]
pub struct ConnectionError(s2n_quic::connection::Error);
impl std::error::Error for ConnectionError {}
impl fmt::Display for ConnectionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl Error for ConnectionError {
fn is_timeout(&self) -> bool {
matches!(self.0, s2n_quic::connection::Error::IdleTimerExpired { .. })
}
fn err_code(&self) -> Option<u64> {
match self.0 {
s2n_quic::connection::Error::Application { error, .. } => Some(error.into()),
_ => None,
}
}
}
impl From<s2n_quic::connection::Error> for ConnectionError {
fn from(e: s2n_quic::connection::Error) -> Self {
Self(e)
}
}
impl<B> quic::Connection<B> for Connection
where
B: Buf,
{
type BidiStream = BidiStream<B>;
type SendStream = SendStream<B>;
type RecvStream = RecvStream;
type OpenStreams = OpenStreams;
type Error = ConnectionError;
fn poll_accept_recv(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> {
let recv = match ready!(self.recv_acceptor.poll_accept_receive_stream(cx))? {
Some(x) => x,
None => return Poll::Ready(Ok(None)),
};
Poll::Ready(Ok(Some(Self::RecvStream::new(recv))))
}
fn poll_accept_bidi(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> {
let (recv, send) = match ready!(self.bidi_acceptor.poll_accept_bidirectional_stream(cx))? {
Some(x) => x.split(),
None => return Poll::Ready(Ok(None)),
};
Poll::Ready(Ok(Some(Self::BidiStream {
send: Self::SendStream::new(send),
recv: Self::RecvStream::new(recv),
})))
}
fn poll_open_bidi(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Self::BidiStream, Self::Error>> {
let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?;
Ok(stream.into()).into()
}
fn poll_open_send(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Self::SendStream, Self::Error>> {
let stream = ready!(self.conn.poll_open_send_stream(cx))?;
Ok(stream.into()).into()
}
fn opener(&self) -> Self::OpenStreams {
OpenStreams {
conn: self.conn.clone(),
}
}
fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
self.conn.close(
code.value()
.try_into()
.expect("s2n-quic supports error codes up to 2^62-1"),
);
}
}
pub struct OpenStreams {
conn: s2n_quic::connection::Handle,
}
impl<B> quic::OpenStreams<B> for OpenStreams
where
B: Buf,
{
type BidiStream = BidiStream<B>;
type SendStream = SendStream<B>;
type RecvStream = RecvStream;
type Error = ConnectionError;
fn poll_open_bidi(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Self::BidiStream, Self::Error>> {
let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?;
Ok(stream.into()).into()
}
fn poll_open_send(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Self::SendStream, Self::Error>> {
let stream = ready!(self.conn.poll_open_send_stream(cx))?;
Ok(stream.into()).into()
}
fn close(&mut self, code: h3::error::Code, _reason: &[u8]) {
self.conn.close(
code.value()
.try_into()
.unwrap_or_else(|_| VarInt::MAX.into()),
);
}
}
impl Clone for OpenStreams {
fn clone(&self) -> Self {
Self {
conn: self.conn.clone(),
}
}
}
pub struct BidiStream<B>
where
B: Buf,
{
send: SendStream<B>,
recv: RecvStream,
}
impl<B> quic::BidiStream<B> for BidiStream<B>
where
B: Buf,
{
type SendStream = SendStream<B>;
type RecvStream = RecvStream;
fn split(self) -> (Self::SendStream, Self::RecvStream) {
(self.send, self.recv)
}
}
impl<B> quic::RecvStream for BidiStream<B>
where
B: Buf,
{
type Buf = Bytes;
type Error = ReadError;
fn poll_data(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
self.recv.poll_data(cx)
}
fn stop_sending(&mut self, error_code: u64) {
self.recv.stop_sending(error_code)
}
fn recv_id(&self) -> StreamId {
self.recv.stream.id().try_into().expect("invalid stream id")
}
}
impl<B> quic::SendStream<B> for BidiStream<B>
where
B: Buf,
{
type Error = SendStreamError;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.send.poll_ready(cx)
}
fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
self.send.poll_finish(cx)
}
fn reset(&mut self, reset_code: u64) {
self.send.reset(reset_code)
}
fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
self.send.send_data(data)
}
fn send_id(&self) -> StreamId {
self.send.stream.id().try_into().expect("invalid stream id")
}
}
impl<B> From<BidirectionalStream> for BidiStream<B>
where
B: Buf,
{
fn from(bidi: BidirectionalStream) -> Self {
let (recv, send) = bidi.split();
BidiStream {
send: send.into(),
recv: recv.into(),
}
}
}
pub struct RecvStream {
stream: s2n_quic::stream::ReceiveStream,
}
impl RecvStream {
fn new(stream: s2n_quic::stream::ReceiveStream) -> Self {
Self { stream }
}
}
impl quic::RecvStream for RecvStream {
type Buf = Bytes;
type Error = ReadError;
fn poll_data(
&mut self,
cx: &mut task::Context<'_>,
) -> Poll<Result<Option<Self::Buf>, Self::Error>> {
let buf = ready!(self.stream.poll_receive(cx))?;
Ok(buf).into()
}
fn stop_sending(&mut self, error_code: u64) {
let _ = self.stream.stop_sending(
s2n_quic::application::Error::new(error_code)
.expect("s2n-quic supports error codes up to 2^62-1"),
);
}
fn recv_id(&self) -> StreamId {
self.stream.id().try_into().expect("invalid stream id")
}
}
impl From<ReceiveStream> for RecvStream {
fn from(recv: ReceiveStream) -> Self {
RecvStream::new(recv)
}
}
#[derive(Debug)]
pub struct ReadError(s2n_quic::stream::Error);
impl std::error::Error for ReadError {}
impl fmt::Display for ReadError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl From<ReadError> for Arc<dyn Error> {
fn from(e: ReadError) -> Self {
Arc::new(e)
}
}
impl From<s2n_quic::stream::Error> for ReadError {
fn from(e: s2n_quic::stream::Error) -> Self {
Self(e)
}
}
impl Error for ReadError {
fn is_timeout(&self) -> bool {
matches!(
self.0,
s2n_quic::stream::Error::ConnectionError {
error: s2n_quic::connection::Error::IdleTimerExpired { .. },
..
}
)
}
fn err_code(&self) -> Option<u64> {
match self.0 {
s2n_quic::stream::Error::ConnectionError {
error: s2n_quic::connection::Error::Application { error, .. },
..
} => Some(error.into()),
s2n_quic::stream::Error::StreamReset { error, .. } => Some(error.into()),
_ => None,
}
}
}
pub struct SendStream<B: Buf> {
stream: s2n_quic::stream::SendStream,
chunk: Option<Bytes>,
buf: Option<WriteBuf<B>>, // TODO: Replace with buf: PhantomData<B>
// after https://github.com/hyperium/h3/issues/78 is resolved
}
impl<B> SendStream<B>
where
B: Buf,
{
fn new(stream: s2n_quic::stream::SendStream) -> SendStream<B> {
Self {
stream,
chunk: None,
buf: Default::default(),
}
}
}
impl<B> quic::SendStream<B> for SendStream<B>
where
B: Buf,
{
type Error = SendStreamError;
fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
loop {
// try to flush the current chunk if we have one
if let Some(chunk) = self.chunk.as_mut() {
ready!(self.stream.poll_send(chunk, cx))?;
// s2n-quic will take the whole chunk on send, even if it exceeds the limits
debug_assert!(chunk.is_empty());
self.chunk = None;
}
// try to take the next chunk from the WriteBuf
if let Some(ref mut data) = self.buf {
let len = data.chunk().len();
// if the write buf is empty, then clear it and break
if len == 0 {
self.buf = None;
break;
}
// copy the first chunk from WriteBuf and prepare it to flush
let chunk = data.copy_to_bytes(len);
self.chunk = Some(chunk);
// loop back around to flush the chunk
continue;
}
// if we didn't have either a chunk or WriteBuf, then we're ready
break;
}
Poll::Ready(Ok(()))
// TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved
// self.available_bytes = ready!(self.stream.poll_send_ready(cx))?;
// Poll::Ready(Ok(()))
}
fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> {
if self.buf.is_some() {
return Err(Self::Error::NotReady);
}
self.buf = Some(data.into());
Ok(())
// TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved
// let mut data = data.into();
// while self.available_bytes > 0 && data.has_remaining() {
// let len = data.chunk().len();
// let chunk = data.copy_to_bytes(len);
// self.stream.send_data(chunk)?;
// self.available_bytes = self.available_bytes.saturating_sub(len);
// }
// Ok(())
}
fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// ensure all chunks are flushed to the QUIC stream before finishing
ready!(self.poll_ready(cx))?;
self.stream.finish()?;
Ok(()).into()
}
fn reset(&mut self, reset_code: u64) {
let _ = self
.stream
.reset(reset_code.try_into().unwrap_or_else(|_| VarInt::MAX.into()));
}
fn send_id(&self) -> StreamId {
self.stream.id().try_into().expect("invalid stream id")
}
}
impl<B> From<s2n_quic::stream::SendStream> for SendStream<B>
where
B: Buf,
{
fn from(send: s2n_quic::stream::SendStream) -> Self {
SendStream::new(send)
}
}
#[derive(Debug)]
pub enum SendStreamError {
Write(s2n_quic::stream::Error),
NotReady,
}
impl std::error::Error for SendStreamError {}
impl Display for SendStreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl From<s2n_quic::stream::Error> for SendStreamError {
fn from(e: s2n_quic::stream::Error) -> Self {
Self::Write(e)
}
}
impl Error for SendStreamError {
fn is_timeout(&self) -> bool {
matches!(
self,
Self::Write(s2n_quic::stream::Error::ConnectionError {
error: s2n_quic::connection::Error::IdleTimerExpired { .. },
..
})
)
}
fn err_code(&self) -> Option<u64> {
match self {
Self::Write(s2n_quic::stream::Error::StreamReset { error, .. }) => {
Some((*error).into())
}
Self::Write(s2n_quic::stream::Error::ConnectionError {
error: s2n_quic::connection::Error::Application { error, .. },
..
}) => Some((*error).into()),
_ => None,
}
}
}
impl From<SendStreamError> for Arc<dyn Error> {
fn from(e: SendStreamError) -> Self {
Arc::new(e)
}
}