Merge branch 'tmp/sticky-cookie' into feat/sticky-cookie-feature

This commit is contained in:
Jun Kurihara 2025-06-03 14:50:00 +09:00 committed by GitHub
commit d8cadf06af
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 4870 additions and 867 deletions

View file

@ -36,12 +36,12 @@ post-quantum = [
]
[dependencies]
rand = "0.8.5"
rustc-hash = "2.0.0"
bytes = "1.8.0"
rand = "0.9.1"
ahash = "0.8.12"
bytes = "1.10.1"
derive_builder = "0.20.2"
futures = { version = "0.3.31", features = ["alloc", "async-await"] }
tokio = { version = "1.41.0", default-features = false, features = [
tokio = { version = "1.45.1", default-features = false, features = [
"net",
"rt-multi-thread",
"time",
@ -49,19 +49,19 @@ tokio = { version = "1.41.0", default-features = false, features = [
"macros",
"fs",
] }
tokio-util = { version = "0.7.12", default-features = false }
pin-project-lite = "0.2.15"
async-trait = "0.1.83"
tokio-util = { version = "0.7.15", default-features = false }
pin-project-lite = "0.2.16"
async-trait = "0.1.88"
# Error handling
anyhow = "1.0.91"
thiserror = "1.0.66"
anyhow = "1.0.98"
thiserror = "2.0.12"
# http for both server and client
http = "1.1.0"
http-body-util = "0.1.2"
hyper = { version = "1.5.0", default-features = false }
hyper-util = { version = "0.1.10", features = ["full"] }
http = "1.3.1"
http-body-util = "0.1.3"
hyper = { version = "1.6.0", default-features = false }
hyper-util = { version = "0.1.13", features = ["full"] }
futures-util = { version = "0.3.31", default-features = false }
futures-channel = { version = "0.3.31", default-features = false }
@ -70,7 +70,7 @@ hyper-tls = { version = "0.6.0", features = [
"alpn",
"vendored",
], optional = true }
hyper-rustls = { version = "0.27.3", default-features = false, features = [
hyper-rustls = { version = "0.27.6", default-features = false, features = [
"aws-lc-rs",
"http1",
"http2",
@ -79,40 +79,40 @@ hyper-rustls = { version = "0.27.3", default-features = false, features = [
# tls and cert management for server
rpxy-certs = { path = "../rpxy-certs/", default-features = false }
hot_reload = "0.1.6"
rustls = { version = "0.23.16", default-features = false }
rustls-post-quantum = { version = "0.1.0", optional = true }
tokio-rustls = { version = "0.26.0", features = ["early-data"] }
hot_reload = "0.1.9"
rustls = { version = "0.23.27", default-features = false }
rustls-post-quantum = { version = "0.2.2", optional = true }
tokio-rustls = { version = "0.26.2", features = ["early-data"] }
# acme
rpxy-acme = { path = "../rpxy-acme/", default-features = false, optional = true }
# logging
tracing = { version = "0.1.40" }
tracing = { version = "0.1.41" }
# http/3
quinn = { version = "0.11.5", optional = true }
h3 = { version = "0.0.6", features = ["tracing"], optional = true }
h3-quinn = { version = "0.0.7", optional = true }
s2n-quic = { version = "1.48.0", path = "../submodules/s2n-quic/quic/s2n-quic/", default-features = false, features = [
quinn = { version = "0.11.8", optional = true }
h3 = { version = "0.0.8", features = ["tracing"], optional = true }
h3-quinn = { version = "0.0.10", optional = true }
s2n-quic = { version = "1.59.0", path = "../submodules/s2n-quic/quic/s2n-quic/", default-features = false, features = [
"provider-tls-rustls",
], optional = true }
s2n-quic-core = { version = "0.48.0", path = "../submodules/s2n-quic/quic/s2n-quic-core", default-features = false, optional = true }
s2n-quic-rustls = { version = "0.48.0", path = "../submodules/s2n-quic/quic/s2n-quic-rustls", optional = true }
s2n-quic-core = { version = "0.59.0", path = "../submodules/s2n-quic/quic/s2n-quic-core", default-features = false, optional = true }
s2n-quic-rustls = { version = "0.59.0", path = "../submodules/s2n-quic/quic/s2n-quic-rustls", optional = true }
s2n-quic-h3 = { path = "../submodules/s2n-quic/quic/s2n-quic-h3/", features = [
"tracing",
], optional = true }
##########
# for UDP socket wit SO_REUSEADDR when h3 with quinn
socket2 = { version = "0.5.7", features = ["all"], optional = true }
socket2 = { version = "0.5.10", features = ["all"], optional = true }
# cache
http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics", default-features = false, optional = true }
lru = { version = "0.12.5", optional = true }
sha2 = { version = "0.10.8", default-features = false, optional = true }
lru = { version = "0.14.0", optional = true }
sha2 = { version = "0.10.9", default-features = false, optional = true }
# cookie handling for sticky cookie
chrono = { version = "0.4.38", default-features = false, features = [
chrono = { version = "0.4.41", default-features = false, features = [
"unstable-locales",
"alloc",
"clock",

View file

@ -1,11 +1,11 @@
use crate::{
AppConfig, AppConfigList,
error::*,
log::*,
name_exp::{ByteName, ServerName},
AppConfig, AppConfigList,
};
use ahash::HashMap;
use derive_builder::Builder;
use rustc_hash::FxHashMap as HashMap;
use std::borrow::Cow;
use super::upstream::PathManager;
@ -26,6 +26,7 @@ pub struct BackendApp {
pub https_redirection: Option<bool>,
/// tls settings: mutual TLS is enabled
#[builder(default)]
#[allow(unused)]
pub mutual_tls: Option<bool>,
}
impl<'a> BackendAppBuilder {

View file

@ -7,8 +7,8 @@ pub use super::{
use derive_builder::Builder;
use rand::Rng;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
atomic::{AtomicUsize, Ordering},
};
/// Constants to specify a load balance option
@ -80,8 +80,8 @@ impl LoadBalanceRandomBuilder {
impl LoadBalanceWithPointer for LoadBalanceRandom {
/// Returns the random index within the range
fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream {
let mut rng = rand::thread_rng();
let ptr = rng.gen_range(0..self.num_upstreams);
let mut rng = rand::rng();
let ptr = rng.random_range(0..self.num_upstreams);
PointerToUpstream { ptr, context: None }
}
}

View file

@ -1,16 +1,16 @@
use super::{
Upstream,
load_balance_main::{LoadBalanceContext, LoadBalanceWithPointer, PointerToUpstream},
sticky_cookie::StickyCookieConfig,
Upstream,
};
use crate::{constants::STICKY_COOKIE_NAME, log::*};
use ahash::HashMap;
use derive_builder::Builder;
use rustc_hash::FxHashMap as HashMap;
use std::{
borrow::Cow,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
atomic::{AtomicUsize, Ordering},
},
};
@ -112,13 +112,16 @@ impl LoadBalanceWithPointer for LoadBalanceSticky {
}
Some(context) => {
let server_id = &context.sticky_cookie.value.value;
if let Some(server_index) = self.get_server_index_from_id(server_id) {
debug!("Valid sticky cookie: id={}, index={}", server_id, server_index);
server_index
} else {
debug!("Invalid sticky cookie: id={}", server_id);
self.simple_increment_ptr()
}
self.get_server_index_from_id(server_id).map_or_else(
|| {
debug!("Invalid sticky cookie: id={}", server_id);
self.simple_increment_ptr()
},
|server_index| {
debug!("Valid sticky cookie: id={}, index={}", server_id, server_index);
server_index
},
)
}
};

View file

@ -9,7 +9,7 @@ use super::upstream::Upstream;
use thiserror::Error;
pub use load_balance_main::{
load_balance_options, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder,
LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, load_balance_options,
};
#[cfg(feature = "sticky-cookie")]
pub use load_balance_sticky::LoadBalanceStickyBuilder;

View file

@ -91,12 +91,7 @@ impl<'a> StickyCookieBuilder {
self
}
/// Set the meta information of sticky cookie
pub fn info(
&mut self,
domain: impl Into<Cow<'a, str>>,
path: impl Into<Cow<'a, str>>,
duration_secs: i64,
) -> &mut Self {
pub fn info(&mut self, domain: impl Into<Cow<'a, str>>, path: impl Into<Cow<'a, str>>, duration_secs: i64) -> &mut Self {
let info = StickyCookieInfoBuilder::default()
.domain(domain)
.path(path)

View file

@ -1,7 +1,7 @@
#[cfg(feature = "sticky-cookie")]
use super::load_balance::LoadBalanceStickyBuilder;
use super::load_balance::{
load_balance_options as lb_opts, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder,
LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, load_balance_options as lb_opts,
};
// use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption};
use super::upstream_opts::UpstreamOption;
@ -11,10 +11,10 @@ use crate::{
log::*,
name_exp::{ByteName, PathName},
};
use ahash::{HashMap, HashSet};
#[cfg(feature = "sticky-cookie")]
use base64::{engine::general_purpose, Engine as _};
use base64::{Engine as _, engine::general_purpose};
use derive_builder::Builder;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
#[cfg(feature = "sticky-cookie")]
use sha2::{Digest, Sha256};
use std::borrow::Cow;
@ -72,27 +72,22 @@ impl PathManager {
.inner
.iter()
.filter(|(route_bytes, _)| {
match path_name.starts_with(route_bytes) {
true => {
route_bytes.len() == 1 // route = '/', i.e., default
|| match path_name.get(route_bytes.len()) {
None => true, // exact case
Some(p) => p == &b'/', // sub-path case
}
}
_ => false,
path_name.starts_with(route_bytes) && {
route_bytes.len() == 1 // route = '/', i.e., default
|| path_name.get(route_bytes.len()).map_or(
true, // exact case
|p| p == &b'/'
) // sub-path case
}
})
.max_by_key(|(route_bytes, _)| route_bytes.len());
if let Some((path, u)) = matched_upstream {
matched_upstream.map(|(path, u)| {
debug!(
"Found upstream: {:?}",
path.try_into().unwrap_or_else(|_| "<none>".to_string())
);
Some(u)
} else {
None
}
u
})
}
}
@ -211,14 +206,15 @@ impl UpstreamCandidatesBuilder {
}
/// Set the activated upstream options defined in [[UpstreamOption]]
pub fn options(&mut self, v: &Option<Vec<String>>) -> &mut Self {
let opts = if let Some(opts) = v {
opts
.iter()
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
.collect::<HashSet<UpstreamOption>>()
} else {
Default::default()
};
let opts = v.as_ref().map_or_else(
|| Default::default(),
|opts| {
opts
.iter()
.filter_map(|str| UpstreamOption::try_from(str.as_str()).ok())
.collect::<HashSet<UpstreamOption>>()
},
);
self.options = Some(opts);
self
}

View file

@ -32,3 +32,9 @@ pub const MAX_CACHE_EACH_SIZE: usize = 65_535;
pub const MAX_CACHE_EACH_SIZE_ON_MEMORY: usize = 4_096;
// TODO: max cache size in total
/// Logging event name TODO: Other separated logs?
pub mod log_event_names {
/// access log
pub const ACCESS_LOG: &str = "rpxy::access";
}

View file

@ -1,6 +1,6 @@
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
atomic::{AtomicUsize, Ordering},
};
#[derive(Debug, Clone, Default)]

View file

@ -37,8 +37,11 @@ pub enum RpxyError {
// http/3 errors
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
#[error("H3 error: {0}")]
H3Error(#[from] h3::Error),
#[error("h3 connection error: {0}")]
H3ConnectionError(#[from] h3::error::ConnectionError),
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
#[error("h3 connection error: {0}")]
H3StreamError(#[from] h3::error::StreamError),
// #[cfg(feature = "http3-s2n")]
// #[error("H3 error: {0}")]
// H3Error(#[from] s2n_quic_h3::h3::Error),

View file

@ -1,10 +1,10 @@
use super::cache_error::*;
use crate::{
globals::Globals,
hyper_ext::body::{full, BoxBody, ResponseBody, UnboundedStreamBody},
hyper_ext::body::{BoxBody, ResponseBody, UnboundedStreamBody, full},
log::*,
};
use base64::{engine::general_purpose, Engine as _};
use base64::{Engine as _, engine::general_purpose};
use bytes::{Buf, Bytes, BytesMut};
use futures::channel::mpsc;
use http::{Request, Response, Uri};
@ -16,8 +16,8 @@ use sha2::{Digest, Sha256};
use std::{
path::{Path, PathBuf},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
atomic::{AtomicUsize, Ordering},
},
time::SystemTime,
};
@ -52,23 +52,30 @@ impl RpxyCache {
if !globals.proxy_config.cache_enabled {
return None;
}
let cache_dir = globals.proxy_config.cache_dir.as_ref().unwrap();
let cache_dir = match globals.proxy_config.cache_dir.as_ref() {
Some(dir) => dir,
None => {
warn!("Cache directory not set in proxy config");
return None;
}
};
let file_store = FileStore::new(&globals.runtime_handle).await;
let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry);
let max_each_size = globals.proxy_config.cache_max_each_size;
let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory;
if max_each_size < max_each_size_on_memory {
warn!(
"Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache"
);
warn!("Maximum size of on-memory cache per entry must be smaller than or equal to the maximum of each file cache");
max_each_size_on_memory = max_each_size;
}
if let Err(e) = fs::remove_dir_all(cache_dir).await {
warn!("Failed to clean up the cache dir: {e}");
};
fs::create_dir_all(&cache_dir).await.unwrap();
}
if let Err(e) = fs::create_dir_all(&cache_dir).await {
error!("Failed to create cache dir: {e}");
return None;
}
Some(Self {
file_store,
@ -89,12 +96,7 @@ impl RpxyCache {
}
/// Put response into the cache
pub(crate) async fn put(
&self,
uri: &hyper::Uri,
mut body: Incoming,
policy: &CachePolicy,
) -> CacheResult<UnboundedStreamBody> {
pub(crate) async fn put(&self, uri: &hyper::Uri, mut body: Incoming, policy: &CachePolicy) -> CacheResult<UnboundedStreamBody> {
let cache_manager = self.inner.clone();
let mut file_store = self.file_store.clone();
let uri = uri.clone();
@ -155,7 +157,7 @@ impl RpxyCache {
let mut hasher = Sha256::new();
hasher.update(buf.as_ref());
let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref());
debug!("Cached data: {} bytes, hash = {:?}", size, hash_bytes);
trace!("Cached data: {} bytes, hash = {:?}", size, hash_bytes);
// Create cache object
let cache_key = derive_cache_key_from_uri(&uri);
@ -188,16 +190,11 @@ impl RpxyCache {
/// Get cached response
pub(crate) async fn get<R>(&self, req: &Request<R>) -> Option<Response<ResponseBody>> {
debug!(
"Current cache status: (total, on-memory, file) = {:?}",
self.count().await
);
trace!("Current cache status: (total, on-memory, file) = {:?}", self.count().await);
let cache_key = derive_cache_key_from_uri(req.uri());
// First check cache chance
let Ok(Some(cached_object)) = self.inner.get(&cache_key) else {
return None;
};
let cached_object = self.inner.get(&cache_key).ok()??;
// Secondly check the cache freshness as an HTTP message
let now = SystemTime::now();
@ -268,25 +265,20 @@ impl FileStore {
let inner = self.inner.read().await;
inner.cnt
}
/// Create a temporary file cache
/// Create a temporary file cache, returns error if file cannot be created or written
async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> {
let mut inner = self.inner.write().await;
inner.create(cache_object, body_bytes).await
}
/// Evict a temporary file cache
/// Evict a temporary file cache, logs warning if removal fails
async fn evict(&self, path: impl AsRef<Path>) {
// Acquire the write lock
let mut inner = self.inner.write().await;
if let Err(e) = inner.remove(path).await {
warn!("Eviction failed during file object removal: {:?}", e);
};
}
}
/// Read a temporary file cache
async fn read(
&self,
path: impl AsRef<Path> + Send + Sync + 'static,
hash: &Bytes,
) -> CacheResult<UnboundedStreamBody> {
/// Read a temporary file cache, returns error if file cannot be opened or hash mismatches
async fn read(&self, path: impl AsRef<Path> + Send + Sync + 'static, hash: &Bytes) -> CacheResult<UnboundedStreamBody> {
let inner = self.inner.read().await;
inner.read(path, hash).await
}
@ -321,26 +313,22 @@ impl FileStoreInner {
return Err(CacheError::InvalidCacheTarget);
}
};
let Ok(mut file) = File::create(&cache_filepath).await else {
return Err(CacheError::FailedToCreateFileCache);
};
let mut file = File::create(&cache_filepath)
.await
.map_err(|_| CacheError::FailedToCreateFileCache)?;
let mut bytes_clone = body_bytes.clone();
while bytes_clone.has_remaining() {
if let Err(e) = file.write_buf(&mut bytes_clone).await {
file.write_buf(&mut bytes_clone).await.map_err(|e| {
error!("Failed to write file cache: {e}");
return Err(CacheError::FailedToWriteFileCache);
};
CacheError::FailedToWriteFileCache
})?;
}
self.cnt += 1;
Ok(())
}
/// Retrieve a stored temporary file cache
async fn read(
&self,
path: impl AsRef<Path> + Send + Sync + 'static,
hash: &Bytes,
) -> CacheResult<UnboundedStreamBody> {
async fn read(&self, path: impl AsRef<Path> + Send + Sync + 'static, hash: &Bytes) -> CacheResult<UnboundedStreamBody> {
let Ok(mut file) = File::open(&path).await else {
warn!("Cache file object cannot be opened");
return Err(CacheError::FailedToOpenCacheFile);
@ -455,11 +443,14 @@ impl LruCacheManager {
self.cnt.load(Ordering::Relaxed)
}
/// Evict an entry
/// Evict an entry from the LRU cache, logs error if mutex cannot be acquired
fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> {
let Ok(mut lock) = self.inner.lock() else {
error!("Mutex can't be locked to evict a cache entry");
return None;
let mut lock = match self.inner.lock() {
Ok(lock) => lock,
Err(_) => {
error!("Mutex can't be locked to evict a cache entry");
return None;
}
};
let res = lock.pop_entry(cache_key);
// This may be inconsistent with the actual number of entries
@ -467,24 +458,24 @@ impl LruCacheManager {
res
}
/// Push an entry
/// Push an entry into the LRU cache, returns error if mutex cannot be acquired
fn push(&self, cache_key: &str, cache_object: &CacheObject) -> CacheResult<Option<(String, CacheObject)>> {
let Ok(mut lock) = self.inner.lock() else {
let mut lock = self.inner.lock().map_err(|_| {
error!("Failed to acquire mutex lock for writing cache entry");
return Err(CacheError::FailedToAcquiredMutexLockForCache);
};
CacheError::FailedToAcquiredMutexLockForCache
})?;
let res = Ok(lock.push(cache_key.to_string(), cache_object.clone()));
// This may be inconsistent with the actual number of entries
self.cnt.store(lock.len(), Ordering::Relaxed);
res
}
/// Get an entry
/// Get an entry from the LRU cache, returns error if mutex cannot be acquired
fn get(&self, cache_key: &str) -> CacheResult<Option<CacheObject>> {
let Ok(mut lock) = self.inner.lock() else {
let mut lock = self.inner.lock().map_err(|_| {
error!("Mutex can't be locked for checking cache entry");
return Err(CacheError::FailedToAcquiredMutexLockForCheck);
};
CacheError::FailedToAcquiredMutexLockForCheck
})?;
let Some(cached_object) = lock.get(cache_key) else {
return Ok(None);
};

View file

@ -2,4 +2,4 @@ mod cache_error;
mod cache_main;
pub use cache_error::CacheError;
pub(crate) use cache_main::{get_policy_if_cacheable, RpxyCache};
pub(crate) use cache_main::{RpxyCache, get_policy_if_cacheable};

View file

@ -9,13 +9,13 @@ use async_trait::async_trait;
use http::{Request, Response, Version};
use hyper::body::{Body, Incoming};
use hyper_util::client::legacy::{
connect::{Connect, HttpConnector},
Client,
connect::{Connect, HttpConnector},
};
use std::sync::Arc;
#[cfg(feature = "cache")]
use super::cache::{get_policy_if_cacheable, RpxyCache};
use super::cache::{RpxyCache, get_policy_if_cacheable};
#[async_trait]
/// Definition of the forwarder that simply forward requests from downstream client to upstream app servers.
@ -126,9 +126,9 @@ where
warn!(
"
--------------------------------------------------------------------------------------------------
Request forwarder is working without TLS support!!!
We recommend to use this just for testing.
Please enable native-tls-backend or rustls-backend feature to enable TLS support.
Request forwarder is working without TLS support!
This mode is intended for testing only.
Enable 'native-tls-backend' or 'rustls-backend' feature for TLS support.
--------------------------------------------------------------------------------------------------"
);
let executor = LocalExecutor::new(_globals.runtime_handle.clone());
@ -159,7 +159,7 @@ where
/// Build forwarder
pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> {
// build hyper client with hyper-tls
info!("Native TLS support is enabled for the connection to backend applications");
info!("Native TLS support enabled for backend connections (native-tls)");
let executor = LocalExecutor::new(_globals.runtime_handle.clone());
let try_build_connector = |alpns: &[&str]| {
@ -209,14 +209,14 @@ where
#[cfg(feature = "webpki-roots")]
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
#[cfg(feature = "webpki-roots")]
info!("Mozilla WebPKI root certs with rustls is used for the connection to backend applications");
info!("Rustls backend: Mozilla WebPKI root certs used for backend connections");
#[cfg(not(feature = "webpki-roots"))]
let builder = hyper_rustls::HttpsConnectorBuilder::new().with_platform_verifier();
#[cfg(not(feature = "webpki-roots"))]
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_platform_verifier();
#[cfg(not(feature = "webpki-roots"))]
info!("Platform verifier with rustls is used for the connection to backend applications");
info!("Rustls backend: Platform verifier used for backend connections");
let mut http = HttpConnector::new();
http.enforce_http(false);
@ -226,7 +226,9 @@ where
let connector = builder.https_or_http().enable_all_versions().wrap_connector(http.clone());
let connector_h2 = builder_h2.https_or_http().enable_http2().wrap_connector(http);
let inner = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector);
let inner_h2 = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector_h2);
let inner_h2 = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone()))
.http2_only(true)
.build::<_, B1>(connector_h2);
Ok(Self {
inner,

View file

@ -2,7 +2,6 @@ use crate::{constants::*, count::RequestCount};
use hot_reload::ReloaderReceiver;
use rpxy_certs::ServerCryptoBase;
use std::{net::SocketAddr, time::Duration};
use tokio_util::sync::CancellationToken;
/// Global object containing proxy configurations and shared object like counters.
/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
@ -13,14 +12,12 @@ pub struct Globals {
pub request_count: RequestCount,
/// Shared context - Async task runtime handler
pub runtime_handle: tokio::runtime::Handle,
/// Shared context - Notify object to stop async tasks
pub cancel_token: Option<CancellationToken>,
/// Shared context - Certificate reloader service receiver // TODO: newer one
pub cert_reloader_rx: Option<ReloaderReceiver<ServerCryptoBase>>,
#[cfg(feature = "acme")]
/// ServerConfig used for only ACME challenge for ACME domains
pub server_configs_acme_challenge: std::sync::Arc<rustc_hash::FxHashMap<String, std::sync::Arc<rustls::ServerConfig>>>,
pub server_configs_acme_challenge: std::sync::Arc<ahash::HashMap<String, std::sync::Arc<rustls::ServerConfig>>>,
}
/// Configuration parameters for proxy transport and request handlers

View file

@ -1,7 +1,7 @@
use super::watch;
use crate::error::*;
use futures_channel::{mpsc, oneshot};
use futures_util::{stream::FusedStream, Future, Stream};
use futures_util::{Future, Stream, stream::FusedStream};
use http::HeaderMap;
use hyper::body::{Body, Bytes, Frame, SizeHint};
use std::{

View file

@ -1,7 +1,7 @@
use super::body::IncomingLike;
use crate::error::RpxyError;
use futures::channel::mpsc::UnboundedReceiver;
use http_body_util::{combinators, BodyExt, Empty, Full, StreamBody};
use http_body_util::{BodyExt, Empty, Full, StreamBody, combinators};
use hyper::body::{Body, Bytes, Frame, Incoming};
use std::pin::Pin;

View file

@ -12,5 +12,5 @@ pub(crate) mod rt {
#[allow(unused)]
pub(crate) mod body {
pub(crate) use super::body_incoming_like::IncomingLike;
pub(crate) use super::body_type::{empty, full, BoxBody, RequestBody, ResponseBody, UnboundedStreamBody};
pub(crate) use super::body_type::{BoxBody, RequestBody, ResponseBody, UnboundedStreamBody, empty, full};
}

View file

@ -7,8 +7,8 @@
use futures_util::task::AtomicWaker;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
atomic::{AtomicUsize, Ordering},
};
use std::task;

View file

@ -27,6 +27,7 @@ use std::sync::Arc;
use tokio_util::sync::CancellationToken;
/* ------------------------------------------------ */
pub use crate::constants::log_event_names;
pub use crate::globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri};
pub mod reexports {
pub use hyper::Uri;
@ -43,12 +44,10 @@ pub struct RpxyOptions {
pub cert_rx: Option<ReloaderReceiver<ServerCryptoBase>>, // TODO:
/// Async task runtime handler
pub runtime_handle: tokio::runtime::Handle,
/// Notify object to stop async tasks
pub cancel_token: Option<CancellationToken>,
#[cfg(feature = "acme")]
/// ServerConfig used for only ACME challenge for ACME domains
pub server_configs_acme_challenge: Arc<rustc_hash::FxHashMap<String, Arc<rustls::ServerConfig>>>,
pub server_configs_acme_challenge: Arc<ahash::HashMap<String, Arc<rustls::ServerConfig>>>,
}
/// Entrypoint that creates and spawns tasks of reverse proxy services
@ -58,10 +57,10 @@ pub async fn entrypoint(
app_config_list,
cert_rx, // TODO:
runtime_handle,
cancel_token,
#[cfg(feature = "acme")]
server_configs_acme_challenge,
}: &RpxyOptions,
cancel_token: CancellationToken,
) -> RpxyResult<()> {
#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))]
warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used");
@ -117,7 +116,6 @@ pub async fn entrypoint(
proxy_config: proxy_config.clone(),
request_count: Default::default(),
runtime_handle: runtime_handle.clone(),
cancel_token: cancel_token.clone(),
cert_reloader_rx: cert_rx.clone(),
#[cfg(feature = "acme")]
@ -153,25 +151,21 @@ pub async fn entrypoint(
message_handler: message_handler.clone(),
};
let cancel_token = globals.cancel_token.as_ref().map(|t| t.child_token());
let parent_cancel_token_clone = globals.cancel_token.clone();
let cancel_token = cancel_token.clone();
globals.runtime_handle.spawn(async move {
info!("rpxy proxy service for {listening_on} started");
if let Some(cancel_token) = cancel_token {
tokio::select! {
_ = cancel_token.cancelled() => {
debug!("rpxy proxy service for {listening_on} terminated");
Ok(())
},
proxy_res = proxy.start() => {
info!("rpxy proxy service for {listening_on} exited");
// cancel other proxy tasks
parent_cancel_token_clone.unwrap().cancel();
proxy_res
}
tokio::select! {
_ = cancel_token.cancelled() => {
debug!("rpxy proxy service for {listening_on} terminated");
Ok(())
},
proxy_res = proxy.start(cancel_token.child_token()) => {
info!("rpxy proxy service for {listening_on} exited");
// cancel other proxy tasks
cancel_token.cancel();
proxy_res
}
} else {
proxy.start().await
}
})
});
@ -186,9 +180,5 @@ pub async fn entrypoint(
}
});
// returns the first error as the representative error
if let Some(e) = errs.next() {
return Err(e);
}
Ok(())
errs.next().map_or(Ok(()), |e| Err(e))
}

View file

@ -1 +1 @@
pub use tracing::{debug, error, info, warn};
pub use tracing::{debug, error, info, trace, warn};

View file

@ -44,10 +44,7 @@ mod tests {
}
#[test]
fn ipv6_to_canonical() {
let socket = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0xdead, 0xbeef)),
8080,
);
let socket = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0xdead, 0xbeef)), 8080);
assert_eq!(socket.to_canonical(), socket);
}
#[test]

View file

@ -71,7 +71,7 @@ where
Ok(v)
}
Err(e) => {
error!("{e}");
error!("{e}: {log_data}");
let code = StatusCode::from(e);
log_data.status_code(&code).output();
synthetic_error_response(code)
@ -107,9 +107,11 @@ where
let backend_app = match self.app_manager.apps.get(&server_name) {
Some(backend_app) => backend_app,
None => {
let Some(default_server_name) = &self.app_manager.default_server_name else {
return Err(HttpError::NoMatchingBackendApp);
};
let default_server_name = self
.app_manager
.default_server_name
.as_ref()
.ok_or(HttpError::NoMatchingBackendApp)?;
debug!("Serving by default app");
self.app_manager.apps.get(default_server_name).unwrap()
}
@ -131,9 +133,7 @@ where
// Find reverse proxy for given path and choose one of upstream host
// Longest prefix match
let path = req.uri().path();
let Some(upstream_candidates) = backend_app.path_manager.get(path) else {
return Err(HttpError::NoUpstreamCandidates);
};
let upstream_candidates = backend_app.path_manager.get(path).ok_or(HttpError::NoUpstreamCandidates)?;
// Upgrade in request header
let upgrade_in_request = extract_upgrade(req.headers());
@ -147,19 +147,17 @@ where
let req_on_upgrade = hyper::upgrade::on(&mut req);
// Build request from destination information
let _context = match self.generate_request_forwarded(
&client_addr,
&listen_addr,
&mut req,
&upgrade_in_request,
upstream_candidates,
tls_enabled,
) {
Err(e) => {
return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string()));
}
Ok(v) => v,
};
let _context = self
.generate_request_forwarded(
&client_addr,
&listen_addr,
&mut req,
&upgrade_in_request,
upstream_candidates,
tls_enabled,
)
.map_err(|e| HttpError::FailedToGenerateUpstreamRequest(e.to_string()))?;
debug!(
"Request to be forwarded: [uri {}, method: {}, version {:?}, headers {:?}]",
req.uri(),
@ -173,12 +171,12 @@ where
//////////////
// Forward request to a chosen backend
let mut res_backend = match self.forwarder.request(req).await {
Ok(v) => v,
Err(e) => {
return Err(HttpError::FailedToGetResponseFromBackend(e.to_string()));
}
};
let mut res_backend = self
.forwarder
.request(req)
.await
.map_err(|e| HttpError::FailedToGetResponseFromBackend(e.to_string()))?;
//////////////
// Process reverse proxy context generated during the forwarding request generation.
#[cfg(feature = "sticky-cookie")]
@ -191,16 +189,16 @@ where
if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS {
// Generate response to client
if let Err(e) = self.generate_response_forwarded(&mut res_backend, backend_app) {
return Err(HttpError::FailedToGenerateDownstreamResponse(e.to_string()));
}
self
.generate_response_forwarded(&mut res_backend, backend_app)
.map_err(|e| HttpError::FailedToGenerateDownstreamResponse(e.to_string()))?;
return Ok(res_backend);
}
// Handle StatusCode::SWITCHING_PROTOCOLS in response
let upgrade_in_response = extract_upgrade(res_backend.headers());
let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) {
(Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(),
(Some(u_req), Some(u_res)) => u_req.eq_ignore_ascii_case(u_res),
_ => false,
};

View file

@ -1,11 +1,11 @@
use super::{handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line, HttpMessageHandler};
use super::{HttpMessageHandler, handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line};
use crate::{
backend::{BackendApp, UpstreamCandidates},
constants::RESPONSE_HEADER_SERVER,
log::*,
};
use anyhow::{anyhow, ensure, Result};
use http::{header, HeaderValue, Request, Response, Uri};
use anyhow::{Result, anyhow, ensure};
use http::{HeaderValue, Request, Response, Uri, header};
use hyper_util::client::legacy::connect::Connect;
use std::net::SocketAddr;
@ -66,17 +66,19 @@ where
upstream_candidates: &UpstreamCandidates,
tls_enabled: bool,
) -> Result<HandlerContext> {
debug!("Generate request to be forwarded");
trace!("Generate request to be forwarded");
// Add te: trailer if contained in original request
let contains_te_trailers = {
if let Some(te) = req.headers().get(header::TE) {
te.as_bytes()
.split(|v| v == &b',' || v == &b' ')
.any(|x| x == "trailers".as_bytes())
} else {
false
}
req
.headers()
.get(header::TE)
.map(|te| {
te.as_bytes()
.split(|v| v == &b',' || v == &b' ')
.any(|x| x == "trailers".as_bytes())
})
.unwrap_or(false)
};
let original_uri = req.uri().to_string();
@ -136,11 +138,7 @@ where
let new_uri = Uri::builder()
.scheme(upstream_chosen.uri.scheme().unwrap().as_str())
.authority(upstream_chosen.uri.authority().unwrap().as_str());
let org_pq = match req.uri().path_and_query() {
Some(pq) => pq.to_string(),
None => "/".to_string(),
}
.into_bytes();
let org_pq = req.uri().path_and_query().map(|pq| pq.as_str()).unwrap_or("/").as_bytes();
// replace some parts of path if opt_replace_path is enabled for chosen upstream
let new_pq = match &upstream_candidates.replace_path {
@ -155,7 +153,7 @@ where
new_pq.extend_from_slice(&org_pq[matched_path.len()..]);
new_pq
}
None => org_pq,
None => org_pq.to_vec(),
};
*req.uri_mut() = new_uri.path_and_query(new_pq).build()?;

View file

@ -34,11 +34,7 @@ impl<T> From<&http::Request<T>> for HttpMessageLog {
client_addr: "".to_string(),
method: req.method().to_string(),
host: header_mapper(header::HOST),
p_and_q: req
.uri()
.path_and_query()
.map_or_else(|| "", |v| v.as_str())
.to_string(),
p_and_q: req.uri().path_and_query().map_or_else(|| "", |v| v.as_str()).to_string(),
version: req.version(),
uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(),
uri_host: req.uri().host().unwrap_or("").to_string(),
@ -50,6 +46,33 @@ impl<T> From<&http::Request<T>> for HttpMessageLog {
}
}
impl std::fmt::Display for HttpMessageLog {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"",
if !self.host.is_empty() {
self.host.as_str()
} else {
self.uri_host.as_str()
},
self.client_addr,
self.method,
self.p_and_q,
self.version,
self.status,
if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() {
format!("{}://{}", self.uri_scheme, self.uri_host)
} else {
"".to_string()
},
self.ua,
self.xff,
self.upstream
)
}
}
impl HttpMessageLog {
pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self {
self.client_addr = client_addr.to_canonical().to_string();
@ -74,26 +97,8 @@ impl HttpMessageLog {
pub fn output(&self) {
info!(
"{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"",
if !self.host.is_empty() {
self.host.as_str()
} else {
self.uri_host.as_str()
},
self.client_addr,
self.method,
self.p_and_q,
self.version,
self.status,
if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() {
format!("{}://{}", self.uri_scheme, self.uri_host)
} else {
"".to_string()
},
self.ua,
self.xff,
self.upstream,
// self.tls_server_name
name: crate::constants::log_event_names::ACCESS_LOG,
"{}", self
);
}
}

View file

@ -53,6 +53,7 @@ impl From<HttpError> for StatusCode {
HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToGetResponseFromBackend(_) => StatusCode::BAD_GATEWAY,
// HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST,
// HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY,
_ => StatusCode::INTERNAL_SERVER_ERROR,

View file

@ -1,7 +1,7 @@
use super::http_result::{HttpError, HttpResult};
use crate::{
error::*,
hyper_ext::body::{empty, ResponseBody},
hyper_ext::body::{ResponseBody, empty},
name_exp::ServerName,
};
use http::{Request, Response, StatusCode, Uri};

View file

@ -3,9 +3,9 @@ use crate::{
backend::{UpstreamCandidates, UpstreamOption},
log::*,
};
use anyhow::{anyhow, Result};
use anyhow::{Result, anyhow, ensure};
use bytes::BufMut;
use http::{header, HeaderMap, HeaderName, HeaderValue, Uri};
use http::{HeaderMap, HeaderName, HeaderValue, Uri, header};
use std::{borrow::Cow, net::SocketAddr};
#[cfg(feature = "sticky-cookie")]
@ -238,10 +238,9 @@ pub(super) fn add_forwarding_header(
pub(super) fn remove_connection_header(headers: &mut HeaderMap) {
if let Some(values) = headers.get(header::CONNECTION) {
if let Ok(v) = values.clone().to_str() {
for m in v.split(',') {
if !m.is_empty() {
headers.remove(m.trim());
}
let keys = v.split(',').map(|m| m.trim()).filter(|m| !m.is_empty());
for m in keys {
headers.remove(m);
}
}
}
@ -274,13 +273,11 @@ pub(super) fn extract_upgrade(headers: &HeaderMap) -> Option<String> {
.to_str()
.unwrap_or("")
.split(',')
.any(|w| w.trim().to_ascii_lowercase() == header::UPGRADE.as_str().to_ascii_lowercase())
.any(|w| w.trim().eq_ignore_ascii_case(header::UPGRADE.as_str()))
{
if let Some(u) = headers.get(header::UPGRADE) {
if let Ok(m) = u.to_str() {
debug!("Upgrade in request header: {}", m);
return Some(m.to_owned());
}
if let Some(Ok(m)) = headers.get(header::UPGRADE).map(|u| u.to_str()) {
debug!("Upgrade in request header: {}", m);
return Some(m.to_owned());
}
}
}

View file

@ -2,8 +2,8 @@ use crate::{
backend::{Upstream, UpstreamCandidates, UpstreamOption},
log::*,
};
use anyhow::{anyhow, ensure, Result};
use http::{header, uri::Scheme, Request, Version};
use anyhow::{Result, anyhow, ensure};
use http::{Request, Version, header, uri::Scheme};
/// Trait defining parser of hostname
/// Inspect and extract hostname from either the request HOST header or request line
@ -59,6 +59,18 @@ pub(super) fn update_request_line<B>(
upstream_chosen: &Upstream,
upstream_candidates: &UpstreamCandidates,
) -> anyhow::Result<()> {
// If request is grpc, HTTP/2 is required
if req
.headers()
.get(header::CONTENT_TYPE)
.map(|v| v.as_bytes().starts_with(b"application/grpc"))
== Some(true)
{
debug!("Must be http/2 for gRPC request.");
*req.version_mut() = Version::HTTP_2;
return Ok(());
}
// If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) {
// Change version to http/1.1 when destination scheme is http

View file

@ -14,12 +14,11 @@ use crate::{
name_exp::ServerName,
};
use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder};
use rustc_hash::FxHashMap as HashMap;
use rustls::ServerConfig;
use std::sync::Arc;
/// SNI to ServerConfig map type
pub type SniServerCryptoMap = HashMap<ServerName, Arc<ServerConfig>>;
pub type SniServerCryptoMap = std::collections::HashMap<ServerName, Arc<ServerConfig>, ahash::RandomState>;
pub(crate) use proxy_main::Proxy;

View file

@ -33,7 +33,7 @@ where
<<C as OpenStreams<Bytes>>::BidiStream as BidiStream<Bytes>>::SendStream: Send,
{
let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?;
info!(
debug!(
"QUIC/HTTP3 connection established from {:?} {}",
client_addr,
<&ServerName as TryInto<String>>::try_into(&tls_server_name).unwrap_or_default()
@ -49,12 +49,17 @@ where
}
Err(e) => {
warn!("HTTP/3 error on accept incoming connection: {}", e);
match e.get_error_level() {
h3::error::ErrorLevel::ConnectionError => break,
h3::error::ErrorLevel::StreamError => continue,
}
break;
}
Ok(Some((req, stream))) => {
// Ok(Some((req, stream))) => {
Ok(Some(req_resolver)) => {
let (req, stream) = match req_resolver.resolve_request().await {
Ok((req, stream)) => (req, stream),
Err(e) => {
warn!("HTTP/3 error on resolve request in stream: {}", e);
continue;
}
};
// We consider the connection count separately from the stream count.
// Max clients for h1/h2 = max 'stream' for h3.
let request_count = self.globals.request_count.clone();
@ -63,7 +68,7 @@ where
h3_conn.shutdown(0).await?;
break;
}
debug!("Request incoming: current # {}", request_count.current());
trace!("Request incoming: current # {}", request_count.current());
let self_inner = self.clone();
let tls_server_name_inner = tls_server_name.clone();
@ -77,7 +82,7 @@ where
warn!("HTTP/3 error on serve stream: {}", e);
}
request_count.decrement();
debug!("Request processed: current # {}", request_count.current());
trace!("Request processed: current # {}", request_count.current());
});
}
}
@ -115,7 +120,7 @@ where
let mut sender = body_sender;
let mut size = 0usize;
while let Some(mut body) = recv_stream.recv_data().await? {
debug!("HTTP/3 incoming request body: remaining {}", body.remaining());
trace!("HTTP/3 incoming request body: remaining {}", body.remaining());
size += body.remaining();
if size > max_body_size {
error!(
@ -129,9 +134,9 @@ where
}
// trailers: use inner for work around. (directly get trailer)
let trailers = recv_stream.as_mut().recv_trailers().await?;
let trailers = futures_util::future::poll_fn(|cx| recv_stream.as_mut().poll_recv_trailers(cx)).await?;
if trailers.is_some() {
debug!("HTTP/3 incoming request trailers");
trace!("HTTP/3 incoming request trailers");
sender.send_trailers(trailers.unwrap()).await?;
}
Ok(()) as RpxyResult<()>
@ -154,13 +159,13 @@ where
match send_stream.send_response(new_res).await {
Ok(_) => {
debug!("HTTP/3 response to connection successful");
trace!("HTTP/3 response to connection successful");
// on-demand body streaming to downstream without expanding the object onto memory.
loop {
let frame = match new_body.frame().await {
Some(frame) => frame,
None => {
debug!("Response body finished");
trace!("Response body finished");
break;
}
}

View file

@ -11,7 +11,7 @@ use crate::{
message_handler::HttpMessageHandler,
name_exp::ServerName,
};
use futures::{select, FutureExt};
use futures::{FutureExt, select};
use http::{Request, Response};
use hyper::{
body::Incoming,
@ -22,6 +22,7 @@ use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::au
use rpxy_certs::ServerCrypto;
use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio::time::timeout;
use tokio_util::sync::CancellationToken;
/// Wrapper function to handle request for HTTP/1.1 and HTTP/2
/// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler
@ -79,7 +80,7 @@ where
request_count.decrement();
return;
}
debug!("Request incoming: current # {}", request_count.current());
trace!("Request incoming: current # {}", request_count.current());
let server_clone = self.connection_builder.clone();
let message_handler_clone = self.message_handler.clone();
@ -109,7 +110,7 @@ where
}
request_count.decrement();
debug!("Request processed: current # {}", request_count.current());
trace!("Request processed: current # {}", request_count.current());
});
}
@ -129,30 +130,56 @@ where
}
/// Start with TLS (HTTPS)
pub(super) async fn start_with_tls(&self) -> RpxyResult<()> {
pub(super) async fn start_with_tls(&self, cancel_token: CancellationToken) -> RpxyResult<()> {
// By default, TLS listener is spawned
let join_handle_tls = self.globals.runtime_handle.spawn({
let self_clone = self.clone();
let cancel_token = cancel_token.clone();
async move {
select! {
_ = self_clone.tls_listener_service().fuse() => {
error!("TCP proxy service for TLS exited");
cancel_token.cancel();
},
_ = cancel_token.cancelled().fuse() => {
debug!("Cancel token is called for TLS listener");
}
}
}
});
#[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))]
{
self.tls_listener_service().await?;
error!("TCP proxy service for TLS exited");
let _ = join_handle_tls.await;
Ok(())
}
#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
{
if self.globals.proxy_config.http3 {
select! {
_ = self.tls_listener_service().fuse() => {
error!("TCP proxy service for TLS exited");
},
_ = self.h3_listener_service().fuse() => {
error!("UDP proxy service for QUIC exited");
}
};
Ok(())
} else {
self.tls_listener_service().await?;
error!("TCP proxy service for TLS exited");
Ok(())
// If HTTP/3 is not enabled, wait for TLS listener to finish
if !self.globals.proxy_config.http3 {
let _ = join_handle_tls.await;
return Ok(());
}
// If HTTP/3 is enabled, spawn a task to handle HTTP/3 connections
let join_handle_h3 = self.globals.runtime_handle.spawn({
let self_clone = self.clone();
async move {
select! {
_ = self_clone.h3_listener_service().fuse() => {
error!("UDP proxy service for QUIC exited");
cancel_token.cancel();
},
_ = cancel_token.cancelled().fuse() => {
debug!("Cancel token is called for QUIC listener");
}
}
}
});
let _ = futures::future::join(join_handle_tls, join_handle_h3).await;
Ok(())
}
}
@ -294,7 +321,7 @@ where
let map = server_config.individual_config_map.clone().iter().map(|(k,v)| {
let server_name = ServerName::from(k.as_slice());
(server_name, v.clone())
}).collect::<rustc_hash::FxHashMap<_,_>>();
}).collect::<std::collections::HashMap<_,_,ahash::RandomState>>();
server_crypto_map = Some(Arc::new(map));
}
}
@ -303,10 +330,10 @@ where
}
/// Entrypoint for HTTP/1.1, 2 and 3 servers
pub async fn start(&self) -> RpxyResult<()> {
pub async fn start(&self, cancel_token: CancellationToken) -> RpxyResult<()> {
let proxy_service = async {
if self.tls_enabled {
self.start_with_tls().await
self.start_with_tls(cancel_token).await
} else {
self.start_without_tls().await
}

View file

@ -2,8 +2,8 @@ use super::{proxy_main::Proxy, socket::bind_udp_socket};
use crate::{error::*, log::*, name_exp::ByteName};
use hyper_util::client::legacy::connect::Connect;
use quinn::{
crypto::rustls::{HandshakeData, QuicServerConfig},
Endpoint, TransportConfig,
crypto::rustls::{HandshakeData, QuicServerConfig},
};
use rpxy_certs::ServerCrypto;
use rustls::ServerConfig;
@ -82,7 +82,7 @@ where
let client_addr = incoming.remote_address();
let quic_connection = match incoming.await {
Ok(new_conn) => {
info!("New connection established");
trace!("New connection established");
h3_quinn::Connection::new(new_conn)
},
Err(e) => {

View file

@ -110,7 +110,7 @@ where
// quic event loop. this immediately cancels when crypto is updated by tokio::select!
while let Some(new_conn) = server.accept().await {
debug!("New QUIC connection established");
trace!("New QUIC connection established");
let Ok(Some(new_server_name)) = new_conn.server_name() else {
warn!("HTTP/3 no SNI is given");
continue;

View file

@ -16,10 +16,12 @@ pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> RpxyResult<TcpSocket
}?;
tcp_socket.set_reuseaddr(true)?;
tcp_socket.set_reuseport(true)?;
if let Err(e) = tcp_socket.bind(*listening_on) {
tcp_socket.bind(*listening_on).map_err(|e| {
error!("Failed to bind TCP socket: {}", e);
return Err(RpxyError::Io(e));
};
RpxyError::Io(e)
})?;
Ok(tcp_socket)
}
@ -36,11 +38,10 @@ pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> RpxyResult<UdpSocket
socket.set_reuse_port(true)?;
socket.set_nonblocking(true)?; // This was made true inside quinn. so this line isn't necessary here. but just in case.
if let Err(e) = socket.bind(&(*listening_on).into()) {
socket.bind(&(*listening_on).into()).map_err(|e| {
error!("Failed to bind UDP socket: {}", e);
return Err(RpxyError::Io(e));
};
let udp_socket: UdpSocket = socket.into();
RpxyError::Io(e)
})?;
Ok(udp_socket)
Ok(socket.into())
}