diff --git a/src/backend.rs b/src/backend/mod.rs similarity index 59% rename from src/backend.rs rename to src/backend/mod.rs index ae86be2..fd0b60d 100644 --- a/src/backend.rs +++ b/src/backend/mod.rs @@ -1,31 +1,27 @@ -use crate::{backend_opt::UpstreamOption, log::*}; -use rand::Rng; -use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +mod upstream; +mod upstream_opts; + +use crate::log::*; +use rustc_hash::FxHashMap as HashMap; use std::{ - borrow::Cow, fs::File, io::{self, BufReader, Cursor, Read}, path::PathBuf, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, + sync::Arc, }; use tokio_rustls::rustls::{ server::ResolvesServerCertUsingSni, sign::{any_supported_type, CertifiedKey}, Certificate, PrivateKey, ServerConfig, }; +pub use upstream::{ReverseProxy, Upstream, UpstreamGroup}; +pub use upstream_opts::UpstreamOption; -// server name (hostname or ip address) in ascii lower case -pub type ServerNameLC = Vec; -pub type PathNameLC = Vec; - -pub struct Backends { - pub apps: HashMap, // hyper::uriで抜いたhostで引っ掛ける - pub default_server_name: Option, // for plaintext http -} +// server name (hostname or ip address) and path name representation in backends +pub type ServerNameExp = Vec; // lowercase ascii bytes +pub type PathNameExp = Vec; // lowercase ascii bytes +/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. pub struct Backend { pub app_name: String, pub server_name: String, @@ -37,104 +33,6 @@ pub struct Backend { pub https_redirection: Option, } -#[derive(Debug, Clone)] -pub struct ReverseProxy { - pub upstream: HashMap, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 -} - -impl ReverseProxy { - pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamGroup> { - // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 - // コスト的にこの程度で十分 - let path_lc = path_str.into().to_ascii_lowercase(); - let path_bytes = path_lc.as_bytes(); - - let matched_upstream = self - .upstream - .iter() - .filter(|(route_bytes, _)| { - match path_bytes.starts_with(route_bytes) { - true => { - route_bytes.len() == 1 // route = '/', i.e., default - || match path_bytes.get(route_bytes.len()) { - None => true, // exact case - Some(p) => p == &b'/', // sub-path case - } - } - _ => false, - } - }) - .max_by_key(|(route_bytes, _)| route_bytes.len()); - if let Some((_path, u)) = matched_upstream { - debug!( - "Found upstream: {:?}", - String::from_utf8(_path.to_vec()).unwrap_or_else(|_| "".to_string()) - ); - Some(u) - } else { - None - } - } -} - -#[allow(dead_code)] -#[derive(Debug, Clone)] -pub enum LoadBalance { - RoundRobin, - Random, -} -impl Default for LoadBalance { - fn default() -> Self { - Self::RoundRobin - } -} - -#[derive(Debug, Clone)] -pub struct Upstream { - pub uri: hyper::Uri, // base uri without specific path -} - -#[derive(Debug, Clone)] -pub struct UpstreamGroup { - pub upstream: Vec, - pub path: PathNameLC, - pub replace_path: Option, - pub lb: LoadBalance, - pub cnt: UpstreamCount, // counter for load balancing - pub opts: HashSet, -} - -#[derive(Debug, Clone, Default)] -pub struct UpstreamCount(Arc); - -impl UpstreamGroup { - pub fn get(&self) -> Option<&Upstream> { - match self.lb { - LoadBalance::RoundRobin => { - let idx = self.increment_cnt(); - self.upstream.get(idx) - } - LoadBalance::Random => { - let mut rng = rand::thread_rng(); - let max = self.upstream.len() - 1; - self.upstream.get(rng.gen_range(0..max)) - } - } - } - - fn current_cnt(&self) -> usize { - self.cnt.0.load(Ordering::Relaxed) - } - - fn increment_cnt(&self) -> usize { - if self.current_cnt() < self.upstream.len() - 1 { - self.cnt.0.fetch_add(1, Ordering::Relaxed) - } else { - self.cnt.0.fetch_and(0, Ordering::Relaxed) - } - } -} - impl Backend { pub fn read_certs_and_key(&self) -> io::Result { debug!("Read TLS server certificates and private key"); @@ -210,6 +108,12 @@ impl Backend { } } +/// HashMap and some meta information for multiple Backend structs. +pub struct Backends { + pub apps: HashMap, // hyper::uriで抜いたhostで引っ掛ける + pub default_server_name: Option, // for plaintext http +} + impl Backends { pub async fn generate_server_crypto_with_cert_resolver(&self) -> Result { let mut resolver = ResolvesServerCertUsingSni::new(); diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs new file mode 100644 index 0000000..da4a97a --- /dev/null +++ b/src/backend/upstream.rs @@ -0,0 +1,109 @@ +use super::{PathNameExp, UpstreamOption}; +use crate::log::*; +use rand::Rng; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use std::{ + borrow::Cow, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +#[derive(Debug, Clone)] +pub struct ReverseProxy { + pub upstream: HashMap, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 +} + +impl ReverseProxy { + pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamGroup> { + // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 + // コスト的にこの程度で十分 + let path_lc = path_str.into().to_ascii_lowercase(); + let path_bytes = path_lc.as_bytes(); + + let matched_upstream = self + .upstream + .iter() + .filter(|(route_bytes, _)| { + match path_bytes.starts_with(route_bytes) { + true => { + route_bytes.len() == 1 // route = '/', i.e., default + || match path_bytes.get(route_bytes.len()) { + None => true, // exact case + Some(p) => p == &b'/', // sub-path case + } + } + _ => false, + } + }) + .max_by_key(|(route_bytes, _)| route_bytes.len()); + if let Some((_path, u)) = matched_upstream { + debug!( + "Found upstream: {:?}", + String::from_utf8(_path.to_vec()).unwrap_or_else(|_| "".to_string()) + ); + Some(u) + } else { + None + } + } +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub enum LoadBalance { + RoundRobin, + Random, +} +impl Default for LoadBalance { + fn default() -> Self { + Self::RoundRobin + } +} + +#[derive(Debug, Clone)] +pub struct Upstream { + pub uri: hyper::Uri, // base uri without specific path +} + +#[derive(Debug, Clone)] +pub struct UpstreamGroup { + pub upstream: Vec, + pub path: PathNameExp, + pub replace_path: Option, + pub lb: LoadBalance, + pub cnt: UpstreamCount, // counter for load balancing + pub opts: HashSet, +} + +#[derive(Debug, Clone, Default)] +pub struct UpstreamCount(Arc); + +impl UpstreamGroup { + pub fn get(&self) -> Option<&Upstream> { + match self.lb { + LoadBalance::RoundRobin => { + let idx = self.increment_cnt(); + self.upstream.get(idx) + } + LoadBalance::Random => { + let mut rng = rand::thread_rng(); + let max = self.upstream.len() - 1; + self.upstream.get(rng.gen_range(0..max)) + } + } + } + + fn current_cnt(&self) -> usize { + self.cnt.0.load(Ordering::Relaxed) + } + + fn increment_cnt(&self) -> usize { + if self.current_cnt() < self.upstream.len() - 1 { + self.cnt.0.fetch_add(1, Ordering::Relaxed) + } else { + self.cnt.0.fetch_and(0, Ordering::Relaxed) + } + } +} diff --git a/src/backend_opt.rs b/src/backend/upstream_opts.rs similarity index 100% rename from src/backend_opt.rs rename to src/backend/upstream_opts.rs diff --git a/src/config/parse.rs b/src/config/parse.rs index 4fe244e..19b080f 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,7 +1,6 @@ use super::toml::{ConfigToml, ReverseProxyOption}; use crate::{ - backend::{Backend, PathNameLC, ReverseProxy, UpstreamGroup}, - backend_opt::UpstreamOption, + backend::{Backend, PathNameExp, ReverseProxy, UpstreamGroup, UpstreamOption}, constants::*, error::*, globals::*, @@ -192,7 +191,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro } fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result { - let mut upstream: HashMap = HashMap::default(); + let mut upstream: HashMap = HashMap::default(); rp_settings.iter().for_each(|rpo| { let path = match &rpo.path { Some(p) => p.as_bytes().to_ascii_lowercase(), diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index b85eb55..045b190 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -1,7 +1,7 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*}; +use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; use crate::{ - backend::{ServerNameLC, UpstreamGroup}, + backend::{ServerNameExp, UpstreamGroup}, error::*, globals::Globals, log::*, @@ -39,9 +39,8 @@ where client_addr: SocketAddr, // アクセス制御用 listen_addr: SocketAddr, tls_enabled: bool, - tls_server_name: Option, + tls_server_name: Option, ) -> Result> { - req.log_debug(&client_addr, Some("(from Client)")); //////// let mut log_data = MessageLog::from(&req); log_data.client_addr(&client_addr); @@ -102,7 +101,6 @@ where return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); }; // debug!("Request to be forwarded: {:?}", req_forwarded); - req.log_debug(&client_addr, Some("(to Backend)")); log_data.xff(&req.headers().get("x-forwarded-for")); log_data.upstream(req.uri()); ////// @@ -123,14 +121,9 @@ where } }; - res_backend.log_debug(&backend.server_name, &client_addr, Some("(from Backend)")); - // let response_log = res_backend.status().to_string(); - if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { // Generate response to client if self.generate_response_forwarded(&mut res_backend).is_ok() { - // info!("{} => {}", request_log, response_log); - res_backend.log_debug(&backend.server_name, &client_addr, Some("(to Client)")); log_data.status_code(&res_backend.status()).output(); return Ok(res_backend); } else { diff --git a/src/handler/mod.rs b/src/handler/mod.rs index 4f26d27..799ef60 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -1,7 +1,6 @@ mod handler_main; mod utils_headers; mod utils_request; -mod utils_response; mod utils_synth_response; pub use handler_main::HttpMessageHandler; diff --git a/src/handler/utils_headers.rs b/src/handler/utils_headers.rs index dc956aa..acb929e 100644 --- a/src/handler/utils_headers.rs +++ b/src/handler/utils_headers.rs @@ -1,4 +1,9 @@ -use crate::{backend::UpstreamGroup, backend_opt::UpstreamOption, error::*, log::*, utils::*}; +use crate::{ + backend::{UpstreamGroup, UpstreamOption}, + error::*, + log::*, + utils::*, +}; use bytes::BufMut; use hyper::{ header::{self, HeaderMap, HeaderName, HeaderValue}, diff --git a/src/handler/utils_request.rs b/src/handler/utils_request.rs index 9ff480d..b36bc5e 100644 --- a/src/handler/utils_request.rs +++ b/src/handler/utils_request.rs @@ -1,58 +1,5 @@ -use crate::{error::*, log::*, utils::*}; +use crate::error::*; use hyper::{header, Request}; -use std::fmt::Display; - -//////////////////////////////////////////////////// -// Functions of utils for request messages -pub trait ReqLog { - fn log(self, src: &T, extra: Option<&str>); - fn log_debug(self, src: &T, extra: Option<&str>); - fn build_message(self, src: &T, extra: Option<&str>) -> String; -} -impl ReqLog for &Request { - fn log(self, src: &T, extra: Option<&str>) { - info!("{}", &self.build_message(src, extra)); - } - fn log_debug(self, src: &T, extra: Option<&str>) { - debug!("{}", &self.build_message(src, extra)); - } - fn build_message(self, src: &T, extra: Option<&str>) -> String { - let canonical_src = src.to_canonical(); - - let host = self - .headers() - .get(header::HOST) - .map_or_else(|| "", |v| v.to_str().unwrap_or("")); - let uri_scheme = self - .uri() - .scheme_str() - .map_or_else(|| "".to_string(), |v| format!("{}://", v)); - let uri_host = self.uri().host().unwrap_or(""); - let uri_pq = self.uri().path_and_query().map_or_else(|| "", |v| v.as_str()); - let ua = self - .headers() - .get(header::USER_AGENT) - .map_or_else(|| "", |v| v.to_str().unwrap_or("")); - let xff = self - .headers() - .get("x-forwarded-for") - .map_or_else(|| "", |v| v.to_str().unwrap_or("")); - - format!( - "{} <- {} -- {} {} {:?} -- ({}{}) \"{}\" \"{}\" {}", - host, - canonical_src, - self.method(), - uri_pq, - self.version(), - uri_scheme, - uri_host, - ua, - xff, - extra.unwrap_or("") - ) - } -} pub trait ParseHost { fn parse_host(&self) -> Result<&[u8]>; diff --git a/src/handler/utils_response.rs b/src/handler/utils_response.rs deleted file mode 100644 index c04b7f6..0000000 --- a/src/handler/utils_response.rs +++ /dev/null @@ -1,41 +0,0 @@ -use crate::{log::*, utils::*}; -use hyper::Response; -use std::fmt::Display; - -//////////////////////////////////////////////////// -// Functions of utils for request messages -pub trait ResLog { - fn log(self, server_name: &T1, client_addr: &T2, extra: Option<&str>); - fn log_debug(self, server_name: &T1, client_addr: &T2, extra: Option<&str>); - fn build_message( - self, - server_name: &T1, - client_addr: &T2, - extra: Option<&str>, - ) -> String; -} -impl ResLog for &Response { - fn log(self, server_name: &T1, client_addr: &T2, extra: Option<&str>) { - info!("{}", &self.build_message(server_name, client_addr, extra)); - } - fn log_debug(self, server_name: &T1, client_addr: &T2, extra: Option<&str>) { - debug!("{}", &self.build_message(server_name, client_addr, extra)); - } - fn build_message( - self, - server_name: &T1, - client_addr: &T2, - extra: Option<&str>, - ) -> String { - let canonical_client_addr = client_addr.to_canonical(); - format!( - "{} <- {} -- {} {:?} {}", - canonical_client_addr, - server_name, - self.status(), - self.version(), - // self.headers(), - extra.map_or_else(|| "", |v| v) - ) - } -} diff --git a/src/main.rs b/src/main.rs index 1e2996c..b69acd6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,6 @@ use tikv_jemallocator::Jemalloc; static GLOBAL: Jemalloc = Jemalloc; mod backend; -mod backend_opt; mod config; mod constants; mod error; @@ -17,7 +16,7 @@ mod proxy; mod utils; use crate::{ - backend::{Backend, Backends, ServerNameLC}, + backend::{Backend, Backends, ServerNameExp}, config::parse_opts, constants::*, error::*, @@ -73,7 +72,7 @@ fn main() { runtime_handle: runtime.handle().clone(), backends: Backends { default_server_name: None, - apps: HashMap::::default(), + apps: HashMap::::default(), }, sni_consistency: true, diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs index 5892b1d..0599a61 100644 --- a/src/proxy/proxy_h3.rs +++ b/src/proxy/proxy_h3.rs @@ -1,5 +1,5 @@ use super::Proxy; -use crate::{backend::ServerNameLC, error::*, log::*}; +use crate::{backend::ServerNameExp, error::*, log::*}; use bytes::{Buf, Bytes}; use h3::{quic::BidiStream, server::RequestStream}; use hyper::{client::connect::Connect, Body, Request, Response}; @@ -10,7 +10,7 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub(super) async fn connection_serve_h3(self, conn: quinn::Connecting, tls_server_name: ServerNameLC) -> Result<()> { + pub(super) async fn connection_serve_h3(self, conn: quinn::Connecting, tls_server_name: ServerNameExp) -> Result<()> { let client_addr = conn.remote_address(); match conn.await { @@ -68,7 +68,7 @@ where req: Request<()>, stream: RequestStream, client_addr: SocketAddr, - tls_server_name: ServerNameLC, + tls_server_name: ServerNameExp, ) -> Result<()> where S: BidiStream + Send + 'static, diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index 0af1f3e..f3cf3b7 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -1,5 +1,5 @@ // use super::proxy_handler::handle_request; -use crate::{backend::ServerNameLC, error::*, globals::Globals, handler::HttpMessageHandler, log::*}; +use crate::{backend::ServerNameExp, error::*, globals::Globals, handler::HttpMessageHandler, log::*}; use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; use std::{net::SocketAddr, sync::Arc}; use tokio::{ @@ -50,7 +50,7 @@ where stream: I, server: Http, peer_addr: SocketAddr, - tls_server_name: Option, + tls_server_name: Option, ) where I: AsyncRead + AsyncWrite + Send + Unpin + 'static, {