add upgrade handling for response

This commit is contained in:
Jun Kurihara 2022-06-25 00:14:48 -04:00
commit 9c34c259ef
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03

View file

@ -7,6 +7,7 @@ use hyper::{
Body, Request, Response, StatusCode, Uri, Body, Request, Response, StatusCode, Uri,
}; };
use std::net::SocketAddr; use std::net::SocketAddr;
use tokio::io::copy_bidirectional;
const HOP_HEADERS: &[&str] = &[ const HOP_HEADERS: &[&str] = &[
"connection", "connection",
@ -26,7 +27,7 @@ where
{ {
pub async fn handle_request( pub async fn handle_request(
self, self,
req: Request<Body>, mut req: Request<Body>,
client_addr: SocketAddr, // アクセス制御用 client_addr: SocketAddr, // アクセス制御用
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
debug!("Handling request: {:?}", req); debug!("Handling request: {:?}", req);
@ -61,6 +62,7 @@ where
// Upgrade in request header // Upgrade in request header
let upgrade_in_request = extract_upgrade(req.headers()); let upgrade_in_request = extract_upgrade(req.headers());
let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
// Build request from destination information // Build request from destination information
let req_forwarded = if let Ok(req) = generate_request_forwarded( let req_forwarded = if let Ok(req) = generate_request_forwarded(
@ -68,7 +70,7 @@ where
req, req,
destination_scheme_host, destination_scheme_host,
path_and_query, path_and_query,
upgrade_in_request, &upgrade_in_request,
) { ) {
req req
} else { } else {
@ -87,8 +89,35 @@ where
}; };
debug!("Response from backend: {:?}", res_backend.status()); debug!("Response from backend: {:?}", res_backend.status());
// TODO: Handle StatusCode::SWITCHING_PROTOCOLS if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS {
// Handle StatusCode::SWITCHING_PROTOCOLS in response
let upgrade_in_response = extract_upgrade(res_backend.headers());
if upgrade_in_request == upgrade_in_response {
if let Some(request_upgraded) = request_upgraded {
let mut response_upgraded = res_backend
.extensions_mut()
.remove::<hyper::upgrade::OnUpgrade>()
.expect("Response does not have an upgrade extension")
.await?;
tokio::spawn(async move {
let mut request_upgraded = request_upgraded.await.expect("Failed to upgrade request");
copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
.await
.expect("Coping between upgraded connections failed");
});
Ok(res_backend)
} else {
error!("Request does not have an upgrade extension");
http_error(StatusCode::BAD_GATEWAY)
}
} else {
error!(
"Backend tried to switch to protocol {:?} when {:?} was requested",
upgrade_in_response, upgrade_in_request
);
http_error(StatusCode::BAD_GATEWAY)
}
} else {
// Generate response to client // Generate response to client
if generate_response_forwarded(&mut res_backend).is_ok() { if generate_response_forwarded(&mut res_backend).is_ok() {
Ok(res_backend) Ok(res_backend)
@ -96,6 +125,7 @@ where
http_error(StatusCode::BAD_GATEWAY) http_error(StatusCode::BAD_GATEWAY)
} }
} }
}
} }
fn generate_response_forwarded<B: core::fmt::Debug>(response: &mut Response<B>) -> Result<()> { fn generate_response_forwarded<B: core::fmt::Debug>(response: &mut Response<B>) -> Result<()> {
@ -110,7 +140,7 @@ fn generate_request_forwarded<B: core::fmt::Debug>(
mut req: Request<B>, mut req: Request<B>,
destination_scheme_host: Uri, destination_scheme_host: Uri,
path_and_query: String, path_and_query: String,
upgrade: Option<String>, upgrade: &Option<String>,
) -> Result<Request<B>> { ) -> Result<Request<B>> {
debug!("Generate request to be forwarded"); debug!("Generate request to be forwarded");