add checking mechanism of consistency between sni and host/request line

This commit is contained in:
Jun Kurihara 2022-07-15 23:55:21 +09:00
commit d37ed57a1c
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
11 changed files with 111 additions and 69 deletions

View file

@ -1,5 +1,5 @@
use super::Proxy;
use crate::{error::*, log::*};
use crate::{backend::ServerNameLC, error::*, log::*};
use bytes::{Buf, Bytes};
use h3::{quic::BidiStream, server::RequestStream};
use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response};
@ -10,13 +10,15 @@ impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
pub async fn client_serve_h3(&self, conn: quinn::Connecting) {
pub async fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) {
let clients_count = self.globals.clients_count.clone();
if clients_count.increment() > self.globals.max_clients {
clients_count.decrement();
return;
}
let fut = self.clone().handle_connection_h3(conn);
let fut = self
.clone()
.handle_connection_h3(conn, tls_server_name.to_vec());
self.globals.runtime_handle.spawn(async move {
// Timeout is based on underlying quic
if let Err(e) = fut.await {
@ -27,31 +29,22 @@ where
});
}
pub async fn handle_connection_h3(self, conn: quinn::Connecting) -> Result<()> {
pub async fn handle_connection_h3(
self,
conn: quinn::Connecting,
tls_server_name: ServerNameLC,
) -> Result<()> {
let client_addr = conn.remote_address();
match conn.await {
Ok(new_conn) => {
info!("QUIC connection established from {:?} {:?}", client_addr, {
let hsd = new_conn
.connection
.handshake_data()
.ok_or_else(|| anyhow!(""))?
.downcast::<quinn::crypto::rustls::HandshakeData>()
.map_err(|_| anyhow!(""))?;
(
hsd.protocol.map_or_else(
|| "<none>".into(),
|x| String::from_utf8_lossy(&x).into_owned(),
),
hsd.server_name.map_or_else(|| "<none>".into(), |x| x),
)
});
let mut h3_conn =
h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn))
.await?;
info!("HTTP/3 connection established");
info!(
"QUIC/HTTP3 connection established from {:?} {:?}",
client_addr, tls_server_name
);
// Does this work enough?
// while let Some((req, stream)) = h3_conn
@ -73,10 +66,11 @@ where
);
let self_inner = self.clone();
let tls_server_name_inner = tls_server_name.clone();
self.globals.runtime_handle.spawn(async move {
if let Err(e) = timeout(
self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
self_inner.handle_stream_h3(req, stream, client_addr),
self_inner.handle_stream_h3(req, stream, client_addr, tls_server_name_inner),
)
.await
{
@ -99,6 +93,7 @@ where
req: Request<()>,
mut stream: RequestStream<S, Bytes>,
client_addr: SocketAddr,
tls_server_name: ServerNameLC,
) -> Result<()>
where
S: BidiStream<Bytes>,
@ -128,7 +123,13 @@ where
let res = self
.msg_handler
.clone()
.handle_request(new_req, client_addr, self.listening_on, self.tls_enabled)
.handle_request(
new_req,
client_addr,
self.listening_on,
self.tls_enabled,
Some(tls_server_name),
)
.await?;
let (new_res_parts, new_body) = res.into_parts();

View file

@ -45,8 +45,13 @@ impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr)
where
pub async fn client_serve<I>(
self,
stream: I,
server: Http<LocalExecutor>,
peer_addr: SocketAddr,
tls_server_name: Option<&[u8]>,
) where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let clients_count = self.globals.clients_count.clone();
@ -55,7 +60,7 @@ where
return;
}
// let handler_inner = self.msg_handler.clone();
let inner = tls_server_name.map_or_else(|| None, |v| Some(v.to_vec()));
self.globals.runtime_handle.clone().spawn(async move {
timeout(
self.globals.proxy_timeout + Duration::from_secs(1),
@ -68,6 +73,7 @@ where
peer_addr,
self.listening_on,
self.tls_enabled,
inner.clone(),
)
}),
)
@ -88,7 +94,7 @@ where
while let Ok((stream, _client_addr)) = tcp_listener.accept().await {
self
.clone()
.client_serve(stream, server.clone(), _client_addr)
.client_serve(stream, server.clone(), _client_addr, None)
.await;
}
Ok(()) as Result<()>

View file

@ -85,7 +85,7 @@ where
};
// Finally serve the TLS connection
if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await {
self.clone().client_serve(stream, server.clone(), _client_addr).await
self.clone().client_serve(stream, server.clone(), _client_addr, Some(server_name.as_bytes())).await
}
}
_ = server_crypto_rx.changed().fuse() => {
@ -101,11 +101,11 @@ where
}
#[cfg(feature = "h3")]
async fn parse_sni_and_get_crypto_h3(
async fn parse_sni_and_get_crypto_h3<'a>(
&self,
peeked_conn: &mut quinn::Connecting,
server_crypto_map: &ServerCryptoMap,
) -> Option<Arc<ServerConfig>> {
server_crypto_map: &'a ServerCryptoMap,
) -> Option<(&'a ServerNameLC, &'a Arc<ServerConfig>)> {
let hsd = if let Ok(h) = peeked_conn.handshake_data().await {
h
} else {
@ -121,9 +121,8 @@ where
"HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig",
server_name
);
server_crypto_map
.get(&server_name.as_bytes().to_vec())
.cloned()
server_crypto_map.get_key_value(&server_name.into_bytes())
// .map_or_else(|| None, |(k, v)| Some((k.clone(), v.clone())));
}
#[cfg(feature = "h3")]
@ -173,19 +172,21 @@ where
continue;
}
let peeked_conn = peeked_conn.unwrap();
let is_acceptable =
if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await {
let new_server_name = match self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await {
Some((new_server_name, new_server_crypto)) => {
// Set ServerConfig::set_server_config for given SNI
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto)));
true
} else {
false
};
endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone())));
Some(new_server_name)
},
None => None
};
// Then acquire actual connection
let peekable_incoming = Pin::new(&mut p);
if let Some(conn) = peekable_incoming.get_mut().next().await {
if is_acceptable {
self.clone().client_serve_h3(conn).await;
if let Some(new_server_name) = new_server_name {
self.clone().client_serve_h3(conn, new_server_name).await;
}
} else {
continue;