diff --git a/Cargo.toml b/Cargo.toml index d333ff1..68f94c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -48,6 +48,7 @@ hyper-trust-dns = { version = "0.4.2", default-features = false, features = [ "rustls-webpki", ], optional = true } hyper-tls = "0.5.0" +rustls = "0.20.6" [dev-dependencies] diff --git a/config-example.toml b/config-example.toml index b7b3d51..38cb49e 100644 --- a/config-example.toml +++ b/config-example.toml @@ -3,29 +3,36 @@ # rust-rxpy configuration # # # ######################################## - -################################## -# Global settings # -################################## - -## Address to listen to. -listen_addresses = ['127.0.0.1:50844', '[::1]:50844'] - -[tls] -tls_cert_path = 'localhost.pem' -tls_cert_key_path = 'localhost.pem' +################################### +# Global settings # +################################### +http_port = 8080 +https_port = 8443 ################################### # Backend settings # ################################### -[[backend]] -domain = 'localhost' -## List of destinations to send data to. -## At this point, round-robin is used for load-balancing if multiple URLs are specified. -destination = ['http://192.168.0.1:3000/', 'https://192.168.0.2:3000'] -allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] -denyhosts = ['*'] [[backend]] -domain = '127.0.0.1' -destination = 'https://www.google.com/' +app_name = 'localhost' # this should be option, if null then same as hostname +hostname = 'localhost' +redirect_to_https = true +reverse_proxy = [ + { path = '*', destination = 'https://192.168.10.0:3000' }, + { path = '/path/to', destination = 'https://192.168.10.1:4000/path/to' }, +] +## List of destinations to send data to. +## At this point, round-robin is used for load-balancing if multiple URLs are specified. +allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] +denyhosts = ['*'] +tls_cert_path = 'localhost1.pem' +tls_cert_key_path = 'localhost1.pem' + + +[[backend]] +app_name = 'locahost_application' +hostname = 'localhost.localdomain' +redirect_to_https = true +reverse_proxy = [{ path = '/', destination = 'https://www.google.com/' }] +tls_cert_path = 'localhost2.pem' +tls_cert_key_path = 'localhost2.pem' diff --git a/src/acceptor.rs b/src/acceptor.rs deleted file mode 100644 index f79ad86..0000000 --- a/src/acceptor.rs +++ /dev/null @@ -1,289 +0,0 @@ -use crate::{error::*, globals::Globals, log::*}; -use futures::{ - task::{Context, Poll}, - Future, -}; -use hyper::{ - client::connect::Connect, - http, - server::conn::Http, - service::{service_fn, Service}, - Body, Client, HeaderMap, Method, Request, Response, StatusCode, -}; -use std::{net::SocketAddr, pin::Pin, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, - net::TcpListener, - runtime::Handle, - time::Duration, -}; - -#[allow(clippy::unnecessary_wraps)] -fn http_error(status_code: StatusCode) -> Result, http::Error> { - let response = Response::builder() - .status(status_code) - .body(Body::empty()) - .unwrap(); - Ok(response) -} - -#[derive(Clone, Debug)] -pub struct LocalExecutor { - runtime_handle: Handle, -} - -impl LocalExecutor { - fn new(runtime_handle: Handle) -> Self { - LocalExecutor { runtime_handle } - } -} - -impl hyper::rt::Executor for LocalExecutor -where - F: std::future::Future + Send + 'static, - F::Output: Send, -{ - fn execute(&self, fut: F) { - self.runtime_handle.spawn(fut); - } -} - -#[derive(Clone)] -pub struct PacketAcceptor -where - T: Connect + Clone + Sync + Send + 'static, -{ - pub listening_on: SocketAddr, - pub forwarder: Arc>, - pub globals: Arc, -} - -// impl Service> for PacketAcceptor -// where -// T: Connect + Clone + Sync + Send + 'static, -// { -// type Response = Response; - -// type Error = http::Error; -// type Future = Pin> + Send>>; - -// fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { -// Poll::Ready(Ok(())) -// } - -// fn call(&mut self, req: Request) -> Self::Future { -// debug!( -// "serving {:?} {:?} request to {:?}", -// req.version(), -// req.method(), -// req.uri() -// ); -// let self_inner = self.clone(); - -// // 1. check uri (domain queried host name) -// // 2. build uri to forwarding target destination -// // 3. build request from uri and body -// // 4. send request to forwarding target - -// if *req.method() == Method::GET { -// Box::pin(async move { -// // let uri = req.uri(); -// let target_uri = hyper::Uri::builder() -// .scheme("https") -// .authority("www.google.com") -// .path_and_query("/") -// .build() -// .unwrap(); -// println!("{:?}", target_uri); -// match self_inner.forwarder.get(target_uri).await { -// Ok(res) => Ok(res), -// Err(e) => { -// error!("{:?}", e); -// http_error(StatusCode::INTERNAL_SERVER_ERROR) -// } -// } -// }) -// } else { -// // let globals = &self.doh.globals; -// // let self_inner = self.clone(); -// // if req.uri().path() == globals.path { -// // Box::pin(async move { -// // let mut subscriber = None; -// // if self_inner.doh.globals.enable_auth_target { -// // subscriber = match auth::authenticate( -// // &self_inner.doh.globals, -// // &req, -// // ValidationLocation::Target, -// // &self_inner.peer_addr, -// // ) { -// // Ok((sub, aud)) => { -// // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); -// // sub -// // } -// // Err(e) => { -// // error!("{:?}", e); -// // return Ok(e); -// // } -// // }; -// // } -// // match *req.method() { -// // Method::POST => self_inner.doh.serve_post(req, subscriber).await, -// // Method::GET => self_inner.doh.serve_get(req, subscriber).await, -// // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), -// // } -// // }) -// // } else if req.uri().path() == globals.odoh_configs_path { -// // match *req.method() { -// // Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }), -// // _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), -// // } -// // } else { -// // #[cfg(not(feature = "odoh-proxy"))] -// // { -// // Box::pin(async { http_error(StatusCode::NOT_FOUND) }) -// // } -// // #[cfg(feature = "odoh-proxy")] -// // { -// // if req.uri().path() == globals.odoh_proxy_path { -// // Box::pin(async move { -// // let mut subscriber = None; -// // if self_inner.doh.globals.enable_auth_proxy { -// // subscriber = match auth::authenticate( -// // &self_inner.doh.globals, -// // &req, -// // ValidationLocation::Proxy, -// // &self_inner.peer_addr, -// // ) { -// // Ok((sub, aud)) => { -// // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); -// // sub -// // } -// // Err(e) => { -// // error!("{:?}", e); -// // return Ok(e); -// // } -// // }; -// // } -// // // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11 -// // // Golang impl.: https://github.com/cloudflare/odoh-server-go -// // // Based on the draft and Golang implementation, only post method is allowed. -// // match *req.method() { -// // Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await, -// // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), -// // } -// // }) -// // } else { -// Box::pin(async { http_error(StatusCode::NOT_FOUND) }) -// } -// // } -// // } -// // } -// } -// } - -async fn handle_request( - req: Request, - client_ip: SocketAddr, - globals: Arc, -) -> Result, http::Error> { - // http_error(StatusCode::NOT_FOUND) - debug!("{:?}", req); - // if req.version() == hyper::Version::HTTP_11 { - // Ok(Response::new(Body::from("Hello World"))) - // } else { - // Note: it's usually better to return a Response - // with an appropriate StatusCode instead of an Err. - // Err("not HTTP/1.1, abort connection") - http_error(StatusCode::NOT_FOUND) - // } - // }); -} - -impl PacketAcceptor -where - T: Connect + Clone + Sync + Send + 'static, -{ - pub async fn client_serve(self, stream: I, server: Http, peer_addr: SocketAddr) - where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, - { - let clients_count = self.globals.clients_count.clone(); - if clients_count.increment() > self.globals.max_clients { - clients_count.decrement(); - return; - } - - self.globals.runtime_handle.clone().spawn(async move { - tokio::time::timeout( - self.globals.timeout + Duration::from_secs(1), - // server.serve_connection(stream, self), - server.serve_connection( - stream, - service_fn(move |req: Request| { - handle_request(req, peer_addr, self.globals.clone()) - }), - ), - ) - .await - .ok(); - - clients_count.decrement(); - }); - } - - async fn start_without_tls( - self, - listener: TcpListener, - server: Http, - ) -> Result<()> { - let listener_service = async { - while let Ok((stream, _client_addr)) = listener.accept().await { - self - .clone() - .client_serve(stream, server.clone(), _client_addr) - .await; - } - Ok(()) as Result<()> - }; - listener_service.await?; - Ok(()) - } - - pub async fn start(self) -> Result<()> { - let tcp_listener = TcpListener::bind(&self.listening_on).await?; - - let mut server = Http::new(); - server.http1_keep_alive(self.globals.keepalive); - server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); - server.pipeline_flush(true); - let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); - let server = server.with_executor(executor); - - let tls_enabled: bool; - #[cfg(not(feature = "tls"))] - { - tls_enabled = false; - } - #[cfg(feature = "tls")] - { - tls_enabled = - self.globals.tls_cert_path.is_some() && self.globals.tls_cert_key_path.is_some(); - } - if tls_enabled { - info!( - "Start server listening on TCP with TLS: {:?}", - tcp_listener.local_addr()? - ); - #[cfg(feature = "tls")] - self.start_with_tls(tcp_listener, server).await?; - } else { - info!( - "Start server listening on TCP: {:?}", - tcp_listener.local_addr()? - ); - self.start_without_tls(tcp_listener, server).await?; - } - - Ok(()) - } -} diff --git a/src/backend.rs b/src/backend.rs new file mode 100644 index 0000000..bcf2de4 --- /dev/null +++ b/src/backend.rs @@ -0,0 +1,134 @@ +use crate::log::*; +use std::{ + collections::HashMap, + fs::File, + io::{self, BufReader, Cursor, Read}, + path::PathBuf, + sync::Mutex, +}; +use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; + +pub struct Backend { + pub app_name: String, + pub hostname: String, + pub reverse_proxy: ReverseProxy, + pub redirect_to_https: Option, + pub tls_cert_path: Option, + pub tls_cert_key_path: Option, + pub server_config: Mutex>, +} + +#[derive(Debug, Clone)] +pub struct ReverseProxy { + pub default_destination_uri: hyper::Uri, + pub destination_uris: Option>, // TODO: url pathで引っ掛ける。 +} + +impl Backend { + pub fn get_tls_server_config(&self) -> Option { + let lock = self.server_config.lock(); + if let Ok(opt) = lock { + let opt_clone = opt.clone(); + if let Some(sc) = opt_clone { + return Some(sc); + } + } + None + } + pub async fn update_server_config(&self) -> io::Result<()> { + debug!("Update TLS server config"); + let certs_path = self.tls_cert_path.as_ref().unwrap(); + let certs_keys_path = self.tls_cert_key_path.as_ref().unwrap(); + let certs: Vec<_> = { + let certs_path_str = certs_path.display().to_string(); + let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { + io::Error::new( + e.kind(), + format!( + "Unable to load the certificates [{}]: {}", + certs_path_str, e + ), + ) + })?); + rustls_pemfile::certs(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates", + ) + })? + } + .drain(..) + .map(Certificate) + .collect(); + let certs_keys: Vec<_> = { + let certs_keys_path_str = certs_keys_path.display().to_string(); + let encoded_keys = { + let mut encoded_keys = vec![]; + File::open(certs_keys_path) + .map_err(|e| { + io::Error::new( + e.kind(), + format!( + "Unable to load the certificate keys [{}]: {}", + certs_keys_path_str, e + ), + ) + })? + .read_to_end(&mut encoded_keys)?; + encoded_keys + }; + let mut reader = Cursor::new(encoded_keys); + let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates private keys (PKCS8)", + ) + })?; + reader.set_position(0); + let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates private keys (RSA)", + ) + })?; + let mut keys = pkcs8_keys; + keys.append(&mut rsa_keys); + if keys.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "No private keys found - Make sure that they are in PKCS#8/PEM format", + )); + } + keys.drain(..).map(PrivateKey).collect() + }; + + let mut server_config = certs_keys + .into_iter() + .find_map(|certs_key| { + let server_config_builder = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth(); + if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) { + Some(found_config) + } else { + None + } + }) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to find a valid certificate and key", + ) + })?; + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + + if let Ok(mut config_store) = self.server_config.lock() { + *config_store = Some(server_config); + } else { + error!("Some thing wrong to write into mutex") + } + + // server_config; + Ok(()) + } +} diff --git a/src/config.rs b/src/config.rs index 51503be..800d645 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,13 +1,45 @@ -use crate::globals::Globals; +use crate::{backend::*, constants::*, globals::*}; +use hyper::Uri; +use std::{collections::HashMap, sync::Mutex}; -#[cfg(feature = "tls")] +// #[cfg(feature = "tls")] use std::path::PathBuf; -pub fn parse_opts(globals: &mut Globals) { - #[cfg(feature = "tls")] - { - // TODO: - globals.tls_cert_path = Some(PathBuf::from(r"localhost.pem")); - globals.tls_cert_key_path = Some(PathBuf::from(r"localhost.pem")); - } +pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap) { + // TODO: + globals.listen_sockets = LISTEN_ADDRESSES + .to_vec() + .iter() + .flat_map(|x| { + vec![ + format!("{}:{}", x, HTTP_LISTEN_PORT).parse().unwrap(), + format!("{}:{}", x, HTTPS_LISTEN_PORT).parse().unwrap(), + ] + }) + .collect(); + globals.http_port = Some(HTTP_LISTEN_PORT); + globals.https_port = Some(HTTPS_LISTEN_PORT); + + // TODO: + let mut map_example: HashMap = HashMap::new(); + map_example.insert( + "/maps".to_string(), + "https://bing.com/".parse::().unwrap(), + ); + backends.insert( + "localhost".to_string(), + Backend { + app_name: "Google except for maps".to_string(), + hostname: "google.com".to_string(), + reverse_proxy: ReverseProxy { + default_destination_uri: "https://google.com/".parse::().unwrap(), + destination_uris: Some(map_example), + }, + redirect_to_https: None, // TODO: ここはHTTPの時のみの設定。tlsの存在とは排他的。 + + tls_cert_path: Some(PathBuf::from(r"localhost1.pem")), + tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")), + server_config: Mutex::new(None), + }, + ); } diff --git a/src/constants.rs b/src/constants.rs index 434dd55..1f5c5be 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,6 +1,8 @@ -pub const LISTEN_ADDRESSES: &[&str] = &["127.0.0.1:8443", "[::1]:8443"]; +pub const LISTEN_ADDRESSES: &[&str] = &["0.0.0.0", "[::]"]; +pub const HTTP_LISTEN_PORT: u32 = 8080; +pub const HTTPS_LISTEN_PORT: u32 = 8443; pub const TIMEOUT_SEC: u64 = 10; pub const MAX_CLIENTS: usize = 512; pub const MAX_CONCURRENT_STREAMS: u32 = 16; -#[cfg(feature = "tls")] +// #[cfg(feature = "tls")] pub const CERTS_WATCH_DELAY_SECS: u32 = 10; diff --git a/src/globals.rs b/src/globals.rs index 2398988..5508c7d 100644 --- a/src/globals.rs +++ b/src/globals.rs @@ -1,6 +1,4 @@ use std::net::SocketAddr; -#[cfg(feature = "tls")] -use std::path::PathBuf; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -9,7 +7,9 @@ use tokio::time::Duration; #[derive(Debug, Clone)] pub struct Globals { - pub listen_addresses: Vec, + pub listen_sockets: Vec, + pub http_port: Option, + pub https_port: Option, pub timeout: Duration, pub max_clients: usize, @@ -18,12 +18,6 @@ pub struct Globals { pub keepalive: bool, pub runtime_handle: tokio::runtime::Handle, - - #[cfg(feature = "tls")] - pub tls_cert_path: Option, - - #[cfg(feature = "tls")] - pub tls_cert_key_path: Option, } #[derive(Debug, Clone, Default)] diff --git a/src/main.rs b/src/main.rs index 77e2d37..0047535 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,18 +1,23 @@ #[global_allocator] static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; -mod acceptor; +mod backend; mod config; mod constants; mod error; mod globals; mod log; mod proxy; -#[cfg(feature = "tls")] -mod tls; +mod proxy_tls; -use crate::{config::parse_opts, constants::*, globals::Globals, log::*, proxy::Proxy}; -use std::{io::Write, sync::Arc}; +use crate::{ + backend::Backend, config::parse_opts, constants::*, error::*, globals::*, log::*, proxy::Proxy, +}; +use futures::future::select_all; +use hyper::Client; +#[cfg(feature = "forward-hyper-trust-dns")] +use hyper_trust_dns::TrustDnsResolver; +use std::{collections::HashMap, io::Write, sync::Arc}; use tokio::time::Duration; fn main() { @@ -39,35 +44,61 @@ fn main() { runtime_builder.thread_name("rust-rpxy"); let runtime = runtime_builder.build().unwrap(); - // TODO: - let listen_addresses: Vec = LISTEN_ADDRESSES - .to_vec() - .iter() - .map(|x| x.parse().unwrap()) - .collect(); - runtime.block_on(async { let mut globals = Globals { - listen_addresses, + listen_sockets: Vec::new(), + http_port: None, + https_port: None, timeout: Duration::from_secs(TIMEOUT_SEC), max_clients: MAX_CLIENTS, clients_count: Default::default(), max_concurrent_streams: MAX_CONCURRENT_STREAMS, keepalive: true, runtime_handle: runtime.handle().clone(), - - #[cfg(feature = "tls")] - tls_cert_path: None, - #[cfg(feature = "tls")] - tls_cert_key_path: None, }; - parse_opts(&mut globals); + let mut backends: HashMap = HashMap::new(); - let proxy = Proxy { - globals: Arc::new(globals), - }; - proxy.entrypoint().await.unwrap() + parse_opts(&mut globals, &mut backends); + + entrypoint(Arc::new(globals), Arc::new(backends)) + .await + .unwrap() }); warn!("Exit the program"); } + +// entrypoint creates and spawns tasks of proxy services +async fn entrypoint(globals: Arc, backends: Arc>) -> Result<()> { + #[cfg(feature = "forward-hyper-trust-dns")] + let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); + #[cfg(not(feature = "forward-hyper-trust-dns"))] + let connector = hyper_tls::HttpsConnector::new(); + let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); + + let addresses = globals.listen_sockets.clone(); + let futures = select_all(addresses.into_iter().map(|addr| { + let mut tls_enabled = false; + if let Some(https_port) = globals.https_port { + tls_enabled = https_port == (addr.port() as u32) + } + + info!("Listen address: {:?} (TLS = {})", addr, tls_enabled); + + let proxy = Proxy { + globals: globals.clone(), + listening_on: addr, + tls_enabled, + backends: backends.clone(), + forwarder: forwarder.clone(), + }; + globals.runtime_handle.spawn(proxy.start()) + })); + + // wait for all future + if let (Ok(_), _, _) = futures.await { + error!("Some proxy services are down"); + }; + + Ok(()) +} diff --git a/src/proxy.rs b/src/proxy.rs index a4b1cfc..2f8a325 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -1,38 +1,159 @@ -use crate::{acceptor::PacketAcceptor, error::*, globals::Globals, log::*}; -use futures::future::select_all; -use hyper::Client; -#[cfg(feature = "forward-hyper-trust-dns")] -use hyper_trust_dns::TrustDnsResolver; -use std::sync::Arc; +use crate::{backend::Backend, error::*, globals::Globals, log::*}; +use futures::{ + select, + task::{Context, Poll}, + Future, FutureExt, +}; +use hyper::{ + client::connect::Connect, + http, + server::conn::Http, + service::{service_fn, Service}, + Body, Client, HeaderMap, Method, Request, Response, StatusCode, +}; +use std::{collections::HashMap, net::SocketAddr, pin::Pin, sync::Arc}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpListener, + runtime::Handle, + time::Duration, +}; -#[derive(Debug, Clone)] -pub struct Proxy { +#[allow(clippy::unnecessary_wraps)] +fn http_error(status_code: StatusCode) -> Result, http::Error> { + let response = Response::builder() + .status(status_code) + .body(Body::empty()) + .unwrap(); + Ok(response) +} + +#[derive(Clone, Debug)] +pub struct LocalExecutor { + runtime_handle: Handle, +} + +impl LocalExecutor { + fn new(runtime_handle: Handle) -> Self { + LocalExecutor { runtime_handle } + } +} + +impl hyper::rt::Executor for LocalExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send, +{ + fn execute(&self, fut: F) { + self.runtime_handle.spawn(fut); + } +} + +#[derive(Clone)] +pub struct Proxy +where + T: Connect + Clone + Sync + Send + 'static, +{ + pub listening_on: SocketAddr, + pub tls_enabled: bool, // TCP待受がTLSかどうか + pub backends: Arc>, // TODO: hyper::uriで抜いたhostnameで引っ掛ける。Stringでいいのか? + pub forwarder: Arc>, pub globals: Arc, } -impl Proxy { - pub async fn entrypoint(self) -> Result<()> { - let addresses = self.globals.listen_addresses.clone(); - let futures = select_all(addresses.into_iter().map(|addr| { - info!("Listen address: {:?}", addr); - #[cfg(feature = "forward-hyper-trust-dns")] - let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); - #[cfg(not(feature = "forward-hyper-trust-dns"))] - let connector = hyper_tls::HttpsConnector::new(); - let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); +// TODO: ここでbackendの名前単位でリクエストを分岐させる +async fn handle_request( + req: Request, + client_ip: SocketAddr, + globals: Arc, +) -> Result, http::Error> { + // http_error(StatusCode::NOT_FOUND) + debug!("{:?}", req); + // if req.version() == hyper::Version::HTTP_11 { + // Ok(Response::new(Body::from("Hello World"))) + // } else { + // Note: it's usually better to return a Response + // with an appropriate StatusCode instead of an Err. + // Err("not HTTP/1.1, abort connection") + http_error(StatusCode::NOT_FOUND) + // } + // }); +} - let acceptor = PacketAcceptor { - listening_on: addr, - globals: self.globals.clone(), - forwarder, - }; - self.globals.runtime_handle.spawn(acceptor.start()) - })); +impl Proxy +where + T: Connect + Clone + Sync + Send + 'static, +{ + pub async fn client_serve(self, stream: I, server: Http, peer_addr: SocketAddr) + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let clients_count = self.globals.clients_count.clone(); + if clients_count.increment() > self.globals.max_clients { + clients_count.decrement(); + return; + } - // wait for all future - if let (Ok(_), _, _) = futures.await { - error!("Some packet acceptors are down"); + self.globals.runtime_handle.clone().spawn(async move { + tokio::time::timeout( + self.globals.timeout + Duration::from_secs(1), + // server.serve_connection(stream, self), + server.serve_connection( + stream, + service_fn(move |req: Request| { + handle_request(req, peer_addr, self.globals.clone()) + }), + ), + ) + .await + .ok(); + + clients_count.decrement(); + }); + } + + async fn start_without_tls( + self, + listener: TcpListener, + server: Http, + ) -> Result<()> { + let listener_service = async { + while let Ok((stream, _client_addr)) = listener.accept().await { + self + .clone() + .client_serve(stream, server.clone(), _client_addr) + .await; + } + Ok(()) as Result<()> }; + listener_service.await?; + Ok(()) + } + + pub async fn start(self) -> Result<()> { + let tcp_listener = TcpListener::bind(&self.listening_on).await?; + + let mut server = Http::new(); + server.http1_keep_alive(self.globals.keepalive); + server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); + server.pipeline_flush(true); + let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); + let server = server.with_executor(executor); + + if self.tls_enabled { + info!( + "Start TCP proxy serving with HTTPS request for configured host names: {:?}", + tcp_listener.local_addr()? + ); + // #[cfg(feature = "tls")] + self.start_with_tls(tcp_listener, server).await?; + } else { + info!( + "Start TCP proxy serving with HTTP request for configured host names: {:?}", + tcp_listener.local_addr()? + ); + self.start_without_tls(tcp_listener, server).await?; + } Ok(()) } diff --git a/src/proxy_tls.rs b/src/proxy_tls.rs new file mode 100644 index 0000000..e3cc95e --- /dev/null +++ b/src/proxy_tls.rs @@ -0,0 +1,77 @@ +use crate::{ + constants::CERTS_WATCH_DELAY_SECS, + error::*, + log::*, + proxy::{LocalExecutor, Proxy}, +}; +use futures::{future::FutureExt, join, select}; +use hyper::{client::connect::Connect, server::conn::Http}; +use std::{sync::Arc, time::Duration}; +use tokio::net::TcpListener; + +impl Proxy +where + T: Connect + Clone + Sync + Send + 'static, +{ + pub async fn start_with_tls( + self, + listener: TcpListener, + server: Http, + ) -> Result<()> { + let cert_service = async { + info!("Start cert watch service for {}", self.listening_on); + loop { + for (hostname, backend) in self.backends.iter() { + if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() { + if let Err(_e) = backend.update_server_config().await { + warn!("Failed to update certs for {}", hostname); + } + } + } + tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; + } + }; + + let listener_service = async { + loop { + select! { + tcp_cnx = listener.accept().fuse() => { + if tcp_cnx.is_err() { + continue; + } + let (raw_stream, _client_addr) = tcp_cnx.unwrap(); + + // First check SNI + let rustls_acceptor = rustls::server::Acceptor::new().unwrap(); + let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls_acceptor, raw_stream); + let start = acceptor.await.unwrap(); + let client_hello = start.client_hello(); + debug!("SNI in ClientHello: {:?}", client_hello.server_name()); + // Find server config for given SNI + let svn = if let Some(svn) = client_hello.server_name() { + svn + } else { + info!("No SNI in ClientHello"); + continue; + }; + let backend_serve = if let Some(backend_serve) = self.backends.get(svn){ + backend_serve + } else { + info!("No configuration for the server name {} given in client_hello", svn); + continue; + }; + let server_config = backend_serve.get_tls_server_config(); + // Finally serve the TLS connection + if let Ok(stream) = start.into_stream(Arc::new(server_config.unwrap())).await { + self.clone().client_serve(stream, server.clone(), _client_addr).await + } + } + complete => break + } + } + Ok(()) as Result<()> + }; + + join!(listener_service, cert_service).0 + } +} diff --git a/src/tls.rs b/src/tls.rs deleted file mode 100644 index 7e642c7..0000000 --- a/src/tls.rs +++ /dev/null @@ -1,176 +0,0 @@ -use std::fs::File; -use std::io::{self, BufReader, Cursor, Read}; -use std::path::Path; -use std::sync::Arc; -use std::time::Duration; - -use futures::{future::FutureExt, join, select}; -use hyper::client::connect::Connect; -use hyper::server::conn::Http; -use tokio::{ - net::TcpListener, - sync::mpsc::{self, Receiver}, -}; -use tokio_rustls::{ - rustls::{Certificate, PrivateKey, ServerConfig}, - TlsAcceptor, -}; - -use crate::acceptor::{LocalExecutor, PacketAcceptor}; -use crate::constants::CERTS_WATCH_DELAY_SECS; -use crate::error::*; - -pub fn create_tls_acceptor(certs_path: P, certs_keys_path: P2) -> io::Result -where - P: AsRef, - P2: AsRef, -{ - let certs: Vec<_> = { - let certs_path_str = certs_path.as_ref().display().to_string(); - let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { - io::Error::new( - e.kind(), - format!( - "Unable to load the certificates [{}]: {}", - certs_path_str, e - ), - ) - })?); - rustls_pemfile::certs(&mut reader).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to parse the certificates", - ) - })? - } - .drain(..) - .map(Certificate) - .collect(); - let certs_keys: Vec<_> = { - let certs_keys_path_str = certs_keys_path.as_ref().display().to_string(); - let encoded_keys = { - let mut encoded_keys = vec![]; - File::open(certs_keys_path) - .map_err(|e| { - io::Error::new( - e.kind(), - format!( - "Unable to load the certificate keys [{}]: {}", - certs_keys_path_str, e - ), - ) - })? - .read_to_end(&mut encoded_keys)?; - encoded_keys - }; - let mut reader = Cursor::new(encoded_keys); - let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to parse the certificates private keys (PKCS8)", - ) - })?; - reader.set_position(0); - let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to parse the certificates private keys (RSA)", - ) - })?; - let mut keys = pkcs8_keys; - keys.append(&mut rsa_keys); - if keys.is_empty() { - return Err(io::Error::new( - io::ErrorKind::InvalidInput, - "No private keys found - Make sure that they are in PKCS#8/PEM format", - )); - } - keys.drain(..).map(PrivateKey).collect() - }; - - let mut server_config = certs_keys - .into_iter() - .find_map(|certs_key| { - let server_config_builder = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth(); - if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) { - Some(found_config) - } else { - None - } - }) - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to find a valid certificate and key", - ) - })?; - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; - Ok(TlsAcceptor::from(Arc::new(server_config))) -} - -impl PacketAcceptor -where - T: Connect + Clone + Sync + Send + 'static, -{ - async fn start_https_service( - self, - mut tls_acceptor_receiver: Receiver, - listener: TcpListener, - server: Http, - ) -> Result<()> { - let mut tls_acceptor: Option = None; - let listener_service = async { - loop { - select! { - tcp_cnx = listener.accept().fuse() => { - if tls_acceptor.is_none() || tcp_cnx.is_err() { - continue; - } - let (raw_stream, _client_addr) = tcp_cnx.unwrap(); - if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await { - self.clone().client_serve(stream, server.clone(), _client_addr).await - } - } - new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => { - if new_tls_acceptor.is_none() { - break; - } - tls_acceptor = new_tls_acceptor; - } - complete => break - } - } - Ok(()) as Result<()> - }; - listener_service.await?; - Ok(()) - } - - pub async fn start_with_tls( - self, - listener: TcpListener, - server: Http, - ) -> Result<()> { - let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone(); - let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone(); - let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1); - let https_service = self.start_https_service(tls_acceptor_receiver, listener, server); - let cert_service = async { - loop { - match create_tls_acceptor(&certs_path, &certs_keys_path) { - Ok(tls_acceptor) => { - if tls_acceptor_sender.send(tls_acceptor).await.is_err() { - break; - } - } - Err(e) => eprintln!("TLS certificates error: {}", e), - } - tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; - } - Ok(()) as Result<()> - }; - return join!(https_service, cert_service).0; - } -}