add toml config support

This commit is contained in:
Jun Kurihara 2022-06-27 15:07:29 -04:00
commit c3c95e9589
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
10 changed files with 246 additions and 116 deletions

View file

@ -48,6 +48,8 @@ hyper-trust-dns = { version = "0.4.2", default-features = false, features = [
] } ] }
rustls = "0.20.6" rustls = "0.20.6"
rand = "0.8.5" rand = "0.8.5"
toml = "0.5.9"
serde = { version = "1.0.137", features = ["derive"] }
[dev-dependencies] [dev-dependencies]

View file

@ -6,20 +6,20 @@
################################### ###################################
# Global settings # # Global settings #
################################### ###################################
http_port = 8080 # Both or either one of http/https ports must be specified
https_port = 8443 listen_port = 8080
listen_port_tls = 8443
################################### ###################################
# Backend settings # # Backend settings #
################################### ###################################
[application]
[[application]] [apps.localhost]
app_name = 'localhost' # this should be option, if null then same as server_name server_name = 'localhost'
hostname = 'localhost'
https_redirection = true
reverse_proxy = [ reverse_proxy = [
# default destination if path is not specified # default destination if path is not specified
# TODO: Array for load balancing # Array for load balancing
{ upstream = [ { upstream = [
{ location = 'www.google.com', tls = true }, { location = 'www.google.com', tls = true },
{ location = 'www.google.co.jp', tls = true }, { location = 'www.google.co.jp', tls = true },
@ -29,18 +29,13 @@ reverse_proxy = [
{ location = 'www.bing.co.jp', tls = true }, { location = 'www.bing.co.jp', tls = true },
] }, ] },
] ]
tls = { https_redirection = true, tls_cert_path = 'localhost.pem', tls_cert_key_path = 'localhost.pem' }
## List of destinations to send data to. ## List of destinations to send data to.
## At this point, round-robin is used for load-balancing if multiple URLs are specified. ## At this point, round-robin is used for load-balancing if multiple URLs are specified.
# allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] # TODO # allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] # TODO
# denyhosts = ['*'] # TODO # denyhosts = ['*'] # TODO
tls_cert_path = 'localhost.pem' [apps.another_localhost]
tls_cert_key_path = 'localhost.pem' server_name = 'localhost.localdomain'
[[application]]
app_name = 'locahost_application'
hostname = 'localhost.localdomain'
https_redirection = true
reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }]
tls_cert_path = 'localhost.pem'
tls_cert_key_path = 'localhost.pem'

View file

