simplified

This commit is contained in:
Jun Kurihara 2022-06-17 19:01:18 -04:00
commit 634d556ea9
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
3 changed files with 156 additions and 121 deletions

View file

@ -4,8 +4,11 @@ use futures::{
Future, Future,
}; };
use hyper::{ use hyper::{
client::connect::Connect, http, server::conn::Http, Body, Client, HeaderMap, Method, Request, client::connect::Connect,
Response, StatusCode, http,
server::conn::Http,
service::{service_fn, Service},
Body, Client, HeaderMap, Method, Request, Response, StatusCode,
}; };
use std::{net::SocketAddr, pin::Pin, sync::Arc}; use std::{net::SocketAddr, pin::Pin, sync::Arc};
use tokio::{ use tokio::{
@ -48,135 +51,157 @@ where
#[derive(Clone)] #[derive(Clone)]
pub struct PacketAcceptor<T> pub struct PacketAcceptor<T>
where where
T: hyper::client::connect::Connect + Send + Sync + Clone + 'static, T: Connect + Clone + Sync + Send + 'static,
{ {
pub listening_on: SocketAddr, pub listening_on: SocketAddr,
pub forwarder: Client<T>, pub forwarder: Arc<Client<T>>,
pub globals: Arc<Globals>, pub globals: Arc<Globals>,
} }
#[allow(clippy::type_complexity)] // impl<T> Service<http::Request<Body>> for PacketAcceptor<T>
impl<T> hyper::service::Service<http::Request<Body>> for PacketAcceptor<T> // where
where // T: Connect + Clone + Sync + Send + 'static,
T: Connect + Clone + Send + Sync + 'static, // {
{ // type Response = Response<Body>;
type Response = Response<Body>;
type Error = http::Error; // type Error = http::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; // type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { // fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(())) // Poll::Ready(Ok(()))
} // }
fn call(&mut self, req: Request<Body>) -> Self::Future { // fn call(&mut self, req: Request<Body>) -> Self::Future {
debug!("\nserve: {:?}\n{:?}", self.listening_on, req); // debug!(
let self_inner = self.clone(); // "serving {:?} {:?} request to {:?}",
// req.version(),
// req.method(),
// req.uri()
// );
// let self_inner = self.clone();
// 1. check uri (domain queried host name) // // 1. check uri (domain queried host name)
// 2. build uri to forwarding target destination // // 2. build uri to forwarding target destination
// 3. build request from uri and body // // 3. build request from uri and body
// 4. send request to forwarding target // // 4. send request to forwarding target
if *req.method() == Method::GET { // if *req.method() == Method::GET {
Box::pin(async move { // Box::pin(async move {
// let uri = req.uri(); // // let uri = req.uri();
let target_uri = hyper::Uri::builder() // let target_uri = hyper::Uri::builder()
.scheme("https") // .scheme("https")
.authority("www.google.com") // .authority("www.google.com")
.path_and_query("/") // .path_and_query("/")
.build() // .build()
.unwrap(); // .unwrap();
println!("{:?}", target_uri); // println!("{:?}", target_uri);
match self_inner.forwarder.get(target_uri).await { // match self_inner.forwarder.get(target_uri).await {
Ok(res) => Ok(res), // Ok(res) => Ok(res),
Err(e) => { // Err(e) => {
error!("{:?}", e); // error!("{:?}", e);
http_error(StatusCode::INTERNAL_SERVER_ERROR) // http_error(StatusCode::INTERNAL_SERVER_ERROR)
} // }
} // }
}) // })
} else { // } else {
// let globals = &self.doh.globals; // // let globals = &self.doh.globals;
// let self_inner = self.clone(); // // let self_inner = self.clone();
// if req.uri().path() == globals.path { // // if req.uri().path() == globals.path {
// Box::pin(async move { // // Box::pin(async move {
// let mut subscriber = None; // // let mut subscriber = None;
// if self_inner.doh.globals.enable_auth_target { // // if self_inner.doh.globals.enable_auth_target {
// subscriber = match auth::authenticate( // // subscriber = match auth::authenticate(
// &self_inner.doh.globals, // // &self_inner.doh.globals,
// &req, // // &req,
// ValidationLocation::Target, // // ValidationLocation::Target,
// &self_inner.peer_addr, // // &self_inner.peer_addr,
// ) { // // ) {
// Ok((sub, aud)) => { // // Ok((sub, aud)) => {
// debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); // // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
// sub // // sub
// } // // }
// Err(e) => { // // Err(e) => {
// error!("{:?}", e); // // error!("{:?}", e);
// return Ok(e); // // return Ok(e);
// } // // }
// }; // // };
// } // // }
// match *req.method() { // // match *req.method() {
// Method::POST => self_inner.doh.serve_post(req, subscriber).await, // // Method::POST => self_inner.doh.serve_post(req, subscriber).await,
// Method::GET => self_inner.doh.serve_get(req, subscriber).await, // // Method::GET => self_inner.doh.serve_get(req, subscriber).await,
// _ => http_error(StatusCode::METHOD_NOT_ALLOWED), // // _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
// } // // }
// }) // // })
// } else if req.uri().path() == globals.odoh_configs_path { // // } else if req.uri().path() == globals.odoh_configs_path {
// match *req.method() { // // match *req.method() {
// Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }), // // Method::GET => Box::pin(async move { self_inner.doh.serve_odoh_configs().await }),
// _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }), // // _ => Box::pin(async { http_error(StatusCode::METHOD_NOT_ALLOWED) }),
// } // // }
// } else { // // } else {
// #[cfg(not(feature = "odoh-proxy"))] // // #[cfg(not(feature = "odoh-proxy"))]
// { // // {
// Box::pin(async { http_error(StatusCode::NOT_FOUND) }) // // Box::pin(async { http_error(StatusCode::NOT_FOUND) })
// } // // }
// #[cfg(feature = "odoh-proxy")] // // #[cfg(feature = "odoh-proxy")]
// { // // {
// if req.uri().path() == globals.odoh_proxy_path { // // if req.uri().path() == globals.odoh_proxy_path {
// Box::pin(async move { // // Box::pin(async move {
// let mut subscriber = None; // // let mut subscriber = None;
// if self_inner.doh.globals.enable_auth_proxy { // // if self_inner.doh.globals.enable_auth_proxy {
// subscriber = match auth::authenticate( // // subscriber = match auth::authenticate(
// &self_inner.doh.globals, // // &self_inner.doh.globals,
// &req, // // &req,
// ValidationLocation::Proxy, // // ValidationLocation::Proxy,
// &self_inner.peer_addr, // // &self_inner.peer_addr,
// ) { // // ) {
// Ok((sub, aud)) => { // // Ok((sub, aud)) => {
// debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud); // // debug!("Valid token or allowed ip: sub={:?}, aud={:?}", &sub, &aud);
// sub // // sub
// } // // }
// Err(e) => { // // Err(e) => {
// error!("{:?}", e); // // error!("{:?}", e);
// return Ok(e); // // return Ok(e);
// } // // }
// }; // // };
// } // // }
// // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11 // // // Draft: https://datatracker.ietf.org/doc/html/draft-pauly-dprive-oblivious-doh-11
// // Golang impl.: https://github.com/cloudflare/odoh-server-go // // // Golang impl.: https://github.com/cloudflare/odoh-server-go
// // Based on the draft and Golang implementation, only post method is allowed. // // // Based on the draft and Golang implementation, only post method is allowed.
// match *req.method() { // // match *req.method() {
// Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await, // // Method::POST => self_inner.doh.serve_odoh_proxy_post(req, subscriber).await,
// _ => http_error(StatusCode::METHOD_NOT_ALLOWED), // // _ => http_error(StatusCode::METHOD_NOT_ALLOWED),
// } // // }
// }) // // })
// } else { // // } else {
Box::pin(async { http_error(StatusCode::NOT_FOUND) }) // Box::pin(async { http_error(StatusCode::NOT_FOUND) })
} // }
// } // // }
// } // // }
// } // // }
} // }
// }
async fn handle_request(
req: Request<Body>,
client_ip: SocketAddr,
globals: Arc<Globals>,
) -> Result<Response<Body>, http::Error> {
// http_error(StatusCode::NOT_FOUND)
debug!("{:?}", req);
// if req.version() == hyper::Version::HTTP_11 {
// Ok(Response::new(Body::from("Hello World")))
// } else {
// Note: it's usually better to return a Response
// with an appropriate StatusCode instead of an Err.
// Err("not HTTP/1.1, abort connection")
http_error(StatusCode::NOT_FOUND)
// }
// });
} }
impl<T> PacketAcceptor<T> impl<T> PacketAcceptor<T>
where where
T: Connect + Clone + Send + Sync + 'static, T: Connect + Clone + Sync + Send + 'static,
{ {
pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr) pub async fn client_serve<I>(self, stream: I, server: Http<LocalExecutor>, peer_addr: SocketAddr)
where where
@ -187,13 +212,21 @@ where
clients_count.decrement(); clients_count.decrement();
return; return;
} }
self.globals.runtime_handle.clone().spawn(async move { self.globals.runtime_handle.clone().spawn(async move {
tokio::time::timeout( tokio::time::timeout(
self.globals.timeout + Duration::from_secs(1), self.globals.timeout + Duration::from_secs(1),
server.serve_connection(stream, self), // server.serve_connection(stream, self),
server.serve_connection(
stream,
service_fn(move |req: Request<Body>| {
handle_request(req, peer_addr, self.globals.clone())
}),
),
) )
.await .await
.ok(); .ok();
clients_count.decrement(); clients_count.decrement();
}); });
} }

