totally refactored

This commit is contained in:
Jun Kurihara 2022-07-09 01:01:00 +09:00
commit 907d7e574b
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
13 changed files with 372 additions and 252 deletions

View file

@ -1,29 +1,29 @@
#!/bin/sh #!/bin/sh
# echo "----------------------------"
# echo "Benchmark on rpxy"
# ab -c 16 -n 10000 http://127.0.0.1:8080/index.html
# echo "----------------------------"
# echo "Benchmark on nginx"
# ab -c 16 -n 10000 http://127.0.0.1:8090/index.html
# echo "----------------------------"
# echo "Benchmark on caddy"
# ab -c 16 -n 10000 http://127.0.0.1:8100/index.html
echo "----------------------------" echo "----------------------------"
echo "Benchmark on rpxy" echo "Benchmark on rpxy"
#wrk -t8 -c100 -d30s http://127.0.0.1:8080/index.html ab -c 100 -n 10000 http://127.0.0.1:8080/index.html
rewrk -c 256 -t 8 -d 15s -h http://127.0.0.1:8080 --pct
echo "----------------------------" echo "----------------------------"
echo "Benchmark on nginx" echo "Benchmark on nginx"
# wrk -t8 -c100 -d30s http://127.0.0.1:8090/index.html ab -c 100 -n 10000 http://127.0.0.1:8090/index.html
rewrk -c 256 -t 8 -d 15s -h http://127.0.0.1:8090 --pct
echo "----------------------------" echo "----------------------------"
echo "Benchmark on caddy" echo "Benchmark on caddy"
# wrk -t8 -c100 -d30s http://127.0.0.1:8100/index.html ab -c 100 -n 10000 http://127.0.0.1:8100/index.html
rewrk -c 256 -t 8 -d 15s -h http://127.0.0.1:8100 --pct
# echo "----------------------------"
# echo "Benchmark on rpxy"
# #wrk -t8 -c100 -d30s http://127.0.0.1:8080/index.html
# rewrk -c 256 -t 4 -d 15s -h http://127.0.0.1:8080 --pct
# echo "----------------------------"
# echo "Benchmark on nginx"
# # wrk -t8 -c100 -d30s http://127.0.0.1:8090/index.html
# rewrk -c 256 -t 4 -d 15s -h http://127.0.0.1:8090 --pct
# echo "----------------------------"
# echo "Benchmark on caddy"
# # wrk -t8 -c100 -d30s http://127.0.0.1:8100/index.html
# rewrk -c 256 -t 4 -d 15s -h http://127.0.0.1:8100 --pct

View file

@ -25,10 +25,12 @@ services:
build: build:
context: ../ context: ../
restart: unless-stopped restart: unless-stopped
environment:
- LOG_LEVEL=debug
- LOG_TO_FILE=false
ports: ports:
- 127.0.0.1:8080:8080 - 127.0.0.1:8080:8080
tty: false tty: false
privileged: true
volumes: volumes:
- ./rpxy.toml:/etc/rpxy.toml:ro - ./rpxy.toml:/etc/rpxy.toml:ro
networks: networks:

View file

@ -1,4 +1,14 @@
#!/usr/bin/env bash #!/usr/bin/env bash
LOG_FILE=/var/log/rpxy/rpxy.log LOG_FILE=/var/log/rpxy/rpxy.log
/run.sh 2>&1 | tee $LOG_FILE if [ -z ${LOG_TO_FILE} ]; then
LOG_TO_FILE=false
fi
if "${LOG_TO_FILE}"; then
echo "rpxy: Start with writing log file"
/run.sh 2>&1 | tee $LOG_FILE
else
echo "rpxy: Start without writing log file"
/run.sh 2>&1
fi

View file

@ -52,12 +52,9 @@ cp -p /etc/cron.daily/logrotate /etc/cron.hourly/
service cron start service cron start
# debug level logging # debug level logging
if [ -z $LOG_LEVEL ]; then
LOG_LEVEL=info LOG_LEVEL=info
if [ ${DEBUG} ]; then fi
echo "Logging in debug mode" echo "rpxy: Logging with level ${LOG_LEVEL}"
LOG_LEVEL=debug
fi
echo "Start rpxy"
RUST_LOG=${LOG_LEVEL} /opt/rpxy/sbin/rpxy --config ${CONFIG_FILE} RUST_LOG=${LOG_LEVEL} /opt/rpxy/sbin/rpxy --config ${CONFIG_FILE}

View file

