wip: refactor whole module in lib

This commit is contained in:
Jun Kurihara 2023-11-21 22:46:52 +09:00
commit f98c778a0c
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
42 changed files with 943 additions and 531 deletions

View file

@ -0,0 +1,393 @@
use crate::{error::*, globals::Globals, log::*, CryptoSource};
use base64::{engine::general_purpose, Engine as _};
use bytes::{Buf, Bytes, BytesMut};
use http_cache_semantics::CachePolicy;
use hyper::{
http::{Request, Response},
Body,
};
use lru::LruCache;
use sha2::{Digest, Sha256};
use std::{
fmt::Debug,
path::{Path, PathBuf},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
time::SystemTime,
};
use tokio::{
fs::{self, File},
io::{AsyncReadExt, AsyncWriteExt},
sync::RwLock,
};
#[derive(Clone, Debug)]
/// Cache target in hybrid manner of on-memory and file system
pub enum CacheFileOrOnMemory {
/// Pointer to the temporary cache file
File(PathBuf),
/// Cached body itself
OnMemory(Vec<u8>),
}
#[derive(Clone, Debug)]
/// Cache object definition
struct CacheObject {
/// Cache policy to determine if the stored cache can be used as a response to a new incoming request
pub policy: CachePolicy,
/// Cache target: on-memory object or temporary file
pub target: CacheFileOrOnMemory,
}
#[derive(Debug)]
/// Manager inner for cache on file system
struct CacheFileManagerInner {
/// Directory of temporary files
cache_dir: PathBuf,
/// Counter of current cached files
cnt: usize,
/// Async runtime
runtime_handle: tokio::runtime::Handle,
}
impl CacheFileManagerInner {
/// Build new cache file manager.
/// This first creates cache file dir if not exists, and cleans up the file inside the directory.
/// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed.
async fn new(path: impl AsRef<Path>, runtime_handle: &tokio::runtime::Handle) -> Self {
let path_buf = path.as_ref().to_path_buf();
if let Err(e) = fs::remove_dir_all(path).await {
warn!("Failed to clean up the cache dir: {e}");
};
fs::create_dir_all(&path_buf).await.unwrap();
Self {
cache_dir: path_buf.clone(),
cnt: 0,
runtime_handle: runtime_handle.clone(),
}
}
/// Create a new temporary file cache
async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> Result<CacheFileOrOnMemory> {
let cache_filepath = self.cache_dir.join(cache_filename);
let Ok(mut file) = File::create(&cache_filepath).await else {
return Err(RpxyError::Cache("Failed to create file"));
};
let mut bytes_clone = body_bytes.clone();
while bytes_clone.has_remaining() {
if let Err(e) = file.write_buf(&mut bytes_clone).await {
error!("Failed to write file cache: {e}");
return Err(RpxyError::Cache("Failed to write file cache: {e}"));
};
}
self.cnt += 1;
Ok(CacheFileOrOnMemory::File(cache_filepath))
}
/// Retrieve a stored temporary file cache
async fn read(&self, path: impl AsRef<Path>) -> Result<Body> {
let Ok(mut file) = File::open(&path).await else {
warn!("Cache file object cannot be opened");
return Err(RpxyError::Cache("Cache file object cannot be opened"));
};
let (body_sender, res_body) = Body::channel();
self.runtime_handle.spawn(async move {
let mut sender = body_sender;
let mut buf = BytesMut::new();
loop {
match file.read_buf(&mut buf).await {
Ok(0) => break,
Ok(_) => sender.send_data(buf.copy_to_bytes(buf.remaining())).await?,
Err(_) => break,
};
}
Ok(()) as Result<()>
});
Ok(res_body)
}
/// Remove file
async fn remove(&mut self, path: impl AsRef<Path>) -> Result<()> {
fs::remove_file(path.as_ref()).await?;
self.cnt -= 1;
debug!("Removed a cache file at {:?} (file count: {})", path.as_ref(), self.cnt);
Ok(())
}
}
#[derive(Debug, Clone)]
/// Cache file manager outer that is responsible to handle `RwLock`
struct CacheFileManager {
inner: Arc<RwLock<CacheFileManagerInner>>,
}
impl CacheFileManager {
/// Build manager
async fn new(path: impl AsRef<Path>, runtime_handle: &tokio::runtime::Handle) -> Self {
Self {
inner: Arc::new(RwLock::new(CacheFileManagerInner::new(path, runtime_handle).await)),
}
}
/// Evict a temporary file cache
async fn evict(&self, path: impl AsRef<Path>) {
// Acquire the write lock
let mut inner = self.inner.write().await;
if let Err(e) = inner.remove(path).await {
warn!("Eviction failed during file object removal: {:?}", e);
};
}
/// Read a temporary file cache
async fn read(&self, path: impl AsRef<Path>) -> Result<Body> {
let mgr = self.inner.read().await;
mgr.read(&path).await
}
/// Create a temporary file cache
async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> Result<CacheFileOrOnMemory> {
let mut mgr = self.inner.write().await;
mgr.create(cache_filename, body_bytes).await
}
async fn count(&self) -> usize {
let mgr = self.inner.read().await;
mgr.cnt
}
}
#[derive(Debug, Clone)]
/// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache`
struct LruCacheManager {
inner: Arc<Mutex<LruCache<String, CacheObject>>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう
cnt: Arc<AtomicUsize>,
}
impl LruCacheManager {
/// Build LruCache
fn new(cache_max_entry: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(LruCache::new(
std::num::NonZeroUsize::new(cache_max_entry).unwrap(),
))),
cnt: Arc::new(AtomicUsize::default()),
}
}
/// Count entries
fn count(&self) -> usize {
self.cnt.load(Ordering::Relaxed)
}
/// Evict an entry
fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> {
let Ok(mut lock) = self.inner.lock() else {
error!("Mutex can't be locked to evict a cache entry");
return None;
};
let res = lock.pop_entry(cache_key);
self.cnt.store(lock.len(), Ordering::Relaxed);
res
}
/// Get an entry
fn get(&self, cache_key: &str) -> Result<Option<CacheObject>> {
let Ok(mut lock) = self.inner.lock() else {
error!("Mutex can't be locked for checking cache entry");
return Err(RpxyError::Cache("Mutex can't be locked for checking cache entry"));
};
let Some(cached_object) = lock.get(cache_key) else {
return Ok(None);
};
Ok(Some(cached_object.clone()))
}
/// Push an entry
fn push(&self, cache_key: &str, cache_object: CacheObject) -> Result<Option<(String, CacheObject)>> {
let Ok(mut lock) = self.inner.lock() else {
error!("Failed to acquire mutex lock for writing cache entry");
return Err(RpxyError::Cache("Failed to acquire mutex lock for writing cache entry"));
};
let res = Ok(lock.push(cache_key.to_string(), cache_object));
self.cnt.store(lock.len(), Ordering::Relaxed);
res
}
}
#[derive(Clone, Debug)]
pub struct RpxyCache {
/// Managing cache file objects through RwLock's lock mechanism for file lock
cache_file_manager: CacheFileManager,
/// Lru cache storing http message caching policy
inner: LruCacheManager,
/// Async runtime
runtime_handle: tokio::runtime::Handle,
/// Maximum size of each cache file object
max_each_size: usize,
/// Maximum size of cache object on memory
max_each_size_on_memory: usize,
}
impl RpxyCache {
/// Generate cache storage
pub async fn new<T: CryptoSource>(globals: &Globals<T>) -> Option<Self> {
if !globals.proxy_config.cache_enabled {
return None;
}
let path = globals.proxy_config.cache_dir.as_ref().unwrap();
let cache_file_manager = CacheFileManager::new(path, &globals.runtime_handle).await;
let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry);
let max_each_size = globals.proxy_config.cache_max_each_size;
let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory;
if max_each_size < max_each_size_on_memory {
warn!(
"Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache"
);
max_each_size_on_memory = max_each_size;
}
Some(Self {
cache_file_manager,
inner,
runtime_handle: globals.runtime_handle.clone(),
max_each_size,
max_each_size_on_memory,
})
}
/// Count cache entries
pub async fn count(&self) -> (usize, usize, usize) {
let total = self.inner.count();
let file = self.cache_file_manager.count().await;
let on_memory = total - file;
(total, on_memory, file)
}
/// Get cached response
pub async fn get<R>(&self, req: &Request<R>) -> Option<Response<Body>> {
debug!(
"Current cache status: (total, on-memory, file) = {:?}",
self.count().await
);
let cache_key = req.uri().to_string();
// First check cache chance
let Ok(Some(cached_object)) = self.inner.get(&cache_key) else {
return None;
};
// Secondly check the cache freshness as an HTTP message
let now = SystemTime::now();
let http_cache_semantics::BeforeRequest::Fresh(res_parts) = cached_object.policy.before_request(req, now) else {
// Evict stale cache entry.
// This might be okay to keep as is since it would be updated later.
// However, there is no guarantee that newly got objects will be still cacheable.
// So, we have to evict stale cache entries and cache file objects if found.
debug!("Stale cache entry: {cache_key}");
let _evicted_entry = self.inner.evict(&cache_key);
// For cache file
if let CacheFileOrOnMemory::File(path) = &cached_object.target {
self.cache_file_manager.evict(&path).await;
}
return None;
};
// Finally retrieve the file/on-memory object
match cached_object.target {
CacheFileOrOnMemory::File(path) => {
let res_body = match self.cache_file_manager.read(&path).await {
Ok(res_body) => res_body,
Err(e) => {
warn!("Failed to read from file cache: {e}");
let _evicted_entry = self.inner.evict(&cache_key);
self.cache_file_manager.evict(&path).await;
return None;
}
};
debug!("Cache hit from file: {cache_key}");
Some(Response::from_parts(res_parts, res_body))
}
CacheFileOrOnMemory::OnMemory(object) => {
debug!("Cache hit from on memory: {cache_key}");
Some(Response::from_parts(res_parts, Body::from(object)))
}
}
}
/// Put response into the cache
pub async fn put(&self, uri: &hyper::Uri, body_bytes: &Bytes, policy: &CachePolicy) -> Result<()> {
let my_cache = self.inner.clone();
let mut mgr = self.cache_file_manager.clone();
let uri = uri.clone();
let bytes_clone = body_bytes.clone();
let policy_clone = policy.clone();
let max_each_size = self.max_each_size;
let max_each_size_on_memory = self.max_each_size_on_memory;
self.runtime_handle.spawn(async move {
if bytes_clone.len() > max_each_size {
warn!("Too large to cache");
return Err(RpxyError::Cache("Too large to cache"));
}
let cache_key = derive_cache_key_from_uri(&uri);
debug!("Object of size {:?} bytes to be cached", bytes_clone.len());
let cache_object = if bytes_clone.len() > max_each_size_on_memory {
let cache_filename = derive_filename_from_uri(&uri);
let target = mgr.create(&cache_filename, &bytes_clone).await?;
debug!("Cached a new cache file: {} - {}", cache_key, cache_filename);
CacheObject {
policy: policy_clone,
target,
}
} else {
debug!("Cached a new object on memory: {}", cache_key);
CacheObject {
policy: policy_clone,
target: CacheFileOrOnMemory::OnMemory(bytes_clone.to_vec()),
}
};
if let Some((k, v)) = my_cache.push(&cache_key, cache_object)? {
if k != cache_key {
info!("Over the cache capacity. Evict least recent used entry");
if let CacheFileOrOnMemory::File(path) = v.target {
mgr.evict(&path).await;
}
}
}
Ok(())
});
Ok(())
}
}
fn derive_filename_from_uri(uri: &hyper::Uri) -> String {
let mut hasher = Sha256::new();
hasher.update(uri.to_string());
let digest = hasher.finalize();
general_purpose::URL_SAFE_NO_PAD.encode(digest)
}
fn derive_cache_key_from_uri(uri: &hyper::Uri) -> String {
uri.to_string()
}
pub fn get_policy_if_cacheable<R>(req: Option<&Request<R>>, res: Option<&Response<Body>>) -> Result<Option<CachePolicy>>
where
R: Debug,
{
// deduce cache policy from req and res
let (Some(req), Some(res)) = (req, res) else {
return Err(RpxyError::Cache("Invalid null request and/or response"));
};
let new_policy = CachePolicy::new(req, res);
if new_policy.is_storable() {
// debug!("Response is cacheable: {:?}\n{:?}", req, res.headers());
Ok(Some(new_policy))
} else {
Ok(None)
}
}

