This commit is contained in:
Jun Kurihara 2022-06-24 19:33:46 -04:00
commit 8a89fcb2c2
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
7 changed files with 153 additions and 97 deletions

View file

@ -14,7 +14,6 @@ publish = false
[features] [features]
default = ["tls"] default = ["tls"]
tls = ["tokio-rustls", "rustls-pemfile"] tls = ["tokio-rustls", "rustls-pemfile"]
forward-hyper-trust-dns = ["hyper-trust-dns"]
[dependencies] [dependencies]
anyhow = "1.0.57" anyhow = "1.0.57"
@ -46,9 +45,9 @@ hyper-trust-dns = { version = "0.4.2", default-features = false, features = [
"dnssec-ring", "dnssec-ring",
"dns-over-https-rustls", "dns-over-https-rustls",
"rustls-webpki", "rustls-webpki",
], optional = true } ] }
hyper-tls = "0.5.0"
rustls = "0.20.6" rustls = "0.20.6"
# phf = { version = "0.10", features = ["macros"] }
[dev-dependencies] [dev-dependencies]

View file

@ -13,13 +13,13 @@ https_port = 8443
# Backend settings # # Backend settings #
################################### ###################################
[[backend]] [[application]]
app_name = 'localhost' # this should be option, if null then same as hostname app_name = 'localhost' # this should be option, if null then same as hostname
hostname = 'localhost' hostname = 'localhost'
redirect_to_https = true https_redirection = true
reverse_proxy = [ reverse_proxy = [
{ path = '*', destination = 'https://192.168.10.0:3000' }, { path = '*', destination = '192.168.10.0:3000', tls = true },
{ path = '/path/to', destination = 'https://192.168.10.1:4000/path/to' }, { path = '/path/to', destination = '192.168.10.1:4000', tls = true },
] ]
## List of destinations to send data to. ## List of destinations to send data to.
## At this point, round-robin is used for load-balancing if multiple URLs are specified. ## At this point, round-robin is used for load-balancing if multiple URLs are specified.
@ -29,10 +29,10 @@ tls_cert_path = 'localhost1.pem'
tls_cert_key_path = 'localhost1.pem' tls_cert_key_path = 'localhost1.pem'
[[backend]] [[application]]
app_name = 'locahost_application' app_name = 'locahost_application'
hostname = 'localhost.localdomain' hostname = 'localhost.localdomain'
redirect_to_https = true https_redirection = true
reverse_proxy = [{ path = '/', destination = 'https://www.google.com/' }] reverse_proxy = [{ path = '/', destination = 'www.google.com', tls = true }]
tls_cert_path = 'localhost2.pem' tls_cert_path = 'localhost2.pem'
tls_cert_key_path = 'localhost2.pem' tls_cert_key_path = 'localhost2.pem'

View file

