restructuring src dir

This commit is contained in:
Jun Kurihara 2022-07-27 20:33:37 +09:00
commit b56bf54318
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
12 changed files with 146 additions and 232 deletions

View file

@ -1,31 +1,27 @@
use crate::{backend_opt::UpstreamOption, log::*};
use rand::Rng;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
mod upstream;
mod upstream_opts;
use crate::log::*;
use rustc_hash::FxHashMap as HashMap;
use std::{
borrow::Cow,
fs::File,
io::{self, BufReader, Cursor, Read},
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
sync::Arc,
};
use tokio_rustls::rustls::{
server::ResolvesServerCertUsingSni,
sign::{any_supported_type, CertifiedKey},
Certificate, PrivateKey, ServerConfig,
};
pub use upstream::{ReverseProxy, Upstream, UpstreamGroup};
pub use upstream_opts::UpstreamOption;
// server name (hostname or ip address) in ascii lower case
pub type ServerNameLC = Vec<u8>;
pub type PathNameLC = Vec<u8>;
pub struct Backends {
pub apps: HashMap<ServerNameLC, Backend>, // hyper::uriで抜いたhostで引っ掛ける
pub default_server_name: Option<ServerNameLC>, // for plaintext http
}
// server name (hostname or ip address) and path name representation in backends
pub type ServerNameExp = Vec<u8>; // lowercase ascii bytes
pub type PathNameExp = Vec<u8>; // lowercase ascii bytes
/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
pub struct Backend {
pub app_name: String,
pub server_name: String,
@ -37,104 +33,6 @@ pub struct Backend {
pub https_redirection: Option<bool>,
}
#[derive(Debug, Clone)]
pub struct ReverseProxy {
pub upstream: HashMap<PathNameLC, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
}
impl ReverseProxy {
pub fn get<'a>(&self, path_str: impl Into<Cow<'a, str>>) -> Option<&UpstreamGroup> {
// trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、
// コスト的にこの程度で十分
let path_lc = path_str.into().to_ascii_lowercase();
let path_bytes = path_lc.as_bytes();
let matched_upstream = self
.upstream
.iter()
.filter(|(route_bytes, _)| {
match path_bytes.starts_with(route_bytes) {
true => {
route_bytes.len() == 1 // route = '/', i.e., default
|| match path_bytes.get(route_bytes.len()) {
None => true, // exact case
Some(p) => p == &b'/', // sub-path case
}
}
_ => false,
}
})
.max_by_key(|(route_bytes, _)| route_bytes.len());
if let Some((_path, u)) = matched_upstream {
debug!(
"Found upstream: {:?}",
String::from_utf8(_path.to_vec()).unwrap_or_else(|_| "<none>".to_string())
);
Some(u)
} else {
None
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum LoadBalance {
RoundRobin,
Random,
}
impl Default for LoadBalance {
fn default() -> Self {
Self::RoundRobin
}
}
#[derive(Debug, Clone)]
pub struct Upstream {
pub uri: hyper::Uri, // base uri without specific path
}
#[derive(Debug, Clone)]
pub struct UpstreamGroup {
pub upstream: Vec<Upstream>,
pub path: PathNameLC,
pub replace_path: Option<PathNameLC>,
pub lb: LoadBalance,
pub cnt: UpstreamCount, // counter for load balancing
pub opts: HashSet<UpstreamOption>,
}
#[derive(Debug, Clone, Default)]
pub struct UpstreamCount(Arc<AtomicUsize>);
impl UpstreamGroup {
pub fn get(&self) -> Option<&Upstream> {
match self.lb {
LoadBalance::RoundRobin => {
let idx = self.increment_cnt();
self.upstream.get(idx)
}
LoadBalance::Random => {
let mut rng = rand::thread_rng();
let max = self.upstream.len() - 1;
self.upstream.get(rng.gen_range(0..max))
}
}
}
fn current_cnt(&self) -> usize {
self.cnt.0.load(Ordering::Relaxed)
}
fn increment_cnt(&self) -> usize {
if self.current_cnt() < self.upstream.len() - 1 {
self.cnt.0.fetch_add(1, Ordering::Relaxed)
} else {
self.cnt.0.fetch_and(0, Ordering::Relaxed)
}
}
}
impl Backend {
pub fn read_certs_and_key(&self) -> io::Result<CertifiedKey> {
debug!("Read TLS server certificates and private key");
@ -210,6 +108,12 @@ impl Backend {
}
}
/// HashMap and some meta information for multiple Backend structs.
pub struct Backends {
pub apps: HashMap<ServerNameExp, Backend>, // hyper::uriで抜いたhostで引っ掛ける
pub default_server_name: Option<ServerNameExp>, // for plaintext http
}
impl Backends {
pub async fn generate_server_crypto_with_cert_resolver(&self) -> Result<ServerConfig, anyhow::Error> {
let mut resolver = ResolvesServerCertUsingSni::new();

109
src/backend/upstream.rs Normal file
View file

@ -0,0 +1,109 @@
use super::{PathNameExp, UpstreamOption};
use crate::log::*;
use rand::Rng;
use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet};
use std::{
borrow::Cow,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
#[derive(Debug, Clone)]
pub struct ReverseProxy {
pub upstream: HashMap<PathNameExp, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
}
impl ReverseProxy {
pub fn get<'a>(&self, path_str: impl Into<Cow<'a, str>>) -> Option<&UpstreamGroup> {
// trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、
// コスト的にこの程度で十分
let path_lc = path_str.into().to_ascii_lowercase();
let path_bytes = path_lc.as_bytes();
let matched_upstream = self
.upstream
.iter()
.filter(|(route_bytes, _)| {
match path_bytes.starts_with(route_bytes) {
true => {
route_bytes.len() == 1 // route = '/', i.e., default
|| match path_bytes.get(route_bytes.len()) {
None => true, // exact case
Some(p) => p == &b'/', // sub-path case
}
}
_ => false,
}
})
.max_by_key(|(route_bytes, _)| route_bytes.len());
if let Some((_path, u)) = matched_upstream {
debug!(
"Found upstream: {:?}",
String::from_utf8(_path.to_vec()).unwrap_or_else(|_| "<none>".to_string())
);
Some(u)
} else {
None
}
}
}
#[allow(dead_code)]
#[derive(Debug, Clone)]
pub enum LoadBalance {
RoundRobin,
Random,
}
impl Default for LoadBalance {
fn default() -> Self {
Self::RoundRobin
}
}
#[derive(Debug, Clone)]
pub struct Upstream {
pub uri: hyper::Uri, // base uri without specific path
}
#[derive(Debug, Clone)]
pub struct UpstreamGroup {
pub upstream: Vec<Upstream>,
pub path: PathNameExp,
pub replace_path: Option<PathNameExp>,
pub lb: LoadBalance,
pub cnt: UpstreamCount, // counter for load balancing
pub opts: HashSet<UpstreamOption>,
}
#[derive(Debug, Clone, Default)]
pub struct UpstreamCount(Arc<AtomicUsize>);
impl UpstreamGroup {
pub fn get(&self) -> Option<&Upstream> {
match self.lb {
LoadBalance::RoundRobin => {
let idx = self.increment_cnt();
self.upstream.get(idx)
}
LoadBalance::Random => {
let mut rng = rand::thread_rng();
let max = self.upstream.len() - 1;
self.upstream.get(rng.gen_range(0..max))
}
}
}
fn current_cnt(&self) -> usize {
self.cnt.0.load(Ordering::Relaxed)
}
fn increment_cnt(&self) -> usize {
if self.current_cnt() < self.upstream.len() - 1 {
self.cnt.0.fetch_add(1, Ordering::Relaxed)
} else {
self.cnt.0.fetch_and(0, Ordering::Relaxed)
}
}
}

View file

@ -1,7 +1,6 @@
use super::toml::{ConfigToml, ReverseProxyOption};
use crate::{
backend::{Backend, PathNameLC, ReverseProxy, UpstreamGroup},
backend_opt::UpstreamOption,
backend::{Backend, PathNameExp, ReverseProxy, UpstreamGroup, UpstreamOption},
constants::*,
error::*,
globals::*,
@ -192,7 +191,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro
}
fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result<ReverseProxy, anyhow::Error> {
let mut upstream: HashMap<PathNameLC, UpstreamGroup> = HashMap::default();
let mut upstream: HashMap<PathNameExp, UpstreamGroup> = HashMap::default();
rp_settings.iter().for_each(|rpo| {
let path = match &rpo.path {
Some(p) => p.as_bytes().to_ascii_lowercase(),

View file

@ -1,7 +1,7 @@
// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
use super::{utils_headers::*, utils_request::*, utils_response::ResLog, utils_synth_response::*};
use super::{utils_headers::*, utils_request::*, utils_synth_response::*};
use crate::{
backend::{ServerNameLC, UpstreamGroup},
backend::{ServerNameExp, UpstreamGroup},
error::*,
globals::Globals,
log::*,
@ -39,9 +39,8 @@ where
client_addr: SocketAddr, // アクセス制御用
listen_addr: SocketAddr,
tls_enabled: bool,
tls_server_name: Option<ServerNameLC>,
tls_server_name: Option<ServerNameExp>,
) -> Result<Response<Body>> {
req.log_debug(&client_addr, Some("(from Client)"));
////////
let mut log_data = MessageLog::from(&req);
log_data.client_addr(&client_addr);
@ -102,7 +101,6 @@ where
return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data);
};
// debug!("Request to be forwarded: {:?}", req_forwarded);
req.log_debug(&client_addr, Some("(to Backend)"));
log_data.xff(&req.headers().get("x-forwarded-for"));
log_data.upstream(req.uri());
//////
@ -123,14 +121,9 @@ where
}
};
res_backend.log_debug(&backend.server_name, &client_addr, Some("(from Backend)"));
// let response_log = res_backend.status().to_string();
if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS {
// Generate response to client
if self.generate_response_forwarded(&mut res_backend).is_ok() {
// info!("{} => {}", request_log, response_log);
res_backend.log_debug(&backend.server_name, &client_addr, Some("(to Client)"));
log_data.status_code(&res_backend.status()).output();
return Ok(res_backend);
} else {

View file

@ -1,7 +1,6 @@
mod handler_main;
mod utils_headers;
mod utils_request;
mod utils_response;
mod utils_synth_response;
pub use handler_main::HttpMessageHandler;

View file

@ -1,4 +1,9 @@
use crate::{backend::UpstreamGroup, backend_opt::UpstreamOption, error::*, log::*, utils::*};
use crate::{
backend::{UpstreamGroup, UpstreamOption},
error::*,
log::*,
utils::*,
};
use bytes::BufMut;
use hyper::{
header::{self, HeaderMap, HeaderName, HeaderValue},

View file

@ -1,58 +1,5 @@
use crate::{error::*, log::*, utils::*};
use crate::error::*;
use hyper::{header, Request};
use std::fmt::Display;
////////////////////////////////////////////////////
// Functions of utils for request messages
pub trait ReqLog {
fn log<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>);
fn log_debug<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>);
fn build_message<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>) -> String;
}
impl<B> ReqLog for &Request<B> {
fn log<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>) {
info!("{}", &self.build_message(src, extra));
}
fn log_debug<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>) {
debug!("{}", &self.build_message(src, extra));
}
fn build_message<T: Display + ToCanonical>(self, src: &T, extra: Option<&str>) -> String {
let canonical_src = src.to_canonical();
let host = self
.headers()
.get(header::HOST)
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
let uri_scheme = self
.uri()
.scheme_str()
.map_or_else(|| "".to_string(), |v| format!("{}://", v));
let uri_host = self.uri().host().unwrap_or("");
let uri_pq = self.uri().path_and_query().map_or_else(|| "", |v| v.as_str());
let ua = self
.headers()
.get(header::USER_AGENT)
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
let xff = self
.headers()
.get("x-forwarded-for")
.map_or_else(|| "", |v| v.to_str().unwrap_or(""));
format!(
"{} <- {} -- {} {} {:?} -- ({}{}) \"{}\" \"{}\" {}",
host,
canonical_src,
self.method(),
uri_pq,
self.version(),
uri_scheme,
uri_host,
ua,
xff,
extra.unwrap_or("")
)
}
}
pub trait ParseHost {
fn parse_host(&self) -> Result<&[u8]>;

View file

@ -1,41 +0,0 @@
use crate::{log::*, utils::*};
use hyper::Response;
use std::fmt::Display;
////////////////////////////////////////////////////
// Functions of utils for request messages
pub trait ResLog {
fn log<T1: Display, T2: Display + ToCanonical>(self, server_name: &T1, client_addr: &T2, extra: Option<&str>);
fn log_debug<T1: Display, T2: Display + ToCanonical>(self, server_name: &T1, client_addr: &T2, extra: Option<&str>);
fn build_message<T1: Display, T2: Display + ToCanonical>(
self,
server_name: &T1,
client_addr: &T2,
extra: Option<&str>,
) -> String;
}
impl<B> ResLog for &Response<B> {
fn log<T1: Display, T2: Display + ToCanonical>(self, server_name: &T1, client_addr: &T2, extra: Option<&str>) {
info!("{}", &self.build_message(server_name, client_addr, extra));
}
fn log_debug<T1: Display, T2: Display + ToCanonical>(self, server_name: &T1, client_addr: &T2, extra: Option<&str>) {
debug!("{}", &self.build_message(server_name, client_addr, extra));
}
fn build_message<T1: Display, T2: Display + ToCanonical>(
self,
server_name: &T1,
client_addr: &T2,
extra: Option<&str>,
) -> String {
let canonical_client_addr = client_addr.to_canonical();
format!(
"{} <- {} -- {} {:?} {}",
canonical_client_addr,
server_name,
self.status(),
self.version(),
// self.headers(),
extra.map_or_else(|| "", |v| v)
)
}
}

View file

@ -6,7 +6,6 @@ use tikv_jemallocator::Jemalloc;
static GLOBAL: Jemalloc = Jemalloc;
mod backend;
mod backend_opt;
mod config;
mod constants;
mod error;
@ -17,7 +16,7 @@ mod proxy;
mod utils;
use crate::{
backend::{Backend, Backends, ServerNameLC},
backend::{Backend, Backends, ServerNameExp},
config::parse_opts,
constants::*,
error::*,
@ -73,7 +72,7 @@ fn main() {
runtime_handle: runtime.handle().clone(),
backends: Backends {
default_server_name: None,
apps: HashMap::<ServerNameLC, Backend>::default(),
apps: HashMap::<ServerNameExp, Backend>::default(),
},
sni_consistency: true,

View file

@ -1,5 +1,5 @@
use super::Proxy;
use crate::{backend::ServerNameLC, error::*, log::*};
use crate::{backend::ServerNameExp, error::*, log::*};
use bytes::{Buf, Bytes};
use h3::{quic::BidiStream, server::RequestStream};
use hyper::{client::connect::Connect, Body, Request, Response};
@ -10,7 +10,7 @@ impl<T> Proxy<T>
where
T: Connect + Clone + Sync + Send + 'static,
{
pub(super) async fn connection_serve_h3(self, conn: quinn::Connecting, tls_server_name: ServerNameLC) -> Result<()> {
pub(super) async fn connection_serve_h3(self, conn: quinn::Connecting, tls_server_name: ServerNameExp) -> Result<()> {
let client_addr = conn.remote_address();
match conn.await {
@ -68,7 +68,7 @@ where
req: Request<()>,
stream: RequestStream<S, Bytes>,
client_addr: SocketAddr,
tls_server_name: ServerNameLC,
tls_server_name: ServerNameExp,
) -> Result<()>
where
S: BidiStream<Bytes> + Send + 'static,

View file

@ -1,5 +1,5 @@
// use super::proxy_handler::handle_request;
use crate::{backend::ServerNameLC, error::*, globals::Globals, handler::HttpMessageHandler, log::*};
use crate::{backend::ServerNameExp, error::*, globals::Globals, handler::HttpMessageHandler, log::*};
use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request};
use std::{net::SocketAddr, sync::Arc};
use tokio::{
@ -50,7 +50,7 @@ where
stream: I,
server: Http<LocalExecutor>,
peer_addr: SocketAddr,
tls_server_name: Option<ServerNameLC>,
tls_server_name: Option<ServerNameExp>,
) where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{