View file

@ -0,0 +1,16 @@
use http::StatusCode;
use thiserror::Error;
pub type HttpResult<T> = std::result::Result<T, HttpError>;
/// Describes things that can go wrong in the handler
#[derive(Debug, Error)]
pub enum HttpError {}
impl From<HttpError> for StatusCode {
fn from(e: HttpError) -> StatusCode {
match e {
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}

View file

@ -0,0 +1,147 @@
#[cfg(feature = "cache")]
use super::cache::{get_policy_if_cacheable, RpxyCache};
use crate::{error::RpxyError, globals::Globals, log::*, CryptoSource};
use async_trait::async_trait;
#[cfg(feature = "cache")]
use bytes::Buf;
use hyper::{
body::{Body, HttpBody},
client::{connect::Connect, HttpConnector},
http::Version,
Client, Request, Response,
};
use hyper_rustls::HttpsConnector;
#[cfg(feature = "cache")]
/// Build synthetic request to cache
fn build_synth_req_for_cache<T>(req: &Request<T>) -> Request<()> {
let mut builder = Request::builder()
.method(req.method())
.uri(req.uri())
.version(req.version());
// TODO: omit extensions. is this approach correct?
for (header_key, header_value) in req.headers() {
builder = builder.header(header_key, header_value);
}
builder.body(()).unwrap()
}
#[async_trait]
/// Definition of the forwarder that simply forward requests from downstream client to upstream app servers.
pub trait ForwardRequest<B> {
type Error;
async fn request(&self, req: Request<B>) -> Result<Response<Body>, Self::Error>;
}
/// Forwarder struct responsible to cache handling
pub struct Forwarder<C, B = Body>
where
C: Connect + Clone + Sync + Send + 'static,
{
#[cfg(feature = "cache")]
cache: Option<RpxyCache>,
inner: Client<C, B>,
inner_h2: Client<C, B>, // `h2c` or http/2-only client is defined separately
}
#[async_trait]
impl<C, B> ForwardRequest<B> for Forwarder<C, B>
where
B: HttpBody + Send + Sync + 'static,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
C: Connect + Clone + Sync + Send + 'static,
{
type Error = RpxyError;
#[cfg(feature = "cache")]
async fn request(&self, req: Request<B>) -> Result<Response<Body>, Self::Error> {
let mut synth_req = None;
if self.cache.is_some() {
if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await {
// if found, return it as response.
info!("Cache hit - Return from cache");
return Ok(cached_response);
};
// Synthetic request copy used just for caching (cannot clone request object...)
synth_req = Some(build_synth_req_for_cache(&req));
}
// TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient.
// Needs to be reconsidered. Currently, this is a kind of work around.
// This possibly relates to https://github.com/hyperium/hyper/issues/2417.
let res = match req.version() {
Version::HTTP_2 => self.inner_h2.request(req).await.map_err(RpxyError::Hyper), // handles `h2c` requests
_ => self.inner.request(req).await.map_err(RpxyError::Hyper),
};
if self.cache.is_none() {
return res;
}
// check cacheability and store it if cacheable
let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else {
return res;
};
let (parts, body) = res.unwrap().into_parts();
let Ok(mut bytes) = hyper::body::aggregate(body).await else {
return Err(RpxyError::Cache("Failed to write byte buffer"));
};
let aggregated = bytes.copy_to_bytes(bytes.remaining());
if let Err(cache_err) = self
.cache
.as_ref()
.unwrap()
.put(synth_req.unwrap().uri(), &aggregated, &cache_policy)
.await
{
error!("{:?}", cache_err);
};
// res
Ok(Response::from_parts(parts, Body::from(aggregated)))
}
#[cfg(not(feature = "cache"))]
async fn request(&self, req: Request<B>) -> Result<Response<Body>, Self::Error> {
match req.version() {
Version::HTTP_2 => self.inner_h2.request(req).await.map_err(RpxyError::Hyper), // handles `h2c` requests
_ => self.inner.request(req).await.map_err(RpxyError::Hyper),
}
}
}
impl Forwarder<HttpsConnector<HttpConnector>, Body> {
/// Build forwarder
pub async fn new<T: CryptoSource>(_globals: &std::sync::Arc<Globals<T>>) -> Self {
#[cfg(feature = "native-roots")]
let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots();
#[cfg(feature = "native-roots")]
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots();
#[cfg(feature = "native-roots")]
info!("Native cert store is used for the connection to backend applications");
#[cfg(not(feature = "native-roots"))]
let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
#[cfg(not(feature = "native-roots"))]
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
#[cfg(not(feature = "native-roots"))]
info!("Mozilla WebPKI root certs is used for the connection to backend applications");
let connector = builder.https_or_http().enable_http1().enable_http2().build();
let connector_h2 = builder_h2.https_or_http().enable_http2().build();
let inner = Client::builder().build::<_, Body>(connector);
let inner_h2 = Client::builder().http2_only(true).build::<_, Body>(connector_h2);
#[cfg(feature = "cache")]
{
let cache = RpxyCache::new(_globals).await;
Self { inner, inner_h2, cache }
}
#[cfg(not(feature = "cache"))]
Self { inner, inner_h2 }
}
}

View file

@ -0,0 +1,384 @@
// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
use super::{
error::*,
// forwarder::{ForwardRequest, Forwarder},
utils_headers::*,
utils_request::*,
// utils_synth_response::*,
HandlerContext,
};
use crate::{
backend::{Backend, UpstreamGroup},
certs::CryptoSource,
constants::RESPONSE_HEADER_SERVER,
error::*,
globals::Globals,
log::*,
utils::ServerNameBytesExp,
};
use derive_builder::Builder;
use http::{
header::{self, HeaderValue},
uri::Scheme,
Request, Response, StatusCode, Uri, Version,
};
use hyper::body::Incoming;
use hyper_util::client::legacy::connect::Connect;
use std::{net::SocketAddr, sync::Arc};
use tokio::{io::copy_bidirectional, time::timeout};
#[derive(Clone, Builder)]
/// HTTP message handler for requests from clients and responses from backend applications,
/// responsible to manipulate and forward messages to upstream backends and downstream clients.
// pub struct HttpMessageHandler<T, U>
pub struct HttpMessageHandler<U>
where
// T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone,
{
// forwarder: Arc<Forwarder<T>>,
globals: Arc<Globals<U>>,
}
impl<U> HttpMessageHandler<U>
where
// T: Connect + Clone + Sync + Send + 'static,
U: CryptoSource + Clone,
{
// /// Return with an arbitrary status code of error and log message
// fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result<Response<Body>> {
// log_data.status_code(&status_code).output();
// http_error(status_code)
// }
/// Handle incoming request message from a client
pub async fn handle_request(
&self,
mut req: Request<Incoming>,
client_addr: SocketAddr, // アクセス制御用
listen_addr: SocketAddr,
tls_enabled: bool,
tls_server_name: Option<ServerNameBytesExp>,
) -> Result<HttpResult<Response<Incoming>>> {
////////
let mut log_data = MessageLog::from(&req);
log_data.client_addr(&client_addr);
//////
// // Here we start to handle with server_name
// let server_name = if let Ok(v) = req.parse_host() {
// ServerNameBytesExp::from(v)
// } else {
// return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data);
// };
// // check consistency of between TLS SNI and HOST/Request URI Line.
// #[allow(clippy::collapsible_if)]
// if tls_enabled && self.globals.proxy_config.sni_consistency {
// if server_name != tls_server_name.unwrap_or_default() {
// return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data);
// }
// }
// // Find backend application for given server_name, and drop if incoming request is invalid as request.
// let backend = match self.globals.backends.apps.get(&server_name) {
// Some(be) => be,
// None => {
// let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else {
// return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data);
// };
// debug!("Serving by default app");
// self.globals.backends.apps.get(default_server_name).unwrap()
// }
// };
// // Redirect to https if !tls_enabled and redirect_to_https is true
// if !tls_enabled && backend.https_redirection.unwrap_or(false) {
// debug!("Redirect to secure connection: {}", &backend.server_name);
// log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output();
// return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req);
// }
// // Find reverse proxy for given path and choose one of upstream host
// // Longest prefix match
// let path = req.uri().path();
// let Some(upstream_group) = backend.reverse_proxy.get(path) else {
// return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data);
// };
// // Upgrade in request header
// let upgrade_in_request = extract_upgrade(req.headers());
// let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
// // Build request from destination information
// let _context = match self.generate_request_forwarded(
// &client_addr,
// &listen_addr,
// &mut req,
// &upgrade_in_request,
// upstream_group,
// tls_enabled,
// ) {
// Err(e) => {
// error!("Failed to generate destination uri for reverse proxy: {}", e);
// return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data);
// }
// Ok(v) => v,
// };
// debug!("Request to be forwarded: {:?}", req);
// log_data.xff(&req.headers().get("x-forwarded-for"));
// log_data.upstream(req.uri());
// //////
// // Forward request to a chosen backend
// let mut res_backend = {
// let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else {
// return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data);
// };
// match result {
// Ok(res) => res,
// Err(e) => {
// error!("Failed to get response from backend: {}", e);
// return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data);
// }
// }
// };
// // Process reverse proxy context generated during the forwarding request generation.
// #[cfg(feature = "sticky-cookie")]
// if let Some(context_from_lb) = _context.context_lb {
// let res_headers = res_backend.headers_mut();
// if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) {
// error!("Failed to append context to the response given from backend: {}", e);
// return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data);
// }
// }
// if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS {
// // Generate response to client
// if self.generate_response_forwarded(&mut res_backend, backend).is_err() {
// return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
// }
// log_data.status_code(&res_backend.status()).output();
// return Ok(res_backend);
// }
// // Handle StatusCode::SWITCHING_PROTOCOLS in response
// let upgrade_in_response = extract_upgrade(res_backend.headers());
// let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref())
// {
// u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase()
// } else {
// false
// };
// if !should_upgrade {
// error!(
// "Backend tried to switch to protocol {:?} when {:?} was requested",
// upgrade_in_response, upgrade_in_request
// );
// return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
// }
// let Some(request_upgraded) = request_upgraded else {
// error!("Request does not have an upgrade extension");
// return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data);
// };
// let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
// error!("Response does not have an upgrade extension");
// return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
// };
// self.globals.runtime_handle.spawn(async move {
// let mut response_upgraded = onupgrade.await.map_err(|e| {
// error!("Failed to upgrade response: {}", e);
// RpxyError::Hyper(e)
// })?;
// let mut request_upgraded = request_upgraded.await.map_err(|e| {
// error!("Failed to upgrade request: {}", e);
// RpxyError::Hyper(e)
// })?;
// copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
// .await
// .map_err(|e| {
// error!("Coping between upgraded connections failed: {}", e);
// RpxyError::Io(e)
// })?;
// Ok(()) as Result<()>
// });
// log_data.status_code(&res_backend.status()).output();
// Ok(res_backend)
todo!()
}
////////////////////////////////////////////////////
// Functions to generate messages
////////////////////////////////////////////////////
// /// Manipulate a response message sent from a backend application to forward downstream to a client.
// fn generate_response_forwarded<B>(&self, response: &mut Response<B>, chosen_backend: &Backend<U>) -> Result<()>
// where
// B: core::fmt::Debug,
// {
// let headers = response.headers_mut();
// remove_connection_header(headers);
// remove_hop_header(headers);
// add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?;
// #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
// {
// // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled
// // TODO: This is a workaround for avoiding a client authentication in HTTP/3
// if self.globals.proxy_config.http3
// && chosen_backend
// .crypto_source
// .as_ref()
// .is_some_and(|v| !v.is_mutual_tls())
// {
// if let Some(port) = self.globals.proxy_config.https_port {
// add_header_entry_overwrite_if_exist(
// headers,
// header::ALT_SVC.as_str(),
// format!(
// "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
// port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age
// ),
// )?;
// }
// } else {
// // remove alt-svc to disallow requests via http3
// headers.remove(header::ALT_SVC.as_str());
// }
// }
// #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))]
// {
// if let Some(port) = self.globals.proxy_config.https_port {
// headers.remove(header::ALT_SVC.as_str());
// }
// }
// Ok(())
// }
// #[allow(clippy::too_many_arguments)]
// /// Manipulate a request message sent from a client to forward upstream to a backend application
// fn generate_request_forwarded<B>(
// &self,
// client_addr: &SocketAddr,
// listen_addr: &SocketAddr,
// req: &mut Request<B>,
// upgrade: &Option<String>,
// upstream_group: &UpstreamGroup,
// tls_enabled: bool,
// ) -> Result<HandlerContext> {
// debug!("Generate request to be forwarded");
// // Add te: trailer if contained in original request
// let contains_te_trailers = {
// if let Some(te) = req.headers().get(header::TE) {
// te.as_bytes()
// .split(|v| v == &b',' || v == &b' ')
// .any(|x| x == "trailers".as_bytes())
// } else {
// false
// }
// };
// let uri = req.uri().to_string();
// let headers = req.headers_mut();
// // delete headers specified in header.connection
// remove_connection_header(headers);
// // delete hop headers including header.connection
// remove_hop_header(headers);
// // X-Forwarded-For
// add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?;
// // Add te: trailer if te_trailer
// if contains_te_trailers {
// headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap());
// }
// // add "host" header of original server_name if not exist (default)
// if req.headers().get(header::HOST).is_none() {
// let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned();
// req
// .headers_mut()
// .insert(header::HOST, HeaderValue::from_str(&org_host)?);
// };
// /////////////////////////////////////////////
// // Fix unique upstream destination since there could be multiple ones.
// #[cfg(feature = "sticky-cookie")]
// let (upstream_chosen_opt, context_from_lb) = {
// let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb {
// takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)?
// } else {
// None
// };
// upstream_group.get(&context_to_lb)
// };
// #[cfg(not(feature = "sticky-cookie"))]
// let (upstream_chosen_opt, _) = upstream_group.get(&None);
// let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?;
// let context = HandlerContext {
// #[cfg(feature = "sticky-cookie")]
// context_lb: context_from_lb,
// #[cfg(not(feature = "sticky-cookie"))]
// context_lb: None,
// };
// /////////////////////////////////////////////
// // apply upstream-specific headers given in upstream_option
// let headers = req.headers_mut();
// apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?;
// // update uri in request
// if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) {
// return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken"));
// };
// let new_uri = Uri::builder()
// .scheme(upstream_chosen.uri.scheme().unwrap().as_str())
// .authority(upstream_chosen.uri.authority().unwrap().as_str());
// let org_pq = match req.uri().path_and_query() {
// Some(pq) => pq.to_string(),
// None => "/".to_string(),
// }
// .into_bytes();
// // replace some parts of path if opt_replace_path is enabled for chosen upstream
// let new_pq = match &upstream_group.replace_path {
// Some(new_path) => {
// let matched_path: &[u8] = upstream_group.path.as_ref();
// if matched_path.is_empty() || org_pq.len() < matched_path.len() {
// return Err(RpxyError::Handler("Upstream uri `path and query` is broken"));
// };
// let mut new_pq = Vec::<u8>::with_capacity(org_pq.len() - matched_path.len() + new_path.len());
// new_pq.extend_from_slice(new_path.as_ref());
// new_pq.extend_from_slice(&org_pq[matched_path.len()..]);
// new_pq
// }
// None => org_pq,
// };
// *req.uri_mut() = new_uri.path_and_query(new_pq).build()?;
// // upgrade
// if let Some(v) = upgrade {
// req.headers_mut().insert(header::UPGRADE, v.parse()?);
// req
// .headers_mut()
// .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?);
// }
// // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
// if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) {
// // Change version to http/1.1 when destination scheme is http
// debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled.");
// *req.version_mut() = Version::HTTP_11;
// } else if req.version() == Version::HTTP_3 {
// // HTTP/3 is always https
// debug!("HTTP/3 is currently unsupported for request to upstream.");
// *req.version_mut() = Version::HTTP_2;
// }
// apply_upstream_options_to_request_line(req, upstream_group)?;
// Ok(context)
// }
}

View file

@ -0,0 +1,22 @@
#[cfg(feature = "cache")]
// mod cache;
mod error;
// mod forwarder;
mod handler_main;
mod utils_headers;
mod utils_request;
// mod utils_synth_response;
#[cfg(feature = "sticky-cookie")]
use crate::backend::LbContext;
pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError};
#[allow(dead_code)]
#[derive(Debug)]
/// Context object to handle sticky cookies at HTTP message handler
struct HandlerContext {
#[cfg(feature = "sticky-cookie")]
context_lb: Option<LbContext>,
#[cfg(not(feature = "sticky-cookie"))]
context_lb: Option<()>,
}

View file

@ -0,0 +1,276 @@
#[cfg(feature = "sticky-cookie")]
use crate::backend::{LbContext, StickyCookie, StickyCookieValue};
use crate::backend::{UpstreamGroup, UpstreamOption};
use crate::{error::*, log::*, utils::*};
use bytes::BufMut;
use hyper::{
header::{self, HeaderMap, HeaderName, HeaderValue},
Uri,
};
use std::{borrow::Cow, net::SocketAddr};
////////////////////////////////////////////////////
// Functions to manipulate headers
#[cfg(feature = "sticky-cookie")]
/// Take sticky cookie header value from request header,
/// and returns LbContext to be forwarded to LB if exist and if needed.
/// Removing sticky cookie is needed and it must not be passed to the upstream.
pub(super) fn takeout_sticky_cookie_lb_context(
headers: &mut HeaderMap,
expected_cookie_name: &str,
) -> Result<Option<LbContext>> {
let mut headers_clone = headers.clone();
match headers_clone.entry(header::COOKIE) {
header::Entry::Vacant(_) => Ok(None),
header::Entry::Occupied(entry) => {
let cookies_iter = entry
.iter()
.flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim()));
let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter
.into_iter()
.partition(|v| v.starts_with(expected_cookie_name));
if sticky_cookies.is_empty() {
return Ok(None);
}
if sticky_cookies.len() > 1 {
error!("Multiple sticky cookie values in request");
return Err(RpxyError::Other(anyhow!(
"Invalid cookie: Multiple sticky cookie values"
)));
}
let cookies_passed_to_upstream = without_sticky_cookies.join("; ");
let cookie_passed_to_lb = sticky_cookies.first().unwrap();
headers.remove(header::COOKIE);
headers.insert(header::COOKIE, cookies_passed_to_upstream.parse()?);
let sticky_cookie = StickyCookie {
value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?,
info: None,
};
Ok(Some(LbContext { sticky_cookie }))
}
}
}
#[cfg(feature = "sticky-cookie")]
/// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated.
/// Set-Cookie response header could be in multiple lines.
/// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie
pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> {
let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?;
let new_header_val: HeaderValue = sticky_cookie_string.parse()?;
let expected_cookie_name = &context_from_lb.sticky_cookie.value.name;
match headers.entry(header::SET_COOKIE) {
header::Entry::Vacant(entry) => {
entry.insert(new_header_val);
}
header::Entry::Occupied(mut entry) => {
let mut flag = false;
for e in entry.iter_mut() {
if e.to_str().unwrap_or("").starts_with(expected_cookie_name) {
*e = new_header_val.clone();
flag = true;
}
}
if !flag {
entry.append(new_header_val);
}
}
};
Ok(())
}
/// Apply options to request header, which are specified in the configuration
pub(super) fn apply_upstream_options_to_header(
headers: &mut HeaderMap,
_client_addr: &SocketAddr,
upstream: &UpstreamGroup,
upstream_base_uri: &Uri,
) -> Result<()> {
for opt in upstream.opts.iter() {
match opt {
UpstreamOption::OverrideHost => {
// overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy)
let upstream_host = upstream_base_uri
.host()
.ok_or_else(|| anyhow!("No hostname is given in override_host option"))?;
headers
.insert(header::HOST, HeaderValue::from_str(upstream_host)?)
.ok_or_else(|| anyhow!("Failed to insert host header in override_host option"))?;
}
UpstreamOption::UpgradeInsecureRequests => {
// add upgrade-insecure-requests in request header if not exist
headers
.entry(header::UPGRADE_INSECURE_REQUESTS)
.or_insert(HeaderValue::from_bytes(&[b'1']).unwrap());
}
_ => (),
}
}
Ok(())
}
/// Append header entry with comma according to [RFC9110](https://datatracker.ietf.org/doc/html/rfc9110)
pub(super) fn append_header_entry_with_comma(headers: &mut HeaderMap, key: &str, value: &str) -> Result<()> {
match headers.entry(HeaderName::from_bytes(key.as_bytes())?) {
header::Entry::Vacant(entry) => {
entry.insert(value.parse::<HeaderValue>()?);
}
header::Entry::Occupied(mut entry) => {
// entry.append(value.parse::<HeaderValue>()?);
let mut new_value = Vec::<u8>::with_capacity(entry.get().as_bytes().len() + 2 + value.len());
new_value.put_slice(entry.get().as_bytes());
new_value.put_slice(&[b',', b' ']);
new_value.put_slice(value.as_bytes());
entry.insert(HeaderValue::from_bytes(&new_value)?);
}
}
Ok(())
}
/// Add header entry if not exist
pub(super) fn add_header_entry_if_not_exist(
headers: &mut HeaderMap,
key: impl Into<Cow<'static, str>>,
value: impl Into<Cow<'static, str>>,
) -> Result<()> {
match headers.entry(HeaderName::from_bytes(key.into().as_bytes())?) {
header::Entry::Vacant(entry) => {
entry.insert(value.into().parse::<HeaderValue>()?);
}
header::Entry::Occupied(_) => (),
};
Ok(())
}
/// Overwrite header entry if exist
pub(super) fn add_header_entry_overwrite_if_exist(
headers: &mut HeaderMap,
key: impl Into<Cow<'static, str>>,
value: impl Into<Cow<'static, str>>,
) -> Result<()> {
match headers.entry(HeaderName::from_bytes(key.into().as_bytes())?) {
header::Entry::Vacant(entry) => {
entry.insert(value.into().parse::<HeaderValue>()?);
}
header::Entry::Occupied(mut entry) => {
entry.insert(HeaderValue::from_bytes(value.into().as_bytes())?);
}
}
Ok(())
}
/// Align cookie values in single line
/// Sometimes violates [RFC6265](https://www.rfc-editor.org/rfc/rfc6265#section-5.4) (for http/1.1).
/// This is allowed in RFC7540 (for http/2) as mentioned [here](https://stackoverflow.com/questions/4843556/in-http-specification-what-is-the-string-that-separates-cookies).
pub(super) fn make_cookie_single_line(headers: &mut HeaderMap) -> Result<()> {
let cookies = headers
.iter()
.filter(|(k, _)| **k == header::COOKIE)
.map(|(_, v)| v.to_str().unwrap_or(""))
.collect::<Vec<_>>()
.join("; ");
if !cookies.is_empty() {
headers.remove(header::COOKIE);
headers.insert(header::COOKIE, HeaderValue::from_bytes(cookies.as_bytes())?);
}
Ok(())
}
/// Add forwarding headers like `x-forwarded-for`.
pub(super) fn add_forwarding_header(
headers: &mut HeaderMap,
client_addr: &SocketAddr,
listen_addr: &SocketAddr,
tls: bool,
uri_str: &str,
) -> Result<()> {
// default process
// optional process defined by upstream_option is applied in fn apply_upstream_options
let canonical_client_addr = client_addr.to_canonical().ip().to_string();
append_header_entry_with_comma(headers, "x-forwarded-for", &canonical_client_addr)?;
// Single line cookie header
// TODO: This should be only for HTTP/1.1. For 2+, this can be multi-lined.
make_cookie_single_line(headers)?;
/////////// As Nginx
// If we receive X-Forwarded-Proto, pass it through; otherwise, pass along the
// scheme used to connect to this server
add_header_entry_if_not_exist(headers, "x-forwarded-proto", if tls { "https" } else { "http" })?;
// If we receive X-Forwarded-Port, pass it through; otherwise, pass along the
// server port the client connected to
add_header_entry_if_not_exist(headers, "x-forwarded-port", listen_addr.port().to_string())?;
/////////// As Nginx-Proxy
// x-real-ip
add_header_entry_overwrite_if_exist(headers, "x-real-ip", canonical_client_addr)?;
// x-forwarded-ssl
add_header_entry_overwrite_if_exist(headers, "x-forwarded-ssl", if tls { "on" } else { "off" })?;
// x-original-uri
add_header_entry_overwrite_if_exist(headers, "x-original-uri", uri_str.to_string())?;
// proxy
add_header_entry_overwrite_if_exist(headers, "proxy", "")?;
Ok(())
}
/// Remove connection header
pub(super) fn remove_connection_header(headers: &mut HeaderMap) {
if let Some(values) = headers.get(header::CONNECTION) {
if let Ok(v) = values.clone().to_str() {
for m in v.split(',') {
if !m.is_empty() {
headers.remove(m.trim());
}
}
}
}
}
/// Hop header values which are removed at proxy
const HOP_HEADERS: &[&str] = &[
"connection",
"te",
"trailer",
"keep-alive",
"proxy-connection",
"proxy-authenticate",
"proxy-authorization",
"transfer-encoding",
"upgrade",
];
/// Remove hop headers
pub(super) fn remove_hop_header(headers: &mut HeaderMap) {
HOP_HEADERS.iter().for_each(|key| {
headers.remove(*key);
});
}
/// Extract upgrade header value if exist
pub(super) fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
if let Some(c) = headers.get(header::CONNECTION) {
if c
.to_str()
.unwrap_or("")
.split(',')
.any(|w| w.trim().to_ascii_lowercase() == header::UPGRADE.as_str().to_ascii_lowercase())
{
if let Some(u) = headers.get(header::UPGRADE) {
if let Ok(m) = u.to_str() {
debug!("Upgrade in request header: {}", m);
return Some(m.to_owned());
}
}
}
}
None
}

