update response to invalid client certificate or no client certificate

This commit is contained in:
Jun Kurihara 2022-10-12 15:16:40 +09:00
commit c765da33db
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
7 changed files with 52 additions and 16 deletions

View file

@ -18,7 +18,7 @@ http3 = ["quinn", "h3", "h3-quinn"]
[dependencies] [dependencies]
env_logger = "0.9.1" env_logger = "0.9.1"
anyhow = "1.0.65" anyhow = "1.0.65"
clap = { version = "4.0.12", features = ["std", "cargo", "wrap_help"] } clap = { version = "4.0.13", features = ["std", "cargo", "wrap_help"] }
futures = { version = "0.3.24", features = ["alloc", "async-await"] } futures = { version = "0.3.24", features = ["alloc", "async-await"] }
hyper = { version = "0.14.20", default-features = false, features = [ hyper = { version = "0.14.20", default-features = false, features = [
"server", "server",

View file

@ -43,3 +43,12 @@ pub enum RpxyError {
#[error(transparent)] #[error(transparent)]
Other(#[from] anyhow::Error), Other(#[from] anyhow::Error),
} }
#[derive(Debug, Error, Clone)]
pub enum ClientCertsError {
#[error("TLS Client Certificate is Required for Given SNI: {0}")]
ClientCertRequired(String),
#[error("Inconsistent TLS Client Certificate for Given SNI: {0}")]
InconsistentClientCert(String),
}

View file

@ -35,11 +35,26 @@ where
listen_addr: SocketAddr, listen_addr: SocketAddr,
tls_enabled: bool, tls_enabled: bool,
tls_server_name: Option<ServerNameBytesExp>, tls_server_name: Option<ServerNameBytesExp>,
tls_client_auth_result: Option<std::result::Result<(), ClientCertsError>>,
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
//////// ////////
let mut log_data = MessageLog::from(&req); let mut log_data = MessageLog::from(&req);
log_data.client_addr(&client_addr); log_data.client_addr(&client_addr);
////// //////
// First check client auth result if exist
if let Some(res) = tls_client_auth_result {
match res {
Err(ClientCertsError::ClientCertRequired(_)) => {
// Client cert is required for the TLS server name
return self.return_with_error_log(StatusCode::FORBIDDEN, &mut log_data);
}
Err(ClientCertsError::InconsistentClientCert(_)) => {
// Client cert provided was inconsistent to the TLS server name
return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data);
}
_ => (),
}
}
// Here we start to handle with server_name // Here we start to handle with server_name
let server_name = if let Ok(v) = req.parse_host() { let server_name = if let Ok(v) = req.parse_host() {

View file

@ -7,9 +7,9 @@ use x509_parser::prelude::*;
// TODO: consider move this function to the layer of handle_request (L7) to return 403 // TODO: consider move this function to the layer of handle_request (L7) to return 403
pub(super) fn check_client_authentication( pub(super) fn check_client_authentication(
client_certs: Option<&[Certificate]>, client_certs: Option<&[Certificate]>,
client_certs_setting_for_sni: Option<&HashSet<Vec<u8>>>, client_ca_keyids_set_for_sni: Option<&HashSet<Vec<u8>>>,
) -> Result<()> { ) -> std::result::Result<(), ClientCertsError> {
let client_ca_keyids_set = match client_certs_setting_for_sni { let client_ca_keyids_set = match client_ca_keyids_set_for_sni {
Some(c) => c, Some(c) => c,
None => { None => {
// No client cert settings for given server name // No client cert settings for given server name
@ -23,9 +23,8 @@ pub(super) fn check_client_authentication(
c c
} }
None => { None => {
// TODO: return 403 here
error!("Client certificate is needed for given server name"); error!("Client certificate is needed for given server name");
return Err(RpxyError::Proxy( return Err(ClientCertsError::ClientCertRequired(
"Client certificate is needed for given server name".to_string(), "Client certificate is needed for given server name".to_string(),
)); ));
} }
@ -45,9 +44,8 @@ pub(super) fn check_client_authentication(
}); });
if !match_server_crypto_and_client_cert { if !match_server_crypto_and_client_cert {
// TODO: return 403 here
error!("Inconsistent client certificate was provided for SNI"); error!("Inconsistent client certificate was provided for SNI");
return Err(RpxyError::Proxy( return Err(ClientCertsError::InconsistentClientCert(
"Inconsistent client certificate was provided for SNI".to_string(), "Inconsistent client certificate was provided for SNI".to_string(),
)); ));
} }

View file

@ -21,7 +21,6 @@ where
match conn.await { match conn.await {
Ok(new_conn) => { Ok(new_conn) => {
// Check client certificates // Check client certificates
// TODO: consider move this function to the layer of handle_request (L7) to return 403
let cc = { let cc = {
// https://docs.rs/quinn/latest/quinn/struct.Connection.html // https://docs.rs/quinn/latest/quinn/struct.Connection.html
let client_certs_setting_for_sni = sni_cc_map.get(&tls_server_name); let client_certs_setting_for_sni = sni_cc_map.get(&tls_server_name);
@ -34,7 +33,8 @@ where
}; };
(client_certs, client_certs_setting_for_sni) (client_certs, client_certs_setting_for_sni)
}; };
check_client_authentication(cc.0.as_ref().map(AsRef::as_ref), cc.1)?; // TODO: pass this value to the layer of handle_request (L7) to return 403
let tls_client_auth_result = check_client_authentication(cc.0.as_ref().map(AsRef::as_ref), cc.1);
let mut h3_conn = h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)).await?; let mut h3_conn = h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)).await?;
info!( info!(
@ -61,10 +61,17 @@ where
let self_inner = self.clone(); let self_inner = self.clone();
let tls_server_name_inner = tls_server_name.clone(); let tls_server_name_inner = tls_server_name.clone();
let tls_client_auth_result_inner = tls_client_auth_result.clone();
self.globals.runtime_handle.spawn(async move { self.globals.runtime_handle.spawn(async move {
if let Err(e) = timeout( if let Err(e) = timeout(
self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 self_inner.globals.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), self_inner.stream_serve_h3(
req,
stream,
client_addr,
tls_server_name_inner,
tls_client_auth_result_inner,
),
) )
.await .await
{ {
@ -90,6 +97,7 @@ where
stream: RequestStream<S, Bytes>, stream: RequestStream<S, Bytes>,
client_addr: SocketAddr, client_addr: SocketAddr,
tls_server_name: ServerNameBytesExp, tls_server_name: ServerNameBytesExp,
tls_client_auth_result: std::result::Result<(), ClientCertsError>,
) -> Result<()> ) -> Result<()>
where where
S: BidiStream<Bytes> + Send + 'static, S: BidiStream<Bytes> + Send + 'static,
@ -141,6 +149,7 @@ where
self.listening_on, self.listening_on,
self.tls_enabled, self.tls_enabled,
Some(tls_server_name), Some(tls_server_name),
Some(tls_client_auth_result),
) )
.await?; .await?;

View file

@ -51,6 +51,7 @@ where
server: Http<LocalExecutor>, server: Http<LocalExecutor>,
peer_addr: SocketAddr, peer_addr: SocketAddr,
tls_server_name: Option<ServerNameBytesExp>, tls_server_name: Option<ServerNameBytesExp>,
tls_client_auth_result: Option<std::result::Result<(), ClientCertsError>>,
) where ) where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static, I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{ {
@ -74,6 +75,7 @@ where
self.listening_on, self.listening_on,
self.tls_enabled, self.tls_enabled,
tls_server_name.clone(), tls_server_name.clone(),
tls_client_auth_result.clone(),
) )
}), }),
) )
@ -92,7 +94,9 @@ where
let tcp_listener = TcpListener::bind(&self.listening_on).await?; let tcp_listener = TcpListener::bind(&self.listening_on).await?;
info!("Start TCP proxy serving with HTTP request for configured host names"); info!("Start TCP proxy serving with HTTP request for configured host names");
while let Ok((stream, _client_addr)) = tcp_listener.accept().await { while let Ok((stream, _client_addr)) = tcp_listener.accept().await {
self.clone().client_serve(stream, server.clone(), _client_addr, None); self
.clone()
.client_serve(stream, server.clone(), _client_addr, None, None);
} }
Ok(()) as Result<()> Ok(()) as Result<()>
}; };

View file

@ -92,13 +92,14 @@ where
} else { } else {
////////////////////////////// //////////////////////////////
// Check client certificate // Check client certificate
// TODO: consider move this function to the layer of handle_request (L7) to return 403
let client_certs = conn.peer_certificates(); let client_certs = conn.peer_certificates();
let client_certs_setting_for_sni = sni_cc_map.get(&server_name.clone().unwrap()); let client_ca_keyids_set_for_sni = sni_cc_map.get(&server_name.clone().unwrap());
check_client_authentication(client_certs, client_certs_setting_for_sni)?; // TODO: pass this value to the layer of handle_request (L7) to return 403
let client_certs_auth_result = check_client_authentication(client_certs, client_ca_keyids_set_for_sni);
////////////////////////////// //////////////////////////////
// this immediately spawns another future to actually handle stream. so it is okay to introduce timeout for handshake. // this immediately spawns another future to actually handle stream. so it is okay to introduce timeout for handshake.
self_inner.client_serve(stream, server_clone, client_addr, server_name); // TODO: don't want to pass copied value... // TODO: don't want to pass copied value...
self_inner.client_serve(stream, server_clone, client_addr, server_name, Some(client_certs_auth_result));
Ok(()) Ok(())
} }
}; };