diff --git a/Cargo.toml b/Cargo.toml index 3d2d26e..21211a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ hyper-trust-dns = { version = "0.4.2", default-features = false, features = [ "rustls-webpki", ] } rustls = "0.20.6" +rand = "0.8.5" [dev-dependencies] diff --git a/config-example.toml b/config-example.toml index e3eace9..b8b846f 100644 --- a/config-example.toml +++ b/config-example.toml @@ -14,13 +14,20 @@ https_port = 8443 ################################### [[application]] -app_name = 'localhost' # this should be option, if null then same as hostname +app_name = 'localhost' # this should be option, if null then same as server_name hostname = 'localhost' https_redirection = true reverse_proxy = [ # default destination if path is not specified - { destination = 'www.google.com', tls = true }, - { destination = 'www.bing.com', path = '/maps', tls = true }, + # TODO: Array for load balancing + { upstream = [ + { location = 'www.google.com', tls = true }, + { location = 'www.google.co.jp', tls = true }, + ] }, + { path = '/maps', upstream = [ + { location = 'www.bing.com', tls = true }, + { location = 'www.bing.co.jp', tls = true }, + ] }, ] ## List of destinations to send data to. ## At this point, round-robin is used for load-balancing if multiple URLs are specified. @@ -34,6 +41,6 @@ tls_cert_key_path = 'localhost.pem' app_name = 'locahost_application' hostname = 'localhost.localdomain' https_redirection = true -reverse_proxy = [{ destination = 'www.google.com', tls = true }] +reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] tls_cert_path = 'localhost.pem' tls_cert_key_path = 'localhost.pem' diff --git a/src/backend.rs b/src/backend.rs index 2d13479..aebb31e 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,10 +1,14 @@ use crate::log::*; +use rand::Rng; use std::{ collections::HashMap, fs::File, io::{self, BufReader, Cursor, Read}, path::PathBuf, - sync::Mutex, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, }; use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; @@ -20,8 +24,58 @@ pub struct Backend { #[derive(Debug, Clone)] pub struct ReverseProxy { - pub default_destination_uri: hyper::Uri, - pub destination_uris: HashMap, // TODO: url pathで引っ掛ける。 + pub default_upstream: Upstream, + pub upstream: HashMap, +} + +#[allow(dead_code)] +#[derive(Debug, Clone)] +pub enum LoadBalance { + RoundRobin, + Random, +} +impl Default for LoadBalance { + fn default() -> Self { + Self::RoundRobin + } +} + +#[derive(Debug, Clone)] +pub struct Upstream { + pub uri: Vec, + pub lb: LoadBalance, + pub cnt: UpstreamCount, // counter for load balancing +} + +#[derive(Debug, Clone, Default)] +pub struct UpstreamCount(Arc); + +impl Upstream { + pub fn get(&self) -> Option<&hyper::Uri> { + match self.lb { + LoadBalance::RoundRobin => { + let idx = self.increment_cnt(); + self.uri.get(idx) + } + LoadBalance::Random => { + let mut rng = rand::thread_rng(); + let max = self.uri.len() - 1; + self.uri.get(rng.gen_range(0..max)) + } + } + } + + fn current_cnt(&self) -> usize { + self.cnt.0.load(Ordering::Relaxed) + } + + fn increment_cnt(&self) -> usize { + if self.current_cnt() < self.uri.len() - 1 { + self.cnt.0.fetch_add(1, Ordering::Relaxed) + } else { + self.cnt.0.fetch_and(0, Ordering::Relaxed) + } + } } impl Backend { diff --git a/src/config/parse.rs b/src/config/parse.rs index 5c5b905..207bcf9 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -21,10 +21,17 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap globals.https_port = Some(HTTPS_LISTEN_PORT); // TODO: - let mut map_example: HashMap = HashMap::new(); + let mut map_example: HashMap = HashMap::new(); map_example.insert( "/maps".to_string(), - "https://www.bing.com".parse::().unwrap(), + Upstream { + uri: vec![ + "https://www.bing.com".parse::().unwrap(), + "https://www.bing.co.jp".parse::().unwrap(), + ], + cnt: Default::default(), + lb: Default::default(), + }, ); backends.insert( "localhost".to_string(), @@ -32,9 +39,16 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap app_name: "Localhost to Google except for maps".to_string(), hostname: "localhost".to_string(), reverse_proxy: ReverseProxy { - // default_destination_uri: "https://www.google.com".parse::().unwrap(), - default_destination_uri: "http://abehiroshi.la.coocan.jp/".parse::().unwrap(), // httpのみの場合の好例 - destination_uris: map_example, + default_upstream: Upstream { + uri: vec![ + "https://www.google.com".parse::().unwrap(), + "https://www.google.co.jp".parse::().unwrap(), + ], + cnt: Default::default(), + lb: Default::default(), + }, + // default_upstream_uri: vec!["http://abehiroshi.la.coocan.jp/".parse::().unwrap()], // httpのみの場合の好例 + upstream: map_example, }, https_redirection: Some(false), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index c3c4c7b..fc4bd55 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -52,14 +52,18 @@ where return secure_redirection(&hostname, self.globals.https_port, &path_and_query); } - // Find reverse proxy for given path + // Find reverse proxy for given path and choose one of upstream host let path = req.uri().path(); - let destination_scheme_host = - if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) { - uri.to_owned() - } else { - backend.reverse_proxy.default_destination_uri.clone() - }; + let upstream_uri = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) { + upstream.get() + } else { + backend.reverse_proxy.default_upstream.get() + }; + let upstream_scheme_host = if let Some(u) = upstream_uri { + u + } else { + return http_error(StatusCode::INTERNAL_SERVER_ERROR); + }; // Upgrade in request header let upgrade_in_request = extract_upgrade(req.headers()); @@ -69,7 +73,7 @@ where let req_forwarded = if let Ok(req) = generate_request_forwarded( client_addr, req, - destination_scheme_host, + upstream_scheme_host, path_and_query, &upgrade_in_request, ) { @@ -139,7 +143,7 @@ fn generate_response_forwarded(response: &mut Response) fn generate_request_forwarded( client_addr: SocketAddr, mut req: Request, - destination_scheme_host: Uri, + upstream_scheme_host: &Uri, path_and_query: String, upgrade: &Option, ) -> Result> { @@ -174,8 +178,8 @@ fn generate_request_forwarded( // update uri in request *req.uri_mut() = Uri::builder() - .scheme(destination_scheme_host.scheme().unwrap().as_str()) - .authority(destination_scheme_host.authority().unwrap().as_str()) + .scheme(upstream_scheme_host.scheme().unwrap().as_str()) + .authority(upstream_scheme_host.authority().unwrap().as_str()) .path_and_query(&path_and_query) .build()?; @@ -188,7 +192,7 @@ fn generate_request_forwarded( } // Change version to http/1.1 when destination scheme is http - if req.version() != Version::HTTP_11 && destination_scheme_host.scheme() == Some(&Scheme::HTTP) { + if req.version() != Version::HTTP_11 && upstream_scheme_host.scheme() == Some(&Scheme::HTTP) { *req.version_mut() = Version::HTTP_11; }