rust-rpxy/rpxy-lib/src/forwarder/client.rs
Pascal Engélibert 9518cc8b73
Some checks failed
Unit Test / test (push) Has been cancelled
ShiftLeft Scan / Scan-Build (push) Has been cancelled
Early data option
2026-01-08 11:05:06 +01:00

268 lines
9.4 KiB
Rust

#[allow(unused)]
use crate::{
error::{RpxyError, RpxyResult},
globals::Globals,
hyper_ext::{body::ResponseBody, rt::LocalExecutor},
log::*,
};
use async_trait::async_trait;
use http::{Request, Response, Version};
use hyper::body::{Body, Incoming};
use hyper_util::client::legacy::{
Client,
connect::{Connect, HttpConnector},
};
use std::sync::Arc;
#[cfg(feature = "cache")]
use super::cache::{RpxyCache, get_policy_if_cacheable};
#[async_trait]
/// Definition of the forwarder that simply forward requests from downstream client to upstream app servers.
pub trait ForwardRequest<B1, B2> {
type Error;
async fn request(&self, req: Request<B1>) -> Result<Response<B2>, Self::Error>;
}
/// Forwarder http client struct responsible to cache handling
pub struct Forwarder<C, B> {
#[cfg(feature = "cache")]
cache: Option<RpxyCache>,
inner: Client<C, B>,
inner_h2: Client<C, B>, // `h2c` or http/2-only client is defined separately
}
#[async_trait]
impl<C, B1> ForwardRequest<B1, ResponseBody> for Forwarder<C, B1>
where
C: Send + Sync + Connect + Clone + 'static,
B1: Body + Send + Sync + Unpin + 'static,
<B1 as Body>::Data: Send,
<B1 as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
type Error = RpxyError;
async fn request(&self, req: Request<B1>) -> Result<Response<ResponseBody>, Self::Error> {
// TODO: cache handling
#[cfg(feature = "cache")]
{
let mut synth_req = None;
if self.cache.is_some() {
// try reading from cache
if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await {
// if found, return it as response.
info!("Cache hit - Return from cache");
return Ok(cached_response);
};
// Synthetic request copy used just for caching (cannot clone request object...)
synth_req = Some(build_synth_req_for_cache(&req));
}
let res = self.request_directly(req).await;
if self.cache.is_none() {
return res.map(|inner| inner.map(ResponseBody::Incoming));
}
// check cacheability and store it if cacheable
let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else {
return res.map(|inner| inner.map(ResponseBody::Incoming));
};
let (parts, body) = res.unwrap().into_parts();
// Get streamed body without waiting for the arrival of the body,
// which is done simultaneously with caching.
let stream_body = self
.cache
.as_ref()
.unwrap()
.put(synth_req.unwrap().uri(), body, &cache_policy)
.await?;
// response with body being cached in background
let new_res = Response::from_parts(parts, ResponseBody::Streamed(stream_body));
Ok(new_res)
}
// No cache handling
#[cfg(not(feature = "cache"))]
{
self
.request_directly(req)
.await
.map(|inner| inner.map(ResponseBody::Incoming))
}
}
}
impl<C, B1> Forwarder<C, B1>
where
C: Send + Sync + Connect + Clone + 'static,
B1: Body + Send + Unpin + 'static,
<B1 as Body>::Data: Send,
<B1 as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
async fn request_directly(&self, req: Request<B1>) -> RpxyResult<Response<Incoming>> {
debug!(
"About to send request with Host header: {}",
req.headers().get(hyper::header::HOST).unwrap().to_str().unwrap()
);
// TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient.
// Needs to be reconsidered. Currently, this is a kind of work around.
// This possibly relates to https://github.com/hyperium/hyper/issues/2417.
match req.version() {
Version::HTTP_2 => self.inner_h2.request(req).await, // handles `h2c` requests
_ => self.inner.request(req).await,
}
.map_err(|e| RpxyError::FailedToFetchFromUpstream(e.to_string()))
}
}
#[cfg(not(any(feature = "native-tls-backend", feature = "rustls-backend")))]
impl<B> Forwarder<HttpConnector, B>
where
B: Body + Send + Unpin + 'static,
<B as Body>::Data: Send,
<B as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
/// Build inner client with http
pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> {
warn!(
"
--------------------------------------------------------------------------------------------------
Request forwarder is working without TLS support!
This mode is intended for testing only.
Enable 'native-tls-backend' or 'rustls-backend' feature for TLS support.
--------------------------------------------------------------------------------------------------"
);
let executor = LocalExecutor::new(_globals.runtime_handle.clone());
let mut http = HttpConnector::new();
http.enforce_http(true);
http.set_reuse_address(true);
http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout));
let inner = Client::builder(executor).build::<_, B>(http);
let inner_h2 = inner.clone();
Ok(Self {
inner,
inner_h2,
#[cfg(feature = "cache")]
cache: RpxyCache::new(_globals).await,
})
}
}
#[cfg(all(feature = "native-tls-backend", not(feature = "rustls-backend")))]
/// Build forwarder with hyper-tls (native-tls)
impl<B1> Forwarder<hyper_tls::HttpsConnector<HttpConnector>, B1>
where
B1: Body + Send + Unpin + 'static,
<B1 as Body>::Data: Send,
<B1 as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
/// Build forwarder
pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> {
// build hyper client with hyper-tls
info!("Native TLS support enabled for backend connections (native-tls)");
let executor = LocalExecutor::new(_globals.runtime_handle.clone());
let try_build_connector = |alpns: &[&str]| {
hyper_tls::native_tls::TlsConnector::builder()
.request_alpns(alpns)
.build()
.map_err(|e| RpxyError::FailedToBuildForwarder(e.to_string()))
.map(|tls| {
let mut http = HttpConnector::new();
http.enforce_http(false);
http.set_reuse_address(true);
http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout));
hyper_tls::HttpsConnector::from((http, tls.into()))
})
};
let connector = try_build_connector(&["h2", "http/1.1"])?;
let inner = Client::builder(executor.clone()).build::<_, B1>(connector);
let connector_h2 = try_build_connector(&["h2"])?;
let inner_h2 = Client::builder(executor.clone())
.http2_only(true)
.build::<_, B1>(connector_h2);
Ok(Self {
inner,
inner_h2,
#[cfg(feature = "cache")]
cache: RpxyCache::new(_globals).await,
})
}
}
#[cfg(feature = "rustls-backend")]
/// Build forwarder with hyper-rustls (rustls)
impl<B1> Forwarder<hyper_rustls::HttpsConnector<HttpConnector>, B1>
where
B1: Body + Send + Unpin + 'static,
<B1 as Body>::Data: Send,
<B1 as Body>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
/// Build forwarder
pub async fn try_new(globals: &Arc<Globals>) -> RpxyResult<Self> {
// build hyper client with rustls and webpki, only https is allowed
#[cfg(feature = "webpki-roots")]
let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
#[cfg(feature = "webpki-roots")]
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
#[cfg(feature = "webpki-roots")]
info!("Rustls backend: Mozilla WebPKI root certs used for backend connections");
#[cfg(not(feature = "webpki-roots"))]
let (builder, builder_h2) = {
use hyper_rustls::ConfigBuilderExt;
use rustls::{ClientConfig, client::Resumption};
let mut client_config = ClientConfig::builder()
.try_with_platform_verifier()
.unwrap()
.with_no_client_auth();
if globals.proxy_config.enable_early_data {
client_config.enable_early_data = true;
} else {
client_config.resumption = Resumption::disabled();
}
let builder = hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(client_config.clone());
let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_tls_config(client_config);
info!("Rustls backend: Platform verifier used for backend connections");
(builder, builder_h2)
};
let mut http = HttpConnector::new();
http.enforce_http(false);
http.set_reuse_address(true);
http.set_keepalive(Some(globals.proxy_config.upstream_idle_timeout));
let connector = builder.https_or_http().enable_all_versions().wrap_connector(http.clone());
let connector_h2 = builder_h2.https_or_http().enable_http2().wrap_connector(http);
let inner = Client::builder(LocalExecutor::new(globals.runtime_handle.clone())).build::<_, B1>(connector);
let inner_h2 = Client::builder(LocalExecutor::new(globals.runtime_handle.clone()))
.http2_only(true)
.set_host(false)
.build::<_, B1>(connector_h2);
Ok(Self {
inner,
inner_h2,
#[cfg(feature = "cache")]
cache: RpxyCache::new(_globals).await,
})
}
}
#[cfg(feature = "cache")]
/// Build synthetic request to cache
fn build_synth_req_for_cache<T>(req: &Request<T>) -> Request<()> {
let mut builder = Request::builder().method(req.method()).uri(req.uri()).version(req.version());
// TODO: omit extensions. is this approach correct?
for (header_key, header_value) in req.headers() {
builder = builder.header(header_key, header_value);
}
builder.body(()).unwrap()
}