diff --git a/src/proxy/proxy_h3.rs b/src/proxy/proxy_h3.rs index 7df9f34..ba30414 100644 --- a/src/proxy/proxy_h3.rs +++ b/src/proxy/proxy_h3.rs @@ -1,14 +1,9 @@ use super::Proxy; use crate::{error::*, log::*}; -use bytes::{Buf, Bytes, BytesMut}; -use futures::{FutureExt, StreamExt}; +use bytes::{Buf, Bytes}; use h3::{quic::BidiStream, server::RequestStream}; -use hyper::body::HttpBody; -use hyper::http::request; -use hyper::Response; -use hyper::{client::connect::Connect, Body, Request}; -use std::{ascii, str}; -use std::{net::SocketAddr, path::PathBuf, sync::Arc}; +use hyper::{client::connect::Connect, Body, HeaderMap, Request, Response}; +use std::net::SocketAddr; impl Proxy where @@ -19,103 +14,47 @@ where match conn.await { Ok(new_conn) => { - debug!( - "HTTP/3 connection established from {:?} {:?}", - client_addr, - { - let hsd = new_conn - .connection - .handshake_data() - .ok_or_else(|| anyhow!(""))? - .downcast::() - .map_err(|_| anyhow!(""))?; - ( - hsd.protocol.map_or_else( - || "".into(), - |x| String::from_utf8_lossy(&x).into_owned(), - ), - hsd.server_name.map_or_else(|| "".into(), |x| x), - ) - } - ); + info!("QUIC connection established from {:?} {:?}", client_addr, { + let hsd = new_conn + .connection + .handshake_data() + .ok_or_else(|| anyhow!(""))? + .downcast::() + .map_err(|_| anyhow!(""))?; + ( + hsd.protocol.map_or_else( + || "".into(), + |x| String::from_utf8_lossy(&x).into_owned(), + ), + hsd.server_name.map_or_else(|| "".into(), |x| x), + ) + }); let mut h3_conn = h3::server::Connection::<_, bytes::Bytes>::new(h3_quinn::Connection::new(new_conn)) .await?; + info!("HTTP/3 connection established"); - // let self_inner = self.clone(); - while let Some((req, stream)) = h3_conn.accept().await? { - debug!("New request: {:?}", req); + while let Some((req, stream)) = h3_conn + .accept() + .await + .map_err(|e| anyhow!("HTTP/3 accept failed (likely timeout): {}", e))? + { + info!("HTTP/3 new request received"); let self_inner = self.clone(); self.globals.runtime_handle.spawn(async move { - let res = self_inner.handle_request_h3(req, stream, client_addr).await; - // if let Err(e) = handle_request(req, stream).await { - // error!("HTTP/3 request failed: {}", e); - // } - // }); - // tokio::spawn(async { - // if let Err(e) = handle_request(req, stream, root).await { - // error!("request failed: {}", e); - // } + if let Err(e) = self_inner.handle_request_h3(req, stream, client_addr).await { + error!("HTTP/3 request failed: {}", e); + } }); } } Err(err) => { - warn!("HTTP/3 accepting connection failed: {:?}", err); + warn!("QUIC accepting connection failed: {:?}", err); } } - // let quinn::NewConnection { - // connection, - // mut bi_streams, - // .. - // } = conn.await?; - // async { - // debug!( - // "HTTP/3 connection established from {:?} (ALPN {:?}, SNI: {:?})", - // connection.remote_address(), - // connection - // .handshake_data() - // .unwrap() - // .downcast::() - // .unwrap() - // .protocol - // .map_or_else( - // || "".into(), - // |x| String::from_utf8_lossy(&x).into_owned() - // ), - // connection - // .handshake_data() - // .unwrap() - // .downcast::() - // .unwrap() - // .server_name - // .map_or_else(|| "".into(), |x| x) - // ); - // // Each stream initiated by the client constitutes a new request. - // while let Some(stream) = bi_streams.next().await { - // let stream = match stream { - // Err(quinn::ConnectionError::ApplicationClosed { .. }) => { - // debug!("HTTP/3 connection closed"); - // return Ok(()); - // } - // Err(e) => { - // return Err(e); - // } - // Ok(s) => s, - // }; - // let fut = handle_request_h3(stream); - // tokio::spawn(async move { - // if let Err(e) = fut.await { - // error!("failed: {reason}", reason = e.to_string()); - // } - // }); - // } - // Ok(()) - // } - // .await?; - // Ok(()) Ok(()) } @@ -130,25 +69,27 @@ where { let (req_parts, _) = req.into_parts(); - let body = if let Some(request_body) = stream.recv_data().await? { - let chunk = request_body.chunk(); - Body::from(chunk.to_owned()) - } else { + // TODO: h3 -> h2/http1.1などのプロトコル変換がなければ、bodyはBytes単位で直でsend_dataして転送した方がいい。やむなし。 + let mut body_chunk: Vec = Vec::new(); + while let Some(request_body) = stream.recv_data().await? { + body_chunk.extend_from_slice(request_body.chunk()); + } + let body = if body_chunk.is_empty() { Body::default() + } else { + debug!("HTTP/3 request with non-empty body"); + Body::from(body_chunk) + }; + // trailers + let trailers = if let Some(trailers) = stream.recv_trailers().await? { + debug!("HTTP/3 request with trailers"); + trailers + } else { + HeaderMap::new() }; - let mut new_req: Request = Request::from_parts(req_parts, body); - if let Some(request_trailers) = stream.recv_trailers().await? { - let headers = new_req.headers_mut(); - for (ok, v) in request_trailers { - if let Some(k) = ok { - headers.insert(k, v); - } - } - }; - + let new_req: Request = Request::from_parts(req_parts, body); let res = self.handle_request(new_req, client_addr).await?; - println!("{:?}", res); let (new_res_parts, new_body) = res.into_parts(); let new_res = Response::from_parts(new_res_parts, ()); @@ -158,57 +99,12 @@ where debug!("HTTP/3 response to connection successful"); let data = hyper::body::to_bytes(new_body).await?; stream.send_data(data).await?; + stream.send_trailers(trailers).await?; } Err(err) => { error!("Unable to send response to connection peer: {:?}", err); } } - Ok(()) + Ok(stream.finish().await?) } } - -// TODO: -// async fn handle_request_h3((mut send, recv): (quinn::SendStream, quinn::RecvStream)) -> Result<()> { -// let req = recv -// .read_to_end(64 * 1024) -// .await -// .map_err(|e| anyhow!("failed reading request: {}", e))?; - -// // let hyper_req = hyper::Request::try_from(req.clone()); - -// let mut escaped = String::new(); -// for &x in &req[..] { -// let part = ascii::escape_default(x).collect::>(); -// escaped.push_str(str::from_utf8(&part).unwrap()); -// } -// info!("content = {:?}", escaped); -// // Execute the request -// let resp = process_get(&req).unwrap_or_else(|e| { -// error!("failed: {}", e); -// format!("failed to process request: {}\n", e).into_bytes() -// }); -// // Write the response -// send -// .write_all(&resp) -// .await -// .map_err(|e| anyhow!("failed to send response: {}", e))?; -// // Gracefully terminate the stream -// send -// .finish() -// .await -// .map_err(|e| anyhow!("failed to shutdown stream: {}", e))?; -// info!("complete"); -// Ok(()) -// } - -// fn process_get(x: &[u8]) -> Result> { -// if x.len() < 4 || &x[0..4] != b"GET " { -// bail!("missing GET"); -// } -// if x[4..].len() < 2 || &x[x.len() - 2..] != b"\r\n" { -// bail!("missing \\r\\n"); -// } - -// let data = b"hello world!".to_vec(); -// Ok(data) -// } diff --git a/src/proxy/proxy_handler.rs b/src/proxy/proxy_handler.rs index 08be2c6..3b947fa 100644 --- a/src/proxy/proxy_handler.rs +++ b/src/proxy/proxy_handler.rs @@ -98,6 +98,14 @@ where return http_error(StatusCode::BAD_REQUEST); } }; + #[cfg(feature = "h3")] + { + if let Some(port) = self.globals.https_port { + res_backend + .headers_mut() + .insert("alt-svc", format!("h3=\":{}\"", port).parse().unwrap()); + } + } debug!("Response from backend: {:?}", res_backend.status()); if res_backend.status() == StatusCode::SWITCHING_PROTOCOLS { @@ -156,12 +164,12 @@ fn generate_request_forwarded( debug!("Generate request to be forwarded"); // Add te: trailer if contained in original request - let te_trailer = { + let te_trailers = { if let Some(te) = req.headers().get("te") { te.to_str() .unwrap() .split(',') - .any(|x| x.trim() == "trailer") + .any(|x| x.trim() == "trailers") } else { false } @@ -175,7 +183,7 @@ fn generate_request_forwarded( // X-Forwarded-For add_forwarding_header(headers, client_addr)?; // Add te: trailer if te_trailer - if te_trailer { + if te_trailers { headers.insert("te", "trailer".parse().unwrap()); } diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index 48b0b84..c1abba1 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -4,6 +4,7 @@ use crate::{constants::CERTS_WATCH_DELAY_SECS, error::*, log::*}; use futures::StreamExt; use futures::{future::FutureExt, join, select}; use hyper::{client::connect::Connect, server::conn::Http}; +use rustls::ServerConfig; use std::{sync::Arc, time::Duration}; use tokio::net::TcpListener; @@ -58,25 +59,13 @@ where info!("No SNI in ClientHello"); continue; }; - let backend_serve = if let Some(backend_serve) = self.backends.apps.get(svn){ - backend_serve - } else { - info!("No configuration for the server name {} given in client_hello", svn); - continue; - }; - - if backend_serve.tls_cert_path.is_none() { // at least cert does exit - debug!("SNI indicates a site that doesn't support TLS."); - continue; - } - let server_config = if let Some(p) = backend_serve.get_tls_server_config(){ + let server_crypto = if let Some(p) = self.fetch_server_crypto(svn) { p } else { - error!("Failed to load server config"); continue; }; // Finally serve the TLS connection - if let Ok(stream) = start.into_stream(Arc::new(server_config)).await { + if let Ok(stream) = start.into_stream(Arc::new(server_crypto)).await { self.clone().client_serve(stream, server.clone(), _client_addr).await } } @@ -86,11 +75,23 @@ where Ok(()) as Result<()> }; - /////////////////////// TODO:!!!!! + /////////////////////// #[cfg(feature = "h3")] let listener_service_h3 = async { - // TODO: とりあえずデフォルトのserver_cryptoが必要になりそう - let backend_serve = self.backends.apps.get("localhost").unwrap(); + // TODO: Work around to initially serve incoming connection + let tls_app_names: Vec = self + .backends + .apps + .iter() + .filter(|&(_, backend)| { + backend.tls_cert_key_path.is_some() && backend.tls_cert_path.is_some() + }) + .map(|(name, _)| name.to_string()) + .collect(); + ensure!(!tls_app_names.is_empty(), "No TLS supported app"); + let initial_app_name = tls_app_names.get(0).unwrap().as_str(); + info!("Initial app_name: {}", initial_app_name); + let backend_serve = self.backends.apps.get(initial_app_name).unwrap(); let server_crypto = backend_serve.get_tls_server_config().unwrap(); let server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); @@ -98,23 +99,39 @@ where quinn::Endpoint::server(server_config_h3, self.listening_on).unwrap(); debug!("HTTP/3 UDP listening on {}", endpoint.local_addr().unwrap()); + // let peekable_incoming = incoming.peekable(); + while let Some(mut conn) = incoming.next().await { - debug!("HTTP/3 connection incoming"); let hsd = conn.handshake_data().await; - let hsd_downcast = hsd - .unwrap() + let hsd_downcast = hsd? .downcast::() .unwrap(); - debug!("HTTP/3 SNI: {:?}", hsd_downcast.server_name); - // TODO: ServerConfig::set_server_configでSNIに応じて再セット + let svn = if let Some(sni) = hsd_downcast.server_name { + info!("HTTP/3 connection incoming (SNI {:?})", sni); + sni + } else { + debug!("HTTP/3 no SNI is given"); + continue; + }; + + let new_server_crypto = if let Some(p) = self.fetch_server_crypto(&svn) { + p + } else { + continue; + }; + // Set ServerConfig::set_server_config for given SNI + let new_server_config_h3 = quinn::ServerConfig::with_crypto(Arc::new(new_server_crypto)); + endpoint.set_server_config(Some(new_server_config_h3)); let fut = self.clone().client_serve_h3(conn); self.globals.runtime_handle.spawn(async { if let Err(e) = fut.await { - error!("connection failed: {reason}", reason = e.to_string()) + warn!("QUIC or HTTP/3 connection failed: {}", e) } }); } + endpoint.wait_idle().await; + Ok(()) as Result<()> }; #[cfg(not(feature = "h3"))] @@ -126,4 +143,28 @@ where join!(listener_service, cert_service, listener_service_h3).0 } } + + fn fetch_server_crypto(&self, server_name: &str) -> Option { + let backend_serve = if let Some(backend_serve) = self.backends.apps.get(server_name) { + backend_serve + } else { + warn!( + "No configuration for the server name {} given in client_hello", + server_name + ); + return None; + }; + + if backend_serve.tls_cert_path.is_none() { + // at least cert does exit + warn!("SNI indicates a site that doesn't support TLS."); + return None; + } + if let Some(p) = backend_serve.get_tls_server_config() { + Some(p) + } else { + error!("Failed to load server config"); + None + } + } }