diff --git a/bench/rpxy.toml b/bench/rpxy.toml index 9eedbc8..4c80774 100644 --- a/bench/rpxy.toml +++ b/bench/rpxy.toml @@ -14,7 +14,11 @@ server_name = 'localhost' reverse_proxy = [ # default destination if path is not specified # Array for load balancing - { upstream = [{ location = 'backend-nginx', tls = false }] }, + { upstream = [ + { location = 'backend-nginx', tls = false, upstream_options = [ + "override_host", + ] }, + ] }, # { upstream = [{ location = '192.168.100.100', tls = false }] }, ] diff --git a/config-example.toml b/config-example.toml index 61e14e9..1544ee3 100644 --- a/config-example.toml +++ b/config-example.toml @@ -37,10 +37,14 @@ reverse_proxy = [ { upstream = [ { location = 'www.google.com', tls = true }, { location = 'www.google.co.jp', tls = true }, + ], upstream_options = [ + "override_host", ] }, { path = '/maps', upstream = [ { location = 'www.bing.com', tls = true }, { location = 'www.bing.co.jp', tls = true }, + ], upstream_options = [ + "override_host", ] }, ] # Optional: TLS setting. if https_port is specified and tls is true above, this must be given. diff --git a/src/config/parse.rs b/src/config/parse.rs index a9e4022..74159a8 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,8 +1,14 @@ use super::toml::{ConfigToml, ReverseProxyOption}; -use crate::{backend::*, constants::*, error::*, globals::*, log::*}; +use crate::{ + constants::*, + error::*, + globals::*, + log::*, + proxy::{Backend, Backends, ReverseProxy, Upstream, UpstreamOption}, +}; use clap::Arg; use parking_lot::Mutex; -use rustc_hash::FxHashMap as HashMap; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::net::SocketAddr; // #[cfg(feature = "tls")] @@ -165,7 +171,18 @@ fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> Result uri: rpo.upstream.iter().map(|x| x.to_uri().unwrap()).collect(), cnt: Default::default(), lb: Default::default(), + opts: { + if let Some(opts) = &rpo.upstream_options { + opts + .iter() + .filter_map(|str| UpstreamOption::try_from(str.as_str()).ok()) + .collect::>() + } else { + Default::default() + } + }, }; + if rpo.path.is_some() { upstream.insert(rpo.path.as_ref().unwrap().to_owned(), elem); } else { diff --git a/src/config/toml.rs b/src/config/toml.rs index 9c6e945..b7ecb7b 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -40,15 +40,16 @@ pub struct TlsOption { #[derive(Deserialize, Debug, Default)] pub struct ReverseProxyOption { pub path: Option, - pub upstream: Vec, + pub upstream: Vec, + pub upstream_options: Option>, } #[derive(Deserialize, Debug, Default)] -pub struct UpstreamOption { +pub struct UpstreamParams { pub location: String, pub tls: Option, } -impl UpstreamOption { +impl UpstreamParams { pub fn to_uri(&self) -> Result { let mut scheme = "http"; if let Some(t) = self.tls { diff --git a/src/main.rs b/src/main.rs index dff5a69..cc034f4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -2,7 +2,6 @@ #[global_allocator] static GLOBAL: tikv_jemallocator::Jemalloc = tikv_jemallocator::Jemalloc; -mod backend; mod config; mod constants; mod error; @@ -11,13 +10,12 @@ mod log; mod proxy; use crate::{ - backend::{Backend, Backends}, config::parse_opts, constants::*, error::*, globals::*, log::*, - proxy::Proxy, + proxy::{Backend, Backends, Proxy}, }; use futures::future::select_all; use hyper::Client; diff --git a/src/backend.rs b/src/proxy/backend.rs similarity index 97% rename from src/backend.rs rename to src/proxy/backend.rs index d5f56ee..62189e7 100644 --- a/src/backend.rs +++ b/src/proxy/backend.rs @@ -1,7 +1,8 @@ +use super::UpstreamOption; use crate::log::*; use parking_lot::Mutex; use rand::Rng; -use rustc_hash::FxHashMap as HashMap; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::{ fs::File, io::{self, BufReader, Cursor, Read}, @@ -53,6 +54,7 @@ pub struct Upstream { pub uri: Vec, pub lb: LoadBalance, pub cnt: UpstreamCount, // counter for load balancing + pub opts: HashSet, } #[derive(Debug, Clone, Default)] diff --git a/src/proxy/backend_opt.rs b/src/proxy/backend_opt.rs new file mode 100644 index 0000000..4ae60bb --- /dev/null +++ b/src/proxy/backend_opt.rs @@ -0,0 +1,16 @@ +use crate::error::*; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum UpstreamOption { + OverrideHost, + // TODO: Adds more options for heder override +} +impl TryFrom<&str> for UpstreamOption { + type Error = anyhow::Error; + fn try_from(val: &str) -> Result { + match val { + "override_host" => Ok(Self::OverrideHost), + _ => Err(anyhow!("Unsupported header option")), + } + } +} diff --git a/src/proxy/mod.rs b/src/proxy/mod.rs index 20b5810..5acb8f1 100644 --- a/src/proxy/mod.rs +++ b/src/proxy/mod.rs @@ -1,7 +1,11 @@ +mod backend; +mod backend_opt; #[cfg(feature = "h3")] mod proxy_h3; mod proxy_handler; mod proxy_main; mod proxy_tls; +pub use backend::*; +pub use backend_opt::UpstreamOption; pub use proxy_main::Proxy; diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index a60eb25..bd7c4eb 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -1,5 +1,5 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::Proxy; +use super::{Proxy, Upstream, UpstreamOption}; use crate::{constants::*, error::*, log::*}; use hyper::{ client::connect::Connect, @@ -68,14 +68,14 @@ where // Find reverse proxy for given path and choose one of upstream host // TODO: More flexible path matcher let path = req.uri().path(); - let upstream_uri = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) { - upstream.get() + let upstream = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) { + upstream } else if let Some(default_upstream) = &backend.reverse_proxy.default_upstream { - default_upstream.get() + default_upstream } else { return http_error(StatusCode::NOT_FOUND); }; - let upstream_scheme_host = if let Some(u) = upstream_uri { + let upstream_scheme_host = if let Some(u) = upstream.get() { u } else { return http_error(StatusCode::INTERNAL_SERVER_ERROR); @@ -92,6 +92,7 @@ where upstream_scheme_host, path_and_query, &upgrade_in_request, + upstream, ) { req } else { @@ -184,6 +185,7 @@ fn generate_request_forwarded( upstream_scheme_host: &Uri, path_and_query: String, upgrade: &Option, + upstream: &Upstream, ) -> Result> { debug!("Generate request to be forwarded"); @@ -206,11 +208,27 @@ fn generate_request_forwarded( remove_hop_header(headers); // X-Forwarded-For add_forwarding_header(headers, client_addr)?; + // Add te: trailer if te_trailer if te_trailers { headers.insert("te", "trailer".parse().unwrap()); } + // add "host" header of original server_name if not exist (default) + if req.headers().get(hyper::header::HOST).is_none() { + let org_host = req.uri().host().unwrap_or("none").to_owned(); + req.headers_mut().insert( + hyper::header::HOST, + HeaderValue::from_str(org_host.as_str()).unwrap(), + ); + }; + + // apply upstream-specific headers given in upstream_option + let headers = req.headers_mut(); + println!("before {:?}", headers); + apply_upstream_options_to_header(headers, client_addr, upstream_scheme_host, upstream)?; + println!("after {:?}", req); + // update uri in request *req.uri_mut() = Uri::builder() .scheme(upstream_scheme_host.scheme().unwrap().as_str()) @@ -223,7 +241,7 @@ fn generate_request_forwarded( req.headers_mut().insert("upgrade", v.parse().unwrap()); req .headers_mut() - .insert("connection", HeaderValue::from_str("upgrade")?); + .insert(hyper::header::CONNECTION, HeaderValue::from_str("upgrade")?); } // Change version to http/1.1 when destination scheme is http @@ -237,8 +255,29 @@ fn generate_request_forwarded( Ok(req) } +fn apply_upstream_options_to_header( + headers: &mut HeaderMap, + _client_addr: SocketAddr, + upstream_scheme_host: &Uri, + upstream: &Upstream, +) -> Result<()> { + upstream.opts.iter().for_each(|opt| match opt { + UpstreamOption::OverrideHost => { + let upstream_host = upstream_scheme_host.host().unwrap(); + headers + .insert( + hyper::header::HOST, + HeaderValue::from_str(upstream_host).unwrap(), + ) + .unwrap(); + } + }); + Ok(()) +} + fn add_forwarding_header(headers: &mut HeaderMap, client_addr: SocketAddr) -> Result<()> { - // TODO: Option対応? + // default process + // optional process defined by upstream_option is applied in fn apply_upstream_options let client_ip = client_addr.ip(); match headers.entry("x-forwarded-for") { hyper::header::Entry::Vacant(entry) => { diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index 22434d5..3c35693 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -1,5 +1,6 @@ // use super::proxy_handler::handle_request; -use crate::{backend::Backends, error::*, globals::Globals, log::*}; +use super::Backends; +use crate::{error::*, globals::Globals, log::*}; use hyper::{ client::connect::Connect, server::conn::Http, service::service_fn, Body, Client, Request, };