commit
				
					
						0dfe2a2098
					
				
			
		
					 14 changed files with 324 additions and 51 deletions
				
			
		|  | @ -12,7 +12,8 @@ publish = false | ||||||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||||
| 
 | 
 | ||||||
| [features] | [features] | ||||||
| default = [] | default = ["h3"] | ||||||
|  | h3 = ["quinn"] | ||||||
| 
 | 
 | ||||||
| [dependencies] | [dependencies] | ||||||
| anyhow = "1.0.58" | anyhow = "1.0.58" | ||||||
|  | @ -26,7 +27,7 @@ hyper = { version = "0.14.19", default-features = false, features = [ | ||||||
|   "stream", |   "stream", | ||||||
| ] } | ] } | ||||||
| log = "0.4.17" | log = "0.4.17" | ||||||
| tokio = { version = "1.19.2", features = [ | tokio = { version = "1.19.2", default-features = false, features = [ | ||||||
|   "net", |   "net", | ||||||
|   "rt-multi-thread", |   "rt-multi-thread", | ||||||
|   "parking_lot", |   "parking_lot", | ||||||
|  | @ -48,6 +49,10 @@ hyper-rustls = { version = "0.23.0", default-features = false, features = [ | ||||||
|   "http2", |   "http2", | ||||||
| ] } | ] } | ||||||
| parking_lot = "0.12.1" | parking_lot = "0.12.1" | ||||||
|  | quinn = { version = "0.8.3", optional = true } | ||||||
|  | h3 = { git = "https://github.com/hyperium/h3.git" } | ||||||
|  | h3-quinn = { git = "https://github.com/hyperium/h3.git" } | ||||||
|  | bytes = "1.1.0" | ||||||
| 
 | 
 | ||||||
| [target.'cfg(not(target_env = "msvc"))'.dependencies] | [target.'cfg(not(target_env = "msvc"))'.dependencies] | ||||||
| tikv-jemallocator = "0.5.0" | tikv-jemallocator = "0.5.0" | ||||||
|  |  | ||||||
|  | @ -2,12 +2,12 @@ | ||||||
| 
 | 
 | ||||||
| echo "----------------------------" | echo "----------------------------" | ||||||
| echo "Benchmark on rpxy" | echo "Benchmark on rpxy" | ||||||
| ab -c 100 -n 10000 http://127.0.0.1:8080/ # TODO: localhost = 127.0.0.1を解決できるように決めておかんとだめそう | ab -c 100 -n 10000 http://127.0.0.1:8080/index.html # TODO: localhost = 127.0.0.1を解決できるように決めておかんとだめそう | ||||||
| 
 | 
 | ||||||
| echo "----------------------------" | echo "----------------------------" | ||||||
| echo "Benchmark on nginx" | echo "Benchmark on nginx" | ||||||
| ab -c 100 -n 10000 http://127.0.0.1:8090/ | ab -c 100 -n 10000  http://127.0.0.1:8090/index.html | ||||||
| 
 | 
 | ||||||
| echo "----------------------------" | echo "----------------------------" | ||||||
| echo "Benchmark on caddy" | echo "Benchmark on caddy" | ||||||
| ab -c 100 -n 10000 http://127.0.0.1:8100/ | ab -c 100 -n 10000  http://127.0.0.1:8100/index.html | ||||||
|  |  | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
| listen_port = 8080 | listen_port = 8080 | ||||||
| # listen_port_tls = 8443 | # listen_port_tls = 8443 | ||||||
| listen_ipv6 = false | listen_ipv6 = false | ||||||
|  | listen_only_ipv6 = false | ||||||
| 
 | 
 | ||||||
| max_concurrent_streams = 128 | max_concurrent_streams = 128 | ||||||
| max_clients = 512 | max_clients = 512 | ||||||
|  | @ -17,3 +18,7 @@ reverse_proxy = [ | ||||||
|   { upstream = [{ location = 'backend-nginx', tls = false }] }, |   { upstream = [{ location = 'backend-nginx', tls = false }] }, | ||||||
|   # { upstream = [{ location = '192.168.100.100', tls = false }] }, |   # { upstream = [{ location = '192.168.100.100', tls = false }] }, | ||||||
| ] | ] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | [experimental] | ||||||
|  | h3 = true | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ max_clients = 512 | ||||||
| 
 | 
 | ||||||
| # Optional: Listen [::] | # Optional: Listen [::] | ||||||
| listen_ipv6 = false | listen_ipv6 = false | ||||||
|  | listen_only_ipv6 = false | ||||||
| 
 | 
 | ||||||
| # Optional: App that serves all plaintext http request by referring to HOSTS or request header | # Optional: App that serves all plaintext http request by referring to HOSTS or request header | ||||||
| # execpt for configured application. | # execpt for configured application. | ||||||
|  | @ -54,3 +55,9 @@ tls = { https_redirection = true, tls_cert_path = 'localhost.pem', tls_cert_key_ | ||||||
| [apps.another_localhost] | [apps.another_localhost] | ||||||
| server_name = 'localhost.localdomain' | server_name = 'localhost.localdomain' | ||||||
| reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] | reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] | ||||||
|  | 
 | ||||||
