From 8f77ce94473d92098925a53e76b9bdc5067d696a Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Tue, 28 Nov 2023 18:04:12 +0900 Subject: [PATCH] wip: implement switching protocols (http upgrade) --- rpxy-lib/src/error.rs | 6 ++ rpxy-lib/src/message_handle/handler_main.rs | 90 ++++++++++----------- rpxy-lib/src/message_handle/http_log.rs | 10 +-- rpxy-lib/src/message_handle/http_result.rs | 10 +++ 4 files changed, 65 insertions(+), 51 deletions(-) diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index da65234..d7123a3 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -71,6 +71,12 @@ pub enum RpxyError { // Handler errors #[error("Failed to build message handler: {0}")] FailedToBuildMessageHandler(#[from] crate::message_handle::HttpMessageHandlerBuilderError), + #[error("Failed to upgrade request: {0}")] + FailedToUpgradeRequest(String), + #[error("Failed to upgrade response: {0}")] + FailedToUpgradeResponse(String), + #[error("Failed to copy bidirectional for upgraded connections: {0}")] + FailedToCopyBidirectional(String), // Upstream connection setting errors #[error("Unsupported upstream option")] diff --git a/rpxy-lib/src/message_handle/handler_main.rs b/rpxy-lib/src/message_handle/handler_main.rs index 666faa3..922e024 100644 --- a/rpxy-lib/src/message_handle/handler_main.rs +++ b/rpxy-lib/src/message_handle/handler_main.rs @@ -16,7 +16,9 @@ use crate::{ }; use derive_builder::Builder; use http::{Request, Response, StatusCode}; +use hyper_util::rt::TokioIo; use std::{net::SocketAddr, sync::Arc}; +use tokio::io::copy_bidirectional; #[allow(dead_code)] #[derive(Debug)] @@ -51,13 +53,15 @@ where /// Responsible to passthrough responses from backend applications or generate synthetic error responses. pub async fn handle_request( &self, - mut req: Request>, + req: Request>, client_addr: SocketAddr, // For access control listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, ) -> RpxyResult>> { + // preparing log data let mut log_data = HttpMessageLog::from(&req); + log_data.client_addr(&client_addr); let http_result = self .handle_request_inner( @@ -96,10 +100,6 @@ where tls_enabled: bool, tls_server_name: Option, ) -> HttpResult>> { - // preparing log data - let mut log_data = HttpMessageLog::from(&req); - log_data.client_addr(&client_addr); - // Here we start to inspect and parse with server_name let server_name = req .inspect_parse_host() @@ -207,48 +207,46 @@ where return Ok(res_backend); } - // // Handle StatusCode::SWITCHING_PROTOCOLS in response - // let upgrade_in_response = extract_upgrade(res_backend.headers()); - // let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) - // { - // u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() - // } else { - // false - // }; - // if !should_upgrade { - // error!( - // "Backend tried to switch to protocol {:?} when {:?} was requested", - // upgrade_in_response, upgrade_in_request - // ); - // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - // } - // let Some(request_upgraded) = request_upgraded else { - // error!("Request does not have an upgrade extension"); - // return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - // }; - // let Some(onupgrade) = res_backend.extensions_mut().remove::() else { - // error!("Response does not have an upgrade extension"); - // return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - // }; + // Handle StatusCode::SWITCHING_PROTOCOLS in response + let upgrade_in_response = extract_upgrade(res_backend.headers()); + let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) { + (Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(), + _ => false, + }; - // self.globals.runtime_handle.spawn(async move { - // let mut response_upgraded = onupgrade.await.map_err(|e| { - // error!("Failed to upgrade response: {}", e); - // RpxyError::Hyper(e) - // })?; - // let mut request_upgraded = request_upgraded.await.map_err(|e| { - // error!("Failed to upgrade request: {}", e); - // RpxyError::Hyper(e) - // })?; - // copy_bidirectional(&mut response_upgraded, &mut request_upgraded) - // .await - // .map_err(|e| { - // error!("Coping between upgraded connections failed: {}", e); - // RpxyError::Io(e) - // })?; - // Ok(()) as Result<()> - // }); - // log_data.status_code(&res_backend.status()).output(); + if !should_upgrade { + error!( + "Backend tried to switch to protocol {:?} when {:?} was requested", + upgrade_in_response, upgrade_in_request + ); + return Err(HttpError::FailedToUpgrade); + } + let Some(request_upgraded) = request_upgraded else { + error!("Request does not have an upgrade extension"); + return Err(HttpError::NoUpgradeExtensionInRequest); + }; + let Some(onupgrade) = res_backend.extensions_mut().remove::() else { + error!("Response does not have an upgrade extension"); + return Err(HttpError::NoUpgradeExtensionInResponse); + }; + + self.globals.runtime_handle.spawn(async move { + let mut response_upgraded = TokioIo::new(onupgrade.await.map_err(|e| { + error!("Failed to upgrade response: {}", e); + RpxyError::FailedToUpgradeResponse(e.to_string()) + })?); + let mut request_upgraded = TokioIo::new(request_upgraded.await.map_err(|e| { + error!("Failed to upgrade request: {}", e); + RpxyError::FailedToUpgradeRequest(e.to_string()) + })?); + copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + .await + .map_err(|e| { + error!("Coping between upgraded connections failed: {}", e); + RpxyError::FailedToCopyBidirectional(e.to_string()) + })?; + Ok(()) as RpxyResult<()> + }); Ok(res_backend) } diff --git a/rpxy-lib/src/message_handle/http_log.rs b/rpxy-lib/src/message_handle/http_log.rs index 7056c80..acda9f0 100644 --- a/rpxy-lib/src/message_handle/http_log.rs +++ b/rpxy-lib/src/message_handle/http_log.rs @@ -11,7 +11,7 @@ pub struct HttpMessageLog { pub method: String, pub host: String, pub p_and_q: String, - pub version: hyper::Version, + pub version: http::Version, pub uri_scheme: String, pub uri_host: String, pub ua: String, @@ -20,8 +20,8 @@ pub struct HttpMessageLog { pub upstream: String, } -impl From<&hyper::Request> for HttpMessageLog { - fn from(req: &hyper::Request) -> Self { +impl From<&http::Request> for HttpMessageLog { + fn from(req: &http::Request) -> Self { let header_mapper = |v: header::HeaderName| { req .headers() @@ -59,7 +59,7 @@ impl HttpMessageLog { // self.tls_server_name = tls_server_name.to_string(); // self // } - pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { + pub fn status_code(&mut self, status_code: &http::StatusCode) -> &mut Self { self.status = status_code.to_string(); self } @@ -67,7 +67,7 @@ impl HttpMessageLog { self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); self } - pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { + pub fn upstream(&mut self, upstream: &http::Uri) -> &mut Self { self.upstream = upstream.to_string(); self } diff --git a/rpxy-lib/src/message_handle/http_result.rs b/rpxy-lib/src/message_handle/http_result.rs index 07a0034..dc77565 100644 --- a/rpxy-lib/src/message_handle/http_result.rs +++ b/rpxy-lib/src/message_handle/http_result.rs @@ -28,6 +28,13 @@ pub enum HttpError { #[error("Failed to generated downstream response: {0}")] FailedToGenerateDownstreamResponse(String), + #[error("Failed to upgrade connection")] + FailedToUpgrade, + #[error("Request does not have an upgrade extension")] + NoUpgradeExtensionInRequest, + #[error("Response does not have an upgrade extension")] + NoUpgradeExtensionInResponse, + #[error(transparent)] Other(#[from] anyhow::Error), } @@ -44,6 +51,9 @@ impl From for StatusCode { HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToAddSetCookeInResponse => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToUpgrade => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST, + HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY, _ => StatusCode::INTERNAL_SERVER_ERROR, } }