@ -14,17 +14,19 @@ use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
pub struct Backend { pub struct Backend {
pub app_name: String, pub app_name: String,
pub hostname: String, pub server_name: String,
pub reverse_proxy: ReverseProxy, pub reverse_proxy: ReverseProxy,
pub https_redirection: Option<bool>,
// tls settings
pub tls_cert_path: Option<PathBuf>, pub tls_cert_path: Option<PathBuf>,
pub tls_cert_key_path: Option<PathBuf>, pub tls_cert_key_path: Option<PathBuf>,
pub https_redirection: Option<bool>,
pub server_config: Mutex<Option<ServerConfig>>, pub server_config: Mutex<Option<ServerConfig>>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ReverseProxy { pub struct ReverseProxy {
pub default_upstream: Upstream, pub default_upstream: Option<Upstream>,
pub upstream: HashMap<String, Upstream>, pub upstream: HashMap<String, Upstream>,
} }

View file

@ -1,3 +1,4 @@
mod parse; mod parse;
mod toml;
pub use parse::parse_opts; pub use parse::parse_opts;

View file

@ -1,60 +1,141 @@
use crate::{backend::*, constants::*, globals::*}; use super::toml::{ConfigToml, ReverseProxyOption};
use hyper::Uri; use crate::{backend::*, constants::*, error::*, globals::*, log::*};
use clap::Arg;
use std::net::SocketAddr;
use std::{collections::HashMap, sync::Mutex}; use std::{collections::HashMap, sync::Mutex};
// #[cfg(feature = "tls")] // #[cfg(feature = "tls")]
use std::path::PathBuf; use std::path::PathBuf;
pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>) { pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>) -> Result<()> {
// TODO: let _ = include_str!("../../Cargo.toml");
let options = clap::command!().arg(
Arg::new("config_file")
.long("config")
.short('c')
.takes_value(true)
.help("Configuration file path like \"./config.toml\""),
);
let matches = options.get_matches();
let config = if let Some(config_file_path) = matches.value_of("config_file") {
ConfigToml::new(config_file_path)?
} else {
// Default config Toml
ConfigToml::default()
};
// listen port and scket
globals.http_port = config.listen_port;
globals.https_port = config.listen_port_tls;
ensure!(
{ globals.http_port.is_some() || globals.https_port.is_some() } && {
if let (Some(p), Some(t)) = (globals.http_port, globals.https_port) {
p != t
} else {
true
}
},
anyhow!("Wrong port spec.")
);
globals.listen_sockets = LISTEN_ADDRESSES globals.listen_sockets = LISTEN_ADDRESSES
.to_vec() .to_vec()
.iter() .iter()
.flat_map(|x| { .flat_map(|x| {
vec![ let mut v: Vec<SocketAddr> = vec![];
format!("{}:{}", x, HTTP_LISTEN_PORT).parse().unwrap(), if let Some(p) = globals.http_port {
format!("{}:{}", x, HTTPS_LISTEN_PORT).parse().unwrap(), v.push(format!("{}:{}", x, p).parse().unwrap());
] }
if let Some(p) = globals.https_port {
v.push(format!("{}:{}", x, p).parse().unwrap());
}
v
}) })
.collect(); .collect();
globals.http_port = Some(HTTP_LISTEN_PORT); if globals.http_port.is_some() {
globals.https_port = Some(HTTPS_LISTEN_PORT); info!("Listen port: {}", globals.http_port.unwrap());
}
if globals.https_port.is_some() {
info!("Listen port: {} (for TLS)", globals.https_port.unwrap());
}
// TODO: // backend apps
let mut map_example: HashMap<String, Upstream> = HashMap::new(); ensure!(config.apps.is_some(), "Missing application spec.");
map_example.insert( let apps = config.apps.unwrap();
"/maps".to_string(), ensure!(!apps.0.is_empty(), "Wrong application spec.");
Upstream {
uri: vec![ // each app
"https://www.bing.com".parse::<Uri>().unwrap(), for (app_name, app) in apps.0.iter() {
"https://www.bing.co.jp".parse::<Uri>().unwrap(), ensure!(app.server_name.is_some(), "Missing server_name");
], let server_name = app.server_name.as_ref().unwrap();
// TLS settings
let (tls_cert_path, tls_cert_key_path, https_redirection) = if app.tls.is_none() {
ensure!(globals.http_port.is_some(), "Required HTTP port");
(None, None, None)
} else {
let tls = app.tls.as_ref().unwrap();
ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some());
(
tls.tls_cert_path.as_ref().map(PathBuf::from),
tls.tls_cert_key_path.as_ref().map(PathBuf::from),
if tls.https_redirection.is_none() {
Some(true) // Default true
} else {
ensure!(globals.https_port.is_some()); // only when both https ports are configured.
tls.https_redirection
},
)
};
if globals.http_port.is_none() {
// if only https_port is specified, tls must be configured
ensure!(app.tls.is_some())
}
// reverse proxy settings
ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy");
let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?;
backends.insert(
server_name.to_owned(),
Backend {
app_name: app_name.to_owned(),
server_name: server_name.to_owned(),
reverse_proxy,
tls_cert_path,
tls_cert_key_path,
https_redirection,
server_config: Mutex::new(None),
},
);
info!("Registering application: {} ({})", app_name, server_name);
}
Ok(())
}
fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> Result<ReverseProxy> {
let mut upstream: HashMap<String, Upstream> = HashMap::new();
let mut default_upstream: Option<Upstream> = None;
rp_settings.iter().for_each(|rpo| {
let elem = Upstream {
uri: rpo.upstream.iter().map(|x| x.to_uri().unwrap()).collect(),
cnt: Default::default(), cnt: Default::default(),
lb: Default::default(), lb: Default::default(),
}, };
); if rpo.path.is_some() {
backends.insert( upstream.insert(rpo.path.as_ref().unwrap().to_owned(), elem);
"localhost".to_string(), } else {
Backend { default_upstream = Some(elem)
app_name: "Localhost to Google except for maps".to_string(), }
hostname: "localhost".to_string(), });
reverse_proxy: ReverseProxy { ensure!(
default_upstream: Upstream { rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2,
uri: vec![ "Multiple default reverse proxy setting"
"https://www.google.com".parse::<Uri>().unwrap(),
"https://www.google.co.jp".parse::<Uri>().unwrap(),
],
cnt: Default::default(),
lb: Default::default(),
},
// default_upstream_uri: vec!["http://abehiroshi.la.coocan.jp/".parse::<Uri>().unwrap()], // httpのみの場合の好例
upstream: map_example,
},
https_redirection: Some(false), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず
tls_cert_path: Some(PathBuf::from(r"localhost1.pem")),
tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")),
server_config: Mutex::new(None),
},
); );
Ok(ReverseProxy {
default_upstream,
upstream,
})
} }