@ -10,7 +10,8 @@ services:
build: build:
context: ./ context: ./
environment: environment:
- DEBUG=true - LOG_LEVEL=info
- LOG_TO_FILE
tty: false tty: false
privileged: true privileged: true
volumes: volumes:

View file

@ -8,6 +8,7 @@ mod error;
mod globals; mod globals;
mod log; mod log;
mod proxy; mod proxy;
mod utils;
use crate::{ use crate::{
config::parse_opts, config::parse_opts,

View file

@ -5,6 +5,9 @@ mod proxy_h3;
mod proxy_handler; mod proxy_handler;
mod proxy_main; mod proxy_main;
mod proxy_tls; mod proxy_tls;
mod utils_headers;
mod utils_request;
mod utils_synth_response;
pub use backend::*; pub use backend::*;
pub use backend_opt::UpstreamOption; pub use backend_opt::UpstreamOption;

View file

@ -1,27 +1,15 @@
// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
use super::{Proxy, Upstream, UpstreamOption}; use super::{utils_headers::*, utils_request::*, utils_synth_response::*, Proxy, Upstream};
use crate::{constants::*, error::*, log::*}; use crate::{constants::*, error::*, log::*};
use hyper::{ use hyper::{
client::connect::Connect, client::connect::Connect,
header::{HeaderMap, HeaderValue}, header::{self, HeaderValue},
http::uri::Scheme, http::uri::Scheme,
Body, Request, Response, StatusCode, Uri, Version, Body, Request, Response, StatusCode, Uri, Version,
}; };
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::copy_bidirectional; use tokio::io::copy_bidirectional;
const HOP_HEADERS: &[&str] = &[
"connection",
"te",
"trailer",
"keep-alive",
"proxy-connection",
"proxy-authenticate",
"proxy-authorization",
"transfer-encoding",
"upgrade",
];
impl<T> Proxy<T> impl<T> Proxy<T>
where where
T: Connect + Clone + Sync + Send + 'static, T: Connect + Clone + Sync + Send + 'static,
@ -31,42 +19,34 @@ where
mut req: Request<Body>, mut req: Request<Body>,
client_addr: SocketAddr, // アクセス制御用 client_addr: SocketAddr, // アクセス制御用
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
info!( let request_log = log_request_msg(&req, client_addr);
"Handling {:?} request from {}: {} {:?} {} {:?}",
req.version(),
client_addr,
req.method(),
req.headers().get("host").map_or_else(
|| req.uri().host().unwrap_or("<none>"),
|h| h.to_str().unwrap()
),
req.uri(),
req
.headers()
.get("user-agent")
.map_or_else(|| "<none>", |ua| ua.to_str().unwrap())
);
// Here we start to handle with server_name // Here we start to handle with server_name
// Find backend application for given server_name // Find backend application for given server_name
let (server_name, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) { let (server_name, _port) = if let Ok(v) = parse_host_port(&req) {
v v
} else { } else {
return http_error(StatusCode::SERVICE_UNAVAILABLE); info!("{} => {}", request_log, StatusCode::BAD_REQUEST);
}; return http_error(StatusCode::BAD_REQUEST);
let backend = if let Some(be) = self.backends.apps.get(server_name.as_str()) {
be
} else if let Some(default_be) = &self.backends.default_app {
debug!("Serving by default app: {}", default_be);
self.backends.apps.get(default_be).unwrap()
} else {
return http_error(StatusCode::SERVICE_UNAVAILABLE);
}; };
// Redirect to https if tls_enabled is false and redirect_to_https is true if !self.backends.apps.contains_key(&server_name) && self.backends.default_app.is_none() {
let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned(); info!("{} => {}", request_log, StatusCode::SERVICE_UNAVAILABLE);
return http_error(StatusCode::SERVICE_UNAVAILABLE);
}
let backend = if let Some(be) = self.backends.apps.get(&server_name) {
be
} else {
let default_be = self.backends.default_app.as_ref().unwrap();
debug!("Serving by default app: {}", default_be);
self.backends.apps.get(default_be).unwrap()
};
// Redirect to https if !tls_enabled and redirect_to_https is true
if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { if !self.tls_enabled && backend.https_redirection.unwrap_or(false) {
debug!("Redirect to secure connection: {}", server_name); debug!("Redirect to secure connection: {}", server_name);
return secure_redirection(&server_name, self.globals.https_port, &path_and_query); info!("{} => {}", request_log, StatusCode::PERMANENT_REDIRECT);
return secure_redirection(&server_name, self.globals.https_port, &req);
} }
// Find reverse proxy for given path and choose one of upstream host // Find reverse proxy for given path and choose one of upstream host
@ -94,7 +74,6 @@ where
client_addr, client_addr,
req, req,
upstream_scheme_host, upstream_scheme_host,
path_and_query,
&upgrade_in_request, &upgrade_in_request,
upstream, upstream,
) { ) {
@ -117,19 +96,19 @@ where
{ {
if self.globals.http3 { if self.globals.http3 {
if let Some(port) = self.globals.https_port { if let Some(port) = self.globals.https_port {
res_backend.headers_mut().insert( let alt_svc_value = HeaderValue::from_str(&format!(
hyper::header::ALT_SVC, "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
format!( port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE
"h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", ))
port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE .unwrap();
) res_backend
.parse() .headers_mut()
.unwrap(), .insert(header::ALT_SVC, alt_svc_value);
);
} }
} }
} }
debug!("Response from backend: {:?}", res_backend.status()); debug!("Response from backend: {:?}", res_backend.status());
let response_log = res_backend.status().to_string();
if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS { if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS {
// Handle StatusCode::SWITCHING_PROTOCOLS in response // Handle StatusCode::SWITCHING_PROTOCOLS in response
@ -153,29 +132,37 @@ where
.map_err(|e| anyhow!("Coping between upgraded connections failed: {}", e))?; // TODO: any response code? .map_err(|e| anyhow!("Coping between upgraded connections failed: {}", e))?; // TODO: any response code?
Ok(()) as Result<()> Ok(()) as Result<()>
}); });
info!("{} => {}", request_log, response_log);
Ok(res_backend) Ok(res_backend)
} else { } else {
error!("Request does not have an upgrade extension"); error!("Request does not have an upgrade extension");
http_error(StatusCode::BAD_GATEWAY) info!("{} => {}", request_log, StatusCode::BAD_REQUEST);
http_error(StatusCode::BAD_REQUEST)
} }
} else { } else {
error!( error!(
"Backend tried to switch to protocol {:?} when {:?} was requested", "Backend tried to switch to protocol {:?} when {:?} was requested",
upgrade_in_response, upgrade_in_request upgrade_in_response, upgrade_in_request
); );
http_error(StatusCode::BAD_GATEWAY) info!("{} => {}", request_log, StatusCode::SERVICE_UNAVAILABLE);
http_error(StatusCode::SERVICE_UNAVAILABLE)
} }
} else { } else {
// Generate response to client // Generate response to client
if generate_response_forwarded(&mut res_backend).is_ok() { if generate_response_forwarded(&mut res_backend).is_ok() {
info!("{} => {}", request_log, response_log);
Ok(res_backend) Ok(res_backend)
} else { } else {
info!("{} => {}", request_log, StatusCode::BAD_GATEWAY);
http_error(StatusCode::BAD_GATEWAY) http_error(StatusCode::BAD_GATEWAY)
} }
} }
} }
} }
////////////////////////////////////////////////////
// Functions to generate messages
fn generate_response_forwarded<B: core::fmt::Debug>(response: &mut Response<B>) -> Result<()> { fn generate_response_forwarded<B: core::fmt::Debug>(response: &mut Response<B>) -> Result<()> {
let headers = response.headers_mut(); let headers = response.headers_mut();
remove_hop_header(headers); remove_hop_header(headers);
@ -193,7 +180,6 @@ fn generate_request_forwarded<B: core::fmt::Debug>(
client_addr: SocketAddr, client_addr: SocketAddr,
mut req: Request<B>, mut req: Request<B>,
upstream_scheme_host: &Uri, upstream_scheme_host: &Uri,
path_and_query: String,
upgrade: &Option<String>, upgrade: &Option<String>,
upstream: &Upstream, upstream: &Upstream,
) -> Result<Request<B>> { ) -> Result<Request<B>> {
@ -201,11 +187,8 @@ fn generate_request_forwarded<B: core::fmt::Debug>(
// Add te: trailer if contained in original request // Add te: trailer if contained in original request
let te_trailers = { let te_trailers = {
if let Some(te) = req.headers().get("te") { if let Some(te) = req.headers().get(header::TE) {
te.to_str() te.to_str()?.split(',').any(|x| x.trim() == "trailers")
.unwrap()
.split(',')
.any(|x| x.trim() == "trailers")
} else { } else {
false false
} }
@ -222,16 +205,15 @@ fn generate_request_forwarded<B: core::fmt::Debug>(
// Add te: trailer if te_trailer // Add te: trailer if te_trailer
if te_trailers { if te_trailers {
headers.insert("te", "trailer".parse().unwrap()); headers.insert(header::TE, "trailer".parse()?);
} }
// add "host" header of original server_name if not exist (default) // add "host" header of original server_name if not exist (default)
if req.headers().get(hyper::header::HOST).is_none() { if req.headers().get(header::HOST).is_none() {
let org_host = req.uri().host().unwrap_or("none").to_owned(); let org_host = req.uri().host().unwrap_or("none").to_owned();
req.headers_mut().insert( req
hyper::header::HOST, .headers_mut()
HeaderValue::from_str(org_host.as_str()).unwrap(), .insert(header::HOST, HeaderValue::from_str(org_host.as_str())?);
);
}; };
// apply upstream-specific headers given in upstream_option // apply upstream-specific headers given in upstream_option
@ -239,18 +221,23 @@ fn generate_request_forwarded<B: core::fmt::Debug>(
apply_upstream_options_to_header(headers, client_addr, upstream_scheme_host, upstream)?; apply_upstream_options_to_header(headers, client_addr, upstream_scheme_host, upstream)?;
// update uri in request // update uri in request
*req.uri_mut() = Uri::builder() ensure!(upstream_scheme_host.authority().is_some() && upstream_scheme_host.scheme().is_some());
let new_uri = Uri::builder()
.scheme(upstream_scheme_host.scheme().unwrap().as_str()) .scheme(upstream_scheme_host.scheme().unwrap().as_str())
.authority(upstream_scheme_host.authority().unwrap().as_str()) .authority(upstream_scheme_host.authority().unwrap().as_str());
.path_and_query(&path_and_query) let pq = req.uri().path_and_query();
.build()?; *req.uri_mut() = match pq {
None => new_uri,
Some(x) => new_uri.path_and_query(x.to_owned()),
}
.build()?;
// upgrade // upgrade
if let Some(v) = upgrade { if let Some(v) = upgrade {
req.headers_mut().insert("upgrade", v.parse().unwrap()); req.headers_mut().insert("upgrade", v.parse()?);
req req
.headers_mut() .headers_mut()
.insert(hyper::header::CONNECTION, HeaderValue::from_str("upgrade")?); .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?);
} }
// Change version to http/1.1 when destination scheme is http // Change version to http/1.1 when destination scheme is http
@ -263,155 +250,3 @@ fn generate_request_forwarded<B: core::fmt::Debug>(
Ok(req) Ok(req)
} }
fn apply_upstream_options_to_header(
headers: &mut HeaderMap,
_client_addr: SocketAddr,
upstream_scheme_host: &Uri,
upstream: &Upstream,
) -> Result<()> {
upstream.opts.iter().for_each(|opt| match opt {
UpstreamOption::OverrideHost => {
let upstream_host = upstream_scheme_host.host().unwrap();
headers
.insert(
hyper::header::HOST,
HeaderValue::from_str(upstream_host).unwrap(),
)
.unwrap();
}
});
Ok(())
}
fn append_header_entry(headers: &mut HeaderMap, key: &'static str, value: &str) -> Result<()> {
match headers.entry(key) {
hyper::header::Entry::Vacant(entry) => {
entry.insert(value.parse::<HeaderValue>()?);
}
hyper::header::Entry::Occupied(mut entry) => {
entry.append(value.parse::<HeaderValue>()?);
}
}
Ok(())
}
fn add_forwarding_header(headers: &mut HeaderMap, client_addr: SocketAddr) -> Result<()> {
// default process
// optional process defined by upstream_option is applied in fn apply_upstream_options
append_header_entry(headers, "x-forwarded-for", &client_addr.ip().to_string())?;
Ok(())
}
fn remove_connection_header(headers: &mut HeaderMap) {
if headers.get("connection").is_some() {
let v = headers.get("connection").cloned().unwrap();
for m in v.to_str().unwrap().split(',') {
if !m.is_empty() {
headers.remove(m.trim());
}
}
}
}
fn remove_hop_header(headers: &mut HeaderMap) {
HOP_HEADERS.iter().for_each(|key| {
headers.remove(*key);
});
}
fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
let response = Response::builder()
.status(status_code)
.body(Body::empty())
.unwrap();
Ok(response)
}
fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
if let Some(c) = headers.get("connection") {
if c
.to_str()
.unwrap_or("")
.split(',')
.into_iter()
.any(|w| w.trim().to_ascii_lowercase() == "upgrade")
{
if let Some(u) = headers.get("upgrade") {
let m = u.to_str().unwrap().to_string();
debug!("Upgrade in request header: {}", m);
return Some(m);
}
}
}
None
}
fn secure_redirection(
server_name: &str,
tls_port: Option<u16>,
path_and_query: &str,
) -> Result<Response<Body>> {
let dest_uri: String = if let Some(tls_port) = tls_port {
if tls_port == 443 {
format!("https://{}{}", server_name, path_and_query)
} else {
format!("https://{}:{}{}", server_name, tls_port, path_and_query)
}
} else {
bail!("Internal error! TLS port is not set internally.");
};
let response = Response::builder()
.status(StatusCode::MOVED_PERMANENTLY)
.header("Location", dest_uri)
.body(Body::empty())
.unwrap();
Ok(response)
}
fn parse_host_port<B: core::fmt::Debug>(
req: &Request<B>,
tls_enabled: bool,
) -> Result<(String, u16)> {
let host_port_headers = req.headers().get("host");
let host_uri = req.uri().host();
let port_uri = req.uri().port_u16();
if host_port_headers.is_none() && host_uri.is_none() {
bail!("No host in request header");
}
let (host, port) = match (host_uri, host_port_headers) {
(Some(x), _) => {
let port = if let Some(p) = port_uri {
p
} else if tls_enabled {
443
} else {
80
};
(x.to_string(), port)
}
(None, Some(x)) => {
let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap();
let host = hp_as_uri
.host()
.ok_or_else(|| anyhow!("Failed to parse host"))?;
let port = if let Some(p) = hp_as_uri.port() {
p.as_u16()
} else if tls_enabled {
443
} else {
80
};
(host.to_string(), port)
}
(None, None) => {
bail!("Host unspecified in request")
}
};
Ok((host, port))
}

112
src/proxy/utils_headers.rs Normal file
View file

@ -0,0 +1,112 @@
use super::{Upstream, UpstreamOption};
use crate::{error::*, log::*, utils::*};
use hyper::{
header::{self, HeaderMap, HeaderValue},
Uri,
};
use std::net::SocketAddr;
////////////////////////////////////////////////////
// Functions to manipulate headers
pub(super) fn apply_upstream_options_to_header(
headers: &mut HeaderMap,
_client_addr: SocketAddr,
upstream_scheme_host: &Uri,
upstream: &Upstream,
) -> Result<()> {
for opt in upstream.opts.iter() {
match opt {
UpstreamOption::OverrideHost => {
let upstream_host = upstream_scheme_host.host().ok_or_else(|| anyhow!("none"))?;
headers
.insert(header::HOST, HeaderValue::from_str(upstream_host)?)
.ok_or_else(|| anyhow!("none"))?;
}
}
}
Ok(())
}
pub(super) fn append_header_entry(
headers: &mut HeaderMap,
key: &'static str,
value: &str,
) -> Result<()> {
match headers.entry(key) {
header::Entry::Vacant(entry) => {
entry.insert(value.parse::<HeaderValue>()?);
}
header::Entry::Occupied(mut entry) => {
entry.append(value.parse::<HeaderValue>()?);
}
}
Ok(())
}
pub(super) fn add_forwarding_header(
headers: &mut HeaderMap,
client_addr: SocketAddr,
) -> Result<()> {
// default process
// optional process defined by upstream_option is applied in fn apply_upstream_options
append_header_entry(
headers,
"x-forwarded-for",
&client_addr.to_canonical().ip().to_string(),
)?;
Ok(())
}
pub(super) fn remove_connection_header(headers: &mut HeaderMap) {
if let Some(values) = headers.get(header::CONNECTION) {
if let Ok(v) = values.clone().to_str() {
for m in v.split(',') {
if !m.is_empty() {
headers.remove(m.trim());
}
}
}
}
}
const HOP_HEADERS: &[&str] = &[
"connection",
"te",
"trailer",
"keep-alive",
"proxy-connection",
"proxy-authenticate",
"proxy-authorization",
"transfer-encoding",
"upgrade",
];
pub(super) fn remove_hop_header(headers: &mut HeaderMap) {
HOP_HEADERS.iter().for_each(|key| {
headers.remove(*key);
});
}
pub(super) fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
if let Some(c) = headers.get(header::CONNECTION) {
if c
.to_str()
.unwrap_or("")
.split(',')
.into_iter()
.any(|w| w.trim().to_ascii_lowercase() == header::UPGRADE.as_str().to_ascii_lowercase())
{
if let Some(u) = headers.get(header::UPGRADE) {
if let Ok(m) = u.to_str() {
debug!("Upgrade in request header: {}", m);
return Some(m.to_owned());
}
}
}
}
None
}

View file

@ -0,0 +1,58 @@
use crate::{error::*, utils::*};
use hyper::{header, Request, Uri};
use std::net::SocketAddr;
////////////////////////////////////////////////////
// Functions of utils for request messages
pub(super) fn log_request_msg<B>(req: &Request<B>, client_addr: SocketAddr) -> String {
let server_name = req.headers().get(header::HOST).map_or_else(
|| {
req
.uri()
.authority()
.map_or_else(|| "<none>", |au| au.as_str())
},
|h| h.to_str().unwrap_or("<none>"),
);
return format!(
"{} <- {} -- {} {:?} {:?} ({:?})",
server_name,
client_addr.to_canonical(),
req.method(),
req.version(),
req
.uri()
.path_and_query()
.map_or_else(|| "", |v| v.as_str()),
req.headers()
);
}
pub(super) fn parse_host_port<B: core::fmt::Debug>(
req: &Request<B>,
) -> Result<(String, Option<u16>)> {
let headers_host = req.headers().get("host");
let uri_host = req.uri().host();
let uri_port = req.uri().port_u16();
ensure!(
!(headers_host.is_none() && uri_host.is_none()),
"No host in request header"
);
// prioritize server_name in uri
if let Some(v) = uri_host {
Ok((v.to_string(), uri_port))
} else {
let uri_from_host = headers_host.unwrap().to_str()?.parse::<Uri>()?;
Ok((
uri_from_host
.host()
.ok_or_else(|| anyhow!("Failed to parse host"))?
.to_string(),
uri_from_host.port_u16(),
))
}
}

View file

@ -0,0 +1,35 @@
// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
use crate::error::*;
use hyper::{Body, Request, Response, StatusCode, Uri};
////////////////////////////////////////////////////
// Functions to create response (error or redirect)
pub(super) fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
let response = Response::builder()
.status(status_code)
.body(Body::empty())?;
Ok(response)
}
pub(super) fn secure_redirection<B>(
server_name: &str,
tls_port: Option<u16>,
req: &Request<B>,
) -> Result<Response<Body>> {
let pq = match req.uri().path_and_query() {
Some(x) => x.as_str(),
_ => "",
};
let new_uri = Uri::builder().scheme("https").path_and_query(pq);
let dest_uri = match tls_port {
Some(443) | None => new_uri.authority(server_name),
Some(p) => new_uri.authority(format!("{}:{}", server_name, p)),
}
.build()?;
let response = Response::builder()
.status(StatusCode::MOVED_PERMANENTLY)
.header("Location", dest_uri.to_string())
.body(Body::empty())?;
Ok(response)
}

3
src/utils/mod.rs Normal file
View file

@ -0,0 +1,3 @@
mod socket_addr;
pub use socket_addr::ToCanonical;

63
src/utils/socket_addr.rs Normal file
View file

@ -0,0 +1,63 @@
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
pub trait ToCanonical {
fn to_canonical(&self) -> Self;
}
impl ToCanonical for SocketAddr {
fn to_canonical(&self) -> Self {
match self {
SocketAddr::V4(_) => *self,
SocketAddr::V6(v6) => match v6.ip().to_ipv4() {
Some(mapped) => {
if mapped == Ipv4Addr::new(0, 0, 0, 1) {
*self
} else {
SocketAddr::new(IpAddr::V4(mapped), self.port())
}
}
None => *self,
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv6Addr;
#[test]
fn ipv4_loopback_to_canonical() {
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
assert_eq!(socket.to_canonical(), socket);
}
#[test]
fn ipv6_loopback_to_canonical() {
let socket = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080);
assert_eq!(socket.to_canonical(), socket);
}
#[test]
fn ipv4_to_canonical() {
let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080);
assert_eq!(socket.to_canonical(), socket);
}
#[test]
fn ipv6_to_canonical() {
let socket = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0xdead, 0xbeef)),
8080,
);
assert_eq!(socket.to_canonical(), socket);
}
#[test]
fn ipv4_mapped_to_ipv6_to_canonical() {
let socket = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)),
8080,
);
assert_eq!(
socket.to_canonical(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 10, 2, 255)), 8080)
);
}
}