diff --git a/src/backend.rs b/src/backend.rs index 07784c0..4cbbdee 100644 --- a/src/backend.rs +++ b/src/backend.rs @@ -1,8 +1,8 @@ -use crate::backend_opt::UpstreamOption; -use crate::log::*; +use crate::{backend_opt::UpstreamOption, log::*}; use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; use std::{ + borrow::Cow, fs::File, io::{self, BufReader, Cursor, Read}, path::PathBuf, @@ -15,6 +15,7 @@ use tokio_rustls::rustls::{Certificate, PrivateKey, ServerConfig}; // server name (hostname or ip address) in ascii lower case pub type ServerNameLC = Vec; +pub type PathNameLC = Vec; pub struct Backends { pub apps: HashMap, // TODO: hyper::uriで抜いたhostで引っ掛ける。Stringでいいのか? @@ -34,8 +35,42 @@ pub struct Backend { #[derive(Debug, Clone)] pub struct ReverseProxy { - pub default_upstream: Option, - pub upstream: HashMap, + pub upstream: HashMap, +} + +impl ReverseProxy { + pub fn get<'a>(&self, path_str: impl Into>) -> Option<&Upstream> { + // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 + // コスト的にこの程度で十分 + let path_lc = path_str.into().to_ascii_lowercase(); + let path_bytes = path_lc.as_bytes(); + + let matched_upstream = self + .upstream + .iter() + .filter(|(route_bytes, _)| { + match path_bytes.starts_with(route_bytes) { + true => { + route_bytes.len() == 1 // route = '/', i.e., default + || match path_bytes.get(route_bytes.len()) { + None => true, // exact case + Some(p) => p == &b'/', // sub-path case + } + } + _ => false, + } + }) + .max_by_key(|(route_bytes, _)| route_bytes.len()); + if let Some((_path, u)) = matched_upstream { + debug!( + "Found upstream: {:?}", + String::from_utf8(_path.to_vec()).unwrap_or_else(|_| "".to_string()) + ); + Some(u) + } else { + None + } + } } #[allow(dead_code)] diff --git a/src/config/parse.rs b/src/config/parse.rs index 43671b8..484b941 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,6 +1,6 @@ use super::toml::{ConfigToml, ReverseProxyOption}; use crate::{ - backend::{Backend, ReverseProxy, Upstream}, + backend::{Backend, PathNameLC, ReverseProxy, Upstream}, backend_opt::UpstreamOption, constants::*, error::*, @@ -169,8 +169,7 @@ pub fn parse_opts(globals: &mut Globals) -> Result<()> { } fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> Result { - let mut upstream: HashMap = HashMap::default(); - let mut default_upstream: Option = None; + let mut upstream: HashMap = HashMap::default(); rp_settings.iter().for_each(|rpo| { let elem = Upstream { uri: rpo.upstream.iter().map(|x| x.to_uri().unwrap()).collect(), @@ -189,17 +188,17 @@ fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> Result }; if rpo.path.is_some() { - upstream.insert(rpo.path.as_ref().unwrap().to_owned(), elem); + upstream.insert( + rpo.path.as_ref().unwrap().as_bytes().to_ascii_lowercase(), + elem, + ); } else { - default_upstream = Some(elem) + upstream.insert("/".as_bytes().to_ascii_lowercase(), elem); } }); ensure!( rp_settings.iter().filter(|rpo| rpo.path.is_none()).count() < 2, "Multiple default reverse proxy setting" ); - Ok(ReverseProxy { - default_upstream, - upstream, - }) + Ok(ReverseProxy { upstream }) } diff --git a/src/msg_handler/handler.rs b/src/msg_handler/handler.rs index 34adf51..f00cc1a 100644 --- a/src/msg_handler/handler.rs +++ b/src/msg_handler/handler.rs @@ -54,14 +54,11 @@ where return secure_redirection(&backend.server_name, self.globals.https_port, &req); } - /////////////////////// // Find reverse proxy for given path and choose one of upstream host - // TODO: More flexible path matcher + // Longest prefix match let path = req.uri().path(); - let upstream = if let Some(upstream) = backend.reverse_proxy.upstream.get(path) { + let upstream = if let Some(upstream) = backend.reverse_proxy.get(path) { upstream - } else if let Some(default_upstream) = &backend.reverse_proxy.default_upstream { - default_upstream } else { return http_error(StatusCode::NOT_FOUND); }; @@ -70,7 +67,6 @@ where } else { return http_error(StatusCode::INTERNAL_SERVER_ERROR); }; - /////////////////////// // Upgrade in request header let upgrade_in_request = extract_upgrade(req.headers());