From b639e79b4d3eb969622091f62389ad02861e5f55 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 18 Nov 2023 14:42:13 +0900 Subject: [PATCH 01/50] wip: implemented hyper-1.0 for http/1.1 and http/2. todo: http/3 and backend handler --- .gitmodules | 6 - rpxy-bin/Cargo.toml | 2 +- rpxy-lib/Cargo.toml | 33 +- rpxy-lib/src/globals.rs | 3 + rpxy-lib/src/handler/error.rs | 16 + rpxy-lib/src/handler/handler_main.rs | 608 ++++++++++---------- rpxy-lib/src/handler/mod.rs | 12 +- rpxy-lib/src/hyper_executor.rs | 45 ++ rpxy-lib/src/lib.rs | 22 +- rpxy-lib/src/proxy/mod.rs | 29 + rpxy-lib/src/proxy/proxy_h3.rs | 99 ++-- rpxy-lib/src/proxy/proxy_main.rs | 160 +++--- rpxy-lib/src/proxy/proxy_quic_quinn.rs | 6 +- rpxy-lib/src/proxy/proxy_quic_s2n.rs | 16 +- rpxy-lib/src/proxy/proxy_tls.rs | 44 +- submodules/h3 | 2 +- submodules/h3-quinn/Cargo.toml | 24 - submodules/h3-quinn/src/lib.rs | 740 ------------------------- submodules/quinn | 1 - submodules/s2n-quic | 1 - submodules/s2n-quic-h3/Cargo.toml | 17 + submodules/s2n-quic-h3/README.md | 10 + submodules/s2n-quic-h3/src/lib.rs | 7 + submodules/s2n-quic-h3/src/s2n_quic.rs | 506 +++++++++++++++++ 24 files changed, 1134 insertions(+), 1275 deletions(-) create mode 100644 rpxy-lib/src/handler/error.rs create mode 100644 rpxy-lib/src/hyper_executor.rs delete mode 100644 submodules/h3-quinn/Cargo.toml delete mode 100644 submodules/h3-quinn/src/lib.rs delete mode 160000 submodules/quinn delete mode 160000 submodules/s2n-quic create mode 100644 submodules/s2n-quic-h3/Cargo.toml create mode 100644 submodules/s2n-quic-h3/README.md create mode 100644 submodules/s2n-quic-h3/src/lib.rs create mode 100644 submodules/s2n-quic-h3/src/s2n_quic.rs diff --git a/.gitmodules b/.gitmodules index 65fcd3b..47ebad0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,6 @@ [submodule "submodules/h3"] path = submodules/h3 url = git@github.com:junkurihara/h3.git -[submodule "submodules/quinn"] - path = submodules/quinn - url = git@github.com:junkurihara/quinn.git -[submodule "submodules/s2n-quic"] - path = submodules/s2n-quic - url = git@github.com:junkurihara/s2n-quic.git [submodule "submodules/rusty-http-cache-semantics"] path = submodules/rusty-http-cache-semantics url = git@github.com:junkurihara/rusty-http-cache-semantics.git diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 36c53b1..1848e5e 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -12,7 +12,7 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "cache"] +default = ["http3-s2n", "cache"] http3-quinn = ["rpxy-lib/http3-quinn"] http3-s2n = ["rpxy-lib/http3-s2n"] cache = ["rpxy-lib/cache"] diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 7f10e60..b4b475d 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -12,7 +12,7 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "sticky-cookie", "cache"] +default = ["http3-s2n", "sticky-cookie", "cache"] http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] sticky-cookie = ["base64", "sha2", "chrono"] @@ -25,7 +25,7 @@ rustc-hash = "1.1.0" bytes = "1.5.0" derive_builder = "0.12.0" futures = { version = "0.3.29", features = ["alloc", "async-await"] } -tokio = { version = "1.33.0", default-features = false, features = [ +tokio = { version = "1.34.0", default-features = false, features = [ "net", "rt-multi-thread", "time", @@ -41,12 +41,10 @@ anyhow = "1.0.75" thiserror = "1.0.50" # http and tls -hyper = { version = "0.14.27", default-features = false, features = [ - "server", - "http1", - "http2", - "stream", -] } +http = "1.0.0" +http-body-util = "0.1.0" +hyper = { version = "1.0.1", default-features = false } +hyper-util = { version = "0.1.0", features = ["full"] } hyper-rustls = { version = "0.24.2", default-features = false, features = [ "tokio-runtime", "webpki-tokio", @@ -54,7 +52,7 @@ hyper-rustls = { version = "0.24.2", default-features = false, features = [ "http2", ] } tokio-rustls = { version = "0.24.1", features = ["early-data"] } -rustls = { version = "0.21.8", default-features = false } +rustls = { version = "0.21.9", default-features = false } webpki = "0.22.4" x509-parser = "0.15.1" @@ -62,18 +60,16 @@ x509-parser = "0.15.1" tracing = { version = "0.1.40" } # http/3 -# quinn = { version = "0.9.3", optional = true } -quinn = { path = "../submodules/quinn/quinn", optional = true } # Tentative to support rustls-0.21 +quinn = { version = "0.10.2", optional = true } h3 = { path = "../submodules/h3/h3/", optional = true } -# h3-quinn = { path = "./h3/h3-quinn/", optional = true } -h3-quinn = { path = "../submodules/h3-quinn/", optional = true } # Tentative to support rustls-0.21 -# for UDP socket wit SO_REUSEADDR when h3 with quinn -socket2 = { version = "0.5.5", features = ["all"], optional = true } -s2n-quic = { path = "../submodules/s2n-quic/quic/s2n-quic/", default-features = false, features = [ +h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } +s2n-quic = { version = "1.31.0", default-features = false, features = [ "provider-tls-rustls", ], optional = true } -s2n-quic-h3 = { path = "../submodules/s2n-quic/quic/s2n-quic-h3/", optional = true } -s2n-quic-rustls = { path = "../submodules/s2n-quic/quic/s2n-quic-rustls/", optional = true } +s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } +s2n-quic-rustls = { version = "0.31.0", optional = true } +# for UDP socket wit SO_REUSEADDR when h3 with quinn +socket2 = { version = "0.5.5", features = ["all"], optional = true } # cache http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } @@ -90,3 +86,4 @@ sha2 = { version = "0.10.8", default-features = false, optional = true } [dev-dependencies] +# http and tls diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index d1c0130..02605a6 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -33,6 +33,9 @@ where /// Shared context - Async task runtime handler pub runtime_handle: tokio::runtime::Handle, + + /// Shared context - Notify object to stop async tasks + pub term_notify: Option>, } /// Configuration parameters for proxy transport and request handlers diff --git a/rpxy-lib/src/handler/error.rs b/rpxy-lib/src/handler/error.rs new file mode 100644 index 0000000..8fb9d79 --- /dev/null +++ b/rpxy-lib/src/handler/error.rs @@ -0,0 +1,16 @@ +use http::StatusCode; +use thiserror::Error; + +pub type HttpResult = std::result::Result; + +/// Describes things that can go wrong in the handler +#[derive(Debug, Error)] +pub enum HttpError {} + +impl From for StatusCode { + fn from(e: HttpError) -> StatusCode { + match e { + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} diff --git a/rpxy-lib/src/handler/handler_main.rs b/rpxy-lib/src/handler/handler_main.rs index 8b13dc7..2720c2f 100644 --- a/rpxy-lib/src/handler/handler_main.rs +++ b/rpxy-lib/src/handler/handler_main.rs @@ -1,9 +1,10 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy use super::{ - forwarder::{ForwardRequest, Forwarder}, + error::*, + // forwarder::{ForwardRequest, Forwarder}, utils_headers::*, utils_request::*, - utils_synth_response::*, + // utils_synth_response::*, HandlerContext, }; use crate::{ @@ -16,365 +17,368 @@ use crate::{ utils::ServerNameBytesExp, }; use derive_builder::Builder; -use hyper::{ - client::connect::Connect, +use http::{ header::{self, HeaderValue}, - http::uri::Scheme, - Body, Request, Response, StatusCode, Uri, Version, + uri::Scheme, + Request, Response, StatusCode, Uri, Version, }; +use hyper::body::Incoming; +use hyper_util::client::legacy::connect::Connect; use std::{net::SocketAddr, sync::Arc}; use tokio::{io::copy_bidirectional, time::timeout}; #[derive(Clone, Builder)] /// HTTP message handler for requests from clients and responses from backend applications, /// responsible to manipulate and forward messages to upstream backends and downstream clients. -pub struct HttpMessageHandler +// pub struct HttpMessageHandler +pub struct HttpMessageHandler where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone, { - forwarder: Arc>, + // forwarder: Arc>, globals: Arc>, } -impl HttpMessageHandler +impl HttpMessageHandler where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone, { - /// Return with an arbitrary status code of error and log message - fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result> { - log_data.status_code(&status_code).output(); - http_error(status_code) - } + // /// Return with an arbitrary status code of error and log message + // fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result> { + // log_data.status_code(&status_code).output(); + // http_error(status_code) + // } /// Handle incoming request message from a client pub async fn handle_request( &self, - mut req: Request, + mut req: Request, client_addr: SocketAddr, // アクセス制御用 listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, - ) -> Result> { + ) -> Result>> { //////// let mut log_data = MessageLog::from(&req); log_data.client_addr(&client_addr); ////// - // Here we start to handle with server_name - let server_name = if let Ok(v) = req.parse_host() { - ServerNameBytesExp::from(v) - } else { - return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - }; - // check consistency of between TLS SNI and HOST/Request URI Line. - #[allow(clippy::collapsible_if)] - if tls_enabled && self.globals.proxy_config.sni_consistency { - if server_name != tls_server_name.unwrap_or_default() { - return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); - } - } - // Find backend application for given server_name, and drop if incoming request is invalid as request. - let backend = match self.globals.backends.apps.get(&server_name) { - Some(be) => be, - None => { - let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else { - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - }; - debug!("Serving by default app"); - self.globals.backends.apps.get(default_server_name).unwrap() - } - }; + // // Here we start to handle with server_name + // let server_name = if let Ok(v) = req.parse_host() { + // ServerNameBytesExp::from(v) + // } else { + // return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); + // }; + // // check consistency of between TLS SNI and HOST/Request URI Line. + // #[allow(clippy::collapsible_if)] + // if tls_enabled && self.globals.proxy_config.sni_consistency { + // if server_name != tls_server_name.unwrap_or_default() { + // return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); + // } + // } + // // Find backend application for given server_name, and drop if incoming request is invalid as request. + // let backend = match self.globals.backends.apps.get(&server_name) { + // Some(be) => be, + // None => { + // let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else { + // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + // }; + // debug!("Serving by default app"); + // self.globals.backends.apps.get(default_server_name).unwrap() + // } + // }; - // Redirect to https if !tls_enabled and redirect_to_https is true - if !tls_enabled && backend.https_redirection.unwrap_or(false) { - debug!("Redirect to secure connection: {}", &backend.server_name); - log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output(); - return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req); - } + // // Redirect to https if !tls_enabled and redirect_to_https is true + // if !tls_enabled && backend.https_redirection.unwrap_or(false) { + // debug!("Redirect to secure connection: {}", &backend.server_name); + // log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output(); + // return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req); + // } - // Find reverse proxy for given path and choose one of upstream host - // Longest prefix match - let path = req.uri().path(); - let Some(upstream_group) = backend.reverse_proxy.get(path) else { - return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data) - }; + // // Find reverse proxy for given path and choose one of upstream host + // // Longest prefix match + // let path = req.uri().path(); + // let Some(upstream_group) = backend.reverse_proxy.get(path) else { + // return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data); + // }; - // Upgrade in request header - let upgrade_in_request = extract_upgrade(req.headers()); - let request_upgraded = req.extensions_mut().remove::(); + // // Upgrade in request header + // let upgrade_in_request = extract_upgrade(req.headers()); + // let request_upgraded = req.extensions_mut().remove::(); - // Build request from destination information - let _context = match self.generate_request_forwarded( - &client_addr, - &listen_addr, - &mut req, - &upgrade_in_request, - upstream_group, - tls_enabled, - ) { - Err(e) => { - error!("Failed to generate destination uri for reverse proxy: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - } - Ok(v) => v, - }; - debug!("Request to be forwarded: {:?}", req); - log_data.xff(&req.headers().get("x-forwarded-for")); - log_data.upstream(req.uri()); - ////// + // // Build request from destination information + // let _context = match self.generate_request_forwarded( + // &client_addr, + // &listen_addr, + // &mut req, + // &upgrade_in_request, + // upstream_group, + // tls_enabled, + // ) { + // Err(e) => { + // error!("Failed to generate destination uri for reverse proxy: {}", e); + // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + // } + // Ok(v) => v, + // }; + // debug!("Request to be forwarded: {:?}", req); + // log_data.xff(&req.headers().get("x-forwarded-for")); + // log_data.upstream(req.uri()); + // ////// - // Forward request to a chosen backend - let mut res_backend = { - let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { - return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); - }; - match result { - Ok(res) => res, - Err(e) => { - error!("Failed to get response from backend: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - } - } - }; + // // Forward request to a chosen backend + // let mut res_backend = { + // let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { + // return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); + // }; + // match result { + // Ok(res) => res, + // Err(e) => { + // error!("Failed to get response from backend: {}", e); + // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + // } + // } + // }; - // Process reverse proxy context generated during the forwarding request generation. - #[cfg(feature = "sticky-cookie")] - if let Some(context_from_lb) = _context.context_lb { - let res_headers = res_backend.headers_mut(); - if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { - error!("Failed to append context to the response given from backend: {}", e); - return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); - } - } + // // Process reverse proxy context generated during the forwarding request generation. + // #[cfg(feature = "sticky-cookie")] + // if let Some(context_from_lb) = _context.context_lb { + // let res_headers = res_backend.headers_mut(); + // if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + // error!("Failed to append context to the response given from backend: {}", e); + // return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); + // } + // } - if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { - // Generate response to client - if self.generate_response_forwarded(&mut res_backend, backend).is_err() { - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - } - log_data.status_code(&res_backend.status()).output(); - return Ok(res_backend); - } + // if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { + // // Generate response to client + // if self.generate_response_forwarded(&mut res_backend, backend).is_err() { + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // } + // log_data.status_code(&res_backend.status()).output(); + // return Ok(res_backend); + // } - // Handle StatusCode::SWITCHING_PROTOCOLS in response - let upgrade_in_response = extract_upgrade(res_backend.headers()); - let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) - { - u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() - } else { - false - }; - if !should_upgrade { - error!( - "Backend tried to switch to protocol {:?} when {:?} was requested", - upgrade_in_response, upgrade_in_request - ); - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - } - let Some(request_upgraded) = request_upgraded else { - error!("Request does not have an upgrade extension"); - return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - }; - let Some(onupgrade) = res_backend.extensions_mut().remove::() else { - error!("Response does not have an upgrade extension"); - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - }; + // // Handle StatusCode::SWITCHING_PROTOCOLS in response + // let upgrade_in_response = extract_upgrade(res_backend.headers()); + // let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) + // { + // u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() + // } else { + // false + // }; + // if !should_upgrade { + // error!( + // "Backend tried to switch to protocol {:?} when {:?} was requested", + // upgrade_in_response, upgrade_in_request + // ); + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // } + // let Some(request_upgraded) = request_upgraded else { + // error!("Request does not have an upgrade extension"); + // return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); + // }; + // let Some(onupgrade) = res_backend.extensions_mut().remove::() else { + // error!("Response does not have an upgrade extension"); + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // }; - self.globals.runtime_handle.spawn(async move { - let mut response_upgraded = onupgrade.await.map_err(|e| { - error!("Failed to upgrade response: {}", e); - RpxyError::Hyper(e) - })?; - let mut request_upgraded = request_upgraded.await.map_err(|e| { - error!("Failed to upgrade request: {}", e); - RpxyError::Hyper(e) - })?; - copy_bidirectional(&mut response_upgraded, &mut request_upgraded) - .await - .map_err(|e| { - error!("Coping between upgraded connections failed: {}", e); - RpxyError::Io(e) - })?; - Ok(()) as Result<()> - }); - log_data.status_code(&res_backend.status()).output(); - Ok(res_backend) + // self.globals.runtime_handle.spawn(async move { + // let mut response_upgraded = onupgrade.await.map_err(|e| { + // error!("Failed to upgrade response: {}", e); + // RpxyError::Hyper(e) + // })?; + // let mut request_upgraded = request_upgraded.await.map_err(|e| { + // error!("Failed to upgrade request: {}", e); + // RpxyError::Hyper(e) + // })?; + // copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + // .await + // .map_err(|e| { + // error!("Coping between upgraded connections failed: {}", e); + // RpxyError::Io(e) + // })?; + // Ok(()) as Result<()> + // }); + // log_data.status_code(&res_backend.status()).output(); + // Ok(res_backend) + todo!() } //////////////////////////////////////////////////// // Functions to generate messages //////////////////////////////////////////////////// - /// Manipulate a response message sent from a backend application to forward downstream to a client. - fn generate_response_forwarded(&self, response: &mut Response, chosen_backend: &Backend) -> Result<()> - where - B: core::fmt::Debug, - { - let headers = response.headers_mut(); - remove_connection_header(headers); - remove_hop_header(headers); - add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; + // /// Manipulate a response message sent from a backend application to forward downstream to a client. + // fn generate_response_forwarded(&self, response: &mut Response, chosen_backend: &Backend) -> Result<()> + // where + // B: core::fmt::Debug, + // { + // let headers = response.headers_mut(); + // remove_connection_header(headers); + // remove_hop_header(headers); + // add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; - #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] - { - // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled - // TODO: This is a workaround for avoiding a client authentication in HTTP/3 - if self.globals.proxy_config.http3 - && chosen_backend - .crypto_source - .as_ref() - .is_some_and(|v| !v.is_mutual_tls()) - { - if let Some(port) = self.globals.proxy_config.https_port { - add_header_entry_overwrite_if_exist( - headers, - header::ALT_SVC.as_str(), - format!( - "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", - port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age - ), - )?; - } - } else { - // remove alt-svc to disallow requests via http3 - headers.remove(header::ALT_SVC.as_str()); - } - } - #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] - { - if let Some(port) = self.globals.proxy_config.https_port { - headers.remove(header::ALT_SVC.as_str()); - } - } + // #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + // { + // // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled + // // TODO: This is a workaround for avoiding a client authentication in HTTP/3 + // if self.globals.proxy_config.http3 + // && chosen_backend + // .crypto_source + // .as_ref() + // .is_some_and(|v| !v.is_mutual_tls()) + // { + // if let Some(port) = self.globals.proxy_config.https_port { + // add_header_entry_overwrite_if_exist( + // headers, + // header::ALT_SVC.as_str(), + // format!( + // "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", + // port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age + // ), + // )?; + // } + // } else { + // // remove alt-svc to disallow requests via http3 + // headers.remove(header::ALT_SVC.as_str()); + // } + // } + // #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + // { + // if let Some(port) = self.globals.proxy_config.https_port { + // headers.remove(header::ALT_SVC.as_str()); + // } + // } - Ok(()) - } + // Ok(()) + // } - #[allow(clippy::too_many_arguments)] - /// Manipulate a request message sent from a client to forward upstream to a backend application - fn generate_request_forwarded( - &self, - client_addr: &SocketAddr, - listen_addr: &SocketAddr, - req: &mut Request, - upgrade: &Option, - upstream_group: &UpstreamGroup, - tls_enabled: bool, - ) -> Result { - debug!("Generate request to be forwarded"); + // #[allow(clippy::too_many_arguments)] + // /// Manipulate a request message sent from a client to forward upstream to a backend application + // fn generate_request_forwarded( + // &self, + // client_addr: &SocketAddr, + // listen_addr: &SocketAddr, + // req: &mut Request, + // upgrade: &Option, + // upstream_group: &UpstreamGroup, + // tls_enabled: bool, + // ) -> Result { + // debug!("Generate request to be forwarded"); - // Add te: trailer if contained in original request - let contains_te_trailers = { - if let Some(te) = req.headers().get(header::TE) { - te.as_bytes() - .split(|v| v == &b',' || v == &b' ') - .any(|x| x == "trailers".as_bytes()) - } else { - false - } - }; + // // Add te: trailer if contained in original request + // let contains_te_trailers = { + // if let Some(te) = req.headers().get(header::TE) { + // te.as_bytes() + // .split(|v| v == &b',' || v == &b' ') + // .any(|x| x == "trailers".as_bytes()) + // } else { + // false + // } + // }; - let uri = req.uri().to_string(); - let headers = req.headers_mut(); - // delete headers specified in header.connection - remove_connection_header(headers); - // delete hop headers including header.connection - remove_hop_header(headers); - // X-Forwarded-For - add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; + // let uri = req.uri().to_string(); + // let headers = req.headers_mut(); + // // delete headers specified in header.connection + // remove_connection_header(headers); + // // delete hop headers including header.connection + // remove_hop_header(headers); + // // X-Forwarded-For + // add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; - // Add te: trailer if te_trailer - if contains_te_trailers { - headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); - } + // // Add te: trailer if te_trailer + // if contains_te_trailers { + // headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); + // } - // add "host" header of original server_name if not exist (default) - if req.headers().get(header::HOST).is_none() { - let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); - req - .headers_mut() - .insert(header::HOST, HeaderValue::from_str(&org_host)?); - }; + // // add "host" header of original server_name if not exist (default) + // if req.headers().get(header::HOST).is_none() { + // let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); + // req + // .headers_mut() + // .insert(header::HOST, HeaderValue::from_str(&org_host)?); + // }; - ///////////////////////////////////////////// - // Fix unique upstream destination since there could be multiple ones. - #[cfg(feature = "sticky-cookie")] - let (upstream_chosen_opt, context_from_lb) = { - let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { - takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? - } else { - None - }; - upstream_group.get(&context_to_lb) - }; - #[cfg(not(feature = "sticky-cookie"))] - let (upstream_chosen_opt, _) = upstream_group.get(&None); + // ///////////////////////////////////////////// + // // Fix unique upstream destination since there could be multiple ones. + // #[cfg(feature = "sticky-cookie")] + // let (upstream_chosen_opt, context_from_lb) = { + // let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { + // takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + // } else { + // None + // }; + // upstream_group.get(&context_to_lb) + // }; + // #[cfg(not(feature = "sticky-cookie"))] + // let (upstream_chosen_opt, _) = upstream_group.get(&None); - let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; - let context = HandlerContext { - #[cfg(feature = "sticky-cookie")] - context_lb: context_from_lb, - #[cfg(not(feature = "sticky-cookie"))] - context_lb: None, - }; - ///////////////////////////////////////////// + // let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + // let context = HandlerContext { + // #[cfg(feature = "sticky-cookie")] + // context_lb: context_from_lb, + // #[cfg(not(feature = "sticky-cookie"))] + // context_lb: None, + // }; + // ///////////////////////////////////////////// - // apply upstream-specific headers given in upstream_option - let headers = req.headers_mut(); - apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?; + // // apply upstream-specific headers given in upstream_option + // let headers = req.headers_mut(); + // apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?; - // update uri in request - if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) { - return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken")); - }; - let new_uri = Uri::builder() - .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) - .authority(upstream_chosen.uri.authority().unwrap().as_str()); - let org_pq = match req.uri().path_and_query() { - Some(pq) => pq.to_string(), - None => "/".to_string(), - } - .into_bytes(); + // // update uri in request + // if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) { + // return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken")); + // }; + // let new_uri = Uri::builder() + // .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) + // .authority(upstream_chosen.uri.authority().unwrap().as_str()); + // let org_pq = match req.uri().path_and_query() { + // Some(pq) => pq.to_string(), + // None => "/".to_string(), + // } + // .into_bytes(); - // replace some parts of path if opt_replace_path is enabled for chosen upstream - let new_pq = match &upstream_group.replace_path { - Some(new_path) => { - let matched_path: &[u8] = upstream_group.path.as_ref(); - if matched_path.is_empty() || org_pq.len() < matched_path.len() { - return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); - }; - let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); - new_pq.extend_from_slice(new_path.as_ref()); - new_pq.extend_from_slice(&org_pq[matched_path.len()..]); - new_pq - } - None => org_pq, - }; - *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; + // // replace some parts of path if opt_replace_path is enabled for chosen upstream + // let new_pq = match &upstream_group.replace_path { + // Some(new_path) => { + // let matched_path: &[u8] = upstream_group.path.as_ref(); + // if matched_path.is_empty() || org_pq.len() < matched_path.len() { + // return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); + // }; + // let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); + // new_pq.extend_from_slice(new_path.as_ref()); + // new_pq.extend_from_slice(&org_pq[matched_path.len()..]); + // new_pq + // } + // None => org_pq, + // }; + // *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; - // upgrade - if let Some(v) = upgrade { - req.headers_mut().insert(header::UPGRADE, v.parse()?); - req - .headers_mut() - .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?); - } + // // upgrade + // if let Some(v) = upgrade { + // req.headers_mut().insert(header::UPGRADE, v.parse()?); + // req + // .headers_mut() + // .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?); + // } - // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 - if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { - // Change version to http/1.1 when destination scheme is http - debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); - *req.version_mut() = Version::HTTP_11; - } else if req.version() == Version::HTTP_3 { - // HTTP/3 is always https - debug!("HTTP/3 is currently unsupported for request to upstream."); - *req.version_mut() = Version::HTTP_2; - } + // // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 + // if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { + // // Change version to http/1.1 when destination scheme is http + // debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); + // *req.version_mut() = Version::HTTP_11; + // } else if req.version() == Version::HTTP_3 { + // // HTTP/3 is always https + // debug!("HTTP/3 is currently unsupported for request to upstream."); + // *req.version_mut() = Version::HTTP_2; + // } - apply_upstream_options_to_request_line(req, upstream_group)?; + // apply_upstream_options_to_request_line(req, upstream_group)?; - Ok(context) - } + // Ok(context) + // } } diff --git a/rpxy-lib/src/handler/mod.rs b/rpxy-lib/src/handler/mod.rs index 84e0226..2ae5aba 100644 --- a/rpxy-lib/src/handler/mod.rs +++ b/rpxy-lib/src/handler/mod.rs @@ -1,17 +1,15 @@ #[cfg(feature = "cache")] -mod cache; -mod forwarder; +// mod cache; +mod error; +// mod forwarder; mod handler_main; mod utils_headers; mod utils_request; -mod utils_synth_response; +// mod utils_synth_response; #[cfg(feature = "sticky-cookie")] use crate::backend::LbContext; -pub use { - forwarder::Forwarder, - handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}, -}; +pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; #[allow(dead_code)] #[derive(Debug)] diff --git a/rpxy-lib/src/hyper_executor.rs b/rpxy-lib/src/hyper_executor.rs new file mode 100644 index 0000000..152bbe9 --- /dev/null +++ b/rpxy-lib/src/hyper_executor.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; +use tokio::runtime::Handle; + +use crate::{globals::Globals, CryptoSource}; + +#[derive(Clone)] +/// Executor for hyper +pub struct LocalExecutor { + runtime_handle: Handle, +} + +impl LocalExecutor { + pub fn new(runtime_handle: Handle) -> Self { + LocalExecutor { runtime_handle } + } +} + +impl hyper::rt::Executor for LocalExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send, +{ + fn execute(&self, fut: F) { + self.runtime_handle.spawn(fut); + } +} + +/// build connection builder shared with proxy instances +pub(crate) fn build_http_server(globals: &Arc>) -> ConnectionBuilder +where + T: CryptoSource, +{ + let executor = LocalExecutor::new(globals.runtime_handle.clone()); + let mut http_server = server::conn::auto::Builder::new(executor); + http_server + .http1() + .keep_alive(globals.proxy_config.keepalive) + .pipeline_flush(true); + http_server + .http2() + .max_concurrent_streams(globals.proxy_config.max_concurrent_streams); + http_server +} diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index fd242c5..7f7ade2 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -4,20 +4,16 @@ mod constants; mod error; mod globals; mod handler; +mod hyper_executor; mod log; mod proxy; mod utils; -use crate::{ - error::*, - globals::Globals, - handler::{Forwarder, HttpMessageHandlerBuilder}, - log::*, - proxy::ProxyBuilder, -}; +use crate::{error::*, globals::Globals, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder}; use futures::future::select_all; +use hyper_executor::build_http_server; // use hyper_trust_dns::TrustDnsResolver; -use std::sync::Arc; +use std::{sync::Arc, time::Duration}; pub use crate::{ certs::{CertsAndKeys, CryptoSource}, @@ -76,16 +72,19 @@ where backends: app_config_list.clone().try_into()?, request_count: Default::default(), runtime_handle: runtime_handle.clone(), + term_notify: term_notify.clone(), }); // build message handler including a request forwarder let msg_handler = Arc::new( HttpMessageHandlerBuilder::default() - .forwarder(Arc::new(Forwarder::new(&globals).await)) + // .forwarder(Arc::new(Forwarder::new(&globals).await)) .globals(globals.clone()) .build()?, ); + let http_server = Arc::new(build_http_server(&globals)); + let addresses = globals.proxy_config.listen_sockets.clone(); let futures = select_all(addresses.into_iter().map(|addr| { let mut tls_enabled = false; @@ -97,16 +96,17 @@ where .globals(globals.clone()) .listening_on(addr) .tls_enabled(tls_enabled) + .http_server(http_server.clone()) .msg_handler(msg_handler.clone()) .build() .unwrap(); - globals.runtime_handle.spawn(proxy.start(term_notify.clone())) + globals.runtime_handle.spawn(async move { proxy.start().await }) })); // wait for all future if let (Ok(Err(e)), _, _) = futures.await { - error!("Some proxy services are down: {:?}", e); + error!("Some proxy services are down: {}", e); }; Ok(()) diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 0551b62..c89c394 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -10,4 +10,33 @@ mod proxy_quic_s2n; mod proxy_tls; mod socket; +use crate::error::*; +use http::{Response, StatusCode}; +use http_body_util::{combinators, BodyExt, Either, Empty}; +use hyper::body::{Bytes, Incoming}; + pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError}; + +/// Type for synthetic boxed body +type BoxBody = combinators::BoxBody; +/// Type for either passthrough body or synthetic body +type EitherBody = Either; + +/// helper function to build http response with passthrough body +fn passthrough_response(response: Response) -> Result> { + Ok(response.map(EitherBody::Left)) +} + +/// build http response with status code of 4xx and 5xx +fn synthetic_error_response(status_code: StatusCode) -> Result> { + let res = Response::builder() + .status(status_code) + .body(EitherBody::Right(BoxBody::new(empty()))) + .unwrap(); + Ok(res) +} + +/// helper function to build a empty body +fn empty() -> BoxBody { + Empty::::new().map_err(|never| match never {}).boxed() +} diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index fd07521..699938b 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,17 +1,21 @@ use super::Proxy; use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; use bytes::{Buf, Bytes}; +use futures::Stream; #[cfg(feature = "http3-quinn")] use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -use hyper::{client::connect::Connect, Body, Request, Response}; +use http::{Request, Response}; +use http_body_util::{BodyExt, BodyStream, StreamBody}; +use hyper::body::{Body, Incoming}; +use hyper_util::client::legacy::connect::Connect; #[cfg(feature = "http3-s2n")] use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; use std::net::SocketAddr; use tokio::time::{timeout, Duration}; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn connection_serve_h3( @@ -89,18 +93,36 @@ where S: BidiStream + Send + 'static, >::RecvStream: Send, { + println!("stream_serve_h3"); let (req_parts, _) = req.into_parts(); // split stream and async body handling let (mut send_stream, mut recv_stream) = stream.split(); - // generate streamed body with trailers using channel - let (body_sender, req_body) = Body::channel(); + // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; + // // let max = body_stream.size_hint().upper().unwrap_or(u64::MAX); + // // if max > max_body_size as u64 { + // // return Err(HttpError::TooLargeRequestBody); + // // } + // let new_req = Request::from_parts(req_parts, body_stream); + + //////////////////// + // TODO: TODO: TODO: TODO: + // TODO: Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside. + // Thus, we need to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of + // Either as body. + // Also, the downstream from the backend handler could be Incoming, but will be wrapped as Either as well due to H3. + // Result, E> type includes E as HttpError to generate the status code and related Response. + // Thus to handle synthetic error messages in BoxBody, the serve() function outputs Response, BoxBody>>>. + //////////////////// + + // // generate streamed body with trailers using channel + // let (body_sender, req_body) = Incoming::channel(); // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. let max_body_size = self.globals.proxy_config.h3_request_max_body_size; self.globals.runtime_handle.spawn(async move { - let mut sender = body_sender; + // let mut sender = body_sender; let mut size = 0usize; while let Some(mut body) = recv_stream.recv_data().await? { debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); @@ -113,51 +135,52 @@ where return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); } // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes - sender.send_data(body.copy_to_bytes(body.remaining())).await?; + // sender.send_data(body.copy_to_bytes(body.remaining())).await?; } // trailers: use inner for work around. (directly get trailer) let trailers = recv_stream.as_mut().recv_trailers().await?; if trailers.is_some() { debug!("HTTP/3 incoming request trailers"); - sender.send_trailers(trailers.unwrap()).await?; + // sender.send_trailers(trailers.unwrap()).await?; } Ok(()) }); - let new_req: Request = Request::from_parts(req_parts, req_body); - let res = self - .msg_handler - .clone() - .handle_request( - new_req, - client_addr, - self.listening_on, - self.tls_enabled, - Some(tls_server_name), - ) - .await?; + // let new_req: Request = Request::from_parts(req_parts, req_body); + // let res = self + // .msg_handler + // .clone() + // .handle_request( + // new_req, + // client_addr, + // self.listening_on, + // self.tls_enabled, + // Some(tls_server_name), + // ) + // .await?; - let (new_res_parts, new_body) = res.into_parts(); - let new_res = Response::from_parts(new_res_parts, ()); + // let (new_res_parts, new_body) = res.into_parts(); + // let new_res = Response::from_parts(new_res_parts, ()); - match send_stream.send_response(new_res).await { - Ok(_) => { - debug!("HTTP/3 response to connection successful"); - // aggregate body without copying - let mut body_data = hyper::body::aggregate(new_body).await?; + // match send_stream.send_response(new_res).await { + // Ok(_) => { + // debug!("HTTP/3 response to connection successful"); + // // aggregate body without copying + // let body_data = new_body.collect().await?.aggregate(); - // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes - send_stream - .send_data(body_data.copy_to_bytes(body_data.remaining())) - .await?; + // // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes + // send_stream + // .send_data(body_data.copy_to_bytes(body_data.remaining())) + // .await?; - // TODO: needs handling trailer? should be included in body from handler. - } - Err(err) => { - error!("Unable to send response to connection peer: {:?}", err); - } - } - Ok(send_stream.finish().await?) + // // TODO: needs handling trailer? should be included in body from handler. + // } + // Err(err) => { + // error!("Unable to send response to connection peer: {:?}", err); + // } + // } + // Ok(send_stream.finish().await?) + todo!() } } diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index bd52ea9..ec1008a 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,78 +1,70 @@ -use super::socket::bind_tcp_socket; +use super::{passthrough_response, socket::bind_tcp_socket, synthetic_error_response, EitherBody}; use crate::{ - certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp, + certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, hyper_executor::LocalExecutor, log::*, + utils::ServerNameBytesExp, }; use derive_builder::{self, Builder}; -use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; -use std::{net::SocketAddr, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - runtime::Handle, - sync::Notify, - time::{timeout, Duration}, +use http::{Request, StatusCode}; +use hyper::{ + body::Incoming, + rt::{Read, Write}, + service::service_fn, }; - -#[derive(Clone)] -pub struct LocalExecutor { - runtime_handle: Handle, -} - -impl LocalExecutor { - fn new(runtime_handle: Handle) -> Self { - LocalExecutor { runtime_handle } - } -} - -impl hyper::rt::Executor for LocalExecutor -where - F: std::future::Future + Send + 'static, - F::Output: Send, -{ - fn execute(&self, fut: F) { - self.runtime_handle.spawn(fut); - } -} +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::time::{timeout, Duration}; #[derive(Clone, Builder)] -pub struct Proxy +/// Proxy main object +pub struct Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub listening_on: SocketAddr, pub tls_enabled: bool, // TCP待受がTLSかどうか - pub msg_handler: Arc>, + /// hyper server receiving http request + pub http_server: Arc>, + // pub msg_handler: Arc>, + pub msg_handler: Arc>, pub globals: Arc>, } -impl Proxy +/// Wrapper function to handle request +async fn serve_request( + req: Request, + // handler: Arc>, + handler: Arc>, + client_addr: SocketAddr, + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, +) -> Result> where - T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, +{ + match handler + .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) + .await? + { + Ok(res) => passthrough_response(res), + Err(e) => synthetic_error_response(StatusCode::from(e)), + } +} + +impl Proxy +where + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send, { - /// Wrapper function to handle request - async fn serve( - handler: Arc>, - req: Request, - client_addr: SocketAddr, - listen_addr: SocketAddr, - tls_enabled: bool, - tls_server_name: Option, - ) -> Result> { - handler - .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) - .await - } - /// Serves requests from clients - pub(super) fn client_serve( - self, + pub(super) fn serve_connection( + &self, stream: I, - server: Http, peer_addr: SocketAddr, tls_server_name: Option, ) where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + I: Read + Write + Send + Unpin + 'static, { let request_count = self.globals.request_count.clone(); if request_count.increment() > self.globals.proxy_config.max_clients { @@ -81,24 +73,27 @@ where } debug!("Request incoming: current # {}", request_count.current()); + let server_clone = self.http_server.clone(); + let msg_handler_clone = self.msg_handler.clone(); + let timeout_sec = self.globals.proxy_config.proxy_timeout; + let tls_enabled = self.tls_enabled; + let listening_on = self.listening_on; self.globals.runtime_handle.clone().spawn(async move { timeout( - self.globals.proxy_config.proxy_timeout + Duration::from_secs(1), - server - .serve_connection( - stream, - service_fn(move |req: Request| { - Self::serve( - self.msg_handler.clone(), - req, - peer_addr, - self.listening_on, - self.tls_enabled, - tls_server_name.clone(), - ) - }), - ) - .with_upgrades(), + timeout_sec + Duration::from_secs(1), + server_clone.serve_connection_with_upgrades( + stream, + service_fn(move |req: Request| { + serve_request( + req, + msg_handler_clone.clone(), + peer_addr, + listening_on, + tls_enabled, + tls_server_name.clone(), + ) + }), + ), ) .await .ok(); @@ -109,13 +104,13 @@ where } /// Start without TLS (HTTP cleartext) - async fn start_without_tls(self, server: Http) -> Result<()> { + async fn start_without_tls(&self) -> Result<()> { let listener_service = async { let tcp_socket = bind_tcp_socket(&self.listening_on)?; let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; info!("Start TCP proxy serving with HTTP request for configured host names"); - while let Ok((stream, _client_addr)) = tcp_listener.accept().await { - self.clone().client_serve(stream, server.clone(), _client_addr, None); + while let Ok((stream, client_addr)) = tcp_listener.accept().await { + self.serve_connection(TokioIo::new(stream), client_addr, None); } Ok(()) as Result<()> }; @@ -124,32 +119,23 @@ where } /// Entrypoint for HTTP/1.1 and HTTP/2 servers - pub async fn start(self, term_notify: Option>) -> Result<()> { - let mut server = Http::new(); - server.http1_keep_alive(self.globals.proxy_config.keepalive); - server.http2_max_concurrent_streams(self.globals.proxy_config.max_concurrent_streams); - server.pipeline_flush(true); - let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); - let server = server.with_executor(executor); - - let listening_on = self.listening_on; - + pub async fn start(&self) -> Result<()> { let proxy_service = async { if self.tls_enabled { - self.start_with_tls(server).await + self.start_with_tls().await } else { - self.start_without_tls(server).await + self.start_without_tls().await } }; - match term_notify { + match &self.globals.term_notify { Some(term) => { tokio::select! { _ = proxy_service => { warn!("Proxy service got down"); } _ = term.notified() => { - info!("Proxy service listening on {} receives term signal", listening_on); + info!("Proxy service listening on {} receives term signal", self.listening_on); } } } @@ -159,8 +145,6 @@ where } } - // proxy_service.await?; - Ok(()) } } diff --git a/rpxy-lib/src/proxy/proxy_quic_quinn.rs b/rpxy-lib/src/proxy/proxy_quic_quinn.rs index fb08420..1828e5f 100644 --- a/rpxy-lib/src/proxy/proxy_quic_quinn.rs +++ b/rpxy-lib/src/proxy/proxy_quic_quinn.rs @@ -5,14 +5,14 @@ use super::{ }; use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; use hot_reload::ReloaderReceiver; -use hyper::client::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; use rustls::ServerConfig; use std::sync::Arc; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn listener_service_h3( diff --git a/rpxy-lib/src/proxy/proxy_quic_s2n.rs b/rpxy-lib/src/proxy/proxy_quic_s2n.rs index e0c41a5..d1d1580 100644 --- a/rpxy-lib/src/proxy/proxy_quic_s2n.rs +++ b/rpxy-lib/src/proxy/proxy_quic_s2n.rs @@ -4,13 +4,13 @@ use super::{ }; use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; use hot_reload::ReloaderReceiver; -use hyper::client::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use s2n_quic::provider; use std::sync::Arc; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn listener_service_h3( @@ -29,7 +29,7 @@ where // event loop loop { tokio::select! { - v = self.serve_connection(&server_crypto) => { + v = self.listener_service_h3_inner(&server_crypto) => { if let Err(e) = v { error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); break; @@ -64,7 +64,7 @@ where }) } - async fn serve_connection(&self, server_crypto: &Option>) -> Result<()> { + async fn listener_service_h3_inner(&self, server_crypto: &Option>) -> Result<()> { // setup UDP socket let io = provider::io::tokio::Builder::default() .with_receive_address(self.listening_on)? @@ -110,9 +110,9 @@ where while let Some(new_conn) = server.accept().await { debug!("New QUIC connection established"); let Ok(Some(new_server_name)) = new_conn.server_name() else { - warn!("HTTP/3 no SNI is given"); - continue; - }; + warn!("HTTP/3 no SNI is given"); + continue; + }; debug!("HTTP/3 connection incoming (SNI {:?})", new_server_name); let self_clone = self.clone(); diff --git a/rpxy-lib/src/proxy/proxy_tls.rs b/rpxy-lib/src/proxy/proxy_tls.rs index 7c5d601..6ed6212 100644 --- a/rpxy-lib/src/proxy/proxy_tls.rs +++ b/rpxy-lib/src/proxy/proxy_tls.rs @@ -1,25 +1,21 @@ use super::{ crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, - proxy_main::{LocalExecutor, Proxy}, + proxy_main::Proxy, socket::bind_tcp_socket, }; use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName}; use hot_reload::{ReloaderReceiver, ReloaderService}; -use hyper::{client::connect::Connect, server::conn::Http}; +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; use std::sync::Arc; use tokio::time::{timeout, Duration}; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { // TCP Listener Service, i.e., http/2 and http/1.1 - async fn listener_service( - &self, - server: Http, - mut server_crypto_rx: ReloaderReceiver, - ) -> Result<()> { + async fn listener_service(&self, mut server_crypto_rx: ReloaderReceiver) -> Result<()> { let tcp_socket = bind_tcp_socket(&self.listening_on)?; let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; info!("Start TCP proxy serving with HTTPS request for configured host names"); @@ -33,7 +29,6 @@ where } let (raw_stream, client_addr) = tcp_cnx.unwrap(); let sc_map_inner = server_crypto_map.clone(); - let server_clone = server.clone(); let self_inner = self.clone(); // spawns async handshake to avoid blocking thread by sequential handshake. @@ -55,30 +50,27 @@ where return Err(RpxyError::Proxy(format!("No TLS serving app for {:?}", server_name.unwrap()))); } let stream = match start.into_stream(server_crypto.unwrap().clone()).await { - Ok(s) => s, + Ok(s) => TokioIo::new(s), Err(e) => { return Err(RpxyError::Proxy(format!("Failed to handshake TLS: {e}"))); } }; - self_inner.client_serve(stream, server_clone, client_addr, server_name_in_bytes); + self_inner.serve_connection(stream, client_addr, server_name_in_bytes); Ok(()) }; self.globals.runtime_handle.spawn( async move { // timeout is introduced to avoid get stuck here. - match timeout( + let Ok(v) = timeout( Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), handshake_fut - ).await { - Ok(a) => { - if let Err(e) = a { - error!("{}", e); - } - }, - Err(e) => { - error!("Timeout to handshake TLS: {}", e); - } + ).await else { + error!("Timeout to handshake TLS"); + return; }; + if let Err(e) = v { + error!("{}", e); + } }); } _ = server_crypto_rx.changed() => { @@ -99,7 +91,7 @@ where Ok(()) as Result<()> } - pub async fn start_with_tls(self, server: Http) -> Result<()> { + pub async fn start_with_tls(&self) -> Result<()> { let (cert_reloader_service, cert_reloader_rx) = ReloaderService::, ServerCryptoBase>::new( &self.globals.clone(), CERTS_WATCH_DELAY_SECS, @@ -114,7 +106,7 @@ where _= cert_reloader_service.start() => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, cert_reloader_rx) => { + _ = self.listener_service(cert_reloader_rx) => { error!("TCP proxy service for TLS exited"); }, else => { @@ -131,7 +123,7 @@ where _= cert_reloader_service.start() => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, cert_reloader_rx.clone()) => { + _ = self.listener_service(cert_reloader_rx.clone()) => { error!("TCP proxy service for TLS exited"); }, _= self.listener_service_h3(cert_reloader_rx) => { @@ -148,7 +140,7 @@ where _= cert_reloader_service.start() => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, cert_reloader_rx) => { + _ = self.listener_service(cert_reloader_rx) => { error!("TCP proxy service for TLS exited"); }, else => { diff --git a/submodules/h3 b/submodules/h3 index b86df12..5c16195 160000 --- a/submodules/h3 +++ b/submodules/h3 @@ -1 +1 @@ -Subproject commit b86df1220775d13b89cead99e787944b55991b1e +Subproject commit 5c161952b02e663f31f9b83829bafa7a047b6627 diff --git a/submodules/h3-quinn/Cargo.toml b/submodules/h3-quinn/Cargo.toml deleted file mode 100644 index abbb21e..0000000 --- a/submodules/h3-quinn/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "h3-quinn" -version = "0.0.4" -rust-version = "1.63" -authors = ["Jean-Christophe BEGUE "] -edition = "2018" -documentation = "https://docs.rs/h3-quinn" -repository = "https://github.com/hyperium/h3" -readme = "../README.md" -description = "QUIC transport implementation based on Quinn." -keywords = ["http3", "quic", "h3"] -categories = ["network-programming", "web-programming"] -license = "MIT" - -[dependencies] -h3 = { version = "0.0.3", path = "../h3/h3" } -bytes = "1" -quinn = { path = "../quinn/quinn/", default-features = false, features = [ - "futures-io", -] } -quinn-proto = { path = "../quinn/quinn-proto/", default-features = false } -tokio-util = { version = "0.7.9" } -futures = { version = "0.3.28" } -tokio = { version = "1.33.0", features = ["io-util"], default-features = false } diff --git a/submodules/h3-quinn/src/lib.rs b/submodules/h3-quinn/src/lib.rs deleted file mode 100644 index 78696de..0000000 --- a/submodules/h3-quinn/src/lib.rs +++ /dev/null @@ -1,740 +0,0 @@ -//! QUIC Transport implementation with Quinn -//! -//! This module implements QUIC traits with Quinn. -#![deny(missing_docs)] - -use std::{ - convert::TryInto, - fmt::{self, Display}, - future::Future, - pin::Pin, - sync::Arc, - task::{self, Poll}, -}; - -use bytes::{Buf, Bytes, BytesMut}; - -use futures::{ - ready, - stream::{self, BoxStream}, - StreamExt, -}; -use quinn::ReadDatagram; -pub use quinn::{ - self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError, -}; - -use h3::{ - ext::Datagram, - quic::{self, Error, StreamId, WriteBuf}, -}; -use tokio_util::sync::ReusableBoxFuture; - -/// A QUIC connection backed by Quinn -/// -/// Implements a [`quic::Connection`] backed by a [`quinn::Connection`]. -pub struct Connection { - conn: quinn::Connection, - incoming_bi: BoxStream<'static, as Future>::Output>, - opening_bi: Option as Future>::Output>>, - incoming_uni: BoxStream<'static, as Future>::Output>, - opening_uni: Option as Future>::Output>>, - datagrams: BoxStream<'static, as Future>::Output>, -} - -impl Connection { - /// Create a [`Connection`] from a [`quinn::NewConnection`] - pub fn new(conn: quinn::Connection) -> Self { - Self { - conn: conn.clone(), - incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async { - Some((conn.accept_bi().await, conn)) - })), - opening_bi: None, - incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { - Some((conn.accept_uni().await, conn)) - })), - opening_uni: None, - datagrams: Box::pin(stream::unfold(conn, |conn| async { - Some((conn.read_datagram().await, conn)) - })), - } - } -} - -/// The error type for [`Connection`] -/// -/// Wraps reasons a Quinn connection might be lost. -#[derive(Debug)] -pub struct ConnectionError(quinn::ConnectionError); - -impl std::error::Error for ConnectionError {} - -impl fmt::Display for ConnectionError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl Error for ConnectionError { - fn is_timeout(&self) -> bool { - matches!(self.0, quinn::ConnectionError::TimedOut) - } - - fn err_code(&self) -> Option { - match self.0 { - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }) => Some(error_code.into_inner()), - _ => None, - } - } -} - -impl From for ConnectionError { - fn from(e: quinn::ConnectionError) -> Self { - Self(e) - } -} - -/// Types of errors when sending a datagram. -#[derive(Debug)] -pub enum SendDatagramError { - /// Datagrams are not supported by the peer - UnsupportedByPeer, - /// Datagrams are locally disabled - Disabled, - /// The datagram was too large to be sent. - TooLarge, - /// Network error - ConnectionLost(Box), -} - -impl fmt::Display for SendDatagramError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), - SendDatagramError::Disabled => write!(f, "datagram support disabled"), - SendDatagramError::TooLarge => write!(f, "datagram too large"), - SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), - } - } -} - -impl std::error::Error for SendDatagramError {} - -impl Error for SendDatagramError { - fn is_timeout(&self) -> bool { - false - } - - fn err_code(&self) -> Option { - match self { - Self::ConnectionLost(err) => err.err_code(), - _ => None, - } - } -} - -impl From for SendDatagramError { - fn from(value: quinn::SendDatagramError) -> Self { - match value { - quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, - quinn::SendDatagramError::Disabled => Self::Disabled, - quinn::SendDatagramError::TooLarge => Self::TooLarge, - quinn::SendDatagramError::ConnectionLost(err) => { - Self::ConnectionLost(ConnectionError::from(err).into()) - } - } - } -} - -impl quic::Connection for Connection -where - B: Buf, -{ - type SendStream = SendStream; - type RecvStream = RecvStream; - type BidiStream = BidiStream; - type OpenStreams = OpenStreams; - type Error = ConnectionError; - - fn poll_accept_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) { - Some(x) => x?, - None => return Poll::Ready(Ok(None)), - }; - Poll::Ready(Ok(Some(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - }))) - } - - fn poll_accept_recv( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) { - Some(x) => x?, - None => return Poll::Ready(Ok(None)), - }; - Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) - } - - fn poll_open_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_bi.is_none() { - self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.clone().open_bi().await, conn)) - }))); - } - - let (send, recv) = - ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - })) - } - - fn poll_open_send( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_uni.is_none() { - self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_uni().await, conn)) - }))); - } - - let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::SendStream::new(send))) - } - - fn opener(&self) -> Self::OpenStreams { - OpenStreams { - conn: self.conn.clone(), - opening_bi: None, - opening_uni: None, - } - } - - fn close(&mut self, code: h3::error::Code, reason: &[u8]) { - self.conn.close( - VarInt::from_u64(code.value()).expect("error code VarInt"), - reason, - ); - } -} - -impl quic::SendDatagramExt for Connection -where - B: Buf, -{ - type Error = SendDatagramError; - - fn send_datagram(&mut self, data: Datagram) -> Result<(), SendDatagramError> { - // TODO investigate static buffer from known max datagram size - let mut buf = BytesMut::new(); - data.encode(&mut buf); - self.conn.send_datagram(buf.freeze())?; - - Ok(()) - } -} - -impl quic::RecvDatagramExt for Connection { - type Buf = Bytes; - - type Error = ConnectionError; - - #[inline] - fn poll_accept_datagram( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - match ready!(self.datagrams.poll_next_unpin(cx)) { - Some(Ok(x)) => Poll::Ready(Ok(Some(x))), - Some(Err(e)) => Poll::Ready(Err(e.into())), - None => Poll::Ready(Ok(None)), - } - } -} - -/// Stream opener backed by a Quinn connection -/// -/// Implements [`quic::OpenStreams`] using [`quinn::Connection`], -/// [`quinn::OpenBi`], [`quinn::OpenUni`]. -pub struct OpenStreams { - conn: quinn::Connection, - opening_bi: Option as Future>::Output>>, - opening_uni: Option as Future>::Output>>, -} - -impl quic::OpenStreams for OpenStreams -where - B: Buf, -{ - type RecvStream = RecvStream; - type SendStream = SendStream; - type BidiStream = BidiStream; - type Error = ConnectionError; - - fn poll_open_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_bi.is_none() { - self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_bi().await, conn)) - }))); - } - - let (send, recv) = - ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - })) - } - - fn poll_open_send( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_uni.is_none() { - self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_uni().await, conn)) - }))); - } - - let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::SendStream::new(send))) - } - - fn close(&mut self, code: h3::error::Code, reason: &[u8]) { - self.conn.close( - VarInt::from_u64(code.value()).expect("error code VarInt"), - reason, - ); - } -} - -impl Clone for OpenStreams { - fn clone(&self) -> Self { - Self { - conn: self.conn.clone(), - opening_bi: None, - opening_uni: None, - } - } -} - -/// Quinn-backed bidirectional stream -/// -/// Implements [`quic::BidiStream`] which allows the stream to be split -/// into two structs each implementing one direction. -pub struct BidiStream -where - B: Buf, -{ - send: SendStream, - recv: RecvStream, -} - -impl quic::BidiStream for BidiStream -where - B: Buf, -{ - type SendStream = SendStream; - type RecvStream = RecvStream; - - fn split(self) -> (Self::SendStream, Self::RecvStream) { - (self.send, self.recv) - } -} - -impl quic::RecvStream for BidiStream { - type Buf = Bytes; - type Error = ReadError; - - fn poll_data( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - self.recv.poll_data(cx) - } - - fn stop_sending(&mut self, error_code: u64) { - self.recv.stop_sending(error_code) - } - - fn recv_id(&self) -> StreamId { - self.recv.recv_id() - } -} - -impl quic::SendStream for BidiStream -where - B: Buf, -{ - type Error = SendStreamError; - - fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.send.poll_ready(cx) - } - - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.send.poll_finish(cx) - } - - fn reset(&mut self, reset_code: u64) { - self.send.reset(reset_code) - } - - fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { - self.send.send_data(data) - } - - fn send_id(&self) -> StreamId { - self.send.send_id() - } -} -impl quic::SendStreamUnframed for BidiStream -where - B: Buf, -{ - fn poll_send( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut D, - ) -> Poll> { - self.send.poll_send(cx, buf) - } -} - -/// Quinn-backed receive stream -/// -/// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`]. -pub struct RecvStream { - stream: Option, - read_chunk_fut: ReadChunkFuture, -} - -type ReadChunkFuture = ReusableBoxFuture< - 'static, - ( - quinn::RecvStream, - Result, quinn::ReadError>, - ), ->; - -impl RecvStream { - fn new(stream: quinn::RecvStream) -> Self { - Self { - stream: Some(stream), - // Should only allocate once the first time it's used - read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }), - } - } -} - -impl quic::RecvStream for RecvStream { - type Buf = Bytes; - type Error = ReadError; - - fn poll_data( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - if let Some(mut stream) = self.stream.take() { - self.read_chunk_fut.set(async move { - let chunk = stream.read_chunk(usize::MAX, true).await; - (stream, chunk) - }) - }; - - let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx)); - self.stream = Some(stream); - Poll::Ready(Ok(chunk?.map(|c| c.bytes))) - } - - fn stop_sending(&mut self, error_code: u64) { - self.stream - .as_mut() - .unwrap() - .stop(VarInt::from_u64(error_code).expect("invalid error_code")) - .ok(); - } - - fn recv_id(&self) -> StreamId { - self.stream - .as_ref() - .unwrap() - .id() - .0 - .try_into() - .expect("invalid stream id") - } -} - -/// The error type for [`RecvStream`] -/// -/// Wraps errors that occur when reading from a receive stream. -#[derive(Debug)] -pub struct ReadError(quinn::ReadError); - -impl From for std::io::Error { - fn from(value: ReadError) -> Self { - value.0.into() - } -} - -impl std::error::Error for ReadError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.0.source() - } -} - -impl fmt::Display for ReadError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl From for Arc { - fn from(e: ReadError) -> Self { - Arc::new(e) - } -} - -impl From for ReadError { - fn from(e: quinn::ReadError) -> Self { - Self(e) - } -} - -impl Error for ReadError { - fn is_timeout(&self) -> bool { - matches!( - self.0, - quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut) - ) - } - - fn err_code(&self) -> Option { - match self.0 { - quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( - quinn_proto::ApplicationClose { error_code, .. }, - )) => Some(error_code.into_inner()), - quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), - _ => None, - } - } -} - -/// Quinn-backed send stream -/// -/// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`]. -pub struct SendStream { - stream: Option, - writing: Option>, - write_fut: WriteFuture, -} - -type WriteFuture = - ReusableBoxFuture<'static, (quinn::SendStream, Result)>; - -impl SendStream -where - B: Buf, -{ - fn new(stream: quinn::SendStream) -> SendStream { - Self { - stream: Some(stream), - writing: None, - write_fut: ReusableBoxFuture::new(async { unreachable!() }), - } - } -} - -impl quic::SendStream for SendStream -where - B: Buf, -{ - type Error = SendStreamError; - - fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - if let Some(ref mut data) = self.writing { - while data.has_remaining() { - if let Some(mut stream) = self.stream.take() { - let chunk = data.chunk().to_owned(); // FIXME - avoid copy - self.write_fut.set(async move { - let ret = stream.write(&chunk).await; - (stream, ret) - }); - } - - let (stream, res) = ready!(self.write_fut.poll(cx)); - self.stream = Some(stream); - match res { - Ok(cnt) => data.advance(cnt), - Err(err) => { - return Poll::Ready(Err(SendStreamError::Write(err))); - } - } - } - } - self.writing = None; - Poll::Ready(Ok(())) - } - - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.stream - .as_mut() - .unwrap() - .poll_finish(cx) - .map_err(Into::into) - } - - fn reset(&mut self, reset_code: u64) { - let _ = self - .stream - .as_mut() - .unwrap() - .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); - } - - fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { - if self.writing.is_some() { - return Err(Self::Error::NotReady); - } - self.writing = Some(data.into()); - Ok(()) - } - - fn send_id(&self) -> StreamId { - self.stream - .as_ref() - .unwrap() - .id() - .0 - .try_into() - .expect("invalid stream id") - } -} - -impl quic::SendStreamUnframed for SendStream -where - B: Buf, -{ - fn poll_send( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut D, - ) -> Poll> { - if self.writing.is_some() { - // This signifies a bug in implementation - panic!("poll_send called while send stream is not ready") - } - - let s = Pin::new(self.stream.as_mut().unwrap()); - - let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); - match res { - Ok(written) => { - buf.advance(written); - Poll::Ready(Ok(written)) - } - Err(err) => { - // We are forced to use AsyncWrite for now because we cannot store - // the result of a call to: - // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result. - // - // This is why we have to unpack the error from io::Error instead of having it - // returned directly. This should not panic as long as quinn's AsyncWrite impl - // doesn't change. - let err = err - .into_inner() - .expect("write stream returned an empty error") - .downcast::() - .expect("write stream returned an error which type is not WriteError"); - - Poll::Ready(Err(SendStreamError::Write(*err))) - } - } - } -} - -/// The error type for [`SendStream`] -/// -/// Wraps errors that can happen writing to or polling a send stream. -#[derive(Debug)] -pub enum SendStreamError { - /// Errors when writing, wrapping a [`quinn::WriteError`] - Write(WriteError), - /// Error when the stream is not ready, because it is still sending - /// data from a previous call - NotReady, -} - -impl From for std::io::Error { - fn from(value: SendStreamError) -> Self { - match value { - SendStreamError::Write(err) => err.into(), - SendStreamError::NotReady => { - std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") - } - } - } -} - -impl std::error::Error for SendStreamError {} - -impl Display for SendStreamError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl From for SendStreamError { - fn from(e: WriteError) -> Self { - Self::Write(e) - } -} - -impl Error for SendStreamError { - fn is_timeout(&self) -> bool { - matches!( - self, - Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::TimedOut - )) - ) - } - - fn err_code(&self) -> Option { - match self { - Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), - Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }), - )) => Some(error_code.into_inner()), - _ => None, - } - } -} - -impl From for Arc { - fn from(e: SendStreamError) -> Self { - Arc::new(e) - } -} diff --git a/submodules/quinn b/submodules/quinn deleted file mode 160000 index 6d80efe..0000000 --- a/submodules/quinn +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 6d80efeeae60b96ff330ae6a70e8cc9291fcc615 diff --git a/submodules/s2n-quic b/submodules/s2n-quic deleted file mode 160000 index 30027ee..0000000 --- a/submodules/s2n-quic +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 30027eeacc7b620da62fc4825b94afd57ab0c7be diff --git a/submodules/s2n-quic-h3/Cargo.toml b/submodules/s2n-quic-h3/Cargo.toml new file mode 100644 index 0000000..fecfd10 --- /dev/null +++ b/submodules/s2n-quic-h3/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "s2n-quic-h3" +# this in an unpublished internal crate so the version should not be changed +version = "0.1.0" +authors = ["AWS s2n"] +edition = "2021" +rust-version = "1.63" +license = "Apache-2.0" +# this contains an http3 implementation for testing purposes and should not be published +publish = false + +[dependencies] +bytes = { version = "1", default-features = false } +futures = { version = "0.3", default-features = false } +h3 = { path = "../h3/h3/" } +s2n-quic = "1.31.0" +s2n-quic-core = "0.31.0" diff --git a/submodules/s2n-quic-h3/README.md b/submodules/s2n-quic-h3/README.md new file mode 100644 index 0000000..aed9475 --- /dev/null +++ b/submodules/s2n-quic-h3/README.md @@ -0,0 +1,10 @@ +# s2n-quic-h3 + +This is an internal crate used by [s2n-quic](https://github.com/aws/s2n-quic) written as a proof of concept for implementing HTTP3 on top of s2n-quic. The API is not currently stable and should not be used directly. + +## License + +This project is licensed under the [Apache-2.0 License][license-url]. + +[license-badge]: https://img.shields.io/badge/license-apache-blue.svg +[license-url]: https://aws.amazon.com/apache-2-0/ diff --git a/submodules/s2n-quic-h3/src/lib.rs b/submodules/s2n-quic-h3/src/lib.rs new file mode 100644 index 0000000..c85f197 --- /dev/null +++ b/submodules/s2n-quic-h3/src/lib.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod s2n_quic; + +pub use self::s2n_quic::*; +pub use h3; diff --git a/submodules/s2n-quic-h3/src/s2n_quic.rs b/submodules/s2n-quic-h3/src/s2n_quic.rs new file mode 100644 index 0000000..dffa19b --- /dev/null +++ b/submodules/s2n-quic-h3/src/s2n_quic.rs @@ -0,0 +1,506 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use bytes::{Buf, Bytes}; +use futures::ready; +use h3::quic::{self, Error, StreamId, WriteBuf}; +use s2n_quic::stream::{BidirectionalStream, ReceiveStream}; +use s2n_quic_core::varint::VarInt; +use std::{ + convert::TryInto, + fmt::{self, Display}, + sync::Arc, + task::{self, Poll}, +}; + +pub struct Connection { + conn: s2n_quic::connection::Handle, + bidi_acceptor: s2n_quic::connection::BidirectionalStreamAcceptor, + recv_acceptor: s2n_quic::connection::ReceiveStreamAcceptor, +} + +impl Connection { + pub fn new(new_conn: s2n_quic::Connection) -> Self { + let (handle, acceptor) = new_conn.split(); + let (bidi, recv) = acceptor.split(); + + Self { + conn: handle, + bidi_acceptor: bidi, + recv_acceptor: recv, + } + } +} + +#[derive(Debug)] +pub struct ConnectionError(s2n_quic::connection::Error); + +impl std::error::Error for ConnectionError {} + +impl fmt::Display for ConnectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Error for ConnectionError { + fn is_timeout(&self) -> bool { + matches!(self.0, s2n_quic::connection::Error::IdleTimerExpired { .. }) + } + + fn err_code(&self) -> Option { + match self.0 { + s2n_quic::connection::Error::Application { error, .. } => Some(error.into()), + _ => None, + } + } +} + +impl From for ConnectionError { + fn from(e: s2n_quic::connection::Error) -> Self { + Self(e) + } +} + +impl quic::Connection for Connection +where + B: Buf, +{ + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type OpenStreams = OpenStreams; + type Error = ConnectionError; + + fn poll_accept_recv( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let recv = match ready!(self.recv_acceptor.poll_accept_receive_stream(cx))? { + Some(x) => x, + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) + } + + fn poll_accept_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let (recv, send) = match ready!(self.bidi_acceptor.poll_accept_bidirectional_stream(cx))? { + Some(x) => x.split(), + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::BidiStream { + send: Self::SendStream::new(send), + recv: Self::RecvStream::new(recv), + }))) + } + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; + Ok(stream.into()).into() + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_send_stream(cx))?; + Ok(stream.into()).into() + } + + fn opener(&self) -> Self::OpenStreams { + OpenStreams { + conn: self.conn.clone(), + } + } + + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + self.conn.close( + code.value() + .try_into() + .expect("s2n-quic supports error codes up to 2^62-1"), + ); + } +} + +pub struct OpenStreams { + conn: s2n_quic::connection::Handle, +} + +impl quic::OpenStreams for OpenStreams +where + B: Buf, +{ + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type Error = ConnectionError; + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; + Ok(stream.into()).into() + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_send_stream(cx))?; + Ok(stream.into()).into() + } + + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + self.conn.close( + code.value() + .try_into() + .unwrap_or_else(|_| VarInt::MAX.into()), + ); + } +} + +impl Clone for OpenStreams { + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + } + } +} + +pub struct BidiStream +where + B: Buf, +{ + send: SendStream, + recv: RecvStream, +} + +impl quic::BidiStream for BidiStream +where + B: Buf, +{ + type SendStream = SendStream; + type RecvStream = RecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + (self.send, self.recv) + } +} + +impl quic::RecvStream for BidiStream +where + B: Buf, +{ + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + self.recv.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.recv.stop_sending(error_code) + } + + fn recv_id(&self) -> StreamId { + self.recv.stream.id().try_into().expect("invalid stream id") + } +} + +impl quic::SendStream for BidiStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_ready(cx) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.send.reset(reset_code) + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn send_id(&self) -> StreamId { + self.send.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for BidiStream +where + B: Buf, +{ + fn from(bidi: BidirectionalStream) -> Self { + let (recv, send) = bidi.split(); + BidiStream { + send: send.into(), + recv: recv.into(), + } + } +} + +pub struct RecvStream { + stream: s2n_quic::stream::ReceiveStream, +} + +impl RecvStream { + fn new(stream: s2n_quic::stream::ReceiveStream) -> Self { + Self { stream } + } +} + +impl quic::RecvStream for RecvStream { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let buf = ready!(self.stream.poll_receive(cx))?; + Ok(buf).into() + } + + fn stop_sending(&mut self, error_code: u64) { + let _ = self.stream.stop_sending( + s2n_quic::application::Error::new(error_code) + .expect("s2n-quic supports error codes up to 2^62-1"), + ); + } + + fn recv_id(&self) -> StreamId { + self.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for RecvStream { + fn from(recv: ReceiveStream) -> Self { + RecvStream::new(recv) + } +} + +#[derive(Debug)] +pub struct ReadError(s2n_quic::stream::Error); + +impl std::error::Error for ReadError {} + +impl fmt::Display for ReadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl From for Arc { + fn from(e: ReadError) -> Self { + Arc::new(e) + } +} + +impl From for ReadError { + fn from(e: s2n_quic::stream::Error) -> Self { + Self(e) + } +} + +impl Error for ReadError { + fn is_timeout(&self) -> bool { + matches!( + self.0, + s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::IdleTimerExpired { .. }, + .. + } + ) + } + + fn err_code(&self) -> Option { + match self.0 { + s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::Application { error, .. }, + .. + } => Some(error.into()), + s2n_quic::stream::Error::StreamReset { error, .. } => Some(error.into()), + _ => None, + } + } +} + +pub struct SendStream { + stream: s2n_quic::stream::SendStream, + chunk: Option, + buf: Option>, // TODO: Replace with buf: PhantomData + // after https://github.com/hyperium/h3/issues/78 is resolved +} + +impl SendStream +where + B: Buf, +{ + fn new(stream: s2n_quic::stream::SendStream) -> SendStream { + Self { + stream, + chunk: None, + buf: Default::default(), + } + } +} + +impl quic::SendStream for SendStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + loop { + // try to flush the current chunk if we have one + if let Some(chunk) = self.chunk.as_mut() { + ready!(self.stream.poll_send(chunk, cx))?; + + // s2n-quic will take the whole chunk on send, even if it exceeds the limits + debug_assert!(chunk.is_empty()); + self.chunk = None; + } + + // try to take the next chunk from the WriteBuf + if let Some(ref mut data) = self.buf { + let len = data.chunk().len(); + + // if the write buf is empty, then clear it and break + if len == 0 { + self.buf = None; + break; + } + + // copy the first chunk from WriteBuf and prepare it to flush + let chunk = data.copy_to_bytes(len); + self.chunk = Some(chunk); + + // loop back around to flush the chunk + continue; + } + + // if we didn't have either a chunk or WriteBuf, then we're ready + break; + } + + Poll::Ready(Ok(())) + + // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved + // self.available_bytes = ready!(self.stream.poll_send_ready(cx))?; + // Poll::Ready(Ok(())) + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + if self.buf.is_some() { + return Err(Self::Error::NotReady); + } + self.buf = Some(data.into()); + Ok(()) + + // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved + // let mut data = data.into(); + // while self.available_bytes > 0 && data.has_remaining() { + // let len = data.chunk().len(); + // let chunk = data.copy_to_bytes(len); + // self.stream.send_data(chunk)?; + // self.available_bytes = self.available_bytes.saturating_sub(len); + // } + // Ok(()) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + // ensure all chunks are flushed to the QUIC stream before finishing + ready!(self.poll_ready(cx))?; + self.stream.finish()?; + Ok(()).into() + } + + fn reset(&mut self, reset_code: u64) { + let _ = self + .stream + .reset(reset_code.try_into().unwrap_or_else(|_| VarInt::MAX.into())); + } + + fn send_id(&self) -> StreamId { + self.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for SendStream +where + B: Buf, +{ + fn from(send: s2n_quic::stream::SendStream) -> Self { + SendStream::new(send) + } +} + +#[derive(Debug)] +pub enum SendStreamError { + Write(s2n_quic::stream::Error), + NotReady, +} + +impl std::error::Error for SendStreamError {} + +impl Display for SendStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl From for SendStreamError { + fn from(e: s2n_quic::stream::Error) -> Self { + Self::Write(e) + } +} + +impl Error for SendStreamError { + fn is_timeout(&self) -> bool { + matches!( + self, + Self::Write(s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::IdleTimerExpired { .. }, + .. + }) + ) + } + + fn err_code(&self) -> Option { + match self { + Self::Write(s2n_quic::stream::Error::StreamReset { error, .. }) => { + Some((*error).into()) + } + Self::Write(s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::Application { error, .. }, + .. + }) => Some((*error).into()), + _ => None, + } + } +} + +impl From for Arc { + fn from(e: SendStreamError) -> Self { + Arc::new(e) + } +} From 7bc6e309350052264f3a63f3cec815c9ddbb35c5 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 21 Nov 2023 21:36:52 +0900 Subject: [PATCH 02/50] chore: deps --- rpxy-bin/Cargo.toml | 12 ++++++------ rpxy-lib/Cargo.toml | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 1848e5e..dd6e744 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -25,9 +25,9 @@ rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ anyhow = "1.0.75" rustc-hash = "1.1.0" -serde = { version = "1.0.192", default-features = false, features = ["derive"] } +serde = { version = "1.0.193", default-features = false, features = ["derive"] } derive_builder = "0.12.0" -tokio = { version = "1.33.0", default-features = false, features = [ +tokio = { version = "1.34.0", default-features = false, features = [ "net", "rt-multi-thread", "time", @@ -35,17 +35,17 @@ tokio = { version = "1.33.0", default-features = false, features = [ "macros", ] } async-trait = "0.1.74" -rustls-pemfile = "1.0.3" +rustls-pemfile = "1.0.4" mimalloc = { version = "*", default-features = false } # config -clap = { version = "4.4.7", features = ["std", "cargo", "wrap_help"] } -toml = { version = "0.8", default-features = false, features = ["parse"] } +clap = { version = "4.4.8", features = ["std", "cargo", "wrap_help"] } +toml = { version = "0.8.8", default-features = false, features = ["parse"] } hot_reload = "0.1.4" # logging tracing = { version = "0.1.40" } -tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } [dev-dependencies] diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index b4b475d..fae0c3c 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -44,7 +44,7 @@ thiserror = "1.0.50" http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } -hyper-util = { version = "0.1.0", features = ["full"] } +hyper-util = { version = "0.1.1", features = ["full"] } hyper-rustls = { version = "0.24.2", default-features = false, features = [ "tokio-runtime", "webpki-tokio", From f98c778a0cb2c8887fe76b1eff0b73468e7a3af7 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 21 Nov 2023 22:46:52 +0900 Subject: [PATCH 03/50] wip: refactor whole module in lib --- Cargo.toml | 2 +- legacy-lib/Cargo.toml | 89 +++++ .../src/backend/load_balance.rs | 0 .../src/backend/load_balance_sticky.rs | 0 {rpxy-lib => legacy-lib}/src/backend/mod.rs | 0 .../src/backend/sticky_cookie.rs | 0 .../src/backend/upstream.rs | 0 .../src/backend/upstream_opts.rs | 0 legacy-lib/src/certs.rs | 91 +++++ legacy-lib/src/constants.rs | 45 +++ legacy-lib/src/error.rs | 86 +++++ legacy-lib/src/globals.rs | 325 ++++++++++++++++++ {rpxy-lib => legacy-lib}/src/handler/cache.rs | 0 {rpxy-lib => legacy-lib}/src/handler/error.rs | 0 .../src/handler/forwarder.rs | 0 .../src/handler/handler_main.rs | 0 {rpxy-lib => legacy-lib}/src/handler/mod.rs | 0 .../src/handler/utils_headers.rs | 0 .../src/handler/utils_request.rs | 0 .../src/handler/utils_synth_response.rs | 0 .../src/hyper_executor.rs | 0 legacy-lib/src/lib.rs | 112 ++++++ legacy-lib/src/log.rs | 98 ++++++ .../src/proxy/crypto_service.rs | 0 {rpxy-lib => legacy-lib}/src/proxy/mod.rs | 0 .../src/proxy/proxy_client_cert.rs | 0 .../src/proxy/proxy_h3.rs | 0 .../src/proxy/proxy_main.rs | 0 .../src/proxy/proxy_quic_quinn.rs | 0 .../src/proxy/proxy_quic_s2n.rs | 0 .../src/proxy/proxy_tls.rs | 0 {rpxy-lib => legacy-lib}/src/proxy/socket.rs | 0 .../src/utils/bytes_name.rs | 0 {rpxy-lib => legacy-lib}/src/utils/mod.rs | 0 .../src/utils/socket_addr.rs | 0 rpxy-bin/Cargo.toml | 12 +- rpxy-lib/Cargo.toml | 94 ++--- rpxy-lib/src/certs.rs | 71 +--- rpxy-lib/src/error.rs | 80 +---- rpxy-lib/src/globals.rs | 188 +--------- rpxy-lib/src/lib.rs | 84 ++--- rpxy-lib/src/log.rs | 97 ------ 42 files changed, 943 insertions(+), 531 deletions(-) create mode 100644 legacy-lib/Cargo.toml rename {rpxy-lib => legacy-lib}/src/backend/load_balance.rs (100%) rename {rpxy-lib => legacy-lib}/src/backend/load_balance_sticky.rs (100%) rename {rpxy-lib => legacy-lib}/src/backend/mod.rs (100%) rename {rpxy-lib => legacy-lib}/src/backend/sticky_cookie.rs (100%) rename {rpxy-lib => legacy-lib}/src/backend/upstream.rs (100%) rename {rpxy-lib => legacy-lib}/src/backend/upstream_opts.rs (100%) create mode 100644 legacy-lib/src/certs.rs create mode 100644 legacy-lib/src/constants.rs create mode 100644 legacy-lib/src/error.rs create mode 100644 legacy-lib/src/globals.rs rename {rpxy-lib => legacy-lib}/src/handler/cache.rs (100%) rename {rpxy-lib => legacy-lib}/src/handler/error.rs (100%) rename {rpxy-lib => legacy-lib}/src/handler/forwarder.rs (100%) rename {rpxy-lib => legacy-lib}/src/handler/handler_main.rs (100%) rename {rpxy-lib => legacy-lib}/src/handler/mod.rs (100%) rename {rpxy-lib => legacy-lib}/src/handler/utils_headers.rs (100%) rename {rpxy-lib => legacy-lib}/src/handler/utils_request.rs (100%) rename {rpxy-lib => legacy-lib}/src/handler/utils_synth_response.rs (100%) rename {rpxy-lib => legacy-lib}/src/hyper_executor.rs (100%) create mode 100644 legacy-lib/src/lib.rs create mode 100644 legacy-lib/src/log.rs rename {rpxy-lib => legacy-lib}/src/proxy/crypto_service.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/mod.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/proxy_client_cert.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/proxy_h3.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/proxy_main.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/proxy_quic_quinn.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/proxy_quic_s2n.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/proxy_tls.rs (100%) rename {rpxy-lib => legacy-lib}/src/proxy/socket.rs (100%) rename {rpxy-lib => legacy-lib}/src/utils/bytes_name.rs (100%) rename {rpxy-lib => legacy-lib}/src/utils/mod.rs (100%) rename {rpxy-lib => legacy-lib}/src/utils/socket_addr.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index c512b18..7868088 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] -members = ["rpxy-bin", "rpxy-lib"] +members = ["rpxy-bin", "rpxy-lib", "legacy-lib"] exclude = ["submodules"] resolver = "2" diff --git a/legacy-lib/Cargo.toml b/legacy-lib/Cargo.toml new file mode 100644 index 0000000..c975fb6 --- /dev/null +++ b/legacy-lib/Cargo.toml @@ -0,0 +1,89 @@ +[package] +name = "rpxy-lib-legacy" +version = "0.6.2" +authors = ["Jun Kurihara"] +homepage = "https://github.com/junkurihara/rust-rpxy" +repository = "https://github.com/junkurihara/rust-rpxy" +license = "MIT" +readme = "../README.md" +edition = "2021" +publish = false + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[features] +default = ["http3-s2n", "sticky-cookie", "cache"] +http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] +http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] +sticky-cookie = ["base64", "sha2", "chrono"] +cache = ["http-cache-semantics", "lru"] +native-roots = ["hyper-rustls/native-tokio"] + +[dependencies] +rand = "0.8.5" +rustc-hash = "1.1.0" +bytes = "1.5.0" +derive_builder = "0.12.0" +futures = { version = "0.3.29", features = ["alloc", "async-await"] } +tokio = { version = "1.34.0", default-features = false, features = [ + "net", + "rt-multi-thread", + "time", + "sync", + "macros", + "fs", +] } +async-trait = "0.1.74" +hot_reload = "0.1.4" # reloading certs + +# Error handling +anyhow = "1.0.75" +thiserror = "1.0.50" + +# http and tls +http = "1.0.0" +http-body-util = "0.1.0" +hyper = { version = "1.0.1", default-features = false } +hyper-util = { version = "0.1.1", features = ["full"] } +hyper-rustls = { version = "0.24.2", default-features = false, features = [ + "tokio-runtime", + "webpki-tokio", + "http1", + "http2", +] } +tokio-rustls = { version = "0.24.1", features = ["early-data"] } +rustls = { version = "0.21.9", default-features = false } +webpki = "0.22.4" +x509-parser = "0.15.1" + +# logging +tracing = { version = "0.1.40" } + +# http/3 +quinn = { version = "0.10.2", optional = true } +h3 = { path = "../submodules/h3/h3/", optional = true } +h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } +s2n-quic = { version = "1.31.0", default-features = false, features = [ + "provider-tls-rustls", +], optional = true } +s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } +s2n-quic-rustls = { version = "0.31.0", optional = true } +# for UDP socket wit SO_REUSEADDR when h3 with quinn +socket2 = { version = "0.5.5", features = ["all"], optional = true } + +# cache +http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } +lru = { version = "0.12.0", optional = true } + +# cookie handling for sticky cookie +chrono = { version = "0.4.31", default-features = false, features = [ + "unstable-locales", + "alloc", + "clock", +], optional = true } +base64 = { version = "0.21.5", optional = true } +sha2 = { version = "0.10.8", default-features = false, optional = true } + + +[dev-dependencies] +# http and tls diff --git a/rpxy-lib/src/backend/load_balance.rs b/legacy-lib/src/backend/load_balance.rs similarity index 100% rename from rpxy-lib/src/backend/load_balance.rs rename to legacy-lib/src/backend/load_balance.rs diff --git a/rpxy-lib/src/backend/load_balance_sticky.rs b/legacy-lib/src/backend/load_balance_sticky.rs similarity index 100% rename from rpxy-lib/src/backend/load_balance_sticky.rs rename to legacy-lib/src/backend/load_balance_sticky.rs diff --git a/rpxy-lib/src/backend/mod.rs b/legacy-lib/src/backend/mod.rs similarity index 100% rename from rpxy-lib/src/backend/mod.rs rename to legacy-lib/src/backend/mod.rs diff --git a/rpxy-lib/src/backend/sticky_cookie.rs b/legacy-lib/src/backend/sticky_cookie.rs similarity index 100% rename from rpxy-lib/src/backend/sticky_cookie.rs rename to legacy-lib/src/backend/sticky_cookie.rs diff --git a/rpxy-lib/src/backend/upstream.rs b/legacy-lib/src/backend/upstream.rs similarity index 100% rename from rpxy-lib/src/backend/upstream.rs rename to legacy-lib/src/backend/upstream.rs diff --git a/rpxy-lib/src/backend/upstream_opts.rs b/legacy-lib/src/backend/upstream_opts.rs similarity index 100% rename from rpxy-lib/src/backend/upstream_opts.rs rename to legacy-lib/src/backend/upstream_opts.rs diff --git a/legacy-lib/src/certs.rs b/legacy-lib/src/certs.rs new file mode 100644 index 0000000..c9cfafd --- /dev/null +++ b/legacy-lib/src/certs.rs @@ -0,0 +1,91 @@ +use async_trait::async_trait; +use rustc_hash::FxHashSet as HashSet; +use rustls::{ + sign::{any_supported_type, CertifiedKey}, + Certificate, OwnedTrustAnchor, PrivateKey, +}; +use std::io; +use x509_parser::prelude::*; + +#[async_trait] +// Trait to read certs and keys anywhere from KVS, file, sqlite, etc. +pub trait CryptoSource { + type Error; + + /// read crypto materials from source + async fn read(&self) -> Result; + + /// Returns true when mutual tls is enabled + fn is_mutual_tls(&self) -> bool; +} + +/// Certificates and private keys in rustls loaded from files +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct CertsAndKeys { + pub certs: Vec, + pub cert_keys: Vec, + pub client_ca_certs: Option>, +} + +impl CertsAndKeys { + pub fn parse_server_certs_and_keys(&self) -> Result { + // for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { + let signing_key = self + .cert_keys + .iter() + .find_map(|k| { + if let Ok(sk) = any_supported_type(k) { + Some(sk) + } else { + None + } + }) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to find a valid certificate and key", + ) + })?; + Ok(CertifiedKey::new(self.certs.clone(), signing_key)) + } + + pub fn parse_client_ca_certs(&self) -> Result<(Vec, HashSet>), anyhow::Error> { + let certs = self.client_ca_certs.as_ref().ok_or(anyhow::anyhow!("No client cert"))?; + + let owned_trust_anchors: Vec<_> = certs + .iter() + .map(|v| { + // let trust_anchor = tokio_rustls::webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); + let trust_anchor = webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + trust_anchor.subject, + trust_anchor.spki, + trust_anchor.name_constraints, + ) + }) + .collect(); + + // TODO: SKID is not used currently + let subject_key_identifiers: HashSet<_> = certs + .iter() + .filter_map(|v| { + // retrieve ca key id (subject key id) + let cert = parse_x509_certificate(&v.0).unwrap().1; + let subject_key_ids = cert + .iter_extensions() + .filter_map(|ext| match ext.parsed_extension() { + ParsedExtension::SubjectKeyIdentifier(skid) => Some(skid), + _ => None, + }) + .collect::>(); + if !subject_key_ids.is_empty() { + Some(subject_key_ids[0].0.to_owned()) + } else { + None + } + }) + .collect(); + + Ok((owned_trust_anchors, subject_key_identifiers)) + } +} diff --git a/legacy-lib/src/constants.rs b/legacy-lib/src/constants.rs new file mode 100644 index 0000000..ebec1fc --- /dev/null +++ b/legacy-lib/src/constants.rs @@ -0,0 +1,45 @@ +pub const RESPONSE_HEADER_SERVER: &str = "rpxy"; +// pub const LISTEN_ADDRESSES_V4: &[&str] = &["0.0.0.0"]; +// pub const LISTEN_ADDRESSES_V6: &[&str] = &["[::]"]; +pub const TCP_LISTEN_BACKLOG: u32 = 1024; +// pub const HTTP_LISTEN_PORT: u16 = 8080; +// pub const HTTPS_LISTEN_PORT: u16 = 8443; +pub const PROXY_TIMEOUT_SEC: u64 = 60; +pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; +pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser +pub const MAX_CLIENTS: usize = 512; +pub const MAX_CONCURRENT_STREAMS: u32 = 64; +pub const CERTS_WATCH_DELAY_SECS: u32 = 60; +pub const LOAD_CERTS_ONLY_WHEN_UPDATED: bool = true; + +// #[cfg(feature = "http3")] +// pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB +// #[cfg(feature = "http3")] +// pub const H3_REQUEST_BUF_SIZE: usize = 65_536; // 64KB // handled by quinn + +#[allow(non_snake_case)] +#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] +pub mod H3 { + pub const ALT_SVC_MAX_AGE: u32 = 3600; + pub const REQUEST_MAX_BODY_SIZE: usize = 268_435_456; // 256MB + pub const MAX_CONCURRENT_CONNECTIONS: u32 = 4096; + pub const MAX_CONCURRENT_BIDISTREAM: u32 = 64; + pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; + pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs +} + +#[cfg(feature = "sticky-cookie")] +/// For load-balancing with sticky cookie +pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; + +#[cfg(feature = "cache")] +// # of entries in cache +pub const MAX_CACHE_ENTRY: usize = 1_000; +#[cfg(feature = "cache")] +// max size for each file in bytes +pub const MAX_CACHE_EACH_SIZE: usize = 65_535; +#[cfg(feature = "cache")] +// on memory cache if less than or equel to +pub const MAX_CACHE_EACH_SIZE_ON_MEMORY: usize = 4_096; + +// TODO: max cache size in total diff --git a/legacy-lib/src/error.rs b/legacy-lib/src/error.rs new file mode 100644 index 0000000..c672682 --- /dev/null +++ b/legacy-lib/src/error.rs @@ -0,0 +1,86 @@ +pub use anyhow::{anyhow, bail, ensure, Context}; +use std::io; +use thiserror::Error; + +pub type Result = std::result::Result; + +/// Describes things that can go wrong in the Rpxy +#[derive(Debug, Error)] +pub enum RpxyError { + #[error("Proxy build error: {0}")] + ProxyBuild(#[from] crate::proxy::ProxyBuilderError), + + #[error("Backend build error: {0}")] + BackendBuild(#[from] crate::backend::BackendBuilderError), + + #[error("MessageHandler build error: {0}")] + HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError), + + #[error("Config builder error: {0}")] + ConfigBuild(&'static str), + + #[error("Http Message Handler Error: {0}")] + Handler(&'static str), + + #[error("Cache Error: {0}")] + Cache(&'static str), + + #[error("Http Request Message Error: {0}")] + Request(&'static str), + + #[error("TCP/UDP Proxy Layer Error: {0}")] + Proxy(String), + + #[allow(unused)] + #[error("LoadBalance Layer Error: {0}")] + LoadBalance(String), + + #[error("I/O Error: {0}")] + Io(#[from] io::Error), + + // #[error("Toml Deserialization Error")] + // TomlDe(#[from] toml::de::Error), + #[cfg(feature = "http3-quinn")] + #[error("Quic Connection Error [quinn]: {0}")] + QuicConn(#[from] quinn::ConnectionError), + + #[cfg(feature = "http3-s2n")] + #[error("Quic Connection Error [s2n-quic]: {0}")] + QUicConn(#[from] s2n_quic::connection::Error), + + #[cfg(feature = "http3-quinn")] + #[error("H3 Error [quinn]: {0}")] + H3(#[from] h3::Error), + + #[cfg(feature = "http3-s2n")] + #[error("H3 Error [s2n-quic]: {0}")] + H3(#[from] s2n_quic_h3::h3::Error), + + #[error("rustls Connection Error: {0}")] + Rustls(#[from] rustls::Error), + + #[error("Hyper Error: {0}")] + Hyper(#[from] hyper::Error), + + #[error("Hyper Http Error: {0}")] + HyperHttp(#[from] hyper::http::Error), + + #[error("Hyper Http HeaderValue Error: {0}")] + HyperHeaderValue(#[from] hyper::header::InvalidHeaderValue), + + #[error("Hyper Http HeaderName Error: {0}")] + HyperHeaderName(#[from] hyper::header::InvalidHeaderName), + + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +#[allow(dead_code)] +#[derive(Debug, Error, Clone)] +pub enum ClientCertsError { + #[error("TLS Client Certificate is Required for Given SNI: {0}")] + ClientCertRequired(String), + + #[error("Inconsistent TLS Client Certificate for Given SNI: {0}")] + InconsistentClientCert(String), +} diff --git a/legacy-lib/src/globals.rs b/legacy-lib/src/globals.rs new file mode 100644 index 0000000..02605a6 --- /dev/null +++ b/legacy-lib/src/globals.rs @@ -0,0 +1,325 @@ +use crate::{ + backend::{ + Backend, BackendBuilder, Backends, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption, + }, + certs::CryptoSource, + constants::*, + error::RpxyError, + log::*, + utils::{BytesName, PathNameBytesExp}, +}; +use rustc_hash::FxHashMap as HashMap; +use std::net::SocketAddr; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use tokio::time::Duration; + +/// Global object containing proxy configurations and shared object like counters. +/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. +pub struct Globals +where + T: CryptoSource, +{ + /// Configuration parameters for proxy transport and request handlers + pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも? + + /// Backend application objects to which http request handler forward incoming requests + pub backends: Backends, + + /// Shared context - Counter for serving requests + pub request_count: RequestCount, + + /// Shared context - Async task runtime handler + pub runtime_handle: tokio::runtime::Handle, + + /// Shared context - Notify object to stop async tasks + pub term_notify: Option>, +} + +/// Configuration parameters for proxy transport and request handlers +#[derive(PartialEq, Eq, Clone)] +pub struct ProxyConfig { + pub listen_sockets: Vec, // when instantiate server + pub http_port: Option, // when instantiate server + pub https_port: Option, // when instantiate server + pub tcp_listen_backlog: u32, // when instantiate server + + pub proxy_timeout: Duration, // when serving requests at Proxy + pub upstream_timeout: Duration, // when serving requests at Handler + + pub max_clients: usize, // when serving requests + pub max_concurrent_streams: u32, // when instantiate server + pub keepalive: bool, // when instantiate server + + // experimentals + pub sni_consistency: bool, // Handler + + #[cfg(feature = "cache")] + pub cache_enabled: bool, + #[cfg(feature = "cache")] + pub cache_dir: Option, + #[cfg(feature = "cache")] + pub cache_max_entry: usize, + #[cfg(feature = "cache")] + pub cache_max_each_size: usize, + #[cfg(feature = "cache")] + pub cache_max_each_size_on_memory: usize, + + // All need to make packet acceptor + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + pub http3: bool, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + pub h3_alt_svc_max_age: u32, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + pub h3_request_max_body_size: usize, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + pub h3_max_concurrent_bidistream: u32, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + pub h3_max_concurrent_unistream: u32, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + pub h3_max_concurrent_connections: u32, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + pub h3_max_idle_timeout: Option, +} + +impl Default for ProxyConfig { + fn default() -> Self { + Self { + listen_sockets: Vec::new(), + http_port: None, + https_port: None, + tcp_listen_backlog: TCP_LISTEN_BACKLOG, + + // TODO: Reconsider each timeout values + proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), + upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), + + max_clients: MAX_CLIENTS, + max_concurrent_streams: MAX_CONCURRENT_STREAMS, + keepalive: true, + + sni_consistency: true, + + #[cfg(feature = "cache")] + cache_enabled: false, + #[cfg(feature = "cache")] + cache_dir: None, + #[cfg(feature = "cache")] + cache_max_entry: MAX_CACHE_ENTRY, + #[cfg(feature = "cache")] + cache_max_each_size: MAX_CACHE_EACH_SIZE, + #[cfg(feature = "cache")] + cache_max_each_size_on_memory: MAX_CACHE_EACH_SIZE_ON_MEMORY, + + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + http3: false, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + h3_alt_svc_max_age: H3::ALT_SVC_MAX_AGE, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + h3_request_max_body_size: H3::REQUEST_MAX_BODY_SIZE, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + h3_max_concurrent_connections: H3::MAX_CONCURRENT_CONNECTIONS, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + h3_max_concurrent_bidistream: H3::MAX_CONCURRENT_BIDISTREAM, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + h3_max_concurrent_unistream: H3::MAX_CONCURRENT_UNISTREAM, + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + h3_max_idle_timeout: Some(Duration::from_secs(H3::MAX_IDLE_TIMEOUT)), + } + } +} + +/// Configuration parameters for backend applications +#[derive(PartialEq, Eq, Clone)] +pub struct AppConfigList +where + T: CryptoSource, +{ + pub inner: Vec>, + pub default_app: Option, +} +impl TryInto> for AppConfigList +where + T: CryptoSource + Clone, +{ + type Error = RpxyError; + + fn try_into(self) -> Result, Self::Error> { + let mut backends = Backends::new(); + for app_config in self.inner.iter() { + let backend = app_config.try_into()?; + backends + .apps + .insert(app_config.server_name.clone().to_server_name_vec(), backend); + info!( + "Registering application {} ({})", + &app_config.server_name, &app_config.app_name + ); + } + + // default backend application for plaintext http requests + if let Some(d) = self.default_app { + let d_sn: Vec<&str> = backends + .apps + .iter() + .filter(|(_k, v)| v.app_name == d) + .map(|(_, v)| v.server_name.as_ref()) + .collect(); + if !d_sn.is_empty() { + info!( + "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", + d, d_sn[0] + ); + backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); + } + } + Ok(backends) + } +} + +/// Configuration parameters for single backend application +#[derive(PartialEq, Eq, Clone)] +pub struct AppConfig +where + T: CryptoSource, +{ + pub app_name: String, + pub server_name: String, + pub reverse_proxy: Vec, + pub tls: Option>, +} +impl TryInto> for &AppConfig +where + T: CryptoSource + Clone, +{ + type Error = RpxyError; + + fn try_into(self) -> Result, Self::Error> { + // backend builder + let mut backend_builder = BackendBuilder::default(); + // reverse proxy settings + let reverse_proxy = self.try_into()?; + + backend_builder + .app_name(self.app_name.clone()) + .server_name(self.server_name.clone()) + .reverse_proxy(reverse_proxy); + + // TLS settings and build backend instance + let backend = if self.tls.is_none() { + backend_builder.build().map_err(RpxyError::BackendBuild)? + } else { + let tls = self.tls.as_ref().unwrap(); + + backend_builder + .https_redirection(Some(tls.https_redirection)) + .crypto_source(Some(tls.inner.clone())) + .build()? + }; + Ok(backend) + } +} +impl TryInto for &AppConfig +where + T: CryptoSource + Clone, +{ + type Error = RpxyError; + + fn try_into(self) -> Result { + let mut upstream: HashMap = HashMap::default(); + + self.reverse_proxy.iter().for_each(|rpo| { + let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); + // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); + // let lb_upstream_num = vec_upstream.len(); + let elem = UpstreamGroupBuilder::default() + .upstream(&upstream_vec) + .path(&rpo.path) + .replace_path(&rpo.replace_path) + .lb(&rpo.load_balance, &upstream_vec, &self.server_name, &rpo.path) + .opts(&rpo.upstream_options) + .build() + .unwrap(); + + upstream.insert(elem.path.clone(), elem); + }); + if self.reverse_proxy.iter().filter(|rpo| rpo.path.is_none()).count() >= 2 { + error!("Multiple default reverse proxy setting"); + return Err(RpxyError::ConfigBuild("Invalid reverse proxy setting")); + } + + if !(upstream.iter().all(|(_, elem)| { + !(elem.opts.contains(&UpstreamOption::ForceHttp11Upstream) + && elem.opts.contains(&UpstreamOption::ForceHttp2Upstream)) + })) { + error!("Either one of force_http11 or force_http2 can be enabled"); + return Err(RpxyError::ConfigBuild("Invalid upstream option setting")); + } + + Ok(ReverseProxy { upstream }) + } +} + +/// Configuration parameters for single reverse proxy corresponding to the path +#[derive(PartialEq, Eq, Clone)] +pub struct ReverseProxyConfig { + pub path: Option, + pub replace_path: Option, + pub upstream: Vec, + pub upstream_options: Option>, + pub load_balance: Option, +} + +/// Configuration parameters for single upstream destination from a reverse proxy +#[derive(PartialEq, Eq, Clone)] +pub struct UpstreamUri { + pub inner: hyper::Uri, +} +impl TryInto for &UpstreamUri { + type Error = anyhow::Error; + + fn try_into(self) -> std::result::Result { + Ok(Upstream { + uri: self.inner.clone(), + }) + } +} + +/// Configuration parameters on TLS for a single backend application +#[derive(PartialEq, Eq, Clone)] +pub struct TlsConfig +where + T: CryptoSource, +{ + pub inner: T, + pub https_redirection: bool, +} + +#[derive(Debug, Clone, Default)] +/// Counter for serving requests +pub struct RequestCount(Arc); + +impl RequestCount { + pub fn current(&self) -> usize { + self.0.load(Ordering::Relaxed) + } + + pub fn increment(&self) -> usize { + self.0.fetch_add(1, Ordering::Relaxed) + } + + pub fn decrement(&self) -> usize { + let mut count; + while { + count = self.0.load(Ordering::Relaxed); + count > 0 + && self + .0 + .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) + != Ok(count) + } {} + count + } +} diff --git a/rpxy-lib/src/handler/cache.rs b/legacy-lib/src/handler/cache.rs similarity index 100% rename from rpxy-lib/src/handler/cache.rs rename to legacy-lib/src/handler/cache.rs diff --git a/rpxy-lib/src/handler/error.rs b/legacy-lib/src/handler/error.rs similarity index 100% rename from rpxy-lib/src/handler/error.rs rename to legacy-lib/src/handler/error.rs diff --git a/rpxy-lib/src/handler/forwarder.rs b/legacy-lib/src/handler/forwarder.rs similarity index 100% rename from rpxy-lib/src/handler/forwarder.rs rename to legacy-lib/src/handler/forwarder.rs diff --git a/rpxy-lib/src/handler/handler_main.rs b/legacy-lib/src/handler/handler_main.rs similarity index 100% rename from rpxy-lib/src/handler/handler_main.rs rename to legacy-lib/src/handler/handler_main.rs diff --git a/rpxy-lib/src/handler/mod.rs b/legacy-lib/src/handler/mod.rs similarity index 100% rename from rpxy-lib/src/handler/mod.rs rename to legacy-lib/src/handler/mod.rs diff --git a/rpxy-lib/src/handler/utils_headers.rs b/legacy-lib/src/handler/utils_headers.rs similarity index 100% rename from rpxy-lib/src/handler/utils_headers.rs rename to legacy-lib/src/handler/utils_headers.rs diff --git a/rpxy-lib/src/handler/utils_request.rs b/legacy-lib/src/handler/utils_request.rs similarity index 100% rename from rpxy-lib/src/handler/utils_request.rs rename to legacy-lib/src/handler/utils_request.rs diff --git a/rpxy-lib/src/handler/utils_synth_response.rs b/legacy-lib/src/handler/utils_synth_response.rs similarity index 100% rename from rpxy-lib/src/handler/utils_synth_response.rs rename to legacy-lib/src/handler/utils_synth_response.rs diff --git a/rpxy-lib/src/hyper_executor.rs b/legacy-lib/src/hyper_executor.rs similarity index 100% rename from rpxy-lib/src/hyper_executor.rs rename to legacy-lib/src/hyper_executor.rs diff --git a/legacy-lib/src/lib.rs b/legacy-lib/src/lib.rs new file mode 100644 index 0000000..a9f48ab --- /dev/null +++ b/legacy-lib/src/lib.rs @@ -0,0 +1,112 @@ +mod backend; +mod certs; +mod constants; +mod error; +mod globals; +mod handler; +mod hyper_executor; +mod log; +mod proxy; +mod utils; + +use crate::{error::*, globals::Globals, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder}; +use futures::future::select_all; +use hyper_executor::build_http_server; +use std::sync::Arc; + +pub use crate::{ + certs::{CertsAndKeys, CryptoSource}, + globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}, +}; +pub mod reexports { + pub use hyper::Uri; + pub use rustls::{Certificate, PrivateKey}; +} + +#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] +compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); + +/// Entrypoint that creates and spawns tasks of reverse proxy services +pub async fn entrypoint( + proxy_config: &ProxyConfig, + app_config_list: &AppConfigList, + runtime_handle: &tokio::runtime::Handle, + term_notify: Option>, +) -> Result<()> +where + T: CryptoSource + Clone + Send + Sync + 'static, +{ + // For initial message logging + if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { + info!("Listen both IPv4 and IPv6") + } else { + info!("Listen IPv4") + } + if proxy_config.http_port.is_some() { + info!("Listen port: {}", proxy_config.http_port.unwrap()); + } + if proxy_config.https_port.is_some() { + info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); + } + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + if proxy_config.http3 { + info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); + } + if !proxy_config.sni_consistency { + info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC."); + } + #[cfg(feature = "cache")] + if proxy_config.cache_enabled { + info!( + "Cache is enabled: cache dir = {:?}", + proxy_config.cache_dir.as_ref().unwrap() + ); + } else { + info!("Cache is disabled") + } + + // build global + let globals = Arc::new(Globals { + proxy_config: proxy_config.clone(), + backends: app_config_list.clone().try_into()?, + request_count: Default::default(), + runtime_handle: runtime_handle.clone(), + term_notify: term_notify.clone(), + }); + + // build message handler including a request forwarder + let msg_handler = Arc::new( + HttpMessageHandlerBuilder::default() + // .forwarder(Arc::new(Forwarder::new(&globals).await)) + .globals(globals.clone()) + .build()?, + ); + + let http_server = Arc::new(build_http_server(&globals)); + + let addresses = globals.proxy_config.listen_sockets.clone(); + let futures = select_all(addresses.into_iter().map(|addr| { + let mut tls_enabled = false; + if let Some(https_port) = globals.proxy_config.https_port { + tls_enabled = https_port == addr.port() + } + + let proxy = ProxyBuilder::default() + .globals(globals.clone()) + .listening_on(addr) + .tls_enabled(tls_enabled) + .http_server(http_server.clone()) + .msg_handler(msg_handler.clone()) + .build() + .unwrap(); + + globals.runtime_handle.spawn(async move { proxy.start().await }) + })); + + // wait for all future + if let (Ok(Err(e)), _, _) = futures.await { + error!("Some proxy services are down: {}", e); + }; + + Ok(()) +} diff --git a/legacy-lib/src/log.rs b/legacy-lib/src/log.rs new file mode 100644 index 0000000..6b8afbe --- /dev/null +++ b/legacy-lib/src/log.rs @@ -0,0 +1,98 @@ +use crate::utils::ToCanonical; +use hyper::header; +use std::net::SocketAddr; +pub use tracing::{debug, error, info, warn}; + +#[derive(Debug, Clone)] +pub struct MessageLog { + // pub tls_server_name: String, + pub client_addr: String, + pub method: String, + pub host: String, + pub p_and_q: String, + pub version: hyper::Version, + pub uri_scheme: String, + pub uri_host: String, + pub ua: String, + pub xff: String, + pub status: String, + pub upstream: String, +} + +impl From<&hyper::Request> for MessageLog { + fn from(req: &hyper::Request) -> Self { + let header_mapper = |v: header::HeaderName| { + req + .headers() + .get(v) + .map_or_else(|| "", |s| s.to_str().unwrap_or("")) + .to_string() + }; + Self { + // tls_server_name: "".to_string(), + client_addr: "".to_string(), + method: req.method().to_string(), + host: header_mapper(header::HOST), + p_and_q: req + .uri() + .path_and_query() + .map_or_else(|| "", |v| v.as_str()) + .to_string(), + version: req.version(), + uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), + uri_host: req.uri().host().unwrap_or("").to_string(), + ua: header_mapper(header::USER_AGENT), + xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), + status: "".to_string(), + upstream: "".to_string(), + } + } +} + +impl MessageLog { + pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { + self.client_addr = client_addr.to_canonical().to_string(); + self + } + // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self { + // self.tls_server_name = tls_server_name.to_string(); + // self + // } + pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { + self.status = status_code.to_string(); + self + } + pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { + self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); + self + } + pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { + self.upstream = upstream.to_string(); + self + } + + pub fn output(&self) { + info!( + "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", + if !self.host.is_empty() { + self.host.as_str() + } else { + self.uri_host.as_str() + }, + self.client_addr, + self.method, + self.p_and_q, + self.version, + self.status, + if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { + format!("{}://{}", self.uri_scheme, self.uri_host) + } else { + "".to_string() + }, + self.ua, + self.xff, + self.upstream, + // self.tls_server_name + ); + } +} diff --git a/rpxy-lib/src/proxy/crypto_service.rs b/legacy-lib/src/proxy/crypto_service.rs similarity index 100% rename from rpxy-lib/src/proxy/crypto_service.rs rename to legacy-lib/src/proxy/crypto_service.rs diff --git a/rpxy-lib/src/proxy/mod.rs b/legacy-lib/src/proxy/mod.rs similarity index 100% rename from rpxy-lib/src/proxy/mod.rs rename to legacy-lib/src/proxy/mod.rs diff --git a/rpxy-lib/src/proxy/proxy_client_cert.rs b/legacy-lib/src/proxy/proxy_client_cert.rs similarity index 100% rename from rpxy-lib/src/proxy/proxy_client_cert.rs rename to legacy-lib/src/proxy/proxy_client_cert.rs diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/legacy-lib/src/proxy/proxy_h3.rs similarity index 100% rename from rpxy-lib/src/proxy/proxy_h3.rs rename to legacy-lib/src/proxy/proxy_h3.rs diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/legacy-lib/src/proxy/proxy_main.rs similarity index 100% rename from rpxy-lib/src/proxy/proxy_main.rs rename to legacy-lib/src/proxy/proxy_main.rs diff --git a/rpxy-lib/src/proxy/proxy_quic_quinn.rs b/legacy-lib/src/proxy/proxy_quic_quinn.rs similarity index 100% rename from rpxy-lib/src/proxy/proxy_quic_quinn.rs rename to legacy-lib/src/proxy/proxy_quic_quinn.rs diff --git a/rpxy-lib/src/proxy/proxy_quic_s2n.rs b/legacy-lib/src/proxy/proxy_quic_s2n.rs similarity index 100% rename from rpxy-lib/src/proxy/proxy_quic_s2n.rs rename to legacy-lib/src/proxy/proxy_quic_s2n.rs diff --git a/rpxy-lib/src/proxy/proxy_tls.rs b/legacy-lib/src/proxy/proxy_tls.rs similarity index 100% rename from rpxy-lib/src/proxy/proxy_tls.rs rename to legacy-lib/src/proxy/proxy_tls.rs diff --git a/rpxy-lib/src/proxy/socket.rs b/legacy-lib/src/proxy/socket.rs similarity index 100% rename from rpxy-lib/src/proxy/socket.rs rename to legacy-lib/src/proxy/socket.rs diff --git a/rpxy-lib/src/utils/bytes_name.rs b/legacy-lib/src/utils/bytes_name.rs similarity index 100% rename from rpxy-lib/src/utils/bytes_name.rs rename to legacy-lib/src/utils/bytes_name.rs diff --git a/rpxy-lib/src/utils/mod.rs b/legacy-lib/src/utils/mod.rs similarity index 100% rename from rpxy-lib/src/utils/mod.rs rename to legacy-lib/src/utils/mod.rs diff --git a/rpxy-lib/src/utils/socket_addr.rs b/legacy-lib/src/utils/socket_addr.rs similarity index 100% rename from rpxy-lib/src/utils/socket_addr.rs rename to legacy-lib/src/utils/socket_addr.rs diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index dd6e744..0d72e56 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -12,15 +12,15 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-s2n", "cache"] -http3-quinn = ["rpxy-lib/http3-quinn"] -http3-s2n = ["rpxy-lib/http3-s2n"] -cache = ["rpxy-lib/cache"] -native-roots = ["rpxy-lib/native-roots"] +# default = ["http3-s2n", "cache"] +# http3-quinn = ["rpxy-lib/http3-quinn"] +# http3-s2n = ["rpxy-lib/http3-s2n"] +# cache = ["rpxy-lib/cache"] +# native-roots = ["rpxy-lib/native-roots"] [dependencies] rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ - "sticky-cookie", + # "sticky-cookie", ] } anyhow = "1.0.75" diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index fae0c3c..3db2b42 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -12,19 +12,19 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-s2n", "sticky-cookie", "cache"] -http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] -http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] -sticky-cookie = ["base64", "sha2", "chrono"] -cache = ["http-cache-semantics", "lru"] -native-roots = ["hyper-rustls/native-tokio"] +# default = ["http3-s2n", "sticky-cookie", "cache"] +# http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] +# http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] +# sticky-cookie = ["base64", "sha2", "chrono"] +# cache = ["http-cache-semantics", "lru"] +# native-roots = ["hyper-rustls/native-tokio"] [dependencies] -rand = "0.8.5" -rustc-hash = "1.1.0" -bytes = "1.5.0" -derive_builder = "0.12.0" -futures = { version = "0.3.29", features = ["alloc", "async-await"] } +# rand = "0.8.5" +# rustc-hash = "1.1.0" +# bytes = "1.5.0" +# derive_builder = "0.12.0" +# futures = { version = "0.3.29", features = ["alloc", "async-await"] } tokio = { version = "1.34.0", default-features = false, features = [ "net", "rt-multi-thread", @@ -34,7 +34,7 @@ tokio = { version = "1.34.0", default-features = false, features = [ "fs", ] } async-trait = "0.1.74" -hot_reload = "0.1.4" # reloading certs +# hot_reload = "0.1.4" # reloading certs # Error handling anyhow = "1.0.75" @@ -42,48 +42,48 @@ thiserror = "1.0.50" # http and tls http = "1.0.0" -http-body-util = "0.1.0" +# http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } -hyper-util = { version = "0.1.1", features = ["full"] } -hyper-rustls = { version = "0.24.2", default-features = false, features = [ - "tokio-runtime", - "webpki-tokio", - "http1", - "http2", -] } -tokio-rustls = { version = "0.24.1", features = ["early-data"] } +# hyper-util = { version = "0.1.1", features = ["full"] } +# hyper-rustls = { version = "0.24.2", default-features = false, features = [ +# "tokio-runtime", +# "webpki-tokio", +# "http1", +# "http2", +# ] } +# tokio-rustls = { version = "0.24.1", features = ["early-data"] } rustls = { version = "0.21.9", default-features = false } -webpki = "0.22.4" -x509-parser = "0.15.1" +# webpki = "0.22.4" +# x509-parser = "0.15.1" # logging tracing = { version = "0.1.40" } -# http/3 -quinn = { version = "0.10.2", optional = true } -h3 = { path = "../submodules/h3/h3/", optional = true } -h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } -s2n-quic = { version = "1.31.0", default-features = false, features = [ - "provider-tls-rustls", -], optional = true } -s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } -s2n-quic-rustls = { version = "0.31.0", optional = true } -# for UDP socket wit SO_REUSEADDR when h3 with quinn -socket2 = { version = "0.5.5", features = ["all"], optional = true } +# # http/3 +# quinn = { version = "0.10.2", optional = true } +# h3 = { path = "../submodules/h3/h3/", optional = true } +# h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } +# s2n-quic = { version = "1.31.0", default-features = false, features = [ +# "provider-tls-rustls", +# ], optional = true } +# s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } +# s2n-quic-rustls = { version = "0.31.0", optional = true } +# # for UDP socket wit SO_REUSEADDR when h3 with quinn +# socket2 = { version = "0.5.5", features = ["all"], optional = true } -# cache -http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } -lru = { version = "0.12.0", optional = true } +# # cache +# http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } +# lru = { version = "0.12.0", optional = true } -# cookie handling for sticky cookie -chrono = { version = "0.4.31", default-features = false, features = [ - "unstable-locales", - "alloc", - "clock", -], optional = true } -base64 = { version = "0.21.5", optional = true } -sha2 = { version = "0.10.8", default-features = false, optional = true } +# # cookie handling for sticky cookie +# chrono = { version = "0.4.31", default-features = false, features = [ +# "unstable-locales", +# "alloc", +# "clock", +# ], optional = true } +# base64 = { version = "0.21.5", optional = true } +# sha2 = { version = "0.10.8", default-features = false, optional = true } -[dev-dependencies] -# http and tls +# [dev-dependencies] +# # http and tls diff --git a/rpxy-lib/src/certs.rs b/rpxy-lib/src/certs.rs index c9cfafd..b93aa8f 100644 --- a/rpxy-lib/src/certs.rs +++ b/rpxy-lib/src/certs.rs @@ -1,11 +1,5 @@ use async_trait::async_trait; -use rustc_hash::FxHashSet as HashSet; -use rustls::{ - sign::{any_supported_type, CertifiedKey}, - Certificate, OwnedTrustAnchor, PrivateKey, -}; -use std::io; -use x509_parser::prelude::*; +use rustls::{Certificate, PrivateKey}; #[async_trait] // Trait to read certs and keys anywhere from KVS, file, sqlite, etc. @@ -26,66 +20,3 @@ pub struct CertsAndKeys { pub cert_keys: Vec, pub client_ca_certs: Option>, } - -impl CertsAndKeys { - pub fn parse_server_certs_and_keys(&self) -> Result { - // for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { - let signing_key = self - .cert_keys - .iter() - .find_map(|k| { - if let Ok(sk) = any_supported_type(k) { - Some(sk) - } else { - None - } - }) - .ok_or_else(|| { - io::Error::new( - io::ErrorKind::InvalidInput, - "Unable to find a valid certificate and key", - ) - })?; - Ok(CertifiedKey::new(self.certs.clone(), signing_key)) - } - - pub fn parse_client_ca_certs(&self) -> Result<(Vec, HashSet>), anyhow::Error> { - let certs = self.client_ca_certs.as_ref().ok_or(anyhow::anyhow!("No client cert"))?; - - let owned_trust_anchors: Vec<_> = certs - .iter() - .map(|v| { - // let trust_anchor = tokio_rustls::webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); - let trust_anchor = webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); - rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( - trust_anchor.subject, - trust_anchor.spki, - trust_anchor.name_constraints, - ) - }) - .collect(); - - // TODO: SKID is not used currently - let subject_key_identifiers: HashSet<_> = certs - .iter() - .filter_map(|v| { - // retrieve ca key id (subject key id) - let cert = parse_x509_certificate(&v.0).unwrap().1; - let subject_key_ids = cert - .iter_extensions() - .filter_map(|ext| match ext.parsed_extension() { - ParsedExtension::SubjectKeyIdentifier(skid) => Some(skid), - _ => None, - }) - .collect::>(); - if !subject_key_ids.is_empty() { - Some(subject_key_ids[0].0.to_owned()) - } else { - None - } - }) - .collect(); - - Ok((owned_trust_anchors, subject_key_identifiers)) - } -} diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index c672682..34769dc 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -1,86 +1,8 @@ pub use anyhow::{anyhow, bail, ensure, Context}; -use std::io; use thiserror::Error; pub type Result = std::result::Result; /// Describes things that can go wrong in the Rpxy #[derive(Debug, Error)] -pub enum RpxyError { - #[error("Proxy build error: {0}")] - ProxyBuild(#[from] crate::proxy::ProxyBuilderError), - - #[error("Backend build error: {0}")] - BackendBuild(#[from] crate::backend::BackendBuilderError), - - #[error("MessageHandler build error: {0}")] - HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError), - - #[error("Config builder error: {0}")] - ConfigBuild(&'static str), - - #[error("Http Message Handler Error: {0}")] - Handler(&'static str), - - #[error("Cache Error: {0}")] - Cache(&'static str), - - #[error("Http Request Message Error: {0}")] - Request(&'static str), - - #[error("TCP/UDP Proxy Layer Error: {0}")] - Proxy(String), - - #[allow(unused)] - #[error("LoadBalance Layer Error: {0}")] - LoadBalance(String), - - #[error("I/O Error: {0}")] - Io(#[from] io::Error), - - // #[error("Toml Deserialization Error")] - // TomlDe(#[from] toml::de::Error), - #[cfg(feature = "http3-quinn")] - #[error("Quic Connection Error [quinn]: {0}")] - QuicConn(#[from] quinn::ConnectionError), - - #[cfg(feature = "http3-s2n")] - #[error("Quic Connection Error [s2n-quic]: {0}")] - QUicConn(#[from] s2n_quic::connection::Error), - - #[cfg(feature = "http3-quinn")] - #[error("H3 Error [quinn]: {0}")] - H3(#[from] h3::Error), - - #[cfg(feature = "http3-s2n")] - #[error("H3 Error [s2n-quic]: {0}")] - H3(#[from] s2n_quic_h3::h3::Error), - - #[error("rustls Connection Error: {0}")] - Rustls(#[from] rustls::Error), - - #[error("Hyper Error: {0}")] - Hyper(#[from] hyper::Error), - - #[error("Hyper Http Error: {0}")] - HyperHttp(#[from] hyper::http::Error), - - #[error("Hyper Http HeaderValue Error: {0}")] - HyperHeaderValue(#[from] hyper::header::InvalidHeaderValue), - - #[error("Hyper Http HeaderName Error: {0}")] - HyperHeaderName(#[from] hyper::header::InvalidHeaderName), - - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -#[allow(dead_code)] -#[derive(Debug, Error, Clone)] -pub enum ClientCertsError { - #[error("TLS Client Certificate is Required for Given SNI: {0}")] - ClientCertRequired(String), - - #[error("Inconsistent TLS Client Certificate for Given SNI: {0}")] - InconsistentClientCert(String), -} +pub enum RpxyError {} diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index 02605a6..2db2805 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -1,42 +1,5 @@ -use crate::{ - backend::{ - Backend, BackendBuilder, Backends, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption, - }, - certs::CryptoSource, - constants::*, - error::RpxyError, - log::*, - utils::{BytesName, PathNameBytesExp}, -}; -use rustc_hash::FxHashMap as HashMap; -use std::net::SocketAddr; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; -use tokio::time::Duration; - -/// Global object containing proxy configurations and shared object like counters. -/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. -pub struct Globals -where - T: CryptoSource, -{ - /// Configuration parameters for proxy transport and request handlers - pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも? - - /// Backend application objects to which http request handler forward incoming requests - pub backends: Backends, - - /// Shared context - Counter for serving requests - pub request_count: RequestCount, - - /// Shared context - Async task runtime handler - pub runtime_handle: tokio::runtime::Handle, - - /// Shared context - Notify object to stop async tasks - pub term_notify: Option>, -} +use crate::{certs::CryptoSource, constants::*}; +use std::{net::SocketAddr, time::Duration}; /// Configuration parameters for proxy transport and request handlers #[derive(PartialEq, Eq, Clone)] @@ -140,44 +103,6 @@ where pub inner: Vec>, pub default_app: Option, } -impl TryInto> for AppConfigList -where - T: CryptoSource + Clone, -{ - type Error = RpxyError; - - fn try_into(self) -> Result, Self::Error> { - let mut backends = Backends::new(); - for app_config in self.inner.iter() { - let backend = app_config.try_into()?; - backends - .apps - .insert(app_config.server_name.clone().to_server_name_vec(), backend); - info!( - "Registering application {} ({})", - &app_config.server_name, &app_config.app_name - ); - } - - // default backend application for plaintext http requests - if let Some(d) = self.default_app { - let d_sn: Vec<&str> = backends - .apps - .iter() - .filter(|(_k, v)| v.app_name == d) - .map(|(_, v)| v.server_name.as_ref()) - .collect(); - if !d_sn.is_empty() { - info!( - "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", - d, d_sn[0] - ); - backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); - } - } - Ok(backends) - } -} /// Configuration parameters for single backend application #[derive(PartialEq, Eq, Clone)] @@ -190,77 +115,6 @@ where pub reverse_proxy: Vec, pub tls: Option>, } -impl TryInto> for &AppConfig -where - T: CryptoSource + Clone, -{ - type Error = RpxyError; - - fn try_into(self) -> Result, Self::Error> { - // backend builder - let mut backend_builder = BackendBuilder::default(); - // reverse proxy settings - let reverse_proxy = self.try_into()?; - - backend_builder - .app_name(self.app_name.clone()) - .server_name(self.server_name.clone()) - .reverse_proxy(reverse_proxy); - - // TLS settings and build backend instance - let backend = if self.tls.is_none() { - backend_builder.build().map_err(RpxyError::BackendBuild)? - } else { - let tls = self.tls.as_ref().unwrap(); - - backend_builder - .https_redirection(Some(tls.https_redirection)) - .crypto_source(Some(tls.inner.clone())) - .build()? - }; - Ok(backend) - } -} -impl TryInto for &AppConfig -where - T: CryptoSource + Clone, -{ - type Error = RpxyError; - - fn try_into(self) -> Result { - let mut upstream: HashMap = HashMap::default(); - - self.reverse_proxy.iter().for_each(|rpo| { - let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); - // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); - // let lb_upstream_num = vec_upstream.len(); - let elem = UpstreamGroupBuilder::default() - .upstream(&upstream_vec) - .path(&rpo.path) - .replace_path(&rpo.replace_path) - .lb(&rpo.load_balance, &upstream_vec, &self.server_name, &rpo.path) - .opts(&rpo.upstream_options) - .build() - .unwrap(); - - upstream.insert(elem.path.clone(), elem); - }); - if self.reverse_proxy.iter().filter(|rpo| rpo.path.is_none()).count() >= 2 { - error!("Multiple default reverse proxy setting"); - return Err(RpxyError::ConfigBuild("Invalid reverse proxy setting")); - } - - if !(upstream.iter().all(|(_, elem)| { - !(elem.opts.contains(&UpstreamOption::ForceHttp11Upstream) - && elem.opts.contains(&UpstreamOption::ForceHttp2Upstream)) - })) { - error!("Either one of force_http11 or force_http2 can be enabled"); - return Err(RpxyError::ConfigBuild("Invalid upstream option setting")); - } - - Ok(ReverseProxy { upstream }) - } -} /// Configuration parameters for single reverse proxy corresponding to the path #[derive(PartialEq, Eq, Clone)] @@ -275,16 +129,7 @@ pub struct ReverseProxyConfig { /// Configuration parameters for single upstream destination from a reverse proxy #[derive(PartialEq, Eq, Clone)] pub struct UpstreamUri { - pub inner: hyper::Uri, -} -impl TryInto for &UpstreamUri { - type Error = anyhow::Error; - - fn try_into(self) -> std::result::Result { - Ok(Upstream { - uri: self.inner.clone(), - }) - } + pub inner: http::Uri, } /// Configuration parameters on TLS for a single backend application @@ -296,30 +141,3 @@ where pub inner: T, pub https_redirection: bool, } - -#[derive(Debug, Clone, Default)] -/// Counter for serving requests -pub struct RequestCount(Arc); - -impl RequestCount { - pub fn current(&self) -> usize { - self.0.load(Ordering::Relaxed) - } - - pub fn increment(&self) -> usize { - self.0.fetch_add(1, Ordering::Relaxed) - } - - pub fn decrement(&self) -> usize { - let mut count; - while { - count = self.0.load(Ordering::Relaxed); - count > 0 - && self - .0 - .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) - != Ok(count) - } {} - count - } -} diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 7f7ade2..b201b52 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -1,19 +1,11 @@ -mod backend; mod certs; mod constants; mod error; mod globals; -mod handler; -mod hyper_executor; mod log; -mod proxy; -mod utils; -use crate::{error::*, globals::Globals, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder}; -use futures::future::select_all; -use hyper_executor::build_http_server; -// use hyper_trust_dns::TrustDnsResolver; -use std::{sync::Arc, time::Duration}; +use crate::{error::*, log::*}; +use std::sync::Arc; pub use crate::{ certs::{CertsAndKeys, CryptoSource}, @@ -66,48 +58,48 @@ where info!("Cache is disabled") } - // build global - let globals = Arc::new(Globals { - proxy_config: proxy_config.clone(), - backends: app_config_list.clone().try_into()?, - request_count: Default::default(), - runtime_handle: runtime_handle.clone(), - term_notify: term_notify.clone(), - }); + // // build global + // let globals = Arc::new(Globals { + // proxy_config: proxy_config.clone(), + // backends: app_config_list.clone().try_into()?, + // request_count: Default::default(), + // runtime_handle: runtime_handle.clone(), + // term_notify: term_notify.clone(), + // }); - // build message handler including a request forwarder - let msg_handler = Arc::new( - HttpMessageHandlerBuilder::default() - // .forwarder(Arc::new(Forwarder::new(&globals).await)) - .globals(globals.clone()) - .build()?, - ); + // // build message handler including a request forwarder + // let msg_handler = Arc::new( + // HttpMessageHandlerBuilder::default() + // // .forwarder(Arc::new(Forwarder::new(&globals).await)) + // .globals(globals.clone()) + // .build()?, + // ); - let http_server = Arc::new(build_http_server(&globals)); + // let http_server = Arc::new(build_http_server(&globals)); - let addresses = globals.proxy_config.listen_sockets.clone(); - let futures = select_all(addresses.into_iter().map(|addr| { - let mut tls_enabled = false; - if let Some(https_port) = globals.proxy_config.https_port { - tls_enabled = https_port == addr.port() - } + // let addresses = globals.proxy_config.listen_sockets.clone(); + // let futures = select_all(addresses.into_iter().map(|addr| { + // let mut tls_enabled = false; + // if let Some(https_port) = globals.proxy_config.https_port { + // tls_enabled = https_port == addr.port() + // } - let proxy = ProxyBuilder::default() - .globals(globals.clone()) - .listening_on(addr) - .tls_enabled(tls_enabled) - .http_server(http_server.clone()) - .msg_handler(msg_handler.clone()) - .build() - .unwrap(); + // let proxy = ProxyBuilder::default() + // .globals(globals.clone()) + // .listening_on(addr) + // .tls_enabled(tls_enabled) + // .http_server(http_server.clone()) + // .msg_handler(msg_handler.clone()) + // .build() + // .unwrap(); - globals.runtime_handle.spawn(async move { proxy.start().await }) - })); + // globals.runtime_handle.spawn(async move { proxy.start().await }) + // })); - // wait for all future - if let (Ok(Err(e)), _, _) = futures.await { - error!("Some proxy services are down: {}", e); - }; + // // wait for all future + // if let (Ok(Err(e)), _, _) = futures.await { + // error!("Some proxy services are down: {}", e); + // }; Ok(()) } diff --git a/rpxy-lib/src/log.rs b/rpxy-lib/src/log.rs index 6b8afbe..c55b5c2 100644 --- a/rpxy-lib/src/log.rs +++ b/rpxy-lib/src/log.rs @@ -1,98 +1 @@ -use crate::utils::ToCanonical; -use hyper::header; -use std::net::SocketAddr; pub use tracing::{debug, error, info, warn}; - -#[derive(Debug, Clone)] -pub struct MessageLog { - // pub tls_server_name: String, - pub client_addr: String, - pub method: String, - pub host: String, - pub p_and_q: String, - pub version: hyper::Version, - pub uri_scheme: String, - pub uri_host: String, - pub ua: String, - pub xff: String, - pub status: String, - pub upstream: String, -} - -impl From<&hyper::Request> for MessageLog { - fn from(req: &hyper::Request) -> Self { - let header_mapper = |v: header::HeaderName| { - req - .headers() - .get(v) - .map_or_else(|| "", |s| s.to_str().unwrap_or("")) - .to_string() - }; - Self { - // tls_server_name: "".to_string(), - client_addr: "".to_string(), - method: req.method().to_string(), - host: header_mapper(header::HOST), - p_and_q: req - .uri() - .path_and_query() - .map_or_else(|| "", |v| v.as_str()) - .to_string(), - version: req.version(), - uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), - uri_host: req.uri().host().unwrap_or("").to_string(), - ua: header_mapper(header::USER_AGENT), - xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), - status: "".to_string(), - upstream: "".to_string(), - } - } -} - -impl MessageLog { - pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { - self.client_addr = client_addr.to_canonical().to_string(); - self - } - // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self { - // self.tls_server_name = tls_server_name.to_string(); - // self - // } - pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { - self.status = status_code.to_string(); - self - } - pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { - self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); - self - } - pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { - self.upstream = upstream.to_string(); - self - } - - pub fn output(&self) { - info!( - "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", - if !self.host.is_empty() { - self.host.as_str() - } else { - self.uri_host.as_str() - }, - self.client_addr, - self.method, - self.p_and_q, - self.version, - self.status, - if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { - format!("{}://{}", self.uri_scheme, self.uri_host) - } else { - "".to_string() - }, - self.ua, - self.xff, - self.upstream, - // self.tls_server_name - ); - } -} From de91c7a68f0d30ae57f757eb03e9d825ddd4a648 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Wed, 22 Nov 2023 22:48:14 +0900 Subject: [PATCH 04/50] wip: refactoring all the structure and improve error messages --- rpxy-bin/Cargo.toml | 10 ++--- rpxy-lib/Cargo.toml | 20 ++++----- rpxy-lib/src/count.rs | 31 ++++++++++++++ rpxy-lib/src/error.rs | 7 ++- rpxy-lib/src/globals.rs | 29 ++++++++++--- rpxy-lib/src/hyper_executor.rs | 23 ++++++++++ rpxy-lib/src/lib.rs | 73 +++++++++++++++++--------------- rpxy-lib/src/proxy/mod.rs | 22 ++++++++++ rpxy-lib/src/proxy/proxy_main.rs | 63 +++++++++++++++++++++++++++ rpxy-lib/src/proxy/socket.rs | 46 ++++++++++++++++++++ 10 files changed, 268 insertions(+), 56 deletions(-) create mode 100644 rpxy-lib/src/count.rs create mode 100644 rpxy-lib/src/hyper_executor.rs create mode 100644 rpxy-lib/src/proxy/mod.rs create mode 100644 rpxy-lib/src/proxy/proxy_main.rs create mode 100644 rpxy-lib/src/proxy/socket.rs diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 0d72e56..8cddc71 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -12,11 +12,11 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -# default = ["http3-s2n", "cache"] -# http3-quinn = ["rpxy-lib/http3-quinn"] -# http3-s2n = ["rpxy-lib/http3-s2n"] -# cache = ["rpxy-lib/cache"] -# native-roots = ["rpxy-lib/native-roots"] +default = ["http3-s2n", "cache"] +http3-quinn = ["rpxy-lib/http3-quinn"] +http3-s2n = ["rpxy-lib/http3-s2n"] +cache = ["rpxy-lib/cache"] +native-roots = ["rpxy-lib/native-roots"] [dependencies] rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 3db2b42..cb6b5b6 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -12,19 +12,19 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -# default = ["http3-s2n", "sticky-cookie", "cache"] -# http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] -# http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] -# sticky-cookie = ["base64", "sha2", "chrono"] -# cache = ["http-cache-semantics", "lru"] -# native-roots = ["hyper-rustls/native-tokio"] +default = ["http3-s2n", "sticky-cookie", "cache"] +http3-quinn = ["socket2"] #"quinn", "h3", "h3-quinn", ] +http3-s2n = [] #"h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] +sticky-cookie = [] #"base64", "sha2", "chrono"] +cache = [] #"http-cache-semantics", "lru"] +native-roots = [] #"hyper-rustls/native-tokio"] [dependencies] # rand = "0.8.5" # rustc-hash = "1.1.0" # bytes = "1.5.0" # derive_builder = "0.12.0" -# futures = { version = "0.3.29", features = ["alloc", "async-await"] } +futures = { version = "0.3.29", features = ["alloc", "async-await"] } tokio = { version = "1.34.0", default-features = false, features = [ "net", "rt-multi-thread", @@ -44,7 +44,7 @@ thiserror = "1.0.50" http = "1.0.0" # http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } -# hyper-util = { version = "0.1.1", features = ["full"] } +hyper-util = { version = "0.1.1", features = ["full"] } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio", @@ -68,8 +68,8 @@ tracing = { version = "0.1.40" } # ], optional = true } # s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } # s2n-quic-rustls = { version = "0.31.0", optional = true } -# # for UDP socket wit SO_REUSEADDR when h3 with quinn -# socket2 = { version = "0.5.5", features = ["all"], optional = true } +# for UDP socket wit SO_REUSEADDR when h3 with quinn +socket2 = { version = "0.5.5", features = ["all"], optional = true } # # cache # http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } diff --git a/rpxy-lib/src/count.rs b/rpxy-lib/src/count.rs new file mode 100644 index 0000000..2ca4028 --- /dev/null +++ b/rpxy-lib/src/count.rs @@ -0,0 +1,31 @@ +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +#[derive(Debug, Clone, Default)] +/// Counter for serving requests +pub struct RequestCount(Arc); + +impl RequestCount { + pub fn current(&self) -> usize { + self.0.load(Ordering::Relaxed) + } + + pub fn increment(&self) -> usize { + self.0.fetch_add(1, Ordering::Relaxed) + } + + pub fn decrement(&self) -> usize { + let mut count; + while { + count = self.0.load(Ordering::Relaxed); + count > 0 + && self + .0 + .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) + != Ok(count) + } {} + count + } +} diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 34769dc..7662a9d 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -1,8 +1,11 @@ pub use anyhow::{anyhow, bail, ensure, Context}; use thiserror::Error; -pub type Result = std::result::Result; +pub type RpxyResult = std::result::Result; /// Describes things that can go wrong in the Rpxy #[derive(Debug, Error)] -pub enum RpxyError {} +pub enum RpxyError { + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index 2db2805..88d6cbf 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -1,13 +1,30 @@ -use crate::{certs::CryptoSource, constants::*}; -use std::{net::SocketAddr, time::Duration}; +use crate::{certs::CryptoSource, constants::*, count::RequestCount}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; + +/// Global object containing proxy configurations and shared object like counters. +/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. +pub struct Globals { + /// Configuration parameters for proxy transport and request handlers + pub proxy_config: ProxyConfig, + /// Shared context - Counter for serving requests + pub request_count: RequestCount, + /// Shared context - Async task runtime handler + pub runtime_handle: tokio::runtime::Handle, + /// Shared context - Notify object to stop async tasks + pub term_notify: Option>, +} /// Configuration parameters for proxy transport and request handlers #[derive(PartialEq, Eq, Clone)] pub struct ProxyConfig { - pub listen_sockets: Vec, // when instantiate server - pub http_port: Option, // when instantiate server - pub https_port: Option, // when instantiate server - pub tcp_listen_backlog: u32, // when instantiate server + /// listen socket addresses + pub listen_sockets: Vec, + /// http port + pub http_port: Option, + /// https port + pub https_port: Option, + /// tcp listen backlog + pub tcp_listen_backlog: u32, pub proxy_timeout: Duration, // when serving requests at Proxy pub upstream_timeout: Duration, // when serving requests at Handler diff --git a/rpxy-lib/src/hyper_executor.rs b/rpxy-lib/src/hyper_executor.rs new file mode 100644 index 0000000..579251e --- /dev/null +++ b/rpxy-lib/src/hyper_executor.rs @@ -0,0 +1,23 @@ +use tokio::runtime::Handle; + +#[derive(Clone)] +/// Executor for hyper +pub struct LocalExecutor { + runtime_handle: Handle, +} + +impl LocalExecutor { + pub fn new(runtime_handle: Handle) -> Self { + LocalExecutor { runtime_handle } + } +} + +impl hyper::rt::Executor for LocalExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send, +{ + fn execute(&self, fut: F) { + self.runtime_handle.spawn(fut); + } +} diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index b201b52..e0a9d21 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -1,10 +1,14 @@ mod certs; mod constants; +mod count; mod error; mod globals; +mod hyper_executor; mod log; +mod proxy; -use crate::{error::*, log::*}; +use crate::{error::*, globals::Globals, log::*, proxy::Proxy}; +use futures::future::select_all; use std::sync::Arc; pub use crate::{ @@ -25,7 +29,7 @@ pub async fn entrypoint( app_config_list: &AppConfigList, runtime_handle: &tokio::runtime::Handle, term_notify: Option>, -) -> Result<()> +) -> RpxyResult<()> where T: CryptoSource + Clone + Send + Sync + 'static, { @@ -58,15 +62,18 @@ where info!("Cache is disabled") } - // // build global - // let globals = Arc::new(Globals { - // proxy_config: proxy_config.clone(), - // backends: app_config_list.clone().try_into()?, - // request_count: Default::default(), - // runtime_handle: runtime_handle.clone(), - // term_notify: term_notify.clone(), - // }); + // build global shared context + let globals = Arc::new(Globals { + proxy_config: proxy_config.clone(), + request_count: Default::default(), + runtime_handle: runtime_handle.clone(), + term_notify: term_notify.clone(), + }); + // TODO: 1. build backends, and make it contained in Arc + // app_config_list: app_config_list.clone(), + + // TODO: 2. build message handler with Arc-ed http_client and backends, and make it contained in Arc as well // // build message handler including a request forwarder // let msg_handler = Arc::new( // HttpMessageHandlerBuilder::default() @@ -75,31 +82,31 @@ where // .build()?, // ); - // let http_server = Arc::new(build_http_server(&globals)); + // TODO: 3. spawn each proxy for a given socket with copied Arc-ed message_handler. + // build hyper connection builder shared with proxy instances + let connection_builder = proxy::connection_builder(&globals); - // let addresses = globals.proxy_config.listen_sockets.clone(); - // let futures = select_all(addresses.into_iter().map(|addr| { - // let mut tls_enabled = false; - // if let Some(https_port) = globals.proxy_config.https_port { - // tls_enabled = https_port == addr.port() - // } + // spawn each proxy for a given socket with copied Arc-ed backend, message_handler and connection builder. + let addresses = globals.proxy_config.listen_sockets.clone(); + let futures_iter = addresses.into_iter().map(|listening_on| { + let mut tls_enabled = false; + if let Some(https_port) = globals.proxy_config.https_port { + tls_enabled = https_port == listening_on.port() + } + let proxy = Proxy { + globals: globals.clone(), + listening_on, + tls_enabled, + connection_builder: connection_builder.clone(), + // TODO: message_handler + }; + globals.runtime_handle.spawn(async move { proxy.start().await }) + }); - // let proxy = ProxyBuilder::default() - // .globals(globals.clone()) - // .listening_on(addr) - // .tls_enabled(tls_enabled) - // .http_server(http_server.clone()) - // .msg_handler(msg_handler.clone()) - // .build() - // .unwrap(); - - // globals.runtime_handle.spawn(async move { proxy.start().await }) - // })); - - // // wait for all future - // if let (Ok(Err(e)), _, _) = futures.await { - // error!("Some proxy services are down: {}", e); - // }; + // wait for all future + if let (Ok(Err(e)), _, _) = select_all(futures_iter).await { + error!("Some proxy services are down: {}", e); + }; Ok(()) } diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs new file mode 100644 index 0000000..9b55e1d --- /dev/null +++ b/rpxy-lib/src/proxy/mod.rs @@ -0,0 +1,22 @@ +mod proxy_main; +mod socket; + +use crate::{globals::Globals, hyper_executor::LocalExecutor}; +use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; +use std::sync::Arc; + +pub(crate) use proxy_main::Proxy; + +/// build connection builder shared with proxy instances +pub(crate) fn connection_builder(globals: &Arc) -> Arc> { + let executor = LocalExecutor::new(globals.runtime_handle.clone()); + let mut http_server = server::conn::auto::Builder::new(executor); + http_server + .http1() + .keep_alive(globals.proxy_config.keepalive) + .pipeline_flush(true); + http_server + .http2() + .max_concurrent_streams(globals.proxy_config.max_concurrent_streams); + Arc::new(http_server) +} diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs new file mode 100644 index 0000000..416ecac --- /dev/null +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -0,0 +1,63 @@ +use super::socket::bind_tcp_socket; +use crate::{error::RpxyResult, globals::Globals, hyper_executor::LocalExecutor, log::*}; +use hyper_util::server::conn::auto::Builder as ConnectionBuilder; +use std::{net::SocketAddr, sync::Arc}; + +/// Proxy main object responsible to serve requests received from clients at the given socket address. +pub(crate) struct Proxy { + /// global context shared among async tasks + pub globals: Arc, + /// listen socket address + pub listening_on: SocketAddr, + /// whether TLS is enabled or not + pub tls_enabled: bool, + /// hyper connection builder serving http request + pub connection_builder: Arc>, +} + +impl Proxy { + /// Start without TLS (HTTP cleartext) + async fn start_without_tls(&self) -> RpxyResult<()> { + let listener_service = async { + let tcp_socket = bind_tcp_socket(&self.listening_on)?; + let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; + info!("Start TCP proxy serving with HTTP request for configured host names"); + while let Ok((stream, client_addr)) = tcp_listener.accept().await { + // self.serve_connection(TokioIo::new(stream), client_addr, None); + } + Ok(()) as RpxyResult<()> + }; + listener_service.await?; + Ok(()) + } + + /// Entrypoint for HTTP/1.1, 2 and 3 servers + pub async fn start(&self) -> RpxyResult<()> { + let proxy_service = async { + // if self.tls_enabled { + // self.start_with_tls().await + // } else { + self.start_without_tls().await + // } + }; + + match &self.globals.term_notify { + Some(term) => { + tokio::select! { + _ = proxy_service => { + warn!("Proxy service got down"); + } + _ = term.notified() => { + info!("Proxy service listening on {} receives term signal", self.listening_on); + } + } + } + None => { + proxy_service.await?; + warn!("Proxy service got down"); + } + } + + Ok(()) + } +} diff --git a/rpxy-lib/src/proxy/socket.rs b/rpxy-lib/src/proxy/socket.rs new file mode 100644 index 0000000..322b42b --- /dev/null +++ b/rpxy-lib/src/proxy/socket.rs @@ -0,0 +1,46 @@ +use crate::{error::*, log::*}; +#[cfg(feature = "http3-quinn")] +use socket2::{Domain, Protocol, Socket, Type}; +use std::net::SocketAddr; +#[cfg(feature = "http3-quinn")] +use std::net::UdpSocket; +use tokio::net::TcpSocket; + +/// Bind TCP socket to the given `SocketAddr`, and returns the TCP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options. +/// This option is required to re-bind the socket address when the proxy instance is reconstructed. +pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> RpxyResult { + let tcp_socket = if listening_on.is_ipv6() { + TcpSocket::new_v6() + } else { + TcpSocket::new_v4() + }?; + tcp_socket.set_reuseaddr(true)?; + tcp_socket.set_reuseport(true)?; + if let Err(e) = tcp_socket.bind(*listening_on) { + error!("Failed to bind TCP socket: {}", e); + return Err(RpxyError::Io(e)); + }; + Ok(tcp_socket) +} + +#[cfg(feature = "http3-quinn")] +/// Bind UDP socket to the given `SocketAddr`, and returns the UDP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options. +/// This option is required to re-bind the socket address when the proxy instance is reconstructed. +pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> RpxyResult { + let socket = if listening_on.is_ipv6() { + Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) + } else { + Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) + }?; + socket.set_reuse_address(true)?; // This isn't necessary? + socket.set_reuse_port(true)?; + socket.set_nonblocking(true)?; // This was made true inside quinn. so this line isn't necessary here. but just in case. + + if let Err(e) = socket.bind(&(*listening_on).into()) { + error!("Failed to bind UDP socket: {}", e); + return Err(RpxyError::Io(e)); + }; + let udp_socket: UdpSocket = socket.into(); + + Ok(udp_socket) +} From 3c6e4e575722e192bf99a4059c8eb47f50bce19e Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 24 Nov 2023 17:57:33 +0900 Subject: [PATCH 05/50] wip: implemented backend --- rpxy-lib/Cargo.toml | 26 +- rpxy-lib/src/backend/backend_main.rs | 136 +++++++++ .../backend/load_balance/load_balance_main.rs | 135 +++++++++ .../load_balance/load_balance_sticky.rs | 137 +++++++++ rpxy-lib/src/backend/load_balance/mod.rs | 41 +++ .../src/backend/load_balance/sticky_cookie.rs | 205 ++++++++++++++ rpxy-lib/src/backend/mod.rs | 14 + rpxy-lib/src/backend/upstream.rs | 266 ++++++++++++++++++ rpxy-lib/src/backend/upstream_opts.rs | 22 ++ rpxy-lib/src/error.rs | 11 + rpxy-lib/src/lib.rs | 6 +- rpxy-lib/src/name_exp.rs | 160 +++++++++++ rpxy-lib/src/proxy/crypto_service.rs | 0 rpxy-lib/src/proxy/mod.rs | 1 + rpxy-lib/src/proxy/proxy_main.rs | 30 +- rpxy-lib/src/proxy/proxy_tls.rs | 6 + 16 files changed, 1173 insertions(+), 23 deletions(-) create mode 100644 rpxy-lib/src/backend/backend_main.rs create mode 100644 rpxy-lib/src/backend/load_balance/load_balance_main.rs create mode 100644 rpxy-lib/src/backend/load_balance/load_balance_sticky.rs create mode 100644 rpxy-lib/src/backend/load_balance/mod.rs create mode 100644 rpxy-lib/src/backend/load_balance/sticky_cookie.rs create mode 100644 rpxy-lib/src/backend/mod.rs create mode 100644 rpxy-lib/src/backend/upstream.rs create mode 100644 rpxy-lib/src/backend/upstream_opts.rs create mode 100644 rpxy-lib/src/name_exp.rs create mode 100644 rpxy-lib/src/proxy/crypto_service.rs create mode 100644 rpxy-lib/src/proxy/proxy_tls.rs diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index cb6b5b6..2de993d 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -15,15 +15,15 @@ publish = false default = ["http3-s2n", "sticky-cookie", "cache"] http3-quinn = ["socket2"] #"quinn", "h3", "h3-quinn", ] http3-s2n = [] #"h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] -sticky-cookie = [] #"base64", "sha2", "chrono"] +sticky-cookie = ["base64", "sha2", "chrono"] cache = [] #"http-cache-semantics", "lru"] native-roots = [] #"hyper-rustls/native-tokio"] [dependencies] -# rand = "0.8.5" -# rustc-hash = "1.1.0" +rand = "0.8.5" +rustc-hash = "1.1.0" # bytes = "1.5.0" -# derive_builder = "0.12.0" +derive_builder = "0.12.0" futures = { version = "0.3.29", features = ["alloc", "async-await"] } tokio = { version = "1.34.0", default-features = false, features = [ "net", @@ -34,13 +34,13 @@ tokio = { version = "1.34.0", default-features = false, features = [ "fs", ] } async-trait = "0.1.74" -# hot_reload = "0.1.4" # reloading certs # Error handling anyhow = "1.0.75" thiserror = "1.0.50" # http and tls +hot_reload = "0.1.4" # reloading certs http = "1.0.0" # http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } @@ -75,14 +75,14 @@ socket2 = { version = "0.5.5", features = ["all"], optional = true } # http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } # lru = { version = "0.12.0", optional = true } -# # cookie handling for sticky cookie -# chrono = { version = "0.4.31", default-features = false, features = [ -# "unstable-locales", -# "alloc", -# "clock", -# ], optional = true } -# base64 = { version = "0.21.5", optional = true } -# sha2 = { version = "0.10.8", default-features = false, optional = true } +# cookie handling for sticky cookie +chrono = { version = "0.4.31", default-features = false, features = [ + "unstable-locales", + "alloc", + "clock", +], optional = true } +base64 = { version = "0.21.5", optional = true } +sha2 = { version = "0.10.8", default-features = false, optional = true } # [dev-dependencies] diff --git a/rpxy-lib/src/backend/backend_main.rs b/rpxy-lib/src/backend/backend_main.rs new file mode 100644 index 0000000..695a063 --- /dev/null +++ b/rpxy-lib/src/backend/backend_main.rs @@ -0,0 +1,136 @@ +use crate::{ + certs::CryptoSource, + error::*, + log::*, + name_exp::{ByteName, ServerName}, + AppConfig, AppConfigList, +}; +use derive_builder::Builder; +use rustc_hash::FxHashMap as HashMap; +use std::borrow::Cow; + +use super::upstream::PathManager; + +/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. +#[derive(Builder)] +pub struct BackendApp +where + T: CryptoSource, +{ + #[builder(setter(into))] + /// backend application name, e.g., app1 + pub app_name: String, + #[builder(setter(custom))] + /// server name, e.g., example.com, in [[ServerName]] object + pub server_name: ServerName, + /// struct of reverse proxy serving incoming request + pub path_manager: PathManager, + /// tls settings: https redirection with 30x + #[builder(default)] + pub https_redirection: Option, + /// TLS settings: source meta for server cert, key, client ca cert + #[builder(default)] + pub crypto_source: Option, +} +impl<'a, T> BackendAppBuilder +where + T: CryptoSource, +{ + pub fn server_name(&mut self, server_name: impl Into>) -> &mut Self { + self.server_name = Some(server_name.to_server_name()); + self + } +} + +/// HashMap and some meta information for multiple Backend structs. +pub struct BackendAppManager +where + T: CryptoSource, +{ + /// HashMap of Backend structs, key is server name + pub apps: HashMap>, + /// for plaintext http + pub default_server_name: Option, +} + +impl Default for BackendAppManager +where + T: CryptoSource, +{ + fn default() -> Self { + Self { + apps: HashMap::>::default(), + default_server_name: None, + } + } +} + +impl TryFrom<&AppConfig> for BackendApp +where + T: CryptoSource + Clone, +{ + type Error = RpxyError; + + fn try_from(app_config: &AppConfig) -> Result { + let mut backend_builder = BackendAppBuilder::default(); + let path_manager = PathManager::try_from(app_config)?; + backend_builder + .app_name(app_config.app_name.clone()) + .server_name(app_config.server_name.clone()) + .path_manager(path_manager); + // TLS settings and build backend instance + let backend = if app_config.tls.is_none() { + backend_builder.build()? + } else { + let tls = app_config.tls.as_ref().unwrap(); + backend_builder + .https_redirection(Some(tls.https_redirection)) + .crypto_source(Some(tls.inner.clone())) + .build()? + }; + Ok(backend) + } +} + +impl TryFrom<&AppConfigList> for BackendAppManager +where + T: CryptoSource + Clone, +{ + type Error = RpxyError; + + fn try_from(config_list: &AppConfigList) -> Result { + let mut manager = Self::default(); + for app_config in config_list.inner.iter() { + let backend: BackendApp = BackendApp::try_from(app_config)?; + manager + .apps + .insert(app_config.server_name.clone().to_server_name(), backend); + + info!( + "Registering application {} ({})", + &app_config.server_name, &app_config.app_name + ); + } + + // default backend application for plaintext http requests + if let Some(default_app_name) = &config_list.default_app { + let default_server_name = manager + .apps + .iter() + .filter(|(_k, v)| &v.app_name == default_app_name) + .map(|(_, v)| v.server_name.clone()) + .collect::>(); + + if !default_server_name.is_empty() { + info!( + "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", + &default_app_name, + (&default_server_name[0]).try_into().unwrap_or_else(|_| "".to_string()) + ); + + manager.default_server_name = Some(default_server_name[0].clone()); + } + } + Ok(manager) + } +} diff --git a/rpxy-lib/src/backend/load_balance/load_balance_main.rs b/rpxy-lib/src/backend/load_balance/load_balance_main.rs new file mode 100644 index 0000000..8ee1600 --- /dev/null +++ b/rpxy-lib/src/backend/load_balance/load_balance_main.rs @@ -0,0 +1,135 @@ +#[cfg(feature = "sticky-cookie")] +pub use super::{ + load_balance_sticky::{LoadBalanceSticky, LoadBalanceStickyBuilder}, + sticky_cookie::StickyCookie, +}; +use derive_builder::Builder; +use rand::Rng; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +/// Constants to specify a load balance option +pub mod load_balance_options { + pub const FIX_TO_FIRST: &str = "none"; + pub const ROUND_ROBIN: &str = "round_robin"; + pub const RANDOM: &str = "random"; + #[cfg(feature = "sticky-cookie")] + pub const STICKY_ROUND_ROBIN: &str = "sticky"; +} + +#[derive(Debug, Clone)] +/// Pointer to upstream serving the incoming request. +/// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given. +pub struct PointerToUpstream { + pub ptr: usize, + pub context: Option, +} +/// Trait for LB +pub(super) trait LoadBalanceWithPointer { + fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream; +} + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object as a pointer to the current serving upstream destination +pub struct LoadBalanceRoundRobin { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LoadBalanceRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LoadBalanceWithPointer for LoadBalanceRoundRobin { + /// Increment the count of upstream served up to the max value + fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + let ptr = if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + }; + PointerToUpstream { ptr, context: None } + } +} + +#[derive(Debug, Clone, Builder)] +/// Random LB object to keep the object of random pools +pub struct LoadBalanceRandom { + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LoadBalanceRandomBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LoadBalanceWithPointer for LoadBalanceRandom { + /// Returns the random index within the range + fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { + let mut rng = rand::thread_rng(); + let ptr = rng.gen_range(0..self.num_upstreams); + PointerToUpstream { ptr, context: None } + } +} + +#[derive(Debug, Clone)] +/// Load Balancing Option +pub enum LoadBalance { + /// Fix to the first upstream. Use if only one upstream destination is specified + FixToFirst, + /// Randomly chose one upstream server + Random(LoadBalanceRandom), + /// Simple round robin without session persistance + RoundRobin(LoadBalanceRoundRobin), + #[cfg(feature = "sticky-cookie")] + /// Round robin with session persistance using cookie + StickyRoundRobin(LoadBalanceSticky), +} +impl Default for LoadBalance { + fn default() -> Self { + Self::FixToFirst + } +} + +impl LoadBalance { + /// Get the index of the upstream serving the incoming request + pub fn get_context(&self, _context_to_lb: &Option) -> PointerToUpstream { + match self { + LoadBalance::FixToFirst => PointerToUpstream { + ptr: 0usize, + context: None, + }, + LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), + LoadBalance::Random(ptr) => ptr.get_ptr(None), + #[cfg(feature = "sticky-cookie")] + LoadBalance::StickyRoundRobin(ptr) => { + // Generate new context if sticky round robin is enabled. + ptr.get_ptr(_context_to_lb.as_ref()) + } + } + } +} + +#[derive(Debug, Clone)] +/// Struct to handle the sticky cookie string, +/// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists. +/// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist. +pub struct LoadBalanceContext { + #[cfg(feature = "sticky-cookie")] + pub sticky_cookie: StickyCookie, + #[cfg(not(feature = "sticky-cookie"))] + pub sticky_cookie: (), +} diff --git a/rpxy-lib/src/backend/load_balance/load_balance_sticky.rs b/rpxy-lib/src/backend/load_balance/load_balance_sticky.rs new file mode 100644 index 0000000..d7a9795 --- /dev/null +++ b/rpxy-lib/src/backend/load_balance/load_balance_sticky.rs @@ -0,0 +1,137 @@ +use super::{ + load_balance_main::{LoadBalanceContext, LoadBalanceWithPointer, PointerToUpstream}, + sticky_cookie::StickyCookieConfig, + Upstream, +}; +use crate::{constants::STICKY_COOKIE_NAME, log::*}; +use derive_builder::Builder; +use rustc_hash::FxHashMap as HashMap; +use std::{ + borrow::Cow, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object in the sticky cookie manner +pub struct LoadBalanceSticky { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, + #[builder(setter(custom))] + /// Information to build the cookie to stick clients to specific backends + pub sticky_config: StickyCookieConfig, + #[builder(setter(custom))] + /// Hashmaps: + /// - Hashmap that maps server indices to server id (string) + /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_maps: UpstreamMap, +} +#[derive(Debug, Clone)] +pub struct UpstreamMap { + /// Hashmap that maps server indices to server id (string) + upstream_index_map: Vec, + /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_id_map: HashMap, +} +impl LoadBalanceStickyBuilder { + /// Set the number of upstream destinations + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } + /// Set the information to build the cookie to stick clients to specific backends + pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option) -> &mut Self { + self.sticky_config = Some(StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように + domain: server_name.to_ascii_lowercase(), + path: if let Some(v) = path_opt { + v.to_ascii_lowercase() + } else { + "/".to_string() + }, + duration: 300, // TODO: config等で変更できるように + }); + self + } + /// Set the hashmaps: upstream_index_map and upstream_id_map + pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + let upstream_index_map: Vec = upstream_vec + .iter() + .enumerate() + .map(|(i, v)| v.calculate_id_with_index(i)) + .collect(); + let mut upstream_id_map = HashMap::default(); + for (i, v) in upstream_index_map.iter().enumerate() { + upstream_id_map.insert(v.to_string(), i); + } + self.upstream_maps = Some(UpstreamMap { + upstream_index_map, + upstream_id_map, + }); + self + } +} +impl<'a> LoadBalanceSticky { + /// Increment the count of upstream served up to the max value + fn simple_increment_ptr(&self) -> usize { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + } + } + /// This is always called only internally. So 'unwrap()' is executed. + fn get_server_id_from_index(&self, index: usize) -> String { + self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() + } + /// This function takes value passed from outside. So 'result' is used. + fn get_server_index_from_id(&self, id: impl Into>) -> Option { + let id_str = id.into().to_string(); + self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) + } +} +impl LoadBalanceWithPointer for LoadBalanceSticky { + /// Get the pointer to the upstream server to serve the incoming request. + fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream { + // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer. + // Otherwise, get the server index indicated by the server_id inside the cookie + let ptr = match req_info { + None => { + debug!("No sticky cookie"); + self.simple_increment_ptr() + } + Some(context) => { + let server_id = &context.sticky_cookie.value.value; + if let Some(server_index) = self.get_server_index_from_id(server_id) { + debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); + server_index + } else { + debug!("Invalid sticky cookie: id={}", server_id); + self.simple_increment_ptr() + } + } + }; + + // Get the server id from the ptr. + // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie). + let upstream_id = self.get_server_id_from_index(ptr); + let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); + let new_context = Some(LoadBalanceContext { + sticky_cookie: new_cookie, + }); + PointerToUpstream { + ptr, + context: new_context, + } + } +} diff --git a/rpxy-lib/src/backend/load_balance/mod.rs b/rpxy-lib/src/backend/load_balance/mod.rs new file mode 100644 index 0000000..d876517 --- /dev/null +++ b/rpxy-lib/src/backend/load_balance/mod.rs @@ -0,0 +1,41 @@ +mod load_balance_main; +#[cfg(feature = "sticky-cookie")] +mod load_balance_sticky; +#[cfg(feature = "sticky-cookie")] +mod sticky_cookie; + +use super::upstream::Upstream; +use thiserror::Error; + +pub use load_balance_main::{ + load_balance_options, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, +}; +#[cfg(feature = "sticky-cookie")] +pub use load_balance_sticky::LoadBalanceStickyBuilder; + +/// Result type for load balancing +type LoadBalanceResult = std::result::Result; +/// Describes things that can go wrong in the Load Balance +#[derive(Debug, Error)] +pub enum LoadBalanceError { + // backend load balance errors + #[cfg(feature = "sticky-cookie")] + #[error("Failed to cookie conversion to/from string")] + FailedToConversionStickyCookie, + + #[cfg(feature = "sticky-cookie")] + #[error("Invalid cookie structure")] + InvalidStickyCookieStructure, + + #[cfg(feature = "sticky-cookie")] + #[error("No sticky cookie value")] + NoStickyCookieValue, + + #[cfg(feature = "sticky-cookie")] + #[error("Failed to cookie conversion into string: no meta information")] + NoStickyCookieNoMetaInfo, + + #[cfg(feature = "sticky-cookie")] + #[error("Failed to build sticky cookie from config")] + FailedToBuildStickyCookie, +} diff --git a/rpxy-lib/src/backend/load_balance/sticky_cookie.rs b/rpxy-lib/src/backend/load_balance/sticky_cookie.rs new file mode 100644 index 0000000..28572b5 --- /dev/null +++ b/rpxy-lib/src/backend/load_balance/sticky_cookie.rs @@ -0,0 +1,205 @@ +use super::{LoadBalanceError, LoadBalanceResult}; +use chrono::{TimeZone, Utc}; +use derive_builder::Builder; +use std::borrow::Cow; + +#[derive(Debug, Clone, Builder)] +/// Cookie value only, used for COOKIE in req +pub struct StickyCookieValue { + #[builder(setter(custom))] + /// Field name indicating sticky cookie + pub name: String, + #[builder(setter(custom))] + /// Upstream server_id + pub value: String, +} +impl<'a> StickyCookieValueBuilder { + pub fn name(&mut self, v: impl Into>) -> &mut Self { + self.name = Some(v.into().to_ascii_lowercase()); + self + } + pub fn value(&mut self, v: impl Into>) -> &mut Self { + self.value = Some(v.into().to_string()); + self + } +} +impl StickyCookieValue { + pub fn try_from(value: &str, expected_name: &str) -> LoadBalanceResult { + if !value.starts_with(expected_name) { + return Err(LoadBalanceError::FailedToConversionStickyCookie); + }; + let kv = value.split('=').map(|v| v.trim()).collect::>(); + if kv.len() != 2 { + return Err(LoadBalanceError::InvalidStickyCookieStructure); + }; + if kv[1].is_empty() { + return Err(LoadBalanceError::NoStickyCookieValue); + } + Ok(StickyCookieValue { + name: expected_name.to_string(), + value: kv[1].to_string(), + }) + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie meta information used for SET-COOKIE in res +pub struct StickyCookieInfo { + #[builder(setter(custom))] + /// Unix time + pub expires: i64, + + #[builder(setter(custom))] + /// Domain + pub domain: String, + + #[builder(setter(custom))] + /// Path + pub path: String, +} +impl<'a> StickyCookieInfoBuilder { + pub fn domain(&mut self, v: impl Into>) -> &mut Self { + self.domain = Some(v.into().to_ascii_lowercase()); + self + } + pub fn path(&mut self, v: impl Into>) -> &mut Self { + self.path = Some(v.into().to_ascii_lowercase()); + self + } + pub fn expires(&mut self, duration_secs: i64) -> &mut Self { + let current = Utc::now().timestamp(); + self.expires = Some(current + duration_secs); + self + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie +pub struct StickyCookie { + #[builder(setter(custom))] + /// Upstream server_id + pub value: StickyCookieValue, + #[builder(setter(custom), default)] + /// Upstream server_id + pub info: Option, +} + +impl<'a> StickyCookieBuilder { + /// Set the value of sticky cookie + pub fn value(&mut self, n: impl Into>, v: impl Into>) -> &mut Self { + self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); + self + } + /// Set the meta information of sticky cookie + pub fn info( + &mut self, + domain: impl Into>, + path: impl Into>, + duration_secs: i64, + ) -> &mut Self { + let info = StickyCookieInfoBuilder::default() + .domain(domain) + .path(path) + .expires(duration_secs) + .build() + .unwrap(); + self.info = Some(Some(info)); + self + } +} + +impl TryInto for StickyCookie { + type Error = LoadBalanceError; + + fn try_into(self) -> LoadBalanceResult { + if self.info.is_none() { + return Err(LoadBalanceError::NoStickyCookieNoMetaInfo); + } + let info = self.info.unwrap(); + let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { + return Err(LoadBalanceError::FailedToConversionStickyCookie); + }; + let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); + let max_age = info.expires - Utc::now().timestamp(); + + Ok(format!( + "{}={}; expires={}; Max-Age={}; path={}; domain={}", + self.value.name, self.value.value, exp_str, max_age, info.path, info.domain + )) + } +} + +#[derive(Debug, Clone)] +/// Configuration to serve incoming requests in the manner of "sticky cookie". +/// Including a dictionary to map Ids included in cookie and upstream destinations, +/// and expiration of cookie. +/// "domain" and "path" in the cookie will be the same as the reverse proxy options. +pub struct StickyCookieConfig { + pub name: String, + pub domain: String, + pub path: String, + pub duration: i64, +} +impl<'a> StickyCookieConfig { + pub fn build_sticky_cookie(&self, v: impl Into>) -> LoadBalanceResult { + StickyCookieBuilder::default() + .value(self.name.clone(), v) + .info(&self.domain, &self.path, self.duration) + .build() + .map_err(|_| LoadBalanceError::FailedToBuildStickyCookie) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::constants::STICKY_COOKIE_NAME; + + #[test] + fn config_works() { + let config = StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), + domain: "example.com".to_string(), + path: "/path".to_string(), + duration: 100, + }; + let expires_unix = Utc::now().timestamp() + 100; + let sc_string: LoadBalanceResult = config.build_sticky_cookie("test_value").unwrap().try_into(); + let expires_date_string = Utc + .timestamp_opt(expires_unix, 0) + .unwrap() + .format("%a, %d-%b-%Y %T GMT") + .to_string(); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com", + STICKY_COOKIE_NAME, expires_date_string, 100 + ) + ); + } + #[test] + fn to_string_works() { + let sc = StickyCookie { + value: StickyCookieValue { + name: STICKY_COOKIE_NAME.to_string(), + value: "test_value".to_string(), + }, + info: Some(StickyCookieInfo { + expires: 1686221173i64, + domain: "example.com".to_string(), + path: "/path".to_string(), + }), + }; + let sc_string: LoadBalanceResult = sc.try_into(); + let max_age = 1686221173i64 - Utc::now().timestamp(); + assert!(sc_string.is_ok()); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com", + STICKY_COOKIE_NAME, max_age + ) + ); + } +} diff --git a/rpxy-lib/src/backend/mod.rs b/rpxy-lib/src/backend/mod.rs new file mode 100644 index 0000000..68f97a8 --- /dev/null +++ b/rpxy-lib/src/backend/mod.rs @@ -0,0 +1,14 @@ +mod backend_main; +mod load_balance; +mod upstream; +mod upstream_opts; + +pub use backend_main::{BackendAppBuilderError, BackendAppManager}; +pub use upstream::Upstream; +// #[cfg(feature = "sticky-cookie")] +// pub use sticky_cookie::{StickyCookie, StickyCookieValue}; +// pub use self::{ +// load_balance::{LbContext, LoadBalance}, +// upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, +// upstream_opts::UpstreamOption, +// }; diff --git a/rpxy-lib/src/backend/upstream.rs b/rpxy-lib/src/backend/upstream.rs new file mode 100644 index 0000000..91e392d --- /dev/null +++ b/rpxy-lib/src/backend/upstream.rs @@ -0,0 +1,266 @@ +#[cfg(feature = "sticky-cookie")] +use super::load_balance::LoadBalanceStickyBuilder; +use super::load_balance::{ + load_balance_options as lb_opts, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, + LoadBalanceRoundRobinBuilder, +}; +// use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; +use super::upstream_opts::UpstreamOption; +use crate::{ + certs::CryptoSource, + error::RpxyError, + globals::{AppConfig, UpstreamUri}, + log::*, + name_exp::{ByteName, PathName}, +}; +#[cfg(feature = "sticky-cookie")] +use base64::{engine::general_purpose, Engine as _}; +use derive_builder::Builder; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +#[cfg(feature = "sticky-cookie")] +use sha2::{Digest, Sha256}; +use std::borrow::Cow; + +#[derive(Debug, Clone)] +/// Handler for given path to route incoming request to path's corresponding upstream server(s). +pub struct PathManager { + /// HashMap of upstream candidate server info, key is path name + /// TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 + inner: HashMap, +} + +impl TryFrom<&AppConfig> for PathManager +where + T: CryptoSource, +{ + type Error = RpxyError; + fn try_from(app_config: &AppConfig) -> Result { + let mut inner: HashMap = HashMap::default(); + + app_config.reverse_proxy.iter().for_each(|rpc| { + let upstream_vec: Vec = rpc.upstream.iter().map(Upstream::from).collect(); + let elem = UpstreamCandidatesBuilder::default() + .upstream(&upstream_vec) + .path(&rpc.path) + .replace_path(&rpc.replace_path) + .load_balance(&rpc.load_balance, &upstream_vec, &app_config.server_name, &rpc.path) + .options(&rpc.upstream_options) + .build() + .unwrap(); + inner.insert(elem.path.clone(), elem); + }); + + if app_config.reverse_proxy.iter().filter(|rpc| rpc.path.is_none()).count() >= 2 { + error!("Multiple default reverse proxy setting"); + return Err(RpxyError::InvalidReverseProxyConfig); + } + + if !(inner.iter().all(|(_, elem)| { + !(elem.options.contains(&UpstreamOption::ForceHttp11Upstream) + && elem.options.contains(&UpstreamOption::ForceHttp2Upstream)) + })) { + error!("Either one of force_http11 or force_http2 can be enabled"); + return Err(RpxyError::InvalidUpstreamOptionSetting); + } + + Ok(PathManager { inner }) + } +} + +impl PathManager { + /// Get an appropriate upstream destinations for given path string. + /// trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 + /// コスト的にこの程度で十分では。 + pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamCandidates> { + let path_name = &path_str.to_path_name(); + + let matched_upstream = self + .inner + .iter() + .filter(|(route_bytes, _)| { + match path_name.starts_with(route_bytes) { + true => { + route_bytes.len() == 1 // route = '/', i.e., default + || match path_name.get(route_bytes.len()) { + None => true, // exact case + Some(p) => p == &b'/', // sub-path case + } + } + _ => false, + } + }) + .max_by_key(|(route_bytes, _)| route_bytes.len()); + if let Some((path, u)) = matched_upstream { + debug!( + "Found upstream: {:?}", + path.try_into().unwrap_or_else(|_| "".to_string()) + ); + Some(u) + } else { + None + } + } +} + +#[derive(Debug, Clone)] +/// Upstream struct just containing uri without path +pub struct Upstream { + /// Base uri without specific path + pub uri: hyper::Uri, +} +impl From<&UpstreamUri> for Upstream { + fn from(value: &UpstreamUri) -> Self { + Self { + uri: value.inner.clone(), + } + } +} +impl Upstream { + #[cfg(feature = "sticky-cookie")] + /// Hashing uri with index to avoid collision + pub fn calculate_id_with_index(&self, index: usize) -> String { + let mut hasher = Sha256::new(); + let uri_string = format!("{}&index={}", self.uri.clone(), index); + hasher.update(uri_string.as_bytes()); + let digest = hasher.finalize(); + general_purpose::URL_SAFE_NO_PAD.encode(digest) + } +} +#[derive(Debug, Clone, Builder)] +/// Struct serving multiple upstream servers for, e.g., load balancing. +pub struct UpstreamCandidates { + #[builder(setter(custom))] + /// Upstream server(s) + pub inner: Vec, + + #[builder(setter(custom), default)] + /// Path like "/path" in [[PathName]] associated with the upstream server(s) + pub path: PathName, + + #[builder(setter(custom), default)] + /// Path in [[PathName]] that will be used to replace the "path" part of incoming url + pub replace_path: Option, + + #[builder(setter(custom), default)] + /// Load balancing option + pub load_balance: LoadBalance, + + #[builder(setter(custom), default)] + /// Activated upstream options defined in [[UpstreamOption]] + pub options: HashSet, +} + +impl UpstreamCandidatesBuilder { + /// Set the upstream server(s) + pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + self.inner = Some(upstream_vec.to_vec()); + self + } + /// Set the path like "/path" in [[PathName]] associated with the upstream server(s), default is "/" + pub fn path(&mut self, v: &Option) -> &mut Self { + let path = match v { + Some(p) => p.to_path_name(), + None => "/".to_path_name(), + }; + self.path = Some(path); + self + } + /// Set the path in [[PathName]] that will be used to replace the "path" part of incoming url + pub fn replace_path(&mut self, v: &Option) -> &mut Self { + self.replace_path = Some(v.to_owned().as_ref().map_or_else(|| None, |v| Some(v.to_path_name()))); + self + } + /// Set the load balancing option + pub fn load_balance( + &mut self, + v: &Option, + // upstream_num: &usize, + upstream_vec: &Vec, + _server_name: &str, + _path_opt: &Option, + ) -> &mut Self { + let upstream_num = &upstream_vec.len(); + let lb = if let Some(x) = v { + match x.as_str() { + lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, + lb_opts::RANDOM => LoadBalance::Random( + LoadBalanceRandomBuilder::default() + .num_upstreams(upstream_num) + .build() + .unwrap(), + ), + lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( + LoadBalanceRoundRobinBuilder::default() + .num_upstreams(upstream_num) + .build() + .unwrap(), + ), + #[cfg(feature = "sticky-cookie")] + lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( + LoadBalanceStickyBuilder::default() + .num_upstreams(upstream_num) + .sticky_config(_server_name, _path_opt) + .upstream_maps(upstream_vec) // TODO: + .build() + .unwrap(), + ), + _ => { + error!("Specified load balancing option is invalid."); + LoadBalance::default() + } + } + } else { + LoadBalance::default() + }; + self.load_balance = Some(lb); + self + } + /// Set the activated upstream options defined in [[UpstreamOption]] + pub fn options(&mut self, v: &Option>) -> &mut Self { + let opts = if let Some(opts) = v { + opts + .iter() + .filter_map(|str| UpstreamOption::try_from(str.as_str()).ok()) + .collect::>() + } else { + Default::default() + }; + self.options = Some(opts); + self + } +} + +impl UpstreamCandidates { + /// Get an enabled option of load balancing [[LoadBalance]] + pub fn get(&self, context_to_lb: &Option) -> (Option<&Upstream>, Option) { + let pointer_to_upstream = self.load_balance.get_context(context_to_lb); + debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); + debug!("Context to LB (Cookie in Request): {:?}", context_to_lb); + debug!( + "Context from LB (Set-Cookie in Response): {:?}", + pointer_to_upstream.context + ); + (self.inner.get(pointer_to_upstream.ptr), pointer_to_upstream.context) + } +} + +#[cfg(test)] +mod test { + #[allow(unused)] + use super::*; + + #[cfg(feature = "sticky-cookie")] + #[test] + fn calc_id_works() { + let uri = "https://www.rust-lang.org".parse::().unwrap(); + let upstream = Upstream { uri }; + assert_eq!( + "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", + upstream.calculate_id_with_index(0) + ); + assert_eq!( + "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", + upstream.calculate_id_with_index(1) + ); + } +} diff --git a/rpxy-lib/src/backend/upstream_opts.rs b/rpxy-lib/src/backend/upstream_opts.rs new file mode 100644 index 0000000..3f5fbc8 --- /dev/null +++ b/rpxy-lib/src/backend/upstream_opts.rs @@ -0,0 +1,22 @@ +use crate::error::*; + +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +pub enum UpstreamOption { + OverrideHost, + UpgradeInsecureRequests, + ForceHttp11Upstream, + ForceHttp2Upstream, + // TODO: Adds more options for heder override +} +impl TryFrom<&str> for UpstreamOption { + type Error = RpxyError; + fn try_from(val: &str) -> RpxyResult { + match val { + "override_host" => Ok(Self::OverrideHost), + "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), + "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), + "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), + _ => Err(RpxyError::UnsupportedUpstreamOption), + } + } +} diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 7662a9d..bc730a9 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -8,4 +8,15 @@ pub type RpxyResult = std::result::Result; pub enum RpxyError { #[error("IO error: {0}")] Io(#[from] std::io::Error), + + // backend errors + #[error("Invalid reverse proxy setting")] + InvalidReverseProxyConfig, + #[error("Invalid upstream option setting")] + InvalidUpstreamOptionSetting, + #[error("Failed to build backend app")] + FailedToBuildBackendApp(#[from] crate::backend::BackendAppBuilderError), + + #[error("Unsupported upstream option")] + UnsupportedUpstreamOption, } diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index e0a9d21..28e08c0 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -1,3 +1,4 @@ +mod backend; mod certs; mod constants; mod count; @@ -5,6 +6,7 @@ mod error; mod globals; mod hyper_executor; mod log; +mod name_exp; mod proxy; use crate::{error::*, globals::Globals, log::*, proxy::Proxy}; @@ -70,8 +72,8 @@ where term_notify: term_notify.clone(), }); - // TODO: 1. build backends, and make it contained in Arc - // app_config_list: app_config_list.clone(), + // 1. build backends, and make it contained in Arc + let app_manager = Arc::new(backend::BackendAppManager::try_from(app_config_list)?); // TODO: 2. build message handler with Arc-ed http_client and backends, and make it contained in Arc as well // // build message handler including a request forwarder diff --git a/rpxy-lib/src/name_exp.rs b/rpxy-lib/src/name_exp.rs new file mode 100644 index 0000000..8ed17e2 --- /dev/null +++ b/rpxy-lib/src/name_exp.rs @@ -0,0 +1,160 @@ +use std::borrow::Cow; + +/// Server name (hostname or ip address) representation in bytes-based struct +/// for searching hashmap or key list by exact or longest-prefix matching +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +pub struct ServerName { + inner: Vec, // lowercase ascii bytes +} +impl From<&str> for ServerName { + fn from(s: &str) -> Self { + let name = s.bytes().collect::>().to_ascii_lowercase(); + Self { inner: name } + } +} +impl From<&[u8]> for ServerName { + fn from(b: &[u8]) -> Self { + Self { + inner: b.to_ascii_lowercase(), + } + } +} +impl TryInto for &ServerName { + type Error = anyhow::Error; + fn try_into(self) -> Result { + let s = std::str::from_utf8(&self.inner)?; + Ok(s.to_string()) + } +} +impl AsRef<[u8]> for ServerName { + fn as_ref(&self) -> &[u8] { + self.inner.as_ref() + } +} + +/// Path name, like "/path/ok", represented in bytes-based struct +/// for searching hashmap or key list by exact or longest-prefix matching +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +pub struct PathName { + inner: Vec, // lowercase ascii bytes +} +impl From<&str> for PathName { + fn from(s: &str) -> Self { + let name = s.bytes().collect::>().to_ascii_lowercase(); + Self { inner: name } + } +} +impl From<&[u8]> for PathName { + fn from(b: &[u8]) -> Self { + Self { + inner: b.to_ascii_lowercase(), + } + } +} +impl TryInto for &PathName { + type Error = anyhow::Error; + fn try_into(self) -> Result { + let s = std::str::from_utf8(&self.inner)?; + Ok(s.to_string()) + } +} +impl AsRef<[u8]> for PathName { + fn as_ref(&self) -> &[u8] { + self.inner.as_ref() + } +} +impl PathName { + pub fn len(&self) -> usize { + self.inner.len() + } + pub fn is_empty(&self) -> bool { + self.inner.len() == 0 + } + pub fn get(&self, index: I) -> Option<&I::Output> + where + I: std::slice::SliceIndex<[u8]>, + { + self.inner.get(index) + } + pub fn starts_with(&self, needle: &Self) -> bool { + self.inner.starts_with(&needle.inner) + } +} + +/// Trait to express names in ascii-lowercased bytes +pub trait ByteName { + type OutputServer: Send + Sync + 'static; + type OutputPath; + fn to_server_name(self) -> Self::OutputServer; + fn to_path_name(self) -> Self::OutputPath; +} + +impl<'a, T: Into>> ByteName for T { + type OutputServer = ServerName; + type OutputPath = PathName; + + fn to_server_name(self) -> Self::OutputServer { + ServerName::from(self.into().as_ref()) + } + + fn to_path_name(self) -> Self::OutputPath { + PathName::from(self.into().as_ref()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn bytes_name_str_works() { + let s = "OK_string"; + let bn = s.to_path_name(); + let bn_lc = s.to_server_name(); + + assert_eq!("ok_string".as_bytes(), bn.as_ref()); + assert_eq!("ok_string".as_bytes(), bn_lc.as_ref()); + } + + #[test] + fn from_works() { + let s = "OK_string".to_server_name(); + let m = ServerName::from("OK_strinG".as_bytes()); + assert_eq!(s, m); + assert_eq!(s.as_ref(), "ok_string".as_bytes()); + assert_eq!(m.as_ref(), "ok_string".as_bytes()); + } + + #[test] + fn get_works() { + let s = "OK_str".to_path_name(); + let i = s.get(0); + assert_eq!(Some(&"o".as_bytes()[0]), i); + let i = s.get(1); + assert_eq!(Some(&"k".as_bytes()[0]), i); + let i = s.get(2); + assert_eq!(Some(&"_".as_bytes()[0]), i); + let i = s.get(3); + assert_eq!(Some(&"s".as_bytes()[0]), i); + let i = s.get(4); + assert_eq!(Some(&"t".as_bytes()[0]), i); + let i = s.get(5); + assert_eq!(Some(&"r".as_bytes()[0]), i); + let i = s.get(6); + assert_eq!(None, i); + } + + #[test] + fn start_with_works() { + let s = "OK_str".to_path_name(); + let correct = "OK".to_path_name(); + let incorrect = "KO".to_path_name(); + assert!(s.starts_with(&correct)); + assert!(!s.starts_with(&incorrect)); + } + + #[test] + fn as_ref_works() { + let s = "OK_str".to_path_name(); + assert_eq!(s.as_ref(), "ok_str".as_bytes()); + } +} diff --git a/rpxy-lib/src/proxy/crypto_service.rs b/rpxy-lib/src/proxy/crypto_service.rs new file mode 100644 index 0000000..e69de29 diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 9b55e1d..9718cc1 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -1,5 +1,6 @@ mod proxy_main; mod socket; +mod proxy_tls; use crate::{globals::Globals, hyper_executor::LocalExecutor}; use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 416ecac..a024bd7 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,10 +1,11 @@ use super::socket::bind_tcp_socket; -use crate::{error::RpxyResult, globals::Globals, hyper_executor::LocalExecutor, log::*}; +use crate::{error::RpxyResult, globals::Globals, log::*}; +use hot_reload::{ReloaderReceiver, ReloaderService}; use hyper_util::server::conn::auto::Builder as ConnectionBuilder; use std::{net::SocketAddr, sync::Arc}; /// Proxy main object responsible to serve requests received from clients at the given socket address. -pub(crate) struct Proxy { +pub(crate) struct Proxy { /// global context shared among async tasks pub globals: Arc, /// listen socket address @@ -15,7 +16,7 @@ pub(crate) struct Proxy { pub connection_builder: Arc>, } -impl Proxy { +impl Proxy { /// Start without TLS (HTTP cleartext) async fn start_without_tls(&self) -> RpxyResult<()> { let listener_service = async { @@ -31,14 +32,27 @@ impl Proxy { Ok(()) } + /// Start with TLS (HTTPS) + pub(super) async fn start_with_tls(&self) -> RpxyResult<()> { + // let (cert_reloader_service, cert_reloader_rx) = ReloaderService::, ServerCryptoBase>::new( + // &self.globals.clone(), + // CERTS_WATCH_DELAY_SECS, + // !LOAD_CERTS_ONLY_WHEN_UPDATED, + // ) + // .await + // .map_err(|e| anyhow::anyhow!(e))?; + loop {} + Ok(()) + } + /// Entrypoint for HTTP/1.1, 2 and 3 servers pub async fn start(&self) -> RpxyResult<()> { let proxy_service = async { - // if self.tls_enabled { - // self.start_with_tls().await - // } else { - self.start_without_tls().await - // } + if self.tls_enabled { + self.start_with_tls().await + } else { + self.start_without_tls().await + } }; match &self.globals.term_notify { diff --git a/rpxy-lib/src/proxy/proxy_tls.rs b/rpxy-lib/src/proxy/proxy_tls.rs new file mode 100644 index 0000000..f67ad8d --- /dev/null +++ b/rpxy-lib/src/proxy/proxy_tls.rs @@ -0,0 +1,6 @@ +use super::proxy_main::Proxy; +use crate::{log::*, error::*}; + +impl Proxy{ + +} From 5576389acbc2bc0ddbc5b4250a9c37c76616b028 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 24 Nov 2023 19:17:02 +0900 Subject: [PATCH 06/50] wip: implemented crypto reloader, as separated object from proxy itself --- rpxy-lib/Cargo.toml | 36 ++-- rpxy-lib/src/backend/backend_main.rs | 2 +- rpxy-lib/src/backend/upstream.rs | 2 +- rpxy-lib/src/certs.rs | 22 --- rpxy-lib/src/crypto/certs.rs | 91 +++++++++ rpxy-lib/src/crypto/mod.rs | 36 ++++ rpxy-lib/src/crypto/service.rs | 272 +++++++++++++++++++++++++++ rpxy-lib/src/error.rs | 5 +- rpxy-lib/src/globals.rs | 11 +- rpxy-lib/src/lib.rs | 44 ++++- rpxy-lib/src/proxy/crypto_service.rs | 0 rpxy-lib/src/proxy/mod.rs | 2 +- rpxy-lib/src/proxy/proxy_main.rs | 1 - 13 files changed, 468 insertions(+), 56 deletions(-) delete mode 100644 rpxy-lib/src/certs.rs create mode 100644 rpxy-lib/src/crypto/certs.rs create mode 100644 rpxy-lib/src/crypto/mod.rs create mode 100644 rpxy-lib/src/crypto/service.rs delete mode 100644 rpxy-lib/src/proxy/crypto_service.rs diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 2de993d..c0cb403 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -13,11 +13,11 @@ publish = false [features] default = ["http3-s2n", "sticky-cookie", "cache"] -http3-quinn = ["socket2"] #"quinn", "h3", "h3-quinn", ] -http3-s2n = [] #"h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] +http3-quinn = ["socket2", "quinn", "h3", "h3-quinn"] +http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] sticky-cookie = ["base64", "sha2", "chrono"] -cache = [] #"http-cache-semantics", "lru"] -native-roots = [] #"hyper-rustls/native-tokio"] +cache = [] #"http-cache-semantics", "lru"] +native-roots = [] #"hyper-rustls/native-tokio"] [dependencies] rand = "0.8.5" @@ -39,8 +39,7 @@ async-trait = "0.1.74" anyhow = "1.0.75" thiserror = "1.0.50" -# http and tls -hot_reload = "0.1.4" # reloading certs +# http http = "1.0.0" # http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } @@ -52,22 +51,25 @@ hyper-util = { version = "0.1.1", features = ["full"] } # "http2", # ] } # tokio-rustls = { version = "0.24.1", features = ["early-data"] } + +# tls and cert management +hot_reload = "0.1.4" rustls = { version = "0.21.9", default-features = false } -# webpki = "0.22.4" -# x509-parser = "0.15.1" +webpki = "0.22.4" +x509-parser = "0.15.1" # logging tracing = { version = "0.1.40" } -# # http/3 -# quinn = { version = "0.10.2", optional = true } -# h3 = { path = "../submodules/h3/h3/", optional = true } -# h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } -# s2n-quic = { version = "1.31.0", default-features = false, features = [ -# "provider-tls-rustls", -# ], optional = true } -# s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } -# s2n-quic-rustls = { version = "0.31.0", optional = true } +# http/3 +quinn = { version = "0.10.2", optional = true } +h3 = { path = "../submodules/h3/h3/", optional = true } +h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } +s2n-quic = { version = "1.31.0", default-features = false, features = [ + "provider-tls-rustls", +], optional = true } +s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } +s2n-quic-rustls = { version = "0.31.0", optional = true } # for UDP socket wit SO_REUSEADDR when h3 with quinn socket2 = { version = "0.5.5", features = ["all"], optional = true } diff --git a/rpxy-lib/src/backend/backend_main.rs b/rpxy-lib/src/backend/backend_main.rs index 695a063..d9fa649 100644 --- a/rpxy-lib/src/backend/backend_main.rs +++ b/rpxy-lib/src/backend/backend_main.rs @@ -1,5 +1,5 @@ use crate::{ - certs::CryptoSource, + crypto::CryptoSource, error::*, log::*, name_exp::{ByteName, ServerName}, diff --git a/rpxy-lib/src/backend/upstream.rs b/rpxy-lib/src/backend/upstream.rs index 91e392d..ac50d69 100644 --- a/rpxy-lib/src/backend/upstream.rs +++ b/rpxy-lib/src/backend/upstream.rs @@ -7,7 +7,7 @@ use super::load_balance::{ // use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; use super::upstream_opts::UpstreamOption; use crate::{ - certs::CryptoSource, + crypto::CryptoSource, error::RpxyError, globals::{AppConfig, UpstreamUri}, log::*, diff --git a/rpxy-lib/src/certs.rs b/rpxy-lib/src/certs.rs deleted file mode 100644 index b93aa8f..0000000 --- a/rpxy-lib/src/certs.rs +++ /dev/null @@ -1,22 +0,0 @@ -use async_trait::async_trait; -use rustls::{Certificate, PrivateKey}; - -#[async_trait] -// Trait to read certs and keys anywhere from KVS, file, sqlite, etc. -pub trait CryptoSource { - type Error; - - /// read crypto materials from source - async fn read(&self) -> Result; - - /// Returns true when mutual tls is enabled - fn is_mutual_tls(&self) -> bool; -} - -/// Certificates and private keys in rustls loaded from files -#[derive(Debug, PartialEq, Eq, Clone)] -pub struct CertsAndKeys { - pub certs: Vec, - pub cert_keys: Vec, - pub client_ca_certs: Option>, -} diff --git a/rpxy-lib/src/crypto/certs.rs b/rpxy-lib/src/crypto/certs.rs new file mode 100644 index 0000000..c9cfafd --- /dev/null +++ b/rpxy-lib/src/crypto/certs.rs @@ -0,0 +1,91 @@ +use async_trait::async_trait; +use rustc_hash::FxHashSet as HashSet; +use rustls::{ + sign::{any_supported_type, CertifiedKey}, + Certificate, OwnedTrustAnchor, PrivateKey, +}; +use std::io; +use x509_parser::prelude::*; + +#[async_trait] +// Trait to read certs and keys anywhere from KVS, file, sqlite, etc. +pub trait CryptoSource { + type Error; + + /// read crypto materials from source + async fn read(&self) -> Result; + + /// Returns true when mutual tls is enabled + fn is_mutual_tls(&self) -> bool; +} + +/// Certificates and private keys in rustls loaded from files +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct CertsAndKeys { + pub certs: Vec, + pub cert_keys: Vec, + pub client_ca_certs: Option>, +} + +impl CertsAndKeys { + pub fn parse_server_certs_and_keys(&self) -> Result { + // for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { + let signing_key = self + .cert_keys + .iter() + .find_map(|k| { + if let Ok(sk) = any_supported_type(k) { + Some(sk) + } else { + None + } + }) + .ok_or_else(|| { + io::Error::new( + io::ErrorKind::InvalidInput, + "Unable to find a valid certificate and key", + ) + })?; + Ok(CertifiedKey::new(self.certs.clone(), signing_key)) + } + + pub fn parse_client_ca_certs(&self) -> Result<(Vec, HashSet>), anyhow::Error> { + let certs = self.client_ca_certs.as_ref().ok_or(anyhow::anyhow!("No client cert"))?; + + let owned_trust_anchors: Vec<_> = certs + .iter() + .map(|v| { + // let trust_anchor = tokio_rustls::webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); + let trust_anchor = webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); + rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( + trust_anchor.subject, + trust_anchor.spki, + trust_anchor.name_constraints, + ) + }) + .collect(); + + // TODO: SKID is not used currently + let subject_key_identifiers: HashSet<_> = certs + .iter() + .filter_map(|v| { + // retrieve ca key id (subject key id) + let cert = parse_x509_certificate(&v.0).unwrap().1; + let subject_key_ids = cert + .iter_extensions() + .filter_map(|ext| match ext.parsed_extension() { + ParsedExtension::SubjectKeyIdentifier(skid) => Some(skid), + _ => None, + }) + .collect::>(); + if !subject_key_ids.is_empty() { + Some(subject_key_ids[0].0.to_owned()) + } else { + None + } + }) + .collect(); + + Ok((owned_trust_anchors, subject_key_identifiers)) + } +} diff --git a/rpxy-lib/src/crypto/mod.rs b/rpxy-lib/src/crypto/mod.rs new file mode 100644 index 0000000..1f6566d --- /dev/null +++ b/rpxy-lib/src/crypto/mod.rs @@ -0,0 +1,36 @@ +mod certs; +mod service; + +use crate::{ + backend::BackendAppManager, + constants::{CERTS_WATCH_DELAY_SECS, LOAD_CERTS_ONLY_WHEN_UPDATED}, + error::RpxyResult, +}; +use hot_reload::{ReloaderReceiver, ReloaderService}; +use service::CryptoReloader; +use std::sync::Arc; + +pub use certs::{CertsAndKeys, CryptoSource}; +pub use service::ServerCryptoBase; + +/// Result type inner of certificate reloader service +type ReloaderServiceResultInner = ( + ReloaderService, ServerCryptoBase>, + ReloaderReceiver, +); +/// Build certificate reloader service +pub(crate) async fn build_cert_reloader( + app_manager: &Arc>, +) -> RpxyResult> +where + T: CryptoSource + Clone + Send + Sync + 'static, +{ + let (cert_reloader_service, cert_reloader_rx) = ReloaderService::< + service::CryptoReloader, + service::ServerCryptoBase, + >::new( + app_manager, CERTS_WATCH_DELAY_SECS, !LOAD_CERTS_ONLY_WHEN_UPDATED + ) + .await?; + Ok((cert_reloader_service, cert_reloader_rx)) +} diff --git a/rpxy-lib/src/crypto/service.rs b/rpxy-lib/src/crypto/service.rs new file mode 100644 index 0000000..0736b0e --- /dev/null +++ b/rpxy-lib/src/crypto/service.rs @@ -0,0 +1,272 @@ +use super::certs::{CertsAndKeys, CryptoSource}; +use crate::{backend::BackendAppManager, log::*, name_exp::ServerName}; +use async_trait::async_trait; +use hot_reload::*; +use rustc_hash::FxHashMap as HashMap; +use rustls::{server::ResolvesServerCertUsingSni, sign::CertifiedKey, RootCertStore, ServerConfig}; +use std::sync::Arc; + +#[derive(Clone)] +/// Reloader service for certificates and keys for TLS +pub struct CryptoReloader +where + T: CryptoSource, +{ + inner: Arc>, +} + +/// SNI to ServerConfig map type +pub type SniServerCryptoMap = HashMap>; +/// SNI to ServerConfig map +pub struct ServerCrypto { + // For Quic/HTTP3, only servers with no client authentication + #[cfg(feature = "http3-quinn")] + pub inner_global_no_client_auth: Arc, + #[cfg(feature = "http3-s2n")] + pub inner_global_no_client_auth: s2n_quic_rustls::Server, + // For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers + pub inner_local_map: Arc, +} + +/// Reloader target for the certificate reloader service +#[derive(Debug, PartialEq, Eq, Clone, Default)] +pub struct ServerCryptoBase { + inner: HashMap, +} + +#[async_trait] +impl Reload for CryptoReloader +where + T: CryptoSource + Sync + Send, +{ + type Source = Arc>; + async fn new(source: &Self::Source) -> Result> { + Ok(Self { inner: source.clone() }) + } + + async fn reload(&self) -> Result, ReloaderError> { + let mut certs_and_keys_map = ServerCryptoBase::default(); + + for (server_name_bytes_exp, backend) in self.inner.apps.iter() { + if let Some(crypto_source) = &backend.crypto_source { + let certs_and_keys = crypto_source + .read() + .await + .map_err(|_e| ReloaderError::::Reload("Failed to reload cert, key or ca cert"))?; + certs_and_keys_map + .inner + .insert(server_name_bytes_exp.to_owned(), certs_and_keys); + } + } + + Ok(Some(certs_and_keys_map)) + } +} + +impl TryInto> for &ServerCryptoBase { + type Error = anyhow::Error; + + fn try_into(self) -> Result, Self::Error> { + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + let server_crypto_global = self.build_server_crypto_global()?; + let server_crypto_local_map: SniServerCryptoMap = self.build_server_crypto_local_map()?; + + Ok(Arc::new(ServerCrypto { + #[cfg(feature = "http3-quinn")] + inner_global_no_client_auth: Arc::new(server_crypto_global), + #[cfg(feature = "http3-s2n")] + inner_global_no_client_auth: server_crypto_global, + inner_local_map: Arc::new(server_crypto_local_map), + })) + } +} + +impl ServerCryptoBase { + fn build_server_crypto_local_map(&self) -> Result> { + let mut server_crypto_local_map: SniServerCryptoMap = HashMap::default(); + + for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { + let server_name: String = server_name_bytes_exp.try_into()?; + + // Parse server certificates and private keys + let Ok(certified_key): Result = certs_and_keys.parse_server_certs_and_keys() else { + warn!("Failed to add certificate for {}", server_name); + continue; + }; + + let mut resolver_local = ResolvesServerCertUsingSni::new(); + let mut client_ca_roots_local = RootCertStore::empty(); + + // add server certificate and key + if let Err(e) = resolver_local.add(server_name.as_str(), certified_key.to_owned()) { + error!( + "{}: Failed to read some certificates and keys {}", + server_name.as_str(), + e + ) + } + + // add client certificate if specified + if certs_and_keys.client_ca_certs.is_some() { + // add client certificate if specified + match certs_and_keys.parse_client_ca_certs() { + Ok((owned_trust_anchors, _subject_key_ids)) => { + client_ca_roots_local.add_trust_anchors(owned_trust_anchors.into_iter()); + } + Err(e) => { + warn!( + "Failed to add client CA certificate for {}: {}", + server_name.as_str(), + e + ); + } + } + } + + let mut server_config_local = if client_ca_roots_local.is_empty() { + // with no client auth, enable http1.1 -- 3 + #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + { + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver_local)) + } + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + { + let mut sc = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver_local)); + sc.alpn_protocols = vec![b"h3".to_vec(), b"hq-29".to_vec()]; // TODO: remove hq-29 later? + sc + } + } else { + // with client auth, enable only http1.1 and 2 + // let client_certs_verifier = rustls::server::AllowAnyAnonymousOrAuthenticatedClient::new(client_ca_roots); + let client_certs_verifier = rustls::server::AllowAnyAuthenticatedClient::new(client_ca_roots_local); + ServerConfig::builder() + .with_safe_defaults() + .with_client_cert_verifier(Arc::new(client_certs_verifier)) + .with_cert_resolver(Arc::new(resolver_local)) + }; + server_config_local.alpn_protocols.push(b"h2".to_vec()); + server_config_local.alpn_protocols.push(b"http/1.1".to_vec()); + + server_crypto_local_map.insert(server_name_bytes_exp.to_owned(), Arc::new(server_config_local)); + } + Ok(server_crypto_local_map) + } + + #[cfg(feature = "http3-quinn")] + fn build_server_crypto_global(&self) -> Result> { + let mut resolver_global = ResolvesServerCertUsingSni::new(); + + for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { + let server_name: String = server_name_bytes_exp.try_into()?; + + // Parse server certificates and private keys + let Ok(certified_key): Result = certs_and_keys.parse_server_certs_and_keys() else { + warn!("Failed to add certificate for {}", server_name); + continue; + }; + + if certs_and_keys.client_ca_certs.is_none() { + // aggregated server config for no client auth server for http3 + if let Err(e) = resolver_global.add(server_name.as_str(), certified_key) { + error!( + "{}: Failed to read some certificates and keys {}", + server_name.as_str(), + e + ) + } + } + } + + ////////////// + let mut server_crypto_global = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver_global)); + + ////////////////////////////// + + server_crypto_global.alpn_protocols = vec![ + b"h3".to_vec(), + b"hq-29".to_vec(), // TODO: remove later? + b"h2".to_vec(), + b"http/1.1".to_vec(), + ]; + Ok(server_crypto_global) + } + + #[cfg(feature = "http3-s2n")] + fn build_server_crypto_global(&self) -> Result> { + let mut resolver_global = s2n_quic_rustls::rustls::server::ResolvesServerCertUsingSni::new(); + + for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { + let server_name: String = server_name_bytes_exp.try_into()?; + + // Parse server certificates and private keys + let Ok(certified_key) = parse_server_certs_and_keys_s2n(certs_and_keys) else { + warn!("Failed to add certificate for {}", server_name); + continue; + }; + + if certs_and_keys.client_ca_certs.is_none() { + // aggregated server config for no client auth server for http3 + if let Err(e) = resolver_global.add(server_name.as_str(), certified_key) { + error!( + "{}: Failed to read some certificates and keys {}", + server_name.as_str(), + e + ) + } + } + } + let alpn = vec![ + b"h3".to_vec(), + b"hq-29".to_vec(), // TODO: remove later? + b"h2".to_vec(), + b"http/1.1".to_vec(), + ]; + let server_crypto_global = s2n_quic::provider::tls::rustls::Server::builder() + .with_cert_resolver(Arc::new(resolver_global)) + .map_err(|e| anyhow::anyhow!(e))? + .with_application_protocols(alpn.iter()) + .map_err(|e| anyhow::anyhow!(e))? + .build() + .map_err(|e| anyhow::anyhow!(e))?; + Ok(server_crypto_global) + } +} + +#[cfg(feature = "http3-s2n")] +/// This is workaround for the version difference between rustls and s2n-quic-rustls +fn parse_server_certs_and_keys_s2n( + certs_and_keys: &CertsAndKeys, +) -> Result { + let signing_key = certs_and_keys + .cert_keys + .iter() + .find_map(|k| { + let s2n_private_key = s2n_quic_rustls::PrivateKey(k.0.clone()); + if let Ok(sk) = s2n_quic_rustls::rustls::sign::any_supported_type(&s2n_private_key) { + Some(sk) + } else { + None + } + }) + .ok_or_else(|| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Unable to find a valid certificate and key", + ) + })?; + let certs: Vec<_> = certs_and_keys + .certs + .iter() + .map(|c| s2n_quic_rustls::rustls::Certificate(c.0.clone())) + .collect(); + Ok(s2n_quic_rustls::rustls::sign::CertifiedKey::new(certs, signing_key)) +} diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index bc730a9..05e5db1 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -9,12 +9,15 @@ pub enum RpxyError { #[error("IO error: {0}")] Io(#[from] std::io::Error), + #[error("Certificate reload error: {0}")] + CertificateReloadError(#[from] hot_reload::ReloaderError), + // backend errors #[error("Invalid reverse proxy setting")] InvalidReverseProxyConfig, #[error("Invalid upstream option setting")] InvalidUpstreamOptionSetting, - #[error("Failed to build backend app")] + #[error("Failed to build backend app: {0}")] FailedToBuildBackendApp(#[from] crate::backend::BackendAppBuilderError), #[error("Unsupported upstream option")] diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index 88d6cbf..71a2dca 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -1,9 +1,14 @@ -use crate::{certs::CryptoSource, constants::*, count::RequestCount}; +use crate::{ + constants::*, + count::RequestCount, + crypto::{CryptoSource, ServerCryptoBase}, +}; +use hot_reload::ReloaderReceiver; use std::{net::SocketAddr, sync::Arc, time::Duration}; /// Global object containing proxy configurations and shared object like counters. /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. -pub struct Globals { +pub(crate) struct Globals { /// Configuration parameters for proxy transport and request handlers pub proxy_config: ProxyConfig, /// Shared context - Counter for serving requests @@ -12,6 +17,8 @@ pub struct Globals { pub runtime_handle: tokio::runtime::Handle, /// Shared context - Notify object to stop async tasks pub term_notify: Option>, + /// Shared context - Certificate reloader service receiver + pub cert_reloader_rx: Option>, } /// Configuration parameters for proxy transport and request handlers diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 28e08c0..1f5fa37 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -1,7 +1,7 @@ mod backend; -mod certs; mod constants; mod count; +mod crypto; mod error; mod globals; mod hyper_executor; @@ -9,12 +9,12 @@ mod log; mod name_exp; mod proxy; -use crate::{error::*, globals::Globals, log::*, proxy::Proxy}; +use crate::{crypto::build_cert_reloader, error::*, globals::Globals, log::*, proxy::Proxy}; use futures::future::select_all; use std::sync::Arc; pub use crate::{ - certs::{CertsAndKeys, CryptoSource}, + crypto::{CertsAndKeys, CryptoSource}, globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}, }; pub mod reexports { @@ -64,17 +64,27 @@ where info!("Cache is disabled") } - // build global shared context + // 1. build backends, and make it contained in Arc + let app_manager = Arc::new(backend::BackendAppManager::try_from(app_config_list)?); + + // 2. build crypto reloader service + let (cert_reloader_service, cert_reloader_rx) = match proxy_config.https_port { + Some(_) => { + let (s, r) = build_cert_reloader(&app_manager).await?; + (Some(s), Some(r)) + } + None => (None, None), + }; + + // 3. build global shared context let globals = Arc::new(Globals { proxy_config: proxy_config.clone(), request_count: Default::default(), runtime_handle: runtime_handle.clone(), term_notify: term_notify.clone(), + cert_reloader_rx: cert_reloader_rx.clone(), }); - // 1. build backends, and make it contained in Arc - let app_manager = Arc::new(backend::BackendAppManager::try_from(app_config_list)?); - // TODO: 2. build message handler with Arc-ed http_client and backends, and make it contained in Arc as well // // build message handler including a request forwarder // let msg_handler = Arc::new( @@ -106,9 +116,23 @@ where }); // wait for all future - if let (Ok(Err(e)), _, _) = select_all(futures_iter).await { - error!("Some proxy services are down: {}", e); - }; + match cert_reloader_service { + Some(cert_service) => { + tokio::select! { + _ = cert_service.start() => { + error!("Certificate reloader service got down"); + } + _ = select_all(futures_iter) => { + error!("Some proxy services are down"); + } + } + } + None => { + if let (Ok(Err(e)), _, _) = select_all(futures_iter).await { + error!("Some proxy services are down: {}", e); + } + } + } Ok(()) } diff --git a/rpxy-lib/src/proxy/crypto_service.rs b/rpxy-lib/src/proxy/crypto_service.rs deleted file mode 100644 index e69de29..0000000 diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 9718cc1..2ca21e4 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -1,6 +1,6 @@ mod proxy_main; -mod socket; mod proxy_tls; +mod socket; use crate::{globals::Globals, hyper_executor::LocalExecutor}; use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index a024bd7..5aea172 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,6 +1,5 @@ use super::socket::bind_tcp_socket; use crate::{error::RpxyResult, globals::Globals, log::*}; -use hot_reload::{ReloaderReceiver, ReloaderService}; use hyper_util::server::conn::auto::Builder as ConnectionBuilder; use std::{net::SocketAddr, sync::Arc}; From 1dc88ce0564cbb23ef35b792a6194c62b13e6e39 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 24 Nov 2023 22:23:57 +0900 Subject: [PATCH 07/50] wip: tested with synthetic echo response from h3 --- legacy-lib/Cargo.toml | 2 +- rpxy-bin/Cargo.toml | 6 +- rpxy-lib/Cargo.toml | 23 ++- rpxy-lib/src/crypto/mod.rs | 2 +- rpxy-lib/src/error.rs | 47 ++++++ rpxy-lib/src/proxy/mod.rs | 6 +- rpxy-lib/src/proxy/proxy_h3.rs | 205 +++++++++++++++++++++++ rpxy-lib/src/proxy/proxy_main.rs | 219 +++++++++++++++++++++++-- rpxy-lib/src/proxy/proxy_quic_quinn.rs | 121 ++++++++++++++ rpxy-lib/src/proxy/proxy_quic_s2n.rs | 132 +++++++++++++++ rpxy-lib/src/proxy/proxy_tls.rs | 6 - 11 files changed, 732 insertions(+), 37 deletions(-) create mode 100644 rpxy-lib/src/proxy/proxy_h3.rs create mode 100644 rpxy-lib/src/proxy/proxy_quic_quinn.rs create mode 100644 rpxy-lib/src/proxy/proxy_quic_s2n.rs delete mode 100644 rpxy-lib/src/proxy/proxy_tls.rs diff --git a/legacy-lib/Cargo.toml b/legacy-lib/Cargo.toml index c975fb6..00f1edb 100644 --- a/legacy-lib/Cargo.toml +++ b/legacy-lib/Cargo.toml @@ -12,7 +12,7 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-s2n", "sticky-cookie", "cache"] +default = ["http3-quinn", "sticky-cookie", "cache"] http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] sticky-cookie = ["base64", "sha2", "chrono"] diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 8cddc71..00c2ef4 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rpxy" -version = "0.6.2" +version = "0.7.0" authors = ["Jun Kurihara"] homepage = "https://github.com/junkurihara/rust-rpxy" repository = "https://github.com/junkurihara/rust-rpxy" @@ -12,7 +12,7 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-s2n", "cache"] +default = ["http3-quinn", "cache"] http3-quinn = ["rpxy-lib/http3-quinn"] http3-s2n = ["rpxy-lib/http3-s2n"] cache = ["rpxy-lib/cache"] @@ -20,7 +20,7 @@ native-roots = ["rpxy-lib/native-roots"] [dependencies] rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ - # "sticky-cookie", + "sticky-cookie", ] } anyhow = "1.0.75" diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index c0cb403..847d3fe 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rpxy-lib" -version = "0.6.2" +version = "0.7.0" authors = ["Jun Kurihara"] homepage = "https://github.com/junkurihara/rust-rpxy" repository = "https://github.com/junkurihara/rust-rpxy" @@ -12,17 +12,23 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-s2n", "sticky-cookie", "cache"] +default = ["http3-quinn", "sticky-cookie", "cache"] http3-quinn = ["socket2", "quinn", "h3", "h3-quinn"] -http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] +http3-s2n = [ + "h3", + "s2n-quic", + "s2n-quic-core", + "s2n-quic-rustls", + "s2n-quic-h3", +] sticky-cookie = ["base64", "sha2", "chrono"] -cache = [] #"http-cache-semantics", "lru"] -native-roots = [] #"hyper-rustls/native-tokio"] +cache = [] #"http-cache-semantics", "lru"] +native-roots = [] #"hyper-rustls/native-tokio"] [dependencies] rand = "0.8.5" rustc-hash = "1.1.0" -# bytes = "1.5.0" +bytes = "1.5.0" derive_builder = "0.12.0" futures = { version = "0.3.29", features = ["alloc", "async-await"] } tokio = { version = "1.34.0", default-features = false, features = [ @@ -41,7 +47,7 @@ thiserror = "1.0.50" # http http = "1.0.0" -# http-body-util = "0.1.0" +http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } hyper-util = { version = "0.1.1", features = ["full"] } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ @@ -50,11 +56,11 @@ hyper-util = { version = "0.1.1", features = ["full"] } # "http1", # "http2", # ] } -# tokio-rustls = { version = "0.24.1", features = ["early-data"] } # tls and cert management hot_reload = "0.1.4" rustls = { version = "0.21.9", default-features = false } +tokio-rustls = { version = "0.24.1", features = ["early-data"] } webpki = "0.22.4" x509-parser = "0.15.1" @@ -68,6 +74,7 @@ h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } s2n-quic = { version = "1.31.0", default-features = false, features = [ "provider-tls-rustls", ], optional = true } +s2n-quic-core = { version = "0.31.0", default-features = false, optional = true } s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } s2n-quic-rustls = { version = "0.31.0", optional = true } # for UDP socket wit SO_REUSEADDR when h3 with quinn diff --git a/rpxy-lib/src/crypto/mod.rs b/rpxy-lib/src/crypto/mod.rs index 1f6566d..7b8935c 100644 --- a/rpxy-lib/src/crypto/mod.rs +++ b/rpxy-lib/src/crypto/mod.rs @@ -11,7 +11,7 @@ use service::CryptoReloader; use std::sync::Arc; pub use certs::{CertsAndKeys, CryptoSource}; -pub use service::ServerCryptoBase; +pub use service::{ServerCrypto, ServerCryptoBase, SniServerCryptoMap}; /// Result type inner of certificate reloader service type ReloaderServiceResultInner = ( diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 05e5db1..8152e0d 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -6,9 +6,51 @@ pub type RpxyResult = std::result::Result; /// Describes things that can go wrong in the Rpxy #[derive(Debug, Error)] pub enum RpxyError { + // general errors #[error("IO error: {0}")] Io(#[from] std::io::Error), + // TLS errors + #[error("Failed to build TLS acceptor: {0}")] + FailedToTlsHandshake(String), + #[error("No server name in ClientHello")] + NoServerNameInClientHello, + #[error("No TLS serving app: {0}")] + NoTlsServingApp(String), + #[error("Failed to update server crypto: {0}")] + FailedToUpdateServerCrypto(String), + #[error("No server crypto: {0}")] + NoServerCrypto(String), + + // hyper errors + #[error("hyper body manipulation error: {0}")] + HyperBodyManipulationError(String), + + // http/3 errors + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + #[error("H3 error: {0}")] + H3Error(#[from] h3::Error), + + #[cfg(feature = "http3-quinn")] + #[error("Invalid rustls TLS version: {0}")] + QuinnInvalidTlsProtocolVersion(String), + #[cfg(feature = "http3-quinn")] + #[error("Quinn connection error: {0}")] + QuinnConnectionFailed(#[from] quinn::ConnectionError), + + #[cfg(feature = "http3-s2n")] + #[error("s2n-quic validation error: {0}")] + S2nQuicValidationError(#[from] s2n_quic_core::transport::parameters::ValidationError), + #[cfg(feature = "http3-s2n")] + #[error("s2n-quic connection error: {0}")] + S2nQuicConnectionError(#[from] s2n_quic_core::connection::Error), + #[cfg(feature = "http3-s2n")] + #[error("s2n-quic start error: {0}")] + S2nQuicStartError(#[from] s2n_quic::provider::StartError), + + // certificate reloader errors + #[error("No certificate reloader when building a proxy for TLS")] + NoCertificateReloader, #[error("Certificate reload error: {0}")] CertificateReloadError(#[from] hot_reload::ReloaderError), @@ -20,6 +62,11 @@ pub enum RpxyError { #[error("Failed to build backend app: {0}")] FailedToBuildBackendApp(#[from] crate::backend::BackendAppBuilderError), + // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, + + // Others + #[error("Infallible")] + Infallible(#[from] std::convert::Infallible), } diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 2ca21e4..5b1ad61 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -1,5 +1,9 @@ +mod proxy_h3; mod proxy_main; -mod proxy_tls; +#[cfg(feature = "http3-quinn")] +mod proxy_quic_quinn; +#[cfg(feature = "http3-s2n")] +mod proxy_quic_s2n; mod socket; use crate::{globals::Globals, hyper_executor::LocalExecutor}; diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs new file mode 100644 index 0000000..056cd4b --- /dev/null +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -0,0 +1,205 @@ +use super::proxy_main::Proxy; +use crate::{error::*, log::*, name_exp::ServerName}; +use bytes::Bytes; +use http::{Request, Response}; +use http_body_util::BodyExt; +use std::{net::SocketAddr, time::Duration}; +use tokio::time::timeout; + +#[cfg(feature = "http3-quinn")] +use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; +#[cfg(feature = "http3-s2n")] +use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; + +// use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; +// use futures::Stream; +// use hyper_util::client::legacy::connect::Connect; + +// impl Proxy +// where +// // T: Connect + Clone + Sync + Send + 'static, +// U: CryptoSource + Clone + Sync + Send + 'static, +// { + +impl Proxy { + pub(super) async fn h3_serve_connection( + &self, + quic_connection: C, + tls_server_name: ServerName, + client_addr: SocketAddr, + ) -> RpxyResult<()> + where + C: ConnectionQuic, + >::BidiStream: BidiStream + Send + 'static, + <>::BidiStream as BidiStream>::RecvStream: Send, + <>::BidiStream as BidiStream>::SendStream: Send, + { + let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; + info!( + "QUIC/HTTP3 connection established from {:?} {}", + client_addr, + <&ServerName as TryInto>::try_into(&tls_server_name).unwrap_or_default() + ); + + // TODO: Is here enough to fetch server_name from NewConnection? + // to avoid deep nested call from listener_service_h3 + loop { + // this routine follows hyperium/h3 examples https://github.com/hyperium/h3/blob/master/examples/server.rs + match h3_conn.accept().await { + Ok(None) => { + break; + } + Err(e) => { + warn!("HTTP/3 error on accept incoming connection: {}", e); + match e.get_error_level() { + h3::error::ErrorLevel::ConnectionError => break, + h3::error::ErrorLevel::StreamError => continue, + } + } + Ok(Some((req, stream))) => { + // We consider the connection count separately from the stream count. + // Max clients for h1/h2 = max 'stream' for h3. + let request_count = self.globals.request_count.clone(); + if request_count.increment() > self.globals.proxy_config.max_clients { + request_count.decrement(); + h3_conn.shutdown(0).await?; + break; + } + debug!("Request incoming: current # {}", request_count.current()); + + let self_inner = self.clone(); + let tls_server_name_inner = tls_server_name.clone(); + self.globals.runtime_handle.spawn(async move { + if let Err(e) = timeout( + self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 + self_inner.h3_serve_stream(req, stream, client_addr, tls_server_name_inner), + ) + .await + { + error!("HTTP/3 failed to process stream: {}", e); + } + request_count.decrement(); + debug!("Request processed: current # {}", request_count.current()); + }); + } + } + } + + Ok(()) + } + + /// Serves a request stream from a client + /// TODO: TODO: TODO: TODO: + /// TODO: Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside. + /// Thus, we need to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of + /// Either as body. + /// Also, the downstream from the backend handler could be Incoming, but will be wrapped as Either as well due to H3. + /// Result, E> type includes E as HttpError to generate the status code and related Response. + /// Thus to handle synthetic error messages in BoxBody, the serve() function outputs Response, BoxBody>>>. + async fn h3_serve_stream( + &self, + req: Request<()>, + stream: RequestStream, + client_addr: SocketAddr, + tls_server_name: ServerName, + ) -> RpxyResult<()> + where + S: BidiStream + Send + 'static, + >::RecvStream: Send, + { + let (req_parts, _) = req.into_parts(); + // split stream and async body handling + let (mut send_stream, mut recv_stream) = stream.split(); + + // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; + // // let max = body_stream.size_hint().upper().unwrap_or(u64::MAX); + // // if max > max_body_size as u64 { + // // return Err(HttpError::TooLargeRequestBody); + // // } + + // let new_req = Request::from_parts(req_parts, body_stream); + + // // generate streamed body with trailers using channel + // let (body_sender, req_body) = Incoming::channel(); + + // // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 + // // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. + // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; + // self.globals.runtime_handle.spawn(async move { + // // let mut sender = body_sender; + // let mut size = 0usize; + // while let Some(mut body) = recv_stream.recv_data().await? { + // debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); + // size += body.remaining(); + // if size > max_body_size { + // error!( + // "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", + // size, max_body_size + // ); + // return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); + // } + // // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes + // // sender.send_data(body.copy_to_bytes(body.remaining())).await?; + // } + + // // trailers: use inner for work around. (directly get trailer) + // let trailers = recv_stream.as_mut().recv_trailers().await?; + // if trailers.is_some() { + // debug!("HTTP/3 incoming request trailers"); + // // sender.send_trailers(trailers.unwrap()).await?; + // } + // Ok(()) + // }); + + // let new_req: Request = Request::from_parts(req_parts, req_body); + // let res = self + // .msg_handler + // .clone() + // .handle_request( + // new_req, + // client_addr, + // self.listening_on, + // self.tls_enabled, + // Some(tls_server_name), + // ) + // .await?; + + // TODO: TODO: TODO: remove later + let body = full(hyper::body::Bytes::from("hello h3 echo")); + let res = Response::builder().body(body).unwrap(); + ///////////////// + + let (new_res_parts, new_body) = res.into_parts(); + let new_res = Response::from_parts(new_res_parts, ()); + + match send_stream.send_response(new_res).await { + Ok(_) => { + debug!("HTTP/3 response to connection successful"); + // aggregate body without copying + let body_data = new_body + .collect() + .await + .map_err(|e| RpxyError::HyperBodyManipulationError(e.to_string()))?; + + // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes inside to_bytes() + send_stream.send_data(body_data.to_bytes()).await?; + + // TODO: needs handling trailer? should be included in body from handler. + } + Err(err) => { + error!("Unable to send response to connection peer: {:?}", err); + } + } + Ok(send_stream.finish().await?) + } +} + +////////////// +/// TODO: remove later +/// helper function to build a full body +use http_body_util::Full; +pub(crate) type BoxBody = http_body_util::combinators::BoxBody; +pub fn full(body: hyper::body::Bytes) -> BoxBody { + Full::new(body).map_err(|never| match never {}).boxed() +} +////////////// diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 5aea172..0d6eb83 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,10 +1,59 @@ use super::socket::bind_tcp_socket; -use crate::{error::RpxyResult, globals::Globals, log::*}; -use hyper_util::server::conn::auto::Builder as ConnectionBuilder; -use std::{net::SocketAddr, sync::Arc}; +use crate::{ + constants::TLS_HANDSHAKE_TIMEOUT_SEC, + crypto::{ServerCrypto, SniServerCryptoMap}, + error::*, + globals::Globals, + hyper_executor::LocalExecutor, + log::*, + name_exp::ServerName, +}; +use futures::{select, FutureExt}; +use http::{Request, Response}; +use hyper::{ + body::Incoming, + rt::{Read, Write}, + service::service_fn, +}; +use hyper_util::{rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tokio::time::timeout; +/// Wrapper function to handle request for HTTP/1.1 and HTTP/2 +/// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler +async fn serve_request( + req: Request, + // handler: Arc>, + // handler: Arc>, + client_addr: SocketAddr, + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, +) -> RpxyResult> { + // match handler + // .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) + // .await? + // { + // Ok(res) => passthrough_response(res), + // Err(e) => synthetic_error_response(StatusCode::from(e)), + // } + let body = full(hyper::body::Bytes::from("hello")); + let res = Response::builder().body(body).unwrap(); + Ok(res) +} +////////////// +/// TODO: remove later +/// helper function to build a full body +use http_body_util::{BodyExt, Full}; +pub(crate) type BoxBody = http_body_util::combinators::BoxBody; +pub fn full(body: hyper::body::Bytes) -> BoxBody { + Full::new(body).map_err(|never| match never {}).boxed() +} +////////////// + +#[derive(Clone)] /// Proxy main object responsible to serve requests received from clients at the given socket address. -pub(crate) struct Proxy { +pub(crate) struct Proxy { /// global context shared among async tasks pub globals: Arc, /// listen socket address @@ -15,7 +64,49 @@ pub(crate) struct Proxy { pub connection_builder: Arc>, } -impl Proxy { +impl Proxy { + /// Serves requests from clients + fn serve_connection(&self, stream: I, peer_addr: SocketAddr, tls_server_name: Option) + where + I: Read + Write + Send + Unpin + 'static, + { + let request_count = self.globals.request_count.clone(); + if request_count.increment() > self.globals.proxy_config.max_clients { + request_count.decrement(); + return; + } + debug!("Request incoming: current # {}", request_count.current()); + + let server_clone = self.connection_builder.clone(); + // let msg_handler_clone = self.msg_handler.clone(); + let timeout_sec = self.globals.proxy_config.proxy_timeout; + let tls_enabled = self.tls_enabled; + let listening_on = self.listening_on; + self.globals.runtime_handle.clone().spawn(async move { + timeout( + timeout_sec + Duration::from_secs(1), + server_clone.serve_connection_with_upgrades( + stream, + service_fn(move |req: Request| { + serve_request( + req, + // msg_handler_clone.clone(), + peer_addr, + listening_on, + tls_enabled, + tls_server_name.clone(), + ) + }), + ), + ) + .await + .ok(); + + request_count.decrement(); + debug!("Request processed: current # {}", request_count.current()); + }); + } + /// Start without TLS (HTTP cleartext) async fn start_without_tls(&self) -> RpxyResult<()> { let listener_service = async { @@ -23,7 +114,7 @@ impl Proxy { let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; info!("Start TCP proxy serving with HTTP request for configured host names"); while let Ok((stream, client_addr)) = tcp_listener.accept().await { - // self.serve_connection(TokioIo::new(stream), client_addr, None); + self.serve_connection(TokioIo::new(stream), client_addr, None); } Ok(()) as RpxyResult<()> }; @@ -33,14 +124,108 @@ impl Proxy { /// Start with TLS (HTTPS) pub(super) async fn start_with_tls(&self) -> RpxyResult<()> { - // let (cert_reloader_service, cert_reloader_rx) = ReloaderService::, ServerCryptoBase>::new( - // &self.globals.clone(), - // CERTS_WATCH_DELAY_SECS, - // !LOAD_CERTS_ONLY_WHEN_UPDATED, - // ) - // .await - // .map_err(|e| anyhow::anyhow!(e))?; - loop {} + #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + { + self.tls_listener_service().await?; + error!("TCP proxy service for TLS exited"); + Ok(()) + } + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + { + if self.globals.proxy_config.http3 { + select! { + _ = self.tls_listener_service().fuse() => { + error!("TCP proxy service for TLS exited"); + }, + _ = self.h3_listener_service().fuse() => { + error!("UDP proxy service for QUIC exited"); + } + }; + Ok(()) + } else { + self.tls_listener_service().await?; + error!("TCP proxy service for TLS exited"); + Ok(()) + } + } + } + + // TCP Listener Service, i.e., http/2 and http/1.1 + async fn tls_listener_service(&self) -> RpxyResult<()> { + let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { + return Err(RpxyError::NoCertificateReloader); + }; + let tcp_socket = bind_tcp_socket(&self.listening_on)?; + let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; + info!("Start TCP proxy serving with HTTPS request for configured host names"); + + let mut server_crypto_map: Option> = None; + loop { + select! { + tcp_cnx = tcp_listener.accept().fuse() => { + if tcp_cnx.is_err() || server_crypto_map.is_none() { + continue; + } + let (raw_stream, client_addr) = tcp_cnx.unwrap(); + let sc_map_inner = server_crypto_map.clone(); + let self_inner = self.clone(); + + // spawns async handshake to avoid blocking thread by sequential handshake. + let handshake_fut = async move { + let acceptor = tokio_rustls::LazyConfigAcceptor::new(tokio_rustls::rustls::server::Acceptor::default(), raw_stream).await; + if let Err(e) = acceptor { + return Err(RpxyError::FailedToTlsHandshake(e.to_string())); + } + let start = acceptor.unwrap(); + let client_hello = start.client_hello(); + let sni = client_hello.server_name(); + debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", sni.unwrap_or("None")); + let server_name = sni.map(ServerName::from); + if server_name.is_none(){ + return Err(RpxyError::NoServerNameInClientHello); + } + let server_crypto = sc_map_inner.as_ref().unwrap().get(server_name.as_ref().unwrap()); + if server_crypto.is_none() { + return Err(RpxyError::NoTlsServingApp(server_name.as_ref().unwrap().try_into().unwrap_or_default())); + } + let stream = match start.into_stream(server_crypto.unwrap().clone()).await { + Ok(s) => TokioIo::new(s), + Err(e) => { + return Err(RpxyError::FailedToTlsHandshake(e.to_string())); + } + }; + self_inner.serve_connection(stream, client_addr, server_name); + Ok(()) as RpxyResult<()> + }; + + self.globals.runtime_handle.spawn( async move { + // timeout is introduced to avoid get stuck here. + let Ok(v) = timeout( + Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), + handshake_fut + ).await else { + error!("Timeout to handshake TLS"); + return; + }; + if let Err(e) = v { + error!("{}", e); + } + }); + } + _ = server_crypto_rx.changed().fuse() => { + if server_crypto_rx.borrow().is_none() { + error!("Reloader is broken"); + break; + } + let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); + let Some(server_crypto): Option> = (&cert_keys_map).try_into().ok() else { + error!("Failed to update server crypto"); + break; + }; + server_crypto_map = Some(server_crypto.inner_local_map.clone()); + } + } + } Ok(()) } @@ -56,11 +241,11 @@ impl Proxy { match &self.globals.term_notify { Some(term) => { - tokio::select! { - _ = proxy_service => { + select! { + _ = proxy_service.fuse() => { warn!("Proxy service got down"); } - _ = term.notified() => { + _ = term.notified().fuse() => { info!("Proxy service listening on {} receives term signal", self.listening_on); } } diff --git a/rpxy-lib/src/proxy/proxy_quic_quinn.rs b/rpxy-lib/src/proxy/proxy_quic_quinn.rs new file mode 100644 index 0000000..bde3b00 --- /dev/null +++ b/rpxy-lib/src/proxy/proxy_quic_quinn.rs @@ -0,0 +1,121 @@ +use super::proxy_main::Proxy; +use super::socket::bind_udp_socket; +use crate::{crypto::ServerCrypto, error::*, log::*, name_exp::ByteName}; +// use hyper_util::client::legacy::connect::Connect; +use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; +use rustls::ServerConfig; +use std::sync::Arc; + +impl Proxy +// where +// // T: Connect + Clone + Sync + Send + 'static, +// U: CryptoSource + Clone + Sync + Send + 'static, +{ + pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { + let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { + return Err(RpxyError::NoCertificateReloader); + }; + info!("Start UDP proxy serving with HTTP/3 request for configured host names [quinn]"); + // first set as null config server + let rustls_server_config = ServerConfig::builder() + .with_safe_default_cipher_suites() + .with_safe_default_kx_groups() + .with_protocol_versions(&[&rustls::version::TLS13]) + .map_err(|e| RpxyError::QuinnInvalidTlsProtocolVersion(e.to_string()))? + .with_no_client_auth() + .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); + + let mut transport_config_quic = TransportConfig::default(); + transport_config_quic + .max_concurrent_bidi_streams(self.globals.proxy_config.h3_max_concurrent_bidistream.into()) + .max_concurrent_uni_streams(self.globals.proxy_config.h3_max_concurrent_unistream.into()) + .max_idle_timeout( + self + .globals + .proxy_config + .h3_max_idle_timeout + .map(|v| quinn::IdleTimeout::try_from(v).unwrap()), + ); + + let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config)); + server_config_h3.transport = Arc::new(transport_config_quic); + server_config_h3.concurrent_connections(self.globals.proxy_config.h3_max_concurrent_connections); + + // To reuse address + let udp_socket = bind_udp_socket(&self.listening_on)?; + let runtime = quinn::default_runtime() + .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?; + let endpoint = Endpoint::new( + quinn::EndpointConfig::default(), + Some(server_config_h3), + udp_socket, + runtime, + )?; + + let mut server_crypto: Option> = None; + loop { + tokio::select! { + new_conn = endpoint.accept() => { + if server_crypto.is_none() || new_conn.is_none() { + continue; + } + let mut conn: quinn::Connecting = new_conn.unwrap(); + let Ok(hsd) = conn.handshake_data().await else { + continue + }; + + let Ok(hsd_downcast) = hsd.downcast::() else { + continue + }; + let Some(new_server_name) = hsd_downcast.server_name else { + warn!("HTTP/3 no SNI is given"); + continue; + }; + debug!( + "HTTP/3 connection incoming (SNI {:?})", + new_server_name + ); + // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは? + // TODO: 通常のTLSと同じenumか何かにまとめたい + let self_clone = self.clone(); + self.globals.runtime_handle.spawn(async move { + let client_addr = conn.remote_address(); + let quic_connection = match conn.await { + Ok(new_conn) => { + info!("New connection established"); + h3_quinn::Connection::new(new_conn) + }, + Err(e) => { + warn!("QUIC accepting connection failed: {:?}", e); + return Err(RpxyError::QuinnConnectionFailed(e)); + } + }; + // Timeout is based on underlying quic + if let Err(e) = self_clone.h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr).await { + warn!("QUIC or HTTP/3 connection failed: {}", e); + }; + Ok(()) + }); + } + _ = server_crypto_rx.changed() => { + if server_crypto_rx.borrow().is_none() { + error!("Reloader is broken"); + break; + } + let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); + + server_crypto = (&cert_keys_map).try_into().ok(); + let Some(inner) = server_crypto.clone() else { + error!("Failed to update server crypto for h3"); + break; + }; + endpoint.set_server_config(Some(QuicServerConfig::with_crypto(inner.clone().inner_global_no_client_auth.clone()))); + + } + else => break + } + } + endpoint.wait_idle().await; + Ok(()) as RpxyResult<()> + } +} diff --git a/rpxy-lib/src/proxy/proxy_quic_s2n.rs b/rpxy-lib/src/proxy/proxy_quic_s2n.rs new file mode 100644 index 0000000..32be619 --- /dev/null +++ b/rpxy-lib/src/proxy/proxy_quic_s2n.rs @@ -0,0 +1,132 @@ +use super::proxy_main::Proxy; +use crate::{ + crypto::{ServerCrypto, ServerCryptoBase}, + error::*, + log::*, + name_exp::ByteName, +}; +use hot_reload::ReloaderReceiver; +use std::sync::Arc; +// use hyper_util::client::legacy::connect::Connect; +use s2n_quic::provider; + +impl Proxy { + /// Start UDP proxy serving with HTTP/3 request for configured host names + pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { + let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { + return Err(RpxyError::NoCertificateReloader); + }; + info!("Start UDP proxy serving with HTTP/3 request for configured host names [s2n-quic]"); + + // initially wait for receipt + let mut server_crypto: Option> = { + let _ = server_crypto_rx.changed().await; + let sc = self.receive_server_crypto(server_crypto_rx.clone())?; + Some(sc) + }; + + // event loop + loop { + tokio::select! { + v = self.h3_listener_service_inner(&server_crypto) => { + if let Err(e) = v { + error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); + break; + } + } + _ = server_crypto_rx.changed() => { + server_crypto = match self.receive_server_crypto(server_crypto_rx.clone()) { + Ok(sc) => Some(sc), + Err(e) => { + error!("{e}"); + break; + } + }; + } + else => break + } + } + + Ok(()) + } + + /// Receive server crypto from reloader + fn receive_server_crypto( + &self, + server_crypto_rx: ReloaderReceiver, + ) -> RpxyResult> { + let cert_keys_map = server_crypto_rx.borrow().clone().ok_or_else(|| { + error!("Reloader is broken"); + RpxyError::CertificateReloadError(anyhow!("Reloader is broken").into()) + })?; + + let server_crypto: Option> = (&cert_keys_map).try_into().ok(); + server_crypto.ok_or_else(|| { + error!("Failed to update server crypto for h3 [s2n-quic]"); + RpxyError::FailedToUpdateServerCrypto("Failed to update server crypto for h3 [s2n-quic]".to_string()) + }) + } + + /// Event loop for UDP proxy serving with HTTP/3 request for configured host names + async fn h3_listener_service_inner(&self, server_crypto: &Option>) -> RpxyResult<()> { + // setup UDP socket + let io = provider::io::tokio::Builder::default() + .with_receive_address(self.listening_on)? + .with_reuse_port()? + .build()?; + + // setup limits + let mut limits = provider::limits::Limits::default() + .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? + .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? + .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? + .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? + .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64)?; + limits = if let Some(v) = self.globals.proxy_config.h3_max_idle_timeout { + limits.with_max_idle_timeout(v)? + } else { + limits + }; + + // setup tls + let Some(server_crypto) = server_crypto else { + warn!("No server crypto is given [s2n-quic]"); + return Err(RpxyError::NoServerCrypto( + "No server crypto is given [s2n-quic]".to_string(), + )); + }; + let tls = server_crypto.inner_global_no_client_auth.clone(); + + let mut server = s2n_quic::Server::builder() + .with_tls(tls)? + .with_io(io)? + .with_limits(limits)? + .start()?; + + // quic event loop. this immediately cancels when crypto is updated by tokio::select! + while let Some(new_conn) = server.accept().await { + debug!("New QUIC connection established"); + let Ok(Some(new_server_name)) = new_conn.server_name() else { + warn!("HTTP/3 no SNI is given"); + continue; + }; + debug!("HTTP/3 connection incoming (SNI {:?})", new_server_name); + let self_clone = self.clone(); + + self.globals.runtime_handle.spawn(async move { + let client_addr = new_conn.remote_addr()?; + let quic_connection = s2n_quic_h3::Connection::new(new_conn); + // Timeout is based on underlying quic + if let Err(e) = self_clone + .h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr) + .await + { + warn!("QUIC or HTTP/3 connection failed: {}", e); + }; + Ok(()) as RpxyResult<()> + }); + } + + Ok(()) + } +} diff --git a/rpxy-lib/src/proxy/proxy_tls.rs b/rpxy-lib/src/proxy/proxy_tls.rs deleted file mode 100644 index f67ad8d..0000000 --- a/rpxy-lib/src/proxy/proxy_tls.rs +++ /dev/null @@ -1,6 +0,0 @@ -use super::proxy_main::Proxy; -use crate::{log::*, error::*}; - -impl Proxy{ - -} From 4b6f63e09fccf2ca0f697e720c0fed9797a65f1d Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 25 Nov 2023 00:26:00 +0900 Subject: [PATCH 08/50] wip: implemented incoming-like body for asynchronous operation in http/3 --- rpxy-lib/Cargo.toml | 2 + rpxy-lib/src/error.rs | 5 + rpxy-lib/src/hyper_ext/body_incoming_like.rs | 189 ++++++++++++++++++ rpxy-lib/src/hyper_ext/body_type.rs | 41 ++++ .../executor.rs} | 0 rpxy-lib/src/hyper_ext/mod.rs | 13 ++ rpxy-lib/src/hyper_ext/watch.rs | 67 +++++++ rpxy-lib/src/lib.rs | 2 +- rpxy-lib/src/proxy/mod.rs | 2 +- rpxy-lib/src/proxy/proxy_h3.rs | 97 ++++----- rpxy-lib/src/proxy/proxy_main.rs | 26 +-- 11 files changed, 376 insertions(+), 68 deletions(-) create mode 100644 rpxy-lib/src/hyper_ext/body_incoming_like.rs create mode 100644 rpxy-lib/src/hyper_ext/body_type.rs rename rpxy-lib/src/{hyper_executor.rs => hyper_ext/executor.rs} (100%) create mode 100644 rpxy-lib/src/hyper_ext/mod.rs create mode 100644 rpxy-lib/src/hyper_ext/watch.rs diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 847d3fe..3469162 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -50,6 +50,8 @@ http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } hyper-util = { version = "0.1.1", features = ["full"] } +futures-util = { version = "0.3.29", default-features = false } +futures-channel = { version = "0.3.29", default-features = false } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio", diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 8152e0d..37f35e0 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -25,11 +25,16 @@ pub enum RpxyError { // hyper errors #[error("hyper body manipulation error: {0}")] HyperBodyManipulationError(String), + #[error("New closed in incoming-like")] + HyperIncomingLikeNewClosed, // http/3 errors #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] #[error("H3 error: {0}")] H3Error(#[from] h3::Error), + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + #[error("Exceeds max request body size for HTTP/3")] + H3TooLargeBody, #[cfg(feature = "http3-quinn")] #[error("Invalid rustls TLS version: {0}")] diff --git a/rpxy-lib/src/hyper_ext/body_incoming_like.rs b/rpxy-lib/src/hyper_ext/body_incoming_like.rs new file mode 100644 index 0000000..2fced25 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/body_incoming_like.rs @@ -0,0 +1,189 @@ +use super::watch; +use crate::error::*; +use futures_channel::{mpsc, oneshot}; +use futures_util::{stream::FusedStream, Future, Stream}; +use http::HeaderMap; +use hyper::body::{Body, Bytes, Frame, SizeHint}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +//////////////////////////////////////////////////////////// +/// Incoming like body to handle incoming request body +pub struct IncomingLike { + content_length: DecodedLength, + want_tx: watch::Sender, + data_rx: mpsc::Receiver>, + trailers_rx: oneshot::Receiver, +} + +macro_rules! ready { + ($e:expr) => { + match $e { + Poll::Ready(v) => v, + Poll::Pending => return Poll::Pending, + } + }; +} + +type BodySender = mpsc::Sender>; +type TrailersSender = oneshot::Sender; + +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) struct DecodedLength(u64); +impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } + /// Converts to an Option representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option { + match self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known), + } + } +} +pub(crate) struct Sender { + want_rx: watch::Receiver, + data_tx: BodySender, + trailers_tx: Option, +} + +const WANT_PENDING: usize = 1; +const WANT_READY: usize = 2; + +impl IncomingLike { + /// Create a `Body` stream with an associated sender half. + /// + /// Useful when wanting to stream chunks from another thread. + #[inline] + #[allow(unused)] + pub(crate) fn channel() -> (Sender, IncomingLike) { + Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) + } + + pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, IncomingLike) { + let (data_tx, data_rx) = mpsc::channel(0); + let (trailers_tx, trailers_rx) = oneshot::channel(); + + // If wanter is true, `Sender::poll_ready()` won't becoming ready + // until the `Body` has been polled for data once. + let want = if wanter { WANT_PENDING } else { WANT_READY }; + + let (want_tx, want_rx) = watch::channel(want); + + let tx = Sender { + want_rx, + data_tx, + trailers_tx: Some(trailers_tx), + }; + let rx = IncomingLike { + content_length, + want_tx, + data_rx, + trailers_rx, + }; + + (tx, rx) + } +} + +impl Body for IncomingLike { + type Data = Bytes; + type Error = hyper::Error; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.want_tx.send(WANT_READY); + + if !self.data_rx.is_terminated() { + if let Some(chunk) = ready!(Pin::new(&mut self.data_rx).poll_next(cx)?) { + self.content_length.sub_if(chunk.len() as u64); + return Poll::Ready(Some(Ok(Frame::data(chunk)))); + } + } + + // check trailers after data is terminated + match ready!(Pin::new(&mut self.trailers_rx).poll(cx)) { + Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))), + Err(_) => Poll::Ready(None), + } + } + + fn is_end_stream(&self) -> bool { + self.content_length == DecodedLength::ZERO + } + + fn size_hint(&self) -> SizeHint { + macro_rules! opt_len { + ($content_length:expr) => {{ + let mut hint = SizeHint::default(); + + if let Some(content_length) = $content_length.into_opt() { + hint.set_exact(content_length); + } + + hint + }}; + } + + opt_len!(self.content_length) + } +} + +impl Sender { + /// Check to see if this `Sender` can send more data. + pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); + self + .data_tx + .poll_ready(cx) + .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } + + fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> RpxyResult<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + + /// Send data on data channel when it is ready. + #[allow(unused)] + pub(crate) async fn send_data(&mut self, chunk: Bytes) -> RpxyResult<()> { + self.ready().await?; + self + .data_tx + .try_send(Ok(chunk)) + .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } + + /// Send trailers on trailers channel. + #[allow(unused)] + pub(crate) async fn send_trailers(&mut self, trailers: HeaderMap) -> RpxyResult<()> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(RpxyError::HyperIncomingLikeNewClosed), + }; + tx.send(trailers).map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } +} diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs new file mode 100644 index 0000000..ba1bdc2 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -0,0 +1,41 @@ +use crate::error::*; +use http::{Response, StatusCode}; +use http_body_util::{combinators, BodyExt, Either, Empty, Full}; +use hyper::body::{Bytes, Incoming}; + +/// Type for synthetic boxed body +pub(crate) type BoxBody = combinators::BoxBody; +/// Type for either passthrough body or given body type, specifically synthetic boxed body +pub(crate) type IncomingOr = Either; + +/// helper function to build http response with passthrough body +pub(crate) fn passthrough_response(response: Response) -> RpxyResult>> +where + B: hyper::body::Body, +{ + Ok(response.map(IncomingOr::Left)) +} + +/// helper function to build http response with synthetic body +pub(crate) fn synthetic_response(response: Response) -> RpxyResult>> { + Ok(response.map(IncomingOr::Right)) +} + +/// build http response with status code of 4xx and 5xx +pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult>> { + let res = Response::builder() + .status(status_code) + .body(IncomingOr::Right(BoxBody::new(empty()))) + .unwrap(); + Ok(res) +} + +/// helper function to build a empty body +fn empty() -> BoxBody { + Empty::::new().map_err(|never| match never {}).boxed() +} + +/// helper function to build a full body +pub(crate) fn full(body: Bytes) -> BoxBody { + Full::new(body).map_err(|never| match never {}).boxed() +} diff --git a/rpxy-lib/src/hyper_executor.rs b/rpxy-lib/src/hyper_ext/executor.rs similarity index 100% rename from rpxy-lib/src/hyper_executor.rs rename to rpxy-lib/src/hyper_ext/executor.rs diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs new file mode 100644 index 0000000..19511a1 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -0,0 +1,13 @@ +mod body_incoming_like; +mod body_type; +mod executor; +mod watch; + +pub(crate) mod rt { + pub(crate) use super::executor::LocalExecutor; +} +pub(crate) mod body { + pub(crate) use super::body_incoming_like::IncomingLike; + pub(crate) use super::body_type::{BoxBody, IncomingOr}; +} +pub(crate) use body_type::{full, passthrough_response, synthetic_error_response, synthetic_response}; diff --git a/rpxy-lib/src/hyper_ext/watch.rs b/rpxy-lib/src/hyper_ext/watch.rs new file mode 100644 index 0000000..d5e1c7e --- /dev/null +++ b/rpxy-lib/src/hyper_ext/watch.rs @@ -0,0 +1,67 @@ +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. +// from https://github.com/hyperium/hyper/blob/master/src/common/watch.rs + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(super) const CLOSED: usize = 0; + +pub(super) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!(initial != CLOSED, "watch::channel initial state of 0 is reserved"); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + (Sender { shared: shared.clone() }, Receiver { shared }) +} + +pub(super) struct Sender { + shared: Arc, +} + +pub(super) struct Receiver { + shared: Arc, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(super) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + #[allow(dead_code)] + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 1f5fa37..706f7b2 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -4,7 +4,7 @@ mod count; mod crypto; mod error; mod globals; -mod hyper_executor; +mod hyper_ext; mod log; mod name_exp; mod proxy; diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 5b1ad61..e4ac6f7 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -6,7 +6,7 @@ mod proxy_quic_quinn; mod proxy_quic_s2n; mod socket; -use crate::{globals::Globals, hyper_executor::LocalExecutor}; +use crate::{globals::Globals, hyper_ext::rt::LocalExecutor}; use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; use std::sync::Arc; diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 056cd4b..6ca6528 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,6 +1,14 @@ use super::proxy_main::Proxy; -use crate::{error::*, log::*, name_exp::ServerName}; -use bytes::Bytes; +use crate::{ + error::*, + hyper_ext::{ + body::{IncomingLike, IncomingOr}, + full, synthetic_response, + }, + log::*, + name_exp::ServerName, +}; +use bytes::{Buf, Bytes}; use http::{Request, Response}; use http_body_util::BodyExt; use std::{net::SocketAddr, time::Duration}; @@ -11,7 +19,6 @@ use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestSt #[cfg(feature = "http3-s2n")] use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -// use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; // use futures::Stream; // use hyper_util::client::legacy::connect::Connect; @@ -111,48 +118,41 @@ impl Proxy { // split stream and async body handling let (mut send_stream, mut recv_stream) = stream.split(); - // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; - // // let max = body_stream.size_hint().upper().unwrap_or(u64::MAX); - // // if max > max_body_size as u64 { - // // return Err(HttpError::TooLargeRequestBody); - // // } + // generate streamed body with trailers using channel + let (body_sender, req_body) = IncomingLike::channel(); - // let new_req = Request::from_parts(req_parts, body_stream); + // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 + // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. + let max_body_size = self.globals.proxy_config.h3_request_max_body_size; + self.globals.runtime_handle.spawn(async move { + let mut sender = body_sender; + let mut size = 0usize; + while let Some(mut body) = recv_stream.recv_data().await? { + debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); + size += body.remaining(); + if size > max_body_size { + error!( + "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", + size, max_body_size + ); + return Err(RpxyError::H3TooLargeBody); + } + // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes + sender.send_data(body.copy_to_bytes(body.remaining())).await?; + } - // // generate streamed body with trailers using channel - // let (body_sender, req_body) = Incoming::channel(); + // trailers: use inner for work around. (directly get trailer) + let trailers = recv_stream.as_mut().recv_trailers().await?; + if trailers.is_some() { + debug!("HTTP/3 incoming request trailers"); + sender.send_trailers(trailers.unwrap()).await?; + } + Ok(()) as RpxyResult<()> + }); - // // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 - // // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. - // let max_body_size = self.globals.proxy_config.h3_request_max_body_size; - // self.globals.runtime_handle.spawn(async move { - // // let mut sender = body_sender; - // let mut size = 0usize; - // while let Some(mut body) = recv_stream.recv_data().await? { - // debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); - // size += body.remaining(); - // if size > max_body_size { - // error!( - // "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", - // size, max_body_size - // ); - // return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); - // } - // // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes - // // sender.send_data(body.copy_to_bytes(body.remaining())).await?; - // } + let mut new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); - // // trailers: use inner for work around. (directly get trailer) - // let trailers = recv_stream.as_mut().recv_trailers().await?; - // if trailers.is_some() { - // debug!("HTTP/3 incoming request trailers"); - // // sender.send_trailers(trailers.unwrap()).await?; - // } - // Ok(()) - // }); - - // let new_req: Request = Request::from_parts(req_parts, req_body); - // let res = self + // let res = selfw // .msg_handler // .clone() // .handle_request( @@ -165,8 +165,9 @@ impl Proxy { // .await?; // TODO: TODO: TODO: remove later - let body = full(hyper::body::Bytes::from("hello h3 echo")); - let res = Response::builder().body(body).unwrap(); + let body = full(Bytes::from("hello h3 echo")); + // here response is IncomingOr from message handler + let res = synthetic_response(Response::builder().body(body).unwrap())?; ///////////////// let (new_res_parts, new_body) = res.into_parts(); @@ -193,13 +194,3 @@ impl Proxy { Ok(send_stream.finish().await?) } } - -////////////// -/// TODO: remove later -/// helper function to build a full body -use http_body_util::Full; -pub(crate) type BoxBody = http_body_util::combinators::BoxBody; -pub fn full(body: hyper::body::Bytes) -> BoxBody { - Full::new(body).map_err(|never| match never {}).boxed() -} -////////////// diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 0d6eb83..cc6636d 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -4,7 +4,12 @@ use crate::{ crypto::{ServerCrypto, SniServerCryptoMap}, error::*, globals::Globals, - hyper_executor::LocalExecutor, + hyper_ext::{ + body::{BoxBody, IncomingOr}, + full, + rt::LocalExecutor, + synthetic_response, + }, log::*, name_exp::ServerName, }; @@ -22,14 +27,14 @@ use tokio::time::timeout; /// Wrapper function to handle request for HTTP/1.1 and HTTP/2 /// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler async fn serve_request( - req: Request, + mut req: Request, // handler: Arc>, // handler: Arc>, client_addr: SocketAddr, listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, -) -> RpxyResult> { +) -> RpxyResult>> { // match handler // .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) // .await? @@ -37,19 +42,14 @@ async fn serve_request( // Ok(res) => passthrough_response(res), // Err(e) => synthetic_error_response(StatusCode::from(e)), // } + + ////////////// + // TODO: remove later let body = full(hyper::body::Bytes::from("hello")); let res = Response::builder().body(body).unwrap(); - Ok(res) + synthetic_response(res) + ////////////// } -////////////// -/// TODO: remove later -/// helper function to build a full body -use http_body_util::{BodyExt, Full}; -pub(crate) type BoxBody = http_body_util::combinators::BoxBody; -pub fn full(body: hyper::body::Bytes) -> BoxBody { - Full::new(body).map_err(|never| match never {}).boxed() -} -////////////// #[derive(Clone)] /// Proxy main object responsible to serve requests received from clients at the given socket address. From b8cec687b23dc65e5d0bbf0896b2f6dc2049c548 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 25 Nov 2023 00:37:54 +0900 Subject: [PATCH 09/50] wip: add stub for message handler --- rpxy-lib/Cargo.toml | 6 +++--- rpxy-lib/src/lib.rs | 1 + rpxy-lib/src/message_handler/mod.rs | 0 3 files changed, 4 insertions(+), 3 deletions(-) create mode 100644 rpxy-lib/src/message_handler/mod.rs diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 3469162..583ac59 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -84,7 +84,7 @@ socket2 = { version = "0.5.5", features = ["all"], optional = true } # # cache # http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } -# lru = { version = "0.12.0", optional = true } +# lru = { version = "0.12.1", optional = true } # cookie handling for sticky cookie chrono = { version = "0.4.31", default-features = false, features = [ @@ -96,5 +96,5 @@ base64 = { version = "0.21.5", optional = true } sha2 = { version = "0.10.8", default-features = false, optional = true } -# [dev-dependencies] -# # http and tls +[dev-dependencies] +# http and tls diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 706f7b2..4e63cbe 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -6,6 +6,7 @@ mod error; mod globals; mod hyper_ext; mod log; +mod message_handler; mod name_exp; mod proxy; diff --git a/rpxy-lib/src/message_handler/mod.rs b/rpxy-lib/src/message_handler/mod.rs new file mode 100644 index 0000000..e69de29 From e8d67bfc41d66d80f789a2d789cb99fbbff6b85c Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 25 Nov 2023 02:33:50 +0900 Subject: [PATCH 10/50] wip: add tests for incoming-like body --- rpxy-lib/Cargo.toml | 2 +- rpxy-lib/src/error.rs | 2 + rpxy-lib/src/hyper_ext/body_incoming_like.rs | 187 ++++++++++++++++++- 3 files changed, 187 insertions(+), 4 deletions(-) diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 583ac59..51631af 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -97,4 +97,4 @@ sha2 = { version = "0.10.8", default-features = false, optional = true } [dev-dependencies] -# http and tls +tokio-test = "0.4.3" diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 37f35e0..49a1d1c 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -27,6 +27,8 @@ pub enum RpxyError { HyperBodyManipulationError(String), #[error("New closed in incoming-like")] HyperIncomingLikeNewClosed, + #[error("New body write aborted")] + HyperNewBodyWriteAborted, // http/3 errors #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] diff --git a/rpxy-lib/src/hyper_ext/body_incoming_like.rs b/rpxy-lib/src/hyper_ext/body_incoming_like.rs index 2fced25..9307b7f 100644 --- a/rpxy-lib/src/hyper_ext/body_incoming_like.rs +++ b/rpxy-lib/src/hyper_ext/body_incoming_like.rs @@ -11,10 +11,11 @@ use std::{ //////////////////////////////////////////////////////////// /// Incoming like body to handle incoming request body +/// ported from https://github.com/hyperium/hyper/blob/master/src/body/incoming.rs pub struct IncomingLike { content_length: DecodedLength, want_tx: watch::Sender, - data_rx: mpsc::Receiver>, + data_rx: mpsc::Receiver>, trailers_rx: oneshot::Receiver, } @@ -27,9 +28,10 @@ macro_rules! ready { }; } -type BodySender = mpsc::Sender>; +type BodySender = mpsc::Sender>; type TrailersSender = oneshot::Sender; +const MAX_LEN: u64 = std::u64::MAX - 2; #[derive(Clone, Copy, PartialEq, Eq)] pub(crate) struct DecodedLength(u64); impl DecodedLength { @@ -37,6 +39,12 @@ impl DecodedLength { pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); pub(crate) const ZERO: DecodedLength = DecodedLength(0); + #[allow(dead_code)] + pub(crate) fn new(len: u64) -> Self { + debug_assert!(len <= MAX_LEN); + DecodedLength(len) + } + pub(crate) fn sub_if(&mut self, amt: u64) { match *self { DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), @@ -100,7 +108,7 @@ impl IncomingLike { impl Body for IncomingLike { type Data = Bytes; - type Error = hyper::Error; + type Error = RpxyError; fn poll_frame( mut self: Pin<&mut Self>, @@ -186,4 +194,177 @@ impl Sender { }; tx.send(trailers).map_err(|_| RpxyError::HyperIncomingLikeNewClosed) } + + /// Try to send data on this channel. + /// + /// # Errors + /// + /// Returns `Err(Bytes)` if the channel could not (currently) accept + /// another `Bytes`. + /// + /// # Note + /// + /// This is mostly useful for when trying to send from some other thread + /// that doesn't have an async context. If in an async context, prefer + /// `send_data()` instead. + #[allow(unused)] + pub(crate) fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { + self + .data_tx + .try_send(Ok(chunk)) + .map_err(|err| err.into_inner().expect("just sent Ok")) + } + + #[allow(unused)] + pub(crate) fn abort(mut self) { + self.send_error(RpxyError::HyperNewBodyWriteAborted); + } + + pub(crate) fn send_error(&mut self, err: RpxyError) { + let _ = self + .data_tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(err)); + } +} + +#[cfg(test)] +mod tests { + use std::mem; + use std::task::Poll; + + use super::{Body, DecodedLength, IncomingLike, Sender, SizeHint}; + use crate::error::RpxyError; + use http_body_util::BodyExt; + + #[test] + fn test_size_of() { + // These are mostly to help catch *accidentally* increasing + // the size by too much. + + let body_size = mem::size_of::(); + let body_expected_size = mem::size_of::() * 5; + assert!( + body_size <= body_expected_size, + "Body size = {} <= {}", + body_size, + body_expected_size, + ); + + //assert_eq!(body_size, mem::size_of::>(), "Option"); + + assert_eq!(mem::size_of::(), mem::size_of::() * 5, "Sender"); + + assert_eq!( + mem::size_of::(), + mem::size_of::>(), + "Option" + ); + } + #[test] + fn size_hint() { + fn eq(body: IncomingLike, b: SizeHint, note: &str) { + let a = body.size_hint(); + assert_eq!(a.lower(), b.lower(), "lower for {:?}", note); + assert_eq!(a.upper(), b.upper(), "upper for {:?}", note); + } + + eq(IncomingLike::channel().1, SizeHint::new(), "channel"); + + eq( + IncomingLike::new_channel(DecodedLength::new(4), /*wanter =*/ false).1, + SizeHint::with_exact(4), + "channel with length", + ); + } + + #[tokio::test] + async fn channel_abort() { + let (tx, mut rx) = IncomingLike::channel(); + + tx.abort(); + + match rx.frame().await.unwrap() { + Err(RpxyError::HyperNewBodyWriteAborted) => true, + unexpected => panic!("unexpected: {:?}", unexpected), + }; + } + + #[tokio::test] + async fn channel_abort_when_buffer_is_full() { + let (mut tx, mut rx) = IncomingLike::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + // buffer is full, but can still send abort + tx.abort(); + + let chunk1 = rx.frame().await.expect("item 1").expect("chunk 1").into_data().unwrap(); + assert_eq!(chunk1, "chunk 1"); + + match rx.frame().await.unwrap() { + Err(RpxyError::HyperNewBodyWriteAborted) => true, + unexpected => panic!("unexpected: {:?}", unexpected), + }; + } + + #[test] + fn channel_buffers_one() { + let (mut tx, _rx) = IncomingLike::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + + // buffer is now full + let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); + assert_eq!(chunk2, "chunk 2"); + } + + #[tokio::test] + async fn channel_empty() { + let (_, mut rx) = IncomingLike::channel(); + + assert!(rx.frame().await.is_none()); + } + + #[test] + fn channel_ready() { + let (mut tx, _rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); + } + + #[test] + fn channel_wanter() { + let (mut tx, mut rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + let mut rx_data = tokio_test::task::spawn(rx.frame()); + + assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); + + assert!(rx_data.poll().is_pending(), "poll rx.data"); + assert!(tx_ready.is_woken(), "rx poll wakes tx"); + + assert!(tx_ready.poll().is_ready(), "tx is ready after rx has been polled"); + } + + #[test] + + fn channel_notices_closure() { + let (mut tx, rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); + + drop(rx); + assert!(tx_ready.is_woken(), "dropping rx wakes tx"); + + match tx_ready.poll() { + Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)) => (), + unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), + } + } } From a9ce26ae7657f431e039a0864bacdb7a82c205a2 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Mon, 27 Nov 2023 15:39:19 +0900 Subject: [PATCH 11/50] wip: implementing message handler --- rpxy-lib/src/error.rs | 4 + rpxy-lib/src/hyper_ext/body_type.rs | 26 +--- rpxy-lib/src/hyper_ext/mod.rs | 3 +- rpxy-lib/src/lib.rs | 26 ++-- .../src/message_handle/canonical_address.rs | 61 ++++++++ rpxy-lib/src/message_handle/handler.rs | 131 ++++++++++++++++++ rpxy-lib/src/message_handle/http_log.rs | 99 +++++++++++++ rpxy-lib/src/message_handle/http_result.rs | 36 +++++ rpxy-lib/src/message_handle/mod.rs | 8 ++ .../src/message_handle/synthetic_response.rs | 57 ++++++++ rpxy-lib/src/message_handle/utils_request.rs | 43 ++++++ rpxy-lib/src/message_handler/mod.rs | 0 rpxy-lib/src/proxy/proxy_h3.rs | 47 +++---- rpxy-lib/src/proxy/proxy_main.rs | 56 ++++---- rpxy-lib/src/proxy/proxy_quic_quinn.rs | 15 +- rpxy-lib/src/proxy/proxy_quic_s2n.rs | 6 +- 16 files changed, 520 insertions(+), 98 deletions(-) create mode 100644 rpxy-lib/src/message_handle/canonical_address.rs create mode 100644 rpxy-lib/src/message_handle/handler.rs create mode 100644 rpxy-lib/src/message_handle/http_log.rs create mode 100644 rpxy-lib/src/message_handle/http_result.rs create mode 100644 rpxy-lib/src/message_handle/mod.rs create mode 100644 rpxy-lib/src/message_handle/synthetic_response.rs create mode 100644 rpxy-lib/src/message_handle/utils_request.rs delete mode 100644 rpxy-lib/src/message_handler/mod.rs diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 49a1d1c..a05a612 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -69,6 +69,10 @@ pub enum RpxyError { #[error("Failed to build backend app: {0}")] FailedToBuildBackendApp(#[from] crate::backend::BackendAppBuilderError), + // Handler errors + #[error("Failed to build message handler: {0}")] + FailedToBuildMessageHandler(#[from] crate::message_handle::HttpMessageHandlerBuilderError), + // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs index ba1bdc2..516e569 100644 --- a/rpxy-lib/src/hyper_ext/body_type.rs +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -1,5 +1,3 @@ -use crate::error::*; -use http::{Response, StatusCode}; use http_body_util::{combinators, BodyExt, Either, Empty, Full}; use hyper::body::{Bytes, Incoming}; @@ -8,30 +6,8 @@ pub(crate) type BoxBody = combinators::BoxBody; /// Type for either passthrough body or given body type, specifically synthetic boxed body pub(crate) type IncomingOr = Either; -/// helper function to build http response with passthrough body -pub(crate) fn passthrough_response(response: Response) -> RpxyResult>> -where - B: hyper::body::Body, -{ - Ok(response.map(IncomingOr::Left)) -} - -/// helper function to build http response with synthetic body -pub(crate) fn synthetic_response(response: Response) -> RpxyResult>> { - Ok(response.map(IncomingOr::Right)) -} - -/// build http response with status code of 4xx and 5xx -pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult>> { - let res = Response::builder() - .status(status_code) - .body(IncomingOr::Right(BoxBody::new(empty()))) - .unwrap(); - Ok(res) -} - /// helper function to build a empty body -fn empty() -> BoxBody { +pub(crate) fn empty() -> BoxBody { Empty::::new().map_err(|never| match never {}).boxed() } diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index 19511a1..a39aef9 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -8,6 +8,5 @@ pub(crate) mod rt { } pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; - pub(crate) use super::body_type::{BoxBody, IncomingOr}; + pub(crate) use super::body_type::{empty, full, BoxBody, IncomingOr}; } -pub(crate) use body_type::{full, passthrough_response, synthetic_error_response, synthetic_response}; diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 4e63cbe..c45327a 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -6,11 +6,14 @@ mod error; mod globals; mod hyper_ext; mod log; -mod message_handler; +mod message_handle; mod name_exp; mod proxy; -use crate::{crypto::build_cert_reloader, error::*, globals::Globals, log::*, proxy::Proxy}; +use crate::{ + crypto::build_cert_reloader, error::*, globals::Globals, log::*, message_handle::HttpMessageHandlerBuilder, + proxy::Proxy, +}; use futures::future::select_all; use std::sync::Arc; @@ -86,16 +89,15 @@ where cert_reloader_rx: cert_reloader_rx.clone(), }); - // TODO: 2. build message handler with Arc-ed http_client and backends, and make it contained in Arc as well - // // build message handler including a request forwarder - // let msg_handler = Arc::new( - // HttpMessageHandlerBuilder::default() - // // .forwarder(Arc::new(Forwarder::new(&globals).await)) - // .globals(globals.clone()) - // .build()?, - // ); + // 4. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well + let message_handler = Arc::new( + HttpMessageHandlerBuilder::default() + .globals(globals.clone()) + .app_manager(app_manager.clone()) + .build()?, + ); - // TODO: 3. spawn each proxy for a given socket with copied Arc-ed message_handler. + // 5. spawn each proxy for a given socket with copied Arc-ed message_handler. // build hyper connection builder shared with proxy instances let connection_builder = proxy::connection_builder(&globals); @@ -111,7 +113,7 @@ where listening_on, tls_enabled, connection_builder: connection_builder.clone(), - // TODO: message_handler + message_handler: message_handler.clone(), }; globals.runtime_handle.spawn(async move { proxy.start().await }) }); diff --git a/rpxy-lib/src/message_handle/canonical_address.rs b/rpxy-lib/src/message_handle/canonical_address.rs new file mode 100644 index 0000000..32dad78 --- /dev/null +++ b/rpxy-lib/src/message_handle/canonical_address.rs @@ -0,0 +1,61 @@ +use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + +/// Trait to convert an IP address to its canonical form +pub trait ToCanonical { + fn to_canonical(&self) -> Self; +} + +impl ToCanonical for SocketAddr { + fn to_canonical(&self) -> Self { + match self { + SocketAddr::V4(_) => *self, + SocketAddr::V6(v6) => match v6.ip().to_ipv4() { + Some(mapped) => { + if mapped == Ipv4Addr::new(0, 0, 0, 1) { + *self + } else { + SocketAddr::new(IpAddr::V4(mapped), self.port()) + } + } + None => *self, + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::net::Ipv6Addr; + #[test] + fn ipv4_loopback_to_canonical() { + let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); + assert_eq!(socket.to_canonical(), socket); + } + #[test] + fn ipv6_loopback_to_canonical() { + let socket = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080); + assert_eq!(socket.to_canonical(), socket); + } + #[test] + fn ipv4_to_canonical() { + let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080); + assert_eq!(socket.to_canonical(), socket); + } + #[test] + fn ipv6_to_canonical() { + let socket = SocketAddr::new( + IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0xdead, 0xbeef)), + 8080, + ); + assert_eq!(socket.to_canonical(), socket); + } + #[test] + fn ipv4_mapped_to_ipv6_to_canonical() { + let socket = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)), 8080); + assert_eq!( + socket.to_canonical(), + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 10, 2, 255)), 8080) + ); + } +} diff --git a/rpxy-lib/src/message_handle/handler.rs b/rpxy-lib/src/message_handle/handler.rs new file mode 100644 index 0000000..145d8ba --- /dev/null +++ b/rpxy-lib/src/message_handle/handler.rs @@ -0,0 +1,131 @@ +use super::{ + http_log::HttpMessageLog, + http_result::{HttpError, HttpResult}, + synthetic_response::{secure_redirection_response, synthetic_error_response}, + utils_request::ParseHost, +}; +use crate::{ + backend::BackendAppManager, + crypto::CryptoSource, + error::*, + globals::Globals, + hyper_ext::body::{BoxBody, IncomingLike, IncomingOr}, + log::*, + name_exp::ServerName, +}; +use derive_builder::Builder; +use http::{Request, Response, StatusCode}; +use std::{net::SocketAddr, sync::Arc}; + +#[derive(Clone, Builder)] +/// HTTP message handler for requests from clients and responses from backend applications, +/// responsible to manipulate and forward messages to upstream backends and downstream clients. +// pub struct HttpMessageHandler +pub struct HttpMessageHandler +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone, +{ + // forwarder: Arc>, + globals: Arc, + app_manager: Arc>, +} + +impl HttpMessageHandler +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone, +{ + /// Handle incoming request message from a client. + /// Responsible to passthrough responses from backend applications or generate synthetic error responses. + pub async fn handle_request( + &self, + mut req: Request>, + client_addr: SocketAddr, // For access control + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, + ) -> RpxyResult>> { + let mut log_data = HttpMessageLog::from(&req); + + let http_result = self + .handle_request_inner( + &mut log_data, + req, + client_addr, + listen_addr, + tls_enabled, + tls_server_name, + ) + .await; + + // passthrough or synthetic response + match http_result { + Ok(v) => { + log_data.status_code(&v.status()).output(); + Ok(v) + } + Err(e) => { + debug!("{e}"); + let code = StatusCode::from(e); + log_data.status_code(&code).output(); + synthetic_error_response(code) + } + } + } + + /// Handle inner with no synthetic error response. + /// Synthetic response is generated by caller. + async fn handle_request_inner( + &self, + mut log_data: &mut HttpMessageLog, + mut req: Request>, + client_addr: SocketAddr, // For access control + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, + ) -> HttpResult>> { + // preparing log data + let mut log_data = HttpMessageLog::from(&req); + log_data.client_addr(&client_addr); + + // Here we start to handle with server_name + let server_name = req.parse_host().map(ServerName::from)?; + + // check consistency of between TLS SNI and HOST/Request URI Line. + #[allow(clippy::collapsible_if)] + if tls_enabled && self.globals.proxy_config.sni_consistency { + if server_name != tls_server_name.unwrap_or_default() { + return Err(HttpError::SniHostInconsistency); + } + } + // Find backend application for given server_name, and drop if incoming request is invalid as request. + let backend_app = match self.app_manager.apps.get(&server_name) { + Some(backend_app) => backend_app, + None => { + let Some(default_server_name) = &self.app_manager.default_server_name else { + return Err(HttpError::NoMatchingBackendApp); + }; + debug!("Serving by default app"); + self.app_manager.apps.get(default_server_name).unwrap() + } + }; + + // Redirect to https if !tls_enabled and redirect_to_https is true + if !tls_enabled && backend_app.https_redirection.unwrap_or(false) { + debug!( + "Redirect to secure connection: {}", + <&ServerName as TryInto>::try_into(&backend_app.server_name).unwrap_or_default() + ); + return secure_redirection_response(&backend_app.server_name, self.globals.proxy_config.https_port, &req); + } + + ////////////// + // // TODO: remove later + let body = crate::hyper_ext::body::full(hyper::body::Bytes::from("not yet implemented")); + let res = super::synthetic_response::synthetic_response(Response::builder().body(body).unwrap()); + Ok(res) + ////////////// + // todo!() + } +} diff --git a/rpxy-lib/src/message_handle/http_log.rs b/rpxy-lib/src/message_handle/http_log.rs new file mode 100644 index 0000000..7056c80 --- /dev/null +++ b/rpxy-lib/src/message_handle/http_log.rs @@ -0,0 +1,99 @@ +use super::canonical_address::ToCanonical; +use crate::log::*; +use http::header; +use std::net::SocketAddr; + +/// Struct to log HTTP messages +#[derive(Debug, Clone)] +pub struct HttpMessageLog { + // pub tls_server_name: String, + pub client_addr: String, + pub method: String, + pub host: String, + pub p_and_q: String, + pub version: hyper::Version, + pub uri_scheme: String, + pub uri_host: String, + pub ua: String, + pub xff: String, + pub status: String, + pub upstream: String, +} + +impl From<&hyper::Request> for HttpMessageLog { + fn from(req: &hyper::Request) -> Self { + let header_mapper = |v: header::HeaderName| { + req + .headers() + .get(v) + .map_or_else(|| "", |s| s.to_str().unwrap_or("")) + .to_string() + }; + Self { + // tls_server_name: "".to_string(), + client_addr: "".to_string(), + method: req.method().to_string(), + host: header_mapper(header::HOST), + p_and_q: req + .uri() + .path_and_query() + .map_or_else(|| "", |v| v.as_str()) + .to_string(), + version: req.version(), + uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), + uri_host: req.uri().host().unwrap_or("").to_string(), + ua: header_mapper(header::USER_AGENT), + xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), + status: "".to_string(), + upstream: "".to_string(), + } + } +} + +impl HttpMessageLog { + pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { + self.client_addr = client_addr.to_canonical().to_string(); + self + } + // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self { + // self.tls_server_name = tls_server_name.to_string(); + // self + // } + pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { + self.status = status_code.to_string(); + self + } + pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { + self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); + self + } + pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { + self.upstream = upstream.to_string(); + self + } + + pub fn output(&self) { + info!( + "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", + if !self.host.is_empty() { + self.host.as_str() + } else { + self.uri_host.as_str() + }, + self.client_addr, + self.method, + self.p_and_q, + self.version, + self.status, + if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { + format!("{}://{}", self.uri_scheme, self.uri_host) + } else { + "".to_string() + }, + self.ua, + self.xff, + self.upstream, + // self.tls_server_name + ); + } +} diff --git a/rpxy-lib/src/message_handle/http_result.rs b/rpxy-lib/src/message_handle/http_result.rs new file mode 100644 index 0000000..8e9d6b4 --- /dev/null +++ b/rpxy-lib/src/message_handle/http_result.rs @@ -0,0 +1,36 @@ +use http::StatusCode; +use thiserror::Error; + +/// HTTP result type, T is typically a hyper::Response +/// HttpError is used to generate a synthetic error response +pub(crate) type HttpResult = std::result::Result; + +/// Describes things that can go wrong in the forwarder +#[derive(Debug, Error)] +pub enum HttpError { + #[error("No host is give nin request header")] + NoHostInRequestHeader, + #[error("Invalid host in request header")] + InvalidHostInRequestHeader, + #[error("SNI and Host header mismatch")] + SniHostInconsistency, + #[error("No matching backend app")] + NoMatchingBackendApp, + #[error("Failed to redirect: {0}")] + FailedToRedirect(String), + + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for StatusCode { + fn from(e: HttpError) -> StatusCode { + match e { + HttpError::NoHostInRequestHeader => StatusCode::BAD_REQUEST, + HttpError::InvalidHostInRequestHeader => StatusCode::BAD_REQUEST, + HttpError::SniHostInconsistency => StatusCode::MISDIRECTED_REQUEST, + HttpError::NoMatchingBackendApp => StatusCode::SERVICE_UNAVAILABLE, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} diff --git a/rpxy-lib/src/message_handle/mod.rs b/rpxy-lib/src/message_handle/mod.rs new file mode 100644 index 0000000..f00b417 --- /dev/null +++ b/rpxy-lib/src/message_handle/mod.rs @@ -0,0 +1,8 @@ +mod canonical_address; +mod handler; +mod http_log; +mod http_result; +mod synthetic_response; +mod utils_request; + +pub(crate) use handler::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; diff --git a/rpxy-lib/src/message_handle/synthetic_response.rs b/rpxy-lib/src/message_handle/synthetic_response.rs new file mode 100644 index 0000000..0038997 --- /dev/null +++ b/rpxy-lib/src/message_handle/synthetic_response.rs @@ -0,0 +1,57 @@ +use crate::{ + error::*, + hyper_ext::body::{empty, BoxBody, IncomingOr}, + name_exp::ServerName, +}; +use http::{Request, Response, StatusCode, Uri}; +use hyper::body::Incoming; + +use super::http_result::{HttpError, HttpResult}; + +/// helper function to build http response with passthrough body +pub(crate) fn passthrough_response(response: Response) -> Response> +where + B: hyper::body::Body, +{ + response.map(IncomingOr::Left) +} + +/// helper function to build http response with synthetic body +pub(crate) fn synthetic_response(response: Response) -> Response> { + response.map(IncomingOr::Right) +} + +/// build http response with status code of 4xx and 5xx +pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult>> { + let res = Response::builder() + .status(status_code) + .body(IncomingOr::Right(empty())) + .unwrap(); + Ok(res) +} + +/// Generate synthetic response message of a redirection to https host with 301 +pub(super) fn secure_redirection_response( + server_name: &ServerName, + tls_port: Option, + req: &Request, +) -> HttpResult>> { + let server_name: String = server_name.try_into().unwrap_or_default(); + let pq = match req.uri().path_and_query() { + Some(x) => x.as_str(), + _ => "", + }; + let new_uri = Uri::builder().scheme("https").path_and_query(pq); + let dest_uri = match tls_port { + Some(443) | None => new_uri.authority(server_name), + Some(p) => new_uri.authority(format!("{server_name}:{p}")), + } + .build() + .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; + let response = Response::builder() + .status(StatusCode::MOVED_PERMANENTLY) + .header("Location", dest_uri.to_string()) + .body(IncomingOr::Right(empty())) + .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; + Ok(response) +} diff --git a/rpxy-lib/src/message_handle/utils_request.rs b/rpxy-lib/src/message_handle/utils_request.rs new file mode 100644 index 0000000..a8f9bd4 --- /dev/null +++ b/rpxy-lib/src/message_handle/utils_request.rs @@ -0,0 +1,43 @@ +use super::http_result::*; +use http::{header, Request}; + +/// Trait defining parser of hostname +pub trait ParseHost { + type Error; + fn parse_host(&self) -> Result<&[u8], Self::Error>; +} +impl ParseHost for Request { + type Error = HttpError; + /// Extract hostname from either the request HOST header or request line + fn parse_host(&self) -> HttpResult<&[u8]> { + let headers_host = self.headers().get(header::HOST); + let uri_host = self.uri().host(); + // let uri_port = self.uri().port_u16(); + + if !(!(headers_host.is_none() && uri_host.is_none())) { + return Err(HttpError::NoHostInRequestHeader); + } + + // prioritize server_name in uri + uri_host.map_or_else( + || { + let m = headers_host.unwrap().as_bytes(); + if m.starts_with(&[b'[']) { + // v6 address with bracket case. if port is specified, always it is in this case. + let mut iter = m.split(|ptr| ptr == &b'[' || ptr == &b']'); + iter.next().ok_or(HttpError::InvalidHostInRequestHeader)?; // first item is always blank + iter.next().ok_or(HttpError::InvalidHostInRequestHeader) + } else if m.len() - m.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { + // v6 address case, if 2 or more ':' is contained + Ok(m) + } else { + // v4 address or hostname + m.split(|colon| colon == &b':') + .next() + .ok_or(HttpError::InvalidHostInRequestHeader) + } + }, + |v| Ok(v.as_bytes()), + ) + } +} diff --git a/rpxy-lib/src/message_handler/mod.rs b/rpxy-lib/src/message_handler/mod.rs deleted file mode 100644 index e69de29..0000000 diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 6ca6528..922b857 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,10 +1,8 @@ use super::proxy_main::Proxy; use crate::{ + crypto::CryptoSource, error::*, - hyper_ext::{ - body::{IncomingLike, IncomingOr}, - full, synthetic_response, - }, + hyper_ext::body::{IncomingLike, IncomingOr}, log::*, name_exp::ServerName, }; @@ -22,13 +20,11 @@ use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic // use futures::Stream; // use hyper_util::client::legacy::connect::Connect; -// impl Proxy -// where -// // T: Connect + Clone + Sync + Send + 'static, -// U: CryptoSource + Clone + Sync + Send + 'static, -// { - -impl Proxy { +impl Proxy +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, +{ pub(super) async fn h3_serve_connection( &self, quic_connection: C, @@ -151,24 +147,17 @@ impl Proxy { }); let mut new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); - - // let res = selfw - // .msg_handler - // .clone() - // .handle_request( - // new_req, - // client_addr, - // self.listening_on, - // self.tls_enabled, - // Some(tls_server_name), - // ) - // .await?; - - // TODO: TODO: TODO: remove later - let body = full(Bytes::from("hello h3 echo")); - // here response is IncomingOr from message handler - let res = synthetic_response(Response::builder().body(body).unwrap())?; - ///////////////// + // Response> wrapped by RpxyResult + let res = self + .message_handler + .handle_request( + new_req, + client_addr, + self.listening_on, + self.tls_enabled, + Some(tls_server_name), + ) + .await?; let (new_res_parts, new_body) = res.into_parts(); let new_res = Response::from_parts(new_res_parts, ()); diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index cc6636d..abdea64 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,16 +1,15 @@ use super::socket::bind_tcp_socket; use crate::{ constants::TLS_HANDSHAKE_TIMEOUT_SEC, - crypto::{ServerCrypto, SniServerCryptoMap}, + crypto::{CryptoSource, ServerCrypto, SniServerCryptoMap}, error::*, globals::Globals, hyper_ext::{ body::{BoxBody, IncomingOr}, - full, rt::LocalExecutor, - synthetic_response, }, log::*, + message_handle::HttpMessageHandler, name_exp::ServerName, }; use futures::{select, FutureExt}; @@ -26,34 +25,37 @@ use tokio::time::timeout; /// Wrapper function to handle request for HTTP/1.1 and HTTP/2 /// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler -async fn serve_request( +async fn serve_request( mut req: Request, // handler: Arc>, - // handler: Arc>, + handler: Arc>, client_addr: SocketAddr, listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, -) -> RpxyResult>> { - // match handler - // .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) - // .await? - // { - // Ok(res) => passthrough_response(res), - // Err(e) => synthetic_error_response(StatusCode::from(e)), - // } - - ////////////// - // TODO: remove later - let body = full(hyper::body::Bytes::from("hello")); - let res = Response::builder().body(body).unwrap(); - synthetic_response(res) - ////////////// +) -> RpxyResult>> +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone, +{ + handler + .handle_request( + req.map(IncomingOr::Left), + client_addr, + listen_addr, + tls_enabled, + tls_server_name, + ) + .await } #[derive(Clone)] /// Proxy main object responsible to serve requests received from clients at the given socket address. -pub(crate) struct Proxy { +pub(crate) struct Proxy +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, +{ /// global context shared among async tasks pub globals: Arc, /// listen socket address @@ -62,9 +64,15 @@ pub(crate) struct Proxy { pub tls_enabled: bool, /// hyper connection builder serving http request pub connection_builder: Arc>, + /// message handler serving incoming http request + pub message_handler: Arc>, } -impl Proxy { +impl Proxy +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, +{ /// Serves requests from clients fn serve_connection(&self, stream: I, peer_addr: SocketAddr, tls_server_name: Option) where @@ -78,7 +86,7 @@ impl Proxy { debug!("Request incoming: current # {}", request_count.current()); let server_clone = self.connection_builder.clone(); - // let msg_handler_clone = self.msg_handler.clone(); + let message_handler_clone = self.message_handler.clone(); let timeout_sec = self.globals.proxy_config.proxy_timeout; let tls_enabled = self.tls_enabled; let listening_on = self.listening_on; @@ -90,7 +98,7 @@ impl Proxy { service_fn(move |req: Request| { serve_request( req, - // msg_handler_clone.clone(), + message_handler_clone.clone(), peer_addr, listening_on, tls_enabled, diff --git a/rpxy-lib/src/proxy/proxy_quic_quinn.rs b/rpxy-lib/src/proxy/proxy_quic_quinn.rs index bde3b00..8380f6e 100644 --- a/rpxy-lib/src/proxy/proxy_quic_quinn.rs +++ b/rpxy-lib/src/proxy/proxy_quic_quinn.rs @@ -1,15 +1,20 @@ use super::proxy_main::Proxy; use super::socket::bind_udp_socket; -use crate::{crypto::ServerCrypto, error::*, log::*, name_exp::ByteName}; +use crate::{ + crypto::{CryptoSource, ServerCrypto}, + error::*, + log::*, + name_exp::ByteName, +}; // use hyper_util::client::legacy::connect::Connect; use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; use rustls::ServerConfig; use std::sync::Arc; -impl Proxy -// where -// // T: Connect + Clone + Sync + Send + 'static, -// U: CryptoSource + Clone + Sync + Send + 'static, +impl Proxy +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { diff --git a/rpxy-lib/src/proxy/proxy_quic_s2n.rs b/rpxy-lib/src/proxy/proxy_quic_s2n.rs index 32be619..3ab41d0 100644 --- a/rpxy-lib/src/proxy/proxy_quic_s2n.rs +++ b/rpxy-lib/src/proxy/proxy_quic_s2n.rs @@ -10,7 +10,11 @@ use std::sync::Arc; // use hyper_util::client::legacy::connect::Connect; use s2n_quic::provider; -impl Proxy { +impl Proxy +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, +{ /// Start UDP proxy serving with HTTP/3 request for configured host names pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { From c4cf40be4ef47c63ddd1454a41db32b5f5666fb0 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 00:51:05 +0900 Subject: [PATCH 12/50] wip: implementing message handler, finished to implement request manipulation --- README.md | 4 +- config-example.toml | 6 +- rpxy-lib/src/backend/load_balance/mod.rs | 2 + rpxy-lib/src/backend/mod.rs | 15 +- rpxy-lib/src/backend/upstream_opts.rs | 4 +- rpxy-lib/src/error.rs | 1 - rpxy-lib/src/message_handle/handler.rs | 131 -------- rpxy-lib/src/message_handle/handler_main.rs | 255 +++++++++++++++ .../handler_manipulate_messages.rs | 195 ++++++++++++ rpxy-lib/src/message_handle/http_result.rs | 11 + rpxy-lib/src/message_handle/mod.rs | 6 +- rpxy-lib/src/message_handle/utils_headers.rs | 292 ++++++++++++++++++ rpxy-lib/src/message_handle/utils_request.rs | 94 ++++-- 13 files changed, 834 insertions(+), 182 deletions(-) delete mode 100644 rpxy-lib/src/message_handle/handler.rs create mode 100644 rpxy-lib/src/message_handle/handler_main.rs create mode 100644 rpxy-lib/src/message_handle/handler_manipulate_messages.rs create mode 100644 rpxy-lib/src/message_handle/utils_headers.rs diff --git a/README.md b/README.md index ef5d0fe..20d7891 100644 --- a/README.md +++ b/README.md @@ -104,11 +104,11 @@ If you want to host multiple and distinct domain names in a single IP address/po ```toml default_application = "app1" -[app.app1] +[apps.app1] server_name = "app1.example.com" #... -[app.app2] +[apps.app2] server_name = "app2.example.org" #... ``` diff --git a/config-example.toml b/config-example.toml index ec79f3d..458061c 100644 --- a/config-example.toml +++ b/config-example.toml @@ -57,8 +57,8 @@ upstream = [ ] load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ - "override_host", - "force_http2_upstream", # mutually exclusive with "force_http11_upstream" + "disable_override_host", # do not overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy) + "force_http2_upstream", # mutually exclusive with "force_http11_upstream" ] # Non-default destination in "localhost" app, which is routed by "path" @@ -76,7 +76,7 @@ upstream = [ ] load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ - "override_host", + "disable_override_host", "upgrade_insecure_requests", "force_http11_upstream", ] diff --git a/rpxy-lib/src/backend/load_balance/mod.rs b/rpxy-lib/src/backend/load_balance/mod.rs index d876517..38d312b 100644 --- a/rpxy-lib/src/backend/load_balance/mod.rs +++ b/rpxy-lib/src/backend/load_balance/mod.rs @@ -12,6 +12,8 @@ pub use load_balance_main::{ }; #[cfg(feature = "sticky-cookie")] pub use load_balance_sticky::LoadBalanceStickyBuilder; +#[cfg(feature = "sticky-cookie")] +pub use sticky_cookie::{StickyCookie, StickyCookieValue}; /// Result type for load balancing type LoadBalanceResult = std::result::Result; diff --git a/rpxy-lib/src/backend/mod.rs b/rpxy-lib/src/backend/mod.rs index 68f97a8..3d3ddbf 100644 --- a/rpxy-lib/src/backend/mod.rs +++ b/rpxy-lib/src/backend/mod.rs @@ -3,12 +3,11 @@ mod load_balance; mod upstream; mod upstream_opts; -pub use backend_main::{BackendAppBuilderError, BackendAppManager}; -pub use upstream::Upstream; // #[cfg(feature = "sticky-cookie")] -// pub use sticky_cookie::{StickyCookie, StickyCookieValue}; -// pub use self::{ -// load_balance::{LbContext, LoadBalance}, -// upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, -// upstream_opts::UpstreamOption, -// }; +// pub use self::load_balance::{StickyCookie, StickyCookieValue}; +pub(crate) use self::{ + load_balance::{LoadBalance, LoadBalanceContext, StickyCookie, StickyCookieValue}, + upstream::{PathManager, Upstream, UpstreamCandidates}, + upstream_opts::UpstreamOption, +}; +pub(crate) use backend_main::{BackendAppBuilderError, BackendAppManager}; diff --git a/rpxy-lib/src/backend/upstream_opts.rs b/rpxy-lib/src/backend/upstream_opts.rs index 3f5fbc8..f19acb4 100644 --- a/rpxy-lib/src/backend/upstream_opts.rs +++ b/rpxy-lib/src/backend/upstream_opts.rs @@ -2,7 +2,7 @@ use crate::error::*; #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum UpstreamOption { - OverrideHost, + DisableOverrideHost, UpgradeInsecureRequests, ForceHttp11Upstream, ForceHttp2Upstream, @@ -12,7 +12,7 @@ impl TryFrom<&str> for UpstreamOption { type Error = RpxyError; fn try_from(val: &str) -> RpxyResult { match val { - "override_host" => Ok(Self::OverrideHost), + "diaable_override_host" => Ok(Self::DisableOverrideHost), "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index a05a612..da65234 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -1,4 +1,3 @@ -pub use anyhow::{anyhow, bail, ensure, Context}; use thiserror::Error; pub type RpxyResult = std::result::Result; diff --git a/rpxy-lib/src/message_handle/handler.rs b/rpxy-lib/src/message_handle/handler.rs deleted file mode 100644 index 145d8ba..0000000 --- a/rpxy-lib/src/message_handle/handler.rs +++ /dev/null @@ -1,131 +0,0 @@ -use super::{ - http_log::HttpMessageLog, - http_result::{HttpError, HttpResult}, - synthetic_response::{secure_redirection_response, synthetic_error_response}, - utils_request::ParseHost, -}; -use crate::{ - backend::BackendAppManager, - crypto::CryptoSource, - error::*, - globals::Globals, - hyper_ext::body::{BoxBody, IncomingLike, IncomingOr}, - log::*, - name_exp::ServerName, -}; -use derive_builder::Builder; -use http::{Request, Response, StatusCode}; -use std::{net::SocketAddr, sync::Arc}; - -#[derive(Clone, Builder)] -/// HTTP message handler for requests from clients and responses from backend applications, -/// responsible to manipulate and forward messages to upstream backends and downstream clients. -// pub struct HttpMessageHandler -pub struct HttpMessageHandler -where - // T: Connect + Clone + Sync + Send + 'static, - U: CryptoSource + Clone, -{ - // forwarder: Arc>, - globals: Arc, - app_manager: Arc>, -} - -impl HttpMessageHandler -where - // T: Connect + Clone + Sync + Send + 'static, - U: CryptoSource + Clone, -{ - /// Handle incoming request message from a client. - /// Responsible to passthrough responses from backend applications or generate synthetic error responses. - pub async fn handle_request( - &self, - mut req: Request>, - client_addr: SocketAddr, // For access control - listen_addr: SocketAddr, - tls_enabled: bool, - tls_server_name: Option, - ) -> RpxyResult>> { - let mut log_data = HttpMessageLog::from(&req); - - let http_result = self - .handle_request_inner( - &mut log_data, - req, - client_addr, - listen_addr, - tls_enabled, - tls_server_name, - ) - .await; - - // passthrough or synthetic response - match http_result { - Ok(v) => { - log_data.status_code(&v.status()).output(); - Ok(v) - } - Err(e) => { - debug!("{e}"); - let code = StatusCode::from(e); - log_data.status_code(&code).output(); - synthetic_error_response(code) - } - } - } - - /// Handle inner with no synthetic error response. - /// Synthetic response is generated by caller. - async fn handle_request_inner( - &self, - mut log_data: &mut HttpMessageLog, - mut req: Request>, - client_addr: SocketAddr, // For access control - listen_addr: SocketAddr, - tls_enabled: bool, - tls_server_name: Option, - ) -> HttpResult>> { - // preparing log data - let mut log_data = HttpMessageLog::from(&req); - log_data.client_addr(&client_addr); - - // Here we start to handle with server_name - let server_name = req.parse_host().map(ServerName::from)?; - - // check consistency of between TLS SNI and HOST/Request URI Line. - #[allow(clippy::collapsible_if)] - if tls_enabled && self.globals.proxy_config.sni_consistency { - if server_name != tls_server_name.unwrap_or_default() { - return Err(HttpError::SniHostInconsistency); - } - } - // Find backend application for given server_name, and drop if incoming request is invalid as request. - let backend_app = match self.app_manager.apps.get(&server_name) { - Some(backend_app) => backend_app, - None => { - let Some(default_server_name) = &self.app_manager.default_server_name else { - return Err(HttpError::NoMatchingBackendApp); - }; - debug!("Serving by default app"); - self.app_manager.apps.get(default_server_name).unwrap() - } - }; - - // Redirect to https if !tls_enabled and redirect_to_https is true - if !tls_enabled && backend_app.https_redirection.unwrap_or(false) { - debug!( - "Redirect to secure connection: {}", - <&ServerName as TryInto>::try_into(&backend_app.server_name).unwrap_or_default() - ); - return secure_redirection_response(&backend_app.server_name, self.globals.proxy_config.https_port, &req); - } - - ////////////// - // // TODO: remove later - let body = crate::hyper_ext::body::full(hyper::body::Bytes::from("not yet implemented")); - let res = super::synthetic_response::synthetic_response(Response::builder().body(body).unwrap()); - Ok(res) - ////////////// - // todo!() - } -} diff --git a/rpxy-lib/src/message_handle/handler_main.rs b/rpxy-lib/src/message_handle/handler_main.rs new file mode 100644 index 0000000..5be08f1 --- /dev/null +++ b/rpxy-lib/src/message_handle/handler_main.rs @@ -0,0 +1,255 @@ +use super::{ + http_log::HttpMessageLog, + http_result::{HttpError, HttpResult}, + synthetic_response::{secure_redirection_response, synthetic_error_response}, + utils_headers::*, + utils_request::InspectParseHost, +}; +use crate::{ + backend::{BackendAppManager, LoadBalanceContext}, + crypto::CryptoSource, + error::*, + globals::Globals, + hyper_ext::body::{BoxBody, IncomingLike, IncomingOr}, + log::*, + name_exp::ServerName, +}; +use derive_builder::Builder; +use http::{Request, Response, StatusCode}; +use std::{net::SocketAddr, sync::Arc}; + +#[allow(dead_code)] +#[derive(Debug)] +/// Context object to handle sticky cookies at HTTP message handler +pub(super) struct HandlerContext { + #[cfg(feature = "sticky-cookie")] + pub(super) context_lb: Option, + #[cfg(not(feature = "sticky-cookie"))] + pub(super) context_lb: Option<()>, +} + +#[derive(Clone, Builder)] +/// HTTP message handler for requests from clients and responses from backend applications, +/// responsible to manipulate and forward messages to upstream backends and downstream clients. +// pub struct HttpMessageHandler +pub struct HttpMessageHandler +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone, +{ + // forwarder: Arc>, + globals: Arc, + app_manager: Arc>, +} + +impl HttpMessageHandler +where + // T: Connect + Clone + Sync + Send + 'static, + U: CryptoSource + Clone, +{ + /// Handle incoming request message from a client. + /// Responsible to passthrough responses from backend applications or generate synthetic error responses. + pub async fn handle_request( + &self, + mut req: Request>, + client_addr: SocketAddr, // For access control + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, + ) -> RpxyResult>> { + let mut log_data = HttpMessageLog::from(&req); + + let http_result = self + .handle_request_inner( + &mut log_data, + req, + client_addr, + listen_addr, + tls_enabled, + tls_server_name, + ) + .await; + + // passthrough or synthetic response + match http_result { + Ok(v) => { + log_data.status_code(&v.status()).output(); + Ok(v) + } + Err(e) => { + debug!("{e}"); + let code = StatusCode::from(e); + log_data.status_code(&code).output(); + synthetic_error_response(code) + } + } + } + + /// Handle inner with no synthetic error response. + /// Synthetic response is generated by caller. + async fn handle_request_inner( + &self, + log_data: &mut HttpMessageLog, + mut req: Request>, + client_addr: SocketAddr, // For access control + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, + ) -> HttpResult>> { + // preparing log data + let mut log_data = HttpMessageLog::from(&req); + log_data.client_addr(&client_addr); + + // Here we start to inspect and parse with server_name + let server_name = req + .inspect_parse_host() + .map(|v| ServerName::from(v.as_slice())) + .map_err(|_e| HttpError::InvalidHostInRequestHeader)?; + + // check consistency of between TLS SNI and HOST/Request URI Line. + #[allow(clippy::collapsible_if)] + if tls_enabled && self.globals.proxy_config.sni_consistency { + if server_name != tls_server_name.unwrap_or_default() { + return Err(HttpError::SniHostInconsistency); + } + } + // Find backend application for given server_name, and drop if incoming request is invalid as request. + let backend_app = match self.app_manager.apps.get(&server_name) { + Some(backend_app) => backend_app, + None => { + let Some(default_server_name) = &self.app_manager.default_server_name else { + return Err(HttpError::NoMatchingBackendApp); + }; + debug!("Serving by default app"); + self.app_manager.apps.get(default_server_name).unwrap() + } + }; + + // Redirect to https if !tls_enabled and redirect_to_https is true + if !tls_enabled && backend_app.https_redirection.unwrap_or(false) { + debug!( + "Redirect to secure connection: {}", + <&ServerName as TryInto>::try_into(&backend_app.server_name).unwrap_or_default() + ); + return secure_redirection_response(&backend_app.server_name, self.globals.proxy_config.https_port, &req); + } + + // Find reverse proxy for given path and choose one of upstream host + // Longest prefix match + let path = req.uri().path(); + let Some(upstream_candidates) = backend_app.path_manager.get(path) else { + return Err(HttpError::NoUpstreamCandidates); + }; + + // Upgrade in request header + let upgrade_in_request = extract_upgrade(req.headers()); + let request_upgraded = req.extensions_mut().remove::(); + + // Build request from destination information + let _context = match self.generate_request_forwarded( + &client_addr, + &listen_addr, + &mut req, + &upgrade_in_request, + upstream_candidates, + tls_enabled, + ) { + Err(e) => { + error!("Failed to generate destination uri for backend application: {}", e); + return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string())); + } + Ok(v) => v, + }; + debug!( + "Request to be forwarded: uri {}, version {:?}, headers {:?}", + req.uri(), + req.version(), + req.headers() + ); + log_data.xff(&req.headers().get("x-forwarded-for")); + log_data.upstream(req.uri()); + ////// + + ////////////// + // // TODO: remove later + let body = crate::hyper_ext::body::full(hyper::body::Bytes::from("not yet implemented")); + let mut res_backend = super::synthetic_response::synthetic_response(Response::builder().body(body).unwrap()); + // // Forward request to a chosen backend + // let mut res_backend = { + // let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { + // return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); + // }; + // match result { + // Ok(res) => res, + // Err(e) => { + // error!("Failed to get response from backend: {}", e); + // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + // } + // } + // }; + ////////////// + // Process reverse proxy context generated during the forwarding request generation. + #[cfg(feature = "sticky-cookie")] + if let Some(context_from_lb) = _context.context_lb { + let res_headers = res_backend.headers_mut(); + if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + error!("Failed to append context to the response given from backend: {}", e); + return Err(HttpError::FailedToAddSetCookeInResponse); + } + } + + if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { + // // Generate response to client + // if self.generate_response_forwarded(&mut res_backend, backend).is_err() { + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // } + // log_data.status_code(&res_backend.status()).output(); + // return Ok(res_backend); + } + + // // Handle StatusCode::SWITCHING_PROTOCOLS in response + // let upgrade_in_response = extract_upgrade(res_backend.headers()); + // let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) + // { + // u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() + // } else { + // false + // }; + // if !should_upgrade { + // error!( + // "Backend tried to switch to protocol {:?} when {:?} was requested", + // upgrade_in_response, upgrade_in_request + // ); + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // } + // let Some(request_upgraded) = request_upgraded else { + // error!("Request does not have an upgrade extension"); + // return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); + // }; + // let Some(onupgrade) = res_backend.extensions_mut().remove::() else { + // error!("Response does not have an upgrade extension"); + // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); + // }; + + // self.globals.runtime_handle.spawn(async move { + // let mut response_upgraded = onupgrade.await.map_err(|e| { + // error!("Failed to upgrade response: {}", e); + // RpxyError::Hyper(e) + // })?; + // let mut request_upgraded = request_upgraded.await.map_err(|e| { + // error!("Failed to upgrade request: {}", e); + // RpxyError::Hyper(e) + // })?; + // copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + // .await + // .map_err(|e| { + // error!("Coping between upgraded connections failed: {}", e); + // RpxyError::Io(e) + // })?; + // Ok(()) as Result<()> + // }); + // log_data.status_code(&res_backend.status()).output(); + + Ok(res_backend) + } +} diff --git a/rpxy-lib/src/message_handle/handler_manipulate_messages.rs b/rpxy-lib/src/message_handle/handler_manipulate_messages.rs new file mode 100644 index 0000000..28a62b1 --- /dev/null +++ b/rpxy-lib/src/message_handle/handler_manipulate_messages.rs @@ -0,0 +1,195 @@ +use super::{ + handler_main::HandlerContext, utils_headers::*, utils_request::apply_upstream_options_to_request_line, + HttpMessageHandler, +}; +use crate::{backend::UpstreamCandidates, log::*, CryptoSource}; +use anyhow::{anyhow, ensure, Result}; +use http::{header, uri::Scheme, HeaderValue, Request, Uri, Version}; +use std::net::SocketAddr; + +impl HttpMessageHandler +where + U: CryptoSource + Clone, +{ + //////////////////////////////////////////////////// + // Functions to generate messages + //////////////////////////////////////////////////// + + // /// Manipulate a response message sent from a backend application to forward downstream to a client. + // fn generate_response_forwarded(&self, response: &mut Response, chosen_backend: &Backend) -> Result<()> { + // where + // B: core::fmt::Debug, + // { + // let headers = response.headers_mut(); + // remove_connection_header(headers); + // remove_hop_header(headers); + // add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; + + // #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + // { + // // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled + // // TODO: This is a workaround for avoiding a client authentication in HTTP/3 + // if self.globals.proxy_config.http3 + // && chosen_backend + // .crypto_source + // .as_ref() + // .is_some_and(|v| !v.is_mutual_tls()) + // { + // if let Some(port) = self.globals.proxy_config.https_port { + // add_header_entry_overwrite_if_exist( + // headers, + // header::ALT_SVC.as_str(), + // format!( + // "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", + // port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age + // ), + // )?; + // } + // } else { + // // remove alt-svc to disallow requests via http3 + // headers.remove(header::ALT_SVC.as_str()); + // } + // } + // #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + // { + // if let Some(port) = self.globals.proxy_config.https_port { + // headers.remove(header::ALT_SVC.as_str()); + // } + // } + + // Ok(()) + // todo!() + // } + + #[allow(clippy::too_many_arguments)] + /// Manipulate a request message sent from a client to forward upstream to a backend application + pub(super) fn generate_request_forwarded( + &self, + client_addr: &SocketAddr, + listen_addr: &SocketAddr, + req: &mut Request, + upgrade: &Option, + upstream_candidates: &UpstreamCandidates, + tls_enabled: bool, + ) -> Result { + debug!("Generate request to be forwarded"); + + // Add te: trailer if contained in original request + let contains_te_trailers = { + if let Some(te) = req.headers().get(header::TE) { + te.as_bytes() + .split(|v| v == &b',' || v == &b' ') + .any(|x| x == "trailers".as_bytes()) + } else { + false + } + }; + + let uri = req.uri().to_string(); + let headers = req.headers_mut(); + // delete headers specified in header.connection + remove_connection_header(headers); + // delete hop headers including header.connection + remove_hop_header(headers); + // X-Forwarded-For + add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; + + // Add te: trailer if te_trailer + if contains_te_trailers { + headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); + } + + // add "host" header of original server_name if not exist (default) + if req.headers().get(header::HOST).is_none() { + let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); + req + .headers_mut() + .insert(header::HOST, HeaderValue::from_str(&org_host)?); + }; + + ///////////////////////////////////////////// + // Fix unique upstream destination since there could be multiple ones. + #[cfg(feature = "sticky-cookie")] + let (upstream_chosen_opt, context_from_lb) = { + let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_candidates.load_balance { + takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + } else { + None + }; + upstream_candidates.get(&context_to_lb) + }; + #[cfg(not(feature = "sticky-cookie"))] + let (upstream_chosen_opt, _) = upstream_candidates.get(&None); + + let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context = HandlerContext { + #[cfg(feature = "sticky-cookie")] + context_lb: context_from_lb, + #[cfg(not(feature = "sticky-cookie"))] + context_lb: None, + }; + ///////////////////////////////////////////// + + // apply upstream-specific headers given in upstream_option + let headers = req.headers_mut(); + // by default, host header is overwritten with upstream hostname + override_host_header(headers, &upstream_chosen.uri)?; + // apply upstream options to header + apply_upstream_options_to_header(headers, upstream_candidates)?; + + // update uri in request + ensure!( + upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some(), + "Upstream uri `scheme` and `authority` is broken" + ); + + let new_uri = Uri::builder() + .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) + .authority(upstream_chosen.uri.authority().unwrap().as_str()); + let org_pq = match req.uri().path_and_query() { + Some(pq) => pq.to_string(), + None => "/".to_string(), + } + .into_bytes(); + + // replace some parts of path if opt_replace_path is enabled for chosen upstream + let new_pq = match &upstream_candidates.replace_path { + Some(new_path) => { + let matched_path: &[u8] = upstream_candidates.path.as_ref(); + ensure!( + !matched_path.is_empty() && org_pq.len() >= matched_path.len(), + "Upstream uri `path and query` is broken" + ); + let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); + new_pq.extend_from_slice(new_path.as_ref()); + new_pq.extend_from_slice(&org_pq[matched_path.len()..]); + new_pq + } + None => org_pq, + }; + *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; + + // upgrade + if let Some(v) = upgrade { + req.headers_mut().insert(header::UPGRADE, v.parse()?); + req + .headers_mut() + .insert(header::CONNECTION, HeaderValue::from_static("upgrade")); + } + + // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 + if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { + // Change version to http/1.1 when destination scheme is http + debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); + *req.version_mut() = Version::HTTP_11; + } else if req.version() == Version::HTTP_3 { + // HTTP/3 is always https + debug!("HTTP/3 is currently unsupported for request to upstream."); + *req.version_mut() = Version::HTTP_2; + } + + apply_upstream_options_to_request_line(req, upstream_candidates)?; + + Ok(context) + } +} diff --git a/rpxy-lib/src/message_handle/http_result.rs b/rpxy-lib/src/message_handle/http_result.rs index 8e9d6b4..3f9df23 100644 --- a/rpxy-lib/src/message_handle/http_result.rs +++ b/rpxy-lib/src/message_handle/http_result.rs @@ -18,6 +18,13 @@ pub enum HttpError { NoMatchingBackendApp, #[error("Failed to redirect: {0}")] FailedToRedirect(String), + #[error("No upstream candidates")] + NoUpstreamCandidates, + #[error("Failed to generate upstream request: {0}")] + FailedToGenerateUpstreamRequest(String), + + #[error("Failed to add set-cookie header in response")] + FailedToAddSetCookeInResponse, #[error(transparent)] Other(#[from] anyhow::Error), @@ -30,6 +37,10 @@ impl From for StatusCode { HttpError::InvalidHostInRequestHeader => StatusCode::BAD_REQUEST, HttpError::SniHostInconsistency => StatusCode::MISDIRECTED_REQUEST, HttpError::NoMatchingBackendApp => StatusCode::SERVICE_UNAVAILABLE, + HttpError::FailedToRedirect(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::NoUpstreamCandidates => StatusCode::NOT_FOUND, + HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToAddSetCookeInResponse => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR, } } diff --git a/rpxy-lib/src/message_handle/mod.rs b/rpxy-lib/src/message_handle/mod.rs index f00b417..a9cb195 100644 --- a/rpxy-lib/src/message_handle/mod.rs +++ b/rpxy-lib/src/message_handle/mod.rs @@ -1,8 +1,10 @@ mod canonical_address; -mod handler; +mod handler_main; +mod handler_manipulate_messages; mod http_log; mod http_result; mod synthetic_response; +mod utils_headers; mod utils_request; -pub(crate) use handler::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; +pub(crate) use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; diff --git a/rpxy-lib/src/message_handle/utils_headers.rs b/rpxy-lib/src/message_handle/utils_headers.rs new file mode 100644 index 0000000..32bc7f3 --- /dev/null +++ b/rpxy-lib/src/message_handle/utils_headers.rs @@ -0,0 +1,292 @@ +use super::canonical_address::ToCanonical; +use crate::{ + backend::{UpstreamCandidates, UpstreamOption}, + log::*, +}; +use anyhow::{anyhow, ensure, Result}; +use bytes::BufMut; +use http::{header, HeaderMap, HeaderName, HeaderValue, Uri}; +use std::{borrow::Cow, net::SocketAddr}; + +#[cfg(feature = "sticky-cookie")] +use crate::backend::{LoadBalanceContext, StickyCookie, StickyCookieValue}; +// use crate::backend::{UpstreamGroup, UpstreamOption}; + +// //////////////////////////////////////////////////// +// // Functions to manipulate headers +#[cfg(feature = "sticky-cookie")] +/// Take sticky cookie header value from request header, +/// and returns LoadBalanceContext to be forwarded to LB if exist and if needed. +/// Removing sticky cookie is needed and it must not be passed to the upstream. +pub(super) fn takeout_sticky_cookie_lb_context( + headers: &mut HeaderMap, + expected_cookie_name: &str, +) -> Result> { + let mut headers_clone = headers.clone(); + + match headers_clone.entry(header::COOKIE) { + header::Entry::Vacant(_) => Ok(None), + header::Entry::Occupied(entry) => { + let cookies_iter = entry + .iter() + .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); + let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter + .into_iter() + .partition(|v| v.starts_with(expected_cookie_name)); + if sticky_cookies.is_empty() { + return Ok(None); + } + ensure!( + sticky_cookies.len() == 1, + "Invalid cookie: Multiple sticky cookie values" + ); + + let cookies_passed_to_upstream = without_sticky_cookies.join("; "); + let cookie_passed_to_lb = sticky_cookies.first().unwrap(); + headers.remove(header::COOKIE); + headers.insert(header::COOKIE, cookies_passed_to_upstream.parse()?); + + let sticky_cookie = StickyCookie { + value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, + info: None, + }; + Ok(Some(LoadBalanceContext { sticky_cookie })) + } + } +} + +#[cfg(feature = "sticky-cookie")] +/// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated. +/// Set-Cookie response header could be in multiple lines. +/// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie +pub(super) fn set_sticky_cookie_lb_context( + headers: &mut HeaderMap, + context_from_lb: &LoadBalanceContext, +) -> Result<()> { + let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; + let new_header_val: HeaderValue = sticky_cookie_string.parse()?; + let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; + match headers.entry(header::SET_COOKIE) { + header::Entry::Vacant(entry) => { + entry.insert(new_header_val); + } + header::Entry::Occupied(mut entry) => { + let mut flag = false; + for e in entry.iter_mut() { + if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { + *e = new_header_val.clone(); + flag = true; + } + } + if !flag { + entry.append(new_header_val); + } + } + }; + Ok(()) +} + +/// default: overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy) +pub(super) fn override_host_header(headers: &mut HeaderMap, upstream_base_uri: &Uri) -> Result<()> { + let mut upstream_host = upstream_base_uri + .host() + .ok_or_else(|| anyhow!("No hostname is given"))? + .to_string(); + // add port if it is not default + if let Some(port) = upstream_base_uri.port_u16() { + upstream_host = format!("{}:{}", upstream_host, port); + } + + // overwrite host header, this removes all the HOST header values + headers.insert(header::HOST, HeaderValue::from_str(&upstream_host)?); + Ok(()) +} + +/// Apply options to request header, which are specified in the configuration +pub(super) fn apply_upstream_options_to_header( + headers: &mut HeaderMap, + // _client_addr: &SocketAddr, + upstream: &UpstreamCandidates, + // _upstream_base_uri: &Uri, +) -> Result<()> { + for opt in upstream.options.iter() { + match opt { + UpstreamOption::DisableOverrideHost => { + // simply remove HOST header value + headers + .remove(header::HOST) + .ok_or_else(|| anyhow!("Failed to remove host header in disable_override_host option"))?; + } + UpstreamOption::UpgradeInsecureRequests => { + // add upgrade-insecure-requests in request header if not exist + headers + .entry(header::UPGRADE_INSECURE_REQUESTS) + .or_insert(HeaderValue::from_bytes(&[b'1']).unwrap()); + } + _ => (), + } + } + + Ok(()) +} + +/// Append header entry with comma according to [RFC9110](https://datatracker.ietf.org/doc/html/rfc9110) +pub(super) fn append_header_entry_with_comma(headers: &mut HeaderMap, key: &str, value: &str) -> Result<()> { + match headers.entry(HeaderName::from_bytes(key.as_bytes())?) { + header::Entry::Vacant(entry) => { + entry.insert(value.parse::()?); + } + header::Entry::Occupied(mut entry) => { + // entry.append(value.parse::()?); + let mut new_value = Vec::::with_capacity(entry.get().as_bytes().len() + 2 + value.len()); + new_value.put_slice(entry.get().as_bytes()); + new_value.put_slice(&[b',', b' ']); + new_value.put_slice(value.as_bytes()); + entry.insert(HeaderValue::from_bytes(&new_value)?); + } + } + + Ok(()) +} + +/// Add header entry if not exist +pub(super) fn add_header_entry_if_not_exist( + headers: &mut HeaderMap, + key: impl Into>, + value: impl Into>, +) -> Result<()> { + match headers.entry(HeaderName::from_bytes(key.into().as_bytes())?) { + header::Entry::Vacant(entry) => { + entry.insert(value.into().parse::()?); + } + header::Entry::Occupied(_) => (), + }; + + Ok(()) +} + +/// Overwrite header entry if exist +pub(super) fn add_header_entry_overwrite_if_exist( + headers: &mut HeaderMap, + key: impl Into>, + value: impl Into>, +) -> Result<()> { + match headers.entry(HeaderName::from_bytes(key.into().as_bytes())?) { + header::Entry::Vacant(entry) => { + entry.insert(value.into().parse::()?); + } + header::Entry::Occupied(mut entry) => { + entry.insert(HeaderValue::from_bytes(value.into().as_bytes())?); + } + } + + Ok(()) +} + +/// Align cookie values in single line +/// Sometimes violates [RFC6265](https://www.rfc-editor.org/rfc/rfc6265#section-5.4) (for http/1.1). +/// This is allowed in RFC7540 (for http/2) as mentioned [here](https://stackoverflow.com/questions/4843556/in-http-specification-what-is-the-string-that-separates-cookies). +pub(super) fn make_cookie_single_line(headers: &mut HeaderMap) -> Result<()> { + let cookies = headers + .iter() + .filter(|(k, _)| **k == header::COOKIE) + .map(|(_, v)| v.to_str().unwrap_or("")) + .collect::>() + .join("; "); + if !cookies.is_empty() { + headers.remove(header::COOKIE); + headers.insert(header::COOKIE, HeaderValue::from_bytes(cookies.as_bytes())?); + } + Ok(()) +} + +/// Add forwarding headers like `x-forwarded-for`. +pub(super) fn add_forwarding_header( + headers: &mut HeaderMap, + client_addr: &SocketAddr, + listen_addr: &SocketAddr, + tls: bool, + uri_str: &str, +) -> Result<()> { + // default process + // optional process defined by upstream_option is applied in fn apply_upstream_options + let canonical_client_addr = client_addr.to_canonical().ip().to_string(); + append_header_entry_with_comma(headers, "x-forwarded-for", &canonical_client_addr)?; + + // Single line cookie header + // TODO: This should be only for HTTP/1.1. For 2+, this can be multi-lined. + make_cookie_single_line(headers)?; + + /////////// As Nginx + // If we receive X-Forwarded-Proto, pass it through; otherwise, pass along the + // scheme used to connect to this server + add_header_entry_if_not_exist(headers, "x-forwarded-proto", if tls { "https" } else { "http" })?; + // If we receive X-Forwarded-Port, pass it through; otherwise, pass along the + // server port the client connected to + add_header_entry_if_not_exist(headers, "x-forwarded-port", listen_addr.port().to_string())?; + + /////////// As Nginx-Proxy + // x-real-ip + add_header_entry_overwrite_if_exist(headers, "x-real-ip", canonical_client_addr)?; + // x-forwarded-ssl + add_header_entry_overwrite_if_exist(headers, "x-forwarded-ssl", if tls { "on" } else { "off" })?; + // x-original-uri + add_header_entry_overwrite_if_exist(headers, "x-original-uri", uri_str.to_string())?; + // proxy + add_header_entry_overwrite_if_exist(headers, "proxy", "")?; + + Ok(()) +} + +/// Remove connection header +pub(super) fn remove_connection_header(headers: &mut HeaderMap) { + if let Some(values) = headers.get(header::CONNECTION) { + if let Ok(v) = values.clone().to_str() { + for m in v.split(',') { + if !m.is_empty() { + headers.remove(m.trim()); + } + } + } + } +} + +/// Hop header values which are removed at proxy +const HOP_HEADERS: &[&str] = &[ + "connection", + "te", + "trailer", + "keep-alive", + "proxy-connection", + "proxy-authenticate", + "proxy-authorization", + "transfer-encoding", + "upgrade", +]; + +/// Remove hop headers +pub(super) fn remove_hop_header(headers: &mut HeaderMap) { + HOP_HEADERS.iter().for_each(|key| { + headers.remove(*key); + }); +} + +/// Extract upgrade header value if exist +pub(super) fn extract_upgrade(headers: &HeaderMap) -> Option { + if let Some(c) = headers.get(header::CONNECTION) { + if c + .to_str() + .unwrap_or("") + .split(',') + .any(|w| w.trim().to_ascii_lowercase() == header::UPGRADE.as_str().to_ascii_lowercase()) + { + if let Some(u) = headers.get(header::UPGRADE) { + if let Ok(m) = u.to_str() { + debug!("Upgrade in request header: {}", m); + return Some(m.to_owned()); + } + } + } + } + None +} diff --git a/rpxy-lib/src/message_handle/utils_request.rs b/rpxy-lib/src/message_handle/utils_request.rs index a8f9bd4..37d5e0b 100644 --- a/rpxy-lib/src/message_handle/utils_request.rs +++ b/rpxy-lib/src/message_handle/utils_request.rs @@ -1,43 +1,71 @@ -use super::http_result::*; +use crate::backend::{UpstreamCandidates, UpstreamOption}; +use anyhow::{anyhow, ensure, Result}; use http::{header, Request}; /// Trait defining parser of hostname -pub trait ParseHost { +/// Inspect and extract hostname from either the request HOST header or request line +pub trait InspectParseHost { type Error; - fn parse_host(&self) -> Result<&[u8], Self::Error>; + fn inspect_parse_host(&self) -> Result, Self::Error>; } -impl ParseHost for Request { - type Error = HttpError; - /// Extract hostname from either the request HOST header or request line - fn parse_host(&self) -> HttpResult<&[u8]> { - let headers_host = self.headers().get(header::HOST); - let uri_host = self.uri().host(); +impl InspectParseHost for Request { + type Error = anyhow::Error; + /// Inspect and extract hostname from either the request HOST header or request line + fn inspect_parse_host(&self) -> Result> { + let drop_port = |v: &[u8]| { + if v.starts_with(&[b'[']) { + // v6 address with bracket case. if port is specified, always it is in this case. + let mut iter = v.split(|ptr| ptr == &b'[' || ptr == &b']'); + iter.next().ok_or(anyhow!("Invalid Host header"))?; // first item is always blank + iter.next().ok_or(anyhow!("Invalid Host header")).map(|b| b.to_owned()) + } else if v.len() - v.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { + // v6 address case, if 2 or more ':' is contained + Ok(v.to_owned()) + } else { + // v4 address or hostname + v.split(|colon| colon == &b':') + .next() + .ok_or(anyhow!("Invalid Host header")) + .map(|v| v.to_ascii_lowercase()) + } + }; + + let headers_host = self.headers().get(header::HOST).map(|v| drop_port(v.as_bytes())); + let uri_host = self.uri().host().map(|v| drop_port(v.as_bytes())); // let uri_port = self.uri().port_u16(); - if !(!(headers_host.is_none() && uri_host.is_none())) { - return Err(HttpError::NoHostInRequestHeader); - } - // prioritize server_name in uri - uri_host.map_or_else( - || { - let m = headers_host.unwrap().as_bytes(); - if m.starts_with(&[b'[']) { - // v6 address with bracket case. if port is specified, always it is in this case. - let mut iter = m.split(|ptr| ptr == &b'[' || ptr == &b']'); - iter.next().ok_or(HttpError::InvalidHostInRequestHeader)?; // first item is always blank - iter.next().ok_or(HttpError::InvalidHostInRequestHeader) - } else if m.len() - m.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { - // v6 address case, if 2 or more ':' is contained - Ok(m) - } else { - // v4 address or hostname - m.split(|colon| colon == &b':') - .next() - .ok_or(HttpError::InvalidHostInRequestHeader) - } - }, - |v| Ok(v.as_bytes()), - ) + match (headers_host, uri_host) { + (Some(Ok(hh)), Some(Ok(hu))) => { + ensure!(hh == hu, "Host header and uri host mismatch"); + Ok(hh) + } + (Some(Ok(hh)), None) => Ok(hh), + (None, Some(Ok(hu))) => Ok(hu), + _ => Err(anyhow!("Neither Host header nor uri host is valid")), + } } } + +//////////////////////////////////////////////////// +// Functions to manipulate request line + +/// Apply upstream options in request line, specified in the configuration +pub(super) fn apply_upstream_options_to_request_line( + req: &mut Request, + upstream: &UpstreamCandidates, +) -> anyhow::Result<()> { + for opt in upstream.options.iter() { + match opt { + UpstreamOption::ForceHttp11Upstream => *req.version_mut() = hyper::Version::HTTP_11, + UpstreamOption::ForceHttp2Upstream => { + // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt + // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required. + *req.version_mut() = hyper::Version::HTTP_2; + } + _ => (), + } + } + + Ok(()) +} From f9453fe9568933bc77a3bf5024a1e3b3cf390909 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 01:00:14 +0900 Subject: [PATCH 13/50] wip: fix private type --- rpxy-lib/src/message_handle/handler_main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rpxy-lib/src/message_handle/handler_main.rs b/rpxy-lib/src/message_handle/handler_main.rs index 5be08f1..1c62bd4 100644 --- a/rpxy-lib/src/message_handle/handler_main.rs +++ b/rpxy-lib/src/message_handle/handler_main.rs @@ -32,7 +32,7 @@ pub(super) struct HandlerContext { /// HTTP message handler for requests from clients and responses from backend applications, /// responsible to manipulate and forward messages to upstream backends and downstream clients. // pub struct HttpMessageHandler -pub struct HttpMessageHandler +pub(crate) struct HttpMessageHandler where // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone, From ab4ac3b00edfb78815e7062c7bff8772a6119b97 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 01:17:12 +0900 Subject: [PATCH 14/50] fix private type again --- rpxy-lib/src/globals.rs | 2 +- rpxy-lib/src/message_handle/handler_main.rs | 2 +- rpxy-lib/src/message_handle/mod.rs | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index 71a2dca..86fdc46 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -8,7 +8,7 @@ use std::{net::SocketAddr, sync::Arc, time::Duration}; /// Global object containing proxy configurations and shared object like counters. /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. -pub(crate) struct Globals { +pub struct Globals { /// Configuration parameters for proxy transport and request handlers pub proxy_config: ProxyConfig, /// Shared context - Counter for serving requests diff --git a/rpxy-lib/src/message_handle/handler_main.rs b/rpxy-lib/src/message_handle/handler_main.rs index 1c62bd4..5be08f1 100644 --- a/rpxy-lib/src/message_handle/handler_main.rs +++ b/rpxy-lib/src/message_handle/handler_main.rs @@ -32,7 +32,7 @@ pub(super) struct HandlerContext { /// HTTP message handler for requests from clients and responses from backend applications, /// responsible to manipulate and forward messages to upstream backends and downstream clients. // pub struct HttpMessageHandler -pub(crate) struct HttpMessageHandler +pub struct HttpMessageHandler where // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone, diff --git a/rpxy-lib/src/message_handle/mod.rs b/rpxy-lib/src/message_handle/mod.rs index a9cb195..edeba27 100644 --- a/rpxy-lib/src/message_handle/mod.rs +++ b/rpxy-lib/src/message_handle/mod.rs @@ -7,4 +7,5 @@ mod synthetic_response; mod utils_headers; mod utils_request; -pub(crate) use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; +pub use handler_main::HttpMessageHandlerBuilderError; +pub(crate) use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder}; From f0b0dbc252f644ba96a57d423a1fad2a7c048f61 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 16:56:23 +0900 Subject: [PATCH 15/50] wip: manipulate response header --- rpxy-lib/src/backend/mod.rs | 2 +- rpxy-lib/src/message_handle/handler_main.rs | 16 ++-- .../handler_manipulate_messages.rs | 90 +++++++++---------- rpxy-lib/src/message_handle/http_result.rs | 3 + 4 files changed, 57 insertions(+), 54 deletions(-) diff --git a/rpxy-lib/src/backend/mod.rs b/rpxy-lib/src/backend/mod.rs index 3d3ddbf..788960d 100644 --- a/rpxy-lib/src/backend/mod.rs +++ b/rpxy-lib/src/backend/mod.rs @@ -10,4 +10,4 @@ pub(crate) use self::{ upstream::{PathManager, Upstream, UpstreamCandidates}, upstream_opts::UpstreamOption, }; -pub(crate) use backend_main::{BackendAppBuilderError, BackendAppManager}; +pub(crate) use backend_main::{BackendApp, BackendAppBuilderError, BackendAppManager}; diff --git a/rpxy-lib/src/message_handle/handler_main.rs b/rpxy-lib/src/message_handle/handler_main.rs index 5be08f1..666faa3 100644 --- a/rpxy-lib/src/message_handle/handler_main.rs +++ b/rpxy-lib/src/message_handle/handler_main.rs @@ -38,7 +38,7 @@ where U: CryptoSource + Clone, { // forwarder: Arc>, - globals: Arc, + pub(super) globals: Arc, app_manager: Arc>, } @@ -155,7 +155,7 @@ where tls_enabled, ) { Err(e) => { - error!("Failed to generate destination uri for backend application: {}", e); + error!("Failed to generate upstream request for backend application: {}", e); return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string())); } Ok(v) => v, @@ -199,12 +199,12 @@ where } if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { - // // Generate response to client - // if self.generate_response_forwarded(&mut res_backend, backend).is_err() { - // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - // } - // log_data.status_code(&res_backend.status()).output(); - // return Ok(res_backend); + // Generate response to client + if let Err(e) = self.generate_response_forwarded(&mut res_backend, backend_app) { + error!("Failed to generate downstream response for clients: {}", e); + return Err(HttpError::FailedToGenerateDownstreamResponse(e.to_string())); + } + return Ok(res_backend); } // // Handle StatusCode::SWITCHING_PROTOCOLS in response diff --git a/rpxy-lib/src/message_handle/handler_manipulate_messages.rs b/rpxy-lib/src/message_handle/handler_manipulate_messages.rs index 28a62b1..a33e58f 100644 --- a/rpxy-lib/src/message_handle/handler_manipulate_messages.rs +++ b/rpxy-lib/src/message_handle/handler_manipulate_messages.rs @@ -2,9 +2,14 @@ use super::{ handler_main::HandlerContext, utils_headers::*, utils_request::apply_upstream_options_to_request_line, HttpMessageHandler, }; -use crate::{backend::UpstreamCandidates, log::*, CryptoSource}; +use crate::{ + backend::{BackendApp, UpstreamCandidates}, + constants::RESPONSE_HEADER_SERVER, + log::*, + CryptoSource, +}; use anyhow::{anyhow, ensure, Result}; -use http::{header, uri::Scheme, HeaderValue, Request, Uri, Version}; +use http::{header, uri::Scheme, HeaderValue, Request, Response, Uri, Version}; use std::net::SocketAddr; impl HttpMessageHandler @@ -15,51 +20,46 @@ where // Functions to generate messages //////////////////////////////////////////////////// - // /// Manipulate a response message sent from a backend application to forward downstream to a client. - // fn generate_response_forwarded(&self, response: &mut Response, chosen_backend: &Backend) -> Result<()> { - // where - // B: core::fmt::Debug, - // { - // let headers = response.headers_mut(); - // remove_connection_header(headers); - // remove_hop_header(headers); - // add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; + /// Manipulate a response message sent from a backend application to forward downstream to a client. + pub(super) fn generate_response_forwarded( + &self, + response: &mut Response, + backend_app: &BackendApp, + ) -> Result<()> { + let headers = response.headers_mut(); + remove_connection_header(headers); + remove_hop_header(headers); + add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; - // #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] - // { - // // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled - // // TODO: This is a workaround for avoiding a client authentication in HTTP/3 - // if self.globals.proxy_config.http3 - // && chosen_backend - // .crypto_source - // .as_ref() - // .is_some_and(|v| !v.is_mutual_tls()) - // { - // if let Some(port) = self.globals.proxy_config.https_port { - // add_header_entry_overwrite_if_exist( - // headers, - // header::ALT_SVC.as_str(), - // format!( - // "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", - // port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age - // ), - // )?; - // } - // } else { - // // remove alt-svc to disallow requests via http3 - // headers.remove(header::ALT_SVC.as_str()); - // } - // } - // #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] - // { - // if let Some(port) = self.globals.proxy_config.https_port { - // headers.remove(header::ALT_SVC.as_str()); - // } - // } + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + { + // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled + // TODO: This is a workaround for avoiding a client authentication in HTTP/3 + if self.globals.proxy_config.http3 && backend_app.crypto_source.as_ref().is_some_and(|v| !v.is_mutual_tls()) { + if let Some(port) = self.globals.proxy_config.https_port { + add_header_entry_overwrite_if_exist( + headers, + header::ALT_SVC.as_str(), + format!( + "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", + port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age + ), + )?; + } + } else { + // remove alt-svc to disallow requests via http3 + headers.remove(header::ALT_SVC.as_str()); + } + } + #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + { + if self.globals.proxy_config.https_port.is_some() { + headers.remove(header::ALT_SVC.as_str()); + } + } - // Ok(()) - // todo!() - // } + Ok(()) + } #[allow(clippy::too_many_arguments)] /// Manipulate a request message sent from a client to forward upstream to a backend application diff --git a/rpxy-lib/src/message_handle/http_result.rs b/rpxy-lib/src/message_handle/http_result.rs index 3f9df23..07a0034 100644 --- a/rpxy-lib/src/message_handle/http_result.rs +++ b/rpxy-lib/src/message_handle/http_result.rs @@ -25,6 +25,8 @@ pub enum HttpError { #[error("Failed to add set-cookie header in response")] FailedToAddSetCookeInResponse, + #[error("Failed to generated downstream response: {0}")] + FailedToGenerateDownstreamResponse(String), #[error(transparent)] Other(#[from] anyhow::Error), @@ -41,6 +43,7 @@ impl From for StatusCode { HttpError::NoUpstreamCandidates => StatusCode::NOT_FOUND, HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToAddSetCookeInResponse => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR, } } From 8f77ce94473d92098925a53e76b9bdc5067d696a Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 18:04:12 +0900 Subject: [PATCH 16/50] wip: implement switching protocols (http upgrade) --- rpxy-lib/src/error.rs | 6 ++ rpxy-lib/src/message_handle/handler_main.rs | 90 ++++++++++----------- rpxy-lib/src/message_handle/http_log.rs | 10 +-- rpxy-lib/src/message_handle/http_result.rs | 10 +++ 4 files changed, 65 insertions(+), 51 deletions(-) diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index da65234..d7123a3 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -71,6 +71,12 @@ pub enum RpxyError { // Handler errors #[error("Failed to build message handler: {0}")] FailedToBuildMessageHandler(#[from] crate::message_handle::HttpMessageHandlerBuilderError), + #[error("Failed to upgrade request: {0}")] + FailedToUpgradeRequest(String), + #[error("Failed to upgrade response: {0}")] + FailedToUpgradeResponse(String), + #[error("Failed to copy bidirectional for upgraded connections: {0}")] + FailedToCopyBidirectional(String), // Upstream connection setting errors #[error("Unsupported upstream option")] diff --git a/rpxy-lib/src/message_handle/handler_main.rs b/rpxy-lib/src/message_handle/handler_main.rs index 666faa3..922e024 100644 --- a/rpxy-lib/src/message_handle/handler_main.rs +++ b/rpxy-lib/src/message_handle/handler_main.rs @@ -16,7 +16,9 @@ use crate::{ }; use derive_builder::Builder; use http::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; use std::{net::SocketAddr, sync::Arc}; +use tokio::io::copy_bidirectional; #[allow(dead_code)] #[derive(Debug)] @@ -51,13 +53,15 @@ where /// Responsible to passthrough responses from backend applications or generate synthetic error responses. pub async fn handle_request( &self, - mut req: Request>, + req: Request>, client_addr: SocketAddr, // For access control listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, ) -> RpxyResult>> { + // preparing log data let mut log_data = HttpMessageLog::from(&req); + log_data.client_addr(&client_addr); let http_result = self .handle_request_inner( @@ -96,10 +100,6 @@ where tls_enabled: bool, tls_server_name: Option, ) -> HttpResult>> { - // preparing log data - let mut log_data = HttpMessageLog::from(&req); - log_data.client_addr(&client_addr); - // Here we start to inspect and parse with server_name let server_name = req .inspect_parse_host() @@ -207,48 +207,46 @@ where return Ok(res_backend); } - // // Handle StatusCode::SWITCHING_PROTOCOLS in response - // let upgrade_in_response = extract_upgrade(res_backend.headers()); - // let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) - // { - // u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() - // } else { - // false - // }; - // if !should_upgrade { - // error!( - // "Backend tried to switch to protocol {:?} when {:?} was requested", - // upgrade_in_response, upgrade_in_request - // ); - // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - // } - // let Some(request_upgraded) = request_upgraded else { - // error!("Request does not have an upgrade extension"); - // return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - // }; - // let Some(onupgrade) = res_backend.extensions_mut().remove::() else { - // error!("Response does not have an upgrade extension"); - // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - // }; + // Handle StatusCode::SWITCHING_PROTOCOLS in response + let upgrade_in_response = extract_upgrade(res_backend.headers()); + let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) { + (Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(), + _ => false, + }; - // self.globals.runtime_handle.spawn(async move { - // let mut response_upgraded = onupgrade.await.map_err(|e| { - // error!("Failed to upgrade response: {}", e); - // RpxyError::Hyper(e) - // })?; - // let mut request_upgraded = request_upgraded.await.map_err(|e| { - // error!("Failed to upgrade request: {}", e); - // RpxyError::Hyper(e) - // })?; - // copy_bidirectional(&mut response_upgraded, &mut request_upgraded) - // .await - // .map_err(|e| { - // error!("Coping between upgraded connections failed: {}", e); - // RpxyError::Io(e) - // })?; - // Ok(()) as Result<()> - // }); - // log_data.status_code(&res_backend.status()).output(); + if !should_upgrade { + error!( + "Backend tried to switch to protocol {:?} when {:?} was requested", + upgrade_in_response, upgrade_in_request + ); + return Err(HttpError::FailedToUpgrade); + } + let Some(request_upgraded) = request_upgraded else { + error!("Request does not have an upgrade extension"); + return Err(HttpError::NoUpgradeExtensionInRequest); + }; + let Some(onupgrade) = res_backend.extensions_mut().remove::() else { + error!("Response does not have an upgrade extension"); + return Err(HttpError::NoUpgradeExtensionInResponse); + }; + + self.globals.runtime_handle.spawn(async move { + let mut response_upgraded = TokioIo::new(onupgrade.await.map_err(|e| { + error!("Failed to upgrade response: {}", e); + RpxyError::FailedToUpgradeResponse(e.to_string()) + })?); + let mut request_upgraded = TokioIo::new(request_upgraded.await.map_err(|e| { + error!("Failed to upgrade request: {}", e); + RpxyError::FailedToUpgradeRequest(e.to_string()) + })?); + copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + .await + .map_err(|e| { + error!("Coping between upgraded connections failed: {}", e); + RpxyError::FailedToCopyBidirectional(e.to_string()) + })?; + Ok(()) as RpxyResult<()> + }); Ok(res_backend) } diff --git a/rpxy-lib/src/message_handle/http_log.rs b/rpxy-lib/src/message_handle/http_log.rs index 7056c80..acda9f0 100644 --- a/rpxy-lib/src/message_handle/http_log.rs +++ b/rpxy-lib/src/message_handle/http_log.rs @@ -11,7 +11,7 @@ pub struct HttpMessageLog { pub method: String, pub host: String, pub p_and_q: String, - pub version: hyper::Version, + pub version: http::Version, pub uri_scheme: String, pub uri_host: String, pub ua: String, @@ -20,8 +20,8 @@ pub struct HttpMessageLog { pub upstream: String, } -impl From<&hyper::Request> for HttpMessageLog { - fn from(req: &hyper::Request) -> Self { +impl From<&http::Request> for HttpMessageLog { + fn from(req: &http::Request) -> Self { let header_mapper = |v: header::HeaderName| { req .headers() @@ -59,7 +59,7 @@ impl HttpMessageLog { // self.tls_server_name = tls_server_name.to_string(); // self // } - pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { + pub fn status_code(&mut self, status_code: &http::StatusCode) -> &mut Self { self.status = status_code.to_string(); self } @@ -67,7 +67,7 @@ impl HttpMessageLog { self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); self } - pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { + pub fn upstream(&mut self, upstream: &http::Uri) -> &mut Self { self.upstream = upstream.to_string(); self } diff --git a/rpxy-lib/src/message_handle/http_result.rs b/rpxy-lib/src/message_handle/http_result.rs index 07a0034..dc77565 100644 --- a/rpxy-lib/src/message_handle/http_result.rs +++ b/rpxy-lib/src/message_handle/http_result.rs @@ -28,6 +28,13 @@ pub enum HttpError { #[error("Failed to generated downstream response: {0}")] FailedToGenerateDownstreamResponse(String), + #[error("Failed to upgrade connection")] + FailedToUpgrade, + #[error("Request does not have an upgrade extension")] + NoUpgradeExtensionInRequest, + #[error("Response does not have an upgrade extension")] + NoUpgradeExtensionInResponse, + #[error(transparent)] Other(#[from] anyhow::Error), } @@ -44,6 +51,9 @@ impl From for StatusCode { HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToAddSetCookeInResponse => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToUpgrade => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST, + HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY, _ => StatusCode::INTERNAL_SERVER_ERROR, } } From f020ece60dd3cf9e26ff728f2870a02abfd1e6db Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 18:11:37 +0900 Subject: [PATCH 17/50] chore: change mod name --- rpxy-lib/src/error.rs | 2 +- rpxy-lib/src/lib.rs | 4 ++-- .../{message_handle => message_handler}/canonical_address.rs | 0 .../src/{message_handle => message_handler}/handler_main.rs | 0 .../handler_manipulate_messages.rs | 0 rpxy-lib/src/{message_handle => message_handler}/http_log.rs | 0 .../src/{message_handle => message_handler}/http_result.rs | 0 rpxy-lib/src/{message_handle => message_handler}/mod.rs | 0 .../{message_handle => message_handler}/synthetic_response.rs | 0 .../src/{message_handle => message_handler}/utils_headers.rs | 0 .../src/{message_handle => message_handler}/utils_request.rs | 0 rpxy-lib/src/proxy/proxy_main.rs | 2 +- 12 files changed, 4 insertions(+), 4 deletions(-) rename rpxy-lib/src/{message_handle => message_handler}/canonical_address.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/handler_main.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/handler_manipulate_messages.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/http_log.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/http_result.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/mod.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/synthetic_response.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/utils_headers.rs (100%) rename rpxy-lib/src/{message_handle => message_handler}/utils_request.rs (100%) diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index d7123a3..ee47e4f 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -70,7 +70,7 @@ pub enum RpxyError { // Handler errors #[error("Failed to build message handler: {0}")] - FailedToBuildMessageHandler(#[from] crate::message_handle::HttpMessageHandlerBuilderError), + FailedToBuildMessageHandler(#[from] crate::message_handler::HttpMessageHandlerBuilderError), #[error("Failed to upgrade request: {0}")] FailedToUpgradeRequest(String), #[error("Failed to upgrade response: {0}")] diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index c45327a..d13ade1 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -6,12 +6,12 @@ mod error; mod globals; mod hyper_ext; mod log; -mod message_handle; +mod message_handler; mod name_exp; mod proxy; use crate::{ - crypto::build_cert_reloader, error::*, globals::Globals, log::*, message_handle::HttpMessageHandlerBuilder, + crypto::build_cert_reloader, error::*, globals::Globals, log::*, message_handler::HttpMessageHandlerBuilder, proxy::Proxy, }; use futures::future::select_all; diff --git a/rpxy-lib/src/message_handle/canonical_address.rs b/rpxy-lib/src/message_handler/canonical_address.rs similarity index 100% rename from rpxy-lib/src/message_handle/canonical_address.rs rename to rpxy-lib/src/message_handler/canonical_address.rs diff --git a/rpxy-lib/src/message_handle/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs similarity index 100% rename from rpxy-lib/src/message_handle/handler_main.rs rename to rpxy-lib/src/message_handler/handler_main.rs diff --git a/rpxy-lib/src/message_handle/handler_manipulate_messages.rs b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs similarity index 100% rename from rpxy-lib/src/message_handle/handler_manipulate_messages.rs rename to rpxy-lib/src/message_handler/handler_manipulate_messages.rs diff --git a/rpxy-lib/src/message_handle/http_log.rs b/rpxy-lib/src/message_handler/http_log.rs similarity index 100% rename from rpxy-lib/src/message_handle/http_log.rs rename to rpxy-lib/src/message_handler/http_log.rs diff --git a/rpxy-lib/src/message_handle/http_result.rs b/rpxy-lib/src/message_handler/http_result.rs similarity index 100% rename from rpxy-lib/src/message_handle/http_result.rs rename to rpxy-lib/src/message_handler/http_result.rs diff --git a/rpxy-lib/src/message_handle/mod.rs b/rpxy-lib/src/message_handler/mod.rs similarity index 100% rename from rpxy-lib/src/message_handle/mod.rs rename to rpxy-lib/src/message_handler/mod.rs diff --git a/rpxy-lib/src/message_handle/synthetic_response.rs b/rpxy-lib/src/message_handler/synthetic_response.rs similarity index 100% rename from rpxy-lib/src/message_handle/synthetic_response.rs rename to rpxy-lib/src/message_handler/synthetic_response.rs diff --git a/rpxy-lib/src/message_handle/utils_headers.rs b/rpxy-lib/src/message_handler/utils_headers.rs similarity index 100% rename from rpxy-lib/src/message_handle/utils_headers.rs rename to rpxy-lib/src/message_handler/utils_headers.rs diff --git a/rpxy-lib/src/message_handle/utils_request.rs b/rpxy-lib/src/message_handler/utils_request.rs similarity index 100% rename from rpxy-lib/src/message_handle/utils_request.rs rename to rpxy-lib/src/message_handler/utils_request.rs diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index abdea64..f8f04c0 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -9,7 +9,7 @@ use crate::{ rt::LocalExecutor, }, log::*, - message_handle::HttpMessageHandler, + message_handler::HttpMessageHandler, name_exp::ServerName, }; use futures::{select, FutureExt}; From a9f5e0ede58700d7dde6350d11603cb0d64cc33b Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 22:22:40 +0900 Subject: [PATCH 18/50] feat: client (wip), still unstable for http2 due to alpn issues --- CHANGELOG.md | 3 + rpxy-bin/src/log.rs | 6 +- rpxy-lib/Cargo.toml | 1 + rpxy-lib/src/error.rs | 4 + rpxy-lib/src/forwarder/client.rs | 117 ++++++++++++++++++ rpxy-lib/src/forwarder/mod.rs | 8 ++ rpxy-lib/src/hyper_ext/body_type.rs | 14 +++ rpxy-lib/src/hyper_ext/mod.rs | 4 +- rpxy-lib/src/lib.rs | 7 +- rpxy-lib/src/message_handler/handler_main.rs | 51 ++++---- rpxy-lib/src/message_handler/http_result.rs | 21 ++-- .../src/message_handler/synthetic_response.rs | 17 +-- rpxy-lib/src/message_handler/utils_request.rs | 4 +- rpxy-lib/src/proxy/proxy_h3.rs | 10 +- rpxy-lib/src/proxy/proxy_main.rs | 2 +- 15 files changed, 199 insertions(+), 70 deletions(-) create mode 100644 rpxy-lib/src/forwarder/client.rs create mode 100644 rpxy-lib/src/forwarder/mod.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 20ac679..de8871f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## 0.7.0 (unreleased) +- Breaking: `hyper`-1.0 for both server and client modules. +- Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `disable_override_host`. That is, `rpxy` always override the host header by the upstream hostname by default. + ## 0.6.2 ### Improvement diff --git a/rpxy-bin/src/log.rs b/rpxy-bin/src/log.rs index 3fcf694..fd7b5cb 100644 --- a/rpxy-bin/src/log.rs +++ b/rpxy-bin/src/log.rs @@ -13,9 +13,9 @@ pub fn init_logger() { .compact(); // This limits the logger to emits only rpxy crate - let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); - let filter_layer = EnvFilter::new(format!("{}={}", env!("CARGO_PKG_NAME"), level_string)); - // let filter_layer = EnvFilter::from_default_env(); + // let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); + // let filter_layer = EnvFilter::new(format!("{}={}", env!("CARGO_PKG_NAME"), level_string)); + let filter_layer = EnvFilter::from_default_env(); tracing_subscriber::registry() .with(format_layer) diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 51631af..5df1060 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -58,6 +58,7 @@ futures-channel = { version = "0.3.29", default-features = false } # "http1", # "http2", # ] } +hyper-tls = { version = "0.6.0", features = ["alpn"] } # tls and cert management hot_reload = "0.1.4" diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index ee47e4f..438e7bb 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -78,6 +78,10 @@ pub enum RpxyError { #[error("Failed to copy bidirectional for upgraded connections: {0}")] FailedToCopyBidirectional(String), + // Forwarder errors + #[error("Failed to fetch from upstream: {0}")] + FailedToFetchFromUpstream(String), + // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs new file mode 100644 index 0000000..57538fa --- /dev/null +++ b/rpxy-lib/src/forwarder/client.rs @@ -0,0 +1,117 @@ +use crate::{ + error::{RpxyError, RpxyResult}, + globals::Globals, + hyper_ext::{ + body::{wrap_incoming_body_response, IncomingOr}, + rt::LocalExecutor, + }, +}; +use async_trait::async_trait; +use http::{Request, Response, Version}; +use hyper::body::Body; +use hyper_tls::HttpsConnector; +use hyper_util::client::legacy::{ + connect::{Connect, HttpConnector}, + Client, +}; +use std::sync::Arc; + +#[async_trait] +/// Definition of the forwarder that simply forward requests from downstream client to upstream app servers. +pub trait ForwardRequest { + type Error; + async fn request(&self, req: Request) -> Result, Self::Error>; +} + +/// Forwarder http client struct responsible to cache handling +pub struct Forwarder { + // #[cfg(feature = "cache")] + // cache: Option, + inner: Client, + inner_h2: Client, // `h2c` or http/2-only client is defined separately +} + +#[async_trait] +impl ForwardRequest> for Forwarder +where + C: Send + Sync + Connect + Clone + 'static, + B1: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, + B2: Body, +{ + type Error = RpxyError; + + async fn request(&self, req: Request) -> Result>, Self::Error> { + self.request_directly(req).await + } +} + +impl Forwarder +where + C: Send + Sync + Connect + Clone + 'static, + B1: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + async fn request_directly(&self, req: Request) -> RpxyResult>> { + match req.version() { + Version::HTTP_2 => self.inner_h2.request(req).await, // handles `h2c` requests + _ => self.inner.request(req).await, + } + .map_err(|e| RpxyError::FailedToFetchFromUpstream(e.to_string())) + .map(wrap_incoming_body_response::) + } +} + +impl Forwarder, B1> +where + B1: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + /// Build forwarder + pub async fn new(_globals: &Arc) -> Self { + // build hyper client with hyper-tls + // TODO: Frame size errorが取れない > H2 どうしようもない。。。。 hyper_rustlsのリリース待ち? + let connector = HttpsConnector::new(); + let executor = LocalExecutor::new(_globals.runtime_handle.clone().clone()); + let inner = Client::builder(executor.clone()).build::<_, B1>(connector); + + let connector = HttpsConnector::new(); + let executor = LocalExecutor::new(_globals.runtime_handle.clone()); + let inner_h2 = Client::builder(executor) + .http2_adaptive_window(true) + .http2_only(true) + .set_host(true) + .build::<_, B1>(connector); + + // #[cfg(feature = "native-roots")] + // let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); + // #[cfg(feature = "native-roots")] + // let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); + // #[cfg(feature = "native-roots")] + // info!("Native cert store is used for the connection to backend applications"); + + // #[cfg(not(feature = "native-roots"))] + // let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); + // #[cfg(not(feature = "native-roots"))] + // let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); + // #[cfg(not(feature = "native-roots"))] + // info!("Mozilla WebPKI root certs is used for the connection to backend applications"); + + // let connector = builder.https_or_http().enable_http1().enable_http2().build(); + // let connector_h2 = builder_h2.https_or_http().enable_http2().build(); + + // let inner = Client::builder().build::<_, Body>(connector); + // let inner_h2 = Client::builder().http2_only(true).build::<_, Body>(connector_h2); + + // #[cfg(feature = "cache")] + // { + // let cache = RpxyCache::new(_globals).await; + // Self { inner, inner_h2, cache } + // } + // #[cfg(not(feature = "cache"))] + Self { inner, inner_h2 } + } +} diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs new file mode 100644 index 0000000..1cb67fd --- /dev/null +++ b/rpxy-lib/src/forwarder/mod.rs @@ -0,0 +1,8 @@ +mod client; + +use crate::hyper_ext::body::{IncomingLike, IncomingOr}; +use hyper_tls::HttpsConnector; +use hyper_util::client::legacy::connect::HttpConnector; +pub type Forwarder = client::Forwarder, IncomingOr>; + +pub use client::ForwardRequest; diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs index 516e569..9616306 100644 --- a/rpxy-lib/src/hyper_ext/body_type.rs +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -1,3 +1,4 @@ +use http::Response; use http_body_util::{combinators, BodyExt, Either, Empty, Full}; use hyper::body::{Bytes, Incoming}; @@ -6,6 +7,19 @@ pub(crate) type BoxBody = combinators::BoxBody; /// Type for either passthrough body or given body type, specifically synthetic boxed body pub(crate) type IncomingOr = Either; +/// helper function to build http response with passthrough body +pub(crate) fn wrap_incoming_body_response(response: Response) -> Response> +where + B: hyper::body::Body, +{ + response.map(IncomingOr::Left) +} + +/// helper function to build http response with synthetic body +pub(crate) fn wrap_synthetic_body_response(response: Response) -> Response> { + response.map(IncomingOr::Right) +} + /// helper function to build a empty body pub(crate) fn empty() -> BoxBody { Empty::::new().map_err(|never| match never {}).boxed() diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index a39aef9..e1b5ae8 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -8,5 +8,7 @@ pub(crate) mod rt { } pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; - pub(crate) use super::body_type::{empty, full, BoxBody, IncomingOr}; + pub(crate) use super::body_type::{ + empty, full, wrap_incoming_body_response, wrap_synthetic_body_response, BoxBody, IncomingOr, + }; } diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index d13ade1..c78e8c8 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -3,6 +3,7 @@ mod constants; mod count; mod crypto; mod error; +mod forwarder; mod globals; mod hyper_ext; mod log; @@ -11,8 +12,8 @@ mod name_exp; mod proxy; use crate::{ - crypto::build_cert_reloader, error::*, globals::Globals, log::*, message_handler::HttpMessageHandlerBuilder, - proxy::Proxy, + crypto::build_cert_reloader, error::*, forwarder::Forwarder, globals::Globals, log::*, + message_handler::HttpMessageHandlerBuilder, proxy::Proxy, }; use futures::future::select_all; use std::sync::Arc; @@ -90,10 +91,12 @@ where }); // 4. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well + let forwarder = Arc::new(Forwarder::new(&globals).await); let message_handler = Arc::new( HttpMessageHandlerBuilder::default() .globals(globals.clone()) .app_manager(app_manager.clone()) + .forwarder(forwarder) .build()?, ); diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index 922e024..d94d2c9 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -9,6 +9,7 @@ use crate::{ backend::{BackendAppManager, LoadBalanceContext}, crypto::CryptoSource, error::*, + forwarder::{ForwardRequest, Forwarder}, globals::Globals, hyper_ext::body::{BoxBody, IncomingLike, IncomingOr}, log::*, @@ -18,7 +19,7 @@ use derive_builder::Builder; use http::{Request, Response, StatusCode}; use hyper_util::rt::TokioIo; use std::{net::SocketAddr, sync::Arc}; -use tokio::io::copy_bidirectional; +use tokio::{io::copy_bidirectional, time::timeout}; #[allow(dead_code)] #[derive(Debug)] @@ -36,10 +37,9 @@ pub(super) struct HandlerContext { // pub struct HttpMessageHandler pub struct HttpMessageHandler where - // T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone, { - // forwarder: Arc>, + forwarder: Arc, pub(super) globals: Arc, app_manager: Arc>, } @@ -81,7 +81,7 @@ where Ok(v) } Err(e) => { - debug!("{e}"); + error!("{e}"); let code = StatusCode::from(e); log_data.status_code(&code).output(); synthetic_error_response(code) @@ -155,14 +155,14 @@ where tls_enabled, ) { Err(e) => { - error!("Failed to generate upstream request for backend application: {}", e); return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string())); } Ok(v) => v, }; debug!( - "Request to be forwarded: uri {}, version {:?}, headers {:?}", + "Request to be forwarded: [uri {}, method: {}, version {:?}, headers {:?}]", req.uri(), + req.method(), req.version(), req.headers() ); @@ -171,37 +171,31 @@ where ////// ////////////// - // // TODO: remove later - let body = crate::hyper_ext::body::full(hyper::body::Bytes::from("not yet implemented")); - let mut res_backend = super::synthetic_response::synthetic_response(Response::builder().body(body).unwrap()); - // // Forward request to a chosen backend - // let mut res_backend = { - // let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { - // return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); - // }; - // match result { - // Ok(res) => res, - // Err(e) => { - // error!("Failed to get response from backend: {}", e); - // return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - // } - // } - // }; + // Forward request to a chosen backend + let mut res_backend = { + let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { + return Err(HttpError::TimeoutUpstreamRequest); + }; + match result { + Ok(res) => res, + Err(e) => { + return Err(HttpError::FailedToGetResponseFromBackend(e.to_string())); + } + } + }; ////////////// // Process reverse proxy context generated during the forwarding request generation. #[cfg(feature = "sticky-cookie")] if let Some(context_from_lb) = _context.context_lb { let res_headers = res_backend.headers_mut(); if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { - error!("Failed to append context to the response given from backend: {}", e); - return Err(HttpError::FailedToAddSetCookeInResponse); + return Err(HttpError::FailedToAddSetCookeInResponse(e.to_string())); } } if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { // Generate response to client if let Err(e) = self.generate_response_forwarded(&mut res_backend, backend_app) { - error!("Failed to generate downstream response for clients: {}", e); return Err(HttpError::FailedToGenerateDownstreamResponse(e.to_string())); } return Ok(res_backend); @@ -215,18 +209,15 @@ where }; if !should_upgrade { - error!( + return Err(HttpError::FailedToUpgrade(format!( "Backend tried to switch to protocol {:?} when {:?} was requested", upgrade_in_response, upgrade_in_request - ); - return Err(HttpError::FailedToUpgrade); + ))); } let Some(request_upgraded) = request_upgraded else { - error!("Request does not have an upgrade extension"); return Err(HttpError::NoUpgradeExtensionInRequest); }; let Some(onupgrade) = res_backend.extensions_mut().remove::() else { - error!("Response does not have an upgrade extension"); return Err(HttpError::NoUpgradeExtensionInResponse); }; diff --git a/rpxy-lib/src/message_handler/http_result.rs b/rpxy-lib/src/message_handler/http_result.rs index dc77565..857ab55 100644 --- a/rpxy-lib/src/message_handler/http_result.rs +++ b/rpxy-lib/src/message_handler/http_result.rs @@ -20,16 +20,20 @@ pub enum HttpError { FailedToRedirect(String), #[error("No upstream candidates")] NoUpstreamCandidates, - #[error("Failed to generate upstream request: {0}")] + #[error("Failed to generate upstream request for backend application: {0}")] FailedToGenerateUpstreamRequest(String), + #[error("Timeout in upstream request")] + TimeoutUpstreamRequest, + #[error("Failed to get response from backend: {0}")] + FailedToGetResponseFromBackend(String), - #[error("Failed to add set-cookie header in response")] - FailedToAddSetCookeInResponse, - #[error("Failed to generated downstream response: {0}")] + #[error("Failed to add set-cookie header in response {0}")] + FailedToAddSetCookeInResponse(String), + #[error("Failed to generated downstream response for clients: {0}")] FailedToGenerateDownstreamResponse(String), - #[error("Failed to upgrade connection")] - FailedToUpgrade, + #[error("Failed to upgrade connection: {0}")] + FailedToUpgrade(String), #[error("Request does not have an upgrade extension")] NoUpgradeExtensionInRequest, #[error("Response does not have an upgrade extension")] @@ -49,9 +53,10 @@ impl From for StatusCode { HttpError::FailedToRedirect(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::NoUpstreamCandidates => StatusCode::NOT_FOUND, HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, - HttpError::FailedToAddSetCookeInResponse => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::TimeoutUpstreamRequest => StatusCode::GATEWAY_TIMEOUT, + HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, - HttpError::FailedToUpgrade => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST, HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY, _ => StatusCode::INTERNAL_SERVER_ERROR, diff --git a/rpxy-lib/src/message_handler/synthetic_response.rs b/rpxy-lib/src/message_handler/synthetic_response.rs index 0038997..60aeeec 100644 --- a/rpxy-lib/src/message_handler/synthetic_response.rs +++ b/rpxy-lib/src/message_handler/synthetic_response.rs @@ -1,25 +1,10 @@ +use super::http_result::{HttpError, HttpResult}; use crate::{ error::*, hyper_ext::body::{empty, BoxBody, IncomingOr}, name_exp::ServerName, }; use http::{Request, Response, StatusCode, Uri}; -use hyper::body::Incoming; - -use super::http_result::{HttpError, HttpResult}; - -/// helper function to build http response with passthrough body -pub(crate) fn passthrough_response(response: Response) -> Response> -where - B: hyper::body::Body, -{ - response.map(IncomingOr::Left) -} - -/// helper function to build http response with synthetic body -pub(crate) fn synthetic_response(response: Response) -> Response> { - response.map(IncomingOr::Right) -} /// build http response with status code of 4xx and 5xx pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult>> { diff --git a/rpxy-lib/src/message_handler/utils_request.rs b/rpxy-lib/src/message_handler/utils_request.rs index 37d5e0b..aa4ce42 100644 --- a/rpxy-lib/src/message_handler/utils_request.rs +++ b/rpxy-lib/src/message_handler/utils_request.rs @@ -57,11 +57,11 @@ pub(super) fn apply_upstream_options_to_request_line( ) -> anyhow::Result<()> { for opt in upstream.options.iter() { match opt { - UpstreamOption::ForceHttp11Upstream => *req.version_mut() = hyper::Version::HTTP_11, + UpstreamOption::ForceHttp11Upstream => *req.version_mut() = http::Version::HTTP_11, UpstreamOption::ForceHttp2Upstream => { // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required. - *req.version_mut() = hyper::Version::HTTP_2; + *req.version_mut() = http::Version::HTTP_2; } _ => (), } diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 922b857..5b77263 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -92,13 +92,9 @@ where } /// Serves a request stream from a client - /// TODO: TODO: TODO: TODO: - /// TODO: Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside. - /// Thus, we need to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of + /// Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside. + /// Thus, we needed to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of /// Either as body. - /// Also, the downstream from the backend handler could be Incoming, but will be wrapped as Either as well due to H3. - /// Result, E> type includes E as HttpError to generate the status code and related Response. - /// Thus to handle synthetic error messages in BoxBody, the serve() function outputs Response, BoxBody>>>. async fn h3_serve_stream( &self, req: Request<()>, @@ -146,7 +142,7 @@ where Ok(()) as RpxyResult<()> }); - let mut new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); + let new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); // Response> wrapped by RpxyResult let res = self .message_handler diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index f8f04c0..28ad76e 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -26,7 +26,7 @@ use tokio::time::timeout; /// Wrapper function to handle request for HTTP/1.1 and HTTP/2 /// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler async fn serve_request( - mut req: Request, + req: Request, // handler: Arc>, handler: Arc>, client_addr: SocketAddr, From 0741990154d9980b6eae2da11036a3d106400e2d Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 22:32:49 +0900 Subject: [PATCH 19/50] wip: fix sync --- rpxy-lib/src/forwarder/client.rs | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 57538fa..576999b 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -35,7 +35,7 @@ pub struct Forwarder { impl ForwardRequest> for Forwarder where C: Send + Sync + Connect + Clone + 'static, - B1: Body + Send + Unpin + 'static, + B1: Body + Send + Sync + Unpin + 'static, ::Data: Send, ::Error: Into>, B2: Body, @@ -80,11 +80,7 @@ where let connector = HttpsConnector::new(); let executor = LocalExecutor::new(_globals.runtime_handle.clone()); - let inner_h2 = Client::builder(executor) - .http2_adaptive_window(true) - .http2_only(true) - .set_host(true) - .build::<_, B1>(connector); + let inner_h2 = Client::builder(executor).http2_only(true).build::<_, B1>(connector); // #[cfg(feature = "native-roots")] // let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); From 48a84a77cba7081edcb81f0839cb135061916da0 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Wed, 29 Nov 2023 17:24:07 +0900 Subject: [PATCH 20/50] implement native-tls client --- rpxy-lib/Cargo.toml | 9 ++++-- rpxy-lib/src/error.rs | 2 ++ rpxy-lib/src/forwarder/client.rs | 32 ++++++++++++++----- rpxy-lib/src/forwarder/mod.rs | 4 +-- rpxy-lib/src/lib.rs | 2 +- rpxy-lib/src/message_handler/handler_main.rs | 12 +++---- .../handler_manipulate_messages.rs | 24 ++++---------- rpxy-lib/src/message_handler/utils_request.rs | 31 +++++++++++++----- rpxy-lib/src/proxy/proxy_h3.rs | 8 ++--- rpxy-lib/src/proxy/proxy_main.rs | 19 ++++++----- rpxy-lib/src/proxy/proxy_quic_quinn.rs | 6 ++-- rpxy-lib/src/proxy/proxy_quic_s2n.rs | 10 +++--- 12 files changed, 90 insertions(+), 69 deletions(-) diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 5df1060..dd21b39 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -45,22 +45,25 @@ async-trait = "0.1.74" anyhow = "1.0.75" thiserror = "1.0.50" -# http +# http for both server and client http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } hyper-util = { version = "0.1.1", features = ["full"] } futures-util = { version = "0.3.29", default-features = false } futures-channel = { version = "0.3.29", default-features = false } + +# http client +hyper-tls = { version = "0.6.0", features = ["alpn"] } +tokio-native-tls = { version = "0.3.1" } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio", # "http1", # "http2", # ] } -hyper-tls = { version = "0.6.0", features = ["alpn"] } -# tls and cert management +# tls and cert management for server hot_reload = "0.1.4" rustls = { version = "0.21.9", default-features = false } tokio-rustls = { version = "0.24.1", features = ["early-data"] } diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 438e7bb..a19ca2c 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -79,6 +79,8 @@ pub enum RpxyError { FailedToCopyBidirectional(String), // Forwarder errors + #[error("Failed to build forwarder: {0}")] + FailedToBuildForwarder(String), #[error("Failed to fetch from upstream: {0}")] FailedToFetchFromUpstream(String), diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 576999b..aa89749 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -43,6 +43,8 @@ where type Error = RpxyError; async fn request(&self, req: Request) -> Result>, Self::Error> { + // TODO: cache handling + self.request_directly(req).await } } @@ -64,6 +66,7 @@ where } } +/// Build forwarder with hyper-tls (native-tls) impl Forwarder, B1> where B1: Body + Send + Unpin + 'static, @@ -71,16 +74,29 @@ where ::Error: Into>, { /// Build forwarder - pub async fn new(_globals: &Arc) -> Self { + pub async fn try_new(_globals: &Arc) -> RpxyResult { // build hyper client with hyper-tls - // TODO: Frame size errorが取れない > H2 どうしようもない。。。。 hyper_rustlsのリリース待ち? - let connector = HttpsConnector::new(); - let executor = LocalExecutor::new(_globals.runtime_handle.clone().clone()); + let executor = LocalExecutor::new(_globals.runtime_handle.clone()); + + let try_build_connector = |alpns: &[&str]| { + hyper_tls::native_tls::TlsConnector::builder() + .request_alpns(alpns) + .build() + .map_err(|e| RpxyError::FailedToBuildForwarder(e.to_string())) + .map(|tls| { + let mut http = HttpConnector::new(); + http.enforce_http(false); + HttpsConnector::from((http, tls.into())) + }) + }; + + let connector = try_build_connector(&["h2", "http/1.1"])?; let inner = Client::builder(executor.clone()).build::<_, B1>(connector); - let connector = HttpsConnector::new(); - let executor = LocalExecutor::new(_globals.runtime_handle.clone()); - let inner_h2 = Client::builder(executor).http2_only(true).build::<_, B1>(connector); + let connector_h2 = try_build_connector(&["h2"])?; + let inner_h2 = Client::builder(executor.clone()) + .http2_only(true) + .build::<_, B1>(connector_h2); // #[cfg(feature = "native-roots")] // let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); @@ -108,6 +124,6 @@ where // Self { inner, inner_h2, cache } // } // #[cfg(not(feature = "cache"))] - Self { inner, inner_h2 } + Ok(Self { inner, inner_h2 }) } } diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs index 1cb67fd..13d37eb 100644 --- a/rpxy-lib/src/forwarder/mod.rs +++ b/rpxy-lib/src/forwarder/mod.rs @@ -1,8 +1,6 @@ mod client; use crate::hyper_ext::body::{IncomingLike, IncomingOr}; -use hyper_tls::HttpsConnector; -use hyper_util::client::legacy::connect::HttpConnector; -pub type Forwarder = client::Forwarder, IncomingOr>; +pub type Forwarder = client::Forwarder>; pub use client::ForwardRequest; diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index c78e8c8..da2cabc 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -91,7 +91,7 @@ where }); // 4. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well - let forwarder = Arc::new(Forwarder::new(&globals).await); + let forwarder = Arc::new(Forwarder::try_new(&globals).await?); let message_handler = Arc::new( HttpMessageHandlerBuilder::default() .globals(globals.clone()) diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index d94d2c9..a9fae01 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -17,7 +17,7 @@ use crate::{ }; use derive_builder::Builder; use http::{Request, Response, StatusCode}; -use hyper_util::rt::TokioIo; +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo}; use std::{net::SocketAddr, sync::Arc}; use tokio::{io::copy_bidirectional, time::timeout}; @@ -34,19 +34,19 @@ pub(super) struct HandlerContext { #[derive(Clone, Builder)] /// HTTP message handler for requests from clients and responses from backend applications, /// responsible to manipulate and forward messages to upstream backends and downstream clients. -// pub struct HttpMessageHandler -pub struct HttpMessageHandler +pub struct HttpMessageHandler where + C: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone, { - forwarder: Arc, + forwarder: Arc>, pub(super) globals: Arc, app_manager: Arc>, } -impl HttpMessageHandler +impl HttpMessageHandler where - // T: Connect + Clone + Sync + Send + 'static, + C: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone, { /// Handle incoming request message from a client. diff --git a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs index a33e58f..46e572c 100644 --- a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs +++ b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs @@ -1,7 +1,4 @@ -use super::{ - handler_main::HandlerContext, utils_headers::*, utils_request::apply_upstream_options_to_request_line, - HttpMessageHandler, -}; +use super::{handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line, HttpMessageHandler}; use crate::{ backend::{BackendApp, UpstreamCandidates}, constants::RESPONSE_HEADER_SERVER, @@ -9,11 +6,13 @@ use crate::{ CryptoSource, }; use anyhow::{anyhow, ensure, Result}; -use http::{header, uri::Scheme, HeaderValue, Request, Response, Uri, Version}; +use http::{header, HeaderValue, Request, Response, Uri}; +use hyper_util::client::legacy::connect::Connect; use std::net::SocketAddr; -impl HttpMessageHandler +impl HttpMessageHandler where + C: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone, { //////////////////////////////////////////////////// @@ -177,18 +176,7 @@ where .insert(header::CONNECTION, HeaderValue::from_static("upgrade")); } - // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 - if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { - // Change version to http/1.1 when destination scheme is http - debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); - *req.version_mut() = Version::HTTP_11; - } else if req.version() == Version::HTTP_3 { - // HTTP/3 is always https - debug!("HTTP/3 is currently unsupported for request to upstream."); - *req.version_mut() = Version::HTTP_2; - } - - apply_upstream_options_to_request_line(req, upstream_candidates)?; + update_request_line(req, upstream_chosen, upstream_candidates)?; Ok(context) } diff --git a/rpxy-lib/src/message_handler/utils_request.rs b/rpxy-lib/src/message_handler/utils_request.rs index aa4ce42..8939433 100644 --- a/rpxy-lib/src/message_handler/utils_request.rs +++ b/rpxy-lib/src/message_handler/utils_request.rs @@ -1,6 +1,9 @@ -use crate::backend::{UpstreamCandidates, UpstreamOption}; +use crate::{ + backend::{Upstream, UpstreamCandidates, UpstreamOption}, + log::*, +}; use anyhow::{anyhow, ensure, Result}; -use http::{header, Request}; +use http::{header, uri::Scheme, Request, Version}; /// Trait defining parser of hostname /// Inspect and extract hostname from either the request HOST header or request line @@ -50,18 +53,30 @@ impl InspectParseHost for Request { //////////////////////////////////////////////////// // Functions to manipulate request line -/// Apply upstream options in request line, specified in the configuration -pub(super) fn apply_upstream_options_to_request_line( +/// Update request line, e.g., version, and apply upstream options to request line, specified in the configuration +pub(super) fn update_request_line( req: &mut Request, - upstream: &UpstreamCandidates, + upstream_chosen: &Upstream, + upstream_candidates: &UpstreamCandidates, ) -> anyhow::Result<()> { - for opt in upstream.options.iter() { + // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 + if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { + // Change version to http/1.1 when destination scheme is http + debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); + *req.version_mut() = Version::HTTP_11; + } else if req.version() == Version::HTTP_3 { + // HTTP/3 is always https + debug!("HTTP/3 is currently unsupported for request to upstream."); + *req.version_mut() = Version::HTTP_2; + } + + for opt in upstream_candidates.options.iter() { match opt { - UpstreamOption::ForceHttp11Upstream => *req.version_mut() = http::Version::HTTP_11, + UpstreamOption::ForceHttp11Upstream => *req.version_mut() = Version::HTTP_11, UpstreamOption::ForceHttp2Upstream => { // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required. - *req.version_mut() = http::Version::HTTP_2; + *req.version_mut() = Version::HTTP_2; } _ => (), } diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 5b77263..813eaa8 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -9,6 +9,7 @@ use crate::{ use bytes::{Buf, Bytes}; use http::{Request, Response}; use http_body_util::BodyExt; +use hyper_util::client::legacy::connect::Connect; use std::{net::SocketAddr, time::Duration}; use tokio::time::timeout; @@ -17,12 +18,9 @@ use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestSt #[cfg(feature = "http3-s2n")] use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -// use futures::Stream; -// use hyper_util::client::legacy::connect::Connect; - -impl Proxy +impl Proxy where - // T: Connect + Clone + Sync + Send + 'static, + T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn h3_serve_connection( diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 28ad76e..9c8e3a5 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -19,23 +19,22 @@ use hyper::{ rt::{Read, Write}, service::service_fn, }; -use hyper_util::{rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; use std::{net::SocketAddr, sync::Arc, time::Duration}; use tokio::time::timeout; /// Wrapper function to handle request for HTTP/1.1 and HTTP/2 /// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler -async fn serve_request( +async fn serve_request( req: Request, - // handler: Arc>, - handler: Arc>, + handler: Arc>, client_addr: SocketAddr, listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, ) -> RpxyResult>> where - // T: Connect + Clone + Sync + Send + 'static, + T: Send + Sync + Connect + Clone, U: CryptoSource + Clone, { handler @@ -51,9 +50,9 @@ where #[derive(Clone)] /// Proxy main object responsible to serve requests received from clients at the given socket address. -pub(crate) struct Proxy +pub(crate) struct Proxy where - // T: Connect + Clone + Sync + Send + 'static, + T: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { /// global context shared among async tasks @@ -65,12 +64,12 @@ where /// hyper connection builder serving http request pub connection_builder: Arc>, /// message handler serving incoming http request - pub message_handler: Arc>, + pub message_handler: Arc>, } -impl Proxy +impl Proxy where - // T: Connect + Clone + Sync + Send + 'static, + T: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { /// Serves requests from clients diff --git a/rpxy-lib/src/proxy/proxy_quic_quinn.rs b/rpxy-lib/src/proxy/proxy_quic_quinn.rs index 8380f6e..9c4bf4e 100644 --- a/rpxy-lib/src/proxy/proxy_quic_quinn.rs +++ b/rpxy-lib/src/proxy/proxy_quic_quinn.rs @@ -6,14 +6,14 @@ use crate::{ log::*, name_exp::ByteName, }; -// use hyper_util::client::legacy::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; use rustls::ServerConfig; use std::sync::Arc; -impl Proxy +impl Proxy where - // T: Connect + Clone + Sync + Send + 'static, + T: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { diff --git a/rpxy-lib/src/proxy/proxy_quic_s2n.rs b/rpxy-lib/src/proxy/proxy_quic_s2n.rs index 3ab41d0..13a8802 100644 --- a/rpxy-lib/src/proxy/proxy_quic_s2n.rs +++ b/rpxy-lib/src/proxy/proxy_quic_s2n.rs @@ -1,18 +1,20 @@ use super::proxy_main::Proxy; use crate::{ + crypto::CryptoSource, crypto::{ServerCrypto, ServerCryptoBase}, error::*, log::*, name_exp::ByteName, }; +use anyhow::anyhow; use hot_reload::ReloaderReceiver; -use std::sync::Arc; -// use hyper_util::client::legacy::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use s2n_quic::provider; +use std::sync::Arc; -impl Proxy +impl Proxy where - // T: Connect + Clone + Sync + Send + 'static, + T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { /// Start UDP proxy serving with HTTP/3 request for configured host names From a6f9fc7065c292d19deaf264cd171b77c819a3e9 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Wed, 29 Nov 2023 17:56:34 +0900 Subject: [PATCH 21/50] remove unneccessary deps --- rpxy-lib/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index dd21b39..67f217a 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -55,7 +55,6 @@ futures-channel = { version = "0.3.29", default-features = false } # http client hyper-tls = { version = "0.6.0", features = ["alpn"] } -tokio-native-tls = { version = "0.3.1" } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio", From deb4c2850e6bc8055bdbce90176fc5dc962844d7 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Thu, 30 Nov 2023 10:39:38 +0900 Subject: [PATCH 22/50] wip: set_reuse_addr for client --- rpxy-lib/src/forwarder/client.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index aa89749..c587be9 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -86,6 +86,7 @@ where .map(|tls| { let mut http = HttpConnector::new(); http.enforce_http(false); + http.set_reuse_address(true); HttpsConnector::from((http, tls.into())) }) }; From 2a48c64ff41a10c6562efbcf627d4d79feb4a4aa Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Mon, 4 Dec 2023 18:08:14 +0900 Subject: [PATCH 23/50] deps --- rpxy-bin/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 00c2ef4..50fb549 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -39,7 +39,7 @@ rustls-pemfile = "1.0.4" mimalloc = { version = "*", default-features = false } # config -clap = { version = "4.4.8", features = ["std", "cargo", "wrap_help"] } +clap = { version = "4.4.10", features = ["std", "cargo", "wrap_help"] } toml = { version = "0.8.8", default-features = false, features = ["parse"] } hot_reload = "0.1.4" From f58ce97f1a29b99f920e2e549ad7ec9960ecdd97 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Mon, 4 Dec 2023 23:30:18 +0900 Subject: [PATCH 24/50] chore: empty feature rustls --- rpxy-bin/Cargo.toml | 1 + rpxy-lib/Cargo.toml | 8 ++-- rpxy-lib/src/forwarder/client.rs | 66 ++++++++++++++++++++++++++------ 3 files changed, 61 insertions(+), 14 deletions(-) diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 50fb549..1512b7d 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -21,6 +21,7 @@ native-roots = ["rpxy-lib/native-roots"] [dependencies] rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ "sticky-cookie", + "native-tls-backend", ] } anyhow = "1.0.75" diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 67f217a..7c6cf24 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -12,7 +12,7 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "sticky-cookie", "cache"] +default = ["http3-quinn", "sticky-cookie", "cache", "native-tls-backend"] http3-quinn = ["socket2", "quinn", "h3", "h3-quinn"] http3-s2n = [ "h3", @@ -22,6 +22,8 @@ http3-s2n = [ "s2n-quic-h3", ] sticky-cookie = ["base64", "sha2", "chrono"] +native-tls-backend = ["hyper-tls"] +rustls-backend = [] cache = [] #"http-cache-semantics", "lru"] native-roots = [] #"hyper-rustls/native-tokio"] @@ -53,8 +55,8 @@ hyper-util = { version = "0.1.1", features = ["full"] } futures-util = { version = "0.3.29", default-features = false } futures-channel = { version = "0.3.29", default-features = false } -# http client -hyper-tls = { version = "0.6.0", features = ["alpn"] } +# http client for upstream +hyper-tls = { version = "0.6.0", features = ["alpn"], optional = true } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio", diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index c587be9..3d0f995 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -5,11 +5,11 @@ use crate::{ body::{wrap_incoming_body_response, IncomingOr}, rt::LocalExecutor, }, + log::*, }; use async_trait::async_trait; use http::{Request, Response, Version}; use hyper::body::Body; -use hyper_tls::HttpsConnector; use hyper_util::client::legacy::{ connect::{Connect, HttpConnector}, Client, @@ -66,8 +66,38 @@ where } } +#[cfg(not(any(feature = "native-tls-backend", feature = "rustls-backend")))] +impl Forwarder +where + B: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + /// Build inner client with http + pub fn try_new(_globals: &Arc) -> RpxyResult { + warn!( + " +-------------------------------------------------------------------------------------------------- +Request forwarder is working without TLS support!!! +We recommend to use this just for testing. +Please enable native-tls-backend or rustls-backend feature to enable TLS support. +--------------------------------------------------------------------------------------------------" + ); + let executor = LocalExecutor::new(_globals.runtime_handle.clone()); + let mut http = HttpConnector::new(); + http.set_reuse_address(true); + let inner = Client::builder(executor).build::<_, B>(http); + + Ok(Self { + inner, + inner_h2: inner.clone(), + }) + } +} + +#[cfg(feature = "native-tls-backend")] /// Build forwarder with hyper-tls (native-tls) -impl Forwarder, B1> +impl Forwarder, B1> where B1: Body + Send + Unpin + 'static, ::Data: Send, @@ -76,6 +106,7 @@ where /// Build forwarder pub async fn try_new(_globals: &Arc) -> RpxyResult { // build hyper client with hyper-tls + info!("Native TLS support is enabled for the connection to backend applications"); let executor = LocalExecutor::new(_globals.runtime_handle.clone()); let try_build_connector = |alpns: &[&str]| { @@ -87,7 +118,7 @@ where let mut http = HttpConnector::new(); http.enforce_http(false); http.set_reuse_address(true); - HttpsConnector::from((http, tls.into())) + hyper_tls::HttpsConnector::from((http, tls.into())) }) }; @@ -99,6 +130,27 @@ where .http2_only(true) .build::<_, B1>(connector_h2); + // #[cfg(feature = "cache")] + // { + // let cache = RpxyCache::new(_globals).await; + // Self { inner, inner_h2, cache } + // } + // #[cfg(not(feature = "cache"))] + Ok(Self { inner, inner_h2 }) + } +} + +#[cfg(feature = "rustls-backend")] +/// Build forwarder with hyper-rustls (rustls) +impl Forwarder, B1> +where + B1: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + /// Build forwarder + pub async fn try_new(_globals: &Arc) -> RpxyResult { + todo!("Not implemented yet. Please use native-tls-backend feature for now."); // #[cfg(feature = "native-roots")] // let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); // #[cfg(feature = "native-roots")] @@ -118,13 +170,5 @@ where // let inner = Client::builder().build::<_, Body>(connector); // let inner_h2 = Client::builder().http2_only(true).build::<_, Body>(connector_h2); - - // #[cfg(feature = "cache")] - // { - // let cache = RpxyCache::new(_globals).await; - // Self { inner, inner_h2, cache } - // } - // #[cfg(not(feature = "cache"))] - Ok(Self { inner, inner_h2 }) } } From 4aa149a2611466e63ec8da257075652c667180cc Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 8 Dec 2023 18:13:09 +0900 Subject: [PATCH 25/50] deps except for rustls family --- rpxy-bin/Cargo.toml | 2 +- rpxy-lib/Cargo.toml | 6 +++--- submodules/s2n-quic-h3/Cargo.toml | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 1512b7d..a6e5720 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -40,7 +40,7 @@ rustls-pemfile = "1.0.4" mimalloc = { version = "*", default-features = false } # config -clap = { version = "4.4.10", features = ["std", "cargo", "wrap_help"] } +clap = { version = "4.4.11", features = ["std", "cargo", "wrap_help"] } toml = { version = "0.8.8", default-features = false, features = ["parse"] } hot_reload = "0.1.4" diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 7c6cf24..92449ab 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -78,12 +78,12 @@ tracing = { version = "0.1.40" } quinn = { version = "0.10.2", optional = true } h3 = { path = "../submodules/h3/h3/", optional = true } h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } -s2n-quic = { version = "1.31.0", default-features = false, features = [ +s2n-quic = { version = "1.32.0", default-features = false, features = [ "provider-tls-rustls", ], optional = true } -s2n-quic-core = { version = "0.31.0", default-features = false, optional = true } +s2n-quic-core = { version = "0.32.0", default-features = false, optional = true } s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } -s2n-quic-rustls = { version = "0.31.0", optional = true } +s2n-quic-rustls = { version = "0.32.0", optional = true } # for UDP socket wit SO_REUSEADDR when h3 with quinn socket2 = { version = "0.5.5", features = ["all"], optional = true } diff --git a/submodules/s2n-quic-h3/Cargo.toml b/submodules/s2n-quic-h3/Cargo.toml index fecfd10..e863ca5 100644 --- a/submodules/s2n-quic-h3/Cargo.toml +++ b/submodules/s2n-quic-h3/Cargo.toml @@ -13,5 +13,5 @@ publish = false bytes = { version = "1", default-features = false } futures = { version = "0.3", default-features = false } h3 = { path = "../h3/h3/" } -s2n-quic = "1.31.0" -s2n-quic-core = "0.31.0" +s2n-quic = "1.32.0" +s2n-quic-core = "0.32.0" From f7142828ac05c3b3113c32f9a3563458405470e6 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 9 Dec 2023 00:17:59 +0900 Subject: [PATCH 26/50] chore: prioritize http3-quinn over http3-s2n when both features are enabled (avoid compile error) --- rpxy-bin/src/main.rs | 7 ++----- rpxy-lib/src/crypto/service.rs | 8 ++++---- rpxy-lib/src/error.rs | 6 +++--- rpxy-lib/src/lib.rs | 6 +++--- rpxy-lib/src/proxy/mod.rs | 2 +- rpxy-lib/src/proxy/proxy_h3.rs | 2 +- 6 files changed, 14 insertions(+), 17 deletions(-) diff --git a/rpxy-bin/src/main.rs b/rpxy-bin/src/main.rs index f04a6f1..9aeb971 100644 --- a/rpxy-bin/src/main.rs +++ b/rpxy-bin/src/main.rs @@ -15,9 +15,6 @@ use crate::{ use hot_reload::{ReloaderReceiver, ReloaderService}; use rpxy_lib::entrypoint; -#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] -compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); - fn main() { init_logger(); @@ -29,8 +26,8 @@ fn main() { runtime.block_on(async { // Initially load options let Ok(parsed_opts) = parse_opts() else { - error!("Invalid toml file"); - std::process::exit(1); + error!("Invalid toml file"); + std::process::exit(1); }; if !parsed_opts.watch { diff --git a/rpxy-lib/src/crypto/service.rs b/rpxy-lib/src/crypto/service.rs index 0736b0e..8eda27a 100644 --- a/rpxy-lib/src/crypto/service.rs +++ b/rpxy-lib/src/crypto/service.rs @@ -22,7 +22,7 @@ pub struct ServerCrypto { // For Quic/HTTP3, only servers with no client authentication #[cfg(feature = "http3-quinn")] pub inner_global_no_client_auth: Arc, - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] pub inner_global_no_client_auth: s2n_quic_rustls::Server, // For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers pub inner_local_map: Arc, @@ -74,7 +74,7 @@ impl TryInto> for &ServerCryptoBase { Ok(Arc::new(ServerCrypto { #[cfg(feature = "http3-quinn")] inner_global_no_client_auth: Arc::new(server_crypto_global), - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] inner_global_no_client_auth: server_crypto_global, inner_local_map: Arc::new(server_crypto_local_map), })) @@ -200,7 +200,7 @@ impl ServerCryptoBase { Ok(server_crypto_global) } - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] fn build_server_crypto_global(&self) -> Result> { let mut resolver_global = s2n_quic_rustls::rustls::server::ResolvesServerCertUsingSni::new(); @@ -241,7 +241,7 @@ impl ServerCryptoBase { } } -#[cfg(feature = "http3-s2n")] +#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] /// This is workaround for the version difference between rustls and s2n-quic-rustls fn parse_server_certs_and_keys_s2n( certs_and_keys: &CertsAndKeys, diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index a19ca2c..843845d 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -44,13 +44,13 @@ pub enum RpxyError { #[error("Quinn connection error: {0}")] QuinnConnectionFailed(#[from] quinn::ConnectionError), - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] #[error("s2n-quic validation error: {0}")] S2nQuicValidationError(#[from] s2n_quic_core::transport::parameters::ValidationError), - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] #[error("s2n-quic connection error: {0}")] S2nQuicConnectionError(#[from] s2n_quic_core::connection::Error), - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] #[error("s2n-quic start error: {0}")] S2nQuicStartError(#[from] s2n_quic::provider::StartError), diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index da2cabc..4f66634 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -27,9 +27,6 @@ pub mod reexports { pub use rustls::{Certificate, PrivateKey}; } -#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] -compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); - /// Entrypoint that creates and spawns tasks of reverse proxy services pub async fn entrypoint( proxy_config: &ProxyConfig, @@ -40,6 +37,9 @@ pub async fn entrypoint( where T: CryptoSource + Clone + Send + Sync + 'static, { + #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] + warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used"); + // For initial message logging if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { info!("Listen both IPv4 and IPv6") diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index e4ac6f7..389df43 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -2,7 +2,7 @@ mod proxy_h3; mod proxy_main; #[cfg(feature = "http3-quinn")] mod proxy_quic_quinn; -#[cfg(feature = "http3-s2n")] +#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] mod proxy_quic_s2n; mod socket; diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 813eaa8..8abb710 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -15,7 +15,7 @@ use tokio::time::timeout; #[cfg(feature = "http3-quinn")] use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -#[cfg(feature = "http3-s2n")] +#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; impl Proxy From 6030bebac5a2dcf004d401f717871c2ba1824c95 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 9 Dec 2023 00:41:32 +0900 Subject: [PATCH 27/50] chore: prioritize rustls-backend while it is not supported (non-default) --- rpxy-lib/src/forwarder/client.rs | 2 +- rpxy-lib/src/lib.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 3d0f995..22c2320 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -95,7 +95,7 @@ Please enable native-tls-backend or rustls-backend feature to enable TLS support } } -#[cfg(feature = "native-tls-backend")] +#[cfg(all(feature = "native-tls-backend", not(feature = "rustls-backend")))] /// Build forwarder with hyper-tls (native-tls) impl Forwarder, B1> where diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 4f66634..6336e19 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -40,6 +40,9 @@ where #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used"); + #[cfg(all(feature = "native-tls-backend", feature = "rustls-backend"))] + warn!("Both \"native-tls-backend\" and \"rustls-backend\" features are enabled. \"rustls-backend\" will be used"); + // For initial message logging if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { info!("Listen both IPv4 and IPv6") From f5197d08692e59795f5ecf2f00bc90c109ce60c7 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 9 Dec 2023 03:34:00 +0900 Subject: [PATCH 28/50] wip: refactoring the cache logic --- rpxy-lib/Cargo.toml | 8 +- rpxy-lib/src/error.rs | 9 ++ rpxy-lib/src/forwarder/cache.rs | 161 ++++++++++++++++++++++++++ rpxy-lib/src/forwarder/client.rs | 102 +++++++++++++--- rpxy-lib/src/forwarder/mod.rs | 1 + submodules/rusty-http-cache-semantics | 2 +- 6 files changed, 261 insertions(+), 22 deletions(-) create mode 100644 rpxy-lib/src/forwarder/cache.rs diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 92449ab..fe715e0 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -24,7 +24,7 @@ http3-s2n = [ sticky-cookie = ["base64", "sha2", "chrono"] native-tls-backend = ["hyper-tls"] rustls-backend = [] -cache = [] #"http-cache-semantics", "lru"] +cache = ["http-cache-semantics", "lru"] native-roots = [] #"hyper-rustls/native-tokio"] [dependencies] @@ -66,7 +66,7 @@ hyper-tls = { version = "0.6.0", features = ["alpn"], optional = true } # tls and cert management for server hot_reload = "0.1.4" -rustls = { version = "0.21.9", default-features = false } +rustls = { version = "0.21.10", default-features = false } tokio-rustls = { version = "0.24.1", features = ["early-data"] } webpki = "0.22.4" x509-parser = "0.15.1" @@ -88,8 +88,8 @@ s2n-quic-rustls = { version = "0.32.0", optional = true } socket2 = { version = "0.5.5", features = ["all"], optional = true } # # cache -# http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } -# lru = { version = "0.12.1", optional = true } +http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } +lru = { version = "0.12.1", optional = true } # cookie handling for sticky cookie chrono = { version = "0.4.31", default-features = false, features = [ diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 843845d..4cbc463 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -84,6 +84,15 @@ pub enum RpxyError { #[error("Failed to fetch from upstream: {0}")] FailedToFetchFromUpstream(String), + // Cache errors, + #[cfg(feature = "cache")] + #[error("Invalid null request and/or response")] + NullRequestOrResponse, + + #[cfg(feature = "cache")] + #[error("Failed to write byte buffer")] + FailedToWriteByteBufferForCache, + // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, diff --git a/rpxy-lib/src/forwarder/cache.rs b/rpxy-lib/src/forwarder/cache.rs new file mode 100644 index 0000000..73bed7b --- /dev/null +++ b/rpxy-lib/src/forwarder/cache.rs @@ -0,0 +1,161 @@ +use crate::{error::*, globals::Globals, log::*}; +use http::{Request, Response}; +use http_cache_semantics::CachePolicy; +use lru::LruCache; +use std::{ + path::{Path, PathBuf}, + sync::{atomic::AtomicUsize, Arc, Mutex}, +}; +use tokio::{fs, sync::RwLock}; + +/* ---------------------------------------------- */ +#[derive(Clone, Debug)] +pub struct RpxyCache { + /// Lru cache storing http message caching policy + inner: LruCacheManager, + /// Managing cache file objects through RwLock's lock mechanism for file lock + file_store: FileStore, + /// Async runtime + runtime_handle: tokio::runtime::Handle, + /// Maximum size of each cache file object + max_each_size: usize, + /// Maximum size of cache object on memory + max_each_size_on_memory: usize, +} + +impl RpxyCache { + /// Generate cache storage + pub async fn new(globals: &Globals) -> Option { + if !globals.proxy_config.cache_enabled { + return None; + } + let path = globals.proxy_config.cache_dir.as_ref().unwrap(); + let file_store = FileStore::new(path, &globals.runtime_handle).await; + let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry); + + let max_each_size = globals.proxy_config.cache_max_each_size; + let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory; + if max_each_size < max_each_size_on_memory { + warn!( + "Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache" + ); + max_each_size_on_memory = max_each_size; + } + + Some(Self { + file_store, + inner, + runtime_handle: globals.runtime_handle.clone(), + max_each_size, + max_each_size_on_memory, + }) + } +} + +/* ---------------------------------------------- */ +#[derive(Debug, Clone)] +/// Cache file manager outer that is responsible to handle `RwLock` +struct FileStore { + inner: Arc>, +} +impl FileStore { + /// Build manager + async fn new(path: impl AsRef, runtime_handle: &tokio::runtime::Handle) -> Self { + Self { + inner: Arc::new(RwLock::new(FileStoreInner::new(path, runtime_handle).await)), + } + } +} + +#[derive(Debug)] +/// Manager inner for cache on file system +struct FileStoreInner { + /// Directory of temporary files + cache_dir: PathBuf, + /// Counter of current cached files + cnt: usize, + /// Async runtime + runtime_handle: tokio::runtime::Handle, +} + +impl FileStoreInner { + /// Build new cache file manager. + /// This first creates cache file dir if not exists, and cleans up the file inside the directory. + /// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed. + async fn new(path: impl AsRef, runtime_handle: &tokio::runtime::Handle) -> Self { + let path_buf = path.as_ref().to_path_buf(); + if let Err(e) = fs::remove_dir_all(path).await { + warn!("Failed to clean up the cache dir: {e}"); + }; + fs::create_dir_all(&path_buf).await.unwrap(); + Self { + cache_dir: path_buf.clone(), + cnt: 0, + runtime_handle: runtime_handle.clone(), + } + } +} + +/* ---------------------------------------------- */ + +#[derive(Clone, Debug)] +/// Cache target in hybrid manner of on-memory and file system +pub enum CacheFileOrOnMemory { + /// Pointer to the temporary cache file + File(PathBuf), + /// Cached body itself + OnMemory(Vec), +} + +#[derive(Clone, Debug)] +/// Cache object definition +struct CacheObject { + /// Cache policy to determine if the stored cache can be used as a response to a new incoming request + pub policy: CachePolicy, + /// Cache target: on-memory object or temporary file + pub target: CacheFileOrOnMemory, + /// SHA256 hash of target to strongly bind the cache metadata (this object) and file target + pub hash: Vec, +} + +/* ---------------------------------------------- */ +#[derive(Debug, Clone)] +/// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache` +struct LruCacheManager { + inner: Arc>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう + cnt: Arc, +} + +impl LruCacheManager { + /// Build LruCache + fn new(cache_max_entry: usize) -> Self { + Self { + inner: Arc::new(Mutex::new(LruCache::new( + std::num::NonZeroUsize::new(cache_max_entry).unwrap(), + ))), + cnt: Arc::new(AtomicUsize::default()), + } + } +} + +/* ---------------------------------------------- */ +pub fn get_policy_if_cacheable( + req: Option<&Request>, + res: Option<&Response>, +) -> RpxyResult> +// where +// B1: core::fmt::Debug, +{ + // deduce cache policy from req and res + let (Some(req), Some(res)) = (req, res) else { + return Err(RpxyError::NullRequestOrResponse); + }; + + let new_policy = CachePolicy::new(req, res); + if new_policy.is_storable() { + // debug!("Response is cacheable: {:?}\n{:?}", req, res.headers()); + Ok(Some(new_policy)) + } else { + Ok(None) + } +} diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 22c2320..820523e 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -9,13 +9,20 @@ use crate::{ }; use async_trait::async_trait; use http::{Request, Response, Version}; -use hyper::body::Body; +use hyper::body::{Body, Incoming}; use hyper_util::client::legacy::{ connect::{Connect, HttpConnector}, Client, }; use std::sync::Arc; +#[cfg(feature = "cache")] +use super::cache::{get_policy_if_cacheable, RpxyCache}; +#[cfg(feature = "cache")] +use crate::hyper_ext::body::{full, BoxBody}; +#[cfg(feature = "cache")] +use http_body_util::BodyExt; + #[async_trait] /// Definition of the forwarder that simply forward requests from downstream client to upstream app servers. pub trait ForwardRequest { @@ -25,27 +32,71 @@ pub trait ForwardRequest { /// Forwarder http client struct responsible to cache handling pub struct Forwarder { - // #[cfg(feature = "cache")] - // cache: Option, + #[cfg(feature = "cache")] + cache: Option, inner: Client, inner_h2: Client, // `h2c` or http/2-only client is defined separately } #[async_trait] -impl ForwardRequest> for Forwarder +impl ForwardRequest> for Forwarder where C: Send + Sync + Connect + Clone + 'static, B1: Body + Send + Sync + Unpin + 'static, ::Data: Send, ::Error: Into>, - B2: Body, { type Error = RpxyError; - async fn request(&self, req: Request) -> Result>, Self::Error> { + async fn request(&self, req: Request) -> Result>, Self::Error> { // TODO: cache handling + #[cfg(feature = "cache")] + { + let mut synth_req = None; + if self.cache.is_some() { + // if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { + // // if found, return it as response. + // info!("Cache hit - Return from cache"); + // return Ok(cached_response); + // }; - self.request_directly(req).await + // Synthetic request copy used just for caching (cannot clone request object...) + synth_req = Some(build_synth_req_for_cache(&req)); + } + let res = self.request_directly(req).await; + + if self.cache.is_none() { + return res.map(wrap_incoming_body_response::); + } + + // check cacheability and store it if cacheable + let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { + return res.map(wrap_incoming_body_response::); + }; + let (parts, body) = res.unwrap().into_parts(); + let Ok(bytes) = body.collect().await.map(|v| v.to_bytes()) else { + return Err(RpxyError::FailedToWriteByteBufferForCache); + }; + + // if let Err(cache_err) = self + // .cache + // .as_ref() + // .unwrap() + // .put(synth_req.unwrap().uri(), &bytes, &cache_policy) + // .await + // { + // error!("{:?}", cache_err); + // }; + + // response with cached body + Ok(Response::from_parts(parts, IncomingOr::Right(full(bytes)))) + } + + // No cache handling + #[cfg(not(feature = "cache"))] + { + self.request_directly(req).await.map(wrap_incoming_body_response::) + } } } @@ -56,13 +107,15 @@ where ::Data: Send, ::Error: Into>, { - async fn request_directly(&self, req: Request) -> RpxyResult>> { + async fn request_directly(&self, req: Request) -> RpxyResult> { + // TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient. + // Needs to be reconsidered. Currently, this is a kind of work around. + // This possibly relates to https://github.com/hyperium/hyper/issues/2417. match req.version() { Version::HTTP_2 => self.inner_h2.request(req).await, // handles `h2c` requests _ => self.inner.request(req).await, } .map_err(|e| RpxyError::FailedToFetchFromUpstream(e.to_string())) - .map(wrap_incoming_body_response::) } } @@ -90,7 +143,9 @@ Please enable native-tls-backend or rustls-backend feature to enable TLS support Ok(Self { inner, - inner_h2: inner.clone(), + inner_h2, + #[cfg(feature = "cache")] + cache: RpxyCache::new(_globals).await, }) } } @@ -130,13 +185,12 @@ where .http2_only(true) .build::<_, B1>(connector_h2); - // #[cfg(feature = "cache")] - // { - // let cache = RpxyCache::new(_globals).await; - // Self { inner, inner_h2, cache } - // } - // #[cfg(not(feature = "cache"))] - Ok(Self { inner, inner_h2 }) + Ok(Self { + inner, + inner_h2, + #[cfg(feature = "cache")] + cache: RpxyCache::new(_globals).await, + }) } } @@ -172,3 +226,17 @@ where // let inner_h2 = Client::builder().http2_only(true).build::<_, Body>(connector_h2); } } + +#[cfg(feature = "cache")] +/// Build synthetic request to cache +fn build_synth_req_for_cache(req: &Request) -> Request<()> { + let mut builder = Request::builder() + .method(req.method()) + .uri(req.uri()) + .version(req.version()); + // TODO: omit extensions. is this approach correct? + for (header_key, header_value) in req.headers() { + builder = builder.header(header_key, header_value); + } + builder.body(()).unwrap() +} diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs index 13d37eb..e901c7d 100644 --- a/rpxy-lib/src/forwarder/mod.rs +++ b/rpxy-lib/src/forwarder/mod.rs @@ -1,3 +1,4 @@ +mod cache; mod client; use crate::hyper_ext::body::{IncomingLike, IncomingOr}; diff --git a/submodules/rusty-http-cache-semantics b/submodules/rusty-http-cache-semantics index 3cd0917..88d23c2 160000 --- a/submodules/rusty-http-cache-semantics +++ b/submodules/rusty-http-cache-semantics @@ -1 +1 @@ -Subproject commit 3cd09170305753309d86e88b9427827cca0de0dd +Subproject commit 88d23c2f5a3ac36295dff4a804968c43932ba46b From cdcb1b13dacc88fb004d1166070f05f087434fed Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 9 Dec 2023 03:41:32 +0900 Subject: [PATCH 29/50] wip: chore: fix bug for unused --- rpxy-lib/src/forwarder/client.rs | 11 +++++++---- rpxy-lib/src/forwarder/mod.rs | 1 + rpxy-lib/src/hyper_ext/mod.rs | 1 + 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 820523e..9aab75a 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -2,7 +2,7 @@ use crate::{ error::{RpxyError, RpxyResult}, globals::Globals, hyper_ext::{ - body::{wrap_incoming_body_response, IncomingOr}, + body::{wrap_incoming_body_response, BoxBody, IncomingOr}, rt::LocalExecutor, }, log::*, @@ -19,7 +19,7 @@ use std::sync::Arc; #[cfg(feature = "cache")] use super::cache::{get_policy_if_cacheable, RpxyCache}; #[cfg(feature = "cache")] -use crate::hyper_ext::body::{full, BoxBody}; +use crate::hyper_ext::body::{full, wrap_synthetic_body_response}; #[cfg(feature = "cache")] use http_body_util::BodyExt; @@ -89,13 +89,16 @@ where // }; // response with cached body - Ok(Response::from_parts(parts, IncomingOr::Right(full(bytes)))) + Ok(wrap_synthetic_body_response(Response::from_parts(parts, full(bytes)))) } // No cache handling #[cfg(not(feature = "cache"))] { - self.request_directly(req).await.map(wrap_incoming_body_response::) + self + .request_directly(req) + .await + .map(wrap_incoming_body_response::) } } } diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs index e901c7d..286cb40 100644 --- a/rpxy-lib/src/forwarder/mod.rs +++ b/rpxy-lib/src/forwarder/mod.rs @@ -1,3 +1,4 @@ +#[cfg(feature = "cache")] mod cache; mod client; diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index e1b5ae8..e6c81e7 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -8,6 +8,7 @@ pub(crate) mod rt { } pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; + #[allow(unused)] pub(crate) use super::body_type::{ empty, full, wrap_incoming_body_response, wrap_synthetic_body_response, BoxBody, IncomingOr, }; From d473b44556ed051ae812aa81de79a3cad2872a92 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 9 Dec 2023 10:17:31 +0900 Subject: [PATCH 30/50] add comment --- rpxy-lib/src/forwarder/client.rs | 5 +++++ rpxy-lib/src/proxy/proxy_h3.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 9aab75a..5718f2e 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -74,10 +74,15 @@ where return res.map(wrap_incoming_body_response::); }; let (parts, body) = res.unwrap().into_parts(); + let Ok(bytes) = body.collect().await.map(|v| v.to_bytes()) else { return Err(RpxyError::FailedToWriteByteBufferForCache); }; + // TODO: this is inefficient. needs to be reconsidered to avoid unnecessary copy and should spawn async task to store cache. + // We may need to use the same logic as h3. + // Is bytes.clone() enough? + // if let Err(cache_err) = self // .cache // .as_ref() diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 8abb710..d194b1f 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -122,7 +122,7 @@ where size += body.remaining(); if size > max_body_size { error!( - "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", + "Exceeds max request body size for HTTP/3: received {}, maximum_allowed {}", size, max_body_size ); return Err(RpxyError::H3TooLargeBody); From ed33c5d4f119b26fb75b2ea64ccff22ed3bf0913 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 9 Dec 2023 12:14:59 +0900 Subject: [PATCH 31/50] wip: implement on-memory cache as is --- rpxy-lib/Cargo.toml | 6 ++--- rpxy-lib/src/error.rs | 4 ++++ rpxy-lib/src/forwarder/cache.rs | 42 ++++++++++++++++++++++++++++++--- 3 files changed, 46 insertions(+), 6 deletions(-) diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index fe715e0..65c7c58 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -24,7 +24,7 @@ http3-s2n = [ sticky-cookie = ["base64", "sha2", "chrono"] native-tls-backend = ["hyper-tls"] rustls-backend = [] -cache = ["http-cache-semantics", "lru"] +cache = ["http-cache-semantics", "lru", "sha2", "base64"] native-roots = [] #"hyper-rustls/native-tokio"] [dependencies] @@ -87,9 +87,10 @@ s2n-quic-rustls = { version = "0.32.0", optional = true } # for UDP socket wit SO_REUSEADDR when h3 with quinn socket2 = { version = "0.5.5", features = ["all"], optional = true } -# # cache +# cache http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } lru = { version = "0.12.1", optional = true } +sha2 = { version = "0.10.8", default-features = false, optional = true } # cookie handling for sticky cookie chrono = { version = "0.4.31", default-features = false, features = [ @@ -98,7 +99,6 @@ chrono = { version = "0.4.31", default-features = false, features = [ "clock", ], optional = true } base64 = { version = "0.21.5", optional = true } -sha2 = { version = "0.10.8", default-features = false, optional = true } [dev-dependencies] diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 4cbc463..2763d1e 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -93,6 +93,10 @@ pub enum RpxyError { #[error("Failed to write byte buffer")] FailedToWriteByteBufferForCache, + #[cfg(feature = "cache")] + #[error("Failed to acquire mutex lock for cache")] + FailedToAcquiredMutexLockForCache, + // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, diff --git a/rpxy-lib/src/forwarder/cache.rs b/rpxy-lib/src/forwarder/cache.rs index 73bed7b..ea29a41 100644 --- a/rpxy-lib/src/forwarder/cache.rs +++ b/rpxy-lib/src/forwarder/cache.rs @@ -4,14 +4,18 @@ use http_cache_semantics::CachePolicy; use lru::LruCache; use std::{ path::{Path, PathBuf}, - sync::{atomic::AtomicUsize, Arc, Mutex}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, }; use tokio::{fs, sync::RwLock}; /* ---------------------------------------------- */ #[derive(Clone, Debug)] +/// Cache main manager pub struct RpxyCache { - /// Lru cache storing http message caching policy + /// Inner lru cache manager storing http message caching policy inner: LruCacheManager, /// Managing cache file objects through RwLock's lock mechanism for file lock file_store: FileStore, @@ -122,7 +126,9 @@ struct CacheObject { #[derive(Debug, Clone)] /// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache` struct LruCacheManager { + /// Inner lru cache manager main object inner: Arc>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう + /// Counter of current cached object (total) cnt: Arc, } @@ -133,12 +139,42 @@ impl LruCacheManager { inner: Arc::new(Mutex::new(LruCache::new( std::num::NonZeroUsize::new(cache_max_entry).unwrap(), ))), - cnt: Arc::new(AtomicUsize::default()), + cnt: Default::default(), } } + + /// Count entries + fn count(&self) -> usize { + self.cnt.load(Ordering::Relaxed) + } + + /// Evict an entry + fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> { + let Ok(mut lock) = self.inner.lock() else { + error!("Mutex can't be locked to evict a cache entry"); + return None; + }; + let res = lock.pop_entry(cache_key); + // This may be inconsistent with the actual number of entries + self.cnt.store(lock.len(), Ordering::Relaxed); + res + } + + /// Push an entry + fn push(&self, cache_key: &str, cache_object: CacheObject) -> RpxyResult> { + let Ok(mut lock) = self.inner.lock() else { + error!("Failed to acquire mutex lock for writing cache entry"); + return Err(RpxyError::FailedToAcquiredMutexLockForCache); + }; + let res = Ok(lock.push(cache_key.to_string(), cache_object)); + // This may be inconsistent with the actual number of entries + self.cnt.store(lock.len(), Ordering::Relaxed); + res + } } /* ---------------------------------------------- */ +/// Generate cache policy if the response is cacheable pub fn get_policy_if_cacheable( req: Option<&Request>, res: Option<&Response>, From cc48394e7302076aac1794fc180374850a2de643 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Mon, 11 Dec 2023 16:52:54 +0900 Subject: [PATCH 32/50] wip: feat: update h3 response reader to use async stream --- rpxy-lib/src/error.rs | 12 ++++ rpxy-lib/src/forwarder/cache.rs | 104 ++++++++++++++++++++++++++++++- rpxy-lib/src/forwarder/client.rs | 3 + rpxy-lib/src/proxy/proxy_h3.rs | 34 +++++++--- 4 files changed, 145 insertions(+), 8 deletions(-) diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 2763d1e..343cf04 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -97,6 +97,18 @@ pub enum RpxyError { #[error("Failed to acquire mutex lock for cache")] FailedToAcquiredMutexLockForCache, + #[cfg(feature = "cache")] + #[error("Failed to create file cache")] + FailedToCreateFileCache, + + #[cfg(feature = "cache")] + #[error("Failed to write file cache")] + FailedToWriteFileCache, + + #[cfg(feature = "cache")] + #[error("Failed to open cache file")] + FailedToOpenCacheFile, + // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, diff --git a/rpxy-lib/src/forwarder/cache.rs b/rpxy-lib/src/forwarder/cache.rs index ea29a41..03755e6 100644 --- a/rpxy-lib/src/forwarder/cache.rs +++ b/rpxy-lib/src/forwarder/cache.rs @@ -1,15 +1,22 @@ use crate::{error::*, globals::Globals, log::*}; +use bytes::{Buf, Bytes, BytesMut}; use http::{Request, Response}; +use http_body_util::StreamBody; use http_cache_semantics::CachePolicy; use lru::LruCache; use std::{ + convert::Infallible, path::{Path, PathBuf}, sync::{ atomic::{AtomicUsize, Ordering}, Arc, Mutex, }, }; -use tokio::{fs, sync::RwLock}; +use tokio::{ + fs::{self, File}, + io::{AsyncReadExt, AsyncWriteExt}, + sync::RwLock, +}; /* ---------------------------------------------- */ #[derive(Clone, Debug)] @@ -54,6 +61,14 @@ impl RpxyCache { max_each_size_on_memory, }) } + + /// Count cache entries + pub async fn count(&self) -> (usize, usize, usize) { + let total = self.inner.count(); + let file = self.file_store.count().await; + let on_memory = total - file; + (total, on_memory, file) + } } /* ---------------------------------------------- */ @@ -71,6 +86,32 @@ impl FileStore { } } +impl FileStore { + /// Count file cache entries + async fn count(&self) -> usize { + let inner = self.inner.read().await; + inner.cnt + } + /// Create a temporary file cache + async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> RpxyResult { + let mut inner = self.inner.write().await; + inner.create(cache_filename, body_bytes).await + } + // /// Evict a temporary file cache + // async fn evict(&self, path: impl AsRef) { + // // Acquire the write lock + // let mut inner = self.inner.write().await; + // if let Err(e) = inner.remove(path).await { + // warn!("Eviction failed during file object removal: {:?}", e); + // }; + // } + // /// Read a temporary file cache + // async fn read(&self, path: impl AsRef) -> RpxyResult { + // let inner = self.inner.read().await; + // inner.read(&path).await + // } +} + #[derive(Debug)] /// Manager inner for cache on file system struct FileStoreInner { @@ -98,6 +139,67 @@ impl FileStoreInner { runtime_handle: runtime_handle.clone(), } } + + /// Create a new temporary file cache + async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> RpxyResult { + let cache_filepath = self.cache_dir.join(cache_filename); + let Ok(mut file) = File::create(&cache_filepath).await else { + return Err(RpxyError::FailedToCreateFileCache); + }; + let mut bytes_clone = body_bytes.clone(); + while bytes_clone.has_remaining() { + if let Err(e) = file.write_buf(&mut bytes_clone).await { + error!("Failed to write file cache: {e}"); + return Err(RpxyError::FailedToWriteFileCache); + }; + } + self.cnt += 1; + Ok(CacheFileOrOnMemory::File(cache_filepath)) + } + + /// Retrieve a stored temporary file cache + async fn read(&self, path: impl AsRef) -> RpxyResult<()> { + let Ok(mut file) = File::open(&path).await else { + warn!("Cache file object cannot be opened"); + return Err(RpxyError::FailedToOpenCacheFile); + }; + + /* ----------------------------- */ + // PoC for streaming body + use futures::channel::mpsc; + let (tx, rx) = mpsc::unbounded::, Infallible>>(); + + // let (body_sender, res_body) = Body::channel(); + self.runtime_handle.spawn(async move { + // let mut sender = body_sender; + let mut buf = BytesMut::new(); + loop { + match file.read_buf(&mut buf).await { + Ok(0) => break, + Ok(_) => tx + .unbounded_send(Ok(hyper::body::Frame::data(buf.copy_to_bytes(buf.remaining())))) + .map_err(|e| anyhow::anyhow!("Failed to read cache file: {e}"))?, + //sender.send_data(buf.copy_to_bytes(buf.remaining())).await?, + Err(_) => break, + }; + } + Ok(()) as anyhow::Result<()> + }); + + let mut rx = http_body_util::StreamBody::new(rx); + // TODO: 結局incominglikeなbodystreamを定義することになる。これだったらh3と合わせて自分で定義した方が良さそう。 + // typeが長すぎるのでwrapperを作った方がいい。 + // let response = Response::builder() + // .status(200) + // .header("content-type", "application/octet-stream") + // .body(rx) + // .unwrap(); + + todo!() + /* ----------------------------- */ + + // Ok(res_body) + } } /* ---------------------------------------------- */ diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 5718f2e..c6c6218 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -75,9 +75,12 @@ where }; let (parts, body) = res.unwrap().into_parts(); + // TODO: This is inefficient since current strategy needs to copy the whole body onto memory to cache it. + // This should be handled by copying buffer simultaneously while forwarding response to downstream. let Ok(bytes) = body.collect().await.map(|v| v.to_bytes()) else { return Err(RpxyError::FailedToWriteByteBufferForCache); }; + let bytes_clone = bytes.clone(); // TODO: this is inefficient. needs to be reconsidered to avoid unnecessary copy and should spawn async task to store cache. // We may need to use the same logic as h3. diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index d194b1f..1846d67 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -153,20 +153,40 @@ where ) .await?; - let (new_res_parts, new_body) = res.into_parts(); + let (new_res_parts, mut new_body) = res.into_parts(); let new_res = Response::from_parts(new_res_parts, ()); match send_stream.send_response(new_res).await { Ok(_) => { debug!("HTTP/3 response to connection successful"); - // aggregate body without copying - let body_data = new_body - .collect() - .await + loop { + let frame = match new_body.frame().await { + Some(frame) => frame, + None => { + debug!("Response body finished"); + break; + } + } .map_err(|e| RpxyError::HyperBodyManipulationError(e.to_string()))?; - // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes inside to_bytes() - send_stream.send_data(body_data.to_bytes()).await?; + if frame.is_data() { + let data = frame.into_data().unwrap_or_default(); + debug!("Write data to HTTP/3 stream"); + send_stream.send_data(data).await?; + } else if frame.is_trailers() { + let trailers = frame.into_trailers().unwrap_or_default(); + debug!("Write trailer to HTTP/3 stream"); + send_stream.send_trailers(trailers).await?; + } + } + // // aggregate body without copying + // let body_data = new_body + // .collect() + // .await + // .map_err(|e| RpxyError::HyperBodyManipulationError(e.to_string()))?; + + // // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes inside to_bytes() + // send_stream.send_data(body_data.to_bytes()).await?; // TODO: needs handling trailer? should be included in body from handler. } From d526ce6cb478fd89c5b76759053ecb4a3799b682 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Mon, 11 Dec 2023 18:23:08 +0900 Subject: [PATCH 33/50] wip: refactor: reconsider timeouts of connections --- rpxy-lib/src/constants.rs | 4 ++-- rpxy-lib/src/forwarder/client.rs | 2 ++ rpxy-lib/src/globals.rs | 10 ++++---- rpxy-lib/src/message_handler/handler_main.rs | 15 ++++-------- rpxy-lib/src/message_handler/http_result.rs | 3 --- rpxy-lib/src/proxy/mod.rs | 2 ++ rpxy-lib/src/proxy/proxy_h3.rs | 13 ++++------ rpxy-lib/src/proxy/proxy_main.rs | 25 ++++++++++---------- 8 files changed, 35 insertions(+), 39 deletions(-) diff --git a/rpxy-lib/src/constants.rs b/rpxy-lib/src/constants.rs index ebec1fc..acc9381 100644 --- a/rpxy-lib/src/constants.rs +++ b/rpxy-lib/src/constants.rs @@ -4,8 +4,8 @@ pub const RESPONSE_HEADER_SERVER: &str = "rpxy"; pub const TCP_LISTEN_BACKLOG: u32 = 1024; // pub const HTTP_LISTEN_PORT: u16 = 8080; // pub const HTTPS_LISTEN_PORT: u16 = 8443; -pub const PROXY_TIMEOUT_SEC: u64 = 60; -pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; +pub const PROXY_IDLE_TIMEOUT_SEC: u64 = 20; +pub const UPSTREAM_IDLE_TIMEOUT_SEC: u64 = 20; pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser pub const MAX_CLIENTS: usize = 512; pub const MAX_CONCURRENT_STREAMS: u32 = 64; diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index c6c6218..8b86f9f 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -8,6 +8,7 @@ use crate::{ log::*, }; use async_trait::async_trait; +use chrono::Duration; use http::{Request, Response, Version}; use hyper::body::{Body, Incoming}; use hyper_util::client::legacy::{ @@ -184,6 +185,7 @@ where let mut http = HttpConnector::new(); http.enforce_http(false); http.set_reuse_address(true); + http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); hyper_tls::HttpsConnector::from((http, tls.into())) }) }; diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index 86fdc46..9cd62b3 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -33,8 +33,10 @@ pub struct ProxyConfig { /// tcp listen backlog pub tcp_listen_backlog: u32, - pub proxy_timeout: Duration, // when serving requests at Proxy - pub upstream_timeout: Duration, // when serving requests at Handler + /// Idle timeout as an HTTP server, used as the keep alive interval and timeout for reading request header + pub proxy_idle_timeout: Duration, + /// Idle timeout as an HTTP client, used as the keep alive interval for upstream connections + pub upstream_idle_timeout: Duration, pub max_clients: usize, // when serving requests pub max_concurrent_streams: u32, // when instantiate server @@ -80,8 +82,8 @@ impl Default for ProxyConfig { tcp_listen_backlog: TCP_LISTEN_BACKLOG, // TODO: Reconsider each timeout values - proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), - upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), + proxy_idle_timeout: Duration::from_secs(PROXY_IDLE_TIMEOUT_SEC), + upstream_idle_timeout: Duration::from_secs(UPSTREAM_IDLE_TIMEOUT_SEC), max_clients: MAX_CLIENTS, max_concurrent_streams: MAX_CONCURRENT_STREAMS, diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index a9fae01..251411b 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -19,7 +19,7 @@ use derive_builder::Builder; use http::{Request, Response, StatusCode}; use hyper_util::{client::legacy::connect::Connect, rt::TokioIo}; use std::{net::SocketAddr, sync::Arc}; -use tokio::{io::copy_bidirectional, time::timeout}; +use tokio::io::copy_bidirectional; #[allow(dead_code)] #[derive(Debug)] @@ -172,15 +172,10 @@ where ////////////// // Forward request to a chosen backend - let mut res_backend = { - let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { - return Err(HttpError::TimeoutUpstreamRequest); - }; - match result { - Ok(res) => res, - Err(e) => { - return Err(HttpError::FailedToGetResponseFromBackend(e.to_string())); - } + let mut res_backend = match self.forwarder.request(req).await { + Ok(v) => v, + Err(e) => { + return Err(HttpError::FailedToGetResponseFromBackend(e.to_string())); } }; ////////////// diff --git a/rpxy-lib/src/message_handler/http_result.rs b/rpxy-lib/src/message_handler/http_result.rs index 857ab55..ec48200 100644 --- a/rpxy-lib/src/message_handler/http_result.rs +++ b/rpxy-lib/src/message_handler/http_result.rs @@ -22,8 +22,6 @@ pub enum HttpError { NoUpstreamCandidates, #[error("Failed to generate upstream request for backend application: {0}")] FailedToGenerateUpstreamRequest(String), - #[error("Timeout in upstream request")] - TimeoutUpstreamRequest, #[error("Failed to get response from backend: {0}")] FailedToGetResponseFromBackend(String), @@ -53,7 +51,6 @@ impl From for StatusCode { HttpError::FailedToRedirect(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::NoUpstreamCandidates => StatusCode::NOT_FOUND, HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, - HttpError::TimeoutUpstreamRequest => StatusCode::GATEWAY_TIMEOUT, HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR, diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 389df43..d1aa5c3 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -19,9 +19,11 @@ pub(crate) fn connection_builder(globals: &Arc) -> Arc| { serve_request( @@ -104,10 +102,9 @@ where tls_server_name.clone(), ) }), - ), - ) - .await - .ok(); + ) + .await + .ok(); request_count.decrement(); debug!("Request processed: current # {}", request_count.current()); @@ -201,8 +198,7 @@ where return Err(RpxyError::FailedToTlsHandshake(e.to_string())); } }; - self_inner.serve_connection(stream, client_addr, server_name); - Ok(()) as RpxyResult<()> + Ok((stream, client_addr, server_name)) }; self.globals.runtime_handle.spawn( async move { @@ -214,8 +210,13 @@ where error!("Timeout to handshake TLS"); return; }; - if let Err(e) = v { - error!("{}", e); + match v { + Ok((stream, client_addr, server_name)) => { + self_inner.serve_connection(stream, client_addr, server_name); + } + Err(e) => { + error!("{}", e); + } } }); } From b8f3034014231b8aab945f027a499c3773b653fc Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Mon, 11 Dec 2023 18:40:31 +0900 Subject: [PATCH 34/50] wip: fix keep alive timeouts --- rpxy-lib/Cargo.toml | 1 + rpxy-lib/src/hyper_ext/mod.rs | 2 ++ rpxy-lib/src/proxy/mod.rs | 7 ++++++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 65c7c58..f30f4bb 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -41,6 +41,7 @@ tokio = { version = "1.34.0", default-features = false, features = [ "macros", "fs", ] } +pin-project-lite = "0.2.13" async-trait = "0.1.74" # Error handling diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index e6c81e7..cfa2b70 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -1,10 +1,12 @@ mod body_incoming_like; mod body_type; mod executor; +mod tokio_timer; mod watch; pub(crate) mod rt { pub(crate) use super::executor::LocalExecutor; + pub(crate) use super::tokio_timer::{TokioSleep, TokioTimer}; } pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index d1aa5c3..a7c1ec8 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -6,7 +6,10 @@ mod proxy_quic_quinn; mod proxy_quic_s2n; mod socket; -use crate::{globals::Globals, hyper_ext::rt::LocalExecutor}; +use crate::{ + globals::Globals, + hyper_ext::rt::{LocalExecutor, TokioTimer}, +}; use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; use std::sync::Arc; @@ -20,10 +23,12 @@ pub(crate) fn connection_builder(globals: &Arc) -> Arc Date: Mon, 11 Dec 2023 18:41:17 +0900 Subject: [PATCH 35/50] add tokio timer --- rpxy-lib/src/hyper_ext/mod.rs | 3 +- rpxy-lib/src/hyper_ext/tokio_timer.rs | 55 +++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 1 deletion(-) create mode 100644 rpxy-lib/src/hyper_ext/tokio_timer.rs diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index cfa2b70..922776c 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -4,13 +4,14 @@ mod executor; mod tokio_timer; mod watch; +#[allow(unused)] pub(crate) mod rt { pub(crate) use super::executor::LocalExecutor; pub(crate) use super::tokio_timer::{TokioSleep, TokioTimer}; } +#[allow(unused)] pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; - #[allow(unused)] pub(crate) use super::body_type::{ empty, full, wrap_incoming_body_response, wrap_synthetic_body_response, BoxBody, IncomingOr, }; diff --git a/rpxy-lib/src/hyper_ext/tokio_timer.rs b/rpxy-lib/src/hyper_ext/tokio_timer.rs new file mode 100644 index 0000000..53a1af7 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/tokio_timer.rs @@ -0,0 +1,55 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use hyper::rt::{Sleep, Timer}; +use pin_project_lite::pin_project; + +#[derive(Clone, Debug)] +pub struct TokioTimer; + +impl Timer for TokioTimer { + fn sleep(&self, duration: Duration) -> Pin> { + Box::pin(TokioSleep { + inner: tokio::time::sleep(duration), + }) + } + + fn sleep_until(&self, deadline: Instant) -> Pin> { + Box::pin(TokioSleep { + inner: tokio::time::sleep_until(deadline.into()), + }) + } + + fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + if let Some(sleep) = sleep.as_mut().downcast_mut_pin::() { + sleep.reset(new_deadline) + } + } +} + +pin_project! { + pub(crate) struct TokioSleep { + #[pin] + pub(crate) inner: tokio::time::Sleep, + } +} + +impl Future for TokioSleep { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +impl Sleep for TokioSleep {} + +impl TokioSleep { + pub fn reset(self: Pin<&mut Self>, deadline: Instant) { + self.project().inner.as_mut().reset(deadline.into()); + } +} From 008b62a9256a9b71e0752ab4433e0fb43600e20e Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 12 Dec 2023 19:58:33 +0900 Subject: [PATCH 36/50] wip: feat: define response body enum --- rpxy-lib/src/error.rs | 30 +----- rpxy-lib/src/forwarder/cache/cache_error.rs | 35 +++++++ .../{cache.rs => cache/cache_main.rs} | 97 ++++++++++++++++--- rpxy-lib/src/forwarder/cache/mod.rs | 5 + rpxy-lib/src/forwarder/client.rs | 53 ++++------ rpxy-lib/src/forwarder/mod.rs | 7 +- rpxy-lib/src/hyper_ext/body_type.rs | 67 ++++++++++--- rpxy-lib/src/hyper_ext/mod.rs | 4 +- rpxy-lib/src/message_handler/handler_main.rs | 6 +- .../src/message_handler/synthetic_response.rs | 10 +- rpxy-lib/src/proxy/proxy_h3.rs | 1 - rpxy-lib/src/proxy/proxy_main.rs | 4 +- 12 files changed, 215 insertions(+), 104 deletions(-) create mode 100644 rpxy-lib/src/forwarder/cache/cache_error.rs rename rpxy-lib/src/forwarder/{cache.rs => cache/cache_main.rs} (75%) create mode 100644 rpxy-lib/src/forwarder/cache/mod.rs diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 343cf04..f63a06c 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -84,35 +84,15 @@ pub enum RpxyError { #[error("Failed to fetch from upstream: {0}")] FailedToFetchFromUpstream(String), - // Cache errors, - #[cfg(feature = "cache")] - #[error("Invalid null request and/or response")] - NullRequestOrResponse, - - #[cfg(feature = "cache")] - #[error("Failed to write byte buffer")] - FailedToWriteByteBufferForCache, - - #[cfg(feature = "cache")] - #[error("Failed to acquire mutex lock for cache")] - FailedToAcquiredMutexLockForCache, - - #[cfg(feature = "cache")] - #[error("Failed to create file cache")] - FailedToCreateFileCache, - - #[cfg(feature = "cache")] - #[error("Failed to write file cache")] - FailedToWriteFileCache, - - #[cfg(feature = "cache")] - #[error("Failed to open cache file")] - FailedToOpenCacheFile, - // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, + // Cache error map + #[cfg(feature = "cache")] + #[error("Cache error: {0}")] + CacheError(#[from] crate::forwarder::CacheError), + // Others #[error("Infallible")] Infallible(#[from] std::convert::Infallible), diff --git a/rpxy-lib/src/forwarder/cache/cache_error.rs b/rpxy-lib/src/forwarder/cache/cache_error.rs new file mode 100644 index 0000000..bb2ffa6 --- /dev/null +++ b/rpxy-lib/src/forwarder/cache/cache_error.rs @@ -0,0 +1,35 @@ +use thiserror::Error; + +pub type CacheResult = std::result::Result; + +/// Describes things that can go wrong in the Rpxy +#[derive(Debug, Error)] +pub enum CacheError { + // Cache errors, + #[error("Invalid null request and/or response")] + NullRequestOrResponse, + + #[error("Failed to write byte buffer")] + FailedToWriteByteBufferForCache, + + #[error("Failed to acquire mutex lock for cache")] + FailedToAcquiredMutexLockForCache, + + #[error("Failed to create file cache")] + FailedToCreateFileCache, + + #[error("Failed to write file cache")] + FailedToWriteFileCache, + + #[error("Failed to open cache file")] + FailedToOpenCacheFile, + + #[error("Too large to cache")] + TooLargeToCache, + + #[error("Failed to cache bytes: {0}")] + FailedToCacheBytes(String), + + #[error("Failed to send frame to cache {0}")] + FailedToSendFrameToCache(String), +} diff --git a/rpxy-lib/src/forwarder/cache.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs similarity index 75% rename from rpxy-lib/src/forwarder/cache.rs rename to rpxy-lib/src/forwarder/cache/cache_main.rs index 03755e6..2bc4548 100644 --- a/rpxy-lib/src/forwarder/cache.rs +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -1,8 +1,11 @@ -use crate::{error::*, globals::Globals, log::*}; +use super::cache_error::*; +use crate::{globals::Globals, hyper_ext::body::UnboundedStreamBody, log::*}; use bytes::{Buf, Bytes, BytesMut}; +use futures::channel::mpsc; use http::{Request, Response}; -use http_body_util::StreamBody; +use http_body_util::{BodyExt, StreamBody}; use http_cache_semantics::CachePolicy; +use hyper::body::{Body, Frame, Incoming}; use lru::LruCache; use std::{ convert::Infallible, @@ -69,6 +72,73 @@ impl RpxyCache { let on_memory = total - file; (total, on_memory, file) } + + /// Put response into the cache + pub async fn put( + &self, + uri: &hyper::Uri, + mut body: Incoming, + policy: &CachePolicy, + ) -> CacheResult { + let my_cache = self.inner.clone(); + let mut file_store = self.file_store.clone(); + let uri = uri.clone(); + let policy_clone = policy.clone(); + let max_each_size = self.max_each_size; + let max_each_size_on_memory = self.max_each_size_on_memory; + + let (body_tx, body_rx) = mpsc::unbounded::, hyper::Error>>(); + + self.runtime_handle.spawn(async move { + let mut size = 0usize; + loop { + let frame = match body.frame().await { + Some(frame) => frame, + None => { + debug!("Response body finished"); + break; + } + }; + let frame_size = frame.as_ref().map(|f| { + if f.is_data() { + f.data_ref().map(|bytes| bytes.remaining()).unwrap_or_default() + } else { + 0 + } + }); + size += frame_size.unwrap_or_default(); + + // check size + if size > max_each_size { + warn!("Too large to cache"); + return Err(CacheError::TooLargeToCache); + } + frame + .as_ref() + .map(|f| { + if f.is_data() { + let data_bytes = f.data_ref().unwrap().clone(); + println!("ddddde"); + // TODO: cache data bytes as file or on memory + // fileにするかmemoryにするかの判断はある程度までバッファしてやってという手を使うことになる。途中までキャッシュしたやつはどうするかとかいう判断も必要。 + // ファイルとObjectのbindをどうやってするか + } + }) + .map_err(|e| CacheError::FailedToCacheBytes(e.to_string()))?; + + // send data to use response downstream + body_tx + .unbounded_send(frame) + .map_err(|e| CacheError::FailedToSendFrameToCache(e.to_string()))?; + } + + Ok(()) as CacheResult<()> + }); + + let stream_body = StreamBody::new(body_rx); + + Ok(stream_body) + } } /* ---------------------------------------------- */ @@ -93,7 +163,7 @@ impl FileStore { inner.cnt } /// Create a temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> RpxyResult { + async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> CacheResult { let mut inner = self.inner.write().await; inner.create(cache_filename, body_bytes).await } @@ -106,7 +176,7 @@ impl FileStore { // }; // } // /// Read a temporary file cache - // async fn read(&self, path: impl AsRef) -> RpxyResult { + // async fn read(&self, path: impl AsRef) -> CacheResult { // let inner = self.inner.read().await; // inner.read(&path).await // } @@ -141,16 +211,16 @@ impl FileStoreInner { } /// Create a new temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> RpxyResult { + async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> CacheResult { let cache_filepath = self.cache_dir.join(cache_filename); let Ok(mut file) = File::create(&cache_filepath).await else { - return Err(RpxyError::FailedToCreateFileCache); + return Err(CacheError::FailedToCreateFileCache); }; let mut bytes_clone = body_bytes.clone(); while bytes_clone.has_remaining() { if let Err(e) = file.write_buf(&mut bytes_clone).await { error!("Failed to write file cache: {e}"); - return Err(RpxyError::FailedToWriteFileCache); + return Err(CacheError::FailedToWriteFileCache); }; } self.cnt += 1; @@ -158,15 +228,14 @@ impl FileStoreInner { } /// Retrieve a stored temporary file cache - async fn read(&self, path: impl AsRef) -> RpxyResult<()> { + async fn read(&self, path: impl AsRef) -> CacheResult<()> { let Ok(mut file) = File::open(&path).await else { warn!("Cache file object cannot be opened"); - return Err(RpxyError::FailedToOpenCacheFile); + return Err(CacheError::FailedToOpenCacheFile); }; /* ----------------------------- */ // PoC for streaming body - use futures::channel::mpsc; let (tx, rx) = mpsc::unbounded::, Infallible>>(); // let (body_sender, res_body) = Body::channel(); @@ -263,10 +332,10 @@ impl LruCacheManager { } /// Push an entry - fn push(&self, cache_key: &str, cache_object: CacheObject) -> RpxyResult> { + fn push(&self, cache_key: &str, cache_object: CacheObject) -> CacheResult> { let Ok(mut lock) = self.inner.lock() else { error!("Failed to acquire mutex lock for writing cache entry"); - return Err(RpxyError::FailedToAcquiredMutexLockForCache); + return Err(CacheError::FailedToAcquiredMutexLockForCache); }; let res = Ok(lock.push(cache_key.to_string(), cache_object)); // This may be inconsistent with the actual number of entries @@ -280,13 +349,13 @@ impl LruCacheManager { pub fn get_policy_if_cacheable( req: Option<&Request>, res: Option<&Response>, -) -> RpxyResult> +) -> CacheResult> // where // B1: core::fmt::Debug, { // deduce cache policy from req and res let (Some(req), Some(res)) = (req, res) else { - return Err(RpxyError::NullRequestOrResponse); + return Err(CacheError::NullRequestOrResponse); }; let new_policy = CachePolicy::new(req, res); diff --git a/rpxy-lib/src/forwarder/cache/mod.rs b/rpxy-lib/src/forwarder/cache/mod.rs new file mode 100644 index 0000000..cfe5a1b --- /dev/null +++ b/rpxy-lib/src/forwarder/cache/mod.rs @@ -0,0 +1,5 @@ +mod cache_error; +mod cache_main; + +pub use cache_error::CacheError; +pub use cache_main::{get_policy_if_cacheable, CacheFileOrOnMemory, RpxyCache}; diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 8b86f9f..8d2e307 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -1,14 +1,10 @@ use crate::{ error::{RpxyError, RpxyResult}, globals::Globals, - hyper_ext::{ - body::{wrap_incoming_body_response, BoxBody, IncomingOr}, - rt::LocalExecutor, - }, + hyper_ext::{body::ResponseBody, rt::LocalExecutor}, log::*, }; use async_trait::async_trait; -use chrono::Duration; use http::{Request, Response, Version}; use hyper::body::{Body, Incoming}; use hyper_util::client::legacy::{ @@ -19,10 +15,6 @@ use std::sync::Arc; #[cfg(feature = "cache")] use super::cache::{get_policy_if_cacheable, RpxyCache}; -#[cfg(feature = "cache")] -use crate::hyper_ext::body::{full, wrap_synthetic_body_response}; -#[cfg(feature = "cache")] -use http_body_util::BodyExt; #[async_trait] /// Definition of the forwarder that simply forward requests from downstream client to upstream app servers. @@ -40,7 +32,7 @@ pub struct Forwarder { } #[async_trait] -impl ForwardRequest> for Forwarder +impl ForwardRequest for Forwarder where C: Send + Sync + Connect + Clone + 'static, B1: Body + Send + Sync + Unpin + 'static, @@ -49,7 +41,7 @@ where { type Error = RpxyError; - async fn request(&self, req: Request) -> Result>, Self::Error> { + async fn request(&self, req: Request) -> Result, Self::Error> { // TODO: cache handling #[cfg(feature = "cache")] { @@ -67,38 +59,27 @@ where let res = self.request_directly(req).await; if self.cache.is_none() { - return res.map(wrap_incoming_body_response::); + return res.map(|inner| inner.map(ResponseBody::Incoming)); } // check cacheability and store it if cacheable let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { - return res.map(wrap_incoming_body_response::); + return res.map(|inner| inner.map(ResponseBody::Incoming)); }; let (parts, body) = res.unwrap().into_parts(); - // TODO: This is inefficient since current strategy needs to copy the whole body onto memory to cache it. - // This should be handled by copying buffer simultaneously while forwarding response to downstream. - let Ok(bytes) = body.collect().await.map(|v| v.to_bytes()) else { - return Err(RpxyError::FailedToWriteByteBufferForCache); - }; - let bytes_clone = bytes.clone(); + // Get streamed body without waiting for the arrival of the body, + // which is done simultaneously with caching. + let stream_body = self + .cache + .as_ref() + .unwrap() + .put(synth_req.unwrap().uri(), body, &cache_policy) + .await?; - // TODO: this is inefficient. needs to be reconsidered to avoid unnecessary copy and should spawn async task to store cache. - // We may need to use the same logic as h3. - // Is bytes.clone() enough? - - // if let Err(cache_err) = self - // .cache - // .as_ref() - // .unwrap() - // .put(synth_req.unwrap().uri(), &bytes, &cache_policy) - // .await - // { - // error!("{:?}", cache_err); - // }; - - // response with cached body - Ok(wrap_synthetic_body_response(Response::from_parts(parts, full(bytes)))) + // response with body being cached in background + let new_res = Response::from_parts(parts, ResponseBody::Streamed(stream_body)); + Ok(new_res) } // No cache handling @@ -107,7 +88,7 @@ where self .request_directly(req) .await - .map(wrap_incoming_body_response::) + .map(|inner| inner.map(ResponseBody::Incoming)) } } } diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs index 286cb40..d53cd73 100644 --- a/rpxy-lib/src/forwarder/mod.rs +++ b/rpxy-lib/src/forwarder/mod.rs @@ -3,6 +3,9 @@ mod cache; mod client; use crate::hyper_ext::body::{IncomingLike, IncomingOr}; -pub type Forwarder = client::Forwarder>; -pub use client::ForwardRequest; +pub(crate) type Forwarder = client::Forwarder>; +pub(crate) use client::ForwardRequest; + +#[cfg(feature = "cache")] +pub(crate) use cache::CacheError; diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs index 9616306..c1eb54b 100644 --- a/rpxy-lib/src/hyper_ext/body_type.rs +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -1,24 +1,25 @@ -use http::Response; +// use http::Response; use http_body_util::{combinators, BodyExt, Either, Empty, Full}; -use hyper::body::{Bytes, Incoming}; +use hyper::body::{Body, Bytes, Incoming}; +use std::pin::Pin; /// Type for synthetic boxed body pub(crate) type BoxBody = combinators::BoxBody; /// Type for either passthrough body or given body type, specifically synthetic boxed body pub(crate) type IncomingOr = Either; -/// helper function to build http response with passthrough body -pub(crate) fn wrap_incoming_body_response(response: Response) -> Response> -where - B: hyper::body::Body, -{ - response.map(IncomingOr::Left) -} +// /// helper function to build http response with passthrough body +// pub(crate) fn wrap_incoming_body_response(response: Response) -> Response> +// where +// B: hyper::body::Body, +// { +// response.map(IncomingOr::Left) +// } -/// helper function to build http response with synthetic body -pub(crate) fn wrap_synthetic_body_response(response: Response) -> Response> { - response.map(IncomingOr::Right) -} +// /// helper function to build http response with synthetic body +// pub(crate) fn wrap_synthetic_body_response(response: Response) -> Response> { +// response.map(IncomingOr::Right) +// } /// helper function to build a empty body pub(crate) fn empty() -> BoxBody { @@ -29,3 +30,43 @@ pub(crate) fn empty() -> BoxBody { pub(crate) fn full(body: Bytes) -> BoxBody { Full::new(body).map_err(|never| match never {}).boxed() } + +/* ------------------------------------ */ +#[cfg(feature = "cache")] +use futures::channel::mpsc::UnboundedReceiver; +#[cfg(feature = "cache")] +use http_body_util::StreamBody; +#[cfg(feature = "cache")] +use hyper::body::Frame; + +#[cfg(feature = "cache")] +pub(crate) type UnboundedStreamBody = StreamBody, hyper::Error>>>; + +/// Response body use in this project +/// - Incoming: just a type that only forwards the upstream response body to downstream. +/// - BoxedCache: a type that is generated from cache, e.g.,, small byte object. +/// - StreamedCache: another type that is generated from cache as stream, e.g., large byte object. +pub(crate) enum ResponseBody { + Incoming(Incoming), + Boxed(BoxBody), + #[cfg(feature = "cache")] + Streamed(UnboundedStreamBody), +} + +impl Body for ResponseBody { + type Data = bytes::Bytes; + type Error = hyper::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + match self.get_mut() { + ResponseBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx), + #[cfg(feature = "cache")] + ResponseBody::Boxed(boxed) => Pin::new(boxed).poll_frame(cx), + #[cfg(feature = "cache")] + ResponseBody::Streamed(streamed) => Pin::new(streamed).poll_frame(cx), + } + } +} diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index 922776c..8b3776c 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -12,7 +12,5 @@ pub(crate) mod rt { #[allow(unused)] pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; - pub(crate) use super::body_type::{ - empty, full, wrap_incoming_body_response, wrap_synthetic_body_response, BoxBody, IncomingOr, - }; + pub(crate) use super::body_type::{empty, full, BoxBody, IncomingOr, ResponseBody, UnboundedStreamBody}; } diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index 251411b..b5ae87d 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -11,7 +11,7 @@ use crate::{ error::*, forwarder::{ForwardRequest, Forwarder}, globals::Globals, - hyper_ext::body::{BoxBody, IncomingLike, IncomingOr}, + hyper_ext::body::{IncomingLike, IncomingOr, ResponseBody}, log::*, name_exp::ServerName, }; @@ -58,7 +58,7 @@ where listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, - ) -> RpxyResult>> { + ) -> RpxyResult> { // preparing log data let mut log_data = HttpMessageLog::from(&req); log_data.client_addr(&client_addr); @@ -99,7 +99,7 @@ where listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, - ) -> HttpResult>> { + ) -> HttpResult> { // Here we start to inspect and parse with server_name let server_name = req .inspect_parse_host() diff --git a/rpxy-lib/src/message_handler/synthetic_response.rs b/rpxy-lib/src/message_handler/synthetic_response.rs index 60aeeec..a955a2d 100644 --- a/rpxy-lib/src/message_handler/synthetic_response.rs +++ b/rpxy-lib/src/message_handler/synthetic_response.rs @@ -1,16 +1,16 @@ use super::http_result::{HttpError, HttpResult}; use crate::{ error::*, - hyper_ext::body::{empty, BoxBody, IncomingOr}, + hyper_ext::body::{empty, ResponseBody}, name_exp::ServerName, }; use http::{Request, Response, StatusCode, Uri}; /// build http response with status code of 4xx and 5xx -pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult>> { +pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult> { let res = Response::builder() .status(status_code) - .body(IncomingOr::Right(empty())) + .body(ResponseBody::Boxed(empty())) .unwrap(); Ok(res) } @@ -20,7 +20,7 @@ pub(super) fn secure_redirection_response( server_name: &ServerName, tls_port: Option, req: &Request, -) -> HttpResult>> { +) -> HttpResult> { let server_name: String = server_name.try_into().unwrap_or_default(); let pq = match req.uri().path_and_query() { Some(x) => x.as_str(), @@ -36,7 +36,7 @@ pub(super) fn secure_redirection_response( let response = Response::builder() .status(StatusCode::MOVED_PERMANENTLY) .header("Location", dest_uri.to_string()) - .body(IncomingOr::Right(empty())) + .body(ResponseBody::Boxed(empty())) .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; Ok(response) } diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 342c995..61328b2 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -138,7 +138,6 @@ where }); let new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); - // Response> wrapped by RpxyResult let res = self .message_handler .handle_request( diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 96ec0be..2d7a649 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -5,7 +5,7 @@ use crate::{ error::*, globals::Globals, hyper_ext::{ - body::{BoxBody, IncomingOr}, + body::{IncomingOr, ResponseBody}, rt::LocalExecutor, }, log::*, @@ -32,7 +32,7 @@ async fn serve_request( listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, -) -> RpxyResult>> +) -> RpxyResult> where T: Send + Sync + Connect + Clone, U: CryptoSource + Clone, From 1c18f3836a38d552b64b2441dd5f3e5e5e11bc9a Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 12 Dec 2023 20:17:13 +0900 Subject: [PATCH 37/50] wip: feat: change request body from either to explicit enum --- rpxy-lib/src/error.rs | 2 + rpxy-lib/src/forwarder/cache/cache_main.rs | 2 +- rpxy-lib/src/forwarder/mod.rs | 4 +- rpxy-lib/src/hyper_ext/body_type.rs | 51 ++++++++++++-------- rpxy-lib/src/hyper_ext/mod.rs | 2 +- rpxy-lib/src/message_handler/handler_main.rs | 6 +-- rpxy-lib/src/proxy/proxy_h3.rs | 15 ++---- rpxy-lib/src/proxy/proxy_main.rs | 4 +- 8 files changed, 45 insertions(+), 41 deletions(-) diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index f63a06c..3b1afc9 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -28,6 +28,8 @@ pub enum RpxyError { HyperIncomingLikeNewClosed, #[error("New body write aborted")] HyperNewBodyWriteAborted, + #[error("Hyper error in serving request or response body type: {0}")] + HyperBodyError(#[from] hyper::Error), // http/3 errors #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] diff --git a/rpxy-lib/src/forwarder/cache/cache_main.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs index 2bc4548..659ac41 100644 --- a/rpxy-lib/src/forwarder/cache/cache_main.rs +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -118,7 +118,7 @@ impl RpxyCache { .map(|f| { if f.is_data() { let data_bytes = f.data_ref().unwrap().clone(); - println!("ddddde"); + debug!("cache data bytes of {} bytes", data_bytes.len()) // TODO: cache data bytes as file or on memory // fileにするかmemoryにするかの判断はある程度までバッファしてやってという手を使うことになる。途中までキャッシュしたやつはどうするかとかいう判断も必要。 // ファイルとObjectのbindをどうやってするか diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs index d53cd73..26aa0c9 100644 --- a/rpxy-lib/src/forwarder/mod.rs +++ b/rpxy-lib/src/forwarder/mod.rs @@ -2,9 +2,9 @@ mod cache; mod client; -use crate::hyper_ext::body::{IncomingLike, IncomingOr}; +use crate::hyper_ext::body::RequestBody; -pub(crate) type Forwarder = client::Forwarder>; +pub(crate) type Forwarder = client::Forwarder; pub(crate) use client::ForwardRequest; #[cfg(feature = "cache")] diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs index c1eb54b..a143eac 100644 --- a/rpxy-lib/src/hyper_ext/body_type.rs +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -1,25 +1,11 @@ -// use http::Response; -use http_body_util::{combinators, BodyExt, Either, Empty, Full}; +use super::body::IncomingLike; +use crate::error::RpxyError; +use http_body_util::{combinators, BodyExt, Empty, Full}; use hyper::body::{Body, Bytes, Incoming}; use std::pin::Pin; /// Type for synthetic boxed body pub(crate) type BoxBody = combinators::BoxBody; -/// Type for either passthrough body or given body type, specifically synthetic boxed body -pub(crate) type IncomingOr = Either; - -// /// helper function to build http response with passthrough body -// pub(crate) fn wrap_incoming_body_response(response: Response) -> Response> -// where -// B: hyper::body::Body, -// { -// response.map(IncomingOr::Left) -// } - -// /// helper function to build http response with synthetic body -// pub(crate) fn wrap_synthetic_body_response(response: Response) -> Response> { -// response.map(IncomingOr::Right) -// } /// helper function to build a empty body pub(crate) fn empty() -> BoxBody { @@ -31,6 +17,30 @@ pub(crate) fn full(body: Bytes) -> BoxBody { Full::new(body).map_err(|never| match never {}).boxed() } +/* ------------------------------------ */ +/// Request body used in this project +/// - Incoming: just a type that only forwards the downstream request body to upstream. +/// - IncomingLike: a Incoming-like type in which channel is used +pub(crate) enum RequestBody { + Incoming(Incoming), + IncomingLike(IncomingLike), +} + +impl Body for RequestBody { + type Data = bytes::Bytes; + type Error = RpxyError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + match self.get_mut() { + RequestBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx).map_err(RpxyError::HyperBodyError), + RequestBody::IncomingLike(incoming_like) => Pin::new(incoming_like).poll_frame(cx), + } + } +} + /* ------------------------------------ */ #[cfg(feature = "cache")] use futures::channel::mpsc::UnboundedReceiver; @@ -44,8 +54,8 @@ pub(crate) type UnboundedStreamBody = StreamBody, @@ -68,5 +78,6 @@ impl Body for ResponseBody { #[cfg(feature = "cache")] ResponseBody::Streamed(streamed) => Pin::new(streamed).poll_frame(cx), } + .map_err(RpxyError::HyperBodyError) } } diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index 8b3776c..a4c5196 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -12,5 +12,5 @@ pub(crate) mod rt { #[allow(unused)] pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; - pub(crate) use super::body_type::{empty, full, BoxBody, IncomingOr, ResponseBody, UnboundedStreamBody}; + pub(crate) use super::body_type::{empty, full, BoxBody, RequestBody, ResponseBody, UnboundedStreamBody}; } diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index b5ae87d..ceb5db4 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -11,7 +11,7 @@ use crate::{ error::*, forwarder::{ForwardRequest, Forwarder}, globals::Globals, - hyper_ext::body::{IncomingLike, IncomingOr, ResponseBody}, + hyper_ext::body::{RequestBody, ResponseBody}, log::*, name_exp::ServerName, }; @@ -53,7 +53,7 @@ where /// Responsible to passthrough responses from backend applications or generate synthetic error responses. pub async fn handle_request( &self, - req: Request>, + req: Request, client_addr: SocketAddr, // For access control listen_addr: SocketAddr, tls_enabled: bool, @@ -94,7 +94,7 @@ where async fn handle_request_inner( &self, log_data: &mut HttpMessageLog, - mut req: Request>, + mut req: Request, client_addr: SocketAddr, // For access control listen_addr: SocketAddr, tls_enabled: bool, diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 61328b2..0295430 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -2,7 +2,7 @@ use super::proxy_main::Proxy; use crate::{ crypto::CryptoSource, error::*, - hyper_ext::body::{IncomingLike, IncomingOr}, + hyper_ext::body::{IncomingLike, RequestBody}, log::*, name_exp::ServerName, }; @@ -137,7 +137,7 @@ where Ok(()) as RpxyResult<()> }); - let new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); + let new_req: Request = Request::from_parts(req_parts, RequestBody::IncomingLike(req_body)); let res = self .message_handler .handle_request( @@ -155,6 +155,7 @@ where match send_stream.send_response(new_res).await { Ok(_) => { debug!("HTTP/3 response to connection successful"); + // on-demand body streaming to downstream without expanding the object onto memory. loop { let frame = match new_body.frame().await { Some(frame) => frame, @@ -175,16 +176,6 @@ where send_stream.send_trailers(trailers).await?; } } - // // aggregate body without copying - // let body_data = new_body - // .collect() - // .await - // .map_err(|e| RpxyError::HyperBodyManipulationError(e.to_string()))?; - - // // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes inside to_bytes() - // send_stream.send_data(body_data.to_bytes()).await?; - - // TODO: needs handling trailer? should be included in body from handler. } Err(err) => { error!("Unable to send response to connection peer: {:?}", err); diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 2d7a649..4fea840 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -5,7 +5,7 @@ use crate::{ error::*, globals::Globals, hyper_ext::{ - body::{IncomingOr, ResponseBody}, + body::{RequestBody, ResponseBody}, rt::LocalExecutor, }, log::*, @@ -39,7 +39,7 @@ where { handler .handle_request( - req.map(IncomingOr::Left), + req.map(RequestBody::Incoming), client_addr, listen_addr, tls_enabled, From 8dd6af6bc5d4dfb53959f7283f73689582ecab07 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 12 Dec 2023 22:15:34 +0900 Subject: [PATCH 38/50] wip: feat: refactored cache implementation for put --- rpxy-lib/src/forwarder/cache/cache_error.rs | 12 ++ rpxy-lib/src/forwarder/cache/cache_main.rs | 189 +++++++++++++------- rpxy-lib/src/forwarder/client.rs | 1 + 3 files changed, 142 insertions(+), 60 deletions(-) diff --git a/rpxy-lib/src/forwarder/cache/cache_error.rs b/rpxy-lib/src/forwarder/cache/cache_error.rs index bb2ffa6..5f6146a 100644 --- a/rpxy-lib/src/forwarder/cache/cache_error.rs +++ b/rpxy-lib/src/forwarder/cache/cache_error.rs @@ -15,6 +15,9 @@ pub enum CacheError { #[error("Failed to acquire mutex lock for cache")] FailedToAcquiredMutexLockForCache, + #[error("Failed to acquire mutex lock for check")] + FailedToAcquiredMutexLockForCheck, + #[error("Failed to create file cache")] FailedToCreateFileCache, @@ -32,4 +35,13 @@ pub enum CacheError { #[error("Failed to send frame to cache {0}")] FailedToSendFrameToCache(String), + + #[error("Failed to send frame from file cache {0}")] + FailedToSendFrameFromCache(String), + + #[error("Failed to remove cache file: {0}")] + FailedToRemoveCacheFile(String), + + #[error("Invalid cache target")] + InvalidCacheTarget, } diff --git a/rpxy-lib/src/forwarder/cache/cache_main.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs index 659ac41..c16f1d6 100644 --- a/rpxy-lib/src/forwarder/cache/cache_main.rs +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -1,14 +1,15 @@ use super::cache_error::*; use crate::{globals::Globals, hyper_ext::body::UnboundedStreamBody, log::*}; +use base64::{engine::general_purpose, Engine as _}; use bytes::{Buf, Bytes, BytesMut}; use futures::channel::mpsc; -use http::{Request, Response}; +use http::{Request, Response, Uri}; use http_body_util::{BodyExt, StreamBody}; use http_cache_semantics::CachePolicy; use hyper::body::{Body, Frame, Incoming}; use lru::LruCache; +use sha2::{Digest, Sha256}; use std::{ - convert::Infallible, path::{Path, PathBuf}, sync::{ atomic::{AtomicUsize, Ordering}, @@ -35,6 +36,8 @@ pub struct RpxyCache { max_each_size: usize, /// Maximum size of cache object on memory max_each_size_on_memory: usize, + /// Cache directory path + cache_dir: PathBuf, } impl RpxyCache { @@ -43,8 +46,8 @@ impl RpxyCache { if !globals.proxy_config.cache_enabled { return None; } - let path = globals.proxy_config.cache_dir.as_ref().unwrap(); - let file_store = FileStore::new(path, &globals.runtime_handle).await; + let cache_dir = globals.proxy_config.cache_dir.as_ref().unwrap(); + let file_store = FileStore::new(&globals.runtime_handle).await; let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry); let max_each_size = globals.proxy_config.cache_max_each_size; @@ -56,12 +59,18 @@ impl RpxyCache { max_each_size_on_memory = max_each_size; } + if let Err(e) = fs::remove_dir_all(cache_dir).await { + warn!("Failed to clean up the cache dir: {e}"); + }; + fs::create_dir_all(&cache_dir).await.unwrap(); + Some(Self { file_store, inner, runtime_handle: globals.runtime_handle.clone(), max_each_size, max_each_size_on_memory, + cache_dir: cache_dir.clone(), }) } @@ -80,17 +89,20 @@ impl RpxyCache { mut body: Incoming, policy: &CachePolicy, ) -> CacheResult { - let my_cache = self.inner.clone(); + let cache_manager = self.inner.clone(); let mut file_store = self.file_store.clone(); let uri = uri.clone(); let policy_clone = policy.clone(); let max_each_size = self.max_each_size; let max_each_size_on_memory = self.max_each_size_on_memory; + let cache_dir = self.cache_dir.clone(); let (body_tx, body_rx) = mpsc::unbounded::, hyper::Error>>(); self.runtime_handle.spawn(async move { let mut size = 0usize; + let mut buf = BytesMut::new(); + loop { let frame = match body.frame().await { Some(frame) => frame, @@ -118,10 +130,9 @@ impl RpxyCache { .map(|f| { if f.is_data() { let data_bytes = f.data_ref().unwrap().clone(); - debug!("cache data bytes of {} bytes", data_bytes.len()) - // TODO: cache data bytes as file or on memory - // fileにするかmemoryにするかの判断はある程度までバッファしてやってという手を使うことになる。途中までキャッシュしたやつはどうするかとかいう判断も必要。 - // ファイルとObjectのbindをどうやってするか + debug!("cache data bytes of {} bytes", data_bytes.len()); + // We do not use stream-type buffering since it needs to lock file during operation. + buf.extend(data_bytes.as_ref()); } }) .map_err(|e| CacheError::FailedToCacheBytes(e.to_string()))?; @@ -132,6 +143,35 @@ impl RpxyCache { .map_err(|e| CacheError::FailedToSendFrameToCache(e.to_string()))?; } + let buf = buf.freeze(); + // Calculate hash of the cached data, after all data is received. + // In-operation calculation is possible but it blocks sending data. + let mut hasher = Sha256::new(); + hasher.update(buf.as_ref()); + let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); + debug!("Cached data: {} bytes, hash = {:?}", size, hash_bytes); + + // Create cache object + let cache_key = derive_cache_key_from_uri(&uri); + let cache_object = CacheObject { + policy: policy_clone, + target: CacheFileOrOnMemory::build(&cache_dir, &uri, &buf, max_each_size_on_memory), + hash: hash_bytes, + }; + + if let Some((k, v)) = cache_manager.push(&cache_key, &cache_object)? { + if k != cache_key { + info!("Over the cache capacity. Evict least recent used entry"); + if let CacheFileOrOnMemory::File(path) = v.target { + file_store.evict(&path).await; + } + } + } + // store cache object to file + if let CacheFileOrOnMemory::File(_) = cache_object.target { + file_store.create(&cache_object, &buf).await?; + } + Ok(()) as CacheResult<()> }); @@ -145,36 +185,35 @@ impl RpxyCache { #[derive(Debug, Clone)] /// Cache file manager outer that is responsible to handle `RwLock` struct FileStore { + /// Inner file store main object inner: Arc>, } impl FileStore { /// Build manager - async fn new(path: impl AsRef, runtime_handle: &tokio::runtime::Handle) -> Self { + async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { Self { - inner: Arc::new(RwLock::new(FileStoreInner::new(path, runtime_handle).await)), + inner: Arc::new(RwLock::new(FileStoreInner::new(runtime_handle).await)), } } -} -impl FileStore { /// Count file cache entries async fn count(&self) -> usize { let inner = self.inner.read().await; inner.cnt } /// Create a temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> CacheResult { + async fn create(&mut self, ref cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { let mut inner = self.inner.write().await; - inner.create(cache_filename, body_bytes).await + inner.create(cache_object, body_bytes).await + } + /// Evict a temporary file cache + async fn evict(&self, path: impl AsRef) { + // Acquire the write lock + let mut inner = self.inner.write().await; + if let Err(e) = inner.remove(path).await { + warn!("Eviction failed during file object removal: {:?}", e); + }; } - // /// Evict a temporary file cache - // async fn evict(&self, path: impl AsRef) { - // // Acquire the write lock - // let mut inner = self.inner.write().await; - // if let Err(e) = inner.remove(path).await { - // warn!("Eviction failed during file object removal: {:?}", e); - // }; - // } // /// Read a temporary file cache // async fn read(&self, path: impl AsRef) -> CacheResult { // let inner = self.inner.read().await; @@ -185,8 +224,6 @@ impl FileStore { #[derive(Debug)] /// Manager inner for cache on file system struct FileStoreInner { - /// Directory of temporary files - cache_dir: PathBuf, /// Counter of current cached files cnt: usize, /// Async runtime @@ -197,22 +234,21 @@ impl FileStoreInner { /// Build new cache file manager. /// This first creates cache file dir if not exists, and cleans up the file inside the directory. /// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed. - async fn new(path: impl AsRef, runtime_handle: &tokio::runtime::Handle) -> Self { - let path_buf = path.as_ref().to_path_buf(); - if let Err(e) = fs::remove_dir_all(path).await { - warn!("Failed to clean up the cache dir: {e}"); - }; - fs::create_dir_all(&path_buf).await.unwrap(); + async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { Self { - cache_dir: path_buf.clone(), cnt: 0, runtime_handle: runtime_handle.clone(), } } /// Create a new temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> CacheResult { - let cache_filepath = self.cache_dir.join(cache_filename); + async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { + let cache_filepath = match cache_object.target { + CacheFileOrOnMemory::File(ref path) => path.clone(), + CacheFileOrOnMemory::OnMemory(_) => { + return Err(CacheError::InvalidCacheTarget); + } + }; let Ok(mut file) = File::create(&cache_filepath).await else { return Err(CacheError::FailedToCreateFileCache); }; @@ -224,50 +260,47 @@ impl FileStoreInner { }; } self.cnt += 1; - Ok(CacheFileOrOnMemory::File(cache_filepath)) + Ok(()) } /// Retrieve a stored temporary file cache - async fn read(&self, path: impl AsRef) -> CacheResult<()> { + async fn read(&self, path: impl AsRef) -> CacheResult { let Ok(mut file) = File::open(&path).await else { warn!("Cache file object cannot be opened"); return Err(CacheError::FailedToOpenCacheFile); }; - /* ----------------------------- */ - // PoC for streaming body - let (tx, rx) = mpsc::unbounded::, Infallible>>(); + let (body_tx, body_rx) = mpsc::unbounded::, hyper::Error>>(); - // let (body_sender, res_body) = Body::channel(); self.runtime_handle.spawn(async move { // let mut sender = body_sender; let mut buf = BytesMut::new(); loop { match file.read_buf(&mut buf).await { Ok(0) => break, - Ok(_) => tx - .unbounded_send(Ok(hyper::body::Frame::data(buf.copy_to_bytes(buf.remaining())))) - .map_err(|e| anyhow::anyhow!("Failed to read cache file: {e}"))?, - //sender.send_data(buf.copy_to_bytes(buf.remaining())).await?, + Ok(_) => body_tx + .unbounded_send(Ok(Frame::data(buf.copy_to_bytes(buf.remaining())))) + .map_err(|e| CacheError::FailedToSendFrameFromCache(e.to_string()))?, Err(_) => break, }; } - Ok(()) as anyhow::Result<()> + Ok(()) as CacheResult<()> }); - let mut rx = http_body_util::StreamBody::new(rx); - // TODO: 結局incominglikeなbodystreamを定義することになる。これだったらh3と合わせて自分で定義した方が良さそう。 - // typeが長すぎるのでwrapperを作った方がいい。 - // let response = Response::builder() - // .status(200) - // .header("content-type", "application/octet-stream") - // .body(rx) - // .unwrap(); + let stream_body = StreamBody::new(body_rx); - todo!() - /* ----------------------------- */ + Ok(stream_body) + } - // Ok(res_body) + /// Remove file + async fn remove(&mut self, path: impl AsRef) -> CacheResult<()> { + fs::remove_file(path.as_ref()) + .await + .map_err(|e| CacheError::FailedToRemoveCacheFile(e.to_string()))?; + self.cnt -= 1; + debug!("Removed a cache file at {:?} (file count: {})", path.as_ref(), self.cnt); + + Ok(()) } } @@ -279,7 +312,20 @@ pub enum CacheFileOrOnMemory { /// Pointer to the temporary cache file File(PathBuf), /// Cached body itself - OnMemory(Vec), + OnMemory(Bytes), +} + +impl CacheFileOrOnMemory { + /// Get cache object target + fn build(cache_dir: &Path, uri: &Uri, object: &Bytes, max_each_size_on_memory: usize) -> Self { + if object.len() > max_each_size_on_memory { + let cache_filename = derive_filename_from_uri(uri); + let cache_filepath = cache_dir.join(cache_filename); + CacheFileOrOnMemory::File(cache_filepath) + } else { + CacheFileOrOnMemory::OnMemory(object.clone()) + } + } } #[derive(Clone, Debug)] @@ -290,7 +336,7 @@ struct CacheObject { /// Cache target: on-memory object or temporary file pub target: CacheFileOrOnMemory, /// SHA256 hash of target to strongly bind the cache metadata (this object) and file target - pub hash: Vec, + pub hash: Bytes, } /* ---------------------------------------------- */ @@ -332,16 +378,28 @@ impl LruCacheManager { } /// Push an entry - fn push(&self, cache_key: &str, cache_object: CacheObject) -> CacheResult> { + fn push(&self, cache_key: &str, cache_object: &CacheObject) -> CacheResult> { let Ok(mut lock) = self.inner.lock() else { error!("Failed to acquire mutex lock for writing cache entry"); return Err(CacheError::FailedToAcquiredMutexLockForCache); }; - let res = Ok(lock.push(cache_key.to_string(), cache_object)); + let res = Ok(lock.push(cache_key.to_string(), cache_object.clone())); // This may be inconsistent with the actual number of entries self.cnt.store(lock.len(), Ordering::Relaxed); res } + + /// Get an entry + fn get(&self, cache_key: &str) -> CacheResult> { + let Ok(mut lock) = self.inner.lock() else { + error!("Mutex can't be locked for checking cache entry"); + return Err(CacheError::FailedToAcquiredMutexLockForCheck); + }; + let Some(cached_object) = lock.get(cache_key) else { + return Ok(None); + }; + Ok(Some(cached_object.clone())) + } } /* ---------------------------------------------- */ @@ -366,3 +424,14 @@ pub fn get_policy_if_cacheable( Ok(None) } } + +fn derive_filename_from_uri(uri: &hyper::Uri) -> String { + let mut hasher = Sha256::new(); + hasher.update(uri.to_string()); + let digest = hasher.finalize(); + general_purpose::URL_SAFE_NO_PAD.encode(digest) +} + +fn derive_cache_key_from_uri(uri: &hyper::Uri) -> String { + uri.to_string() +} diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 8d2e307..c6f1ca9 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -47,6 +47,7 @@ where { let mut synth_req = None; if self.cache.is_some() { + // TODO: try reading from cache // if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { // // if found, return it as response. // info!("Cache hit - Return from cache"); From bd29c9dc1da19918bc9e253b20435787a3346ff9 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 12 Dec 2023 22:50:24 +0900 Subject: [PATCH 39/50] wip: feat: implemented cache --- rpxy-lib/src/backend/mod.rs | 7 +- rpxy-lib/src/forwarder/cache/cache_error.rs | 3 + rpxy-lib/src/forwarder/cache/cache_main.rs | 114 +++++++++++++++++--- rpxy-lib/src/forwarder/client.rs | 12 +-- 4 files changed, 113 insertions(+), 23 deletions(-) diff --git a/rpxy-lib/src/backend/mod.rs b/rpxy-lib/src/backend/mod.rs index 788960d..097810a 100644 --- a/rpxy-lib/src/backend/mod.rs +++ b/rpxy-lib/src/backend/mod.rs @@ -3,10 +3,11 @@ mod load_balance; mod upstream; mod upstream_opts; -// #[cfg(feature = "sticky-cookie")] -// pub use self::load_balance::{StickyCookie, StickyCookieValue}; +#[cfg(feature = "sticky-cookie")] +pub(crate) use self::load_balance::{StickyCookie, StickyCookieValue}; +#[allow(unused)] pub(crate) use self::{ - load_balance::{LoadBalance, LoadBalanceContext, StickyCookie, StickyCookieValue}, + load_balance::{LoadBalance, LoadBalanceContext}, upstream::{PathManager, Upstream, UpstreamCandidates}, upstream_opts::UpstreamOption, }; diff --git a/rpxy-lib/src/forwarder/cache/cache_error.rs b/rpxy-lib/src/forwarder/cache/cache_error.rs index 5f6146a..35eae83 100644 --- a/rpxy-lib/src/forwarder/cache/cache_error.rs +++ b/rpxy-lib/src/forwarder/cache/cache_error.rs @@ -44,4 +44,7 @@ pub enum CacheError { #[error("Invalid cache target")] InvalidCacheTarget, + + #[error("Hash mismatched in cache file")] + HashMismatchedInCacheFile, } diff --git a/rpxy-lib/src/forwarder/cache/cache_main.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs index c16f1d6..3c85d0c 100644 --- a/rpxy-lib/src/forwarder/cache/cache_main.rs +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -1,12 +1,16 @@ use super::cache_error::*; -use crate::{globals::Globals, hyper_ext::body::UnboundedStreamBody, log::*}; +use crate::{ + globals::Globals, + hyper_ext::body::{full, BoxBody, ResponseBody, UnboundedStreamBody}, + log::*, +}; use base64::{engine::general_purpose, Engine as _}; use bytes::{Buf, Bytes, BytesMut}; use futures::channel::mpsc; use http::{Request, Response, Uri}; use http_body_util::{BodyExt, StreamBody}; use http_cache_semantics::CachePolicy; -use hyper::body::{Body, Frame, Incoming}; +use hyper::body::{Frame, Incoming}; use lru::LruCache; use sha2::{Digest, Sha256}; use std::{ @@ -15,6 +19,7 @@ use std::{ atomic::{AtomicUsize, Ordering}, Arc, Mutex, }, + time::SystemTime, }; use tokio::{ fs::{self, File}, @@ -179,6 +184,66 @@ impl RpxyCache { Ok(stream_body) } + + /// Get cached response + pub async fn get(&self, req: &Request) -> Option> { + debug!( + "Current cache status: (total, on-memory, file) = {:?}", + self.count().await + ); + let cache_key = derive_cache_key_from_uri(req.uri()); + + // First check cache chance + let Ok(Some(cached_object)) = self.inner.get(&cache_key) else { + return None; + }; + + // Secondly check the cache freshness as an HTTP message + let now = SystemTime::now(); + let http_cache_semantics::BeforeRequest::Fresh(res_parts) = cached_object.policy.before_request(req, now) else { + // Evict stale cache entry. + // This might be okay to keep as is since it would be updated later. + // However, there is no guarantee that newly got objects will be still cacheable. + // So, we have to evict stale cache entries and cache file objects if found. + debug!("Stale cache entry: {cache_key}"); + let _evicted_entry = self.inner.evict(&cache_key); + // For cache file + if let CacheFileOrOnMemory::File(path) = &cached_object.target { + self.file_store.evict(&path).await; + } + return None; + }; + + // Finally retrieve the file/on-memory object + let response_body = match cached_object.target { + CacheFileOrOnMemory::File(path) => { + let stream_body = match self.file_store.read(path.clone(), &cached_object.hash).await { + Ok(s) => s, + Err(e) => { + warn!("Failed to read from file cache: {e}"); + let _evicted_entry = self.inner.evict(&cache_key); + self.file_store.evict(path).await; + return None; + } + }; + debug!("Cache hit from file: {cache_key}"); + ResponseBody::Streamed(stream_body) + } + CacheFileOrOnMemory::OnMemory(object) => { + debug!("Cache hit from on memory: {cache_key}"); + let mut hasher = Sha256::new(); + hasher.update(object.as_ref()); + let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); + if hash_bytes != cached_object.hash { + warn!("Hash mismatched. Cache object is corrupted"); + let _evicted_entry = self.inner.evict(&cache_key); + return None; + } + ResponseBody::Boxed(BoxBody::new(full(object))) + } + }; + Some(Response::from_parts(res_parts, response_body)) + } } /* ---------------------------------------------- */ @@ -202,7 +267,7 @@ impl FileStore { inner.cnt } /// Create a temporary file cache - async fn create(&mut self, ref cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { + async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { let mut inner = self.inner.write().await; inner.create(cache_object, body_bytes).await } @@ -214,14 +279,18 @@ impl FileStore { warn!("Eviction failed during file object removal: {:?}", e); }; } - // /// Read a temporary file cache - // async fn read(&self, path: impl AsRef) -> CacheResult { - // let inner = self.inner.read().await; - // inner.read(&path).await - // } + /// Read a temporary file cache + async fn read( + &self, + path: impl AsRef + Send + Sync + 'static, + hash: &Bytes, + ) -> CacheResult { + let inner = self.inner.read().await; + inner.read(path, hash).await + } } -#[derive(Debug)] +#[derive(Debug, Clone)] /// Manager inner for cache on file system struct FileStoreInner { /// Counter of current cached files @@ -264,26 +333,43 @@ impl FileStoreInner { } /// Retrieve a stored temporary file cache - async fn read(&self, path: impl AsRef) -> CacheResult { + async fn read( + &self, + path: impl AsRef + Send + Sync + 'static, + hash: &Bytes, + ) -> CacheResult { let Ok(mut file) = File::open(&path).await else { warn!("Cache file object cannot be opened"); return Err(CacheError::FailedToOpenCacheFile); }; + let hash_clone = hash.clone(); + let mut self_clone = self.clone(); let (body_tx, body_rx) = mpsc::unbounded::, hyper::Error>>(); self.runtime_handle.spawn(async move { - // let mut sender = body_sender; + let mut hasher = Sha256::new(); let mut buf = BytesMut::new(); loop { match file.read_buf(&mut buf).await { Ok(0) => break, - Ok(_) => body_tx - .unbounded_send(Ok(Frame::data(buf.copy_to_bytes(buf.remaining())))) - .map_err(|e| CacheError::FailedToSendFrameFromCache(e.to_string()))?, + Ok(_) => { + let bytes = buf.copy_to_bytes(buf.remaining()); + hasher.update(bytes.as_ref()); + body_tx + .unbounded_send(Ok(Frame::data(bytes))) + .map_err(|e| CacheError::FailedToSendFrameFromCache(e.to_string()))? + } Err(_) => break, }; } + let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); + if hash_bytes != hash_clone { + warn!("Hash mismatched. Cache object is corrupted. Force to remove the cache file."); + // only file can be evicted + let _evicted_entry = self_clone.remove(&path).await; + return Err(CacheError::HashMismatchedInCacheFile); + } Ok(()) as CacheResult<()> }); diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index c6f1ca9..c8a8ec7 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -47,12 +47,12 @@ where { let mut synth_req = None; if self.cache.is_some() { - // TODO: try reading from cache - // if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { - // // if found, return it as response. - // info!("Cache hit - Return from cache"); - // return Ok(cached_response); - // }; + // try reading from cache + if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { + // if found, return it as response. + info!("Cache hit - Return from cache"); + return Ok(cached_response); + }; // Synthetic request copy used just for caching (cannot clone request object...) synth_req = Some(build_synth_req_for_cache(&req)); From 92638ccd2aa98f4d61b4ba88fcaef7387017b54a Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 12 Dec 2023 22:59:58 +0900 Subject: [PATCH 40/50] wip: update changelog and todo --- CHANGELOG.md | 4 ++++ TODO.md | 1 + 2 files changed, 5 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index de8871f..c0b2649 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ - Breaking: `hyper`-1.0 for both server and client modules. - Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `disable_override_host`. That is, `rpxy` always override the host header by the upstream hostname by default. +- Breaking: Introduced `hyper-tls-backend` feature to use the native TLS engine to access backend applications. +- Redesigned: Cache structure is totally redesigned with more memory-efficient way to read from cache file, and more secure way to strongly bind memory-objects with files with hash values. +- Redesigned: HTTP body handling flow is also redesigned with more memory-and-time efficient techniques without putting the whole objects on memory by using `futures::stream::Stream` and `futures::channel::mpsc` +- Refactor: lots of minor improvements ## 0.6.2 diff --git a/TODO.md b/TODO.md index 1e25ee1..1a159fd 100644 --- a/TODO.md +++ b/TODO.md @@ -4,6 +4,7 @@ - [Initial implementation in v0.6.0] ~~**Cache option for the response with `Cache-Control: public` header directive ([#55](https://github.com/junkurihara/rust-rpxy/issues/55))**~~ Using `lru` crate might be inefficient in terms of the speed. - Consider more sophisticated architecture for cache - Persistent cache (if possible). + - More secure cache file object naming - etc etc - Improvement of path matcher - More flexible option for rewriting path From f41a2213f9d5467ddc6e9156835ed917044e0342 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 13:22:57 +0900 Subject: [PATCH 41/50] preparing nightly-build --- .dockerignore | 1 + .github/workflows/release.yml | 56 +++++++-------- .github/workflows/release_docker.yml | 70 +++++++++---------- Cargo.toml | 2 +- rpxy-bin/Cargo.toml | 10 +-- rpxy-lib/Cargo.toml | 15 ++-- rpxy-lib/src/forwarder/cache/cache_error.rs | 5 +- rpxy-lib/src/forwarder/cache/cache_main.rs | 20 +++--- rpxy-lib/src/forwarder/cache/mod.rs | 2 +- rpxy-lib/src/forwarder/client.rs | 5 +- rpxy-lib/src/hyper_ext/body_type.rs | 26 +++---- .../handler_manipulate_messages.rs | 1 + rpxy-lib/src/proxy/mod.rs | 6 +- 13 files changed, 109 insertions(+), 110 deletions(-) diff --git a/.dockerignore b/.dockerignore index 3538235..4294ee6 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,3 +4,4 @@ bench/ .private/ .github/ example-certs/ +legacy-lib/ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2b04184..cee7007 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -44,35 +44,35 @@ jobs: platform: linux/arm64 tags-suffix: "-s2n" - - target: "gnu" - build-feature: "-native-roots" - platform: linux/amd64 - tags-suffix: "-native-roots" + # - target: "gnu" + # build-feature: "-native-roots" + # platform: linux/amd64 + # tags-suffix: "-native-roots" - - target: "gnu" - build-feature: "-native-roots" - platform: linux/arm64 - tags-suffix: "-native-roots" + # - target: "gnu" + # build-feature: "-native-roots" + # platform: linux/arm64 + # tags-suffix: "-native-roots" - - target: "musl" - build-feature: "-native-roots" - platform: linux/amd64 - tags-suffix: "-slim-native-roots" + # - target: "musl" + # build-feature: "-native-roots" + # platform: linux/amd64 + # tags-suffix: "-slim-native-roots" - - target: "musl" - build-feature: "-native-roots" - platform: linux/arm64 - tags-suffix: "-slim-native-roots" + # - target: "musl" + # build-feature: "-native-roots" + # platform: linux/arm64 + # tags-suffix: "-slim-native-roots" - - target: "gnu" - build-feature: "-s2n-native-roots" - platform: linux/amd64 - tags-suffix: "-s2n-native-roots" + # - target: "gnu" + # build-feature: "-s2n-native-roots" + # platform: linux/amd64 + # tags-suffix: "-s2n-native-roots" - - target: "gnu" - build-feature: "-s2n-native-roots" - platform: linux/arm64 - tags-suffix: "-s2n-native-roots" + # - target: "gnu" + # build-feature: "-s2n-native-roots" + # platform: linux/arm64 + # tags-suffix: "-s2n-native-roots" steps: - run: "echo 'The relese triggering workflows passed'" @@ -81,8 +81,8 @@ jobs: id: "set-env" run: | if [ ${{ matrix.platform }} == 'linux/amd64' ]; then PLATFORM_MAP="x86_64"; else PLATFORM_MAP="aarch64"; fi - if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_NAME="-nightly"; else BUILD_NAME=""; fi - if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_IMG="nightly"; else BUILD_IMG="latest"; fi + if [ ${{ github.ref_name == 'develop' && github.event.client_payload.pull_request.head == 'develop' && github.event.client_payload.pull_request.base == 'main' }} || ${{ github.ref_name == 'main' }}]; then BUILD_NAME=""; else BUILD_NAME="-nightly"; fi + if [ ${{ github.ref_name }} == 'main' ]; then BUILD_IMG="latest"; else BUILD_IMG="nightly"; fi echo "build_img=${BUILD_IMG}" >> $GITHUB_OUTPUT echo "target_name=rpxy${BUILD_NAME}-${PLATFORM_MAP}-unknown-linux-${{ matrix.target }}${{ matrix.build-feature }}" >> $GITHUB_OUTPUT @@ -93,7 +93,7 @@ jobs: docker cp ${CONTAINER_ID}:/rpxy/bin/rpxy /tmp/${{ steps.set-env.outputs.target_name }} - name: "upload artifacts" - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ steps.set-env.outputs.target_name }} path: "/tmp/${{ steps.set-env.outputs.target_name }}" @@ -122,7 +122,7 @@ jobs: - name: download artifacts if: ${{ steps.regex-match.outputs.match != ''}} - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: path: /tmp/rpxy diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml index 60dd7ff..391d7cd 100644 --- a/.github/workflows/release_docker.yml +++ b/.github/workflows/release_docker.yml @@ -44,7 +44,7 @@ jobs: - target: "s2n" dockerfile: ./docker/Dockerfile build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache" + "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-tls-backend" "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" platforms: linux/amd64,linux/arm64 tags-suffix: "-s2n" @@ -53,42 +53,42 @@ jobs: jqtype/rpxy:s2n ghcr.io/junkurihara/rust-rpxy:s2n - - target: "native-roots" - dockerfile: ./docker/Dockerfile - platforms: linux/amd64,linux/arm64 - build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" - tags-suffix: "-native-roots" - # Aliases must be used only for release builds - aliases: | - jqtype/rpxy:native-roots - ghcr.io/junkurihara/rust-rpxy:native-roots + # - target: "native-roots" + # dockerfile: ./docker/Dockerfile + # platforms: linux/amd64,linux/arm64 + # build-args: | + # "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" + # tags-suffix: "-native-roots" + # # Aliases must be used only for release builds + # aliases: | + # jqtype/rpxy:native-roots + # ghcr.io/junkurihara/rust-rpxy:native-roots - - target: "slim-native-roots" - dockerfile: ./docker/Dockerfile-slim - build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" - build-contexts: | - messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl - messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl - platforms: linux/amd64,linux/arm64 - tags-suffix: "-slim-native-roots" - # Aliases must be used only for release builds - aliases: | - jqtype/rpxy:slim-native-roots - ghcr.io/junkurihara/rust-rpxy:slim-native-roots + # - target: "slim-native-roots" + # dockerfile: ./docker/Dockerfile-slim + # build-args: | + # "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" + # build-contexts: | + # messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl + # messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl + # platforms: linux/amd64,linux/arm64 + # tags-suffix: "-slim-native-roots" + # # Aliases must be used only for release builds + # aliases: | + # jqtype/rpxy:slim-native-roots + # ghcr.io/junkurihara/rust-rpxy:slim-native-roots - - target: "s2n-native-roots" - dockerfile: ./docker/Dockerfile - build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-roots" - "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" - platforms: linux/amd64,linux/arm64 - tags-suffix: "-s2n-native-roots" - # Aliases must be used only for release builds - aliases: | - jqtype/rpxy:s2n-native-roots - ghcr.io/junkurihara/rust-rpxy:s2n-native-roots + # - target: "s2n-native-roots" + # dockerfile: ./docker/Dockerfile + # build-args: | + # "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-roots" + # "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" + # platforms: linux/amd64,linux/arm64 + # tags-suffix: "-s2n-native-roots" + # # Aliases must be used only for release builds + # aliases: | + # jqtype/rpxy:s2n-native-roots + # ghcr.io/junkurihara/rust-rpxy:s2n-native-roots steps: - name: Checkout diff --git a/Cargo.toml b/Cargo.toml index 7868088..c512b18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] -members = ["rpxy-bin", "rpxy-lib", "legacy-lib"] +members = ["rpxy-bin", "rpxy-lib"] exclude = ["submodules"] resolver = "2" diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index a6e5720..6ec94f1 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rpxy" -version = "0.7.0" +version = "0.7.0-alpha.0" authors = ["Jun Kurihara"] homepage = "https://github.com/junkurihara/rust-rpxy" repository = "https://github.com/junkurihara/rust-rpxy" @@ -12,23 +12,25 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "cache"] +default = ["http3-quinn", "cache", "native-tls-backend"] http3-quinn = ["rpxy-lib/http3-quinn"] http3-s2n = ["rpxy-lib/http3-s2n"] +native-tls-backend = ["rpxy-lib/native-tls-backend"] +# Not yet implemented +rustls-backend = ["rpxy-lib/rustls-backend"] cache = ["rpxy-lib/cache"] native-roots = ["rpxy-lib/native-roots"] [dependencies] rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ "sticky-cookie", - "native-tls-backend", ] } anyhow = "1.0.75" rustc-hash = "1.1.0" serde = { version = "1.0.193", default-features = false, features = ["derive"] } derive_builder = "0.12.0" -tokio = { version = "1.34.0", default-features = false, features = [ +tokio = { version = "1.35.0", default-features = false, features = [ "net", "rt-multi-thread", "time", diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index f30f4bb..22e091f 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rpxy-lib" -version = "0.7.0" +version = "0.7.0-alpha.0" authors = ["Jun Kurihara"] homepage = "https://github.com/junkurihara/rust-rpxy" repository = "https://github.com/junkurihara/rust-rpxy" @@ -21,11 +21,11 @@ http3-s2n = [ "s2n-quic-rustls", "s2n-quic-h3", ] +cache = ["http-cache-semantics", "lru", "sha2", "base64"] sticky-cookie = ["base64", "sha2", "chrono"] native-tls-backend = ["hyper-tls"] -rustls-backend = [] -cache = ["http-cache-semantics", "lru", "sha2", "base64"] -native-roots = [] #"hyper-rustls/native-tokio"] +rustls-backend = [] # not implemented yet +native-roots = [] #"hyper-rustls/native-tokio"] # not implemented yet [dependencies] rand = "0.8.5" @@ -33,7 +33,7 @@ rustc-hash = "1.1.0" bytes = "1.5.0" derive_builder = "0.12.0" futures = { version = "0.3.29", features = ["alloc", "async-await"] } -tokio = { version = "1.34.0", default-features = false, features = [ +tokio = { version = "1.35.0", default-features = false, features = [ "net", "rt-multi-thread", "time", @@ -57,7 +57,10 @@ futures-util = { version = "0.3.29", default-features = false } futures-channel = { version = "0.3.29", default-features = false } # http client for upstream -hyper-tls = { version = "0.6.0", features = ["alpn"], optional = true } +hyper-tls = { version = "0.6.0", features = [ + "alpn", + "vendored", +], optional = true } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio", diff --git a/rpxy-lib/src/forwarder/cache/cache_error.rs b/rpxy-lib/src/forwarder/cache/cache_error.rs index 35eae83..341c928 100644 --- a/rpxy-lib/src/forwarder/cache/cache_error.rs +++ b/rpxy-lib/src/forwarder/cache/cache_error.rs @@ -1,6 +1,6 @@ use thiserror::Error; -pub type CacheResult = std::result::Result; +pub(crate) type CacheResult = std::result::Result; /// Describes things that can go wrong in the Rpxy #[derive(Debug, Error)] @@ -9,9 +9,6 @@ pub enum CacheError { #[error("Invalid null request and/or response")] NullRequestOrResponse, - #[error("Failed to write byte buffer")] - FailedToWriteByteBufferForCache, - #[error("Failed to acquire mutex lock for cache")] FailedToAcquiredMutexLockForCache, diff --git a/rpxy-lib/src/forwarder/cache/cache_main.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs index 3c85d0c..f3fc463 100644 --- a/rpxy-lib/src/forwarder/cache/cache_main.rs +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -30,7 +30,7 @@ use tokio::{ /* ---------------------------------------------- */ #[derive(Clone, Debug)] /// Cache main manager -pub struct RpxyCache { +pub(crate) struct RpxyCache { /// Inner lru cache manager storing http message caching policy inner: LruCacheManager, /// Managing cache file objects through RwLock's lock mechanism for file lock @@ -47,7 +47,7 @@ pub struct RpxyCache { impl RpxyCache { /// Generate cache storage - pub async fn new(globals: &Globals) -> Option { + pub(crate) async fn new(globals: &Globals) -> Option { if !globals.proxy_config.cache_enabled { return None; } @@ -80,7 +80,7 @@ impl RpxyCache { } /// Count cache entries - pub async fn count(&self) -> (usize, usize, usize) { + pub(crate) async fn count(&self) -> (usize, usize, usize) { let total = self.inner.count(); let file = self.file_store.count().await; let on_memory = total - file; @@ -88,7 +88,7 @@ impl RpxyCache { } /// Put response into the cache - pub async fn put( + pub(crate) async fn put( &self, uri: &hyper::Uri, mut body: Incoming, @@ -186,7 +186,7 @@ impl RpxyCache { } /// Get cached response - pub async fn get(&self, req: &Request) -> Option> { + pub(crate) async fn get(&self, req: &Request) -> Option> { debug!( "Current cache status: (total, on-memory, file) = {:?}", self.count().await @@ -394,7 +394,7 @@ impl FileStoreInner { #[derive(Clone, Debug)] /// Cache target in hybrid manner of on-memory and file system -pub enum CacheFileOrOnMemory { +pub(crate) enum CacheFileOrOnMemory { /// Pointer to the temporary cache file File(PathBuf), /// Cached body itself @@ -418,11 +418,11 @@ impl CacheFileOrOnMemory { /// Cache object definition struct CacheObject { /// Cache policy to determine if the stored cache can be used as a response to a new incoming request - pub policy: CachePolicy, + policy: CachePolicy, /// Cache target: on-memory object or temporary file - pub target: CacheFileOrOnMemory, + target: CacheFileOrOnMemory, /// SHA256 hash of target to strongly bind the cache metadata (this object) and file target - pub hash: Bytes, + hash: Bytes, } /* ---------------------------------------------- */ @@ -490,7 +490,7 @@ impl LruCacheManager { /* ---------------------------------------------- */ /// Generate cache policy if the response is cacheable -pub fn get_policy_if_cacheable( +pub(crate) fn get_policy_if_cacheable( req: Option<&Request>, res: Option<&Response>, ) -> CacheResult> diff --git a/rpxy-lib/src/forwarder/cache/mod.rs b/rpxy-lib/src/forwarder/cache/mod.rs index cfe5a1b..076eaa3 100644 --- a/rpxy-lib/src/forwarder/cache/mod.rs +++ b/rpxy-lib/src/forwarder/cache/mod.rs @@ -2,4 +2,4 @@ mod cache_error; mod cache_main; pub use cache_error::CacheError; -pub use cache_main::{get_policy_if_cacheable, CacheFileOrOnMemory, RpxyCache}; +pub(crate) use cache_main::{get_policy_if_cacheable, RpxyCache}; diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index c8a8ec7..26c2276 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -121,7 +121,7 @@ where ::Error: Into>, { /// Build inner client with http - pub fn try_new(_globals: &Arc) -> RpxyResult { + pub async fn try_new(_globals: &Arc) -> RpxyResult { warn!( " -------------------------------------------------------------------------------------------------- @@ -134,6 +134,7 @@ Please enable native-tls-backend or rustls-backend feature to enable TLS support let mut http = HttpConnector::new(); http.set_reuse_address(true); let inner = Client::builder(executor).build::<_, B>(http); + let inner_h2 = inner.clone(); Ok(Self { inner, @@ -191,7 +192,7 @@ where #[cfg(feature = "rustls-backend")] /// Build forwarder with hyper-rustls (rustls) -impl Forwarder, B1> +impl Forwarder where B1: Body + Send + Unpin + 'static, ::Data: Send, diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs index a143eac..ca44756 100644 --- a/rpxy-lib/src/hyper_ext/body_type.rs +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -1,11 +1,12 @@ use super::body::IncomingLike; use crate::error::RpxyError; -use http_body_util::{combinators, BodyExt, Empty, Full}; -use hyper::body::{Body, Bytes, Incoming}; +use futures::channel::mpsc::UnboundedReceiver; +use http_body_util::{combinators, BodyExt, Empty, Full, StreamBody}; +use hyper::body::{Body, Bytes, Frame, Incoming}; use std::pin::Pin; /// Type for synthetic boxed body -pub(crate) type BoxBody = combinators::BoxBody; +pub type BoxBody = combinators::BoxBody; /// helper function to build a empty body pub(crate) fn empty() -> BoxBody { @@ -17,11 +18,12 @@ pub(crate) fn full(body: Bytes) -> BoxBody { Full::new(body).map_err(|never| match never {}).boxed() } +#[allow(unused)] /* ------------------------------------ */ /// Request body used in this project /// - Incoming: just a type that only forwards the downstream request body to upstream. /// - IncomingLike: a Incoming-like type in which channel is used -pub(crate) enum RequestBody { +pub enum RequestBody { Incoming(Incoming), IncomingLike(IncomingLike), } @@ -42,24 +44,16 @@ impl Body for RequestBody { } /* ------------------------------------ */ -#[cfg(feature = "cache")] -use futures::channel::mpsc::UnboundedReceiver; -#[cfg(feature = "cache")] -use http_body_util::StreamBody; -#[cfg(feature = "cache")] -use hyper::body::Frame; - -#[cfg(feature = "cache")] -pub(crate) type UnboundedStreamBody = StreamBody, hyper::Error>>>; +pub type UnboundedStreamBody = StreamBody, hyper::Error>>>; +#[allow(unused)] /// Response body use in this project /// - Incoming: just a type that only forwards the upstream response body to downstream. /// - Boxed: a type that is generated from cache or synthetic response body, e.g.,, small byte object. /// - Streamed: another type that is generated from stream, e.g., large byte object. -pub(crate) enum ResponseBody { +pub enum ResponseBody { Incoming(Incoming), Boxed(BoxBody), - #[cfg(feature = "cache")] Streamed(UnboundedStreamBody), } @@ -73,9 +67,7 @@ impl Body for ResponseBody { ) -> std::task::Poll, Self::Error>>> { match self.get_mut() { ResponseBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx), - #[cfg(feature = "cache")] ResponseBody::Boxed(boxed) => Pin::new(boxed).poll_frame(cx), - #[cfg(feature = "cache")] ResponseBody::Streamed(streamed) => Pin::new(streamed).poll_frame(cx), } .map_err(RpxyError::HyperBodyError) diff --git a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs index 46e572c..143b3e8 100644 --- a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs +++ b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs @@ -19,6 +19,7 @@ where // Functions to generate messages //////////////////////////////////////////////////// + #[allow(unused_variables)] /// Manipulate a response message sent from a backend application to forward downstream to a client. pub(super) fn generate_response_forwarded( &self, diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index a7c1ec8..3f19059 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -1,10 +1,12 @@ -mod proxy_h3; mod proxy_main; +mod socket; + +#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] +mod proxy_h3; #[cfg(feature = "http3-quinn")] mod proxy_quic_quinn; #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] mod proxy_quic_s2n; -mod socket; use crate::{ globals::Globals, From db658723373916d8754ded965acc1e430eee21ea Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 13:30:51 +0900 Subject: [PATCH 42/50] update docs. preparing 0.7.0-alpha.0. --- CHANGELOG.md | 2 +- TODO.md | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c0b2649..d3d55c1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,7 +4,7 @@ - Breaking: `hyper`-1.0 for both server and client modules. - Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `disable_override_host`. That is, `rpxy` always override the host header by the upstream hostname by default. -- Breaking: Introduced `hyper-tls-backend` feature to use the native TLS engine to access backend applications. +- Breaking: Introduced `native-tls-backend` feature to use the native TLS engine to access backend applications. - Redesigned: Cache structure is totally redesigned with more memory-efficient way to read from cache file, and more secure way to strongly bind memory-objects with files with hash values. - Redesigned: HTTP body handling flow is also redesigned with more memory-and-time efficient techniques without putting the whole objects on memory by using `futures::stream::Stream` and `futures::channel::mpsc` - Refactor: lots of minor improvements diff --git a/TODO.md b/TODO.md index 1a159fd..031c8aa 100644 --- a/TODO.md +++ b/TODO.md @@ -1,5 +1,6 @@ # TODO List +- Support of `rustls-0.22` along with `hyper-1.0`. Maybe `hyper-rustls` is the most difficult part. - [Done in 0.6.0] But we need more sophistication on `Forwarder` struct. ~~Fix strategy for `h2c` requests on forwarded requests upstream. This needs to update forwarder definition. Also, maybe forwarder would have a cache corresponding to the following task.~~ - [Initial implementation in v0.6.0] ~~**Cache option for the response with `Cache-Control: public` header directive ([#55](https://github.com/junkurihara/rust-rpxy/issues/55))**~~ Using `lru` crate might be inefficient in terms of the speed. - Consider more sophisticated architecture for cache From 47a3f4c301b818d38f99d8b67c80283c3b973b68 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 13:55:18 +0900 Subject: [PATCH 43/50] limit logs --- rpxy-bin/src/log.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/rpxy-bin/src/log.rs b/rpxy-bin/src/log.rs index fd7b5cb..978d686 100644 --- a/rpxy-bin/src/log.rs +++ b/rpxy-bin/src/log.rs @@ -12,10 +12,13 @@ pub fn init_logger() { .with_level(true) .compact(); - // This limits the logger to emits only rpxy crate + // This limits the logger to emits only proxy crate + let pkg_name = env!("CARGO_PKG_NAME").replace('-', "_"); // let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); - // let filter_layer = EnvFilter::new(format!("{}={}", env!("CARGO_PKG_NAME"), level_string)); - let filter_layer = EnvFilter::from_default_env(); + // let filter_layer = EnvFilter::new(format!("{}={}", pkg_name, level_string)); + let filter_layer = EnvFilter::try_from_default_env() + .unwrap_or_else(|_| EnvFilter::new("info")) + .add_directive(format!("{}=trace", pkg_name).parse().unwrap()); tracing_subscriber::registry() .with(format_layer) From 1a2a91325619e5056818f48c00244cf847c218d4 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 14:45:40 +0900 Subject: [PATCH 44/50] fix disableoverridehost option --- CHANGELOG.md | 2 +- config-example.toml | 4 ++-- rpxy-bin/src/log.rs | 10 +++++----- rpxy-lib/src/backend/upstream_opts.rs | 4 ++-- rpxy-lib/src/forwarder/cache/cache_main.rs | 2 +- .../src/message_handler/handler_manipulate_messages.rs | 7 ++++--- rpxy-lib/src/message_handler/utils_headers.rs | 9 +++++---- rpxy-lib/src/proxy/proxy_h3.rs | 4 ++-- 8 files changed, 22 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d3d55c1..2a6deae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,7 @@ ## 0.7.0 (unreleased) - Breaking: `hyper`-1.0 for both server and client modules. -- Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `disable_override_host`. That is, `rpxy` always override the host header by the upstream hostname by default. +- Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `keep_original_host`. That is, `rpxy` always override the host header by the upstream hostname (backend uri host name) by default. If this reverse option specified, original `host` header is maintained or added from the value of url request line. - Breaking: Introduced `native-tls-backend` feature to use the native TLS engine to access backend applications. - Redesigned: Cache structure is totally redesigned with more memory-efficient way to read from cache file, and more secure way to strongly bind memory-objects with files with hash values. - Redesigned: HTTP body handling flow is also redesigned with more memory-and-time efficient techniques without putting the whole objects on memory by using `futures::stream::Stream` and `futures::channel::mpsc` diff --git a/config-example.toml b/config-example.toml index 458061c..7460c35 100644 --- a/config-example.toml +++ b/config-example.toml @@ -57,8 +57,8 @@ upstream = [ ] load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ - "disable_override_host", # do not overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy) - "force_http2_upstream", # mutually exclusive with "force_http11_upstream" + "keep_original_host", # do not overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy) + "force_http2_upstream", # mutually exclusive with "force_http11_upstream" ] # Non-default destination in "localhost" app, which is routed by "path" diff --git a/rpxy-bin/src/log.rs b/rpxy-bin/src/log.rs index 978d686..0ca3b99 100644 --- a/rpxy-bin/src/log.rs +++ b/rpxy-bin/src/log.rs @@ -14,11 +14,11 @@ pub fn init_logger() { // This limits the logger to emits only proxy crate let pkg_name = env!("CARGO_PKG_NAME").replace('-', "_"); - // let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); - // let filter_layer = EnvFilter::new(format!("{}={}", pkg_name, level_string)); - let filter_layer = EnvFilter::try_from_default_env() - .unwrap_or_else(|_| EnvFilter::new("info")) - .add_directive(format!("{}=trace", pkg_name).parse().unwrap()); + let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); + let filter_layer = EnvFilter::new(format!("{}={}", pkg_name, level_string)); + // let filter_layer = EnvFilter::try_from_default_env() + // .unwrap_or_else(|_| EnvFilter::new("info")) + // .add_directive(format!("{}=trace", pkg_name).parse().unwrap()); tracing_subscriber::registry() .with(format_layer) diff --git a/rpxy-lib/src/backend/upstream_opts.rs b/rpxy-lib/src/backend/upstream_opts.rs index f19acb4..c4c3db5 100644 --- a/rpxy-lib/src/backend/upstream_opts.rs +++ b/rpxy-lib/src/backend/upstream_opts.rs @@ -2,7 +2,7 @@ use crate::error::*; #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum UpstreamOption { - DisableOverrideHost, + KeepOriginalHost, UpgradeInsecureRequests, ForceHttp11Upstream, ForceHttp2Upstream, @@ -12,7 +12,7 @@ impl TryFrom<&str> for UpstreamOption { type Error = RpxyError; fn try_from(val: &str) -> RpxyResult { match val { - "diaable_override_host" => Ok(Self::DisableOverrideHost), + "keep_original_host" => Ok(Self::KeepOriginalHost), "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), diff --git a/rpxy-lib/src/forwarder/cache/cache_main.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs index f3fc463..02aec93 100644 --- a/rpxy-lib/src/forwarder/cache/cache_main.rs +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -135,7 +135,7 @@ impl RpxyCache { .map(|f| { if f.is_data() { let data_bytes = f.data_ref().unwrap().clone(); - debug!("cache data bytes of {} bytes", data_bytes.len()); + // debug!("cache data bytes of {} bytes", data_bytes.len()); // We do not use stream-type buffering since it needs to lock file during operation. buf.extend(data_bytes.as_ref()); } diff --git a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs index 143b3e8..529a17a 100644 --- a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs +++ b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs @@ -85,14 +85,14 @@ where } }; - let uri = req.uri().to_string(); + let original_uri = req.uri().to_string(); let headers = req.headers_mut(); // delete headers specified in header.connection remove_connection_header(headers); // delete hop headers including header.connection remove_hop_header(headers); // X-Forwarded-For - add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; + add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &original_uri)?; // Add te: trailer if te_trailer if contains_te_trailers { @@ -106,6 +106,7 @@ where .headers_mut() .insert(header::HOST, HeaderValue::from_str(&org_host)?); }; + let original_host_header = req.headers().get(header::HOST).unwrap().clone(); ///////////////////////////////////////////// // Fix unique upstream destination since there could be multiple ones. @@ -135,7 +136,7 @@ where // by default, host header is overwritten with upstream hostname override_host_header(headers, &upstream_chosen.uri)?; // apply upstream options to header - apply_upstream_options_to_header(headers, upstream_candidates)?; + apply_upstream_options_to_header(headers, &original_host_header, upstream_candidates)?; // update uri in request ensure!( diff --git a/rpxy-lib/src/message_handler/utils_headers.rs b/rpxy-lib/src/message_handler/utils_headers.rs index 32bc7f3..df2d57b 100644 --- a/rpxy-lib/src/message_handler/utils_headers.rs +++ b/rpxy-lib/src/message_handler/utils_headers.rs @@ -105,17 +105,18 @@ pub(super) fn override_host_header(headers: &mut HeaderMap, upstream_base_uri: & /// Apply options to request header, which are specified in the configuration pub(super) fn apply_upstream_options_to_header( headers: &mut HeaderMap, + original_host_header: &HeaderValue, // _client_addr: &SocketAddr, upstream: &UpstreamCandidates, // _upstream_base_uri: &Uri, ) -> Result<()> { for opt in upstream.options.iter() { match opt { - UpstreamOption::DisableOverrideHost => { - // simply remove HOST header value + UpstreamOption::KeepOriginalHost => { + // revert hostname headers - .remove(header::HOST) - .ok_or_else(|| anyhow!("Failed to remove host header in disable_override_host option"))?; + .insert(header::HOST, original_host_header.to_owned()) + .ok_or_else(|| anyhow!("Failed to revert host header in keep_original_host option"))?; } UpstreamOption::UpgradeInsecureRequests => { // add upgrade-insecure-requests in request header if not exist diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 0295430..7e02f32 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -168,11 +168,11 @@ where if frame.is_data() { let data = frame.into_data().unwrap_or_default(); - debug!("Write data to HTTP/3 stream"); + // debug!("Write data to HTTP/3 stream"); send_stream.send_data(data).await?; } else if frame.is_trailers() { let trailers = frame.into_trailers().unwrap_or_default(); - debug!("Write trailer to HTTP/3 stream"); + // debug!("Write trailer to HTTP/3 stream"); send_stream.send_trailers(trailers).await?; } } From d85d7e6c390e5e6c2febca02c3c50ab7b9a8a937 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 15:36:00 +0900 Subject: [PATCH 45/50] use hyper::upgrade::on --- rpxy-lib/src/message_handler/handler_main.rs | 21 +++++++++++--------- rpxy-lib/src/message_handler/http_result.rs | 13 ++++++------ 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index ceb5db4..5fb978d 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -143,7 +143,8 @@ where // Upgrade in request header let upgrade_in_request = extract_upgrade(req.headers()); - let request_upgraded = req.extensions_mut().remove::(); + // let request_upgraded = req.extensions_mut().remove::(); + let req_on_upgrade = hyper::upgrade::on(&mut req); // Build request from destination information let _context = match self.generate_request_forwarded( @@ -209,19 +210,21 @@ where upgrade_in_response, upgrade_in_request ))); } - let Some(request_upgraded) = request_upgraded else { - return Err(HttpError::NoUpgradeExtensionInRequest); - }; - let Some(onupgrade) = res_backend.extensions_mut().remove::() else { - return Err(HttpError::NoUpgradeExtensionInResponse); - }; + // let Some(request_upgraded) = request_upgraded else { + // return Err(HttpError::NoUpgradeExtensionInRequest); + // }; + + // let Some(onupgrade) = res_backend.extensions_mut().remove::() else { + // return Err(HttpError::NoUpgradeExtensionInResponse); + // }; + let res_on_upgrade = hyper::upgrade::on(&mut res_backend); self.globals.runtime_handle.spawn(async move { - let mut response_upgraded = TokioIo::new(onupgrade.await.map_err(|e| { + let mut response_upgraded = TokioIo::new(res_on_upgrade.await.map_err(|e| { error!("Failed to upgrade response: {}", e); RpxyError::FailedToUpgradeResponse(e.to_string()) })?); - let mut request_upgraded = TokioIo::new(request_upgraded.await.map_err(|e| { + let mut request_upgraded = TokioIo::new(req_on_upgrade.await.map_err(|e| { error!("Failed to upgrade request: {}", e); RpxyError::FailedToUpgradeRequest(e.to_string()) })?); diff --git a/rpxy-lib/src/message_handler/http_result.rs b/rpxy-lib/src/message_handler/http_result.rs index ec48200..691c087 100644 --- a/rpxy-lib/src/message_handler/http_result.rs +++ b/rpxy-lib/src/message_handler/http_result.rs @@ -32,11 +32,10 @@ pub enum HttpError { #[error("Failed to upgrade connection: {0}")] FailedToUpgrade(String), - #[error("Request does not have an upgrade extension")] - NoUpgradeExtensionInRequest, - #[error("Response does not have an upgrade extension")] - NoUpgradeExtensionInResponse, - + // #[error("Request does not have an upgrade extension")] + // NoUpgradeExtensionInRequest, + // #[error("Response does not have an upgrade extension")] + // NoUpgradeExtensionInResponse, #[error(transparent)] Other(#[from] anyhow::Error), } @@ -54,8 +53,8 @@ impl From for StatusCode { HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR, - HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST, - HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY, + // HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST, + // HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY, _ => StatusCode::INTERNAL_SERVER_ERROR, } } From 1c1b50d213e80df6256c293d71125eb4f1be2dcf Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 16:29:51 +0900 Subject: [PATCH 46/50] limit upgrade only for http1.1 request --- rpxy-lib/src/message_handler/handler_main.rs | 6 ++++++ rpxy-lib/src/message_handler/handler_manipulate_messages.rs | 5 ++++- rpxy-lib/src/message_handler/http_result.rs | 6 +++--- 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index 5fb978d..c46ac85 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -143,6 +143,12 @@ where // Upgrade in request header let upgrade_in_request = extract_upgrade(req.headers()); + if upgrade_in_request.is_some() && req.version() != http::Version::HTTP_11 { + return Err(HttpError::FailedToUpgrade(format!( + "Unsupported HTTP version: {:?}", + req.version() + ))); + } // let request_upgraded = req.extensions_mut().remove::(); let req_on_upgrade = hyper::upgrade::on(&mut req); diff --git a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs index 529a17a..a0b37e0 100644 --- a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs +++ b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs @@ -177,8 +177,11 @@ where .headers_mut() .insert(header::CONNECTION, HeaderValue::from_static("upgrade")); } + if upgrade.is_none() { + // can update request line i.e., http version, only if not upgrade (http 1.1) + update_request_line(req, upstream_chosen, upstream_candidates)?; + } - update_request_line(req, upstream_chosen, upstream_candidates)?; Ok(context) } diff --git a/rpxy-lib/src/message_handler/http_result.rs b/rpxy-lib/src/message_handler/http_result.rs index 691c087..98cdb45 100644 --- a/rpxy-lib/src/message_handler/http_result.rs +++ b/rpxy-lib/src/message_handler/http_result.rs @@ -8,8 +8,8 @@ pub(crate) type HttpResult = std::result::Result; /// Describes things that can go wrong in the forwarder #[derive(Debug, Error)] pub enum HttpError { - #[error("No host is give nin request header")] - NoHostInRequestHeader, + // #[error("No host is give in request header")] + // NoHostInRequestHeader, #[error("Invalid host in request header")] InvalidHostInRequestHeader, #[error("SNI and Host header mismatch")] @@ -43,7 +43,7 @@ pub enum HttpError { impl From for StatusCode { fn from(e: HttpError) -> StatusCode { match e { - HttpError::NoHostInRequestHeader => StatusCode::BAD_REQUEST, + // HttpError::NoHostInRequestHeader => StatusCode::BAD_REQUEST, HttpError::InvalidHostInRequestHeader => StatusCode::BAD_REQUEST, HttpError::SniHostInconsistency => StatusCode::MISDIRECTED_REQUEST, HttpError::NoMatchingBackendApp => StatusCode::SERVICE_UNAVAILABLE, From 78a5487293ce1e2c9c89ba97ba069086f818e429 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 16:50:49 +0900 Subject: [PATCH 47/50] add unstable build for testing --- .github/workflows/release_docker.yml | 18 ++++++++++++++++++ rpxy-lib/src/constants.rs | 2 ++ rpxy-lib/src/proxy/proxy_h3.rs | 11 +++++++---- rpxy-lib/src/proxy/proxy_main.rs | 15 +++++++++------ 4 files changed, 36 insertions(+), 10 deletions(-) diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml index 391d7cd..076e6f7 100644 --- a/.github/workflows/release_docker.yml +++ b/.github/workflows/release_docker.yml @@ -2,6 +2,7 @@ name: Release - Build and publish docker, and trigger package release on: push: branches: + - "feat/*" - "develop" pull_request: types: [closed] @@ -135,6 +136,23 @@ jobs: # platforms: linux/amd64 # labels: ${{ steps.meta.outputs.labels }} + - name: Unstable build and push from develop branch + if: ${{ (github.ref_name == 'feat/*') && (github.event_name == 'push') }} + uses: docker/build-push-action@v5 + with: + context: . + build-args: ${{ matrix.build-args }} + push: true + tags: | + ${{ env.GHCR }}/${{ env.GHCR_IMAGE_NAME }}:unstable${{ matrix.tags-suffix }} + ${{ env.DH_REGISTRY_NAME }}:unstable${{ matrix.tags-suffix }} + build-contexts: ${{ matrix.build-contexts }} + file: ${{ matrix.dockerfile }} + cache-from: type=gha,scope=rpxy-unstable-${{ matrix.target }} + cache-to: type=gha,mode=max,scope=rpxy-unstable-${{ matrix.target }} + platforms: ${{ matrix.platforms }} + labels: ${{ steps.meta.outputs.labels }} + - name: Nightly build and push from develop branch if: ${{ (github.ref_name == 'develop') && (github.event_name == 'push') }} uses: docker/build-push-action@v5 diff --git a/rpxy-lib/src/constants.rs b/rpxy-lib/src/constants.rs index acc9381..064f5fd 100644 --- a/rpxy-lib/src/constants.rs +++ b/rpxy-lib/src/constants.rs @@ -12,6 +12,8 @@ pub const MAX_CONCURRENT_STREAMS: u32 = 64; pub const CERTS_WATCH_DELAY_SECS: u32 = 60; pub const LOAD_CERTS_ONLY_WHEN_UPDATED: bool = true; +pub const CONNECTION_TIMEOUT_SEC: u64 = 30; // timeout to serve a connection. this might limits the max length of response. + // #[cfg(feature = "http3")] // pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB // #[cfg(feature = "http3")] diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 7e02f32..90457e3 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,5 +1,6 @@ use super::proxy_main::Proxy; use crate::{ + constants::CONNECTION_TIMEOUT_SEC, crypto::CryptoSource, error::*, hyper_ext::body::{IncomingLike, RequestBody}, @@ -10,7 +11,7 @@ use bytes::{Buf, Bytes}; use http::{Request, Response}; use http_body_util::BodyExt; use hyper_util::client::legacy::connect::Connect; -use std::net::SocketAddr; +use std::{net::SocketAddr, time::Duration}; #[cfg(feature = "http3-quinn")] use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; @@ -70,9 +71,11 @@ where let self_inner = self.clone(); let tls_server_name_inner = tls_server_name.clone(); self.globals.runtime_handle.spawn(async move { - if let Err(e) = self_inner - .h3_serve_stream(req, stream, client_addr, tls_server_name_inner) - .await + if let Err(e) = tokio::time::timeout( + Duration::from_secs(CONNECTION_TIMEOUT_SEC + 1), // just in case... + self_inner.h3_serve_stream(req, stream, client_addr, tls_server_name_inner), + ) + .await { warn!("HTTP/3 error on serve stream: {}", e); } diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 4fea840..61176f3 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,6 +1,6 @@ use super::socket::bind_tcp_socket; use crate::{ - constants::TLS_HANDSHAKE_TIMEOUT_SEC, + constants::{CONNECTION_TIMEOUT_SEC, TLS_HANDSHAKE_TIMEOUT_SEC}, crypto::{CryptoSource, ServerCrypto, SniServerCryptoMap}, error::*, globals::Globals, @@ -88,9 +88,11 @@ where let message_handler_clone = self.message_handler.clone(); let tls_enabled = self.tls_enabled; let listening_on = self.listening_on; + let timeout_sec = Duration::from_secs(CONNECTION_TIMEOUT_SEC + 1); // just in case... self.globals.runtime_handle.clone().spawn(async move { - server_clone - .serve_connection_with_upgrades( + timeout( + timeout_sec + Duration::from_secs(1), // just in case... + server_clone.serve_connection_with_upgrades( stream, service_fn(move |req: Request| { serve_request( @@ -102,9 +104,10 @@ where tls_server_name.clone(), ) }), - ) - .await - .ok(); + ), + ) + .await + .ok(); request_count.decrement(); debug!("Request processed: current # {}", request_count.current()); From f509bc70b022a8442aecb834387b8523b576bdd4 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 16:54:14 +0900 Subject: [PATCH 48/50] add unstable build for testing --- .github/workflows/release_docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml index 076e6f7..ab78117 100644 --- a/.github/workflows/release_docker.yml +++ b/.github/workflows/release_docker.yml @@ -137,7 +137,7 @@ jobs: # labels: ${{ steps.meta.outputs.labels }} - name: Unstable build and push from develop branch - if: ${{ (github.ref_name == 'feat/*') && (github.event_name == 'push') }} + if: ${{ startsWith(github.ref_name, 'feat/') && (github.event_name == 'push') }} uses: docker/build-push-action@v5 with: context: . From 7e37b8177186fded4f82aa2c35411c07975090f4 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 16:55:13 +0900 Subject: [PATCH 49/50] limit amd64 for unstable build --- .github/workflows/release_docker.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml index ab78117..4fe9d14 100644 --- a/.github/workflows/release_docker.yml +++ b/.github/workflows/release_docker.yml @@ -150,7 +150,7 @@ jobs: file: ${{ matrix.dockerfile }} cache-from: type=gha,scope=rpxy-unstable-${{ matrix.target }} cache-to: type=gha,mode=max,scope=rpxy-unstable-${{ matrix.target }} - platforms: ${{ matrix.platforms }} + platforms: linux/amd64 labels: ${{ steps.meta.outputs.labels }} - name: Nightly build and push from develop branch From 3dc20af5d8717e4990ea9aa17e63f2c9d97aec6c Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 15 Dec 2023 20:16:48 +0900 Subject: [PATCH 50/50] fix: fix bug for http upgrade (due to hyper-util bug) --- rpxy-lib/Cargo.toml | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 22e091f..3e3b1b0 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -52,15 +52,22 @@ thiserror = "1.0.50" http = "1.0.0" http-body-util = "0.1.0" hyper = { version = "1.0.1", default-features = false } -hyper-util = { version = "0.1.1", features = ["full"] } +# hyper-util = { version = "0.1.1", features = ["full"] } +hyper-util = { git = "https://github.com/junkurihara/hyper-util", features = [ + "full", +], rev = "99409f5c4059633b7e2fa8b9c2e6c110b0f2f64b" } futures-util = { version = "0.3.29", default-features = false } futures-channel = { version = "0.3.29", default-features = false } # http client for upstream -hyper-tls = { version = "0.6.0", features = [ +hyper-tls = { git = "https://github.com/junkurihara/hyper-tls", features = [ "alpn", "vendored", -], optional = true } +], rev = "06fb462ee67ec349936ceb64849d64d05e58458a", optional = true } +# hyper-tls = { version = "0.6.0", features = [ +# "alpn", +# "vendored", +# ], optional = true } # hyper-rustls = { version = "0.24.2", default-features = false, features = [ # "tokio-runtime", # "webpki-tokio",