initial commit
This commit is contained in:
parent
8a427950ee
commit
819b944a46
12 changed files with 653 additions and 0 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -1,3 +1,7 @@
|
|||
.vscode
|
||||
.private
|
||||
|
||||
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
/target/
|
||||
|
|
|
|||
52
Cargo.toml
Normal file
52
Cargo.toml
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
[package]
|
||||
name = "rust-rpxy"
|
||||
version = "0.1.0"
|
||||
authors = ["Jun Kurihara"]
|
||||
homepage = "https://github.com/junkurihara/rust-rpxy"
|
||||
repository = "https://github.com/junkurihara/rust-rpxy"
|
||||
license = "MIT"
|
||||
readme = "README.md"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[features]
|
||||
default = ["tls"]
|
||||
tls = ["tokio-rustls", "rustls-pemfile"]
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.57"
|
||||
clap = { version = "3.2.4", features = ["std", "cargo", "wrap_help"] }
|
||||
env_logger = "0.9.0"
|
||||
futures = "0.3.21"
|
||||
hyper = { version = "0.14.19", default-features = false, features = [
|
||||
"server",
|
||||
"http1",
|
||||
"http2",
|
||||
"stream",
|
||||
] }
|
||||
log = "0.4.17"
|
||||
mimalloc = { version = "0.1.29", default-features = false }
|
||||
tokio = { version = "1.19.2", features = [
|
||||
"net",
|
||||
"rt-multi-thread",
|
||||
"parking_lot",
|
||||
"time",
|
||||
"sync",
|
||||
"macros",
|
||||
] }
|
||||
tokio-rustls = { version = "0.23.4", features = [
|
||||
"early-data",
|
||||
], optional = true }
|
||||
rustls-pemfile = { version = "1.0.0", optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
|
||||
|
||||
[profile.release]
|
||||
codegen-units = 1
|
||||
incremental = false
|
||||
lto = "fat"
|
||||
opt-level = 3
|
||||
panic = "abort"
|
||||
31
config-example.toml
Normal file
31
config-example.toml
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
########################################
|
||||
# #
|
||||
# rust-rxpy configuration #
|
||||
# #
|
||||
########################################
|
||||
|
||||
##################################
|
||||
# Global settings #
|
||||
##################################
|
||||
|
||||
## Address to listen to.
|
||||
listen_addresses = ['127.0.0.1:50844', '[::1]:50844']
|
||||
|
||||
[tls]
|
||||
tls_cert_path = 'localhost.pem'
|
||||
tls_cert_key_path = 'localhost.pem'
|
||||
|
||||
###################################
|
||||
# Backend settings #
|
||||
###################################
|
||||
[[backend]]
|
||||
domain = 'localhost'
|
||||
## List of destinations to send data to.
|
||||
## At this point, round-robin is used for load-balancing if multiple URLs are specified.
|
||||
destination = ['http://192.168.0.1:3000/', 'https://192.168.0.2:3000']
|
||||
allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24']
|
||||
denyhosts = ['*']
|
||||
|
||||
[[backend]]
|
||||
domain = '127.0.0.1'
|
||||
destination = 'https://www.google.com/'
|
||||
219
src/acceptor.rs
Normal file
219
src/acceptor.rs
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
use crate::{error::*, globals::Globals, log::*};
|
||||
|
||||
use futures::{
|
||||
task::{Context, Poll},
|
||||
Future,
|
||||
};
|
||||
use hyper::http;
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::{Body, HeaderMap, Method, Request, Response, StatusCode};
|
||||
use std::{net::SocketAddr, pin::Pin, sync::Arc};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||
net::TcpListener,
|
||||
runtime::Handle,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
#[allow(clippy::unnecessary_wraps)]
|
||||
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
||||
let response = Response::builder()
|
||||
.status(status_code)
|
||||
.body(Body::empty())
|
||||
.unwrap();
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct LocalExecutor {
|
||||
runtime_handle: Handle,
|
||||
}
|
||||
|
||||
impl LocalExecutor {
|
||||
fn new(runtime_handle: Handle) -> Self {
|
||||
LocalExecutor { runtime_handle }
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> hyper::rt::Executor<F> for LocalExecutor
|
||||
where
|
||||
F: std::future::Future + Send + 'static,
|
||||
F::Output: Send,
|
||||
{
|
||||
fn execute(&self, fut: F) {
|
||||
self.runtime_handle.spawn(fut);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PacketAcceptor {
|
||||
pub listening_on: SocketAddr,
|
||||
pub globals: Arc<Globals>,
|
||||
}
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
impl hyper::service::Service<http::Request<Body>> for PacketAcceptor {
|
||||
type Response = Response<Body>;
|
||||
|
||||
type Error = http::Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, req: Request<Body>) -> Self::Future {
|
||||
debug!("\nserve:{:?}\n{:?}", self.listening_on, req);
|
||||
// let globals = &self.doh.globals;
|
||||
// let self_inner = self.clone();
|
||||
// if req.uri().path() == globals.path {
|
||||
// Box::pin(async move {
|
||||
// let mut subscriber = None;
|
||||
// if self_inner.doh.globals.enable_auth_target {
|
||||
// subscriber = match auth::authenticate(
|
||||
// &self_inner.doh.globals,
|
||||
// &req,
|
||||
// ValidationLocation::Target,
|
||||
// &self_inner.peer_addr,
|
||||
// ) {
|
||||
// Ok((sub, aud)) => {
|
||||
// debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
|
||||
// sub
|
||||
// }
|
||||
// Err(e) => {
|
||||
// error!("{:?}", e);
|
||||
// return Ok(e);
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
// match *req.method() {
|
||||
// Method::POST => self_inner.doh.serve_post(req, subscriber).await,
|
||||
// Method::GET => self_inner.doh.serve_get(req, subscriber).await,
|
||||
// _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
|
||||
// }
|
||||
// })
|
||||
// } else if req.uri().path() == globals.odoh_configs_path {
|
||||
// match *req.method() {
|
||||
// Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }),
|
||||
// _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
||||
// }
|
||||
// } else {
|
||||
// #[cfg(not(feature = "odoh-proxy"))]
|
||||
// {
|
||||
// Box::pin(async { http_error(StatusCode::NOT_FOUND) })
|
||||
// }
|
||||
// #[cfg(feature = "odoh-proxy")]
|
||||
// {
|
||||
// if req.uri().path() == globals.odoh_proxy_path {
|
||||
// Box::pin(async move {
|
||||
// let mut subscriber = None;
|
||||
// if self_inner.doh.globals.enable_auth_proxy {
|
||||
// subscriber = match auth::authenticate(
|
||||
// &self_inner.doh.globals,
|
||||
// &req,
|
||||
// ValidationLocation::Proxy,
|
||||
// &self_inner.peer_addr,
|
||||
// ) {
|
||||
// Ok((sub, aud)) => {
|
||||
// debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
|
||||
// sub
|
||||
// }
|
||||
// Err(e) => {
|
||||
// error!("{:?}", e);
|
||||
// return Ok(e);
|
||||
// }
|
||||
// };
|
||||
// }
|
||||
// // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11
|
||||
// // Golang impl.: https://github.com/cloudflare/odoh-server-go
|
||||
// // Based on the draft and Golang implementation, only post method is allowed.
|
||||
// match *req.method() {
|
||||
// Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await,
|
||||
// _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
|
||||
// }
|
||||
// })
|
||||
// } else {
|
||||
Box::pin(async { http_error(StatusCode::NOT_FOUND) })
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
impl PacketAcceptor {
|
||||
pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr)
|
||||
where
|
||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let clients_count = self.globals.clients_count.clone();
|
||||
if clients_count.increment() > self.globals.max_clients {
|
||||
clients_count.decrement();
|
||||
return;
|
||||
}
|
||||
self.globals.runtime_handle.clone().spawn(async move {
|
||||
tokio::time::timeout(
|
||||
self.globals.timeout + Duration::from_secs(1),
|
||||
server.serve_connection(stream, self),
|
||||
)
|
||||
.await
|
||||
.ok();
|
||||
clients_count.decrement();
|
||||
});
|
||||
}
|
||||
|
||||
async fn start_without_tls(
|
||||
self,
|
||||
listener: TcpListener,
|
||||
server: Http<LocalExecutor>,
|
||||
) -> Result<()> {
|
||||
let listener_service = async {
|
||||
while let Ok((stream, _client_addr)) = listener.accept().await {
|
||||
self
|
||||
.clone()
|
||||
.client_serve(stream, server.clone(), _client_addr)
|
||||
.await;
|
||||
}
|
||||
Ok(()) as Result<()>
|
||||
};
|
||||
listener_service.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn start(self) -> Result<()> {
|
||||
let tcp_listener = TcpListener::bind(&self.listening_on).await?;
|
||||
|
||||
let mut server = Http::new();
|
||||
server.http1_keep_alive(self.globals.keepalive);
|
||||
server.http2_max_concurrent_streams(self.globals.max_concurrent_streams);
|
||||
server.pipeline_flush(true);
|
||||
let executor = LocalExecutor::new(self.globals.runtime_handle.clone());
|
||||
let server = server.with_executor(executor);
|
||||
|
||||
let tls_enabled: bool;
|
||||
#[cfg(not(feature = "tls"))]
|
||||
{
|
||||
tls_enabled = false;
|
||||
}
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
tls_enabled =
|
||||
self.globals.tls_cert_path.is_some() && self.globals.tls_cert_key_path.is_some();
|
||||
}
|
||||
if tls_enabled {
|
||||
info!(
|
||||
"Start server listening on TCP with TLS: {:?}",
|
||||
tcp_listener.local_addr()?
|
||||
);
|
||||
#[cfg(feature = "tls")]
|
||||
self.start_with_tls(tcp_listener, server).await?;
|
||||
} else {
|
||||
info!(
|
||||
"Start server listening on TCP: {:?}",
|
||||
tcp_listener.local_addr()?
|
||||
);
|
||||
self.start_without_tls(tcp_listener, server).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
13
src/config.rs
Normal file
13
src/config.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
use crate::globals::Globals;
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub fn parse_opts(globals: &mut Globals) {
|
||||
#[cfg(feature = "tls")]
|
||||
{
|
||||
// TODO:
|
||||
globals.tls_cert_path = Some(PathBuf::from(r"localhost.pem"));
|
||||
globals.tls_cert_key_path = Some(PathBuf::from(r"localhost.pem"));
|
||||
}
|
||||
}
|
||||
6
src/constants.rs
Normal file
6
src/constants.rs
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
pub const LISTEN_ADDRESSES: &[&str] = &["127.0.0.1:8443", "[::1]:8443"];
|
||||
pub const TIMEOUT_SEC: u64 = 10;
|
||||
pub const MAX_CLIENTS: usize = 512;
|
||||
pub const MAX_CONCURRENT_STREAMS: u32 = 16;
|
||||
#[cfg(feature = "tls")]
|
||||
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
|
||||
1
src/error.rs
Normal file
1
src/error.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub use anyhow::{anyhow, bail, ensure, Context, Error, Result};
|
||||
53
src/globals.rs
Normal file
53
src/globals.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
use std::net::SocketAddr;
|
||||
#[cfg(feature = "tls")]
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
use tokio::time::Duration;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Globals {
|
||||
pub listen_addresses: Vec<SocketAddr>,
|
||||
|
||||
pub timeout: Duration,
|
||||
pub max_clients: usize,
|
||||
pub clients_count: ClientsCount,
|
||||
pub max_concurrent_streams: u32,
|
||||
pub keepalive: bool,
|
||||
|
||||
pub runtime_handle: tokio::runtime::Handle,
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub tls_cert_path: Option<PathBuf>,
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
pub tls_cert_key_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ClientsCount(Arc<AtomicUsize>);
|
||||
|
||||
impl ClientsCount {
|
||||
pub fn current(&self) -> usize {
|
||||
self.0.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn increment(&self) -> usize {
|
||||
self.0.fetch_add(1, Ordering::Relaxed)
|
||||
}
|
||||
|
||||
pub fn decrement(&self) -> usize {
|
||||
let mut count;
|
||||
while {
|
||||
count = self.0.load(Ordering::Relaxed);
|
||||
count > 0
|
||||
&& self
|
||||
.0
|
||||
.compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed)
|
||||
!= Ok(count)
|
||||
} {}
|
||||
count
|
||||
}
|
||||
}
|
||||
1
src/log.rs
Normal file
1
src/log.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub use log::{debug, error, info, warn};
|
||||
73
src/main.rs
Normal file
73
src/main.rs
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
#[global_allocator]
|
||||
static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
mod acceptor;
|
||||
mod config;
|
||||
mod constants;
|
||||
mod error;
|
||||
mod globals;
|
||||
mod log;
|
||||
mod proxy;
|
||||
#[cfg(feature = "tls")]
|
||||
mod tls;
|
||||
|
||||
use crate::{config::parse_opts, constants::*, globals::Globals, log::*, proxy::Proxy};
|
||||
use std::{io::Write, sync::Arc};
|
||||
use tokio::time::Duration;
|
||||
|
||||
fn main() {
|
||||
// env::set_var("RUST_LOG", "info");
|
||||
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info"))
|
||||
.format(|buf, record| {
|
||||
let ts = buf.timestamp();
|
||||
writeln!(
|
||||
buf,
|
||||
"{} [{}] {}",
|
||||
ts,
|
||||
record.level(),
|
||||
// record.target(),
|
||||
record.args(),
|
||||
// record.file().unwrap_or("unknown"),
|
||||
// record.line().unwrap_or(0),
|
||||
)
|
||||
})
|
||||
.init();
|
||||
info!("Start http (reverse) proxy");
|
||||
|
||||
let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
|
||||
runtime_builder.enable_all();
|
||||
runtime_builder.thread_name("rust-rpxy");
|
||||
let runtime = runtime_builder.build().unwrap();
|
||||
|
||||
// TODO:
|
||||
let listen_addresses: Vec<std::net::SocketAddr> = LISTEN_ADDRESSES
|
||||
.to_vec()
|
||||
.iter()
|
||||
.map(|x| x.parse().unwrap())
|
||||
.collect();
|
||||
|
||||
runtime.block_on(async {
|
||||
let mut globals = Globals {
|
||||
listen_addresses,
|
||||
timeout: Duration::from_secs(TIMEOUT_SEC),
|
||||
max_clients: MAX_CLIENTS,
|
||||
clients_count: Default::default(),
|
||||
max_concurrent_streams: MAX_CONCURRENT_STREAMS,
|
||||
keepalive: true,
|
||||
runtime_handle: runtime.handle().clone(),
|
||||
|
||||
#[cfg(feature = "tls")]
|
||||
tls_cert_path: None,
|
||||
#[cfg(feature = "tls")]
|
||||
tls_cert_key_path: None,
|
||||
};
|
||||
|
||||
parse_opts(&mut globals);
|
||||
|
||||
let proxy = Proxy {
|
||||
globals: Arc::new(globals),
|
||||
};
|
||||
proxy.entrypoint().await.unwrap()
|
||||
});
|
||||
warn!("Exit the program");
|
||||
}
|
||||
28
src/proxy.rs
Normal file
28
src/proxy.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
use crate::{acceptor::PacketAcceptor, error::*, globals::Globals, log::*};
|
||||
use futures::future::select_all;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Proxy {
|
||||
pub globals: Arc<Globals>,
|
||||
}
|
||||
impl Proxy {
|
||||
pub async fn entrypoint(self) -> Result<()> {
|
||||
let addresses = self.globals.listen_addresses.clone();
|
||||
let futures = select_all(addresses.into_iter().map(|addr| {
|
||||
info!("Listen address: {:?}", addr);
|
||||
let acceptor = PacketAcceptor {
|
||||
listening_on: addr,
|
||||
globals: self.globals.clone(),
|
||||
};
|
||||
self.globals.runtime_handle.spawn(acceptor.start())
|
||||
}));
|
||||
|
||||
// wait for all future
|
||||
if let (Ok(_), _, _) = futures.await {
|
||||
error!("Some packet acceptors are down");
|
||||
};
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
172
src/tls.rs
Normal file
172
src/tls.rs
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
use std::fs::File;
|
||||
use std::io::{self, BufReader, Cursor, Read};
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::{future::FutureExt, join, select};
|
||||
use hyper::server::conn::Http;
|
||||
use tokio::{
|
||||
net::TcpListener,
|
||||
sync::mpsc::{self, Receiver},
|
||||
};
|
||||
use tokio_rustls::{
|
||||
rustls::{Certificate, PrivateKey, ServerConfig},
|
||||
TlsAcceptor,
|
||||
};
|
||||
|
||||
use crate::acceptor::{LocalExecutor, PacketAcceptor};
|
||||
use crate::constants::CERTS_WATCH_DELAY_SECS;
|
||||
use crate::error::*;
|
||||
|
||||
pub fn create_tls_acceptor<P, P2>(certs_path: P, certs_keys_path: P2) -> io::Result<TlsAcceptor>
|
||||
where
|
||||
P: AsRef<Path>,
|
||||
P2: AsRef<Path>,
|
||||
{
|
||||
let certs: Vec<_> = {
|
||||
let certs_path_str = certs_path.as_ref().display().to_string();
|
||||
let mut reader = BufReader::new(File::open(certs_path).map_err(|e| {
|
||||
io::Error::new(
|
||||
e.kind(),
|
||||
format!(
|
||||
"Unable to load the certificates [{}]: {}",
|
||||
certs_path_str, e
|
||||
),
|
||||
)
|
||||
})?);
|
||||
rustls_pemfile::certs(&mut reader).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unable to parse the certificates",
|
||||
)
|
||||
})?
|
||||
}
|
||||
.drain(..)
|
||||
.map(Certificate)
|
||||
.collect();
|
||||
let certs_keys: Vec<_> = {
|
||||
let certs_keys_path_str = certs_keys_path.as_ref().display().to_string();
|
||||
let encoded_keys = {
|
||||
let mut encoded_keys = vec![];
|
||||
File::open(certs_keys_path)
|
||||
.map_err(|e| {
|
||||
io::Error::new(
|
||||
e.kind(),
|
||||
format!(
|
||||
"Unable to load the certificate keys [{}]: {}",
|
||||
certs_keys_path_str, e
|
||||
),
|
||||
)
|
||||
})?
|
||||
.read_to_end(&mut encoded_keys)?;
|
||||
encoded_keys
|
||||
};
|
||||
let mut reader = Cursor::new(encoded_keys);
|
||||
let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unable to parse the certificates private keys (PKCS8)",
|
||||
)
|
||||
})?;
|
||||
reader.set_position(0);
|
||||
let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unable to parse the certificates private keys (RSA)",
|
||||
)
|
||||
})?;
|
||||
let mut keys = pkcs8_keys;
|
||||
keys.append(&mut rsa_keys);
|
||||
if keys.is_empty() {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"No private keys found - Make sure that they are in PKCS#8/PEM format",
|
||||
));
|
||||
}
|
||||
keys.drain(..).map(PrivateKey).collect()
|
||||
};
|
||||
|
||||
let mut server_config = certs_keys
|
||||
.into_iter()
|
||||
.find_map(|certs_key| {
|
||||
let server_config_builder = ServerConfig::builder()
|
||||
.with_safe_defaults()
|
||||
.with_no_client_auth();
|
||||
if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) {
|
||||
Some(found_config)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidInput,
|
||||
"Unable to find a valid certificate and key",
|
||||
)
|
||||
})?;
|
||||
server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||
Ok(TlsAcceptor::from(Arc::new(server_config)))
|
||||
}
|
||||
|
||||
impl PacketAcceptor {
|
||||
async fn start_https_service(
|
||||
self,
|
||||
mut tls_acceptor_receiver: Receiver<TlsAcceptor>,
|
||||
listener: TcpListener,
|
||||
server: Http<LocalExecutor>,
|
||||
) -> Result<()> {
|
||||
let mut tls_acceptor: Option<TlsAcceptor> = None;
|
||||
let listener_service = async {
|
||||
loop {
|
||||
select! {
|
||||
tcp_cnx = listener.accept().fuse() => {
|
||||
if tls_acceptor.is_none() || tcp_cnx.is_err() {
|
||||
continue;
|
||||
}
|
||||
let (raw_stream, _client_addr) = tcp_cnx.unwrap();
|
||||
if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
|
||||
self.clone().client_serve(stream, server.clone(), _client_addr).await
|
||||
}
|
||||
}
|
||||
new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => {
|
||||
if new_tls_acceptor.is_none() {
|
||||
break;
|
||||
}
|
||||
tls_acceptor = new_tls_acceptor;
|
||||
}
|
||||
complete => break
|
||||
}
|
||||
}
|
||||
Ok(()) as Result<()>
|
||||
};
|
||||
listener_service.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn start_with_tls(
|
||||
self,
|
||||
listener: TcpListener,
|
||||
server: Http<LocalExecutor>,
|
||||
) -> Result<()> {
|
||||
let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone();
|
||||
let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone();
|
||||
let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1);
|
||||
let https_service = self.start_https_service(tls_acceptor_receiver, listener, server);
|
||||
let cert_service = async {
|
||||
loop {
|
||||
match create_tls_acceptor(&certs_path, &certs_keys_path) {
|
||||
Ok(tls_acceptor) => {
|
||||
if tls_acceptor_sender.send(tls_acceptor).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => eprintln!("TLS certificates error: {}", e),
|
||||
}
|
||||
tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await;
|
||||
}
|
||||
Ok(()) as Result<()>
|
||||
};
|
||||
return join!(https_service, cert_service).0;
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue