feat: initial implementation of sticky cookie for session persistance when load-balancing
This commit is contained in:
		
					parent
					
						
							
								f4c59c9f2f
							
						
					
				
			
			
				commit
				
					
						a0ae3578d7
					
				
			
		
					 11 changed files with 580 additions and 49 deletions
				
			
		|  | @ -65,6 +65,15 @@ h3 = { path = "./h3/h3/", optional = true } | ||||||
| # h3-quinn = { path = "./h3/h3-quinn/", optional = true } | # h3-quinn = { path = "./h3/h3-quinn/", optional = true } | ||||||
| h3-quinn = { path = "./h3-quinn/", optional = true } # Tentative to support rustls-0.21 | h3-quinn = { path = "./h3-quinn/", optional = true } # Tentative to support rustls-0.21 | ||||||
| 
 | 
 | ||||||
|  | # cookie handling | ||||||
|  | chrono = { version = "0.4.26", default-features = false, features = [ | ||||||
|  |   "unstable-locales", | ||||||
|  |   "alloc", | ||||||
|  |   "clock", | ||||||
|  | ] } | ||||||
|  | base64 = "0.21.2" | ||||||
|  | sha2 = { version = "0.10.6", default-features = false } | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| [target.'cfg(not(target_env = "msvc"))'.dependencies] | [target.'cfg(not(target_env = "msvc"))'.dependencies] | ||||||
| tikv-jemallocator = "0.5.0" | tikv-jemallocator = "0.5.0" | ||||||
|  |  | ||||||
|  | @ -1,8 +1,14 @@ | ||||||
|  | use super::{load_balance_sticky_cookie::StickyCookieConfig, LbContext, Upstream}; | ||||||
|  | use crate::{constants::STICKY_COOKIE_NAME, error::*, log::*}; | ||||||
| use derive_builder::Builder; | use derive_builder::Builder; | ||||||
| use rand::Rng; | use rand::Rng; | ||||||
| use std::sync::{ | use rustc_hash::FxHashMap as HashMap; | ||||||
|  | use std::{ | ||||||
|  |   borrow::Cow, | ||||||
|  |   sync::{ | ||||||
|     atomic::{AtomicUsize, Ordering}, |     atomic::{AtomicUsize, Ordering}, | ||||||
|     Arc, |     Arc, | ||||||
|  |   }, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| /// Constants to specify a load balance option
 | /// Constants to specify a load balance option
 | ||||||
|  | @ -13,6 +19,18 @@ pub(super) mod load_balance_options { | ||||||
|   pub const STICKY_ROUND_ROBIN: &str = "sticky"; |   pub const STICKY_ROUND_ROBIN: &str = "sticky"; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | /// Pointer to upstream serving the incoming request.
 | ||||||
|  | /// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given.
 | ||||||
|  | pub(super) struct PointerToUpstream { | ||||||
|  |   pub ptr: usize, | ||||||
|  |   pub context_lb: Option<LbContext>, | ||||||
|  | } | ||||||
|  | /// Trait for LB
 | ||||||
|  | trait LbWithPointer { | ||||||
|  |   fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; | ||||||
|  | } | ||||||
|  | 
 | ||||||
| #[derive(Debug, Clone, Builder)] | #[derive(Debug, Clone, Builder)] | ||||||
| /// Round Robin LB object as a pointer to the current serving upstream destination
 | /// Round Robin LB object as a pointer to the current serving upstream destination
 | ||||||
| pub struct LbRoundRobin { | pub struct LbRoundRobin { | ||||||
|  | @ -29,20 +47,19 @@ impl LbRoundRobinBuilder { | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
| } | } | ||||||
| impl LbRoundRobin { | impl LbWithPointer for LbRoundRobin { | ||||||
|   /// Get a current count of upstream served
 |  | ||||||
|   fn current_ptr(&self) -> usize { |  | ||||||
|     self.ptr.load(Ordering::Relaxed) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Increment the count of upstream served up to the max value
 |   /// Increment the count of upstream served up to the max value
 | ||||||
|   pub fn increment_ptr(&self) -> usize { |   fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { | ||||||
|     if self.current_ptr() < self.num_upstreams - 1 { |     // Get a current count of upstream served
 | ||||||
|  |     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||||
|  | 
 | ||||||
|  |     let ptr = if current_ptr < self.num_upstreams - 1 { | ||||||
|       self.ptr.fetch_add(1, Ordering::Relaxed) |       self.ptr.fetch_add(1, Ordering::Relaxed) | ||||||
|     } else { |     } else { | ||||||
|       // Clear the counter
 |       // Clear the counter
 | ||||||
|       self.ptr.fetch_and(0, Ordering::Relaxed) |       self.ptr.fetch_and(0, Ordering::Relaxed) | ||||||
|     } |     }; | ||||||
|  |     PointerToUpstream { ptr, context_lb: None } | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -59,11 +76,129 @@ impl LbRandomBuilder { | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
| } | } | ||||||
| impl LbRandom { | impl LbWithPointer for LbRandom { | ||||||
|   /// Returns the random index within the range
 |   /// Returns the random index within the range
 | ||||||
|   pub fn get_ptr(&self) -> usize { |   fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { | ||||||
|     let mut rng = rand::thread_rng(); |     let mut rng = rand::thread_rng(); | ||||||
|     rng.gen_range(0..self.num_upstreams) |     let ptr = rng.gen_range(0..self.num_upstreams); | ||||||
|  |     PointerToUpstream { ptr, context_lb: None } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone, Builder)] | ||||||
|  | /// Round Robin LB object in the sticky cookie manner
 | ||||||
|  | pub struct LbStickyRoundRobin { | ||||||
|  |   #[builder(default)] | ||||||
|  |   /// Pointer to the index of the last served upstream destination
 | ||||||
|  |   ptr: Arc<AtomicUsize>, | ||||||
|  |   #[builder(setter(custom), default)] | ||||||
|  |   /// Number of upstream destinations
 | ||||||
|  |   num_upstreams: usize, | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Information to build the cookie to stick clients to specific backends
 | ||||||
|  |   pub sticky_config: StickyCookieConfig, | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Hashmaps:
 | ||||||
|  |   /// - Hashmap that maps server indices to server id (string)
 | ||||||
|  |   /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 | ||||||
|  |   upstream_maps: UpstreamMap, | ||||||
|  | } | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | pub struct UpstreamMap { | ||||||
|  |   /// Hashmap that maps server indices to server id (string)
 | ||||||
|  |   upstream_index_map: Vec<String>, | ||||||
|  |   /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 | ||||||
|  |   upstream_id_map: HashMap<String, usize>, | ||||||
|  | } | ||||||
|  | impl LbStickyRoundRobinBuilder { | ||||||
|  |   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||||
|  |     self.num_upstreams = Some(*v); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option<String>) -> &mut Self { | ||||||
|  |     self.sticky_config = Some(StickyCookieConfig { | ||||||
|  |       name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように
 | ||||||
|  |       domain: server_name.to_ascii_lowercase(), | ||||||
|  |       path: if let Some(v) = path_opt { | ||||||
|  |         v.to_ascii_lowercase() | ||||||
|  |       } else { | ||||||
|  |         "/".to_string() | ||||||
|  |       }, | ||||||
|  |       duration: 300, // TODO: config等で変更できるように
 | ||||||
|  |     }); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||||
|  |     let upstream_index_map: Vec<String> = upstream_vec | ||||||
|  |       .iter() | ||||||
|  |       .enumerate() | ||||||
|  |       .map(|(i, v)| v.calculate_id_with_index(i)) | ||||||
|  |       .collect(); | ||||||
|  |     let mut upstream_id_map = HashMap::default(); | ||||||
|  |     for (i, v) in upstream_index_map.iter().enumerate() { | ||||||
|  |       upstream_id_map.insert(v.to_string(), i); | ||||||
|  |     } | ||||||
|  |     self.upstream_maps = Some(UpstreamMap { | ||||||
|  |       upstream_index_map, | ||||||
|  |       upstream_id_map, | ||||||
|  |     }); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl<'a> LbStickyRoundRobin { | ||||||
|  |   fn simple_increment_ptr(&self) -> usize { | ||||||
|  |     // Get a current count of upstream served
 | ||||||
|  |     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||||
|  | 
 | ||||||
|  |     if current_ptr < self.num_upstreams - 1 { | ||||||
|  |       self.ptr.fetch_add(1, Ordering::Relaxed) | ||||||
|  |     } else { | ||||||
|  |       // Clear the counter
 | ||||||
|  |       self.ptr.fetch_and(0, Ordering::Relaxed) | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   /// This is always called only internally. So 'unwrap()' is executed.
 | ||||||
|  |   fn get_server_id_from_index(&self, index: usize) -> String { | ||||||
|  |     self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() | ||||||
|  |   } | ||||||
|  |   /// This function takes value passed from outside. So 'result' is used.
 | ||||||
|  |   fn get_server_index_from_id(&self, id: impl Into<Cow<'a, str>>) -> Option<usize> { | ||||||
|  |     let id_str = id.into().to_string(); | ||||||
|  |     self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl LbWithPointer for LbStickyRoundRobin { | ||||||
|  |   fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { | ||||||
|  |     // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer.
 | ||||||
|  |     // Otherwise, get the server index indicated by the server_id inside the cookie
 | ||||||
|  |     let ptr = match req_info { | ||||||
|  |       None => { | ||||||
|  |         debug!("No sticky cookie"); | ||||||
|  |         self.simple_increment_ptr() | ||||||
|  |       } | ||||||
|  |       Some(context) => { | ||||||
|  |         let server_id = &context.sticky_cookie.value.value; | ||||||
|  |         if let Some(server_index) = self.get_server_index_from_id(server_id) { | ||||||
|  |           debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); | ||||||
|  |           server_index | ||||||
|  |         } else { | ||||||
|  |           debug!("Invalid sticky cookie: id={}", server_id); | ||||||
|  |           self.simple_increment_ptr() | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     // Get the server id from the ptr.
 | ||||||
|  |     // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie).
 | ||||||
|  |     let upstream_id = self.get_server_id_from_index(ptr); | ||||||
|  |     let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); | ||||||
|  |     let new_context = Some(LbContext { | ||||||
|  |       sticky_cookie: new_cookie, | ||||||
|  |     }); | ||||||
|  |     PointerToUpstream { | ||||||
|  |       ptr, | ||||||
|  |       context_lb: new_context, | ||||||
|  |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -77,7 +212,7 @@ pub enum LoadBalance { | ||||||
|   /// Simple round robin without session persistance
 |   /// Simple round robin without session persistance
 | ||||||
|   RoundRobin(LbRoundRobin), |   RoundRobin(LbRoundRobin), | ||||||
|   /// Round robin with session persistance using cookie
 |   /// Round robin with session persistance using cookie
 | ||||||
|   StickyRoundRobin(LbRoundRobin), |   StickyRoundRobin(LbStickyRoundRobin), | ||||||
| } | } | ||||||
| impl Default for LoadBalance { | impl Default for LoadBalance { | ||||||
|   fn default() -> Self { |   fn default() -> Self { | ||||||
|  | @ -87,12 +222,18 @@ impl Default for LoadBalance { | ||||||
| 
 | 
 | ||||||
| impl LoadBalance { | impl LoadBalance { | ||||||
|   /// Get the index of the upstream serving the incoming request
 |   /// Get the index of the upstream serving the incoming request
 | ||||||
|   pub(super) fn get_idx(&self) -> usize { |   pub(super) fn get_context(&self, context_to_lb: &Option<LbContext>) -> PointerToUpstream { | ||||||
|     match self { |     match self { | ||||||
|       LoadBalance::FixToFirst => 0usize, |       LoadBalance::FixToFirst => PointerToUpstream { | ||||||
|       LoadBalance::RoundRobin(ptr) => ptr.increment_ptr(), |         ptr: 0usize, | ||||||
|       LoadBalance::Random(v) => v.get_ptr(), |         context_lb: None, | ||||||
|       LoadBalance::StickyRoundRobin(_ptr) => 0usize, // todo!(), // TODO: TODO: TODO: TODO: tentative value
 |       }, | ||||||
|  |       LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), | ||||||
|  |       LoadBalance::Random(ptr) => ptr.get_ptr(None), | ||||||
|  |       LoadBalance::StickyRoundRobin(ptr) => { | ||||||
|  |         // Generate new context if sticky round robin is enabled.
 | ||||||
|  |         ptr.get_ptr(context_to_lb.as_ref()) | ||||||
|  |       } | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										216
									
								
								src/backend/load_balance_sticky_cookie.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										216
									
								
								src/backend/load_balance_sticky_cookie.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,216 @@ | ||||||
|  | use std::borrow::Cow; | ||||||
|  | 
 | ||||||
|  | use crate::error::*; | ||||||
|  | use chrono::{TimeZone, Utc}; | ||||||
|  | use derive_builder::Builder; | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | /// Struct to handle the sticky cookie string,
 | ||||||
|  | /// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists.
 | ||||||
|  | /// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist.
 | ||||||
|  | pub struct LbContext { | ||||||
|  |   pub sticky_cookie: StickyCookie, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone, Builder)] | ||||||
|  | /// Cookie value only, used for COOKIE in req
 | ||||||
|  | pub struct StickyCookieValue { | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Field name indicating sticky cookie
 | ||||||
|  |   pub name: String, | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Upstream server_id
 | ||||||
|  |   pub value: String, | ||||||
|  | } | ||||||
|  | impl<'a> StickyCookieValueBuilder { | ||||||
|  |   pub fn name(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||||
|  |     self.name = Some(v.into().to_ascii_lowercase()); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn value(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||||
|  |     self.value = Some(v.into().to_string()); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl StickyCookieValue { | ||||||
|  |   pub fn try_from(value: &str, expected_name: &str) -> Result<Self> { | ||||||
|  |     if !value.starts_with(expected_name) { | ||||||
|  |       return Err(RpxyError::LoadBalance( | ||||||
|  |         "Failed to cookie conversion from string".to_string(), | ||||||
|  |       )); | ||||||
|  |     }; | ||||||
|  |     let kv = value.split('=').map(|v| v.trim()).collect::<Vec<&str>>(); | ||||||
|  |     if kv.len() != 2 { | ||||||
|  |       return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); | ||||||
|  |     }; | ||||||
|  |     if kv[1].is_empty() { | ||||||
|  |       return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); | ||||||
|  |     } | ||||||
|  |     Ok(StickyCookieValue { | ||||||
|  |       name: expected_name.to_string(), | ||||||
|  |       value: kv[1].to_string(), | ||||||
|  |     }) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone, Builder)] | ||||||
|  | /// Struct describing sticky cookie meta information used for SET-COOKIE in res
 | ||||||
|  | pub struct StickyCookieInfo { | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Unix time
 | ||||||
|  |   pub expires: i64, | ||||||
|  | 
 | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Domain
 | ||||||
|  |   pub domain: String, | ||||||
|  | 
 | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Path
 | ||||||
|  |   pub path: String, | ||||||
|  | } | ||||||
|  | impl<'a> StickyCookieInfoBuilder { | ||||||
|  |   pub fn domain(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||||
|  |     self.domain = Some(v.into().to_ascii_lowercase()); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn path(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||||
|  |     self.path = Some(v.into().to_ascii_lowercase()); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn expires(&mut self, duration_secs: i64) -> &mut Self { | ||||||
|  |     let current = Utc::now().timestamp(); | ||||||
|  |     self.expires = Some(current + duration_secs); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone, Builder)] | ||||||
|  | /// Struct describing sticky cookie
 | ||||||
|  | pub struct StickyCookie { | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// Upstream server_id
 | ||||||
|  |   pub value: StickyCookieValue, | ||||||
|  |   #[builder(setter(custom), default)] | ||||||
|  |   /// Upstream server_id
 | ||||||
|  |   pub info: Option<StickyCookieInfo>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<'a> StickyCookieBuilder { | ||||||
|  |   pub fn value(&mut self, n: impl Into<Cow<'a, str>>, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||||
|  |     self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn info( | ||||||
|  |     &mut self, | ||||||
|  |     domain: impl Into<Cow<'a, str>>, | ||||||
|  |     path: impl Into<Cow<'a, str>>, | ||||||
|  |     duration_secs: i64, | ||||||
|  |   ) -> &mut Self { | ||||||
|  |     let info = StickyCookieInfoBuilder::default() | ||||||
|  |       .domain(domain) | ||||||
|  |       .path(path) | ||||||
|  |       .expires(duration_secs) | ||||||
|  |       .build() | ||||||
|  |       .unwrap(); | ||||||
|  |     self.info = Some(Some(info)); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl TryInto<String> for StickyCookie { | ||||||
|  |   type Error = RpxyError; | ||||||
|  | 
 | ||||||
|  |   fn try_into(self) -> Result<String> { | ||||||
|  |     if self.info.is_none() { | ||||||
|  |       return Err(RpxyError::LoadBalance( | ||||||
|  |         "Failed to cookie conversion into string: no meta information".to_string(), | ||||||
|  |       )); | ||||||
|  |     } | ||||||
|  |     let info = self.info.unwrap(); | ||||||
|  |     let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { | ||||||
|  |       return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); | ||||||
|  |     }; | ||||||
|  |     let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); | ||||||
|  |     let max_age = info.expires - Utc::now().timestamp(); | ||||||
|  | 
 | ||||||
|  |     Ok(format!( | ||||||
|  |       "{}={}; expires={}; Max-Age={}; path={}; domain={}; HttpOnly", | ||||||
|  |       self.value.name, self.value.value, exp_str, max_age, info.path, info.domain | ||||||
|  |     )) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | /// Configuration to serve incoming requests in the manner of "sticky cookie".
 | ||||||
|  | /// Including a dictionary to map Ids included in cookie and upstream destinations,
 | ||||||
|  | /// and expiration of cookie.
 | ||||||
|  | /// "domain" and "path" in the cookie will be the same as the reverse proxy options.
 | ||||||
|  | pub struct StickyCookieConfig { | ||||||
|  |   pub name: String, | ||||||
|  |   pub domain: String, | ||||||
|  |   pub path: String, | ||||||
|  |   pub duration: i64, | ||||||
|  | } | ||||||
|  | impl<'a> StickyCookieConfig { | ||||||
|  |   pub fn build_sticky_cookie(&self, v: impl Into<Cow<'a, str>>) -> Result<StickyCookie> { | ||||||
|  |     StickyCookieBuilder::default() | ||||||
|  |       .value(self.name.clone(), v) | ||||||
|  |       .info(&self.domain, &self.path, self.duration) | ||||||
|  |       .build() | ||||||
|  |       .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(test)] | ||||||
|  | mod tests { | ||||||
|  |   use super::*; | ||||||
|  |   use crate::constants::STICKY_COOKIE_NAME; | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn config_works() { | ||||||
|  |     let config = StickyCookieConfig { | ||||||
|  |       name: STICKY_COOKIE_NAME.to_string(), | ||||||
|  |       domain: "example.com".to_string(), | ||||||
|  |       path: "/path".to_string(), | ||||||
|  |       duration: 100, | ||||||
|  |     }; | ||||||
|  |     let expires_unix = Utc::now().timestamp() + 100; | ||||||
|  |     let sc_string: Result<String> = config.build_sticky_cookie("test_value").unwrap().try_into(); | ||||||
|  |     let expires_date_string = Utc | ||||||
|  |       .timestamp_opt(expires_unix, 0) | ||||||
|  |       .unwrap() | ||||||
|  |       .format("%a, %d-%b-%Y %T GMT") | ||||||
|  |       .to_string(); | ||||||
|  |     assert_eq!( | ||||||
|  |       sc_string.unwrap(), | ||||||
|  |       format!( | ||||||
|  |         "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com; HttpOnly", | ||||||
|  |         STICKY_COOKIE_NAME, expires_date_string, 100 | ||||||
|  |       ) | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  |   #[test] | ||||||
|  |   fn to_string_works() { | ||||||
|  |     let sc = StickyCookie { | ||||||
|  |       value: StickyCookieValue { | ||||||
|  |         name: STICKY_COOKIE_NAME.to_string(), | ||||||
|  |         value: "test_value".to_string(), | ||||||
|  |       }, | ||||||
|  |       info: Some(StickyCookieInfo { | ||||||
|  |         expires: 1686221173i64, | ||||||
|  |         domain: "example.com".to_string(), | ||||||
|  |         path: "/path".to_string(), | ||||||
|  |       }), | ||||||
|  |     }; | ||||||
|  |     let sc_string: Result<String> = sc.try_into(); | ||||||
|  |     let max_age = 1686221173i64 - Utc::now().timestamp(); | ||||||
|  |     assert!(sc_string.is_ok()); | ||||||
|  |     assert_eq!( | ||||||
|  |       sc_string.unwrap(), | ||||||
|  |       format!( | ||||||
|  |         "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com; HttpOnly", | ||||||
|  |         STICKY_COOKIE_NAME, max_age | ||||||
|  |       ) | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -1,7 +1,14 @@ | ||||||
| mod load_balance; | mod load_balance; | ||||||
|  | mod load_balance_sticky_cookie; | ||||||
| mod upstream; | mod upstream; | ||||||
| mod upstream_opts; | mod upstream_opts; | ||||||
| 
 | 
 | ||||||
|  | pub use self::{ | ||||||
|  |   load_balance::LoadBalance, | ||||||
|  |   load_balance_sticky_cookie::{LbContext, StickyCookie, StickyCookieBuilder, StickyCookieValue}, | ||||||
|  |   upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, | ||||||
|  |   upstream_opts::UpstreamOption, | ||||||
|  | }; | ||||||
| use crate::{ | use crate::{ | ||||||
|   log::*, |   log::*, | ||||||
|   utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, |   utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, | ||||||
|  | @ -21,8 +28,6 @@ use tokio_rustls::rustls::{ | ||||||
|   sign::{any_supported_type, CertifiedKey}, |   sign::{any_supported_type, CertifiedKey}, | ||||||
|   Certificate, PrivateKey, ServerConfig, |   Certificate, PrivateKey, ServerConfig, | ||||||
| }; | }; | ||||||
| pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}; |  | ||||||
| pub use upstream_opts::UpstreamOption; |  | ||||||
| use x509_parser::prelude::*; | use x509_parser::prelude::*; | ||||||
| 
 | 
 | ||||||
| /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 | /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 | ||||||
|  |  | ||||||
|  | @ -1,12 +1,16 @@ | ||||||
| use super::{ | use super::{ | ||||||
|   load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}, |   load_balance::{ | ||||||
|  |     load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LbStickyRoundRobinBuilder, LoadBalance, | ||||||
|  |   }, | ||||||
|  |   load_balance_sticky_cookie::LbContext, | ||||||
|   BytesName, PathNameBytesExp, UpstreamOption, |   BytesName, PathNameBytesExp, UpstreamOption, | ||||||
| }; | }; | ||||||
| use crate::log::*; | use crate::log::*; | ||||||
|  | use base64::{engine::general_purpose, Engine as _}; | ||||||
| use derive_builder::Builder; | use derive_builder::Builder; | ||||||
| use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; | use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; | ||||||
|  | use sha2::{Digest, Sha256}; | ||||||
| use std::borrow::Cow; | use std::borrow::Cow; | ||||||
| 
 |  | ||||||
| #[derive(Debug, Clone)] | #[derive(Debug, Clone)] | ||||||
| pub struct ReverseProxy { | pub struct ReverseProxy { | ||||||
|   pub upstream: HashMap<PathNameBytesExp, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 |   pub upstream: HashMap<PathNameBytesExp, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 | ||||||
|  | @ -53,10 +57,20 @@ pub struct Upstream { | ||||||
|   /// Base uri without specific path
 |   /// Base uri without specific path
 | ||||||
|   pub uri: hyper::Uri, |   pub uri: hyper::Uri, | ||||||
| } | } | ||||||
| 
 | impl Upstream { | ||||||
|  |   /// Hashing uri with index to avoid collision
 | ||||||
|  |   pub fn calculate_id_with_index(&self, index: usize) -> String { | ||||||
|  |     let mut hasher = Sha256::new(); | ||||||
|  |     let uri_string = format!("{}&index={}", self.uri.clone(), index); | ||||||
|  |     hasher.update(uri_string.as_bytes()); | ||||||
|  |     let digest = hasher.finalize(); | ||||||
|  |     general_purpose::URL_SAFE_NO_PAD.encode(digest) | ||||||
|  |   } | ||||||
|  | } | ||||||
| #[derive(Debug, Clone, Builder)] | #[derive(Debug, Clone, Builder)] | ||||||
| /// Struct serving multiple upstream servers for, e.g., load balancing.
 | /// Struct serving multiple upstream servers for, e.g., load balancing.
 | ||||||
| pub struct UpstreamGroup { | pub struct UpstreamGroup { | ||||||
|  |   #[builder(setter(custom))] | ||||||
|   /// Upstream server(s)
 |   /// Upstream server(s)
 | ||||||
|   pub upstream: Vec<Upstream>, |   pub upstream: Vec<Upstream>, | ||||||
|   #[builder(setter(custom), default)] |   #[builder(setter(custom), default)] | ||||||
|  | @ -75,6 +89,10 @@ pub struct UpstreamGroup { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl UpstreamGroupBuilder { | impl UpstreamGroupBuilder { | ||||||
|  |   pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||||
|  |     self.upstream = Some(upstream_vec.to_vec()); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|   pub fn path(&mut self, v: &Option<String>) -> &mut Self { |   pub fn path(&mut self, v: &Option<String>) -> &mut Self { | ||||||
|     let path = match v { |     let path = match v { | ||||||
|       Some(p) => p.to_path_name_vec(), |       Some(p) => p.to_path_name_vec(), | ||||||
|  | @ -91,7 +109,15 @@ impl UpstreamGroupBuilder { | ||||||
|     ); |     ); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|   pub fn lb(&mut self, v: &Option<String>, upstream_num: &usize) -> &mut Self { |   pub fn lb( | ||||||
|  |     &mut self, | ||||||
|  |     v: &Option<String>, | ||||||
|  |     // upstream_num: &usize,
 | ||||||
|  |     upstream_vec: &Vec<Upstream>, | ||||||
|  |     server_name: &str, | ||||||
|  |     path_opt: &Option<String>, | ||||||
|  |   ) -> &mut Self { | ||||||
|  |     let upstream_num = &upstream_vec.len(); | ||||||
|     let lb = if let Some(x) = v { |     let lb = if let Some(x) = v { | ||||||
|       match x.as_str() { |       match x.as_str() { | ||||||
|         lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, |         lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, | ||||||
|  | @ -103,8 +129,10 @@ impl UpstreamGroupBuilder { | ||||||
|             .unwrap(), |             .unwrap(), | ||||||
|         ), |         ), | ||||||
|         lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( |         lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( | ||||||
|           LbRoundRobinBuilder::default() |           LbStickyRoundRobinBuilder::default() | ||||||
|             .num_upstreams(upstream_num) |             .num_upstreams(upstream_num) | ||||||
|  |             .sticky_config(server_name, path_opt) | ||||||
|  |             .upstream_maps(upstream_vec) // TODO:
 | ||||||
|             .build() |             .build() | ||||||
|             .unwrap(), |             .unwrap(), | ||||||
|         ), |         ), | ||||||
|  | @ -135,9 +163,35 @@ impl UpstreamGroupBuilder { | ||||||
| 
 | 
 | ||||||
| impl UpstreamGroup { | impl UpstreamGroup { | ||||||
|   /// Get an enabled option of load balancing [[LoadBalance]]
 |   /// Get an enabled option of load balancing [[LoadBalance]]
 | ||||||
|   pub fn get(&self) -> Option<&Upstream> { |   pub fn get(&self, context_to_lb: &Option<LbContext>) -> (Option<&Upstream>, Option<LbContext>) { | ||||||
|     let idx = self.lb.get_idx(); |     let pointer_to_upstream = self.lb.get_context(context_to_lb); | ||||||
|     debug!("Upstream of index {idx} is chosen."); |     debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); | ||||||
|     self.upstream.get(idx) |     debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); | ||||||
|  |     debug!( | ||||||
|  |       "Context from LB (Set-Cookie in Res): {:?}", | ||||||
|  |       pointer_to_upstream.context_lb | ||||||
|  |     ); | ||||||
|  |     ( | ||||||
|  |       self.upstream.get(pointer_to_upstream.ptr), | ||||||
|  |       pointer_to_upstream.context_lb, | ||||||
|  |     ) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(test)] | ||||||
|  | mod test { | ||||||
|  |   use super::*; | ||||||
|  |   #[test] | ||||||
|  |   fn calc_id_works() { | ||||||
|  |     let uri = "https://www.rust-lang.org".parse::<hyper::Uri>().unwrap(); | ||||||
|  |     let upstream = Upstream { uri }; | ||||||
|  |     assert_eq!( | ||||||
|  |       "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", | ||||||
|  |       upstream.calculate_id_with_index(0) | ||||||
|  |     ); | ||||||
|  |     assert_eq!( | ||||||
|  |       "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", | ||||||
|  |       upstream.calculate_id_with_index(1) | ||||||
|  |     ); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -99,7 +99,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro | ||||||
|     let mut backend_builder = BackendBuilder::default(); |     let mut backend_builder = BackendBuilder::default(); | ||||||
|     // reverse proxy settings
 |     // reverse proxy settings
 | ||||||
|     ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); |     ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); | ||||||
|     let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; |     let reverse_proxy = get_reverse_proxy(server_name_string, app.reverse_proxy.as_ref().unwrap())?; | ||||||
| 
 | 
 | ||||||
|     backend_builder |     backend_builder | ||||||
|       .app_name(server_name_string) |       .app_name(server_name_string) | ||||||
|  | @ -198,17 +198,21 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro | ||||||
|   Ok(()) |   Ok(()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result<ReverseProxy, anyhow::Error> { | fn get_reverse_proxy( | ||||||
|  |   server_name_string: &str, | ||||||
|  |   rp_settings: &[ReverseProxyOption], | ||||||
|  | ) -> std::result::Result<ReverseProxy, anyhow::Error> { | ||||||
|   let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default(); |   let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default(); | ||||||
| 
 | 
 | ||||||
|   rp_settings.iter().for_each(|rpo| { |   rp_settings.iter().for_each(|rpo| { | ||||||
|     let vec_upstream: Vec<Upstream> = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); |     let upstream_vec: Vec<Upstream> = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); | ||||||
|     let lb_upstream_num = vec_upstream.len(); |     // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap());
 | ||||||
|  |     // let lb_upstream_num = vec_upstream.len();
 | ||||||
|     let elem = UpstreamGroupBuilder::default() |     let elem = UpstreamGroupBuilder::default() | ||||||
|       .upstream(vec_upstream) |       .upstream(&upstream_vec) | ||||||
|       .path(&rpo.path) |       .path(&rpo.path) | ||||||
|       .replace_path(&rpo.replace_path) |       .replace_path(&rpo.replace_path) | ||||||
|       .lb(&rpo.load_balance, &lb_upstream_num) |       .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) | ||||||
|       .opts(&rpo.upstream_options) |       .opts(&rpo.upstream_options) | ||||||
|       .build() |       .build() | ||||||
|       .unwrap(); |       .unwrap(); | ||||||
|  |  | ||||||
|  | @ -24,3 +24,6 @@ pub mod H3 { | ||||||
|   pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; |   pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; | ||||||
|   pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs
 |   pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs
 | ||||||
| } | } | ||||||
|  | 
 | ||||||
|  | // For load-balancing with sticky cookie
 | ||||||
|  | pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; | ||||||
|  |  | ||||||
|  | @ -22,6 +22,9 @@ pub enum RpxyError { | ||||||
|   #[error("TCP/UDP Proxy Layer Error: {0}")] |   #[error("TCP/UDP Proxy Layer Error: {0}")] | ||||||
|   Proxy(String), |   Proxy(String), | ||||||
| 
 | 
 | ||||||
|  |   #[error("LoadBalance Layer Error: {0}")] | ||||||
|  |   LoadBalance(String), | ||||||
|  | 
 | ||||||
|   #[error("I/O Error")] |   #[error("I/O Error")] | ||||||
|   Io(#[from] io::Error), |   Io(#[from] io::Error), | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,7 +1,7 @@ | ||||||
| // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 | // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 | ||||||
| use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; | use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext}; | ||||||
| use crate::{ | use crate::{ | ||||||
|   backend::{Backend, UpstreamGroup}, |   backend::{Backend, LoadBalance, UpstreamGroup}, | ||||||
|   error::*, |   error::*, | ||||||
|   globals::Globals, |   globals::Globals, | ||||||
|   log::*, |   log::*, | ||||||
|  | @ -91,7 +91,7 @@ where | ||||||
|     let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>(); |     let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>(); | ||||||
| 
 | 
 | ||||||
|     // Build request from destination information
 |     // Build request from destination information
 | ||||||
|     if let Err(e) = self.generate_request_forwarded( |     let context = match self.generate_request_forwarded( | ||||||
|       &client_addr, |       &client_addr, | ||||||
|       &listen_addr, |       &listen_addr, | ||||||
|       &mut req, |       &mut req, | ||||||
|  | @ -99,8 +99,11 @@ where | ||||||
|       upstream_group, |       upstream_group, | ||||||
|       tls_enabled, |       tls_enabled, | ||||||
|     ) { |     ) { | ||||||
|  |       Err(e) => { | ||||||
|         error!("Failed to generate destination uri for reverse proxy: {}", e); |         error!("Failed to generate destination uri for reverse proxy: {}", e); | ||||||
|         return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); |         return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); | ||||||
|  |       } | ||||||
|  |       Ok(v) => v, | ||||||
|     }; |     }; | ||||||
|     debug!("Request to be forwarded: {:?}", req); |     debug!("Request to be forwarded: {:?}", req); | ||||||
|     log_data.xff(&req.headers().get("x-forwarded-for")); |     log_data.xff(&req.headers().get("x-forwarded-for")); | ||||||
|  | @ -123,6 +126,15 @@ where | ||||||
|       } |       } | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     // Process reverse proxy context generated during the forwarding request generation.
 | ||||||
|  |     if let Some(context_from_lb) = context.context_lb { | ||||||
|  |       let res_headers = res_backend.headers_mut(); | ||||||
|  |       if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { | ||||||
|  |         error!("Failed to append context to the response given from backend: {}", e); | ||||||
|  |         return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|     if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { |     if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { | ||||||
|       // Generate response to client
 |       // Generate response to client
 | ||||||
|       if self.generate_response_forwarded(&mut res_backend, backend).is_ok() { |       if self.generate_response_forwarded(&mut res_backend, backend).is_ok() { | ||||||
|  | @ -229,7 +241,7 @@ where | ||||||
|     upgrade: &Option<String>, |     upgrade: &Option<String>, | ||||||
|     upstream_group: &UpstreamGroup, |     upstream_group: &UpstreamGroup, | ||||||
|     tls_enabled: bool, |     tls_enabled: bool, | ||||||
|   ) -> Result<()> { |   ) -> Result<HandlerContext> { | ||||||
|     debug!("Generate request to be forwarded"); |     debug!("Generate request to be forwarded"); | ||||||
| 
 | 
 | ||||||
|     // Add te: trailer if contained in original request
 |     // Add te: trailer if contained in original request
 | ||||||
|  | @ -265,10 +277,19 @@ where | ||||||
|         .insert(header::HOST, HeaderValue::from_str(&org_host)?); |         .insert(header::HOST, HeaderValue::from_str(&org_host)?); | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|  |     /////////////////////////////////////////////
 | ||||||
|     // Fix unique upstream destination since there could be multiple ones.
 |     // Fix unique upstream destination since there could be multiple ones.
 | ||||||
|     // TODO: StickyならCookieをここでgetに与える必要
 |     let context_to_lb = if let LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { | ||||||
|     // TODO: Stickyで、Cookieが与えられなかったらset-cookie向けにcookieを返す必要。upstreamオブジェクトに含めるのも手。
 |       takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? | ||||||
|     let upstream_chosen = upstream_group.get().ok_or_else(|| anyhow!("Failed to get upstream"))?; |     } else { | ||||||
|  |       None | ||||||
|  |     }; | ||||||
|  |     let (upstream_chosen_opt, context_from_lb) = upstream_group.get(&context_to_lb); | ||||||
|  |     let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; | ||||||
|  |     let context = HandlerContext { | ||||||
|  |       context_lb: context_from_lb, | ||||||
|  |     }; | ||||||
|  |     /////////////////////////////////////////////
 | ||||||
| 
 | 
 | ||||||
|     // apply upstream-specific headers given in upstream_option
 |     // apply upstream-specific headers given in upstream_option
 | ||||||
|     let headers = req.headers_mut(); |     let headers = req.headers_mut(); | ||||||
|  | @ -321,6 +342,6 @@ where | ||||||
|       *req.version_mut() = Version::HTTP_2; |       *req.version_mut() = Version::HTTP_2; | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     Ok(()) |     Ok(context) | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,3 +4,10 @@ mod utils_request; | ||||||
| mod utils_synth_response; | mod utils_synth_response; | ||||||
| 
 | 
 | ||||||
| pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; | pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; | ||||||
|  | 
 | ||||||
|  | use crate::backend::LbContext; | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | struct HandlerContext { | ||||||
|  |   context_lb: Option<LbContext>, | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| use crate::{ | use crate::{ | ||||||
|   backend::{UpstreamGroup, UpstreamOption}, |   backend::{LbContext, StickyCookie, StickyCookieValue, UpstreamGroup, UpstreamOption}, | ||||||
|   error::*, |   error::*, | ||||||
|   log::*, |   log::*, | ||||||
|   utils::*, |   utils::*, | ||||||
|  | @ -14,6 +14,74 @@ use std::net::SocketAddr; | ||||||
| ////////////////////////////////////////////////////
 | ////////////////////////////////////////////////////
 | ||||||
| // Functions to manipulate headers
 | // Functions to manipulate headers
 | ||||||
| 
 | 
 | ||||||
|  | /// Take sticky cookie header value from request header,
 | ||||||
|  | /// and returns LbContext to be forwarded to LB if exist and if needed.
 | ||||||
|  | /// Removing sticky cookie is needed and it must not be passed to the upstream.
 | ||||||
|  | pub(super) fn takeout_sticky_cookie_lb_context( | ||||||
|  |   headers: &mut HeaderMap, | ||||||
|  |   expected_cookie_name: &str, | ||||||
|  | ) -> Result<Option<LbContext>> { | ||||||
|  |   let mut headers_clone = headers.clone(); | ||||||
|  | 
 | ||||||
|  |   match headers_clone.entry(hyper::header::COOKIE) { | ||||||
|  |     header::Entry::Vacant(_) => Ok(None), | ||||||
|  |     header::Entry::Occupied(entry) => { | ||||||
|  |       let cookies_iter = entry | ||||||
|  |         .iter() | ||||||
|  |         .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); | ||||||
|  |       let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter | ||||||
|  |         .into_iter() | ||||||
|  |         .partition(|v| v.starts_with(expected_cookie_name)); | ||||||
|  |       if sticky_cookies.is_empty() { | ||||||
|  |         return Ok(None); | ||||||
|  |       } | ||||||
|  |       if sticky_cookies.len() > 1 { | ||||||
|  |         error!("Multiple sticky cookie values in request"); | ||||||
|  |         return Err(RpxyError::Other(anyhow!( | ||||||
|  |           "Invalid cookie: Multiple sticky cookie values" | ||||||
|  |         ))); | ||||||
|  |       } | ||||||
|  |       let cookies_passed_to_upstream = without_sticky_cookies.join("; "); | ||||||
|  |       let cookie_passed_to_lb = sticky_cookies.first().unwrap(); | ||||||
|  |       headers.remove(hyper::header::COOKIE); | ||||||
|  |       headers.insert(hyper::header::COOKIE, cookies_passed_to_upstream.parse()?); | ||||||
|  | 
 | ||||||
|  |       let sticky_cookie = StickyCookie { | ||||||
|  |         value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, | ||||||
|  |         info: None, | ||||||
|  |       }; | ||||||
|  |       Ok(Some(LbContext { sticky_cookie })) | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated.
 | ||||||
|  | /// Set-Cookie response header could be in multiple lines.
 | ||||||
|  | /// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie
 | ||||||
|  | pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { | ||||||
|  |   let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; | ||||||
|  |   let new_header_val: HeaderValue = sticky_cookie_string.parse()?; | ||||||
|  |   let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; | ||||||
|  |   match headers.entry(hyper::header::SET_COOKIE) { | ||||||
|  |     header::Entry::Vacant(entry) => { | ||||||
|  |       entry.insert(new_header_val); | ||||||
|  |     } | ||||||
|  |     header::Entry::Occupied(mut entry) => { | ||||||
|  |       let mut flag = false; | ||||||
|  |       for e in entry.iter_mut() { | ||||||
|  |         if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { | ||||||
|  |           *e = new_header_val.clone(); | ||||||
|  |           flag = true; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       if !flag { | ||||||
|  |         entry.append(new_header_val); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  |   Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| pub(super) fn apply_upstream_options_to_header( | pub(super) fn apply_upstream_options_to_header( | ||||||
|   headers: &mut HeaderMap, |   headers: &mut HeaderMap, | ||||||
|   _client_addr: &SocketAddr, |   _client_addr: &SocketAddr, | ||||||
|  |  | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara