diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 847d3fe..3469162 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -50,6 +50,8 @@ http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } hyper-util = { version = "0.1.1", features = ["full"] } +futures-util = { version = "0.3.29", default-features = false } +futures-channel = { version = "0.3.29", default-features = false } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio", diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 8152e0d..37f35e0 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -25,11 +25,16 @@ pub enum RpxyError { // hyper errors #[error("hyper body manipulation error: {0}")] HyperBodyManipulationError(String), + #[error("New closed in incoming-like")] + HyperIncomingLikeNewClosed, // http/3 errors #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] #[error("H3 error: {0}")] H3Error(#[from] h3::Error), + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + #[error("Exceeds max request body size for HTTP/3")] + H3TooLargeBody, #[cfg(feature = "http3-quinn")] #[error("Invalid rustls TLS version: {0}")] diff --git a/rpxy-lib/src/hyper_ext/body_incoming_like.rs b/rpxy-lib/src/hyper_ext/body_incoming_like.rs new file mode 100644 index 0000000..2fced25 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/body_incoming_like.rs @@ -0,0 +1,189 @@ +use super::watch; +use crate::error::*; +use futures_channel::{mpsc, oneshot}; +use futures_util::{stream::FusedStream, Future, Stream}; +use http::HeaderMap; +use hyper::body::{Body, Bytes, Frame, SizeHint}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +//////////////////////////////////////////////////////////// +/// Incoming like body to handle incoming request body +pub struct IncomingLike { + content_length: DecodedLength, + want_tx: watch::Sender, + data_rx: mpsc::Receiver>, + trailers_rx: oneshot::Receiver, +} + +macro_rules! ready { + ($e:expr) => { + match $e { + Poll::Ready(v) => v, + Poll::Pending => return Poll::Pending, + } + }; +} + +type BodySender = mpsc::Sender>; +type TrailersSender = oneshot::Sender; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) struct DecodedLength(u64); +impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } + /// Converts to an Option representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option { + match self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known), + } + } +} +pub(crate) struct Sender { + want_rx: watch::Receiver, + data_tx: BodySender, + trailers_tx: Option, +} + +const WANT_PENDING: usize = 1; +const WANT_READY: usize = 2; + +impl IncomingLike { + /// Create a `Body` stream with an associated sender half. + /// + /// Useful when wanting to stream chunks from another thread. + #[inline] + #[allow(unused)] + pub(crate) fn channel() -> (Sender, IncomingLike) { + Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) + } + + pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, IncomingLike) { + let (data_tx, data_rx) = mpsc::channel(0); + let (trailers_tx, trailers_rx) = oneshot::channel(); + + // If wanter is true, `Sender::poll_ready()` won't becoming ready + // until the `Body` has been polled for data once. + let want = if wanter { WANT_PENDING } else { WANT_READY }; + + let (want_tx, want_rx) = watch::channel(want); + + let tx = Sender { + want_rx, + data_tx, + trailers_tx: Some(trailers_tx), + }; + let rx = IncomingLike { + content_length, + want_tx, + data_rx, + trailers_rx, + }; + + (tx, rx) + } +} + +impl Body for IncomingLike { + type Data = Bytes; + type Error = hyper::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.want_tx.send(WANT_READY); + + if !self.data_rx.is_terminated() { + if let Some(chunk) = ready!(Pin::new(&mut self.data_rx).poll_next(cx)?) { + self.content_length.sub_if(chunk.len() as u64); + return Poll::Ready(Some(Ok(Frame::data(chunk)))); + } + } + + // check trailers after data is terminated + match ready!(Pin::new(&mut self.trailers_rx).poll(cx)) { + Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))), + Err(_) => Poll::Ready(None), + } + } + + fn is_end_stream(&self) -> bool { + self.content_length == DecodedLength::ZERO + } + + fn size_hint(&self) -> SizeHint { + macro_rules! opt_len { + ($content_length:expr) => {{ + let mut hint = SizeHint::default(); + + if let Some(content_length) = $content_length.into_opt() { + hint.set_exact(content_length); + } + + hint + }}; + } + + opt_len!(self.content_length) + } +} + +impl Sender { + /// Check to see if this `Sender` can send more data. + pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); + self + .data_tx + .poll_ready(cx) + .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } + + fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> RpxyResult<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + + /// Send data on data channel when it is ready. + #[allow(unused)] + pub(crate) async fn send_data(&mut self, chunk: Bytes) -> RpxyResult<()> { + self.ready().await?; + self + .data_tx + .try_send(Ok(chunk)) + .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } + + /// Send trailers on trailers channel. + #[allow(unused)] + pub(crate) async fn send_trailers(&mut self, trailers: HeaderMap) -> RpxyResult<()> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(RpxyError::HyperIncomingLikeNewClosed), + }; + tx.send(trailers).map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } +} diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs new file mode 100644 index 0000000..ba1bdc2 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -0,0 +1,41 @@ +use crate::error::*; +use http::{Response, StatusCode}; +use http_body_util::{combinators, BodyExt, Either, Empty, Full}; +use hyper::body::{Bytes, Incoming}; + +/// Type for synthetic boxed body +pub(crate) type BoxBody = combinators::BoxBody; +/// Type for either passthrough body or given body type, specifically synthetic boxed body +pub(crate) type IncomingOr = Either; + +/// helper function to build http response with passthrough body +pub(crate) fn passthrough_response(response: Response) -> RpxyResult>> +where + B: hyper::body::Body, +{ + Ok(response.map(IncomingOr::Left)) +} + +/// helper function to build http response with synthetic body +pub(crate) fn synthetic_response(response: Response) -> RpxyResult>> { + Ok(response.map(IncomingOr::Right)) +} + +/// build http response with status code of 4xx and 5xx +pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult>> { + let res = Response::builder() + .status(status_code) + .body(IncomingOr::Right(BoxBody::new(empty()))) + .unwrap(); + Ok(res) +} + +/// helper function to build a empty body +fn empty() -> BoxBody { + Empty::::new().map_err(|never| match never {}).boxed() +} + +/// helper function to build a full body +pub(crate) fn full(body: Bytes) -> BoxBody { + Full::new(body).map_err(|never| match never {}).boxed() +} diff --git a/rpxy-lib/src/hyper_executor.rs b/rpxy-lib/src/hyper_ext/executor.rs similarity index 100% rename from rpxy-lib/src/hyper_executor.rs rename to rpxy-lib/src/hyper_ext/executor.rs diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs new file mode 100644 index 0000000..19511a1 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -0,0 +1,13 @@ +mod body_incoming_like; +mod body_type; +mod executor; +mod watch; + +pub(crate) mod rt { + pub(crate) use super::executor::LocalExecutor; +} +pub(crate) mod body { + pub(crate) use super::body_incoming_like::IncomingLike; + pub(crate) use super::body_type::{BoxBody, IncomingOr}; +} +pub(crate) use body_type::{full, passthrough_response, synthetic_error_response, synthetic_response}; diff --git a/rpxy-lib/src/hyper_ext/watch.rs b/rpxy-lib/src/hyper_ext/watch.rs new file mode 100644 index 0000000..d5e1c7e --- /dev/null +++ b/rpxy-lib/src/hyper_ext/watch.rs @@ -0,0 +1,67 @@ +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. +// from https://github.com/hyperium/hyper/blob/master/src/common/watch.rs + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(super) const CLOSED: usize = 0; + +pub(super) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!(initial != CLOSED, "watch::channel initial state of 0 is reserved"); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + (Sender { shared: shared.clone() }, Receiver { shared }) +} + +pub(super) struct Sender { + shared: Arc, +} + +pub(super) struct Receiver { + shared: Arc, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(super) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + #[allow(dead_code)] + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 1f5fa37..706f7b2 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -4,7 +4,7 @@ mod count; mod crypto; mod error; mod globals; -mod hyper_executor; +mod hyper_ext; mod log; mod name_exp; mod proxy; diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 5b1ad61..e4ac6f7 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -6,7 +6,7 @@ mod proxy_quic_quinn; mod proxy_quic_s2n; mod socket; -use crate::{globals::Globals, hyper_executor::LocalExecutor}; +use crate::{globals::Globals, hyper_ext::rt::LocalExecutor}; use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; use std::sync::Arc; diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 056cd4b..6ca6528 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,6 +1,14 @@ use super::proxy_main::Proxy; -use crate::{error::*, log::*, name_exp::ServerName}; -use bytes::Bytes; +use crate::{ + error::*, + hyper_ext::{ + body::{IncomingLike, IncomingOr}, + full, synthetic_response, + }, + log::*, + name_exp::ServerName, +}; +use bytes::{Buf, Bytes}; use http::{Request, Response}; use http_body_util::BodyExt; use std::{net::SocketAddr, time::Duration}; @@ -11,7 +19,6 @@ use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestSt #[cfg(feature = "http3-s2n")] use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -// use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; // use futures::Stream; // use hyper_util::client::legacy::connect::Connect; @@ -111,48 +118,41 @@ impl Proxy { // split stream and async body handling let (mut send_stream, mut recv_stream) = stream.split(); - // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; - // // let max = body_stream.size_hint().upper().unwrap_or(u64::MAX); - // // if max > max_body_size as u64 { - // // return Err(HttpError::TooLargeRequestBody); - // // } + // generate streamed body with trailers using channel + let (body_sender, req_body) = IncomingLike::channel(); - // let new_req = Request::from_parts(req_parts, body_stream); + // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 + // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. + let max_body_size = self.globals.proxy_config.h3_request_max_body_size; + self.globals.runtime_handle.spawn(async move { + let mut sender = body_sender; + let mut size = 0usize; + while let Some(mut body) = recv_stream.recv_data().await? { + debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); + size += body.remaining(); + if size > max_body_size { + error!( + "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", + size, max_body_size + ); + return Err(RpxyError::H3TooLargeBody); + } + // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes + sender.send_data(body.copy_to_bytes(body.remaining())).await?; + } - // // generate streamed body with trailers using channel - // let (body_sender, req_body) = Incoming::channel(); + // trailers: use inner for work around. (directly get trailer) + let trailers = recv_stream.as_mut().recv_trailers().await?; + if trailers.is_some() { + debug!("HTTP/3 incoming request trailers"); + sender.send_trailers(trailers.unwrap()).await?; + } + Ok(()) as RpxyResult<()> + }); - // // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 - // // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. - // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; - // self.globals.runtime_handle.spawn(async move { - // // let mut sender = body_sender; - // let mut size = 0usize; - // while let Some(mut body) = recv_stream.recv_data().await? { - // debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); - // size += body.remaining(); - // if size > max_body_size { - // error!( - // "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", - // size, max_body_size - // ); - // return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); - // } - // // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes - // // sender.send_data(body.copy_to_bytes(body.remaining())).await?; - // } + let mut new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); - // // trailers: use inner for work around. (directly get trailer) - // let trailers = recv_stream.as_mut().recv_trailers().await?; - // if trailers.is_some() { - // debug!("HTTP/3 incoming request trailers"); - // // sender.send_trailers(trailers.unwrap()).await?; - // } - // Ok(()) - // }); - - // let new_req: Request = Request::from_parts(req_parts, req_body); - // let res = self + // let res = selfw // .msg_handler // .clone() // .handle_request( @@ -165,8 +165,9 @@ impl Proxy { // .await?; // TODO: TODO: TODO: remove later - let body = full(hyper::body::Bytes::from("hello h3 echo")); - let res = Response::builder().body(body).unwrap(); + let body = full(Bytes::from("hello h3 echo")); + // here response is IncomingOr from message handler + let res = synthetic_response(Response::builder().body(body).unwrap())?; ///////////////// let (new_res_parts, new_body) = res.into_parts(); @@ -193,13 +194,3 @@ impl Proxy { Ok(send_stream.finish().await?) } } - -////////////// -/// TODO: remove later -/// helper function to build a full body -use http_body_util::Full; -pub(crate) type BoxBody = http_body_util::combinators::BoxBody; -pub fn full(body: hyper::body::Bytes) -> BoxBody { - Full::new(body).map_err(|never| match never {}).boxed() -} -////////////// diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 0d6eb83..cc6636d 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -4,7 +4,12 @@ use crate::{ crypto::{ServerCrypto, SniServerCryptoMap}, error::*, globals::Globals, - hyper_executor::LocalExecutor, + hyper_ext::{ + body::{BoxBody, IncomingOr}, + full, + rt::LocalExecutor, + synthetic_response, + }, log::*, name_exp::ServerName, }; @@ -22,14 +27,14 @@ use tokio::time::timeout; /// Wrapper function to handle request for HTTP/1.1 and HTTP/2 /// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler async fn serve_request( - req: Request, + mut req: Request, // handler: Arc>, // handler: Arc>, client_addr: SocketAddr, listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, -) -> RpxyResult> { +) -> RpxyResult>> { // match handler // .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) // .await? @@ -37,19 +42,14 @@ async fn serve_request( // Ok(res) => passthrough_response(res), // Err(e) => synthetic_error_response(StatusCode::from(e)), // } + + ////////////// + // TODO: remove later let body = full(hyper::body::Bytes::from("hello")); let res = Response::builder().body(body).unwrap(); - Ok(res) + synthetic_response(res) + ////////////// } -////////////// -/// TODO: remove later -/// helper function to build a full body -use http_body_util::{BodyExt, Full}; -pub(crate) type BoxBody = http_body_util::combinators::BoxBody; -pub fn full(body: hyper::body::Bytes) -> BoxBody { - Full::new(body).map_err(|never| match never {}).boxed() -} -////////////// #[derive(Clone)] /// Proxy main object responsible to serve requests received from clients at the given socket address.