diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs
index d3926af..851e174 100644
--- a/src/proxy/proxy_handler.rs
+++ b/src/proxy/proxy_handler.rs
@@ -7,6 +7,7 @@ use hyper::{
Body, Request, Response, StatusCode, Uri,
};
use std::net::SocketAddr;
+use tokio::io::copy_bidirectional;
const HOP_HEADERS: &[&str] = &[
"connection",
@@ -26,7 +27,7 @@ where
{
pub async fn handle_request(
self,
- req: Request
,
+ mut req: Request,
client_addr: SocketAddr, // アクセス制御用
) -> Result> {
debug!("Handling request: {:?}", req);
@@ -61,6 +62,7 @@ where
// Upgrade in request header
let upgrade_in_request = extract_upgrade(req.headers());
+ let request_upgraded = req.extensions_mut().remove::();
// Build request from destination information
let req_forwarded = if let Ok(req) = generate_request_forwarded(
@@ -68,7 +70,7 @@ where
req,
destination_scheme_host,
path_and_query,
- upgrade_in_request,
+ &upgrade_in_request,
) {
req
} else {
@@ -87,13 +89,41 @@ where
};
debug!("Response from backend: {:?}", res_backend.status());
- // TODO: Handle StatusCode::SWITCHING_PROTOCOLS
-
- // Generate response to client
- if generate_response_forwarded(&mut res_backend).is_ok() {
- Ok(res_backend)
+ if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS {
+ // Handle StatusCode::SWITCHING_PROTOCOLS in response
+ let upgrade_in_response = extract_upgrade(res_backend.headers());
+ if upgrade_in_request == upgrade_in_response {
+ if let Some(request_upgraded) = request_upgraded {
+ let mut response_upgraded = res_backend
+ .extensions_mut()
+ .remove::()
+ .expect("Response does not have an upgrade extension")
+ .await?;
+ tokio::spawn(async move {
+ let mut request_upgraded = request_upgraded.await.expect("Failed to upgrade request");
+ copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
+ .await
+ .expect("Coping between upgraded connections failed");
+ });
+ Ok(res_backend)
+ } else {
+ error!("Request does not have an upgrade extension");
+ http_error(StatusCode::BAD_GATEWAY)
+ }
+ } else {
+ error!(
+ "Backend tried to switch to protocol {:?} when {:?} was requested",
+ upgrade_in_response, upgrade_in_request
+ );
+ http_error(StatusCode::BAD_GATEWAY)
+ }
} else {
- http_error(StatusCode::BAD_GATEWAY)
+ // Generate response to client
+ if generate_response_forwarded(&mut res_backend).is_ok() {
+ Ok(res_backend)
+ } else {
+ http_error(StatusCode::BAD_GATEWAY)
+ }
}
}
}
@@ -110,7 +140,7 @@ fn generate_request_forwarded(
mut req: Request,
destination_scheme_host: Uri,
path_and_query: String,
- upgrade: Option,
+ upgrade: &Option,
) -> Result> {
debug!("Generate request to be forwarded");