View file

@ -19,11 +19,12 @@ impl Proxy {
let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector(); let connector = TrustDnsResolver::default().into_rustls_webpki_https_connector();
#[cfg(not(feature = "forward-hyper-trust-dns"))] #[cfg(not(feature = "forward-hyper-trust-dns"))]
let connector = hyper_tls::HttpsConnector::new(); let connector = hyper_tls::HttpsConnector::new();
let forwarder = Arc::new(Client::builder().build::<_, hyper::Body>(connector));
let acceptor = PacketAcceptor { let acceptor = PacketAcceptor {
listening_on: addr, listening_on: addr,
globals: self.globals.clone(), globals: self.globals.clone(),
forwarder: Client::builder().build::<_, hyper::Body>(connector), forwarder,
}; };
self.globals.runtime_handle.spawn(acceptor.start()) self.globals.runtime_handle.spawn(acceptor.start())
})); }));

View file

@ -5,7 +5,8 @@ use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use futures::{future::FutureExt, join, select}; use futures::{future::FutureExt, join, select};
use hyper::{client::connect::Connect, server::conn::Http}; use hyper::client::connect::Connect;
use hyper::server::conn::Http;
use tokio::{ use tokio::{
net::TcpListener, net::TcpListener,
sync::mpsc::{self, Receiver}, sync::mpsc::{self, Receiver},
@ -111,7 +112,7 @@ where
impl<T> PacketAcceptor<T> impl<T> PacketAcceptor<T>
where where
T: Connect + Clone + Send + Sync + 'static, T: Connect + Clone + Sync + Send + 'static,
{ {
async fn start_https_service( async fn start_https_service(
self, self,