diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs index bda5497..8ff8b4b 100644 --- a/src/backend/load_balance.rs +++ b/src/backend/load_balance.rs @@ -1,4 +1,5 @@ use derive_builder::Builder; +use rand::Rng; use std::sync::{ atomic::{AtomicUsize, Ordering}, Arc, @@ -13,50 +14,85 @@ pub(super) mod load_balance_options { } #[derive(Debug, Clone, Builder)] -/// Counter object as a pointer to the current serving upstream destination -pub struct LbRoundRobinCount { +/// Round Robin LB object as a pointer to the current serving upstream destination +pub struct LbRoundRobin { #[builder(default)] - cnt: Arc, + /// Pointer to the index of the last served upstream destination + ptr: Arc, #[builder(setter(custom), default)] - max_val: usize, + /// Number of upstream destinations + num_upstreams: usize, } -impl LbRoundRobinCountBuilder { - pub fn max_val(&mut self, v: &usize) -> &mut Self { - self.max_val = Some(*v); +impl LbRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); self } } -impl LbRoundRobinCount { +impl LbRoundRobin { /// Get a current count of upstream served - fn current_cnt(&self) -> usize { - self.cnt.load(Ordering::Relaxed) + fn current_ptr(&self) -> usize { + self.ptr.load(Ordering::Relaxed) } /// Increment the count of upstream served up to the max value - pub fn increment_cnt(&self) -> usize { - if self.current_cnt() < self.max_val - 1 { - self.cnt.fetch_add(1, Ordering::Relaxed) + pub fn increment_ptr(&self) -> usize { + if self.current_ptr() < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) } else { // Clear the counter - self.cnt.fetch_and(0, Ordering::Relaxed) + self.ptr.fetch_and(0, Ordering::Relaxed) } } } +#[derive(Debug, Clone, Builder)] +/// Random LB object to keep the object of random pools +pub struct LbRandom { + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LbRandomBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LbRandom { + /// Returns the random index within the range + pub fn get_ptr(&self) -> usize { + let mut rng = rand::thread_rng(); + rng.gen_range(0..self.num_upstreams) + } +} + #[derive(Debug, Clone)] /// Load Balancing Option pub enum LoadBalance { /// Fix to the first upstream. Use if only one upstream destination is specified FixToFirst, /// Randomly chose one upstream server - Random, + Random(LbRandom), /// Simple round robin without session persistance - RoundRobin(LbRoundRobinCount), + RoundRobin(LbRoundRobin), /// Round robin with session persistance using cookie - StickyRoundRobin(LbRoundRobinCount), + StickyRoundRobin(LbRoundRobin), } impl Default for LoadBalance { fn default() -> Self { Self::FixToFirst } } + +impl LoadBalance { + /// Get the index of the upstream serving the incoming request + pub(super) fn get_idx(&self) -> usize { + match self { + LoadBalance::FixToFirst => 0usize, + LoadBalance::RoundRobin(ptr) => ptr.increment_ptr(), + LoadBalance::Random(v) => v.get_ptr(), + LoadBalance::StickyRoundRobin(_ptr) => 0usize, // todo!(), // TODO: TODO: TODO: TODO: tentative value + } + } +} diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index 11b186e..ac17a40 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,10 +1,9 @@ use super::{ - load_balance::{load_balance_options as lb_opts, LbRoundRobinCountBuilder, LoadBalance}, + load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}, BytesName, PathNameBytesExp, UpstreamOption, }; use crate::log::*; use derive_builder::Builder; -use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::borrow::Cow; @@ -96,16 +95,16 @@ impl UpstreamGroupBuilder { let lb = if let Some(x) = v { match x.as_str() { lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, - lb_opts::RANDOM => LoadBalance::Random, + lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( - LbRoundRobinCountBuilder::default() - .max_val(upstream_num) + LbRoundRobinBuilder::default() + .num_upstreams(upstream_num) .build() .unwrap(), ), lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( - LbRoundRobinCountBuilder::default() - .max_val(upstream_num) + LbRoundRobinBuilder::default() + .num_upstreams(upstream_num) .build() .unwrap(), ), @@ -137,18 +136,8 @@ impl UpstreamGroupBuilder { impl UpstreamGroup { /// Get an enabled option of load balancing [[LoadBalance]] pub fn get(&self) -> Option<&Upstream> { - match &self.lb { - LoadBalance::FixToFirst => self.upstream.get(0), - LoadBalance::RoundRobin(cnt) => { - let idx = cnt.increment_cnt(); - self.upstream.get(idx) - } - LoadBalance::Random => { - let mut rng = rand::thread_rng(); - let max = self.upstream.len() - 1; - self.upstream.get(rng.gen_range(0..max)) - } - LoadBalance::StickyRoundRobin(_cnt) => todo!(), // TODO: TODO: - } + let idx = self.lb.get_idx(); + debug!("Upstream of index {idx} is chosen."); + self.upstream.get(idx) } }