64
src/config/toml.rs Normal file
View file

@ -0,0 +1,64 @@
use crate::error::*;
use serde::Deserialize;
use std::{collections::HashMap, fs};
#[derive(Deserialize, Debug, Default)]
pub struct ConfigToml {
pub listen_port: Option<u16>,
pub listen_port_tls: Option<u16>,
pub apps: Option<Apps>,
}
#[derive(Deserialize, Debug, Default)]
pub struct Apps(pub HashMap<String, Application>);
#[derive(Deserialize, Debug, Default)]
pub struct Application {
pub server_name: Option<String>,
pub reverse_proxy: Option<Vec<ReverseProxyOption>>,
pub tls: Option<TlsOption>,
}
#[derive(Deserialize, Debug, Default)]
pub struct TlsOption {
pub tls_cert_path: Option<String>,
pub tls_cert_key_path: Option<String>,
pub https_redirection: Option<bool>,
}
#[derive(Deserialize, Debug, Default)]
pub struct ReverseProxyOption {
pub path: Option<String>,
pub upstream: Vec<UpstreamOption>,
}
#[derive(Deserialize, Debug, Default)]
pub struct UpstreamOption {
pub location: String,
pub tls: Option<bool>,
}
impl UpstreamOption {
pub fn to_uri(&self) -> Result<hyper::Uri> {
let mut scheme = "http";
if let Some(t) = self.tls {
if t {
scheme = "https";
}
}
let location = format!("{}://{}", scheme, self.location);
location.parse::<hyper::Uri>().map_err(|e| anyhow!("{}", e))
}
}
impl ConfigToml {
pub fn new(config_file: &str) -> Result<Self> {
let config_str = if let Ok(s) = fs::read_to_string(config_file) {
s
} else {
bail!("Failed to read config file");
};
let parsed: Result<ConfigToml> = toml::from_str(&config_str)
.map_err(|e: toml::de::Error| anyhow!("Failed to parse toml config: {:?}", e));
parsed
}
}

View file

