refactor with derive_builder

This commit is contained in:
Jun Kurihara 2023-01-19 18:27:31 +09:00
commit d2b5cdcc5b
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
10 changed files with 142 additions and 82 deletions

View file

@ -54,6 +54,7 @@ h3 = { path = "./h3/h3/", optional = true }
h3-quinn = { path = "./h3/h3-quinn/", optional = true }
thiserror = "1.0.37"
x509-parser = "0.14.0"
derive_builder = "0.12.0"
[target.'cfg(not(target_env = "msvc"))'.dependencies]

View file

@ -5,9 +5,11 @@ use crate::{
log::*,
utils::{BytesName, PathNameBytesExp, ServerNameBytesExp},
};
use derive_builder::Builder;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use rustls::{OwnedTrustAnchor, RootCertStore};
use std::{
borrow::Cow,
fs::File,
io::{self, BufReader, Cursor, Read},
path::PathBuf,
@ -18,22 +20,51 @@ use tokio_rustls::rustls::{
sign::{any_supported_type, CertifiedKey},
Certificate, PrivateKey, ServerConfig,
};
pub use upstream::{ReverseProxy, Upstream, UpstreamGroup};
pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder};
pub use upstream_opts::UpstreamOption;
use x509_parser::prelude::*;
/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
#[derive(Builder)]
pub struct Backend {
#[builder(setter(into))]
pub app_name: String,
#[builder(setter(custom))]
pub server_name: String,
pub reverse_proxy: ReverseProxy,
// tls settings
#[builder(setter(custom), default)]
pub tls_cert_path: Option<PathBuf>,
#[builder(setter(custom), default)]
pub tls_cert_key_path: Option<PathBuf>,
#[builder(default)]
pub https_redirection: Option<bool>,
#[builder(setter(custom), default)]
pub client_ca_cert_path: Option<PathBuf>,
}
impl<'a> BackendBuilder {
pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self {
self.server_name = Some(server_name.into().to_ascii_lowercase());
self
}
pub fn tls_cert_path(&mut self, v: &Option<String>) -> &mut Self {
self.tls_cert_path = Some(opt_string_to_opt_pathbuf(v));
self
}
pub fn tls_cert_key_path(&mut self, v: &Option<String>) -> &mut Self {
self.tls_cert_key_path = Some(opt_string_to_opt_pathbuf(v));
self
}
pub fn client_ca_cert_path(&mut self, v: &Option<String>) -> &mut Self {
self.client_ca_cert_path = Some(opt_string_to_opt_pathbuf(v));
self
}
}
fn opt_string_to_opt_pathbuf(input: &Option<String>) -> Option<PathBuf> {
input.to_owned().as_ref().map(PathBuf::from)
}
impl Backend {
pub fn read_certs_and_key(&self) -> io::Result<CertifiedKey> {

View file

@ -1,5 +1,6 @@
use super::{BytesName, PathNameBytesExp, UpstreamOption};
use crate::log::*;
use derive_builder::Builder;
use rand::Rng;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use std::{
@ -66,15 +67,50 @@ pub struct Upstream {
pub uri: hyper::Uri, // base uri without specific path
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Builder)]
pub struct UpstreamGroup {
pub upstream: Vec<Upstream>,
#[builder(setter(custom), default)]
pub path: PathNameBytesExp,
#[builder(setter(custom), default)]
pub replace_path: Option<PathNameBytesExp>,
#[builder(default)]
pub lb: LoadBalance,
#[builder(default)]
pub cnt: UpstreamCount, // counter for load balancing
#[builder(setter(custom), default)]
pub opts: HashSet<UpstreamOption>,
}
impl UpstreamGroupBuilder {
pub fn path(&mut self, v: &Option<String>) -> &mut Self {
let path = match v {
Some(p) => p.to_path_name_vec(),
None => "/".to_path_name_vec(),
};
self.path = Some(path);
self
}
pub fn replace_path(&mut self, v: &Option<String>) -> &mut Self {
self.replace_path = Some(
v.to_owned()
.as_ref()
.map_or_else(|| None, |v| Some(v.to_path_name_vec())),
);
self
}
pub fn opts(&mut self, v: &Option<Vec<String>>) -> &mut Self {
let opts = if let Some(opts) = v {
opts
.iter()
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
.collect::<HashSet<UpstreamOption>>()
} else {
Default::default()
};
self.opts = Some(opts);
self
}
}
#[derive(Debug, Clone, Default)]
pub struct UpstreamCount(Arc<AtomicUsize>);

View file

@ -1,6 +1,6 @@
use super::toml::{ConfigToml, ReverseProxyOption};
use crate::{
backend::{Backend, ReverseProxy, UpstreamGroup, UpstreamOption},
backend::{BackendBuilder, ReverseProxy, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption},
constants::*,
error::*,
globals::*,
@ -8,9 +8,8 @@ use crate::{
utils::{BytesName, PathNameBytesExp},
};
use clap::Arg;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use rustc_hash::FxHashMap as HashMap;
use std::net::SocketAddr;
use std::path::PathBuf;
pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Error> {
let _ = include_str!("../../Cargo.toml");
@ -91,49 +90,49 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro
for (app_name, app) in apps.0.iter() {
ensure!(app.server_name.is_some(), "Missing server_name");
let server_name_string = app.server_name.as_ref().unwrap();
// TLS settings
let (tls_cert_path, tls_cert_key_path, https_redirection, client_ca_cert_path) = if app.tls.is_none() {
ensure!(globals.http_port.is_some(), "Required HTTP port");
(None, None, None, None)
} else {
let tls = app.tls.as_ref().unwrap();
ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some());
(
tls.tls_cert_path.as_ref().map(PathBuf::from),
tls.tls_cert_key_path.as_ref().map(PathBuf::from),
if tls.https_redirection.is_none() {
Some(true) // Default true
} else {
ensure!(globals.https_port.is_some()); // only when both https ports are configured.
tls.https_redirection
},
tls.client_ca_cert_path.as_ref().map(PathBuf::from),
)
};
if globals.http_port.is_none() {
// if only https_port is specified, tls must be configured
ensure!(app.tls.is_some())
}
// backend builder
let mut backend_builder = BackendBuilder::default();
// reverse proxy settings
ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy");
let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?;
globals.backends.apps.insert(
server_name_string.to_server_name_vec(),
Backend {
app_name: app_name.to_owned(),
server_name: server_name_string.to_ascii_lowercase(),
reverse_proxy,
backend_builder
.app_name(server_name_string)
.server_name(server_name_string)
.reverse_proxy(reverse_proxy);
tls_cert_path,
tls_cert_key_path,
https_redirection,
client_ca_cert_path,
},
);
// TLS settings and build backend instance
let backend = if app.tls.is_none() {
ensure!(globals.http_port.is_some(), "Required HTTP port");
backend_builder.build()?
} else {
let tls = app.tls.as_ref().unwrap();
ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some());
let https_redirection = if tls.https_redirection.is_none() {
Some(true) // Default true
} else {
ensure!(globals.https_port.is_some()); // only when both https ports are configured.
tls.https_redirection
};
backend_builder
.tls_cert_path(&tls.tls_cert_path)
.tls_cert_key_path(&tls.tls_cert_key_path)
.https_redirection(https_redirection)
.client_ca_cert_path(&tls.client_ca_cert_path)
.build()?
};
globals
.backends
.apps
.insert(server_name_string.to_server_name_vec(), backend);
info!("Registering application: {} ({})", app_name, server_name_string);
}
@ -194,33 +193,15 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro
fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result<ReverseProxy, anyhow::Error> {
let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default();
rp_settings.iter().for_each(|rpo| {
let path = match &rpo.path {
Some(p) => p.to_path_name_vec(),
None => "/".to_path_name_vec(),
};
let elem = UpstreamGroupBuilder::default()
.upstream(rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect())
.path(&rpo.path)
.replace_path(&rpo.replace_path)
.opts(&rpo.upstream_options)
.build()
.unwrap();
let elem = UpstreamGroup {
upstream: rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(),
path: path.clone(),
replace_path: rpo
.replace_path
.as_ref()
.map_or_else(|| None, |v| Some(v.to_path_name_vec())),
cnt: Default::default(),
lb: Default::default(),
opts: {
if let Some(opts) = &rpo.upstream_options {
opts
.iter()
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
.collect::<HashSet<UpstreamOption>>()
} else {
Default::default()
}
},
};
upstream.insert(path, elem);
upstream.insert(elem.path.clone(), elem);
});
ensure!(
rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2,

View file

@ -7,6 +7,12 @@ pub type Result<T> = std::result::Result<T, RpxyError>;
/// Describes things that can go wrong in the Rpxy
#[derive(Debug, Error)]
pub enum RpxyError {
#[error("Proxy build error")]
ProxyBuild(#[from] crate::proxy::ProxyBuilderError),
#[error("MessageHandler build error")]
HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError),
#[error("Http Message Handler Error: {0}")]
Handler(&'static str),

View file

@ -7,6 +7,7 @@ use crate::{
log::*,
utils::ServerNameBytesExp,
};
use derive_builder::Builder;
use hyper::{
client::connect::Connect,
header::{self, HeaderValue},
@ -16,13 +17,13 @@ use hyper::{
use std::{env, net::SocketAddr, sync::Arc};
use tokio::{io::copy_bidirectional, time::timeout};
#[derive(Clone)]
#[derive(Clone, Builder)]
pub struct HttpMessageHandler<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
pub forwarder: Arc<Client<T>>,
pub globals: Arc<Globals>,
forwarder: Arc<Client<T>>,
globals: Arc<Globals>,
}
impl<T> HttpMessageHandler<T>

View file

@ -3,4 +3,4 @@ mod utils_headers;
mod utils_request;
mod utils_synth_response;
pub use handler_main::HttpMessageHandler;
pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError};

View file

@ -21,12 +21,12 @@ use crate::{
constants::*,
error::*,
globals::*,
handler::HttpMessageHandlerBuilder,
log::*,
proxy::Proxy,
proxy::ProxyBuilder,
utils::ServerNameBytesExp,
};
use futures::future::select_all;
use handler::HttpMessageHandler;
use hyper::Client;
// use hyper_trust_dns::TrustDnsResolver;
use rustc_hash::FxHashMap as HashMap;
@ -110,10 +110,11 @@ async fn entrypoint(globals: Arc<Globals>) -> Result<()> {
.enable_http1()
.enable_http2()
.build();
let msg_handler = HttpMessageHandler {
forwarder: Arc::new(Client::builder().build::<_, hyper::Body>(connector)),
globals: globals.clone(),
};
let msg_handler = HttpMessageHandlerBuilder::default()
.forwarder(Arc::new(Client::builder().build::<_, hyper::Body>(connector)))
.globals(globals.clone())
.build()?;
let addresses = globals.listen_sockets.clone();
let futures = select_all(addresses.into_iter().map(|addr| {
@ -122,12 +123,14 @@ async fn entrypoint(globals: Arc<Globals>) -> Result<()> {
tls_enabled = https_port == addr.port()
}
let proxy = Proxy {
globals: globals.clone(),
listening_on: addr,
tls_enabled,
msg_handler: msg_handler.clone(),
};
let proxy = ProxyBuilder::default()
.globals(globals.clone())
.listening_on(addr)
.tls_enabled(tls_enabled)
.msg_handler(msg_handler.clone())
.build()
.unwrap();
globals.runtime_handle.spawn(proxy.start())
}));

View file

@ -4,4 +4,4 @@ mod proxy_h3;
mod proxy_main;
mod proxy_tls;
pub use proxy_main::Proxy;
pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError};

View file

@ -1,5 +1,6 @@
// use super::proxy_handler::handle_request;
use crate::{error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp};
use derive_builder::{self, Builder};
use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
@ -30,7 +31,7 @@ where
}
}
#[derive(Clone)]
#[derive(Clone, Builder)]
pub struct Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,