From 9c34c259ef90a863165e4fc4e3501993cdc7ee56 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 25 Jun 2022 00:14:48 -0400 Subject: [PATCH] add upgrade handling for response --- src/proxy/proxy_handler.rs | 48 +++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 9 deletions(-) diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index d3926af..851e174 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -7,6 +7,7 @@ use hyper::{ Body, Request, Response, StatusCode, Uri, }; use std::net::SocketAddr; +use tokio::io::copy_bidirectional; const HOP_HEADERS: &[&str] = &[ "connection", @@ -26,7 +27,7 @@ where { pub async fn handle_request( self, - req: Request, + mut req: Request, client_addr: SocketAddr, // アクセス制御用 ) -> Result> { debug!("Handling request: {:?}", req); @@ -61,6 +62,7 @@ where // Upgrade in request header let upgrade_in_request = extract_upgrade(req.headers()); + let request_upgraded = req.extensions_mut().remove::(); // Build request from destination information let req_forwarded = if let Ok(req) = generate_request_forwarded( @@ -68,7 +70,7 @@ where req, destination_scheme_host, path_and_query, - upgrade_in_request, + &upgrade_in_request, ) { req } else { @@ -87,13 +89,41 @@ where }; debug!("Response from backend: {:?}", res_backend.status()); - // TODO: Handle StatusCode::SWITCHING_PROTOCOLS - - // Generate response to client - if generate_response_forwarded(&mut res_backend).is_ok() { - Ok(res_backend) + if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS { + // Handle StatusCode::SWITCHING_PROTOCOLS in response + let upgrade_in_response = extract_upgrade(res_backend.headers()); + if upgrade_in_request == upgrade_in_response { + if let Some(request_upgraded) = request_upgraded { + let mut response_upgraded = res_backend + .extensions_mut() + .remove::() + .expect("Response does not have an upgrade extension") + .await?; + tokio::spawn(async move { + let mut request_upgraded = request_upgraded.await.expect("Failed to upgrade request"); + copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + .await + .expect("Coping between upgraded connections failed"); + }); + Ok(res_backend) + } else { + error!("Request does not have an upgrade extension"); + http_error(StatusCode::BAD_GATEWAY) + } + } else { + error!( + "Backend tried to switch to protocol {:?} when {:?} was requested", + upgrade_in_response, upgrade_in_request + ); + http_error(StatusCode::BAD_GATEWAY) + } } else { - http_error(StatusCode::BAD_GATEWAY) + // Generate response to client + if generate_response_forwarded(&mut res_backend).is_ok() { + Ok(res_backend) + } else { + http_error(StatusCode::BAD_GATEWAY) + } } } } @@ -110,7 +140,7 @@ fn generate_request_forwarded( mut req: Request, destination_scheme_host: Uri, path_and_query: String, - upgrade: Option, + upgrade: &Option, ) -> Result> { debug!("Generate request to be forwarded");