View file

@ -0,0 +1,64 @@
use crate::{
backend::{UpstreamGroup, UpstreamOption},
error::*,
};
use hyper::{header, Request};
////////////////////////////////////////////////////
// Functions to manipulate request line
/// Apply upstream options in request line, specified in the configuration
pub(super) fn apply_upstream_options_to_request_line<B>(req: &mut Request<B>, upstream: &UpstreamGroup) -> Result<()> {
for opt in upstream.opts.iter() {
match opt {
UpstreamOption::ForceHttp11Upstream => *req.version_mut() = hyper::Version::HTTP_11,
UpstreamOption::ForceHttp2Upstream => {
// case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt
// Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required.
*req.version_mut() = hyper::Version::HTTP_2;
}
_ => (),
}
}
Ok(())
}
/// Trait defining parser of hostname
pub trait ParseHost {
fn parse_host(&self) -> Result<&[u8]>;
}
impl<B> ParseHost for Request<B> {
/// Extract hostname from either the request HOST header or request line
fn parse_host(&self) -> Result<&[u8]> {
let headers_host = self.headers().get(header::HOST);
let uri_host = self.uri().host();
// let uri_port = self.uri().port_u16();
if !(!(headers_host.is_none() && uri_host.is_none())) {
return Err(RpxyError::Request("No host in request header"));
}
// prioritize server_name in uri
uri_host.map_or_else(
|| {
let m = headers_host.unwrap().as_bytes();
if m.starts_with(&[b'[']) {
// v6 address with bracket case. if port is specified, always it is in this case.
let mut iter = m.split(|ptr| ptr == &b'[' || ptr == &b']');
iter.next().ok_or(RpxyError::Request("Invalid Host"))?; // first item is always blank
iter.next().ok_or(RpxyError::Request("Invalid Host"))
} else if m.len() - m.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 {
// v6 address case, if 2 or more ':' is contained
Ok(m)
} else {
// v4 address or hostname
m.split(|colon| colon == &b':')
.next()
.ok_or(RpxyError::Request("Invalid Host"))
}
},
|v| Ok(v.as_bytes()),
)
}
}

View file

@ -0,0 +1,35 @@
// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
use crate::error::*;
use hyper::{Body, Request, Response, StatusCode, Uri};
////////////////////////////////////////////////////
// Functions to create response (error or redirect)
/// Generate a synthetic response message of a certain error status code
pub(super) fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
let response = Response::builder().status(status_code).body(Body::empty())?;
Ok(response)
}
/// Generate synthetic response message of a redirection to https host with 301
pub(super) fn secure_redirection<B>(
server_name: &str,
tls_port: Option<u16>,
req: &Request<B>,
) -> Result<Response<Body>> {
let pq = match req.uri().path_and_query() {
Some(x) => x.as_str(),
_ => "",
};
let new_uri = Uri::builder().scheme("https").path_and_query(pq);
let dest_uri = match tls_port {
Some(443) | None => new_uri.authority(server_name),
Some(p) => new_uri.authority(format!("{server_name}:{p}")),
}
.build()?;
let response = Response::builder()
.status(StatusCode::MOVED_PERMANENTLY)
.header("Location", dest_uri.to_string())
.body(Body::empty())?;
Ok(response)
}