From 1a708692271a3d27408777f80902760414bcb787 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Thu, 28 Jul 2022 21:46:53 +0900 Subject: [PATCH] change path name and server name to specific stract to find longest prefix and exact matching in hashtables --- src/backend/upstream.rs | 4 +- src/config/parse.rs | 2 +- src/handler/handler_main.rs | 10 ++-- src/proxy/proxy_tls.rs | 2 +- src/utils/bytes_name.rs | 91 ++++++++++++++++++++++++++++++++----- 5 files changed, 89 insertions(+), 20 deletions(-) diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index 4c588f9..aef0d83 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -19,7 +19,7 @@ impl ReverseProxy { pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamGroup> { // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 // コスト的にこの程度で十分 - let path_bytes = &(path_str.to_path_name_vec())[..]; + let path_bytes = &path_str.to_path_name_vec(); let matched_upstream = self .upstream @@ -40,7 +40,7 @@ impl ReverseProxy { if let Some((_path, u)) = matched_upstream { debug!( "Found upstream: {:?}", - String::from_utf8(_path.to_vec()).unwrap_or_else(|_| "".to_string()) + String::from_utf8(_path.0.clone()).unwrap_or_else(|_| "".to_string()) ); Some(u) } else { diff --git a/src/config/parse.rs b/src/config/parse.rs index 9f1d289..f0efe75 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -151,7 +151,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", d, d_sn[0] ); - globals.backends.default_server_name_bytes = Some(d_sn[0].as_bytes().to_vec()); + globals.backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); } } diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index 373c363..7c156fe 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -42,20 +42,20 @@ where ////// // Here we start to handle with server_name - let server_name_bytes = if let Ok(v) = req.parse_host() { - v.to_ascii_lowercase() + let server_name = if let Ok(v) = req.parse_host() { + ServerNameBytesExp::from(v) } else { return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); }; // check consistency of between TLS SNI and HOST/Request URI Line. #[allow(clippy::collapsible_if)] if tls_enabled && self.globals.sni_consistency { - if !server_name_bytes.eq_ignore_ascii_case(&tls_server_name.unwrap_or_default()) { + if server_name != tls_server_name.unwrap_or_default() { return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); } } // Find backend application for given server_name, and drop if incoming request is invalid as request. - let backend = if let Some(be) = self.globals.backends.apps.get(&server_name_bytes) { + let backend = if let Some(be) = self.globals.backends.apps.get(&server_name) { be } else if let Some(default_server_name) = &self.globals.backends.default_server_name_bytes { debug!("Serving by default app"); @@ -271,7 +271,7 @@ where return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); }; let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); - new_pq.extend_from_slice(new_path); + new_pq.extend_from_slice(new_path.as_ref()); new_pq.extend_from_slice(&org_pq[matched_path.len()..]); new_pq } diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 0e63a10..be71acc 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -148,7 +148,7 @@ where }; debug!( "HTTP/3 connection incoming (SNI {:?})", - new_server_name + new_server_name.0 ); // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは? // TODO: 通常のTLSと同じenumか何かにまとめたい diff --git a/src/utils/bytes_name.rs b/src/utils/bytes_name.rs index c7c8def..0858f48 100644 --- a/src/utils/bytes_name.rs +++ b/src/utils/bytes_name.rs @@ -1,12 +1,38 @@ -// Server name (hostname or ip address) and path name representation in backends -// For searching hashmap or key list by exact or longest-prefix matching -pub type ServerNameBytesExp = Vec; // lowercase ascii bytes +/// Server name (hostname or ip address) representation in bytes-based struct +/// For searching hashmap or key list by exact or longest-prefix matching +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +pub struct ServerNameBytesExp(pub Vec); // lowercase ascii bytes +impl From<&[u8]> for ServerNameBytesExp { + fn from(b: &[u8]) -> Self { + Self(b.to_ascii_lowercase()) + } +} -// #[derive(Clone, Debug)] -// pub struct ServerNameBytesExp(Vec); - -pub type PathNameBytesExp = Vec; // lowercase ascii bytes +/// Server name (hostname or ip address) representation in bytes-based struct +/// For searching hashmap or key list by exact or longest-prefix matching +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +pub struct PathNameBytesExp(pub Vec); // lowercase ascii bytes +impl PathNameBytesExp { + pub fn len(&self) -> usize { + self.0.len() + } + pub fn get(&self, index: I) -> Option<&I::Output> + where + I: std::slice::SliceIndex<[u8]>, + { + (&self.0).get(index) + } + pub fn starts_with(&self, needle: &Self) -> bool { + self.0.starts_with(&needle.0) + } +} +impl AsRef<[u8]> for PathNameBytesExp { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} +/// Trait to express names in ascii-lowercased bytes pub trait BytesName { type OutputSv: Send + Sync + 'static; type OutputPath; @@ -20,12 +46,12 @@ impl<'a, T: Into>> BytesName for T { fn to_server_name_vec(self) -> Self::OutputSv { let name = self.into().bytes().collect::>().to_ascii_lowercase(); - name + ServerNameBytesExp(name) } fn to_path_name_vec(self) -> Self::OutputPath { let name = self.into().bytes().collect::>().to_ascii_lowercase(); - name + PathNameBytesExp(name) } } @@ -38,7 +64,50 @@ mod tests { let bn = s.to_path_name_vec(); let bn_lc = s.to_server_name_vec(); - assert_eq!(Vec::from(s.as_bytes()), bn); - assert_eq!(Vec::from(s.as_bytes()), bn_lc); + assert_eq!(Vec::from("ok_string".as_bytes()), bn.0); + assert_eq!(Vec::from("ok_string".as_bytes()), bn_lc.0); + } + + #[test] + fn from_works() { + let s = "OK_string".to_server_name_vec(); + let m = ServerNameBytesExp::from("OK_strinG".as_bytes()); + assert_eq!(s, m); + assert_eq!(s.0, "ok_string".as_bytes().to_vec()); + assert_eq!(m.0, "ok_string".as_bytes().to_vec()); + } + + #[test] + fn get_works() { + let s = "OK_str".to_path_name_vec(); + let i = s.get(0); + assert_eq!(Some(&"o".as_bytes()[0]), i); + let i = s.get(1); + assert_eq!(Some(&"k".as_bytes()[0]), i); + let i = s.get(2); + assert_eq!(Some(&"_".as_bytes()[0]), i); + let i = s.get(3); + assert_eq!(Some(&"s".as_bytes()[0]), i); + let i = s.get(4); + assert_eq!(Some(&"t".as_bytes()[0]), i); + let i = s.get(5); + assert_eq!(Some(&"r".as_bytes()[0]), i); + let i = s.get(6); + assert_eq!(None, i); + } + + #[test] + fn start_with_works() { + let s = "OK_str".to_path_name_vec(); + let correct = "OK".to_path_name_vec(); + let incorrect = "KO".to_path_name_vec(); + assert!(s.starts_with(&correct)); + assert!(!s.starts_with(&incorrect)); + } + + #[test] + fn as_ref_works() { + let s = "OK_str".to_path_name_vec(); + assert_eq!(s.as_ref(), "ok_str".as_bytes()); } }