add checking mechanism of consistency between sni and host/request line
This commit is contained in:
		
					parent
					
						
							
								4f5a1cbf91
							
						
					
				
			
			
				commit
				
					
						d37ed57a1c
					
				
			
		
					 11 changed files with 111 additions and 69 deletions
				
			
		|  | @ -163,6 +163,12 @@ pub fn parse_opts(globals: &mut Globals) -> Result<()> { | |||
|         info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.") | ||||
|       } | ||||
|     } | ||||
|     if let Some(b) = exp.ignore_sni_consistency { | ||||
|       globals.sni_consistency = !b; | ||||
|       if b { | ||||
|         info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC.") | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   Ok(()) | ||||
|  |  | |||
|  | @ -18,6 +18,7 @@ pub struct ConfigToml { | |||
| #[derive(Deserialize, Debug, Default)] | ||||
| pub struct Experimental { | ||||
|   pub h3: Option<bool>, | ||||
|   pub ignore_sni_consistency: Option<bool>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Deserialize, Debug, Default)] | ||||
|  |  | |||
|  | @ -19,6 +19,7 @@ pub struct Globals { | |||
|   pub max_concurrent_streams: u32, | ||||
|   pub keepalive: bool, | ||||
|   pub http3: bool, | ||||
|   pub sni_consistency: bool, | ||||
| 
 | ||||
|   pub runtime_handle: tokio::runtime::Handle, | ||||
| 
 | ||||
|  |  | |||
|  | @ -60,6 +60,7 @@ fn main() { | |||
|       http_port: None, | ||||
|       https_port: None, | ||||
|       http3: false, | ||||
|       sni_consistency: true, | ||||
| 
 | ||||
|       // TODO: Reconsider each timeout values
 | ||||
|       proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), | ||||
|  |  | |||
|  | @ -1,13 +1,19 @@ | |||
| // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 | ||||
| use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*}; | ||||
| use crate::{backend::Upstream, constants::*, error::*, globals::Globals, log::*}; | ||||
| use crate::{ | ||||
|   backend::{ServerNameLC, Upstream}, | ||||
|   constants::*, | ||||
|   error::*, | ||||
|   globals::Globals, | ||||
|   log::*, | ||||
| }; | ||||
| use hyper::{ | ||||
|   client::connect::Connect, | ||||
|   header::{self, HeaderValue}, | ||||
|   http::uri::Scheme, | ||||
|   Body, Client, Request, Response, StatusCode, Uri, Version, | ||||
| }; | ||||
| use std::{net::SocketAddr, sync::Arc}; | ||||
| use std::{env, net::SocketAddr, sync::Arc}; | ||||
| use tokio::{ | ||||
|   io::copy_bidirectional, | ||||
|   time::{timeout, Duration}, | ||||
|  | @ -32,14 +38,19 @@ where | |||
|     client_addr: SocketAddr, // アクセス制御用
 | ||||
|     listen_addr: SocketAddr, | ||||
|     tls_enabled: bool, | ||||
|     tls_server_name: Option<ServerNameLC>, | ||||
|   ) -> Result<Response<Body>> { | ||||
|     req.log_debug(&client_addr, Some("(Request from Client)")); | ||||
| 
 | ||||
|     // Here we start to handle with server_name
 | ||||
|     // Find backend application for given server_name, and drop if incoming request is invalid as request.
 | ||||
|     // let (server_name, _port) = parse_host_port(&req)?;
 | ||||
|     let server_name_bytes = req.parse_host()?.to_ascii_lowercase(); | ||||
| 
 | ||||
|     // check consistency of between TLS SNI and HOST/Request URI Line.
 | ||||
|     if self.globals.sni_consistency | ||||
|       && !server_name_bytes.eq_ignore_ascii_case(&tls_server_name.unwrap()) | ||||
|     { | ||||
|       return http_error(StatusCode::MISDIRECTED_REQUEST); | ||||
|     } | ||||
|     // Find backend application for given server_name, and drop if incoming request is invalid as request.
 | ||||
|     let backend = if let Some(be) = self.globals.backends.apps.get(&server_name_bytes) { | ||||
|       be | ||||
|     } else if let Some(default_server_name) = &self.globals.backends.default_server_name { | ||||
|  | @ -91,7 +102,7 @@ where | |||
|       return http_error(StatusCode::SERVICE_UNAVAILABLE); | ||||
|     }; | ||||
|     // debug!("Request to be forwarded: {:?}", req_forwarded);
 | ||||
|     req_forwarded.log(&client_addr, Some("(Request to Backend)")); | ||||
|     req_forwarded.log_debug(&client_addr, Some("(Request to Backend)")); | ||||
| 
 | ||||
|     // Forward request to
 | ||||
|     let mut res_backend = { | ||||
|  | @ -168,7 +179,7 @@ where | |||
|       // Generate response to client
 | ||||
|       if self.generate_response_forwarded(&mut res_backend).is_ok() { | ||||
|         // info!("{} => {}", request_log, response_log);
 | ||||
|         res_backend.log( | ||||
|         res_backend.log_debug( | ||||
|           &backend.server_name, | ||||
|           &client_addr, | ||||
|           Some("(Response to Client)"), | ||||
|  |  | |||
|  | @ -19,27 +19,40 @@ impl<B> ReqLog for &Request<B> { | |||
|   fn build_message<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>) -> String { | ||||
|     let canonical_src = src.to_canonical(); | ||||
| 
 | ||||
|     let server_name = self.headers().get(header::HOST).map_or_else( | ||||
|       || { | ||||
|         self | ||||
|           .uri() | ||||
|           .authority() | ||||
|           .map_or_else(|| "<none>", |au| au.as_str()) | ||||
|       }, | ||||
|       |h| h.to_str().unwrap_or("<none>"), | ||||
|     ); | ||||
|     let host = self | ||||
|       .headers() | ||||
|       .get(header::HOST) | ||||
|       .map_or_else(|| "", |v| v.to_str().unwrap_or("")); | ||||
|     let uri_scheme = self | ||||
|       .uri() | ||||
|       .scheme_str() | ||||
|       .map_or_else(|| "".to_string(), |v| format!("{}://", v)); | ||||
|     let uri_host = self.uri().host().unwrap_or(""); | ||||
|     let uri_pq = self | ||||
|       .uri() | ||||
|       .path_and_query() | ||||
|       .map_or_else(|| "", |v| v.as_str()); | ||||
|     let ua = self | ||||
|       .headers() | ||||
|       .get(header::USER_AGENT) | ||||
|       .map_or_else(|| "", |v| v.to_str().unwrap_or("")); | ||||
|     let xff = self | ||||
|       .headers() | ||||
|       .get("x-forwarded-for") | ||||
|       .map_or_else(|| "", |v| v.to_str().unwrap_or("")); | ||||
| 
 | ||||
|     format!( | ||||
|       "{} <- {} -- {} {:?} {:?} {:?} {}", | ||||
|       server_name, | ||||
|       "{} <- {} -- {} {} {:?} -- ({}{}) \"{}\" \"{}\" {}", | ||||
|       host, | ||||
|       canonical_src, | ||||
|       self.method(), | ||||
|       uri_pq, | ||||
|       self.version(), | ||||
|       self | ||||
|         .uri() | ||||
|         .path_and_query() | ||||
|         .map_or_else(|| "", |v| v.as_str()), | ||||
|       self.headers(), | ||||
|       extra.map_or_else(|| "", |v| v) | ||||
|       uri_scheme, | ||||
|       uri_host, | ||||
|       ua, | ||||
|       xff, | ||||
|       extra.unwrap_or("") | ||||
|     ) | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -49,12 +49,12 @@ impl<B> ResLog for &Response<B> { | |||
|   ) -> String { | ||||
|     let canonical_client_addr = client_addr.to_canonical(); | ||||
|     format!( | ||||
|       "{} <- {} -- {} {:?} {:?} {}", | ||||
|       "{} <- {} -- {} {:?} {}", | ||||
|       canonical_client_addr, | ||||
|       server_name, | ||||
|       self.status(), | ||||
|       self.version(), | ||||
|       self.headers(), | ||||
|       // self.headers(),
 | ||||
|       extra.map_or_else(|| "", |v| v) | ||||
|     ) | ||||
|   } | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| use super::Proxy; | ||||
| use crate::{error::*, log::*}; | ||||
| use crate::{backend::ServerNameLC, error::*, log::*}; | ||||
| use bytes::{Buf, Bytes}; | ||||
| use h3::{quic::BidiStream, server::RequestStream}; | ||||
| use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response}; | ||||
|  | @ -10,13 +10,15 @@ impl<T> Proxy<T> | |||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub async fn client_serve_h3(&self, conn: quinn::Connecting) { | ||||
|   pub async fn client_serve_h3(&self, conn: quinn::Connecting, tls_server_name: &[u8]) { | ||||
|     let clients_count = self.globals.clients_count.clone(); | ||||
|     if clients_count.increment() > self.globals.max_clients { | ||||
|       clients_count.decrement(); | ||||
|       return; | ||||
|     } | ||||
|     let fut = self.clone().handle_connection_h3(conn); | ||||
|     let fut = self | ||||
|       .clone() | ||||
|       .handle_connection_h3(conn, tls_server_name.to_vec()); | ||||
|     self.globals.runtime_handle.spawn(async move { | ||||
|       // Timeout is based on underlying quic
 | ||||
|       if let Err(e) = fut.await { | ||||
|  | @ -27,31 +29,22 @@ where | |||
|     }); | ||||
|   } | ||||
| 
 | ||||
|   pub async fn handle_connection_h3(self, conn: quinn::Connecting) -> Result<()> { | ||||
|   pub async fn handle_connection_h3( | ||||
|     self, | ||||
|     conn: quinn::Connecting, | ||||
|     tls_server_name: ServerNameLC, | ||||
|   ) -> Result<()> { | ||||
|     let client_addr = conn.remote_address(); | ||||
| 
 | ||||
|     match conn.await { | ||||
|       Ok(new_conn) => { | ||||
|         info!("QUIC connection established from {:?} {:?}", client_addr, { | ||||
|           let hsd = new_conn | ||||
|             .connection | ||||
|             .handshake_data() | ||||
|             .ok_or_else(|| anyhow!(""))? | ||||
|             .downcast::<quinn::crypto::rustls::HandshakeData>() | ||||
|             .map_err(|_| anyhow!(""))?; | ||||
|           ( | ||||
|             hsd.protocol.map_or_else( | ||||
|               || "<none>".into(), | ||||
|               |x| String::from_utf8_lossy(&x).into_owned(), | ||||
|             ), | ||||
|             hsd.server_name.map_or_else(|| "<none>".into(), |x| x), | ||||
|           ) | ||||
|         }); | ||||
| 
 | ||||
|         let mut h3_conn = | ||||
|           h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)) | ||||
|             .await?; | ||||
|         info!("HTTP/3 connection established"); | ||||
|         info!( | ||||
|           "QUIC/HTTP3 connection established from {:?} {:?}", | ||||
|           client_addr, tls_server_name | ||||
|         ); | ||||
| 
 | ||||
|         // Does this work enough?
 | ||||
|         // while let Some((req, stream)) = h3_conn
 | ||||
|  | @ -73,10 +66,11 @@ where | |||
|           ); | ||||
| 
 | ||||
|           let self_inner = self.clone(); | ||||
|           let tls_server_name_inner = tls_server_name.clone(); | ||||
|           self.globals.runtime_handle.spawn(async move { | ||||
|             if let Err(e) = timeout( | ||||
|               self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
 | ||||
|               self_inner.handle_stream_h3(req, stream, client_addr), | ||||
|               self_inner.handle_stream_h3(req, stream, client_addr, tls_server_name_inner), | ||||
|             ) | ||||
|             .await | ||||
|             { | ||||
|  | @ -99,6 +93,7 @@ where | |||
|     req: Request<()>, | ||||
|     mut stream: RequestStream<S, Bytes>, | ||||
|     client_addr: SocketAddr, | ||||
|     tls_server_name: ServerNameLC, | ||||
|   ) -> Result<()> | ||||
|   where | ||||
|     S: BidiStream<Bytes>, | ||||
|  | @ -128,7 +123,13 @@ where | |||
|     let res = self | ||||
|       .msg_handler | ||||
|       .clone() | ||||
|       .handle_request(new_req, client_addr, self.listening_on, self.tls_enabled) | ||||
|       .handle_request( | ||||
|         new_req, | ||||
|         client_addr, | ||||
|         self.listening_on, | ||||
|         self.tls_enabled, | ||||
|         Some(tls_server_name), | ||||
|       ) | ||||
|       .await?; | ||||
| 
 | ||||
|     let (new_res_parts, new_body) = res.into_parts(); | ||||
|  |  | |||
|  | @ -45,8 +45,13 @@ impl<T> Proxy<T> | |||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr) | ||||
|   where | ||||
|   pub async fn client_serve<I>( | ||||
|     self, | ||||
|     stream: I, | ||||
|     server: Http<LocalExecutor>, | ||||
|     peer_addr: SocketAddr, | ||||
|     tls_server_name: Option<&[u8]>, | ||||
|   ) where | ||||
|     I: AsyncRead + AsyncWrite + Send + Unpin + 'static, | ||||
|   { | ||||
|     let clients_count = self.globals.clients_count.clone(); | ||||
|  | @ -55,7 +60,7 @@ where | |||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     // let handler_inner = self.msg_handler.clone();
 | ||||
|     let inner = tls_server_name.map_or_else(|| None, |v| Some(v.to_vec())); | ||||
|     self.globals.runtime_handle.clone().spawn(async move { | ||||
|       timeout( | ||||
|         self.globals.proxy_timeout + Duration::from_secs(1), | ||||
|  | @ -68,6 +73,7 @@ where | |||
|                 peer_addr, | ||||
|                 self.listening_on, | ||||
|                 self.tls_enabled, | ||||
|                 inner.clone(), | ||||
|               ) | ||||
|             }), | ||||
|           ) | ||||
|  | @ -88,7 +94,7 @@ where | |||
|       while let Ok((stream, _client_addr)) = tcp_listener.accept().await { | ||||
|         self | ||||
|           .clone() | ||||
|           .client_serve(stream, server.clone(), _client_addr) | ||||
|           .client_serve(stream, server.clone(), _client_addr, None) | ||||
|           .await; | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|  |  | |||
|  | @ -85,7 +85,7 @@ where | |||
|           }; | ||||
|           // Finally serve the TLS connection
 | ||||
|           if let Ok(stream) = start.into_stream(server_crypto.unwrap().clone()).await { | ||||
|             self.clone().client_serve(stream, server.clone(), _client_addr).await | ||||
|             self.clone().client_serve(stream, server.clone(), _client_addr, Some(server_name.as_bytes())).await | ||||
|           } | ||||
|         } | ||||
|         _ = server_crypto_rx.changed().fuse() => { | ||||
|  | @ -101,11 +101,11 @@ where | |||
|   } | ||||
| 
 | ||||
|   #[cfg(feature = "h3")] | ||||
|   async fn parse_sni_and_get_crypto_h3( | ||||
|   async fn parse_sni_and_get_crypto_h3<'a>( | ||||
|     &self, | ||||
|     peeked_conn: &mut quinn::Connecting, | ||||
|     server_crypto_map: &ServerCryptoMap, | ||||
|   ) -> Option<Arc<ServerConfig>> { | ||||
|     server_crypto_map: &'a ServerCryptoMap, | ||||
|   ) -> Option<(&'a ServerNameLC, &'a Arc<ServerConfig>)> { | ||||
|     let hsd = if let Ok(h) = peeked_conn.handshake_data().await { | ||||
|       h | ||||
|     } else { | ||||
|  | @ -121,9 +121,8 @@ where | |||
|       "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", | ||||
|       server_name | ||||
|     ); | ||||
|     server_crypto_map | ||||
|       .get(&server_name.as_bytes().to_vec()) | ||||
|       .cloned() | ||||
|     server_crypto_map.get_key_value(&server_name.into_bytes()) | ||||
|     // .map_or_else(|| None, |(k, v)| Some((k.clone(), v.clone())));
 | ||||
|   } | ||||
| 
 | ||||
|   #[cfg(feature = "h3")] | ||||
|  | @ -173,19 +172,21 @@ where | |||
|             continue; | ||||
|           } | ||||
|           let peeked_conn = peeked_conn.unwrap(); | ||||
|           let is_acceptable = | ||||
|             if let Some(new_server_crypto) = self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { | ||||
| 
 | ||||
|           let new_server_name = match self.parse_sni_and_get_crypto_h3(peeked_conn, server_crypto_map.as_ref().unwrap()).await { | ||||
|             Some((new_server_name, new_server_crypto)) => { | ||||
|               // Set ServerConfig::set_server_config for given SNI
 | ||||
|               endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto))); | ||||
|               true | ||||
|             } else { | ||||
|               false | ||||
|             }; | ||||
|               endpoint.set_server_config(Some(quinn::ServerConfig::with_crypto(new_server_crypto.clone()))); | ||||
|               Some(new_server_name) | ||||
|             }, | ||||
|             None => None | ||||
|           }; | ||||
| 
 | ||||
|           // Then acquire actual connection
 | ||||
|           let peekable_incoming = Pin::new(&mut p); | ||||
|           if let Some(conn) = peekable_incoming.get_mut().next().await { | ||||
|             if is_acceptable { | ||||
|               self.clone().client_serve_h3(conn).await; | ||||
|             if let Some(new_server_name) = new_server_name { | ||||
|               self.clone().client_serve_h3(conn, new_server_name).await; | ||||
|             } | ||||
|           } else { | ||||
|             continue; | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara