use hyper::upgrade::on

This commit is contained in:
Jun Kurihara 2023-12-15 15:36:00 +09:00
commit d85d7e6c39
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
2 changed files with 18 additions and 16 deletions

View file

@ -143,7 +143,8 @@ where
// Upgrade in request header // Upgrade in request header
let upgrade_in_request = extract_upgrade(req.headers()); let upgrade_in_request = extract_upgrade(req.headers());
let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>(); // let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
let req_on_upgrade = hyper::upgrade::on(&mut req);
// Build request from destination information // Build request from destination information
let _context = match self.generate_request_forwarded( let _context = match self.generate_request_forwarded(
@ -209,19 +210,21 @@ where
upgrade_in_response, upgrade_in_request upgrade_in_response, upgrade_in_request
))); )));
} }
let Some(request_upgraded) = request_upgraded else { // let Some(request_upgraded) = request_upgraded else {
return Err(HttpError::NoUpgradeExtensionInRequest); // return Err(HttpError::NoUpgradeExtensionInRequest);
}; // };
let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
return Err(HttpError::NoUpgradeExtensionInResponse); // let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
}; // return Err(HttpError::NoUpgradeExtensionInResponse);
// };
let res_on_upgrade = hyper::upgrade::on(&mut res_backend);
self.globals.runtime_handle.spawn(async move { self.globals.runtime_handle.spawn(async move {
let mut response_upgraded = TokioIo::new(onupgrade.await.map_err(|e| { let mut response_upgraded = TokioIo::new(res_on_upgrade.await.map_err(|e| {
error!("Failed to upgrade response: {}", e); error!("Failed to upgrade response: {}", e);
RpxyError::FailedToUpgradeResponse(e.to_string()) RpxyError::FailedToUpgradeResponse(e.to_string())
})?); })?);
let mut request_upgraded = TokioIo::new(request_upgraded.await.map_err(|e| { let mut request_upgraded = TokioIo::new(req_on_upgrade.await.map_err(|e| {
error!("Failed to upgrade request: {}", e); error!("Failed to upgrade request: {}", e);
RpxyError::FailedToUpgradeRequest(e.to_string()) RpxyError::FailedToUpgradeRequest(e.to_string())
})?); })?);

View file

@ -32,11 +32,10 @@ pub enum HttpError {
#[error("Failed to upgrade connection: {0}")] #[error("Failed to upgrade connection: {0}")]
FailedToUpgrade(String), FailedToUpgrade(String),
#[error("Request does not have an upgrade extension")] // #[error("Request does not have an upgrade extension")]
NoUpgradeExtensionInRequest, // NoUpgradeExtensionInRequest,
#[error("Response does not have an upgrade extension")] // #[error("Response does not have an upgrade extension")]
NoUpgradeExtensionInResponse, // NoUpgradeExtensionInResponse,
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
@ -54,8 +53,8 @@ impl From<HttpError> for StatusCode {
HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR, HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR,
HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST, // HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST,
HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY, // HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY,
_ => StatusCode::INTERNAL_SERVER_ERROR, _ => StatusCode::INTERNAL_SERVER_ERROR,
} }
} }