diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 583ac59..51631af 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -97,4 +97,4 @@ sha2 = { version = "0.10.8", default-features = false, optional = true } [dev-dependencies] -# http and tls +tokio-test = "0.4.3" diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 37f35e0..49a1d1c 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -27,6 +27,8 @@ pub enum RpxyError { HyperBodyManipulationError(String), #[error("New closed in incoming-like")] HyperIncomingLikeNewClosed, + #[error("New body write aborted")] + HyperNewBodyWriteAborted, // http/3 errors #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] diff --git a/rpxy-lib/src/hyper_ext/body_incoming_like.rs b/rpxy-lib/src/hyper_ext/body_incoming_like.rs index 2fced25..9307b7f 100644 --- a/rpxy-lib/src/hyper_ext/body_incoming_like.rs +++ b/rpxy-lib/src/hyper_ext/body_incoming_like.rs @@ -11,10 +11,11 @@ use std::{ //////////////////////////////////////////////////////////// /// Incoming like body to handle incoming request body +/// ported from https://github.com/hyperium/hyper/blob/master/src/body/incoming.rs pub struct IncomingLike { content_length: DecodedLength, want_tx: watch::Sender, - data_rx: mpsc::Receiver>, + data_rx: mpsc::Receiver>, trailers_rx: oneshot::Receiver, } @@ -27,9 +28,10 @@ macro_rules! ready { }; } -type BodySender = mpsc::Sender>; +type BodySender = mpsc::Sender>; type TrailersSender = oneshot::Sender; +const MAX_LEN: u64 = std::u64::MAX - 2; #[derive(Clone, Copy, PartialEq, Eq)] pub(crate) struct DecodedLength(u64); impl DecodedLength { @@ -37,6 +39,12 @@ impl DecodedLength { pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); pub(crate) const ZERO: DecodedLength = DecodedLength(0); + #[allow(dead_code)] + pub(crate) fn new(len: u64) -> Self { + debug_assert!(len <= MAX_LEN); + DecodedLength(len) + } + pub(crate) fn sub_if(&mut self, amt: u64) { match *self { DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), @@ -100,7 +108,7 @@ impl IncomingLike { impl Body for IncomingLike { type Data = Bytes; - type Error = hyper::Error; + type Error = RpxyError; fn poll_frame( mut self: Pin<&mut Self>, @@ -186,4 +194,177 @@ impl Sender { }; tx.send(trailers).map_err(|_| RpxyError::HyperIncomingLikeNewClosed) } + + /// Try to send data on this channel. + /// + /// # Errors + /// + /// Returns `Err(Bytes)` if the channel could not (currently) accept + /// another `Bytes`. + /// + /// # Note + /// + /// This is mostly useful for when trying to send from some other thread + /// that doesn't have an async context. If in an async context, prefer + /// `send_data()` instead. + #[allow(unused)] + pub(crate) fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { + self + .data_tx + .try_send(Ok(chunk)) + .map_err(|err| err.into_inner().expect("just sent Ok")) + } + + #[allow(unused)] + pub(crate) fn abort(mut self) { + self.send_error(RpxyError::HyperNewBodyWriteAborted); + } + + pub(crate) fn send_error(&mut self, err: RpxyError) { + let _ = self + .data_tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(err)); + } +} + +#[cfg(test)] +mod tests { + use std::mem; + use std::task::Poll; + + use super::{Body, DecodedLength, IncomingLike, Sender, SizeHint}; + use crate::error::RpxyError; + use http_body_util::BodyExt; + + #[test] + fn test_size_of() { + // These are mostly to help catch *accidentally* increasing + // the size by too much. + + let body_size = mem::size_of::(); + let body_expected_size = mem::size_of::() * 5; + assert!( + body_size <= body_expected_size, + "Body size = {} <= {}", + body_size, + body_expected_size, + ); + + //assert_eq!(body_size, mem::size_of::>(), "Option"); + + assert_eq!(mem::size_of::(), mem::size_of::() * 5, "Sender"); + + assert_eq!( + mem::size_of::(), + mem::size_of::>(), + "Option" + ); + } + #[test] + fn size_hint() { + fn eq(body: IncomingLike, b: SizeHint, note: &str) { + let a = body.size_hint(); + assert_eq!(a.lower(), b.lower(), "lower for {:?}", note); + assert_eq!(a.upper(), b.upper(), "upper for {:?}", note); + } + + eq(IncomingLike::channel().1, SizeHint::new(), "channel"); + + eq( + IncomingLike::new_channel(DecodedLength::new(4), /*wanter =*/ false).1, + SizeHint::with_exact(4), + "channel with length", + ); + } + + #[tokio::test] + async fn channel_abort() { + let (tx, mut rx) = IncomingLike::channel(); + + tx.abort(); + + match rx.frame().await.unwrap() { + Err(RpxyError::HyperNewBodyWriteAborted) => true, + unexpected => panic!("unexpected: {:?}", unexpected), + }; + } + + #[tokio::test] + async fn channel_abort_when_buffer_is_full() { + let (mut tx, mut rx) = IncomingLike::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + // buffer is full, but can still send abort + tx.abort(); + + let chunk1 = rx.frame().await.expect("item 1").expect("chunk 1").into_data().unwrap(); + assert_eq!(chunk1, "chunk 1"); + + match rx.frame().await.unwrap() { + Err(RpxyError::HyperNewBodyWriteAborted) => true, + unexpected => panic!("unexpected: {:?}", unexpected), + }; + } + + #[test] + fn channel_buffers_one() { + let (mut tx, _rx) = IncomingLike::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + + // buffer is now full + let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); + assert_eq!(chunk2, "chunk 2"); + } + + #[tokio::test] + async fn channel_empty() { + let (_, mut rx) = IncomingLike::channel(); + + assert!(rx.frame().await.is_none()); + } + + #[test] + fn channel_ready() { + let (mut tx, _rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); + } + + #[test] + fn channel_wanter() { + let (mut tx, mut rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + let mut rx_data = tokio_test::task::spawn(rx.frame()); + + assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); + + assert!(rx_data.poll().is_pending(), "poll rx.data"); + assert!(tx_ready.is_woken(), "rx poll wakes tx"); + + assert!(tx_ready.poll().is_ready(), "tx is ready after rx has been polled"); + } + + #[test] + + fn channel_notices_closure() { + let (mut tx, rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); + + drop(rx); + assert!(tx_ready.is_woken(), "dropping rx wakes tx"); + + match tx_ready.poll() { + Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)) => (), + unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), + } + } }