From 2d79be55772c05ac304780b5472cc14c4a7b9c14 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 28 Apr 2023 20:03:50 +0900 Subject: [PATCH 01/10] add some comments to design sticky session --- config-example.toml | 4 ++-- src/backend/mod.rs | 5 ++++- src/backend/upstream.rs | 24 ++++++++++++++++++++++-- src/handler/handler_main.rs | 2 ++ src/utils/bytes_name.rs | 6 +++--- 5 files changed, 33 insertions(+), 8 deletions(-) diff --git a/config-example.toml b/config-example.toml index 9b2b463..673f4db 100644 --- a/config-example.toml +++ b/config-example.toml @@ -60,8 +60,8 @@ path = '/maps' # For request path starting with "/maps", # this configuration results that any path like "/maps/org/any.ext" is mapped to "/replacing/path1/org/any.ext" # by replacing "/maps" with "/replacing/path1" for routing to the locations given in upstream array -# Note that unless "path_replaced_with" is specified, the "path" is always preserved. -# "path_replaced_with" must be start from "/" (root path) +# Note that unless "replace_path" is specified, the "path" is always preserved. +# "replace_path" must be start from "/" (root path) replace_path = "/replacing/path1" upstream = [ { location = 'www.bing.com', tls = true }, diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 409222d..a31fd15 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -28,12 +28,15 @@ use x509_parser::prelude::*; #[derive(Builder)] pub struct Backend { #[builder(setter(into))] + /// backend application name, e.g., app1 pub app_name: String, #[builder(setter(custom))] + /// server name, e.g., example.com, in String ascii lower case pub server_name: String, + /// struct of reverse proxy serving incoming request pub reverse_proxy: ReverseProxy, - // tls settings + /// tls settings #[builder(setter(custom), default)] pub tls_cert_path: Option, #[builder(setter(custom), default)] diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index c2fdb34..2a56cae 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -17,6 +17,7 @@ pub struct ReverseProxy { } impl ReverseProxy { + /// Get an appropriate upstream destination for given path string. pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamGroup> { // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 // コスト的にこの程度で十分 @@ -52,9 +53,14 @@ impl ReverseProxy { #[allow(dead_code)] #[derive(Debug, Clone)] +/// Load Balancing Option pub enum LoadBalance { + /// Simple round robin without session persistance RoundRobin, + /// Randomly chose one upstream server Random, + /// Round robin with session persistance using cookie + StickyRoundRobin, } impl Default for LoadBalance { fn default() -> Self { @@ -63,22 +69,32 @@ impl Default for LoadBalance { } #[derive(Debug, Clone)] +/// Upstream struct just containing uri without path pub struct Upstream { - pub uri: hyper::Uri, // base uri without specific path + /// Base uri without specific path + pub uri: hyper::Uri, } #[derive(Debug, Clone, Builder)] +/// Struct serving multiple upstream servers for, e.g., load balancing. pub struct UpstreamGroup { + /// Upstream server(s) pub upstream: Vec, #[builder(setter(custom), default)] + /// Path like "/path" in [[PathNameBytesExp]] associated with the upstream server(s) pub path: PathNameBytesExp, #[builder(setter(custom), default)] + /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url pub replace_path: Option, + #[builder(default)] + /// Load balancing option pub lb: LoadBalance, #[builder(default)] - pub cnt: UpstreamCount, // counter for load balancing + /// Counter for load balancing + pub cnt: UpstreamCount, #[builder(setter(custom), default)] + /// Activated upstream options defined in [[UpstreamOption]] pub opts: HashSet, } impl UpstreamGroupBuilder { @@ -116,6 +132,7 @@ impl UpstreamGroupBuilder { pub struct UpstreamCount(Arc); impl UpstreamGroup { + /// Get an enabled option of load balancing [[LoadBalance]] pub fn get(&self) -> Option<&Upstream> { match self.lb { LoadBalance::RoundRobin => { @@ -127,13 +144,16 @@ impl UpstreamGroup { let max = self.upstream.len() - 1; self.upstream.get(rng.gen_range(0..max)) } + LoadBalance::StickyRoundRobin => todo!(), // TODO: TODO: } } + /// Get a current count of upstream served fn current_cnt(&self) -> usize { self.cnt.0.load(Ordering::Relaxed) } + /// Increment count of upstream served fn increment_cnt(&self) -> usize { if self.current_cnt() < self.upstream.len() - 1 { self.cnt.0.fetch_add(1, Ordering::Relaxed) diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index 2ef5665..664f5e4 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -266,6 +266,8 @@ where }; // Fix unique upstream destination since there could be multiple ones. + // TODO: StickyならCookieをここでgetに与える必要 + // TODO: Stickyで、Cookieが与えられなかったらset-cookie向けにcookieを返す必要。upstreamオブジェクトに含めるのも手。 let upstream_chosen = upstream_group.get().ok_or_else(|| anyhow!("Failed to get upstream"))?; // apply upstream-specific headers given in upstream_option diff --git a/src/utils/bytes_name.rs b/src/utils/bytes_name.rs index 80bc0f0..16ec7ab 100644 --- a/src/utils/bytes_name.rs +++ b/src/utils/bytes_name.rs @@ -1,5 +1,5 @@ /// Server name (hostname or ip address) representation in bytes-based struct -/// For searching hashmap or key list by exact or longest-prefix matching +/// for searching hashmap or key list by exact or longest-prefix matching #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] pub struct ServerNameBytesExp(pub Vec); // lowercase ascii bytes impl From<&[u8]> for ServerNameBytesExp { @@ -8,8 +8,8 @@ impl From<&[u8]> for ServerNameBytesExp { } } -/// Server name (hostname or ip address) representation in bytes-based struct -/// For searching hashmap or key list by exact or longest-prefix matching +/// Path name, like "/path/ok", represented in bytes-based struct +/// for searching hashmap or key list by exact or longest-prefix matching #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] pub struct PathNameBytesExp(pub Vec); // lowercase ascii bytes impl PathNameBytesExp { From f66be5fef1ff9b1810c37a572c8ed8e50dd60af8 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 26 May 2023 15:43:23 +0900 Subject: [PATCH 02/10] update default load balance option --- README.md | 5 ++++- config-example.toml | 2 ++ src/backend/load_balance.rs | 25 +++++++++++++++++++++ src/backend/mod.rs | 1 + src/backend/upstream.rs | 45 +++++++++++++++++++++---------------- src/config/parse.rs | 1 + src/config/toml.rs | 1 + 7 files changed, 60 insertions(+), 20 deletions(-) create mode 100644 src/backend/load_balance.rs diff --git a/README.md b/README.md index 3f255df..3bce303 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ revese_proxy = [ #### Load Balancing -You can specify multiple backend locations in the `reverse_proxy` array for *load-balancing*. Currently it works in the manner of round-robin. +You can specify multiple backend locations in the `reverse_proxy` array for *load-balancing* with an appropriate `load_balance` option. Currently it works in the manner of round-robin or in the random fashion. if `load_balance` is not specified, the first backend location is always chosen. ```toml [apps."app_name"] @@ -117,8 +117,11 @@ reverse_proxy = [ { location = 'app1.local:8080' }, { location = 'app2.local:8000' } ] +load_balance = 'round_robin' ``` +(TODO: Sticky session is currently being implemented) + ### Second Step: Terminating TLS First of all, you need to specify a port `listen_port_tls` listening the HTTPS traffic, separately from HTTP port (`listen_port`). Then, serving an HTTPS endpoint can be easily done for your desired application just by specifying TLS certificates and private keys in PEM files. diff --git a/config-example.toml b/config-example.toml index 673f4db..0382393 100644 --- a/config-example.toml +++ b/config-example.toml @@ -52,6 +52,7 @@ upstream = [ { location = 'www.yahoo.com', tls = true }, { location = 'www.yahoo.co.jp', tls = true }, ] +load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = ["override_host", "convert_https_to_2"] # Non-default destination in "localhost" app, which is routed by "path" @@ -67,6 +68,7 @@ upstream = [ { location = 'www.bing.com', tls = true }, { location = 'www.bing.co.jp', tls = true }, ] +load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ "override_host", "upgrade_insecure_requests", diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs new file mode 100644 index 0000000..a189c6f --- /dev/null +++ b/src/backend/load_balance.rs @@ -0,0 +1,25 @@ +/// Constants to specify a load balance option +pub(super) mod load_balance_options { + pub const FIX_TO_FIRST: &str = "none"; + pub const ROUND_ROBIN: &str = "round_robin"; + pub const RANDOM: &str = "random"; + pub const STICKY_ROUND_ROBIN: &str = "sticky"; +} + +#[derive(Debug, Clone)] +/// Load Balancing Option +pub enum LoadBalance { + /// Fix to the first upstream. Use if only one upstream destination is specified + FixToFirst, + /// Simple round robin without session persistance + RoundRobin, // TODO: カウンタはここにいれる。randomとかには不要なので + /// Randomly chose one upstream server + Random, + /// Round robin with session persistance using cookie + StickyRoundRobin, +} +impl Default for LoadBalance { + fn default() -> Self { + Self::FixToFirst + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index a31fd15..00a6c83 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,3 +1,4 @@ +mod load_balance; mod upstream; mod upstream_opts; diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index 2a56cae..10d0a3c 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,4 +1,7 @@ -use super::{BytesName, PathNameBytesExp, UpstreamOption}; +use super::{ + load_balance::{load_balance_options as lb_opts, LoadBalance}, + BytesName, PathNameBytesExp, UpstreamOption, +}; use crate::log::*; use derive_builder::Builder; use rand::Rng; @@ -51,23 +54,6 @@ impl ReverseProxy { } } -#[allow(dead_code)] -#[derive(Debug, Clone)] -/// Load Balancing Option -pub enum LoadBalance { - /// Simple round robin without session persistance - RoundRobin, - /// Randomly chose one upstream server - Random, - /// Round robin with session persistance using cookie - StickyRoundRobin, -} -impl Default for LoadBalance { - fn default() -> Self { - Self::RoundRobin - } -} - #[derive(Debug, Clone)] /// Upstream struct just containing uri without path pub struct Upstream { @@ -87,7 +73,7 @@ pub struct UpstreamGroup { /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url pub replace_path: Option, - #[builder(default)] + #[builder(setter(custom), default)] /// Load balancing option pub lb: LoadBalance, #[builder(default)] @@ -97,6 +83,7 @@ pub struct UpstreamGroup { /// Activated upstream options defined in [[UpstreamOption]] pub opts: HashSet, } + impl UpstreamGroupBuilder { pub fn path(&mut self, v: &Option) -> &mut Self { let path = match v { @@ -114,6 +101,24 @@ impl UpstreamGroupBuilder { ); self } + pub fn lb(&mut self, v: &Option) -> &mut Self { + let lb = if let Some(x) = v { + match x.as_str() { + lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, + lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin, + lb_opts::RANDOM => LoadBalance::Random, + lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin, + _ => { + error!("Specified load balancing option is invalid."); + LoadBalance::default() + } + } + } else { + LoadBalance::default() + }; + self.lb = Some(lb); + self + } pub fn opts(&mut self, v: &Option>) -> &mut Self { let opts = if let Some(opts) = v { opts @@ -128,6 +133,7 @@ impl UpstreamGroupBuilder { } } +// TODO: カウンタの移動 #[derive(Debug, Clone, Default)] pub struct UpstreamCount(Arc); @@ -135,6 +141,7 @@ impl UpstreamGroup { /// Get an enabled option of load balancing [[LoadBalance]] pub fn get(&self) -> Option<&Upstream> { match self.lb { + LoadBalance::FixToFirst => self.upstream.get(0), LoadBalance::RoundRobin => { let idx = self.increment_cnt(); self.upstream.get(idx) diff --git a/src/config/parse.rs b/src/config/parse.rs index 935f86c..d023f56 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -205,6 +205,7 @@ fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result< .upstream(rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect()) .path(&rpo.path) .replace_path(&rpo.replace_path) + .lb(&rpo.load_balance) .opts(&rpo.upstream_options) .build() .unwrap(); diff --git a/src/config/toml.rs b/src/config/toml.rs index cefacb2..6ce48b2 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -57,6 +57,7 @@ pub struct ReverseProxyOption { pub replace_path: Option, pub upstream: Vec, pub upstream_options: Option>, + pub load_balance: Option, } #[derive(Deserialize, Debug, Default)] From 5cba376394b4d6d055f58fb97338d053fa936c8b Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 3 Jun 2023 14:55:34 +0900 Subject: [PATCH 03/10] refactor: update logic of round-robin --- src/backend/load_balance.rs | 43 +++++++++++++++++++++++++++++++++- src/backend/upstream.rs | 46 ++++++++++--------------------------- src/config/parse.rs | 9 +++++--- 3 files changed, 60 insertions(+), 38 deletions(-) diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs index a189c6f..f6b2292 100644 --- a/src/backend/load_balance.rs +++ b/src/backend/load_balance.rs @@ -1,3 +1,9 @@ +use derive_builder::Builder; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + /// Constants to specify a load balance option pub(super) mod load_balance_options { pub const FIX_TO_FIRST: &str = "none"; @@ -6,13 +12,48 @@ pub(super) mod load_balance_options { pub const STICKY_ROUND_ROBIN: &str = "sticky"; } +// +// /// Counter for load balancing +// pub cnt: UpstreamCount, + +// TODO: カウンタの移動 +#[derive(Debug, Clone, Builder)] +pub struct LbRoundRobinCount { + #[builder(default)] + cnt: Arc, + #[builder(setter(custom), default)] + max_val: usize, +} +impl LbRoundRobinCountBuilder { + pub fn max_val(&mut self, v: &usize) -> &mut Self { + self.max_val = Some(*v); + self + } +} +impl LbRoundRobinCount { + /// Get a current count of upstream served + fn current_cnt(&self) -> usize { + self.cnt.load(Ordering::Relaxed) + } + + /// Increment the count of upstream served up to the max value + pub fn increment_cnt(&self) -> usize { + if self.current_cnt() < self.max_val - 1 { + self.cnt.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.cnt.fetch_and(0, Ordering::Relaxed) + } + } +} + #[derive(Debug, Clone)] /// Load Balancing Option pub enum LoadBalance { /// Fix to the first upstream. Use if only one upstream destination is specified FixToFirst, /// Simple round robin without session persistance - RoundRobin, // TODO: カウンタはここにいれる。randomとかには不要なので + RoundRobin(LbRoundRobinCount), // TODO: カウンタはここにいれる。randomとかには不要なので /// Randomly chose one upstream server Random, /// Round robin with session persistance using cookie diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index 10d0a3c..5875189 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,18 +1,12 @@ use super::{ - load_balance::{load_balance_options as lb_opts, LoadBalance}, + load_balance::{load_balance_options as lb_opts, LbRoundRobinCountBuilder, LoadBalance}, BytesName, PathNameBytesExp, UpstreamOption, }; use crate::log::*; use derive_builder::Builder; use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use std::{ - borrow::Cow, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, -}; +use std::borrow::Cow; #[derive(Debug, Clone)] pub struct ReverseProxy { @@ -76,9 +70,6 @@ pub struct UpstreamGroup { #[builder(setter(custom), default)] /// Load balancing option pub lb: LoadBalance, - #[builder(default)] - /// Counter for load balancing - pub cnt: UpstreamCount, #[builder(setter(custom), default)] /// Activated upstream options defined in [[UpstreamOption]] pub opts: HashSet, @@ -101,11 +92,16 @@ impl UpstreamGroupBuilder { ); self } - pub fn lb(&mut self, v: &Option) -> &mut Self { + pub fn lb(&mut self, v: &Option, upstream_num: &usize) -> &mut Self { let lb = if let Some(x) = v { match x.as_str() { lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, - lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin, + lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( + LbRoundRobinCountBuilder::default() + .max_val(upstream_num) + .build() + .unwrap(), + ), lb_opts::RANDOM => LoadBalance::Random, lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin, _ => { @@ -133,17 +129,13 @@ impl UpstreamGroupBuilder { } } -// TODO: カウンタの移動 -#[derive(Debug, Clone, Default)] -pub struct UpstreamCount(Arc); - impl UpstreamGroup { /// Get an enabled option of load balancing [[LoadBalance]] pub fn get(&self) -> Option<&Upstream> { - match self.lb { + match &self.lb { LoadBalance::FixToFirst => self.upstream.get(0), - LoadBalance::RoundRobin => { - let idx = self.increment_cnt(); + LoadBalance::RoundRobin(cnt) => { + let idx = cnt.increment_cnt(); self.upstream.get(idx) } LoadBalance::Random => { @@ -154,18 +146,4 @@ impl UpstreamGroup { LoadBalance::StickyRoundRobin => todo!(), // TODO: TODO: } } - - /// Get a current count of upstream served - fn current_cnt(&self) -> usize { - self.cnt.0.load(Ordering::Relaxed) - } - - /// Increment count of upstream served - fn increment_cnt(&self) -> usize { - if self.current_cnt() < self.upstream.len() - 1 { - self.cnt.0.fetch_add(1, Ordering::Relaxed) - } else { - self.cnt.0.fetch_and(0, Ordering::Relaxed) - } - } } diff --git a/src/config/parse.rs b/src/config/parse.rs index d023f56..88892a8 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,6 +1,6 @@ use super::toml::{ConfigToml, ReverseProxyOption}; use crate::{ - backend::{BackendBuilder, ReverseProxy, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, + backend::{BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, constants::*, error::*, globals::*, @@ -200,12 +200,15 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result { let mut upstream: HashMap = HashMap::default(); + rp_settings.iter().for_each(|rpo| { + let vec_upstream: Vec = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); + let lb_upstream_num = vec_upstream.len(); let elem = UpstreamGroupBuilder::default() - .upstream(rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect()) + .upstream(vec_upstream) .path(&rpo.path) .replace_path(&rpo.replace_path) - .lb(&rpo.load_balance) + .lb(&rpo.load_balance, &lb_upstream_num) .opts(&rpo.upstream_options) .build() .unwrap(); From 96810a4d4fb133be4c953d8395318020a80a6e28 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Sat, 3 Jun 2023 15:01:58 +0900 Subject: [PATCH 04/10] refactor: remove unneeded comments --- src/backend/load_balance.rs | 12 ++++-------- src/backend/upstream.rs | 11 ++++++++--- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs index f6b2292..bda5497 100644 --- a/src/backend/load_balance.rs +++ b/src/backend/load_balance.rs @@ -12,12 +12,8 @@ pub(super) mod load_balance_options { pub const STICKY_ROUND_ROBIN: &str = "sticky"; } -// -// /// Counter for load balancing -// pub cnt: UpstreamCount, - -// TODO: カウンタの移動 #[derive(Debug, Clone, Builder)] +/// Counter object as a pointer to the current serving upstream destination pub struct LbRoundRobinCount { #[builder(default)] cnt: Arc, @@ -52,12 +48,12 @@ impl LbRoundRobinCount { pub enum LoadBalance { /// Fix to the first upstream. Use if only one upstream destination is specified FixToFirst, - /// Simple round robin without session persistance - RoundRobin(LbRoundRobinCount), // TODO: カウンタはここにいれる。randomとかには不要なので /// Randomly chose one upstream server Random, + /// Simple round robin without session persistance + RoundRobin(LbRoundRobinCount), /// Round robin with session persistance using cookie - StickyRoundRobin, + StickyRoundRobin(LbRoundRobinCount), } impl Default for LoadBalance { fn default() -> Self { diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index 5875189..11b186e 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -96,14 +96,19 @@ impl UpstreamGroupBuilder { let lb = if let Some(x) = v { match x.as_str() { lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, + lb_opts::RANDOM => LoadBalance::Random, lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( LbRoundRobinCountBuilder::default() .max_val(upstream_num) .build() .unwrap(), ), - lb_opts::RANDOM => LoadBalance::Random, - lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin, + lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( + LbRoundRobinCountBuilder::default() + .max_val(upstream_num) + .build() + .unwrap(), + ), _ => { error!("Specified load balancing option is invalid."); LoadBalance::default() @@ -143,7 +148,7 @@ impl UpstreamGroup { let max = self.upstream.len() - 1; self.upstream.get(rng.gen_range(0..max)) } - LoadBalance::StickyRoundRobin => todo!(), // TODO: TODO: + LoadBalance::StickyRoundRobin(_cnt) => todo!(), // TODO: TODO: } } } From f4c59c9f2f178f6af85f648fd9212c8892d205bc Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Wed, 7 Jun 2023 15:03:06 +0900 Subject: [PATCH 05/10] fix: lb random range bug --- src/backend/load_balance.rs | 70 ++++++++++++++++++++++++++++--------- src/backend/upstream.rs | 29 +++++---------- 2 files changed, 62 insertions(+), 37 deletions(-) diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs index bda5497..8ff8b4b 100644 --- a/src/backend/load_balance.rs +++ b/src/backend/load_balance.rs @@ -1,4 +1,5 @@ use derive_builder::Builder; +use rand::Rng; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -13,50 +14,85 @@ pub(super) mod load_balance_options { } #[derive(Debug, Clone, Builder)] -/// Counter object as a pointer to the current serving upstream destination -pub struct LbRoundRobinCount { +/// Round Robin LB object as a pointer to the current serving upstream destination +pub struct LbRoundRobin { #[builder(default)] - cnt: Arc, + /// Pointer to the index of the last served upstream destination + ptr: Arc, #[builder(setter(custom), default)] - max_val: usize, + /// Number of upstream destinations + num_upstreams: usize, } -impl LbRoundRobinCountBuilder { - pub fn max_val(&mut self, v: &usize) -> &mut Self { - self.max_val = Some(*v); +impl LbRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); self } } -impl LbRoundRobinCount { +impl LbRoundRobin { /// Get a current count of upstream served - fn current_cnt(&self) -> usize { - self.cnt.load(Ordering::Relaxed) + fn current_ptr(&self) -> usize { + self.ptr.load(Ordering::Relaxed) } /// Increment the count of upstream served up to the max value - pub fn increment_cnt(&self) -> usize { - if self.current_cnt() < self.max_val - 1 { - self.cnt.fetch_add(1, Ordering::Relaxed) + pub fn increment_ptr(&self) -> usize { + if self.current_ptr() < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) } else { // Clear the counter - self.cnt.fetch_and(0, Ordering::Relaxed) + self.ptr.fetch_and(0, Ordering::Relaxed) } } } +#[derive(Debug, Clone, Builder)] +/// Random LB object to keep the object of random pools +pub struct LbRandom { + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LbRandomBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LbRandom { + /// Returns the random index within the range + pub fn get_ptr(&self) -> usize { + let mut rng = rand::thread_rng(); + rng.gen_range(0..self.num_upstreams) + } +} + #[derive(Debug, Clone)] /// Load Balancing Option pub enum LoadBalance { /// Fix to the first upstream. Use if only one upstream destination is specified FixToFirst, /// Randomly chose one upstream server - Random, + Random(LbRandom), /// Simple round robin without session persistance - RoundRobin(LbRoundRobinCount), + RoundRobin(LbRoundRobin), /// Round robin with session persistance using cookie - StickyRoundRobin(LbRoundRobinCount), + StickyRoundRobin(LbRoundRobin), } impl Default for LoadBalance { fn default() -> Self { Self::FixToFirst } } + +impl LoadBalance { + /// Get the index of the upstream serving the incoming request + pub(super) fn get_idx(&self) -> usize { + match self { + LoadBalance::FixToFirst => 0usize, + LoadBalance::RoundRobin(ptr) => ptr.increment_ptr(), + LoadBalance::Random(v) => v.get_ptr(), + LoadBalance::StickyRoundRobin(_ptr) => 0usize, // todo!(), // TODO: TODO: TODO: TODO: tentative value + } + } +} diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index 11b186e..ac17a40 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,10 +1,9 @@ use super::{ - load_balance::{load_balance_options as lb_opts, LbRoundRobinCountBuilder, LoadBalance}, + load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}, BytesName, PathNameBytesExp, UpstreamOption, }; use crate::log::*; use derive_builder::Builder; -use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::borrow::Cow; @@ -96,16 +95,16 @@ impl UpstreamGroupBuilder { let lb = if let Some(x) = v { match x.as_str() { lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, - lb_opts::RANDOM => LoadBalance::Random, + lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( - LbRoundRobinCountBuilder::default() - .max_val(upstream_num) + LbRoundRobinBuilder::default() + .num_upstreams(upstream_num) .build() .unwrap(), ), lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( - LbRoundRobinCountBuilder::default() - .max_val(upstream_num) + LbRoundRobinBuilder::default() + .num_upstreams(upstream_num) .build() .unwrap(), ), @@ -137,18 +136,8 @@ impl UpstreamGroupBuilder { impl UpstreamGroup { /// Get an enabled option of load balancing [[LoadBalance]] pub fn get(&self) -> Option<&Upstream> { - match &self.lb { - LoadBalance::FixToFirst => self.upstream.get(0), - LoadBalance::RoundRobin(cnt) => { - let idx = cnt.increment_cnt(); - self.upstream.get(idx) - } - LoadBalance::Random => { - let mut rng = rand::thread_rng(); - let max = self.upstream.len() - 1; - self.upstream.get(rng.gen_range(0..max)) - } - LoadBalance::StickyRoundRobin(_cnt) => todo!(), // TODO: TODO: - } + let idx = self.lb.get_idx(); + debug!("Upstream of index {idx} is chosen."); + self.upstream.get(idx) } } From a0ae3578d7ca4d2597ab655676adf2bbffb418b6 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 9 Jun 2023 02:18:01 +0900 Subject: [PATCH 06/10] feat: initial implementation of sticky cookie for session persistance when load-balancing --- Cargo.toml | 9 + src/backend/load_balance.rs | 183 +++++++++++++++--- src/backend/load_balance_sticky_cookie.rs | 216 ++++++++++++++++++++++ src/backend/mod.rs | 9 +- src/backend/upstream.rs | 72 +++++++- src/config/parse.rs | 16 +- src/constants.rs | 3 + src/error.rs | 3 + src/handler/handler_main.rs | 41 +++- src/handler/mod.rs | 7 + src/handler/utils_headers.rs | 70 ++++++- 11 files changed, 580 insertions(+), 49 deletions(-) create mode 100644 src/backend/load_balance_sticky_cookie.rs diff --git a/Cargo.toml b/Cargo.toml index 10400e3..33b6080 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,15 @@ h3 = { path = "./h3/h3/", optional = true } # h3-quinn = { path = "./h3/h3-quinn/", optional = true } h3-quinn = { path = "./h3-quinn/", optional = true } # Tentative to support rustls-0.21 +# cookie handling +chrono = { version = "0.4.26", default-features = false, features = [ + "unstable-locales", + "alloc", + "clock", +] } +base64 = "0.21.2" +sha2 = { version = "0.10.6", default-features = false } + [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs index 8ff8b4b..647903f 100644 --- a/src/backend/load_balance.rs +++ b/src/backend/load_balance.rs @@ -1,8 +1,14 @@ +use super::{load_balance_sticky_cookie::StickyCookieConfig, LbContext, Upstream}; +use crate::{constants::STICKY_COOKIE_NAME, error::*, log::*}; use derive_builder::Builder; use rand::Rng; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use rustc_hash::FxHashMap as HashMap; +use std::{ + borrow::Cow, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; /// Constants to specify a load balance option @@ -13,6 +19,18 @@ pub(super) mod load_balance_options { pub const STICKY_ROUND_ROBIN: &str = "sticky"; } +#[derive(Debug, Clone)] +/// Pointer to upstream serving the incoming request. +/// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given. +pub(super) struct PointerToUpstream { + pub ptr: usize, + pub context_lb: Option, +} +/// Trait for LB +trait LbWithPointer { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; +} + #[derive(Debug, Clone, Builder)] /// Round Robin LB object as a pointer to the current serving upstream destination pub struct LbRoundRobin { @@ -29,20 +47,19 @@ impl LbRoundRobinBuilder { self } } -impl LbRoundRobin { - /// Get a current count of upstream served - fn current_ptr(&self) -> usize { - self.ptr.load(Ordering::Relaxed) - } - +impl LbWithPointer for LbRoundRobin { /// Increment the count of upstream served up to the max value - pub fn increment_ptr(&self) -> usize { - if self.current_ptr() < self.num_upstreams - 1 { + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + let ptr = if current_ptr < self.num_upstreams - 1 { self.ptr.fetch_add(1, Ordering::Relaxed) } else { // Clear the counter self.ptr.fetch_and(0, Ordering::Relaxed) - } + }; + PointerToUpstream { ptr, context_lb: None } } } @@ -59,11 +76,129 @@ impl LbRandomBuilder { self } } -impl LbRandom { +impl LbWithPointer for LbRandom { /// Returns the random index within the range - pub fn get_ptr(&self) -> usize { + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { let mut rng = rand::thread_rng(); - rng.gen_range(0..self.num_upstreams) + let ptr = rng.gen_range(0..self.num_upstreams); + PointerToUpstream { ptr, context_lb: None } + } +} + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object in the sticky cookie manner +pub struct LbStickyRoundRobin { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, + #[builder(setter(custom))] + /// Information to build the cookie to stick clients to specific backends + pub sticky_config: StickyCookieConfig, + #[builder(setter(custom))] + /// Hashmaps: + /// - Hashmap that maps server indices to server id (string) + /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_maps: UpstreamMap, +} +#[derive(Debug, Clone)] +pub struct UpstreamMap { + /// Hashmap that maps server indices to server id (string) + upstream_index_map: Vec, + /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_id_map: HashMap, +} +impl LbStickyRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } + pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option) -> &mut Self { + self.sticky_config = Some(StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように + domain: server_name.to_ascii_lowercase(), + path: if let Some(v) = path_opt { + v.to_ascii_lowercase() + } else { + "/".to_string() + }, + duration: 300, // TODO: config等で変更できるように + }); + self + } + pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + let upstream_index_map: Vec = upstream_vec + .iter() + .enumerate() + .map(|(i, v)| v.calculate_id_with_index(i)) + .collect(); + let mut upstream_id_map = HashMap::default(); + for (i, v) in upstream_index_map.iter().enumerate() { + upstream_id_map.insert(v.to_string(), i); + } + self.upstream_maps = Some(UpstreamMap { + upstream_index_map, + upstream_id_map, + }); + self + } +} +impl<'a> LbStickyRoundRobin { + fn simple_increment_ptr(&self) -> usize { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + } + } + /// This is always called only internally. So 'unwrap()' is executed. + fn get_server_id_from_index(&self, index: usize) -> String { + self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() + } + /// This function takes value passed from outside. So 'result' is used. + fn get_server_index_from_id(&self, id: impl Into>) -> Option { + let id_str = id.into().to_string(); + self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) + } +} +impl LbWithPointer for LbStickyRoundRobin { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { + // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer. + // Otherwise, get the server index indicated by the server_id inside the cookie + let ptr = match req_info { + None => { + debug!("No sticky cookie"); + self.simple_increment_ptr() + } + Some(context) => { + let server_id = &context.sticky_cookie.value.value; + if let Some(server_index) = self.get_server_index_from_id(server_id) { + debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); + server_index + } else { + debug!("Invalid sticky cookie: id={}", server_id); + self.simple_increment_ptr() + } + } + }; + + // Get the server id from the ptr. + // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie). + let upstream_id = self.get_server_id_from_index(ptr); + let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); + let new_context = Some(LbContext { + sticky_cookie: new_cookie, + }); + PointerToUpstream { + ptr, + context_lb: new_context, + } } } @@ -77,7 +212,7 @@ pub enum LoadBalance { /// Simple round robin without session persistance RoundRobin(LbRoundRobin), /// Round robin with session persistance using cookie - StickyRoundRobin(LbRoundRobin), + StickyRoundRobin(LbStickyRoundRobin), } impl Default for LoadBalance { fn default() -> Self { @@ -87,12 +222,18 @@ impl Default for LoadBalance { impl LoadBalance { /// Get the index of the upstream serving the incoming request - pub(super) fn get_idx(&self) -> usize { + pub(super) fn get_context(&self, context_to_lb: &Option) -> PointerToUpstream { match self { - LoadBalance::FixToFirst => 0usize, - LoadBalance::RoundRobin(ptr) => ptr.increment_ptr(), - LoadBalance::Random(v) => v.get_ptr(), - LoadBalance::StickyRoundRobin(_ptr) => 0usize, // todo!(), // TODO: TODO: TODO: TODO: tentative value + LoadBalance::FixToFirst => PointerToUpstream { + ptr: 0usize, + context_lb: None, + }, + LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), + LoadBalance::Random(ptr) => ptr.get_ptr(None), + LoadBalance::StickyRoundRobin(ptr) => { + // Generate new context if sticky round robin is enabled. + ptr.get_ptr(context_to_lb.as_ref()) + } } } } diff --git a/src/backend/load_balance_sticky_cookie.rs b/src/backend/load_balance_sticky_cookie.rs new file mode 100644 index 0000000..d293004 --- /dev/null +++ b/src/backend/load_balance_sticky_cookie.rs @@ -0,0 +1,216 @@ +use std::borrow::Cow; + +use crate::error::*; +use chrono::{TimeZone, Utc}; +use derive_builder::Builder; + +#[derive(Debug, Clone)] +/// Struct to handle the sticky cookie string, +/// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists. +/// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist. +pub struct LbContext { + pub sticky_cookie: StickyCookie, +} + +#[derive(Debug, Clone, Builder)] +/// Cookie value only, used for COOKIE in req +pub struct StickyCookieValue { + #[builder(setter(custom))] + /// Field name indicating sticky cookie + pub name: String, + #[builder(setter(custom))] + /// Upstream server_id + pub value: String, +} +impl<'a> StickyCookieValueBuilder { + pub fn name(&mut self, v: impl Into>) -> &mut Self { + self.name = Some(v.into().to_ascii_lowercase()); + self + } + pub fn value(&mut self, v: impl Into>) -> &mut Self { + self.value = Some(v.into().to_string()); + self + } +} +impl StickyCookieValue { + pub fn try_from(value: &str, expected_name: &str) -> Result { + if !value.starts_with(expected_name) { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion from string".to_string(), + )); + }; + let kv = value.split('=').map(|v| v.trim()).collect::>(); + if kv.len() != 2 { + return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); + }; + if kv[1].is_empty() { + return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); + } + Ok(StickyCookieValue { + name: expected_name.to_string(), + value: kv[1].to_string(), + }) + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie meta information used for SET-COOKIE in res +pub struct StickyCookieInfo { + #[builder(setter(custom))] + /// Unix time + pub expires: i64, + + #[builder(setter(custom))] + /// Domain + pub domain: String, + + #[builder(setter(custom))] + /// Path + pub path: String, +} +impl<'a> StickyCookieInfoBuilder { + pub fn domain(&mut self, v: impl Into>) -> &mut Self { + self.domain = Some(v.into().to_ascii_lowercase()); + self + } + pub fn path(&mut self, v: impl Into>) -> &mut Self { + self.path = Some(v.into().to_ascii_lowercase()); + self + } + pub fn expires(&mut self, duration_secs: i64) -> &mut Self { + let current = Utc::now().timestamp(); + self.expires = Some(current + duration_secs); + self + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie +pub struct StickyCookie { + #[builder(setter(custom))] + /// Upstream server_id + pub value: StickyCookieValue, + #[builder(setter(custom), default)] + /// Upstream server_id + pub info: Option, +} + +impl<'a> StickyCookieBuilder { + pub fn value(&mut self, n: impl Into>, v: impl Into>) -> &mut Self { + self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); + self + } + pub fn info( + &mut self, + domain: impl Into>, + path: impl Into>, + duration_secs: i64, + ) -> &mut Self { + let info = StickyCookieInfoBuilder::default() + .domain(domain) + .path(path) + .expires(duration_secs) + .build() + .unwrap(); + self.info = Some(Some(info)); + self + } +} + +impl TryInto for StickyCookie { + type Error = RpxyError; + + fn try_into(self) -> Result { + if self.info.is_none() { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion into string: no meta information".to_string(), + )); + } + let info = self.info.unwrap(); + let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { + return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); + }; + let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); + let max_age = info.expires - Utc::now().timestamp(); + + Ok(format!( + "{}={}; expires={}; Max-Age={}; path={}; domain={}; HttpOnly", + self.value.name, self.value.value, exp_str, max_age, info.path, info.domain + )) + } +} + +#[derive(Debug, Clone)] +/// Configuration to serve incoming requests in the manner of "sticky cookie". +/// Including a dictionary to map Ids included in cookie and upstream destinations, +/// and expiration of cookie. +/// "domain" and "path" in the cookie will be the same as the reverse proxy options. +pub struct StickyCookieConfig { + pub name: String, + pub domain: String, + pub path: String, + pub duration: i64, +} +impl<'a> StickyCookieConfig { + pub fn build_sticky_cookie(&self, v: impl Into>) -> Result { + StickyCookieBuilder::default() + .value(self.name.clone(), v) + .info(&self.domain, &self.path, self.duration) + .build() + .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::constants::STICKY_COOKIE_NAME; + + #[test] + fn config_works() { + let config = StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), + domain: "example.com".to_string(), + path: "/path".to_string(), + duration: 100, + }; + let expires_unix = Utc::now().timestamp() + 100; + let sc_string: Result = config.build_sticky_cookie("test_value").unwrap().try_into(); + let expires_date_string = Utc + .timestamp_opt(expires_unix, 0) + .unwrap() + .format("%a, %d-%b-%Y %T GMT") + .to_string(); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com; HttpOnly", + STICKY_COOKIE_NAME, expires_date_string, 100 + ) + ); + } + #[test] + fn to_string_works() { + let sc = StickyCookie { + value: StickyCookieValue { + name: STICKY_COOKIE_NAME.to_string(), + value: "test_value".to_string(), + }, + info: Some(StickyCookieInfo { + expires: 1686221173i64, + domain: "example.com".to_string(), + path: "/path".to_string(), + }), + }; + let sc_string: Result = sc.try_into(); + let max_age = 1686221173i64 - Utc::now().timestamp(); + assert!(sc_string.is_ok()); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com; HttpOnly", + STICKY_COOKIE_NAME, max_age + ) + ); + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 00a6c83..9164c45 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,7 +1,14 @@ mod load_balance; +mod load_balance_sticky_cookie; mod upstream; mod upstream_opts; +pub use self::{ + load_balance::LoadBalance, + load_balance_sticky_cookie::{LbContext, StickyCookie, StickyCookieBuilder, StickyCookieValue}, + upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, + upstream_opts::UpstreamOption, +}; use crate::{ log::*, utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, @@ -21,8 +28,6 @@ use tokio_rustls::rustls::{ sign::{any_supported_type, CertifiedKey}, Certificate, PrivateKey, ServerConfig, }; -pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}; -pub use upstream_opts::UpstreamOption; use x509_parser::prelude::*; /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index ac17a40..9e53e5d 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,12 +1,16 @@ use super::{ - load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}, + load_balance::{ + load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LbStickyRoundRobinBuilder, LoadBalance, + }, + load_balance_sticky_cookie::LbContext, BytesName, PathNameBytesExp, UpstreamOption, }; use crate::log::*; +use base64::{engine::general_purpose, Engine as _}; use derive_builder::Builder; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use sha2::{Digest, Sha256}; use std::borrow::Cow; - #[derive(Debug, Clone)] pub struct ReverseProxy { pub upstream: HashMap, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 @@ -53,10 +57,20 @@ pub struct Upstream { /// Base uri without specific path pub uri: hyper::Uri, } - +impl Upstream { + /// Hashing uri with index to avoid collision + pub fn calculate_id_with_index(&self, index: usize) -> String { + let mut hasher = Sha256::new(); + let uri_string = format!("{}&index={}", self.uri.clone(), index); + hasher.update(uri_string.as_bytes()); + let digest = hasher.finalize(); + general_purpose::URL_SAFE_NO_PAD.encode(digest) + } +} #[derive(Debug, Clone, Builder)] /// Struct serving multiple upstream servers for, e.g., load balancing. pub struct UpstreamGroup { + #[builder(setter(custom))] /// Upstream server(s) pub upstream: Vec, #[builder(setter(custom), default)] @@ -75,6 +89,10 @@ pub struct UpstreamGroup { } impl UpstreamGroupBuilder { + pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + self.upstream = Some(upstream_vec.to_vec()); + self + } pub fn path(&mut self, v: &Option) -> &mut Self { let path = match v { Some(p) => p.to_path_name_vec(), @@ -91,7 +109,15 @@ impl UpstreamGroupBuilder { ); self } - pub fn lb(&mut self, v: &Option, upstream_num: &usize) -> &mut Self { + pub fn lb( + &mut self, + v: &Option, + // upstream_num: &usize, + upstream_vec: &Vec, + server_name: &str, + path_opt: &Option, + ) -> &mut Self { + let upstream_num = &upstream_vec.len(); let lb = if let Some(x) = v { match x.as_str() { lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, @@ -103,8 +129,10 @@ impl UpstreamGroupBuilder { .unwrap(), ), lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( - LbRoundRobinBuilder::default() + LbStickyRoundRobinBuilder::default() .num_upstreams(upstream_num) + .sticky_config(server_name, path_opt) + .upstream_maps(upstream_vec) // TODO: .build() .unwrap(), ), @@ -135,9 +163,35 @@ impl UpstreamGroupBuilder { impl UpstreamGroup { /// Get an enabled option of load balancing [[LoadBalance]] - pub fn get(&self) -> Option<&Upstream> { - let idx = self.lb.get_idx(); - debug!("Upstream of index {idx} is chosen."); - self.upstream.get(idx) + pub fn get(&self, context_to_lb: &Option) -> (Option<&Upstream>, Option) { + let pointer_to_upstream = self.lb.get_context(context_to_lb); + debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); + debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); + debug!( + "Context from LB (Set-Cookie in Res): {:?}", + pointer_to_upstream.context_lb + ); + ( + self.upstream.get(pointer_to_upstream.ptr), + pointer_to_upstream.context_lb, + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn calc_id_works() { + let uri = "https://www.rust-lang.org".parse::().unwrap(); + let upstream = Upstream { uri }; + assert_eq!( + "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", + upstream.calculate_id_with_index(0) + ); + assert_eq!( + "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", + upstream.calculate_id_with_index(1) + ); } } diff --git a/src/config/parse.rs b/src/config/parse.rs index 88892a8..8e4ddf7 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -99,7 +99,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro let mut backend_builder = BackendBuilder::default(); // reverse proxy settings ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); - let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; + let reverse_proxy = get_reverse_proxy(server_name_string, app.reverse_proxy.as_ref().unwrap())?; backend_builder .app_name(server_name_string) @@ -198,17 +198,21 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro Ok(()) } -fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result { +fn get_reverse_proxy( + server_name_string: &str, + rp_settings: &[ReverseProxyOption], +) -> std::result::Result { let mut upstream: HashMap = HashMap::default(); rp_settings.iter().for_each(|rpo| { - let vec_upstream: Vec = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); - let lb_upstream_num = vec_upstream.len(); + let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); + // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); + // let lb_upstream_num = vec_upstream.len(); let elem = UpstreamGroupBuilder::default() - .upstream(vec_upstream) + .upstream(&upstream_vec) .path(&rpo.path) .replace_path(&rpo.replace_path) - .lb(&rpo.load_balance, &lb_upstream_num) + .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) .opts(&rpo.upstream_options) .build() .unwrap(); diff --git a/src/constants.rs b/src/constants.rs index d2fc25f..6d4d8ad 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -24,3 +24,6 @@ pub mod H3 { pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs } + +// For load-balancing with sticky cookie +pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; diff --git a/src/error.rs b/src/error.rs index 6da3b02..c5b34ad 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,6 +22,9 @@ pub enum RpxyError { #[error("TCP/UDP Proxy Layer Error: {0}")] Proxy(String), + #[error("LoadBalance Layer Error: {0}")] + LoadBalance(String), + #[error("I/O Error")] Io(#[from] io::Error), diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index 664f5e4..74f73a0 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -1,7 +1,7 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; +use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext}; use crate::{ - backend::{Backend, UpstreamGroup}, + backend::{Backend, LoadBalance, UpstreamGroup}, error::*, globals::Globals, log::*, @@ -91,7 +91,7 @@ where let request_upgraded = req.extensions_mut().remove::(); // Build request from destination information - if let Err(e) = self.generate_request_forwarded( + let context = match self.generate_request_forwarded( &client_addr, &listen_addr, &mut req, @@ -99,8 +99,11 @@ where upstream_group, tls_enabled, ) { - error!("Failed to generate destination uri for reverse proxy: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + Err(e) => { + error!("Failed to generate destination uri for reverse proxy: {}", e); + return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + } + Ok(v) => v, }; debug!("Request to be forwarded: {:?}", req); log_data.xff(&req.headers().get("x-forwarded-for")); @@ -123,6 +126,15 @@ where } }; + // Process reverse proxy context generated during the forwarding request generation. + if let Some(context_from_lb) = context.context_lb { + let res_headers = res_backend.headers_mut(); + if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + error!("Failed to append context to the response given from backend: {}", e); + return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); + } + } + if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { // Generate response to client if self.generate_response_forwarded(&mut res_backend, backend).is_ok() { @@ -229,7 +241,7 @@ where upgrade: &Option, upstream_group: &UpstreamGroup, tls_enabled: bool, - ) -> Result<()> { + ) -> Result { debug!("Generate request to be forwarded"); // Add te: trailer if contained in original request @@ -265,10 +277,19 @@ where .insert(header::HOST, HeaderValue::from_str(&org_host)?); }; + ///////////////////////////////////////////// // Fix unique upstream destination since there could be multiple ones. - // TODO: StickyならCookieをここでgetに与える必要 - // TODO: Stickyで、Cookieが与えられなかったらset-cookie向けにcookieを返す必要。upstreamオブジェクトに含めるのも手。 - let upstream_chosen = upstream_group.get().ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context_to_lb = if let LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { + takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + } else { + None + }; + let (upstream_chosen_opt, context_from_lb) = upstream_group.get(&context_to_lb); + let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context = HandlerContext { + context_lb: context_from_lb, + }; + ///////////////////////////////////////////// // apply upstream-specific headers given in upstream_option let headers = req.headers_mut(); @@ -321,6 +342,6 @@ where *req.version_mut() = Version::HTTP_2; } - Ok(()) + Ok(context) } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index c2225ce..fc30129 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -4,3 +4,10 @@ mod utils_request; mod utils_synth_response; pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; + +use crate::backend::LbContext; + +#[derive(Debug)] +struct HandlerContext { + context_lb: Option, +} diff --git a/src/handler/utils_headers.rs b/src/handler/utils_headers.rs index 7fc4a5f..3819386 100644 --- a/src/handler/utils_headers.rs +++ b/src/handler/utils_headers.rs @@ -1,5 +1,5 @@ use crate::{ - backend::{UpstreamGroup, UpstreamOption}, + backend::{LbContext, StickyCookie, StickyCookieValue, UpstreamGroup, UpstreamOption}, error::*, log::*, utils::*, @@ -14,6 +14,74 @@ use std::net::SocketAddr; //////////////////////////////////////////////////// // Functions to manipulate headers +/// Take sticky cookie header value from request header, +/// and returns LbContext to be forwarded to LB if exist and if needed. +/// Removing sticky cookie is needed and it must not be passed to the upstream. +pub(super) fn takeout_sticky_cookie_lb_context( + headers: &mut HeaderMap, + expected_cookie_name: &str, +) -> Result> { + let mut headers_clone = headers.clone(); + + match headers_clone.entry(hyper::header::COOKIE) { + header::Entry::Vacant(_) => Ok(None), + header::Entry::Occupied(entry) => { + let cookies_iter = entry + .iter() + .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); + let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter + .into_iter() + .partition(|v| v.starts_with(expected_cookie_name)); + if sticky_cookies.is_empty() { + return Ok(None); + } + if sticky_cookies.len() > 1 { + error!("Multiple sticky cookie values in request"); + return Err(RpxyError::Other(anyhow!( + "Invalid cookie: Multiple sticky cookie values" + ))); + } + let cookies_passed_to_upstream = without_sticky_cookies.join("; "); + let cookie_passed_to_lb = sticky_cookies.first().unwrap(); + headers.remove(hyper::header::COOKIE); + headers.insert(hyper::header::COOKIE, cookies_passed_to_upstream.parse()?); + + let sticky_cookie = StickyCookie { + value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, + info: None, + }; + Ok(Some(LbContext { sticky_cookie })) + } + } +} + +/// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated. +/// Set-Cookie response header could be in multiple lines. +/// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie +pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { + let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; + let new_header_val: HeaderValue = sticky_cookie_string.parse()?; + let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; + match headers.entry(hyper::header::SET_COOKIE) { + header::Entry::Vacant(entry) => { + entry.insert(new_header_val); + } + header::Entry::Occupied(mut entry) => { + let mut flag = false; + for e in entry.iter_mut() { + if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { + *e = new_header_val.clone(); + flag = true; + } + } + if !flag { + entry.append(new_header_val); + } + } + }; + Ok(()) +} + pub(super) fn apply_upstream_options_to_header( headers: &mut HeaderMap, _client_addr: &SocketAddr, From ba458b32de693f2ee33baf03bb7119acd1827a96 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 9 Jun 2023 02:18:26 +0900 Subject: [PATCH 07/10] refactor --- src/backend/load_balance.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs index 647903f..b02a0f7 100644 --- a/src/backend/load_balance.rs +++ b/src/backend/load_balance.rs @@ -1,5 +1,5 @@ use super::{load_balance_sticky_cookie::StickyCookieConfig, LbContext, Upstream}; -use crate::{constants::STICKY_COOKIE_NAME, error::*, log::*}; +use crate::{constants::STICKY_COOKIE_NAME, log::*}; use derive_builder::Builder; use rand::Rng; use rustc_hash::FxHashMap as HashMap; From 8c2df78ead3d58a99e82d049dfa77ba1b2682ae1 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 9 Jun 2023 02:25:03 +0900 Subject: [PATCH 08/10] deps: serde --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 33b6080..23574c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ clap = { version = "4.3.2", features = ["std", "cargo", "wrap_help"] } rand = "0.8.5" toml = { version = "0.7.4", default-features = false, features = ["parse"] } rustc-hash = "1.1.0" -serde = { version = "1.0.163", default-features = false, features = ["derive"] } +serde = { version = "1.0.164", default-features = false, features = ["derive"] } bytes = "1.4.0" thiserror = "1.0.40" x509-parser = "0.15.0" From 5e3ac3b3615512a906345d3ca5ad7331a4512a53 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 9 Jun 2023 02:28:20 +0900 Subject: [PATCH 09/10] update docks --- CHANGELOG.md | 2 ++ TODO.md | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 80628bb..0d887b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ ## 0.3.0 (unreleased) ### Improvement + - Update `h3` with `quinn-0.10` or higher. +- Implement the session persistance function for load balancing using sticky cookie (initial implementation). ## 0.2.0 diff --git a/TODO.md b/TODO.md index 3c44f82..3dbc6eb 100644 --- a/TODO.md +++ b/TODO.md @@ -4,7 +4,6 @@ - More flexible option for rewriting path - Refactoring - Unit tests -- Implementing load-balancing of backend apps (currently it doesn't consider to maintain session but simply rotate in a certain fashion) - Options to serve custom http_error page. - Prometheus metrics - Documentation @@ -13,4 +12,7 @@ - Currently, we took the following approach (caveats) - For Http2 and 1.1, prepare `rustls::ServerConfig` for each domain name and hence client CA cert is set for each one. - For Http3, use aggregated `rustls::ServerConfig` for multiple domain names except for ones requiring client-auth. So, if a domain name is set with client authentication, http3 doesn't work for the domain. +- Make the session-persistance option for load-balancing sophisticated. (mostly done in v0.3.0) + - add option for sticky cookie name + - add option for sticky cookie duration - etc. From 7234d3f399893fff1565addb22731c6c7f50999f Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 9 Jun 2023 02:29:09 +0900 Subject: [PATCH 10/10] submodule --- quinn | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quinn b/quinn index 65bbb1e..98f5fe2 160000 --- a/quinn +++ b/quinn @@ -1 +1 @@ -Subproject commit 65bbb1e154ad66874a7f2ed59d55a7dbaa67883b +Subproject commit 98f5fe2a3fabb9ff991f8c831e8d43de76985ff3