some refactor

This commit is contained in:
Jun Kurihara 2022-07-09 23:57:08 +09:00
commit 9335fc702e
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
6 changed files with 146 additions and 88 deletions

View file

@ -91,7 +91,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()>
// each app // each app
for (app_name, app) in apps.0.iter() { for (app_name, app) in apps.0.iter() {
ensure!(app.server_name.is_some(), "Missing server_name"); ensure!(app.server_name.is_some(), "Missing server_name");
let server_name = app.server_name.as_ref().unwrap(); let server_name = app.server_name.as_ref().unwrap().to_ascii_lowercase();
// TLS settings // TLS settings
let (tls_cert_path, tls_cert_key_path, https_redirection) = if app.tls.is_none() { let (tls_cert_path, tls_cert_key_path, https_redirection) = if app.tls.is_none() {
@ -122,7 +122,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()>
let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?;
backends.apps.insert( backends.apps.insert(
server_name.to_owned(), server_name.as_bytes().to_vec(),
Backend { Backend {
app_name: app_name.to_owned(), app_name: app_name.to_owned(),
server_name: server_name.to_owned(), server_name: server_name.to_owned(),
@ -149,7 +149,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()>
"Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).",
d, d_sn[0] d, d_sn[0]
); );
backends.default_app = Some(d_sn[0].to_owned()); backends.default_server_name = Some(d_sn[0].as_bytes().to_vec());
} }
} }

View file

@ -16,7 +16,7 @@ use crate::{
error::*, error::*,
globals::*, globals::*,
log::*, log::*,
proxy::{Backend, Backends, Proxy}, proxy::{Backend, Backends, Proxy, ServerNameLC},
}; };
use futures::future::select_all; use futures::future::select_all;
use hyper::Client; use hyper::Client;
@ -64,8 +64,8 @@ fn main() {
}; };
let mut backends = Backends { let mut backends = Backends {
default_app: None, default_server_name: None,
apps: HashMap::<String, Backend>::default(), apps: HashMap::<ServerNameLC, Backend>::default(),
}; };
parse_opts(&mut globals, &mut backends).expect("Invalid configuration"); parse_opts(&mut globals, &mut backends).expect("Invalid configuration");

View file

