rust-rpxy/src/proxy/proxy_handler.rs
2022-07-06 14:58:49 +09:00

369 lines
11 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
use super::Proxy;
use crate::{constants::*, error::*, log::*};
use hyper::{
client::connect::Connect,
header::{HeaderMap, HeaderValue},
http::uri::Scheme,
Body, Request, Response, StatusCode, Uri, Version,
};
use std::net::SocketAddr;
use tokio::io::copy_bidirectional;
const HOP_HEADERS: &[&str] = &[
"connection",
"te",
"trailer",
"keep-alive",
"proxy-connection",
"proxy-authenticate",
"proxy-authorization",
"transfer-encoding",
"upgrade",
];
impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
pub async fn handle_request(
self,
mut req: Request<Body>,
client_addr: SocketAddr, // アクセス制御用
) -> Result<Response<Body>> {
info!(
"Handling {:?} request from {}: {} {} {:?}",
req.version(),
client_addr,
req.method(),
req.uri(),
req
.headers()
.get("user-agent")
.map_or_else(|| "<none>", |ua| ua.to_str().unwrap())
);
// Here we start to handle with server_name
// Find backend application for given server_name
let (server_name, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) {
v
} else {
return http_error(StatusCode::SERVICE_UNAVAILABLE);
};
let backend = if let Some(be) = self.backends.apps.get(server_name.as_str()) {
be
} else if let Some(default_be) = &self.backends.default_app {
debug!("Serving by default app: {}", default_be);
self.backends.apps.get(default_be).unwrap()
} else {
return http_error(StatusCode::SERVICE_UNAVAILABLE);
};
// Redirect to https if tls_enabled is false and redirect_to_https is true
let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned();
if !self.tls_enabled && backend.https_redirection.unwrap_or(false) {
debug!("Redirect to secure connection: {}", server_name);
return secure_redirection(&server_name, self.globals.https_port, &path_and_query);
}
// Find reverse proxy for given path and choose one of upstream host
// TODO: More flexible path matcher
let path = req.uri().path();
let upstream_uri = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) {
upstream.get()
} else if let Some(default_upstream) = &backend.reverse_proxy.default_upstream {
default_upstream.get()
} else {
return http_error(StatusCode::NOT_FOUND);
};
let upstream_scheme_host = if let Some(u) = upstream_uri {
u
} else {
return http_error(StatusCode::INTERNAL_SERVER_ERROR);
};
// Upgrade in request header
let upgrade_in_request = extract_upgrade(req.headers());
let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
// Build request from destination information
let req_forwarded = if let Ok(req) = generate_request_forwarded(
client_addr,
req,
upstream_scheme_host,
path_and_query,
&upgrade_in_request,
) {
req
} else {
error!("Failed to generate destination uri for reverse proxy");
return http_error(StatusCode::SERVICE_UNAVAILABLE);
};
debug!("Request to be forwarded: {:?}", req_forwarded);
// Forward request to
let mut res_backend = match self.forwarder.request(req_forwarded).await {
Ok(res) => res,
Err(e) => {
error!("Failed to get response from backend: {}", e);
return http_error(StatusCode::BAD_REQUEST);
}
};
#[cfg(feature = "h3")]
{
if self.globals.http3 {
if let Some(port) = self.globals.https_port {
res_backend.headers_mut().insert(
hyper::header::ALT_SVC,
format!(
"h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE
)
.parse()
.unwrap(),
);
}
}
}
debug!("Response from backend: {:?}", res_backend.status());
if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS {
// Handle StatusCode::SWITCHING_PROTOCOLS in response
let upgrade_in_response = extract_upgrade(res_backend.headers());
if upgrade_in_request == upgrade_in_response {
if let Some(request_upgraded) = request_upgraded {
let mut response_upgraded = res_backend
.extensions_mut()
.remove::<hyper::upgrade::OnUpgrade>()
.ok_or_else(|| anyhow!("Response does not have an upgrade extension"))? // TODO: any response code?
.await?;
// TODO: H3で死ぬことがある
// thread 'rpxy' panicked at 'Failed to upgrade request: hyper::Error(User(ManualUpgrade))', src/proxy/proxy_handler.rs:124:63
tokio::spawn(async move {
let mut request_upgraded = request_upgraded.await.map_err(|e| {
error!("Failed to upgrade request: {}", e);
anyhow!("Failed to upgrade request: {}", e)
})?; // TODO: any response code?
copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
.await
.map_err(|e| anyhow!("Coping between upgraded connections failed: {}", e))?; // TODO: any response code?
Ok(()) as Result<()>
});
Ok(res_backend)
} else {
error!("Request does not have an upgrade extension");
http_error(StatusCode::BAD_GATEWAY)
}
} else {
error!(
"Backend tried to switch to protocol {:?} when {:?} was requested",
upgrade_in_response, upgrade_in_request
);
http_error(StatusCode::BAD_GATEWAY)
}
} else {
// Generate response to client
if generate_response_forwarded(&mut res_backend).is_ok() {
Ok(res_backend)
} else {
http_error(StatusCode::BAD_GATEWAY)
}
}
}
}
fn generate_response_forwarded<B: core::fmt::Debug>(response: &mut Response<B>) -> Result<()> {
let headers = response.headers_mut();
remove_hop_header(headers);
remove_connection_header(headers);
Ok(())
}
fn generate_request_forwarded<B: core::fmt::Debug>(
client_addr: SocketAddr,
mut req: Request<B>,
upstream_scheme_host: &Uri,
path_and_query: String,
upgrade: &Option<String>,
) -> Result<Request<B>> {
debug!("Generate request to be forwarded");
// Add te: trailer if contained in original request
let te_trailers = {
if let Some(te) = req.headers().get("te") {
te.to_str()
.unwrap()
.split(',')
.any(|x| x.trim() == "trailers")
} else {
false
}
};
let headers = req.headers_mut();
// delete headers specified in header.connection
remove_connection_header(headers);
// delete hop headers including header.connection
remove_hop_header(headers);
// X-Forwarded-For
add_forwarding_header(headers, client_addr)?;
// Add te: trailer if te_trailer
if te_trailers {
headers.insert("te", "trailer".parse().unwrap());
}
// update uri in request
*req.uri_mut() = Uri::builder()
.scheme(upstream_scheme_host.scheme().unwrap().as_str())
.authority(upstream_scheme_host.authority().unwrap().as_str())
.path_and_query(&path_and_query)
.build()?;
// upgrade
if let Some(v) = upgrade {
req.headers_mut().insert("upgrade", v.parse().unwrap());
req
.headers_mut()
.insert("connection", HeaderValue::from_str("upgrade")?);
}
// Change version to http/1.1 when destination scheme is http
if req.version() != Version::HTTP_11 && upstream_scheme_host.scheme() == Some(&Scheme::HTTP) {
*req.version_mut() = Version::HTTP_11;
} else if req.version() == Version::HTTP_3 {
debug!("HTTP/3 is currently unsupported for request to upstream. Use HTTP/2.");
*req.version_mut() = Version::HTTP_2;
}
Ok(req)
}
fn add_forwarding_header(headers: &mut HeaderMap, client_addr: SocketAddr) -> Result<()> {
// TODO: Option対応
let client_ip = client_addr.ip();
match headers.entry("x-forwarded-for") {
hyper::header::Entry::Vacant(entry) => {
entry.insert(client_ip.to_string().parse()?);
}
hyper::header::Entry::Occupied(entry) => {
let client_ip_str = client_ip.to_string();
let mut addr = String::with_capacity(entry.get().as_bytes().len() + 2 + client_ip_str.len());
addr.push_str(std::str::from_utf8(entry.get().as_bytes()).unwrap());
addr.push(',');
addr.push(' ');
addr.push_str(&client_ip_str);
}
}
Ok(())
}
fn remove_connection_header(headers: &mut HeaderMap) {
if headers.get("connection").is_some() {
let v = headers.get("connection").cloned().unwrap();
for m in v.to_str().unwrap().split(',') {
if !m.is_empty() {
headers.remove(m.trim());
}
}
}
}
fn remove_hop_header(headers: &mut HeaderMap) {
HOP_HEADERS.iter().for_each(|key| {
headers.remove(*key);
});
}
fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
let response = Response::builder()
.status(status_code)
.body(Body::empty())
.unwrap();
Ok(response)
}
fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
if let Some(c) = headers.get("connection") {
if c
.to_str()
.unwrap_or("")
.split(',')
.into_iter()
.any(|w| w.trim().to_ascii_lowercase() == "upgrade")
{
if let Some(u) = headers.get("upgrade") {
let m = u.to_str().unwrap().to_string();
debug!("Upgrade in request header: {}", m);
return Some(m);
}
}
}
None
}
fn secure_redirection(
server_name: &str,
tls_port: Option<u16>,
path_and_query: &str,
) -> Result<Response<Body>> {
let dest_uri: String = if let Some(tls_port) = tls_port {
if tls_port == 443 {
format!("https://{}{}", server_name, path_and_query)
} else {
format!("https://{}:{}{}", server_name, tls_port, path_and_query)
}
} else {
bail!("Internal error! TLS port is not set internally.");
};
let response = Response::builder()
.status(StatusCode::MOVED_PERMANENTLY)
.header("Location", dest_uri)
.body(Body::empty())
.unwrap();
Ok(response)
}
fn parse_host_port<B: core::fmt::Debug>(
req: &Request<B>,
tls_enabled: bool,
) -> Result<(String, u16)> {
let host_port_headers = req.headers().get("host");
let host_uri = req.uri().host();
let port_uri = req.uri().port_u16();
if host_port_headers.is_none() && host_uri.is_none() {
bail!("No host in request header");
}
let (host, port) = match (host_uri, host_port_headers) {
(Some(x), _) => {
let port = if let Some(p) = port_uri {
p
} else if tls_enabled {
443
} else {
80
};
(x.to_string(), port)
}
(None, Some(x)) => {
let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap();
let host = hp_as_uri
.host()
.ok_or_else(|| anyhow!("Failed to parse host"))?;
let port = if let Some(p) = hp_as_uri.port() {
p.as_u16()
} else if tls_enabled {
443
} else {
80
};
(host.to_string(), port)
}
(None, None) => {
bail!("Host unspecified in request")
}
};
Ok((host, port))
}