refactor with derive_builder

This commit is contained in:
Jun Kurihara 2023-01-19 18:27:31 +09:00
commit d2b5cdcc5b
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
10 changed files with 142 additions and 82 deletions

View file

@ -54,6 +54,7 @@ h3 = { path = "./h3/h3/", optional = true }
h3-quinn = { path = "./h3/h3-quinn/", optional = true } h3-quinn = { path = "./h3/h3-quinn/", optional = true }
thiserror = "1.0.37" thiserror = "1.0.37"
x509-parser = "0.14.0" x509-parser = "0.14.0"
derive_builder = "0.12.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies] [target.'cfg(not(target_env = "msvc"))'.dependencies]

View file

@ -5,9 +5,11 @@ use crate::{
log::*, log::*,
utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, utils::{BytesName, PathNameBytesExp, ServerNameBytesExp},
}; };
use derive_builder::Builder;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use rustls::{OwnedTrustAnchor, RootCertStore}; use rustls::{OwnedTrustAnchor, RootCertStore};
use std::{ use std::{
borrow::Cow,
fs::File, fs::File,
io::{self, BufReader, Cursor, Read}, io::{self, BufReader, Cursor, Read},
path::PathBuf, path::PathBuf,
@ -18,22 +20,51 @@ use tokio_rustls::rustls::{
sign::{any_supported_type, CertifiedKey}, sign::{any_supported_type, CertifiedKey},
Certificate, PrivateKey, ServerConfig, Certificate, PrivateKey, ServerConfig,
}; };
pub use upstream::{ReverseProxy, Upstream, UpstreamGroup}; pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder};
pub use upstream_opts::UpstreamOption; pub use upstream_opts::UpstreamOption;
use x509_parser::prelude::*; use x509_parser::prelude::*;
/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
#[derive(Builder)]
pub struct Backend { pub struct Backend {
#[builder(setter(into))]
pub app_name: String, pub app_name: String,
#[builder(setter(custom))]
pub server_name: String, pub server_name: String,
pub reverse_proxy: ReverseProxy, pub reverse_proxy: ReverseProxy,
// tls settings // tls settings
#[builder(setter(custom), default)]
pub tls_cert_path: Option<PathBuf>, pub tls_cert_path: Option<PathBuf>,
#[builder(setter(custom), default)]
pub tls_cert_key_path: Option<PathBuf>, pub tls_cert_key_path: Option<PathBuf>,
#[builder(default)]
pub https_redirection: Option<bool>, pub https_redirection: Option<bool>,
#[builder(setter(custom), default)]
pub client_ca_cert_path: Option<PathBuf>, pub client_ca_cert_path: Option<PathBuf>,
} }
impl<'a> BackendBuilder {
pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self {
self.server_name = Some(server_name.into().to_ascii_lowercase());
self
}
pub fn tls_cert_path(&mut self, v: &Option<String>) -> &mut Self {
self.tls_cert_path = Some(opt_string_to_opt_pathbuf(v));
self
}
pub fn tls_cert_key_path(&mut self, v: &Option<String>) -> &mut Self {
self.tls_cert_key_path = Some(opt_string_to_opt_pathbuf(v));
self
}
pub fn client_ca_cert_path(&mut self, v: &Option<String>) -> &mut Self {
self.client_ca_cert_path = Some(opt_string_to_opt_pathbuf(v));
self
}
}
fn opt_string_to_opt_pathbuf(input: &Option<String>) -> Option<PathBuf> {
input.to_owned().as_ref().map(PathBuf::from)
}
impl Backend { impl Backend {
pub fn read_certs_and_key(&self) -> io::Result<CertifiedKey> { pub fn read_certs_and_key(&self) -> io::Result<CertifiedKey> {

View file

@ -1,5 +1,6 @@
use super::{BytesName, PathNameBytesExp, UpstreamOption}; use super::{BytesName, PathNameBytesExp, UpstreamOption};
use crate::log::*; use crate::log::*;
use derive_builder::Builder;
use rand::Rng; use rand::Rng;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use std::{ use std::{
@ -66,15 +67,50 @@ pub struct Upstream {
pub uri: hyper::Uri, // base uri without specific path pub uri: hyper::Uri, // base uri without specific path
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone, Builder)]
pub struct UpstreamGroup { pub struct UpstreamGroup {
pub upstream: Vec<Upstream>, pub upstream: Vec<Upstream>,
#[builder(setter(custom), default)]
pub path: PathNameBytesExp, pub path: PathNameBytesExp,
#[builder(setter(custom), default)]
pub replace_path: Option<PathNameBytesExp>, pub replace_path: Option<PathNameBytesExp>,
#[builder(default)]
pub lb: LoadBalance, pub lb: LoadBalance,
#[builder(default)]
pub cnt: UpstreamCount, // counter for load balancing pub cnt: UpstreamCount, // counter for load balancing
#[builder(setter(custom), default)]
pub opts: HashSet<UpstreamOption>, pub opts: HashSet<UpstreamOption>,
} }
impl UpstreamGroupBuilder {
pub fn path(&mut self, v: &Option<String>) -> &mut Self {
let path = match v {
Some(p) => p.to_path_name_vec(),
None => "/".to_path_name_vec(),
};
self.path = Some(path);
self
}
pub fn replace_path(&mut self, v: &Option<String>) -> &mut Self {
self.replace_path = Some(
v.to_owned()
.as_ref()
.map_or_else(|| None, |v| Some(v.to_path_name_vec())),
);
self
}
pub fn opts(&mut self, v: &Option<Vec<String>>) -> &mut Self {
let opts = if let Some(opts) = v {
opts
.iter()
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
.collect::<HashSet<UpstreamOption>>()
} else {
Default::default()
};
self.opts = Some(opts);
self
}
}
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct UpstreamCount(Arc<AtomicUsize>); pub struct UpstreamCount(Arc<AtomicUsize>);

View file

@ -1,6 +1,6 @@
use super::toml::{ConfigToml, ReverseProxyOption}; use super::toml::{ConfigToml, ReverseProxyOption};
use crate::{ use crate::{
backend::{Backend, ReverseProxy, UpstreamGroup, UpstreamOption}, backend::{BackendBuilder, ReverseProxy, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption},
constants::*, constants::*,
error::*, error::*,
globals::*, globals::*,
@ -8,9 +8,8 @@ use crate::{
utils::{BytesName, PathNameBytesExp}, utils::{BytesName, PathNameBytesExp},
}; };
use clap::Arg; use clap::Arg;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use rustc_hash::FxHashMap as HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::path::PathBuf;
pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Error> { pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Error> {
let _ = include_str!("../../Cargo.toml"); let _ = include_str!("../../Cargo.toml");
@ -91,49 +90,49 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro
for (app_name, app) in apps.0.iter() { for (app_name, app) in apps.0.iter() {
ensure!(app.server_name.is_some(), "Missing server_name"); ensure!(app.server_name.is_some(), "Missing server_name");
let server_name_string = app.server_name.as_ref().unwrap(); let server_name_string = app.server_name.as_ref().unwrap();
// TLS settings
let (tls_cert_path, tls_cert_key_path, https_redirection, client_ca_cert_path) = if app.tls.is_none() {
ensure!(globals.http_port.is_some(), "Required HTTP port");
(None, None, None, None)
} else {
let tls = app.tls.as_ref().unwrap();
ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some());
(
tls.tls_cert_path.as_ref().map(PathBuf::from),
tls.tls_cert_key_path.as_ref().map(PathBuf::from),
if tls.https_redirection.is_none() {
Some(true) // Default true
} else {
ensure!(globals.https_port.is_some()); // only when both https ports are configured.
tls.https_redirection
},
tls.client_ca_cert_path.as_ref().map(PathBuf::from),
)
};
if globals.http_port.is_none() { if globals.http_port.is_none() {
// if only https_port is specified, tls must be configured // if only https_port is specified, tls must be configured
ensure!(app.tls.is_some()) ensure!(app.tls.is_some())
} }
// backend builder
let mut backend_builder = BackendBuilder::default();
// reverse proxy settings // reverse proxy settings
ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy");
let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?;
globals.backends.apps.insert( backend_builder
server_name_string.to_server_name_vec(), .app_name(server_name_string)
Backend { .server_name(server_name_string)
app_name: app_name.to_owned(), .reverse_proxy(reverse_proxy);
server_name: server_name_string.to_ascii_lowercase(),
reverse_proxy,
tls_cert_path, // TLS settings and build backend instance
tls_cert_key_path, let backend = if app.tls.is_none() {
https_redirection, ensure!(globals.http_port.is_some(), "Required HTTP port");
client_ca_cert_path, backend_builder.build()?
}, } else {
); let tls = app.tls.as_ref().unwrap();
ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some());
let https_redirection = if tls.https_redirection.is_none() {
Some(true) // Default true
} else {
ensure!(globals.https_port.is_some()); // only when both https ports are configured.
tls.https_redirection
};
backend_builder
.tls_cert_path(&tls.tls_cert_path)
.tls_cert_key_path(&tls.tls_cert_key_path)
.https_redirection(https_redirection)
.client_ca_cert_path(&tls.client_ca_cert_path)
.build()?
};
globals
.backends
.apps
.insert(server_name_string.to_server_name_vec(), backend);
info!("Registering application: {} ({})", app_name, server_name_string); info!("Registering application: {} ({})", app_name, server_name_string);
} }
@ -194,33 +193,15 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro
fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result<ReverseProxy, anyhow::Error> { fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result<ReverseProxy, anyhow::Error> {
let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default(); let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default();
rp_settings.iter().for_each(|rpo| { rp_settings.iter().for_each(|rpo| {
let path = match &rpo.path { let elem = UpstreamGroupBuilder::default()
Some(p) => p.to_path_name_vec(), .upstream(rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect())
None => "/".to_path_name_vec(), .path(&rpo.path)
}; .replace_path(&rpo.replace_path)
.opts(&rpo.upstream_options)
.build()
.unwrap();
let elem = UpstreamGroup { upstream.insert(elem.path.clone(), elem);
upstream: rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(),
path: path.clone(),
replace_path: rpo
.replace_path
.as_ref()
.map_or_else(|| None, |v| Some(v.to_path_name_vec())),
cnt: Default::default(),
lb: Default::default(),
opts: {
if let Some(opts) = &rpo.upstream_options {
opts
.iter()
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
.collect::<HashSet<UpstreamOption>>()
} else {
Default::default()
}
},
};
upstream.insert(path, elem);
}); });
ensure!( ensure!(
rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2, rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2,

View file

@ -7,6 +7,12 @@ pub type Result<T> = std::result::Result<T, RpxyError>;
/// Describes things that can go wrong in the Rpxy /// Describes things that can go wrong in the Rpxy
#[derive(Debug, Error)] #[derive(Debug, Error)]
pub enum RpxyError { pub enum RpxyError {
#[error("Proxy build error")]
ProxyBuild(#[from] crate::proxy::ProxyBuilderError),
#[error("MessageHandler build error")]
HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError),
#[error("Http Message Handler Error: {0}")] #[error("Http Message Handler Error: {0}")]
Handler(&'static str), Handler(&'static str),

View file

@ -7,6 +7,7 @@ use crate::{
log::*, log::*,
utils::ServerNameBytesExp, utils::ServerNameBytesExp,
}; };
use derive_builder::Builder;
use hyper::{ use hyper::{
client::connect::Connect, client::connect::Connect,
header::{self, HeaderValue}, header::{self, HeaderValue},
@ -16,13 +17,13 @@ use hyper::{
use std::{env, net::SocketAddr, sync::Arc}; use std::{env, net::SocketAddr, sync::Arc};
use tokio::{io::copy_bidirectional, time::timeout}; use tokio::{io::copy_bidirectional, time::timeout};
#[derive(Clone)] #[derive(Clone, Builder)]
pub struct HttpMessageHandler<T> pub struct HttpMessageHandler<T>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
{ {
pub forwarder: Arc<Client<T>>, forwarder: Arc<Client<T>>,
pub globals: Arc<Globals>, globals: Arc<Globals>,
} }
impl<T> HttpMessageHandler<T> impl<T> HttpMessageHandler<T>

View file

@ -3,4 +3,4 @@ mod utils_headers;
mod utils_request; mod utils_request;
mod utils_synth_response; mod utils_synth_response;
pub use handler_main::HttpMessageHandler; pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError};

View file

@ -21,12 +21,12 @@ use crate::{
constants::*, constants::*,
error::*, error::*,
globals::*, globals::*,
handler::HttpMessageHandlerBuilder,
log::*, log::*,
proxy::Proxy, proxy::ProxyBuilder,
utils::ServerNameBytesExp, utils::ServerNameBytesExp,
}; };
use futures::future::select_all; use futures::future::select_all;
use handler::HttpMessageHandler;
use hyper::Client; use hyper::Client;
// use hyper_trust_dns::TrustDnsResolver; // use hyper_trust_dns::TrustDnsResolver;
use rustc_hash::FxHashMap as HashMap; use rustc_hash::FxHashMap as HashMap;
@ -110,10 +110,11 @@ async fn entrypoint(globals: Arc<Globals>) -> Result<()> {
.enable_http1() .enable_http1()
.enable_http2() .enable_http2()
.build(); .build();
let msg_handler = HttpMessageHandler {
forwarder: Arc::new(Client::builder().build::<_, hyper::Body>(connector)), let msg_handler = HttpMessageHandlerBuilder::default()
globals: globals.clone(), .forwarder(Arc::new(Client::builder().build::<_, hyper::Body>(connector)))
}; .globals(globals.clone())
.build()?;
let addresses = globals.listen_sockets.clone(); let addresses = globals.listen_sockets.clone();
let futures = select_all(addresses.into_iter().map(|addr| { let futures = select_all(addresses.into_iter().map(|addr| {
@ -122,12 +123,14 @@ async fn entrypoint(globals: Arc<Globals>) -> Result<()> {
tls_enabled = https_port == addr.port() tls_enabled = https_port == addr.port()
} }
let proxy = Proxy { let proxy = ProxyBuilder::default()
globals: globals.clone(), .globals(globals.clone())
listening_on: addr, .listening_on(addr)
tls_enabled, .tls_enabled(tls_enabled)
msg_handler: msg_handler.clone(), .msg_handler(msg_handler.clone())
}; .build()
.unwrap();
globals.runtime_handle.spawn(proxy.start()) globals.runtime_handle.spawn(proxy.start())
})); }));

View file

@ -4,4 +4,4 @@ mod proxy_h3;
mod proxy_main; mod proxy_main;
mod proxy_tls; mod proxy_tls;
pub use proxy_main::Proxy; pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError};

View file

@ -1,5 +1,6 @@
// use super::proxy_handler::handle_request; // use super::proxy_handler::handle_request;
use crate::{error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp}; use crate::{error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp};
use derive_builder::{self, Builder};
use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request};
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use tokio::{ use tokio::{
@ -30,7 +31,7 @@ where
} }
} }
#[derive(Clone)] #[derive(Clone, Builder)]
pub struct Proxy<T> pub struct Proxy<T>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,