refactor: split quic layer and h3 layer
This commit is contained in:
		
					parent
					
						
							
								bc919b9478
							
						
					
				
			
			
				commit
				
					
						fb389a6aab
					
				
			
		
					 9 changed files with 188 additions and 159 deletions
				
			
		|  | @ -44,6 +44,7 @@ where | |||
|   if proxy_config.https_port.is_some() { | ||||
|     info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); | ||||
|   } | ||||
|   #[cfg(feature = "http3")] | ||||
|   if proxy_config.http3 { | ||||
|     info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); | ||||
|   } | ||||
|  |  | |||
|  | @ -3,6 +3,8 @@ mod proxy_client_cert; | |||
| #[cfg(feature = "http3")] | ||||
| mod proxy_h3; | ||||
| mod proxy_main; | ||||
| #[cfg(feature = "http3")] | ||||
| mod proxy_quic; | ||||
| mod proxy_tls; | ||||
| mod socket; | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| use super::Proxy; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; | ||||
| use bytes::{Buf, Bytes}; | ||||
| use h3::{quic::BidiStream, server::RequestStream}; | ||||
| use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | ||||
| use hyper::{client::connect::Connect, Body, Request, Response}; | ||||
| use std::net::SocketAddr; | ||||
| use tokio::time::{timeout, Duration}; | ||||
|  | @ -11,67 +11,64 @@ where | |||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn connection_serve_h3( | ||||
|   pub(super) async fn connection_serve_h3<C>( | ||||
|     self, | ||||
|     conn: quinn::Connecting, | ||||
|     quic_connection: C, | ||||
|     tls_server_name: ServerNameBytesExp, | ||||
|   ) -> Result<()> { | ||||
|     let client_addr = conn.remote_address(); | ||||
| 
 | ||||
|     match conn.await { | ||||
|       Ok(new_conn) => { | ||||
|         let mut h3_conn = h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)).await?; | ||||
|         info!( | ||||
|           "QUIC/HTTP3 connection established from {:?} {:?}", | ||||
|           client_addr, tls_server_name | ||||
|         ); | ||||
|         // TODO: Is here enough to fetch server_name from NewConnection?
 | ||||
|         // to avoid deep nested call from listener_service_h3
 | ||||
|         loop { | ||||
|           // this routine follows hyperium/h3 examples https://github.com/hyperium/h3/blob/master/examples/server.rs
 | ||||
|           match h3_conn.accept().await { | ||||
|             Ok(None) => { | ||||
|               break; | ||||
|             } | ||||
|             Err(e) => { | ||||
|               warn!("HTTP/3 error on accept incoming connection: {}", e); | ||||
|               match e.get_error_level() { | ||||
|                 h3::error::ErrorLevel::ConnectionError => break, | ||||
|                 h3::error::ErrorLevel::StreamError => continue, | ||||
|               } | ||||
|             } | ||||
|             Ok(Some((req, stream))) => { | ||||
|               // We consider the connection count separately from the stream count.
 | ||||
|               // Max clients for h1/h2 = max 'stream' for h3.
 | ||||
|               let request_count = self.globals.request_count.clone(); | ||||
|               if request_count.increment() > self.globals.proxy_config.max_clients { | ||||
|                 request_count.decrement(); | ||||
|                 h3_conn.shutdown(0).await?; | ||||
|                 break; | ||||
|               } | ||||
|               debug!("Request incoming: current # {}", request_count.current()); | ||||
| 
 | ||||
|               let self_inner = self.clone(); | ||||
|               let tls_server_name_inner = tls_server_name.clone(); | ||||
|               self.globals.runtime_handle.spawn(async move { | ||||
|                 if let Err(e) = timeout( | ||||
|                   self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
 | ||||
|                   self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), | ||||
|                 ) | ||||
|                 .await | ||||
|                 { | ||||
|                   error!("HTTP/3 failed to process stream: {}", e); | ||||
|                 } | ||||
|                 request_count.decrement(); | ||||
|                 debug!("Request processed: current # {}", request_count.current()); | ||||
|               }); | ||||
|             } | ||||
|     client_addr: SocketAddr, | ||||
|   ) -> Result<()> | ||||
|   where | ||||
|     C: ConnectionQuic<Bytes>, | ||||
|     <C as ConnectionQuic<Bytes>>::BidiStream: BidiStream<Bytes> + Send + 'static, | ||||
|     <<C as ConnectionQuic<Bytes>>::BidiStream as BidiStream<Bytes>>::RecvStream: Send, | ||||
|     <<C as ConnectionQuic<Bytes>>::BidiStream as BidiStream<Bytes>>::SendStream: Send, | ||||
|   { | ||||
|     let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; | ||||
|     info!( | ||||
|       "QUIC/HTTP3 connection established from {:?} {:?}", | ||||
|       client_addr, tls_server_name | ||||
|     ); | ||||
|     // TODO: Is here enough to fetch server_name from NewConnection?
 | ||||
|     // to avoid deep nested call from listener_service_h3
 | ||||
|     loop { | ||||
|       // this routine follows hyperium/h3 examples https://github.com/hyperium/h3/blob/master/examples/server.rs
 | ||||
|       match h3_conn.accept().await { | ||||
|         Ok(None) => { | ||||
|           break; | ||||
|         } | ||||
|         Err(e) => { | ||||
|           warn!("HTTP/3 error on accept incoming connection: {}", e); | ||||
|           match e.get_error_level() { | ||||
|             h3::error::ErrorLevel::ConnectionError => break, | ||||
|             h3::error::ErrorLevel::StreamError => continue, | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       Err(err) => { | ||||
|         warn!("QUIC accepting connection failed: {:?}", err); | ||||
|         return Err(RpxyError::QuicConn(err)); | ||||
|         Ok(Some((req, stream))) => { | ||||
|           // We consider the connection count separately from the stream count.
 | ||||
|           // Max clients for h1/h2 = max 'stream' for h3.
 | ||||
|           let request_count = self.globals.request_count.clone(); | ||||
|           if request_count.increment() > self.globals.proxy_config.max_clients { | ||||
|             request_count.decrement(); | ||||
|             h3_conn.shutdown(0).await?; | ||||
|             break; | ||||
|           } | ||||
|           debug!("Request incoming: current # {}", request_count.current()); | ||||
| 
 | ||||
|           let self_inner = self.clone(); | ||||
|           let tls_server_name_inner = tls_server_name.clone(); | ||||
|           self.globals.runtime_handle.spawn(async move { | ||||
|             if let Err(e) = timeout( | ||||
|               self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
 | ||||
|               self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), | ||||
|             ) | ||||
|             .await | ||||
|             { | ||||
|               error!("HTTP/3 failed to process stream: {}", e); | ||||
|             } | ||||
|             request_count.decrement(); | ||||
|             debug!("Request processed: current # {}", request_count.current()); | ||||
|           }); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										124
									
								
								rpxy-lib/src/proxy/proxy_quic.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								rpxy-lib/src/proxy/proxy_quic.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,124 @@ | |||
| use super::socket::bind_udp_socket; | ||||
| use super::{ | ||||
|   crypto_service::{ServerCrypto, ServerCryptoBase}, | ||||
|   proxy_main::Proxy, | ||||
| }; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; | ||||
| use hot_reload::ReloaderReceiver; | ||||
| use hyper::client::connect::Connect; | ||||
| use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; | ||||
| use rustls::ServerConfig; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| impl<T, U> Proxy<T, U> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn listener_service_h3( | ||||
|     &self, | ||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||
|   ) -> Result<()> { | ||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names"); | ||||
|     // first set as null config server
 | ||||
|     let rustls_server_config = ServerConfig::builder() | ||||
|       .with_safe_default_cipher_suites() | ||||
|       .with_safe_default_kx_groups() | ||||
|       .with_protocol_versions(&[&rustls::version::TLS13])? | ||||
|       .with_no_client_auth() | ||||
|       .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); | ||||
| 
 | ||||
|     let mut transport_config_quic = TransportConfig::default(); | ||||
|     transport_config_quic | ||||
|       .max_concurrent_bidi_streams(self.globals.proxy_config.h3_max_concurrent_bidistream) | ||||
|       .max_concurrent_uni_streams(self.globals.proxy_config.h3_max_concurrent_unistream) | ||||
|       .max_idle_timeout( | ||||
|         self | ||||
|           .globals | ||||
|           .proxy_config | ||||
|           .h3_max_idle_timeout | ||||
|           .map(|v| quinn::IdleTimeout::try_from(v).unwrap()), | ||||
|       ); | ||||
| 
 | ||||
|     let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config)); | ||||
|     server_config_h3.transport = Arc::new(transport_config_quic); | ||||
|     server_config_h3.concurrent_connections(self.globals.proxy_config.h3_max_concurrent_connections); | ||||
| 
 | ||||
|     // To reuse address
 | ||||
|     let udp_socket = bind_udp_socket(&self.listening_on)?; | ||||
|     let runtime = quinn::default_runtime() | ||||
|       .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?; | ||||
|     let endpoint = Endpoint::new( | ||||
|       quinn::EndpointConfig::default(), | ||||
|       Some(server_config_h3), | ||||
|       udp_socket, | ||||
|       runtime, | ||||
|     )?; | ||||
| 
 | ||||
|     let mut server_crypto: Option<Arc<ServerCrypto>> = None; | ||||
|     loop { | ||||
|       tokio::select! { | ||||
|         new_conn = endpoint.accept() => { | ||||
|           if server_crypto.is_none() || new_conn.is_none() { | ||||
|             continue; | ||||
|           } | ||||
|           let mut conn: quinn::Connecting = new_conn.unwrap(); | ||||
|           let Ok(hsd) = conn.handshake_data().await else { | ||||
|             continue
 | ||||
|           }; | ||||
| 
 | ||||
|           let Ok(hsd_downcast) = hsd.downcast::<HandshakeData>() else { | ||||
|             continue
 | ||||
|           }; | ||||
|           let Some(new_server_name) = hsd_downcast.server_name else { | ||||
|             warn!("HTTP/3 no SNI is given"); | ||||
|             continue; | ||||
|           }; | ||||
|           debug!( | ||||
|             "HTTP/3 connection incoming (SNI {:?})", | ||||
|             new_server_name | ||||
|           ); | ||||
|           // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは?
 | ||||
|           // TODO: 通常のTLSと同じenumか何かにまとめたい
 | ||||
|           let self_clone = self.clone(); | ||||
|           self.globals.runtime_handle.spawn(async move { | ||||
|             let client_addr = conn.remote_address(); | ||||
|             let quic_connection = match conn.await { | ||||
|               Ok(new_conn) => { | ||||
|                 info!("New connection established"); | ||||
|                 h3_quinn::Connection::new(new_conn) | ||||
|               }, | ||||
|               Err(e) => { | ||||
|                 warn!("QUIC accepting connection failed: {:?}", e); | ||||
|                 return Err(RpxyError::QuicConn(e)); | ||||
|               } | ||||
|             }; | ||||
|             // Timeout is based on underlying quic
 | ||||
|             if let Err(e) = self_clone.connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr).await { | ||||
|               warn!("QUIC or HTTP/3 connection failed: {}", e); | ||||
|             }; | ||||
|             Ok(()) | ||||
|           }); | ||||
|         } | ||||
|         _ = server_crypto_rx.changed() => { | ||||
|           if server_crypto_rx.borrow().is_none() { | ||||
|             error!("Reloader is broken"); | ||||
|             break; | ||||
|           } | ||||
|           let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); | ||||
| 
 | ||||
|           server_crypto = (&cert_keys_map).try_into().ok(); | ||||
|           let Some(inner) = server_crypto.clone() else { | ||||
|             error!("Failed to update server crypto for h3"); | ||||
|             break; | ||||
|           }; | ||||
|           endpoint.set_server_config(Some(QuicServerConfig::with_crypto(inner.clone().inner_global_no_client_auth.clone()))); | ||||
| 
 | ||||
|         } | ||||
|         else => break
 | ||||
|       } | ||||
|     } | ||||
|     endpoint.wait_idle().await; | ||||
|     Ok(()) as Result<()> | ||||
|   } | ||||
| } | ||||
|  | @ -1,5 +1,3 @@ | |||
| #[cfg(feature = "http3")] | ||||
| use super::socket::bind_udp_socket; | ||||
| use super::{ | ||||
|   crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, | ||||
|   proxy_main::{LocalExecutor, Proxy}, | ||||
|  | @ -8,10 +6,6 @@ use super::{ | |||
| use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName}; | ||||
| use hot_reload::{ReloaderReceiver, ReloaderService}; | ||||
| use hyper::{client::connect::Connect, server::conn::Http}; | ||||
| #[cfg(feature = "http3")] | ||||
| use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; | ||||
| #[cfg(feature = "http3")] | ||||
| use rustls::ServerConfig; | ||||
| use std::sync::Arc; | ||||
| use tokio::time::{timeout, Duration}; | ||||
| 
 | ||||
|  | @ -105,99 +99,6 @@ where | |||
|     Ok(()) as Result<()> | ||||
|   } | ||||
| 
 | ||||
|   #[cfg(feature = "http3")] | ||||
|   async fn listener_service_h3(&self, mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>) -> Result<()> { | ||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names"); | ||||
|     // first set as null config server
 | ||||
|     let rustls_server_config = ServerConfig::builder() | ||||
|       .with_safe_default_cipher_suites() | ||||
|       .with_safe_default_kx_groups() | ||||
|       .with_protocol_versions(&[&rustls::version::TLS13])? | ||||
|       .with_no_client_auth() | ||||
|       .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); | ||||
| 
 | ||||
|     let mut transport_config_quic = TransportConfig::default(); | ||||
|     transport_config_quic | ||||
|       .max_concurrent_bidi_streams(self.globals.proxy_config.h3_max_concurrent_bidistream) | ||||
|       .max_concurrent_uni_streams(self.globals.proxy_config.h3_max_concurrent_unistream) | ||||
|       .max_idle_timeout( | ||||
|         self | ||||
|           .globals | ||||
|           .proxy_config | ||||
|           .h3_max_idle_timeout | ||||
|           .map(|v| quinn::IdleTimeout::try_from(v).unwrap()), | ||||
|       ); | ||||
| 
 | ||||
|     let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config)); | ||||
|     server_config_h3.transport = Arc::new(transport_config_quic); | ||||
|     server_config_h3.concurrent_connections(self.globals.proxy_config.h3_max_concurrent_connections); | ||||
| 
 | ||||
