fix: change tokio::sync::Notify to tokio_util::sync::CancellationToken

This commit is contained in:
Jun Kurihara 2024-07-26 20:58:00 +09:00
commit 0950fdbd15
No known key found for this signature in database
GPG key ID: D992B3E3DE1DED23
8 changed files with 22 additions and 18 deletions

View file

@ -30,4 +30,5 @@ rustls-acme = { path = "../submodules/rustls-acme/", default-features = false, f
"aws-lc-rs", "aws-lc-rs",
] } ] }
tokio = { version = "1.39.1", default-features = false } tokio = { version = "1.39.1", default-features = false }
tokio-util = { version = "0.7.11", default-features = false }
tokio-stream = { version = "0.1.15", default-features = false } tokio-stream = { version = "0.1.15", default-features = false }

View file

@ -74,7 +74,7 @@ impl AcmeManager {
/// Returns a Vec<JoinHandle<()>> as a tasks handles and a map of domain to ServerConfig for challenge. /// Returns a Vec<JoinHandle<()>> as a tasks handles and a map of domain to ServerConfig for challenge.
pub fn spawn_manager_tasks( pub fn spawn_manager_tasks(
&self, &self,
term_notify: Option<Arc<tokio::sync::Notify>>, cancel_token: Option<tokio_util::sync::CancellationToken>,
) -> (Vec<tokio::task::JoinHandle<()>>, HashMap<String, Arc<ServerConfig>>) { ) -> (Vec<tokio::task::JoinHandle<()>>, HashMap<String, Arc<ServerConfig>>) {
let rustls_client_config = rustls::ClientConfig::builder() let rustls_client_config = rustls::ClientConfig::builder()
.dangerous() // The `Verifier` we're using is actually safe .dangerous() // The `Verifier` we're using is actually safe
@ -96,7 +96,7 @@ impl AcmeManager {
let mut state = config.state(); let mut state = config.state();
server_configs_for_challenge.insert(domain.to_ascii_lowercase(), state.challenge_rustls_config()); server_configs_for_challenge.insert(domain.to_ascii_lowercase(), state.challenge_rustls_config());
self.runtime_handle.spawn({ self.runtime_handle.spawn({
let term_notify = term_notify.clone(); let cancel_token = cancel_token.clone();
async move { async move {
info!("rpxy ACME manager task for {domain} started"); info!("rpxy ACME manager task for {domain} started");
// infinite loop unless the return value is None // infinite loop unless the return value is None
@ -112,10 +112,10 @@ impl AcmeManager {
} }
} }
}; };
if let Some(notify) = term_notify.as_ref() { if let Some(cancel_token) = cancel_token.as_ref() {
tokio::select! { tokio::select! {
_ = task => {}, _ = task => {},
_ = notify.notified() => { info!("rpxy ACME manager task for {domain} terminated") } _ = cancel_token.cancelled() => { info!("rpxy ACME manager task for {domain} terminated") }
} }
} else { } else {
task.await; task.await;

View file

@ -39,6 +39,7 @@ tokio = { version = "1.39.1", default-features = false, features = [
"sync", "sync",
"macros", "macros",
] } ] }
tokio-util = { version = "0.7.11", default-features = false }
async-trait = "0.1.81" async-trait = "0.1.81"
futures-util = { version = "0.3.30", default-features = false } futures-util = { version = "0.3.30", default-features = false }

View file

@ -127,11 +127,11 @@ async fn rpxy_service_with_watcher(
.await .await
.map_err(|e| anyhow!("Invalid cert configuration: {e}"))?; .map_err(|e| anyhow!("Invalid cert configuration: {e}"))?;
// Notifier for proxy service termination
let term_notify = std::sync::Arc::new(tokio::sync::Notify::new());
// Continuous monitoring // Continuous monitoring
loop { loop {
// Notifier for proxy service termination
let cancel_token = tokio_util::sync::CancellationToken::new();
let (cert_service, cert_rx) = cert_service_and_rx let (cert_service, cert_rx) = cert_service_and_rx
.as_ref() .as_ref()
.map(|(s, r)| (Some(s), Some(r))) .map(|(s, r)| (Some(s), Some(r)))
@ -140,7 +140,7 @@ async fn rpxy_service_with_watcher(
#[cfg(feature = "acme")] #[cfg(feature = "acme")]
let (acme_join_handles, server_config_acme_challenge) = acme_manager let (acme_join_handles, server_config_acme_challenge) = acme_manager
.as_ref() .as_ref()
.map(|m| m.spawn_manager_tasks(Some(term_notify.clone()))) .map(|m| m.spawn_manager_tasks(Some(cancel_token.child_token())))
.unwrap_or((vec![], Default::default())); .unwrap_or((vec![], Default::default()));
let rpxy_opts = { let rpxy_opts = {
@ -150,7 +150,7 @@ async fn rpxy_service_with_watcher(
.app_config_list(app_conf.clone()) .app_config_list(app_conf.clone())
.cert_rx(cert_rx.cloned()) .cert_rx(cert_rx.cloned())
.runtime_handle(runtime_handle.clone()) .runtime_handle(runtime_handle.clone())
.term_notify(Some(term_notify.clone())) .cancel_token(Some(cancel_token.child_token()))
.server_configs_acme_challenge(std::sync::Arc::new(server_config_acme_challenge)) .server_configs_acme_challenge(std::sync::Arc::new(server_config_acme_challenge))
.build(); .build();
@ -216,8 +216,7 @@ async fn rpxy_service_with_watcher(
} }
info!("Configuration updated. Terminate all spawned services and force to re-bind TCP/UDP sockets"); info!("Configuration updated. Terminate all spawned services and force to re-bind TCP/UDP sockets");
term_notify.notify_waiters(); cancel_token.cancel();
// tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
} }
else => break else => break
} }

View file

@ -43,6 +43,7 @@ tokio = { version = "1.39.1", default-features = false, features = [
"macros", "macros",
"fs", "fs",
] } ] }
tokio-util = { version = "0.7.11", default-features = false }
pin-project-lite = "0.2.14" pin-project-lite = "0.2.14"
async-trait = "0.1.81" async-trait = "0.1.81"

View file

@ -2,6 +2,7 @@ use crate::{constants::*, count::RequestCount};
use hot_reload::ReloaderReceiver; use hot_reload::ReloaderReceiver;
use rpxy_certs::ServerCryptoBase; use rpxy_certs::ServerCryptoBase;
use std::{net::SocketAddr, sync::Arc, time::Duration}; use std::{net::SocketAddr, sync::Arc, time::Duration};
use tokio_util::sync::CancellationToken;
/// Global object containing proxy configurations and shared object like counters. /// Global object containing proxy configurations and shared object like counters.
/// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
@ -13,7 +14,7 @@ pub struct Globals {
/// Shared context - Async task runtime handler /// Shared context - Async task runtime handler
pub runtime_handle: tokio::runtime::Handle, pub runtime_handle: tokio::runtime::Handle,
/// Shared context - Notify object to stop async tasks /// Shared context - Notify object to stop async tasks
pub term_notify: Option<Arc<tokio::sync::Notify>>, pub cancel_token: Option<CancellationToken>,
/// Shared context - Certificate reloader service receiver // TODO: newer one /// Shared context - Certificate reloader service receiver // TODO: newer one
pub cert_reloader_rx: Option<ReloaderReceiver<ServerCryptoBase>>, pub cert_reloader_rx: Option<ReloaderReceiver<ServerCryptoBase>>,

View file

@ -23,6 +23,7 @@ use futures::future::select_all;
use hot_reload::ReloaderReceiver; use hot_reload::ReloaderReceiver;
use rpxy_certs::ServerCryptoBase; use rpxy_certs::ServerCryptoBase;
use std::sync::Arc; use std::sync::Arc;
use tokio_util::sync::CancellationToken;
/* ------------------------------------------------ */ /* ------------------------------------------------ */
pub use crate::globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}; pub use crate::globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri};
@ -42,7 +43,7 @@ pub struct RpxyOptions {
/// Async task runtime handler /// Async task runtime handler
pub runtime_handle: tokio::runtime::Handle, pub runtime_handle: tokio::runtime::Handle,
/// Notify object to stop async tasks /// Notify object to stop async tasks
pub term_notify: Option<Arc<tokio::sync::Notify>>, pub cancel_token: Option<CancellationToken>,
#[cfg(feature = "acme")] #[cfg(feature = "acme")]
/// ServerConfig used for only ACME challenge for ACME domains /// ServerConfig used for only ACME challenge for ACME domains
@ -56,7 +57,7 @@ pub async fn entrypoint(
app_config_list, app_config_list,
cert_rx, // TODO: cert_rx, // TODO:
runtime_handle, runtime_handle,
term_notify, cancel_token,
#[cfg(feature = "acme")] #[cfg(feature = "acme")]
server_configs_acme_challenge, server_configs_acme_challenge,
}: &RpxyOptions, }: &RpxyOptions,
@ -107,7 +108,7 @@ pub async fn entrypoint(
proxy_config: proxy_config.clone(), proxy_config: proxy_config.clone(),
request_count: Default::default(), request_count: Default::default(),
runtime_handle: runtime_handle.clone(), runtime_handle: runtime_handle.clone(),
term_notify: term_notify.clone(), cancel_token: cancel_token.clone(),
cert_reloader_rx: cert_rx.clone(), cert_reloader_rx: cert_rx.clone(),
#[cfg(feature = "acme")] #[cfg(feature = "acme")]

View file

@ -312,13 +312,13 @@ where
} }
}; };
match &self.globals.term_notify { match &self.globals.cancel_token {
Some(term) => { Some(cancel_token) => {
select! { select! {
_ = proxy_service.fuse() => { _ = proxy_service.fuse() => {
warn!("Proxy service got down"); warn!("Proxy service got down");
} }
_ = term.notified().fuse() => { _ = cancel_token.cancelled().fuse() => {
info!("Proxy service listening on {} receives term signal", self.listening_on); info!("Proxy service listening on {} receives term signal", self.listening_on);
} }
} }