add checking mechanism of consistency between sni and host/request line

This commit is contained in:
Jun Kurihara 2022-07-15 23:55:21 +09:00
commit d37ed57a1c
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
11 changed files with 111 additions and 69 deletions

View file

@ -66,3 +66,4 @@ reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }]
###################################
[experimental]
h3 = true
ignore_sni_consistency = false # Higly recommend not to be true. If true, you ignore RFC.

View file

@ -163,6 +163,12 @@ pub fn parse_opts(globals: &mut Globals) -> Result<()> {
info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.")
}
}
if let Some(b) = exp.ignore_sni_consistency {
globals.sni_consistency = !b;
if b {
info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC.")
}
}
}
Ok(())

View file

@ -18,6 +18,7 @@ pub struct ConfigToml {
#[derive(Deserialize, Debug, Default)]
pub struct Experimental {
pub h3: Option<bool>,
pub ignore_sni_consistency: Option<bool>,
}
#[derive(Deserialize, Debug, Default)]

View file

@ -19,6 +19,7 @@ pub struct Globals {
pub max_concurrent_streams: u32,
pub keepalive: bool,
pub http3: bool,
pub sni_consistency: bool,
pub runtime_handle: tokio::runtime::Handle,

View file

@ -60,6 +60,7 @@ fn main() {
http_port: None,
https_port: None,
http3: false,
sni_consistency: true,
// TODO: Reconsider each timeout values
proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC),

View file

@ -1,13 +1,19 @@
// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*};
use crate::{backend::Upstream, constants::*, error::*, globals::Globals, log::*};
use crate::{
backend::{ServerNameLC, Upstream},
constants::*,
error::*,
globals::Globals,
log::*,
};
use hyper::{
client::connect::Connect,
header::{self, HeaderValue},
http::uri::Scheme,
Body, Client, Request, Response, StatusCode, Uri, Version,
};
use std::{net::SocketAddr, sync::Arc};
use std::{env, net::SocketAddr, sync::Arc};
use tokio::{
io::copy_bidirectional,
time::{timeout, Duration},
@ -32,14 +38,19 @@ where
client_addr: SocketAddr, // アクセス制御用
listen_addr: SocketAddr,
tls_enabled: bool,
tls_server_name: Option<ServerNameLC>,
) -> Result<Response<Body>> {
req.log_debug(&client_addr, Some("(Request from Client)"));
// Here we start to handle with server_name
// Find backend application for given server_name, and drop if incoming request is invalid as request.
// let (server_name, _port) = parse_host_port(&req)?;
let server_name_bytes = req.parse_host()?.to_ascii_lowercase();
// check consistency of between TLS SNI and HOST/Request URI Line.
if self.globals.sni_consistency
&& !server_name_bytes.eq_ignore_ascii_case(&tls_server_name.unwrap())
{
return http_error(StatusCode::MISDIRECTED_REQUEST);
}
// Find backend application for given server_name, and drop if incoming request is invalid as request.
let backend = if let Some(be) = self.globals.backends.apps.get(&server_name_bytes) {
be
} else if let Some(default_server_name) = &self.globals.backends.default_server_name {
@ -91,7 +102,7 @@ where
return http_error(StatusCode::SERVICE_UNAVAILABLE);
};
// debug!("Request to be forwarded: {:?}", req_forwarded);
req_forwarded.log(&client_addr, Some("(Request to Backend)"));
req_forwarded.log_debug(&client_addr, Some("(Request to Backend)"));
// Forward request to
let mut res_backend = {
@ -168,7 +179,7 @@ where
// Generate response to client
if self.generate_response_forwarded(&mut res_backend).is_ok() {
// info!("{} => {}", request_log, response_log);
res_backend.log(
res_backend.log_debug(
&backend.server_name,
&client_addr,
Some("(Response to Client)"),

View file

@ -19,27 +19,40 @@ impl<B> ReqLog for &Request<B> {
fn build_message<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>) -> String {
let canonical_src = src.to_canonical();
let server_name = self.headers().get(header::HOST).map_or_else(
|| {
self
.uri()
.authority()
.map_or_else(|| "<none>", |au| au.as_str())
},
|h| h.to_str().unwrap_or("<none>"),
);
let host = self
.headers()
.get(header::HOST)
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
let uri_scheme = self
.uri()
.scheme_str()
.map_or_else(|| "".to_string(), |v| format!("{}://", v));
let uri_host = self.uri().host().unwrap_or("");
let uri_pq = self
.uri()
.path_and_query()
.map_or_else(|| "", |v| v.as_str());
let ua = self
.headers()
.get(header::USER_AGENT)
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
let xff = self
.headers()
.get("x-forwarded-for")
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
format!(
"{} <- {} -- {} {:?} {:?} {:?} {}",
server_name,
"{} <- {} -- {} {} {:?} -- ({}{}) \"{}\" \"{}\" {}",
host,
canonical_src,
self.method(),
uri_pq,
self.version(),
self
.uri()
.path_and_query()
.map_or_else(|| "", |v| v.as_str()),
self.headers(),
extra.map_or_else(|| "", |v| v)
uri_scheme,
uri_host,
ua,
xff,
extra.unwrap_or("")
)
}
}

View file

@ -49,12 +49,12 @@ impl<B> ResLog for &Response<B> {
) -> String {
let canonical_client_addr = client_addr.to_canonical();
format!(
"{} <- {} -- {} {:?} {:?} {}",
"{} <- {} -- {} {:?} {}",
canonical_client_addr,
server_name,
self.status(),
self.version(),
self.headers(),
// self.headers(),
extra.map_or_else(|| "", |v| v)
)
}

View file

@ -1,5 +1,5 @@
use super::Proxy;
use crate::{error::*, log::*};
use crate::{backend::ServerNameLC, error::*, log::*};
use bytes::{Buf, Bytes};
use h3::{quic::BidiStream, server::RequestStream};
use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response};
@ -10,13 +10,15 @@ impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
pub async fn client_serve_h3(&self, conn: quinn::Connecting) {
pub async fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) {
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
return;
}
let fut = self.clone().handle_connection_h3(conn);
let fut = self
.clone()
.handle_connection_h3(conn, tls_server_name.to_vec());
self.globals.runtime_handle.spawn(async move {
// Timeout is based on underlying quic
if let Err(e) = fut.await {
@ -27,31 +29,22 @@ where
});
}
pub async fn handle_connection_h3(self, conn: quinn::Connecting) -> Result<()> {
pub async fn handle_connection_h3(
self,
conn: quinn::Connecting,
tls_server_name: ServerNameLC,
) -> Result<()> {
let client_addr = conn.remote_address();
match conn.await {
Ok(new_conn) => {
info!("QUIC connection established from {:?} {:?}", client_addr, {
let hsd = new_conn
.connection
.handshake_data()
.ok_or_else(|| anyhow!(""))?
.downcast::<quinn::crypto::rustls::HandshakeData>()
.map_err(|_| anyhow!(""))?;
(
hsd.protocol.map_or_else(
|| "<none>".into(),
|x| String::from_utf8_lossy(&x).into_owned(),
),
hsd.server_name.map_or_else(|| "<none>".into(), |x| x),
)
});
let mut h3_conn =
h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn))
.await?;
info!("HTTP/3 connection established");
info!(
"QUIC/HTTP3 connection established from {:?} {:?}",
client_addr, tls_server_name
);
// Does this work enough?
// while let Some((req, stream)) = h3_conn
@ -73,10 +66,11 @@ where
);
let self_inner = self.clone();
let tls_server_name_inner = tls_server_name.clone();
self.globals.runtime_handle.spawn(async move {
if let Err(e) = timeout(
self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
self_inner.handle_stream_h3(req, stream, client_addr),
self_inner.handle_stream_h3(req, stream, client_addr, tls_server_name_inner),
)
.await
{
@ -99,6 +93,7 @@ where
req: Request<()>,
mut stream: RequestStream<S, Bytes>,
client_addr: SocketAddr,
tls_server_name: ServerNameLC,
) -> Result<()>
where
S: BidiStream<Bytes>,
@ -128,7 +123,13 @@ where
let res = self
.msg_handler
.clone()
.handle_request(new_req, client_addr, self.listening_on, self.tls_enabled)
.handle_request(
new_req,
client_addr,
self.listening_on,
self.tls_enabled,
Some(tls_server_name),
)
.await?;
let (new_res_parts, new_body) = res.into_parts();

View file

@ -45,8 +45,13 @@ impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr)
where
pub async fn client_serve<I>(
self,
stream: I,
server: Http<LocalExecutor>,
peer_addr: SocketAddr,
tls_server_name: Option<&[u8]>,
) where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let clients_count = self.globals.clients_count.clone();
@ -55,7 +60,7 @@ where
return;
}
// let handler_inner = self.msg_handler.clone();
let inner = tls_server_name.map_or_else(|| None, |v| Some(v.to_vec()));
self.globals.runtime_handle.clone().spawn(async move {
timeout(
self.globals.proxy_timeout + Duration::from_secs(1),
@ -68,6 +73,7 @@ where
peer_addr,
self.listening_on,
self.tls_enabled,
inner.clone(),
)
}),
)
@ -88,7 +94,7 @@ where
while let Ok((stream, _client_addr)) = tcp_listener.accept().await {
self
.clone()
.client_serve(stream, server.clone(), _client_addr)
.client_serve(stream, server.clone(), _client_addr, None)
.await;
}
Ok(()) as Result<()>

View file

@ -85,7 +85,7 @@ where
};
// Finally serve the TLS connection
if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await {
self.clone().client_serve(stream, server.clone(), _client_addr).await
self.clone().client_serve(stream, server.clone(), _client_addr, Some(server_name.as_bytes())).await
}
}
_ = server_crypto_rx.changed().fuse() => {
@ -101,11 +101,11 @@ where
}
#[cfg(feature = "h3")]
async fn parse_sni_and_get_crypto_h3(
async fn parse_sni_and_get_crypto_h3<'a>(
&self,
peeked_conn: &mut quinn::Connecting,
server_crypto_map: &ServerCryptoMap,
) -> Option<Arc<ServerConfig>> {
server_crypto_map: &'a ServerCryptoMap,
) -> Option<(&'a ServerNameLC, &'a Arc<ServerConfig>)> {
let hsd = if let Ok(h) = peeked_conn.handshake_data().await {
h
} else {
@ -121,9 +121,8 @@ where
"HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig",
server_name
);
server_crypto_map
.get(&server_name.as_bytes().to_vec())
.cloned()
server_crypto_map.get_key_value(&server_name.into_bytes())
// .map_or_else(|| None, |(k, v)| Some((k.clone(), v.clone())));
}
#[cfg(feature = "h3")]
@ -173,19 +172,21 @@ where
continue;
}
let peeked_conn = peeked_conn.unwrap();
let is_acceptable =
if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await {
let new_server_name = match self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await {
Some((new_server_name, new_server_crypto)) => {
// Set ServerConfig::set_server_config for given SNI
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto)));
true
} else {
false
};
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone())));
Some(new_server_name)
},
None => None
};
// Then acquire actual connection
let peekable_incoming = Pin::new(&mut p);
if let Some(conn) = peekable_incoming.get_mut().next().await {
if is_acceptable {
self.clone().client_serve_h3(conn).await;
if let Some(new_server_name) = new_server_name {
self.clone().client_serve_h3(conn, new_server_name).await;
}
} else {
continue;