diff --git a/Cargo.toml b/Cargo.toml index f8515ea..f2858bd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,8 +17,8 @@ http3 = ["quinn", "h3", "h3-quinn"] [dependencies] env_logger = "0.10.0" -anyhow = "1.0.66" -clap = { version = "4.0.29", features = ["std", "cargo", "wrap_help"] } +anyhow = "1.0.68" +clap = { version = "4.1.1", features = ["std", "cargo", "wrap_help"] } futures = { version = "0.3.25", features = ["alloc", "async-await"] } hyper = { version = "0.14.23", default-features = false, features = [ "server", @@ -27,7 +27,7 @@ hyper = { version = "0.14.23", default-features = false, features = [ "stream", ] } log = "0.4.17" -tokio = { version = "1.23.0", default-features = false, features = [ +tokio = { version = "1.24.2", default-features = false, features = [ "net", "rt-multi-thread", "parking_lot", @@ -36,12 +36,12 @@ tokio = { version = "1.23.0", default-features = false, features = [ "macros", ] } tokio-rustls = { version = "0.23.4", features = ["early-data"] } -rustls-pemfile = "1.0.1" -rustls = { version = "0.20.7", default-features = false } +rustls-pemfile = "1.0.2" +rustls = { version = "0.20.8", default-features = false } rand = "0.8.5" -toml = { version = "0.5.9", default-features = false } +toml = { version = "0.5.11", default-features = false } rustc-hash = "1.1.0" -serde = { version = "1.0.150", default-features = false, features = ["derive"] } +serde = { version = "1.0.152", default-features = false, features = ["derive"] } hyper-rustls = { version = "0.23.2", default-features = false, features = [ "tokio-runtime", "webpki-tokio", @@ -49,10 +49,10 @@ hyper-rustls = { version = "0.23.2", default-features = false, features = [ "http2", ] } bytes = "1.3.0" -quinn = { version = "0.8.5", optional = true } +quinn = { version = "0.9.3", optional = true } h3 = { path = "./h3/h3/", optional = true } -h3-quinn = { path = "./h3/h3-quinn/", optional = true } -thiserror = "1.0.37" +h3-quinn = { path = "./h3-quinn/", optional = true } +thiserror = "1.0.38" x509-parser = "0.14.0" derive_builder = "0.12.0" diff --git a/h3 b/h3 index 885144b..bf6320e 160000 --- a/h3 +++ b/h3 @@ -1 +1 @@ -Subproject commit 885144bfcb0414bcb86b63cec9c65a611c7e56ee +Subproject commit bf6320e8a82eff4bb12b80da11f714ea643861f4 diff --git a/h3-quinn/.gitignore b/h3-quinn/.gitignore new file mode 100644 index 0000000..1e7caa9 --- /dev/null +++ b/h3-quinn/.gitignore @@ -0,0 +1,2 @@ +Cargo.lock +target/ diff --git a/h3-quinn/Cargo.toml b/h3-quinn/Cargo.toml new file mode 100644 index 0000000..9c28e3e --- /dev/null +++ b/h3-quinn/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "h3-quinn" +version = "0.0.0" +authors = ["Original from hyperium/h3", "Jun Kurihara"] +edition = "2021" +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +h3 = { path = "../h3/h3/" } +bytes = "1.3.0" +futures-util = { version = "0.3.25", default-features = false, features = [ + "io", +] } +quinn = { version = "0.9.3", default-features = false } +quinn-proto = { version = "0.9.2", default-features = false } diff --git a/h3-quinn/src/lib.rs b/h3-quinn/src/lib.rs new file mode 100644 index 0000000..26f1327 --- /dev/null +++ b/h3-quinn/src/lib.rs @@ -0,0 +1,405 @@ +//! [`h3::quic`] traits implemented with Quinn +//! Copied from the original pull request at `hyperium/h3`: https://github.com/hyperium/h3/pull/145. +//! Currently maintained by Jun. + +#![deny(missing_docs)] + +use std::{ + convert::TryInto, + fmt::{self, Display}, + sync::Arc, + task::{Context, Poll}, +}; + +use bytes::{Buf, Bytes}; +use futures_util::{ready, FutureExt}; +use h3::quic::{self, Error, StreamId, WriteBuf}; +use quinn::VarInt; + +pub use quinn::{self}; + +/// QUIC connection +/// +/// A [`quic::Connection`] backed by [`quinn::Connection`]. +pub struct Connection { + conn: quinn::Connection, +} + +impl Connection { + /// Create a [`Connection`] from a [`quinn::Connection`] + pub fn new(conn: quinn::Connection) -> Self { + Self { conn } + } +} + +impl quic::Connection for Connection { + type OpenStreams = OpenStreams; + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type Error = ConnectionError; + + fn poll_accept_bidi(&mut self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + Poll::Ready(match ready!(Box::pin(self.conn.accept_bi()).poll_unpin(cx)) { + Ok((send, recv)) => Ok(Some(Self::BidiStream::new( + Self::SendStream::new(send), + Self::RecvStream::new(recv), + ))), + Err(e) => Err(ConnectionError(e)), + }) + } + + fn poll_accept_recv(&mut self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + Poll::Ready(match ready!(Box::pin(self.conn.accept_uni()).poll_unpin(cx)) { + Ok(recv) => Ok(Some(Self::RecvStream::new(recv))), + Err(e) => Err(ConnectionError(e)), + }) + } + + fn poll_open_bidi(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(match ready!(Box::pin(self.conn.open_bi()).poll_unpin(cx)) { + Ok((send, recv)) => Ok(Self::BidiStream::new( + Self::SendStream::new(send), + Self::RecvStream::new(recv), + )), + Err(e) => Err(ConnectionError(e)), + }) + } + + fn poll_open_send(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(match ready!(Box::pin(self.conn.open_uni()).poll_unpin(cx)) { + Ok(send) => Ok(Self::SendStream::new(send)), + Err(e) => Err(ConnectionError(e)), + }) + } + + fn opener(&self) -> Self::OpenStreams { + Self::OpenStreams { + conn: self.conn.clone(), + } + } + + fn close(&mut self, code: h3::error::Code, reason: &[u8]) { + self + .conn + .close(VarInt::from_u64(code.value()).expect("Invalid error code"), reason); + } +} + +/// Stream opener +/// +/// Implements [`quic::OpenStreams`]. +pub struct OpenStreams { + conn: quinn::Connection, +} + +impl quic::OpenStreams for OpenStreams { + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type Error = ConnectionError; + + fn poll_open_bidi(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(match ready!(Box::pin(self.conn.open_bi()).poll_unpin(cx)) { + Ok((send, recv)) => Ok(Self::BidiStream::new( + Self::SendStream::new(send), + Self::RecvStream::new(recv), + )), + Err(e) => Err(ConnectionError(e)), + }) + } + + fn poll_open_send(&mut self, cx: &mut Context<'_>) -> Poll> { + Poll::Ready(match ready!(Box::pin(self.conn.open_uni()).poll_unpin(cx)) { + Ok(send) => Ok(Self::SendStream::new(send)), + Err(e) => Err(ConnectionError(e)), + }) + } + + fn close(&mut self, code: h3::error::Code, reason: &[u8]) { + self + .conn + .close(VarInt::from_u64(code.value()).expect("Invalid error code"), reason); + } +} + +impl Clone for OpenStreams { + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + } + } +} + +/// Stream that can be used to send and receive data +/// +/// A [`quic::BidiStream`], which can be split into one send-only and +/// one receive-only stream. +pub struct BidiStream { + send: SendStream, + recv: RecvStream, +} + +impl BidiStream { + fn new(send: SendStream, recv: RecvStream) -> Self { + Self { send, recv } + } +} + +impl quic::BidiStream for BidiStream { + type SendStream = SendStream; + type RecvStream = RecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + (self.send, self.recv) + } +} + +impl quic::SendStream for BidiStream { + type Error = SendError; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.send.poll_ready(cx) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + self.send.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.send.reset(reset_code) + } + + fn id(&self) -> StreamId { + self.send.id() + } +} + +impl quic::RecvStream for BidiStream { + type Buf = Bytes; + type Error = RecvError; + + fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + self.recv.poll_data(cx) + } + + fn stop_sending(&mut self, err_code: u64) { + self.recv.stop_sending(err_code) + } +} + +/// Send-only stream +/// +/// A [`quic::SendStream`] backed by [`quinn::SendStream`]. +pub struct SendStream { + stream: quinn::SendStream, + writing: Option>, +} + +impl SendStream { + fn new(stream: quinn::SendStream) -> Self { + Self { stream, writing: None } + } +} + +impl quic::SendStream for SendStream { + type Error = SendError; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + if let Some(ref mut data) = self.writing { + while data.has_remaining() { + match ready!({ + let mut write_fut = Box::pin(self.stream.write(data.chunk())); + write_fut.poll_unpin(cx) + }) { + Ok(cnt) => data.advance(cnt), + Err(e) => return Poll::Ready(Err(Self::Error::Write(e))), + } + } + } + + self.writing = None; + Poll::Ready(Ok(())) + } + + fn send_data>>(&mut self, data: T) -> Result<(), Self::Error> { + match self.writing { + Some(_) => Err(Self::Error::NotReady), + None => { + self.writing = Some(data.into()); + Ok(()) + } + } + } + + fn poll_finish(&mut self, cx: &mut Context<'_>) -> Poll> { + Box::pin(self.stream.finish()).poll_unpin(cx).map_err(Into::into) + } + + fn reset(&mut self, reset_code: u64) { + let _ = self.stream.reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); + } + + fn id(&self) -> StreamId { + self.stream.id().0.try_into().expect("Invalid stream id") + } +} + +/// Receive-only stream +/// +/// A [`quic::RecvStream`] backed by [`quinn::RecvStream`]. +pub struct RecvStream { + stream: quinn::RecvStream, +} + +impl RecvStream { + fn new(stream: quinn::RecvStream) -> Self { + Self { stream } + } +} + +impl quic::RecvStream for RecvStream { + type Buf = Bytes; + type Error = RecvError; + + fn poll_data(&mut self, cx: &mut Context<'_>) -> Poll, Self::Error>> { + let data = ready!(Box::pin(self.stream.read_chunk(usize::MAX, true)).poll_unpin(cx))?; + Poll::Ready(Ok(data.map(|ch| ch.bytes))) + } + + fn stop_sending(&mut self, err_code: u64) { + let _ = self + .stream + .stop(VarInt::from_u64(err_code).expect("Invalid error code")); + } +} + +/// The error type for [`quic::Connection::Error`] +/// +/// Used by [`Connection`]. +#[derive(Debug)] +pub struct ConnectionError(quinn::ConnectionError); + +impl Error for ConnectionError { + fn is_timeout(&self) -> bool { + matches!(self.0, quinn::ConnectionError::TimedOut) + } + + fn err_code(&self) -> Option { + match self.0 { + quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { error_code, .. }) => { + Some(error_code.into_inner()) + } + _ => None, + } + } +} + +impl std::error::Error for ConnectionError {} + +impl Display for ConnectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl From for ConnectionError { + fn from(e: quinn::ConnectionError) -> Self { + Self(e) + } +} + +/// The error type for [`quic::SendStream::Error`] +/// +/// Used by [`SendStream`] and [`BidiStream`]. +#[derive(Debug)] +pub enum SendError { + /// For write errors, wrapping a [`quinn::WriteError`] + Write(quinn::WriteError), + /// For trying to send when stream is not ready, because it is + /// still sending data from the previous call + NotReady, +} + +impl Error for SendError { + fn is_timeout(&self) -> bool { + matches!( + self, + Self::Write(quinn::WriteError::ConnectionLost(quinn::ConnectionError::TimedOut)) + ) + } + + fn err_code(&self) -> Option { + match self { + Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), + Self::Write(quinn::WriteError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + ))) => Some(error_code.into_inner()), + _ => None, + } + } +} + +impl std::error::Error for SendError {} + +impl Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:?}", self) + } +} + +impl From for SendError { + fn from(e: quinn::WriteError) -> Self { + Self::Write(e) + } +} + +/// The error type for [`quic::RecvStream::Error`] +/// +/// Used by [`RecvStream`] and [`BidiStream`]. +#[derive(Debug)] +pub struct RecvError(quinn::ReadError); + +impl Error for RecvError { + fn is_timeout(&self) -> bool { + matches!( + self.0, + quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut), + ) + } + + fn err_code(&self) -> Option { + match self.0 { + quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { + error_code, + .. + })) => Some(error_code.into_inner()), + quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), + _ => None, + } + } +} + +impl std::error::Error for RecvError {} + +impl fmt::Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl From for Arc { + fn from(e: RecvError) -> Self { + Arc::new(e) + } +} + +impl From for RecvError { + fn from(e: quinn::ReadError) -> Self { + Self(e) + } +} diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 416ca8a..407a30f 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -15,8 +15,6 @@ use tokio::{ time::{sleep, timeout, Duration}, }; -#[cfg(feature = "http3")] -use futures::StreamExt; #[cfg(feature = "http3")] use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; @@ -136,12 +134,13 @@ where let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config)); server_config_h3.transport = Arc::new(transport_config_quic); server_config_h3.concurrent_connections(self.globals.h3_max_concurrent_connections); - let (endpoint, mut incoming) = Endpoint::server(server_config_h3, self.listening_on)?; + // let (endpoint, mut incoming) = Endpoint::server(server_config_h3, self.listening_on)?; + let endpoint = Endpoint::server(server_config_h3, self.listening_on)?; let mut server_crypto: Option> = None; loop { tokio::select! { - new_conn = incoming.next() => { + new_conn = endpoint.accept() => { if server_crypto.is_none() || new_conn.is_none() { continue; }