|  | ################################### | ||||||
|  | #      Experimantal settings      # | ||||||
|  | ################################### | ||||||
|  | [experimental] | ||||||
|  | h3 = true | ||||||
|  |  | ||||||
|  | @ -180,7 +180,20 @@ impl Backend { | ||||||
|           "Unable to find a valid certificate and key", |           "Unable to find a valid certificate and key", | ||||||
|         ) |         ) | ||||||
|       })?; |       })?; | ||||||
|  | 
 | ||||||
|  |     #[cfg(feature = "h3")] | ||||||
|  |     { | ||||||
|  |       server_config.alpn_protocols = vec![ | ||||||
|  |         b"h3".to_vec(), | ||||||
|  |         b"hq-29".to_vec(), // quinn draft example TODO: remove later
 | ||||||
|  |         b"h2".to_vec(), | ||||||
|  |         b"http/1.1".to_vec(), | ||||||
|  |       ]; | ||||||
|  |     } | ||||||
|  |     #[cfg(not(feature = "h3"))] | ||||||
|  |     { | ||||||
|       server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; |       server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; | ||||||
|  |     } | ||||||
| 
 | 
 | ||||||
|     let mut config_store = self.server_config.lock(); |     let mut config_store = self.server_config.lock(); | ||||||
|     *config_store = Some(server_config); |     *config_store = Some(server_config); | ||||||
|  |  | ||||||
|  | @ -39,11 +39,18 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> | ||||||
|     }, |     }, | ||||||
|     anyhow!("Wrong port spec.") |     anyhow!("Wrong port spec.") | ||||||
|   ); |   ); | ||||||
|   let mut listen_addresses: Vec<&str> = LISTEN_ADDRESSES_V4.to_vec(); |   let mut listen_addresses: Vec<&str> = Vec::new(); | ||||||
|   if let Some(v) = config.listen_ipv6 { |   if let Some(v) = config.listen_only_ipv6 { | ||||||
|     if v { |     if v { | ||||||
|       listen_addresses.extend(LISTEN_ADDRESSES_V6.iter()); |       listen_addresses.extend(LISTEN_ADDRESSES_V6.iter()); | ||||||
|     } |     } | ||||||
|  |   } else if let Some(v) = config.listen_ipv6 { | ||||||
|  |     listen_addresses.extend(LISTEN_ADDRESSES_V4.iter()); | ||||||
|  |     if v { | ||||||
|  |       listen_addresses.extend(LISTEN_ADDRESSES_V6.iter()); | ||||||
|  |     } | ||||||
|  |   } else { | ||||||
|  |     listen_addresses.extend(LISTEN_ADDRESSES_V4.iter()); | ||||||
|   } |   } | ||||||
|   globals.listen_sockets = listen_addresses |   globals.listen_sockets = listen_addresses | ||||||
|     .iter() |     .iter() | ||||||
|  | @ -144,6 +151,16 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut Backends) -> Result<()> | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  |   // experimental
 | ||||||
