add upgrade and x-forwarded-for

This commit is contained in:
Jun Kurihara 2022-06-24 23:21:54 -04:00
commit 8db9e647e3
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
2 changed files with 110 additions and 26 deletions

View file

@ -24,7 +24,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>
let mut map_example: HashMap<String, Uri> = HashMap::new();
map_example.insert(
"/maps".to_string(),
Uri::builder().authority("www.bing.com").build().unwrap(),
"https://www.bing.com".parse::<Uri>().unwrap(),
);
backends.insert(
"localhost".to_string(),
@ -32,7 +32,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>
app_name: "Localhost to Google except for maps".to_string(),
hostname: "localhost".to_string(),
reverse_proxy: ReverseProxy {
default_destination_uri: Uri::builder().authority("www.google.com").build().unwrap(),
default_destination_uri: "https://www.google.com".parse::<Uri>().unwrap(),
destination_uris: map_example,
},
https_redirection: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず

View file

@ -1,26 +1,33 @@
// Motivated by https://github.com/felipenoris/hyper-reverse-proxy
use super::Proxy;
use crate::{error::*, log::*};
use hyper::{
client::connect::Connect,
header::{HeaderMap, HeaderName, HeaderValue},
header::{HeaderMap, HeaderValue},
Body, Request, Response, StatusCode, Uri,
};
use std::net::SocketAddr;
// pub static HEADERS: phf::Map<&'static str, HeaderName> = phf_map! {
// "CONNECTION" => HeaderName::from_static("connection"),
// "ws" => "wss",
// };
const HOP_HEADERS: &[&str] = &[
"connection",
"te",
"trailer",
"keep-alive",
"proxy-connection",
"proxy-authenticate",
"proxy-authorization",
"transfer-encoding",
"upgrade",
];
impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
// TODO: ここでbackendの名前単位でリクエストを分岐させる
pub async fn handle_request(
self,
req: Request<Body>,
client_ip: SocketAddr, // アクセス制御用
client_addr: SocketAddr, // アクセス制御用
) -> Result<Response<Body>> {
debug!("Handling request: {:?}", req);
// Here we start to handle with hostname
@ -45,20 +52,24 @@ where
// Find reverse proxy for given path
let path = req.uri().path();
let destination_host_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) {
uri.to_owned()
} else {
backend.reverse_proxy.default_destination_uri.clone()
};
let destination_scheme_host =
if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) {
uri.to_owned()
} else {
backend.reverse_proxy.default_destination_uri.clone()
};
// TODO: Upgrade
// TODO: X-Forwarded-For
// TODO: Transfer Encoding
// Upgrade in request header
let upgrade_in_request = extract_upgrade(req.headers());
// Build request from destination information
let req_forwarded = if let Ok(req) =
generate_request_forwarded(client_ip, req, destination_host_uri, path_and_query)
{
let req_forwarded = if let Ok(req) = generate_request_forwarded(
client_addr,
req,
destination_scheme_host,
path_and_query,
upgrade_in_request,
) {
req
} else {
error!("Failed to generate destination uri for reverse proxy");
@ -81,35 +92,89 @@ where
}
}
// Motivated by https://github.com/felipenoris/hyper-reverse-proxy
fn generate_request_forwarded<B: core::fmt::Debug>(
client_ip: SocketAddr,
client_addr: SocketAddr,
mut req: Request<B>,
destination_host_uri: Uri,
destination_scheme_host: Uri,
path_and_query: String,
upgrade: Option<String>,
) -> Result<Request<B>> {
debug!("Generate request to be forwarded");
// TODO: Transfer Encoding
// update "host" key in request header
if req.headers().contains_key("host") {
// HTTP/1.1
req.headers_mut().insert(
"host",
HeaderValue::from_str(destination_host_uri.host().unwrap())
HeaderValue::from_str(destination_scheme_host.host().unwrap())
.map_err(|_| anyhow!("Failed to insert destination host into forwarded request"))?,
);
}
let headers = req.headers_mut();
// delete headers specified in header.connection
remove_connection_header(headers);
// delete hop headers including header.connection
remove_hop_header(headers);
// X-Forwarded-For
add_forwarding_header(headers, client_addr)?;
// update uri in request
*req.uri_mut() = Uri::builder()
.scheme(destination_host_uri.scheme().unwrap().as_str())
.authority(destination_host_uri.authority().unwrap().as_str())
.scheme(destination_scheme_host.scheme().unwrap().as_str())
.authority(destination_scheme_host.authority().unwrap().as_str())
.path_and_query(&path_and_query)
.build()?;
// upgrade
if let Some(v) = upgrade {
req.headers_mut().insert("upgrade", v.parse().unwrap());
req
.headers_mut()
.insert("connection", HeaderValue::from_str("upgrade")?);
}
Ok(req)
}
fn add_forwarding_header(headers: &mut HeaderMap, client_addr: SocketAddr) -> Result<()> {
let client_ip = client_addr.ip();
match headers.entry("x-forwarded-for") {
hyper::header::Entry::Vacant(entry) => {
entry.insert(client_ip.to_string().parse()?);
}
hyper::header::Entry::Occupied(entry) => {
let client_ip_str = client_ip.to_string();
let mut addr = String::with_capacity(entry.get().as_bytes().len() + 2 + client_ip_str.len());
addr.push_str(std::str::from_utf8(entry.get().as_bytes()).unwrap());
addr.push(',');
addr.push(' ');
addr.push_str(&client_ip_str);
}
}
Ok(())
}
fn remove_connection_header(headers: &mut HeaderMap) {
if headers.get("connection").is_some() {
let v = headers.get("connection").cloned().unwrap();
for m in v.to_str().unwrap().split(',') {
if !m.is_empty() {
headers.remove(m.trim());
}
}
}
}
fn remove_hop_header(headers: &mut HeaderMap) {
let _ = HOP_HEADERS.iter().for_each(|key| {
headers.remove(*key);
});
}
fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
let response = Response::builder()
.status(status_code)
@ -118,6 +183,25 @@ fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
Ok(response)
}
fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
if let Some(c) = headers.get("connection") {
if c
.to_str()
.unwrap_or("")
.split(',')
.into_iter()
.any(|w| w.trim().to_ascii_lowercase() == "upgrade")
{
if let Some(u) = headers.get("upgrade") {
let m = u.to_str().unwrap().to_string();
debug!("Upgrade in request header: {}", m);
return Some(m);
}
}
}
None
}
fn secure_redirection(
hostname: &str,
tls_port: Option<u16>,