remake architecture to handle multiple tls endpoints
This commit is contained in:
		
					parent
					
						
							
								634d556ea9
							
						
					
				
			
			
				commit
				
					
						99e6bce992
					
				
			
		
					 11 changed files with 490 additions and 556 deletions
				
			
		|  | @ -48,6 +48,7 @@ hyper-trust-dns = { version = "0.4.2", default-features = false, features = [ | |||
|   "rustls-webpki", | ||||
| ], optional = true } | ||||
| hyper-tls = "0.5.0" | ||||
| rustls = "0.20.6" | ||||
| 
 | ||||
| [dev-dependencies] | ||||
| 
 | ||||
|  |  | |||
|  | @ -3,29 +3,36 @@ | |||
| #       rust-rxpy configuration        # | ||||
| #                                      # | ||||
| ######################################## | ||||
| 
 | ||||
| ################################## | ||||
| #         Global settings        # | ||||
| ################################## | ||||
| 
 | ||||
| ## Address to listen to. | ||||
| listen_addresses = ['127.0.0.1:50844', '[::1]:50844'] | ||||
| 
 | ||||
| [tls] | ||||
| tls_cert_path = 'localhost.pem' | ||||
| tls_cert_key_path = 'localhost.pem' | ||||
| ################################### | ||||
| #         Global settings         # | ||||
| ################################### | ||||
| http_port = 8080 | ||||
| https_port = 8443 | ||||
| 
 | ||||
| ################################### | ||||
| #         Backend settings        # | ||||
| ################################### | ||||
| [[backend]] | ||||
| domain = 'localhost' | ||||
| ## List of destinations to send data to. | ||||
| ## At this point, round-robin is used for load-balancing if multiple URLs are specified. | ||||
| destination = ['http://192.168.0.1:3000/', 'https://192.168.0.2:3000'] | ||||
| allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] | ||||
| denyhosts = ['*'] | ||||
| 
 | ||||
| [[backend]] | ||||
| domain = '127.0.0.1' | ||||
| destination = 'https://www.google.com/' | ||||
| app_name = 'localhost' # this should be option, if null then same as hostname | ||||
| hostname = 'localhost' | ||||
| redirect_to_https = true | ||||
| reverse_proxy = [ | ||||
|   { path = '*', destination = 'https://192.168.10.0:3000' }, | ||||
|   { path = '/path/to', destination = 'https://192.168.10.1:4000/path/to' }, | ||||
| ] | ||||
| ## List of destinations to send data to. | ||||
| ## At this point, round-robin is used for load-balancing if multiple URLs are specified. | ||||
| allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] | ||||
| denyhosts = ['*'] | ||||
| tls_cert_path = 'localhost1.pem' | ||||
| tls_cert_key_path = 'localhost1.pem' | ||||
| 
 | ||||
| 
 | ||||
| [[backend]] | ||||
| app_name = 'locahost_application' | ||||
| hostname = 'localhost.localdomain' | ||||
| redirect_to_https = true | ||||
| reverse_proxy = [{ path = '/', destination = 'https://www.google.com/' }] | ||||
| tls_cert_path = 'localhost2.pem' | ||||
| tls_cert_key_path = 'localhost2.pem' | ||||
|  |  | |||
							
								
								
									
										289
									
								
								src/acceptor.rs
									
										
									
									
									
								
							
							
						
						
									
										289
									
								
								src/acceptor.rs
									
										
									
									
									
								
							|  | @ -1,289 +0,0 @@ | |||
| use crate::{error::*, globals::Globals, log::*}; | ||||
| use futures::{ | ||||
|   task::{Context, Poll}, | ||||
|   Future, | ||||
| }; | ||||
| use hyper::{ | ||||
|   client::connect::Connect, | ||||
|   http, | ||||
|   server::conn::Http, | ||||
|   service::{service_fn, Service}, | ||||
|   Body, Client, HeaderMap, Method, Request, Response, StatusCode, | ||||
| }; | ||||
| use std::{net::SocketAddr, pin::Pin, sync::Arc}; | ||||
| use tokio::{ | ||||
|   io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, | ||||
|   net::TcpListener, | ||||
|   runtime::Handle, | ||||
|   time::Duration, | ||||
| }; | ||||
| 
 | ||||
| #[allow(clippy::unnecessary_wraps)] | ||||
| fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> { | ||||
|   let response = Response::builder() | ||||
|     .status(status_code) | ||||
|     .body(Body::empty()) | ||||
|     .unwrap(); | ||||
|   Ok(response) | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone, Debug)] | ||||
| pub struct LocalExecutor { | ||||
|   runtime_handle: Handle, | ||||
| } | ||||
| 
 | ||||
