wip
This commit is contained in:
		
					parent
					
						
							
								c47efbfc93
							
						
					
				
			
			
				commit
				
					
						8a89fcb2c2
					
				
			
		
					 7 changed files with 153 additions and 97 deletions
				
			
		|  | @ -12,7 +12,7 @@ pub struct Backend { | |||
|   pub app_name: String, | ||||
|   pub hostname: String, | ||||
|   pub reverse_proxy: ReverseProxy, | ||||
|   pub redirect_to_https: Option<bool>, | ||||
|   pub https_redirection: Option<bool>, | ||||
|   pub tls_cert_path: Option<PathBuf>, | ||||
|   pub tls_cert_key_path: Option<PathBuf>, | ||||
|   pub server_config: Mutex<Option<ServerConfig>>, | ||||
|  |  | |||
|  | @ -24,7 +24,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend> | |||
|   let mut map_example: HashMap<String, Uri> = HashMap::new(); | ||||
|   map_example.insert( | ||||
|     "/maps".to_string(), | ||||
|     "https://bing.com/".parse::<Uri>().unwrap(), | ||||
|     Uri::builder().authority("www.bing.com").build().unwrap(), | ||||
|   ); | ||||
|   backends.insert( | ||||
|     "localhost".to_string(), | ||||
|  | @ -32,10 +32,10 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend> | |||
|       app_name: "Localhost to Google except for maps".to_string(), | ||||
|       hostname: "localhost".to_string(), | ||||
|       reverse_proxy: ReverseProxy { | ||||
|         default_destination_uri: "https://google.com/".parse::<Uri>().unwrap(), | ||||
|         default_destination_uri: Uri::builder().authority("www.google.com").build().unwrap(), | ||||
|         destination_uris: map_example, | ||||
|       }, | ||||
|       redirect_to_https: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず
 | ||||
|       https_redirection: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず
 | ||||
| 
 | ||||
|       tls_cert_path: Some(PathBuf::from(r"localhost1.pem")), | ||||
|       tls_cert_key_path: Some(PathBuf::from(r"localhost1.pem")), | ||||
|  |  | |||
|  | @ -14,7 +14,6 @@ use crate::{ | |||
| }; | ||||
| use futures::future::select_all; | ||||
| use hyper::Client; | ||||
| #[cfg(feature = "forward-hyper-trust-dns")] | ||||
| use hyper_trust_dns::TrustDnsResolver; | ||||
| use std::{collections::HashMap, io::Write, sync::Arc}; | ||||
| use tokio::time::Duration; | ||||
|  | @ -40,7 +39,7 @@ fn main() { | |||
| 
 | ||||
|   let mut runtime_builder = tokio::runtime::Builder::new_multi_thread(); | ||||
|   runtime_builder.enable_all(); | ||||
|   runtime_builder.thread_name("rust-rpxy"); | ||||
|   runtime_builder.thread_name("rpxy"); | ||||
|   let runtime = runtime_builder.build().unwrap(); | ||||
| 
 | ||||
|   runtime.block_on(async { | ||||
|  | @ -69,10 +68,7 @@ fn main() { | |||
| 
 | ||||
| // entrypoint creates and spawns tasks of proxy services
 | ||||
| async fn entrypoint(globals: Arc<Globals>, backends: Arc<HashMap<String, Backend>>) -> Result<()> { | ||||
|   #[cfg(feature = "forward-hyper-trust-dns")] | ||||
|   let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); | ||||
|   #[cfg(not(feature = "forward-hyper-trust-dns"))] | ||||
|   let connector = hyper_tls::HttpsConnector::new(); | ||||
|   let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); | ||||
| 
 | ||||
|   let addresses = globals.listen_sockets.clone(); | ||||
|  |  | |||
|  | @ -1,69 +1,113 @@ | |||
| use crate::{backend::Backend, error::*, globals::Globals, log::*}; | ||||
| use futures::{ | ||||
|   select, | ||||
|   task::{Context, Poll}, | ||||
|   Future, FutureExt, | ||||
| }; | ||||
| use super::Proxy; | ||||
| use crate::{error::*, log::*}; | ||||
| use hyper::{ | ||||
|   client::connect::Connect, | ||||
|   http, | ||||
|   server::conn::Http, | ||||
|   service::{service_fn, Service}, | ||||
|   Body, Client, HeaderMap, Method, Request, Response, StatusCode, Uri, | ||||
| }; | ||||
| use std::{collections::HashMap, net::SocketAddr, pin::Pin, sync::Arc}; | ||||
| use tokio::{ | ||||
|   io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}, | ||||
|   net::TcpListener, | ||||
|   runtime::Handle, | ||||
|   time::Duration, | ||||
|   header::{HeaderMap, HeaderName, HeaderValue}, | ||||
|   Body, Request, Response, StatusCode, Uri, | ||||
| }; | ||||
| use std::net::SocketAddr; | ||||
| 
 | ||||
| // TODO: ここでbackendの名前単位でリクエストを分岐させる
 | ||||
| pub async fn handle_request( | ||||
|   req: Request<Body>, | ||||
| // pub static HEADERS: phf::Map<&'static str, HeaderName> = phf_map! {
 | ||||
| //   "CONNECTION" => HeaderName::from_static("connection"),
 | ||||
| //   "ws" => "wss",
 | ||||
| // };
 | ||||
| 
 | ||||
| impl<T> Proxy<T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   // TODO: ここでbackendの名前単位でリクエストを分岐させる
 | ||||
|   pub async fn handle_request( | ||||
|     self, | ||||
|     req: Request<Body>, | ||||
|     client_ip: SocketAddr, // アクセス制御用
 | ||||
|   ) -> Result<Response<Body>> { | ||||
|     debug!("Handling request: {:?}", req); | ||||
|     // Here we start to handle with hostname
 | ||||
|     // Find backend application for given hostname
 | ||||
|     let (hostname, _port) = if let Ok(v) = parse_host_port(&req, self.tls_enabled) { | ||||
|       v | ||||
|     } else { | ||||
|       return http_error(StatusCode::SERVICE_UNAVAILABLE); | ||||
|     }; | ||||
|     let backend = if let Some(be) = self.backends.get(hostname.as_str()) { | ||||
|       be | ||||
|     } else { | ||||
|       return http_error(StatusCode::SERVICE_UNAVAILABLE); | ||||
|     }; | ||||
| 
 | ||||
|     // Redirect to https if tls_enabled is false and redirect_to_https is true
 | ||||
|     let path_and_query = req.uri().path_and_query().unwrap().as_str().to_owned(); | ||||
|     if !self.tls_enabled && backend.https_redirection.unwrap_or(false) { | ||||
|       debug!("Redirect to secure connection: {}", hostname); | ||||
|       return secure_redirection(&hostname, self.globals.https_port, &path_and_query); | ||||
|     } | ||||
| 
 | ||||
|     // Find reverse proxy for given path
 | ||||
|     let path = req.uri().path(); | ||||
|     let destination_host_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) { | ||||
|       uri.to_owned() | ||||
|     } else { | ||||
|       backend.reverse_proxy.default_destination_uri.clone() | ||||
|     }; | ||||
| 
 | ||||
|     // TODO: Upgrade
 | ||||
|     // TODO: X-Forwarded-For
 | ||||
|     // TODO: Transfer Encoding
 | ||||
| 
 | ||||
|     // Build request from destination information
 | ||||
|     let req_forwarded = if let Ok(req) = | ||||
|       generate_request_forwarded(client_ip, req, destination_host_uri, path_and_query) | ||||
|     { | ||||
|       req | ||||
|     } else { | ||||
|       error!("Failed to generate destination uri for reverse proxy"); | ||||
|       return http_error(StatusCode::SERVICE_UNAVAILABLE); | ||||
|     }; | ||||
|     debug!("Request to be forwarded: {:?}", req_forwarded); | ||||
| 
 | ||||
|     // // Forward request to
 | ||||
|     // let res_backend = match self.forwarder.request(req_forwarded).await {
 | ||||
|     //   Ok(res) => res,
 | ||||
|     //   Err(e) => {
 | ||||
|     //     error!("Failed to get response from backend: {}", e);
 | ||||
|     //     return http_error(StatusCode::BAD_REQUEST);
 | ||||
|     //   }
 | ||||
|     // };
 | ||||
|     // debug!("Response from backend: {:?}", res_backend.status());
 | ||||
|     // Ok(res_backend)
 | ||||
| 
 | ||||
|     http_error(StatusCode::NOT_FOUND) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| // Motivated by https://github.com/felipenoris/hyper-reverse-proxy
 | ||||
| fn generate_request_forwarded<B: core::fmt::Debug>( | ||||
|   client_ip: SocketAddr, | ||||
|   tls_enabled: bool, | ||||
|   globals: Arc<Globals>, | ||||
|   backends: Arc<HashMap<String, Backend>>, | ||||
| ) -> Result<Response<Body>> { | ||||
|   debug!("req: {:?}", req); | ||||
|   // Here we start to handle with hostname
 | ||||
|   // Find backend application for given hostname
 | ||||
|   let (hostname, _port) = parse_hostname_port(&req, tls_enabled)?; | ||||
|   let path = req.uri().path(); | ||||
|   let path_and_query = req.uri().path_and_query().unwrap().as_str(); | ||||
|   println!("{:?}", path_and_query); | ||||
|   let backend = if let Some(be) = backends.get(hostname.as_str()) { | ||||
|     be | ||||
|   } else { | ||||
|     return http_error(StatusCode::SERVICE_UNAVAILABLE); | ||||
|   }; | ||||
|   mut req: Request<B>, | ||||
|   destination_host_uri: Uri, | ||||
|   path_and_query: String, | ||||
| ) -> Result<Request<B>> { | ||||
|   debug!("Generate request to be forwarded"); | ||||
| 
 | ||||
|   // Redirect to https if tls_enabled is false and redirect_to_https is true
 | ||||
|   if !tls_enabled && backend.redirect_to_https.unwrap_or(false) { | ||||
|     debug!("Redirect to https: {}", hostname); | ||||
|     return https_redirection(hostname, globals.https_port, path_and_query); | ||||
|   // update "host" key in request header
 | ||||
|   if req.headers().contains_key("host") { | ||||
|     // HTTP/1.1
 | ||||
|     req.headers_mut().insert( | ||||
|       "host", | ||||
|       HeaderValue::from_str(destination_host_uri.host().unwrap()) | ||||
|         .map_err(|_| anyhow!("Failed to insert destination host into forwarded request"))?, | ||||
|     ); | ||||
|   } | ||||
| 
 | ||||
|   // Find reverse proxy for given path
 | ||||
|   let destination_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) { | ||||
|     uri.to_owned() | ||||
|   } else { | ||||
|     backend.reverse_proxy.default_destination_uri.clone() | ||||
|   }; | ||||
|   // update uri in request
 | ||||
|   *req.uri_mut() = Uri::builder() | ||||
|     .scheme(destination_host_uri.scheme().unwrap().as_str()) | ||||
|     .authority(destination_host_uri.authority().unwrap().as_str()) | ||||
|     .path_and_query(&path_and_query) | ||||
|     .build()?; | ||||
| 
 | ||||
|   debug!("destination_uri: {}", destination_uri); | ||||
|   // if req.version() == hyper::Version::HTTP_11 {
 | ||||
|   //   Ok(Response::new(Body::from("Hello World")))
 | ||||
|   // } else {
 | ||||
|   // Note: it's usually better to return a Response
 | ||||
|   // with an appropriate StatusCode instead of an Err.
 | ||||
|   // Err("not HTTP/1.1, abort connection")
 | ||||
|   // http_error(StatusCode::NOT_FOUND)
 | ||||
|   https_redirection("www.google.com".to_string(), Some(443_u16), "/") | ||||
|   // }
 | ||||
|   // });
 | ||||
|   Ok(req) | ||||
| } | ||||
| 
 | ||||
| fn http_error(status_code: StatusCode) -> Result<Response<Body>> { | ||||
|  | @ -74,19 +118,19 @@ fn http_error(status_code: StatusCode) -> Result<Response<Body>> { | |||
|   Ok(response) | ||||
| } | ||||
| 
 | ||||
| fn https_redirection( | ||||
|   hostname: String, | ||||
|   https_port: Option<u16>, | ||||
| fn secure_redirection( | ||||
|   hostname: &str, | ||||
|   tls_port: Option<u16>, | ||||
|   path_and_query: &str, | ||||
| ) -> Result<Response<Body>> { | ||||
|   let dest_uri: String = if let Some(https_port) = https_port { | ||||
|     if https_port == 443 { | ||||
|   let dest_uri: String = if let Some(tls_port) = tls_port { | ||||
|     if tls_port == 443 { | ||||
|       format!("https://{}{}", hostname, path_and_query) | ||||
|     } else { | ||||
|       format!("https://{}:{}{}", hostname, https_port, path_and_query) | ||||
|       format!("https://{}:{}{}", hostname, tls_port, path_and_query) | ||||
|     } | ||||
|   } else { | ||||
|     return http_error(StatusCode::SERVICE_UNAVAILABLE); | ||||
|     bail!("Internal error! TLS port is not set internally."); | ||||
|   }; | ||||
|   let response = Response::builder() | ||||
|     .status(StatusCode::MOVED_PERMANENTLY) | ||||
|  | @ -96,7 +140,7 @@ fn https_redirection( | |||
|   Ok(response) | ||||
| } | ||||
| 
 | ||||
| fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> { | ||||
| fn parse_host_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> { | ||||
|   let hostname_port_headers = req.headers().get("host"); | ||||
|   let hostname_uri = req.uri().host(); | ||||
|   let port_uri = req.uri().port_u16(); | ||||
|  | @ -107,7 +151,6 @@ fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String | |||
| 
 | ||||
|   let (hostname, port) = match (hostname_uri, hostname_port_headers) { | ||||
|     (Some(x), _) => { | ||||
|       let hostname = hostname_uri.unwrap(); | ||||
|       let port = if let Some(p) = port_uri { | ||||
|         p | ||||
|       } else if tls_enabled { | ||||
|  | @ -115,7 +158,7 @@ fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String | |||
|       } else { | ||||
|         80 | ||||
|       }; | ||||
|       (hostname.to_string(), port) | ||||
|       (x.to_string(), port) | ||||
|     } | ||||
|     (None, Some(x)) => { | ||||
|       let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap(); | ||||
|  | @ -138,3 +181,29 @@ fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String | |||
| 
 | ||||
|   Ok((hostname, port)) | ||||
| } | ||||
| 
 | ||||
| // fn get_upgrade_type(headers: &HeaderMap) -> Option<String> {
 | ||||
| //   #[allow(clippy::blocks_in_if_conditions)]
 | ||||
| //   if headers
 | ||||
| //     .get(&*CONNECTION_HEADER)
 | ||||
| //     .map(|value| {
 | ||||
| //       value
 | ||||
| //         .to_str()
 | ||||
| //         .unwrap()
 | ||||
| //         .split(',')
 | ||||
| //         .any(|e| e.trim() == *UPGRADE_HEADER)
 | ||||
| //     })
 | ||||
| //     .unwrap_or(false)
 | ||||
| //   {
 | ||||
| //     if let Some(upgrade_value) = headers.get(&*UPGRADE_HEADER) {
 | ||||
| //       debug!(
 | ||||
| //         "Found upgrade header with value: {}",
 | ||||
| //         upgrade_value.to_str().unwrap().to_owned()
 | ||||
| //       );
 | ||||
| 
 | ||||
| //       return Some(upgrade_value.to_str().unwrap().to_owned());
 | ||||
| //     }
 | ||||
| //   }
 | ||||
| 
 | ||||
| //   None
 | ||||
| // }
 | ||||
|  |  | |||
|  | @ -1,7 +1,7 @@ | |||
| use super::proxy_handler::handle_request; | ||||
| // use super::proxy_handler::handle_request;
 | ||||
| use crate::{backend::Backend, error::*, globals::Globals, log::*}; | ||||
| use hyper::{ | ||||
|   client::connect::Connect, server::conn::Http, service::service_fn, Body, Client, Method, Request, | ||||
|   client::connect::Connect, server::conn::Http, service::service_fn, Body, Client, Request, | ||||
| }; | ||||
| use std::{collections::HashMap, net::SocketAddr, sync::Arc}; | ||||
| use tokio::{ | ||||
|  | @ -64,15 +64,7 @@ where | |||
|         // server.serve_connection(stream, self),
 | ||||
|         server.serve_connection( | ||||
|           stream, | ||||
|           service_fn(move |req: Request<Body>| { | ||||
|             handle_request( | ||||
|               req, | ||||
|               peer_addr, | ||||
|               self.tls_enabled, | ||||
|               self.globals.clone(), | ||||
|               self.backends.clone(), | ||||
|             ) | ||||
|           }), | ||||
|           service_fn(move |req: Request<Body>| self.clone().handle_request(req, peer_addr)), | ||||
|         ), | ||||
|       ) | ||||
|       .await | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara