wip: implement switching protocols (http upgrade)

This commit is contained in:
Jun Kurihara 2023-11-28 18:04:12 +09:00
commit 8f77ce9447
No known key found for this signature in database
GPG key ID: D992B3E3DE1DED23
4 changed files with 65 additions and 51 deletions

View file

@ -71,6 +71,12 @@ pub enum RpxyError {
// Handler errors // Handler errors
#[error("Failed to build message handler: {0}")] #[error("Failed to build message handler: {0}")]
FailedToBuildMessageHandler(#[from] crate::message_handle::HttpMessageHandlerBuilderError), FailedToBuildMessageHandler(#[from] crate::message_handle::HttpMessageHandlerBuilderError),
#[error("Failed to upgrade request: {0}")]
FailedToUpgradeRequest(String),
#[error("Failed to upgrade response: {0}")]
FailedToUpgradeResponse(String),
#[error("Failed to copy bidirectional for upgraded connections: {0}")]
FailedToCopyBidirectional(String),
// Upstream connection setting errors // Upstream connection setting errors
#[error("Unsupported upstream option")] #[error("Unsupported upstream option")]

View file

@ -16,7 +16,9 @@ use crate::{
}; };
use derive_builder::Builder; use derive_builder::Builder;
use http::{Request, Response, StatusCode}; use http::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use std::{net::SocketAddr, sync::Arc}; use std::{net::SocketAddr, sync::Arc};
use tokio::io::copy_bidirectional;
#[allow(dead_code)] #[allow(dead_code)]
#[derive(Debug)] #[derive(Debug)]
@ -51,13 +53,15 @@ where
/// Responsible to passthrough responses from backend applications or generate synthetic error responses. /// Responsible to passthrough responses from backend applications or generate synthetic error responses.
pub async fn handle_request( pub async fn handle_request(
&self, &self,
mut req: Request<IncomingOr<IncomingLike>>, req: Request<IncomingOr<IncomingLike>>,
client_addr: SocketAddr, // For access control client_addr: SocketAddr, // For access control
listen_addr: SocketAddr, listen_addr: SocketAddr,
tls_enabled: bool, tls_enabled: bool,
tls_server_name: Option<ServerName>, tls_server_name: Option<ServerName>,
) -> RpxyResult<Response<IncomingOr<BoxBody>>> { ) -> RpxyResult<Response<IncomingOr<BoxBody>>> {
// preparing log data
let mut log_data = HttpMessageLog::from(&req); let mut log_data = HttpMessageLog::from(&req);
log_data.client_addr(&client_addr);
let http_result = self let http_result = self
.handle_request_inner( .handle_request_inner(
@ -96,10 +100,6 @@ where
tls_enabled: bool, tls_enabled: bool,
tls_server_name: Option<ServerName>, tls_server_name: Option<ServerName>,
) -> HttpResult<Response<IncomingOr<BoxBody>>> { ) -> HttpResult<Response<IncomingOr<BoxBody>>> {
// preparing log data
let mut log_data = HttpMessageLog::from(&req);
log_data.client_addr(&client_addr);
// Here we start to inspect and parse with server_name // Here we start to inspect and parse with server_name
let server_name = req let server_name = req
.inspect_parse_host() .inspect_parse_host()
@ -207,48 +207,46 @@ where
return Ok(res_backend); return Ok(res_backend);
} }
// // Handle StatusCode::SWITCHING_PROTOCOLS in response // Handle StatusCode::SWITCHING_PROTOCOLS in response
// let upgrade_in_response = extract_upgrade(res_backend.headers()); let upgrade_in_response = extract_upgrade(res_backend.headers());
// let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) {
// { (Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(),
// u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() _ => false,
// } else { };
// false
// };
// if !should_upgrade {
// error!(
// "Backend tried to switch to protocol {:?} when {:?} was requested",
// upgrade_in_response, upgrade_in_request
// );
// return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
// }
// let Some(request_upgraded) = request_upgraded else {
// error!("Request does not have an upgrade extension");
// return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data);
// };
// let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
// error!("Response does not have an upgrade extension");
// return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
// };
// self.globals.runtime_handle.spawn(async move { if !should_upgrade {
// let mut response_upgraded = onupgrade.await.map_err(|e| { error!(
// error!("Failed to upgrade response: {}", e); "Backend tried to switch to protocol {:?} when {:?} was requested",
// RpxyError::Hyper(e) upgrade_in_response, upgrade_in_request
// })?; );
// let mut request_upgraded = request_upgraded.await.map_err(|e| { return Err(HttpError::FailedToUpgrade);
// error!("Failed to upgrade request: {}", e); }
// RpxyError::Hyper(e) let Some(request_upgraded) = request_upgraded else {
// })?; error!("Request does not have an upgrade extension");
// copy_bidirectional(&mut response_upgraded, &mut request_upgraded) return Err(HttpError::NoUpgradeExtensionInRequest);
// .await };
// .map_err(|e| { let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
// error!("Coping between upgraded connections failed: {}", e); error!("Response does not have an upgrade extension");
// RpxyError::Io(e) return Err(HttpError::NoUpgradeExtensionInResponse);
// })?; };
// Ok(()) as Result<()>
// }); self.globals.runtime_handle.spawn(async move {
// log_data.status_code(&res_backend.status()).output(); let mut response_upgraded = TokioIo::new(onupgrade.await.map_err(|e| {
error!("Failed to upgrade response: {}", e);
RpxyError::FailedToUpgradeResponse(e.to_string())
})?);
let mut request_upgraded = TokioIo::new(request_upgraded.await.map_err(|e| {
error!("Failed to upgrade request: {}", e);
RpxyError::FailedToUpgradeRequest(e.to_string())
})?);
copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
.await
.map_err(|e| {
error!("Coping between upgraded connections failed: {}", e);
RpxyError::FailedToCopyBidirectional(e.to_string())
})?;
Ok(()) as RpxyResult<()>
});
Ok(res_backend) Ok(res_backend)
} }