| impl LocalExecutor { | ||||
|   fn new(runtime_handle: Handle) -> Self { | ||||
|     LocalExecutor { runtime_handle } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<F> hyper::rt::Executor<F> for LocalExecutor | ||||
| where | ||||
|   F: std::future::Future + Send + 'static, | ||||
|   F::Output: Send, | ||||
| { | ||||
|   fn execute(&self, fut: F) { | ||||
|     self.runtime_handle.spawn(fut); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| pub struct PacketAcceptor<T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub listening_on: SocketAddr, | ||||
|   pub forwarder: Arc<Client<T>>, | ||||
|   pub globals: Arc<Globals>, | ||||
| } | ||||
| 
 | ||||
| // impl<T> Service<http::Request<Body>> for PacketAcceptor<T>
 | ||||
| // where
 | ||||
| //   T: Connect + Clone + Sync + Send + 'static,
 | ||||
| // {
 | ||||
| //   type Response = Response<Body>;
 | ||||
| 
 | ||||
| //   type Error = http::Error;
 | ||||
| //   type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
 | ||||
| 
 | ||||
| //   fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
 | ||||
| //     Poll::Ready(Ok(()))
 | ||||
| //   }
 | ||||
| 
 | ||||
| //   fn call(&mut self, req: Request<Body>) -> Self::Future {
 | ||||
| //     debug!(
 | ||||
| //       "serving {:?} {:?} request to {:?}",
 | ||||
| //       req.version(),
 | ||||
| //       req.method(),
 | ||||
| //       req.uri()
 | ||||
| //     );
 | ||||
| //     let self_inner = self.clone();
 | ||||
| 
 | ||||
| //     // 1. check uri (domain queried host name)
 | ||||
| //     // 2. build uri to forwarding target destination
 | ||||
| //     // 3. build request from uri and body
 | ||||
| //     // 4. send request to forwarding target
 | ||||
| 
 | ||||
| //     if *req.method() == Method::GET {
 | ||||
| //       Box::pin(async move {
 | ||||
| //         // let uri = req.uri();
 | ||||
| //         let target_uri = hyper::Uri::builder()
 | ||||
| //           .scheme("https")
 | ||||
| //           .authority("www.google.com")
 | ||||
| //           .path_and_query("/")
 | ||||
| //           .build()
 | ||||
| //           .unwrap();
 | ||||
| //         println!("{:?}", target_uri);
 | ||||
| //         match self_inner.forwarder.get(target_uri).await {
 | ||||
| //           Ok(res) => Ok(res),
 | ||||
| //           Err(e) => {
 | ||||
| //             error!("{:?}", e);
 | ||||
| //             http_error(StatusCode::INTERNAL_SERVER_ERROR)
 | ||||
| //           }
 | ||||
| //         }
 | ||||
| //       })
 | ||||
| //     } else {
 | ||||
| //       // let globals = &self.doh.globals;
 | ||||
| //       // let self_inner = self.clone();
 | ||||
| //       // if req.uri().path() == globals.path {
 | ||||
| //       //   Box::pin(async move {
 | ||||
| //       //     let mut subscriber = None;
 | ||||
| //       //     if self_inner.doh.globals.enable_auth_target {
 | ||||
| //       //       subscriber = match auth::authenticate(
 | ||||
| //       //         &self_inner.doh.globals,
 | ||||
| //       //         &req,
 | ||||
| //       //         ValidationLocation::Target,
 | ||||
| //       //         &self_inner.peer_addr,
 | ||||
| //       //       ) {
 | ||||
| //       //         Ok((sub, aud)) => {
 | ||||
| //       //           debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
 | ||||
| //       //           sub
 | ||||
| //       //         }
 | ||||
| //       //         Err(e) => {
 | ||||
| //       //           error!("{:?}", e);
 | ||||
| //       //           return Ok(e);
 | ||||
| //       //         }
 | ||||
| //       //       };
 | ||||
| //       //     }
 | ||||
| //       //     match *req.method() {
 | ||||
| //       //       Method::POST => self_inner.doh.serve_post(req, subscriber).await,
 | ||||
| //       //       Method::GET => self_inner.doh.serve_get(req, subscriber).await,
 | ||||
| //       //       _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
 | ||||
| //       //     }
 | ||||
| //       //   })
 | ||||
| //       // } else if req.uri().path() == globals.odoh_configs_path {
 | ||||
| //       //   match *req.method() {
 | ||||
| //       //     Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }),
 | ||||
| //       //     _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
 | ||||
| //       //   }
 | ||||
| //       // } else {
 | ||||
| //       //   #[cfg(not(feature = "odoh-proxy"))]
 | ||||
| //       //   {
 | ||||
| //       //     Box::pin(async { http_error(StatusCode::NOT_FOUND) })
 | ||||
| //       //   }
 | ||||
| //       //   #[cfg(feature = "odoh-proxy")]
 | ||||
| //       //   {
 | ||||
| //       //     if req.uri().path() == globals.odoh_proxy_path {
 | ||||
| //       //       Box::pin(async move {
 | ||||
| //       //         let mut subscriber = None;
 | ||||
| //       //         if self_inner.doh.globals.enable_auth_proxy {
 | ||||
| //       //           subscriber = match auth::authenticate(
 | ||||
| //       //             &self_inner.doh.globals,
 | ||||
| //       //             &req,
 | ||||
| //       //             ValidationLocation::Proxy,
 | ||||
| //       //             &self_inner.peer_addr,
 | ||||
| //       //           ) {
 | ||||
| //       //             Ok((sub, aud)) => {
 | ||||
| //       //               debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
 | ||||
| //       //               sub
 | ||||
| //       //             }
 | ||||
| //       //             Err(e) => {
 | ||||
| //       //               error!("{:?}", e);
 | ||||
| //       //               return Ok(e);
 | ||||
| //       //             }
 | ||||
| //       //           };
 | ||||
| //       //         }
 | ||||
| //       //         // Draft:        https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11
 | ||||
| //       //         // Golang impl.: https://github.com/cloudflare/odoh-server-go
 | ||||
| //       //         // Based on the draft and Golang implementation, only post method is allowed.
 | ||||
| //       //         match *req.method() {
 | ||||
| //       //           Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await,
 | ||||
| //       //           _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
 | ||||
| //       //         }
 | ||||
| //       //       })
 | ||||
| //       //     } else {
 | ||||
| //       Box::pin(async { http_error(StatusCode::NOT_FOUND) })
 | ||||
| //     }
 | ||||
| //     //     }
 | ||||
| //     // }
 | ||||
| //     // }
 | ||||
| //   }
 | ||||
| // }
 | ||||
| 
 | ||||
| async fn handle_request( | ||||
|   req: Request<Body>, | ||||
|   client_ip: SocketAddr, | ||||
|   globals: Arc<Globals>, | ||||
| ) -> Result<Response<Body>, http::Error> { | ||||
|   // http_error(StatusCode::NOT_FOUND)
 | ||||
|   debug!("{:?}", req); | ||||
|   // if req.version() == hyper::Version::HTTP_11 {
 | ||||
|   //   Ok(Response::new(Body::from("Hello World")))
 | ||||
|   // } else {
 | ||||
|   // Note: it's usually better to return a Response
 | ||||
|   // with an appropriate StatusCode instead of an Err.
 | ||||
|   // Err("not HTTP/1.1, abort connection")
 | ||||
|   http_error(StatusCode::NOT_FOUND) | ||||
|   // }
 | ||||
|   // });
 | ||||
| } | ||||
| 
 | ||||
| impl<T> PacketAcceptor<T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr) | ||||
|   where | ||||
|     I: AsyncRead + AsyncWrite + Send + Unpin + 'static, | ||||
|   { | ||||
|     let clients_count = self.globals.clients_count.clone(); | ||||
|     if clients_count.increment() > self.globals.max_clients { | ||||
|       clients_count.decrement(); | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     self.globals.runtime_handle.clone().spawn(async move { | ||||
|       tokio::time::timeout( | ||||
|         self.globals.timeout + Duration::from_secs(1), | ||||
|         // server.serve_connection(stream, self),
 | ||||
|         server.serve_connection( | ||||
|           stream, | ||||
|           service_fn(move |req: Request<Body>| { | ||||
|             handle_request(req, peer_addr, self.globals.clone()) | ||||
|           }), | ||||
|         ), | ||||
|       ) | ||||
|       .await | ||||
|       .ok(); | ||||
| 
 | ||||
|       clients_count.decrement(); | ||||
|     }); | ||||
|   } | ||||
| 
 | ||||
|   async fn start_without_tls( | ||||
|     self, | ||||
|     listener: TcpListener, | ||||
|     server: Http<LocalExecutor>, | ||||
|   ) -> Result<()> { | ||||
|     let listener_service = async { | ||||
|       while let Ok((stream, _client_addr)) = listener.accept().await { | ||||
|         self | ||||
|           .clone() | ||||
|           .client_serve(stream, server.clone(), _client_addr) | ||||
|           .await; | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|     }; | ||||
|     listener_service.await?; | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   pub async fn start(self) -> Result<()> { | ||||
|     let tcp_listener = TcpListener::bind(&self.listening_on).await?; | ||||
| 
 | ||||
|     let mut server = Http::new(); | ||||
|     server.http1_keep_alive(self.globals.keepalive); | ||||
|     server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); | ||||
|     server.pipeline_flush(true); | ||||
|     let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); | ||||
|     let server = server.with_executor(executor); | ||||
| 
 | ||||
|     let tls_enabled: bool; | ||||
|     #[cfg(not(feature = "tls"))] | ||||
|     { | ||||
|       tls_enabled = false; | ||||
|     } | ||||
|     #[cfg(feature = "tls")] | ||||
|     { | ||||
|       tls_enabled = | ||||
|         self.globals.tls_cert_path.is_some() && self.globals.tls_cert_key_path.is_some(); | ||||
|     } | ||||
|     if tls_enabled { | ||||
|       info!( | ||||
|         "Start server listening on TCP with TLS: {:?}", | ||||
|         tcp_listener.local_addr()? | ||||
|       ); | ||||
|       #[cfg(feature = "tls")] | ||||
|       self.start_with_tls(tcp_listener, server).await?; | ||||
|     } else { | ||||
|       info!( | ||||
|         "Start server listening on TCP: {:?}", | ||||
|         tcp_listener.local_addr()? | ||||
|       ); | ||||
|       self.start_without_tls(tcp_listener, server).await?; | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										134
									
								
								src/backend.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								src/backend.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,134 @@ | |||
| use crate::log::*; | ||||
| use std::{ | ||||
|   collections::HashMap, | ||||
|   fs::File, | ||||
|   io::{self, BufReader, Cursor, Read}, | ||||
|   path::PathBuf, | ||||
|   sync::Mutex, | ||||
| }; | ||||
| use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; | ||||
| 
 | ||||
| pub struct Backend { | ||||
|   pub app_name: String, | ||||
|   pub hostname: String, | ||||
|   pub reverse_proxy: ReverseProxy, | ||||
|   pub redirect_to_https: Option<bool>, | ||||
|   pub tls_cert_path: Option<PathBuf>, | ||||
|   pub tls_cert_key_path: Option<PathBuf>, | ||||
|   pub server_config: Mutex<Option<ServerConfig>>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct ReverseProxy { | ||||
|   pub default_destination_uri: hyper::Uri, | ||||
|   pub destination_uris: Option<HashMap<String, hyper::Uri>>, // TODO: url pathで引っ掛ける。
 | ||||
| } | ||||
| 
 | ||||
| impl Backend { | ||||
|   pub fn get_tls_server_config(&self) -> Option<ServerConfig> { | ||||
|     let lock = self.server_config.lock(); | ||||
|     if let Ok(opt) = lock { | ||||
|       let opt_clone = opt.clone(); | ||||
|       if let Some(sc) = opt_clone { | ||||
|         return Some(sc); | ||||
|       } | ||||
|     } | ||||
|     None | ||||
|   } | ||||
|   pub async fn update_server_config(&self) -> io::Result<()> { | ||||
|     debug!("Update TLS server config"); | ||||
|     let certs_path = self.tls_cert_path.as_ref().unwrap(); | ||||
|     let certs_keys_path = self.tls_cert_key_path.as_ref().unwrap(); | ||||
|     let certs: Vec<_> = { | ||||
|       let certs_path_str = certs_path.display().to_string(); | ||||
|       let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { | ||||
|         io::Error::new( | ||||
|           e.kind(), | ||||
|           format!( | ||||
|             "Unable to load the certificates [{}]: {}", | ||||
|             certs_path_str, e | ||||
|           ), | ||||
|         ) | ||||
|       })?); | ||||
|       rustls_pemfile::certs(&mut reader).map_err(|_| { | ||||
|         io::Error::new( | ||||
|           io::ErrorKind::InvalidInput, | ||||
|           "Unable to parse the certificates", | ||||
|         ) | ||||
|       })? | ||||
|     } | ||||
|     .drain(..) | ||||
|     .map(Certificate) | ||||
|     .collect(); | ||||
|     let certs_keys: Vec<_> = { | ||||
|       let certs_keys_path_str = certs_keys_path.display().to_string(); | ||||
|       let encoded_keys = { | ||||
|         let mut encoded_keys = vec![]; | ||||
|         File::open(certs_keys_path) | ||||
|           .map_err(|e| { | ||||
|             io::Error::new( | ||||
|               e.kind(), | ||||
|               format!( | ||||
|                 "Unable to load the certificate keys [{}]: {}", | ||||
|                 certs_keys_path_str, e | ||||
|               ), | ||||
|             ) | ||||
|           })? | ||||
|           .read_to_end(&mut encoded_keys)?; | ||||
|         encoded_keys | ||||
|       }; | ||||
|       let mut reader = Cursor::new(encoded_keys); | ||||
|       let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { | ||||
|         io::Error::new( | ||||
|           io::ErrorKind::InvalidInput, | ||||
|           "Unable to parse the certificates private keys (PKCS8)", | ||||
|         ) | ||||
|       })?; | ||||
|       reader.set_position(0); | ||||
|       let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| { | ||||
|         io::Error::new( | ||||
|           io::ErrorKind::InvalidInput, | ||||
|           "Unable to parse the certificates private keys (RSA)", | ||||
|         ) | ||||
|       })?; | ||||
|       let mut keys = pkcs8_keys; | ||||
|       keys.append(&mut rsa_keys); | ||||
|       if keys.is_empty() { | ||||
|         return Err(io::Error::new( | ||||
|           io::ErrorKind::InvalidInput, | ||||
|           "No private keys found - Make sure that they are in PKCS#8/PEM format", | ||||
|         )); | ||||
|       } | ||||
|       keys.drain(..).map(PrivateKey).collect() | ||||
|     }; | ||||
| 
 | ||||
|     let mut server_config = certs_keys | ||||
|       .into_iter() | ||||
|       .find_map(|certs_key| { | ||||
|         let server_config_builder = ServerConfig::builder() | ||||
|           .with_safe_defaults() | ||||
|           .with_no_client_auth(); | ||||
|         if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) { | ||||
|           Some(found_config) | ||||
|         } else { | ||||
|           None | ||||
|         } | ||||
|       }) | ||||
|       .ok_or_else(|| { | ||||
|         io::Error::new( | ||||
|           io::ErrorKind::InvalidInput, | ||||
|           "Unable to find a valid certificate and key", | ||||
|         ) | ||||
|       })?; | ||||
|     server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; | ||||
| 
 | ||||
|     if let Ok(mut config_store) = self.server_config.lock() { | ||||
|       *config_store = Some(server_config); | ||||
|     } else { | ||||
|       error!("Some thing wrong to write into mutex") | ||||
|     } | ||||
| 
 | ||||
|     // server_config;
 | ||||
|     Ok(()) | ||||
|   } | ||||
| } | ||||
|  | @ -1,13 +1,45 @@ | |||
| use crate::globals::Globals; | ||||
| use crate::{backend::*, constants::*, globals::*}; | ||||
| use hyper::Uri; | ||||
| use std::{collections::HashMap, sync::Mutex}; | ||||
| 
 | ||||
| #[cfg(feature = "tls")] | ||||
| // #[cfg(feature = "tls")]
 | ||||
| use std::path::PathBuf; | ||||
| 
 | ||||
| pub fn parse_opts(globals: &mut Globals) { | ||||
|   #[cfg(feature = "tls")] | ||||
|   { | ||||
|     // TODO:
 | ||||
|     globals.tls_cert_path = Some(PathBuf::from(r"localhost.pem")); | ||||
|     globals.tls_cert_key_path = Some(PathBuf::from(r"localhost.pem")); | ||||
|   } | ||||
| pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>) { | ||||
|   // TODO:
 | ||||
|   globals.listen_sockets = LISTEN_ADDRESSES | ||||
|     .to_vec() | ||||
|     .iter() | ||||
|     .flat_map(|x| { | ||||
|       vec![ | ||||
|         format!("{}:{}", x, HTTP_LISTEN_PORT).parse().unwrap(), | ||||
|         format!("{}:{}", x, HTTPS_LISTEN_PORT).parse().unwrap(), | ||||
|       ] | ||||
|     }) | ||||
|     .collect(); | ||||
|   globals.http_port = Some(HTTP_LISTEN_PORT); | ||||
|   globals.https_port = Some(HTTPS_LISTEN_PORT); | ||||
| 
 | ||||
|   // TODO:
 | ||||
|   let mut map_example: HashMap<String, Uri> = HashMap::new(); | ||||
|   map_example.insert( | ||||
|     "/maps".to_string(), | ||||
|     "https://bing.com/".parse::<Uri>().unwrap(), | ||||
|   ); | ||||
|   backends.insert( | ||||
|     "localhost".to_string(), | ||||
|     Backend { | ||||
|       app_name: "Google except for maps".to_string(), | ||||
|       hostname: "google.com".to_string(), | ||||
|       reverse_proxy: ReverseProxy { | ||||
|         default_destination_uri: "https://google.com/".parse::<Uri>().unwrap(), | ||||
|         destination_uris: Some(map_example), | ||||
|       }, | ||||
|       redirect_to_https: None, // TODO: ここはHTTPの時のみの設定。tlsの存在とは排他的。
 | ||||
| 
 | ||||
|       tls_cert_path: Some(PathBuf::from(r"localhost1.pem")), | ||||
|       tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")), | ||||
|       server_config: Mutex::new(None), | ||||
|     }, | ||||
|   ); | ||||
| } | ||||
|  |  | |||
|  | @ -1,6 +1,8 @@ | |||
| pub const LISTEN_ADDRESSES: &[&str] = &["127.0.0.1:8443", "[::1]:8443"]; | ||||
| pub const LISTEN_ADDRESSES: &[&str] = &["0.0.0.0", "[::]"]; | ||||
| pub const HTTP_LISTEN_PORT: u32 = 8080; | ||||
| pub const HTTPS_LISTEN_PORT: u32 = 8443; | ||||
| pub const TIMEOUT_SEC: u64 = 10; | ||||
| pub const MAX_CLIENTS: usize = 512; | ||||
| pub const MAX_CONCURRENT_STREAMS: u32 = 16; | ||||
| #[cfg(feature = "tls")] | ||||
| // #[cfg(feature = "tls")]
 | ||||
| pub const CERTS_WATCH_DELAY_SECS: u32 = 10; | ||||
|  |  | |||
|  | @ -1,6 +1,4 @@ | |||
| use std::net::SocketAddr; | ||||
| #[cfg(feature = "tls")] | ||||
| use std::path::PathBuf; | ||||
| use std::sync::{ | ||||
|   atomic::{AtomicUsize, Ordering}, | ||||
|   Arc, | ||||
|  | @ -9,7 +7,9 @@ use tokio::time::Duration; | |||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Globals { | ||||
|   pub listen_addresses: Vec<SocketAddr>, | ||||
|   pub listen_sockets: Vec<SocketAddr>, | ||||
|   pub http_port: Option<u32>, | ||||
|   pub https_port: Option<u32>, | ||||
| 
 | ||||
|   pub timeout: Duration, | ||||
|   pub max_clients: usize, | ||||
|  | @ -18,12 +18,6 @@ pub struct Globals { | |||
|   pub keepalive: bool, | ||||
| 
 | ||||
|   pub runtime_handle: tokio::runtime::Handle, | ||||
| 
 | ||||
|   #[cfg(feature = "tls")] | ||||
|   pub tls_cert_path: Option<PathBuf>, | ||||
| 
 | ||||
|   #[cfg(feature = "tls")] | ||||
|   pub tls_cert_key_path: Option<PathBuf>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Default)] | ||||
|  |  | |||
							
								
								
									
										77
									
								
								src/main.rs
									
										
									
									
									
								
							
							
						
						
									
										77
									
								
								src/main.rs
									
										
									
									
									
								
							|  | @ -1,18 +1,23 @@ | |||
| #[global_allocator] | ||||
| static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; | ||||
| 
 | ||||
| mod acceptor; | ||||
| mod backend; | ||||
| mod config; | ||||
| mod constants; | ||||
| mod error; | ||||
| mod globals; | ||||
| mod log; | ||||
| mod proxy; | ||||
| #[cfg(feature = "tls")] | ||||
| mod tls; | ||||
| mod proxy_tls; | ||||
| 
 | ||||
| use crate::{config::parse_opts, constants::*, globals::Globals, log::*, proxy::Proxy}; | ||||
| use std::{io::Write, sync::Arc}; | ||||
| use crate::{ | ||||
|   backend::Backend, config::parse_opts, constants::*, error::*, globals::*, log::*, proxy::Proxy, | ||||
| }; | ||||
| use futures::future::select_all; | ||||
| use hyper::Client; | ||||
| #[cfg(feature = "forward-hyper-trust-dns")] | ||||
| use hyper_trust_dns::TrustDnsResolver; | ||||
| use std::{collections::HashMap, io::Write, sync::Arc}; | ||||
| use tokio::time::Duration; | ||||
| 
 | ||||
| fn main() { | ||||
|  | @ -39,35 +44,61 @@ fn main() { | |||
|   runtime_builder.thread_name("rust-rpxy"); | ||||
|   let runtime = runtime_builder.build().unwrap(); | ||||
| 
 | ||||
|   // TODO:
 | ||||
|   let listen_addresses: Vec<std::net::SocketAddr> = LISTEN_ADDRESSES | ||||
|     .to_vec() | ||||
|     .iter() | ||||
|     .map(|x| x.parse().unwrap()) | ||||
|     .collect(); | ||||
| 
 | ||||
|   runtime.block_on(async { | ||||
|     let mut globals = Globals { | ||||
|       listen_addresses, | ||||
|       listen_sockets: Vec::new(), | ||||
|       http_port: None, | ||||
|       https_port: None, | ||||
|       timeout: Duration::from_secs(TIMEOUT_SEC), | ||||
|       max_clients: MAX_CLIENTS, | ||||
|       clients_count: Default::default(), | ||||
|       max_concurrent_streams: MAX_CONCURRENT_STREAMS, | ||||
|       keepalive: true, | ||||
|       runtime_handle: runtime.handle().clone(), | ||||
| 
 | ||||
|       #[cfg(feature = "tls")] | ||||
|       tls_cert_path: None, | ||||
|       #[cfg(feature = "tls")] | ||||
|       tls_cert_key_path: None, | ||||
|     }; | ||||
| 
 | ||||
|     parse_opts(&mut globals); | ||||
|     let mut backends: HashMap<String, Backend> = HashMap::new(); | ||||
| 
 | ||||
|     let proxy = Proxy { | ||||
|       globals: Arc::new(globals), | ||||
|     }; | ||||
|     proxy.entrypoint().await.unwrap() | ||||
|     parse_opts(&mut globals, &mut backends); | ||||
| 
 | ||||
|     entrypoint(Arc::new(globals), Arc::new(backends)) | ||||
|       .await | ||||
|       .unwrap() | ||||
|   }); | ||||
|   warn!("Exit the program"); | ||||
| } | ||||
| 
 | ||||
| // entrypoint creates and spawns tasks of proxy services
 | ||||
| async fn entrypoint(globals: Arc<Globals>, backends: Arc<HashMap<String, Backend>>) -> Result<()> { | ||||
|   #[cfg(feature = "forward-hyper-trust-dns")] | ||||
|   let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); | ||||
|   #[cfg(not(feature = "forward-hyper-trust-dns"))] | ||||
|   let connector = hyper_tls::HttpsConnector::new(); | ||||
|   let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); | ||||
| 
 | ||||
|   let addresses = globals.listen_sockets.clone(); | ||||
|   let futures = select_all(addresses.into_iter().map(|addr| { | ||||
|     let mut tls_enabled = false; | ||||
|     if let Some(https_port) = globals.https_port { | ||||
|       tls_enabled = https_port == (addr.port() as u32) | ||||
|     } | ||||
| 
 | ||||
|     info!("Listen address: {:?} (TLS = {})", addr, tls_enabled); | ||||
| 
 | ||||
|     let proxy = Proxy { | ||||
|       globals: globals.clone(), | ||||
|       listening_on: addr, | ||||
|       tls_enabled, | ||||
|       backends: backends.clone(), | ||||
|       forwarder: forwarder.clone(), | ||||
|     }; | ||||
|     globals.runtime_handle.spawn(proxy.start()) | ||||
|   })); | ||||
| 
 | ||||
|   // wait for all future
 | ||||
|   if let (Ok(_), _, _) = futures.await { | ||||
|     error!("Some proxy services are down"); | ||||
|   }; | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
|  |  | |||
							
								
								
									
										177
									
								
								src/proxy.rs
									
										
									
									
									
								
							
							
						
						
									
										177
									
								
								src/proxy.rs
									
										
									
									
									
								
							|  | @ -1,38 +1,159 @@ | |||
| use crate::{acceptor::PacketAcceptor, error::*, globals::Globals, log::*}; | ||||
| use futures::future::select_all; | ||||
| use hyper::Client; | ||||
| #[cfg(feature = "forward-hyper-trust-dns")] | ||||
| use hyper_trust_dns::TrustDnsResolver; | ||||
| use std::sync::Arc; | ||||
| use crate::{backend::Backend, error::*, globals::Globals, log::*}; | ||||
| use futures::{ | ||||
|   select, | ||||
|   task::{Context, Poll}, | ||||
|   Future, FutureExt, | ||||
| }; | ||||
| use hyper::{ | ||||
|   client::connect::Connect, | ||||
|   http, | ||||
|   server::conn::Http, | ||||
|   service::{service_fn, Service}, | ||||
|   Body, Client, HeaderMap, Method, Request, Response, StatusCode, | ||||
| }; | ||||
| use std::{collections::HashMap, net::SocketAddr, pin::Pin, sync::Arc}; | ||||
| use tokio::{ | ||||
|   io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, | ||||
|   net::TcpListener, | ||||
|   runtime::Handle, | ||||
|   time::Duration, | ||||
| }; | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct Proxy { | ||||
| #[allow(clippy::unnecessary_wraps)] | ||||
| fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> { | ||||
|   let response = Response::builder() | ||||
|     .status(status_code) | ||||
|     .body(Body::empty()) | ||||
|     .unwrap(); | ||||
|   Ok(response) | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone, Debug)] | ||||
| pub struct LocalExecutor { | ||||
|   runtime_handle: Handle, | ||||
| } | ||||
| 
 | ||||
| impl LocalExecutor { | ||||
|   fn new(runtime_handle: Handle) -> Self { | ||||
|     LocalExecutor { runtime_handle } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<F> hyper::rt::Executor<F> for LocalExecutor | ||||
| where | ||||
|   F: std::future::Future + Send + 'static, | ||||
|   F::Output: Send, | ||||
| { | ||||
|   fn execute(&self, fut: F) { | ||||
|     self.runtime_handle.spawn(fut); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| pub struct Proxy<T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub listening_on: SocketAddr, | ||||
|   pub tls_enabled: bool,                       // TCP待受がTLSかどうか
 | ||||
|   pub backends: Arc<HashMap<String, Backend>>, // TODO: hyper::uriで抜いたhostnameで引っ掛ける。Stringでいいのか?
 | ||||
|   pub forwarder: Arc<Client<T>>, | ||||
|   pub globals: Arc<Globals>, | ||||
| } | ||||
| impl Proxy { | ||||
|   pub async fn entrypoint(self) -> Result<()> { | ||||
|     let addresses = self.globals.listen_addresses.clone(); | ||||
|     let futures = select_all(addresses.into_iter().map(|addr| { | ||||
|       info!("Listen address: {:?}", addr); | ||||
| 
 | ||||
|       #[cfg(feature = "forward-hyper-trust-dns")] | ||||
|       let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); | ||||
|       #[cfg(not(feature = "forward-hyper-trust-dns"))] | ||||
|       let connector = hyper_tls::HttpsConnector::new(); | ||||
|       let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); | ||||
| // TODO: ここでbackendの名前単位でリクエストを分岐させる
 | ||||
| async fn handle_request( | ||||
|   req: Request<Body>, | ||||
|   client_ip: SocketAddr, | ||||
|   globals: Arc<Globals>, | ||||
| ) -> Result<Response<Body>, http::Error> { | ||||
|   // http_error(StatusCode::NOT_FOUND)
 | ||||
|   debug!("{:?}", req); | ||||
|   // if req.version() == hyper::Version::HTTP_11 {
 | ||||
|   //   Ok(Response::new(Body::from("Hello World")))
 | ||||
|   // } else {
 | ||||
|   // Note: it's usually better to return a Response
 | ||||
|   // with an appropriate StatusCode instead of an Err.
 | ||||
|   // Err("not HTTP/1.1, abort connection")
 | ||||
|   http_error(StatusCode::NOT_FOUND) | ||||
|   // }
 | ||||
|   // });
 | ||||
| } | ||||
| 
 | ||||
|       let acceptor = PacketAcceptor { | ||||
|         listening_on: addr, | ||||
|         globals: self.globals.clone(), | ||||
|         forwarder, | ||||
|       }; | ||||
|       self.globals.runtime_handle.spawn(acceptor.start()) | ||||
|     })); | ||||
| impl<T> Proxy<T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr) | ||||
|   where | ||||
|     I: AsyncRead + AsyncWrite + Send + Unpin + 'static, | ||||
|   { | ||||
|     let clients_count = self.globals.clients_count.clone(); | ||||
|     if clients_count.increment() > self.globals.max_clients { | ||||
|       clients_count.decrement(); | ||||
|       return; | ||||
|     } | ||||
| 
 | ||||
|     // wait for all future
 | ||||
|     if let (Ok(_), _, _) = futures.await { | ||||
|       error!("Some packet acceptors are down"); | ||||
|     self.globals.runtime_handle.clone().spawn(async move { | ||||
|       tokio::time::timeout( | ||||
|         self.globals.timeout + Duration::from_secs(1), | ||||
|         // server.serve_connection(stream, self),
 | ||||
|         server.serve_connection( | ||||
|           stream, | ||||
|           service_fn(move |req: Request<Body>| { | ||||
|             handle_request(req, peer_addr, self.globals.clone()) | ||||
|           }), | ||||
|         ), | ||||
|       ) | ||||
|       .await | ||||
|       .ok(); | ||||
| 
 | ||||
|       clients_count.decrement(); | ||||
|     }); | ||||
|   } | ||||
| 
 | ||||
|   async fn start_without_tls( | ||||
|     self, | ||||
|     listener: TcpListener, | ||||
|     server: Http<LocalExecutor>, | ||||
|   ) -> Result<()> { | ||||
|     let listener_service = async { | ||||
|       while let Ok((stream, _client_addr)) = listener.accept().await { | ||||
|         self | ||||
|           .clone() | ||||
|           .client_serve(stream, server.clone(), _client_addr) | ||||
|           .await; | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|     }; | ||||
|     listener_service.await?; | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   pub async fn start(self) -> Result<()> { | ||||
|     let tcp_listener = TcpListener::bind(&self.listening_on).await?; | ||||
| 
 | ||||
|     let mut server = Http::new(); | ||||
|     server.http1_keep_alive(self.globals.keepalive); | ||||
|     server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); | ||||
|     server.pipeline_flush(true); | ||||
|     let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); | ||||
|     let server = server.with_executor(executor); | ||||
| 
 | ||||
|     if self.tls_enabled { | ||||
|       info!( | ||||
|         "Start TCP proxy serving with HTTPS request for configured host names: {:?}", | ||||
|         tcp_listener.local_addr()? | ||||
|       ); | ||||
|       // #[cfg(feature = "tls")]
 | ||||
|       self.start_with_tls(tcp_listener, server).await?; | ||||
|     } else { | ||||
|       info!( | ||||
|         "Start TCP proxy serving with HTTP request for configured host names: {:?}", | ||||
|         tcp_listener.local_addr()? | ||||
|       ); | ||||
|       self.start_without_tls(tcp_listener, server).await?; | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
|  |  | |||
							
								
								
									
										77
									
								
								src/proxy_tls.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								src/proxy_tls.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,77 @@ | |||
| use crate::{ | ||||
|   constants::CERTS_WATCH_DELAY_SECS, | ||||
|   error::*, | ||||
|   log::*, | ||||
|   proxy::{LocalExecutor, Proxy}, | ||||
| }; | ||||
| use futures::{future::FutureExt, join, select}; | ||||
| use hyper::{client::connect::Connect, server::conn::Http}; | ||||
| use std::{sync::Arc, time::Duration}; | ||||
| use tokio::net::TcpListener; | ||||
| 
 | ||||
| impl<T> Proxy<T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub async fn start_with_tls( | ||||
|     self, | ||||
|     listener: TcpListener, | ||||
|     server: Http<LocalExecutor>, | ||||
|   ) -> Result<()> { | ||||
|     let cert_service = async { | ||||
|       info!("Start cert watch service for {}", self.listening_on); | ||||
|       loop { | ||||
|         for (hostname, backend) in self.backends.iter() { | ||||
|           if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { | ||||
|             if let Err(_e) = backend.update_server_config().await { | ||||
|               warn!("Failed to update certs for {}", hostname); | ||||
|             } | ||||
|           } | ||||
|         } | ||||
|         tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     let listener_service = async { | ||||
|       loop { | ||||
|         select! { | ||||
|           tcp_cnx = listener.accept().fuse() => { | ||||
|             if tcp_cnx.is_err() { | ||||
|               continue; | ||||
|             } | ||||
|             let (raw_stream, _client_addr) = tcp_cnx.unwrap(); | ||||
| 
 | ||||
|             // First check SNI
 | ||||
|             let rustls_acceptor = rustls::server::Acceptor::new().unwrap(); | ||||
|             let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls_acceptor, raw_stream); | ||||
|             let start = acceptor.await.unwrap(); | ||||
|             let client_hello = start.client_hello(); | ||||
|             debug!("SNI in ClientHello: {:?}", client_hello.server_name()); | ||||
|             // Find server config for given SNI
 | ||||
|             let svn = if let Some(svn) = client_hello.server_name() { | ||||
|               svn | ||||
|             } else { | ||||
|               info!("No SNI in ClientHello"); | ||||
|               continue; | ||||
|             }; | ||||
|             let backend_serve = if let Some(backend_serve) = self.backends.get(svn){ | ||||
|               backend_serve | ||||
|             } else { | ||||
|               info!("No configuration for the server name {} given in client_hello", svn); | ||||
|               continue; | ||||
|             }; | ||||
|             let server_config = backend_serve.get_tls_server_config(); | ||||
|             // Finally serve the TLS connection
 | ||||
|             if let Ok(stream) = start.into_stream(Arc::new(server_config.unwrap())).await { | ||||
|               self.clone().client_serve(stream, server.clone(), _client_addr).await | ||||
|             } | ||||
|           } | ||||
|           complete => break
 | ||||
|         } | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|     }; | ||||
| 
 | ||||
|     join!(listener_service, cert_service).0 | ||||
|   } | ||||
| } | ||||
							
								
								
									
										176
									
								
								src/tls.rs
									
										
									
									
									
								
							
							
						
						
									
										176
									
								
								src/tls.rs
									
										
									
									
									
								
							|  | @ -1,176 +0,0 @@ | |||
| use std::fs::File; | ||||
| use std::io::{self, BufReader, Cursor, Read}; | ||||
| use std::path::Path; | ||||
| use std::sync::Arc; | ||||
| use std::time::Duration; | ||||
| 
 | ||||
| use futures::{future::FutureExt, join, select}; | ||||
| use hyper::client::connect::Connect; | ||||
| use hyper::server::conn::Http; | ||||
| use tokio::{ | ||||
|   net::TcpListener, | ||||
|   sync::mpsc::{self, Receiver}, | ||||
| }; | ||||
| use tokio_rustls::{ | ||||
|   rustls::{Certificate, PrivateKey, ServerConfig}, | ||||
|   TlsAcceptor, | ||||
| }; | ||||
| 
 | ||||
| use crate::acceptor::{LocalExecutor, PacketAcceptor}; | ||||
| use crate::constants::CERTS_WATCH_DELAY_SECS; | ||||
| use crate::error::*; | ||||
| 
 | ||||
| pub fn create_tls_acceptor<P, P2>(certs_path: P, certs_keys_path: P2) -> io::Result<TlsAcceptor> | ||||
| where | ||||
|   P: AsRef<Path>, | ||||
|   P2: AsRef<Path>, | ||||
| { | ||||
|   let certs: Vec<_> = { | ||||
|     let certs_path_str = certs_path.as_ref().display().to_string(); | ||||
|     let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { | ||||
|       io::Error::new( | ||||
|         e.kind(), | ||||
|         format!( | ||||
|           "Unable to load the certificates [{}]: {}", | ||||
|           certs_path_str, e | ||||
|         ), | ||||
|       ) | ||||
|     })?); | ||||
|     rustls_pemfile::certs(&mut reader).map_err(|_| { | ||||
|       io::Error::new( | ||||
|         io::ErrorKind::InvalidInput, | ||||
|         "Unable to parse the certificates", | ||||
|       ) | ||||
|     })? | ||||
|   } | ||||
|   .drain(..) | ||||
|   .map(Certificate) | ||||
|   .collect(); | ||||
|   let certs_keys: Vec<_> = { | ||||
|     let certs_keys_path_str = certs_keys_path.as_ref().display().to_string(); | ||||
|     let encoded_keys = { | ||||
|       let mut encoded_keys = vec![]; | ||||
|       File::open(certs_keys_path) | ||||
|         .map_err(|e| { | ||||
|           io::Error::new( | ||||
|             e.kind(), | ||||
|             format!( | ||||
|               "Unable to load the certificate keys [{}]: {}", | ||||
|               certs_keys_path_str, e | ||||
|             ), | ||||
|           ) | ||||
|         })? | ||||
|         .read_to_end(&mut encoded_keys)?; | ||||
|       encoded_keys | ||||
|     }; | ||||
|     let mut reader = Cursor::new(encoded_keys); | ||||
|     let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { | ||||
|       io::Error::new( | ||||
|         io::ErrorKind::InvalidInput, | ||||
|         "Unable to parse the certificates private keys (PKCS8)", | ||||
|       ) | ||||
|     })?; | ||||
|     reader.set_position(0); | ||||
|     let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| { | ||||
|       io::Error::new( | ||||
|         io::ErrorKind::InvalidInput, | ||||
|         "Unable to parse the certificates private keys (RSA)", | ||||
|       ) | ||||
|     })?; | ||||
|     let mut keys = pkcs8_keys; | ||||
|     keys.append(&mut rsa_keys); | ||||
|     if keys.is_empty() { | ||||
|       return Err(io::Error::new( | ||||
|         io::ErrorKind::InvalidInput, | ||||
|         "No private keys found - Make sure that they are in PKCS#8/PEM format", | ||||
|       )); | ||||
|     } | ||||
|     keys.drain(..).map(PrivateKey).collect() | ||||
|   }; | ||||
| 
 | ||||
|   let mut server_config = certs_keys | ||||
|     .into_iter() | ||||
|     .find_map(|certs_key| { | ||||
|       let server_config_builder = ServerConfig::builder() | ||||
|         .with_safe_defaults() | ||||
|         .with_no_client_auth(); | ||||
|       if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) { | ||||
|         Some(found_config) | ||||
|       } else { | ||||
|         None | ||||
|       } | ||||
|     }) | ||||
|     .ok_or_else(|| { | ||||
|       io::Error::new( | ||||
|         io::ErrorKind::InvalidInput, | ||||
|         "Unable to find a valid certificate and key", | ||||
|       ) | ||||
|     })?; | ||||
|   server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; | ||||
|   Ok(TlsAcceptor::from(Arc::new(server_config))) | ||||
| } | ||||
| 
 | ||||
| impl<T> PacketAcceptor<T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   async fn start_https_service( | ||||
|     self, | ||||
|     mut tls_acceptor_receiver: Receiver<TlsAcceptor>, | ||||
|     listener: TcpListener, | ||||
|     server: Http<LocalExecutor>, | ||||
|   ) -> Result<()> { | ||||
|     let mut tls_acceptor: Option<TlsAcceptor> = None; | ||||
|     let listener_service = async { | ||||
|       loop { | ||||
|         select! { | ||||
|           tcp_cnx = listener.accept().fuse() => { | ||||
|             if tls_acceptor.is_none() || tcp_cnx.is_err() { | ||||
|               continue; | ||||
|             } | ||||
|             let (raw_stream, _client_addr) = tcp_cnx.unwrap(); | ||||
|             if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await { | ||||
|               self.clone().client_serve(stream, server.clone(), _client_addr).await | ||||
|             } | ||||
|           } | ||||
|           new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => { | ||||
|             if new_tls_acceptor.is_none() { | ||||
|                 break; | ||||
|             } | ||||
|             tls_acceptor = new_tls_acceptor; | ||||
|           } | ||||
|           complete => break
 | ||||
|         } | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|     }; | ||||
|     listener_service.await?; | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   pub async fn start_with_tls( | ||||
|     self, | ||||
|     listener: TcpListener, | ||||
|     server: Http<LocalExecutor>, | ||||
|   ) -> Result<()> { | ||||
|     let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone(); | ||||
|     let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone(); | ||||
|     let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1); | ||||
|     let https_service = self.start_https_service(tls_acceptor_receiver, listener, server); | ||||
|     let cert_service = async { | ||||
|       loop { | ||||
|         match create_tls_acceptor(&certs_path, &certs_keys_path) { | ||||
|           Ok(tls_acceptor) => { | ||||
|             if tls_acceptor_sender.send(tls_acceptor).await.is_err() { | ||||
|               break; | ||||
|             } | ||||
|           } | ||||
|           Err(e) => eprintln!("TLS certificates error: {}", e), | ||||
|         } | ||||
|         tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|     }; | ||||
|     return join!(https_service, cert_service).0; | ||||
|   } | ||||
| } | ||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara