diff --git a/CHANGELOG.md b/CHANGELOG.md index 80628bb..0d887b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,9 @@ ## 0.3.0 (unreleased) ### Improvement + - Update `h3` with `quinn-0.10` or higher. +- Implement the session persistance function for load balancing using sticky cookie (initial implementation). ## 0.2.0 diff --git a/Cargo.toml b/Cargo.toml index 10400e3..23574c6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ clap = { version = "4.3.2", features = ["std", "cargo", "wrap_help"] } rand = "0.8.5" toml = { version = "0.7.4", default-features = false, features = ["parse"] } rustc-hash = "1.1.0" -serde = { version = "1.0.163", default-features = false, features = ["derive"] } +serde = { version = "1.0.164", default-features = false, features = ["derive"] } bytes = "1.4.0" thiserror = "1.0.40" x509-parser = "0.15.0" @@ -65,6 +65,15 @@ h3 = { path = "./h3/h3/", optional = true } # h3-quinn = { path = "./h3/h3-quinn/", optional = true } h3-quinn = { path = "./h3-quinn/", optional = true } # Tentative to support rustls-0.21 +# cookie handling +chrono = { version = "0.4.26", default-features = false, features = [ + "unstable-locales", + "alloc", + "clock", +] } +base64 = "0.21.2" +sha2 = { version = "0.10.6", default-features = false } + [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/README.md b/README.md index 9c8011e..53ffc91 100644 --- a/README.md +++ b/README.md @@ -108,7 +108,7 @@ revese_proxy = [ #### Load Balancing -You can specify multiple backend locations in the `reverse_proxy` array for *load-balancing*. Currently it works in the manner of round-robin. +You can specify multiple backend locations in the `reverse_proxy` array for *load-balancing* with an appropriate `load_balance` option. Currently it works in the manner of round-robin or in the random fashion. if `load_balance` is not specified, the first backend location is always chosen. ```toml [apps."app_name"] @@ -117,8 +117,11 @@ reverse_proxy = [ { location = 'app1.local:8080' }, { location = 'app2.local:8000' } ] +load_balance = 'round_robin' ``` +(TODO: Sticky session is currently being implemented) + ### Second Step: Terminating TLS First of all, you need to specify a port `listen_port_tls` listening the HTTPS traffic, separately from HTTP port (`listen_port`). Then, serving an HTTPS endpoint can be easily done for your desired application just by specifying TLS certificates and private keys in PEM files. diff --git a/TODO.md b/TODO.md index 3c44f82..3dbc6eb 100644 --- a/TODO.md +++ b/TODO.md @@ -4,7 +4,6 @@ - More flexible option for rewriting path - Refactoring - Unit tests -- Implementing load-balancing of backend apps (currently it doesn't consider to maintain session but simply rotate in a certain fashion) - Options to serve custom http_error page. - Prometheus metrics - Documentation @@ -13,4 +12,7 @@ - Currently, we took the following approach (caveats) - For Http2 and 1.1, prepare `rustls::ServerConfig` for each domain name and hence client CA cert is set for each one. - For Http3, use aggregated `rustls::ServerConfig` for multiple domain names except for ones requiring client-auth. So, if a domain name is set with client authentication, http3 doesn't work for the domain. +- Make the session-persistance option for load-balancing sophisticated. (mostly done in v0.3.0) + - add option for sticky cookie name + - add option for sticky cookie duration - etc. diff --git a/config-example.toml b/config-example.toml index 9b2b463..0382393 100644 --- a/config-example.toml +++ b/config-example.toml @@ -52,6 +52,7 @@ upstream = [ { location = 'www.yahoo.com', tls = true }, { location = 'www.yahoo.co.jp', tls = true }, ] +load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = ["override_host", "convert_https_to_2"] # Non-default destination in "localhost" app, which is routed by "path" @@ -60,13 +61,14 @@ path = '/maps' # For request path starting with "/maps", # this configuration results that any path like "/maps/org/any.ext" is mapped to "/replacing/path1/org/any.ext" # by replacing "/maps" with "/replacing/path1" for routing to the locations given in upstream array -# Note that unless "path_replaced_with" is specified, the "path" is always preserved. -# "path_replaced_with" must be start from "/" (root path) +# Note that unless "replace_path" is specified, the "path" is always preserved. +# "replace_path" must be start from "/" (root path) replace_path = "/replacing/path1" upstream = [ { location = 'www.bing.com', tls = true }, { location = 'www.bing.co.jp', tls = true }, ] +load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ "override_host", "upgrade_insecure_requests", diff --git a/quinn b/quinn index 65bbb1e..98f5fe2 160000 --- a/quinn +++ b/quinn @@ -1 +1 @@ -Subproject commit 65bbb1e154ad66874a7f2ed59d55a7dbaa67883b +Subproject commit 98f5fe2a3fabb9ff991f8c831e8d43de76985ff3 diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs new file mode 100644 index 0000000..b02a0f7 --- /dev/null +++ b/src/backend/load_balance.rs @@ -0,0 +1,239 @@ +use super::{load_balance_sticky_cookie::StickyCookieConfig, LbContext, Upstream}; +use crate::{constants::STICKY_COOKIE_NAME, log::*}; +use derive_builder::Builder; +use rand::Rng; +use rustc_hash::FxHashMap as HashMap; +use std::{ + borrow::Cow, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +/// Constants to specify a load balance option +pub(super) mod load_balance_options { + pub const FIX_TO_FIRST: &str = "none"; + pub const ROUND_ROBIN: &str = "round_robin"; + pub const RANDOM: &str = "random"; + pub const STICKY_ROUND_ROBIN: &str = "sticky"; +} + +#[derive(Debug, Clone)] +/// Pointer to upstream serving the incoming request. +/// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given. +pub(super) struct PointerToUpstream { + pub ptr: usize, + pub context_lb: Option, +} +/// Trait for LB +trait LbWithPointer { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; +} + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object as a pointer to the current serving upstream destination +pub struct LbRoundRobin { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LbRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LbWithPointer for LbRoundRobin { + /// Increment the count of upstream served up to the max value + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + let ptr = if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + }; + PointerToUpstream { ptr, context_lb: None } + } +} + +#[derive(Debug, Clone, Builder)] +/// Random LB object to keep the object of random pools +pub struct LbRandom { + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LbRandomBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LbWithPointer for LbRandom { + /// Returns the random index within the range + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + let mut rng = rand::thread_rng(); + let ptr = rng.gen_range(0..self.num_upstreams); + PointerToUpstream { ptr, context_lb: None } + } +} + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object in the sticky cookie manner +pub struct LbStickyRoundRobin { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, + #[builder(setter(custom))] + /// Information to build the cookie to stick clients to specific backends + pub sticky_config: StickyCookieConfig, + #[builder(setter(custom))] + /// Hashmaps: + /// - Hashmap that maps server indices to server id (string) + /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_maps: UpstreamMap, +} +#[derive(Debug, Clone)] +pub struct UpstreamMap { + /// Hashmap that maps server indices to server id (string) + upstream_index_map: Vec, + /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_id_map: HashMap, +} +impl LbStickyRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } + pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option) -> &mut Self { + self.sticky_config = Some(StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように + domain: server_name.to_ascii_lowercase(), + path: if let Some(v) = path_opt { + v.to_ascii_lowercase() + } else { + "/".to_string() + }, + duration: 300, // TODO: config等で変更できるように + }); + self + } + pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + let upstream_index_map: Vec = upstream_vec + .iter() + .enumerate() + .map(|(i, v)| v.calculate_id_with_index(i)) + .collect(); + let mut upstream_id_map = HashMap::default(); + for (i, v) in upstream_index_map.iter().enumerate() { + upstream_id_map.insert(v.to_string(), i); + } + self.upstream_maps = Some(UpstreamMap { + upstream_index_map, + upstream_id_map, + }); + self + } +} +impl<'a> LbStickyRoundRobin { + fn simple_increment_ptr(&self) -> usize { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + } + } + /// This is always called only internally. So 'unwrap()' is executed. + fn get_server_id_from_index(&self, index: usize) -> String { + self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() + } + /// This function takes value passed from outside. So 'result' is used. + fn get_server_index_from_id(&self, id: impl Into>) -> Option { + let id_str = id.into().to_string(); + self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) + } +} +impl LbWithPointer for LbStickyRoundRobin { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { + // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer. + // Otherwise, get the server index indicated by the server_id inside the cookie + let ptr = match req_info { + None => { + debug!("No sticky cookie"); + self.simple_increment_ptr() + } + Some(context) => { + let server_id = &context.sticky_cookie.value.value; + if let Some(server_index) = self.get_server_index_from_id(server_id) { + debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); + server_index + } else { + debug!("Invalid sticky cookie: id={}", server_id); + self.simple_increment_ptr() + } + } + }; + + // Get the server id from the ptr. + // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie). + let upstream_id = self.get_server_id_from_index(ptr); + let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); + let new_context = Some(LbContext { + sticky_cookie: new_cookie, + }); + PointerToUpstream { + ptr, + context_lb: new_context, + } + } +} + +#[derive(Debug, Clone)] +/// Load Balancing Option +pub enum LoadBalance { + /// Fix to the first upstream. Use if only one upstream destination is specified + FixToFirst, + /// Randomly chose one upstream server + Random(LbRandom), + /// Simple round robin without session persistance + RoundRobin(LbRoundRobin), + /// Round robin with session persistance using cookie + StickyRoundRobin(LbStickyRoundRobin), +} +impl Default for LoadBalance { + fn default() -> Self { + Self::FixToFirst + } +} + +impl LoadBalance { + /// Get the index of the upstream serving the incoming request + pub(super) fn get_context(&self, context_to_lb: &Option) -> PointerToUpstream { + match self { + LoadBalance::FixToFirst => PointerToUpstream { + ptr: 0usize, + context_lb: None, + }, + LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), + LoadBalance::Random(ptr) => ptr.get_ptr(None), + LoadBalance::StickyRoundRobin(ptr) => { + // Generate new context if sticky round robin is enabled. + ptr.get_ptr(context_to_lb.as_ref()) + } + } + } +} diff --git a/src/backend/load_balance_sticky_cookie.rs b/src/backend/load_balance_sticky_cookie.rs new file mode 100644 index 0000000..d293004 --- /dev/null +++ b/src/backend/load_balance_sticky_cookie.rs @@ -0,0 +1,216 @@ +use std::borrow::Cow; + +use crate::error::*; +use chrono::{TimeZone, Utc}; +use derive_builder::Builder; + +#[derive(Debug, Clone)] +/// Struct to handle the sticky cookie string, +/// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists. +/// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist. +pub struct LbContext { + pub sticky_cookie: StickyCookie, +} + +#[derive(Debug, Clone, Builder)] +/// Cookie value only, used for COOKIE in req +pub struct StickyCookieValue { + #[builder(setter(custom))] + /// Field name indicating sticky cookie + pub name: String, + #[builder(setter(custom))] + /// Upstream server_id + pub value: String, +} +impl<'a> StickyCookieValueBuilder { + pub fn name(&mut self, v: impl Into>) -> &mut Self { + self.name = Some(v.into().to_ascii_lowercase()); + self + } + pub fn value(&mut self, v: impl Into>) -> &mut Self { + self.value = Some(v.into().to_string()); + self + } +} +impl StickyCookieValue { + pub fn try_from(value: &str, expected_name: &str) -> Result { + if !value.starts_with(expected_name) { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion from string".to_string(), + )); + }; + let kv = value.split('=').map(|v| v.trim()).collect::>(); + if kv.len() != 2 { + return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); + }; + if kv[1].is_empty() { + return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); + } + Ok(StickyCookieValue { + name: expected_name.to_string(), + value: kv[1].to_string(), + }) + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie meta information used for SET-COOKIE in res +pub struct StickyCookieInfo { + #[builder(setter(custom))] + /// Unix time + pub expires: i64, + + #[builder(setter(custom))] + /// Domain + pub domain: String, + + #[builder(setter(custom))] + /// Path + pub path: String, +} +impl<'a> StickyCookieInfoBuilder { + pub fn domain(&mut self, v: impl Into>) -> &mut Self { + self.domain = Some(v.into().to_ascii_lowercase()); + self + } + pub fn path(&mut self, v: impl Into>) -> &mut Self { + self.path = Some(v.into().to_ascii_lowercase()); + self + } + pub fn expires(&mut self, duration_secs: i64) -> &mut Self { + let current = Utc::now().timestamp(); + self.expires = Some(current + duration_secs); + self + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie +pub struct StickyCookie { + #[builder(setter(custom))] + /// Upstream server_id + pub value: StickyCookieValue, + #[builder(setter(custom), default)] + /// Upstream server_id + pub info: Option, +} + +impl<'a> StickyCookieBuilder { + pub fn value(&mut self, n: impl Into>, v: impl Into>) -> &mut Self { + self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); + self + } + pub fn info( + &mut self, + domain: impl Into>, + path: impl Into>, + duration_secs: i64, + ) -> &mut Self { + let info = StickyCookieInfoBuilder::default() + .domain(domain) + .path(path) + .expires(duration_secs) + .build() + .unwrap(); + self.info = Some(Some(info)); + self + } +} + +impl TryInto for StickyCookie { + type Error = RpxyError; + + fn try_into(self) -> Result { + if self.info.is_none() { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion into string: no meta information".to_string(), + )); + } + let info = self.info.unwrap(); + let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { + return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); + }; + let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); + let max_age = info.expires - Utc::now().timestamp(); + + Ok(format!( + "{}={}; expires={}; Max-Age={}; path={}; domain={}; HttpOnly", + self.value.name, self.value.value, exp_str, max_age, info.path, info.domain + )) + } +} + +#[derive(Debug, Clone)] +/// Configuration to serve incoming requests in the manner of "sticky cookie". +/// Including a dictionary to map Ids included in cookie and upstream destinations, +/// and expiration of cookie. +/// "domain" and "path" in the cookie will be the same as the reverse proxy options. +pub struct StickyCookieConfig { + pub name: String, + pub domain: String, + pub path: String, + pub duration: i64, +} +impl<'a> StickyCookieConfig { + pub fn build_sticky_cookie(&self, v: impl Into>) -> Result { + StickyCookieBuilder::default() + .value(self.name.clone(), v) + .info(&self.domain, &self.path, self.duration) + .build() + .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::constants::STICKY_COOKIE_NAME; + + #[test] + fn config_works() { + let config = StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), + domain: "example.com".to_string(), + path: "/path".to_string(), + duration: 100, + }; + let expires_unix = Utc::now().timestamp() + 100; + let sc_string: Result = config.build_sticky_cookie("test_value").unwrap().try_into(); + let expires_date_string = Utc + .timestamp_opt(expires_unix, 0) + .unwrap() + .format("%a, %d-%b-%Y %T GMT") + .to_string(); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com; HttpOnly", + STICKY_COOKIE_NAME, expires_date_string, 100 + ) + ); + } + #[test] + fn to_string_works() { + let sc = StickyCookie { + value: StickyCookieValue { + name: STICKY_COOKIE_NAME.to_string(), + value: "test_value".to_string(), + }, + info: Some(StickyCookieInfo { + expires: 1686221173i64, + domain: "example.com".to_string(), + path: "/path".to_string(), + }), + }; + let sc_string: Result = sc.try_into(); + let max_age = 1686221173i64 - Utc::now().timestamp(); + assert!(sc_string.is_ok()); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com; HttpOnly", + STICKY_COOKIE_NAME, max_age + ) + ); + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 409222d..9164c45 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,6 +1,14 @@ +mod load_balance; +mod load_balance_sticky_cookie; mod upstream; mod upstream_opts; +pub use self::{ + load_balance::LoadBalance, + load_balance_sticky_cookie::{LbContext, StickyCookie, StickyCookieBuilder, StickyCookieValue}, + upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, + upstream_opts::UpstreamOption, +}; use crate::{ log::*, utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, @@ -20,20 +28,21 @@ use tokio_rustls::rustls::{ sign::{any_supported_type, CertifiedKey}, Certificate, PrivateKey, ServerConfig, }; -pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}; -pub use upstream_opts::UpstreamOption; use x509_parser::prelude::*; /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. #[derive(Builder)] pub struct Backend { #[builder(setter(into))] + /// backend application name, e.g., app1 pub app_name: String, #[builder(setter(custom))] + /// server name, e.g., example.com, in String ascii lower case pub server_name: String, + /// struct of reverse proxy serving incoming request pub reverse_proxy: ReverseProxy, - // tls settings + /// tls settings #[builder(setter(custom), default)] pub tls_cert_path: Option, #[builder(setter(custom), default)] diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index c2fdb34..9e53e5d 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,22 +1,23 @@ -use super::{BytesName, PathNameBytesExp, UpstreamOption}; -use crate::log::*; -use derive_builder::Builder; -use rand::Rng; -use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use std::{ - borrow::Cow, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use super::{ + load_balance::{ + load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LbStickyRoundRobinBuilder, LoadBalance, }, + load_balance_sticky_cookie::LbContext, + BytesName, PathNameBytesExp, UpstreamOption, }; - +use crate::log::*; +use base64::{engine::general_purpose, Engine as _}; +use derive_builder::Builder; +use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use sha2::{Digest, Sha256}; +use std::borrow::Cow; #[derive(Debug, Clone)] pub struct ReverseProxy { pub upstream: HashMap, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 } impl ReverseProxy { + /// Get an appropriate upstream destination for given path string. pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamGroup> { // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 // コスト的にこの程度で十分 @@ -50,38 +51,48 @@ impl ReverseProxy { } } -#[allow(dead_code)] #[derive(Debug, Clone)] -pub enum LoadBalance { - RoundRobin, - Random, +/// Upstream struct just containing uri without path +pub struct Upstream { + /// Base uri without specific path + pub uri: hyper::Uri, } -impl Default for LoadBalance { - fn default() -> Self { - Self::RoundRobin +impl Upstream { + /// Hashing uri with index to avoid collision + pub fn calculate_id_with_index(&self, index: usize) -> String { + let mut hasher = Sha256::new(); + let uri_string = format!("{}&index={}", self.uri.clone(), index); + hasher.update(uri_string.as_bytes()); + let digest = hasher.finalize(); + general_purpose::URL_SAFE_NO_PAD.encode(digest) } } - -#[derive(Debug, Clone)] -pub struct Upstream { - pub uri: hyper::Uri, // base uri without specific path -} - #[derive(Debug, Clone, Builder)] +/// Struct serving multiple upstream servers for, e.g., load balancing. pub struct UpstreamGroup { + #[builder(setter(custom))] + /// Upstream server(s) pub upstream: Vec, #[builder(setter(custom), default)] + /// Path like "/path" in [[PathNameBytesExp]] associated with the upstream server(s) pub path: PathNameBytesExp, #[builder(setter(custom), default)] + /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url pub replace_path: Option, - #[builder(default)] - pub lb: LoadBalance, - #[builder(default)] - pub cnt: UpstreamCount, // counter for load balancing + #[builder(setter(custom), default)] + /// Load balancing option + pub lb: LoadBalance, + #[builder(setter(custom), default)] + /// Activated upstream options defined in [[UpstreamOption]] pub opts: HashSet, } + impl UpstreamGroupBuilder { + pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + self.upstream = Some(upstream_vec.to_vec()); + self + } pub fn path(&mut self, v: &Option) -> &mut Self { let path = match v { Some(p) => p.to_path_name_vec(), @@ -98,6 +109,44 @@ impl UpstreamGroupBuilder { ); self } + pub fn lb( + &mut self, + v: &Option, + // upstream_num: &usize, + upstream_vec: &Vec, + server_name: &str, + path_opt: &Option, + ) -> &mut Self { + let upstream_num = &upstream_vec.len(); + let lb = if let Some(x) = v { + match x.as_str() { + lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, + lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), + lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( + LbRoundRobinBuilder::default() + .num_upstreams(upstream_num) + .build() + .unwrap(), + ), + lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( + LbStickyRoundRobinBuilder::default() + .num_upstreams(upstream_num) + .sticky_config(server_name, path_opt) + .upstream_maps(upstream_vec) // TODO: + .build() + .unwrap(), + ), + _ => { + error!("Specified load balancing option is invalid."); + LoadBalance::default() + } + } + } else { + LoadBalance::default() + }; + self.lb = Some(lb); + self + } pub fn opts(&mut self, v: &Option>) -> &mut Self { let opts = if let Some(opts) = v { opts @@ -112,33 +161,37 @@ impl UpstreamGroupBuilder { } } -#[derive(Debug, Clone, Default)] -pub struct UpstreamCount(Arc); - impl UpstreamGroup { - pub fn get(&self) -> Option<&Upstream> { - match self.lb { - LoadBalance::RoundRobin => { - let idx = self.increment_cnt(); - self.upstream.get(idx) - } - LoadBalance::Random => { - let mut rng = rand::thread_rng(); - let max = self.upstream.len() - 1; - self.upstream.get(rng.gen_range(0..max)) - } - } - } - - fn current_cnt(&self) -> usize { - self.cnt.0.load(Ordering::Relaxed) - } - - fn increment_cnt(&self) -> usize { - if self.current_cnt() < self.upstream.len() - 1 { - self.cnt.0.fetch_add(1, Ordering::Relaxed) - } else { - self.cnt.0.fetch_and(0, Ordering::Relaxed) - } + /// Get an enabled option of load balancing [[LoadBalance]] + pub fn get(&self, context_to_lb: &Option) -> (Option<&Upstream>, Option) { + let pointer_to_upstream = self.lb.get_context(context_to_lb); + debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); + debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); + debug!( + "Context from LB (Set-Cookie in Res): {:?}", + pointer_to_upstream.context_lb + ); + ( + self.upstream.get(pointer_to_upstream.ptr), + pointer_to_upstream.context_lb, + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn calc_id_works() { + let uri = "https://www.rust-lang.org".parse::().unwrap(); + let upstream = Upstream { uri }; + assert_eq!( + "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", + upstream.calculate_id_with_index(0) + ); + assert_eq!( + "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", + upstream.calculate_id_with_index(1) + ); } } diff --git a/src/config/parse.rs b/src/config/parse.rs index 935f86c..8e4ddf7 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,6 +1,6 @@ use super::toml::{ConfigToml, ReverseProxyOption}; use crate::{ - backend::{BackendBuilder, ReverseProxy, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, + backend::{BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, constants::*, error::*, globals::*, @@ -99,7 +99,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro let mut backend_builder = BackendBuilder::default(); // reverse proxy settings ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); - let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; + let reverse_proxy = get_reverse_proxy(server_name_string, app.reverse_proxy.as_ref().unwrap())?; backend_builder .app_name(server_name_string) @@ -198,13 +198,21 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro Ok(()) } -fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result { +fn get_reverse_proxy( + server_name_string: &str, + rp_settings: &[ReverseProxyOption], +) -> std::result::Result { let mut upstream: HashMap = HashMap::default(); + rp_settings.iter().for_each(|rpo| { + let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); + // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); + // let lb_upstream_num = vec_upstream.len(); let elem = UpstreamGroupBuilder::default() - .upstream(rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect()) + .upstream(&upstream_vec) .path(&rpo.path) .replace_path(&rpo.replace_path) + .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) .opts(&rpo.upstream_options) .build() .unwrap(); diff --git a/src/config/toml.rs b/src/config/toml.rs index cefacb2..6ce48b2 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -57,6 +57,7 @@ pub struct ReverseProxyOption { pub replace_path: Option, pub upstream: Vec, pub upstream_options: Option>, + pub load_balance: Option, } #[derive(Deserialize, Debug, Default)] diff --git a/src/constants.rs b/src/constants.rs index d2fc25f..6d4d8ad 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -24,3 +24,6 @@ pub mod H3 { pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs } + +// For load-balancing with sticky cookie +pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; diff --git a/src/error.rs b/src/error.rs index 6da3b02..c5b34ad 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,6 +22,9 @@ pub enum RpxyError { #[error("TCP/UDP Proxy Layer Error: {0}")] Proxy(String), + #[error("LoadBalance Layer Error: {0}")] + LoadBalance(String), + #[error("I/O Error")] Io(#[from] io::Error), diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index 2ef5665..74f73a0 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -1,7 +1,7 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; +use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext}; use crate::{ - backend::{Backend, UpstreamGroup}, + backend::{Backend, LoadBalance, UpstreamGroup}, error::*, globals::Globals, log::*, @@ -91,7 +91,7 @@ where let request_upgraded = req.extensions_mut().remove::(); // Build request from destination information - if let Err(e) = self.generate_request_forwarded( + let context = match self.generate_request_forwarded( &client_addr, &listen_addr, &mut req, @@ -99,8 +99,11 @@ where upstream_group, tls_enabled, ) { - error!("Failed to generate destination uri for reverse proxy: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + Err(e) => { + error!("Failed to generate destination uri for reverse proxy: {}", e); + return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + } + Ok(v) => v, }; debug!("Request to be forwarded: {:?}", req); log_data.xff(&req.headers().get("x-forwarded-for")); @@ -123,6 +126,15 @@ where } }; + // Process reverse proxy context generated during the forwarding request generation. + if let Some(context_from_lb) = context.context_lb { + let res_headers = res_backend.headers_mut(); + if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + error!("Failed to append context to the response given from backend: {}", e); + return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); + } + } + if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { // Generate response to client if self.generate_response_forwarded(&mut res_backend, backend).is_ok() { @@ -229,7 +241,7 @@ where upgrade: &Option, upstream_group: &UpstreamGroup, tls_enabled: bool, - ) -> Result<()> { + ) -> Result { debug!("Generate request to be forwarded"); // Add te: trailer if contained in original request @@ -265,8 +277,19 @@ where .insert(header::HOST, HeaderValue::from_str(&org_host)?); }; + ///////////////////////////////////////////// // Fix unique upstream destination since there could be multiple ones. - let upstream_chosen = upstream_group.get().ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context_to_lb = if let LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { + takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + } else { + None + }; + let (upstream_chosen_opt, context_from_lb) = upstream_group.get(&context_to_lb); + let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context = HandlerContext { + context_lb: context_from_lb, + }; + ///////////////////////////////////////////// // apply upstream-specific headers given in upstream_option let headers = req.headers_mut(); @@ -319,6 +342,6 @@ where *req.version_mut() = Version::HTTP_2; } - Ok(()) + Ok(context) } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index c2225ce..fc30129 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -4,3 +4,10 @@ mod utils_request; mod utils_synth_response; pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; + +use crate::backend::LbContext; + +#[derive(Debug)] +struct HandlerContext { + context_lb: Option, +} diff --git a/src/handler/utils_headers.rs b/src/handler/utils_headers.rs index 7fc4a5f..3819386 100644 --- a/src/handler/utils_headers.rs +++ b/src/handler/utils_headers.rs @@ -1,5 +1,5 @@ use crate::{ - backend::{UpstreamGroup, UpstreamOption}, + backend::{LbContext, StickyCookie, StickyCookieValue, UpstreamGroup, UpstreamOption}, error::*, log::*, utils::*, @@ -14,6 +14,74 @@ use std::net::SocketAddr; //////////////////////////////////////////////////// // Functions to manipulate headers +/// Take sticky cookie header value from request header, +/// and returns LbContext to be forwarded to LB if exist and if needed. +/// Removing sticky cookie is needed and it must not be passed to the upstream. +pub(super) fn takeout_sticky_cookie_lb_context( + headers: &mut HeaderMap, + expected_cookie_name: &str, +) -> Result> { + let mut headers_clone = headers.clone(); + + match headers_clone.entry(hyper::header::COOKIE) { + header::Entry::Vacant(_) => Ok(None), + header::Entry::Occupied(entry) => { + let cookies_iter = entry + .iter() + .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); + let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter + .into_iter() + .partition(|v| v.starts_with(expected_cookie_name)); + if sticky_cookies.is_empty() { + return Ok(None); + } + if sticky_cookies.len() > 1 { + error!("Multiple sticky cookie values in request"); + return Err(RpxyError::Other(anyhow!( + "Invalid cookie: Multiple sticky cookie values" + ))); + } + let cookies_passed_to_upstream = without_sticky_cookies.join("; "); + let cookie_passed_to_lb = sticky_cookies.first().unwrap(); + headers.remove(hyper::header::COOKIE); + headers.insert(hyper::header::COOKIE, cookies_passed_to_upstream.parse()?); + + let sticky_cookie = StickyCookie { + value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, + info: None, + }; + Ok(Some(LbContext { sticky_cookie })) + } + } +} + +/// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated. +/// Set-Cookie response header could be in multiple lines. +/// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie +pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { + let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; + let new_header_val: HeaderValue = sticky_cookie_string.parse()?; + let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; + match headers.entry(hyper::header::SET_COOKIE) { + header::Entry::Vacant(entry) => { + entry.insert(new_header_val); + } + header::Entry::Occupied(mut entry) => { + let mut flag = false; + for e in entry.iter_mut() { + if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { + *e = new_header_val.clone(); + flag = true; + } + } + if !flag { + entry.append(new_header_val); + } + } + }; + Ok(()) +} + pub(super) fn apply_upstream_options_to_header( headers: &mut HeaderMap, _client_addr: &SocketAddr, diff --git a/src/utils/bytes_name.rs b/src/utils/bytes_name.rs index 80bc0f0..16ec7ab 100644 --- a/src/utils/bytes_name.rs +++ b/src/utils/bytes_name.rs @@ -1,5 +1,5 @@ /// Server name (hostname or ip address) representation in bytes-based struct -/// For searching hashmap or key list by exact or longest-prefix matching +/// for searching hashmap or key list by exact or longest-prefix matching #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] pub struct ServerNameBytesExp(pub Vec); // lowercase ascii bytes impl From<&[u8]> for ServerNameBytesExp { @@ -8,8 +8,8 @@ impl From<&[u8]> for ServerNameBytesExp { } } -/// Server name (hostname or ip address) representation in bytes-based struct -/// For searching hashmap or key list by exact or longest-prefix matching +/// Path name, like "/path/ok", represented in bytes-based struct +/// for searching hashmap or key list by exact or longest-prefix matching #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] pub struct PathNameBytesExp(pub Vec); // lowercase ascii bytes impl PathNameBytesExp {