@ -57,7 +57,7 @@ fn main() {
let mut backends: HashMap<String, Backend> = HashMap::new(); let mut backends: HashMap<String, Backend> = HashMap::new();
parse_opts(&mut globals, &mut backends); let _ = parse_opts(&mut globals, &mut backends).expect("Invalid configuration");
entrypoint(Arc::new(globals), Arc::new(backends)) entrypoint(Arc::new(globals), Arc::new(backends))
.await .await
@ -78,8 +78,6 @@ async fn entrypoint(globals: Arc<Globals>, backends: Arc<HashMap<String, Backend
tls_enabled = https_port == (addr.port() as u16) tls_enabled = https_port == (addr.port() as u16)
} }
info!("Listen address: {:?} (TLS = {})", addr, tls_enabled);
let proxy = Proxy { let proxy = Proxy {
globals: globals.clone(), globals: globals.clone(),
listening_on: addr, listening_on: addr,

View file

@ -32,14 +32,14 @@ where
client_addr: SocketAddr, // アクセス制御用 client_addr: SocketAddr, // アクセス制御用
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
debug!("Handling request: {:?}", req); debug!("Handling request: {:?}", req);
// Here we start to handle with hostname // Here we start to handle with server_name
// Find backend application for given hostname // Find backend application for given server_name
let (hostname, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) { let (server_name, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) {
v v
} else { } else {
return http_error(StatusCode::SERVICE_UNAVAILABLE); return http_error(StatusCode::SERVICE_UNAVAILABLE);
}; };
let backend = if let Some(be) = self.backends.get(hostname.as_str()) { let backend = if let Some(be) = self.backends.get(server_name.as_str()) {
be be
} else { } else {
return http_error(StatusCode::SERVICE_UNAVAILABLE); return http_error(StatusCode::SERVICE_UNAVAILABLE);
@ -48,16 +48,19 @@ where
// Redirect to https if tls_enabled is false and redirect_to_https is true // Redirect to https if tls_enabled is false and redirect_to_https is true
let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned(); let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned();
if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { if !self.tls_enabled && backend.https_redirection.unwrap_or(false) {
debug!("Redirect to secure connection: {}", hostname); debug!("Redirect to secure connection: {}", server_name);
return secure_redirection(&hostname, self.globals.https_port, &path_and_query); return secure_redirection(&server_name, self.globals.https_port, &path_and_query);
} }
// Find reverse proxy for given path and choose one of upstream host // Find reverse proxy for given path and choose one of upstream host
// TODO: More flexible path matcher
let path = req.uri().path(); let path = req.uri().path();
let upstream_uri = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) { let upstream_uri = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) {
upstream.get() upstream.get()
} else if let Some(default_upstream) = &backend.reverse_proxy.default_upstream {
default_upstream.get()
} else { } else {
backend.reverse_proxy.default_upstream.get() return http_error(StatusCode::NOT_FOUND);
}; };
let upstream_scheme_host = if let Some(u) = upstream_uri { let upstream_scheme_host = if let Some(u) = upstream_uri {
u u
@ -263,15 +266,15 @@ fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
} }
fn secure_redirection( fn secure_redirection(
hostname: &str, server_name: &str,
tls_port: Option<u16>, tls_port: Option<u16>,
path_and_query: &str, path_and_query: &str,
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
let dest_uri: String = if let Some(tls_port) = tls_port { let dest_uri: String = if let Some(tls_port) = tls_port {
if tls_port == 443 { if tls_port == 443 {
format!("https://{}{}", hostname, path_and_query) format!("https://{}{}", server_name, path_and_query)
} else { } else {
format!("https://{}:{}{}", hostname, tls_port, path_and_query) format!("https://{}:{}{}", server_name, tls_port, path_and_query)
} }
} else { } else {
bail!("Internal error! TLS port is not set internally."); bail!("Internal error! TLS port is not set internally.");
@ -285,15 +288,15 @@ fn secure_redirection(
} }
fn parse_host_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> { fn parse_host_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> {
let hostname_port_headers = req.headers().get("host"); let host_port_headers = req.headers().get("host");
let hostname_uri = req.uri().host(); let host_uri = req.uri().host();
let port_uri = req.uri().port_u16(); let port_uri = req.uri().port_u16();
if hostname_port_headers.is_none() && hostname_uri.is_none() { if host_port_headers.is_none() && host_uri.is_none() {
bail!("No host in request header"); bail!("No host in request header");
} }
let (hostname, port) = match (hostname_uri, hostname_port_headers) { let (host, port) = match (host_uri, host_port_headers) {
(Some(x), _) => { (Some(x), _) => {
let port = if let Some(p) = port_uri { let port = if let Some(p) = port_uri {
p p
@ -306,9 +309,9 @@ fn parse_host_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u1
} }
(None, Some(x)) => { (None, Some(x)) => {
let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap(); let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap();
let hostname = hp_as_uri let host = hp_as_uri
.host() .host()
.ok_or_else(|| anyhow!("Failed to parse hostname"))?; .ok_or_else(|| anyhow!("Failed to parse host"))?;
let port = if let Some(p) = hp_as_uri.port() { let port = if let Some(p) = hp_as_uri.port() {
p.as_u16() p.as_u16()
} else if tls_enabled { } else if tls_enabled {
@ -316,38 +319,12 @@ fn parse_host_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u1
} else { } else {
80 80
}; };
(hostname.to_string(), port) (host.to_string(), port)
} }
(None, None) => { (None, None) => {
bail!("Host unspecified in request") bail!("Host unspecified in request")
} }
}; };
Ok((hostname, port)) Ok((host, port))
} }
// fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
// #[allow(clippy::blocks_in_if_conditions)]
// if headers
// .get(&*CONNECTION_HEADER)
// .map(|value| {
// value
// .to_str()
// .unwrap()
// .split(',')
// .any(|e| e.trim() == *UPGRADE_HEADER)
// })
// .unwrap_or(false)
// {
// if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
// debug!(
// "Found upgrade header with value: {}",
// upgrade_value.to_str().unwrap().to_owned()
// );
// return Some(upgrade_value.to_str().unwrap().to_owned());
// }
// }
// None
// }

View file

@ -39,7 +39,7 @@ where
{ {
pub listening_on: SocketAddr, pub listening_on: SocketAddr,
pub tls_enabled: bool, // TCP待受がTLSかどうか pub tls_enabled: bool, // TCP待受がTLSかどうか
pub backends: Arc<HashMap<String, Backend>>, // TODO: hyper::uriで抜いたhostnameで引っ掛ける。Stringでいいのか pub backends: Arc<HashMap<String, Backend>>, // TODO: hyper::uriで抜いたhostで引っ掛ける。Stringでいいのか
pub forwarder: Arc<Client<T>>, pub forwarder: Arc<Client<T>>,
pub globals: Arc<Globals>, pub globals: Arc<Globals>,
} }

View file

@ -17,10 +17,10 @@ where
let cert_service = async { let cert_service = async {
info!("Start cert watch service for {}", self.listening_on); info!("Start cert watch service for {}", self.listening_on);
loop { loop {
for (hostname, backend) in self.backends.iter() { for (server_name, backend) in self.backends.iter() {
if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() {
if let Err(_e) = backend.update_server_config().await { if let Err(_e) = backend.update_server_config().await {
warn!("Failed to update certs for {}", hostname); warn!("Failed to update certs for {}", server_name);
} }
} }
} }
@ -59,9 +59,19 @@ where
info!("No configuration for the server name {} given in client_hello", svn); info!("No configuration for the server name {} given in client_hello", svn);
continue; continue;
}; };
let server_config = backend_serve.get_tls_server_config();
if backend_serve.tls_cert_path.is_none() { // at least cert does exit
debug!("SNI indicates a site that doesn't support TLS.");
continue;
}
let server_config = if let Some(p) = backend_serve.get_tls_server_config(){
p
} else {
error!("Failed to load server config");
continue;
};
// Finally serve the TLS connection // Finally serve the TLS connection
if let Ok(stream) = start.into_stream(Arc::new(server_config.unwrap())).await { if let Ok(stream) = start.into_stream(Arc::new(server_config)).await {
self.clone().client_serve(stream, server.clone(), _client_addr).await self.clone().client_serve(stream, server.clone(), _client_addr).await
} }
} }