remake architecture to handle multiple tls endpoints
This commit is contained in:
parent
634d556ea9
commit
99e6bce992
11 changed files with 490 additions and 556 deletions
|
|
@ -48,6 +48,7 @@ hyper-trust-dns = { version = "0.4.2", default-features = false, features = [
|
||||||
"rustls-webpki",
|
"rustls-webpki",
|
||||||
], optional = true }
|
], optional = true }
|
||||||
hyper-tls = "0.5.0"
|
hyper-tls = "0.5.0"
|
||||||
|
rustls = "0.20.6"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,29 +3,36 @@
|
||||||
# rust-rxpy configuration #
|
# rust-rxpy configuration #
|
||||||
# #
|
# #
|
||||||
########################################
|
########################################
|
||||||
|
###################################
|
||||||
##################################
|
# Global settings #
|
||||||
# Global settings #
|
###################################
|
||||||
##################################
|
http_port = 8080
|
||||||
|
https_port = 8443
|
||||||
## Address to listen to.
|
|
||||||
listen_addresses = ['127.0.0.1:50844', '[::1]:50844']
|
|
||||||
|
|
||||||
[tls]
|
|
||||||
tls_cert_path = 'localhost.pem'
|
|
||||||
tls_cert_key_path = 'localhost.pem'
|
|
||||||
|
|
||||||
###################################
|
###################################
|
||||||
# Backend settings #
|
# Backend settings #
|
||||||
###################################
|
###################################
|
||||||
[[backend]]
|
|
||||||
domain = 'localhost'
|
|
||||||
## List of destinations to send data to.
|
|
||||||
## At this point, round-robin is used for load-balancing if multiple URLs are specified.
|
|
||||||
destination = ['http://192.168.0.1:3000/', 'https://192.168.0.2:3000']
|
|
||||||
allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24']
|
|
||||||
denyhosts = ['*']
|
|
||||||
|
|
||||||
[[backend]]
|
[[backend]]
|
||||||
domain = '127.0.0.1'
|
app_name = 'localhost' # this should be option, if null then same as hostname
|
||||||
destination = 'https://www.google.com/'
|
hostname = 'localhost'
|
||||||
|
redirect_to_https = true
|
||||||
|
reverse_proxy = [
|
||||||
|
{ path = '*', destination = 'https://192.168.10.0:3000' },
|
||||||
|
{ path = '/path/to', destination = 'https://192.168.10.1:4000/path/to' },
|
||||||
|
]
|
||||||
|
## List of destinations to send data to.
|
||||||
|
## At this point, round-robin is used for load-balancing if multiple URLs are specified.
|
||||||
|
allowhosts = ['127.0.0.1', '::1', '192.168.10.0/24']
|
||||||
|
denyhosts = ['*']
|
||||||
|
tls_cert_path = 'localhost1.pem'
|
||||||
|
tls_cert_key_path = 'localhost1.pem'
|
||||||
|
|
||||||
|
|
||||||
|
[[backend]]
|
||||||
|
app_name = 'locahost_application'
|
||||||
|
hostname = 'localhost.localdomain'
|
||||||
|
redirect_to_https = true
|
||||||
|
reverse_proxy = [{ path = '/', destination = 'https://www.google.com/' }]
|
||||||
|
tls_cert_path = 'localhost2.pem'
|
||||||
|
tls_cert_key_path = 'localhost2.pem'
|
||||||
|
|
|
||||||
289
src/acceptor.rs
289
src/acceptor.rs
|
|
@ -1,289 +0,0 @@
|
||||||
use crate::{error::*, globals::Globals, log::*};
|
|
||||||
use futures::{
|
|
||||||
task::{Context, Poll},
|
|
||||||
Future,
|
|
||||||
};
|
|
||||||
use hyper::{
|
|
||||||
client::connect::Connect,
|
|
||||||
http,
|
|
||||||
server::conn::Http,
|
|
||||||
service::{service_fn, Service},
|
|
||||||
Body, Client, HeaderMap, Method, Request, Response, StatusCode,
|
|
||||||
};
|
|
||||||
use std::{net::SocketAddr, pin::Pin, sync::Arc};
|
|
||||||
use tokio::{
|
|
||||||
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
|
||||||
net::TcpListener,
|
|
||||||
runtime::Handle,
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[allow(clippy::unnecessary_wraps)]
|
|
||||||
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
|
||||||
let response = Response::builder()
|
|
||||||
.status(status_code)
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
Ok(response)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug)]
|
|
||||||
pub struct LocalExecutor {
|
|
||||||
runtime_handle: Handle,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl LocalExecutor {
|
|
||||||
fn new(runtime_handle: Handle) -> Self {
|
|
||||||
LocalExecutor { runtime_handle }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<F> hyper::rt::Executor<F> for LocalExecutor
|
|
||||||
where
|
|
||||||
F: std::future::Future + Send + 'static,
|
|
||||||
F::Output: Send,
|
|
||||||
{
|
|
||||||
fn execute(&self, fut: F) {
|
|
||||||
self.runtime_handle.spawn(fut);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub struct PacketAcceptor<T>
|
|
||||||
where
|
|
||||||
T: Connect + Clone + Sync + Send + 'static,
|
|
||||||
{
|
|
||||||
pub listening_on: SocketAddr,
|
|
||||||
pub forwarder: Arc<Client<T>>,
|
|
||||||
pub globals: Arc<Globals>,
|
|
||||||
}
|
|
||||||
|
|
||||||
// impl<T> Service<http::Request<Body>> for PacketAcceptor<T>
|
|
||||||
// where
|
|
||||||
// T: Connect + Clone + Sync + Send + 'static,
|
|
||||||
// {
|
|
||||||
// type Response = Response<Body>;
|
|
||||||
|
|
||||||
// type Error = http::Error;
|
|
||||||
// type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
|
||||||
|
|
||||||
// fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
|
||||||
// Poll::Ready(Ok(()))
|
|
||||||
// }
|
|
||||||
|
|
||||||
// fn call(&mut self, req: Request<Body>) -> Self::Future {
|
|
||||||
// debug!(
|
|
||||||
// "serving {:?} {:?} request to {:?}",
|
|
||||||
// req.version(),
|
|
||||||
// req.method(),
|
|
||||||
// req.uri()
|
|
||||||
// );
|
|
||||||
// let self_inner = self.clone();
|
|
||||||
|
|
||||||
// // 1. check uri (domain queried host name)
|
|
||||||
// // 2. build uri to forwarding target destination
|
|
||||||
// // 3. build request from uri and body
|
|
||||||
// // 4. send request to forwarding target
|
|
||||||
|
|
||||||
// if *req.method() == Method::GET {
|
|
||||||
// Box::pin(async move {
|
|
||||||
// // let uri = req.uri();
|
|
||||||
// let target_uri = hyper::Uri::builder()
|
|
||||||
// .scheme("https")
|
|
||||||
// .authority("www.google.com")
|
|
||||||
// .path_and_query("/")
|
|
||||||
// .build()
|
|
||||||
// .unwrap();
|
|
||||||
// println!("{:?}", target_uri);
|
|
||||||
// match self_inner.forwarder.get(target_uri).await {
|
|
||||||
// Ok(res) => Ok(res),
|
|
||||||
// Err(e) => {
|
|
||||||
// error!("{:?}", e);
|
|
||||||
// http_error(StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// })
|
|
||||||
// } else {
|
|
||||||
// // let globals = &self.doh.globals;
|
|
||||||
// // let self_inner = self.clone();
|
|
||||||
// // if req.uri().path() == globals.path {
|
|
||||||
// // Box::pin(async move {
|
|
||||||
// // let mut subscriber = None;
|
|
||||||
// // if self_inner.doh.globals.enable_auth_target {
|
|
||||||
// // subscriber = match auth::authenticate(
|
|
||||||
// // &self_inner.doh.globals,
|
|
||||||
// // &req,
|
|
||||||
// // ValidationLocation::Target,
|
|
||||||
// // &self_inner.peer_addr,
|
|
||||||
// // ) {
|
|
||||||
// // Ok((sub, aud)) => {
|
|
||||||
// // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
|
|
||||||
// // sub
|
|
||||||
// // }
|
|
||||||
// // Err(e) => {
|
|
||||||
// // error!("{:?}", e);
|
|
||||||
// // return Ok(e);
|
|
||||||
// // }
|
|
||||||
// // };
|
|
||||||
// // }
|
|
||||||
// // match *req.method() {
|
|
||||||
// // Method::POST => self_inner.doh.serve_post(req, subscriber).await,
|
|
||||||
// // Method::GET => self_inner.doh.serve_get(req, subscriber).await,
|
|
||||||
// // _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
|
|
||||||
// // }
|
|
||||||
// // })
|
|
||||||
// // } else if req.uri().path() == globals.odoh_configs_path {
|
|
||||||
// // match *req.method() {
|
|
||||||
// // Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }),
|
|
||||||
// // _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
|
|
||||||
// // }
|
|
||||||
// // } else {
|
|
||||||
// // #[cfg(not(feature = "odoh-proxy"))]
|
|
||||||
// // {
|
|
||||||
// // Box::pin(async { http_error(StatusCode::NOT_FOUND) })
|
|
||||||
// // }
|
|
||||||
// // #[cfg(feature = "odoh-proxy")]
|
|
||||||
// // {
|
|
||||||
// // if req.uri().path() == globals.odoh_proxy_path {
|
|
||||||
// // Box::pin(async move {
|
|
||||||
// // let mut subscriber = None;
|
|
||||||
// // if self_inner.doh.globals.enable_auth_proxy {
|
|
||||||
// // subscriber = match auth::authenticate(
|
|
||||||
// // &self_inner.doh.globals,
|
|
||||||
// // &req,
|
|
||||||
// // ValidationLocation::Proxy,
|
|
||||||
// // &self_inner.peer_addr,
|
|
||||||
// // ) {
|
|
||||||
// // Ok((sub, aud)) => {
|
|
||||||
// // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
|
|
||||||
// // sub
|
|
||||||
// // }
|
|
||||||
// // Err(e) => {
|
|
||||||
// // error!("{:?}", e);
|
|
||||||
// // return Ok(e);
|
|
||||||
// // }
|
|
||||||
// // };
|
|
||||||
// // }
|
|
||||||
// // // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11
|
|
||||||
// // // Golang impl.: https://github.com/cloudflare/odoh-server-go
|
|
||||||
// // // Based on the draft and Golang implementation, only post method is allowed.
|
|
||||||
// // match *req.method() {
|
|
||||||
// // Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await,
|
|
||||||
// // _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
|
|
||||||
// // }
|
|
||||||
// // })
|
|
||||||
// // } else {
|
|
||||||
// Box::pin(async { http_error(StatusCode::NOT_FOUND) })
|
|
||||||
// }
|
|
||||||
// // }
|
|
||||||
// // }
|
|
||||||
// // }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
async fn handle_request(
|
|
||||||
req: Request<Body>,
|
|
||||||
client_ip: SocketAddr,
|
|
||||||
globals: Arc<Globals>,
|
|
||||||
) -> Result<Response<Body>, http::Error> {
|
|
||||||
// http_error(StatusCode::NOT_FOUND)
|
|
||||||
debug!("{:?}", req);
|
|
||||||
// if req.version() == hyper::Version::HTTP_11 {
|
|
||||||
// Ok(Response::new(Body::from("Hello World")))
|
|
||||||
// } else {
|
|
||||||
// Note: it's usually better to return a Response
|
|
||||||
// with an appropriate StatusCode instead of an Err.
|
|
||||||
// Err("not HTTP/1.1, abort connection")
|
|
||||||
http_error(StatusCode::NOT_FOUND)
|
|
||||||
// }
|
|
||||||
// });
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> PacketAcceptor<T>
|
|
||||||
where
|
|
||||||
T: Connect + Clone + Sync + Send + 'static,
|
|
||||||
{
|
|
||||||
pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr)
|
|
||||||
where
|
|
||||||
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
|
||||||
{
|
|
||||||
let clients_count = self.globals.clients_count.clone();
|
|
||||||
if clients_count.increment() > self.globals.max_clients {
|
|
||||||
clients_count.decrement();
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
self.globals.runtime_handle.clone().spawn(async move {
|
|
||||||
tokio::time::timeout(
|
|
||||||
self.globals.timeout + Duration::from_secs(1),
|
|
||||||
// server.serve_connection(stream, self),
|
|
||||||
server.serve_connection(
|
|
||||||
stream,
|
|
||||||
service_fn(move |req: Request<Body>| {
|
|
||||||
handle_request(req, peer_addr, self.globals.clone())
|
|
||||||
}),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.ok();
|
|
||||||
|
|
||||||
clients_count.decrement();
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn start_without_tls(
|
|
||||||
self,
|
|
||||||
listener: TcpListener,
|
|
||||||
server: Http<LocalExecutor>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let listener_service = async {
|
|
||||||
while let Ok((stream, _client_addr)) = listener.accept().await {
|
|
||||||
self
|
|
||||||
.clone()
|
|
||||||
.client_serve(stream, server.clone(), _client_addr)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
Ok(()) as Result<()>
|
|
||||||
};
|
|
||||||
listener_service.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start(self) -> Result<()> {
|
|
||||||
let tcp_listener = TcpListener::bind(&self.listening_on).await?;
|
|
||||||
|
|
||||||
let mut server = Http::new();
|
|
||||||
server.http1_keep_alive(self.globals.keepalive);
|
|
||||||
server.http2_max_concurrent_streams(self.globals.max_concurrent_streams);
|
|
||||||
server.pipeline_flush(true);
|
|
||||||
let executor = LocalExecutor::new(self.globals.runtime_handle.clone());
|
|
||||||
let server = server.with_executor(executor);
|
|
||||||
|
|
||||||
let tls_enabled: bool;
|
|
||||||
#[cfg(not(feature = "tls"))]
|
|
||||||
{
|
|
||||||
tls_enabled = false;
|
|
||||||
}
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
{
|
|
||||||
tls_enabled =
|
|
||||||
self.globals.tls_cert_path.is_some() && self.globals.tls_cert_key_path.is_some();
|
|
||||||
}
|
|
||||||
if tls_enabled {
|
|
||||||
info!(
|
|
||||||
"Start server listening on TCP with TLS: {:?}",
|
|
||||||
tcp_listener.local_addr()?
|
|
||||||
);
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
self.start_with_tls(tcp_listener, server).await?;
|
|
||||||
} else {
|
|
||||||
info!(
|
|
||||||
"Start server listening on TCP: {:?}",
|
|
||||||
tcp_listener.local_addr()?
|
|
||||||
);
|
|
||||||
self.start_without_tls(tcp_listener, server).await?;
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
134
src/backend.rs
Normal file
134
src/backend.rs
Normal file
|
|
@ -0,0 +1,134 @@
|
||||||
|
use crate::log::*;
|
||||||
|
use std::{
|
||||||
|
collections::HashMap,
|
||||||
|
fs::File,
|
||||||
|
io::{self, BufReader, Cursor, Read},
|
||||||
|
path::PathBuf,
|
||||||
|
sync::Mutex,
|
||||||
|
};
|
||||||
|
use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig};
|
||||||
|
|
||||||
|
pub struct Backend {
|
||||||
|
pub app_name: String,
|
||||||
|
pub hostname: String,
|
||||||
|
pub reverse_proxy: ReverseProxy,
|
||||||
|
pub redirect_to_https: Option<bool>,
|
||||||
|
pub tls_cert_path: Option<PathBuf>,
|
||||||
|
pub tls_cert_key_path: Option<PathBuf>,
|
||||||
|
pub server_config: Mutex<Option<ServerConfig>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ReverseProxy {
|
||||||
|
pub default_destination_uri: hyper::Uri,
|
||||||
|
pub destination_uris: Option<HashMap<String, hyper::Uri>>, // TODO: url pathで引っ掛ける。
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Backend {
|
||||||
|
pub fn get_tls_server_config(&self) -> Option<ServerConfig> {
|
||||||
|
let lock = self.server_config.lock();
|
||||||
|
if let Ok(opt) = lock {
|
||||||
|
let opt_clone = opt.clone();
|
||||||
|
if let Some(sc) = opt_clone {
|
||||||
|
return Some(sc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
pub async fn update_server_config(&self) -> io::Result<()> {
|
||||||
|
debug!("Update TLS server config");
|
||||||
|
let certs_path = self.tls_cert_path.as_ref().unwrap();
|
||||||
|
let certs_keys_path = self.tls_cert_key_path.as_ref().unwrap();
|
||||||
|
let certs: Vec<_> = {
|
||||||
|
let certs_path_str = certs_path.display().to_string();
|
||||||
|
let mut reader = BufReader::new(File::open(certs_path).map_err(|e| {
|
||||||
|
io::Error::new(
|
||||||
|
e.kind(),
|
||||||
|
format!(
|
||||||
|
"Unable to load the certificates [{}]: {}",
|
||||||
|
certs_path_str, e
|
||||||
|
),
|
||||||
|
)
|
||||||
|
})?);
|
||||||
|
rustls_pemfile::certs(&mut reader).map_err(|_| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"Unable to parse the certificates",
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
.drain(..)
|
||||||
|
.map(Certificate)
|
||||||
|
.collect();
|
||||||
|
let certs_keys: Vec<_> = {
|
||||||
|
let certs_keys_path_str = certs_keys_path.display().to_string();
|
||||||
|
let encoded_keys = {
|
||||||
|
let mut encoded_keys = vec![];
|
||||||
|
File::open(certs_keys_path)
|
||||||
|
.map_err(|e| {
|
||||||
|
io::Error::new(
|
||||||
|
e.kind(),
|
||||||
|
format!(
|
||||||
|
"Unable to load the certificate keys [{}]: {}",
|
||||||
|
certs_keys_path_str, e
|
||||||
|
),
|
||||||
|
)
|
||||||
|
})?
|
||||||
|
.read_to_end(&mut encoded_keys)?;
|
||||||
|
encoded_keys
|
||||||
|
};
|
||||||
|
let mut reader = Cursor::new(encoded_keys);
|
||||||
|
let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"Unable to parse the certificates private keys (PKCS8)",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
reader.set_position(0);
|
||||||
|
let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"Unable to parse the certificates private keys (RSA)",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
let mut keys = pkcs8_keys;
|
||||||
|
keys.append(&mut rsa_keys);
|
||||||
|
if keys.is_empty() {
|
||||||
|
return Err(io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"No private keys found - Make sure that they are in PKCS#8/PEM format",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
keys.drain(..).map(PrivateKey).collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut server_config = certs_keys
|
||||||
|
.into_iter()
|
||||||
|
.find_map(|certs_key| {
|
||||||
|
let server_config_builder = ServerConfig::builder()
|
||||||
|
.with_safe_defaults()
|
||||||
|
.with_no_client_auth();
|
||||||
|
if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) {
|
||||||
|
Some(found_config)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.ok_or_else(|| {
|
||||||
|
io::Error::new(
|
||||||
|
io::ErrorKind::InvalidInput,
|
||||||
|
"Unable to find a valid certificate and key",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
||||||
|
|
||||||
|
if let Ok(mut config_store) = self.server_config.lock() {
|
||||||
|
*config_store = Some(server_config);
|
||||||
|
} else {
|
||||||
|
error!("Some thing wrong to write into mutex")
|
||||||
|
}
|
||||||
|
|
||||||
|
// server_config;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,13 +1,45 @@
|
||||||
use crate::globals::Globals;
|
use crate::{backend::*, constants::*, globals::*};
|
||||||
|
use hyper::Uri;
|
||||||
|
use std::{collections::HashMap, sync::Mutex};
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
// #[cfg(feature = "tls")]
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
pub fn parse_opts(globals: &mut Globals) {
|
pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>) {
|
||||||
#[cfg(feature = "tls")]
|
// TODO:
|
||||||
{
|
globals.listen_sockets = LISTEN_ADDRESSES
|
||||||
// TODO:
|
.to_vec()
|
||||||
globals.tls_cert_path = Some(PathBuf::from(r"localhost.pem"));
|
.iter()
|
||||||
globals.tls_cert_key_path = Some(PathBuf::from(r"localhost.pem"));
|
.flat_map(|x| {
|
||||||
}
|
vec![
|
||||||
|
format!("{}:{}", x, HTTP_LISTEN_PORT).parse().unwrap(),
|
||||||
|
format!("{}:{}", x, HTTPS_LISTEN_PORT).parse().unwrap(),
|
||||||
|
]
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
globals.http_port = Some(HTTP_LISTEN_PORT);
|
||||||
|
globals.https_port = Some(HTTPS_LISTEN_PORT);
|
||||||
|
|
||||||
|
// TODO:
|
||||||
|
let mut map_example: HashMap<String, Uri> = HashMap::new();
|
||||||
|
map_example.insert(
|
||||||
|
"/maps".to_string(),
|
||||||
|
"https://bing.com/".parse::<Uri>().unwrap(),
|
||||||
|
);
|
||||||
|
backends.insert(
|
||||||
|
"localhost".to_string(),
|
||||||
|
Backend {
|
||||||
|
app_name: "Google except for maps".to_string(),
|
||||||
|
hostname: "google.com".to_string(),
|
||||||
|
reverse_proxy: ReverseProxy {
|
||||||
|
default_destination_uri: "https://google.com/".parse::<Uri>().unwrap(),
|
||||||
|
destination_uris: Some(map_example),
|
||||||
|
},
|
||||||
|
redirect_to_https: None, // TODO: ここはHTTPの時のみの設定。tlsの存在とは排他的。
|
||||||
|
|
||||||
|
tls_cert_path: Some(PathBuf::from(r"localhost1.pem")),
|
||||||
|
tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")),
|
||||||
|
server_config: Mutex::new(None),
|
||||||
|
},
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
pub const LISTEN_ADDRESSES: &[&str] = &["127.0.0.1:8443", "[::1]:8443"];
|
pub const LISTEN_ADDRESSES: &[&str] = &["0.0.0.0", "[::]"];
|
||||||
|
pub const HTTP_LISTEN_PORT: u32 = 8080;
|
||||||
|
pub const HTTPS_LISTEN_PORT: u32 = 8443;
|
||||||
pub const TIMEOUT_SEC: u64 = 10;
|
pub const TIMEOUT_SEC: u64 = 10;
|
||||||
pub const MAX_CLIENTS: usize = 512;
|
pub const MAX_CLIENTS: usize = 512;
|
||||||
pub const MAX_CONCURRENT_STREAMS: u32 = 16;
|
pub const MAX_CONCURRENT_STREAMS: u32 = 16;
|
||||||
#[cfg(feature = "tls")]
|
// #[cfg(feature = "tls")]
|
||||||
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
|
pub const CERTS_WATCH_DELAY_SECS: u32 = 10;
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
use std::path::PathBuf;
|
|
||||||
use std::sync::{
|
use std::sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
Arc,
|
Arc,
|
||||||
|
|
@ -9,7 +7,9 @@ use tokio::time::Duration;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Globals {
|
pub struct Globals {
|
||||||
pub listen_addresses: Vec<SocketAddr>,
|
pub listen_sockets: Vec<SocketAddr>,
|
||||||
|
pub http_port: Option<u32>,
|
||||||
|
pub https_port: Option<u32>,
|
||||||
|
|
||||||
pub timeout: Duration,
|
pub timeout: Duration,
|
||||||
pub max_clients: usize,
|
pub max_clients: usize,
|
||||||
|
|
@ -18,12 +18,6 @@ pub struct Globals {
|
||||||
pub keepalive: bool,
|
pub keepalive: bool,
|
||||||
|
|
||||||
pub runtime_handle: tokio::runtime::Handle,
|
pub runtime_handle: tokio::runtime::Handle,
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
pub tls_cert_path: Option<PathBuf>,
|
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
pub tls_cert_key_path: Option<PathBuf>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Default)]
|
#[derive(Debug, Clone, Default)]
|
||||||
|
|
|
||||||
77
src/main.rs
77
src/main.rs
|
|
@ -1,18 +1,23 @@
|
||||||
#[global_allocator]
|
#[global_allocator]
|
||||||
static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
static ALLOC: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||||
|
|
||||||
mod acceptor;
|
mod backend;
|
||||||
mod config;
|
mod config;
|
||||||
mod constants;
|
mod constants;
|
||||||
mod error;
|
mod error;
|
||||||
mod globals;
|
mod globals;
|
||||||
mod log;
|
mod log;
|
||||||
mod proxy;
|
mod proxy;
|
||||||
#[cfg(feature = "tls")]
|
mod proxy_tls;
|
||||||
mod tls;
|
|
||||||
|
|
||||||
use crate::{config::parse_opts, constants::*, globals::Globals, log::*, proxy::Proxy};
|
use crate::{
|
||||||
use std::{io::Write, sync::Arc};
|
backend::Backend, config::parse_opts, constants::*, error::*, globals::*, log::*, proxy::Proxy,
|
||||||
|
};
|
||||||
|
use futures::future::select_all;
|
||||||
|
use hyper::Client;
|
||||||
|
#[cfg(feature = "forward-hyper-trust-dns")]
|
||||||
|
use hyper_trust_dns::TrustDnsResolver;
|
||||||
|
use std::{collections::HashMap, io::Write, sync::Arc};
|
||||||
use tokio::time::Duration;
|
use tokio::time::Duration;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
|
@ -39,35 +44,61 @@ fn main() {
|
||||||
runtime_builder.thread_name("rust-rpxy");
|
runtime_builder.thread_name("rust-rpxy");
|
||||||
let runtime = runtime_builder.build().unwrap();
|
let runtime = runtime_builder.build().unwrap();
|
||||||
|
|
||||||
// TODO:
|
|
||||||
let listen_addresses: Vec<std::net::SocketAddr> = LISTEN_ADDRESSES
|
|
||||||
.to_vec()
|
|
||||||
.iter()
|
|
||||||
.map(|x| x.parse().unwrap())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
runtime.block_on(async {
|
runtime.block_on(async {
|
||||||
let mut globals = Globals {
|
let mut globals = Globals {
|
||||||
listen_addresses,
|
listen_sockets: Vec::new(),
|
||||||
|
http_port: None,
|
||||||
|
https_port: None,
|
||||||
timeout: Duration::from_secs(TIMEOUT_SEC),
|
timeout: Duration::from_secs(TIMEOUT_SEC),
|
||||||
max_clients: MAX_CLIENTS,
|
max_clients: MAX_CLIENTS,
|
||||||
clients_count: Default::default(),
|
clients_count: Default::default(),
|
||||||
max_concurrent_streams: MAX_CONCURRENT_STREAMS,
|
max_concurrent_streams: MAX_CONCURRENT_STREAMS,
|
||||||
keepalive: true,
|
keepalive: true,
|
||||||
runtime_handle: runtime.handle().clone(),
|
runtime_handle: runtime.handle().clone(),
|
||||||
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
tls_cert_path: None,
|
|
||||||
#[cfg(feature = "tls")]
|
|
||||||
tls_cert_key_path: None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
parse_opts(&mut globals);
|
let mut backends: HashMap<String, Backend> = HashMap::new();
|
||||||
|
|
||||||
let proxy = Proxy {
|
parse_opts(&mut globals, &mut backends);
|
||||||
globals: Arc::new(globals),
|
|
||||||
};
|
entrypoint(Arc::new(globals), Arc::new(backends))
|
||||||
proxy.entrypoint().await.unwrap()
|
.await
|
||||||
|
.unwrap()
|
||||||
});
|
});
|
||||||
warn!("Exit the program");
|
warn!("Exit the program");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// entrypoint creates and spawns tasks of proxy services
|
||||||
|
async fn entrypoint(globals: Arc<Globals>, backends: Arc<HashMap<String, Backend>>) -> Result<()> {
|
||||||
|
#[cfg(feature = "forward-hyper-trust-dns")]
|
||||||
|
let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector();
|
||||||
|
#[cfg(not(feature = "forward-hyper-trust-dns"))]
|
||||||
|
let connector = hyper_tls::HttpsConnector::new();
|
||||||
|
let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector));
|
||||||
|
|
||||||
|
let addresses = globals.listen_sockets.clone();
|
||||||
|
let futures = select_all(addresses.into_iter().map(|addr| {
|
||||||
|
let mut tls_enabled = false;
|
||||||
|
if let Some(https_port) = globals.https_port {
|
||||||
|
tls_enabled = https_port == (addr.port() as u32)
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Listen address: {:?} (TLS = {})", addr, tls_enabled);
|
||||||
|
|
||||||
|
let proxy = Proxy {
|
||||||
|
globals: globals.clone(),
|
||||||
|
listening_on: addr,
|
||||||
|
tls_enabled,
|
||||||
|
backends: backends.clone(),
|
||||||
|
forwarder: forwarder.clone(),
|
||||||
|
};
|
||||||
|
globals.runtime_handle.spawn(proxy.start())
|
||||||
|
}));
|
||||||
|
|
||||||
|
// wait for all future
|
||||||
|
if let (Ok(_), _, _) = futures.await {
|
||||||
|
error!("Some proxy services are down");
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
|
||||||
177
src/proxy.rs
177
src/proxy.rs
|
|
@ -1,38 +1,159 @@
|
||||||
use crate::{acceptor::PacketAcceptor, error::*, globals::Globals, log::*};
|
use crate::{backend::Backend, error::*, globals::Globals, log::*};
|
||||||
use futures::future::select_all;
|
use futures::{
|
||||||
use hyper::Client;
|
select,
|
||||||
#[cfg(feature = "forward-hyper-trust-dns")]
|
task::{Context, Poll},
|
||||||
use hyper_trust_dns::TrustDnsResolver;
|
Future, FutureExt,
|
||||||
use std::sync::Arc;
|
};
|
||||||
|
use hyper::{
|
||||||
|
client::connect::Connect,
|
||||||
|
http,
|
||||||
|
server::conn::Http,
|
||||||
|
service::{service_fn, Service},
|
||||||
|
Body, Client, HeaderMap, Method, Request, Response, StatusCode,
|
||||||
|
};
|
||||||
|
use std::{collections::HashMap, net::SocketAddr, pin::Pin, sync::Arc};
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
|
||||||
|
net::TcpListener,
|
||||||
|
runtime::Handle,
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[allow(clippy::unnecessary_wraps)]
|
||||||
pub struct Proxy {
|
fn http_error(status_code: StatusCode) -> Result<Response<Body>, http::Error> {
|
||||||
|
let response = Response::builder()
|
||||||
|
.status(status_code)
|
||||||
|
.body(Body::empty())
|
||||||
|
.unwrap();
|
||||||
|
Ok(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct LocalExecutor {
|
||||||
|
runtime_handle: Handle,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalExecutor {
|
||||||
|
fn new(runtime_handle: Handle) -> Self {
|
||||||
|
LocalExecutor { runtime_handle }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> hyper::rt::Executor<F> for LocalExecutor
|
||||||
|
where
|
||||||
|
F: std::future::Future + Send + 'static,
|
||||||
|
F::Output: Send,
|
||||||
|
{
|
||||||
|
fn execute(&self, fut: F) {
|
||||||
|
self.runtime_handle.spawn(fut);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Proxy<T>
|
||||||
|
where
|
||||||
|
T: Connect + Clone + Sync + Send + 'static,
|
||||||
|
{
|
||||||
|
pub listening_on: SocketAddr,
|
||||||
|
pub tls_enabled: bool, // TCP待受がTLSかどうか
|
||||||
|
pub backends: Arc<HashMap<String, Backend>>, // TODO: hyper::uriで抜いたhostnameで引っ掛ける。Stringでいいのか?
|
||||||
|
pub forwarder: Arc<Client<T>>,
|
||||||
pub globals: Arc<Globals>,
|
pub globals: Arc<Globals>,
|
||||||
}
|
}
|
||||||
impl Proxy {
|
|
||||||
pub async fn entrypoint(self) -> Result<()> {
|
|
||||||
let addresses = self.globals.listen_addresses.clone();
|
|
||||||
let futures = select_all(addresses.into_iter().map(|addr| {
|
|
||||||
info!("Listen address: {:?}", addr);
|
|
||||||
|
|
||||||
#[cfg(feature = "forward-hyper-trust-dns")]
|
// TODO: ここでbackendの名前単位でリクエストを分岐させる
|
||||||
let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector();
|
async fn handle_request(
|
||||||
#[cfg(not(feature = "forward-hyper-trust-dns"))]
|
req: Request<Body>,
|
||||||
let connector = hyper_tls::HttpsConnector::new();
|
client_ip: SocketAddr,
|
||||||
let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector));
|
globals: Arc<Globals>,
|
||||||
|
) -> Result<Response<Body>, http::Error> {
|
||||||
|
// http_error(StatusCode::NOT_FOUND)
|
||||||
|
debug!("{:?}", req);
|
||||||
|
// if req.version() == hyper::Version::HTTP_11 {
|
||||||
|
// Ok(Response::new(Body::from("Hello World")))
|
||||||
|
// } else {
|
||||||
|
// Note: it's usually better to return a Response
|
||||||
|
// with an appropriate StatusCode instead of an Err.
|
||||||
|
// Err("not HTTP/1.1, abort connection")
|
||||||
|
http_error(StatusCode::NOT_FOUND)
|
||||||
|
// }
|
||||||
|
// });
|
||||||
|
}
|
||||||
|
|
||||||
let acceptor = PacketAcceptor {
|
impl<T> Proxy<T>
|
||||||
listening_on: addr,
|
where
|
||||||
globals: self.globals.clone(),
|
T: Connect + Clone + Sync + Send + 'static,
|
||||||
forwarder,
|
{
|
||||||
};
|
pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr)
|
||||||
self.globals.runtime_handle.spawn(acceptor.start())
|
where
|
||||||
}));
|
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
|
||||||
|
{
|
||||||
|
let clients_count = self.globals.clients_count.clone();
|
||||||
|
if clients_count.increment() > self.globals.max_clients {
|
||||||
|
clients_count.decrement();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
// wait for all future
|
self.globals.runtime_handle.clone().spawn(async move {
|
||||||
if let (Ok(_), _, _) = futures.await {
|
tokio::time::timeout(
|
||||||
error!("Some packet acceptors are down");
|
self.globals.timeout + Duration::from_secs(1),
|
||||||
|
// server.serve_connection(stream, self),
|
||||||
|
server.serve_connection(
|
||||||
|
stream,
|
||||||
|
service_fn(move |req: Request<Body>| {
|
||||||
|
handle_request(req, peer_addr, self.globals.clone())
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
clients_count.decrement();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn start_without_tls(
|
||||||
|
self,
|
||||||
|
listener: TcpListener,
|
||||||
|
server: Http<LocalExecutor>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let listener_service = async {
|
||||||
|
while let Ok((stream, _client_addr)) = listener.accept().await {
|
||||||
|
self
|
||||||
|
.clone()
|
||||||
|
.client_serve(stream, server.clone(), _client_addr)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
Ok(()) as Result<()>
|
||||||
};
|
};
|
||||||
|
listener_service.await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start(self) -> Result<()> {
|
||||||
|
let tcp_listener = TcpListener::bind(&self.listening_on).await?;
|
||||||
|
|
||||||
|
let mut server = Http::new();
|
||||||
|
server.http1_keep_alive(self.globals.keepalive);
|
||||||
|
server.http2_max_concurrent_streams(self.globals.max_concurrent_streams);
|
||||||
|
server.pipeline_flush(true);
|
||||||
|
let executor = LocalExecutor::new(self.globals.runtime_handle.clone());
|
||||||
|
let server = server.with_executor(executor);
|
||||||
|
|
||||||
|
if self.tls_enabled {
|
||||||
|
info!(
|
||||||
|
"Start TCP proxy serving with HTTPS request for configured host names: {:?}",
|
||||||
|
tcp_listener.local_addr()?
|
||||||
|
);
|
||||||
|
// #[cfg(feature = "tls")]
|
||||||
|
self.start_with_tls(tcp_listener, server).await?;
|
||||||
|
} else {
|
||||||
|
info!(
|
||||||
|
"Start TCP proxy serving with HTTP request for configured host names: {:?}",
|
||||||
|
tcp_listener.local_addr()?
|
||||||
|
);
|
||||||
|
self.start_without_tls(tcp_listener, server).await?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
77
src/proxy_tls.rs
Normal file
77
src/proxy_tls.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
||||||
|
use crate::{
|
||||||
|
constants::CERTS_WATCH_DELAY_SECS,
|
||||||
|
error::*,
|
||||||
|
log::*,
|
||||||
|
proxy::{LocalExecutor, Proxy},
|
||||||
|
};
|
||||||
|
use futures::{future::FutureExt, join, select};
|
||||||
|
use hyper::{client::connect::Connect, server::conn::Http};
|
||||||
|
use std::{sync::Arc, time::Duration};
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
impl<T> Proxy<T>
|
||||||
|
where
|
||||||
|
T: Connect + Clone + Sync + Send + 'static,
|
||||||
|
{
|
||||||
|
pub async fn start_with_tls(
|
||||||
|
self,
|
||||||
|
listener: TcpListener,
|
||||||
|
server: Http<LocalExecutor>,
|
||||||
|
) -> Result<()> {
|
||||||
|
let cert_service = async {
|
||||||
|
info!("Start cert watch service for {}", self.listening_on);
|
||||||
|
loop {
|
||||||
|
for (hostname, backend) in self.backends.iter() {
|
||||||
|
if backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() {
|
||||||
|
if let Err(_e) = backend.update_server_config().await {
|
||||||
|
warn!("Failed to update certs for {}", hostname);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let listener_service = async {
|
||||||
|
loop {
|
||||||
|
select! {
|
||||||
|
tcp_cnx = listener.accept().fuse() => {
|
||||||
|
if tcp_cnx.is_err() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let (raw_stream, _client_addr) = tcp_cnx.unwrap();
|
||||||
|
|
||||||
|
// First check SNI
|
||||||
|
let rustls_acceptor = rustls::server::Acceptor::new().unwrap();
|
||||||
|
let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls_acceptor, raw_stream);
|
||||||
|
let start = acceptor.await.unwrap();
|
||||||
|
let client_hello = start.client_hello();
|
||||||
|
debug!("SNI in ClientHello: {:?}", client_hello.server_name());
|
||||||
|
// Find server config for given SNI
|
||||||
|
let svn = if let Some(svn) = client_hello.server_name() {
|
||||||
|
svn
|
||||||
|
} else {
|
||||||
|
info!("No SNI in ClientHello");
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let backend_serve = if let Some(backend_serve) = self.backends.get(svn){
|
||||||
|
backend_serve
|
||||||
|
} else {
|
||||||
|
info!("No configuration for the server name {} given in client_hello", svn);
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
let server_config = backend_serve.get_tls_server_config();
|
||||||
|
// Finally serve the TLS connection
|
||||||
|
if let Ok(stream) = start.into_stream(Arc::new(server_config.unwrap())).await {
|
||||||
|
self.clone().client_serve(stream, server.clone(), _client_addr).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
complete => break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(()) as Result<()>
|
||||||
|
};
|
||||||
|
|
||||||
|
join!(listener_service, cert_service).0
|
||||||
|
}
|
||||||
|
}
|
||||||
176
src/tls.rs
176
src/tls.rs
|
|
@ -1,176 +0,0 @@
|
||||||
use std::fs::File;
|
|
||||||
use std::io::{self, BufReader, Cursor, Read};
|
|
||||||
use std::path::Path;
|
|
||||||
use std::sync::Arc;
|
|
||||||
use std::time::Duration;
|
|
||||||
|
|
||||||
use futures::{future::FutureExt, join, select};
|
|
||||||
use hyper::client::connect::Connect;
|
|
||||||
use hyper::server::conn::Http;
|
|
||||||
use tokio::{
|
|
||||||
net::TcpListener,
|
|
||||||
sync::mpsc::{self, Receiver},
|
|
||||||
};
|
|
||||||
use tokio_rustls::{
|
|
||||||
rustls::{Certificate, PrivateKey, ServerConfig},
|
|
||||||
TlsAcceptor,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::acceptor::{LocalExecutor, PacketAcceptor};
|
|
||||||
use crate::constants::CERTS_WATCH_DELAY_SECS;
|
|
||||||
use crate::error::*;
|
|
||||||
|
|
||||||
pub fn create_tls_acceptor<P, P2>(certs_path: P, certs_keys_path: P2) -> io::Result<TlsAcceptor>
|
|
||||||
where
|
|
||||||
P: AsRef<Path>,
|
|
||||||
P2: AsRef<Path>,
|
|
||||||
{
|
|
||||||
let certs: Vec<_> = {
|
|
||||||
let certs_path_str = certs_path.as_ref().display().to_string();
|
|
||||||
let mut reader = BufReader::new(File::open(certs_path).map_err(|e| {
|
|
||||||
io::Error::new(
|
|
||||||
e.kind(),
|
|
||||||
format!(
|
|
||||||
"Unable to load the certificates [{}]: {}",
|
|
||||||
certs_path_str, e
|
|
||||||
),
|
|
||||||
)
|
|
||||||
})?);
|
|
||||||
rustls_pemfile::certs(&mut reader).map_err(|_| {
|
|
||||||
io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
"Unable to parse the certificates",
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
}
|
|
||||||
.drain(..)
|
|
||||||
.map(Certificate)
|
|
||||||
.collect();
|
|
||||||
let certs_keys: Vec<_> = {
|
|
||||||
let certs_keys_path_str = certs_keys_path.as_ref().display().to_string();
|
|
||||||
let encoded_keys = {
|
|
||||||
let mut encoded_keys = vec![];
|
|
||||||
File::open(certs_keys_path)
|
|
||||||
.map_err(|e| {
|
|
||||||
io::Error::new(
|
|
||||||
e.kind(),
|
|
||||||
format!(
|
|
||||||
"Unable to load the certificate keys [{}]: {}",
|
|
||||||
certs_keys_path_str, e
|
|
||||||
),
|
|
||||||
)
|
|
||||||
})?
|
|
||||||
.read_to_end(&mut encoded_keys)?;
|
|
||||||
encoded_keys
|
|
||||||
};
|
|
||||||
let mut reader = Cursor::new(encoded_keys);
|
|
||||||
let pkcs8_keys = rustls_pemfile::pkcs8_private_keys(&mut reader).map_err(|_| {
|
|
||||||
io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
"Unable to parse the certificates private keys (PKCS8)",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
reader.set_position(0);
|
|
||||||
let mut rsa_keys = rustls_pemfile::rsa_private_keys(&mut reader).map_err(|_| {
|
|
||||||
io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
"Unable to parse the certificates private keys (RSA)",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
let mut keys = pkcs8_keys;
|
|
||||||
keys.append(&mut rsa_keys);
|
|
||||||
if keys.is_empty() {
|
|
||||||
return Err(io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
"No private keys found - Make sure that they are in PKCS#8/PEM format",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
keys.drain(..).map(PrivateKey).collect()
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut server_config = certs_keys
|
|
||||||
.into_iter()
|
|
||||||
.find_map(|certs_key| {
|
|
||||||
let server_config_builder = ServerConfig::builder()
|
|
||||||
.with_safe_defaults()
|
|
||||||
.with_no_client_auth();
|
|
||||||
if let Ok(found_config) = server_config_builder.with_single_cert(certs.clone(), certs_key) {
|
|
||||||
Some(found_config)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.ok_or_else(|| {
|
|
||||||
io::Error::new(
|
|
||||||
io::ErrorKind::InvalidInput,
|
|
||||||
"Unable to find a valid certificate and key",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()];
|
|
||||||
Ok(TlsAcceptor::from(Arc::new(server_config)))
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> PacketAcceptor<T>
|
|
||||||
where
|
|
||||||
T: Connect + Clone + Sync + Send + 'static,
|
|
||||||
{
|
|
||||||
async fn start_https_service(
|
|
||||||
self,
|
|
||||||
mut tls_acceptor_receiver: Receiver<TlsAcceptor>,
|
|
||||||
listener: TcpListener,
|
|
||||||
server: Http<LocalExecutor>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut tls_acceptor: Option<TlsAcceptor> = None;
|
|
||||||
let listener_service = async {
|
|
||||||
loop {
|
|
||||||
select! {
|
|
||||||
tcp_cnx = listener.accept().fuse() => {
|
|
||||||
if tls_acceptor.is_none() || tcp_cnx.is_err() {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
let (raw_stream, _client_addr) = tcp_cnx.unwrap();
|
|
||||||
if let Ok(stream) = tls_acceptor.as_ref().unwrap().accept(raw_stream).await {
|
|
||||||
self.clone().client_serve(stream, server.clone(), _client_addr).await
|
|
||||||
}
|
|
||||||
}
|
|
||||||
new_tls_acceptor = tls_acceptor_receiver.recv().fuse() => {
|
|
||||||
if new_tls_acceptor.is_none() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
tls_acceptor = new_tls_acceptor;
|
|
||||||
}
|
|
||||||
complete => break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(()) as Result<()>
|
|
||||||
};
|
|
||||||
listener_service.await?;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn start_with_tls(
|
|
||||||
self,
|
|
||||||
listener: TcpListener,
|
|
||||||
server: Http<LocalExecutor>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let certs_path = self.globals.tls_cert_path.as_ref().unwrap().clone();
|
|
||||||
let certs_keys_path = self.globals.tls_cert_key_path.as_ref().unwrap().clone();
|
|
||||||
let (tls_acceptor_sender, tls_acceptor_receiver) = mpsc::channel(1);
|
|
||||||
let https_service = self.start_https_service(tls_acceptor_receiver, listener, server);
|
|
||||||
let cert_service = async {
|
|
||||||
loop {
|
|
||||||
match create_tls_acceptor(&certs_path, &certs_keys_path) {
|
|
||||||
Ok(tls_acceptor) => {
|
|
||||||
if tls_acceptor_sender.send(tls_acceptor).await.is_err() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(e) => eprintln!("TLS certificates error: {}", e),
|
|
||||||
}
|
|
||||||
tokio::time::sleep(Duration::from_secs(CERTS_WATCH_DELAY_SECS.into())).await;
|
|
||||||
}
|
|
||||||
Ok(()) as Result<()>
|
|
||||||
};
|
|
||||||
return join!(https_service, cert_service).0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue