From d37ed57a1cd80efd5f22ff69dae3dd103e1bf69a Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Jul 2022 23:55:21 +0900 Subject: [PATCH] add checking mechanism of consistency between sni and host/request line --- config-example.toml | 1 + src/config/parse.rs | 6 ++++ src/config/toml.rs | 1 + src/globals.rs | 1 + src/main.rs | 1 + src/msg_handler/handler.rs | 25 +++++++++++----- src/msg_handler/utils_request.rs | 47 ++++++++++++++++++++----------- src/msg_handler/utils_response.rs | 4 +-- src/proxy/proxy_h3.rs | 47 ++++++++++++++++--------------- src/proxy/proxy_main.rs | 14 ++++++--- src/proxy/proxy_tls.rs | 33 +++++++++++----------- 11 files changed, 111 insertions(+), 69 deletions(-) diff --git a/config-example.toml b/config-example.toml index 3f567fd..68c0e8a 100644 --- a/config-example.toml +++ b/config-example.toml @@ -66,3 +66,4 @@ reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] ################################### [experimental] h3 = true +ignore_sni_consistency = false # Higly recommend not to be true. If true, you ignore RFC. diff --git a/src/config/parse.rs b/src/config/parse.rs index 484b941..fc0b19f 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -163,6 +163,12 @@ pub fn parse_opts(globals: &mut Globals) -> Result<()> { info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.") } } + if let Some(b) = exp.ignore_sni_consistency { + globals.sni_consistency = !b; + if b { + info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC.") + } + } } Ok(()) diff --git a/src/config/toml.rs b/src/config/toml.rs index 784da8c..19813ab 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -18,6 +18,7 @@ pub struct ConfigToml { #[derive(Deserialize, Debug, Default)] pub struct Experimental { pub h3: Option, + pub ignore_sni_consistency: Option, } #[derive(Deserialize, Debug, Default)] diff --git a/src/globals.rs b/src/globals.rs index 846390b..abec03f 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -19,6 +19,7 @@ pub struct Globals { pub max_concurrent_streams: u32, pub keepalive: bool, pub http3: bool, + pub sni_consistency: bool, pub runtime_handle: tokio::runtime::Handle, diff --git a/src/main.rs b/src/main.rs index 1abb2fd..591062d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -60,6 +60,7 @@ fn main() { http_port: None, https_port: None, http3: false, + sni_consistency: true, // TODO: Reconsider each timeout values proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), diff --git a/src/msg_handler/handler.rs b/src/msg_handler/handler.rs index eaab9b9..135b971 100644 --- a/src/msg_handler/handler.rs +++ b/src/msg_handler/handler.rs @@ -1,13 +1,19 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*}; -use crate::{backend::Upstream, constants::*, error::*, globals::Globals, log::*}; +use crate::{ + backend::{ServerNameLC, Upstream}, + constants::*, + error::*, + globals::Globals, + log::*, +}; use hyper::{ client::connect::Connect, header::{self, HeaderValue}, http::uri::Scheme, Body, Client, Request, Response, StatusCode, Uri, Version, }; -use std::{net::SocketAddr, sync::Arc}; +use std::{env, net::SocketAddr, sync::Arc}; use tokio::{ io::copy_bidirectional, time::{timeout, Duration}, @@ -32,14 +38,19 @@ where client_addr: SocketAddr, // アクセス制御用 listen_addr: SocketAddr, tls_enabled: bool, + tls_server_name: Option, ) -> Result> { req.log_debug(&client_addr, Some("(Request from Client)")); // Here we start to handle with server_name - // Find backend application for given server_name, and drop if incoming request is invalid as request. - // let (server_name, _port) = parse_host_port(&req)?; let server_name_bytes = req.parse_host()?.to_ascii_lowercase(); - + // check consistency of between TLS SNI and HOST/Request URI Line. + if self.globals.sni_consistency + && !server_name_bytes.eq_ignore_ascii_case(&tls_server_name.unwrap()) + { + return http_error(StatusCode::MISDIRECTED_REQUEST); + } + // Find backend application for given server_name, and drop if incoming request is invalid as request. let backend = if let Some(be) = self.globals.backends.apps.get(&server_name_bytes) { be } else if let Some(default_server_name) = &self.globals.backends.default_server_name { @@ -91,7 +102,7 @@ where return http_error(StatusCode::SERVICE_UNAVAILABLE); }; // debug!("Request to be forwarded: {:?}", req_forwarded); - req_forwarded.log(&client_addr, Some("(Request to Backend)")); + req_forwarded.log_debug(&client_addr, Some("(Request to Backend)")); // Forward request to let mut res_backend = { @@ -168,7 +179,7 @@ where // Generate response to client if self.generate_response_forwarded(&mut res_backend).is_ok() { // info!("{} => {}", request_log, response_log); - res_backend.log( + res_backend.log_debug( &backend.server_name, &client_addr, Some("(Response to Client)"), diff --git a/src/msg_handler/utils_request.rs b/src/msg_handler/utils_request.rs index fe9433d..d68054e 100644 --- a/src/msg_handler/utils_request.rs +++ b/src/msg_handler/utils_request.rs @@ -19,27 +19,40 @@ impl ReqLog for &Request { fn build_message(self, src: &T, extra: Option<&str>) -> String { let canonical_src = src.to_canonical(); - let server_name = self.headers().get(header::HOST).map_or_else( - || { - self - .uri() - .authority() - .map_or_else(|| "", |au| au.as_str()) - }, - |h| h.to_str().unwrap_or(""), - ); + let host = self + .headers() + .get(header::HOST) + .map_or_else(|| "", |v| v.to_str().unwrap_or("")); + let uri_scheme = self + .uri() + .scheme_str() + .map_or_else(|| "".to_string(), |v| format!("{}://", v)); + let uri_host = self.uri().host().unwrap_or(""); + let uri_pq = self + .uri() + .path_and_query() + .map_or_else(|| "", |v| v.as_str()); + let ua = self + .headers() + .get(header::USER_AGENT) + .map_or_else(|| "", |v| v.to_str().unwrap_or("")); + let xff = self + .headers() + .get("x-forwarded-for") + .map_or_else(|| "", |v| v.to_str().unwrap_or("")); + format!( - "{} <- {} -- {} {:?} {:?} {:?} {}", - server_name, + "{} <- {} -- {} {} {:?} -- ({}{}) \"{}\" \"{}\" {}", + host, canonical_src, self.method(), + uri_pq, self.version(), - self - .uri() - .path_and_query() - .map_or_else(|| "", |v| v.as_str()), - self.headers(), - extra.map_or_else(|| "", |v| v) + uri_scheme, + uri_host, + ua, + xff, + extra.unwrap_or("") ) } } diff --git a/src/msg_handler/utils_response.rs b/src/msg_handler/utils_response.rs index 3bcf89f..9be933f 100644 --- a/src/msg_handler/utils_response.rs +++ b/src/msg_handler/utils_response.rs @@ -49,12 +49,12 @@ impl ResLog for &Response { ) -> String { let canonical_client_addr = client_addr.to_canonical(); format!( - "{} <- {} -- {} {:?} {:?} {}", + "{} <- {} -- {} {:?} {}", canonical_client_addr, server_name, self.status(), self.version(), - self.headers(), + // self.headers(), extra.map_or_else(|| "", |v| v) ) } diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs index af9284c..069e533 100644 --- a/src/proxy/proxy_h3.rs +++ b/src/proxy/proxy_h3.rs @@ -1,5 +1,5 @@ use super::Proxy; -use crate::{error::*, log::*}; +use crate::{backend::ServerNameLC, error::*, log::*}; use bytes::{Buf, Bytes}; use h3::{quic::BidiStream, server::RequestStream}; use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response}; @@ -10,13 +10,15 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn client_serve_h3(&self, conn: quinn::Connecting) { + pub async fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) { let clients_count = self.globals.clients_count.clone(); if clients_count.increment() > self.globals.max_clients { clients_count.decrement(); return; } - let fut = self.clone().handle_connection_h3(conn); + let fut = self + .clone() + .handle_connection_h3(conn, tls_server_name.to_vec()); self.globals.runtime_handle.spawn(async move { // Timeout is based on underlying quic if let Err(e) = fut.await { @@ -27,31 +29,22 @@ where }); } - pub async fn handle_connection_h3(self, conn: quinn::Connecting) -> Result<()> { + pub async fn handle_connection_h3( + self, + conn: quinn::Connecting, + tls_server_name: ServerNameLC, + ) -> Result<()> { let client_addr = conn.remote_address(); match conn.await { Ok(new_conn) => { - info!("QUIC connection established from {:?} {:?}", client_addr, { - let hsd = new_conn - .connection - .handshake_data() - .ok_or_else(|| anyhow!(""))? - .downcast::() - .map_err(|_| anyhow!(""))?; - ( - hsd.protocol.map_or_else( - || "".into(), - |x| String::from_utf8_lossy(&x).into_owned(), - ), - hsd.server_name.map_or_else(|| "".into(), |x| x), - ) - }); - let mut h3_conn = h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)) .await?; - info!("HTTP/3 connection established"); + info!( + "QUIC/HTTP3 connection established from {:?} {:?}", + client_addr, tls_server_name + ); // Does this work enough? // while let Some((req, stream)) = h3_conn @@ -73,10 +66,11 @@ where ); let self_inner = self.clone(); + let tls_server_name_inner = tls_server_name.clone(); self.globals.runtime_handle.spawn(async move { if let Err(e) = timeout( self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 - self_inner.handle_stream_h3(req, stream, client_addr), + self_inner.handle_stream_h3(req, stream, client_addr, tls_server_name_inner), ) .await { @@ -99,6 +93,7 @@ where req: Request<()>, mut stream: RequestStream, client_addr: SocketAddr, + tls_server_name: ServerNameLC, ) -> Result<()> where S: BidiStream, @@ -128,7 +123,13 @@ where let res = self .msg_handler .clone() - .handle_request(new_req, client_addr, self.listening_on, self.tls_enabled) + .handle_request( + new_req, + client_addr, + self.listening_on, + self.tls_enabled, + Some(tls_server_name), + ) .await?; let (new_res_parts, new_body) = res.into_parts(); diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index 366045c..70054a3 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -45,8 +45,13 @@ impl Proxy where T: Connect + Clone + Sync + Send + 'static, { - pub async fn client_serve(self, stream: I, server: Http, peer_addr: SocketAddr) - where + pub async fn client_serve( + self, + stream: I, + server: Http, + peer_addr: SocketAddr, + tls_server_name: Option<&[u8]>, + ) where I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { let clients_count = self.globals.clients_count.clone(); @@ -55,7 +60,7 @@ where return; } - // let handler_inner = self.msg_handler.clone(); + let inner = tls_server_name.map_or_else(|| None, |v| Some(v.to_vec())); self.globals.runtime_handle.clone().spawn(async move { timeout( self.globals.proxy_timeout + Duration::from_secs(1), @@ -68,6 +73,7 @@ where peer_addr, self.listening_on, self.tls_enabled, + inner.clone(), ) }), ) @@ -88,7 +94,7 @@ where while let Ok((stream, _client_addr)) = tcp_listener.accept().await { self .clone() - .client_serve(stream, server.clone(), _client_addr) + .client_serve(stream, server.clone(), _client_addr, None) .await; } Ok(()) as Result<()> diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index c4a9d70..bc38b3c 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -85,7 +85,7 @@ where }; // Finally serve the TLS connection if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await { - self.clone().client_serve(stream, server.clone(), _client_addr).await + self.clone().client_serve(stream, server.clone(), _client_addr, Some(server_name.as_bytes())).await } } _ = server_crypto_rx.changed().fuse() => { @@ -101,11 +101,11 @@ where } #[cfg(feature = "h3")] - async fn parse_sni_and_get_crypto_h3( + async fn parse_sni_and_get_crypto_h3<'a>( &self, peeked_conn: &mut quinn::Connecting, - server_crypto_map: &ServerCryptoMap, - ) -> Option> { + server_crypto_map: &'a ServerCryptoMap, + ) -> Option<(&'a ServerNameLC, &'a Arc)> { let hsd = if let Ok(h) = peeked_conn.handshake_data().await { h } else { @@ -121,9 +121,8 @@ where "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", server_name ); - server_crypto_map - .get(&server_name.as_bytes().to_vec()) - .cloned() + server_crypto_map.get_key_value(&server_name.into_bytes()) + // .map_or_else(|| None, |(k, v)| Some((k.clone(), v.clone()))); } #[cfg(feature = "h3")] @@ -173,19 +172,21 @@ where continue; } let peeked_conn = peeked_conn.unwrap(); - let is_acceptable = - if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { + + let new_server_name = match self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { + Some((new_server_name, new_server_crypto)) => { // Set ServerConfig::set_server_config for given SNI - endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto))); - true - } else { - false - }; + endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone()))); + Some(new_server_name) + }, + None => None + }; + // Then acquire actual connection let peekable_incoming = Pin::new(&mut p); if let Some(conn) = peekable_incoming.get_mut().next().await { - if is_acceptable { - self.clone().client_serve_h3(conn).await; + if let Some(new_server_name) = new_server_name { + self.clone().client_serve_h3(conn, new_server_name).await; } } else { continue;