@ -13,9 +13,12 @@ use std::{
}; };
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
// server name (hostname or ip address) in ascii lower case
pub type ServerNameLC = Vec<u8>;
pub struct Backends { pub struct Backends {
pub apps: HashMap<String, Backend>, // TODO: hyper::uriで抜いたhostで引っ掛ける。Stringでいいのか pub apps: HashMap<ServerNameLC, Backend>, // TODO: hyper::uriで抜いたhostで引っ掛ける。Stringでいいのか
pub default_app: Option<String>, // for plaintext http pub default_server_name: Option<ServerNameLC>, // for plaintext http
} }
pub struct Backend { pub struct Backend {

View file

@ -23,27 +23,27 @@ where
// Here we start to handle with server_name // Here we start to handle with server_name
// Find backend application for given server_name, and drop if incoming request is invalid as request. // Find backend application for given server_name, and drop if incoming request is invalid as request.
let (server_name, _port) = parse_host_port(&req)?; // let (server_name, _port) = parse_host_port(&req)?;
let server_name_bytes = req.parse_host()?.to_ascii_lowercase();
if !self.backends.apps.contains_key(&server_name) && self.backends.default_app.is_none() { let backend = if let Some(be) = self.backends.apps.get(&server_name_bytes) {
be
} else if let Some(default_server_name) = &self.backends.default_server_name {
debug!("Serving by default app");
self.backends.apps.get(default_server_name).unwrap()
} else {
// info!("{} => {}", request_log, StatusCode::SERVICE_UNAVAILABLE); // info!("{} => {}", request_log, StatusCode::SERVICE_UNAVAILABLE);
return http_error(StatusCode::SERVICE_UNAVAILABLE); return http_error(StatusCode::SERVICE_UNAVAILABLE);
}
let backend = if let Some(be) = self.backends.apps.get(&server_name) {
be
} else {
let default_be = self.backends.default_app.as_ref().unwrap();
debug!("Serving by default app: {}", default_be);
self.backends.apps.get(default_be).unwrap()
}; };
// Redirect to https if !tls_enabled and redirect_to_https is true // Redirect to https if !tls_enabled and redirect_to_https is true
if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { if !self.tls_enabled && backend.https_redirection.unwrap_or(false) {
debug!("Redirect to secure connection: {}", server_name); debug!("Redirect to secure connection: {}", &backend.server_name);
// info!("{} => {}", request_log, StatusCode::PERMANENT_REDIRECT); // info!("{} => {}", request_log, StatusCode::PERMANENT_REDIRECT);
return secure_redirection(&server_name, self.globals.https_port, &req); return secure_redirection(&backend.server_name, self.globals.https_port, &req);
} }
///////////////////////
// Find reverse proxy for given path and choose one of upstream host // Find reverse proxy for given path and choose one of upstream host
// TODO: More flexible path matcher // TODO: More flexible path matcher
let path = req.uri().path(); let path = req.uri().path();
@ -59,6 +59,7 @@ where
} else { } else {
return http_error(StatusCode::INTERNAL_SERVER_ERROR); return http_error(StatusCode::INTERNAL_SERVER_ERROR);
}; };
///////////////////////
// Upgrade in request header // Upgrade in request header
let upgrade_in_request = extract_upgrade(req.headers()); let upgrade_in_request = extract_upgrade(req.headers());
@ -77,7 +78,8 @@ where
error!("Failed to generate destination uri for reverse proxy"); error!("Failed to generate destination uri for reverse proxy");
return http_error(StatusCode::SERVICE_UNAVAILABLE); return http_error(StatusCode::SERVICE_UNAVAILABLE);
}; };
debug!("Request to be forwarded: {:?}", req_forwarded); // debug!("Request to be forwarded: {:?}", req_forwarded);
req_forwarded.log(&client_addr, Some("Forwarding"));
// Forward request to // Forward request to
let mut res_backend = match self.forwarder.request(req_forwarded).await { let mut res_backend = match self.forwarder.request(req_forwarded).await {
@ -87,21 +89,6 @@ where
return http_error(StatusCode::BAD_REQUEST); return http_error(StatusCode::BAD_REQUEST);
} }
}; };
#[cfg(feature = "h3")]
{
if self.globals.http3 {
if let Some(port) = self.globals.https_port {
let alt_svc_value = HeaderValue::from_str(&format!(
"h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE
))
.unwrap();
res_backend
.headers_mut()
.insert(header::ALT_SVC, alt_svc_value);
}
}
}
debug!("Response from backend: {:?}", res_backend.status()); debug!("Response from backend: {:?}", res_backend.status());
// let response_log = res_backend.status().to_string(); // let response_log = res_backend.status().to_string();
@ -175,6 +162,21 @@ where
"server", "server",
&format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")), &format!("{}/{}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")),
)?; )?;
#[cfg(feature = "h3")]
{
if self.globals.http3 {
if let Some(port) = self.globals.https_port {
append_header_entry(
headers,
header::ALT_SVC.as_str(),
&format!(
"h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
port, H3_ALT_SVC_MAX_AGE, port, H3_ALT_SVC_MAX_AGE
),
)?;
}
}
}
Ok(()) Ok(())
} }
@ -192,7 +194,9 @@ where
// Add te: trailer if contained in original request // Add te: trailer if contained in original request
let te_trailers = { let te_trailers = {
if let Some(te) = req.headers().get(header::TE) { if let Some(te) = req.headers().get(header::TE) {
te.to_str()?.split(',').any(|x| x.trim() == "trailers") te.as_bytes()
.split(|v| v == &b',' || v == &b' ')
.any(|x| x == "trailers".as_bytes())
} else { } else {
false false
} }
@ -205,7 +209,6 @@ where
remove_hop_header(headers); remove_hop_header(headers);
// X-Forwarded-For // X-Forwarded-For
add_forwarding_header(headers, client_addr, self.tls_enabled)?; add_forwarding_header(headers, client_addr, self.tls_enabled)?;
// println!("{:?}", headers);
// Add te: trailer if te_trailer // Add te: trailer if te_trailer
if te_trailers { if te_trailers {
@ -214,10 +217,14 @@ where
// add "host" header of original server_name if not exist (default) // add "host" header of original server_name if not exist (default)
if req.headers().get(header::HOST).is_none() { if req.headers().get(header::HOST).is_none() {
let org_host = req.uri().host().unwrap_or("none").to_owned(); let org_host = req
.uri()
.host()
.ok_or_else(|| anyhow!("Invalid request"))?
.to_owned();
req req
.headers_mut() .headers_mut()
.insert(header::HOST, HeaderValue::from_str(org_host.as_str())?); .insert(header::HOST, HeaderValue::from_str(&org_host)?);
}; };
// apply upstream-specific headers given in upstream_option // apply upstream-specific headers given in upstream_option

View file

@ -1,4 +1,7 @@
use super::proxy_main::{LocalExecutor, Proxy}; use super::{
proxy_main::{LocalExecutor, Proxy},
ServerNameLC,
};
use crate::{constants::*, error::*, log::*}; use crate::{constants::*, error::*, log::*};
#[cfg(feature = "h3")] #[cfg(feature = "h3")]
use futures::StreamExt; use futures::StreamExt;
@ -9,7 +12,7 @@ use rustls::ServerConfig;
use std::sync::Arc; use std::sync::Arc;
use tokio::{net::TcpListener, sync::watch, time::Duration}; use tokio::{net::TcpListener, sync::watch, time::Duration};
type ServerCryptoMap = HashMap<String, ServerConfig>; type ServerCryptoMap = HashMap<ServerNameLC, Arc<ServerConfig>>;
impl<T> Proxy<T> impl<T> Proxy<T>
where where
@ -18,16 +21,19 @@ where
async fn cert_service(&self, server_crypto_tx: watch::Sender<Option<ServerCryptoMap>>) { async fn cert_service(&self, server_crypto_tx: watch::Sender<Option<ServerCryptoMap>>) {
info!("Start cert watch service"); info!("Start cert watch service");
loop { loop {
let mut hm_server_config = HashMap::<String, ServerConfig>::default(); let mut hm_server_config = HashMap::<ServerNameLC, Arc<ServerConfig>>::default();
for (server_name, backend) in self.backends.apps.iter() { for (server_name_bytes, backend) in self.backends.apps.iter() {
if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() {
match backend.update_server_config().await { match backend.update_server_config().await {
Err(_e) => { Err(_e) => {
error!("Failed to update certs for {}: {}", server_name, _e); error!(
"Failed to update certs for {}: {}",
&backend.server_name, _e
);
break; break;
} }
Ok(server_config) => { Ok(server_config) => {
hm_server_config.insert(server_name.to_owned(), server_config); hm_server_config.insert(server_name_bytes.to_vec(), Arc::new(server_config));
} }
} }
} }
@ -66,21 +72,20 @@ where
let start = acceptor.unwrap(); let start = acceptor.unwrap();
let client_hello = start.client_hello(); let client_hello = start.client_hello();
debug!("SNI in ClientHello: {:?}", client_hello.server_name());
// Find server config for given SNI // Find server config for given SNI
let svn = if let Some(svn) = client_hello.server_name() { if client_hello.server_name().is_none(){
svn
} else {
info!("No SNI in ClientHello"); info!("No SNI in ClientHello");
continue; continue;
}; }
let server_crypto = if let Some(p) = server_crypto_map.as_ref().unwrap().get(svn) { let server_name = client_hello.server_name().unwrap().to_ascii_lowercase();
p.to_owned() debug!("SNI in ClientHello: {:?}", server_name);
} else { let server_crypto = server_crypto_map.as_ref().unwrap().get(server_name.as_bytes());
if server_crypto.is_none() {
debug!("No TLS serving app for {}", server_name);
continue; continue;
}; };
// Finally serve the TLS connection // Finally serve the TLS connection
if let Ok(stream) = start.into_stream(Arc::new(server_crypto)).await { if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await {
self.clone().client_serve(stream, server.clone(), _client_addr).await self.clone().client_serve(stream, server.clone(), _client_addr).await
} }
} }
@ -101,7 +106,7 @@ where
&self, &self,
peeked_conn: &mut quinn::Connecting, peeked_conn: &mut quinn::Connecting,
server_crypto_map: &ServerCryptoMap, server_crypto_map: &ServerCryptoMap,
) -> Option<ServerConfig> { ) -> Option<Arc<ServerConfig>> {
let hsd = if let Ok(h) = peeked_conn.handshake_data().await { let hsd = if let Ok(h) = peeked_conn.handshake_data().await {
h h
} else { } else {
@ -112,12 +117,14 @@ where
} else { } else {
return None; return None;
}; };
let server_name = hsd_downcast.server_name?; let server_name = hsd_downcast.server_name?.to_ascii_lowercase();
info!( info!(
"HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig",
server_name server_name
); );
server_crypto_map.get(&server_name).cloned() server_crypto_map
.get(&server_name.as_bytes().to_vec())
.cloned()
} }
#[cfg(feature = "h3")] #[cfg(feature = "h3")]
@ -127,20 +134,20 @@ where
) -> Result<()> { ) -> Result<()> {
// TODO: Work around to initially serve incoming connection // TODO: Work around to initially serve incoming connection
// かなり適当。エラーが出たり出なかったり。原因がわからない… // かなり適当。エラーが出たり出なかったり。原因がわからない…
let tls_app_names: Vec<String> = self let next = self
.backends .backends
.apps .apps
.iter() .iter()
.filter(|&(_, backend)| { .filter(|&(_, backend)| {
backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some()
}) })
.map(|(name, _)| name.to_string()) .map(|(name, _)| name)
.collect(); .next();
ensure!(!tls_app_names.is_empty(), "No TLS supported app"); ensure!(next.is_some(), "No TLS supported app");
let initial_app_name = tls_app_names.get(0).ok_or_else(|| anyhow!(""))?.as_str(); let initial_app_name = next.ok_or_else(|| anyhow!(""))?;
debug!( debug!(
"HTTP/3 SNI multiplexer initial app_name: {}", "HTTP/3 SNI multiplexer initial app_name: {:?}",
initial_app_name String::from_utf8(initial_app_name.to_vec())
); );
let backend_serve = self let backend_serve = self
.backends .backends
@ -168,7 +175,7 @@ where
let is_acceptable = let is_acceptable =
if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await {
// Set ServerConfig::set_server_config for given SNI // Set ServerConfig::set_server_config for given SNI
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)))); endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto)));
true true
} else { } else {
false false

View file

@ -1,5 +1,5 @@
use crate::{error::*, log::*, utils::*}; use crate::{error::*, log::*, utils::*};
use hyper::{header, Request, Uri}; use hyper::{header, Request};
use std::fmt::Display; use std::fmt::Display;
//////////////////////////////////////////////////// ////////////////////////////////////////////////////
@ -36,12 +36,14 @@ impl<B> MsgLog for &Request<B> {
} }
} }
pub(super) fn parse_host_port<B: core::fmt::Debug>( pub trait ParseHost {
req: &Request<B>, fn parse_host(&self) -> Result<&[u8]>;
) -> Result<(String, Option<u16>)> { }
let headers_host = req.headers().get("host"); impl<B> ParseHost for Request<B> {
let uri_host = req.uri().host(); fn parse_host(&self) -> Result<&[u8]> {
let uri_port = req.uri().port_u16(); let headers_host = self.headers().get(header::HOST);
let uri_host = self.uri().host();
// let uri_port = self.uri().port_u16();
ensure!( ensure!(
!(headers_host.is_none() && uri_host.is_none()), !(headers_host.is_none() && uri_host.is_none()),
@ -49,16 +51,55 @@ pub(super) fn parse_host_port<B: core::fmt::Debug>(
); );
// prioritize server_name in uri // prioritize server_name in uri
if let Some(v) = uri_host { uri_host.map_or_else(
Ok((v.to_string(), uri_port)) || {
let m = headers_host.unwrap().as_bytes();
if m.starts_with(&[b'[']) {
println!("v6 bracket");
// v6 address with bracket case. if port is specified, always it is in this case.
let mut iter = m.split(|ptr| ptr == &b'[' || ptr == &b']');
iter.next().ok_or_else(|| anyhow!("Invalid Host"))?; // first item is always blank
iter.next().ok_or_else(|| anyhow!("Invalid Host"))
} else if m.len() - m.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 {
println!("v6 non-bracket");
// v6 address case, if 2 or more ':' is contained
Ok(m)
} else { } else {
let uri_from_host = headers_host.unwrap().to_str()?.parse::<Uri>()?; // v4 address or hostname
Ok(( m.split(|colon| colon == &b':')
uri_from_host .into_iter()
.host() .next()
.ok_or_else(|| anyhow!("Failed to parse host"))? .ok_or_else(|| anyhow!("Invalid Host"))
.to_string(), }
uri_from_host.port_u16(), },
)) |v| Ok(v.as_bytes()),
)
} }
} }
// pub(super) fn parse_host_port<B: core::fmt::Debug>(
// req: &Request<B>,
// ) -> Result<(String, Option<u16>)> {
// let headers_host = req.headers().get("host");
// let uri_host = req.uri().host();
// let uri_port = req.uri().port_u16();
// ensure!(
// !(headers_host.is_none() && uri_host.is_none()),
// "No host in request header"
// );
// // prioritize server_name in uri
// if let Some(v) = uri_host {
// Ok((v.to_string(), uri_port))
// } else {
// let uri_from_host = headers_host.unwrap().to_str()?.parse::<Uri>()?;
// Ok((
// uri_from_host
// .host()
// .ok_or_else(|| anyhow!("Failed to parse host"))?
// .to_string(),
// uri_from_host.port_u16(),
// ))
// }
// }