diff --git a/config-example.toml b/config-example.toml index 8e08318..e1c71cd 100644 --- a/config-example.toml +++ b/config-example.toml @@ -83,8 +83,8 @@ load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none upstream_options = [ "upgrade_insecure_requests", "force_http11_upstream", - "set_upstream_host", # overwrite HOST value with upstream hostname (like www.yahoo.com) - "forwarded_header" # add Forwarded header + "set_upstream_host", # overwrite HOST value with upstream hostname (like www.yahoo.com) + "forwarded_header" # add Forwarded header (by default, this is not added. However, if the incoming request has Forwarded header, it would be preserved and updated) ] ###################################################################### diff --git a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs index 9f74fe8..08227f7 100644 --- a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs +++ b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs @@ -81,7 +81,7 @@ where .unwrap_or(false) }; - let original_uri = req.uri().to_string(); + let original_uri = req.uri().clone(); let headers = req.headers_mut(); // delete headers specified in header.connection remove_connection_header(headers); @@ -98,18 +98,6 @@ where // by default, add "host" header of original server_name if not exist if req.headers().get(header::HOST).is_none() { let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); - // Omit port 80 if !tls_enabled, omit port 443 if tls_enabled - let org_host = req - .uri() - .port_u16() - .map(|port| { - if (tls_enabled && port == 443) || (!tls_enabled && port == 80) { - org_host.clone() - } else { - format!("{}:{}", org_host, port) - } - }) - .unwrap_or(org_host); req.headers_mut().insert(header::HOST, HeaderValue::from_str(&org_host)?); }; @@ -139,7 +127,7 @@ where // apply upstream-specific headers given in upstream_option let headers = req.headers_mut(); // apply upstream options to header, after X-Forwarded-For is added - apply_upstream_options_to_header(headers, &upstream_chosen.uri, upstream_candidates)?; + apply_upstream_options_to_header(headers, &upstream_chosen.uri, upstream_candidates, &original_uri)?; // update uri in request ensure!( diff --git a/rpxy-lib/src/message_handler/utils_headers.rs b/rpxy-lib/src/message_handler/utils_headers.rs index 61a1d50..7f84494 100644 --- a/rpxy-lib/src/message_handler/utils_headers.rs +++ b/rpxy-lib/src/message_handler/utils_headers.rs @@ -102,6 +102,7 @@ pub(super) fn apply_upstream_options_to_header( upstream_base_uri: &Uri, // _client_addr: &SocketAddr, upstream: &UpstreamCandidates, + original_uri: &Uri, ) -> Result<()> { for opt in upstream.options.iter() { match opt { @@ -121,11 +122,9 @@ pub(super) fn apply_upstream_options_to_header( UpstreamOption::ForwardedHeader => { // This is called after X-Forwarded-For is added // Generate RFC 7239 Forwarded header - // TODO: host is generated from x-original-uri - let host = headers.get(header::HOST).and_then(|h| h.to_str().ok()).unwrap_or("unknown"); let tls = upstream_base_uri.scheme_str() == Some("https"); - match generate_forwarded_header(headers, tls, host) { + match generate_forwarded_header(headers, tls, original_uri) { Ok(forwarded_value) => { add_header_entry_overwrite_if_exist(headers, "forwarded", forwarded_value)?; } @@ -220,7 +219,7 @@ pub(super) fn add_forwarding_header( client_addr: &SocketAddr, listen_addr: &SocketAddr, tls: bool, - uri_str: &str, + original_uri: &Uri, ) -> Result<()> { let canonical_client_addr = client_addr.to_canonical().ip().to_string(); let has_forwarded = headers.contains_key("forwarded"); @@ -241,9 +240,7 @@ pub(super) fn add_forwarding_header( // IMPORTANT: If Forwarded header exists, always update it for consistency // This ensures headers remain consistent even when forwarded_header upstream option is not specified if has_forwarded { - let host = headers.get(header::HOST).and_then(|h| h.to_str().ok()).unwrap_or("unknown"); - - match generate_forwarded_header(headers, tls, host) { + match generate_forwarded_header(headers, tls, original_uri) { Ok(forwarded_value) => { add_header_entry_overwrite_if_exist(headers, "forwarded", forwarded_value)?; } @@ -272,7 +269,7 @@ pub(super) fn add_forwarding_header( // x-forwarded-ssl add_header_entry_overwrite_if_exist(headers, "x-forwarded-ssl", if tls { "on" } else { "off" })?; // x-original-uri - add_header_entry_overwrite_if_exist(headers, "x-original-uri", uri_str.to_string())?; + add_header_entry_overwrite_if_exist(headers, "x-original-uri", original_uri.to_string())?; // proxy add_header_entry_overwrite_if_exist(headers, "proxy", "")?; @@ -325,7 +322,7 @@ fn update_xff_from_forwarded(headers: &mut HeaderMap, client_addr: &SocketAddr) /// Generate RFC 7239 Forwarded header from X-Forwarded-For /// This function assumes that the X-Forwarded-For header is present and well-formed. -fn generate_forwarded_header(headers: &HeaderMap, tls: bool, host: &str) -> Result { +fn generate_forwarded_header(headers: &HeaderMap, tls: bool, original_uri: &Uri) -> Result { let for_values = headers .get("x-forwarded-for") .and_then(|h| h.to_str().ok()) @@ -356,12 +353,27 @@ fn generate_forwarded_header(headers: &HeaderMap, tls: bool, host: &str) -> Resu "for={};proto={};host={}", for_values, if tls { "https" } else { "http" }, - host + host_from_uri(original_uri)? ); Ok(forwarded_value) } +#[inline] +/// Extract host from URI +fn host_from_uri(uri: &Uri) -> Result { + uri + .host() + .map(|host| { + if let Some(port) = uri.port_u16() { + format!("{}:{}", host, port) + } else { + host.to_string() + } + }) + .ok_or_else(|| anyhow!("No host found in URI")) +} + /// Remove connection header pub(super) fn remove_connection_header(headers: &mut HeaderMap) { if let Some(values) = headers.get(header::CONNECTION) {