fix redirection

This commit is contained in:
Jun Kurihara 2022-06-23 23:10:19 -04:00
commit c47efbfc93
No known key found for this signature in database
GPG key ID: 48ADFD173ED22B03
6 changed files with 72 additions and 49 deletions

View file

@ -21,7 +21,7 @@ pub struct Backend {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct ReverseProxy { pub struct ReverseProxy {
pub default_destination_uri: hyper::Uri, pub default_destination_uri: hyper::Uri,
pub destination_uris: Option<HashMap<String, hyper::Uri>>, // TODO: url pathで引っ掛ける。 pub destination_uris: HashMap<String, hyper::Uri>, // TODO: url pathで引っ掛ける。
} }
impl Backend { impl Backend {

View file

@ -33,7 +33,7 @@ pub fn parse_opts(globals: &mut Globals, backends: &mut HashMap<String, Backend>
hostname: "localhost".to_string(), hostname: "localhost".to_string(),
reverse_proxy: ReverseProxy { reverse_proxy: ReverseProxy {
default_destination_uri: "https://google.com/".parse::<Uri>().unwrap(), default_destination_uri: "https://google.com/".parse::<Uri>().unwrap(),
destination_uris: Some(map_example), destination_uris: map_example,
}, },
redirect_to_https: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず redirect_to_https: Some(true), // TODO: ここはtlsが存在する時はSomeにすべき。Noneはtlsがないときのみのはず

View file

@ -1,6 +1,6 @@
pub const LISTEN_ADDRESSES: &[&str] = &["0.0.0.0", "[::]"]; pub const LISTEN_ADDRESSES: &[&str] = &["0.0.0.0", "[::]"];
pub const HTTP_LISTEN_PORT: u32 = 8080; pub const HTTP_LISTEN_PORT: u16 = 8080;
pub const HTTPS_LISTEN_PORT: u32 = 8443; pub const HTTPS_LISTEN_PORT: u16 = 8443;
pub const TIMEOUT_SEC: u64 = 10; pub const TIMEOUT_SEC: u64 = 10;
pub const MAX_CLIENTS: usize = 512; pub const MAX_CLIENTS: usize = 512;
pub const MAX_CONCURRENT_STREAMS: u32 = 16; pub const MAX_CONCURRENT_STREAMS: u32 = 16;

View file

@ -8,8 +8,8 @@ use tokio::time::Duration;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Globals { pub struct Globals {
pub listen_sockets: Vec<SocketAddr>, pub listen_sockets: Vec<SocketAddr>,
pub http_port: Option<u32>, pub http_port: Option<u16>,
pub https_port: Option<u32>, pub https_port: Option<u16>,
pub timeout: Duration, pub timeout: Duration,
pub max_clients: usize, pub max_clients: usize,

View file

@ -79,7 +79,7 @@ async fn entrypoint(globals: Arc<Globals>, backends: Arc<HashMap<String, Backend
let futures = select_all(addresses.into_iter().map(|addr| { let futures = select_all(addresses.into_iter().map(|addr| {
let mut tls_enabled = false; let mut tls_enabled = false;
if let Some(https_port) = globals.https_port { if let Some(https_port) = globals.https_port {
tls_enabled = https_port == (addr.port() as u32) tls_enabled = https_port == (addr.port() as u16)
} }
info!("Listen address: {:?} (TLS = {})", addr, tls_enabled); info!("Listen address: {:?} (TLS = {})", addr, tls_enabled);

View file

@ -27,12 +27,13 @@ pub async fn handle_request(
globals: Arc<Globals>, globals: Arc<Globals>,
backends: Arc<HashMap<String, Backend>>, backends: Arc<HashMap<String, Backend>>,
) -> Result<Response<Body>> { ) -> Result<Response<Body>> {
let headers = req.headers(); debug!("req: {:?}", req);
// Here we start to handle with hostname // Here we start to handle with hostname
// Find backend application for given hostname // Find backend application for given hostname
let (hostname, port) = parse_hostname_port(headers, tls_enabled)?; let (hostname, _port) = parse_hostname_port(&req, tls_enabled)?;
let path = req.uri().path(); let path = req.uri().path();
let path_and_query = req.uri().path_and_query().unwrap().as_str();
println!("{:?}", path_and_query);
let backend = if let Some(be) = backends.get(hostname.as_str()) { let backend = if let Some(be) = backends.get(hostname.as_str()) {
be be
} else { } else {
@ -41,31 +42,18 @@ pub async fn handle_request(
// Redirect to https if tls_enabled is false and redirect_to_https is true // Redirect to https if tls_enabled is false and redirect_to_https is true
if !tls_enabled && backend.redirect_to_https.unwrap_or(false) { if !tls_enabled && backend.redirect_to_https.unwrap_or(false) {
if let Some(https_port) = globals.https_port { debug!("Redirect to https: {}", hostname);
let dest = if https_port == 443 { return https_redirection(hostname, globals.https_port, path_and_query);
format!("https://{}{}", hostname, path)
} else {
format!(
"https://{}:{}{}",
hostname,
globals.https_port.unwrap(),
path
)
};
return https_redirection(dest);
} else {
return http_error(StatusCode::SERVICE_UNAVAILABLE);
}
} }
// Find reverse proxy for given path // Find reverse proxy for given path
// let destination_uri = if backend.reverse_proxy.destination_uris.is_some() { let destination_uri = if let Some(uri) = backend.reverse_proxy.destination_uris.get(path) {
// if let (b) = backend.re uri.to_owned()
// } else { } else {
// backend.reverse_proxy.default_destination_uri.clone(); backend.reverse_proxy.default_destination_uri.clone()
// }; };
debug!("path: {}", req.uri().path()); debug!("destination_uri: {}", destination_uri);
// if req.version() == hyper::Version::HTTP_11 { // if req.version() == hyper::Version::HTTP_11 {
// Ok(Response::new(Body::from("Hello World"))) // Ok(Response::new(Body::from("Hello World")))
// } else { // } else {
@ -73,7 +61,7 @@ pub async fn handle_request(
// with an appropriate StatusCode instead of an Err. // with an appropriate StatusCode instead of an Err.
// Err("not HTTP/1.1, abort connection") // Err("not HTTP/1.1, abort connection")
// http_error(StatusCode::NOT_FOUND) // http_error(StatusCode::NOT_FOUND)
https_redirection("https://www.google.com/".to_string()) https_redirection("www.google.com".to_string(), Some(443_u16), "/")
// } // }
// }); // });
} }
@ -86,32 +74,67 @@ fn http_error(status_code: StatusCode) -> Result<Response<Body>> {
Ok(response) Ok(response)
} }
fn https_redirection(redirect_to: String) -> Result<Response<Body>> { fn https_redirection(
hostname: String,
https_port: Option<u16>,
path_and_query: &str,
) -> Result<Response<Body>> {
let dest_uri: String = if let Some(https_port) = https_port {
if https_port == 443 {
format!("https://{}{}", hostname, path_and_query)
} else {
format!("https://{}:{}{}", hostname, https_port, path_and_query)
}
} else {
return http_error(StatusCode::SERVICE_UNAVAILABLE);
};
let response = Response::builder() let response = Response::builder()
.status(StatusCode::MOVED_PERMANENTLY) .status(StatusCode::MOVED_PERMANENTLY)
.header("Location", redirect_to) .header("Location", dest_uri)
.body(Body::empty()) .body(Body::empty())
.unwrap(); .unwrap();
Ok(response) Ok(response)
} }
fn parse_hostname_port(headers: &HeaderMap, tls_enabled: bool) -> Result<(String, u16)> { fn parse_hostname_port(req: &Request<Body>, tls_enabled: bool) -> Result<(String, u16)> {
let hostname_port = headers let hostname_port_headers = req.headers().get("host");
.get("host") let hostname_uri = req.uri().host();
.ok_or_else(|| anyhow!("No host in request header"))?; let port_uri = req.uri().port_u16();
let hp_as_uri = hostname_port.to_str().unwrap().parse::<Uri>().unwrap();
let hostname = hp_as_uri if hostname_port_headers.is_none() && hostname_uri.is_none() {
.host() bail!("No host in request header");
.ok_or_else(|| anyhow!("Failed to parse hostname"))?; }
let port = if let Some(p) = hp_as_uri.port() { let (hostname, port) = match (hostname_uri, hostname_port_headers) {
p.as_u16() (Some(x), _) => {
} else if tls_enabled { let hostname = hostname_uri.unwrap();
443 let port = if let Some(p) = port_uri {
} else { p
80 } else if tls_enabled {
443
} else {
80
};
(hostname.to_string(), port)
}
(None, Some(x)) => {
let hp_as_uri = x.to_str().unwrap().parse::<Uri>().unwrap();
let hostname = hp_as_uri
.host()
.ok_or_else(|| anyhow!("Failed to parse hostname"))?;
let port = if let Some(p) = hp_as_uri.port() {
p.as_u16()
} else if tls_enabled {
443
} else {
80
};
(hostname.to_string(), port)
}
(None, None) => {
bail!("Host unspecified in request")
}
}; };
Ok((hostname.to_string(), port)) Ok((hostname, port))
} }