|     // To reuse address
 | ||||
|     let udp_socket = bind_udp_socket(&self.listening_on)?; | ||||
|     let runtime = quinn::default_runtime() | ||||
|       .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?; | ||||
|     let endpoint = Endpoint::new( | ||||
|       quinn::EndpointConfig::default(), | ||||
|       Some(server_config_h3), | ||||
|       udp_socket, | ||||
|       runtime, | ||||
|     )?; | ||||
| 
 | ||||
|     let mut server_crypto: Option<Arc<ServerCrypto>> = None; | ||||
|     loop { | ||||
|       tokio::select! { | ||||
|         new_conn = endpoint.accept() => { | ||||
|           if server_crypto.is_none() || new_conn.is_none() { | ||||
|             continue; | ||||
|           } | ||||
|           let mut conn: quinn::Connecting = new_conn.unwrap(); | ||||
|           let Ok(hsd) = conn.handshake_data().await else { | ||||
|             continue
 | ||||
|           }; | ||||
| 
 | ||||
|           let Ok(hsd_downcast) = hsd.downcast::<HandshakeData>() else { | ||||
|             continue
 | ||||
|           }; | ||||
|           let Some(new_server_name) = hsd_downcast.server_name else { | ||||
|             warn!("HTTP/3 no SNI is given"); | ||||
|             continue; | ||||
|           }; | ||||
|           debug!( | ||||
|             "HTTP/3 connection incoming (SNI {:?})", | ||||
|             new_server_name | ||||
|           ); | ||||
|           // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは?
 | ||||
|           // TODO: 通常のTLSと同じenumか何かにまとめたい
 | ||||
|           let fut = self.clone().connection_serve_h3(conn, new_server_name.to_server_name_vec()); | ||||
|           self.globals.runtime_handle.spawn(async move { | ||||
|             // Timeout is based on underlying quic
 | ||||
|             if let Err(e) = fut.await { | ||||
|               warn!("QUIC or HTTP/3 connection failed: {}", e) | ||||
|             } | ||||
|           }); | ||||
|         } | ||||
|         _ = server_crypto_rx.changed() => { | ||||
|           if server_crypto_rx.borrow().is_none() { | ||||
|             error!("Reloader is broken"); | ||||
|             break; | ||||
|           } | ||||
|           let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); | ||||
| 
 | ||||
|           server_crypto = (&cert_keys_map).try_into().ok(); | ||||
|           let Some(inner) = server_crypto.clone() else { | ||||
|             error!("Failed to update server crypto for h3"); | ||||
|             break; | ||||
|           }; | ||||
|           endpoint.set_server_config(Some(QuicServerConfig::with_crypto(inner.clone().inner_global_no_client_auth.clone()))); | ||||
| 
 | ||||
|         } | ||||
|         else => break
 | ||||
|       } | ||||
|     } | ||||
|     endpoint.wait_idle().await; | ||||
|     Ok(()) as Result<()> | ||||
|   } | ||||
| 
 | ||||
|   pub async fn start_with_tls(self, server: Http<LocalExecutor>) -> Result<()> { | ||||
|     let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<CryptoReloader<U>, ServerCryptoBase>::new( | ||||
|       &self.globals.clone(), | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara