refactor: add cert_reader object in backend

This commit is contained in:
Jun Kurihara 2023-07-12 20:31:31 +09:00
commit 6c0fd85ca5
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
11 changed files with 96 additions and 40 deletions

View file

@ -13,14 +13,20 @@ pub use self::{
upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder},
upstream_opts::UpstreamOption, upstream_opts::UpstreamOption,
}; };
use crate::utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}; use crate::{
certs::CryptoSource,
utils::{BytesName, PathNameBytesExp, ServerNameBytesExp},
};
use derive_builder::Builder; use derive_builder::Builder;
use rustc_hash::FxHashMap as HashMap; use rustc_hash::FxHashMap as HashMap;
use std::{borrow::Cow, path::PathBuf}; use std::{borrow::Cow, path::PathBuf};
/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
#[derive(Builder)] #[derive(Builder)]
pub struct Backend { pub struct Backend<T>
where
T: CryptoSource,
{
#[builder(setter(into))] #[builder(setter(into))]
/// backend application name, e.g., app1 /// backend application name, e.g., app1
pub app_name: String, pub app_name: String,
@ -39,8 +45,14 @@ pub struct Backend {
pub https_redirection: Option<bool>, pub https_redirection: Option<bool>,
#[builder(setter(custom), default)] #[builder(setter(custom), default)]
pub client_ca_cert_path: Option<PathBuf>, pub client_ca_cert_path: Option<PathBuf>,
#[builder(default)]
pub crypto_source: Option<T>,
} }
impl<'a> BackendBuilder { impl<'a, T> BackendBuilder<T>
where
T: CryptoSource,
{
pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self { pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self {
self.server_name = Some(server_name.into().to_ascii_lowercase()); self.server_name = Some(server_name.into().to_ascii_lowercase());
self self
@ -63,9 +75,23 @@ fn opt_string_to_opt_pathbuf(input: &Option<String>) -> Option<PathBuf> {
input.to_owned().as_ref().map(PathBuf::from) input.to_owned().as_ref().map(PathBuf::from)
} }
#[derive(Default)]
/// HashMap and some meta information for multiple Backend structs. /// HashMap and some meta information for multiple Backend structs.
pub struct Backends { pub struct Backends<T>
pub apps: HashMap<ServerNameBytesExp, Backend>, // hyper::uriで抜いたhostで引っ掛ける where
T: CryptoSource,
{
pub apps: HashMap<ServerNameBytesExp, Backend<T>>, // hyper::uriで抜いたhostで引っ掛ける
pub default_server_name_bytes: Option<ServerNameBytesExp>, // for plaintext http pub default_server_name_bytes: Option<ServerNameBytesExp>, // for plaintext http
} }
impl<T> Backends<T>
where
T: CryptoSource,
{
pub fn new() -> Self {
Backends {
apps: HashMap::<ServerNameBytesExp, Backend<T>>::default(),
default_server_name_bytes: None,
}
}
}

View file

@ -11,7 +11,7 @@ use std::{
path::PathBuf, path::PathBuf,
}; };
#[derive(Builder, Debug)] #[derive(Builder, Debug, Clone)]
/// Crypto-related file reader implementing certs::CryptoRead trait /// Crypto-related file reader implementing certs::CryptoRead trait
pub struct CryptoFileSource { pub struct CryptoFileSource {
#[builder(setter(custom))] #[builder(setter(custom))]

View file

@ -1,9 +1,12 @@
use super::toml::ConfigToml; use super::toml::ConfigToml;
use crate::{backend::Backends, error::*, globals::*, log::*, utils::BytesName}; use crate::{backend::Backends, certs::CryptoSource, error::*, globals::*, log::*, utils::BytesName};
use clap::Arg; use clap::Arg;
use tokio::runtime::Handle; use tokio::runtime::Handle;
pub fn build_globals(runtime_handle: Handle) -> std::result::Result<Globals, anyhow::Error> { pub fn build_globals<T>(runtime_handle: Handle) -> std::result::Result<Globals<T>, anyhow::Error>
where
T: CryptoSource + Clone,
{
let _ = include_str!("../../Cargo.toml"); let _ = include_str!("../../Cargo.toml");
let options = clap::command!().arg( let options = clap::command!().arg(
Arg::new("config_file") Arg::new("config_file")
@ -72,7 +75,7 @@ pub fn build_globals(runtime_handle: Handle) -> std::result::Result<Globals, any
} }
// build backends // build backends
let mut backends = Backends::default(); let mut backends = Backends::new();
for (app_name, app) in apps.0.iter() { for (app_name, app) in apps.0.iter() {
let server_name_string = app.server_name.as_ref().ok_or(anyhow!("No server name"))?; let server_name_string = app.server_name.as_ref().ok_or(anyhow!("No server name"))?;
let backend = app.try_into()?; let backend = app.try_into()?;

View file

@ -1,5 +1,6 @@
use crate::{ use crate::{
backend::{Backend, BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, backend::{Backend, BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption},
certs::CryptoSource,
constants::*, constants::*,
error::*, error::*,
globals::ProxyConfig, globals::ProxyConfig,
@ -170,10 +171,13 @@ impl ConfigToml {
} }
} }
impl TryInto<Backend> for &Application { impl<T> TryInto<Backend<T>> for &Application
where
T: CryptoSource + Clone,
{
type Error = anyhow::Error; type Error = anyhow::Error;
fn try_into(self) -> std::result::Result<Backend, Self::Error> { fn try_into(self) -> std::result::Result<Backend<T>, Self::Error> {
let server_name_string = self.server_name.as_ref().ok_or(anyhow!("Missing server_name"))?; let server_name_string = self.server_name.as_ref().ok_or(anyhow!("Missing server_name"))?;
// backend builder // backend builder

View file

@ -1,3 +1,4 @@
use crate::certs::CryptoSource;
use crate::{backend::Backends, constants::*}; use crate::{backend::Backends, constants::*};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::{ use std::sync::{
@ -8,12 +9,15 @@ use tokio::time::Duration;
/// Global object containing proxy configurations and shared object like counters. /// Global object containing proxy configurations and shared object like counters.
/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
pub struct Globals { pub struct Globals<T>
where
T: CryptoSource,
{
/// Configuration parameters for proxy transport and request handlers /// Configuration parameters for proxy transport and request handlers
pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも
/// Backend application objects to which http request handler forward incoming requests /// Backend application objects to which http request handler forward incoming requests
pub backends: Backends, pub backends: Backends<T>,
/// Shared context - Counter for serving requests /// Shared context - Counter for serving requests
pub request_count: RequestCount, pub request_count: RequestCount,

View file

@ -2,6 +2,7 @@
use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext}; use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext};
use crate::{ use crate::{
backend::{Backend, UpstreamGroup}, backend::{Backend, UpstreamGroup},
certs::CryptoSource,
error::*, error::*,
globals::Globals, globals::Globals,
log::*, log::*,
@ -18,17 +19,19 @@ use std::{env, net::SocketAddr, sync::Arc};
use tokio::{io::copy_bidirectional, time::timeout}; use tokio::{io::copy_bidirectional, time::timeout};
#[derive(Clone, Builder)] #[derive(Clone, Builder)]
pub struct HttpMessageHandler<T> pub struct HttpMessageHandler<T, U>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone,
{ {
forwarder: Arc<Client<T>>, forwarder: Arc<Client<T>>,
globals: Arc<Globals>, globals: Arc<Globals<U>>,
} }
impl<T> HttpMessageHandler<T> impl<T, U> HttpMessageHandler<T, U>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone,
{ {
fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result<Response<Body>> { fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result<Response<Body>> {
log_data.status_code(&status_code).output(); log_data.status_code(&status_code).output();
@ -194,11 +197,10 @@ where
//////////////////////////////////////////////////// ////////////////////////////////////////////////////
// Functions to generate messages // Functions to generate messages
fn generate_response_forwarded<B: core::fmt::Debug>( fn generate_response_forwarded<B>(&self, response: &mut Response<B>, chosen_backend: &Backend<U>) -> Result<()>
&self, where
response: &mut Response<B>, B: core::fmt::Debug,
chosen_backend: &Backend, {
) -> Result<()> {
let headers = response.headers_mut(); let headers = response.headers_mut();
remove_connection_header(headers); remove_connection_header(headers);
remove_hop_header(headers); remove_hop_header(headers);

View file

@ -1,3 +1,4 @@
use certs::CryptoSource;
#[cfg(not(target_env = "msvc"))] #[cfg(not(target_env = "msvc"))]
use tikv_jemallocator::Jemalloc; use tikv_jemallocator::Jemalloc;
@ -18,7 +19,8 @@ mod proxy;
mod utils; mod utils;
use crate::{ use crate::{
config::build_globals, error::*, globals::*, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder, cert_file_reader::CryptoFileSource, config::build_globals, error::*, globals::*, handler::HttpMessageHandlerBuilder,
log::*, proxy::ProxyBuilder,
}; };
use futures::future::select_all; use futures::future::select_all;
use hyper::Client; use hyper::Client;
@ -34,7 +36,7 @@ fn main() {
let runtime = runtime_builder.build().unwrap(); let runtime = runtime_builder.build().unwrap();
runtime.block_on(async { runtime.block_on(async {
let globals = match build_globals(runtime.handle().clone()) { let globals: Globals<CryptoFileSource> = match build_globals(runtime.handle().clone()) {
Ok(g) => g, Ok(g) => g,
Err(e) => { Err(e) => {
error!("Invalid configuration: {}", e); error!("Invalid configuration: {}", e);
@ -48,7 +50,10 @@ fn main() {
} }
// entrypoint creates and spawns tasks of proxy services // entrypoint creates and spawns tasks of proxy services
async fn entrypoint(globals: Arc<Globals>) -> Result<()> { async fn entrypoint<T>(globals: Arc<Globals<T>>) -> Result<()>
where
T: CryptoSource + Clone + Send + Sync + 'static,
{
// let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); // let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector();
let connector = hyper_rustls::HttpsConnectorBuilder::new() let connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_webpki_roots() .with_webpki_roots()

View file

@ -1,6 +1,6 @@
use crate::{ use crate::{
cert_file_reader::read_certs_and_keys, // TODO: Trait defining read_certs_and_keys and add struct implementing the trait to backend when build backend cert_file_reader::read_certs_and_keys, // TODO: Trait defining read_certs_and_keys and add struct implementing the trait to backend when build backend
certs::CertsAndKeys, certs::{CertsAndKeys, CryptoSource},
globals::Globals, globals::Globals,
log::*, log::*,
utils::ServerNameBytesExp, utils::ServerNameBytesExp,
@ -18,8 +18,11 @@ use x509_parser::prelude::*;
#[derive(Clone)] #[derive(Clone)]
/// Reloader service for certificates and keys for TLS /// Reloader service for certificates and keys for TLS
pub struct CryptoReloader { pub struct CryptoReloader<T>
globals: Arc<Globals>, where
T: CryptoSource,
{
globals: Arc<Globals<T>>,
} }
pub type SniServerCryptoMap = HashMap<ServerNameBytesExp, Arc<ServerConfig>>; pub type SniServerCryptoMap = HashMap<ServerNameBytesExp, Arc<ServerConfig>>;
@ -37,8 +40,11 @@ pub struct ServerCryptoBase {
} }
#[async_trait] #[async_trait]
impl Reload<ServerCryptoBase> for CryptoReloader { impl<T> Reload<ServerCryptoBase> for CryptoReloader<T>
type Source = Arc<Globals>; where
T: CryptoSource + Sync + Send,
{
type Source = Arc<Globals<T>>;
async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> { async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> {
Ok(Self { Ok(Self {
globals: source.clone(), globals: source.clone(),

View file

@ -1,14 +1,15 @@
use super::Proxy; use super::Proxy;
use crate::{error::*, log::*, utils::ServerNameBytesExp}; use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp};
use bytes::{Buf, Bytes}; use bytes::{Buf, Bytes};
use h3::{quic::BidiStream, server::RequestStream}; use h3::{quic::BidiStream, server::RequestStream};
use hyper::{client::connect::Connect, Body, Request, Response}; use hyper::{client::connect::Connect, Body, Request, Response};
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::time::{timeout, Duration}; use tokio::time::{timeout, Duration};
impl<T> Proxy<T> impl<T, U> Proxy<T, U>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{ {
pub(super) async fn connection_serve_h3( pub(super) async fn connection_serve_h3(
self, self,

View file

@ -1,5 +1,7 @@
// use super::proxy_handler::handle_request; // use super::proxy_handler::handle_request;
use crate::{error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp}; use crate::{
certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp,
};
use derive_builder::{self, Builder}; use derive_builder::{self, Builder};
use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request};
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
@ -32,19 +34,21 @@ where
} }
#[derive(Clone, Builder)] #[derive(Clone, Builder)]
pub struct Proxy<T> pub struct Proxy<T, U>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{ {
pub listening_on: SocketAddr, pub listening_on: SocketAddr,
pub tls_enabled: bool, // TCP待受がTLSかどうか pub tls_enabled: bool, // TCP待受がTLSかどうか
pub msg_handler: HttpMessageHandler<T>, pub msg_handler: HttpMessageHandler<T, U>,
pub globals: Arc<Globals>, pub globals: Arc<Globals<U>>,
} }
impl<T> Proxy<T> impl<T, U> Proxy<T, U>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send,
{ {
pub(super) fn client_serve<I>( pub(super) fn client_serve<I>(
self, self,

View file

@ -2,7 +2,7 @@ use super::{
crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap},
proxy_main::{LocalExecutor, Proxy}, proxy_main::{LocalExecutor, Proxy},
}; };
use crate::{constants::*, error::*, log::*, utils::BytesName}; use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName};
use hot_reload::{ReloaderReceiver, ReloaderService}; use hot_reload::{ReloaderReceiver, ReloaderService};
use hyper::{client::connect::Connect, server::conn::Http}; use hyper::{client::connect::Connect, server::conn::Http};
#[cfg(feature = "http3")] #[cfg(feature = "http3")]
@ -15,9 +15,10 @@ use tokio::{
time::{timeout, Duration}, time::{timeout, Duration},
}; };
impl<T> Proxy<T> impl<T, U> Proxy<T, U>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{ {
// TCP Listener Service, i.e., http/2 and http/1.1 // TCP Listener Service, i.e., http/2 and http/1.1
async fn listener_service( async fn listener_service(
@ -181,7 +182,7 @@ where
} }
pub async fn start_with_tls(self, server: Http<LocalExecutor>) -> Result<()> { pub async fn start_with_tls(self, server: Http<LocalExecutor>) -> Result<()> {
let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<CryptoReloader, ServerCryptoBase>::new( let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<CryptoReloader<U>, ServerCryptoBase>::new(
&self.globals.clone(), &self.globals.clone(),
CERTS_WATCH_DELAY_SECS, CERTS_WATCH_DELAY_SECS,
!LOAD_CERTS_ONLY_WHEN_UPDATED, !LOAD_CERTS_ONLY_WHEN_UPDATED,