diff --git a/Cargo.toml b/Cargo.toml index 50e2199..c1c70d9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ h3 = { path = "./h3/h3/", optional = true } h3-quinn = { path = "./h3/h3-quinn/", optional = true } thiserror = "1.0.37" x509-parser = "0.14.0" +derive_builder = "0.12.0" [target.'cfg(not(target_env = "msvc"))'.dependencies] diff --git a/src/backend/mod.rs b/src/backend/mod.rs index c6a2842..7c60c3c 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -5,9 +5,11 @@ use crate::{ log::*, utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, }; +use derive_builder::Builder; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use rustls::{OwnedTrustAnchor, RootCertStore}; use std::{ + borrow::Cow, fs::File, io::{self, BufReader, Cursor, Read}, path::PathBuf, @@ -18,22 +20,51 @@ use tokio_rustls::rustls::{ sign::{any_supported_type, CertifiedKey}, Certificate, PrivateKey, ServerConfig, }; -pub use upstream::{ReverseProxy, Upstream, UpstreamGroup}; +pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}; pub use upstream_opts::UpstreamOption; use x509_parser::prelude::*; /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. +#[derive(Builder)] pub struct Backend { + #[builder(setter(into))] pub app_name: String, + #[builder(setter(custom))] pub server_name: String, pub reverse_proxy: ReverseProxy, // tls settings + #[builder(setter(custom), default)] pub tls_cert_path: Option, + #[builder(setter(custom), default)] pub tls_cert_key_path: Option, + #[builder(default)] pub https_redirection: Option, + #[builder(setter(custom), default)] pub client_ca_cert_path: Option, } +impl<'a> BackendBuilder { + pub fn server_name(&mut self, server_name: impl Into>) -> &mut Self { + self.server_name = Some(server_name.into().to_ascii_lowercase()); + self + } + pub fn tls_cert_path(&mut self, v: &Option) -> &mut Self { + self.tls_cert_path = Some(opt_string_to_opt_pathbuf(v)); + self + } + pub fn tls_cert_key_path(&mut self, v: &Option) -> &mut Self { + self.tls_cert_key_path = Some(opt_string_to_opt_pathbuf(v)); + self + } + pub fn client_ca_cert_path(&mut self, v: &Option) -> &mut Self { + self.client_ca_cert_path = Some(opt_string_to_opt_pathbuf(v)); + self + } +} + +fn opt_string_to_opt_pathbuf(input: &Option) -> Option { + input.to_owned().as_ref().map(PathBuf::from) +} impl Backend { pub fn read_certs_and_key(&self) -> io::Result { diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index aef0d83..c2fdb34 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,5 +1,6 @@ use super::{BytesName, PathNameBytesExp, UpstreamOption}; use crate::log::*; +use derive_builder::Builder; use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::{ @@ -66,15 +67,50 @@ pub struct Upstream { pub uri: hyper::Uri, // base uri without specific path } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Builder)] pub struct UpstreamGroup { pub upstream: Vec, + #[builder(setter(custom), default)] pub path: PathNameBytesExp, + #[builder(setter(custom), default)] pub replace_path: Option, + #[builder(default)] pub lb: LoadBalance, + #[builder(default)] pub cnt: UpstreamCount, // counter for load balancing + #[builder(setter(custom), default)] pub opts: HashSet, } +impl UpstreamGroupBuilder { + pub fn path(&mut self, v: &Option) -> &mut Self { + let path = match v { + Some(p) => p.to_path_name_vec(), + None => "/".to_path_name_vec(), + }; + self.path = Some(path); + self + } + pub fn replace_path(&mut self, v: &Option) -> &mut Self { + self.replace_path = Some( + v.to_owned() + .as_ref() + .map_or_else(|| None, |v| Some(v.to_path_name_vec())), + ); + self + } + pub fn opts(&mut self, v: &Option>) -> &mut Self { + let opts = if let Some(opts) = v { + opts + .iter() + .filter_map(|str| UpstreamOption::try_from(str.as_str()).ok()) + .collect::>() + } else { + Default::default() + }; + self.opts = Some(opts); + self + } +} #[derive(Debug, Clone, Default)] pub struct UpstreamCount(Arc); diff --git a/src/config/parse.rs b/src/config/parse.rs index 93406fd..39e2dac 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,6 +1,6 @@ use super::toml::{ConfigToml, ReverseProxyOption}; use crate::{ - backend::{Backend, ReverseProxy, UpstreamGroup, UpstreamOption}, + backend::{BackendBuilder, ReverseProxy, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, constants::*, error::*, globals::*, @@ -8,9 +8,8 @@ use crate::{ utils::{BytesName, PathNameBytesExp}, }; use clap::Arg; -use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use rustc_hash::FxHashMap as HashMap; use std::net::SocketAddr; -use std::path::PathBuf; pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Error> { let _ = include_str!("../../Cargo.toml"); @@ -91,49 +90,49 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro for (app_name, app) in apps.0.iter() { ensure!(app.server_name.is_some(), "Missing server_name"); let server_name_string = app.server_name.as_ref().unwrap(); - - // TLS settings - let (tls_cert_path, tls_cert_key_path, https_redirection, client_ca_cert_path) = if app.tls.is_none() { - ensure!(globals.http_port.is_some(), "Required HTTP port"); - (None, None, None, None) - } else { - let tls = app.tls.as_ref().unwrap(); - ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some()); - - ( - tls.tls_cert_path.as_ref().map(PathBuf::from), - tls.tls_cert_key_path.as_ref().map(PathBuf::from), - if tls.https_redirection.is_none() { - Some(true) // Default true - } else { - ensure!(globals.https_port.is_some()); // only when both https ports are configured. - tls.https_redirection - }, - tls.client_ca_cert_path.as_ref().map(PathBuf::from), - ) - }; if globals.http_port.is_none() { // if only https_port is specified, tls must be configured ensure!(app.tls.is_some()) } + // backend builder + let mut backend_builder = BackendBuilder::default(); // reverse proxy settings ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; - globals.backends.apps.insert( - server_name_string.to_server_name_vec(), - Backend { - app_name: app_name.to_owned(), - server_name: server_name_string.to_ascii_lowercase(), - reverse_proxy, + backend_builder + .app_name(server_name_string) + .server_name(server_name_string) + .reverse_proxy(reverse_proxy); - tls_cert_path, - tls_cert_key_path, - https_redirection, - client_ca_cert_path, - }, - ); + // TLS settings and build backend instance + let backend = if app.tls.is_none() { + ensure!(globals.http_port.is_some(), "Required HTTP port"); + backend_builder.build()? + } else { + let tls = app.tls.as_ref().unwrap(); + ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some()); + + let https_redirection = if tls.https_redirection.is_none() { + Some(true) // Default true + } else { + ensure!(globals.https_port.is_some()); // only when both https ports are configured. + tls.https_redirection + }; + + backend_builder + .tls_cert_path(&tls.tls_cert_path) + .tls_cert_key_path(&tls.tls_cert_key_path) + .https_redirection(https_redirection) + .client_ca_cert_path(&tls.client_ca_cert_path) + .build()? + }; + + globals + .backends + .apps + .insert(server_name_string.to_server_name_vec(), backend); info!("Registering application: {} ({})", app_name, server_name_string); } @@ -194,33 +193,15 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result { let mut upstream: HashMap = HashMap::default(); rp_settings.iter().for_each(|rpo| { - let path = match &rpo.path { - Some(p) => p.to_path_name_vec(), - None => "/".to_path_name_vec(), - }; + let elem = UpstreamGroupBuilder::default() + .upstream(rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect()) + .path(&rpo.path) + .replace_path(&rpo.replace_path) + .opts(&rpo.upstream_options) + .build() + .unwrap(); - let elem = UpstreamGroup { - upstream: rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(), - path: path.clone(), - replace_path: rpo - .replace_path - .as_ref() - .map_or_else(|| None, |v| Some(v.to_path_name_vec())), - cnt: Default::default(), - lb: Default::default(), - opts: { - if let Some(opts) = &rpo.upstream_options { - opts - .iter() - .filter_map(|str| UpstreamOption::try_from(str.as_str()).ok()) - .collect::>() - } else { - Default::default() - } - }, - }; - - upstream.insert(path, elem); + upstream.insert(elem.path.clone(), elem); }); ensure!( rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2, diff --git a/src/error.rs b/src/error.rs index aa679f8..6da3b02 100644 --- a/src/error.rs +++ b/src/error.rs @@ -7,6 +7,12 @@ pub type Result = std::result::Result; /// Describes things that can go wrong in the Rpxy #[derive(Debug, Error)] pub enum RpxyError { + #[error("Proxy build error")] + ProxyBuild(#[from] crate::proxy::ProxyBuilderError), + + #[error("MessageHandler build error")] + HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError), + #[error("Http Message Handler Error: {0}")] Handler(&'static str), diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index fbc5161..4f60ee5 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -7,6 +7,7 @@ use crate::{ log::*, utils::ServerNameBytesExp, }; +use derive_builder::Builder; use hyper::{ client::connect::Connect, header::{self, HeaderValue}, @@ -16,13 +17,13 @@ use hyper::{ use std::{env, net::SocketAddr, sync::Arc}; use tokio::{io::copy_bidirectional, time::timeout}; -#[derive(Clone)] +#[derive(Clone, Builder)] pub struct HttpMessageHandler where T: Connect + Clone + Sync + Send + 'static, { - pub forwarder: Arc>, - pub globals: Arc, + forwarder: Arc>, + globals: Arc, } impl HttpMessageHandler diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 799ef60..c2225ce 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -3,4 +3,4 @@ mod utils_headers; mod utils_request; mod utils_synth_response; -pub use handler_main::HttpMessageHandler; +pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; diff --git a/src/main.rs b/src/main.rs index 58610ea..2fd8602 100644 --- a/src/main.rs +++ b/src/main.rs @@ -21,12 +21,12 @@ use crate::{ constants::*, error::*, globals::*, + handler::HttpMessageHandlerBuilder, log::*, - proxy::Proxy, + proxy::ProxyBuilder, utils::ServerNameBytesExp, }; use futures::future::select_all; -use handler::HttpMessageHandler; use hyper::Client; // use hyper_trust_dns::TrustDnsResolver; use rustc_hash::FxHashMap as HashMap; @@ -110,10 +110,11 @@ async fn entrypoint(globals: Arc) -> Result<()> { .enable_http1() .enable_http2() .build(); - let msg_handler = HttpMessageHandler { - forwarder: Arc::new(Client::builder().build::<_, hyper::Body>(connector)), - globals: globals.clone(), - }; + + let msg_handler = HttpMessageHandlerBuilder::default() + .forwarder(Arc::new(Client::builder().build::<_, hyper::Body>(connector))) + .globals(globals.clone()) + .build()?; let addresses = globals.listen_sockets.clone(); let futures = select_all(addresses.into_iter().map(|addr| { @@ -122,12 +123,14 @@ async fn entrypoint(globals: Arc) -> Result<()> { tls_enabled = https_port == addr.port() } - let proxy = Proxy { - globals: globals.clone(), - listening_on: addr, - tls_enabled, - msg_handler: msg_handler.clone(), - }; + let proxy = ProxyBuilder::default() + .globals(globals.clone()) + .listening_on(addr) + .tls_enabled(tls_enabled) + .msg_handler(msg_handler.clone()) + .build() + .unwrap(); + globals.runtime_handle.spawn(proxy.start()) })); diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 82d775b..04413f5 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -4,4 +4,4 @@ mod proxy_h3; mod proxy_main; mod proxy_tls; -pub use proxy_main::Proxy; +pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError}; diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index 964ad70..722ef3c 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -1,5 +1,6 @@ // use super::proxy_handler::handle_request; use crate::{error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp}; +use derive_builder::{self, Builder}; use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; use std::{net::SocketAddr, sync::Arc}; use tokio::{ @@ -30,7 +31,7 @@ where } } -#[derive(Clone)] +#[derive(Clone, Builder)] pub struct Proxy where T: Connect + Clone + Sync + Send + 'static,