Merge branch 'tmp/sticky-cookie' into feat/sticky-cookie-feature
This commit is contained in:
commit
d8cadf06af
80 changed files with 4870 additions and 867 deletions
|
|
@ -1,11 +1,11 @@
|
|||
use crate::{
|
||||
AppConfig, AppConfigList,
|
||||
error::*,
|
||||
log::*,
|
||||
name_exp::{ByteName, ServerName},
|
||||
AppConfig, AppConfigList,
|
||||
};
|
||||
use ahash::HashMap;
|
||||
use derive_builder::Builder;
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
use std::borrow::Cow;
|
||||
|
||||
use super::upstream::PathManager;
|
||||
|
|
@ -26,6 +26,7 @@ pub struct BackendApp {
|
|||
pub https_redirection: Option<bool>,
|
||||
/// tls settings: mutual TLS is enabled
|
||||
#[builder(default)]
|
||||
#[allow(unused)]
|
||||
pub mutual_tls: Option<bool>,
|
||||
}
|
||||
impl<'a> BackendAppBuilder {
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ pub use super::{
|
|||
use derive_builder::Builder;
|
||||
use rand::Rng;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
/// Constants to specify a load balance option
|
||||
|
|
@ -80,8 +80,8 @@ impl LoadBalanceRandomBuilder {
|
|||
impl LoadBalanceWithPointer for LoadBalanceRandom {
|
||||
/// Returns the random index within the range
|
||||
fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream {
|
||||
let mut rng = rand::thread_rng();
|
||||
let ptr = rng.gen_range(0..self.num_upstreams);
|
||||
let mut rng = rand::rng();
|
||||
let ptr = rng.random_range(0..self.num_upstreams);
|
||||
PointerToUpstream { ptr, context: None }
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
use super::{
|
||||
Upstream,
|
||||
load_balance_main::{LoadBalanceContext, LoadBalanceWithPointer, PointerToUpstream},
|
||||
sticky_cookie::StickyCookieConfig,
|
||||
Upstream,
|
||||
};
|
||||
use crate::{constants::STICKY_COOKIE_NAME, log::*};
|
||||
use ahash::HashMap;
|
||||
use derive_builder::Builder;
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
};
|
||||
|
||||
|
|
@ -112,13 +112,16 @@ impl LoadBalanceWithPointer for LoadBalanceSticky {
|
|||
}
|
||||
Some(context) => {
|
||||
let server_id = &context.sticky_cookie.value.value;
|
||||
if let Some(server_index) = self.get_server_index_from_id(server_id) {
|
||||
debug!("Valid sticky cookie: id={}, index={}", server_id, server_index);
|
||||
server_index
|
||||
} else {
|
||||
debug!("Invalid sticky cookie: id={}", server_id);
|
||||
self.simple_increment_ptr()
|
||||
}
|
||||
self.get_server_index_from_id(server_id).map_or_else(
|
||||
|| {
|
||||
debug!("Invalid sticky cookie: id={}", server_id);
|
||||
self.simple_increment_ptr()
|
||||
},
|
||||
|server_index| {
|
||||
debug!("Valid sticky cookie: id={}, index={}", server_id, server_index);
|
||||
server_index
|
||||
},
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use super::upstream::Upstream;
|
|||
use thiserror::Error;
|
||||
|
||||
pub use load_balance_main::{
|
||||
load_balance_options, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder,
|
||||
LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, load_balance_options,
|
||||
};
|
||||
#[cfg(feature = "sticky-cookie")]
|
||||
pub use load_balance_sticky::LoadBalanceStickyBuilder;
|
||||
|
|
|
|||
|
|
@ -91,12 +91,7 @@ impl<'a> StickyCookieBuilder {
|
|||
self
|
||||
}
|
||||
/// Set the meta information of sticky cookie
|
||||
pub fn info(
|
||||
&mut self,
|
||||
domain: impl Into<Cow<'a, str>>,
|
||||
path: impl Into<Cow<'a, str>>,
|
||||
duration_secs: i64,
|
||||
) -> &mut Self {
|
||||
pub fn info(&mut self, domain: impl Into<Cow<'a, str>>, path: impl Into<Cow<'a, str>>, duration_secs: i64) -> &mut Self {
|
||||
let info = StickyCookieInfoBuilder::default()
|
||||
.domain(domain)
|
||||
.path(path)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#[cfg(feature = "sticky-cookie")]
|
||||
use super::load_balance::LoadBalanceStickyBuilder;
|
||||
use super::load_balance::{
|
||||
load_balance_options as lb_opts, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder,
|
||||
LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, load_balance_options as lb_opts,
|
||||
};
|
||||
// use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption};
|
||||
use super::upstream_opts::UpstreamOption;
|
||||
|
|
@ -11,10 +11,10 @@ use crate::{
|
|||
log::*,
|
||||
name_exp::{ByteName, PathName},
|
||||
};
|
||||
use ahash::{HashMap, HashSet};
|
||||
#[cfg(feature = "sticky-cookie")]
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use derive_builder::Builder;
|
||||
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
|
||||
#[cfg(feature = "sticky-cookie")]
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::borrow::Cow;
|
||||
|
|
@ -72,27 +72,22 @@ impl PathManager {
|
|||
.inner
|
||||
.iter()
|
||||
.filter(|(route_bytes, _)| {
|
||||
match path_name.starts_with(route_bytes) {
|
||||
true => {
|
||||
route_bytes.len() == 1 // route = '/', i.e., default
|
||||
|| match path_name.get(route_bytes.len()) {
|
||||
None => true, // exact case
|
||||
Some(p) => p == &b'/', // sub-path case
|
||||
}
|
||||
}
|
||||
_ => false,
|
||||
path_name.starts_with(route_bytes) && {
|
||||
route_bytes.len() == 1 // route = '/', i.e., default
|
||||
|| path_name.get(route_bytes.len()).map_or(
|
||||
true, // exact case
|
||||
|p| p == &b'/'
|
||||
) // sub-path case
|
||||
}
|
||||
})
|
||||
.max_by_key(|(route_bytes, _)| route_bytes.len());
|
||||
if let Some((path, u)) = matched_upstream {
|
||||
matched_upstream.map(|(path, u)| {
|
||||
debug!(
|
||||
"Found upstream: {:?}",
|
||||
path.try_into().unwrap_or_else(|_| "<none>".to_string())
|
||||
);
|
||||
Some(u)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
u
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -211,14 +206,15 @@ impl UpstreamCandidatesBuilder {
|
|||
}
|
||||
/// Set the activated upstream options defined in [[UpstreamOption]]
|
||||
pub fn options(&mut self, v: &Option<Vec<String>>) -> &mut Self {
|
||||
let opts = if let Some(opts) = v {
|
||||
opts
|
||||
.iter()
|
||||
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
|
||||
.collect::<HashSet<UpstreamOption>>()
|
||||
} else {
|
||||
Default::default()
|
||||
};
|
||||
let opts = v.as_ref().map_or_else(
|
||||
|| Default::default(),
|
||||
|opts| {
|
||||
opts
|
||||
.iter()
|
||||
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
|
||||
.collect::<HashSet<UpstreamOption>>()
|
||||
},
|
||||
);
|
||||
self.options = Some(opts);
|
||||
self
|
||||
}
|
||||
|
|
|
|||
|
|
@ -32,3 +32,9 @@ pub const MAX_CACHE_EACH_SIZE: usize = 65_535;
|
|||
pub const MAX_CACHE_EACH_SIZE_ON_MEMORY: usize = 4_096;
|
||||
|
||||
// TODO: max cache size in total
|
||||
|
||||
/// Logging event name TODO: Other separated logs?
|
||||
pub mod log_event_names {
|
||||
/// access log
|
||||
pub const ACCESS_LOG: &str = "rpxy::access";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
|
|
|
|||
|
|
@ -37,8 +37,11 @@ pub enum RpxyError {
|
|||
|
||||
// http/3 errors
|
||||
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
|
||||
#[error("H3 error: {0}")]
|
||||
H3Error(#[from] h3::Error),
|
||||
#[error("h3 connection error: {0}")]
|
||||
H3ConnectionError(#[from] h3::error::ConnectionError),
|
||||
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
|
||||
#[error("h3 connection error: {0}")]
|
||||
H3StreamError(#[from] h3::error::StreamError),
|
||||
// #[cfg(feature = "http3-s2n")]
|
||||
// #[error("H3 error: {0}")]
|
||||
// H3Error(#[from] s2n_quic_h3::h3::Error),
|
||||
|
|
|
|||
103
rpxy-lib/src/forwarder/cache/cache_main.rs
vendored
103
rpxy-lib/src/forwarder/cache/cache_main.rs
vendored
|
|
@ -1,10 +1,10 @@
|
|||
use super::cache_error::*;
|
||||
use crate::{
|
||||
globals::Globals,
|
||||
hyper_ext::body::{full, BoxBody, ResponseBody, UnboundedStreamBody},
|
||||
hyper_ext::body::{BoxBody, ResponseBody, UnboundedStreamBody, full},
|
||||
log::*,
|
||||
};
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use bytes::{Buf, Bytes, BytesMut};
|
||||
use futures::channel::mpsc;
|
||||
use http::{Request, Response, Uri};
|
||||
|
|
@ -16,8 +16,8 @@ use sha2::{Digest, Sha256};
|
|||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
},
|
||||
time::SystemTime,
|
||||
};
|
||||
|
|
@ -52,23 +52,30 @@ impl RpxyCache {
|
|||
if !globals.proxy_config.cache_enabled {
|
||||
return None;
|
||||
}
|
||||
let cache_dir = globals.proxy_config.cache_dir.as_ref().unwrap();
|
||||
let cache_dir = match globals.proxy_config.cache_dir.as_ref() {
|
||||
Some(dir) => dir,
|
||||
None => {
|
||||
warn!("Cache directory not set in proxy config");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let file_store = FileStore::new(&globals.runtime_handle).await;
|
||||
let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry);
|
||||
|
||||
let max_each_size = globals.proxy_config.cache_max_each_size;
|
||||
let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory;
|
||||
if max_each_size < max_each_size_on_memory {
|
||||
warn!(
|
||||
"Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache"
|
||||
);
|
||||
warn!("Maximum size of on-memory cache per entry must be smaller than or equal to the maximum of each file cache");
|
||||
max_each_size_on_memory = max_each_size;
|
||||
}
|
||||
|
||||
if let Err(e) = fs::remove_dir_all(cache_dir).await {
|
||||
warn!("Failed to clean up the cache dir: {e}");
|
||||
};
|
||||
fs::create_dir_all(&cache_dir).await.unwrap();
|
||||
}
|
||||
if let Err(e) = fs::create_dir_all(&cache_dir).await {
|
||||
error!("Failed to create cache dir: {e}");
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(Self {
|
||||
file_store,
|
||||
|
|
@ -89,12 +96,7 @@ impl RpxyCache {
|
|||
}
|
||||
|
||||
/// Put response into the cache
|
||||
pub(crate) async fn put(
|
||||
&self,
|
||||
uri: &hyper::Uri,
|
||||
mut body: Incoming,
|
||||
policy: &CachePolicy,
|
||||
) -> CacheResult<UnboundedStreamBody> {
|
||||
pub(crate) async fn put(&self, uri: &hyper::Uri, mut body: Incoming, policy: &CachePolicy) -> CacheResult<UnboundedStreamBody> {
|
||||
let cache_manager = self.inner.clone();
|
||||
let mut file_store = self.file_store.clone();
|
||||
let uri = uri.clone();
|
||||
|
|
@ -155,7 +157,7 @@ impl RpxyCache {
|
|||
let mut hasher = Sha256::new();
|
||||
hasher.update(buf.as_ref());
|
||||
let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref());
|
||||
debug!("Cached data: {} bytes, hash = {:?}", size, hash_bytes);
|
||||
trace!("Cached data: {} bytes, hash = {:?}", size, hash_bytes);
|
||||
|
||||
// Create cache object
|
||||
let cache_key = derive_cache_key_from_uri(&uri);
|
||||
|
|
@ -188,16 +190,11 @@ impl RpxyCache {
|
|||
|
||||
/// Get cached response
|
||||
pub(crate) async fn get<R>(&self, req: &Request<R>) -> Option<Response<ResponseBody>> {
|
||||
debug!(
|
||||
"Current cache status: (total, on-memory, file) = {:?}",
|
||||
self.count().await
|
||||
);
|
||||
trace!("Current cache status: (total, on-memory, file) = {:?}", self.count().await);
|
||||
let cache_key = derive_cache_key_from_uri(req.uri());
|
||||
|
||||
// First check cache chance
|
||||
let Ok(Some(cached_object)) = self.inner.get(&cache_key) else {
|
||||
return None;
|
||||
};
|
||||
let cached_object = self.inner.get(&cache_key).ok()??;
|
||||
|
||||
// Secondly check the cache freshness as an HTTP message
|
||||
let now = SystemTime::now();
|
||||
|
|
@ -268,25 +265,20 @@ impl FileStore {
|
|||
let inner = self.inner.read().await;
|
||||
inner.cnt
|
||||
}
|
||||
/// Create a temporary file cache
|
||||
/// Create a temporary file cache, returns error if file cannot be created or written
|
||||
async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> {
|
||||
let mut inner = self.inner.write().await;
|
||||
inner.create(cache_object, body_bytes).await
|
||||
}
|
||||
/// Evict a temporary file cache
|
||||
/// Evict a temporary file cache, logs warning if removal fails
|
||||
async fn evict(&self, path: impl AsRef<Path>) {
|
||||
// Acquire the write lock
|
||||
let mut inner = self.inner.write().await;
|
||||
if let Err(e) = inner.remove(path).await {
|
||||
warn!("Eviction failed during file object removal: {:?}", e);
|
||||
};
|
||||
}
|
||||
}
|
||||
/// Read a temporary file cache
|
||||
async fn read(
|
||||
&self,
|
||||
path: impl AsRef<Path> + Send + Sync + 'static,
|
||||
hash: &Bytes,
|
||||
) -> CacheResult<UnboundedStreamBody> {
|
||||
/// Read a temporary file cache, returns error if file cannot be opened or hash mismatches
|
||||
async fn read(&self, path: impl AsRef<Path> + Send + Sync + 'static, hash: &Bytes) -> CacheResult<UnboundedStreamBody> {
|
||||
let inner = self.inner.read().await;
|
||||
inner.read(path, hash).await
|
||||
}
|
||||
|
|
@ -321,26 +313,22 @@ impl FileStoreInner {
|
|||
return Err(CacheError::InvalidCacheTarget);
|
||||
}
|
||||
};
|
||||
let Ok(mut file) = File::create(&cache_filepath).await else {
|
||||
return Err(CacheError::FailedToCreateFileCache);
|
||||
};
|
||||
let mut file = File::create(&cache_filepath)
|
||||
.await
|
||||
.map_err(|_| CacheError::FailedToCreateFileCache)?;
|
||||
let mut bytes_clone = body_bytes.clone();
|
||||
while bytes_clone.has_remaining() {
|
||||
if let Err(e) = file.write_buf(&mut bytes_clone).await {
|
||||
file.write_buf(&mut bytes_clone).await.map_err(|e| {
|
||||
error!("Failed to write file cache: {e}");
|
||||
return Err(CacheError::FailedToWriteFileCache);
|
||||
};
|
||||
CacheError::FailedToWriteFileCache
|
||||
})?;
|
||||
}
|
||||
self.cnt += 1;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve a stored temporary file cache
|
||||
async fn read(
|
||||
&self,
|
||||
path: impl AsRef<Path> + Send + Sync + 'static,
|
||||
hash: &Bytes,
|
||||
) -> CacheResult<UnboundedStreamBody> {
|
||||
async fn read(&self, path: impl AsRef<Path> + Send + Sync + 'static, hash: &Bytes) -> CacheResult<UnboundedStreamBody> {
|
||||
let Ok(mut file) = File::open(&path).await else {
|
||||
warn!("Cache file object cannot be opened");
|
||||
return Err(CacheError::FailedToOpenCacheFile);
|
||||
|
|
@ -455,11 +443,14 @@ impl LruCacheManager {
|
|||
self.cnt.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
/// Evict an entry
|
||||
/// Evict an entry from the LRU cache, logs error if mutex cannot be acquired
|
||||
fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> {
|
||||
let Ok(mut lock) = self.inner.lock() else {
|
||||
error!("Mutex can't be locked to evict a cache entry");
|
||||
return None;
|
||||
let mut lock = match self.inner.lock() {
|
||||
Ok(lock) => lock,
|
||||
Err(_) => {
|
||||
error!("Mutex can't be locked to evict a cache entry");
|
||||
return None;
|
||||
}
|
||||
};
|
||||
let res = lock.pop_entry(cache_key);
|
||||
// This may be inconsistent with the actual number of entries
|
||||
|
|
@ -467,24 +458,24 @@ impl LruCacheManager {
|
|||
res
|
||||
}
|
||||
|
||||
/// Push an entry
|
||||
/// Push an entry into the LRU cache, returns error if mutex cannot be acquired
|
||||
fn push(&self, cache_key: &str, cache_object: &CacheObject) -> CacheResult<Option<(String, CacheObject)>> {
|
||||
let Ok(mut lock) = self.inner.lock() else {
|
||||
let mut lock = self.inner.lock().map_err(|_| {
|
||||
error!("Failed to acquire mutex lock for writing cache entry");
|
||||
return Err(CacheError::FailedToAcquiredMutexLockForCache);
|
||||
};
|
||||
CacheError::FailedToAcquiredMutexLockForCache
|
||||
})?;
|
||||
let res = Ok(lock.push(cache_key.to_string(), cache_object.clone()));
|
||||
// This may be inconsistent with the actual number of entries
|
||||
self.cnt.store(lock.len(), Ordering::Relaxed);
|
||||
res
|
||||
}
|
||||
|
||||
/// Get an entry
|
||||
/// Get an entry from the LRU cache, returns error if mutex cannot be acquired
|
||||
fn get(&self, cache_key: &str) -> CacheResult<Option<CacheObject>> {
|
||||
let Ok(mut lock) = self.inner.lock() else {
|
||||
let mut lock = self.inner.lock().map_err(|_| {
|
||||
error!("Mutex can't be locked for checking cache entry");
|
||||
return Err(CacheError::FailedToAcquiredMutexLockForCheck);
|
||||
};
|
||||
CacheError::FailedToAcquiredMutexLockForCheck
|
||||
})?;
|
||||
let Some(cached_object) = lock.get(cache_key) else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
|
|
|||
2
rpxy-lib/src/forwarder/cache/mod.rs
vendored
2
rpxy-lib/src/forwarder/cache/mod.rs
vendored
|
|
@ -2,4 +2,4 @@ mod cache_error;
|
|||
mod cache_main;
|
||||
|
||||
pub use cache_error::CacheError;
|
||||
pub(crate) use cache_main::{get_policy_if_cacheable, RpxyCache};
|
||||
pub(crate) use cache_main::{RpxyCache, get_policy_if_cacheable};
|
||||
|
|
|
|||
|
|
@ -9,13 +9,13 @@ use async_trait::async_trait;
|
|||
use http::{Request, Response, Version};
|
||||
use hyper::body::{Body, Incoming};
|
||||
use hyper_util::client::legacy::{
|
||||
connect::{Connect, HttpConnector},
|
||||
Client,
|
||||
connect::{Connect, HttpConnector},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "cache")]
|
||||
use super::cache::{get_policy_if_cacheable, RpxyCache};
|
||||
use super::cache::{RpxyCache, get_policy_if_cacheable};
|
||||
|
||||
#[async_trait]
|
||||
/// Definition of the forwarder that simply forward requests from downstream client to upstream app servers.
|
||||
|
|
@ -126,9 +126,9 @@ where
|
|||
warn!(
|
||||
"
|
||||
--------------------------------------------------------------------------------------------------
|
||||
Request forwarder is working without TLS support!!!
|
||||
We recommend to use this just for testing.
|
||||
Please enable native-tls-backend or rustls-backend feature to enable TLS support.
|
||||
Request forwarder is working without TLS support!
|
||||
This mode is intended for testing only.
|
||||
Enable 'native-tls-backend' or 'rustls-backend' feature for TLS support.
|
||||
--------------------------------------------------------------------------------------------------"
|
||||
);
|
||||
let executor = LocalExecutor::new(_globals.runtime_handle.clone());
|
||||
|
|
@ -159,7 +159,7 @@ where
|
|||
/// Build forwarder
|
||||
pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> {
|
||||
// build hyper client with hyper-tls
|
||||
info!("Native TLS support is enabled for the connection to backend applications");
|
||||
info!("Native TLS support enabled for backend connections (native-tls)");
|
||||
let executor = LocalExecutor::new(_globals.runtime_handle.clone());
|
||||
|
||||
let try_build_connector = |alpns: &[&str]| {
|
||||
|
|
@ -209,14 +209,14 @@ where
|
|||
#[cfg(feature = "webpki-roots")]
|
||||
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
|
||||
#[cfg(feature = "webpki-roots")]
|
||||
info!("Mozilla WebPKI root certs with rustls is used for the connection to backend applications");
|
||||
info!("Rustls backend: Mozilla WebPKI root certs used for backend connections");
|
||||
|
||||
#[cfg(not(feature = "webpki-roots"))]
|
||||
let builder = hyper_rustls::HttpsConnectorBuilder::new().with_platform_verifier();
|
||||
#[cfg(not(feature = "webpki-roots"))]
|
||||
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_platform_verifier();
|
||||
#[cfg(not(feature = "webpki-roots"))]
|
||||
info!("Platform verifier with rustls is used for the connection to backend applications");
|
||||
info!("Rustls backend: Platform verifier used for backend connections");
|
||||
|
||||
let mut http = HttpConnector::new();
|
||||
http.enforce_http(false);
|
||||
|
|
@ -226,7 +226,9 @@ where
|
|||
let connector = builder.https_or_http().enable_all_versions().wrap_connector(http.clone());
|
||||
let connector_h2 = builder_h2.https_or_http().enable_http2().wrap_connector(http);
|
||||
let inner = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector);
|
||||
let inner_h2 = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector_h2);
|
||||
let inner_h2 = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone()))
|
||||
.http2_only(true)
|
||||
.build::<_, B1>(connector_h2);
|
||||
|
||||
Ok(Self {
|
||||
inner,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ use crate::{constants::*, count::RequestCount};
|
|||
use hot_reload::ReloaderReceiver;
|
||||
use rpxy_certs::ServerCryptoBase;
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
/// Global object containing proxy configurations and shared object like counters.
|
||||
/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
|
||||
|
|
@ -13,14 +12,12 @@ pub struct Globals {
|
|||
pub request_count: RequestCount,
|
||||
/// Shared context - Async task runtime handler
|
||||
pub runtime_handle: tokio::runtime::Handle,
|
||||
/// Shared context - Notify object to stop async tasks
|
||||
pub cancel_token: Option<CancellationToken>,
|
||||
/// Shared context - Certificate reloader service receiver // TODO: newer one
|
||||
pub cert_reloader_rx: Option<ReloaderReceiver<ServerCryptoBase>>,
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
/// ServerConfig used for only ACME challenge for ACME domains
|
||||
pub server_configs_acme_challenge: std::sync::Arc<rustc_hash::FxHashMap<String, std::sync::Arc<rustls::ServerConfig>>>,
|
||||
pub server_configs_acme_challenge: std::sync::Arc<ahash::HashMap<String, std::sync::Arc<rustls::ServerConfig>>>,
|
||||
}
|
||||
|
||||
/// Configuration parameters for proxy transport and request handlers
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use super::watch;
|
||||
use crate::error::*;
|
||||
use futures_channel::{mpsc, oneshot};
|
||||
use futures_util::{stream::FusedStream, Future, Stream};
|
||||
use futures_util::{Future, Stream, stream::FusedStream};
|
||||
use http::HeaderMap;
|
||||
use hyper::body::{Body, Bytes, Frame, SizeHint};
|
||||
use std::{
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use super::body::IncomingLike;
|
||||
use crate::error::RpxyError;
|
||||
use futures::channel::mpsc::UnboundedReceiver;
|
||||
use http_body_util::{combinators, BodyExt, Empty, Full, StreamBody};
|
||||
use http_body_util::{BodyExt, Empty, Full, StreamBody, combinators};
|
||||
use hyper::body::{Body, Bytes, Frame, Incoming};
|
||||
use std::pin::Pin;
|
||||
|
||||
|
|
|
|||
|
|
@ -12,5 +12,5 @@ pub(crate) mod rt {
|
|||
#[allow(unused)]
|
||||
pub(crate) mod body {
|
||||
pub(crate) use super::body_incoming_like::IncomingLike;
|
||||
pub(crate) use super::body_type::{empty, full, BoxBody, RequestBody, ResponseBody, UnboundedStreamBody};
|
||||
pub(crate) use super::body_type::{BoxBody, RequestBody, ResponseBody, UnboundedStreamBody, empty, full};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@
|
|||
|
||||
use futures_util::task::AtomicWaker;
|
||||
use std::sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
};
|
||||
use std::task;
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ use std::sync::Arc;
|
|||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
/* ------------------------------------------------ */
|
||||
pub use crate::constants::log_event_names;
|
||||
pub use crate::globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri};
|
||||
pub mod reexports {
|
||||
pub use hyper::Uri;
|
||||
|
|
@ -43,12 +44,10 @@ pub struct RpxyOptions {
|
|||
pub cert_rx: Option<ReloaderReceiver<ServerCryptoBase>>, // TODO:
|
||||
/// Async task runtime handler
|
||||
pub runtime_handle: tokio::runtime::Handle,
|
||||
/// Notify object to stop async tasks
|
||||
pub cancel_token: Option<CancellationToken>,
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
/// ServerConfig used for only ACME challenge for ACME domains
|
||||
pub server_configs_acme_challenge: Arc<rustc_hash::FxHashMap<String, Arc<rustls::ServerConfig>>>,
|
||||
pub server_configs_acme_challenge: Arc<ahash::HashMap<String, Arc<rustls::ServerConfig>>>,
|
||||
}
|
||||
|
||||
/// Entrypoint that creates and spawns tasks of reverse proxy services
|
||||
|
|
@ -58,10 +57,10 @@ pub async fn entrypoint(
|
|||
app_config_list,
|
||||
cert_rx, // TODO:
|
||||
runtime_handle,
|
||||
cancel_token,
|
||||
#[cfg(feature = "acme")]
|
||||
server_configs_acme_challenge,
|
||||
}: &RpxyOptions,
|
||||
cancel_token: CancellationToken,
|
||||
) -> RpxyResult<()> {
|
||||
#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))]
|
||||
warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used");
|
||||
|
|
@ -117,7 +116,6 @@ pub async fn entrypoint(
|
|||
proxy_config: proxy_config.clone(),
|
||||
request_count: Default::default(),
|
||||
runtime_handle: runtime_handle.clone(),
|
||||
cancel_token: cancel_token.clone(),
|
||||
cert_reloader_rx: cert_rx.clone(),
|
||||
|
||||
#[cfg(feature = "acme")]
|
||||
|
|
@ -153,25 +151,21 @@ pub async fn entrypoint(
|
|||
message_handler: message_handler.clone(),
|
||||
};
|
||||
|
||||
let cancel_token = globals.cancel_token.as_ref().map(|t| t.child_token());
|
||||
let parent_cancel_token_clone = globals.cancel_token.clone();
|
||||
let cancel_token = cancel_token.clone();
|
||||
globals.runtime_handle.spawn(async move {
|
||||
info!("rpxy proxy service for {listening_on} started");
|
||||
if let Some(cancel_token) = cancel_token {
|
||||
tokio::select! {
|
||||
_ = cancel_token.cancelled() => {
|
||||
debug!("rpxy proxy service for {listening_on} terminated");
|
||||
Ok(())
|
||||
},
|
||||
proxy_res = proxy.start() => {
|
||||
info!("rpxy proxy service for {listening_on} exited");
|
||||
// cancel other proxy tasks
|
||||
parent_cancel_token_clone.unwrap().cancel();
|
||||
proxy_res
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
_ = cancel_token.cancelled() => {
|
||||
debug!("rpxy proxy service for {listening_on} terminated");
|
||||
Ok(())
|
||||
},
|
||||
proxy_res = proxy.start(cancel_token.child_token()) => {
|
||||
info!("rpxy proxy service for {listening_on} exited");
|
||||
// cancel other proxy tasks
|
||||
cancel_token.cancel();
|
||||
proxy_res
|
||||
}
|
||||
} else {
|
||||
proxy.start().await
|
||||
}
|
||||
})
|
||||
});
|
||||
|
|
@ -186,9 +180,5 @@ pub async fn entrypoint(
|
|||
}
|
||||
});
|
||||
// returns the first error as the representative error
|
||||
if let Some(e) = errs.next() {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
errs.next().map_or(Ok(()), |e| Err(e))
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
pub use tracing::{debug, error, info, warn};
|
||||
pub use tracing::{debug, error, info, trace, warn};
|
||||
|
|
|
|||
|
|
@ -44,10 +44,7 @@ mod tests {
|
|||
}
|
||||
#[test]
|
||||
fn ipv6_to_canonical() {
|
||||
let socket = SocketAddr::new(
|
||||
IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0xdead, 0xbeef)),
|
||||
8080,
|
||||
);
|
||||
let socket = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0xdead, 0xbeef)), 8080);
|
||||
assert_eq!(socket.to_canonical(), socket);
|
||||
}
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ where
|
|||
Ok(v)
|
||||
}
|
||||
Err(e) => {
|
||||
error!("{e}");
|
||||
error!("{e}: {log_data}");
|
||||
let code = StatusCode::from(e);
|
||||
log_data.status_code(&code).output();
|
||||
synthetic_error_response(code)
|
||||
|
|
@ -107,9 +107,11 @@ where
|
|||
let backend_app = match self.app_manager.apps.get(&server_name) {
|
||||
Some(backend_app) => backend_app,
|
||||
None => {
|
||||
let Some(default_server_name) = &self.app_manager.default_server_name else {
|
||||
return Err(HttpError::NoMatchingBackendApp);
|
||||
};
|
||||
let default_server_name = self
|
||||
.app_manager
|
||||
.default_server_name
|
||||
.as_ref()
|
||||
.ok_or(HttpError::NoMatchingBackendApp)?;
|
||||
debug!("Serving by default app");
|
||||
self.app_manager.apps.get(default_server_name).unwrap()
|
||||
}
|
||||
|
|
@ -131,9 +133,7 @@ where
|
|||
// Find reverse proxy for given path and choose one of upstream host
|
||||
// Longest prefix match
|
||||
let path = req.uri().path();
|
||||
let Some(upstream_candidates) = backend_app.path_manager.get(path) else {
|
||||
return Err(HttpError::NoUpstreamCandidates);
|
||||
};
|
||||
let upstream_candidates = backend_app.path_manager.get(path).ok_or(HttpError::NoUpstreamCandidates)?;
|
||||
|
||||
// Upgrade in request header
|
||||
let upgrade_in_request = extract_upgrade(req.headers());
|
||||
|
|
@ -147,19 +147,17 @@ where
|
|||
let req_on_upgrade = hyper::upgrade::on(&mut req);
|
||||
|
||||
// Build request from destination information
|
||||
let _context = match self.generate_request_forwarded(
|
||||
&client_addr,
|
||||
&listen_addr,
|
||||
&mut req,
|
||||
&upgrade_in_request,
|
||||
upstream_candidates,
|
||||
tls_enabled,
|
||||
) {
|
||||
Err(e) => {
|
||||
return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string()));
|
||||
}
|
||||
Ok(v) => v,
|
||||
};
|
||||
let _context = self
|
||||
.generate_request_forwarded(
|
||||
&client_addr,
|
||||
&listen_addr,
|
||||
&mut req,
|
||||
&upgrade_in_request,
|
||||
upstream_candidates,
|
||||
tls_enabled,
|
||||
)
|
||||
.map_err(|e| HttpError::FailedToGenerateUpstreamRequest(e.to_string()))?;
|
||||
|
||||
debug!(
|
||||
"Request to be forwarded: [uri {}, method: {}, version {:?}, headers {:?}]",
|
||||
req.uri(),
|
||||
|
|
@ -173,12 +171,12 @@ where
|
|||
|
||||
//////////////
|
||||
// Forward request to a chosen backend
|
||||
let mut res_backend = match self.forwarder.request(req).await {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
return Err(HttpError::FailedToGetResponseFromBackend(e.to_string()));
|
||||
}
|
||||
};
|
||||
let mut res_backend = self
|
||||
.forwarder
|
||||
.request(req)
|
||||
.await
|
||||
.map_err(|e| HttpError::FailedToGetResponseFromBackend(e.to_string()))?;
|
||||
|
||||
//////////////
|
||||
// Process reverse proxy context generated during the forwarding request generation.
|
||||
#[cfg(feature = "sticky-cookie")]
|
||||
|
|
@ -191,16 +189,16 @@ where
|
|||
|
||||
if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS {
|
||||
// Generate response to client
|
||||
if let Err(e) = self.generate_response_forwarded(&mut res_backend, backend_app) {
|
||||
return Err(HttpError::FailedToGenerateDownstreamResponse(e.to_string()));
|
||||
}
|
||||
self
|
||||
.generate_response_forwarded(&mut res_backend, backend_app)
|
||||
.map_err(|e| HttpError::FailedToGenerateDownstreamResponse(e.to_string()))?;
|
||||
return Ok(res_backend);
|
||||
}
|
||||
|
||||
// Handle StatusCode::SWITCHING_PROTOCOLS in response
|
||||
let upgrade_in_response = extract_upgrade(res_backend.headers());
|
||||
let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) {
|
||||
(Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(),
|
||||
(Some(u_req), Some(u_res)) => u_req.eq_ignore_ascii_case(u_res),
|
||||
_ => false,
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
use super::{handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line, HttpMessageHandler};
|
||||
use super::{HttpMessageHandler, handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line};
|
||||
use crate::{
|
||||
backend::{BackendApp, UpstreamCandidates},
|
||||
constants::RESPONSE_HEADER_SERVER,
|
||||
log::*,
|
||||
};
|
||||
use anyhow::{anyhow, ensure, Result};
|
||||
use http::{header, HeaderValue, Request, Response, Uri};
|
||||
use anyhow::{Result, anyhow, ensure};
|
||||
use http::{HeaderValue, Request, Response, Uri, header};
|
||||
use hyper_util::client::legacy::connect::Connect;
|
||||
use std::net::SocketAddr;
|
||||
|
||||
|
|
@ -66,17 +66,19 @@ where
|
|||
upstream_candidates: &UpstreamCandidates,
|
||||
tls_enabled: bool,
|
||||
) -> Result<HandlerContext> {
|
||||
debug!("Generate request to be forwarded");
|
||||
trace!("Generate request to be forwarded");
|
||||
|
||||
// Add te: trailer if contained in original request
|
||||
let contains_te_trailers = {
|
||||
if let Some(te) = req.headers().get(header::TE) {
|
||||
te.as_bytes()
|
||||
.split(|v| v == &b',' || v == &b' ')
|
||||
.any(|x| x == "trailers".as_bytes())
|
||||
} else {
|
||||
false
|
||||
}
|
||||
req
|
||||
.headers()
|
||||
.get(header::TE)
|
||||
.map(|te| {
|
||||
te.as_bytes()
|
||||
.split(|v| v == &b',' || v == &b' ')
|
||||
.any(|x| x == "trailers".as_bytes())
|
||||
})
|
||||
.unwrap_or(false)
|
||||
};
|
||||
|
||||
let original_uri = req.uri().to_string();
|
||||
|
|
@ -136,11 +138,7 @@ where
|
|||
let new_uri = Uri::builder()
|
||||
.scheme(upstream_chosen.uri.scheme().unwrap().as_str())
|
||||
.authority(upstream_chosen.uri.authority().unwrap().as_str());
|
||||
let org_pq = match req.uri().path_and_query() {
|
||||
Some(pq) => pq.to_string(),
|
||||
None => "/".to_string(),
|
||||
}
|
||||
.into_bytes();
|
||||
let org_pq = req.uri().path_and_query().map(|pq| pq.as_str()).unwrap_or("/").as_bytes();
|
||||
|
||||
// replace some parts of path if opt_replace_path is enabled for chosen upstream
|
||||
let new_pq = match &upstream_candidates.replace_path {
|
||||
|
|
@ -155,7 +153,7 @@ where
|
|||
new_pq.extend_from_slice(&org_pq[matched_path.len()..]);
|
||||
new_pq
|
||||
}
|
||||
None => org_pq,
|
||||
None => org_pq.to_vec(),
|
||||
};
|
||||
*req.uri_mut() = new_uri.path_and_query(new_pq).build()?;
|
||||
|
||||
|
|
|
|||
|
|
@ -34,11 +34,7 @@ impl<T> From<&http::Request<T>> for HttpMessageLog {
|
|||
client_addr: "".to_string(),
|
||||
method: req.method().to_string(),
|
||||
host: header_mapper(header::HOST),
|
||||
p_and_q: req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map_or_else(|| "", |v| v.as_str())
|
||||
.to_string(),
|
||||
p_and_q: req.uri().path_and_query().map_or_else(|| "", |v| v.as_str()).to_string(),
|
||||
version: req.version(),
|
||||
uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(),
|
||||
uri_host: req.uri().host().unwrap_or("").to_string(),
|
||||
|
|
@ -50,6 +46,33 @@ impl<T> From<&http::Request<T>> for HttpMessageLog {
|
|||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for HttpMessageLog {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"",
|
||||
if !self.host.is_empty() {
|
||||
self.host.as_str()
|
||||
} else {
|
||||
self.uri_host.as_str()
|
||||
},
|
||||
self.client_addr,
|
||||
self.method,
|
||||
self.p_and_q,
|
||||
self.version,
|
||||
self.status,
|
||||
if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() {
|
||||
format!("{}://{}", self.uri_scheme, self.uri_host)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
self.ua,
|
||||
self.xff,
|
||||
self.upstream
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpMessageLog {
|
||||
pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self {
|
||||
self.client_addr = client_addr.to_canonical().to_string();
|
||||
|
|
@ -74,26 +97,8 @@ impl HttpMessageLog {
|
|||
|
||||
pub fn output(&self) {
|
||||
info!(
|
||||
"{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"",
|
||||
if !self.host.is_empty() {
|
||||
self.host.as_str()
|
||||
} else {
|
||||
self.uri_host.as_str()
|
||||
},
|
||||
self.client_addr,
|
||||
self.method,
|
||||
self.p_and_q,
|
||||
self.version,
|
||||
self.status,
|
||||
if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() {
|
||||
format!("{}://{}", self.uri_scheme, self.uri_host)
|
||||
} else {
|
||||
"".to_string()
|
||||
},
|
||||
self.ua,
|
||||
self.xff,
|
||||
self.upstream,
|
||||
// self.tls_server_name
|
||||
name: crate::constants::log_event_names::ACCESS_LOG,
|
||||
"{}", self
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ impl From<HttpError> for StatusCode {
|
|||
HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
HttpError::FailedToGetResponseFromBackend(_) => StatusCode::BAD_GATEWAY,
|
||||
// HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST,
|
||||
// HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY,
|
||||
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use super::http_result::{HttpError, HttpResult};
|
||||
use crate::{
|
||||
error::*,
|
||||
hyper_ext::body::{empty, ResponseBody},
|
||||
hyper_ext::body::{ResponseBody, empty},
|
||||
name_exp::ServerName,
|
||||
};
|
||||
use http::{Request, Response, StatusCode, Uri};
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ use crate::{
|
|||
backend::{UpstreamCandidates, UpstreamOption},
|
||||
log::*,
|
||||
};
|
||||
use anyhow::{anyhow, Result};
|
||||
use anyhow::{Result, anyhow, ensure};
|
||||
use bytes::BufMut;
|
||||
use http::{header, HeaderMap, HeaderName, HeaderValue, Uri};
|
||||
use http::{HeaderMap, HeaderName, HeaderValue, Uri, header};
|
||||
use std::{borrow::Cow, net::SocketAddr};
|
||||
|
||||
#[cfg(feature = "sticky-cookie")]
|
||||
|
|
@ -238,10 +238,9 @@ pub(super) fn add_forwarding_header(
|
|||
pub(super) fn remove_connection_header(headers: &mut HeaderMap) {
|
||||
if let Some(values) = headers.get(header::CONNECTION) {
|
||||
if let Ok(v) = values.clone().to_str() {
|
||||
for m in v.split(',') {
|
||||
if !m.is_empty() {
|
||||
headers.remove(m.trim());
|
||||
}
|
||||
let keys = v.split(',').map(|m| m.trim()).filter(|m| !m.is_empty());
|
||||
for m in keys {
|
||||
headers.remove(m);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -274,13 +273,11 @@ pub(super) fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
|
|||
.to_str()
|
||||
.unwrap_or("")
|
||||
.split(',')
|
||||
.any(|w| w.trim().to_ascii_lowercase() == header::UPGRADE.as_str().to_ascii_lowercase())
|
||||
.any(|w| w.trim().eq_ignore_ascii_case(header::UPGRADE.as_str()))
|
||||
{
|
||||
if let Some(u) = headers.get(header::UPGRADE) {
|
||||
if let Ok(m) = u.to_str() {
|
||||
debug!("Upgrade in request header: {}", m);
|
||||
return Some(m.to_owned());
|
||||
}
|
||||
if let Some(Ok(m)) = headers.get(header::UPGRADE).map(|u| u.to_str()) {
|
||||
debug!("Upgrade in request header: {}", m);
|
||||
return Some(m.to_owned());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ use crate::{
|
|||
backend::{Upstream, UpstreamCandidates, UpstreamOption},
|
||||
log::*,
|
||||
};
|
||||
use anyhow::{anyhow, ensure, Result};
|
||||
use http::{header, uri::Scheme, Request, Version};
|
||||
use anyhow::{Result, anyhow, ensure};
|
||||
use http::{Request, Version, header, uri::Scheme};
|
||||
|
||||
/// Trait defining parser of hostname
|
||||
/// Inspect and extract hostname from either the request HOST header or request line
|
||||
|
|
@ -59,6 +59,18 @@ pub(super) fn update_request_line<B>(
|
|||
upstream_chosen: &Upstream,
|
||||
upstream_candidates: &UpstreamCandidates,
|
||||
) -> anyhow::Result<()> {
|
||||
// If request is grpc, HTTP/2 is required
|
||||
if req
|
||||
.headers()
|
||||
.get(header::CONTENT_TYPE)
|
||||
.map(|v| v.as_bytes().starts_with(b"application/grpc"))
|
||||
== Some(true)
|
||||
{
|
||||
debug!("Must be http/2 for gRPC request.");
|
||||
*req.version_mut() = Version::HTTP_2;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
|
||||
if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) {
|
||||
// Change version to http/1.1 when destination scheme is http
|
||||
|
|
|
|||
|
|
@ -14,12 +14,11 @@ use crate::{
|
|||
name_exp::ServerName,
|
||||
};
|
||||
use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder};
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
use rustls::ServerConfig;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// SNI to ServerConfig map type
|
||||
pub type SniServerCryptoMap = HashMap<ServerName, Arc<ServerConfig>>;
|
||||
pub type SniServerCryptoMap = std::collections::HashMap<ServerName, Arc<ServerConfig>, ahash::RandomState>;
|
||||
|
||||
pub(crate) use proxy_main::Proxy;
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ where
|
|||
<<C as OpenStreams<Bytes>>::BidiStream as BidiStream<Bytes>>::SendStream: Send,
|
||||
{
|
||||
let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?;
|
||||
info!(
|
||||
debug!(
|
||||
"QUIC/HTTP3 connection established from {:?} {}",
|
||||
client_addr,
|
||||
<&ServerName as TryInto<String>>::try_into(&tls_server_name).unwrap_or_default()
|
||||
|
|
@ -49,12 +49,17 @@ where
|
|||
}
|
||||
Err(e) => {
|
||||
warn!("HTTP/3 error on accept incoming connection: {}", e);
|
||||
match e.get_error_level() {
|
||||
h3::error::ErrorLevel::ConnectionError => break,
|
||||
h3::error::ErrorLevel::StreamError => continue,
|
||||
}
|
||||
break;
|
||||
}
|
||||
Ok(Some((req, stream))) => {
|
||||
// Ok(Some((req, stream))) => {
|
||||
Ok(Some(req_resolver)) => {
|
||||
let (req, stream) = match req_resolver.resolve_request().await {
|
||||
Ok((req, stream)) => (req, stream),
|
||||
Err(e) => {
|
||||
warn!("HTTP/3 error on resolve request in stream: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// We consider the connection count separately from the stream count.
|
||||
// Max clients for h1/h2 = max 'stream' for h3.
|
||||
let request_count = self.globals.request_count.clone();
|
||||
|
|
@ -63,7 +68,7 @@ where
|
|||
h3_conn.shutdown(0).await?;
|
||||
break;
|
||||
}
|
||||
debug!("Request incoming: current # {}", request_count.current());
|
||||
trace!("Request incoming: current # {}", request_count.current());
|
||||
|
||||
let self_inner = self.clone();
|
||||
let tls_server_name_inner = tls_server_name.clone();
|
||||
|
|
@ -77,7 +82,7 @@ where
|
|||
warn!("HTTP/3 error on serve stream: {}", e);
|
||||
}
|
||||
request_count.decrement();
|
||||
debug!("Request processed: current # {}", request_count.current());
|
||||
trace!("Request processed: current # {}", request_count.current());
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
@ -115,7 +120,7 @@ where
|
|||
let mut sender = body_sender;
|
||||
let mut size = 0usize;
|
||||
while let Some(mut body) = recv_stream.recv_data().await? {
|
||||
debug!("HTTP/3 incoming request body: remaining {}", body.remaining());
|
||||
trace!("HTTP/3 incoming request body: remaining {}", body.remaining());
|
||||
size += body.remaining();
|
||||
if size > max_body_size {
|
||||
error!(
|
||||
|
|
@ -129,9 +134,9 @@ where
|
|||
}
|
||||
|
||||
// trailers: use inner for work around. (directly get trailer)
|
||||
let trailers = recv_stream.as_mut().recv_trailers().await?;
|
||||
let trailers = futures_util::future::poll_fn(|cx| recv_stream.as_mut().poll_recv_trailers(cx)).await?;
|
||||
if trailers.is_some() {
|
||||
debug!("HTTP/3 incoming request trailers");
|
||||
trace!("HTTP/3 incoming request trailers");
|
||||
sender.send_trailers(trailers.unwrap()).await?;
|
||||
}
|
||||
Ok(()) as RpxyResult<()>
|
||||
|
|
@ -154,13 +159,13 @@ where
|
|||
|
||||
match send_stream.send_response(new_res).await {
|
||||
Ok(_) => {
|
||||
debug!("HTTP/3 response to connection successful");
|
||||
trace!("HTTP/3 response to connection successful");
|
||||
// on-demand body streaming to downstream without expanding the object onto memory.
|
||||
loop {
|
||||
let frame = match new_body.frame().await {
|
||||
Some(frame) => frame,
|
||||
None => {
|
||||
debug!("Response body finished");
|
||||
trace!("Response body finished");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use crate::{
|
|||
message_handler::HttpMessageHandler,
|
||||
name_exp::ServerName,
|
||||
};
|
||||
use futures::{select, FutureExt};
|
||||
use futures::{FutureExt, select};
|
||||
use http::{Request, Response};
|
||||
use hyper::{
|
||||
body::Incoming,
|
||||
|
|
@ -22,6 +22,7 @@ use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::au
|
|||
use rpxy_certs::ServerCrypto;
|
||||
use std::{net::SocketAddr, sync::Arc, time::Duration};
|
||||
use tokio::time::timeout;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
/// Wrapper function to handle request for HTTP/1.1 and HTTP/2
|
||||
/// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler
|
||||
|
|
@ -79,7 +80,7 @@ where
|
|||
request_count.decrement();
|
||||
return;
|
||||
}
|
||||
debug!("Request incoming: current # {}", request_count.current());
|
||||
trace!("Request incoming: current # {}", request_count.current());
|
||||
|
||||
let server_clone = self.connection_builder.clone();
|
||||
let message_handler_clone = self.message_handler.clone();
|
||||
|
|
@ -109,7 +110,7 @@ where
|
|||
}
|
||||
|
||||
request_count.decrement();
|
||||
debug!("Request processed: current # {}", request_count.current());
|
||||
trace!("Request processed: current # {}", request_count.current());
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -129,30 +130,56 @@ where
|
|||
}
|
||||
|
||||
/// Start with TLS (HTTPS)
|
||||
pub(super) async fn start_with_tls(&self) -> RpxyResult<()> {
|
||||
pub(super) async fn start_with_tls(&self, cancel_token: CancellationToken) -> RpxyResult<()> {
|
||||
// By default, TLS listener is spawned
|
||||
let join_handle_tls = self.globals.runtime_handle.spawn({
|
||||
let self_clone = self.clone();
|
||||
let cancel_token = cancel_token.clone();
|
||||
async move {
|
||||
select! {
|
||||
_ = self_clone.tls_listener_service().fuse() => {
|
||||
error!("TCP proxy service for TLS exited");
|
||||
cancel_token.cancel();
|
||||
},
|
||||
_ = cancel_token.cancelled().fuse() => {
|
||||
debug!("Cancel token is called for TLS listener");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
#[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))]
|
||||
{
|
||||
self.tls_listener_service().await?;
|
||||
error!("TCP proxy service for TLS exited");
|
||||
let _ = join_handle_tls.await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
|
||||
{
|
||||
if self.globals.proxy_config.http3 {
|
||||
select! {
|
||||
_ = self.tls_listener_service().fuse() => {
|
||||
error!("TCP proxy service for TLS exited");
|
||||
},
|
||||
_ = self.h3_listener_service().fuse() => {
|
||||
error!("UDP proxy service for QUIC exited");
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
} else {
|
||||
self.tls_listener_service().await?;
|
||||
error!("TCP proxy service for TLS exited");
|
||||
Ok(())
|
||||
// If HTTP/3 is not enabled, wait for TLS listener to finish
|
||||
if !self.globals.proxy_config.http3 {
|
||||
let _ = join_handle_tls.await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// If HTTP/3 is enabled, spawn a task to handle HTTP/3 connections
|
||||
let join_handle_h3 = self.globals.runtime_handle.spawn({
|
||||
let self_clone = self.clone();
|
||||
async move {
|
||||
select! {
|
||||
_ = self_clone.h3_listener_service().fuse() => {
|
||||
error!("UDP proxy service for QUIC exited");
|
||||
cancel_token.cancel();
|
||||
},
|
||||
_ = cancel_token.cancelled().fuse() => {
|
||||
debug!("Cancel token is called for QUIC listener");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let _ = futures::future::join(join_handle_tls, join_handle_h3).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -294,7 +321,7 @@ where
|
|||
let map = server_config.individual_config_map.clone().iter().map(|(k,v)| {
|
||||
let server_name = ServerName::from(k.as_slice());
|
||||
(server_name, v.clone())
|
||||
}).collect::<rustc_hash::FxHashMap<_,_>>();
|
||||
}).collect::<std::collections::HashMap<_,_,ahash::RandomState>>();
|
||||
server_crypto_map = Some(Arc::new(map));
|
||||
}
|
||||
}
|
||||
|
|
@ -303,10 +330,10 @@ where
|
|||
}
|
||||
|
||||
/// Entrypoint for HTTP/1.1, 2 and 3 servers
|
||||
pub async fn start(&self) -> RpxyResult<()> {
|
||||
pub async fn start(&self, cancel_token: CancellationToken) -> RpxyResult<()> {
|
||||
let proxy_service = async {
|
||||
if self.tls_enabled {
|
||||
self.start_with_tls().await
|
||||
self.start_with_tls(cancel_token).await
|
||||
} else {
|
||||
self.start_without_tls().await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ use super::{proxy_main::Proxy, socket::bind_udp_socket};
|
|||
use crate::{error::*, log::*, name_exp::ByteName};
|
||||
use hyper_util::client::legacy::connect::Connect;
|
||||
use quinn::{
|
||||
crypto::rustls::{HandshakeData, QuicServerConfig},
|
||||
Endpoint, TransportConfig,
|
||||
crypto::rustls::{HandshakeData, QuicServerConfig},
|
||||
};
|
||||
use rpxy_certs::ServerCrypto;
|
||||
use rustls::ServerConfig;
|
||||
|
|
@ -82,7 +82,7 @@ where
|
|||
let client_addr = incoming.remote_address();
|
||||
let quic_connection = match incoming.await {
|
||||
Ok(new_conn) => {
|
||||
info!("New connection established");
|
||||
trace!("New connection established");
|
||||
h3_quinn::Connection::new(new_conn)
|
||||
},
|
||||
Err(e) => {
|
||||
|
|
|
|||
|
|
@ -110,7 +110,7 @@ where
|
|||
|
||||
// quic event loop. this immediately cancels when crypto is updated by tokio::select!
|
||||
while let Some(new_conn) = server.accept().await {
|
||||
debug!("New QUIC connection established");
|
||||
trace!("New QUIC connection established");
|
||||
let Ok(Some(new_server_name)) = new_conn.server_name() else {
|
||||
warn!("HTTP/3 no SNI is given");
|
||||
continue;
|
||||
|
|
|
|||
|
|
@ -16,10 +16,12 @@ pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> RpxyResult<TcpSocket
|
|||
}?;
|
||||
tcp_socket.set_reuseaddr(true)?;
|
||||
tcp_socket.set_reuseport(true)?;
|
||||
if let Err(e) = tcp_socket.bind(*listening_on) {
|
||||
|
||||
tcp_socket.bind(*listening_on).map_err(|e| {
|
||||
error!("Failed to bind TCP socket: {}", e);
|
||||
return Err(RpxyError::Io(e));
|
||||
};
|
||||
RpxyError::Io(e)
|
||||
})?;
|
||||
|
||||
Ok(tcp_socket)
|
||||
}
|
||||
|
||||
|
|
@ -36,11 +38,10 @@ pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> RpxyResult<UdpSocket
|
|||
socket.set_reuse_port(true)?;
|
||||
socket.set_nonblocking(true)?; // This was made true inside quinn. so this line isn't necessary here. but just in case.
|
||||
|
||||
if let Err(e) = socket.bind(&(*listening_on).into()) {
|
||||
socket.bind(&(*listening_on).into()).map_err(|e| {
|
||||
error!("Failed to bind UDP socket: {}", e);
|
||||
return Err(RpxyError::Io(e));
|
||||
};
|
||||
let udp_socket: UdpSocket = socket.into();
|
||||
RpxyError::Io(e)
|
||||
})?;
|
||||
|
||||
Ok(udp_socket)
|
||||
Ok(socket.into())
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue