add checking mechanism of consistency between sni and host/request line
This commit is contained in:
parent
4f5a1cbf91
commit
d37ed57a1c
11 changed files with 111 additions and 69 deletions
|
|
@ -66,3 +66,4 @@ reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }]
|
|||
###################################
|
||||
[experimental]
|
||||
h3 = true
|
||||
ignore_sni_consistency = false # Higly recommend not to be true. If true, you ignore RFC.
|
||||
|
|
|
|||
|
|
@ -163,6 +163,12 @@ pub fn parse_opts(globals: &mut Globals) -> Result<()> {
|
|||
info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.")
|
||||
}
|
||||
}
|
||||
if let Some(b) = exp.ignore_sni_consistency {
|
||||
globals.sni_consistency = !b;
|
||||
if b {
|
||||
info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ pub struct ConfigToml {
|
|||
#[derive(Deserialize, Debug, Default)]
|
||||
pub struct Experimental {
|
||||
pub h3: Option<bool>,
|
||||
pub ignore_sni_consistency: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Default)]
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ pub struct Globals {
|
|||
pub max_concurrent_streams: u32,
|
||||
pub keepalive: bool,
|
||||
pub http3: bool,
|
||||
pub sni_consistency: bool,
|
||||
|
||||
pub runtime_handle: tokio::runtime::Handle,
|
||||
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ fn main() {
|
|||
http_port: None,
|
||||
https_port: None,
|
||||
http3: false,
|
||||
sni_consistency: true,
|
||||
|
||||
// TODO: Reconsider each timeout values
|
||||
proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC),
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
|
||||
use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*};
|
||||
use crate::{backend::Upstream, constants::*, error::*, globals::Globals, log::*};
|
||||
use crate::{
|
||||
backend::{ServerNameLC, Upstream},
|
||||
constants::*,
|
||||
error::*,
|
||||
globals::Globals,
|
||||
log::*,
|
||||
};
|
||||
use hyper::{
|
||||
client::connect::Connect,
|
||||
header::{self, HeaderValue},
|
||||
http::uri::Scheme,
|
||||
Body, Client, Request, Response, StatusCode, Uri, Version,
|
||||
};
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{env, net::SocketAddr, sync::Arc};
|
||||
use tokio::{
|
||||
io::copy_bidirectional,
|
||||
time::{timeout, Duration},
|
||||
|
|
@ -32,14 +38,19 @@ where
|
|||
client_addr: SocketAddr, // アクセス制御用
|
||||
listen_addr: SocketAddr,
|
||||
tls_enabled: bool,
|
||||
tls_server_name: Option<ServerNameLC>,
|
||||
) -> Result<Response<Body>> {
|
||||
req.log_debug(&client_addr, Some("(Request from Client)"));
|
||||
|
||||
// Here we start to handle with server_name
|
||||
// Find backend application for given server_name, and drop if incoming request is invalid as request.
|
||||
// let (server_name, _port) = parse_host_port(&req)?;
|
||||
let server_name_bytes = req.parse_host()?.to_ascii_lowercase();
|
||||
|
||||
// check consistency of between TLS SNI and HOST/Request URI Line.
|
||||
if self.globals.sni_consistency
|
||||
&& !server_name_bytes.eq_ignore_ascii_case(&tls_server_name.unwrap())
|
||||
{
|
||||
return http_error(StatusCode::MISDIRECTED_REQUEST);
|
||||
}
|
||||
// Find backend application for given server_name, and drop if incoming request is invalid as request.
|
||||
let backend = if let Some(be) = self.globals.backends.apps.get(&server_name_bytes) {
|
||||
be
|
||||
} else if let Some(default_server_name) = &self.globals.backends.default_server_name {
|
||||
|
|
@ -91,7 +102,7 @@ where
|
|||
return http_error(StatusCode::SERVICE_UNAVAILABLE);
|
||||
};
|
||||
// debug!("Request to be forwarded: {:?}", req_forwarded);
|
||||
req_forwarded.log(&client_addr, Some("(Request to Backend)"));
|
||||
req_forwarded.log_debug(&client_addr, Some("(Request to Backend)"));
|
||||
|
||||
// Forward request to
|
||||
let mut res_backend = {
|
||||
|
|
@ -168,7 +179,7 @@ where
|
|||
// Generate response to client
|
||||
if self.generate_response_forwarded(&mut res_backend).is_ok() {
|
||||
// info!("{} => {}", request_log, response_log);
|
||||
res_backend.log(
|
||||
res_backend.log_debug(
|
||||
&backend.server_name,
|
||||
&client_addr,
|
||||
Some("(Response to Client)"),
|
||||
|
|
|
|||
|
|
@ -19,27 +19,40 @@ impl<B> ReqLog for &Request<B> {
|
|||
fn build_message<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>) -> String {
|
||||
let canonical_src = src.to_canonical();
|
||||
|
||||
let server_name = self.headers().get(header::HOST).map_or_else(
|
||||
|| {
|
||||
self
|
||||
.uri()
|
||||
.authority()
|
||||
.map_or_else(|| "<none>", |au| au.as_str())
|
||||
},
|
||||
|h| h.to_str().unwrap_or("<none>"),
|
||||
);
|
||||
let host = self
|
||||
.headers()
|
||||
.get(header::HOST)
|
||||
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
|
||||
let uri_scheme = self
|
||||
.uri()
|
||||
.scheme_str()
|
||||
.map_or_else(|| "".to_string(), |v| format!("{}://", v));
|
||||
let uri_host = self.uri().host().unwrap_or("");
|
||||
let uri_pq = self
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map_or_else(|| "", |v| v.as_str());
|
||||
let ua = self
|
||||
.headers()
|
||||
.get(header::USER_AGENT)
|
||||
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
|
||||
let xff = self
|
||||
.headers()
|
||||
.get("x-forwarded-for")
|
||||
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
|
||||
|
||||
format!(
|
||||
"{} <- {} -- {} {:?} {:?} {:?} {}",
|
||||
server_name,
|
||||
"{} <- {} -- {} {} {:?} -- ({}{}) \"{}\" \"{}\" {}",
|
||||
host,
|
||||
canonical_src,
|
||||
self.method(),
|
||||
uri_pq,
|
||||
self.version(),
|
||||
self
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map_or_else(|| "", |v| v.as_str()),
|
||||
self.headers(),
|
||||
extra.map_or_else(|| "", |v| v)
|
||||
uri_scheme,
|
||||
uri_host,
|
||||
ua,
|
||||
xff,
|
||||
extra.unwrap_or("")
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,12 +49,12 @@ impl<B> ResLog for &Response<B> {
|
|||
) -> String {
|
||||
let canonical_client_addr = client_addr.to_canonical();
|
||||
format!(
|
||||
"{} <- {} -- {} {:?} {:?} {}",
|
||||
"{} <- {} -- {} {:?} {}",
|
||||
canonical_client_addr,
|
||||
server_name,
|
||||
self.status(),
|
||||
self.version(),
|
||||
self.headers(),
|
||||
// self.headers(),
|
||||
extra.map_or_else(|| "", |v| v)
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use super::Proxy;
|
||||
use crate::{error::*, log::*};
|
||||
use crate::{backend::ServerNameLC, error::*, log::*};
|
||||
use bytes::{Buf, Bytes};
|
||||
use h3::{quic::BidiStream, server::RequestStream};
|
||||
use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response};
|
||||
|
|
@ -10,13 +10,15 @@ impl<T> Proxy<T>
|
|||
where
|
||||
T: Connect + Clone + Sync + Send + 'static,
|
||||
{
|
||||
pub async fn client_serve_h3(&self, conn: quinn::Connecting) {
|
||||
pub async fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) {
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
if clients_count.increment() > self.globals.max_clients {
|
||||
clients_count.decrement();
|
||||
return;
|
||||
}
|
||||
let fut = self.clone().handle_connection_h3(conn);
|
||||
let fut = self
|
||||
.clone()
|
||||
.handle_connection_h3(conn, tls_server_name.to_vec());
|
||||
self.globals.runtime_handle.spawn(async move {
|
||||
// Timeout is based on underlying quic
|
||||
if let Err(e) = fut.await {
|
||||
|
|
@ -27,31 +29,22 @@ where
|
|||
});
|
||||
}
|
||||
|
||||
pub async fn handle_connection_h3(self, conn: quinn::Connecting) -> Result<()> {
|
||||
pub async fn handle_connection_h3(
|
||||
self,
|
||||
conn: quinn::Connecting,
|
||||
tls_server_name: ServerNameLC,
|
||||
) -> Result<()> {
|
||||
let client_addr = conn.remote_address();
|
||||
|
||||
match conn.await {
|
||||
Ok(new_conn) => {
|
||||
info!("QUIC connection established from {:?} {:?}", client_addr, {
|
||||
let hsd = new_conn
|
||||
.connection
|
||||
.handshake_data()
|
||||
.ok_or_else(|| anyhow!(""))?
|
||||
.downcast::<quinn::crypto::rustls::HandshakeData>()
|
||||
.map_err(|_| anyhow!(""))?;
|
||||
(
|
||||
hsd.protocol.map_or_else(
|
||||
|| "<none>".into(),
|
||||
|x| String::from_utf8_lossy(&x).into_owned(),
|
||||
),
|
||||
hsd.server_name.map_or_else(|| "<none>".into(), |x| x),
|
||||
)
|
||||
});
|
||||
|
||||
let mut h3_conn =
|
||||
h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn))
|
||||
.await?;
|
||||
info!("HTTP/3 connection established");
|
||||
info!(
|
||||
"QUIC/HTTP3 connection established from {:?} {:?}",
|
||||
client_addr, tls_server_name
|
||||
);
|
||||
|
||||
// Does this work enough?
|
||||
// while let Some((req, stream)) = h3_conn
|
||||
|
|
@ -73,10 +66,11 @@ where
|
|||
);
|
||||
|
||||
let self_inner = self.clone();
|
||||
let tls_server_name_inner = tls_server_name.clone();
|
||||
self.globals.runtime_handle.spawn(async move {
|
||||
if let Err(e) = timeout(
|
||||
self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
|
||||
self_inner.handle_stream_h3(req, stream, client_addr),
|
||||
self_inner.handle_stream_h3(req, stream, client_addr, tls_server_name_inner),
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
|
@ -99,6 +93,7 @@ where
|
|||
req: Request<()>,
|
||||
mut stream: RequestStream<S, Bytes>,
|
||||
client_addr: SocketAddr,
|
||||
tls_server_name: ServerNameLC,
|
||||
) -> Result<()>
|
||||
where
|
||||
S: BidiStream<Bytes>,
|
||||
|
|
@ -128,7 +123,13 @@ where
|
|||
let res = self
|
||||
.msg_handler
|
||||
.clone()
|
||||
.handle_request(new_req, client_addr, self.listening_on, self.tls_enabled)
|
||||
.handle_request(
|
||||
new_req,
|
||||
client_addr,
|
||||
self.listening_on,
|
||||
self.tls_enabled,
|
||||
Some(tls_server_name),
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (new_res_parts, new_body) = res.into_parts();
|
||||
|
|
|
|||
|
|
@ -45,8 +45,13 @@ impl<T> Proxy<T>
|
|||
where
|
||||
T: Connect + Clone + Sync + Send + 'static,
|
||||
{
|
||||
pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr)
|
||||
where
|
||||
pub async fn client_serve<I>(
|
||||
self,
|
||||
stream: I,
|
||||
server: Http<LocalExecutor>,
|
||||
peer_addr: SocketAddr,
|
||||
tls_server_name: Option<&[u8]>,
|
||||
) where
|
||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
|
|
@ -55,7 +60,7 @@ where
|
|||
return;
|
||||
}
|
||||
|
||||
// let handler_inner = self.msg_handler.clone();
|
||||
let inner = tls_server_name.map_or_else(|| None, |v| Some(v.to_vec()));
|
||||
self.globals.runtime_handle.clone().spawn(async move {
|
||||
timeout(
|
||||
self.globals.proxy_timeout + Duration::from_secs(1),
|
||||
|
|
@ -68,6 +73,7 @@ where
|
|||
peer_addr,
|
||||
self.listening_on,
|
||||
self.tls_enabled,
|
||||
inner.clone(),
|
||||
)
|
||||
}),
|
||||
)
|
||||
|
|
@ -88,7 +94,7 @@ where
|
|||
while let Ok((stream, _client_addr)) = tcp_listener.accept().await {
|
||||
self
|
||||
.clone()
|
||||
.client_serve(stream, server.clone(), _client_addr)
|
||||
.client_serve(stream, server.clone(), _client_addr, None)
|
||||
.await;
|
||||
}
|
||||
Ok(()) as Result<()>
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ where
|
|||
};
|
||||
// Finally serve the TLS connection
|
||||
if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await {
|
||||
self.clone().client_serve(stream, server.clone(), _client_addr).await
|
||||
self.clone().client_serve(stream, server.clone(), _client_addr, Some(server_name.as_bytes())).await
|
||||
}
|
||||
}
|
||||
_ = server_crypto_rx.changed().fuse() => {
|
||||
|
|
@ -101,11 +101,11 @@ where
|
|||
}
|
||||
|
||||
#[cfg(feature = "h3")]
|
||||
async fn parse_sni_and_get_crypto_h3(
|
||||
async fn parse_sni_and_get_crypto_h3<'a>(
|
||||
&self,
|
||||
peeked_conn: &mut quinn::Connecting,
|
||||
server_crypto_map: &ServerCryptoMap,
|
||||
) -> Option<Arc<ServerConfig>> {
|
||||
server_crypto_map: &'a ServerCryptoMap,
|
||||
) -> Option<(&'a ServerNameLC, &'a Arc<ServerConfig>)> {
|
||||
let hsd = if let Ok(h) = peeked_conn.handshake_data().await {
|
||||
h
|
||||
} else {
|
||||
|
|
@ -121,9 +121,8 @@ where
|
|||
"HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig",
|
||||
server_name
|
||||
);
|
||||
server_crypto_map
|
||||
.get(&server_name.as_bytes().to_vec())
|
||||
.cloned()
|
||||
server_crypto_map.get_key_value(&server_name.into_bytes())
|
||||
// .map_or_else(|| None, |(k, v)| Some((k.clone(), v.clone())));
|
||||
}
|
||||
|
||||
#[cfg(feature = "h3")]
|
||||
|
|
@ -173,19 +172,21 @@ where
|
|||
continue;
|
||||
}
|
||||
let peeked_conn = peeked_conn.unwrap();
|
||||
let is_acceptable =
|
||||
if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await {
|
||||
|
||||
let new_server_name = match self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await {
|
||||
Some((new_server_name, new_server_crypto)) => {
|
||||
// Set ServerConfig::set_server_config for given SNI
|
||||
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto)));
|
||||
true
|
||||
} else {
|
||||
false
|
||||
};
|
||||
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone())));
|
||||
Some(new_server_name)
|
||||
},
|
||||
None => None
|
||||
};
|
||||
|
||||
// Then acquire actual connection
|
||||
let peekable_incoming = Pin::new(&mut p);
|
||||
if let Some(conn) = peekable_incoming.get_mut().next().await {
|
||||
if is_acceptable {
|
||||
self.clone().client_serve_h3(conn).await;
|
||||
if let Some(new_server_name) = new_server_name {
|
||||
self.clone().client_serve_h3(conn, new_server_name).await;
|
||||
}
|
||||
} else {
|
||||
continue;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue