diff --git a/Cargo.toml b/Cargo.toml index 21211a4..87f2273 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,8 @@ hyper-trust-dns = { version = "0.4.2", default-features = false, features = [ ] } rustls = "0.20.6" rand = "0.8.5" +toml = "0.5.9" +serde = { version = "1.0.137", features = ["derive"] } [dev-dependencies] diff --git a/config-example.toml b/config-example.toml index b8b846f..4d69180 100644 --- a/config-example.toml +++ b/config-example.toml @@ -6,20 +6,20 @@ ################################### # Global settings # ################################### -http_port = 8080 -https_port = 8443 +# Both or either one of http/https ports must be specified +listen_port = 8080 +listen_port_tls = 8443 ################################### # Backend settings # ################################### +[application] -[[application]] -app_name = 'localhost' # this should be option, if null then same as server_name -hostname = 'localhost' -https_redirection = true +[apps.localhost] +server_name = 'localhost' reverse_proxy = [ # default destination if path is not specified - # TODO: Array for load balancing + # Array for load balancing { upstream = [ { location = 'www.google.com', tls = true }, { location = 'www.google.co.jp', tls = true }, @@ -29,18 +29,13 @@ reverse_proxy = [ { location = 'www.bing.co.jp', tls = true }, ] }, ] +tls = { https_redirection = true, tls_cert_path = 'localhost.pem', tls_cert_key_path = 'localhost.pem' } + + ## List of destinations to send data to. ## At this point, round-robin is used for load-balancing if multiple URLs are specified. # allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] # TODO # denyhosts = ['*'] # TODO -tls_cert_path = 'localhost.pem' -tls_cert_key_path = 'localhost.pem' - - -[[application]] -app_name = 'locahost_application' -hostname = 'localhost.localdomain' -https_redirection = true +[apps.another_localhost] +server_name = 'localhost.localdomain' reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] -tls_cert_path = 'localhost.pem' -tls_cert_key_path = 'localhost.pem' diff --git a/src/backend.rs b/src/backend.rs index aebb31e..e67f80e 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -14,17 +14,19 @@ use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; pub struct Backend { pub app_name: String, - pub hostname: String, + pub server_name: String, pub reverse_proxy: ReverseProxy, - pub https_redirection: Option, + + // tls settings pub tls_cert_path: Option, pub tls_cert_key_path: Option, + pub https_redirection: Option, pub server_config: Mutex>, } #[derive(Debug, Clone)] pub struct ReverseProxy { - pub default_upstream: Upstream, + pub default_upstream: Option, pub upstream: HashMap, } diff --git a/src/config/mod.rs b/src/config/mod.rs index 9f0d40a..6e8123c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,3 +1,4 @@ mod parse; +mod toml; pub use parse::parse_opts; diff --git a/src/config/parse.rs b/src/config/parse.rs index 207bcf9..f35ccf2 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,60 +1,141 @@ -use crate::{backend::*, constants::*, globals::*}; -use hyper::Uri; +use super::toml::{ConfigToml, ReverseProxyOption}; +use crate::{backend::*, constants::*, error::*, globals::*, log::*}; +use clap::Arg; +use std::net::SocketAddr; use std::{collections::HashMap, sync::Mutex}; // #[cfg(feature = "tls")] use std::path::PathBuf; -pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap) { - // TODO: +pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap) -> Result<()> { + let _ = include_str!("../../Cargo.toml"); + let options = clap::command!().arg( + Arg::new("config_file") + .long("config") + .short('c') + .takes_value(true) + .help("Configuration file path like \"./config.toml\""), + ); + let matches = options.get_matches(); + + let config = if let Some(config_file_path) = matches.value_of("config_file") { + ConfigToml::new(config_file_path)? + } else { + // Default config Toml + ConfigToml::default() + }; + + // listen port and scket + globals.http_port = config.listen_port; + globals.https_port = config.listen_port_tls; + ensure!( + { globals.http_port.is_some() || globals.https_port.is_some() } && { + if let (Some(p), Some(t)) = (globals.http_port, globals.https_port) { + p != t + } else { + true + } + }, + anyhow!("Wrong port spec.") + ); globals.listen_sockets = LISTEN_ADDRESSES .to_vec() .iter() .flat_map(|x| { - vec![ - format!("{}:{}", x, HTTP_LISTEN_PORT).parse().unwrap(), - format!("{}:{}", x, HTTPS_LISTEN_PORT).parse().unwrap(), - ] + let mut v: Vec = vec![]; + if let Some(p) = globals.http_port { + v.push(format!("{}:{}", x, p).parse().unwrap()); + } + if let Some(p) = globals.https_port { + v.push(format!("{}:{}", x, p).parse().unwrap()); + } + v }) .collect(); - globals.http_port = Some(HTTP_LISTEN_PORT); - globals.https_port = Some(HTTPS_LISTEN_PORT); + if globals.http_port.is_some() { + info!("Listen port: {}", globals.http_port.unwrap()); + } + if globals.https_port.is_some() { + info!("Listen port: {} (for TLS)", globals.https_port.unwrap()); + } - // TODO: - let mut map_example: HashMap = HashMap::new(); - map_example.insert( - "/maps".to_string(), - Upstream { - uri: vec![ - "https://www.bing.com".parse::().unwrap(), - "https://www.bing.co.jp".parse::().unwrap(), - ], + // backend apps + ensure!(config.apps.is_some(), "Missing application spec."); + let apps = config.apps.unwrap(); + ensure!(!apps.0.is_empty(), "Wrong application spec."); + + // each app + for (app_name, app) in apps.0.iter() { + ensure!(app.server_name.is_some(), "Missing server_name"); + let server_name = app.server_name.as_ref().unwrap(); + + // TLS settings + let (tls_cert_path, tls_cert_key_path, https_redirection) = if app.tls.is_none() { + ensure!(globals.http_port.is_some(), "Required HTTP port"); + (None, None, None) + } else { + let tls = app.tls.as_ref().unwrap(); + ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some()); + + ( + tls.tls_cert_path.as_ref().map(PathBuf::from), + tls.tls_cert_key_path.as_ref().map(PathBuf::from), + if tls.https_redirection.is_none() { + Some(true) // Default true + } else { + ensure!(globals.https_port.is_some()); // only when both https ports are configured. + tls.https_redirection + }, + ) + }; + if globals.http_port.is_none() { + // if only https_port is specified, tls must be configured + ensure!(app.tls.is_some()) + } + + // reverse proxy settings + ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); + let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; + + backends.insert( + server_name.to_owned(), + Backend { + app_name: app_name.to_owned(), + server_name: server_name.to_owned(), + reverse_proxy, + + tls_cert_path, + tls_cert_key_path, + https_redirection, + server_config: Mutex::new(None), + }, + ); + info!("Registering application: {} ({})", app_name, server_name); + } + Ok(()) +} + +fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> Result { + let mut upstream: HashMap = HashMap::new(); + let mut default_upstream: Option = None; + rp_settings.iter().for_each(|rpo| { + let elem = Upstream { + uri: rpo.upstream.iter().map(|x| x.to_uri().unwrap()).collect(), cnt: Default::default(), lb: Default::default(), - }, - ); - backends.insert( - "localhost".to_string(), - Backend { - app_name: "Localhost to Google except for maps".to_string(), - hostname: "localhost".to_string(), - reverse_proxy: ReverseProxy { - default_upstream: Upstream { - uri: vec![ - "https://www.google.com".parse::().unwrap(), - "https://www.google.co.jp".parse::().unwrap(), - ], - cnt: Default::default(), - lb: Default::default(), - }, - // default_upstream_uri: vec!["http://abehiroshi.la.coocan.jp/".parse::().unwrap()], // httpのみの場合の好例 - upstream: map_example, - }, - https_redirection: Some(false), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず - - tls_cert_path: Some(PathBuf::from(r"localhost1.pem")), - tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")), - server_config: Mutex::new(None), - }, + }; + if rpo.path.is_some() { + upstream.insert(rpo.path.as_ref().unwrap().to_owned(), elem); + } else { + default_upstream = Some(elem) + } + }); + ensure!( + rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2, + "Multiple default reverse proxy setting" ); + Ok(ReverseProxy { + default_upstream, + upstream, + }) } diff --git a/src/config/toml.rs b/src/config/toml.rs new file mode 100644 index 0000000..e614dbe --- /dev/null +++ b/src/config/toml.rs @@ -0,0 +1,64 @@ +use crate::error::*; +use serde::Deserialize; +use std::{collections::HashMap, fs}; + +#[derive(Deserialize, Debug, Default)] +pub struct ConfigToml { + pub listen_port: Option, + pub listen_port_tls: Option, + pub apps: Option, +} + +#[derive(Deserialize, Debug, Default)] +pub struct Apps(pub HashMap); + +#[derive(Deserialize, Debug, Default)] +pub struct Application { + pub server_name: Option, + pub reverse_proxy: Option>, + pub tls: Option, +} + +#[derive(Deserialize, Debug, Default)] +pub struct TlsOption { + pub tls_cert_path: Option, + pub tls_cert_key_path: Option, + pub https_redirection: Option, +} + +#[derive(Deserialize, Debug, Default)] +pub struct ReverseProxyOption { + pub path: Option, + pub upstream: Vec, +} + +#[derive(Deserialize, Debug, Default)] +pub struct UpstreamOption { + pub location: String, + pub tls: Option, +} +impl UpstreamOption { + pub fn to_uri(&self) -> Result { + let mut scheme = "http"; + if let Some(t) = self.tls { + if t { + scheme = "https"; + } + } + let location = format!("{}://{}", scheme, self.location); + location.parse::().map_err(|e| anyhow!("{}", e)) + } +} + +impl ConfigToml { + pub fn new(config_file: &str) -> Result { + let config_str = if let Ok(s) = fs::read_to_string(config_file) { + s + } else { + bail!("Failed to read config file"); + }; + let parsed: Result = toml::from_str(&config_str) + .map_err(|e: toml::de::Error| anyhow!("Failed to parse toml config: {:?}", e)); + parsed + } +} diff --git a/src/main.rs b/src/main.rs index 0907fa2..74287b2 100644 --- a/src/main.rs +++ b/src/main.rs @@ -57,7 +57,7 @@ fn main() { let mut backends: HashMap = HashMap::new(); - parse_opts(&mut globals, &mut backends); + let _ = parse_opts(&mut globals, &mut backends).expect("Invalid configuration"); entrypoint(Arc::new(globals), Arc::new(backends)) .await @@ -78,8 +78,6 @@ async fn entrypoint(globals: Arc, backends: Arc Result> { debug!("Handling request: {:?}", req); - // Here we start to handle with hostname - // Find backend application for given hostname - let (hostname, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) { + // Here we start to handle with server_name + // Find backend application for given server_name + let (server_name, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) { v } else { return http_error(StatusCode::SERVICE_UNAVAILABLE); }; - let backend = if let Some(be) = self.backends.get(hostname.as_str()) { + let backend = if let Some(be) = self.backends.get(server_name.as_str()) { be } else { return http_error(StatusCode::SERVICE_UNAVAILABLE); @@ -48,16 +48,19 @@ where // Redirect to https if tls_enabled is false and redirect_to_https is true let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned(); if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { - debug!("Redirect to secure connection: {}", hostname); - return secure_redirection(&hostname, self.globals.https_port, &path_and_query); + debug!("Redirect to secure connection: {}", server_name); + return secure_redirection(&server_name, self.globals.https_port, &path_and_query); } // Find reverse proxy for given path and choose one of upstream host + // TODO: More flexible path matcher let path = req.uri().path(); let upstream_uri = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) { upstream.get() + } else if let Some(default_upstream) = &backend.reverse_proxy.default_upstream { + default_upstream.get() } else { - backend.reverse_proxy.default_upstream.get() + return http_error(StatusCode::NOT_FOUND); }; let upstream_scheme_host = if let Some(u) = upstream_uri { u @@ -263,15 +266,15 @@ fn extract_upgrade(headers: &HeaderMap) -> Option { } fn secure_redirection( - hostname: &str, + server_name: &str, tls_port: Option, path_and_query: &str, ) -> Result> { let dest_uri: String = if let Some(tls_port) = tls_port { if tls_port == 443 { - format!("https://{}{}", hostname, path_and_query) + format!("https://{}{}", server_name, path_and_query) } else { - format!("https://{}:{}{}", hostname, tls_port, path_and_query) + format!("https://{}:{}{}", server_name, tls_port, path_and_query) } } else { bail!("Internal error! TLS port is not set internally."); @@ -285,15 +288,15 @@ fn secure_redirection( } fn parse_host_port(req: &Request, tls_enabled: bool) -> Result<(String, u16)> { - let hostname_port_headers = req.headers().get("host"); - let hostname_uri = req.uri().host(); + let host_port_headers = req.headers().get("host"); + let host_uri = req.uri().host(); let port_uri = req.uri().port_u16(); - if hostname_port_headers.is_none() && hostname_uri.is_none() { + if host_port_headers.is_none() && host_uri.is_none() { bail!("No host in request header"); } - let (hostname, port) = match (hostname_uri, hostname_port_headers) { + let (host, port) = match (host_uri, host_port_headers) { (Some(x), _) => { let port = if let Some(p) = port_uri { p @@ -306,9 +309,9 @@ fn parse_host_port(req: &Request, tls_enabled: bool) -> Result<(String, u1 } (None, Some(x)) => { let hp_as_uri = x.to_str().unwrap().parse::().unwrap(); - let hostname = hp_as_uri + let host = hp_as_uri .host() - .ok_or_else(|| anyhow!("Failed to parse hostname"))?; + .ok_or_else(|| anyhow!("Failed to parse host"))?; let port = if let Some(p) = hp_as_uri.port() { p.as_u16() } else if tls_enabled { @@ -316,38 +319,12 @@ fn parse_host_port(req: &Request, tls_enabled: bool) -> Result<(String, u1 } else { 80 }; - (hostname.to_string(), port) + (host.to_string(), port) } (None, None) => { bail!("Host unspecified in request") } }; - Ok((hostname, port)) + Ok((host, port)) } - -// fn get_upgrade_type(headers: &HeaderMap) -> Option { -// #[allow(clippy::blocks_in_if_conditions)] -// if headers -// .get(&*CONNECTION_HEADER) -// .map(|value| { -// value -// .to_str() -// .unwrap() -// .split(',') -// .any(|e| e.trim() == *UPGRADE_HEADER) -// }) -// .unwrap_or(false) -// { -// if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) { -// debug!( -// "Found upgrade header with value: {}", -// upgrade_value.to_str().unwrap().to_owned() -// ); - -// return Some(upgrade_value.to_str().unwrap().to_owned()); -// } -// } - -// None -// } diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index c260204..cdf0bfb 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -39,7 +39,7 @@ where { pub listening_on: SocketAddr, pub tls_enabled: bool, // TCP待受がTLSかどうか - pub backends: Arc>, // TODO: hyper::uriで抜いたhostnameで引っ掛ける。Stringでいいのか? + pub backends: Arc>, // TODO: hyper::uriで抜いたhostで引っ掛ける。Stringでいいのか? pub forwarder: Arc>, pub globals: Arc, } diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 1af0068..6e539d9 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -17,10 +17,10 @@ where let cert_service = async { info!("Start cert watch service for {}", self.listening_on); loop { - for (hostname, backend) in self.backends.iter() { + for (server_name, backend) in self.backends.iter() { if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { if let Err(_e) = backend.update_server_config().await { - warn!("Failed to update certs for {}", hostname); + warn!("Failed to update certs for {}", server_name); } } } @@ -59,9 +59,19 @@ where info!("No configuration for the server name {} given in client_hello", svn); continue; }; - let server_config = backend_serve.get_tls_server_config(); + + if backend_serve.tls_cert_path.is_none() { // at least cert does exit + debug!("SNI indicates a site that doesn't support TLS."); + continue; + } + let server_config = if let Some(p) = backend_serve.get_tls_server_config(){ + p + } else { + error!("Failed to load server config"); + continue; + }; // Finally serve the TLS connection - if let Ok(stream) = start.into_stream(Arc::new(server_config.unwrap())).await { + if let Ok(stream) = start.into_stream(Arc::new(server_config)).await { self.clone().client_serve(stream, server.clone(), _client_addr).await } }