diff --git a/Cargo.toml b/Cargo.toml index 7f16166..b64f78e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ http3 = ["quinn", "h3", "h3-quinn"] [dependencies] env_logger = "0.9.1" anyhow = "1.0.65" -clap = { version = "4.0.12", features = ["std", "cargo", "wrap_help"] } +clap = { version = "4.0.13", features = ["std", "cargo", "wrap_help"] } futures = { version = "0.3.24", features = ["alloc", "async-await"] } hyper = { version = "0.14.20", default-features = false, features = [ "server", diff --git a/src/error.rs b/src/error.rs index 3771627..7a39c9e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -43,3 +43,12 @@ pub enum RpxyError { #[error(transparent)] Other(#[from] anyhow::Error), } + +#[derive(Debug, Error, Clone)] +pub enum ClientCertsError { + #[error("TLS Client Certificate is Required for Given SNI: {0}")] + ClientCertRequired(String), + + #[error("Inconsistent TLS Client Certificate for Given SNI: {0}")] + InconsistentClientCert(String), +} diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index b6d0146..251f898 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -35,11 +35,26 @@ where listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, + tls_client_auth_result: Option>, ) -> Result> { //////// let mut log_data = MessageLog::from(&req); log_data.client_addr(&client_addr); ////// + // First check client auth result if exist + if let Some(res) = tls_client_auth_result { + match res { + Err(ClientCertsError::ClientCertRequired(_)) => { + // Client cert is required for the TLS server name + return self.return_with_error_log(StatusCode::FORBIDDEN, &mut log_data); + } + Err(ClientCertsError::InconsistentClientCert(_)) => { + // Client cert provided was inconsistent to the TLS server name + return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); + } + _ => (), + } + } // Here we start to handle with server_name let server_name = if let Ok(v) = req.parse_host() { diff --git a/src/proxy/proxy_client_cert.rs b/src/proxy/proxy_client_cert.rs index c77b0f9..aa212c1 100644 --- a/src/proxy/proxy_client_cert.rs +++ b/src/proxy/proxy_client_cert.rs @@ -7,9 +7,9 @@ use x509_parser::prelude::*; // TODO: consider move this function to the layer of handle_request (L7) to return 403 pub(super) fn check_client_authentication( client_certs: Option<&[Certificate]>, - client_certs_setting_for_sni: Option<&HashSet>>, -) -> Result<()> { - let client_ca_keyids_set = match client_certs_setting_for_sni { + client_ca_keyids_set_for_sni: Option<&HashSet>>, +) -> std::result::Result<(), ClientCertsError> { + let client_ca_keyids_set = match client_ca_keyids_set_for_sni { Some(c) => c, None => { // No client cert settings for given server name @@ -23,9 +23,8 @@ pub(super) fn check_client_authentication( c } None => { - // TODO: return 403 here error!("Client certificate is needed for given server name"); - return Err(RpxyError::Proxy( + return Err(ClientCertsError::ClientCertRequired( "Client certificate is needed for given server name".to_string(), )); } @@ -45,9 +44,8 @@ pub(super) fn check_client_authentication( }); if !match_server_crypto_and_client_cert { - // TODO: return 403 here error!("Inconsistent client certificate was provided for SNI"); - return Err(RpxyError::Proxy( + return Err(ClientCertsError::InconsistentClientCert( "Inconsistent client certificate was provided for SNI".to_string(), )); } diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs index 63369cf..d0aaf5d 100644 --- a/src/proxy/proxy_h3.rs +++ b/src/proxy/proxy_h3.rs @@ -21,7 +21,6 @@ where match conn.await { Ok(new_conn) => { // Check client certificates - // TODO: consider move this function to the layer of handle_request (L7) to return 403 let cc = { // https://docs.rs/quinn/latest/quinn/struct.Connection.html let client_certs_setting_for_sni = sni_cc_map.get(&tls_server_name); @@ -34,7 +33,8 @@ where }; (client_certs, client_certs_setting_for_sni) }; - check_client_authentication(cc.0.as_ref().map(AsRef::as_ref), cc.1)?; + // TODO: pass this value to the layer of handle_request (L7) to return 403 + let tls_client_auth_result = check_client_authentication(cc.0.as_ref().map(AsRef::as_ref), cc.1); let mut h3_conn = h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)).await?; info!( @@ -61,10 +61,17 @@ where let self_inner = self.clone(); let tls_server_name_inner = tls_server_name.clone(); + let tls_client_auth_result_inner = tls_client_auth_result.clone(); self.globals.runtime_handle.spawn(async move { if let Err(e) = timeout( self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 - self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), + self_inner.stream_serve_h3( + req, + stream, + client_addr, + tls_server_name_inner, + tls_client_auth_result_inner, + ), ) .await { @@ -90,6 +97,7 @@ where stream: RequestStream, client_addr: SocketAddr, tls_server_name: ServerNameBytesExp, + tls_client_auth_result: std::result::Result<(), ClientCertsError>, ) -> Result<()> where S: BidiStream + Send + 'static, @@ -141,6 +149,7 @@ where self.listening_on, self.tls_enabled, Some(tls_server_name), + Some(tls_client_auth_result), ) .await?; diff --git a/src/proxy/proxy_main.rs b/src/proxy/proxy_main.rs index 964ad70..8501902 100644 --- a/src/proxy/proxy_main.rs +++ b/src/proxy/proxy_main.rs @@ -51,6 +51,7 @@ where server: Http, peer_addr: SocketAddr, tls_server_name: Option, + tls_client_auth_result: Option>, ) where I: AsyncRead + AsyncWrite + Send + Unpin + 'static, { @@ -74,6 +75,7 @@ where self.listening_on, self.tls_enabled, tls_server_name.clone(), + tls_client_auth_result.clone(), ) }), ) @@ -92,7 +94,9 @@ where let tcp_listener = TcpListener::bind(&self.listening_on).await?; info!("Start TCP proxy serving with HTTP request for configured host names"); while let Ok((stream, _client_addr)) = tcp_listener.accept().await { - self.clone().client_serve(stream, server.clone(), _client_addr, None); + self + .clone() + .client_serve(stream, server.clone(), _client_addr, None, None); } Ok(()) as Result<()> }; diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 6e3200c..b90c8a3 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -92,13 +92,14 @@ where } else { ////////////////////////////// // Check client certificate - // TODO: consider move this function to the layer of handle_request (L7) to return 403 let client_certs = conn.peer_certificates(); - let client_certs_setting_for_sni = sni_cc_map.get(&server_name.clone().unwrap()); - check_client_authentication(client_certs, client_certs_setting_for_sni)?; + let client_ca_keyids_set_for_sni = sni_cc_map.get(&server_name.clone().unwrap()); + // TODO: pass this value to the layer of handle_request (L7) to return 403 + let client_certs_auth_result = check_client_authentication(client_certs, client_ca_keyids_set_for_sni); ////////////////////////////// // this immediately spawns another future to actually handle stream. so it is okay to introduce timeout for handshake. - self_inner.client_serve(stream, server_clone, client_addr, server_name); // TODO: don't want to pass copied value... + // TODO: don't want to pass copied value... + self_inner.client_serve(stream, server_clone, client_addr, server_name, Some(client_certs_auth_result)); Ok(()) } };