diff --git a/Cargo.toml b/Cargo.toml index 10400e3..33b6080 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,6 +65,15 @@ h3 = { path = "./h3/h3/", optional = true } # h3-quinn = { path = "./h3/h3-quinn/", optional = true } h3-quinn = { path = "./h3-quinn/", optional = true } # Tentative to support rustls-0.21 +# cookie handling +chrono = { version = "0.4.26", default-features = false, features = [ + "unstable-locales", + "alloc", + "clock", +] } +base64 = "0.21.2" +sha2 = { version = "0.10.6", default-features = false } + [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = "0.5.0" diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs index 8ff8b4b..647903f 100644 --- a/src/backend/load_balance.rs +++ b/src/backend/load_balance.rs @@ -1,8 +1,14 @@ +use super::{load_balance_sticky_cookie::StickyCookieConfig, LbContext, Upstream}; +use crate::{constants::STICKY_COOKIE_NAME, error::*, log::*}; use derive_builder::Builder; use rand::Rng; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, +use rustc_hash::FxHashMap as HashMap; +use std::{ + borrow::Cow, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, }; /// Constants to specify a load balance option @@ -13,6 +19,18 @@ pub(super) mod load_balance_options { pub const STICKY_ROUND_ROBIN: &str = "sticky"; } +#[derive(Debug, Clone)] +/// Pointer to upstream serving the incoming request. +/// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given. +pub(super) struct PointerToUpstream { + pub ptr: usize, + pub context_lb: Option, +} +/// Trait for LB +trait LbWithPointer { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; +} + #[derive(Debug, Clone, Builder)] /// Round Robin LB object as a pointer to the current serving upstream destination pub struct LbRoundRobin { @@ -29,20 +47,19 @@ impl LbRoundRobinBuilder { self } } -impl LbRoundRobin { - /// Get a current count of upstream served - fn current_ptr(&self) -> usize { - self.ptr.load(Ordering::Relaxed) - } - +impl LbWithPointer for LbRoundRobin { /// Increment the count of upstream served up to the max value - pub fn increment_ptr(&self) -> usize { - if self.current_ptr() < self.num_upstreams - 1 { + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + let ptr = if current_ptr < self.num_upstreams - 1 { self.ptr.fetch_add(1, Ordering::Relaxed) } else { // Clear the counter self.ptr.fetch_and(0, Ordering::Relaxed) - } + }; + PointerToUpstream { ptr, context_lb: None } } } @@ -59,11 +76,129 @@ impl LbRandomBuilder { self } } -impl LbRandom { +impl LbWithPointer for LbRandom { /// Returns the random index within the range - pub fn get_ptr(&self) -> usize { + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { let mut rng = rand::thread_rng(); - rng.gen_range(0..self.num_upstreams) + let ptr = rng.gen_range(0..self.num_upstreams); + PointerToUpstream { ptr, context_lb: None } + } +} + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object in the sticky cookie manner +pub struct LbStickyRoundRobin { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, + #[builder(setter(custom))] + /// Information to build the cookie to stick clients to specific backends + pub sticky_config: StickyCookieConfig, + #[builder(setter(custom))] + /// Hashmaps: + /// - Hashmap that maps server indices to server id (string) + /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_maps: UpstreamMap, +} +#[derive(Debug, Clone)] +pub struct UpstreamMap { + /// Hashmap that maps server indices to server id (string) + upstream_index_map: Vec, + /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_id_map: HashMap, +} +impl LbStickyRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } + pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option) -> &mut Self { + self.sticky_config = Some(StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように + domain: server_name.to_ascii_lowercase(), + path: if let Some(v) = path_opt { + v.to_ascii_lowercase() + } else { + "/".to_string() + }, + duration: 300, // TODO: config等で変更できるように + }); + self + } + pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + let upstream_index_map: Vec = upstream_vec + .iter() + .enumerate() + .map(|(i, v)| v.calculate_id_with_index(i)) + .collect(); + let mut upstream_id_map = HashMap::default(); + for (i, v) in upstream_index_map.iter().enumerate() { + upstream_id_map.insert(v.to_string(), i); + } + self.upstream_maps = Some(UpstreamMap { + upstream_index_map, + upstream_id_map, + }); + self + } +} +impl<'a> LbStickyRoundRobin { + fn simple_increment_ptr(&self) -> usize { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + } + } + /// This is always called only internally. So 'unwrap()' is executed. + fn get_server_id_from_index(&self, index: usize) -> String { + self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() + } + /// This function takes value passed from outside. So 'result' is used. + fn get_server_index_from_id(&self, id: impl Into>) -> Option { + let id_str = id.into().to_string(); + self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) + } +} +impl LbWithPointer for LbStickyRoundRobin { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { + // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer. + // Otherwise, get the server index indicated by the server_id inside the cookie + let ptr = match req_info { + None => { + debug!("No sticky cookie"); + self.simple_increment_ptr() + } + Some(context) => { + let server_id = &context.sticky_cookie.value.value; + if let Some(server_index) = self.get_server_index_from_id(server_id) { + debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); + server_index + } else { + debug!("Invalid sticky cookie: id={}", server_id); + self.simple_increment_ptr() + } + } + }; + + // Get the server id from the ptr. + // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie). + let upstream_id = self.get_server_id_from_index(ptr); + let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); + let new_context = Some(LbContext { + sticky_cookie: new_cookie, + }); + PointerToUpstream { + ptr, + context_lb: new_context, + } } } @@ -77,7 +212,7 @@ pub enum LoadBalance { /// Simple round robin without session persistance RoundRobin(LbRoundRobin), /// Round robin with session persistance using cookie - StickyRoundRobin(LbRoundRobin), + StickyRoundRobin(LbStickyRoundRobin), } impl Default for LoadBalance { fn default() -> Self { @@ -87,12 +222,18 @@ impl Default for LoadBalance { impl LoadBalance { /// Get the index of the upstream serving the incoming request - pub(super) fn get_idx(&self) -> usize { + pub(super) fn get_context(&self, context_to_lb: &Option) -> PointerToUpstream { match self { - LoadBalance::FixToFirst => 0usize, - LoadBalance::RoundRobin(ptr) => ptr.increment_ptr(), - LoadBalance::Random(v) => v.get_ptr(), - LoadBalance::StickyRoundRobin(_ptr) => 0usize, // todo!(), // TODO: TODO: TODO: TODO: tentative value + LoadBalance::FixToFirst => PointerToUpstream { + ptr: 0usize, + context_lb: None, + }, + LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), + LoadBalance::Random(ptr) => ptr.get_ptr(None), + LoadBalance::StickyRoundRobin(ptr) => { + // Generate new context if sticky round robin is enabled. + ptr.get_ptr(context_to_lb.as_ref()) + } } } } diff --git a/src/backend/load_balance_sticky_cookie.rs b/src/backend/load_balance_sticky_cookie.rs new file mode 100644 index 0000000..d293004 --- /dev/null +++ b/src/backend/load_balance_sticky_cookie.rs @@ -0,0 +1,216 @@ +use std::borrow::Cow; + +use crate::error::*; +use chrono::{TimeZone, Utc}; +use derive_builder::Builder; + +#[derive(Debug, Clone)] +/// Struct to handle the sticky cookie string, +/// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists. +/// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist. +pub struct LbContext { + pub sticky_cookie: StickyCookie, +} + +#[derive(Debug, Clone, Builder)] +/// Cookie value only, used for COOKIE in req +pub struct StickyCookieValue { + #[builder(setter(custom))] + /// Field name indicating sticky cookie + pub name: String, + #[builder(setter(custom))] + /// Upstream server_id + pub value: String, +} +impl<'a> StickyCookieValueBuilder { + pub fn name(&mut self, v: impl Into>) -> &mut Self { + self.name = Some(v.into().to_ascii_lowercase()); + self + } + pub fn value(&mut self, v: impl Into>) -> &mut Self { + self.value = Some(v.into().to_string()); + self + } +} +impl StickyCookieValue { + pub fn try_from(value: &str, expected_name: &str) -> Result { + if !value.starts_with(expected_name) { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion from string".to_string(), + )); + }; + let kv = value.split('=').map(|v| v.trim()).collect::>(); + if kv.len() != 2 { + return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); + }; + if kv[1].is_empty() { + return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); + } + Ok(StickyCookieValue { + name: expected_name.to_string(), + value: kv[1].to_string(), + }) + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie meta information used for SET-COOKIE in res +pub struct StickyCookieInfo { + #[builder(setter(custom))] + /// Unix time + pub expires: i64, + + #[builder(setter(custom))] + /// Domain + pub domain: String, + + #[builder(setter(custom))] + /// Path + pub path: String, +} +impl<'a> StickyCookieInfoBuilder { + pub fn domain(&mut self, v: impl Into>) -> &mut Self { + self.domain = Some(v.into().to_ascii_lowercase()); + self + } + pub fn path(&mut self, v: impl Into>) -> &mut Self { + self.path = Some(v.into().to_ascii_lowercase()); + self + } + pub fn expires(&mut self, duration_secs: i64) -> &mut Self { + let current = Utc::now().timestamp(); + self.expires = Some(current + duration_secs); + self + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie +pub struct StickyCookie { + #[builder(setter(custom))] + /// Upstream server_id + pub value: StickyCookieValue, + #[builder(setter(custom), default)] + /// Upstream server_id + pub info: Option, +} + +impl<'a> StickyCookieBuilder { + pub fn value(&mut self, n: impl Into>, v: impl Into>) -> &mut Self { + self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); + self + } + pub fn info( + &mut self, + domain: impl Into>, + path: impl Into>, + duration_secs: i64, + ) -> &mut Self { + let info = StickyCookieInfoBuilder::default() + .domain(domain) + .path(path) + .expires(duration_secs) + .build() + .unwrap(); + self.info = Some(Some(info)); + self + } +} + +impl TryInto for StickyCookie { + type Error = RpxyError; + + fn try_into(self) -> Result { + if self.info.is_none() { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion into string: no meta information".to_string(), + )); + } + let info = self.info.unwrap(); + let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { + return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); + }; + let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); + let max_age = info.expires - Utc::now().timestamp(); + + Ok(format!( + "{}={}; expires={}; Max-Age={}; path={}; domain={}; HttpOnly", + self.value.name, self.value.value, exp_str, max_age, info.path, info.domain + )) + } +} + +#[derive(Debug, Clone)] +/// Configuration to serve incoming requests in the manner of "sticky cookie". +/// Including a dictionary to map Ids included in cookie and upstream destinations, +/// and expiration of cookie. +/// "domain" and "path" in the cookie will be the same as the reverse proxy options. +pub struct StickyCookieConfig { + pub name: String, + pub domain: String, + pub path: String, + pub duration: i64, +} +impl<'a> StickyCookieConfig { + pub fn build_sticky_cookie(&self, v: impl Into>) -> Result { + StickyCookieBuilder::default() + .value(self.name.clone(), v) + .info(&self.domain, &self.path, self.duration) + .build() + .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::constants::STICKY_COOKIE_NAME; + + #[test] + fn config_works() { + let config = StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), + domain: "example.com".to_string(), + path: "/path".to_string(), + duration: 100, + }; + let expires_unix = Utc::now().timestamp() + 100; + let sc_string: Result = config.build_sticky_cookie("test_value").unwrap().try_into(); + let expires_date_string = Utc + .timestamp_opt(expires_unix, 0) + .unwrap() + .format("%a, %d-%b-%Y %T GMT") + .to_string(); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com; HttpOnly", + STICKY_COOKIE_NAME, expires_date_string, 100 + ) + ); + } + #[test] + fn to_string_works() { + let sc = StickyCookie { + value: StickyCookieValue { + name: STICKY_COOKIE_NAME.to_string(), + value: "test_value".to_string(), + }, + info: Some(StickyCookieInfo { + expires: 1686221173i64, + domain: "example.com".to_string(), + path: "/path".to_string(), + }), + }; + let sc_string: Result = sc.try_into(); + let max_age = 1686221173i64 - Utc::now().timestamp(); + assert!(sc_string.is_ok()); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com; HttpOnly", + STICKY_COOKIE_NAME, max_age + ) + ); + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 00a6c83..9164c45 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,7 +1,14 @@ mod load_balance; +mod load_balance_sticky_cookie; mod upstream; mod upstream_opts; +pub use self::{ + load_balance::LoadBalance, + load_balance_sticky_cookie::{LbContext, StickyCookie, StickyCookieBuilder, StickyCookieValue}, + upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, + upstream_opts::UpstreamOption, +}; use crate::{ log::*, utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, @@ -21,8 +28,6 @@ use tokio_rustls::rustls::{ sign::{any_supported_type, CertifiedKey}, Certificate, PrivateKey, ServerConfig, }; -pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}; -pub use upstream_opts::UpstreamOption; use x509_parser::prelude::*; /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index ac17a40..9e53e5d 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,12 +1,16 @@ use super::{ - load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}, + load_balance::{ + load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LbStickyRoundRobinBuilder, LoadBalance, + }, + load_balance_sticky_cookie::LbContext, BytesName, PathNameBytesExp, UpstreamOption, }; use crate::log::*; +use base64::{engine::general_purpose, Engine as _}; use derive_builder::Builder; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; +use sha2::{Digest, Sha256}; use std::borrow::Cow; - #[derive(Debug, Clone)] pub struct ReverseProxy { pub upstream: HashMap, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 @@ -53,10 +57,20 @@ pub struct Upstream { /// Base uri without specific path pub uri: hyper::Uri, } - +impl Upstream { + /// Hashing uri with index to avoid collision + pub fn calculate_id_with_index(&self, index: usize) -> String { + let mut hasher = Sha256::new(); + let uri_string = format!("{}&index={}", self.uri.clone(), index); + hasher.update(uri_string.as_bytes()); + let digest = hasher.finalize(); + general_purpose::URL_SAFE_NO_PAD.encode(digest) + } +} #[derive(Debug, Clone, Builder)] /// Struct serving multiple upstream servers for, e.g., load balancing. pub struct UpstreamGroup { + #[builder(setter(custom))] /// Upstream server(s) pub upstream: Vec, #[builder(setter(custom), default)] @@ -75,6 +89,10 @@ pub struct UpstreamGroup { } impl UpstreamGroupBuilder { + pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + self.upstream = Some(upstream_vec.to_vec()); + self + } pub fn path(&mut self, v: &Option) -> &mut Self { let path = match v { Some(p) => p.to_path_name_vec(), @@ -91,7 +109,15 @@ impl UpstreamGroupBuilder { ); self } - pub fn lb(&mut self, v: &Option, upstream_num: &usize) -> &mut Self { + pub fn lb( + &mut self, + v: &Option, + // upstream_num: &usize, + upstream_vec: &Vec, + server_name: &str, + path_opt: &Option, + ) -> &mut Self { + let upstream_num = &upstream_vec.len(); let lb = if let Some(x) = v { match x.as_str() { lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, @@ -103,8 +129,10 @@ impl UpstreamGroupBuilder { .unwrap(), ), lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( - LbRoundRobinBuilder::default() + LbStickyRoundRobinBuilder::default() .num_upstreams(upstream_num) + .sticky_config(server_name, path_opt) + .upstream_maps(upstream_vec) // TODO: .build() .unwrap(), ), @@ -135,9 +163,35 @@ impl UpstreamGroupBuilder { impl UpstreamGroup { /// Get an enabled option of load balancing [[LoadBalance]] - pub fn get(&self) -> Option<&Upstream> { - let idx = self.lb.get_idx(); - debug!("Upstream of index {idx} is chosen."); - self.upstream.get(idx) + pub fn get(&self, context_to_lb: &Option) -> (Option<&Upstream>, Option) { + let pointer_to_upstream = self.lb.get_context(context_to_lb); + debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); + debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); + debug!( + "Context from LB (Set-Cookie in Res): {:?}", + pointer_to_upstream.context_lb + ); + ( + self.upstream.get(pointer_to_upstream.ptr), + pointer_to_upstream.context_lb, + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + #[test] + fn calc_id_works() { + let uri = "https://www.rust-lang.org".parse::().unwrap(); + let upstream = Upstream { uri }; + assert_eq!( + "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", + upstream.calculate_id_with_index(0) + ); + assert_eq!( + "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", + upstream.calculate_id_with_index(1) + ); } } diff --git a/src/config/parse.rs b/src/config/parse.rs index 88892a8..8e4ddf7 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -99,7 +99,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro let mut backend_builder = BackendBuilder::default(); // reverse proxy settings ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); - let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; + let reverse_proxy = get_reverse_proxy(server_name_string, app.reverse_proxy.as_ref().unwrap())?; backend_builder .app_name(server_name_string) @@ -198,17 +198,21 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro Ok(()) } -fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result { +fn get_reverse_proxy( + server_name_string: &str, + rp_settings: &[ReverseProxyOption], +) -> std::result::Result { let mut upstream: HashMap = HashMap::default(); rp_settings.iter().for_each(|rpo| { - let vec_upstream: Vec = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); - let lb_upstream_num = vec_upstream.len(); + let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); + // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); + // let lb_upstream_num = vec_upstream.len(); let elem = UpstreamGroupBuilder::default() - .upstream(vec_upstream) + .upstream(&upstream_vec) .path(&rpo.path) .replace_path(&rpo.replace_path) - .lb(&rpo.load_balance, &lb_upstream_num) + .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) .opts(&rpo.upstream_options) .build() .unwrap(); diff --git a/src/constants.rs b/src/constants.rs index d2fc25f..6d4d8ad 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -24,3 +24,6 @@ pub mod H3 { pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs } + +// For load-balancing with sticky cookie +pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; diff --git a/src/error.rs b/src/error.rs index 6da3b02..c5b34ad 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,6 +22,9 @@ pub enum RpxyError { #[error("TCP/UDP Proxy Layer Error: {0}")] Proxy(String), + #[error("LoadBalance Layer Error: {0}")] + LoadBalance(String), + #[error("I/O Error")] Io(#[from] io::Error), diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index 664f5e4..74f73a0 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -1,7 +1,7 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; +use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext}; use crate::{ - backend::{Backend, UpstreamGroup}, + backend::{Backend, LoadBalance, UpstreamGroup}, error::*, globals::Globals, log::*, @@ -91,7 +91,7 @@ where let request_upgraded = req.extensions_mut().remove::(); // Build request from destination information - if let Err(e) = self.generate_request_forwarded( + let context = match self.generate_request_forwarded( &client_addr, &listen_addr, &mut req, @@ -99,8 +99,11 @@ where upstream_group, tls_enabled, ) { - error!("Failed to generate destination uri for reverse proxy: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + Err(e) => { + error!("Failed to generate destination uri for reverse proxy: {}", e); + return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + } + Ok(v) => v, }; debug!("Request to be forwarded: {:?}", req); log_data.xff(&req.headers().get("x-forwarded-for")); @@ -123,6 +126,15 @@ where } }; + // Process reverse proxy context generated during the forwarding request generation. + if let Some(context_from_lb) = context.context_lb { + let res_headers = res_backend.headers_mut(); + if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + error!("Failed to append context to the response given from backend: {}", e); + return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); + } + } + if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { // Generate response to client if self.generate_response_forwarded(&mut res_backend, backend).is_ok() { @@ -229,7 +241,7 @@ where upgrade: &Option, upstream_group: &UpstreamGroup, tls_enabled: bool, - ) -> Result<()> { + ) -> Result { debug!("Generate request to be forwarded"); // Add te: trailer if contained in original request @@ -265,10 +277,19 @@ where .insert(header::HOST, HeaderValue::from_str(&org_host)?); }; + ///////////////////////////////////////////// // Fix unique upstream destination since there could be multiple ones. - // TODO: StickyならCookieをここでgetに与える必要 - // TODO: Stickyで、Cookieが与えられなかったらset-cookie向けにcookieを返す必要。upstreamオブジェクトに含めるのも手。 - let upstream_chosen = upstream_group.get().ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context_to_lb = if let LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { + takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + } else { + None + }; + let (upstream_chosen_opt, context_from_lb) = upstream_group.get(&context_to_lb); + let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context = HandlerContext { + context_lb: context_from_lb, + }; + ///////////////////////////////////////////// // apply upstream-specific headers given in upstream_option let headers = req.headers_mut(); @@ -321,6 +342,6 @@ where *req.version_mut() = Version::HTTP_2; } - Ok(()) + Ok(context) } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index c2225ce..fc30129 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -4,3 +4,10 @@ mod utils_request; mod utils_synth_response; pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; + +use crate::backend::LbContext; + +#[derive(Debug)] +struct HandlerContext { + context_lb: Option, +} diff --git a/src/handler/utils_headers.rs b/src/handler/utils_headers.rs index 7fc4a5f..3819386 100644 --- a/src/handler/utils_headers.rs +++ b/src/handler/utils_headers.rs @@ -1,5 +1,5 @@ use crate::{ - backend::{UpstreamGroup, UpstreamOption}, + backend::{LbContext, StickyCookie, StickyCookieValue, UpstreamGroup, UpstreamOption}, error::*, log::*, utils::*, @@ -14,6 +14,74 @@ use std::net::SocketAddr; //////////////////////////////////////////////////// // Functions to manipulate headers +/// Take sticky cookie header value from request header, +/// and returns LbContext to be forwarded to LB if exist and if needed. +/// Removing sticky cookie is needed and it must not be passed to the upstream. +pub(super) fn takeout_sticky_cookie_lb_context( + headers: &mut HeaderMap, + expected_cookie_name: &str, +) -> Result> { + let mut headers_clone = headers.clone(); + + match headers_clone.entry(hyper::header::COOKIE) { + header::Entry::Vacant(_) => Ok(None), + header::Entry::Occupied(entry) => { + let cookies_iter = entry + .iter() + .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); + let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter + .into_iter() + .partition(|v| v.starts_with(expected_cookie_name)); + if sticky_cookies.is_empty() { + return Ok(None); + } + if sticky_cookies.len() > 1 { + error!("Multiple sticky cookie values in request"); + return Err(RpxyError::Other(anyhow!( + "Invalid cookie: Multiple sticky cookie values" + ))); + } + let cookies_passed_to_upstream = without_sticky_cookies.join("; "); + let cookie_passed_to_lb = sticky_cookies.first().unwrap(); + headers.remove(hyper::header::COOKIE); + headers.insert(hyper::header::COOKIE, cookies_passed_to_upstream.parse()?); + + let sticky_cookie = StickyCookie { + value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, + info: None, + }; + Ok(Some(LbContext { sticky_cookie })) + } + } +} + +/// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated. +/// Set-Cookie response header could be in multiple lines. +/// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie +pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { + let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; + let new_header_val: HeaderValue = sticky_cookie_string.parse()?; + let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; + match headers.entry(hyper::header::SET_COOKIE) { + header::Entry::Vacant(entry) => { + entry.insert(new_header_val); + } + header::Entry::Occupied(mut entry) => { + let mut flag = false; + for e in entry.iter_mut() { + if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { + *e = new_header_val.clone(); + flag = true; + } + } + if !flag { + entry.append(new_header_val); + } + } + }; + Ok(()) +} + pub(super) fn apply_upstream_options_to_header( headers: &mut HeaderMap, _client_addr: &SocketAddr,