refactor: add cert_reader object in backend

This commit is contained in:
Jun Kurihara 2023-07-12 20:31:31 +09:00
commit 6c0fd85ca5
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
11 changed files with 96 additions and 40 deletions

View file

@ -13,14 +13,20 @@ pub use self::{
upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder},
upstream_opts::UpstreamOption,
};
use crate::utils::{BytesName, PathNameBytesExp, ServerNameBytesExp};
use crate::{
certs::CryptoSource,
utils::{BytesName, PathNameBytesExp, ServerNameBytesExp},
};
use derive_builder::Builder;
use rustc_hash::FxHashMap as HashMap;
use std::{borrow::Cow, path::PathBuf};
/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
#[derive(Builder)]
pub struct Backend {
pub struct Backend<T>
where
T: CryptoSource,
{
#[builder(setter(into))]
/// backend application name, e.g., app1
pub app_name: String,
@ -39,8 +45,14 @@ pub struct Backend {
pub https_redirection: Option<bool>,
#[builder(setter(custom), default)]
pub client_ca_cert_path: Option<PathBuf>,
#[builder(default)]
pub crypto_source: Option<T>,
}
impl<'a> BackendBuilder {
impl<'a, T> BackendBuilder<T>
where
T: CryptoSource,
{
pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self {
self.server_name = Some(server_name.into().to_ascii_lowercase());
self
@ -63,9 +75,23 @@ fn opt_string_to_opt_pathbuf(input: &Option<String>) -> Option<PathBuf> {
input.to_owned().as_ref().map(PathBuf::from)
}
#[derive(Default)]
/// HashMap and some meta information for multiple Backend structs.
pub struct Backends {
pub apps: HashMap<ServerNameBytesExp, Backend>, // hyper::uriで抜いたhostで引っ掛ける
pub struct Backends<T>
where
T: CryptoSource,
{
pub apps: HashMap<ServerNameBytesExp, Backend<T>>, // hyper::uriで抜いたhostで引っ掛ける
pub default_server_name_bytes: Option<ServerNameBytesExp>, // for plaintext http
}
impl<T> Backends<T>
where
T: CryptoSource,
{
pub fn new() -> Self {
Backends {
apps: HashMap::<ServerNameBytesExp, Backend<T>>::default(),
default_server_name_bytes: None,
}
}
}

View file

@ -11,7 +11,7 @@ use std::{
path::PathBuf,
};
#[derive(Builder, Debug)]
#[derive(Builder, Debug, Clone)]
/// Crypto-related file reader implementing certs::CryptoRead trait
pub struct CryptoFileSource {
#[builder(setter(custom))]

View file

@ -1,9 +1,12 @@
use super::toml::ConfigToml;
use crate::{backend::Backends, error::*, globals::*, log::*, utils::BytesName};
use crate::{backend::Backends, certs::CryptoSource, error::*, globals::*, log::*, utils::BytesName};
use clap::Arg;
use tokio::runtime::Handle;
pub fn build_globals(runtime_handle: Handle) -> std::result::Result<Globals, anyhow::Error> {
pub fn build_globals<T>(runtime_handle: Handle) -> std::result::Result<Globals<T>, anyhow::Error>
where
T: CryptoSource + Clone,
{
let _ = include_str!("../../Cargo.toml");
let options = clap::command!().arg(
Arg::new("config_file")
@ -72,7 +75,7 @@ pub fn build_globals(runtime_handle: Handle) -> std::result::Result<Globals, any
}
// build backends
let mut backends = Backends::default();
let mut backends = Backends::new();
for (app_name, app) in apps.0.iter() {
let server_name_string = app.server_name.as_ref().ok_or(anyhow!("No server name"))?;
let backend = app.try_into()?;

View file

@ -1,5 +1,6 @@
use crate::{
backend::{Backend, BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption},
certs::CryptoSource,
constants::*,
error::*,
globals::ProxyConfig,
@ -170,10 +171,13 @@ impl ConfigToml {
}
}
impl TryInto<Backend> for &Application {
impl<T> TryInto<Backend<T>> for &Application
where
T: CryptoSource + Clone,
{
type Error = anyhow::Error;
fn try_into(self) -> std::result::Result<Backend, Self::Error> {
fn try_into(self) -> std::result::Result<Backend<T>, Self::Error> {
let server_name_string = self.server_name.as_ref().ok_or(anyhow!("Missing server_name"))?;
// backend builder

View file

@ -1,3 +1,4 @@
use crate::certs::CryptoSource;
use crate::{backend::Backends, constants::*};
use std::net::SocketAddr;
use std::sync::{
@ -8,12 +9,15 @@ use tokio::time::Duration;
/// Global object containing proxy configurations and shared object like counters.
/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
pub struct Globals {
pub struct Globals<T>
where
T: CryptoSource,
{
/// Configuration parameters for proxy transport and request handlers
pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも
/// Backend application objects to which http request handler forward incoming requests
pub backends: Backends,
pub backends: Backends<T>,
/// Shared context - Counter for serving requests
pub request_count: RequestCount,

View file

@ -2,6 +2,7 @@
use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext};
use crate::{
backend::{Backend, UpstreamGroup},
certs::CryptoSource,
error::*,
globals::Globals,
log::*,
@ -18,17 +19,19 @@ use std::{env, net::SocketAddr, sync::Arc};
use tokio::{io::copy_bidirectional, time::timeout};
#[derive(Clone, Builder)]
pub struct HttpMessageHandler<T>
pub struct HttpMessageHandler<T, U>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone,
{
forwarder: Arc<Client<T>>,
globals: Arc<Globals>,
globals: Arc<Globals<U>>,
}
impl<T> HttpMessageHandler<T>
impl<T, U> HttpMessageHandler<T, U>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone,
{
fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result<Response<Body>> {
log_data.status_code(&status_code).output();
@ -194,11 +197,10 @@ where
////////////////////////////////////////////////////
// Functions to generate messages
fn generate_response_forwarded<B: core::fmt::Debug>(
&self,
response: &mut Response<B>,
chosen_backend: &Backend,
) -> Result<()> {
fn generate_response_forwarded<B>(&self, response: &mut Response<B>, chosen_backend: &Backend<U>) -> Result<()>
where
B: core::fmt::Debug,
{
let headers = response.headers_mut();
remove_connection_header(headers);
remove_hop_header(headers);

View file

@ -1,3 +1,4 @@
use certs::CryptoSource;
#[cfg(not(target_env = "msvc"))]
use tikv_jemallocator::Jemalloc;
@ -18,7 +19,8 @@ mod proxy;
mod utils;
use crate::{
config::build_globals, error::*, globals::*, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder,
cert_file_reader::CryptoFileSource, config::build_globals, error::*, globals::*, handler::HttpMessageHandlerBuilder,
log::*, proxy::ProxyBuilder,
};
use futures::future::select_all;
use hyper::Client;
@ -34,7 +36,7 @@ fn main() {
let runtime = runtime_builder.build().unwrap();
runtime.block_on(async {
let globals = match build_globals(runtime.handle().clone()) {
let globals: Globals<CryptoFileSource> = match build_globals(runtime.handle().clone()) {
Ok(g) => g,
Err(e) => {
error!("Invalid configuration: {}", e);
@ -48,7 +50,10 @@ fn main() {
}
// entrypoint creates and spawns tasks of proxy services
async fn entrypoint(globals: Arc<Globals>) -> Result<()> {
async fn entrypoint<T>(globals: Arc<Globals<T>>) -> Result<()>
where
T: CryptoSource + Clone + Send + Sync + 'static,
{
// let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector();
let connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_webpki_roots()

View file

@ -1,6 +1,6 @@
use crate::{
cert_file_reader::read_certs_and_keys, // TODO: Trait defining read_certs_and_keys and add struct implementing the trait to backend when build backend
certs::CertsAndKeys,
certs::{CertsAndKeys, CryptoSource},
globals::Globals,
log::*,
utils::ServerNameBytesExp,
@ -18,8 +18,11 @@ use x509_parser::prelude::*;
#[derive(Clone)]
/// Reloader service for certificates and keys for TLS
pub struct CryptoReloader {
globals: Arc<Globals>,
pub struct CryptoReloader<T>
where
T: CryptoSource,
{
globals: Arc<Globals<T>>,
}
pub type SniServerCryptoMap = HashMap<ServerNameBytesExp, Arc<ServerConfig>>;
@ -37,8 +40,11 @@ pub struct ServerCryptoBase {
}
#[async_trait]
impl Reload<ServerCryptoBase> for CryptoReloader {
type Source = Arc<Globals>;
impl<T> Reload<ServerCryptoBase> for CryptoReloader<T>
where
T: CryptoSource + Sync + Send,
{
type Source = Arc<Globals<T>>;
async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> {
Ok(Self {
globals: source.clone(),

View file

@ -1,14 +1,15 @@
use super::Proxy;
use crate::{error::*, log::*, utils::ServerNameBytesExp};
use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp};
use bytes::{Buf, Bytes};
use h3::{quic::BidiStream, server::RequestStream};
use hyper::{client::connect::Connect, Body, Request, Response};
use std::net::SocketAddr;
use tokio::time::{timeout, Duration};
impl<T> Proxy<T>
impl<T, U> Proxy<T, U>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
pub(super) async fn connection_serve_h3(
self,

View file

@ -1,5 +1,7 @@
// use super::proxy_handler::handle_request;
use crate::{error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp};
use crate::{
certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp,
};
use derive_builder::{self, Builder};
use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request};
use std::{net::SocketAddr, sync::Arc};
@ -32,19 +34,21 @@ where
}
#[derive(Clone, Builder)]
pub struct Proxy<T>
pub struct Proxy<T, U>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
pub listening_on: SocketAddr,
pub tls_enabled: bool, // TCP待受がTLSかどうか
pub msg_handler: HttpMessageHandler<T>,
pub globals: Arc<Globals>,
pub msg_handler: HttpMessageHandler<T, U>,
pub globals: Arc<Globals<U>>,
}
impl<T> Proxy<T>
impl<T, U> Proxy<T, U>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send,
{
pub(super) fn client_serve<I>(
self,

View file

@ -2,7 +2,7 @@ use super::{
crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap},
proxy_main::{LocalExecutor, Proxy},
};
use crate::{constants::*, error::*, log::*, utils::BytesName};
use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName};
use hot_reload::{ReloaderReceiver, ReloaderService};
use hyper::{client::connect::Connect, server::conn::Http};
#[cfg(feature = "http3")]
@ -15,9 +15,10 @@ use tokio::{
time::{timeout, Duration},
};
impl<T> Proxy<T>
impl<T, U> Proxy<T, U>
where
T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone + Sync + Send + 'static,
{
// TCP Listener Service, i.e., http/2 and http/1.1
async fn listener_service(
@ -181,7 +182,7 @@ where
}
pub async fn start_with_tls(self, server: Http<LocalExecutor>) -> Result<()> {
let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<CryptoReloader, ServerCryptoBase>::new(
let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<CryptoReloader<U>, ServerCryptoBase>::new(
&self.globals.clone(),
CERTS_WATCH_DELAY_SECS,
!LOAD_CERTS_ONLY_WHEN_UPDATED,