From 634d556ea9887b4a55c74b9731acd5af5dfb5f10 Mon Sep 17 00:00:00 2001 From: Jun Kurihara Date: Fri, 17 Jun 2022 19:01:18 -0400 Subject: [PATCH] simplified --- src/acceptor.rs | 269 +++++++++++++++++++++++++++--------------------- src/proxy.rs | 3 +- src/tls.rs | 5 +- 3 files changed, 156 insertions(+), 121 deletions(-) diff --git a/src/acceptor.rs b/src/acceptor.rs index 3bf0c09..f79ad86 100644 --- a/src/acceptor.rs +++ b/src/acceptor.rs @@ -4,8 +4,11 @@ use futures::{ Future, }; use hyper::{ - client::connect::Connect, http, server::conn::Http, Body, Client, HeaderMap, Method, Request, - Response, StatusCode, + client::connect::Connect, + http, + server::conn::Http, + service::{service_fn, Service}, + Body, Client, HeaderMap, Method, Request, Response, StatusCode, }; use std::{net::SocketAddr, pin::Pin, sync::Arc}; use tokio::{ @@ -48,135 +51,157 @@ where #[derive(Clone)] pub struct PacketAcceptor where - T: hyper::client::connect::Connect + Send + Sync + Clone + 'static, + T: Connect + Clone + Sync + Send + 'static, { pub listening_on: SocketAddr, - pub forwarder: Client, + pub forwarder: Arc>, pub globals: Arc, } -#[allow(clippy::type_complexity)] -impl hyper::service::Service> for PacketAcceptor -where - T: Connect + Clone + Send + Sync + 'static, -{ - type Response = Response; +// impl Service> for PacketAcceptor +// where +// T: Connect + Clone + Sync + Send + 'static, +// { +// type Response = Response; - type Error = http::Error; - type Future = Pin> + Send>>; +// type Error = http::Error; +// type Future = Pin> + Send>>; - fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } +// fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll> { +// Poll::Ready(Ok(())) +// } - fn call(&mut self, req: Request) -> Self::Future { - debug!("\nserve: {:?}\n{:?}", self.listening_on, req); - let self_inner = self.clone(); +// fn call(&mut self, req: Request) -> Self::Future { +// debug!( +// "serving {:?} {:?} request to {:?}", +// req.version(), +// req.method(), +// req.uri() +// ); +// let self_inner = self.clone(); - // 1. check uri (domain queried host name) - // 2. build uri to forwarding target destination - // 3. build request from uri and body - // 4. send request to forwarding target +// // 1. check uri (domain queried host name) +// // 2. build uri to forwarding target destination +// // 3. build request from uri and body +// // 4. send request to forwarding target - if *req.method() == Method::GET { - Box::pin(async move { - // let uri = req.uri(); - let target_uri = hyper::Uri::builder() - .scheme("https") - .authority("www.google.com") - .path_and_query("/") - .build() - .unwrap(); - println!("{:?}", target_uri); - match self_inner.forwarder.get(target_uri).await { - Ok(res) => Ok(res), - Err(e) => { - error!("{:?}", e); - http_error(StatusCode::INTERNAL_SERVER_ERROR) - } - } - }) - } else { - // let globals = &self.doh.globals; - // let self_inner = self.clone(); - // if req.uri().path() == globals.path { - // Box::pin(async move { - // let mut subscriber = None; - // if self_inner.doh.globals.enable_auth_target { - // subscriber = match auth::authenticate( - // &self_inner.doh.globals, - // &req, - // ValidationLocation::Target, - // &self_inner.peer_addr, - // ) { - // Ok((sub, aud)) => { - // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); - // sub - // } - // Err(e) => { - // error!("{:?}", e); - // return Ok(e); - // } - // }; - // } - // match *req.method() { - // Method::POST => self_inner.doh.serve_post(req, subscriber).await, - // Method::GET => self_inner.doh.serve_get(req, subscriber).await, - // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), - // } - // }) - // } else if req.uri().path() == globals.odoh_configs_path { - // match *req.method() { - // Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }), - // _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), - // } - // } else { - // #[cfg(not(feature = "odoh-proxy"))] - // { - // Box::pin(async { http_error(StatusCode::NOT_FOUND) }) - // } - // #[cfg(feature = "odoh-proxy")] - // { - // if req.uri().path() == globals.odoh_proxy_path { - // Box::pin(async move { - // let mut subscriber = None; - // if self_inner.doh.globals.enable_auth_proxy { - // subscriber = match auth::authenticate( - // &self_inner.doh.globals, - // &req, - // ValidationLocation::Proxy, - // &self_inner.peer_addr, - // ) { - // Ok((sub, aud)) => { - // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); - // sub - // } - // Err(e) => { - // error!("{:?}", e); - // return Ok(e); - // } - // }; - // } - // // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11 - // // Golang impl.: https://github.com/cloudflare/odoh-server-go - // // Based on the draft and Golang implementation, only post method is allowed. - // match *req.method() { - // Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await, - // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), - // } - // }) - // } else { - Box::pin(async { http_error(StatusCode::NOT_FOUND) }) - } - // } - // } - // } - } +// if *req.method() == Method::GET { +// Box::pin(async move { +// // let uri = req.uri(); +// let target_uri = hyper::Uri::builder() +// .scheme("https") +// .authority("www.google.com") +// .path_and_query("/") +// .build() +// .unwrap(); +// println!("{:?}", target_uri); +// match self_inner.forwarder.get(target_uri).await { +// Ok(res) => Ok(res), +// Err(e) => { +// error!("{:?}", e); +// http_error(StatusCode::INTERNAL_SERVER_ERROR) +// } +// } +// }) +// } else { +// // let globals = &self.doh.globals; +// // let self_inner = self.clone(); +// // if req.uri().path() == globals.path { +// // Box::pin(async move { +// // let mut subscriber = None; +// // if self_inner.doh.globals.enable_auth_target { +// // subscriber = match auth::authenticate( +// // &self_inner.doh.globals, +// // &req, +// // ValidationLocation::Target, +// // &self_inner.peer_addr, +// // ) { +// // Ok((sub, aud)) => { +// // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); +// // sub +// // } +// // Err(e) => { +// // error!("{:?}", e); +// // return Ok(e); +// // } +// // }; +// // } +// // match *req.method() { +// // Method::POST => self_inner.doh.serve_post(req, subscriber).await, +// // Method::GET => self_inner.doh.serve_get(req, subscriber).await, +// // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), +// // } +// // }) +// // } else if req.uri().path() == globals.odoh_configs_path { +// // match *req.method() { +// // Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }), +// // _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), +// // } +// // } else { +// // #[cfg(not(feature = "odoh-proxy"))] +// // { +// // Box::pin(async { http_error(StatusCode::NOT_FOUND) }) +// // } +// // #[cfg(feature = "odoh-proxy")] +// // { +// // if req.uri().path() == globals.odoh_proxy_path { +// // Box::pin(async move { +// // let mut subscriber = None; +// // if self_inner.doh.globals.enable_auth_proxy { +// // subscriber = match auth::authenticate( +// // &self_inner.doh.globals, +// // &req, +// // ValidationLocation::Proxy, +// // &self_inner.peer_addr, +// // ) { +// // Ok((sub, aud)) => { +// // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); +// // sub +// // } +// // Err(e) => { +// // error!("{:?}", e); +// // return Ok(e); +// // } +// // }; +// // } +// // // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11 +// // // Golang impl.: https://github.com/cloudflare/odoh-server-go +// // // Based on the draft and Golang implementation, only post method is allowed. +// // match *req.method() { +// // Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await, +// // _ => http_error(StatusCode::METHOD_NOT_ALLOWED), +// // } +// // }) +// // } else { +// Box::pin(async { http_error(StatusCode::NOT_FOUND) }) +// } +// // } +// // } +// // } +// } +// } + +async fn handle_request( + req: Request, + client_ip: SocketAddr, + globals: Arc, +) -> Result, http::Error> { + // http_error(StatusCode::NOT_FOUND) + debug!("{:?}", req); + // if req.version() == hyper::Version::HTTP_11 { + // Ok(Response::new(Body::from("Hello World"))) + // } else { + // Note: it's usually better to return a Response + // with an appropriate StatusCode instead of an Err. + // Err("not HTTP/1.1, abort connection") + http_error(StatusCode::NOT_FOUND) + // } + // }); } impl PacketAcceptor where - T: Connect + Clone + Send + Sync + 'static, + T: Connect + Clone + Sync + Send + 'static, { pub async fn client_serve(self, stream: I, server: Http, peer_addr: SocketAddr) where @@ -187,13 +212,21 @@ where clients_count.decrement(); return; } + self.globals.runtime_handle.clone().spawn(async move { tokio::time::timeout( self.globals.timeout + Duration::from_secs(1), - server.serve_connection(stream, self), + // server.serve_connection(stream, self), + server.serve_connection( + stream, + service_fn(move |req: Request| { + handle_request(req, peer_addr, self.globals.clone()) + }), + ), ) .await .ok(); + clients_count.decrement(); }); } diff --git a/src/proxy.rs b/src/proxy.rs index e0943fd..a4b1cfc 100644 --- a/src/proxy.rs +++ b/src/proxy.rs @@ -19,11 +19,12 @@ impl Proxy { let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); #[cfg(not(feature = "forward-hyper-trust-dns"))] let connector = hyper_tls::HttpsConnector::new(); + let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector)); let acceptor = PacketAcceptor { listening_on: addr, globals: self.globals.clone(), - forwarder: Client::builder().build::<_, hyper::Body>(connector), + forwarder, }; self.globals.runtime_handle.spawn(acceptor.start()) })); diff --git a/src/tls.rs b/src/tls.rs index b503bfb..7e642c7 100644 --- a/src/tls.rs +++ b/src/tls.rs @@ -5,7 +5,8 @@ use std::sync::Arc; use std::time::Duration; use futures::{future::FutureExt, join, select}; -use hyper::{client::connect::Connect, server::conn::Http}; +use hyper::client::connect::Connect; +use hyper::server::conn::Http; use tokio::{ net::TcpListener, sync::mpsc::{self, Receiver}, @@ -111,7 +112,7 @@ where impl PacketAcceptor where - T: Connect + Clone + Send + Sync + 'static, + T: Connect + Clone + Sync + Send + 'static, { async fn start_https_service( self,