diff --git a/rpxy-acme/Cargo.toml b/rpxy-acme/Cargo.toml index d1d318c..e996225 100644 --- a/rpxy-acme/Cargo.toml +++ b/rpxy-acme/Cargo.toml @@ -30,4 +30,5 @@ rustls-acme = { path = "../submodules/rustls-acme/", default-features = false, f "aws-lc-rs", ] } tokio = { version = "1.39.1", default-features = false } +tokio-util = { version = "0.7.11", default-features = false } tokio-stream = { version = "0.1.15", default-features = false } diff --git a/rpxy-acme/src/manager.rs b/rpxy-acme/src/manager.rs index e54731c..9242743 100644 --- a/rpxy-acme/src/manager.rs +++ b/rpxy-acme/src/manager.rs @@ -74,7 +74,7 @@ impl AcmeManager { /// Returns a Vec> as a tasks handles and a map of domain to ServerConfig for challenge. pub fn spawn_manager_tasks( &self, - term_notify: Option>, + cancel_token: Option, ) -> (Vec>, HashMap>) { let rustls_client_config = rustls::ClientConfig::builder() .dangerous() // The `Verifier` we're using is actually safe @@ -96,7 +96,7 @@ impl AcmeManager { let mut state = config.state(); server_configs_for_challenge.insert(domain.to_ascii_lowercase(), state.challenge_rustls_config()); self.runtime_handle.spawn({ - let term_notify = term_notify.clone(); + let cancel_token = cancel_token.clone(); async move { info!("rpxy ACME manager task for {domain} started"); // infinite loop unless the return value is None @@ -112,10 +112,10 @@ impl AcmeManager { } } }; - if let Some(notify) = term_notify.as_ref() { + if let Some(cancel_token) = cancel_token.as_ref() { tokio::select! { _ = task => {}, - _ = notify.notified() => { info!("rpxy ACME manager task for {domain} terminated") } + _ = cancel_token.cancelled() => { info!("rpxy ACME manager task for {domain} terminated") } } } else { task.await; diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index 73911d8..e42212e 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -39,6 +39,7 @@ tokio = { version = "1.39.1", default-features = false, features = [ "sync", "macros", ] } +tokio-util = { version = "0.7.11", default-features = false } async-trait = "0.1.81" futures-util = { version = "0.3.30", default-features = false } diff --git a/rpxy-bin/src/main.rs b/rpxy-bin/src/main.rs index eff2648..d7f7121 100644 --- a/rpxy-bin/src/main.rs +++ b/rpxy-bin/src/main.rs @@ -127,11 +127,11 @@ async fn rpxy_service_with_watcher( .await .map_err(|e| anyhow!("Invalid cert configuration: {e}"))?; - // Notifier for proxy service termination - let term_notify = std::sync::Arc::new(tokio::sync::Notify::new()); - // Continuous monitoring loop { + // Notifier for proxy service termination + let cancel_token = tokio_util::sync::CancellationToken::new(); + let (cert_service, cert_rx) = cert_service_and_rx .as_ref() .map(|(s, r)| (Some(s), Some(r))) @@ -140,7 +140,7 @@ async fn rpxy_service_with_watcher( #[cfg(feature = "acme")] let (acme_join_handles, server_config_acme_challenge) = acme_manager .as_ref() - .map(|m| m.spawn_manager_tasks(Some(term_notify.clone()))) + .map(|m| m.spawn_manager_tasks(Some(cancel_token.child_token()))) .unwrap_or((vec![], Default::default())); let rpxy_opts = { @@ -150,7 +150,7 @@ async fn rpxy_service_with_watcher( .app_config_list(app_conf.clone()) .cert_rx(cert_rx.cloned()) .runtime_handle(runtime_handle.clone()) - .term_notify(Some(term_notify.clone())) + .cancel_token(Some(cancel_token.child_token())) .server_configs_acme_challenge(std::sync::Arc::new(server_config_acme_challenge)) .build(); @@ -216,8 +216,7 @@ async fn rpxy_service_with_watcher( } info!("Configuration updated. Terminate all spawned services and force to re-bind TCP/UDP sockets"); - term_notify.notify_waiters(); - // tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + cancel_token.cancel(); } else => break } diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index 3386d32..452a322 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -43,6 +43,7 @@ tokio = { version = "1.39.1", default-features = false, features = [ "macros", "fs", ] } +tokio-util = { version = "0.7.11", default-features = false } pin-project-lite = "0.2.14" async-trait = "0.1.81" diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index 8c5e093..de1983d 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -2,6 +2,7 @@ use crate::{constants::*, count::RequestCount}; use hot_reload::ReloaderReceiver; use rpxy_certs::ServerCryptoBase; use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tokio_util::sync::CancellationToken; /// Global object containing proxy configurations and shared object like counters. /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. @@ -13,7 +14,7 @@ pub struct Globals { /// Shared context - Async task runtime handler pub runtime_handle: tokio::runtime::Handle, /// Shared context - Notify object to stop async tasks - pub term_notify: Option>, + pub cancel_token: Option, /// Shared context - Certificate reloader service receiver // TODO: newer one pub cert_reloader_rx: Option>, diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index 9dd78da..3a6097d 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -23,6 +23,7 @@ use futures::future::select_all; use hot_reload::ReloaderReceiver; use rpxy_certs::ServerCryptoBase; use std::sync::Arc; +use tokio_util::sync::CancellationToken; /* ------------------------------------------------ */ pub use crate::globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}; @@ -42,7 +43,7 @@ pub struct RpxyOptions { /// Async task runtime handler pub runtime_handle: tokio::runtime::Handle, /// Notify object to stop async tasks - pub term_notify: Option>, + pub cancel_token: Option, #[cfg(feature = "acme")] /// ServerConfig used for only ACME challenge for ACME domains @@ -56,7 +57,7 @@ pub async fn entrypoint( app_config_list, cert_rx, // TODO: runtime_handle, - term_notify, + cancel_token, #[cfg(feature = "acme")] server_configs_acme_challenge, }: &RpxyOptions, @@ -107,7 +108,7 @@ pub async fn entrypoint( proxy_config: proxy_config.clone(), request_count: Default::default(), runtime_handle: runtime_handle.clone(), - term_notify: term_notify.clone(), + cancel_token: cancel_token.clone(), cert_reloader_rx: cert_rx.clone(), #[cfg(feature = "acme")] diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 3690d35..9be175d 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -312,13 +312,13 @@ where } }; - match &self.globals.term_notify { - Some(term) => { + match &self.globals.cancel_token { + Some(cancel_token) => { select! { _ = proxy_service.fuse() => { warn!("Proxy service got down"); } - _ = term.notified().fuse() => { + _ = cancel_token.cancelled().fuse() => { info!("Proxy service listening on {} receives term signal", self.listening_on); } }