diff --git a/src/proxy/backend.rs b/src/backend.rs similarity index 99% rename from src/proxy/backend.rs rename to src/backend.rs index 24d4265..07784c0 100644 --- a/src/proxy/backend.rs +++ b/src/backend.rs @@ -1,4 +1,4 @@ -use super::UpstreamOption; +use crate::backend_opt::UpstreamOption; use crate::log::*; use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; diff --git a/src/proxy/backend_opt.rs b/src/backend_opt.rs similarity index 100% rename from src/proxy/backend_opt.rs rename to src/backend_opt.rs diff --git a/src/config/parse.rs b/src/config/parse.rs index 6015082..43671b8 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,10 +1,11 @@ use super::toml::{ConfigToml, ReverseProxyOption}; use crate::{ + backend::{Backend, ReverseProxy, Upstream}, + backend_opt::UpstreamOption, constants::*, error::*, globals::*, log::*, - proxy::{Backend, Backends, ReverseProxy, Upstream, UpstreamOption}, }; use clap::Arg; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; @@ -13,7 +14,7 @@ use std::net::SocketAddr; // #[cfg(feature = "tls")] use std::path::PathBuf; -pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> { +pub fn parse_opts(globals: &mut Globals) -> Result<()> { let _ = include_str!("../../Cargo.toml"); let options = clap::command!().arg( Arg::new("config_file") @@ -121,7 +122,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; - backends.apps.insert( + globals.backends.apps.insert( server_name.as_bytes().to_vec(), Backend { app_name: app_name.to_owned(), @@ -138,7 +139,8 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> // default backend application for plaintext http requests if let Some(d) = config.default_app { - let d_sn: Vec<&str> = backends + let d_sn: Vec<&str> = globals + .backends .apps .iter() .filter(|(_k, v)| v.app_name == d) @@ -149,7 +151,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", d, d_sn[0] ); - backends.default_server_name = Some(d_sn[0].as_bytes().to_vec()); + globals.backends.default_server_name = Some(d_sn[0].as_bytes().to_vec()); } } diff --git a/src/globals.rs b/src/globals.rs index 19dd539..a1746a6 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -1,3 +1,4 @@ +use crate::backend::Backends; use std::net::SocketAddr; use std::sync::{ atomic::{AtomicUsize, Ordering}, @@ -5,7 +6,6 @@ use std::sync::{ }; use tokio::time::Duration; -#[derive(Debug, Clone)] pub struct Globals { pub listen_sockets: Vec, pub http_port: Option, @@ -19,6 +19,8 @@ pub struct Globals { pub http3: bool, pub runtime_handle: tokio::runtime::Handle, + + pub backends: Backends, } #[derive(Debug, Clone, Default)] diff --git a/src/main.rs b/src/main.rs index 6a56710..1ab9c57 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,24 +2,29 @@ #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; +mod backend; +mod backend_opt; mod config; mod constants; mod error; mod globals; mod log; +mod msg_handler; mod proxy; mod utils; use crate::{ + backend::{Backend, Backends, ServerNameLC}, config::parse_opts, constants::*, error::*, globals::*, log::*, - proxy::{Backend, Backends, Proxy, ServerNameLC}, + proxy::Proxy, }; use futures::future::select_all; use hyper::Client; +use msg_handler::HttpMessageHandler; // use hyper_trust_dns::TrustDnsResolver; use rustc_hash::FxHashMap as HashMap; use std::{io::Write, sync::Arc}; @@ -61,24 +66,22 @@ fn main() { max_concurrent_streams: MAX_CONCURRENT_STREAMS, keepalive: true, runtime_handle: runtime.handle().clone(), + + backends: Backends { + default_server_name: None, + apps: HashMap::::default(), + }, }; - let mut backends = Backends { - default_server_name: None, - apps: HashMap::::default(), - }; + parse_opts(&mut globals).expect("Invalid configuration"); - parse_opts(&mut globals, &mut backends).expect("Invalid configuration"); - - entrypoint(Arc::new(globals), Arc::new(backends)) - .await - .unwrap() + entrypoint(Arc::new(globals)).await.unwrap() }); warn!("Exit the program"); } // entrypoint creates and spawns tasks of proxy services -async fn entrypoint(globals: Arc, backends: Arc) -> Result<()> { +async fn entrypoint(globals: Arc) -> Result<()> { // let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); let connector = hyper_rustls::HttpsConnectorBuilder::new() .with_webpki_roots() @@ -86,7 +89,10 @@ async fn entrypoint(globals: Arc, backends: Arc) -> Result<() .enable_http1() .enable_http2() .build(); - let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); + let msg_handler = HttpMessageHandler { + forwarder: Arc::new(Client::builder().build::<_, hyper::Body>(connector)), + globals: globals.clone(), + }; let addresses = globals.listen_sockets.clone(); let futures = select_all(addresses.into_iter().map(|addr| { @@ -99,8 +105,7 @@ async fn entrypoint(globals: Arc, backends: Arc) -> Result<() globals: globals.clone(), listening_on: addr, tls_enabled, - backends: backends.clone(), - forwarder: forwarder.clone(), + msg_handler: msg_handler.clone(), }; globals.runtime_handle.spawn(proxy.start()) })); diff --git a/src/proxy/proxy_handler.rs b/src/msg_handler/handler.rs similarity index 89% rename from src/proxy/proxy_handler.rs rename to src/msg_handler/handler.rs index 3276ee1..ed94aaa 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/msg_handler/handler.rs @@ -1,16 +1,25 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::{utils_headers::*, utils_request::*, utils_synth_response::*, Proxy, Upstream}; -use crate::{constants::*, error::*, log::*}; +use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; +use crate::{backend::Upstream, constants::*, error::*, globals::Globals, log::*}; use hyper::{ client::connect::Connect, header::{self, HeaderValue}, http::uri::Scheme, - Body, Request, Response, StatusCode, Uri, Version, + Body, Client, Request, Response, StatusCode, Uri, Version, }; -use std::net::SocketAddr; +use std::{net::SocketAddr, sync::Arc}; use tokio::io::copy_bidirectional; -impl Proxy +#[derive(Clone)] +pub struct HttpMessageHandler +where + T: Connect + Clone + Sync + Send + 'static, +{ + pub forwarder: Arc>, + pub globals: Arc, +} + +impl HttpMessageHandler where T: Connect + Clone + Sync + Send + 'static, { @@ -18,6 +27,8 @@ where self, mut req: Request, client_addr: SocketAddr, // アクセス制御用 + listen_addr: SocketAddr, + tls_enabled: bool, ) -> Result> { req.log(&client_addr, Some("(Incoming)")); @@ -26,18 +37,18 @@ where // let (server_name, _port) = parse_host_port(&req)?; let server_name_bytes = req.parse_host()?.to_ascii_lowercase(); - let backend = if let Some(be) = self.backends.apps.get(&server_name_bytes) { + let backend = if let Some(be) = self.globals.backends.apps.get(&server_name_bytes) { be - } else if let Some(default_server_name) = &self.backends.default_server_name { + } else if let Some(default_server_name) = &self.globals.backends.default_server_name { debug!("Serving by default app"); - self.backends.apps.get(default_server_name).unwrap() + self.globals.backends.apps.get(default_server_name).unwrap() } else { // info!("{} => {}", request_log, StatusCode::SERVICE_UNAVAILABLE); return http_error(StatusCode::SERVICE_UNAVAILABLE); }; // Redirect to https if !tls_enabled and redirect_to_https is true - if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { + if !tls_enabled && backend.https_redirection.unwrap_or(false) { debug!("Redirect to secure connection: {}", &backend.server_name); // info!("{} => {}", request_log, StatusCode::PERMANENT_REDIRECT); return secure_redirection(&backend.server_name, self.globals.https_port, &req); @@ -68,10 +79,12 @@ where // Build request from destination information let req_forwarded = if let Ok(req) = self.generate_request_forwarded( &client_addr, + &listen_addr, req, upstream_scheme_host, &upgrade_in_request, upstream, + tls_enabled, ) { req } else { @@ -110,7 +123,7 @@ where .await?; // TODO: H3で死ぬことがある // thread 'rpxy' panicked at 'Failed to upgrade request: hyper::Error(User(ManualUpgrade))', src/proxy/proxy_handler.rs:124:63 - tokio::spawn(async move { + self.globals.runtime_handle.spawn(async move { let mut request_upgraded = request_upgraded.await.map_err(|e| { error!("Failed to upgrade request: {}", e); anyhow!("Failed to upgrade request: {}", e) @@ -181,13 +194,16 @@ where Ok(()) } + #[allow(clippy::too_many_arguments)] fn generate_request_forwarded( &self, client_addr: &SocketAddr, + listen_addr: &SocketAddr, mut req: Request, upstream_scheme_host: &Uri, upgrade: &Option, upstream: &Upstream, + tls_enabled: bool, ) -> Result> { debug!("Generate request to be forwarded"); @@ -208,7 +224,7 @@ where // delete hop headers including header.connection remove_hop_header(headers); // X-Forwarded-For - add_forwarding_header(headers, client_addr, self.tls_enabled, &self.globals)?; + add_forwarding_header(headers, client_addr, listen_addr, tls_enabled)?; // Add te: trailer if te_trailer if te_trailers { diff --git a/src/msg_handler/mod.rs b/src/msg_handler/mod.rs new file mode 100644 index 0000000..89d6c8a --- /dev/null +++ b/src/msg_handler/mod.rs @@ -0,0 +1,6 @@ +mod handler; +mod utils_headers; +mod utils_request; +mod utils_synth_response; + +pub use handler::HttpMessageHandler; diff --git a/src/proxy/utils_headers.rs b/src/msg_handler/utils_headers.rs similarity index 90% rename from src/proxy/utils_headers.rs rename to src/msg_handler/utils_headers.rs index ca1d478..a106e97 100644 --- a/src/proxy/utils_headers.rs +++ b/src/msg_handler/utils_headers.rs @@ -1,11 +1,10 @@ -use super::{Upstream, UpstreamOption}; -use crate::{error::*, globals::Globals, log::*, utils::*}; +use crate::{backend::Upstream, backend_opt::UpstreamOption, error::*, log::*, utils::*}; use bytes::BufMut; use hyper::{ header::{self, HeaderMap, HeaderName, HeaderValue}, Uri, }; -use std::{net::SocketAddr, sync::Arc}; +use std::net::SocketAddr; //////////////////////////////////////////////////// // Functions to manipulate headers @@ -71,8 +70,8 @@ pub(super) fn add_header_entry_if_not_exist( pub(super) fn add_forwarding_header( headers: &mut HeaderMap, client_addr: &SocketAddr, + listen_addr: &SocketAddr, tls: bool, - globals: &Arc, // TODO: Fix ) -> Result<()> { // default process // optional process defined by upstream_option is applied in fn apply_upstream_options @@ -92,15 +91,7 @@ pub(super) fn add_forwarding_header( )?; // If we receive X-Forwarded-Port, pass it through; otherwise, pass along the // server port the client connected to - add_header_entry_if_not_exist( - headers, - "x-forwarded-port", - if tls { - globals.https_port.unwrap().to_string() - } else { - globals.http_port.unwrap().to_string() - }, - )?; + add_header_entry_if_not_exist(headers, "x-forwarded-port", listen_addr.port().to_string())?; Ok(()) } diff --git a/src/proxy/utils_request.rs b/src/msg_handler/utils_request.rs similarity index 100% rename from src/proxy/utils_request.rs rename to src/msg_handler/utils_request.rs diff --git a/src/proxy/utils_synth_response.rs b/src/msg_handler/utils_synth_response.rs similarity index 100% rename from src/proxy/utils_synth_response.rs rename to src/msg_handler/utils_synth_response.rs diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index b466a6c..4ea42ec 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,14 +1,6 @@ -mod backend; -mod backend_opt; #[cfg(feature = "h3")] mod proxy_h3; -mod proxy_handler; mod proxy_main; mod proxy_tls; -mod utils_headers; -mod utils_request; -mod utils_synth_response; -pub use backend::*; -pub use backend_opt::UpstreamOption; pub use proxy_main::Proxy; diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs index c56b35d..b9e5c30 100644 --- a/src/proxy/proxy_h3.rs +++ b/src/proxy/proxy_h3.rs @@ -10,7 +10,7 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn client_serve_h3(self, conn: quinn::Connecting) { + pub async fn client_serve_h3(&self, conn: quinn::Connecting) { let clients_count = self.globals.clients_count.clone(); if clients_count.increment() > self.globals.max_clients { clients_count.decrement(); @@ -79,7 +79,7 @@ where let self_inner = self.clone(); self.globals.runtime_handle.spawn(async move { - if let Err(e) = self_inner.handle_request_h3(req, stream, client_addr).await { + if let Err(e) = self_inner.handle_stream_h3(req, stream, client_addr).await { error!("HTTP/3 request failed: {}", e); } // // TODO: Work around for timeout @@ -98,7 +98,7 @@ where Ok(()) } - async fn handle_request_h3( + async fn handle_stream_h3( self, req: Request<()>, mut stream: RequestStream, @@ -129,7 +129,11 @@ where }; let new_req: Request = Request::from_parts(req_parts, body); - let res = self.handle_request(new_req, client_addr).await?; + let res = self + .msg_handler + .clone() + .handle_request(new_req, client_addr, self.listening_on, self.tls_enabled) + .await?; let (new_res_parts, new_body) = res.into_parts(); let new_res = Response::from_parts(new_res_parts, ()); diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index 2686c42..56378b5 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -1,9 +1,6 @@ // use super::proxy_handler::handle_request; -use super::Backends; -use crate::{error::*, globals::Globals, log::*}; -use hyper::{ - client::connect::Connect, server::conn::Http, service::service_fn, Body, Client, Request, -}; +use crate::{error::*, globals::Globals, log::*, msg_handler::HttpMessageHandler}; +use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; use std::{net::SocketAddr, sync::Arc}; use tokio::{ io::{AsyncRead, AsyncWrite}, @@ -40,8 +37,7 @@ where { pub listening_on: SocketAddr, pub tls_enabled: bool, // TCP待受がTLSかどうか - pub backends: Arc, - pub forwarder: Arc>, + pub msg_handler: HttpMessageHandler, pub globals: Arc, } @@ -59,13 +55,21 @@ where return; } + // let handler_inner = self.msg_handler.clone(); self.globals.runtime_handle.clone().spawn(async move { tokio::time::timeout( self.globals.timeout + Duration::from_secs(1), server .serve_connection( stream, - service_fn(move |req: Request| self.clone().handle_request(req, peer_addr)), + service_fn(move |req: Request| { + self.msg_handler.clone().handle_request( + req, + peer_addr, + self.listening_on, + self.tls_enabled, + ) + }), ) .with_upgrades(), ) diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 586f0c7..80d5042 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -1,8 +1,5 @@ -use super::{ - proxy_main::{LocalExecutor, Proxy}, - ServerNameLC, -}; -use crate::{constants::*, error::*, log::*}; +use super::proxy_main::{LocalExecutor, Proxy}; +use crate::{backend::ServerNameLC, constants::*, error::*, log::*}; #[cfg(feature = "h3")] use futures::StreamExt; use futures::{future::FutureExt, select}; @@ -24,7 +21,7 @@ where info!("Start cert watch service"); loop { let mut hm_server_config = HashMap::>::default(); - for (server_name_bytes, backend) in self.backends.apps.iter() { + for (server_name_bytes, backend) in self.globals.backends.apps.iter() { if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { match backend.update_server_config().await { Err(_e) => { @@ -137,6 +134,7 @@ where // TODO: Work around to initially serve incoming connection // かなり適当。エラーが出たり出なかったり。原因がわからない… let next = self + .globals .backends .apps .iter() @@ -152,6 +150,7 @@ where std::str::from_utf8(initial_app_name) ); let backend_serve = self + .globals .backends .apps .get(initial_app_name)