From 819b944a46cfb21fae7a0f32b0c8579ee4428213 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Thu, 16 Jun 2022 17:13:28 -0400 Subject: [PATCH] initial commit --- .gitignore | 4 + Cargo.toml | 52 +++++++++++ config-example.toml | 31 +++++++ src/acceptor.rs | 219 ++++++++++++++++++++++++++++++++++++++++++++ src/config.rs | 13 +++ src/constants.rs | 6 ++ src/error.rs | 1 + src/globals.rs | 53 +++++++++++ src/log.rs | 1 + src/main.rs | 73 +++++++++++++++ src/proxy.rs | 28 ++++++ src/tls.rs | 172 ++++++++++++++++++++++++++++++++++ 12 files changed, 653 insertions(+) create mode 100644 Cargo.toml create mode 100644 config-example.toml create mode 100644 src/acceptor.rs create mode 100644 src/config.rs create mode 100644 src/constants.rs create mode 100644 src/error.rs create mode 100644 src/globals.rs create mode 100644 src/log.rs create mode 100644 src/main.rs create mode 100644 src/proxy.rs create mode 100644 src/tls.rs diff --git a/.gitignore b/.gitignore index 088ba6b..02474f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +.vscode +.private + + # Generated by Cargo # will have compiled files and executables /target/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..f361f17 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,52 @@ +[package] +name = "rust-rpxy" +version = "0.1.0" +authors = ["Jun Kurihara"] +homepage = "https://github.com/junkurihara/rust-rpxy" +repository = "https://github.com/junkurihara/rust-rpxy" +license = "MIT" +readme = "README.md" +edition = "2021" +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +default = ["tls"] +tls = ["tokio-rustls", "rustls-pemfile"] + +[dependencies] +anyhow = "1.0.57" +clap = { version = "3.2.4", features = ["std", "cargo", "wrap_help"] } +env_logger = "0.9.0" +futures = "0.3.21" +hyper = { version = "0.14.19", default-features = false, features = [ + "server", + "http1", + "http2", + "stream", +] } +log = "0.4.17" +mimalloc = { version = "0.1.29", default-features = false } +tokio = { version = "1.19.2", features = [ + "net", + "rt-multi-thread", + "parking_lot", + "time", + "sync", + "macros", +] } +tokio-rustls = { version = "0.23.4", features = [ + "early-data", +], optional = true } +rustls-pemfile = { version = "1.0.0", optional = true } + +[dev-dependencies] + + +[profile.release] +codegen-units = 1 +incremental = false +lto = "fat" +opt-level = 3 +panic = "abort" diff --git a/config-example.toml b/config-example.toml new file mode 100644 index 0000000..b7b3d51 --- /dev/null +++ b/config-example.toml @@ -0,0 +1,31 @@ +######################################## +# # +# rust-rxpy configuration # +# # +######################################## + +################################## +# Global settings # +################################## + +## Address to listen to. +listen_addresses = ['127.0.0.1:50844', '[::1]:50844'] + +[tls] +tls_cert_path = 'localhost.pem' +tls_cert_key_path = 'localhost.pem' + +################################### +# Backend settings # +################################### +[[backend]] +domain = 'localhost' +## List of destinations to send data to. +## At this point, round-robin is used for load-balancing if multiple URLs are specified. +destination = ['http://192.168.0.1:3000/', 'https://192.168.0.2:3000'] +allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24'] +denyhosts = ['*'] + +[[backend]] +domain = '127.0.0.1' +destination = 'https://www.google.com/' diff --git a/src/acceptor.rs b/src/acceptor.rs new file mode 100644 index 0000000..003ba57 --- /dev/null +++ b/src/acceptor.rs @@ -0,0 +1,219 @@ +use crate::{error::*, globals::Globals, log::*}; + +use futures::{ + task::{Context, Poll}, + Future, +}; +use hyper::http; +use hyper::server::conn::Http; +use hyper::{Body, HeaderMap, Method, Request, Response, StatusCode}; +use std::{net::SocketAddr, pin::Pin, sync::Arc}; +use tokio::{ + io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, + net::TcpListener, + runtime::Handle, + time::Duration, +}; + +#[allow(clippy::unnecessary_wraps)] +fn http_error(status_code: StatusCode) -> Result, http::Error> { + let response = Response::builder() + .status(status_code) + .body(Body::empty()) + .unwrap(); + Ok(response) +} + +#[derive(Clone, Debug)] +pub struct LocalExecutor { + runtime_handle: Handle, +} + +impl LocalExecutor { + fn new(runtime_handle: Handle) -> Self { + LocalExecutor { runtime_handle } + } +} + +impl hyper::rt::Executor for LocalExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send, +{ + fn execute(&self, fut: F) { + self.runtime_handle.spawn(fut); + } +} + +#[derive(Clone)] +pub struct PacketAcceptor { + pub listening_on: SocketAddr, + pub globals: Arc, +} + +#[allow(clippy::type_complexity)] +impl hyper::service::Service> for PacketAcceptor { + type Response = Response; + + type Error = http::Error; + type Future = Pin> + Send>>; + + fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + debug!("\nserve:{:?}\n{:?}", self.listening_on, req); + // let globals = &self.doh.globals; + // let self_inner = self.clone(); + // if req.uri().path() == globals.path { + // Box::pin(async move { + // let mut subscriber = None; + // if self_inner.doh.globals.enable_auth_target { + // subscriber = match auth::authenticate( + // &self_inner.doh.globals, + // &req, + // ValidationLocation::Target, + // &self_inner.peer_addr, + // ) { + // Ok((sub, aud)) => { + // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); + // sub + // } + // Err(e) => { + // error!("{:?}", e); + // return Ok(e); + // } + // }; + // } + // match *req.method() { + // Method::POST => self_inner.doh.serve_post(req, subscriber).await, + // Method::GET => self_inner.doh.serve_get(req, subscriber).await, + // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), + // } + // }) + // } else if req.uri().path() == globals.odoh_configs_path { + // match *req.method() { + // Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }), + // _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), + // } + // } else { + // #[cfg(not(feature = "odoh-proxy"))] + // { + // Box::pin(async { http_error(StatusCode::NOT_FOUND) }) + // } + // #[cfg(feature = "odoh-proxy")] + // { + // if req.uri().path() == globals.odoh_proxy_path { + // Box::pin(async move { + // let mut subscriber = None; + // if self_inner.doh.globals.enable_auth_proxy { + // subscriber = match auth::authenticate( + // &self_inner.doh.globals, + // &req, + // ValidationLocation::Proxy, + // &self_inner.peer_addr, + // ) { + // Ok((sub, aud)) => { + // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); + // sub + // } + // Err(e) => { + // error!("{:?}", e); + // return Ok(e); + // } + // }; + // } + // // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11 + // // Golang impl.: https://github.com/cloudflare/odoh-server-go + // // Based on the draft and Golang implementation, only post method is allowed. + // match *req.method() { + // Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await, + // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), + // } + // }) + // } else { + Box::pin(async { http_error(StatusCode::NOT_FOUND) }) + // } + // } + // } + } +} + +impl PacketAcceptor { + pub async fn client_serve(self, stream: I, server: Http, peer_addr: SocketAddr) + where + I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + { + let clients_count = self.globals.clients_count.clone(); + if clients_count.increment() > self.globals.max_clients { + clients_count.decrement(); + return; + } + self.globals.runtime_handle.clone().spawn(async move { + tokio::time::timeout( + self.globals.timeout + Duration::from_secs(1), + server.serve_connection(stream, self), + ) + .await + .ok(); + clients_count.decrement(); + }); + } + + async fn start_without_tls( + self, + listener: TcpListener, + server: Http, + ) -> Result<()> { + let listener_service = async { + while let Ok((stream, _client_addr)) = listener.accept().await { + self + .clone() + .client_serve(stream, server.clone(), _client_addr) + .await; + } + Ok(()) as Result<()> + }; + listener_service.await?; + Ok(()) + } + + pub async fn start(self) -> Result<()> { + let tcp_listener = TcpListener::bind(&self.listening_on).await?; + + let mut server = Http::new(); + server.http1_keep_alive(self.globals.keepalive); + server.http2_max_concurrent_streams(self.globals.max_concurrent_streams); + server.pipeline_flush(true); + let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); + let server = server.with_executor(executor); + + let tls_enabled: bool; + #[cfg(not(feature = "tls"))] + { + tls_enabled = false; + } + #[cfg(feature = "tls")] + { + tls_enabled = + self.globals.tls_cert_path.is_some() && self.globals.tls_cert_key_path.is_some(); + } + if tls_enabled { + info!( + "Start server listening on TCP with TLS: {:?}", + tcp_listener.local_addr()? + ); + #[cfg(feature = "tls")] + self.start_with_tls(tcp_listener, server).await?; + } else { + info!( + "Start server listening on TCP: {:?}", + tcp_listener.local_addr()? + ); + self.start_without_tls(tcp_listener, server).await?; + } + + Ok(()) + } +} diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..51503be --- /dev/null +++ b/src/config.rs @@ -0,0 +1,13 @@ +use crate::globals::Globals; + +#[cfg(feature = "tls")] +use std::path::PathBuf; + +pub fn parse_opts(globals: &mut Globals) { + #[cfg(feature = "tls")] + { + // TODO: + globals.tls_cert_path = Some(PathBuf::from(r"localhost.pem")); + globals.tls_cert_key_path = Some(PathBuf::from(r"localhost.pem")); + } +} diff --git a/src/constants.rs b/src/constants.rs new file mode 100644 index 0000000..434dd55 --- /dev/null +++ b/src/constants.rs @@ -0,0 +1,6 @@ +pub const LISTEN_ADDRESSES: &[&str] = &["127.0.0.1:8443", "[::1]:8443"]; +pub const TIMEOUT_SEC: u64 = 10; +pub const MAX_CLIENTS: usize = 512; +pub const MAX_CONCURRENT_STREAMS: u32 = 16; +#[cfg(feature = "tls")] +pub const CERTS_WATCH_DELAY_SECS: u32 = 10; diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..6ac9a6b --- /dev/null +++ b/src/error.rs @@ -0,0 +1 @@ +pub use anyhow::{anyhow, bail, ensure, Context, Error, Result}; diff --git a/src/globals.rs b/src/globals.rs new file mode 100644 index 0000000..2398988 --- /dev/null +++ b/src/globals.rs @@ -0,0 +1,53 @@ +use std::net::SocketAddr; +#[cfg(feature = "tls")] +use std::path::PathBuf; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tokio::time::Duration; + +#[derive(Debug, Clone)] +pub struct Globals { + pub listen_addresses: Vec, + + pub timeout: Duration, + pub max_clients: usize, + pub clients_count: ClientsCount, + pub max_concurrent_streams: u32, + pub keepalive: bool, + + pub runtime_handle: tokio::runtime::Handle, + + #[cfg(feature = "tls")] + pub tls_cert_path: Option, + + #[cfg(feature = "tls")] + pub tls_cert_key_path: Option, +} + +#[derive(Debug, Clone, Default)] +pub struct ClientsCount(Arc); + +impl ClientsCount { + pub fn current(&self) -> usize { + self.0.load(Ordering::Relaxed) + } + + pub fn increment(&self) -> usize { + self.0.fetch_add(1, Ordering::Relaxed) + } + + pub fn decrement(&self) -> usize { + let mut count; + while { + count = self.0.load(Ordering::Relaxed); + count > 0 + && self + .0 + .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) + != Ok(count) + } {} + count + } +} diff --git a/src/log.rs b/src/log.rs new file mode 100644 index 0000000..cb4253e --- /dev/null +++ b/src/log.rs @@ -0,0 +1 @@ +pub use log::{debug, error, info, warn}; diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..77e2d37 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,73 @@ +#[global_allocator] +static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc; + +mod acceptor; +mod config; +mod constants; +mod error; +mod globals; +mod log; +mod proxy; +#[cfg(feature = "tls")] +mod tls; + +use crate::{config::parse_opts, constants::*, globals::Globals, log::*, proxy::Proxy}; +use std::{io::Write, sync::Arc}; +use tokio::time::Duration; + +fn main() { + // env::set_var("RUST_LOG", "info"); + env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")) + .format(|buf, record| { + let ts = buf.timestamp(); + writeln!( + buf, + "{} [{}] {}", + ts, + record.level(), + // record.target(), + record.args(), + // record.file().unwrap_or("unknown"), + // record.line().unwrap_or(0), + ) + }) + .init(); + info!("Start http (reverse) proxy"); + + let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); + runtime_builder.enable_all(); + runtime_builder.thread_name("rust-rpxy"); + let runtime = runtime_builder.build().unwrap(); + + // TODO: + let listen_addresses: Vec = LISTEN_ADDRESSES + .to_vec() + .iter() + .map(|x| x.parse().unwrap()) + .collect(); + + runtime.block_on(async { + let mut globals = Globals { + listen_addresses, + timeout: Duration::from_secs(TIMEOUT_SEC), + max_clients: MAX_CLIENTS, + clients_count: Default::default(), + max_concurrent_streams: MAX_CONCURRENT_STREAMS, + keepalive: true, + runtime_handle: runtime.handle().clone(), + + #[cfg(feature = "tls")] + tls_cert_path: None, + #[cfg(feature = "tls")] + tls_cert_key_path: None, + }; + + parse_opts(&mut globals); + + let proxy = Proxy { + globals: Arc::new(globals), + }; + proxy.entrypoint().await.unwrap() + }); + warn!("Exit the program"); +} diff --git a/src/proxy.rs b/src/proxy.rs new file mode 100644 index 0000000..ed2f9fb --- /dev/null +++ b/src/proxy.rs @@ -0,0 +1,28 @@ +use crate::{acceptor::PacketAcceptor, error::*, globals::Globals, log::*}; +use futures::future::select_all; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct Proxy { + pub globals: Arc, +} +impl Proxy { + pub async fn entrypoint(self) -> Result<()> { + let addresses = self.globals.listen_addresses.clone(); + let futures = select_all(addresses.into_iter().map(|addr| { + info!("Listen address: {:?}", addr); + let acceptor = PacketAcceptor { + listening_on: addr, + globals: self.globals.clone(), + }; + self.globals.runtime_handle.spawn(acceptor.start()) + })); + + // wait for all future + if let (Ok(_), _, _) = futures.await { + error!("Some packet acceptors are down"); + }; + + Ok(()) + } +} diff --git a/src/tls.rs b/src/tls.rs new file mode 100644 index 0000000..687e949 --- /dev/null +++ b/src/tls.rs @@ -0,0 +1,172 @@ +use std::fs::File; +use std::io::{self, BufReader, Cursor, Read}; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +use futures::{future::FutureExt, join, select}; +use hyper::server::conn::Http; +use tokio::{ + net::TcpListener, + sync::mpsc::{self, Receiver}, +}; +use tokio_rustls::{ + rustls::{Certificate, PrivateKey, ServerConfig}, + TlsAcceptor, +}; + +use crate::acceptor::{LocalExecutor, PacketAcceptor}; +use crate::constants::CERTS_WATCH_DELAY_SECS; +use crate::error::*; + +pub fn create_tls_acceptor(certs_path: P, certs_keys_path: P2) -> io::Result +where + P: AsRef, + P2: AsRef, +{ + let certs: Vec<_> = { + let certs_path_str = certs_path.as_ref().display().to_string(); + let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { + io::Error::new( + e.kind(), + format!( + "Unable to load the certificates [{}]: {}", + certs_path_str, e + ), + ) + })?); + rustls_pemfile::certs(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates", + ) + })? + } + .drain(..) + .map(Certificate) + .collect(); + let certs_keys: Vec<_> = { + let certs_keys_path_str = certs_keys_path.as_ref().display().to_string(); + let encoded_keys = { + let mut encoded_keys = vec![]; + File::open(certs_keys_path) + .map_err(|e| { + io::Error::new( + e.kind(), + format!( + "Unable to load the certificate keys [{}]: {}", + certs_keys_path_str, e + ), + ) + })? + .read_to_end(&mut encoded_keys)?; + encoded_keys + }; + let mut reader = Cursor::new(encoded_keys); + let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates private keys (PKCS8)", + ) + })?; + reader.set_position(0); + let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to parse the certificates private keys (RSA)", + ) + })?; + let mut keys = pkcs8_keys; + keys.append(&mut rsa_keys); + if keys.is_empty() { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "No private keys found - Make sure that they are in PKCS#8/PEM format", + )); + } + keys.drain(..).map(PrivateKey).collect() + }; + + let mut server_config = certs_keys + .into_iter() + .find_map(|certs_key| { + let server_config_builder = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth(); + if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) { + Some(found_config) + } else { + None + } + }) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to find a valid certificate and key", + ) + })?; + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + Ok(TlsAcceptor::from(Arc::new(server_config))) +} + +impl PacketAcceptor { + async fn start_https_service( + self, + mut tls_acceptor_receiver: Receiver, + listener: TcpListener, + server: Http, + ) -> Result<()> { + let mut tls_acceptor: Option = None; + let listener_service = async { + loop { + select! { + tcp_cnx = listener.accept().fuse() => { + if tls_acceptor.is_none() || tcp_cnx.is_err() { + continue; + } + let (raw_stream, _client_addr) = tcp_cnx.unwrap(); + if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await { + self.clone().client_serve(stream, server.clone(), _client_addr).await + } + } + new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => { + if new_tls_acceptor.is_none() { + break; + } + tls_acceptor = new_tls_acceptor; + } + complete => break + } + } + Ok(()) as Result<()> + }; + listener_service.await?; + Ok(()) + } + + pub async fn start_with_tls( + self, + listener: TcpListener, + server: Http, + ) -> Result<()> { + let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone(); + let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone(); + let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1); + let https_service = self.start_https_service(tls_acceptor_receiver, listener, server); + let cert_service = async { + loop { + match create_tls_acceptor(&certs_path, &certs_keys_path) { + Ok(tls_acceptor) => { + if tls_acceptor_sender.send(tls_acceptor).await.is_err() { + break; + } + } + Err(e) => eprintln!("TLS certificates error: {}", e), + } + tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await; + } + Ok(()) as Result<()> + }; + return join!(https_service, cert_service).0; + } +}