diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 4aa5f7a..97753f2 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -212,21 +212,13 @@ impl Backend { } } +#[derive(Default)] /// HashMap and some meta information for multiple Backend structs. pub struct Backends { pub apps: HashMap, // hyper::uriで抜いたhostで引っ掛ける pub default_server_name_bytes: Option, // for plaintext http } -impl Default for Backends { - fn default() -> Self { - Self { - default_server_name_bytes: None, - apps: HashMap::::default(), - } - } -} - pub type SniServerCryptoMap = HashMap>; pub struct ServerCrypto { // For Quic/HTTP3, only servers with no client authentication diff --git a/src/config/mod.rs b/src/config/mod.rs index 6e8123c..54b2600 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,4 +1,4 @@ mod parse; mod toml; -pub use parse::parse_opts; +pub use parse::build_globals; diff --git a/src/config/parse.rs b/src/config/parse.rs index 1b13d30..1593aba 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,17 +1,9 @@ -use super::toml::{ConfigToml, ReverseProxyOption}; -use crate::{ - backend::{BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, - constants::*, - error::*, - globals::*, - log::*, - utils::{BytesName, PathNameBytesExp}, -}; +use super::toml::ConfigToml; +use crate::{backend::Backends, error::*, globals::*, log::*, utils::BytesName}; use clap::Arg; -use rustc_hash::FxHashMap as HashMap; -use std::net::SocketAddr; +use tokio::runtime::Handle; -pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Error> { +pub fn build_globals(runtime_handle: Handle) -> std::result::Result { let _ = include_str!("../../Cargo.toml"); let options = clap::command!().arg( Arg::new("config_file") @@ -22,6 +14,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro ); let matches = options.get_matches(); + /////////////////////////////////// let config = if let Some(config_file_path) = matches.get_one::("config_file") { ConfigToml::new(config_file_path)? } else { @@ -29,117 +22,67 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro ConfigToml::default() }; - // listen port and socket - globals.proxy_config.http_port = config.listen_port; - globals.proxy_config.https_port = config.listen_port_tls; - ensure!( - { globals.proxy_config.http_port.is_some() || globals.proxy_config.https_port.is_some() } && { - if let (Some(p), Some(t)) = (globals.proxy_config.http_port, globals.proxy_config.https_port) { - p != t - } else { - true - } - }, - anyhow!("Wrong port spec.") - ); - // NOTE: when [::]:xx is bound, both v4 and v6 listeners are enabled. - let listen_addresses: Vec<&str> = match config.listen_ipv6 { - Some(true) => { - info!("Listen both IPv4 and IPv6"); - LISTEN_ADDRESSES_V6.to_vec() - } - Some(false) | None => { - info!("Listen IPv4"); - LISTEN_ADDRESSES_V4.to_vec() - } - }; - globals.proxy_config.listen_sockets = listen_addresses - .iter() - .flat_map(|x| { - let mut v: Vec = vec![]; - if let Some(p) = globals.proxy_config.http_port { - v.push(format!("{x}:{p}").parse().unwrap()); - } - if let Some(p) = globals.proxy_config.https_port { - v.push(format!("{x}:{p}").parse().unwrap()); - } - v - }) - .collect(); - if globals.proxy_config.http_port.is_some() { - info!("Listen port: {}", globals.proxy_config.http_port.unwrap()); + /////////////////////////////////// + // build proxy config + let proxy_config: ProxyConfig = (&config).try_into()?; + // For loggings + if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { + info!("Listen both IPv4 and IPv6") + } else { + info!("Listen IPv4") } - if globals.proxy_config.https_port.is_some() { - info!("Listen port: {} (for TLS)", globals.proxy_config.https_port.unwrap()); + if proxy_config.http_port.is_some() { + info!("Listen port: {}", proxy_config.http_port.unwrap()); + } + if proxy_config.https_port.is_some() { + info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); + } + if proxy_config.http3 { + info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); + } + if !proxy_config.sni_consistency { + info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC."); } - // max values - if let Some(c) = config.max_clients { - globals.proxy_config.max_clients = c as usize; - } - if let Some(c) = config.max_concurrent_streams { - globals.proxy_config.max_concurrent_streams = c; - } + /////////////////////////////////// + // backend_apps + let apps = config.apps.ok_or(anyhow!("Missing application spec"))?; - // backend apps - ensure!(config.apps.is_some(), "Missing application spec."); - let apps = config.apps.unwrap(); + // assertions for all backend apps ensure!(!apps.0.is_empty(), "Wrong application spec."); + // if only https_port is specified, tls must be configured for all apps + if proxy_config.http_port.is_none() { + ensure!( + apps.0.iter().all(|(_, app)| app.tls.is_some()), + "Some apps serves only plaintext HTTP" + ); + } + // https redirection can be configured if both ports are active + if !(proxy_config.https_port.is_some() && proxy_config.http_port.is_some()) { + ensure!( + apps.0.iter().all(|(_, app)| { + if let Some(tls) = app.tls.as_ref() { + tls.https_redirection.is_none() + } else { + true + } + }), + "https_redirection can be specified only when both http_port and https_port are specified" + ); + } - // each app + // build backends + let mut backends = Backends::default(); for (app_name, app) in apps.0.iter() { - ensure!(app.server_name.is_some(), "Missing server_name"); - let server_name_string = app.server_name.as_ref().unwrap(); - if globals.proxy_config.http_port.is_none() { - // if only https_port is specified, tls must be configured - ensure!(app.tls.is_some()) - } - - // backend builder - let mut backend_builder = BackendBuilder::default(); - // reverse proxy settings - ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); - let reverse_proxy = get_reverse_proxy(server_name_string, app.reverse_proxy.as_ref().unwrap())?; - - backend_builder - .app_name(server_name_string) - .server_name(server_name_string) - .reverse_proxy(reverse_proxy); - - // TLS settings and build backend instance - let backend = if app.tls.is_none() { - ensure!(globals.proxy_config.http_port.is_some(), "Required HTTP port"); - backend_builder.build()? - } else { - let tls = app.tls.as_ref().unwrap(); - ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some()); - - let https_redirection = if tls.https_redirection.is_none() { - Some(true) // Default true - } else { - ensure!(globals.proxy_config.https_port.is_some()); // only when both https ports are configured. - tls.https_redirection - }; - - backend_builder - .tls_cert_path(&tls.tls_cert_path) - .tls_cert_key_path(&tls.tls_cert_key_path) - .https_redirection(https_redirection) - .client_ca_cert_path(&tls.client_ca_cert_path) - .build()? - }; - - globals - .backends - .apps - .insert(server_name_string.to_server_name_vec(), backend); + let server_name_string = app.server_name.as_ref().ok_or(anyhow!("No server name"))?; + let backend = app.try_into()?; + backends.apps.insert(server_name_string.to_server_name_vec(), backend); info!("Registering application: {} ({})", app_name, server_name_string); } // default backend application for plaintext http requests if let Some(d) = config.default_app { - let d_sn: Vec<&str> = globals - .backends + let d_sn: Vec<&str> = backends .apps .iter() .filter(|(_k, v)| v.app_name == d) @@ -150,86 +93,17 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", d, d_sn[0] ); - globals.backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); + backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); } } - // experimental - if let Some(exp) = config.experimental { - #[cfg(feature = "http3")] - { - if let Some(h3option) = exp.h3 { - globals.proxy_config.http3 = true; - info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); - if let Some(x) = h3option.alt_svc_max_age { - globals.proxy_config.h3_alt_svc_max_age = x; - } - if let Some(x) = h3option.request_max_body_size { - globals.proxy_config.h3_request_max_body_size = x; - } - if let Some(x) = h3option.max_concurrent_connections { - globals.proxy_config.h3_max_concurrent_connections = x; - } - if let Some(x) = h3option.max_concurrent_bidistream { - globals.proxy_config.h3_max_concurrent_bidistream = x.into(); - } - if let Some(x) = h3option.max_concurrent_unistream { - globals.proxy_config.h3_max_concurrent_unistream = x.into(); - } - if let Some(x) = h3option.max_idle_timeout { - if x == 0u64 { - globals.proxy_config.h3_max_idle_timeout = None; - } else { - globals.proxy_config.h3_max_idle_timeout = - Some(quinn::IdleTimeout::try_from(tokio::time::Duration::from_secs(x)).unwrap()) - } - } - } - } + /////////////////////////////////// + let globals = Globals { + proxy_config, + backends, + request_count: Default::default(), + runtime_handle, + }; - if let Some(b) = exp.ignore_sni_consistency { - globals.proxy_config.sni_consistency = !b; - if b { - info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC."); - } - } - } - - Ok(()) -} - -fn get_reverse_proxy( - server_name_string: &str, - rp_settings: &[ReverseProxyOption], -) -> std::result::Result { - let mut upstream: HashMap = HashMap::default(); - - rp_settings.iter().for_each(|rpo| { - let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); - // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); - // let lb_upstream_num = vec_upstream.len(); - let elem = UpstreamGroupBuilder::default() - .upstream(&upstream_vec) - .path(&rpo.path) - .replace_path(&rpo.replace_path) - .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) - .opts(&rpo.upstream_options) - .build() - .unwrap(); - - upstream.insert(elem.path.clone(), elem); - }); - ensure!( - rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2, - "Multiple default reverse proxy setting" - ); - ensure!( - upstream - .iter() - .all(|(_, elem)| !(elem.opts.contains(&UpstreamOption::ConvertHttpsTo11) - && elem.opts.contains(&UpstreamOption::ConvertHttpsTo2))), - "either one of force_http11 or force_http2 can be enabled" - ); - - Ok(ReverseProxy { upstream }) + Ok(globals) } diff --git a/src/config/toml.rs b/src/config/toml.rs index 29e76cc..b883f6a 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -1,7 +1,13 @@ -use crate::{backend::Upstream, error::*}; +use crate::{ + backend::{Backend, BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, + constants::*, + error::*, + globals::ProxyConfig, + utils::PathNameBytesExp, +}; use rustc_hash::FxHashMap as HashMap; use serde::Deserialize; -use std::fs; +use std::{fs, net::SocketAddr}; #[derive(Deserialize, Debug, Default)] pub struct ConfigToml { @@ -66,20 +72,93 @@ pub struct UpstreamParams { pub tls: Option, } -impl TryInto for &UpstreamParams { - type Error = RpxyError; +impl TryInto for &ConfigToml { + type Error = anyhow::Error; - fn try_into(self) -> std::result::Result { - let mut scheme = "http"; - if let Some(t) = self.tls { - if t { - scheme = "https"; + fn try_into(self) -> std::result::Result { + let mut proxy_config = ProxyConfig { + // listen port and socket + http_port: self.listen_port, + https_port: self.listen_port_tls, + ..Default::default() + }; + ensure!( + proxy_config.http_port.is_some() || proxy_config.https_port.is_some(), + anyhow!("Either/Both of http_port or https_port must be specified") + ); + if proxy_config.http_port.is_some() && proxy_config.https_port.is_some() { + ensure!( + proxy_config.http_port.unwrap() != proxy_config.https_port.unwrap(), + anyhow!("http_port and https_port must be different") + ); + } + + // NOTE: when [::]:xx is bound, both v4 and v6 listeners are enabled. + let listen_addresses: Vec<&str> = if let Some(true) = self.listen_ipv6 { + LISTEN_ADDRESSES_V6.to_vec() + } else { + LISTEN_ADDRESSES_V4.to_vec() + }; + proxy_config.listen_sockets = listen_addresses + .iter() + .flat_map(|addr| { + let mut v: Vec = vec![]; + if let Some(port) = proxy_config.http_port { + v.push(format!("{addr}:{port}").parse().unwrap()); + } + if let Some(port) = proxy_config.https_port { + v.push(format!("{addr}:{port}").parse().unwrap()); + } + v + }) + .collect(); + + // max values + if let Some(c) = self.max_clients { + proxy_config.max_clients = c as usize; + } + if let Some(c) = self.max_concurrent_streams { + proxy_config.max_concurrent_streams = c; + } + + // experimental + if let Some(exp) = &self.experimental { + #[cfg(feature = "http3")] + { + if let Some(h3option) = &exp.h3 { + proxy_config.http3 = true; + if let Some(x) = h3option.alt_svc_max_age { + proxy_config.h3_alt_svc_max_age = x; + } + if let Some(x) = h3option.request_max_body_size { + proxy_config.h3_request_max_body_size = x; + } + if let Some(x) = h3option.max_concurrent_connections { + proxy_config.h3_max_concurrent_connections = x; + } + if let Some(x) = h3option.max_concurrent_bidistream { + proxy_config.h3_max_concurrent_bidistream = x.into(); + } + if let Some(x) = h3option.max_concurrent_unistream { + proxy_config.h3_max_concurrent_unistream = x.into(); + } + if let Some(x) = h3option.max_idle_timeout { + if x == 0u64 { + proxy_config.h3_max_idle_timeout = None; + } else { + proxy_config.h3_max_idle_timeout = + Some(quinn::IdleTimeout::try_from(tokio::time::Duration::from_secs(x)).unwrap()) + } + } + } + } + + if let Some(ignore) = exp.ignore_sni_consistency { + proxy_config.sni_consistency = !ignore; } } - let location = format!("{}://{}", scheme, self.location); - Ok(Upstream { - uri: location.parse::().map_err(|e| anyhow!("{}", e))?, - }) + + Ok(proxy_config) } } @@ -90,3 +169,98 @@ impl ConfigToml { toml::from_str(&config_str).map_err(RpxyError::TomlDe) } } + +impl TryInto for &Application { + type Error = anyhow::Error; + + fn try_into(self) -> std::result::Result { + let server_name_string = self.server_name.as_ref().ok_or(anyhow!("Missing server_name"))?; + + // backend builder + let mut backend_builder = BackendBuilder::default(); + // reverse proxy settings + let reverse_proxy = self.try_into()?; + + backend_builder + .app_name(server_name_string) + .server_name(server_name_string) + .reverse_proxy(reverse_proxy); + + // TLS settings and build backend instance + let backend = if self.tls.is_none() { + backend_builder.build()? + } else { + let tls = self.tls.as_ref().unwrap(); + ensure!(tls.tls_cert_key_path.is_some() && tls.tls_cert_path.is_some()); + + let https_redirection = if tls.https_redirection.is_none() { + Some(true) // Default true + } else { + tls.https_redirection + }; + + backend_builder + .tls_cert_path(&tls.tls_cert_path) + .tls_cert_key_path(&tls.tls_cert_key_path) + .https_redirection(https_redirection) + .client_ca_cert_path(&tls.client_ca_cert_path) + .build()? + }; + Ok(backend) + } +} + +impl TryInto for &Application { + type Error = anyhow::Error; + + fn try_into(self) -> std::result::Result { + let server_name_string = self.server_name.as_ref().ok_or(anyhow!("Missing server_name"))?; + let rp_settings = self.reverse_proxy.as_ref().ok_or(anyhow!("Missing reverse_proxy"))?; + + let mut upstream: HashMap = HashMap::default(); + + rp_settings.iter().for_each(|rpo| { + let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); + // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); + // let lb_upstream_num = vec_upstream.len(); + let elem = UpstreamGroupBuilder::default() + .upstream(&upstream_vec) + .path(&rpo.path) + .replace_path(&rpo.replace_path) + .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) + .opts(&rpo.upstream_options) + .build() + .unwrap(); + + upstream.insert(elem.path.clone(), elem); + }); + ensure!( + rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2, + "Multiple default reverse proxy setting" + ); + ensure!( + upstream + .iter() + .all(|(_, elem)| !(elem.opts.contains(&UpstreamOption::ConvertHttpsTo11) + && elem.opts.contains(&UpstreamOption::ConvertHttpsTo2))), + "either one of force_http11 or force_http2 can be enabled" + ); + + Ok(ReverseProxy { upstream }) + } +} + +impl TryInto for &UpstreamParams { + type Error = RpxyError; + + fn try_into(self) -> std::result::Result { + let scheme = match self.tls { + Some(true) => "https", + _ => "http", + }; + let location = format!("{}://{}", scheme, self.location); + Ok(Upstream { + uri: location.parse::().map_err(|e| anyhow!("{}", e))?, + }) + } +} diff --git a/src/globals.rs b/src/globals.rs index b5c4a46..cd47611 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -10,7 +10,7 @@ use tokio::time::Duration; /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. pub struct Globals { /// Configuration parameters for proxy transport and request handlers - pub proxy_config: ProxyConfig, + pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも? /// Shared context - Backend application objects to which http request handler forward incoming requests pub backends: Backends, diff --git a/src/main.rs b/src/main.rs index 2b5e1fa..77b5d1b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -16,8 +16,7 @@ mod proxy; mod utils; use crate::{ - backend::Backends, config::parse_opts, error::*, globals::*, handler::HttpMessageHandlerBuilder, log::*, - proxy::ProxyBuilder, + config::build_globals, error::*, globals::*, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder, }; use futures::future::select_all; use hyper::Client; @@ -33,23 +32,17 @@ fn main() { let runtime = runtime_builder.build().unwrap(); runtime.block_on(async { - let mut globals = Globals { - // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも? - proxy_config: ProxyConfig::default(), - backends: Backends::default(), - - request_count: Default::default(), - runtime_handle: runtime.handle().clone(), - }; - - if let Err(e) = parse_opts(&mut globals) { - error!("Invalid configuration: {}", e); - std::process::exit(1); + let globals = match build_globals(runtime.handle().clone()) { + Ok(g) => g, + Err(e) => { + error!("Invalid configuration: {}", e); + std::process::exit(1); + } }; entrypoint(Arc::new(globals)).await.unwrap() }); - warn!("Exit the program"); + warn!("rpxy exited!"); } // entrypoint creates and spawns tasks of proxy services