View file

@ -11,7 +11,7 @@ pub struct HttpMessageLog {
pub method: String, pub method: String,
pub host: String, pub host: String,
pub p_and_q: String, pub p_and_q: String,
pub version: hyper::Version, pub version: http::Version,
pub uri_scheme: String, pub uri_scheme: String,
pub uri_host: String, pub uri_host: String,
pub ua: String, pub ua: String,
@ -20,8 +20,8 @@ pub struct HttpMessageLog {
pub upstream: String, pub upstream: String,
} }
impl<T> From<&hyper::Request<T>> for HttpMessageLog { impl<T> From<&http::Request<T>> for HttpMessageLog {
fn from(req: &hyper::Request<T>) -> Self { fn from(req: &http::Request<T>) -> Self {
let header_mapper = |v: header::HeaderName| { let header_mapper = |v: header::HeaderName| {
req req
.headers() .headers()
@ -59,7 +59,7 @@ impl HttpMessageLog {
// self.tls_server_name = tls_server_name.to_string(); // self.tls_server_name = tls_server_name.to_string();
// self // self
// } // }
pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { pub fn status_code(&mut self, status_code: &http::StatusCode) -> &mut Self {
self.status = status_code.to_string(); self.status = status_code.to_string();
self self
} }
@ -67,7 +67,7 @@ impl HttpMessageLog {
self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string();
self self
} }
pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { pub fn upstream(&mut self, upstream: &http::Uri) -> &mut Self {
self.upstream = upstream.to_string(); self.upstream = upstream.to_string();
self self
} }

View file

@ -28,6 +28,13 @@ pub enum HttpError {
#[error("Failed to generated downstream response: {0}")] #[error("Failed to generated downstream response: {0}")]
FailedToGenerateDownstreamResponse(String), FailedToGenerateDownstreamResponse(String),
#[error("Failed to upgrade connection")]
FailedToUpgrade,
#[error("Request does not have an upgrade extension")]
NoUpgradeExtensionInRequest,
#[error("Response does not have an upgrade extension")]
NoUpgradeExtensionInResponse,
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
@ -44,6 +51,9 @@ impl From<HttpError> for StatusCode {
HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToAddSetCookeInResponse => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToAddSetCookeInResponse => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToUpgrade => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST,
HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY,
_ => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR,
} }
} }