@ -12,7 +12,7 @@ pub struct Backend {
pub app_name: String, pub app_name: String,
pub hostname: String, pub hostname: String,
pub reverse_proxy: ReverseProxy, pub reverse_proxy: ReverseProxy,
pub redirect_to_https: Option<bool>, pub https_redirection: Option<bool>,
pub tls_cert_path: Option<PathBuf>, pub tls_cert_path: Option<PathBuf>,
pub tls_cert_key_path: Option<PathBuf>, pub tls_cert_key_path: Option<PathBuf>,
pub server_config: Mutex<Option<ServerConfig>>, pub server_config: Mutex<Option<ServerConfig>>,

View file

@ -24,7 +24,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>
let mut map_example: HashMap<String, Uri> = HashMap::new(); let mut map_example: HashMap<String, Uri> = HashMap::new();
map_example.insert( map_example.insert(
"/maps".to_string(), "/maps".to_string(),
"https://bing.com/".parse::<Uri>().unwrap(), Uri::builder().authority("www.bing.com").build().unwrap(),
); );
backends.insert( backends.insert(
"localhost".to_string(), "localhost".to_string(),
@ -32,10 +32,10 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>
app_name: "Localhost to Google except for maps".to_string(), app_name: "Localhost to Google except for maps".to_string(),
hostname: "localhost".to_string(), hostname: "localhost".to_string(),
reverse_proxy: ReverseProxy { reverse_proxy: ReverseProxy {
default_destination_uri: "https://google.com/".parse::<Uri>().unwrap(), default_destination_uri: Uri::builder().authority("www.google.com").build().unwrap(),
destination_uris: map_example, destination_uris: map_example,
}, },
redirect_to_https: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず https_redirection: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず
tls_cert_path: Some(PathBuf::from(r"localhost1.pem")), tls_cert_path: Some(PathBuf::from(r"localhost1.pem")),
tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")), tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")),

View file

@ -14,7 +14,6 @@ use crate::{
}; };
use futures::future::select_all; use futures::future::select_all;
use hyper::Client; use hyper::Client;
#[cfg(feature = "forward-hyper-trust-dns")]
use hyper_trust_dns::TrustDnsResolver; use hyper_trust_dns::TrustDnsResolver;
use std::{collections::HashMap, io::Write, sync::Arc}; use std::{collections::HashMap, io::Write, sync::Arc};
use tokio::time::Duration; use tokio::time::Duration;
@ -40,7 +39,7 @@ fn main() {
let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
runtime_builder.enable_all(); runtime_builder.enable_all();
runtime_builder.thread_name("rust-rpxy"); runtime_builder.thread_name("rpxy");
let runtime = runtime_builder.build().unwrap(); let runtime = runtime_builder.build().unwrap();
runtime.block_on(async { runtime.block_on(async {
@ -69,10 +68,7 @@ fn main() {
// entrypoint creates and spawns tasks of proxy services // entrypoint creates and spawns tasks of proxy services
async fn entrypoint(globals: Arc<Globals>, backends: Arc<HashMap<String, Backend>>) -> Result<()> { async fn entrypoint(globals: Arc<Globals>, backends: Arc<HashMap<String, Backend>>) -> Result<()> {
#[cfg(feature = "forward-hyper-trust-dns")]
let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector();
#[cfg(not(feature = "forward-hyper-trust-dns"))]
let connector = hyper_tls::HttpsConnector::new();
let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector));
let addresses = globals.listen_sockets.clone(); let addresses = globals.listen_sockets.clone();

View file

@ -1,69 +1,113 @@
use crate::{backend::Backend, error::*, globals::Globals, log::*}; use super::Proxy;
use futures::{ use crate::{error::*, log::*};
select,
task::{Context, Poll},
Future, FutureExt,
};
use hyper::{ use hyper::{
client::connect::Connect, client::connect::Connect,
http, header::{HeaderMap, HeaderName, HeaderValue},
server::conn::Http, Body, Request, Response, StatusCode, Uri,
service::{service_fn, Service},
Body, Client, HeaderMap, Method, Request, Response, StatusCode, Uri,
};
use std::{collections::HashMap, net::SocketAddr, pin::Pin, sync::Arc};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::TcpListener,
runtime::Handle,
time::Duration,
}; };
use std::net::SocketAddr;
// pub static HEADERS: phf::Map<&'static str, HeaderName> = phf_map! {
// "CONNECTION" => HeaderName::from_static("connection"),
// "ws" => "wss",
// };
impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
// TODO: ここでbackendの名前単位でリクエストを分岐させる // TODO: ここでbackendの名前単位でリクエストを分岐させる
pub async fn handle_request( pub async fn handle_request(
self,
req: Request<Body>, req: Request<Body>,
client_ip: SocketAddr, client_ip: SocketAddr, // アクセス制御用
tls_enabled: bool,
globals: Arc<Globals>,
backends: Arc<HashMap<String, Backend>>,
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
debug!("req: {:?}", req); debug!("Handling request: {:?}", req);
// Here we start to handle with hostname // Here we start to handle with hostname
// Find backend application for given hostname // Find backend application for given hostname
let (hostname, _port) = parse_hostname_port(&req, tls_enabled)?; let (hostname, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) {
let path = req.uri().path(); v
let path_and_query = req.uri().path_and_query().unwrap().as_str(); } else {
println!("{:?}", path_and_query); return http_error(StatusCode::SERVICE_UNAVAILABLE);
let backend = if let Some(be) = backends.get(hostname.as_str()) { };
let backend = if let Some(be) = self.backends.get(hostname.as_str()) {
be be
} else { } else {
return http_error(StatusCode::SERVICE_UNAVAILABLE); return http_error(StatusCode::SERVICE_UNAVAILABLE);
}; };
// Redirect to https if tls_enabled is false and redirect_to_https is true // Redirect to https if tls_enabled is false and redirect_to_https is true
if !tls_enabled && backend.redirect_to_https.unwrap_or(false) { let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned();
debug!("Redirect to https: {}", hostname); if !self.tls_enabled && backend.https_redirection.unwrap_or(false) {
return https_redirection(hostname, globals.https_port, path_and_query); debug!("Redirect to secure connection: {}", hostname);
return secure_redirection(&hostname, self.globals.https_port, &path_and_query);
} }
// Find reverse proxy for given path // Find reverse proxy for given path
let destination_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) { let path = req.uri().path();
let destination_host_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) {
uri.to_owned() uri.to_owned()
} else { } else {
backend.reverse_proxy.default_destination_uri.clone() backend.reverse_proxy.default_destination_uri.clone()
}; };
debug!("destination_uri: {}", destination_uri); // TODO: Upgrade
// if req.version() == hyper::Version::HTTP_11 { // TODO: X-Forwarded-For
// Ok(Response::new(Body::from("Hello World"))) // TODO: Transfer Encoding
// } else {
// Note: it's usually better to return a Response // Build request from destination information
// with an appropriate StatusCode instead of an Err. let req_forwarded = if let Ok(req) =
// Err("not HTTP/1.1, abort connection") generate_request_forwarded(client_ip, req, destination_host_uri, path_and_query)
// http_error(StatusCode::NOT_FOUND) {
https_redirection("www.google.com".to_string(), Some(443_u16), "/") req
} else {
error!("Failed to generate destination uri for reverse proxy");
return http_error(StatusCode::SERVICE_UNAVAILABLE);
};
debug!("Request to be forwarded: {:?}", req_forwarded);
// // Forward request to
// let res_backend = match self.forwarder.request(req_forwarded).await {
// Ok(res) => res,
// Err(e) => {
// error!("Failed to get response from backend: {}", e);
// return http_error(StatusCode::BAD_REQUEST);
// } // }
// }); // };
// debug!("Response from backend: {:?}", res_backend.status());
// Ok(res_backend)
http_error(StatusCode::NOT_FOUND)
}
}
// Motivated by https://github.com/felipenoris/hyper-reverse-proxy
fn generate_request_forwarded<B: core::fmt::Debug>(
client_ip: SocketAddr,
mut req: Request<B>,
destination_host_uri: Uri,
path_and_query: String,
) -> Result<Request<B>> {
debug!("Generate request to be forwarded");
// update "host" key in request header
if req.headers().contains_key("host") {
// HTTP/1.1
req.headers_mut().insert(
"host",
HeaderValue::from_str(destination_host_uri.host().unwrap())
.map_err(|_| anyhow!("Failed to insert destination host into forwarded request"))?,
);
}
// update uri in request
*req.uri_mut() = Uri::builder()
.scheme(destination_host_uri.scheme().unwrap().as_str())
.authority(destination_host_uri.authority().unwrap().as_str())
.path_and_query(&path_and_query)
.build()?;
Ok(req)
} }
fn http_error(status_code: StatusCode) -> Result<Response<Body>> { fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
@ -74,19 +118,19 @@ fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
Ok(response) Ok(response)
} }
fn https_redirection( fn secure_redirection(
hostname: String, hostname: &str,
https_port: Option<u16>, tls_port: Option<u16>,
path_and_query: &str, path_and_query: &str,
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
let dest_uri: String = if let Some(https_port) = https_port { let dest_uri: String = if let Some(tls_port) = tls_port {
if https_port == 443 { if tls_port == 443 {
format!("https://{}{}", hostname, path_and_query) format!("https://{}{}", hostname, path_and_query)
} else { } else {
format!("https://{}:{}{}", hostname, https_port, path_and_query) format!("https://{}:{}{}", hostname, tls_port, path_and_query)
} }
} else { } else {
return http_error(StatusCode::SERVICE_UNAVAILABLE); bail!("Internal error! TLS port is not set internally.");
}; };
let response = Response::builder() let response = Response::builder()
.status(StatusCode::MOVED_PERMANENTLY) .status(StatusCode::MOVED_PERMANENTLY)
@ -96,7 +140,7 @@ fn https_redirection(
Ok(response) Ok(response)
} }
fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> { fn parse_host_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> {
let hostname_port_headers = req.headers().get("host"); let hostname_port_headers = req.headers().get("host");
let hostname_uri = req.uri().host(); let hostname_uri = req.uri().host();
let port_uri = req.uri().port_u16(); let port_uri = req.uri().port_u16();
@ -107,7 +151,6 @@ fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String
let (hostname, port) = match (hostname_uri, hostname_port_headers) { let (hostname, port) = match (hostname_uri, hostname_port_headers) {
(Some(x), _) => { (Some(x), _) => {
let hostname = hostname_uri.unwrap();
let port = if let Some(p) = port_uri { let port = if let Some(p) = port_uri {
p p
} else if tls_enabled { } else if tls_enabled {
@ -115,7 +158,7 @@ fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String
} else { } else {
80 80
}; };
(hostname.to_string(), port) (x.to_string(), port)
} }
(None, Some(x)) => { (None, Some(x)) => {
let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap(); let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap();
@ -138,3 +181,29 @@ fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String
Ok((hostname, port)) Ok((hostname, port))
} }
// fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
// #[allow(clippy::blocks_in_if_conditions)]
// if headers
// .get(&*CONNECTION_HEADER)
// .map(|value| {
// value
// .to_str()
// .unwrap()
// .split(',')
// .any(|e| e.trim() == *UPGRADE_HEADER)
// })
// .unwrap_or(false)
// {
// if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
// debug!(
// "Found upgrade header with value: {}",
// upgrade_value.to_str().unwrap().to_owned()
// );
// return Some(upgrade_value.to_str().unwrap().to_owned());
// }
// }
// None
// }

View file

@ -1,7 +1,7 @@
use super::proxy_handler::handle_request; // use super::proxy_handler::handle_request;
use crate::{backend::Backend, error::*, globals::Globals, log::*}; use crate::{backend::Backend, error::*, globals::Globals, log::*};
use hyper::{ use hyper::{
client::connect::Connect, server::conn::Http, service::service_fn, Body, Client, Method, Request, client::connect::Connect, server::conn::Http, service::service_fn, Body, Client, Request,
}; };
use std::{collections::HashMap, net::SocketAddr, sync::Arc}; use std::{collections::HashMap, net::SocketAddr, sync::Arc};
use tokio::{ use tokio::{
@ -64,15 +64,7 @@ where
// server.serve_connection(stream, self), // server.serve_connection(stream, self),
server.serve_connection( server.serve_connection(
stream, stream,
service_fn(move |req: Request<Body>| { service_fn(move |req: Request<Body>| self.clone().handle_request(req, peer_addr)),
handle_request(
req,
peer_addr,
self.tls_enabled,
self.globals.clone(),
self.backends.clone(),
)
}),
), ),
) )
.await .await