|  |   if let Some(exp) = config.experimental { | ||||||
|  |     if let Some(b) = exp.h3 { | ||||||
|  |       globals.http3 = b; | ||||||
|  |       if b { | ||||||
|  |         info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable.") | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|   Ok(()) |   Ok(()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -8,10 +8,17 @@ pub struct ConfigToml { | ||||||
|   pub listen_port: Option<u16>, |   pub listen_port: Option<u16>, | ||||||
|   pub listen_port_tls: Option<u16>, |   pub listen_port_tls: Option<u16>, | ||||||
|   pub listen_ipv6: Option<bool>, |   pub listen_ipv6: Option<bool>, | ||||||
|  |   pub listen_only_ipv6: Option<bool>, | ||||||
|   pub max_concurrent_streams: Option<u32>, |   pub max_concurrent_streams: Option<u32>, | ||||||
|   pub max_clients: Option<u32>, |   pub max_clients: Option<u32>, | ||||||
|   pub apps: Option<Apps>, |   pub apps: Option<Apps>, | ||||||
|   pub default_app: Option<String>, |   pub default_app: Option<String>, | ||||||
|  |   pub experimental: Option<Experimental>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Deserialize, Debug, Default)] | ||||||
|  | pub struct Experimental { | ||||||
|  |   pub h3: Option<bool>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[derive(Deserialize, Debug, Default)] | #[derive(Deserialize, Debug, Default)] | ||||||
|  |  | ||||||
|  | @ -16,6 +16,7 @@ pub struct Globals { | ||||||
|   pub clients_count: ClientsCount, |   pub clients_count: ClientsCount, | ||||||
|   pub max_concurrent_streams: u32, |   pub max_concurrent_streams: u32, | ||||||
|   pub keepalive: bool, |   pub keepalive: bool, | ||||||
|  |   pub http3: bool, | ||||||
| 
 | 
 | ||||||
|   pub runtime_handle: tokio::runtime::Handle, |   pub runtime_handle: tokio::runtime::Handle, | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -55,6 +55,7 @@ fn main() { | ||||||
|       listen_sockets: Vec::new(), |       listen_sockets: Vec::new(), | ||||||
|       http_port: None, |       http_port: None, | ||||||
|       https_port: None, |       https_port: None, | ||||||
|  |       http3: false, | ||||||
|       timeout: Duration::from_secs(TIMEOUT_SEC), |       timeout: Duration::from_secs(TIMEOUT_SEC), | ||||||
|       max_clients: MAX_CLIENTS, |       max_clients: MAX_CLIENTS, | ||||||
|       clients_count: Default::default(), |       clients_count: Default::default(), | ||||||
|  |  | ||||||
|  | @ -1,3 +1,5 @@ | ||||||
|  | #[cfg(feature = "h3")] | ||||||
|  | mod proxy_h3; | ||||||
| mod proxy_handler; | mod proxy_handler; | ||||||
| mod proxy_main; | mod proxy_main; | ||||||
| mod proxy_tls; | mod proxy_tls; | ||||||
|  |  | ||||||
							
								
								
									
										110
									
								
								src/proxy/proxy_h3.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										110
									
								
								src/proxy/proxy_h3.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,110 @@ | ||||||
|  | use super::Proxy; | ||||||
|  | use crate::{error::*, log::*}; | ||||||
|  | use bytes::{Buf, Bytes}; | ||||||
|  | use h3::{quic::BidiStream, server::RequestStream}; | ||||||
|  | use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response}; | ||||||
|  | use std::net::SocketAddr; | ||||||
|  | 
 | ||||||
|  | impl<T> Proxy<T> | ||||||
|  | where | ||||||
|  |   T: Connect + Clone + Sync + Send + 'static, | ||||||
|  | { | ||||||
|  |   pub async fn client_serve_h3(self, conn: quinn::Connecting) -> Result<()> { | ||||||
|  |     let client_addr = conn.remote_address(); | ||||||
|  | 
 | ||||||
|  |     match conn.await { | ||||||
|  |       Ok(new_conn) => { | ||||||
|  |         info!("QUIC connection established from {:?} {:?}", client_addr, { | ||||||
|  |           let hsd = new_conn | ||||||
|  |             .connection | ||||||
|  |             .handshake_data() | ||||||
|  |             .ok_or_else(|| anyhow!(""))? | ||||||
|  |             .downcast::<quinn::crypto::rustls::HandshakeData>() | ||||||
|  |             .map_err(|_| anyhow!(""))?; | ||||||
|  |           ( | ||||||
|  |             hsd.protocol.map_or_else( | ||||||
|  |               || "<none>".into(), | ||||||
|  |               |x| String::from_utf8_lossy(&x).into_owned(), | ||||||
|  |             ), | ||||||
|  |             hsd.server_name.map_or_else(|| "<none>".into(), |x| x), | ||||||
|  |           ) | ||||||
|  |         }); | ||||||
|  | 
 | ||||||
|  |         let mut h3_conn = | ||||||
|  |           h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)) | ||||||
|  |             .await?; | ||||||
|  |         info!("HTTP/3 connection established"); | ||||||
|  | 
 | ||||||
|  |         while let Some((req, stream)) = h3_conn | ||||||
|  |           .accept() | ||||||
|  |           .await | ||||||
|  |           .map_err(|e| anyhow!("HTTP/3 accept failed: {}", e))? | ||||||
|  |         { | ||||||
|  |           info!("HTTP/3 new request received"); | ||||||
|  | 
 | ||||||
|  |           let self_inner = self.clone(); | ||||||
|  |           self.globals.runtime_handle.spawn(async move { | ||||||
|  |             if let Err(e) = self_inner.handle_request_h3(req, stream, client_addr).await { | ||||||
|  |               error!("HTTP/3 request failed: {}", e); | ||||||
|  |             } | ||||||
|  |           }); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       Err(err) => { | ||||||
|  |         warn!("QUIC accepting connection failed: {:?}", err); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ok(()) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   async fn handle_request_h3<S>( | ||||||
|  |     self, | ||||||
|  |     req: Request<()>, | ||||||
|  |     mut stream: RequestStream<S, Bytes>, | ||||||
|  |     client_addr: SocketAddr, | ||||||
|  |   ) -> Result<()> | ||||||
|  |   where | ||||||
|  |     S: BidiStream<Bytes>, | ||||||
|  |   { | ||||||
|  |     let (req_parts, _) = req.into_parts(); | ||||||
|  | 
 | ||||||
|  |     // TODO: h3 -> h2/http1.1などのプロトコル変換がなければ、bodyはBytes単位で直でsend_dataして転送した方がいい。やむなし。
 | ||||||
|  |     let mut body_chunk: Vec<u8> = Vec::new(); | ||||||
|  |     while let Some(request_body) = stream.recv_data().await? { | ||||||
|  |       body_chunk.extend_from_slice(request_body.chunk()); | ||||||
|  |     } | ||||||
|  |     let body = if body_chunk.is_empty() { | ||||||
|  |       Body::default() | ||||||
|  |     } else { | ||||||
|  |       debug!("HTTP/3 request with non-empty body"); | ||||||
|  |       Body::from(body_chunk) | ||||||
|  |     }; | ||||||
|  |     // trailers
 | ||||||
|  |     let trailers = if let Some(trailers) = stream.recv_trailers().await? { | ||||||
|  |       debug!("HTTP/3 request with trailers"); | ||||||
|  |       trailers | ||||||
|  |     } else { | ||||||
|  |       HeaderMap::new() | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let new_req: Request<Body> = Request::from_parts(req_parts, body); | ||||||
|  |     let res = self.handle_request(new_req, client_addr).await?; | ||||||
|  | 
 | ||||||
|  |     let (new_res_parts, new_body) = res.into_parts(); | ||||||
|  |     let new_res = Response::from_parts(new_res_parts, ()); | ||||||
|  | 
 | ||||||
|  |     match stream.send_response(new_res).await { | ||||||
|  |       Ok(_) => { | ||||||
|  |         debug!("HTTP/3 response to connection successful"); | ||||||
|  |         let data = hyper::body::to_bytes(new_body).await?; | ||||||
|  |         stream.send_data(data).await?; | ||||||
|  |         stream.send_trailers(trailers).await?; | ||||||
|  |       } | ||||||
|  |       Err(err) => { | ||||||
|  |         error!("Unable to send response to connection peer: {:?}", err); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     Ok(stream.finish().await?) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -98,6 +98,14 @@ where | ||||||
|         return http_error(StatusCode::BAD_REQUEST); |         return http_error(StatusCode::BAD_REQUEST); | ||||||
|       } |       } | ||||||
|     }; |     }; | ||||||
|  |     #[cfg(feature = "h3")] | ||||||
|  |     { | ||||||
|  |       if let Some(port) = self.globals.https_port { | ||||||
|  |         res_backend | ||||||
|  |           .headers_mut() | ||||||
|  |           .insert("alt-svc", format!("h3=\":{}\"", port).parse().unwrap()); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|     debug!("Response from backend: {:?}", res_backend.status()); |     debug!("Response from backend: {:?}", res_backend.status()); | ||||||
| 
 | 
 | ||||||
|     if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS { |     if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS { | ||||||
|  | @ -156,12 +164,12 @@ fn generate_request_forwarded<B: core::fmt::Debug>( | ||||||
|   debug!("Generate request to be forwarded"); |   debug!("Generate request to be forwarded"); | ||||||
| 
 | 
 | ||||||
|   // Add te: trailer if contained in original request
 |   // Add te: trailer if contained in original request
 | ||||||
|   let te_trailer = { |   let te_trailers = { | ||||||
|     if let Some(te) = req.headers().get("te") { |     if let Some(te) = req.headers().get("te") { | ||||||
|       te.to_str() |       te.to_str() | ||||||
|         .unwrap() |         .unwrap() | ||||||
|         .split(',') |         .split(',') | ||||||
|         .any(|x| x.trim() == "trailer") |         .any(|x| x.trim() == "trailers") | ||||||
|     } else { |     } else { | ||||||
|       false |       false | ||||||
|     } |     } | ||||||
|  | @ -175,7 +183,7 @@ fn generate_request_forwarded<B: core::fmt::Debug>( | ||||||
|   // X-Forwarded-For
 |   // X-Forwarded-For
 | ||||||
|   add_forwarding_header(headers, client_addr)?; |   add_forwarding_header(headers, client_addr)?; | ||||||
|   // Add te: trailer if te_trailer
 |   // Add te: trailer if te_trailer
 | ||||||
|   if te_trailer { |   if te_trailers { | ||||||
|     headers.insert("te", "trailer".parse().unwrap()); |     headers.insert("te", "trailer".parse().unwrap()); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|  | @ -200,6 +208,9 @@ fn generate_request_forwarded<B: core::fmt::Debug>( | ||||||
|   // Change version to http/1.1 when destination scheme is http
 |   // Change version to http/1.1 when destination scheme is http
 | ||||||
|   if req.version() != Version::HTTP_11 && upstream_scheme_host.scheme() == Some(&Scheme::HTTP) { |   if req.version() != Version::HTTP_11 && upstream_scheme_host.scheme() == Some(&Scheme::HTTP) { | ||||||
|     *req.version_mut() = Version::HTTP_11; |     *req.version_mut() = Version::HTTP_11; | ||||||
|  |   } else if req.version() == Version::HTTP_3 { | ||||||
|  |     debug!("HTTP/3 is currently unsupported for request to upstream. Use HTTP/2."); | ||||||
|  |     *req.version_mut() = Version::HTTP_2; | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   Ok(req) |   Ok(req) | ||||||
|  | @ -290,7 +301,10 @@ fn secure_redirection( | ||||||
|   Ok(response) |   Ok(response) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| fn parse_host_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> { | fn parse_host_port<B: core::fmt::Debug>( | ||||||
|  |   req: &Request<B>, | ||||||
|  |   tls_enabled: bool, | ||||||
|  | ) -> Result<(String, u16)> { | ||||||
|   let host_port_headers = req.headers().get("host"); |   let host_port_headers = req.headers().get("host"); | ||||||
|   let host_uri = req.uri().host(); |   let host_uri = req.uri().host(); | ||||||
|   let port_uri = req.uri().port_u16(); |   let port_uri = req.uri().port_u16(); | ||||||
|  |  | ||||||
|  | @ -74,13 +74,14 @@ where | ||||||
|     }); |     }); | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   async fn start_without_tls( |   async fn start_without_tls(self, server: Http<LocalExecutor>) -> Result<()> { | ||||||
|     self, |  | ||||||
|     listener: TcpListener, |  | ||||||
|     server: Http<LocalExecutor>, |  | ||||||
|   ) -> Result<()> { |  | ||||||
|     let listener_service = async { |     let listener_service = async { | ||||||
|       while let Ok((stream, _client_addr)) = listener.accept().await { |       let tcp_listener = TcpListener::bind(&self.listening_on).await?; | ||||||
|  |       info!( | ||||||
|  |         "Start TCP proxy serving with HTTP request for configured host names: {:?}", | ||||||
|  |         tcp_listener.local_addr()? | ||||||
|  |       ); | ||||||
|  |       while let Ok((stream, _client_addr)) = tcp_listener.accept().await { | ||||||
|         self |         self | ||||||
|           .clone() |           .clone() | ||||||
|           .client_serve(stream, server.clone(), _client_addr) |           .client_serve(stream, server.clone(), _client_addr) | ||||||
|  | @ -93,8 +94,6 @@ where | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   pub async fn start(self) -> Result<()> { |   pub async fn start(self) -> Result<()> { | ||||||
|     let tcp_listener = TcpListener::bind(&self.listening_on).await?; |  | ||||||
| 
 |  | ||||||
|     let mut server = Http::new(); |     let mut server = Http::new(); | ||||||
|     server.http1_keep_alive(self.globals.keepalive); |     server.http1_keep_alive(self.globals.keepalive); | ||||||
|     server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); |     server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); | ||||||
|  | @ -103,18 +102,10 @@ where | ||||||
|     let server = server.with_executor(executor); |     let server = server.with_executor(executor); | ||||||
| 
 | 
 | ||||||
|     if self.tls_enabled { |     if self.tls_enabled { | ||||||
|       info!( |  | ||||||
|         "Start TCP proxy serving with HTTPS request for configured host names: {:?}", |  | ||||||
|         tcp_listener.local_addr()? |  | ||||||
|       ); |  | ||||||
|       // #[cfg(feature = "tls")]
 |       // #[cfg(feature = "tls")]
 | ||||||
|       self.start_with_tls(tcp_listener, server).await?; |       self.start_with_tls(server).await?; | ||||||
|     } else { |     } else { | ||||||
|       info!( |       self.start_without_tls(server).await?; | ||||||
|         "Start TCP proxy serving with HTTP request for configured host names: {:?}", |  | ||||||
|         tcp_listener.local_addr()? |  | ||||||
|       ); |  | ||||||
|       self.start_without_tls(tcp_listener, server).await?; |  | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     Ok(()) |     Ok(()) | ||||||
|  |  | ||||||
|  | @ -1,7 +1,10 @@ | ||||||
| use super::proxy_main::{LocalExecutor, Proxy}; | use super::proxy_main::{LocalExecutor, Proxy}; | ||||||
| use crate::{constants::CERTS_WATCH_DELAY_SECS, error::*, log::*}; | use crate::{constants::CERTS_WATCH_DELAY_SECS, error::*, log::*}; | ||||||
|  | #[cfg(feature = "h3")] | ||||||
|  | use futures::StreamExt; | ||||||
| use futures::{future::FutureExt, join, select}; | use futures::{future::FutureExt, join, select}; | ||||||
| use hyper::{client::connect::Connect, server::conn::Http}; | use hyper::{client::connect::Connect, server::conn::Http}; | ||||||
|  | use rustls::ServerConfig; | ||||||
| use std::{sync::Arc, time::Duration}; | use std::{sync::Arc, time::Duration}; | ||||||
| use tokio::net::TcpListener; | use tokio::net::TcpListener; | ||||||
| 
 | 
 | ||||||
|  | @ -9,11 +12,7 @@ impl<T> Proxy<T> | ||||||
| where | where | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |   T: Connect + Clone + Sync + Send + 'static, | ||||||
| { | { | ||||||
|   pub async fn start_with_tls( |   pub async fn start_with_tls(self, server: Http<LocalExecutor>) -> Result<()> { | ||||||
|     self, |  | ||||||
|     listener: TcpListener, |  | ||||||
|     server: Http<LocalExecutor>, |  | ||||||
|   ) -> Result<()> { |  | ||||||
|     let cert_service = async { |     let cert_service = async { | ||||||
|       info!("Start cert watch service for {}", self.listening_on); |       info!("Start cert watch service for {}", self.listening_on); | ||||||
|       loop { |       loop { | ||||||
|  | @ -28,10 +27,17 @@ where | ||||||
|       } |       } | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     // TCP Listener Service, i.e., http/2 and http/1.1
 | ||||||
|     let listener_service = async { |     let listener_service = async { | ||||||
|  |       let tcp_listener = TcpListener::bind(&self.listening_on).await?; | ||||||
|  |       info!( | ||||||
|  |         "Start TCP proxy serving with HTTPS request for configured host names: {:?}", | ||||||
|  |         tcp_listener.local_addr()? | ||||||
|  |       ); | ||||||
|  | 
 | ||||||
|       loop { |       loop { | ||||||
|         select! { |         select! { | ||||||
|           tcp_cnx = listener.accept().fuse() => { |           tcp_cnx = tcp_listener.accept().fuse() => { | ||||||
|             if tcp_cnx.is_err() { |             if tcp_cnx.is_err() { | ||||||
|               continue; |               continue; | ||||||
|             } |             } | ||||||
|  | @ -53,25 +59,13 @@ where | ||||||
|               info!("No SNI in ClientHello"); |               info!("No SNI in ClientHello"); | ||||||
|               continue; |               continue; | ||||||
|             }; |             }; | ||||||
|             let backend_serve = if let Some(backend_serve) = self.backends.apps.get(svn){ |             let server_crypto = if let Some(p) = self.fetch_server_crypto(svn) { | ||||||
|               backend_serve |  | ||||||
|             } else { |  | ||||||
|               info!("No configuration for the server name {} given in client_hello", svn); |  | ||||||
|               continue; |  | ||||||
|             }; |  | ||||||
| 
 |  | ||||||
|             if backend_serve.tls_cert_path.is_none() { // at least cert does exit
 |  | ||||||
|               debug!("SNI indicates a site that doesn't support TLS."); |  | ||||||
|               continue; |  | ||||||
|             } |  | ||||||
|             let server_config = if let Some(p) = backend_serve.get_tls_server_config(){ |  | ||||||
|               p |               p | ||||||
|             } else { |             } else { | ||||||
|               error!("Failed to load server config"); |  | ||||||
|               continue; |               continue; | ||||||
|             }; |             }; | ||||||
|             // Finally serve the TLS connection
 |             // Finally serve the TLS connection
 | ||||||
|             if let Ok(stream) = start.into_stream(Arc::new(server_config)).await { |             if let Ok(stream) = start.into_stream(Arc::new(server_crypto)).await { | ||||||
|               self.clone().client_serve(stream, server.clone(), _client_addr).await |               self.clone().client_serve(stream, server.clone(), _client_addr).await | ||||||
|             } |             } | ||||||
|           } |           } | ||||||
|  | @ -81,6 +75,112 @@ where | ||||||
|       Ok(()) as Result<()> |       Ok(()) as Result<()> | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     ///////////////////////
 | ||||||
|  |     #[cfg(feature = "h3")] | ||||||
|  |     let listener_service_h3 = async { | ||||||
|  |       // TODO: Work around to initially serve incoming connection
 | ||||||
|  |       // かなり適当。エラーが出たり出なかったり。原因がわからない…
 | ||||||
|  |       let tls_app_names: Vec<String> = self | ||||||
|  |         .backends | ||||||
|  |         .apps | ||||||
|  |         .iter() | ||||||
|  |         .filter(|&(_, backend)| { | ||||||
|  |           backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() | ||||||
|  |         }) | ||||||
|  |         .map(|(name, _)| name.to_string()) | ||||||
|  |         .collect(); | ||||||
|  |       ensure!(!tls_app_names.is_empty(), "No TLS supported app"); | ||||||
|  |       let initial_app_name = tls_app_names.get(0).unwrap().as_str(); | ||||||
|  |       info!("Initial app_name: {}", initial_app_name); | ||||||
|  |       let backend_serve = self.backends.apps.get(initial_app_name).unwrap(); | ||||||
|  |       let server_crypto = backend_serve.get_tls_server_config().unwrap(); | ||||||
|  |       let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); | ||||||
|  | 
 | ||||||
|  |       let (endpoint, incoming) = | ||||||
|  |         quinn::Endpoint::server(server_config_h3, self.listening_on).unwrap(); | ||||||
|  |       debug!("HTTP/3 UDP listening on {}", endpoint.local_addr().unwrap()); | ||||||
|  | 
 | ||||||
|  |       let mut p = incoming.peekable(); | ||||||
|  |       loop { | ||||||
|  |         // TODO: Not sure if this properly works to handle multiple "server_name"s to host multiple hosts.
 | ||||||
|  |         // peek() should work for that.
 | ||||||
|  |         if let Some(peeked_conn) = std::pin::Pin::new(&mut p).peek_mut().await { | ||||||
|  |           let hsd = peeked_conn.handshake_data().await; | ||||||
|  |           let hsd_downcast = hsd? | ||||||
|  |             .downcast::<quinn::crypto::rustls::HandshakeData>() | ||||||
|  |             .unwrap(); | ||||||
|  |           let svn = if let Some(sni) = hsd_downcast.server_name { | ||||||
|  |             sni | ||||||
|  |           } else { | ||||||
|  |             debug!("HTTP/3 no SNI is given"); | ||||||
|  |             continue; | ||||||
|  |           }; | ||||||
|  |           let new_server_crypto = if let Some(p) = self.fetch_server_crypto(&svn) { | ||||||
|  |             p | ||||||
|  |           } else { | ||||||
|  |             continue; | ||||||
|  |           }; | ||||||
|  |           // Set ServerConfig::set_server_config for given SNI
 | ||||||
|  |           let mut new_server_config_h3 = | ||||||
|  |             quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)); | ||||||
|  |           if svn == "localhost" { | ||||||
|  |             new_server_config_h3.concurrent_connections(512); | ||||||
|  |           } | ||||||
|  |           info!( | ||||||
|  |             "HTTP/3 connection incoming (SNI {:?}): Overwrite ServerConfig", | ||||||
|  |             svn | ||||||
|  |           ); | ||||||
|  |           endpoint.set_server_config(Some(new_server_config_h3)); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         // Then acquire actual connection
 | ||||||
|  |         let peekable_incoming = std::pin::Pin::new(&mut p); | ||||||
|  |         if let Some(conn) = peekable_incoming.get_mut().next().await { | ||||||
|  |           let fut = self.clone().client_serve_h3(conn); | ||||||
|  |           self.globals.runtime_handle.spawn(async { | ||||||
|  |             if let Err(e) = fut.await { | ||||||
|  |               warn!("QUIC or HTTP/3 connection failed: {}", e) | ||||||
|  |             } | ||||||
|  |           }); | ||||||
|  |         } else { | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       endpoint.wait_idle().await; | ||||||
|  |       Ok(()) as Result<()> | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     #[cfg(not(feature = "h3"))] | ||||||
|  |     { | ||||||
|       join!(listener_service, cert_service).0 |       join!(listener_service, cert_service).0 | ||||||
|     } |     } | ||||||
|  |     #[cfg(feature = "h3")] | ||||||
|  |     { | ||||||
|  |       join!(listener_service, cert_service, listener_service_h3).0 | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   fn fetch_server_crypto(&self, server_name: &str) -> Option<ServerConfig> { | ||||||
|  |     let backend_serve = if let Some(backend_serve) = self.backends.apps.get(server_name) { | ||||||
|  |       backend_serve | ||||||
|  |     } else { | ||||||
|  |       warn!( | ||||||
|  |         "No configuration for the server name {} given in client_hello", | ||||||
|  |         server_name | ||||||
|  |       ); | ||||||
|  |       return None; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     if backend_serve.tls_cert_path.is_none() { | ||||||
|  |       // at least cert does exit
 | ||||||
|  |       warn!("SNI indicates a site that doesn't support TLS."); | ||||||
|  |       return None; | ||||||
|  |     } | ||||||
|  |     if let Some(p) = backend_serve.get_tls_server_config() { | ||||||
|  |       Some(p) | ||||||
|  |     } else { | ||||||
|  |       error!("Failed to load server config"); | ||||||
|  |       None | ||||||
|  |     } | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara