feat: initial implementation of sticky cookie for session persistance when load-balancing
This commit is contained in:
		
					parent
					
						
							
								f4c59c9f2f
							
						
					
				
			
			
				commit
				
					
						a0ae3578d7
					
				
			
		
					 11 changed files with 580 additions and 49 deletions
				
			
		|  | @ -65,6 +65,15 @@ h3 = { path = "./h3/h3/", optional = true } | |||
| # h3-quinn = { path = "./h3/h3-quinn/", optional = true } | ||||
| h3-quinn = { path = "./h3-quinn/", optional = true } # Tentative to support rustls-0.21 | ||||
| 
 | ||||
| # cookie handling | ||||
| chrono = { version = "0.4.26", default-features = false, features = [ | ||||
|   "unstable-locales", | ||||
|   "alloc", | ||||
|   "clock", | ||||
| ] } | ||||
| base64 = "0.21.2" | ||||
| sha2 = { version = "0.10.6", default-features = false } | ||||
| 
 | ||||
| 
 | ||||
| [target.'cfg(not(target_env = "msvc"))'.dependencies] | ||||
| tikv-jemallocator = "0.5.0" | ||||
|  |  | |||
|  | @ -1,8 +1,14 @@ | |||
| use super::{load_balance_sticky_cookie::StickyCookieConfig, LbContext, Upstream}; | ||||
| use crate::{constants::STICKY_COOKIE_NAME, error::*, log::*}; | ||||
| use derive_builder::Builder; | ||||
| use rand::Rng; | ||||
| use std::sync::{ | ||||
|   atomic::{AtomicUsize, Ordering}, | ||||
|   Arc, | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use std::{ | ||||
|   borrow::Cow, | ||||
|   sync::{ | ||||
|     atomic::{AtomicUsize, Ordering}, | ||||
|     Arc, | ||||
|   }, | ||||
| }; | ||||
| 
 | ||||
| /// Constants to specify a load balance option
 | ||||
|  | @ -13,6 +19,18 @@ pub(super) mod load_balance_options { | |||
|   pub const STICKY_ROUND_ROBIN: &str = "sticky"; | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Pointer to upstream serving the incoming request.
 | ||||
| /// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given.
 | ||||
| pub(super) struct PointerToUpstream { | ||||
|   pub ptr: usize, | ||||
|   pub context_lb: Option<LbContext>, | ||||
| } | ||||
| /// Trait for LB
 | ||||
| trait LbWithPointer { | ||||
|   fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Round Robin LB object as a pointer to the current serving upstream destination
 | ||||
| pub struct LbRoundRobin { | ||||
|  | @ -29,20 +47,19 @@ impl LbRoundRobinBuilder { | |||
|     self | ||||
|   } | ||||
| } | ||||
| impl LbRoundRobin { | ||||
|   /// Get a current count of upstream served
 | ||||
|   fn current_ptr(&self) -> usize { | ||||
|     self.ptr.load(Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
| impl LbWithPointer for LbRoundRobin { | ||||
|   /// Increment the count of upstream served up to the max value
 | ||||
|   pub fn increment_ptr(&self) -> usize { | ||||
|     if self.current_ptr() < self.num_upstreams - 1 { | ||||
|   fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { | ||||
|     // Get a current count of upstream served
 | ||||
|     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||
| 
 | ||||
|     let ptr = if current_ptr < self.num_upstreams - 1 { | ||||
|       self.ptr.fetch_add(1, Ordering::Relaxed) | ||||
|     } else { | ||||
|       // Clear the counter
 | ||||
|       self.ptr.fetch_and(0, Ordering::Relaxed) | ||||
|     } | ||||
|     }; | ||||
|     PointerToUpstream { ptr, context_lb: None } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  | @ -59,11 +76,129 @@ impl LbRandomBuilder { | |||
|     self | ||||
|   } | ||||
| } | ||||
| impl LbRandom { | ||||
| impl LbWithPointer for LbRandom { | ||||
|   /// Returns the random index within the range
 | ||||
|   pub fn get_ptr(&self) -> usize { | ||||
|   fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { | ||||
|     let mut rng = rand::thread_rng(); | ||||
|     rng.gen_range(0..self.num_upstreams) | ||||
|     let ptr = rng.gen_range(0..self.num_upstreams); | ||||
|     PointerToUpstream { ptr, context_lb: None } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Round Robin LB object in the sticky cookie manner
 | ||||
| pub struct LbStickyRoundRobin { | ||||
|   #[builder(default)] | ||||
|   /// Pointer to the index of the last served upstream destination
 | ||||
|   ptr: Arc<AtomicUsize>, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Number of upstream destinations
 | ||||
|   num_upstreams: usize, | ||||
|   #[builder(setter(custom))] | ||||
|   /// Information to build the cookie to stick clients to specific backends
 | ||||
|   pub sticky_config: StickyCookieConfig, | ||||
|   #[builder(setter(custom))] | ||||
|   /// Hashmaps:
 | ||||
|   /// - Hashmap that maps server indices to server id (string)
 | ||||
|   /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 | ||||
|   upstream_maps: UpstreamMap, | ||||
| } | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct UpstreamMap { | ||||
|   /// Hashmap that maps server indices to server id (string)
 | ||||
|   upstream_index_map: Vec<String>, | ||||
|   /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 | ||||
|   upstream_id_map: HashMap<String, usize>, | ||||
| } | ||||
| impl LbStickyRoundRobinBuilder { | ||||
|   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||
|     self.num_upstreams = Some(*v); | ||||
|     self | ||||
|   } | ||||
|   pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option<String>) -> &mut Self { | ||||
|     self.sticky_config = Some(StickyCookieConfig { | ||||
|       name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように
 | ||||
|       domain: server_name.to_ascii_lowercase(), | ||||
|       path: if let Some(v) = path_opt { | ||||
|         v.to_ascii_lowercase() | ||||
|       } else { | ||||
|         "/".to_string() | ||||
|       }, | ||||
|       duration: 300, // TODO: config等で変更できるように
 | ||||
|     }); | ||||
|     self | ||||
|   } | ||||
|   pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||
|     let upstream_index_map: Vec<String> = upstream_vec | ||||
|       .iter() | ||||
|       .enumerate() | ||||
|       .map(|(i, v)| v.calculate_id_with_index(i)) | ||||
|       .collect(); | ||||
|     let mut upstream_id_map = HashMap::default(); | ||||
|     for (i, v) in upstream_index_map.iter().enumerate() { | ||||
|       upstream_id_map.insert(v.to_string(), i); | ||||
|     } | ||||
|     self.upstream_maps = Some(UpstreamMap { | ||||
|       upstream_index_map, | ||||
|       upstream_id_map, | ||||
|     }); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| impl<'a> LbStickyRoundRobin { | ||||
|   fn simple_increment_ptr(&self) -> usize { | ||||
|     // Get a current count of upstream served
 | ||||
|     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||
| 
 | ||||
|     if current_ptr < self.num_upstreams - 1 { | ||||
|       self.ptr.fetch_add(1, Ordering::Relaxed) | ||||
|     } else { | ||||
|       // Clear the counter
 | ||||
|       self.ptr.fetch_and(0, Ordering::Relaxed) | ||||
|     } | ||||
|   } | ||||
|   /// This is always called only internally. So 'unwrap()' is executed.
 | ||||
|   fn get_server_id_from_index(&self, index: usize) -> String { | ||||
|     self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() | ||||
|   } | ||||
|   /// This function takes value passed from outside. So 'result' is used.
 | ||||
|   fn get_server_index_from_id(&self, id: impl Into<Cow<'a, str>>) -> Option<usize> { | ||||
|     let id_str = id.into().to_string(); | ||||
|     self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) | ||||
|   } | ||||
| } | ||||
| impl LbWithPointer for LbStickyRoundRobin { | ||||
|   fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { | ||||
|     // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer.
 | ||||
|     // Otherwise, get the server index indicated by the server_id inside the cookie
 | ||||
|     let ptr = match req_info { | ||||
|       None => { | ||||
|         debug!("No sticky cookie"); | ||||
|         self.simple_increment_ptr() | ||||
|       } | ||||
|       Some(context) => { | ||||
|         let server_id = &context.sticky_cookie.value.value; | ||||
|         if let Some(server_index) = self.get_server_index_from_id(server_id) { | ||||
|           debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); | ||||
|           server_index | ||||
|         } else { | ||||
|           debug!("Invalid sticky cookie: id={}", server_id); | ||||
|           self.simple_increment_ptr() | ||||
|         } | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     // Get the server id from the ptr.
 | ||||
|     // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie).
 | ||||
|     let upstream_id = self.get_server_id_from_index(ptr); | ||||
|     let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); | ||||
|     let new_context = Some(LbContext { | ||||
|       sticky_cookie: new_cookie, | ||||
|     }); | ||||
|     PointerToUpstream { | ||||
|       ptr, | ||||
|       context_lb: new_context, | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  | @ -77,7 +212,7 @@ pub enum LoadBalance { | |||
|   /// Simple round robin without session persistance
 | ||||
|   RoundRobin(LbRoundRobin), | ||||
|   /// Round robin with session persistance using cookie
 | ||||
|   StickyRoundRobin(LbRoundRobin), | ||||
|   StickyRoundRobin(LbStickyRoundRobin), | ||||
| } | ||||
| impl Default for LoadBalance { | ||||
|   fn default() -> Self { | ||||
|  | @ -87,12 +222,18 @@ impl Default for LoadBalance { | |||
| 
 | ||||
| impl LoadBalance { | ||||
|   /// Get the index of the upstream serving the incoming request
 | ||||
|   pub(super) fn get_idx(&self) -> usize { | ||||
|   pub(super) fn get_context(&self, context_to_lb: &Option<LbContext>) -> PointerToUpstream { | ||||
|     match self { | ||||
|       LoadBalance::FixToFirst => 0usize, | ||||
|       LoadBalance::RoundRobin(ptr) => ptr.increment_ptr(), | ||||
|       LoadBalance::Random(v) => v.get_ptr(), | ||||
|       LoadBalance::StickyRoundRobin(_ptr) => 0usize, // todo!(), // TODO: TODO: TODO: TODO: tentative value
 | ||||
|       LoadBalance::FixToFirst => PointerToUpstream { | ||||
|         ptr: 0usize, | ||||
|         context_lb: None, | ||||
|       }, | ||||
|       LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), | ||||
|       LoadBalance::Random(ptr) => ptr.get_ptr(None), | ||||
|       LoadBalance::StickyRoundRobin(ptr) => { | ||||
|         // Generate new context if sticky round robin is enabled.
 | ||||
|         ptr.get_ptr(context_to_lb.as_ref()) | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  |  | |||
							
								
								
									
										216
									
								
								src/backend/load_balance_sticky_cookie.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										216
									
								
								src/backend/load_balance_sticky_cookie.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,216 @@ | |||
| use std::borrow::Cow; | ||||
| 
 | ||||
| use crate::error::*; | ||||
| use chrono::{TimeZone, Utc}; | ||||
| use derive_builder::Builder; | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Struct to handle the sticky cookie string,
 | ||||
| /// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists.
 | ||||
| /// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist.
 | ||||
| pub struct LbContext { | ||||
|   pub sticky_cookie: StickyCookie, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Cookie value only, used for COOKIE in req
 | ||||
| pub struct StickyCookieValue { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Field name indicating sticky cookie
 | ||||
|   pub name: String, | ||||
|   #[builder(setter(custom))] | ||||
|   /// Upstream server_id
 | ||||
|   pub value: String, | ||||
| } | ||||
| impl<'a> StickyCookieValueBuilder { | ||||
|   pub fn name(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.name = Some(v.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
|   pub fn value(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.value = Some(v.into().to_string()); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| impl StickyCookieValue { | ||||
|   pub fn try_from(value: &str, expected_name: &str) -> Result<Self> { | ||||
|     if !value.starts_with(expected_name) { | ||||
|       return Err(RpxyError::LoadBalance( | ||||
|         "Failed to cookie conversion from string".to_string(), | ||||
|       )); | ||||
|     }; | ||||
|     let kv = value.split('=').map(|v| v.trim()).collect::<Vec<&str>>(); | ||||
|     if kv.len() != 2 { | ||||
|       return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); | ||||
|     }; | ||||
|     if kv[1].is_empty() { | ||||
|       return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); | ||||
|     } | ||||
|     Ok(StickyCookieValue { | ||||
|       name: expected_name.to_string(), | ||||
|       value: kv[1].to_string(), | ||||
|     }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Struct describing sticky cookie meta information used for SET-COOKIE in res
 | ||||
| pub struct StickyCookieInfo { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Unix time
 | ||||
|   pub expires: i64, | ||||
| 
 | ||||
|   #[builder(setter(custom))] | ||||
|   /// Domain
 | ||||
|   pub domain: String, | ||||
| 
 | ||||
|   #[builder(setter(custom))] | ||||
|   /// Path
 | ||||
|   pub path: String, | ||||
| } | ||||
| impl<'a> StickyCookieInfoBuilder { | ||||
|   pub fn domain(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.domain = Some(v.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
|   pub fn path(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.path = Some(v.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
|   pub fn expires(&mut self, duration_secs: i64) -> &mut Self { | ||||
|     let current = Utc::now().timestamp(); | ||||
|     self.expires = Some(current + duration_secs); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Struct describing sticky cookie
 | ||||
| pub struct StickyCookie { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Upstream server_id
 | ||||
|   pub value: StickyCookieValue, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Upstream server_id
 | ||||
|   pub info: Option<StickyCookieInfo>, | ||||
| } | ||||
| 
 | ||||
| impl<'a> StickyCookieBuilder { | ||||
|   pub fn value(&mut self, n: impl Into<Cow<'a, str>>, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); | ||||
|     self | ||||
|   } | ||||
|   pub fn info( | ||||
|     &mut self, | ||||
|     domain: impl Into<Cow<'a, str>>, | ||||
|     path: impl Into<Cow<'a, str>>, | ||||
|     duration_secs: i64, | ||||
|   ) -> &mut Self { | ||||
|     let info = StickyCookieInfoBuilder::default() | ||||
|       .domain(domain) | ||||
|       .path(path) | ||||
|       .expires(duration_secs) | ||||
|       .build() | ||||
|       .unwrap(); | ||||
|     self.info = Some(Some(info)); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl TryInto<String> for StickyCookie { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<String> { | ||||
|     if self.info.is_none() { | ||||
|       return Err(RpxyError::LoadBalance( | ||||
|         "Failed to cookie conversion into string: no meta information".to_string(), | ||||
|       )); | ||||
|     } | ||||
|     let info = self.info.unwrap(); | ||||
|     let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { | ||||
|       return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); | ||||
|     }; | ||||
|     let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); | ||||
|     let max_age = info.expires - Utc::now().timestamp(); | ||||
| 
 | ||||
|     Ok(format!( | ||||
|       "{}={}; expires={}; Max-Age={}; path={}; domain={}; HttpOnly", | ||||
|       self.value.name, self.value.value, exp_str, max_age, info.path, info.domain | ||||
|     )) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Configuration to serve incoming requests in the manner of "sticky cookie".
 | ||||
| /// Including a dictionary to map Ids included in cookie and upstream destinations,
 | ||||
| /// and expiration of cookie.
 | ||||
| /// "domain" and "path" in the cookie will be the same as the reverse proxy options.
 | ||||
| pub struct StickyCookieConfig { | ||||
|   pub name: String, | ||||
|   pub domain: String, | ||||
|   pub path: String, | ||||
|   pub duration: i64, | ||||
| } | ||||
| impl<'a> StickyCookieConfig { | ||||
|   pub fn build_sticky_cookie(&self, v: impl Into<Cow<'a, str>>) -> Result<StickyCookie> { | ||||
|     StickyCookieBuilder::default() | ||||
|       .value(self.name.clone(), v) | ||||
|       .info(&self.domain, &self.path, self.duration) | ||||
|       .build() | ||||
|       .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|   use super::*; | ||||
|   use crate::constants::STICKY_COOKIE_NAME; | ||||
| 
 | ||||
|   #[test] | ||||
|   fn config_works() { | ||||
|     let config = StickyCookieConfig { | ||||
|       name: STICKY_COOKIE_NAME.to_string(), | ||||
|       domain: "example.com".to_string(), | ||||
|       path: "/path".to_string(), | ||||
|       duration: 100, | ||||
|     }; | ||||
|     let expires_unix = Utc::now().timestamp() + 100; | ||||
|     let sc_string: Result<String> = config.build_sticky_cookie("test_value").unwrap().try_into(); | ||||
|     let expires_date_string = Utc | ||||
|       .timestamp_opt(expires_unix, 0) | ||||
|       .unwrap() | ||||
|       .format("%a, %d-%b-%Y %T GMT") | ||||
|       .to_string(); | ||||
|     assert_eq!( | ||||
|       sc_string.unwrap(), | ||||
|       format!( | ||||
|         "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com; HttpOnly", | ||||
|         STICKY_COOKIE_NAME, expires_date_string, 100 | ||||
|       ) | ||||
|     ); | ||||
|   } | ||||
|   #[test] | ||||
|   fn to_string_works() { | ||||
|     let sc = StickyCookie { | ||||
|       value: StickyCookieValue { | ||||
|         name: STICKY_COOKIE_NAME.to_string(), | ||||
|         value: "test_value".to_string(), | ||||
|       }, | ||||
|       info: Some(StickyCookieInfo { | ||||
|         expires: 1686221173i64, | ||||
|         domain: "example.com".to_string(), | ||||
|         path: "/path".to_string(), | ||||
|       }), | ||||
|     }; | ||||
|     let sc_string: Result<String> = sc.try_into(); | ||||
|     let max_age = 1686221173i64 - Utc::now().timestamp(); | ||||
|     assert!(sc_string.is_ok()); | ||||
|     assert_eq!( | ||||
|       sc_string.unwrap(), | ||||
|       format!( | ||||
|         "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com; HttpOnly", | ||||
|         STICKY_COOKIE_NAME, max_age | ||||
|       ) | ||||
|     ); | ||||
|   } | ||||
| } | ||||
|  | @ -1,7 +1,14 @@ | |||
| mod load_balance; | ||||
| mod load_balance_sticky_cookie; | ||||
| mod upstream; | ||||
| mod upstream_opts; | ||||
| 
 | ||||
| pub use self::{ | ||||
|   load_balance::LoadBalance, | ||||
|   load_balance_sticky_cookie::{LbContext, StickyCookie, StickyCookieBuilder, StickyCookieValue}, | ||||
|   upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, | ||||
|   upstream_opts::UpstreamOption, | ||||
| }; | ||||
| use crate::{ | ||||
|   log::*, | ||||
|   utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, | ||||
|  | @ -21,8 +28,6 @@ use tokio_rustls::rustls::{ | |||
|   sign::{any_supported_type, CertifiedKey}, | ||||
|   Certificate, PrivateKey, ServerConfig, | ||||
| }; | ||||
| pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}; | ||||
| pub use upstream_opts::UpstreamOption; | ||||
| use x509_parser::prelude::*; | ||||
| 
 | ||||
| /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 | ||||
|  |  | |||
|  | @ -1,12 +1,16 @@ | |||
| use super::{ | ||||
|   load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}, | ||||
|   load_balance::{ | ||||
|     load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LbStickyRoundRobinBuilder, LoadBalance, | ||||
|   }, | ||||
|   load_balance_sticky_cookie::LbContext, | ||||
|   BytesName, PathNameBytesExp, UpstreamOption, | ||||
| }; | ||||
| use crate::log::*; | ||||
| use base64::{engine::general_purpose, Engine as _}; | ||||
| use derive_builder::Builder; | ||||
| use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; | ||||
| use sha2::{Digest, Sha256}; | ||||
| use std::borrow::Cow; | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct ReverseProxy { | ||||
|   pub upstream: HashMap<PathNameBytesExp, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 | ||||
|  | @ -53,10 +57,20 @@ pub struct Upstream { | |||
|   /// Base uri without specific path
 | ||||
|   pub uri: hyper::Uri, | ||||
| } | ||||
| 
 | ||||
| impl Upstream { | ||||
|   /// Hashing uri with index to avoid collision
 | ||||
|   pub fn calculate_id_with_index(&self, index: usize) -> String { | ||||
|     let mut hasher = Sha256::new(); | ||||
|     let uri_string = format!("{}&index={}", self.uri.clone(), index); | ||||
|     hasher.update(uri_string.as_bytes()); | ||||
|     let digest = hasher.finalize(); | ||||
|     general_purpose::URL_SAFE_NO_PAD.encode(digest) | ||||
|   } | ||||
| } | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Struct serving multiple upstream servers for, e.g., load balancing.
 | ||||
| pub struct UpstreamGroup { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Upstream server(s)
 | ||||
|   pub upstream: Vec<Upstream>, | ||||
|   #[builder(setter(custom), default)] | ||||
|  | @ -75,6 +89,10 @@ pub struct UpstreamGroup { | |||
| } | ||||
| 
 | ||||
| impl UpstreamGroupBuilder { | ||||
|   pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||
|     self.upstream = Some(upstream_vec.to_vec()); | ||||
|     self | ||||
|   } | ||||
|   pub fn path(&mut self, v: &Option<String>) -> &mut Self { | ||||
|     let path = match v { | ||||
|       Some(p) => p.to_path_name_vec(), | ||||
|  | @ -91,7 +109,15 @@ impl UpstreamGroupBuilder { | |||
|     ); | ||||
|     self | ||||
|   } | ||||
|   pub fn lb(&mut self, v: &Option<String>, upstream_num: &usize) -> &mut Self { | ||||
|   pub fn lb( | ||||
|     &mut self, | ||||
|     v: &Option<String>, | ||||
|     // upstream_num: &usize,
 | ||||
|     upstream_vec: &Vec<Upstream>, | ||||
|     server_name: &str, | ||||
|     path_opt: &Option<String>, | ||||
|   ) -> &mut Self { | ||||
|     let upstream_num = &upstream_vec.len(); | ||||
|     let lb = if let Some(x) = v { | ||||
|       match x.as_str() { | ||||
|         lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, | ||||
|  | @ -103,8 +129,10 @@ impl UpstreamGroupBuilder { | |||
|             .unwrap(), | ||||
|         ), | ||||
|         lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( | ||||
|           LbRoundRobinBuilder::default() | ||||
|           LbStickyRoundRobinBuilder::default() | ||||
|             .num_upstreams(upstream_num) | ||||
|             .sticky_config(server_name, path_opt) | ||||
|             .upstream_maps(upstream_vec) // TODO:
 | ||||
|             .build() | ||||
|             .unwrap(), | ||||
|         ), | ||||
|  | @ -135,9 +163,35 @@ impl UpstreamGroupBuilder { | |||
| 
 | ||||
| impl UpstreamGroup { | ||||
|   /// Get an enabled option of load balancing [[LoadBalance]]
 | ||||
|   pub fn get(&self) -> Option<&Upstream> { | ||||
|     let idx = self.lb.get_idx(); | ||||
|     debug!("Upstream of index {idx} is chosen."); | ||||
|     self.upstream.get(idx) | ||||
|   pub fn get(&self, context_to_lb: &Option<LbContext>) -> (Option<&Upstream>, Option<LbContext>) { | ||||
|     let pointer_to_upstream = self.lb.get_context(context_to_lb); | ||||
|     debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); | ||||
|     debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); | ||||
|     debug!( | ||||
|       "Context from LB (Set-Cookie in Res): {:?}", | ||||
|       pointer_to_upstream.context_lb | ||||
|     ); | ||||
|     ( | ||||
|       self.upstream.get(pointer_to_upstream.ptr), | ||||
|       pointer_to_upstream.context_lb, | ||||
|     ) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod test { | ||||
|   use super::*; | ||||
|   #[test] | ||||
|   fn calc_id_works() { | ||||
|     let uri = "https://www.rust-lang.org".parse::<hyper::Uri>().unwrap(); | ||||
|     let upstream = Upstream { uri }; | ||||
|     assert_eq!( | ||||
|       "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", | ||||
|       upstream.calculate_id_with_index(0) | ||||
|     ); | ||||
|     assert_eq!( | ||||
|       "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", | ||||
|       upstream.calculate_id_with_index(1) | ||||
|     ); | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -99,7 +99,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro | |||
|     let mut backend_builder = BackendBuilder::default(); | ||||
|     // reverse proxy settings
 | ||||
|     ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); | ||||
|     let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; | ||||
|     let reverse_proxy = get_reverse_proxy(server_name_string, app.reverse_proxy.as_ref().unwrap())?; | ||||
| 
 | ||||
|     backend_builder | ||||
|       .app_name(server_name_string) | ||||
|  | @ -198,17 +198,21 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro | |||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result<ReverseProxy, anyhow::Error> { | ||||
| fn get_reverse_proxy( | ||||
|   server_name_string: &str, | ||||
|   rp_settings: &[ReverseProxyOption], | ||||
| ) -> std::result::Result<ReverseProxy, anyhow::Error> { | ||||
|   let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default(); | ||||
| 
 | ||||
|   rp_settings.iter().for_each(|rpo| { | ||||
|     let vec_upstream: Vec<Upstream> = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); | ||||
|     let lb_upstream_num = vec_upstream.len(); | ||||
|     let upstream_vec: Vec<Upstream> = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); | ||||
|     // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap());
 | ||||
|     // let lb_upstream_num = vec_upstream.len();
 | ||||
|     let elem = UpstreamGroupBuilder::default() | ||||
|       .upstream(vec_upstream) | ||||
|       .upstream(&upstream_vec) | ||||
|       .path(&rpo.path) | ||||
|       .replace_path(&rpo.replace_path) | ||||
|       .lb(&rpo.load_balance, &lb_upstream_num) | ||||
|       .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) | ||||
|       .opts(&rpo.upstream_options) | ||||
|       .build() | ||||
|       .unwrap(); | ||||
|  |  | |||
|  | @ -24,3 +24,6 @@ pub mod H3 { | |||
|   pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; | ||||
|   pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs
 | ||||
| } | ||||
| 
 | ||||
| // For load-balancing with sticky cookie
 | ||||
| pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; | ||||
|  |  | |||
|  | @ -22,6 +22,9 @@ pub enum RpxyError { | |||
|   #[error("TCP/UDP Proxy Layer Error: {0}")] | ||||
|   Proxy(String), | ||||
| 
 | ||||
|   #[error("LoadBalance Layer Error: {0}")] | ||||
|   LoadBalance(String), | ||||
| 
 | ||||
|   #[error("I/O Error")] | ||||
|   Io(#[from] io::Error), | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 | ||||
| use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; | ||||
| use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext}; | ||||
| use crate::{ | ||||
|   backend::{Backend, UpstreamGroup}, | ||||
|   backend::{Backend, LoadBalance, UpstreamGroup}, | ||||
|   error::*, | ||||
|   globals::Globals, | ||||
|   log::*, | ||||
|  | @ -91,7 +91,7 @@ where | |||
|     let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>(); | ||||
| 
 | ||||
|     // Build request from destination information
 | ||||
|     if let Err(e) = self.generate_request_forwarded( | ||||
|     let context = match self.generate_request_forwarded( | ||||
|       &client_addr, | ||||
|       &listen_addr, | ||||
|       &mut req, | ||||
|  | @ -99,8 +99,11 @@ where | |||
|       upstream_group, | ||||
|       tls_enabled, | ||||
|     ) { | ||||
|       error!("Failed to generate destination uri for reverse proxy: {}", e); | ||||
|       return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); | ||||
|       Err(e) => { | ||||
|         error!("Failed to generate destination uri for reverse proxy: {}", e); | ||||
|         return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); | ||||
|       } | ||||
|       Ok(v) => v, | ||||
|     }; | ||||
|     debug!("Request to be forwarded: {:?}", req); | ||||
|     log_data.xff(&req.headers().get("x-forwarded-for")); | ||||
|  | @ -123,6 +126,15 @@ where | |||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     // Process reverse proxy context generated during the forwarding request generation.
 | ||||
|     if let Some(context_from_lb) = context.context_lb { | ||||
|       let res_headers = res_backend.headers_mut(); | ||||
|       if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { | ||||
|         error!("Failed to append context to the response given from backend: {}", e); | ||||
|         return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { | ||||
|       // Generate response to client
 | ||||
|       if self.generate_response_forwarded(&mut res_backend, backend).is_ok() { | ||||
|  | @ -229,7 +241,7 @@ where | |||
|     upgrade: &Option<String>, | ||||
|     upstream_group: &UpstreamGroup, | ||||
|     tls_enabled: bool, | ||||
|   ) -> Result<()> { | ||||
|   ) -> Result<HandlerContext> { | ||||
|     debug!("Generate request to be forwarded"); | ||||
| 
 | ||||
|     // Add te: trailer if contained in original request
 | ||||
|  | @ -265,10 +277,19 @@ where | |||
|         .insert(header::HOST, HeaderValue::from_str(&org_host)?); | ||||
|     }; | ||||
| 
 | ||||
|     /////////////////////////////////////////////
 | ||||
|     // Fix unique upstream destination since there could be multiple ones.
 | ||||
|     // TODO: StickyならCookieをここでgetに与える必要
 | ||||
|     // TODO: Stickyで、Cookieが与えられなかったらset-cookie向けにcookieを返す必要。upstreamオブジェクトに含めるのも手。
 | ||||
|     let upstream_chosen = upstream_group.get().ok_or_else(|| anyhow!("Failed to get upstream"))?; | ||||
|     let context_to_lb = if let LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { | ||||
|       takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? | ||||
|     } else { | ||||
|       None | ||||
|     }; | ||||
|     let (upstream_chosen_opt, context_from_lb) = upstream_group.get(&context_to_lb); | ||||
|     let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; | ||||
|     let context = HandlerContext { | ||||
|       context_lb: context_from_lb, | ||||
|     }; | ||||
|     /////////////////////////////////////////////
 | ||||
| 
 | ||||
|     // apply upstream-specific headers given in upstream_option
 | ||||
|     let headers = req.headers_mut(); | ||||
|  | @ -321,6 +342,6 @@ where | |||
|       *req.version_mut() = Version::HTTP_2; | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|     Ok(context) | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -4,3 +4,10 @@ mod utils_request; | |||
| mod utils_synth_response; | ||||
| 
 | ||||
| pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; | ||||
| 
 | ||||
| use crate::backend::LbContext; | ||||
| 
 | ||||
| #[derive(Debug)] | ||||
| struct HandlerContext { | ||||
|   context_lb: Option<LbContext>, | ||||
| } | ||||
|  |  | |||
|  | @ -1,5 +1,5 @@ | |||
| use crate::{ | ||||
|   backend::{UpstreamGroup, UpstreamOption}, | ||||
|   backend::{LbContext, StickyCookie, StickyCookieValue, UpstreamGroup, UpstreamOption}, | ||||
|   error::*, | ||||
|   log::*, | ||||
|   utils::*, | ||||
|  | @ -14,6 +14,74 @@ use std::net::SocketAddr; | |||
| ////////////////////////////////////////////////////
 | ||||
| // Functions to manipulate headers
 | ||||
| 
 | ||||
| /// Take sticky cookie header value from request header,
 | ||||
| /// and returns LbContext to be forwarded to LB if exist and if needed.
 | ||||
| /// Removing sticky cookie is needed and it must not be passed to the upstream.
 | ||||
| pub(super) fn takeout_sticky_cookie_lb_context( | ||||
|   headers: &mut HeaderMap, | ||||
|   expected_cookie_name: &str, | ||||
| ) -> Result<Option<LbContext>> { | ||||
|   let mut headers_clone = headers.clone(); | ||||
| 
 | ||||
|   match headers_clone.entry(hyper::header::COOKIE) { | ||||
|     header::Entry::Vacant(_) => Ok(None), | ||||
|     header::Entry::Occupied(entry) => { | ||||
|       let cookies_iter = entry | ||||
|         .iter() | ||||
|         .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); | ||||
|       let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter | ||||
|         .into_iter() | ||||
|         .partition(|v| v.starts_with(expected_cookie_name)); | ||||
|       if sticky_cookies.is_empty() { | ||||
|         return Ok(None); | ||||
|       } | ||||
|       if sticky_cookies.len() > 1 { | ||||
|         error!("Multiple sticky cookie values in request"); | ||||
|         return Err(RpxyError::Other(anyhow!( | ||||
|           "Invalid cookie: Multiple sticky cookie values" | ||||
|         ))); | ||||
|       } | ||||
|       let cookies_passed_to_upstream = without_sticky_cookies.join("; "); | ||||
|       let cookie_passed_to_lb = sticky_cookies.first().unwrap(); | ||||
|       headers.remove(hyper::header::COOKIE); | ||||
|       headers.insert(hyper::header::COOKIE, cookies_passed_to_upstream.parse()?); | ||||
| 
 | ||||
|       let sticky_cookie = StickyCookie { | ||||
|         value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, | ||||
|         info: None, | ||||
|       }; | ||||
|       Ok(Some(LbContext { sticky_cookie })) | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated.
 | ||||
| /// Set-Cookie response header could be in multiple lines.
 | ||||
| /// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie
 | ||||
| pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { | ||||
|   let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; | ||||
|   let new_header_val: HeaderValue = sticky_cookie_string.parse()?; | ||||
|   let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; | ||||
|   match headers.entry(hyper::header::SET_COOKIE) { | ||||
|     header::Entry::Vacant(entry) => { | ||||
|       entry.insert(new_header_val); | ||||
|     } | ||||
|     header::Entry::Occupied(mut entry) => { | ||||
|       let mut flag = false; | ||||
|       for e in entry.iter_mut() { | ||||
|         if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { | ||||
|           *e = new_header_val.clone(); | ||||
|           flag = true; | ||||
|         } | ||||
|       } | ||||
|       if !flag { | ||||
|         entry.append(new_header_val); | ||||
|       } | ||||
|     } | ||||
|   }; | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| pub(super) fn apply_upstream_options_to_header( | ||||
|   headers: &mut HeaderMap, | ||||
|   _client_addr: &SocketAddr, | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara