diff --git a/.dockerignore b/.dockerignore index 3538235..4294ee6 100644 --- a/.dockerignore +++ b/.dockerignore @@ -4,3 +4,4 @@ bench/ .private/ .github/ example-certs/ +legacy-lib/ diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2b04184..ddf4a55 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -45,34 +45,34 @@ jobs: tags-suffix: "-s2n" - target: "gnu" - build-feature: "-native-roots" + build-feature: "-webpki-roots" platform: linux/amd64 - tags-suffix: "-native-roots" + tags-suffix: "-webpki-roots" - target: "gnu" - build-feature: "-native-roots" + build-feature: "-webpki-roots" platform: linux/arm64 - tags-suffix: "-native-roots" + tags-suffix: "-webpki-roots" - target: "musl" - build-feature: "-native-roots" + build-feature: "-webpki-roots" platform: linux/amd64 - tags-suffix: "-slim-native-roots" + tags-suffix: "-slim-webpki-roots" - target: "musl" - build-feature: "-native-roots" + build-feature: "-webpki-roots" platform: linux/arm64 - tags-suffix: "-slim-native-roots" + tags-suffix: "-slim-webpki-roots" - target: "gnu" - build-feature: "-s2n-native-roots" + build-feature: "-s2n-webpki-roots" platform: linux/amd64 - tags-suffix: "-s2n-native-roots" + tags-suffix: "-s2n-webpki-roots" - target: "gnu" - build-feature: "-s2n-native-roots" + build-feature: "-s2n-webpki-roots" platform: linux/arm64 - tags-suffix: "-s2n-native-roots" + tags-suffix: "-s2n-webpki-roots" steps: - run: "echo 'The relese triggering workflows passed'" @@ -81,10 +81,9 @@ jobs: id: "set-env" run: | if [ ${{ matrix.platform }} == 'linux/amd64' ]; then PLATFORM_MAP="x86_64"; else PLATFORM_MAP="aarch64"; fi - if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_NAME="-nightly"; else BUILD_NAME=""; fi - if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_IMG="nightly"; else BUILD_IMG="latest"; fi + if [ ${{ github.ref_name }} == 'main' ]; then BUILD_IMG="latest"; else BUILD_IMG="nightly"; fi echo "build_img=${BUILD_IMG}" >> $GITHUB_OUTPUT - echo "target_name=rpxy${BUILD_NAME}-${PLATFORM_MAP}-unknown-linux-${{ matrix.target }}${{ matrix.build-feature }}" >> $GITHUB_OUTPUT + echo "target_name=rpxy-${PLATFORM_MAP}-unknown-linux-${{ matrix.target }}${{ matrix.build-feature }}" >> $GITHUB_OUTPUT - name: "docker pull and extract binary from docker image" id: "extract-binary" @@ -93,7 +92,7 @@ jobs: docker cp ${CONTAINER_ID}:/rpxy/bin/rpxy /tmp/${{ steps.set-env.outputs.target_name }} - name: "upload artifacts" - uses: actions/upload-artifact@v3 + uses: actions/upload-artifact@v4 with: name: ${{ steps.set-env.outputs.target_name }} path: "/tmp/${{ steps.set-env.outputs.target_name }}" @@ -110,7 +109,7 @@ jobs: needs: on-success steps: - name: check pull_request title - uses: kaisugi/action-regex-match@v1.0.0 + uses: kaisugi/action-regex-match@v1.0.1 id: regex-match with: text: ${{ github.event.client_payload.pull_request.title }} @@ -122,7 +121,7 @@ jobs: - name: download artifacts if: ${{ steps.regex-match.outputs.match != ''}} - uses: actions/download-artifact@v3 + uses: actions/download-artifact@v4 with: path: /tmp/rpxy diff --git a/.github/workflows/release_docker.yml b/.github/workflows/release_docker.yml index 60dd7ff..2e2aff9 100644 --- a/.github/workflows/release_docker.yml +++ b/.github/workflows/release_docker.yml @@ -2,6 +2,7 @@ name: Release - Build and publish docker, and trigger package release on: push: branches: + - "feat/*" - "develop" pull_request: types: [closed] @@ -44,7 +45,7 @@ jobs: - target: "s2n" dockerfile: ./docker/Dockerfile build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache" + "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,rustls-backend" "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" platforms: linux/amd64,linux/arm64 tags-suffix: "-s2n" @@ -53,42 +54,42 @@ jobs: jqtype/rpxy:s2n ghcr.io/junkurihara/rust-rpxy:s2n - - target: "native-roots" + - target: "webpki-roots" dockerfile: ./docker/Dockerfile platforms: linux/amd64,linux/arm64 build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" - tags-suffix: "-native-roots" + "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,webpki-roots" + tags-suffix: "-webpki-roots" # Aliases must be used only for release builds aliases: | - jqtype/rpxy:native-roots - ghcr.io/junkurihara/rust-rpxy:native-roots + jqtype/rpxy:webpki-roots + ghcr.io/junkurihara/rust-rpxy:webpki-roots - - target: "slim-native-roots" + - target: "slim-webpki-roots" dockerfile: ./docker/Dockerfile-slim build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" + "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,webpki-roots" build-contexts: | messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl platforms: linux/amd64,linux/arm64 - tags-suffix: "-slim-native-roots" + tags-suffix: "-slim-webpki-roots" # Aliases must be used only for release builds aliases: | - jqtype/rpxy:slim-native-roots - ghcr.io/junkurihara/rust-rpxy:slim-native-roots + jqtype/rpxy:slim-webpki-roots + ghcr.io/junkurihara/rust-rpxy:slim-webpki-roots - - target: "s2n-native-roots" + - target: "s2n-webpki-roots" dockerfile: ./docker/Dockerfile build-args: | - "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-roots" + "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,webpki-roots" "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" platforms: linux/amd64,linux/arm64 - tags-suffix: "-s2n-native-roots" + tags-suffix: "-s2n-webpki-roots" # Aliases must be used only for release builds aliases: | - jqtype/rpxy:s2n-native-roots - ghcr.io/junkurihara/rust-rpxy:s2n-native-roots + jqtype/rpxy:s2n-webpki-roots + ghcr.io/junkurihara/rust-rpxy:s2n-webpki-roots steps: - name: Checkout @@ -135,6 +136,23 @@ jobs: # platforms: linux/amd64 # labels: ${{ steps.meta.outputs.labels }} + - name: Unstable build and push from develop branch + if: ${{ startsWith(github.ref_name, 'feat/') && (github.event_name == 'push') }} + uses: docker/build-push-action@v5 + with: + context: . + build-args: ${{ matrix.build-args }} + push: true + tags: | + ${{ env.GHCR }}/${{ env.GHCR_IMAGE_NAME }}:unstable${{ matrix.tags-suffix }} + ${{ env.DH_REGISTRY_NAME }}:unstable${{ matrix.tags-suffix }} + build-contexts: ${{ matrix.build-contexts }} + file: ${{ matrix.dockerfile }} + cache-from: type=gha,scope=rpxy-unstable-${{ matrix.target }} + cache-to: type=gha,mode=max,scope=rpxy-unstable-${{ matrix.target }} + platforms: linux/amd64 + labels: ${{ steps.meta.outputs.labels }} + - name: Nightly build and push from develop branch if: ${{ (github.ref_name == 'develop') && (github.event_name == 'push') }} uses: docker/build-push-action@v5 @@ -176,7 +194,7 @@ jobs: needs: build_and_push steps: - name: Repository dispatch for release - uses: peter-evans/repository-dispatch@v2 + uses: peter-evans/repository-dispatch@v3 with: event-type: release-event client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "pull_request": { "title": "${{ github.event.pull_request.title }}", "body": ${{ toJson(github.event.pull_request.body) }}, "number": "${{ github.event.pull_request.number }}", "head": "${{ github.event.pull_request.head.ref }}", "base": "${{ github.event.pull_request.base.ref}}"}}' diff --git a/.gitmodules b/.gitmodules index 65fcd3b..47ebad0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,12 +1,6 @@ [submodule "submodules/h3"] path = submodules/h3 url = git@github.com:junkurihara/h3.git -[submodule "submodules/quinn"] - path = submodules/quinn - url = git@github.com:junkurihara/quinn.git -[submodule "submodules/s2n-quic"] - path = submodules/s2n-quic - url = git@github.com:junkurihara/s2n-quic.git [submodule "submodules/rusty-http-cache-semantics"] path = submodules/rusty-http-cache-semantics url = git@github.com:junkurihara/rusty-http-cache-semantics.git diff --git a/CHANGELOG.md b/CHANGELOG.md index 20ac679..13653ad 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,22 @@ # CHANGELOG -## 0.7.0 (unreleased) +## 0.8.0 (Unreleased) + +## 0.7.0 + +### Important Changes + +- Breaking: `hyper`-1.0 for both server and client modules. +- Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `keep_original_host`, and the similar option `set_upstream_host`. While `keep_original_host` can be explicitly specified, `rpxy` keeps the original `host` given by the incoming request by default. Then, the original `host` header is maintained or added from the value of url request line. If `host` header needs to be overridden with the upstream host name (backend uri's host name), `set_upstream_host` has to be set. If both of `set_upstream_host` and `keep_original_host` are set, `keep_original_host` is prioritized since it is explicitly specified. +- Breaking: Introduced `native-tls-backend` feature to use the native TLS engine to access backend applications. +- Breaking: Changed the policy of the default cert store from `webpki` to the system-native store. Thus we terminated the feature `native-roots` and introduced `webpki-roots` feature to use `webpki` root cert store. + +### Improvement + +- Redesigned: Cache structure is totally redesigned with more memory-efficient way to read from cache file, and more secure way to strongly bind memory-objects with files with hash values. +- Redesigned: HTTP body handling flow is also redesigned with more memory-and-time efficient techniques without putting the whole objects on memory by using `futures::stream::Stream` and `futures::channel::mpsc` +- Feat: Allow to disable/enable forced-connection-timeout regardless of connection status (idle or not). [default: disabled] +- Refactor: lots of minor improvements ## 0.6.2 diff --git a/Cargo.toml b/Cargo.toml index c512b18..982f81d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,14 @@ -[workspace] +[workspace.package] +version = "0.7.0" +authors = ["Jun Kurihara"] +homepage = "https://github.com/junkurihara/rust-rpxy" +repository = "https://github.com/junkurihara/rust-rpxy" +license = "MIT" +readme = "./README.md" +edition = "2021" +publish = false +[workspace] members = ["rpxy-bin", "rpxy-lib"] exclude = ["submodules"] resolver = "2" diff --git a/LICENSE b/LICENSE index 967c341..096f5e1 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Jun Kurihara +Copyright (c) 2024 Jun Kurihara Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index b714fa5..20d7891 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) ![Unit Test](https://github.com/junkurihara/rust-rpxy/actions/workflows/ci.yml/badge.svg) -![Docker](https://github.com/junkurihara/rust-rpxy/actions/workflows/docker_build_push.yml/badge.svg) +![Docker](https://github.com/junkurihara/rust-rpxy/actions/workflows/release_docker.yml/badge.svg) ![ShiftLeft Scan](https://github.com/junkurihara/rust-rpxy/actions/workflows/shift_left.yml/badge.svg) [![Docker Image Size (latest by date)](https://img.shields.io/docker/image-size/jqtype/rpxy)](https://hub.docker.com/r/jqtype/rpxy) @@ -104,11 +104,11 @@ If you want to host multiple and distinct domain names in a single IP address/po ```toml default_application = "app1" -[app.app1] +[apps.app1] server_name = "app1.example.com" #... -[app.app2] +[apps.app2] server_name = "app2.example.org" #... ``` diff --git a/TODO.md b/TODO.md index 1e25ee1..b304abd 100644 --- a/TODO.md +++ b/TODO.md @@ -1,9 +1,11 @@ # TODO List -- [Done in 0.6.0] But we need more sophistication on `Forwarder` struct. ~~Fix strategy for `h2c` requests on forwarded requests upstream. This needs to update forwarder definition. Also, maybe forwarder would have a cache corresponding to the following task.~~ -- [Initial implementation in v0.6.0] ~~**Cache option for the response with `Cache-Control: public` header directive ([#55](https://github.com/junkurihara/rust-rpxy/issues/55))**~~ Using `lru` crate might be inefficient in terms of the speed. +- Support of `rustls-0.22`. +- We need more sophistication on `Forwarder` struct to handle `h2c`. +- Cache using `lru` crate might be inefficient in terms of the speed. - Consider more sophisticated architecture for cache - Persistent cache (if possible). + - More secure cache file object naming - etc etc - Improvement of path matcher - More flexible option for rewriting path @@ -17,7 +19,7 @@ - Unit tests - Options to serve custom http_error page. -- Prometheus metrics +- Traces and metrics using opentelemetry (`tracing-opentelemetry` crate) - Documentation - Client certificate - support intermediate certificate. Currently, only supports client certificates directly signed by root CA. @@ -27,15 +29,4 @@ - Make the session-persistance option for load-balancing sophisticated. (mostly done in v0.3.0) - add option for sticky cookie name - add option for sticky cookie duration - -- Done in v0.5.0 ~~Use `gchr.io`~~ -- Done in v0.5.0: - ~~Consideration on migrating from `quinn` and `h3-quinn` to other QUIC implementations ([#57](https://github.com/junkurihara/rust-rpxy/issues/57))~~ -- Done in v0.4.0: - ~~Benchmark with other reverse proxy implementations like Sozu ([#58](https://github.com/junkurihara/rust-rpxy/issues/58)) Currently, Sozu can work only on `amd64` format due to its HTTP message parser limitation... Since the main developer have only `arm64` (Apple M1) laptops, so we should do that on VPS?~~ -- Done in v0.4.0: - ~~Split `rpxy` source codes into `rpxy-lib` and `rpxy-bin` to make the core part (reverse proxy) isolated from the misc part like toml file loader. This is in order to make the configuration-related part more flexible (related to [#33](https://github.com/junkurihara/rust-rpxy/issues/33))~~ -- Done in 0.6.0: - ~~Fix dynamic reloading of configuration file~~ - - etc. diff --git a/config-example.toml b/config-example.toml index ec79f3d..c3d1e47 100644 --- a/config-example.toml +++ b/config-example.toml @@ -57,7 +57,7 @@ upstream = [ ] load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ - "override_host", + "keep_original_host", # [default] do not overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy), which is prior to "set_upstream_host" if both are specified. "force_http2_upstream", # mutually exclusive with "force_http11_upstream" ] @@ -76,9 +76,9 @@ upstream = [ ] load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ - "override_host", "upgrade_insecure_requests", "force_http11_upstream", + "set_upstream_host", # overwrite HOST value with upstream hostname (like www.yahoo.com) ] ###################################################################### @@ -98,6 +98,11 @@ reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] # We should note that this strongly depends on the client implementation. ignore_sni_consistency = false +# Force connection handling timeout regardless of the connection status, i.e., idle or not. +# 0 represents an infinite timeout. [default: 0] +# Note that idel and header read timeouts are always specified independently of this. +connection_handling_timeout = 0 # sec + # If this specified, h3 is enabled [experimental.h3] alt_svc_max_age = 3600 # sec diff --git a/docker/docker-compose-slim.yml b/docker/docker-compose-slim.yml index 90f5e76..57f9cc9 100644 --- a/docker/docker-compose-slim.yml +++ b/docker/docker-compose-slim.yml @@ -14,8 +14,8 @@ services: additional_contexts: - messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl - messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl - # args: # Uncomment when build with native cert store - # - "CARGO_FEATURES=--no-default-features --features=http3-quinn,native-roots" + # args: # Uncomment when build with webpki cert store + # - "CARGO_FEATURES=--no-default-features --features=http3-quinn,webpki-roots" dockerfile: ./docker/Dockerfile-slim # based on alpine and build x86_64-unknown-linux-musl platforms: # Choose your platforms # - "linux/amd64" diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index bac5957..0c95fc6 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -14,8 +14,8 @@ services: # args: # Uncomment when build quic-s2n version # - "CARGO_FEATURES=--no-default-features --features=http3-s2n" # - "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" - # args: # Uncomment when build with native cert store - # - "CARGO_FEATURES=--no-default-features --features=http3-quinn,native-roots" + # args: # Uncomment when build with webpki root store + # - "CARGO_FEATURES=--no-default-features --features=http3-quinn,webpki-roots" dockerfile: ./docker/Dockerfile # based on ubuntu 22.04 and build x86_64-unknown-linux-gnu platforms: # Choose your platforms # - "linux/amd64" diff --git a/rpxy-bin/Cargo.toml b/rpxy-bin/Cargo.toml index fbe14dc..4ad5d1c 100644 --- a/rpxy-bin/Cargo.toml +++ b/rpxy-bin/Cargo.toml @@ -1,51 +1,54 @@ [package] name = "rpxy" -version = "0.6.2" -authors = ["Jun Kurihara"] -homepage = "https://github.com/junkurihara/rust-rpxy" -repository = "https://github.com/junkurihara/rust-rpxy" -license = "MIT" -readme = "../README.md" -edition = "2021" -publish = false +description = "`rpxy`: a simple and ultrafast http reverse proxy" +version.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true +readme.workspace = true +edition.workspace = true +publish.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "cache"] +default = ["http3-quinn", "cache", "rustls-backend"] http3-quinn = ["rpxy-lib/http3-quinn"] http3-s2n = ["rpxy-lib/http3-s2n"] +native-tls-backend = ["rpxy-lib/native-tls-backend"] +rustls-backend = ["rpxy-lib/rustls-backend"] +webpki-roots = ["rpxy-lib/webpki-roots"] cache = ["rpxy-lib/cache"] -native-roots = ["rpxy-lib/native-roots"] [dependencies] rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ "sticky-cookie", ] } -anyhow = "1.0.75" +anyhow = "1.0.79" rustc-hash = "1.1.0" -serde = { version = "1.0.188", default-features = false, features = ["derive"] } -derive_builder = "0.12.0" -tokio = { version = "1.33.0", default-features = false, features = [ +serde = { version = "1.0.196", default-features = false, features = ["derive"] } +derive_builder = "0.20.0" +tokio = { version = "1.36.0", default-features = false, features = [ "net", "rt-multi-thread", "time", "sync", "macros", ] } -async-trait = "0.1.73" -rustls-pemfile = "1.0.3" +async-trait = "0.1.77" +rustls-pemfile = "1.0.4" mimalloc = { version = "*", default-features = false } # config -clap = { version = "4.4.6", features = ["std", "cargo", "wrap_help"] } -toml = { version = "0.8", default-features = false, features = ["parse"] } -hot_reload = "0.1.4" +clap = { version = "4.5.0", features = ["std", "cargo", "wrap_help"] } +toml = { version = "0.8.10", default-features = false, features = ["parse"] } +hot_reload = "0.1.5" # logging -tracing = { version = "0.1.37" } -tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } +tracing = { version = "0.1.40" } +tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } [dev-dependencies] diff --git a/rpxy-bin/src/cert_file_reader.rs b/rpxy-bin/src/cert_file_reader.rs index 0a6a14f..ee9a591 100644 --- a/rpxy-bin/src/cert_file_reader.rs +++ b/rpxy-bin/src/cert_file_reader.rs @@ -8,7 +8,7 @@ use rpxy_lib::{ use std::{ fs::File, io::{self, BufReader, Cursor, Read}, - path::PathBuf, + path::{Path, PathBuf}, }; #[derive(Builder, Debug, Clone)] @@ -28,16 +28,16 @@ pub struct CryptoFileSource { } impl CryptoFileSourceBuilder { - pub fn tls_cert_path(&mut self, v: &str) -> &mut Self { - self.tls_cert_path = Some(PathBuf::from(v)); + pub fn tls_cert_path>(&mut self, v: T) -> &mut Self { + self.tls_cert_path = Some(v.as_ref().to_path_buf()); self } - pub fn tls_cert_key_path(&mut self, v: &str) -> &mut Self { - self.tls_cert_key_path = Some(PathBuf::from(v)); + pub fn tls_cert_key_path>(&mut self, v: T) -> &mut Self { + self.tls_cert_key_path = Some(v.as_ref().to_path_buf()); self } - pub fn client_ca_cert_path(&mut self, v: &Option) -> &mut Self { - self.client_ca_cert_path = Some(v.to_owned().as_ref().map(PathBuf::from)); + pub fn client_ca_cert_path>(&mut self, v: Option) -> &mut Self { + self.client_ca_cert_path = Some(v.map(|p| p.as_ref().to_path_buf())); self } } @@ -167,11 +167,11 @@ mod tests { async fn read_server_crt_key_files_with_client_ca_crt() { let tls_cert_path = "../example-certs/server.crt"; let tls_cert_key_path = "../example-certs/server.key"; - let client_ca_cert_path = Some("../example-certs/client.ca.crt".to_string()); + let client_ca_cert_path = Some("../example-certs/client.ca.crt"); let crypto_file_source = CryptoFileSourceBuilder::default() .tls_cert_key_path(tls_cert_key_path) .tls_cert_path(tls_cert_path) - .client_ca_cert_path(&client_ca_cert_path) + .client_ca_cert_path(client_ca_cert_path) .build(); assert!(crypto_file_source.is_ok()); diff --git a/rpxy-bin/src/config/toml.rs b/rpxy-bin/src/config/toml.rs index e678012..7cd3653 100644 --- a/rpxy-bin/src/config/toml.rs +++ b/rpxy-bin/src/config/toml.rs @@ -7,6 +7,7 @@ use rpxy_lib::{reexports::Uri, AppConfig, ProxyConfig, ReverseProxyConfig, TlsCo use rustc_hash::FxHashMap as HashMap; use serde::Deserialize; use std::{fs, net::SocketAddr}; +use tokio::time::Duration; #[derive(Deserialize, Debug, Default, PartialEq, Eq, Clone)] pub struct ConfigToml { @@ -48,6 +49,7 @@ pub struct Experimental { #[cfg(feature = "cache")] pub cache: Option, pub ignore_sni_consistency: Option, + pub connection_handling_timeout: Option, } #[derive(Deserialize, Debug, Default, PartialEq, Eq, Clone)] @@ -162,7 +164,7 @@ impl TryInto for &ConfigToml { if x == 0u64 { proxy_config.h3_max_idle_timeout = None; } else { - proxy_config.h3_max_idle_timeout = Some(tokio::time::Duration::from_secs(x)) + proxy_config.h3_max_idle_timeout = Some(Duration::from_secs(x)) } } } @@ -172,6 +174,14 @@ impl TryInto for &ConfigToml { proxy_config.sni_consistency = !ignore; } + if let Some(timeout) = exp.connection_handling_timeout { + if timeout == 0u64 { + proxy_config.connection_handling_timeout = None; + } else { + proxy_config.connection_handling_timeout = Some(Duration::from_secs(timeout)); + } + } + #[cfg(feature = "cache")] if let Some(cache_option) = &exp.cache { proxy_config.cache_enabled = true; @@ -217,7 +227,7 @@ impl Application { let inner = CryptoFileSourceBuilder::default() .tls_cert_path(tls.tls_cert_path.as_ref().unwrap()) .tls_cert_key_path(tls.tls_cert_key_path.as_ref().unwrap()) - .client_ca_cert_path(&tls.client_ca_cert_path) + .client_ca_cert_path(tls.client_ca_cert_path.as_deref()) .build()?; let https_redirection = if tls.https_redirection.is_none() { diff --git a/rpxy-bin/src/error.rs b/rpxy-bin/src/error.rs index b559bce..9751fb5 100644 --- a/rpxy-bin/src/error.rs +++ b/rpxy-bin/src/error.rs @@ -1 +1,2 @@ +#[allow(unused)] pub use anyhow::{anyhow, bail, ensure, Context}; diff --git a/rpxy-bin/src/log.rs b/rpxy-bin/src/log.rs index 3fcf694..f910e94 100644 --- a/rpxy-bin/src/log.rs +++ b/rpxy-bin/src/log.rs @@ -1,3 +1,4 @@ +#[allow(unused)] pub use tracing::{debug, error, info, warn}; pub fn init_logger() { @@ -12,10 +13,13 @@ pub fn init_logger() { .with_level(true) .compact(); - // This limits the logger to emits only rpxy crate + // This limits the logger to emits only proxy crate + let pkg_name = env!("CARGO_PKG_NAME").replace('-', "_"); let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); - let filter_layer = EnvFilter::new(format!("{}={}", env!("CARGO_PKG_NAME"), level_string)); - // let filter_layer = EnvFilter::from_default_env(); + let filter_layer = EnvFilter::new(format!("{}={}", pkg_name, level_string)); + // let filter_layer = EnvFilter::try_from_default_env() + // .unwrap_or_else(|_| EnvFilter::new("info")) + // .add_directive(format!("{}=trace", pkg_name).parse().unwrap()); tracing_subscriber::registry() .with(format_layer) diff --git a/rpxy-bin/src/main.rs b/rpxy-bin/src/main.rs index f04a6f1..9aeb971 100644 --- a/rpxy-bin/src/main.rs +++ b/rpxy-bin/src/main.rs @@ -15,9 +15,6 @@ use crate::{ use hot_reload::{ReloaderReceiver, ReloaderService}; use rpxy_lib::entrypoint; -#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] -compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); - fn main() { init_logger(); @@ -29,8 +26,8 @@ fn main() { runtime.block_on(async { // Initially load options let Ok(parsed_opts) = parse_opts() else { - error!("Invalid toml file"); - std::process::exit(1); + error!("Invalid toml file"); + std::process::exit(1); }; if !parsed_opts.watch { diff --git a/rpxy-lib/Cargo.toml b/rpxy-lib/Cargo.toml index c7cce09..65fb9e2 100644 --- a/rpxy-lib/Cargo.toml +++ b/rpxy-lib/Cargo.toml @@ -1,31 +1,40 @@ [package] name = "rpxy-lib" -version = "0.6.2" -authors = ["Jun Kurihara"] -homepage = "https://github.com/junkurihara/rust-rpxy" -repository = "https://github.com/junkurihara/rust-rpxy" -license = "MIT" -readme = "../README.md" -edition = "2021" -publish = false +description = "Library of `rpxy`: a simple and ultrafast http reverse proxy" +version.workspace = true +authors.workspace = true +homepage.workspace = true +repository.workspace = true +license.workspace = true +readme.workspace = true +edition.workspace = true +publish.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3-quinn", "sticky-cookie", "cache"] -http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] -http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] +default = ["http3-quinn", "sticky-cookie", "cache", "rustls-backend"] +http3-quinn = ["socket2", "quinn", "h3", "h3-quinn"] +http3-s2n = [ + "h3", + "s2n-quic", + "s2n-quic-core", + "s2n-quic-rustls", + "s2n-quic-h3", +] +cache = ["http-cache-semantics", "lru", "sha2", "base64"] sticky-cookie = ["base64", "sha2", "chrono"] -cache = ["http-cache-semantics", "lru"] -native-roots = ["hyper-rustls/native-tokio"] +native-tls-backend = ["hyper-tls"] +rustls-backend = ["hyper-rustls"] +webpki-roots = ["rustls-backend", "hyper-rustls/webpki-tokio"] [dependencies] rand = "0.8.5" rustc-hash = "1.1.0" bytes = "1.5.0" -derive_builder = "0.12.0" -futures = { version = "0.3.28", features = ["alloc", "async-await"] } -tokio = { version = "1.33.0", default-features = false, features = [ +derive_builder = "0.20.0" +futures = { version = "0.3.30", features = ["alloc", "async-await"] } +tokio = { version = "1.36.0", default-features = false, features = [ "net", "rt-multi-thread", "time", @@ -33,60 +42,69 @@ tokio = { version = "1.33.0", default-features = false, features = [ "macros", "fs", ] } -async-trait = "0.1.73" -hot_reload = "0.1.4" # reloading certs +pin-project-lite = "0.2.13" +async-trait = "0.1.77" # Error handling -anyhow = "1.0.75" -thiserror = "1.0.49" +anyhow = "1.0.79" +thiserror = "1.0.57" -# http and tls -hyper = { version = "0.14.27", default-features = false, features = [ - "server", +# http for both server and client +http = "1.0.0" +http-body-util = "0.1.0" +hyper = { version = "1.1.0", default-features = false } +hyper-util = { version = "0.1.3", features = ["full"] } +futures-util = { version = "0.3.30", default-features = false } +futures-channel = { version = "0.3.30", default-features = false } + +# http client for upstream +hyper-tls = { version = "0.6.0", features = [ + "alpn", + "vendored", +], optional = true } +hyper-rustls = { version = "0.26.0", default-features = false, features = [ + "ring", + "native-tokio", "http1", "http2", - "stream", -] } -hyper-rustls = { version = "0.24.1", default-features = false, features = [ - "tokio-runtime", - "webpki-tokio", - "http1", - "http2", -] } +], optional = true } + +# tls and cert management for server +hot_reload = "0.1.5" +rustls = { version = "0.21.10", default-features = false } tokio-rustls = { version = "0.24.1", features = ["early-data"] } -rustls = { version = "0.21.7", default-features = false } webpki = "0.22.4" x509-parser = "0.15.1" # logging -tracing = { version = "0.1.37" } +tracing = { version = "0.1.40" } # http/3 -# quinn = { version = "0.9.3", optional = true } -quinn = { path = "../submodules/quinn/quinn", optional = true } # Tentative to support rustls-0.21 +quinn = { version = "0.10.2", optional = true } h3 = { path = "../submodules/h3/h3/", optional = true } -# h3-quinn = { path = "./h3/h3-quinn/", optional = true } -h3-quinn = { path = "../submodules/h3-quinn/", optional = true } # Tentative to support rustls-0.21 -# for UDP socket wit SO_REUSEADDR when h3 with quinn -socket2 = { version = "0.5.4", features = ["all"], optional = true } -s2n-quic = { path = "../submodules/s2n-quic/quic/s2n-quic/", default-features = false, features = [ +h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } +s2n-quic = { version = "1.33.0", default-features = false, features = [ "provider-tls-rustls", ], optional = true } -s2n-quic-h3 = { path = "../submodules/s2n-quic/quic/s2n-quic-h3/", optional = true } -s2n-quic-rustls = { path = "../submodules/s2n-quic/quic/s2n-quic-rustls/", optional = true } +s2n-quic-core = { version = "0.33.0", default-features = false, optional = true } +s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } +s2n-quic-rustls = { version = "0.33.0", optional = true } +# for UDP socket wit SO_REUSEADDR when h3 with quinn +socket2 = { version = "0.5.5", features = ["all"], optional = true } # cache http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } -lru = { version = "0.12.0", optional = true } +lru = { version = "0.12.2", optional = true } +sha2 = { version = "0.10.8", default-features = false, optional = true } # cookie handling for sticky cookie -chrono = { version = "0.4.31", default-features = false, features = [ +chrono = { version = "0.4.34", default-features = false, features = [ "unstable-locales", "alloc", "clock", ], optional = true } -base64 = { version = "0.21.4", optional = true } -sha2 = { version = "0.10.8", default-features = false, optional = true } +base64 = { version = "0.21.7", optional = true } [dev-dependencies] +tokio-test = "0.4.3" diff --git a/rpxy-lib/src/backend/backend_main.rs b/rpxy-lib/src/backend/backend_main.rs new file mode 100644 index 0000000..d9fa649 --- /dev/null +++ b/rpxy-lib/src/backend/backend_main.rs @@ -0,0 +1,136 @@ +use crate::{ + crypto::CryptoSource, + error::*, + log::*, + name_exp::{ByteName, ServerName}, + AppConfig, AppConfigList, +}; +use derive_builder::Builder; +use rustc_hash::FxHashMap as HashMap; +use std::borrow::Cow; + +use super::upstream::PathManager; + +/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. +#[derive(Builder)] +pub struct BackendApp +where + T: CryptoSource, +{ + #[builder(setter(into))] + /// backend application name, e.g., app1 + pub app_name: String, + #[builder(setter(custom))] + /// server name, e.g., example.com, in [[ServerName]] object + pub server_name: ServerName, + /// struct of reverse proxy serving incoming request + pub path_manager: PathManager, + /// tls settings: https redirection with 30x + #[builder(default)] + pub https_redirection: Option, + /// TLS settings: source meta for server cert, key, client ca cert + #[builder(default)] + pub crypto_source: Option, +} +impl<'a, T> BackendAppBuilder +where + T: CryptoSource, +{ + pub fn server_name(&mut self, server_name: impl Into>) -> &mut Self { + self.server_name = Some(server_name.to_server_name()); + self + } +} + +/// HashMap and some meta information for multiple Backend structs. +pub struct BackendAppManager +where + T: CryptoSource, +{ + /// HashMap of Backend structs, key is server name + pub apps: HashMap>, + /// for plaintext http + pub default_server_name: Option, +} + +impl Default for BackendAppManager +where + T: CryptoSource, +{ + fn default() -> Self { + Self { + apps: HashMap::>::default(), + default_server_name: None, + } + } +} + +impl TryFrom<&AppConfig> for BackendApp +where + T: CryptoSource + Clone, +{ + type Error = RpxyError; + + fn try_from(app_config: &AppConfig) -> Result { + let mut backend_builder = BackendAppBuilder::default(); + let path_manager = PathManager::try_from(app_config)?; + backend_builder + .app_name(app_config.app_name.clone()) + .server_name(app_config.server_name.clone()) + .path_manager(path_manager); + // TLS settings and build backend instance + let backend = if app_config.tls.is_none() { + backend_builder.build()? + } else { + let tls = app_config.tls.as_ref().unwrap(); + backend_builder + .https_redirection(Some(tls.https_redirection)) + .crypto_source(Some(tls.inner.clone())) + .build()? + }; + Ok(backend) + } +} + +impl TryFrom<&AppConfigList> for BackendAppManager +where + T: CryptoSource + Clone, +{ + type Error = RpxyError; + + fn try_from(config_list: &AppConfigList) -> Result { + let mut manager = Self::default(); + for app_config in config_list.inner.iter() { + let backend: BackendApp = BackendApp::try_from(app_config)?; + manager + .apps + .insert(app_config.server_name.clone().to_server_name(), backend); + + info!( + "Registering application {} ({})", + &app_config.server_name, &app_config.app_name + ); + } + + // default backend application for plaintext http requests + if let Some(default_app_name) = &config_list.default_app { + let default_server_name = manager + .apps + .iter() + .filter(|(_k, v)| &v.app_name == default_app_name) + .map(|(_, v)| v.server_name.clone()) + .collect::>(); + + if !default_server_name.is_empty() { + info!( + "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", + &default_app_name, + (&default_server_name[0]).try_into().unwrap_or_else(|_| "".to_string()) + ); + + manager.default_server_name = Some(default_server_name[0].clone()); + } + } + Ok(manager) + } +} diff --git a/rpxy-lib/src/backend/load_balance.rs b/rpxy-lib/src/backend/load_balance/load_balance_main.rs similarity index 75% rename from rpxy-lib/src/backend/load_balance.rs rename to rpxy-lib/src/backend/load_balance/load_balance_main.rs index 5d93f0a..0b3eff8 100644 --- a/rpxy-lib/src/backend/load_balance.rs +++ b/rpxy-lib/src/backend/load_balance/load_balance_main.rs @@ -1,6 +1,7 @@ +#[allow(unused)] #[cfg(feature = "sticky-cookie")] pub use super::{ - load_balance_sticky::{LbStickyRoundRobin, LbStickyRoundRobinBuilder}, + load_balance_sticky::{LoadBalanceSticky, LoadBalanceStickyBuilder}, sticky_cookie::StickyCookie, }; use derive_builder::Builder; @@ -11,7 +12,7 @@ use std::sync::{ }; /// Constants to specify a load balance option -pub(super) mod load_balance_options { +pub mod load_balance_options { pub const FIX_TO_FIRST: &str = "none"; pub const ROUND_ROBIN: &str = "round_robin"; pub const RANDOM: &str = "random"; @@ -22,18 +23,18 @@ pub(super) mod load_balance_options { #[derive(Debug, Clone)] /// Pointer to upstream serving the incoming request. /// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given. -pub(super) struct PointerToUpstream { +pub struct PointerToUpstream { pub ptr: usize, - pub context_lb: Option, + pub context: Option, } /// Trait for LB -pub(super) trait LbWithPointer { - fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; +pub(super) trait LoadBalanceWithPointer { + fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream; } #[derive(Debug, Clone, Builder)] /// Round Robin LB object as a pointer to the current serving upstream destination -pub struct LbRoundRobin { +pub struct LoadBalanceRoundRobin { #[builder(default)] /// Pointer to the index of the last served upstream destination ptr: Arc, @@ -41,15 +42,15 @@ pub struct LbRoundRobin { /// Number of upstream destinations num_upstreams: usize, } -impl LbRoundRobinBuilder { +impl LoadBalanceRoundRobinBuilder { pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { self.num_upstreams = Some(*v); self } } -impl LbWithPointer for LbRoundRobin { +impl LoadBalanceWithPointer for LoadBalanceRoundRobin { /// Increment the count of upstream served up to the max value - fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { // Get a current count of upstream served let current_ptr = self.ptr.load(Ordering::Relaxed); @@ -59,29 +60,29 @@ impl LbWithPointer for LbRoundRobin { // Clear the counter self.ptr.fetch_and(0, Ordering::Relaxed) }; - PointerToUpstream { ptr, context_lb: None } + PointerToUpstream { ptr, context: None } } } #[derive(Debug, Clone, Builder)] /// Random LB object to keep the object of random pools -pub struct LbRandom { +pub struct LoadBalanceRandom { #[builder(setter(custom), default)] /// Number of upstream destinations num_upstreams: usize, } -impl LbRandomBuilder { +impl LoadBalanceRandomBuilder { pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { self.num_upstreams = Some(*v); self } } -impl LbWithPointer for LbRandom { +impl LoadBalanceWithPointer for LoadBalanceRandom { /// Returns the random index within the range - fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { let mut rng = rand::thread_rng(); let ptr = rng.gen_range(0..self.num_upstreams); - PointerToUpstream { ptr, context_lb: None } + PointerToUpstream { ptr, context: None } } } @@ -91,12 +92,12 @@ pub enum LoadBalance { /// Fix to the first upstream. Use if only one upstream destination is specified FixToFirst, /// Randomly chose one upstream server - Random(LbRandom), + Random(LoadBalanceRandom), /// Simple round robin without session persistance - RoundRobin(LbRoundRobin), + RoundRobin(LoadBalanceRoundRobin), #[cfg(feature = "sticky-cookie")] /// Round robin with session persistance using cookie - StickyRoundRobin(LbStickyRoundRobin), + StickyRoundRobin(LoadBalanceSticky), } impl Default for LoadBalance { fn default() -> Self { @@ -106,11 +107,11 @@ impl Default for LoadBalance { impl LoadBalance { /// Get the index of the upstream serving the incoming request - pub(super) fn get_context(&self, _context_to_lb: &Option) -> PointerToUpstream { + pub fn get_context(&self, _context_to_lb: &Option) -> PointerToUpstream { match self { LoadBalance::FixToFirst => PointerToUpstream { ptr: 0usize, - context_lb: None, + context: None, }, LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), LoadBalance::Random(ptr) => ptr.get_ptr(None), @@ -127,7 +128,7 @@ impl LoadBalance { /// Struct to handle the sticky cookie string, /// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists. /// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist. -pub struct LbContext { +pub struct LoadBalanceContext { #[cfg(feature = "sticky-cookie")] pub sticky_cookie: StickyCookie, #[cfg(not(feature = "sticky-cookie"))] diff --git a/rpxy-lib/src/backend/load_balance_sticky.rs b/rpxy-lib/src/backend/load_balance/load_balance_sticky.rs similarity index 85% rename from rpxy-lib/src/backend/load_balance_sticky.rs rename to rpxy-lib/src/backend/load_balance/load_balance_sticky.rs index 32f4fe5..d7a9795 100644 --- a/rpxy-lib/src/backend/load_balance_sticky.rs +++ b/rpxy-lib/src/backend/load_balance/load_balance_sticky.rs @@ -1,5 +1,5 @@ use super::{ - load_balance::{LbContext, LbWithPointer, PointerToUpstream}, + load_balance_main::{LoadBalanceContext, LoadBalanceWithPointer, PointerToUpstream}, sticky_cookie::StickyCookieConfig, Upstream, }; @@ -16,7 +16,7 @@ use std::{ #[derive(Debug, Clone, Builder)] /// Round Robin LB object in the sticky cookie manner -pub struct LbStickyRoundRobin { +pub struct LoadBalanceSticky { #[builder(default)] /// Pointer to the index of the last served upstream destination ptr: Arc, @@ -39,11 +39,13 @@ pub struct UpstreamMap { /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup upstream_id_map: HashMap, } -impl LbStickyRoundRobinBuilder { +impl LoadBalanceStickyBuilder { + /// Set the number of upstream destinations pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { self.num_upstreams = Some(*v); self } + /// Set the information to build the cookie to stick clients to specific backends pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option) -> &mut Self { self.sticky_config = Some(StickyCookieConfig { name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように @@ -57,6 +59,7 @@ impl LbStickyRoundRobinBuilder { }); self } + /// Set the hashmaps: upstream_index_map and upstream_id_map pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { let upstream_index_map: Vec = upstream_vec .iter() @@ -74,7 +77,8 @@ impl LbStickyRoundRobinBuilder { self } } -impl<'a> LbStickyRoundRobin { +impl<'a> LoadBalanceSticky { + /// Increment the count of upstream served up to the max value fn simple_increment_ptr(&self) -> usize { // Get a current count of upstream served let current_ptr = self.ptr.load(Ordering::Relaxed); @@ -96,8 +100,9 @@ impl<'a> LbStickyRoundRobin { self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) } } -impl LbWithPointer for LbStickyRoundRobin { - fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { +impl LoadBalanceWithPointer for LoadBalanceSticky { + /// Get the pointer to the upstream server to serve the incoming request. + fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream { // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer. // Otherwise, get the server index indicated by the server_id inside the cookie let ptr = match req_info { @@ -121,12 +126,12 @@ impl LbWithPointer for LbStickyRoundRobin { // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie). let upstream_id = self.get_server_id_from_index(ptr); let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); - let new_context = Some(LbContext { + let new_context = Some(LoadBalanceContext { sticky_cookie: new_cookie, }); PointerToUpstream { ptr, - context_lb: new_context, + context: new_context, } } } diff --git a/rpxy-lib/src/backend/load_balance/mod.rs b/rpxy-lib/src/backend/load_balance/mod.rs new file mode 100644 index 0000000..38d312b --- /dev/null +++ b/rpxy-lib/src/backend/load_balance/mod.rs @@ -0,0 +1,43 @@ +mod load_balance_main; +#[cfg(feature = "sticky-cookie")] +mod load_balance_sticky; +#[cfg(feature = "sticky-cookie")] +mod sticky_cookie; + +use super::upstream::Upstream; +use thiserror::Error; + +pub use load_balance_main::{ + load_balance_options, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, +}; +#[cfg(feature = "sticky-cookie")] +pub use load_balance_sticky::LoadBalanceStickyBuilder; +#[cfg(feature = "sticky-cookie")] +pub use sticky_cookie::{StickyCookie, StickyCookieValue}; + +/// Result type for load balancing +type LoadBalanceResult = std::result::Result; +/// Describes things that can go wrong in the Load Balance +#[derive(Debug, Error)] +pub enum LoadBalanceError { + // backend load balance errors + #[cfg(feature = "sticky-cookie")] + #[error("Failed to cookie conversion to/from string")] + FailedToConversionStickyCookie, + + #[cfg(feature = "sticky-cookie")] + #[error("Invalid cookie structure")] + InvalidStickyCookieStructure, + + #[cfg(feature = "sticky-cookie")] + #[error("No sticky cookie value")] + NoStickyCookieValue, + + #[cfg(feature = "sticky-cookie")] + #[error("Failed to cookie conversion into string: no meta information")] + NoStickyCookieNoMetaInfo, + + #[cfg(feature = "sticky-cookie")] + #[error("Failed to build sticky cookie from config")] + FailedToBuildStickyCookie, +} diff --git a/rpxy-lib/src/backend/sticky_cookie.rs b/rpxy-lib/src/backend/load_balance/sticky_cookie.rs similarity index 84% rename from rpxy-lib/src/backend/sticky_cookie.rs rename to rpxy-lib/src/backend/load_balance/sticky_cookie.rs index 998426b..28572b5 100644 --- a/rpxy-lib/src/backend/sticky_cookie.rs +++ b/rpxy-lib/src/backend/load_balance/sticky_cookie.rs @@ -1,8 +1,7 @@ -use std::borrow::Cow; - -use crate::error::*; +use super::{LoadBalanceError, LoadBalanceResult}; use chrono::{TimeZone, Utc}; use derive_builder::Builder; +use std::borrow::Cow; #[derive(Debug, Clone, Builder)] /// Cookie value only, used for COOKIE in req @@ -25,18 +24,16 @@ impl<'a> StickyCookieValueBuilder { } } impl StickyCookieValue { - pub fn try_from(value: &str, expected_name: &str) -> Result { + pub fn try_from(value: &str, expected_name: &str) -> LoadBalanceResult { if !value.starts_with(expected_name) { - return Err(RpxyError::LoadBalance( - "Failed to cookie conversion from string".to_string(), - )); + return Err(LoadBalanceError::FailedToConversionStickyCookie); }; let kv = value.split('=').map(|v| v.trim()).collect::>(); if kv.len() != 2 { - return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); + return Err(LoadBalanceError::InvalidStickyCookieStructure); }; if kv[1].is_empty() { - return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); + return Err(LoadBalanceError::NoStickyCookieValue); } Ok(StickyCookieValue { name: expected_name.to_string(), @@ -88,10 +85,12 @@ pub struct StickyCookie { } impl<'a> StickyCookieBuilder { + /// Set the value of sticky cookie pub fn value(&mut self, n: impl Into>, v: impl Into>) -> &mut Self { self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); self } + /// Set the meta information of sticky cookie pub fn info( &mut self, domain: impl Into>, @@ -110,17 +109,15 @@ impl<'a> StickyCookieBuilder { } impl TryInto for StickyCookie { - type Error = RpxyError; + type Error = LoadBalanceError; - fn try_into(self) -> Result { + fn try_into(self) -> LoadBalanceResult { if self.info.is_none() { - return Err(RpxyError::LoadBalance( - "Failed to cookie conversion into string: no meta information".to_string(), - )); + return Err(LoadBalanceError::NoStickyCookieNoMetaInfo); } let info = self.info.unwrap(); let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { - return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); + return Err(LoadBalanceError::FailedToConversionStickyCookie); }; let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); let max_age = info.expires - Utc::now().timestamp(); @@ -144,12 +141,12 @@ pub struct StickyCookieConfig { pub duration: i64, } impl<'a> StickyCookieConfig { - pub fn build_sticky_cookie(&self, v: impl Into>) -> Result { + pub fn build_sticky_cookie(&self, v: impl Into>) -> LoadBalanceResult { StickyCookieBuilder::default() .value(self.name.clone(), v) .info(&self.domain, &self.path, self.duration) .build() - .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) + .map_err(|_| LoadBalanceError::FailedToBuildStickyCookie) } } @@ -167,7 +164,7 @@ mod tests { duration: 100, }; let expires_unix = Utc::now().timestamp() + 100; - let sc_string: Result = config.build_sticky_cookie("test_value").unwrap().try_into(); + let sc_string: LoadBalanceResult = config.build_sticky_cookie("test_value").unwrap().try_into(); let expires_date_string = Utc .timestamp_opt(expires_unix, 0) .unwrap() @@ -194,7 +191,7 @@ mod tests { path: "/path".to_string(), }), }; - let sc_string: Result = sc.try_into(); + let sc_string: LoadBalanceResult = sc.try_into(); let max_age = 1686221173i64 - Utc::now().timestamp(); assert!(sc_string.is_ok()); assert_eq!( diff --git a/rpxy-lib/src/backend/mod.rs b/rpxy-lib/src/backend/mod.rs index 73c4466..097810a 100644 --- a/rpxy-lib/src/backend/mod.rs +++ b/rpxy-lib/src/backend/mod.rs @@ -1,77 +1,14 @@ +mod backend_main; mod load_balance; -#[cfg(feature = "sticky-cookie")] -mod load_balance_sticky; -#[cfg(feature = "sticky-cookie")] -mod sticky_cookie; mod upstream; mod upstream_opts; #[cfg(feature = "sticky-cookie")] -pub use self::sticky_cookie::{StickyCookie, StickyCookieValue}; -pub use self::{ - load_balance::{LbContext, LoadBalance}, - upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, +pub(crate) use self::load_balance::{StickyCookie, StickyCookieValue}; +#[allow(unused)] +pub(crate) use self::{ + load_balance::{LoadBalance, LoadBalanceContext}, + upstream::{PathManager, Upstream, UpstreamCandidates}, upstream_opts::UpstreamOption, }; -use crate::{ - certs::CryptoSource, - utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, -}; -use derive_builder::Builder; -use rustc_hash::FxHashMap as HashMap; -use std::borrow::Cow; - -/// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. -#[derive(Builder)] -pub struct Backend -where - T: CryptoSource, -{ - #[builder(setter(into))] - /// backend application name, e.g., app1 - pub app_name: String, - #[builder(setter(custom))] - /// server name, e.g., example.com, in String ascii lower case - pub server_name: String, - /// struct of reverse proxy serving incoming request - pub reverse_proxy: ReverseProxy, - - /// tls settings: https redirection with 30x - #[builder(default)] - pub https_redirection: Option, - - /// TLS settings: source meta for server cert, key, client ca cert - #[builder(default)] - pub crypto_source: Option, -} -impl<'a, T> BackendBuilder -where - T: CryptoSource, -{ - pub fn server_name(&mut self, server_name: impl Into>) -> &mut Self { - self.server_name = Some(server_name.into().to_ascii_lowercase()); - self - } -} - -/// HashMap and some meta information for multiple Backend structs. -pub struct Backends -where - T: CryptoSource, -{ - pub apps: HashMap>, // hyper::uriで抜いたhostで引っ掛ける - pub default_server_name_bytes: Option, // for plaintext http -} - -impl Backends -where - T: CryptoSource, -{ - #[allow(clippy::new_without_default)] - pub fn new() -> Self { - Backends { - apps: HashMap::>::default(), - default_server_name_bytes: None, - } - } -} +pub(crate) use backend_main::{BackendApp, BackendAppBuilderError, BackendAppManager}; diff --git a/rpxy-lib/src/backend/upstream.rs b/rpxy-lib/src/backend/upstream.rs index 2bfd2d6..702be29 100644 --- a/rpxy-lib/src/backend/upstream.rs +++ b/rpxy-lib/src/backend/upstream.rs @@ -1,8 +1,17 @@ #[cfg(feature = "sticky-cookie")] -use super::load_balance::LbStickyRoundRobinBuilder; -use super::load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}; -use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; -use crate::log::*; +use super::load_balance::LoadBalanceStickyBuilder; +use super::load_balance::{ + load_balance_options as lb_opts, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, +}; +// use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; +use super::upstream_opts::UpstreamOption; +use crate::{ + crypto::CryptoSource, + error::RpxyError, + globals::{AppConfig, UpstreamUri}, + log::*, + name_exp::{ByteName, PathName}, +}; #[cfg(feature = "sticky-cookie")] use base64::{engine::general_purpose, Engine as _}; use derive_builder::Builder; @@ -10,38 +19,79 @@ use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; #[cfg(feature = "sticky-cookie")] use sha2::{Digest, Sha256}; use std::borrow::Cow; + #[derive(Debug, Clone)] -pub struct ReverseProxy { - pub upstream: HashMap, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 +/// Handler for given path to route incoming request to path's corresponding upstream server(s). +pub struct PathManager { + /// HashMap of upstream candidate server info, key is path name + /// TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 + inner: HashMap, } -impl ReverseProxy { - /// Get an appropriate upstream destination for given path string. - pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamGroup> { - // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 - // コスト的にこの程度で十分 - let path_bytes = &path_str.to_path_name_vec(); +impl TryFrom<&AppConfig> for PathManager +where + T: CryptoSource, +{ + type Error = RpxyError; + fn try_from(app_config: &AppConfig) -> Result { + let mut inner: HashMap = HashMap::default(); + + app_config.reverse_proxy.iter().for_each(|rpc| { + let upstream_vec: Vec = rpc.upstream.iter().map(Upstream::from).collect(); + let elem = UpstreamCandidatesBuilder::default() + .upstream(&upstream_vec) + .path(&rpc.path) + .replace_path(&rpc.replace_path) + .load_balance(&rpc.load_balance, &upstream_vec, &app_config.server_name, &rpc.path) + .options(&rpc.upstream_options) + .build() + .unwrap(); + inner.insert(elem.path.clone(), elem); + }); + + if app_config.reverse_proxy.iter().filter(|rpc| rpc.path.is_none()).count() >= 2 { + error!("Multiple default reverse proxy setting"); + return Err(RpxyError::InvalidReverseProxyConfig); + } + + if !(inner.iter().all(|(_, elem)| { + !(elem.options.contains(&UpstreamOption::ForceHttp11Upstream) && elem.options.contains(&UpstreamOption::ForceHttp2Upstream)) + })) { + error!("Either one of force_http11 or force_http2 can be enabled"); + return Err(RpxyError::InvalidUpstreamOptionSetting); + } + + Ok(PathManager { inner }) + } +} + +impl PathManager { + /// Get an appropriate upstream destinations for given path string. + /// trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 + /// コスト的にこの程度で十分では。 + pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamCandidates> { + let path_name = &path_str.to_path_name(); let matched_upstream = self - .upstream + .inner .iter() .filter(|(route_bytes, _)| { - match path_bytes.starts_with(route_bytes) { + match path_name.starts_with(route_bytes) { true => { route_bytes.len() == 1 // route = '/', i.e., default - || match path_bytes.get(route_bytes.len()) { - None => true, // exact case - Some(p) => p == &b'/', // sub-path case - } + || match path_name.get(route_bytes.len()) { + None => true, // exact case + Some(p) => p == &b'/', // sub-path case + } } _ => false, } }) .max_by_key(|(route_bytes, _)| route_bytes.len()); - if let Some((_path, u)) = matched_upstream { + if let Some((path, u)) = matched_upstream { debug!( "Found upstream: {:?}", - String::from_utf8(_path.0.clone()).unwrap_or_else(|_| "".to_string()) + path.try_into().unwrap_or_else(|_| "".to_string()) ); Some(u) } else { @@ -56,6 +106,13 @@ pub struct Upstream { /// Base uri without specific path pub uri: hyper::Uri, } +impl From<&UpstreamUri> for Upstream { + fn from(value: &UpstreamUri) -> Self { + Self { + uri: value.inner.clone(), + } + } +} impl Upstream { #[cfg(feature = "sticky-cookie")] /// Hashing uri with index to avoid collision @@ -69,51 +126,54 @@ impl Upstream { } #[derive(Debug, Clone, Builder)] /// Struct serving multiple upstream servers for, e.g., load balancing. -pub struct UpstreamGroup { +pub struct UpstreamCandidates { #[builder(setter(custom))] /// Upstream server(s) - pub upstream: Vec, + pub inner: Vec, + #[builder(setter(custom), default)] - /// Path like "/path" in [[PathNameBytesExp]] associated with the upstream server(s) - pub path: PathNameBytesExp, + /// Path like "/path" in [[PathName]] associated with the upstream server(s) + pub path: PathName, + #[builder(setter(custom), default)] - /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url - pub replace_path: Option, + /// Path in [[PathName]] that will be used to replace the "path" part of incoming url + pub replace_path: Option, #[builder(setter(custom), default)] /// Load balancing option - pub lb: LoadBalance, + pub load_balance: LoadBalance, + #[builder(setter(custom), default)] /// Activated upstream options defined in [[UpstreamOption]] - pub opts: HashSet, + pub options: HashSet, } -impl UpstreamGroupBuilder { +impl UpstreamCandidatesBuilder { + /// Set the upstream server(s) pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { - self.upstream = Some(upstream_vec.to_vec()); + self.inner = Some(upstream_vec.to_vec()); self } + /// Set the path like "/path" in [[PathName]] associated with the upstream server(s), default is "/" pub fn path(&mut self, v: &Option) -> &mut Self { let path = match v { - Some(p) => p.to_path_name_vec(), - None => "/".to_path_name_vec(), + Some(p) => p.to_path_name(), + None => "/".to_path_name(), }; self.path = Some(path); self } + /// Set the path in [[PathName]] that will be used to replace the "path" part of incoming url pub fn replace_path(&mut self, v: &Option) -> &mut Self { - self.replace_path = Some( - v.to_owned() - .as_ref() - .map_or_else(|| None, |v| Some(v.to_path_name_vec())), - ); + self.replace_path = Some(v.to_owned().as_ref().map_or_else(|| None, |v| Some(v.to_path_name()))); self } - pub fn lb( + /// Set the load balancing option + pub fn load_balance( &mut self, v: &Option, // upstream_num: &usize, - upstream_vec: &Vec, + upstream_vec: &[Upstream], _server_name: &str, _path_opt: &Option, ) -> &mut Self { @@ -121,16 +181,21 @@ impl UpstreamGroupBuilder { let lb = if let Some(x) = v { match x.as_str() { lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, - lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), + lb_opts::RANDOM => LoadBalance::Random( + LoadBalanceRandomBuilder::default() + .num_upstreams(upstream_num) + .build() + .unwrap(), + ), lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( - LbRoundRobinBuilder::default() + LoadBalanceRoundRobinBuilder::default() .num_upstreams(upstream_num) .build() .unwrap(), ), #[cfg(feature = "sticky-cookie")] lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( - LbStickyRoundRobinBuilder::default() + LoadBalanceStickyBuilder::default() .num_upstreams(upstream_num) .sticky_config(_server_name, _path_opt) .upstream_maps(upstream_vec) // TODO: @@ -145,10 +210,11 @@ impl UpstreamGroupBuilder { } else { LoadBalance::default() }; - self.lb = Some(lb); + self.load_balance = Some(lb); self } - pub fn opts(&mut self, v: &Option>) -> &mut Self { + /// Set the activated upstream options defined in [[UpstreamOption]] + pub fn options(&mut self, v: &Option>) -> &mut Self { let opts = if let Some(opts) = v { opts .iter() @@ -157,25 +223,19 @@ impl UpstreamGroupBuilder { } else { Default::default() }; - self.opts = Some(opts); + self.options = Some(opts); self } } -impl UpstreamGroup { +impl UpstreamCandidates { /// Get an enabled option of load balancing [[LoadBalance]] - pub fn get(&self, context_to_lb: &Option) -> (Option<&Upstream>, Option) { - let pointer_to_upstream = self.lb.get_context(context_to_lb); + pub fn get(&self, context_to_lb: &Option) -> (Option<&Upstream>, Option) { + let pointer_to_upstream = self.load_balance.get_context(context_to_lb); debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); - debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); - debug!( - "Context from LB (Set-Cookie in Res): {:?}", - pointer_to_upstream.context_lb - ); - ( - self.upstream.get(pointer_to_upstream.ptr), - pointer_to_upstream.context_lb, - ) + debug!("Context to LB (Cookie in Request): {:?}", context_to_lb); + debug!("Context from LB (Set-Cookie in Response): {:?}", pointer_to_upstream.context); + (self.inner.get(pointer_to_upstream.ptr), pointer_to_upstream.context) } } diff --git a/rpxy-lib/src/backend/upstream_opts.rs b/rpxy-lib/src/backend/upstream_opts.rs index a96bb58..68309ca 100644 --- a/rpxy-lib/src/backend/upstream_opts.rs +++ b/rpxy-lib/src/backend/upstream_opts.rs @@ -1,22 +1,30 @@ use crate::error::*; +/// Options for request message to be sent to upstream. #[derive(Debug, Clone, Hash, Eq, PartialEq)] pub enum UpstreamOption { - OverrideHost, + /// Keep original host header, which is prioritized over SetUpstreamHost + KeepOriginalHost, + /// Overwrite host header with upstream hostname + SetUpstreamHost, + /// Add upgrade-insecure-requests header UpgradeInsecureRequests, + /// Force HTTP/1.1 upstream ForceHttp11Upstream, + /// Force HTTP/2 upstream ForceHttp2Upstream, // TODO: Adds more options for heder override } impl TryFrom<&str> for UpstreamOption { type Error = RpxyError; - fn try_from(val: &str) -> Result { + fn try_from(val: &str) -> RpxyResult { match val { - "override_host" => Ok(Self::OverrideHost), + "keep_original_host" => Ok(Self::KeepOriginalHost), + "set_upstream_host" => Ok(Self::SetUpstreamHost), "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), - _ => Err(RpxyError::Other(anyhow!("Unsupported header option"))), + _ => Err(RpxyError::UnsupportedUpstreamOption), } } } diff --git a/rpxy-lib/src/constants.rs b/rpxy-lib/src/constants.rs index ebec1fc..acc9381 100644 --- a/rpxy-lib/src/constants.rs +++ b/rpxy-lib/src/constants.rs @@ -4,8 +4,8 @@ pub const RESPONSE_HEADER_SERVER: &str = "rpxy"; pub const TCP_LISTEN_BACKLOG: u32 = 1024; // pub const HTTP_LISTEN_PORT: u16 = 8080; // pub const HTTPS_LISTEN_PORT: u16 = 8443; -pub const PROXY_TIMEOUT_SEC: u64 = 60; -pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; +pub const PROXY_IDLE_TIMEOUT_SEC: u64 = 20; +pub const UPSTREAM_IDLE_TIMEOUT_SEC: u64 = 20; pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser pub const MAX_CLIENTS: usize = 512; pub const MAX_CONCURRENT_STREAMS: u32 = 64; diff --git a/rpxy-lib/src/count.rs b/rpxy-lib/src/count.rs new file mode 100644 index 0000000..2ca4028 --- /dev/null +++ b/rpxy-lib/src/count.rs @@ -0,0 +1,31 @@ +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +#[derive(Debug, Clone, Default)] +/// Counter for serving requests +pub struct RequestCount(Arc); + +impl RequestCount { + pub fn current(&self) -> usize { + self.0.load(Ordering::Relaxed) + } + + pub fn increment(&self) -> usize { + self.0.fetch_add(1, Ordering::Relaxed) + } + + pub fn decrement(&self) -> usize { + let mut count; + while { + count = self.0.load(Ordering::Relaxed); + count > 0 + && self + .0 + .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) + != Ok(count) + } {} + count + } +} diff --git a/rpxy-lib/src/certs.rs b/rpxy-lib/src/crypto/certs.rs similarity index 100% rename from rpxy-lib/src/certs.rs rename to rpxy-lib/src/crypto/certs.rs diff --git a/rpxy-lib/src/crypto/mod.rs b/rpxy-lib/src/crypto/mod.rs new file mode 100644 index 0000000..7b8935c --- /dev/null +++ b/rpxy-lib/src/crypto/mod.rs @@ -0,0 +1,36 @@ +mod certs; +mod service; + +use crate::{ + backend::BackendAppManager, + constants::{CERTS_WATCH_DELAY_SECS, LOAD_CERTS_ONLY_WHEN_UPDATED}, + error::RpxyResult, +}; +use hot_reload::{ReloaderReceiver, ReloaderService}; +use service::CryptoReloader; +use std::sync::Arc; + +pub use certs::{CertsAndKeys, CryptoSource}; +pub use service::{ServerCrypto, ServerCryptoBase, SniServerCryptoMap}; + +/// Result type inner of certificate reloader service +type ReloaderServiceResultInner = ( + ReloaderService, ServerCryptoBase>, + ReloaderReceiver, +); +/// Build certificate reloader service +pub(crate) async fn build_cert_reloader( + app_manager: &Arc>, +) -> RpxyResult> +where + T: CryptoSource + Clone + Send + Sync + 'static, +{ + let (cert_reloader_service, cert_reloader_rx) = ReloaderService::< + service::CryptoReloader, + service::ServerCryptoBase, + >::new( + app_manager, CERTS_WATCH_DELAY_SECS, !LOAD_CERTS_ONLY_WHEN_UPDATED + ) + .await?; + Ok((cert_reloader_service, cert_reloader_rx)) +} diff --git a/rpxy-lib/src/proxy/crypto_service.rs b/rpxy-lib/src/crypto/service.rs similarity index 92% rename from rpxy-lib/src/proxy/crypto_service.rs rename to rpxy-lib/src/crypto/service.rs index ae0f993..8eda27a 100644 --- a/rpxy-lib/src/proxy/crypto_service.rs +++ b/rpxy-lib/src/crypto/service.rs @@ -1,9 +1,5 @@ -use crate::{ - certs::{CertsAndKeys, CryptoSource}, - globals::Globals, - log::*, - utils::ServerNameBytesExp, -}; +use super::certs::{CertsAndKeys, CryptoSource}; +use crate::{backend::BackendAppManager, log::*, name_exp::ServerName}; use async_trait::async_trait; use hot_reload::*; use rustc_hash::FxHashMap as HashMap; @@ -16,15 +12,17 @@ pub struct CryptoReloader where T: CryptoSource, { - globals: Arc>, + inner: Arc>, } -pub type SniServerCryptoMap = HashMap>; +/// SNI to ServerConfig map type +pub type SniServerCryptoMap = HashMap>; +/// SNI to ServerConfig map pub struct ServerCrypto { // For Quic/HTTP3, only servers with no client authentication #[cfg(feature = "http3-quinn")] pub inner_global_no_client_auth: Arc, - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] pub inner_global_no_client_auth: s2n_quic_rustls::Server, // For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers pub inner_local_map: Arc, @@ -33,7 +31,7 @@ pub struct ServerCrypto { /// Reloader target for the certificate reloader service #[derive(Debug, PartialEq, Eq, Clone, Default)] pub struct ServerCryptoBase { - inner: HashMap, + inner: HashMap, } #[async_trait] @@ -41,17 +39,15 @@ impl Reload for CryptoReloader where T: CryptoSource + Sync + Send, { - type Source = Arc>; + type Source = Arc>; async fn new(source: &Self::Source) -> Result> { - Ok(Self { - globals: source.clone(), - }) + Ok(Self { inner: source.clone() }) } async fn reload(&self) -> Result, ReloaderError> { let mut certs_and_keys_map = ServerCryptoBase::default(); - for (server_name_bytes_exp, backend) in self.globals.backends.apps.iter() { + for (server_name_bytes_exp, backend) in self.inner.apps.iter() { if let Some(crypto_source) = &backend.crypto_source { let certs_and_keys = crypto_source .read() @@ -78,7 +74,7 @@ impl TryInto> for &ServerCryptoBase { Ok(Arc::new(ServerCrypto { #[cfg(feature = "http3-quinn")] inner_global_no_client_auth: Arc::new(server_crypto_global), - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] inner_global_no_client_auth: server_crypto_global, inner_local_map: Arc::new(server_crypto_local_map), })) @@ -204,7 +200,7 @@ impl ServerCryptoBase { Ok(server_crypto_global) } - #[cfg(feature = "http3-s2n")] + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] fn build_server_crypto_global(&self) -> Result> { let mut resolver_global = s2n_quic_rustls::rustls::server::ResolvesServerCertUsingSni::new(); @@ -245,7 +241,7 @@ impl ServerCryptoBase { } } -#[cfg(feature = "http3-s2n")] +#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] /// This is workaround for the version difference between rustls and s2n-quic-rustls fn parse_server_certs_and_keys_s2n( certs_and_keys: &CertsAndKeys, diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index c672682..3b1afc9 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -1,86 +1,101 @@ -pub use anyhow::{anyhow, bail, ensure, Context}; -use std::io; use thiserror::Error; -pub type Result = std::result::Result; +pub type RpxyResult = std::result::Result; /// Describes things that can go wrong in the Rpxy #[derive(Debug, Error)] pub enum RpxyError { - #[error("Proxy build error: {0}")] - ProxyBuild(#[from] crate::proxy::ProxyBuilderError), + // general errors + #[error("IO error: {0}")] + Io(#[from] std::io::Error), - #[error("Backend build error: {0}")] - BackendBuild(#[from] crate::backend::BackendBuilderError), + // TLS errors + #[error("Failed to build TLS acceptor: {0}")] + FailedToTlsHandshake(String), + #[error("No server name in ClientHello")] + NoServerNameInClientHello, + #[error("No TLS serving app: {0}")] + NoTlsServingApp(String), + #[error("Failed to update server crypto: {0}")] + FailedToUpdateServerCrypto(String), + #[error("No server crypto: {0}")] + NoServerCrypto(String), - #[error("MessageHandler build error: {0}")] - HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError), + // hyper errors + #[error("hyper body manipulation error: {0}")] + HyperBodyManipulationError(String), + #[error("New closed in incoming-like")] + HyperIncomingLikeNewClosed, + #[error("New body write aborted")] + HyperNewBodyWriteAborted, + #[error("Hyper error in serving request or response body type: {0}")] + HyperBodyError(#[from] hyper::Error), - #[error("Config builder error: {0}")] - ConfigBuild(&'static str), - - #[error("Http Message Handler Error: {0}")] - Handler(&'static str), - - #[error("Cache Error: {0}")] - Cache(&'static str), - - #[error("Http Request Message Error: {0}")] - Request(&'static str), - - #[error("TCP/UDP Proxy Layer Error: {0}")] - Proxy(String), - - #[allow(unused)] - #[error("LoadBalance Layer Error: {0}")] - LoadBalance(String), - - #[error("I/O Error: {0}")] - Io(#[from] io::Error), - - // #[error("Toml Deserialization Error")] - // TomlDe(#[from] toml::de::Error), - #[cfg(feature = "http3-quinn")] - #[error("Quic Connection Error [quinn]: {0}")] - QuicConn(#[from] quinn::ConnectionError), - - #[cfg(feature = "http3-s2n")] - #[error("Quic Connection Error [s2n-quic]: {0}")] - QUicConn(#[from] s2n_quic::connection::Error), + // http/3 errors + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + #[error("H3 error: {0}")] + H3Error(#[from] h3::Error), + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + #[error("Exceeds max request body size for HTTP/3")] + H3TooLargeBody, #[cfg(feature = "http3-quinn")] - #[error("H3 Error [quinn]: {0}")] - H3(#[from] h3::Error), + #[error("Invalid rustls TLS version: {0}")] + QuinnInvalidTlsProtocolVersion(String), + #[cfg(feature = "http3-quinn")] + #[error("Quinn connection error: {0}")] + QuinnConnectionFailed(#[from] quinn::ConnectionError), - #[cfg(feature = "http3-s2n")] - #[error("H3 Error [s2n-quic]: {0}")] - H3(#[from] s2n_quic_h3::h3::Error), + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] + #[error("s2n-quic validation error: {0}")] + S2nQuicValidationError(#[from] s2n_quic_core::transport::parameters::ValidationError), + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] + #[error("s2n-quic connection error: {0}")] + S2nQuicConnectionError(#[from] s2n_quic_core::connection::Error), + #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] + #[error("s2n-quic start error: {0}")] + S2nQuicStartError(#[from] s2n_quic::provider::StartError), - #[error("rustls Connection Error: {0}")] - Rustls(#[from] rustls::Error), + // certificate reloader errors + #[error("No certificate reloader when building a proxy for TLS")] + NoCertificateReloader, + #[error("Certificate reload error: {0}")] + CertificateReloadError(#[from] hot_reload::ReloaderError), - #[error("Hyper Error: {0}")] - Hyper(#[from] hyper::Error), + // backend errors + #[error("Invalid reverse proxy setting")] + InvalidReverseProxyConfig, + #[error("Invalid upstream option setting")] + InvalidUpstreamOptionSetting, + #[error("Failed to build backend app: {0}")] + FailedToBuildBackendApp(#[from] crate::backend::BackendAppBuilderError), - #[error("Hyper Http Error: {0}")] - HyperHttp(#[from] hyper::http::Error), + // Handler errors + #[error("Failed to build message handler: {0}")] + FailedToBuildMessageHandler(#[from] crate::message_handler::HttpMessageHandlerBuilderError), + #[error("Failed to upgrade request: {0}")] + FailedToUpgradeRequest(String), + #[error("Failed to upgrade response: {0}")] + FailedToUpgradeResponse(String), + #[error("Failed to copy bidirectional for upgraded connections: {0}")] + FailedToCopyBidirectional(String), - #[error("Hyper Http HeaderValue Error: {0}")] - HyperHeaderValue(#[from] hyper::header::InvalidHeaderValue), + // Forwarder errors + #[error("Failed to build forwarder: {0}")] + FailedToBuildForwarder(String), + #[error("Failed to fetch from upstream: {0}")] + FailedToFetchFromUpstream(String), - #[error("Hyper Http HeaderName Error: {0}")] - HyperHeaderName(#[from] hyper::header::InvalidHeaderName), + // Upstream connection setting errors + #[error("Unsupported upstream option")] + UnsupportedUpstreamOption, - #[error(transparent)] - Other(#[from] anyhow::Error), -} - -#[allow(dead_code)] -#[derive(Debug, Error, Clone)] -pub enum ClientCertsError { - #[error("TLS Client Certificate is Required for Given SNI: {0}")] - ClientCertRequired(String), - - #[error("Inconsistent TLS Client Certificate for Given SNI: {0}")] - InconsistentClientCert(String), + // Cache error map + #[cfg(feature = "cache")] + #[error("Cache error: {0}")] + CacheError(#[from] crate::forwarder::CacheError), + + // Others + #[error("Infallible")] + Infallible(#[from] std::convert::Infallible), } diff --git a/rpxy-lib/src/forwarder/cache/cache_error.rs b/rpxy-lib/src/forwarder/cache/cache_error.rs new file mode 100644 index 0000000..341c928 --- /dev/null +++ b/rpxy-lib/src/forwarder/cache/cache_error.rs @@ -0,0 +1,47 @@ +use thiserror::Error; + +pub(crate) type CacheResult = std::result::Result; + +/// Describes things that can go wrong in the Rpxy +#[derive(Debug, Error)] +pub enum CacheError { + // Cache errors, + #[error("Invalid null request and/or response")] + NullRequestOrResponse, + + #[error("Failed to acquire mutex lock for cache")] + FailedToAcquiredMutexLockForCache, + + #[error("Failed to acquire mutex lock for check")] + FailedToAcquiredMutexLockForCheck, + + #[error("Failed to create file cache")] + FailedToCreateFileCache, + + #[error("Failed to write file cache")] + FailedToWriteFileCache, + + #[error("Failed to open cache file")] + FailedToOpenCacheFile, + + #[error("Too large to cache")] + TooLargeToCache, + + #[error("Failed to cache bytes: {0}")] + FailedToCacheBytes(String), + + #[error("Failed to send frame to cache {0}")] + FailedToSendFrameToCache(String), + + #[error("Failed to send frame from file cache {0}")] + FailedToSendFrameFromCache(String), + + #[error("Failed to remove cache file: {0}")] + FailedToRemoveCacheFile(String), + + #[error("Invalid cache target")] + InvalidCacheTarget, + + #[error("Hash mismatched in cache file")] + HashMismatchedInCacheFile, +} diff --git a/rpxy-lib/src/forwarder/cache/cache_main.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs new file mode 100644 index 0000000..edb1ec5 --- /dev/null +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -0,0 +1,527 @@ +use super::cache_error::*; +use crate::{ + globals::Globals, + hyper_ext::body::{full, BoxBody, ResponseBody, UnboundedStreamBody}, + log::*, +}; +use base64::{engine::general_purpose, Engine as _}; +use bytes::{Buf, Bytes, BytesMut}; +use futures::channel::mpsc; +use http::{Request, Response, Uri}; +use http_body_util::{BodyExt, StreamBody}; +use http_cache_semantics::CachePolicy; +use hyper::body::{Frame, Incoming}; +use lru::LruCache; +use sha2::{Digest, Sha256}; +use std::{ + path::{Path, PathBuf}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::SystemTime, +}; +use tokio::{ + fs::{self, File}, + io::{AsyncReadExt, AsyncWriteExt}, + sync::RwLock, +}; + +/* ---------------------------------------------- */ +#[derive(Clone, Debug)] +/// Cache main manager +pub(crate) struct RpxyCache { + /// Inner lru cache manager storing http message caching policy + inner: LruCacheManager, + /// Managing cache file objects through RwLock's lock mechanism for file lock + file_store: FileStore, + /// Async runtime + runtime_handle: tokio::runtime::Handle, + /// Maximum size of each cache file object + max_each_size: usize, + /// Maximum size of cache object on memory + max_each_size_on_memory: usize, + /// Cache directory path + cache_dir: PathBuf, +} + +impl RpxyCache { + #[allow(unused)] + /// Generate cache storage + pub(crate) async fn new(globals: &Globals) -> Option { + if !globals.proxy_config.cache_enabled { + return None; + } + let cache_dir = globals.proxy_config.cache_dir.as_ref().unwrap(); + let file_store = FileStore::new(&globals.runtime_handle).await; + let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry); + + let max_each_size = globals.proxy_config.cache_max_each_size; + let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory; + if max_each_size < max_each_size_on_memory { + warn!( + "Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache" + ); + max_each_size_on_memory = max_each_size; + } + + if let Err(e) = fs::remove_dir_all(cache_dir).await { + warn!("Failed to clean up the cache dir: {e}"); + }; + fs::create_dir_all(&cache_dir).await.unwrap(); + + Some(Self { + file_store, + inner, + runtime_handle: globals.runtime_handle.clone(), + max_each_size, + max_each_size_on_memory, + cache_dir: cache_dir.clone(), + }) + } + + /// Count cache entries + pub(crate) async fn count(&self) -> (usize, usize, usize) { + let total = self.inner.count(); + let file = self.file_store.count().await; + let on_memory = total - file; + (total, on_memory, file) + } + + /// Put response into the cache + pub(crate) async fn put( + &self, + uri: &hyper::Uri, + mut body: Incoming, + policy: &CachePolicy, + ) -> CacheResult { + let cache_manager = self.inner.clone(); + let mut file_store = self.file_store.clone(); + let uri = uri.clone(); + let policy_clone = policy.clone(); + let max_each_size = self.max_each_size; + let max_each_size_on_memory = self.max_each_size_on_memory; + let cache_dir = self.cache_dir.clone(); + + let (body_tx, body_rx) = mpsc::unbounded::, hyper::Error>>(); + + self.runtime_handle.spawn(async move { + let mut size = 0usize; + let mut buf = BytesMut::new(); + + loop { + let frame = match body.frame().await { + Some(frame) => frame, + None => { + debug!("Response body finished"); + break; + } + }; + let frame_size = frame.as_ref().map(|f| { + if f.is_data() { + f.data_ref().map(|bytes| bytes.remaining()).unwrap_or_default() + } else { + 0 + } + }); + size += frame_size.unwrap_or_default(); + + // check size + if size > max_each_size { + warn!("Too large to cache"); + return Err(CacheError::TooLargeToCache); + } + frame + .as_ref() + .map(|f| { + if f.is_data() { + let data_bytes = f.data_ref().unwrap().clone(); + // debug!("cache data bytes of {} bytes", data_bytes.len()); + // We do not use stream-type buffering since it needs to lock file during operation. + buf.extend(data_bytes.as_ref()); + } + }) + .map_err(|e| CacheError::FailedToCacheBytes(e.to_string()))?; + + // send data to use response downstream + body_tx + .unbounded_send(frame) + .map_err(|e| CacheError::FailedToSendFrameToCache(e.to_string()))?; + } + + let buf = buf.freeze(); + // Calculate hash of the cached data, after all data is received. + // In-operation calculation is possible but it blocks sending data. + let mut hasher = Sha256::new(); + hasher.update(buf.as_ref()); + let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); + debug!("Cached data: {} bytes, hash = {:?}", size, hash_bytes); + + // Create cache object + let cache_key = derive_cache_key_from_uri(&uri); + let cache_object = CacheObject { + policy: policy_clone, + target: CacheFileOrOnMemory::build(&cache_dir, &uri, &buf, max_each_size_on_memory), + hash: hash_bytes, + }; + + if let Some((k, v)) = cache_manager.push(&cache_key, &cache_object)? { + if k != cache_key { + info!("Over the cache capacity. Evict least recent used entry"); + if let CacheFileOrOnMemory::File(path) = v.target { + file_store.evict(&path).await; + } + } + } + // store cache object to file + if let CacheFileOrOnMemory::File(_) = cache_object.target { + file_store.create(&cache_object, &buf).await?; + } + + Ok(()) as CacheResult<()> + }); + + let stream_body = StreamBody::new(body_rx); + + Ok(stream_body) + } + + /// Get cached response + pub(crate) async fn get(&self, req: &Request) -> Option> { + debug!( + "Current cache status: (total, on-memory, file) = {:?}", + self.count().await + ); + let cache_key = derive_cache_key_from_uri(req.uri()); + + // First check cache chance + let Ok(Some(cached_object)) = self.inner.get(&cache_key) else { + return None; + }; + + // Secondly check the cache freshness as an HTTP message + let now = SystemTime::now(); + let http_cache_semantics::BeforeRequest::Fresh(res_parts) = cached_object.policy.before_request(req, now) else { + // Evict stale cache entry. + // This might be okay to keep as is since it would be updated later. + // However, there is no guarantee that newly got objects will be still cacheable. + // So, we have to evict stale cache entries and cache file objects if found. + debug!("Stale cache entry: {cache_key}"); + let _evicted_entry = self.inner.evict(&cache_key); + // For cache file + if let CacheFileOrOnMemory::File(path) = &cached_object.target { + self.file_store.evict(&path).await; + } + return None; + }; + + // Finally retrieve the file/on-memory object + let response_body = match cached_object.target { + CacheFileOrOnMemory::File(path) => { + let stream_body = match self.file_store.read(path.clone(), &cached_object.hash).await { + Ok(s) => s, + Err(e) => { + warn!("Failed to read from file cache: {e}"); + let _evicted_entry = self.inner.evict(&cache_key); + self.file_store.evict(path).await; + return None; + } + }; + debug!("Cache hit from file: {cache_key}"); + ResponseBody::Streamed(stream_body) + } + CacheFileOrOnMemory::OnMemory(object) => { + debug!("Cache hit from on memory: {cache_key}"); + let mut hasher = Sha256::new(); + hasher.update(object.as_ref()); + let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); + if hash_bytes != cached_object.hash { + warn!("Hash mismatched. Cache object is corrupted"); + let _evicted_entry = self.inner.evict(&cache_key); + return None; + } + ResponseBody::Boxed(BoxBody::new(full(object))) + } + }; + Some(Response::from_parts(res_parts, response_body)) + } +} + +/* ---------------------------------------------- */ +#[derive(Debug, Clone)] +/// Cache file manager outer that is responsible to handle `RwLock` +struct FileStore { + /// Inner file store main object + inner: Arc>, +} +impl FileStore { + #[allow(unused)] + /// Build manager + async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { + Self { + inner: Arc::new(RwLock::new(FileStoreInner::new(runtime_handle).await)), + } + } + + /// Count file cache entries + async fn count(&self) -> usize { + let inner = self.inner.read().await; + inner.cnt + } + /// Create a temporary file cache + async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { + let mut inner = self.inner.write().await; + inner.create(cache_object, body_bytes).await + } + /// Evict a temporary file cache + async fn evict(&self, path: impl AsRef) { + // Acquire the write lock + let mut inner = self.inner.write().await; + if let Err(e) = inner.remove(path).await { + warn!("Eviction failed during file object removal: {:?}", e); + }; + } + /// Read a temporary file cache + async fn read( + &self, + path: impl AsRef + Send + Sync + 'static, + hash: &Bytes, + ) -> CacheResult { + let inner = self.inner.read().await; + inner.read(path, hash).await + } +} + +#[derive(Debug, Clone)] +/// Manager inner for cache on file system +struct FileStoreInner { + /// Counter of current cached files + cnt: usize, + /// Async runtime + runtime_handle: tokio::runtime::Handle, +} + +impl FileStoreInner { + #[allow(unused)] + /// Build new cache file manager. + /// This first creates cache file dir if not exists, and cleans up the file inside the directory. + /// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed. + async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { + Self { + cnt: 0, + runtime_handle: runtime_handle.clone(), + } + } + + /// Create a new temporary file cache + async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { + let cache_filepath = match cache_object.target { + CacheFileOrOnMemory::File(ref path) => path.clone(), + CacheFileOrOnMemory::OnMemory(_) => { + return Err(CacheError::InvalidCacheTarget); + } + }; + let Ok(mut file) = File::create(&cache_filepath).await else { + return Err(CacheError::FailedToCreateFileCache); + }; + let mut bytes_clone = body_bytes.clone(); + while bytes_clone.has_remaining() { + if let Err(e) = file.write_buf(&mut bytes_clone).await { + error!("Failed to write file cache: {e}"); + return Err(CacheError::FailedToWriteFileCache); + }; + } + self.cnt += 1; + Ok(()) + } + + /// Retrieve a stored temporary file cache + async fn read( + &self, + path: impl AsRef + Send + Sync + 'static, + hash: &Bytes, + ) -> CacheResult { + let Ok(mut file) = File::open(&path).await else { + warn!("Cache file object cannot be opened"); + return Err(CacheError::FailedToOpenCacheFile); + }; + let hash_clone = hash.clone(); + let mut self_clone = self.clone(); + + let (body_tx, body_rx) = mpsc::unbounded::, hyper::Error>>(); + + self.runtime_handle.spawn(async move { + let mut hasher = Sha256::new(); + let mut buf = BytesMut::new(); + loop { + match file.read_buf(&mut buf).await { + Ok(0) => break, + Ok(_) => { + let bytes = buf.copy_to_bytes(buf.remaining()); + hasher.update(bytes.as_ref()); + body_tx + .unbounded_send(Ok(Frame::data(bytes))) + .map_err(|e| CacheError::FailedToSendFrameFromCache(e.to_string()))? + } + Err(_) => break, + }; + } + let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); + if hash_bytes != hash_clone { + warn!("Hash mismatched. Cache object is corrupted. Force to remove the cache file."); + // only file can be evicted + let _evicted_entry = self_clone.remove(&path).await; + return Err(CacheError::HashMismatchedInCacheFile); + } + Ok(()) as CacheResult<()> + }); + + let stream_body = StreamBody::new(body_rx); + + Ok(stream_body) + } + + /// Remove file + async fn remove(&mut self, path: impl AsRef) -> CacheResult<()> { + fs::remove_file(path.as_ref()) + .await + .map_err(|e| CacheError::FailedToRemoveCacheFile(e.to_string()))?; + self.cnt -= 1; + debug!("Removed a cache file at {:?} (file count: {})", path.as_ref(), self.cnt); + + Ok(()) + } +} + +/* ---------------------------------------------- */ + +#[derive(Clone, Debug)] +/// Cache target in hybrid manner of on-memory and file system +pub(crate) enum CacheFileOrOnMemory { + /// Pointer to the temporary cache file + File(PathBuf), + /// Cached body itself + OnMemory(Bytes), +} + +impl CacheFileOrOnMemory { + /// Get cache object target + fn build(cache_dir: &Path, uri: &Uri, object: &Bytes, max_each_size_on_memory: usize) -> Self { + if object.len() > max_each_size_on_memory { + let cache_filename = derive_filename_from_uri(uri); + let cache_filepath = cache_dir.join(cache_filename); + CacheFileOrOnMemory::File(cache_filepath) + } else { + CacheFileOrOnMemory::OnMemory(object.clone()) + } + } +} + +#[derive(Clone, Debug)] +/// Cache object definition +struct CacheObject { + /// Cache policy to determine if the stored cache can be used as a response to a new incoming request + policy: CachePolicy, + /// Cache target: on-memory object or temporary file + target: CacheFileOrOnMemory, + /// SHA256 hash of target to strongly bind the cache metadata (this object) and file target + hash: Bytes, +} + +/* ---------------------------------------------- */ +#[derive(Debug, Clone)] +/// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache` +struct LruCacheManager { + /// Inner lru cache manager main object + inner: Arc>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう + /// Counter of current cached object (total) + cnt: Arc, +} + +impl LruCacheManager { + #[allow(unused)] + /// Build LruCache + fn new(cache_max_entry: usize) -> Self { + Self { + inner: Arc::new(Mutex::new(LruCache::new( + std::num::NonZeroUsize::new(cache_max_entry).unwrap(), + ))), + cnt: Default::default(), + } + } + + /// Count entries + fn count(&self) -> usize { + self.cnt.load(Ordering::Relaxed) + } + + /// Evict an entry + fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> { + let Ok(mut lock) = self.inner.lock() else { + error!("Mutex can't be locked to evict a cache entry"); + return None; + }; + let res = lock.pop_entry(cache_key); + // This may be inconsistent with the actual number of entries + self.cnt.store(lock.len(), Ordering::Relaxed); + res + } + + /// Push an entry + fn push(&self, cache_key: &str, cache_object: &CacheObject) -> CacheResult> { + let Ok(mut lock) = self.inner.lock() else { + error!("Failed to acquire mutex lock for writing cache entry"); + return Err(CacheError::FailedToAcquiredMutexLockForCache); + }; + let res = Ok(lock.push(cache_key.to_string(), cache_object.clone())); + // This may be inconsistent with the actual number of entries + self.cnt.store(lock.len(), Ordering::Relaxed); + res + } + + /// Get an entry + fn get(&self, cache_key: &str) -> CacheResult> { + let Ok(mut lock) = self.inner.lock() else { + error!("Mutex can't be locked for checking cache entry"); + return Err(CacheError::FailedToAcquiredMutexLockForCheck); + }; + let Some(cached_object) = lock.get(cache_key) else { + return Ok(None); + }; + Ok(Some(cached_object.clone())) + } +} + +/* ---------------------------------------------- */ +/// Generate cache policy if the response is cacheable +pub(crate) fn get_policy_if_cacheable( + req: Option<&Request>, + res: Option<&Response>, +) -> CacheResult> +// where +// B1: core::fmt::Debug, +{ + // deduce cache policy from req and res + let (Some(req), Some(res)) = (req, res) else { + return Err(CacheError::NullRequestOrResponse); + }; + + let new_policy = CachePolicy::new(req, res); + if new_policy.is_storable() { + // debug!("Response is cacheable: {:?}\n{:?}", req, res.headers()); + Ok(Some(new_policy)) + } else { + Ok(None) + } +} + +fn derive_filename_from_uri(uri: &hyper::Uri) -> String { + let mut hasher = Sha256::new(); + hasher.update(uri.to_string()); + let digest = hasher.finalize(); + general_purpose::URL_SAFE_NO_PAD.encode(digest) +} + +fn derive_cache_key_from_uri(uri: &hyper::Uri) -> String { + uri.to_string() +} diff --git a/rpxy-lib/src/forwarder/cache/mod.rs b/rpxy-lib/src/forwarder/cache/mod.rs new file mode 100644 index 0000000..076eaa3 --- /dev/null +++ b/rpxy-lib/src/forwarder/cache/mod.rs @@ -0,0 +1,5 @@ +mod cache_error; +mod cache_main; + +pub use cache_error::CacheError; +pub(crate) use cache_main::{get_policy_if_cacheable, RpxyCache}; diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs new file mode 100644 index 0000000..9be7b00 --- /dev/null +++ b/rpxy-lib/src/forwarder/client.rs @@ -0,0 +1,255 @@ +#[allow(unused)] +use crate::{ + error::{RpxyError, RpxyResult}, + globals::Globals, + hyper_ext::{body::ResponseBody, rt::LocalExecutor}, + log::*, +}; +use async_trait::async_trait; +use http::{Request, Response, Version}; +use hyper::body::{Body, Incoming}; +use hyper_util::client::legacy::{ + connect::{Connect, HttpConnector}, + Client, +}; +use std::sync::Arc; + +#[cfg(feature = "cache")] +use super::cache::{get_policy_if_cacheable, RpxyCache}; + +#[async_trait] +/// Definition of the forwarder that simply forward requests from downstream client to upstream app servers. +pub trait ForwardRequest { + type Error; + async fn request(&self, req: Request) -> Result, Self::Error>; +} + +/// Forwarder http client struct responsible to cache handling +pub struct Forwarder { + #[cfg(feature = "cache")] + cache: Option, + inner: Client, + inner_h2: Client, // `h2c` or http/2-only client is defined separately +} + +#[async_trait] +impl ForwardRequest for Forwarder +where + C: Send + Sync + Connect + Clone + 'static, + B1: Body + Send + Sync + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + type Error = RpxyError; + + async fn request(&self, req: Request) -> Result, Self::Error> { + // TODO: cache handling + #[cfg(feature = "cache")] + { + let mut synth_req = None; + if self.cache.is_some() { + // try reading from cache + if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { + // if found, return it as response. + info!("Cache hit - Return from cache"); + return Ok(cached_response); + }; + + // Synthetic request copy used just for caching (cannot clone request object...) + synth_req = Some(build_synth_req_for_cache(&req)); + } + let res = self.request_directly(req).await; + + if self.cache.is_none() { + return res.map(|inner| inner.map(ResponseBody::Incoming)); + } + + // check cacheability and store it if cacheable + let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { + return res.map(|inner| inner.map(ResponseBody::Incoming)); + }; + let (parts, body) = res.unwrap().into_parts(); + + // Get streamed body without waiting for the arrival of the body, + // which is done simultaneously with caching. + let stream_body = self + .cache + .as_ref() + .unwrap() + .put(synth_req.unwrap().uri(), body, &cache_policy) + .await?; + + // response with body being cached in background + let new_res = Response::from_parts(parts, ResponseBody::Streamed(stream_body)); + Ok(new_res) + } + + // No cache handling + #[cfg(not(feature = "cache"))] + { + self + .request_directly(req) + .await + .map(|inner| inner.map(ResponseBody::Incoming)) + } + } +} + +impl Forwarder +where + C: Send + Sync + Connect + Clone + 'static, + B1: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + async fn request_directly(&self, req: Request) -> RpxyResult> { + // TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient. + // Needs to be reconsidered. Currently, this is a kind of work around. + // This possibly relates to https://github.com/hyperium/hyper/issues/2417. + match req.version() { + Version::HTTP_2 => self.inner_h2.request(req).await, // handles `h2c` requests + _ => self.inner.request(req).await, + } + .map_err(|e| RpxyError::FailedToFetchFromUpstream(e.to_string())) + } +} + +#[cfg(not(any(feature = "native-tls-backend", feature = "rustls-backend")))] +impl Forwarder +where + B: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + /// Build inner client with http + pub async fn try_new(_globals: &Arc) -> RpxyResult { + warn!( + " +-------------------------------------------------------------------------------------------------- +Request forwarder is working without TLS support!!! +We recommend to use this just for testing. +Please enable native-tls-backend or rustls-backend feature to enable TLS support. +--------------------------------------------------------------------------------------------------" + ); + let executor = LocalExecutor::new(_globals.runtime_handle.clone()); + let mut http = HttpConnector::new(); + http.enforce_http(true); + http.set_reuse_address(true); + http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); + let inner = Client::builder(executor).build::<_, B>(http); + let inner_h2 = inner.clone(); + + Ok(Self { + inner, + inner_h2, + #[cfg(feature = "cache")] + cache: RpxyCache::new(_globals).await, + }) + } +} + +#[cfg(all(feature = "native-tls-backend", not(feature = "rustls-backend")))] +/// Build forwarder with hyper-tls (native-tls) +impl Forwarder, B1> +where + B1: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + /// Build forwarder + pub async fn try_new(_globals: &Arc) -> RpxyResult { + // build hyper client with hyper-tls + info!("Native TLS support is enabled for the connection to backend applications"); + let executor = LocalExecutor::new(_globals.runtime_handle.clone()); + + let try_build_connector = |alpns: &[&str]| { + hyper_tls::native_tls::TlsConnector::builder() + .request_alpns(alpns) + .build() + .map_err(|e| RpxyError::FailedToBuildForwarder(e.to_string())) + .map(|tls| { + let mut http = HttpConnector::new(); + http.enforce_http(false); + http.set_reuse_address(true); + http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); + hyper_tls::HttpsConnector::from((http, tls.into())) + }) + }; + + let connector = try_build_connector(&["h2", "http/1.1"])?; + let inner = Client::builder(executor.clone()).build::<_, B1>(connector); + + let connector_h2 = try_build_connector(&["h2"])?; + let inner_h2 = Client::builder(executor.clone()) + .http2_only(true) + .build::<_, B1>(connector_h2); + + Ok(Self { + inner, + inner_h2, + #[cfg(feature = "cache")] + cache: RpxyCache::new(_globals).await, + }) + } +} + +#[cfg(feature = "rustls-backend")] +/// Build forwarder with hyper-rustls (rustls) +impl Forwarder, B1> +where + B1: Body + Send + Unpin + 'static, + ::Data: Send, + ::Error: Into>, +{ + /// Build forwarder + pub async fn try_new(_globals: &Arc) -> RpxyResult { + // build hyper client with rustls and webpki, only https is allowed + #[cfg(feature = "rustls-backend-webpki")] + let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); + #[cfg(feature = "rustls-backend-webpki")] + let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); + #[cfg(feature = "rustls-backend-webpki")] + info!("Mozilla WebPKI root certs with rustls is used for the connection to backend applications"); + + #[cfg(not(feature = "rustls-backend-webpki"))] + let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots()?; + #[cfg(not(feature = "rustls-backend-webpki"))] + let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots()?; + #[cfg(not(feature = "rustls-backend-webpki"))] + info!("Native cert store with rustls is used for the connection to backend applications"); + + let mut http = HttpConnector::new(); + http.enforce_http(false); + http.set_reuse_address(true); + http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); + + let connector = builder + .https_or_http() + .enable_all_versions() + .wrap_connector(http.clone()); + let connector_h2 = builder_h2.https_or_http().enable_http2().wrap_connector(http); + let inner = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector); + let inner_h2 = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector_h2); + + Ok(Self { + inner, + inner_h2, + #[cfg(feature = "cache")] + cache: RpxyCache::new(_globals).await, + }) + } +} + +#[cfg(feature = "cache")] +/// Build synthetic request to cache +fn build_synth_req_for_cache(req: &Request) -> Request<()> { + let mut builder = Request::builder() + .method(req.method()) + .uri(req.uri()) + .version(req.version()); + // TODO: omit extensions. is this approach correct? + for (header_key, header_value) in req.headers() { + builder = builder.header(header_key, header_value); + } + builder.body(()).unwrap() +} diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs new file mode 100644 index 0000000..26aa0c9 --- /dev/null +++ b/rpxy-lib/src/forwarder/mod.rs @@ -0,0 +1,11 @@ +#[cfg(feature = "cache")] +mod cache; +mod client; + +use crate::hyper_ext::body::RequestBody; + +pub(crate) type Forwarder = client::Forwarder; +pub(crate) use client::ForwardRequest; + +#[cfg(feature = "cache")] +pub(crate) use cache::CacheError; diff --git a/rpxy-lib/src/globals.rs b/rpxy-lib/src/globals.rs index d1c0130..e4bff9e 100644 --- a/rpxy-lib/src/globals.rs +++ b/rpxy-lib/src/globals.rs @@ -1,57 +1,53 @@ use crate::{ - backend::{ - Backend, BackendBuilder, Backends, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption, - }, - certs::CryptoSource, constants::*, - error::RpxyError, - log::*, - utils::{BytesName, PathNameBytesExp}, + count::RequestCount, + crypto::{CryptoSource, ServerCryptoBase}, }; -use rustc_hash::FxHashMap as HashMap; -use std::net::SocketAddr; -use std::sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, -}; -use tokio::time::Duration; +use hot_reload::ReloaderReceiver; +use std::{net::SocketAddr, sync::Arc, time::Duration}; /// Global object containing proxy configurations and shared object like counters. /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks. -pub struct Globals -where - T: CryptoSource, -{ +pub struct Globals { /// Configuration parameters for proxy transport and request handlers - pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも? - - /// Backend application objects to which http request handler forward incoming requests - pub backends: Backends, - + pub proxy_config: ProxyConfig, /// Shared context - Counter for serving requests pub request_count: RequestCount, - /// Shared context - Async task runtime handler pub runtime_handle: tokio::runtime::Handle, + /// Shared context - Notify object to stop async tasks + pub term_notify: Option>, + /// Shared context - Certificate reloader service receiver + pub cert_reloader_rx: Option>, } /// Configuration parameters for proxy transport and request handlers #[derive(PartialEq, Eq, Clone)] pub struct ProxyConfig { - pub listen_sockets: Vec, // when instantiate server - pub http_port: Option, // when instantiate server - pub https_port: Option, // when instantiate server - pub tcp_listen_backlog: u32, // when instantiate server + /// listen socket addresses + pub listen_sockets: Vec, + /// http port + pub http_port: Option, + /// https port + pub https_port: Option, + /// tcp listen backlog + pub tcp_listen_backlog: u32, - pub proxy_timeout: Duration, // when serving requests at Proxy - pub upstream_timeout: Duration, // when serving requests at Handler + /// Idle timeout as an HTTP server, used as the keep alive interval and timeout for reading request header + pub proxy_idle_timeout: Duration, + /// Idle timeout as an HTTP client, used as the keep alive interval for upstream connections + pub upstream_idle_timeout: Duration, pub max_clients: usize, // when serving requests pub max_concurrent_streams: u32, // when instantiate server pub keepalive: bool, // when instantiate server // experimentals + /// SNI consistency check pub sni_consistency: bool, // Handler + /// Connection handling timeout + /// timeout to handle a connection, total time of receive request, serve, and send response. this might limits the max length of response. + pub connection_handling_timeout: Option, #[cfg(feature = "cache")] pub cache_enabled: bool, @@ -90,14 +86,15 @@ impl Default for ProxyConfig { tcp_listen_backlog: TCP_LISTEN_BACKLOG, // TODO: Reconsider each timeout values - proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), - upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), + proxy_idle_timeout: Duration::from_secs(PROXY_IDLE_TIMEOUT_SEC), + upstream_idle_timeout: Duration::from_secs(UPSTREAM_IDLE_TIMEOUT_SEC), max_clients: MAX_CLIENTS, max_concurrent_streams: MAX_CONCURRENT_STREAMS, keepalive: true, sni_consistency: true, + connection_handling_timeout: None, #[cfg(feature = "cache")] cache_enabled: false, @@ -137,44 +134,6 @@ where pub inner: Vec>, pub default_app: Option, } -impl TryInto> for AppConfigList -where - T: CryptoSource + Clone, -{ - type Error = RpxyError; - - fn try_into(self) -> Result, Self::Error> { - let mut backends = Backends::new(); - for app_config in self.inner.iter() { - let backend = app_config.try_into()?; - backends - .apps - .insert(app_config.server_name.clone().to_server_name_vec(), backend); - info!( - "Registering application {} ({})", - &app_config.server_name, &app_config.app_name - ); - } - - // default backend application for plaintext http requests - if let Some(d) = self.default_app { - let d_sn: Vec<&str> = backends - .apps - .iter() - .filter(|(_k, v)| v.app_name == d) - .map(|(_, v)| v.server_name.as_ref()) - .collect(); - if !d_sn.is_empty() { - info!( - "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", - d, d_sn[0] - ); - backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); - } - } - Ok(backends) - } -} /// Configuration parameters for single backend application #[derive(PartialEq, Eq, Clone)] @@ -187,77 +146,6 @@ where pub reverse_proxy: Vec, pub tls: Option>, } -impl TryInto> for &AppConfig -where - T: CryptoSource + Clone, -{ - type Error = RpxyError; - - fn try_into(self) -> Result, Self::Error> { - // backend builder - let mut backend_builder = BackendBuilder::default(); - // reverse proxy settings - let reverse_proxy = self.try_into()?; - - backend_builder - .app_name(self.app_name.clone()) - .server_name(self.server_name.clone()) - .reverse_proxy(reverse_proxy); - - // TLS settings and build backend instance - let backend = if self.tls.is_none() { - backend_builder.build().map_err(RpxyError::BackendBuild)? - } else { - let tls = self.tls.as_ref().unwrap(); - - backend_builder - .https_redirection(Some(tls.https_redirection)) - .crypto_source(Some(tls.inner.clone())) - .build()? - }; - Ok(backend) - } -} -impl TryInto for &AppConfig -where - T: CryptoSource + Clone, -{ - type Error = RpxyError; - - fn try_into(self) -> Result { - let mut upstream: HashMap = HashMap::default(); - - self.reverse_proxy.iter().for_each(|rpo| { - let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); - // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); - // let lb_upstream_num = vec_upstream.len(); - let elem = UpstreamGroupBuilder::default() - .upstream(&upstream_vec) - .path(&rpo.path) - .replace_path(&rpo.replace_path) - .lb(&rpo.load_balance, &upstream_vec, &self.server_name, &rpo.path) - .opts(&rpo.upstream_options) - .build() - .unwrap(); - - upstream.insert(elem.path.clone(), elem); - }); - if self.reverse_proxy.iter().filter(|rpo| rpo.path.is_none()).count() >= 2 { - error!("Multiple default reverse proxy setting"); - return Err(RpxyError::ConfigBuild("Invalid reverse proxy setting")); - } - - if !(upstream.iter().all(|(_, elem)| { - !(elem.opts.contains(&UpstreamOption::ForceHttp11Upstream) - && elem.opts.contains(&UpstreamOption::ForceHttp2Upstream)) - })) { - error!("Either one of force_http11 or force_http2 can be enabled"); - return Err(RpxyError::ConfigBuild("Invalid upstream option setting")); - } - - Ok(ReverseProxy { upstream }) - } -} /// Configuration parameters for single reverse proxy corresponding to the path #[derive(PartialEq, Eq, Clone)] @@ -272,16 +160,7 @@ pub struct ReverseProxyConfig { /// Configuration parameters for single upstream destination from a reverse proxy #[derive(PartialEq, Eq, Clone)] pub struct UpstreamUri { - pub inner: hyper::Uri, -} -impl TryInto for &UpstreamUri { - type Error = anyhow::Error; - - fn try_into(self) -> std::result::Result { - Ok(Upstream { - uri: self.inner.clone(), - }) - } + pub inner: http::Uri, } /// Configuration parameters on TLS for a single backend application @@ -293,30 +172,3 @@ where pub inner: T, pub https_redirection: bool, } - -#[derive(Debug, Clone, Default)] -/// Counter for serving requests -pub struct RequestCount(Arc); - -impl RequestCount { - pub fn current(&self) -> usize { - self.0.load(Ordering::Relaxed) - } - - pub fn increment(&self) -> usize { - self.0.fetch_add(1, Ordering::Relaxed) - } - - pub fn decrement(&self) -> usize { - let mut count; - while { - count = self.0.load(Ordering::Relaxed); - count > 0 - && self - .0 - .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) - != Ok(count) - } {} - count - } -} diff --git a/rpxy-lib/src/handler/cache.rs b/rpxy-lib/src/handler/cache.rs deleted file mode 100644 index 44cdc11..0000000 --- a/rpxy-lib/src/handler/cache.rs +++ /dev/null @@ -1,393 +0,0 @@ -use crate::{error::*, globals::Globals, log::*, CryptoSource}; -use base64::{engine::general_purpose, Engine as _}; -use bytes::{Buf, Bytes, BytesMut}; -use http_cache_semantics::CachePolicy; -use hyper::{ - http::{Request, Response}, - Body, -}; -use lru::LruCache; -use sha2::{Digest, Sha256}; -use std::{ - fmt::Debug, - path::{Path, PathBuf}, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, Mutex, - }, - time::SystemTime, -}; -use tokio::{ - fs::{self, File}, - io::{AsyncReadExt, AsyncWriteExt}, - sync::RwLock, -}; - -#[derive(Clone, Debug)] -/// Cache target in hybrid manner of on-memory and file system -pub enum CacheFileOrOnMemory { - /// Pointer to the temporary cache file - File(PathBuf), - /// Cached body itself - OnMemory(Vec), -} - -#[derive(Clone, Debug)] -/// Cache object definition -struct CacheObject { - /// Cache policy to determine if the stored cache can be used as a response to a new incoming request - pub policy: CachePolicy, - /// Cache target: on-memory object or temporary file - pub target: CacheFileOrOnMemory, -} - -#[derive(Debug)] -/// Manager inner for cache on file system -struct CacheFileManagerInner { - /// Directory of temporary files - cache_dir: PathBuf, - /// Counter of current cached files - cnt: usize, - /// Async runtime - runtime_handle: tokio::runtime::Handle, -} - -impl CacheFileManagerInner { - /// Build new cache file manager. - /// This first creates cache file dir if not exists, and cleans up the file inside the directory. - /// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed. - async fn new(path: impl AsRef, runtime_handle: &tokio::runtime::Handle) -> Self { - let path_buf = path.as_ref().to_path_buf(); - if let Err(e) = fs::remove_dir_all(path).await { - warn!("Failed to clean up the cache dir: {e}"); - }; - fs::create_dir_all(&path_buf).await.unwrap(); - Self { - cache_dir: path_buf.clone(), - cnt: 0, - runtime_handle: runtime_handle.clone(), - } - } - - /// Create a new temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> Result { - let cache_filepath = self.cache_dir.join(cache_filename); - let Ok(mut file) = File::create(&cache_filepath).await else { - return Err(RpxyError::Cache("Failed to create file")); - }; - let mut bytes_clone = body_bytes.clone(); - while bytes_clone.has_remaining() { - if let Err(e) = file.write_buf(&mut bytes_clone).await { - error!("Failed to write file cache: {e}"); - return Err(RpxyError::Cache("Failed to write file cache: {e}")); - }; - } - self.cnt += 1; - Ok(CacheFileOrOnMemory::File(cache_filepath)) - } - - /// Retrieve a stored temporary file cache - async fn read(&self, path: impl AsRef) -> Result { - let Ok(mut file) = File::open(&path).await else { - warn!("Cache file object cannot be opened"); - return Err(RpxyError::Cache("Cache file object cannot be opened")); - }; - let (body_sender, res_body) = Body::channel(); - self.runtime_handle.spawn(async move { - let mut sender = body_sender; - let mut buf = BytesMut::new(); - loop { - match file.read_buf(&mut buf).await { - Ok(0) => break, - Ok(_) => sender.send_data(buf.copy_to_bytes(buf.remaining())).await?, - Err(_) => break, - }; - } - Ok(()) as Result<()> - }); - - Ok(res_body) - } - - /// Remove file - async fn remove(&mut self, path: impl AsRef) -> Result<()> { - fs::remove_file(path.as_ref()).await?; - self.cnt -= 1; - debug!("Removed a cache file at {:?} (file count: {})", path.as_ref(), self.cnt); - - Ok(()) - } -} - -#[derive(Debug, Clone)] -/// Cache file manager outer that is responsible to handle `RwLock` -struct CacheFileManager { - inner: Arc>, -} - -impl CacheFileManager { - /// Build manager - async fn new(path: impl AsRef, runtime_handle: &tokio::runtime::Handle) -> Self { - Self { - inner: Arc::new(RwLock::new(CacheFileManagerInner::new(path, runtime_handle).await)), - } - } - /// Evict a temporary file cache - async fn evict(&self, path: impl AsRef) { - // Acquire the write lock - let mut inner = self.inner.write().await; - if let Err(e) = inner.remove(path).await { - warn!("Eviction failed during file object removal: {:?}", e); - }; - } - /// Read a temporary file cache - async fn read(&self, path: impl AsRef) -> Result { - let mgr = self.inner.read().await; - mgr.read(&path).await - } - /// Create a temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> Result { - let mut mgr = self.inner.write().await; - mgr.create(cache_filename, body_bytes).await - } - async fn count(&self) -> usize { - let mgr = self.inner.read().await; - mgr.cnt - } -} - -#[derive(Debug, Clone)] -/// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache` -struct LruCacheManager { - inner: Arc>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう - cnt: Arc, -} - -impl LruCacheManager { - /// Build LruCache - fn new(cache_max_entry: usize) -> Self { - Self { - inner: Arc::new(Mutex::new(LruCache::new( - std::num::NonZeroUsize::new(cache_max_entry).unwrap(), - ))), - cnt: Arc::new(AtomicUsize::default()), - } - } - /// Count entries - fn count(&self) -> usize { - self.cnt.load(Ordering::Relaxed) - } - /// Evict an entry - fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> { - let Ok(mut lock) = self.inner.lock() else { - error!("Mutex can't be locked to evict a cache entry"); - return None; - }; - let res = lock.pop_entry(cache_key); - self.cnt.store(lock.len(), Ordering::Relaxed); - res - } - /// Get an entry - fn get(&self, cache_key: &str) -> Result> { - let Ok(mut lock) = self.inner.lock() else { - error!("Mutex can't be locked for checking cache entry"); - return Err(RpxyError::Cache("Mutex can't be locked for checking cache entry")); - }; - let Some(cached_object) = lock.get(cache_key) else { - return Ok(None); - }; - Ok(Some(cached_object.clone())) - } - /// Push an entry - fn push(&self, cache_key: &str, cache_object: CacheObject) -> Result> { - let Ok(mut lock) = self.inner.lock() else { - error!("Failed to acquire mutex lock for writing cache entry"); - return Err(RpxyError::Cache("Failed to acquire mutex lock for writing cache entry")); - }; - let res = Ok(lock.push(cache_key.to_string(), cache_object)); - self.cnt.store(lock.len(), Ordering::Relaxed); - res - } -} - -#[derive(Clone, Debug)] -pub struct RpxyCache { - /// Managing cache file objects through RwLock's lock mechanism for file lock - cache_file_manager: CacheFileManager, - /// Lru cache storing http message caching policy - inner: LruCacheManager, - /// Async runtime - runtime_handle: tokio::runtime::Handle, - /// Maximum size of each cache file object - max_each_size: usize, - /// Maximum size of cache object on memory - max_each_size_on_memory: usize, -} - -impl RpxyCache { - /// Generate cache storage - pub async fn new(globals: &Globals) -> Option { - if !globals.proxy_config.cache_enabled { - return None; - } - - let path = globals.proxy_config.cache_dir.as_ref().unwrap(); - let cache_file_manager = CacheFileManager::new(path, &globals.runtime_handle).await; - let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry); - - let max_each_size = globals.proxy_config.cache_max_each_size; - let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory; - if max_each_size < max_each_size_on_memory { - warn!( - "Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache" - ); - max_each_size_on_memory = max_each_size; - } - - Some(Self { - cache_file_manager, - inner, - runtime_handle: globals.runtime_handle.clone(), - max_each_size, - max_each_size_on_memory, - }) - } - - /// Count cache entries - pub async fn count(&self) -> (usize, usize, usize) { - let total = self.inner.count(); - let file = self.cache_file_manager.count().await; - let on_memory = total - file; - (total, on_memory, file) - } - - /// Get cached response - pub async fn get(&self, req: &Request) -> Option> { - debug!( - "Current cache status: (total, on-memory, file) = {:?}", - self.count().await - ); - let cache_key = req.uri().to_string(); - - // First check cache chance - let Ok(Some(cached_object)) = self.inner.get(&cache_key) else { - return None; - }; - - // Secondly check the cache freshness as an HTTP message - let now = SystemTime::now(); - let http_cache_semantics::BeforeRequest::Fresh(res_parts) = cached_object.policy.before_request(req, now) else { - // Evict stale cache entry. - // This might be okay to keep as is since it would be updated later. - // However, there is no guarantee that newly got objects will be still cacheable. - // So, we have to evict stale cache entries and cache file objects if found. - debug!("Stale cache entry: {cache_key}"); - let _evicted_entry = self.inner.evict(&cache_key); - // For cache file - if let CacheFileOrOnMemory::File(path) = &cached_object.target { - self.cache_file_manager.evict(&path).await; - } - return None; - }; - - // Finally retrieve the file/on-memory object - match cached_object.target { - CacheFileOrOnMemory::File(path) => { - let res_body = match self.cache_file_manager.read(&path).await { - Ok(res_body) => res_body, - Err(e) => { - warn!("Failed to read from file cache: {e}"); - let _evicted_entry = self.inner.evict(&cache_key); - self.cache_file_manager.evict(&path).await; - return None; - } - }; - - debug!("Cache hit from file: {cache_key}"); - Some(Response::from_parts(res_parts, res_body)) - } - CacheFileOrOnMemory::OnMemory(object) => { - debug!("Cache hit from on memory: {cache_key}"); - Some(Response::from_parts(res_parts, Body::from(object))) - } - } - } - - /// Put response into the cache - pub async fn put(&self, uri: &hyper::Uri, body_bytes: &Bytes, policy: &CachePolicy) -> Result<()> { - let my_cache = self.inner.clone(); - let mut mgr = self.cache_file_manager.clone(); - let uri = uri.clone(); - let bytes_clone = body_bytes.clone(); - let policy_clone = policy.clone(); - let max_each_size = self.max_each_size; - let max_each_size_on_memory = self.max_each_size_on_memory; - - self.runtime_handle.spawn(async move { - if bytes_clone.len() > max_each_size { - warn!("Too large to cache"); - return Err(RpxyError::Cache("Too large to cache")); - } - let cache_key = derive_cache_key_from_uri(&uri); - - debug!("Object of size {:?} bytes to be cached", bytes_clone.len()); - - let cache_object = if bytes_clone.len() > max_each_size_on_memory { - let cache_filename = derive_filename_from_uri(&uri); - let target = mgr.create(&cache_filename, &bytes_clone).await?; - debug!("Cached a new cache file: {} - {}", cache_key, cache_filename); - CacheObject { - policy: policy_clone, - target, - } - } else { - debug!("Cached a new object on memory: {}", cache_key); - CacheObject { - policy: policy_clone, - target: CacheFileOrOnMemory::OnMemory(bytes_clone.to_vec()), - } - }; - - if let Some((k, v)) = my_cache.push(&cache_key, cache_object)? { - if k != cache_key { - info!("Over the cache capacity. Evict least recent used entry"); - if let CacheFileOrOnMemory::File(path) = v.target { - mgr.evict(&path).await; - } - } - } - Ok(()) - }); - - Ok(()) - } -} - -fn derive_filename_from_uri(uri: &hyper::Uri) -> String { - let mut hasher = Sha256::new(); - hasher.update(uri.to_string()); - let digest = hasher.finalize(); - general_purpose::URL_SAFE_NO_PAD.encode(digest) -} - -fn derive_cache_key_from_uri(uri: &hyper::Uri) -> String { - uri.to_string() -} - -pub fn get_policy_if_cacheable(req: Option<&Request>, res: Option<&Response>) -> Result> -where - R: Debug, -{ - // deduce cache policy from req and res - let (Some(req), Some(res)) = (req, res) else { - return Err(RpxyError::Cache("Invalid null request and/or response")); - }; - - let new_policy = CachePolicy::new(req, res); - if new_policy.is_storable() { - // debug!("Response is cacheable: {:?}\n{:?}", req, res.headers()); - Ok(Some(new_policy)) - } else { - Ok(None) - } -} diff --git a/rpxy-lib/src/handler/forwarder.rs b/rpxy-lib/src/handler/forwarder.rs deleted file mode 100644 index 4764d36..0000000 --- a/rpxy-lib/src/handler/forwarder.rs +++ /dev/null @@ -1,147 +0,0 @@ -#[cfg(feature = "cache")] -use super::cache::{get_policy_if_cacheable, RpxyCache}; -use crate::{error::RpxyError, globals::Globals, log::*, CryptoSource}; -use async_trait::async_trait; -#[cfg(feature = "cache")] -use bytes::Buf; -use hyper::{ - body::{Body, HttpBody}, - client::{connect::Connect, HttpConnector}, - http::Version, - Client, Request, Response, -}; -use hyper_rustls::HttpsConnector; - -#[cfg(feature = "cache")] -/// Build synthetic request to cache -fn build_synth_req_for_cache(req: &Request) -> Request<()> { - let mut builder = Request::builder() - .method(req.method()) - .uri(req.uri()) - .version(req.version()); - // TODO: omit extensions. is this approach correct? - for (header_key, header_value) in req.headers() { - builder = builder.header(header_key, header_value); - } - builder.body(()).unwrap() -} - -#[async_trait] -/// Definition of the forwarder that simply forward requests from downstream client to upstream app servers. -pub trait ForwardRequest { - type Error; - async fn request(&self, req: Request) -> Result, Self::Error>; -} - -/// Forwarder struct responsible to cache handling -pub struct Forwarder -where - C: Connect + Clone + Sync + Send + 'static, -{ - #[cfg(feature = "cache")] - cache: Option, - inner: Client, - inner_h2: Client, // `h2c` or http/2-only client is defined separately -} - -#[async_trait] -impl ForwardRequest for Forwarder -where - B: HttpBody + Send + Sync + 'static, - B::Data: Send, - B::Error: Into>, - C: Connect + Clone + Sync + Send + 'static, -{ - type Error = RpxyError; - - #[cfg(feature = "cache")] - async fn request(&self, req: Request) -> Result, Self::Error> { - let mut synth_req = None; - if self.cache.is_some() { - if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { - // if found, return it as response. - info!("Cache hit - Return from cache"); - return Ok(cached_response); - }; - - // Synthetic request copy used just for caching (cannot clone request object...) - synth_req = Some(build_synth_req_for_cache(&req)); - } - - // TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient. - // Needs to be reconsidered. Currently, this is a kind of work around. - // This possibly relates to https://github.com/hyperium/hyper/issues/2417. - let res = match req.version() { - Version::HTTP_2 => self.inner_h2.request(req).await.map_err(RpxyError::Hyper), // handles `h2c` requests - _ => self.inner.request(req).await.map_err(RpxyError::Hyper), - }; - - if self.cache.is_none() { - return res; - } - - // check cacheability and store it if cacheable - let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { - return res; - }; - let (parts, body) = res.unwrap().into_parts(); - let Ok(mut bytes) = hyper::body::aggregate(body).await else { - return Err(RpxyError::Cache("Failed to write byte buffer")); - }; - let aggregated = bytes.copy_to_bytes(bytes.remaining()); - - if let Err(cache_err) = self - .cache - .as_ref() - .unwrap() - .put(synth_req.unwrap().uri(), &aggregated, &cache_policy) - .await - { - error!("{:?}", cache_err); - }; - - // res - Ok(Response::from_parts(parts, Body::from(aggregated))) - } - - #[cfg(not(feature = "cache"))] - async fn request(&self, req: Request) -> Result, Self::Error> { - match req.version() { - Version::HTTP_2 => self.inner_h2.request(req).await.map_err(RpxyError::Hyper), // handles `h2c` requests - _ => self.inner.request(req).await.map_err(RpxyError::Hyper), - } - } -} - -impl Forwarder, Body> { - /// Build forwarder - pub async fn new(_globals: &std::sync::Arc>) -> Self { - #[cfg(feature = "native-roots")] - let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); - #[cfg(feature = "native-roots")] - let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); - #[cfg(feature = "native-roots")] - info!("Native cert store is used for the connection to backend applications"); - - #[cfg(not(feature = "native-roots"))] - let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); - #[cfg(not(feature = "native-roots"))] - let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); - #[cfg(not(feature = "native-roots"))] - info!("Mozilla WebPKI root certs is used for the connection to backend applications"); - - let connector = builder.https_or_http().enable_http1().enable_http2().build(); - let connector_h2 = builder_h2.https_or_http().enable_http2().build(); - - let inner = Client::builder().build::<_, Body>(connector); - let inner_h2 = Client::builder().http2_only(true).build::<_, Body>(connector_h2); - - #[cfg(feature = "cache")] - { - let cache = RpxyCache::new(_globals).await; - Self { inner, inner_h2, cache } - } - #[cfg(not(feature = "cache"))] - Self { inner, inner_h2 } - } -} diff --git a/rpxy-lib/src/handler/handler_main.rs b/rpxy-lib/src/handler/handler_main.rs deleted file mode 100644 index 8b13dc7..0000000 --- a/rpxy-lib/src/handler/handler_main.rs +++ /dev/null @@ -1,380 +0,0 @@ -// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::{ - forwarder::{ForwardRequest, Forwarder}, - utils_headers::*, - utils_request::*, - utils_synth_response::*, - HandlerContext, -}; -use crate::{ - backend::{Backend, UpstreamGroup}, - certs::CryptoSource, - constants::RESPONSE_HEADER_SERVER, - error::*, - globals::Globals, - log::*, - utils::ServerNameBytesExp, -}; -use derive_builder::Builder; -use hyper::{ - client::connect::Connect, - header::{self, HeaderValue}, - http::uri::Scheme, - Body, Request, Response, StatusCode, Uri, Version, -}; -use std::{net::SocketAddr, sync::Arc}; -use tokio::{io::copy_bidirectional, time::timeout}; - -#[derive(Clone, Builder)] -/// HTTP message handler for requests from clients and responses from backend applications, -/// responsible to manipulate and forward messages to upstream backends and downstream clients. -pub struct HttpMessageHandler -where - T: Connect + Clone + Sync + Send + 'static, - U: CryptoSource + Clone, -{ - forwarder: Arc>, - globals: Arc>, -} - -impl HttpMessageHandler -where - T: Connect + Clone + Sync + Send + 'static, - U: CryptoSource + Clone, -{ - /// Return with an arbitrary status code of error and log message - fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result> { - log_data.status_code(&status_code).output(); - http_error(status_code) - } - - /// Handle incoming request message from a client - pub async fn handle_request( - &self, - mut req: Request, - client_addr: SocketAddr, // アクセス制御用 - listen_addr: SocketAddr, - tls_enabled: bool, - tls_server_name: Option, - ) -> Result> { - //////// - let mut log_data = MessageLog::from(&req); - log_data.client_addr(&client_addr); - ////// - - // Here we start to handle with server_name - let server_name = if let Ok(v) = req.parse_host() { - ServerNameBytesExp::from(v) - } else { - return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - }; - // check consistency of between TLS SNI and HOST/Request URI Line. - #[allow(clippy::collapsible_if)] - if tls_enabled && self.globals.proxy_config.sni_consistency { - if server_name != tls_server_name.unwrap_or_default() { - return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); - } - } - // Find backend application for given server_name, and drop if incoming request is invalid as request. - let backend = match self.globals.backends.apps.get(&server_name) { - Some(be) => be, - None => { - let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else { - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - }; - debug!("Serving by default app"); - self.globals.backends.apps.get(default_server_name).unwrap() - } - }; - - // Redirect to https if !tls_enabled and redirect_to_https is true - if !tls_enabled && backend.https_redirection.unwrap_or(false) { - debug!("Redirect to secure connection: {}", &backend.server_name); - log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output(); - return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req); - } - - // Find reverse proxy for given path and choose one of upstream host - // Longest prefix match - let path = req.uri().path(); - let Some(upstream_group) = backend.reverse_proxy.get(path) else { - return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data) - }; - - // Upgrade in request header - let upgrade_in_request = extract_upgrade(req.headers()); - let request_upgraded = req.extensions_mut().remove::(); - - // Build request from destination information - let _context = match self.generate_request_forwarded( - &client_addr, - &listen_addr, - &mut req, - &upgrade_in_request, - upstream_group, - tls_enabled, - ) { - Err(e) => { - error!("Failed to generate destination uri for reverse proxy: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - } - Ok(v) => v, - }; - debug!("Request to be forwarded: {:?}", req); - log_data.xff(&req.headers().get("x-forwarded-for")); - log_data.upstream(req.uri()); - ////// - - // Forward request to a chosen backend - let mut res_backend = { - let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { - return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); - }; - match result { - Ok(res) => res, - Err(e) => { - error!("Failed to get response from backend: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); - } - } - }; - - // Process reverse proxy context generated during the forwarding request generation. - #[cfg(feature = "sticky-cookie")] - if let Some(context_from_lb) = _context.context_lb { - let res_headers = res_backend.headers_mut(); - if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { - error!("Failed to append context to the response given from backend: {}", e); - return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); - } - } - - if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { - // Generate response to client - if self.generate_response_forwarded(&mut res_backend, backend).is_err() { - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - } - log_data.status_code(&res_backend.status()).output(); - return Ok(res_backend); - } - - // Handle StatusCode::SWITCHING_PROTOCOLS in response - let upgrade_in_response = extract_upgrade(res_backend.headers()); - let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) - { - u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() - } else { - false - }; - if !should_upgrade { - error!( - "Backend tried to switch to protocol {:?} when {:?} was requested", - upgrade_in_response, upgrade_in_request - ); - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - } - let Some(request_upgraded) = request_upgraded else { - error!("Request does not have an upgrade extension"); - return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); - }; - let Some(onupgrade) = res_backend.extensions_mut().remove::() else { - error!("Response does not have an upgrade extension"); - return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); - }; - - self.globals.runtime_handle.spawn(async move { - let mut response_upgraded = onupgrade.await.map_err(|e| { - error!("Failed to upgrade response: {}", e); - RpxyError::Hyper(e) - })?; - let mut request_upgraded = request_upgraded.await.map_err(|e| { - error!("Failed to upgrade request: {}", e); - RpxyError::Hyper(e) - })?; - copy_bidirectional(&mut response_upgraded, &mut request_upgraded) - .await - .map_err(|e| { - error!("Coping between upgraded connections failed: {}", e); - RpxyError::Io(e) - })?; - Ok(()) as Result<()> - }); - log_data.status_code(&res_backend.status()).output(); - Ok(res_backend) - } - - //////////////////////////////////////////////////// - // Functions to generate messages - //////////////////////////////////////////////////// - - /// Manipulate a response message sent from a backend application to forward downstream to a client. - fn generate_response_forwarded(&self, response: &mut Response, chosen_backend: &Backend) -> Result<()> - where - B: core::fmt::Debug, - { - let headers = response.headers_mut(); - remove_connection_header(headers); - remove_hop_header(headers); - add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; - - #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] - { - // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled - // TODO: This is a workaround for avoiding a client authentication in HTTP/3 - if self.globals.proxy_config.http3 - && chosen_backend - .crypto_source - .as_ref() - .is_some_and(|v| !v.is_mutual_tls()) - { - if let Some(port) = self.globals.proxy_config.https_port { - add_header_entry_overwrite_if_exist( - headers, - header::ALT_SVC.as_str(), - format!( - "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", - port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age - ), - )?; - } - } else { - // remove alt-svc to disallow requests via http3 - headers.remove(header::ALT_SVC.as_str()); - } - } - #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] - { - if let Some(port) = self.globals.proxy_config.https_port { - headers.remove(header::ALT_SVC.as_str()); - } - } - - Ok(()) - } - - #[allow(clippy::too_many_arguments)] - /// Manipulate a request message sent from a client to forward upstream to a backend application - fn generate_request_forwarded( - &self, - client_addr: &SocketAddr, - listen_addr: &SocketAddr, - req: &mut Request, - upgrade: &Option, - upstream_group: &UpstreamGroup, - tls_enabled: bool, - ) -> Result { - debug!("Generate request to be forwarded"); - - // Add te: trailer if contained in original request - let contains_te_trailers = { - if let Some(te) = req.headers().get(header::TE) { - te.as_bytes() - .split(|v| v == &b',' || v == &b' ') - .any(|x| x == "trailers".as_bytes()) - } else { - false - } - }; - - let uri = req.uri().to_string(); - let headers = req.headers_mut(); - // delete headers specified in header.connection - remove_connection_header(headers); - // delete hop headers including header.connection - remove_hop_header(headers); - // X-Forwarded-For - add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; - - // Add te: trailer if te_trailer - if contains_te_trailers { - headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); - } - - // add "host" header of original server_name if not exist (default) - if req.headers().get(header::HOST).is_none() { - let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); - req - .headers_mut() - .insert(header::HOST, HeaderValue::from_str(&org_host)?); - }; - - ///////////////////////////////////////////// - // Fix unique upstream destination since there could be multiple ones. - #[cfg(feature = "sticky-cookie")] - let (upstream_chosen_opt, context_from_lb) = { - let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { - takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? - } else { - None - }; - upstream_group.get(&context_to_lb) - }; - #[cfg(not(feature = "sticky-cookie"))] - let (upstream_chosen_opt, _) = upstream_group.get(&None); - - let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; - let context = HandlerContext { - #[cfg(feature = "sticky-cookie")] - context_lb: context_from_lb, - #[cfg(not(feature = "sticky-cookie"))] - context_lb: None, - }; - ///////////////////////////////////////////// - - // apply upstream-specific headers given in upstream_option - let headers = req.headers_mut(); - apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?; - - // update uri in request - if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) { - return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken")); - }; - let new_uri = Uri::builder() - .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) - .authority(upstream_chosen.uri.authority().unwrap().as_str()); - let org_pq = match req.uri().path_and_query() { - Some(pq) => pq.to_string(), - None => "/".to_string(), - } - .into_bytes(); - - // replace some parts of path if opt_replace_path is enabled for chosen upstream - let new_pq = match &upstream_group.replace_path { - Some(new_path) => { - let matched_path: &[u8] = upstream_group.path.as_ref(); - if matched_path.is_empty() || org_pq.len() < matched_path.len() { - return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); - }; - let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); - new_pq.extend_from_slice(new_path.as_ref()); - new_pq.extend_from_slice(&org_pq[matched_path.len()..]); - new_pq - } - None => org_pq, - }; - *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; - - // upgrade - if let Some(v) = upgrade { - req.headers_mut().insert(header::UPGRADE, v.parse()?); - req - .headers_mut() - .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?); - } - - // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 - if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { - // Change version to http/1.1 when destination scheme is http - debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); - *req.version_mut() = Version::HTTP_11; - } else if req.version() == Version::HTTP_3 { - // HTTP/3 is always https - debug!("HTTP/3 is currently unsupported for request to upstream."); - *req.version_mut() = Version::HTTP_2; - } - - apply_upstream_options_to_request_line(req, upstream_group)?; - - Ok(context) - } -} diff --git a/rpxy-lib/src/handler/mod.rs b/rpxy-lib/src/handler/mod.rs deleted file mode 100644 index 84e0226..0000000 --- a/rpxy-lib/src/handler/mod.rs +++ /dev/null @@ -1,24 +0,0 @@ -#[cfg(feature = "cache")] -mod cache; -mod forwarder; -mod handler_main; -mod utils_headers; -mod utils_request; -mod utils_synth_response; - -#[cfg(feature = "sticky-cookie")] -use crate::backend::LbContext; -pub use { - forwarder::Forwarder, - handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}, -}; - -#[allow(dead_code)] -#[derive(Debug)] -/// Context object to handle sticky cookies at HTTP message handler -struct HandlerContext { - #[cfg(feature = "sticky-cookie")] - context_lb: Option, - #[cfg(not(feature = "sticky-cookie"))] - context_lb: Option<()>, -} diff --git a/rpxy-lib/src/handler/utils_request.rs b/rpxy-lib/src/handler/utils_request.rs deleted file mode 100644 index 6204f41..0000000 --- a/rpxy-lib/src/handler/utils_request.rs +++ /dev/null @@ -1,64 +0,0 @@ -use crate::{ - backend::{UpstreamGroup, UpstreamOption}, - error::*, -}; -use hyper::{header, Request}; - -//////////////////////////////////////////////////// -// Functions to manipulate request line - -/// Apply upstream options in request line, specified in the configuration -pub(super) fn apply_upstream_options_to_request_line(req: &mut Request, upstream: &UpstreamGroup) -> Result<()> { - for opt in upstream.opts.iter() { - match opt { - UpstreamOption::ForceHttp11Upstream => *req.version_mut() = hyper::Version::HTTP_11, - UpstreamOption::ForceHttp2Upstream => { - // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt - // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required. - *req.version_mut() = hyper::Version::HTTP_2; - } - _ => (), - } - } - - Ok(()) -} - -/// Trait defining parser of hostname -pub trait ParseHost { - fn parse_host(&self) -> Result<&[u8]>; -} -impl ParseHost for Request { - /// Extract hostname from either the request HOST header or request line - fn parse_host(&self) -> Result<&[u8]> { - let headers_host = self.headers().get(header::HOST); - let uri_host = self.uri().host(); - // let uri_port = self.uri().port_u16(); - - if !(!(headers_host.is_none() && uri_host.is_none())) { - return Err(RpxyError::Request("No host in request header")); - } - - // prioritize server_name in uri - uri_host.map_or_else( - || { - let m = headers_host.unwrap().as_bytes(); - if m.starts_with(&[b'[']) { - // v6 address with bracket case. if port is specified, always it is in this case. - let mut iter = m.split(|ptr| ptr == &b'[' || ptr == &b']'); - iter.next().ok_or(RpxyError::Request("Invalid Host"))?; // first item is always blank - iter.next().ok_or(RpxyError::Request("Invalid Host")) - } else if m.len() - m.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { - // v6 address case, if 2 or more ':' is contained - Ok(m) - } else { - // v4 address or hostname - m.split(|colon| colon == &b':') - .next() - .ok_or(RpxyError::Request("Invalid Host")) - } - }, - |v| Ok(v.as_bytes()), - ) - } -} diff --git a/rpxy-lib/src/handler/utils_synth_response.rs b/rpxy-lib/src/handler/utils_synth_response.rs deleted file mode 100644 index baa6987..0000000 --- a/rpxy-lib/src/handler/utils_synth_response.rs +++ /dev/null @@ -1,35 +0,0 @@ -// Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use crate::error::*; -use hyper::{Body, Request, Response, StatusCode, Uri}; - -//////////////////////////////////////////////////// -// Functions to create response (error or redirect) - -/// Generate a synthetic response message of a certain error status code -pub(super) fn http_error(status_code: StatusCode) -> Result> { - let response = Response::builder().status(status_code).body(Body::empty())?; - Ok(response) -} - -/// Generate synthetic response message of a redirection to https host with 301 -pub(super) fn secure_redirection( - server_name: &str, - tls_port: Option, - req: &Request, -) -> Result> { - let pq = match req.uri().path_and_query() { - Some(x) => x.as_str(), - _ => "", - }; - let new_uri = Uri::builder().scheme("https").path_and_query(pq); - let dest_uri = match tls_port { - Some(443) | None => new_uri.authority(server_name), - Some(p) => new_uri.authority(format!("{server_name}:{p}")), - } - .build()?; - let response = Response::builder() - .status(StatusCode::MOVED_PERMANENTLY) - .header("Location", dest_uri.to_string()) - .body(Body::empty())?; - Ok(response) -} diff --git a/rpxy-lib/src/hyper_ext/body_incoming_like.rs b/rpxy-lib/src/hyper_ext/body_incoming_like.rs new file mode 100644 index 0000000..9307b7f --- /dev/null +++ b/rpxy-lib/src/hyper_ext/body_incoming_like.rs @@ -0,0 +1,370 @@ +use super::watch; +use crate::error::*; +use futures_channel::{mpsc, oneshot}; +use futures_util::{stream::FusedStream, Future, Stream}; +use http::HeaderMap; +use hyper::body::{Body, Bytes, Frame, SizeHint}; +use std::{ + pin::Pin, + task::{Context, Poll}, +}; + +//////////////////////////////////////////////////////////// +/// Incoming like body to handle incoming request body +/// ported from https://github.com/hyperium/hyper/blob/master/src/body/incoming.rs +pub struct IncomingLike { + content_length: DecodedLength, + want_tx: watch::Sender, + data_rx: mpsc::Receiver>, + trailers_rx: oneshot::Receiver, +} + +macro_rules! ready { + ($e:expr) => { + match $e { + Poll::Ready(v) => v, + Poll::Pending => return Poll::Pending, + } + }; +} + +type BodySender = mpsc::Sender>; +type TrailersSender = oneshot::Sender; + +const MAX_LEN: u64 = std::u64::MAX - 2; +#[derive(Clone, Copy, PartialEq, Eq)] +pub(crate) struct DecodedLength(u64); +impl DecodedLength { + pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); + pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); + pub(crate) const ZERO: DecodedLength = DecodedLength(0); + + #[allow(dead_code)] + pub(crate) fn new(len: u64) -> Self { + debug_assert!(len <= MAX_LEN); + DecodedLength(len) + } + + pub(crate) fn sub_if(&mut self, amt: u64) { + match *self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), + DecodedLength(ref mut known) => { + *known -= amt; + } + } + } + /// Converts to an Option representing a Known or Unknown length. + pub(crate) fn into_opt(self) -> Option { + match self { + DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, + DecodedLength(known) => Some(known), + } + } +} +pub(crate) struct Sender { + want_rx: watch::Receiver, + data_tx: BodySender, + trailers_tx: Option, +} + +const WANT_PENDING: usize = 1; +const WANT_READY: usize = 2; + +impl IncomingLike { + /// Create a `Body` stream with an associated sender half. + /// + /// Useful when wanting to stream chunks from another thread. + #[inline] + #[allow(unused)] + pub(crate) fn channel() -> (Sender, IncomingLike) { + Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) + } + + pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, IncomingLike) { + let (data_tx, data_rx) = mpsc::channel(0); + let (trailers_tx, trailers_rx) = oneshot::channel(); + + // If wanter is true, `Sender::poll_ready()` won't becoming ready + // until the `Body` has been polled for data once. + let want = if wanter { WANT_PENDING } else { WANT_READY }; + + let (want_tx, want_rx) = watch::channel(want); + + let tx = Sender { + want_rx, + data_tx, + trailers_tx: Some(trailers_tx), + }; + let rx = IncomingLike { + content_length, + want_tx, + data_rx, + trailers_rx, + }; + + (tx, rx) + } +} + +impl Body for IncomingLike { + type Data = Bytes; + type Error = RpxyError; + + fn poll_frame( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll, Self::Error>>> { + self.want_tx.send(WANT_READY); + + if !self.data_rx.is_terminated() { + if let Some(chunk) = ready!(Pin::new(&mut self.data_rx).poll_next(cx)?) { + self.content_length.sub_if(chunk.len() as u64); + return Poll::Ready(Some(Ok(Frame::data(chunk)))); + } + } + + // check trailers after data is terminated + match ready!(Pin::new(&mut self.trailers_rx).poll(cx)) { + Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))), + Err(_) => Poll::Ready(None), + } + } + + fn is_end_stream(&self) -> bool { + self.content_length == DecodedLength::ZERO + } + + fn size_hint(&self) -> SizeHint { + macro_rules! opt_len { + ($content_length:expr) => {{ + let mut hint = SizeHint::default(); + + if let Some(content_length) = $content_length.into_opt() { + hint.set_exact(content_length); + } + + hint + }}; + } + + opt_len!(self.content_length) + } +} + +impl Sender { + /// Check to see if this `Sender` can send more data. + pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + // Check if the receiver end has tried polling for the body yet + ready!(self.poll_want(cx)?); + self + .data_tx + .poll_ready(cx) + .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } + + fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll> { + match self.want_rx.load(cx) { + WANT_READY => Poll::Ready(Ok(())), + WANT_PENDING => Poll::Pending, + watch::CLOSED => Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)), + unexpected => unreachable!("want_rx value: {}", unexpected), + } + } + + async fn ready(&mut self) -> RpxyResult<()> { + futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await + } + + /// Send data on data channel when it is ready. + #[allow(unused)] + pub(crate) async fn send_data(&mut self, chunk: Bytes) -> RpxyResult<()> { + self.ready().await?; + self + .data_tx + .try_send(Ok(chunk)) + .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } + + /// Send trailers on trailers channel. + #[allow(unused)] + pub(crate) async fn send_trailers(&mut self, trailers: HeaderMap) -> RpxyResult<()> { + let tx = match self.trailers_tx.take() { + Some(tx) => tx, + None => return Err(RpxyError::HyperIncomingLikeNewClosed), + }; + tx.send(trailers).map_err(|_| RpxyError::HyperIncomingLikeNewClosed) + } + + /// Try to send data on this channel. + /// + /// # Errors + /// + /// Returns `Err(Bytes)` if the channel could not (currently) accept + /// another `Bytes`. + /// + /// # Note + /// + /// This is mostly useful for when trying to send from some other thread + /// that doesn't have an async context. If in an async context, prefer + /// `send_data()` instead. + #[allow(unused)] + pub(crate) fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { + self + .data_tx + .try_send(Ok(chunk)) + .map_err(|err| err.into_inner().expect("just sent Ok")) + } + + #[allow(unused)] + pub(crate) fn abort(mut self) { + self.send_error(RpxyError::HyperNewBodyWriteAborted); + } + + pub(crate) fn send_error(&mut self, err: RpxyError) { + let _ = self + .data_tx + // clone so the send works even if buffer is full + .clone() + .try_send(Err(err)); + } +} + +#[cfg(test)] +mod tests { + use std::mem; + use std::task::Poll; + + use super::{Body, DecodedLength, IncomingLike, Sender, SizeHint}; + use crate::error::RpxyError; + use http_body_util::BodyExt; + + #[test] + fn test_size_of() { + // These are mostly to help catch *accidentally* increasing + // the size by too much. + + let body_size = mem::size_of::(); + let body_expected_size = mem::size_of::() * 5; + assert!( + body_size <= body_expected_size, + "Body size = {} <= {}", + body_size, + body_expected_size, + ); + + //assert_eq!(body_size, mem::size_of::>(), "Option"); + + assert_eq!(mem::size_of::(), mem::size_of::() * 5, "Sender"); + + assert_eq!( + mem::size_of::(), + mem::size_of::>(), + "Option" + ); + } + #[test] + fn size_hint() { + fn eq(body: IncomingLike, b: SizeHint, note: &str) { + let a = body.size_hint(); + assert_eq!(a.lower(), b.lower(), "lower for {:?}", note); + assert_eq!(a.upper(), b.upper(), "upper for {:?}", note); + } + + eq(IncomingLike::channel().1, SizeHint::new(), "channel"); + + eq( + IncomingLike::new_channel(DecodedLength::new(4), /*wanter =*/ false).1, + SizeHint::with_exact(4), + "channel with length", + ); + } + + #[tokio::test] + async fn channel_abort() { + let (tx, mut rx) = IncomingLike::channel(); + + tx.abort(); + + match rx.frame().await.unwrap() { + Err(RpxyError::HyperNewBodyWriteAborted) => true, + unexpected => panic!("unexpected: {:?}", unexpected), + }; + } + + #[tokio::test] + async fn channel_abort_when_buffer_is_full() { + let (mut tx, mut rx) = IncomingLike::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + // buffer is full, but can still send abort + tx.abort(); + + let chunk1 = rx.frame().await.expect("item 1").expect("chunk 1").into_data().unwrap(); + assert_eq!(chunk1, "chunk 1"); + + match rx.frame().await.unwrap() { + Err(RpxyError::HyperNewBodyWriteAborted) => true, + unexpected => panic!("unexpected: {:?}", unexpected), + }; + } + + #[test] + fn channel_buffers_one() { + let (mut tx, _rx) = IncomingLike::channel(); + + tx.try_send_data("chunk 1".into()).expect("send 1"); + + // buffer is now full + let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); + assert_eq!(chunk2, "chunk 2"); + } + + #[tokio::test] + async fn channel_empty() { + let (_, mut rx) = IncomingLike::channel(); + + assert!(rx.frame().await.is_none()); + } + + #[test] + fn channel_ready() { + let (mut tx, _rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); + } + + #[test] + fn channel_wanter() { + let (mut tx, mut rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + let mut rx_data = tokio_test::task::spawn(rx.frame()); + + assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); + + assert!(rx_data.poll().is_pending(), "poll rx.data"); + assert!(tx_ready.is_woken(), "rx poll wakes tx"); + + assert!(tx_ready.poll().is_ready(), "tx is ready after rx has been polled"); + } + + #[test] + + fn channel_notices_closure() { + let (mut tx, rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); + + let mut tx_ready = tokio_test::task::spawn(tx.ready()); + + assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); + + drop(rx); + assert!(tx_ready.is_woken(), "dropping rx wakes tx"); + + match tx_ready.poll() { + Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)) => (), + unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), + } + } +} diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs new file mode 100644 index 0000000..ca44756 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -0,0 +1,75 @@ +use super::body::IncomingLike; +use crate::error::RpxyError; +use futures::channel::mpsc::UnboundedReceiver; +use http_body_util::{combinators, BodyExt, Empty, Full, StreamBody}; +use hyper::body::{Body, Bytes, Frame, Incoming}; +use std::pin::Pin; + +/// Type for synthetic boxed body +pub type BoxBody = combinators::BoxBody; + +/// helper function to build a empty body +pub(crate) fn empty() -> BoxBody { + Empty::::new().map_err(|never| match never {}).boxed() +} + +/// helper function to build a full body +pub(crate) fn full(body: Bytes) -> BoxBody { + Full::new(body).map_err(|never| match never {}).boxed() +} + +#[allow(unused)] +/* ------------------------------------ */ +/// Request body used in this project +/// - Incoming: just a type that only forwards the downstream request body to upstream. +/// - IncomingLike: a Incoming-like type in which channel is used +pub enum RequestBody { + Incoming(Incoming), + IncomingLike(IncomingLike), +} + +impl Body for RequestBody { + type Data = bytes::Bytes; + type Error = RpxyError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + match self.get_mut() { + RequestBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx).map_err(RpxyError::HyperBodyError), + RequestBody::IncomingLike(incoming_like) => Pin::new(incoming_like).poll_frame(cx), + } + } +} + +/* ------------------------------------ */ +pub type UnboundedStreamBody = StreamBody, hyper::Error>>>; + +#[allow(unused)] +/// Response body use in this project +/// - Incoming: just a type that only forwards the upstream response body to downstream. +/// - Boxed: a type that is generated from cache or synthetic response body, e.g.,, small byte object. +/// - Streamed: another type that is generated from stream, e.g., large byte object. +pub enum ResponseBody { + Incoming(Incoming), + Boxed(BoxBody), + Streamed(UnboundedStreamBody), +} + +impl Body for ResponseBody { + type Data = bytes::Bytes; + type Error = RpxyError; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + match self.get_mut() { + ResponseBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx), + ResponseBody::Boxed(boxed) => Pin::new(boxed).poll_frame(cx), + ResponseBody::Streamed(streamed) => Pin::new(streamed).poll_frame(cx), + } + .map_err(RpxyError::HyperBodyError) + } +} diff --git a/rpxy-lib/src/hyper_ext/executor.rs b/rpxy-lib/src/hyper_ext/executor.rs new file mode 100644 index 0000000..579251e --- /dev/null +++ b/rpxy-lib/src/hyper_ext/executor.rs @@ -0,0 +1,23 @@ +use tokio::runtime::Handle; + +#[derive(Clone)] +/// Executor for hyper +pub struct LocalExecutor { + runtime_handle: Handle, +} + +impl LocalExecutor { + pub fn new(runtime_handle: Handle) -> Self { + LocalExecutor { runtime_handle } + } +} + +impl hyper::rt::Executor for LocalExecutor +where + F: std::future::Future + Send + 'static, + F::Output: Send, +{ + fn execute(&self, fut: F) { + self.runtime_handle.spawn(fut); + } +} diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs new file mode 100644 index 0000000..a4c5196 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -0,0 +1,16 @@ +mod body_incoming_like; +mod body_type; +mod executor; +mod tokio_timer; +mod watch; + +#[allow(unused)] +pub(crate) mod rt { + pub(crate) use super::executor::LocalExecutor; + pub(crate) use super::tokio_timer::{TokioSleep, TokioTimer}; +} +#[allow(unused)] +pub(crate) mod body { + pub(crate) use super::body_incoming_like::IncomingLike; + pub(crate) use super::body_type::{empty, full, BoxBody, RequestBody, ResponseBody, UnboundedStreamBody}; +} diff --git a/rpxy-lib/src/hyper_ext/tokio_timer.rs b/rpxy-lib/src/hyper_ext/tokio_timer.rs new file mode 100644 index 0000000..53a1af7 --- /dev/null +++ b/rpxy-lib/src/hyper_ext/tokio_timer.rs @@ -0,0 +1,55 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use hyper::rt::{Sleep, Timer}; +use pin_project_lite::pin_project; + +#[derive(Clone, Debug)] +pub struct TokioTimer; + +impl Timer for TokioTimer { + fn sleep(&self, duration: Duration) -> Pin> { + Box::pin(TokioSleep { + inner: tokio::time::sleep(duration), + }) + } + + fn sleep_until(&self, deadline: Instant) -> Pin> { + Box::pin(TokioSleep { + inner: tokio::time::sleep_until(deadline.into()), + }) + } + + fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + if let Some(sleep) = sleep.as_mut().downcast_mut_pin::() { + sleep.reset(new_deadline) + } + } +} + +pin_project! { + pub(crate) struct TokioSleep { + #[pin] + pub(crate) inner: tokio::time::Sleep, + } +} + +impl Future for TokioSleep { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().inner.poll(cx) + } +} + +impl Sleep for TokioSleep {} + +impl TokioSleep { + pub fn reset(self: Pin<&mut Self>, deadline: Instant) { + self.project().inner.as_mut().reset(deadline.into()); + } +} diff --git a/rpxy-lib/src/hyper_ext/watch.rs b/rpxy-lib/src/hyper_ext/watch.rs new file mode 100644 index 0000000..d5e1c7e --- /dev/null +++ b/rpxy-lib/src/hyper_ext/watch.rs @@ -0,0 +1,67 @@ +//! An SPSC broadcast channel. +//! +//! - The value can only be a `usize`. +//! - The consumer is only notified if the value is different. +//! - The value `0` is reserved for closed. +// from https://github.com/hyperium/hyper/blob/master/src/common/watch.rs + +use futures_util::task::AtomicWaker; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; +use std::task; + +type Value = usize; + +pub(super) const CLOSED: usize = 0; + +pub(super) fn channel(initial: Value) -> (Sender, Receiver) { + debug_assert!(initial != CLOSED, "watch::channel initial state of 0 is reserved"); + + let shared = Arc::new(Shared { + value: AtomicUsize::new(initial), + waker: AtomicWaker::new(), + }); + + (Sender { shared: shared.clone() }, Receiver { shared }) +} + +pub(super) struct Sender { + shared: Arc, +} + +pub(super) struct Receiver { + shared: Arc, +} + +struct Shared { + value: AtomicUsize, + waker: AtomicWaker, +} + +impl Sender { + pub(super) fn send(&mut self, value: Value) { + if self.shared.value.swap(value, Ordering::SeqCst) != value { + self.shared.waker.wake(); + } + } +} + +impl Drop for Sender { + fn drop(&mut self) { + self.send(CLOSED); + } +} + +impl Receiver { + pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { + self.shared.waker.register(cx.waker()); + self.shared.value.load(Ordering::SeqCst) + } + + #[allow(dead_code)] + pub(crate) fn peek(&self) -> Value { + self.shared.value.load(Ordering::Relaxed) + } +} diff --git a/rpxy-lib/src/lib.rs b/rpxy-lib/src/lib.rs index fd242c5..115b78a 100644 --- a/rpxy-lib/src/lib.rs +++ b/rpxy-lib/src/lib.rs @@ -1,26 +1,25 @@ mod backend; -mod certs; mod constants; +mod count; +mod crypto; mod error; +mod forwarder; mod globals; -mod handler; +mod hyper_ext; mod log; +mod message_handler; +mod name_exp; mod proxy; -mod utils; use crate::{ - error::*, - globals::Globals, - handler::{Forwarder, HttpMessageHandlerBuilder}, - log::*, - proxy::ProxyBuilder, + crypto::build_cert_reloader, error::*, forwarder::Forwarder, globals::Globals, log::*, + message_handler::HttpMessageHandlerBuilder, proxy::Proxy, }; use futures::future::select_all; -// use hyper_trust_dns::TrustDnsResolver; use std::sync::Arc; pub use crate::{ - certs::{CertsAndKeys, CryptoSource}, + crypto::{CertsAndKeys, CryptoSource}, globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}, }; pub mod reexports { @@ -28,19 +27,22 @@ pub mod reexports { pub use rustls::{Certificate, PrivateKey}; } -#[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] -compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); - /// Entrypoint that creates and spawns tasks of reverse proxy services pub async fn entrypoint( proxy_config: &ProxyConfig, app_config_list: &AppConfigList, runtime_handle: &tokio::runtime::Handle, term_notify: Option>, -) -> Result<()> +) -> RpxyResult<()> where T: CryptoSource + Clone + Send + Sync + 'static, { + #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] + warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used"); + + #[cfg(all(feature = "native-tls-backend", feature = "rustls-backend"))] + warn!("Both \"native-tls-backend\" and \"rustls-backend\" features are enabled. \"rustls-backend\" will be used"); + // For initial message logging if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { info!("Listen both IPv4 and IPv6") @@ -53,6 +55,12 @@ where if proxy_config.https_port.is_some() { info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); } + if proxy_config.connection_handling_timeout.is_some() { + info!( + "Force connection handling timeout: {:?} sec", + proxy_config.connection_handling_timeout.unwrap_or_default().as_secs() + ); + } #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] if proxy_config.http3 { info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); @@ -62,52 +70,81 @@ where } #[cfg(feature = "cache")] if proxy_config.cache_enabled { - info!( - "Cache is enabled: cache dir = {:?}", - proxy_config.cache_dir.as_ref().unwrap() - ); + info!("Cache is enabled: cache dir = {:?}", proxy_config.cache_dir.as_ref().unwrap()); } else { info!("Cache is disabled") } - // build global + // 1. build backends, and make it contained in Arc + let app_manager = Arc::new(backend::BackendAppManager::try_from(app_config_list)?); + + // 2. build crypto reloader service + let (cert_reloader_service, cert_reloader_rx) = match proxy_config.https_port { + Some(_) => { + let (s, r) = build_cert_reloader(&app_manager).await?; + (Some(s), Some(r)) + } + None => (None, None), + }; + + // 3. build global shared context let globals = Arc::new(Globals { proxy_config: proxy_config.clone(), - backends: app_config_list.clone().try_into()?, request_count: Default::default(), runtime_handle: runtime_handle.clone(), + term_notify: term_notify.clone(), + cert_reloader_rx: cert_reloader_rx.clone(), }); - // build message handler including a request forwarder - let msg_handler = Arc::new( + // 4. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well + let forwarder = Arc::new(Forwarder::try_new(&globals).await?); + let message_handler = Arc::new( HttpMessageHandlerBuilder::default() - .forwarder(Arc::new(Forwarder::new(&globals).await)) .globals(globals.clone()) + .app_manager(app_manager.clone()) + .forwarder(forwarder) .build()?, ); + // 5. spawn each proxy for a given socket with copied Arc-ed message_handler. + // build hyper connection builder shared with proxy instances + let connection_builder = proxy::connection_builder(&globals); + + // spawn each proxy for a given socket with copied Arc-ed backend, message_handler and connection builder. let addresses = globals.proxy_config.listen_sockets.clone(); - let futures = select_all(addresses.into_iter().map(|addr| { + let futures_iter = addresses.into_iter().map(|listening_on| { let mut tls_enabled = false; if let Some(https_port) = globals.proxy_config.https_port { - tls_enabled = https_port == addr.port() + tls_enabled = https_port == listening_on.port() } - - let proxy = ProxyBuilder::default() - .globals(globals.clone()) - .listening_on(addr) - .tls_enabled(tls_enabled) - .msg_handler(msg_handler.clone()) - .build() - .unwrap(); - - globals.runtime_handle.spawn(proxy.start(term_notify.clone())) - })); + let proxy = Proxy { + globals: globals.clone(), + listening_on, + tls_enabled, + connection_builder: connection_builder.clone(), + message_handler: message_handler.clone(), + }; + globals.runtime_handle.spawn(async move { proxy.start().await }) + }); // wait for all future - if let (Ok(Err(e)), _, _) = futures.await { - error!("Some proxy services are down: {:?}", e); - }; + match cert_reloader_service { + Some(cert_service) => { + tokio::select! { + _ = cert_service.start() => { + error!("Certificate reloader service got down"); + } + _ = select_all(futures_iter) => { + error!("Some proxy services are down"); + } + } + } + None => { + if let (Ok(Err(e)), _, _) = select_all(futures_iter).await { + error!("Some proxy services are down: {}", e); + } + } + } Ok(()) } diff --git a/rpxy-lib/src/log.rs b/rpxy-lib/src/log.rs index 6b8afbe..c55b5c2 100644 --- a/rpxy-lib/src/log.rs +++ b/rpxy-lib/src/log.rs @@ -1,98 +1 @@ -use crate::utils::ToCanonical; -use hyper::header; -use std::net::SocketAddr; pub use tracing::{debug, error, info, warn}; - -#[derive(Debug, Clone)] -pub struct MessageLog { - // pub tls_server_name: String, - pub client_addr: String, - pub method: String, - pub host: String, - pub p_and_q: String, - pub version: hyper::Version, - pub uri_scheme: String, - pub uri_host: String, - pub ua: String, - pub xff: String, - pub status: String, - pub upstream: String, -} - -impl From<&hyper::Request> for MessageLog { - fn from(req: &hyper::Request) -> Self { - let header_mapper = |v: header::HeaderName| { - req - .headers() - .get(v) - .map_or_else(|| "", |s| s.to_str().unwrap_or("")) - .to_string() - }; - Self { - // tls_server_name: "".to_string(), - client_addr: "".to_string(), - method: req.method().to_string(), - host: header_mapper(header::HOST), - p_and_q: req - .uri() - .path_and_query() - .map_or_else(|| "", |v| v.as_str()) - .to_string(), - version: req.version(), - uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), - uri_host: req.uri().host().unwrap_or("").to_string(), - ua: header_mapper(header::USER_AGENT), - xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), - status: "".to_string(), - upstream: "".to_string(), - } - } -} - -impl MessageLog { - pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { - self.client_addr = client_addr.to_canonical().to_string(); - self - } - // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self { - // self.tls_server_name = tls_server_name.to_string(); - // self - // } - pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { - self.status = status_code.to_string(); - self - } - pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { - self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); - self - } - pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { - self.upstream = upstream.to_string(); - self - } - - pub fn output(&self) { - info!( - "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", - if !self.host.is_empty() { - self.host.as_str() - } else { - self.uri_host.as_str() - }, - self.client_addr, - self.method, - self.p_and_q, - self.version, - self.status, - if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { - format!("{}://{}", self.uri_scheme, self.uri_host) - } else { - "".to_string() - }, - self.ua, - self.xff, - self.upstream, - // self.tls_server_name - ); - } -} diff --git a/rpxy-lib/src/utils/socket_addr.rs b/rpxy-lib/src/message_handler/canonical_address.rs similarity index 96% rename from rpxy-lib/src/utils/socket_addr.rs rename to rpxy-lib/src/message_handler/canonical_address.rs index 105fc55..32dad78 100644 --- a/rpxy-lib/src/utils/socket_addr.rs +++ b/rpxy-lib/src/message_handler/canonical_address.rs @@ -1,5 +1,6 @@ use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +/// Trait to convert an IP address to its canonical form pub trait ToCanonical { fn to_canonical(&self) -> Self; } diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs new file mode 100644 index 0000000..c46ac85 --- /dev/null +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -0,0 +1,248 @@ +use super::{ + http_log::HttpMessageLog, + http_result::{HttpError, HttpResult}, + synthetic_response::{secure_redirection_response, synthetic_error_response}, + utils_headers::*, + utils_request::InspectParseHost, +}; +use crate::{ + backend::{BackendAppManager, LoadBalanceContext}, + crypto::CryptoSource, + error::*, + forwarder::{ForwardRequest, Forwarder}, + globals::Globals, + hyper_ext::body::{RequestBody, ResponseBody}, + log::*, + name_exp::ServerName, +}; +use derive_builder::Builder; +use http::{Request, Response, StatusCode}; +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo}; +use std::{net::SocketAddr, sync::Arc}; +use tokio::io::copy_bidirectional; + +#[allow(dead_code)] +#[derive(Debug)] +/// Context object to handle sticky cookies at HTTP message handler +pub(super) struct HandlerContext { + #[cfg(feature = "sticky-cookie")] + pub(super) context_lb: Option, + #[cfg(not(feature = "sticky-cookie"))] + pub(super) context_lb: Option<()>, +} + +#[derive(Clone, Builder)] +/// HTTP message handler for requests from clients and responses from backend applications, +/// responsible to manipulate and forward messages to upstream backends and downstream clients. +pub struct HttpMessageHandler +where + C: Send + Sync + Connect + Clone + 'static, + U: CryptoSource + Clone, +{ + forwarder: Arc>, + pub(super) globals: Arc, + app_manager: Arc>, +} + +impl HttpMessageHandler +where + C: Send + Sync + Connect + Clone + 'static, + U: CryptoSource + Clone, +{ + /// Handle incoming request message from a client. + /// Responsible to passthrough responses from backend applications or generate synthetic error responses. + pub async fn handle_request( + &self, + req: Request, + client_addr: SocketAddr, // For access control + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, + ) -> RpxyResult> { + // preparing log data + let mut log_data = HttpMessageLog::from(&req); + log_data.client_addr(&client_addr); + + let http_result = self + .handle_request_inner( + &mut log_data, + req, + client_addr, + listen_addr, + tls_enabled, + tls_server_name, + ) + .await; + + // passthrough or synthetic response + match http_result { + Ok(v) => { + log_data.status_code(&v.status()).output(); + Ok(v) + } + Err(e) => { + error!("{e}"); + let code = StatusCode::from(e); + log_data.status_code(&code).output(); + synthetic_error_response(code) + } + } + } + + /// Handle inner with no synthetic error response. + /// Synthetic response is generated by caller. + async fn handle_request_inner( + &self, + log_data: &mut HttpMessageLog, + mut req: Request, + client_addr: SocketAddr, // For access control + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, + ) -> HttpResult> { + // Here we start to inspect and parse with server_name + let server_name = req + .inspect_parse_host() + .map(|v| ServerName::from(v.as_slice())) + .map_err(|_e| HttpError::InvalidHostInRequestHeader)?; + + // check consistency of between TLS SNI and HOST/Request URI Line. + #[allow(clippy::collapsible_if)] + if tls_enabled && self.globals.proxy_config.sni_consistency { + if server_name != tls_server_name.unwrap_or_default() { + return Err(HttpError::SniHostInconsistency); + } + } + // Find backend application for given server_name, and drop if incoming request is invalid as request. + let backend_app = match self.app_manager.apps.get(&server_name) { + Some(backend_app) => backend_app, + None => { + let Some(default_server_name) = &self.app_manager.default_server_name else { + return Err(HttpError::NoMatchingBackendApp); + }; + debug!("Serving by default app"); + self.app_manager.apps.get(default_server_name).unwrap() + } + }; + + // Redirect to https if !tls_enabled and redirect_to_https is true + if !tls_enabled && backend_app.https_redirection.unwrap_or(false) { + debug!( + "Redirect to secure connection: {}", + <&ServerName as TryInto>::try_into(&backend_app.server_name).unwrap_or_default() + ); + return secure_redirection_response(&backend_app.server_name, self.globals.proxy_config.https_port, &req); + } + + // Find reverse proxy for given path and choose one of upstream host + // Longest prefix match + let path = req.uri().path(); + let Some(upstream_candidates) = backend_app.path_manager.get(path) else { + return Err(HttpError::NoUpstreamCandidates); + }; + + // Upgrade in request header + let upgrade_in_request = extract_upgrade(req.headers()); + if upgrade_in_request.is_some() && req.version() != http::Version::HTTP_11 { + return Err(HttpError::FailedToUpgrade(format!( + "Unsupported HTTP version: {:?}", + req.version() + ))); + } + // let request_upgraded = req.extensions_mut().remove::(); + let req_on_upgrade = hyper::upgrade::on(&mut req); + + // Build request from destination information + let _context = match self.generate_request_forwarded( + &client_addr, + &listen_addr, + &mut req, + &upgrade_in_request, + upstream_candidates, + tls_enabled, + ) { + Err(e) => { + return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string())); + } + Ok(v) => v, + }; + debug!( + "Request to be forwarded: [uri {}, method: {}, version {:?}, headers {:?}]", + req.uri(), + req.method(), + req.version(), + req.headers() + ); + log_data.xff(&req.headers().get("x-forwarded-for")); + log_data.upstream(req.uri()); + ////// + + ////////////// + // Forward request to a chosen backend + let mut res_backend = match self.forwarder.request(req).await { + Ok(v) => v, + Err(e) => { + return Err(HttpError::FailedToGetResponseFromBackend(e.to_string())); + } + }; + ////////////// + // Process reverse proxy context generated during the forwarding request generation. + #[cfg(feature = "sticky-cookie")] + if let Some(context_from_lb) = _context.context_lb { + let res_headers = res_backend.headers_mut(); + if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + return Err(HttpError::FailedToAddSetCookeInResponse(e.to_string())); + } + } + + if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { + // Generate response to client + if let Err(e) = self.generate_response_forwarded(&mut res_backend, backend_app) { + return Err(HttpError::FailedToGenerateDownstreamResponse(e.to_string())); + } + return Ok(res_backend); + } + + // Handle StatusCode::SWITCHING_PROTOCOLS in response + let upgrade_in_response = extract_upgrade(res_backend.headers()); + let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) { + (Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(), + _ => false, + }; + + if !should_upgrade { + return Err(HttpError::FailedToUpgrade(format!( + "Backend tried to switch to protocol {:?} when {:?} was requested", + upgrade_in_response, upgrade_in_request + ))); + } + // let Some(request_upgraded) = request_upgraded else { + // return Err(HttpError::NoUpgradeExtensionInRequest); + // }; + + // let Some(onupgrade) = res_backend.extensions_mut().remove::() else { + // return Err(HttpError::NoUpgradeExtensionInResponse); + // }; + let res_on_upgrade = hyper::upgrade::on(&mut res_backend); + + self.globals.runtime_handle.spawn(async move { + let mut response_upgraded = TokioIo::new(res_on_upgrade.await.map_err(|e| { + error!("Failed to upgrade response: {}", e); + RpxyError::FailedToUpgradeResponse(e.to_string()) + })?); + let mut request_upgraded = TokioIo::new(req_on_upgrade.await.map_err(|e| { + error!("Failed to upgrade request: {}", e); + RpxyError::FailedToUpgradeRequest(e.to_string()) + })?); + copy_bidirectional(&mut response_upgraded, &mut request_upgraded) + .await + .map_err(|e| { + error!("Coping between upgraded connections failed: {}", e); + RpxyError::FailedToCopyBidirectional(e.to_string()) + })?; + Ok(()) as RpxyResult<()> + }); + + Ok(res_backend) + } +} diff --git a/rpxy-lib/src/message_handler/handler_manipulate_messages.rs b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs new file mode 100644 index 0000000..ecfd53c --- /dev/null +++ b/rpxy-lib/src/message_handler/handler_manipulate_messages.rs @@ -0,0 +1,185 @@ +use super::{handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line, HttpMessageHandler}; +use crate::{ + backend::{BackendApp, UpstreamCandidates}, + constants::RESPONSE_HEADER_SERVER, + log::*, + CryptoSource, +}; +use anyhow::{anyhow, ensure, Result}; +use http::{header, HeaderValue, Request, Response, Uri}; +use hyper_util::client::legacy::connect::Connect; +use std::net::SocketAddr; + +impl HttpMessageHandler +where + C: Send + Sync + Connect + Clone + 'static, + U: CryptoSource + Clone, +{ + //////////////////////////////////////////////////// + // Functions to generate messages + //////////////////////////////////////////////////// + + #[allow(unused_variables)] + /// Manipulate a response message sent from a backend application to forward downstream to a client. + pub(super) fn generate_response_forwarded( + &self, + response: &mut Response, + backend_app: &BackendApp, + ) -> Result<()> { + let headers = response.headers_mut(); + remove_connection_header(headers); + remove_hop_header(headers); + add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; + + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + { + // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled + // TODO: This is a workaround for avoiding a client authentication in HTTP/3 + if self.globals.proxy_config.http3 && backend_app.crypto_source.as_ref().is_some_and(|v| !v.is_mutual_tls()) { + if let Some(port) = self.globals.proxy_config.https_port { + add_header_entry_overwrite_if_exist( + headers, + header::ALT_SVC.as_str(), + format!( + "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", + port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age + ), + )?; + } + } else { + // remove alt-svc to disallow requests via http3 + headers.remove(header::ALT_SVC.as_str()); + } + } + #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + { + if self.globals.proxy_config.https_port.is_some() { + headers.remove(header::ALT_SVC.as_str()); + } + } + + Ok(()) + } + + #[allow(clippy::too_many_arguments)] + /// Manipulate a request message sent from a client to forward upstream to a backend application + pub(super) fn generate_request_forwarded( + &self, + client_addr: &SocketAddr, + listen_addr: &SocketAddr, + req: &mut Request, + upgrade: &Option, + upstream_candidates: &UpstreamCandidates, + tls_enabled: bool, + ) -> Result { + debug!("Generate request to be forwarded"); + + // Add te: trailer if contained in original request + let contains_te_trailers = { + if let Some(te) = req.headers().get(header::TE) { + te.as_bytes() + .split(|v| v == &b',' || v == &b' ') + .any(|x| x == "trailers".as_bytes()) + } else { + false + } + }; + + let original_uri = req.uri().to_string(); + let headers = req.headers_mut(); + // delete headers specified in header.connection + remove_connection_header(headers); + // delete hop headers including header.connection + remove_hop_header(headers); + // X-Forwarded-For + add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &original_uri)?; + + // Add te: trailer if te_trailer + if contains_te_trailers { + headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); + } + + // by default, add "host" header of original server_name if not exist + if req.headers().get(header::HOST).is_none() { + let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); + req + .headers_mut() + .insert(header::HOST, HeaderValue::from_str(&org_host)?); + }; + println!("{:?}", req.headers().get(header::HOST)); + + ///////////////////////////////////////////// + // Fix unique upstream destination since there could be multiple ones. + #[cfg(feature = "sticky-cookie")] + let (upstream_chosen_opt, context_from_lb) = { + let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_candidates.load_balance { + takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + } else { + None + }; + upstream_candidates.get(&context_to_lb) + }; + #[cfg(not(feature = "sticky-cookie"))] + let (upstream_chosen_opt, _) = upstream_candidates.get(&None); + + let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context = HandlerContext { + #[cfg(feature = "sticky-cookie")] + context_lb: context_from_lb, + #[cfg(not(feature = "sticky-cookie"))] + context_lb: None, + }; + ///////////////////////////////////////////// + + // apply upstream-specific headers given in upstream_option + let headers = req.headers_mut(); + // apply upstream options to header + apply_upstream_options_to_header(headers, &upstream_chosen.uri, upstream_candidates)?; + + // update uri in request + ensure!( + upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some(), + "Upstream uri `scheme` and `authority` is broken" + ); + + let new_uri = Uri::builder() + .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) + .authority(upstream_chosen.uri.authority().unwrap().as_str()); + let org_pq = match req.uri().path_and_query() { + Some(pq) => pq.to_string(), + None => "/".to_string(), + } + .into_bytes(); + + // replace some parts of path if opt_replace_path is enabled for chosen upstream + let new_pq = match &upstream_candidates.replace_path { + Some(new_path) => { + let matched_path: &[u8] = upstream_candidates.path.as_ref(); + ensure!( + !matched_path.is_empty() && org_pq.len() >= matched_path.len(), + "Upstream uri `path and query` is broken" + ); + let mut new_pq = Vec::::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); + new_pq.extend_from_slice(new_path.as_ref()); + new_pq.extend_from_slice(&org_pq[matched_path.len()..]); + new_pq + } + None => org_pq, + }; + *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; + + // upgrade + if let Some(v) = upgrade { + req.headers_mut().insert(header::UPGRADE, v.parse()?); + req + .headers_mut() + .insert(header::CONNECTION, HeaderValue::from_static("upgrade")); + } + if upgrade.is_none() { + // can update request line i.e., http version, only if not upgrade (http 1.1) + update_request_line(req, upstream_chosen, upstream_candidates)?; + } + + Ok(context) + } +} diff --git a/rpxy-lib/src/message_handler/http_log.rs b/rpxy-lib/src/message_handler/http_log.rs new file mode 100644 index 0000000..acda9f0 --- /dev/null +++ b/rpxy-lib/src/message_handler/http_log.rs @@ -0,0 +1,99 @@ +use super::canonical_address::ToCanonical; +use crate::log::*; +use http::header; +use std::net::SocketAddr; + +/// Struct to log HTTP messages +#[derive(Debug, Clone)] +pub struct HttpMessageLog { + // pub tls_server_name: String, + pub client_addr: String, + pub method: String, + pub host: String, + pub p_and_q: String, + pub version: http::Version, + pub uri_scheme: String, + pub uri_host: String, + pub ua: String, + pub xff: String, + pub status: String, + pub upstream: String, +} + +impl From<&http::Request> for HttpMessageLog { + fn from(req: &http::Request) -> Self { + let header_mapper = |v: header::HeaderName| { + req + .headers() + .get(v) + .map_or_else(|| "", |s| s.to_str().unwrap_or("")) + .to_string() + }; + Self { + // tls_server_name: "".to_string(), + client_addr: "".to_string(), + method: req.method().to_string(), + host: header_mapper(header::HOST), + p_and_q: req + .uri() + .path_and_query() + .map_or_else(|| "", |v| v.as_str()) + .to_string(), + version: req.version(), + uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), + uri_host: req.uri().host().unwrap_or("").to_string(), + ua: header_mapper(header::USER_AGENT), + xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), + status: "".to_string(), + upstream: "".to_string(), + } + } +} + +impl HttpMessageLog { + pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { + self.client_addr = client_addr.to_canonical().to_string(); + self + } + // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self { + // self.tls_server_name = tls_server_name.to_string(); + // self + // } + pub fn status_code(&mut self, status_code: &http::StatusCode) -> &mut Self { + self.status = status_code.to_string(); + self + } + pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { + self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); + self + } + pub fn upstream(&mut self, upstream: &http::Uri) -> &mut Self { + self.upstream = upstream.to_string(); + self + } + + pub fn output(&self) { + info!( + "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", + if !self.host.is_empty() { + self.host.as_str() + } else { + self.uri_host.as_str() + }, + self.client_addr, + self.method, + self.p_and_q, + self.version, + self.status, + if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { + format!("{}://{}", self.uri_scheme, self.uri_host) + } else { + "".to_string() + }, + self.ua, + self.xff, + self.upstream, + // self.tls_server_name + ); + } +} diff --git a/rpxy-lib/src/message_handler/http_result.rs b/rpxy-lib/src/message_handler/http_result.rs new file mode 100644 index 0000000..98cdb45 --- /dev/null +++ b/rpxy-lib/src/message_handler/http_result.rs @@ -0,0 +1,61 @@ +use http::StatusCode; +use thiserror::Error; + +/// HTTP result type, T is typically a hyper::Response +/// HttpError is used to generate a synthetic error response +pub(crate) type HttpResult = std::result::Result; + +/// Describes things that can go wrong in the forwarder +#[derive(Debug, Error)] +pub enum HttpError { + // #[error("No host is give in request header")] + // NoHostInRequestHeader, + #[error("Invalid host in request header")] + InvalidHostInRequestHeader, + #[error("SNI and Host header mismatch")] + SniHostInconsistency, + #[error("No matching backend app")] + NoMatchingBackendApp, + #[error("Failed to redirect: {0}")] + FailedToRedirect(String), + #[error("No upstream candidates")] + NoUpstreamCandidates, + #[error("Failed to generate upstream request for backend application: {0}")] + FailedToGenerateUpstreamRequest(String), + #[error("Failed to get response from backend: {0}")] + FailedToGetResponseFromBackend(String), + + #[error("Failed to add set-cookie header in response {0}")] + FailedToAddSetCookeInResponse(String), + #[error("Failed to generated downstream response for clients: {0}")] + FailedToGenerateDownstreamResponse(String), + + #[error("Failed to upgrade connection: {0}")] + FailedToUpgrade(String), + // #[error("Request does not have an upgrade extension")] + // NoUpgradeExtensionInRequest, + // #[error("Response does not have an upgrade extension")] + // NoUpgradeExtensionInResponse, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for StatusCode { + fn from(e: HttpError) -> StatusCode { + match e { + // HttpError::NoHostInRequestHeader => StatusCode::BAD_REQUEST, + HttpError::InvalidHostInRequestHeader => StatusCode::BAD_REQUEST, + HttpError::SniHostInconsistency => StatusCode::MISDIRECTED_REQUEST, + HttpError::NoMatchingBackendApp => StatusCode::SERVICE_UNAVAILABLE, + HttpError::FailedToRedirect(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::NoUpstreamCandidates => StatusCode::NOT_FOUND, + HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, + HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR, + // HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST, + // HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} diff --git a/rpxy-lib/src/message_handler/mod.rs b/rpxy-lib/src/message_handler/mod.rs new file mode 100644 index 0000000..edeba27 --- /dev/null +++ b/rpxy-lib/src/message_handler/mod.rs @@ -0,0 +1,11 @@ +mod canonical_address; +mod handler_main; +mod handler_manipulate_messages; +mod http_log; +mod http_result; +mod synthetic_response; +mod utils_headers; +mod utils_request; + +pub use handler_main::HttpMessageHandlerBuilderError; +pub(crate) use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder}; diff --git a/rpxy-lib/src/message_handler/synthetic_response.rs b/rpxy-lib/src/message_handler/synthetic_response.rs new file mode 100644 index 0000000..a955a2d --- /dev/null +++ b/rpxy-lib/src/message_handler/synthetic_response.rs @@ -0,0 +1,42 @@ +use super::http_result::{HttpError, HttpResult}; +use crate::{ + error::*, + hyper_ext::body::{empty, ResponseBody}, + name_exp::ServerName, +}; +use http::{Request, Response, StatusCode, Uri}; + +/// build http response with status code of 4xx and 5xx +pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult> { + let res = Response::builder() + .status(status_code) + .body(ResponseBody::Boxed(empty())) + .unwrap(); + Ok(res) +} + +/// Generate synthetic response message of a redirection to https host with 301 +pub(super) fn secure_redirection_response( + server_name: &ServerName, + tls_port: Option, + req: &Request, +) -> HttpResult> { + let server_name: String = server_name.try_into().unwrap_or_default(); + let pq = match req.uri().path_and_query() { + Some(x) => x.as_str(), + _ => "", + }; + let new_uri = Uri::builder().scheme("https").path_and_query(pq); + let dest_uri = match tls_port { + Some(443) | None => new_uri.authority(server_name), + Some(p) => new_uri.authority(format!("{server_name}:{p}")), + } + .build() + .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; + let response = Response::builder() + .status(StatusCode::MOVED_PERMANENTLY) + .header("Location", dest_uri.to_string()) + .body(ResponseBody::Boxed(empty())) + .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; + Ok(response) +} diff --git a/rpxy-lib/src/handler/utils_headers.rs b/rpxy-lib/src/message_handler/utils_headers.rs similarity index 81% rename from rpxy-lib/src/handler/utils_headers.rs rename to rpxy-lib/src/message_handler/utils_headers.rs index 6a09c1d..9be45e5 100644 --- a/rpxy-lib/src/handler/utils_headers.rs +++ b/rpxy-lib/src/message_handler/utils_headers.rs @@ -1,26 +1,27 @@ -#[cfg(feature = "sticky-cookie")] -use crate::backend::{LbContext, StickyCookie, StickyCookieValue}; -use crate::backend::{UpstreamGroup, UpstreamOption}; - -use crate::{error::*, log::*, utils::*}; -use bytes::BufMut; -use hyper::{ - header::{self, HeaderMap, HeaderName, HeaderValue}, - Uri, +use super::canonical_address::ToCanonical; +use crate::{ + backend::{UpstreamCandidates, UpstreamOption}, + log::*, }; +use anyhow::{anyhow, ensure, Result}; +use bytes::BufMut; +use http::{header, HeaderMap, HeaderName, HeaderValue, Uri}; use std::{borrow::Cow, net::SocketAddr}; -//////////////////////////////////////////////////// -// Functions to manipulate headers +#[cfg(feature = "sticky-cookie")] +use crate::backend::{LoadBalanceContext, StickyCookie, StickyCookieValue}; +// use crate::backend::{UpstreamGroup, UpstreamOption}; +// //////////////////////////////////////////////////// +// // Functions to manipulate headers #[cfg(feature = "sticky-cookie")] /// Take sticky cookie header value from request header, -/// and returns LbContext to be forwarded to LB if exist and if needed. +/// and returns LoadBalanceContext to be forwarded to LB if exist and if needed. /// Removing sticky cookie is needed and it must not be passed to the upstream. pub(super) fn takeout_sticky_cookie_lb_context( headers: &mut HeaderMap, expected_cookie_name: &str, -) -> Result> { +) -> Result> { let mut headers_clone = headers.clone(); match headers_clone.entry(header::COOKIE) { @@ -35,12 +36,11 @@ pub(super) fn takeout_sticky_cookie_lb_context( if sticky_cookies.is_empty() { return Ok(None); } - if sticky_cookies.len() > 1 { - error!("Multiple sticky cookie values in request"); - return Err(RpxyError::Other(anyhow!( - "Invalid cookie: Multiple sticky cookie values" - ))); - } + ensure!( + sticky_cookies.len() == 1, + "Invalid cookie: Multiple sticky cookie values" + ); + let cookies_passed_to_upstream = without_sticky_cookies.join("; "); let cookie_passed_to_lb = sticky_cookies.first().unwrap(); headers.remove(header::COOKIE); @@ -50,7 +50,7 @@ pub(super) fn takeout_sticky_cookie_lb_context( value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, info: None, }; - Ok(Some(LbContext { sticky_cookie })) + Ok(Some(LoadBalanceContext { sticky_cookie })) } } } @@ -59,7 +59,10 @@ pub(super) fn takeout_sticky_cookie_lb_context( /// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated. /// Set-Cookie response header could be in multiple lines. /// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie -pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { +pub(super) fn set_sticky_cookie_lb_context( + headers: &mut HeaderMap, + context_from_lb: &LoadBalanceContext, +) -> Result<()> { let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; let new_header_val: HeaderValue = sticky_cookie_string.parse()?; let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; @@ -83,23 +86,37 @@ pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from Ok(()) } +/// overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy) +fn override_host_header(headers: &mut HeaderMap, upstream_base_uri: &Uri) -> Result<()> { + let mut upstream_host = upstream_base_uri + .host() + .ok_or_else(|| anyhow!("No hostname is given"))? + .to_string(); + // add port if it is not default + if let Some(port) = upstream_base_uri.port_u16() { + upstream_host = format!("{}:{}", upstream_host, port); + } + + // overwrite host header, this removes all the HOST header values + headers.insert(header::HOST, HeaderValue::from_str(&upstream_host)?); + Ok(()) +} + /// Apply options to request header, which are specified in the configuration pub(super) fn apply_upstream_options_to_header( headers: &mut HeaderMap, - _client_addr: &SocketAddr, - upstream: &UpstreamGroup, upstream_base_uri: &Uri, + // _client_addr: &SocketAddr, + upstream: &UpstreamCandidates, ) -> Result<()> { - for opt in upstream.opts.iter() { + for opt in upstream.options.iter() { match opt { - UpstreamOption::OverrideHost => { - // overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy) - let upstream_host = upstream_base_uri - .host() - .ok_or_else(|| anyhow!("No hostname is given in override_host option"))?; - headers - .insert(header::HOST, HeaderValue::from_str(upstream_host)?) - .ok_or_else(|| anyhow!("Failed to insert host header in override_host option"))?; + UpstreamOption::SetUpstreamHost => { + // prioritize KeepOriginalHost + if !upstream.options.contains(&UpstreamOption::KeepOriginalHost) { + // overwrite host header, this removes all the HOST header values + override_host_header(headers, upstream_base_uri)?; + } } UpstreamOption::UpgradeInsecureRequests => { // add upgrade-insecure-requests in request header if not exist diff --git a/rpxy-lib/src/message_handler/utils_request.rs b/rpxy-lib/src/message_handler/utils_request.rs new file mode 100644 index 0000000..8939433 --- /dev/null +++ b/rpxy-lib/src/message_handler/utils_request.rs @@ -0,0 +1,86 @@ +use crate::{ + backend::{Upstream, UpstreamCandidates, UpstreamOption}, + log::*, +}; +use anyhow::{anyhow, ensure, Result}; +use http::{header, uri::Scheme, Request, Version}; + +/// Trait defining parser of hostname +/// Inspect and extract hostname from either the request HOST header or request line +pub trait InspectParseHost { + type Error; + fn inspect_parse_host(&self) -> Result, Self::Error>; +} +impl InspectParseHost for Request { + type Error = anyhow::Error; + /// Inspect and extract hostname from either the request HOST header or request line + fn inspect_parse_host(&self) -> Result> { + let drop_port = |v: &[u8]| { + if v.starts_with(&[b'[']) { + // v6 address with bracket case. if port is specified, always it is in this case. + let mut iter = v.split(|ptr| ptr == &b'[' || ptr == &b']'); + iter.next().ok_or(anyhow!("Invalid Host header"))?; // first item is always blank + iter.next().ok_or(anyhow!("Invalid Host header")).map(|b| b.to_owned()) + } else if v.len() - v.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { + // v6 address case, if 2 or more ':' is contained + Ok(v.to_owned()) + } else { + // v4 address or hostname + v.split(|colon| colon == &b':') + .next() + .ok_or(anyhow!("Invalid Host header")) + .map(|v| v.to_ascii_lowercase()) + } + }; + + let headers_host = self.headers().get(header::HOST).map(|v| drop_port(v.as_bytes())); + let uri_host = self.uri().host().map(|v| drop_port(v.as_bytes())); + // let uri_port = self.uri().port_u16(); + + // prioritize server_name in uri + match (headers_host, uri_host) { + (Some(Ok(hh)), Some(Ok(hu))) => { + ensure!(hh == hu, "Host header and uri host mismatch"); + Ok(hh) + } + (Some(Ok(hh)), None) => Ok(hh), + (None, Some(Ok(hu))) => Ok(hu), + _ => Err(anyhow!("Neither Host header nor uri host is valid")), + } + } +} + +//////////////////////////////////////////////////// +// Functions to manipulate request line + +/// Update request line, e.g., version, and apply upstream options to request line, specified in the configuration +pub(super) fn update_request_line( + req: &mut Request, + upstream_chosen: &Upstream, + upstream_candidates: &UpstreamCandidates, +) -> anyhow::Result<()> { + // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3 + if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { + // Change version to http/1.1 when destination scheme is http + debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); + *req.version_mut() = Version::HTTP_11; + } else if req.version() == Version::HTTP_3 { + // HTTP/3 is always https + debug!("HTTP/3 is currently unsupported for request to upstream."); + *req.version_mut() = Version::HTTP_2; + } + + for opt in upstream_candidates.options.iter() { + match opt { + UpstreamOption::ForceHttp11Upstream => *req.version_mut() = Version::HTTP_11, + UpstreamOption::ForceHttp2Upstream => { + // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt + // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required. + *req.version_mut() = Version::HTTP_2; + } + _ => (), + } + } + + Ok(()) +} diff --git a/rpxy-lib/src/name_exp.rs b/rpxy-lib/src/name_exp.rs new file mode 100644 index 0000000..8ed17e2 --- /dev/null +++ b/rpxy-lib/src/name_exp.rs @@ -0,0 +1,160 @@ +use std::borrow::Cow; + +/// Server name (hostname or ip address) representation in bytes-based struct +/// for searching hashmap or key list by exact or longest-prefix matching +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +pub struct ServerName { + inner: Vec, // lowercase ascii bytes +} +impl From<&str> for ServerName { + fn from(s: &str) -> Self { + let name = s.bytes().collect::>().to_ascii_lowercase(); + Self { inner: name } + } +} +impl From<&[u8]> for ServerName { + fn from(b: &[u8]) -> Self { + Self { + inner: b.to_ascii_lowercase(), + } + } +} +impl TryInto for &ServerName { + type Error = anyhow::Error; + fn try_into(self) -> Result { + let s = std::str::from_utf8(&self.inner)?; + Ok(s.to_string()) + } +} +impl AsRef<[u8]> for ServerName { + fn as_ref(&self) -> &[u8] { + self.inner.as_ref() + } +} + +/// Path name, like "/path/ok", represented in bytes-based struct +/// for searching hashmap or key list by exact or longest-prefix matching +#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] +pub struct PathName { + inner: Vec, // lowercase ascii bytes +} +impl From<&str> for PathName { + fn from(s: &str) -> Self { + let name = s.bytes().collect::>().to_ascii_lowercase(); + Self { inner: name } + } +} +impl From<&[u8]> for PathName { + fn from(b: &[u8]) -> Self { + Self { + inner: b.to_ascii_lowercase(), + } + } +} +impl TryInto for &PathName { + type Error = anyhow::Error; + fn try_into(self) -> Result { + let s = std::str::from_utf8(&self.inner)?; + Ok(s.to_string()) + } +} +impl AsRef<[u8]> for PathName { + fn as_ref(&self) -> &[u8] { + self.inner.as_ref() + } +} +impl PathName { + pub fn len(&self) -> usize { + self.inner.len() + } + pub fn is_empty(&self) -> bool { + self.inner.len() == 0 + } + pub fn get(&self, index: I) -> Option<&I::Output> + where + I: std::slice::SliceIndex<[u8]>, + { + self.inner.get(index) + } + pub fn starts_with(&self, needle: &Self) -> bool { + self.inner.starts_with(&needle.inner) + } +} + +/// Trait to express names in ascii-lowercased bytes +pub trait ByteName { + type OutputServer: Send + Sync + 'static; + type OutputPath; + fn to_server_name(self) -> Self::OutputServer; + fn to_path_name(self) -> Self::OutputPath; +} + +impl<'a, T: Into>> ByteName for T { + type OutputServer = ServerName; + type OutputPath = PathName; + + fn to_server_name(self) -> Self::OutputServer { + ServerName::from(self.into().as_ref()) + } + + fn to_path_name(self) -> Self::OutputPath { + PathName::from(self.into().as_ref()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn bytes_name_str_works() { + let s = "OK_string"; + let bn = s.to_path_name(); + let bn_lc = s.to_server_name(); + + assert_eq!("ok_string".as_bytes(), bn.as_ref()); + assert_eq!("ok_string".as_bytes(), bn_lc.as_ref()); + } + + #[test] + fn from_works() { + let s = "OK_string".to_server_name(); + let m = ServerName::from("OK_strinG".as_bytes()); + assert_eq!(s, m); + assert_eq!(s.as_ref(), "ok_string".as_bytes()); + assert_eq!(m.as_ref(), "ok_string".as_bytes()); + } + + #[test] + fn get_works() { + let s = "OK_str".to_path_name(); + let i = s.get(0); + assert_eq!(Some(&"o".as_bytes()[0]), i); + let i = s.get(1); + assert_eq!(Some(&"k".as_bytes()[0]), i); + let i = s.get(2); + assert_eq!(Some(&"_".as_bytes()[0]), i); + let i = s.get(3); + assert_eq!(Some(&"s".as_bytes()[0]), i); + let i = s.get(4); + assert_eq!(Some(&"t".as_bytes()[0]), i); + let i = s.get(5); + assert_eq!(Some(&"r".as_bytes()[0]), i); + let i = s.get(6); + assert_eq!(None, i); + } + + #[test] + fn start_with_works() { + let s = "OK_str".to_path_name(); + let correct = "OK".to_path_name(); + let incorrect = "KO".to_path_name(); + assert!(s.starts_with(&correct)); + assert!(!s.starts_with(&incorrect)); + } + + #[test] + fn as_ref_works() { + let s = "OK_str".to_path_name(); + assert_eq!(s.as_ref(), "ok_str".as_bytes()); + } +} diff --git a/rpxy-lib/src/proxy/mod.rs b/rpxy-lib/src/proxy/mod.rs index 0551b62..2cc9b75 100644 --- a/rpxy-lib/src/proxy/mod.rs +++ b/rpxy-lib/src/proxy/mod.rs @@ -1,13 +1,42 @@ -mod crypto_service; -mod proxy_client_cert; -#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] -mod proxy_h3; mod proxy_main; -#[cfg(feature = "http3-quinn")] -mod proxy_quic_quinn; -#[cfg(feature = "http3-s2n")] -mod proxy_quic_s2n; -mod proxy_tls; mod socket; -pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError}; +#[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] +mod proxy_h3; +#[cfg(feature = "http3-quinn")] +mod proxy_quic_quinn; +#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] +mod proxy_quic_s2n; + +use crate::{ + globals::Globals, + hyper_ext::rt::{LocalExecutor, TokioTimer}, +}; +use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; +use std::sync::Arc; + +pub(crate) use proxy_main::Proxy; + +/// build connection builder shared with proxy instances +pub(crate) fn connection_builder(globals: &Arc) -> Arc> { + let executor = LocalExecutor::new(globals.runtime_handle.clone()); + let mut http_server = server::conn::auto::Builder::new(executor); + http_server + .http1() + .keep_alive(globals.proxy_config.keepalive) + .header_read_timeout(globals.proxy_config.proxy_idle_timeout) + .timer(TokioTimer) + .pipeline_flush(true); + http_server + .http2() + .max_concurrent_streams(globals.proxy_config.max_concurrent_streams); + + if globals.proxy_config.keepalive { + http_server + .http2() + .keep_alive_interval(Some(globals.proxy_config.proxy_idle_timeout)) + .keep_alive_timeout(globals.proxy_config.proxy_idle_timeout + std::time::Duration::from_secs(1)) + .timer(TokioTimer); + } + Arc::new(http_server) +} diff --git a/rpxy-lib/src/proxy/proxy_client_cert.rs b/rpxy-lib/src/proxy/proxy_client_cert.rs deleted file mode 100644 index dfba4ce..0000000 --- a/rpxy-lib/src/proxy/proxy_client_cert.rs +++ /dev/null @@ -1,47 +0,0 @@ -use crate::{error::*, log::*}; -use rustc_hash::FxHashSet as HashSet; -use rustls::Certificate; -use x509_parser::extensions::ParsedExtension; -use x509_parser::prelude::*; - -#[allow(dead_code)] -// TODO: consider move this function to the layer of handle_request (L7) to return 403 -pub(super) fn check_client_authentication( - client_certs: Option<&[Certificate]>, - client_ca_keyids_set_for_sni: Option<&HashSet>>, -) -> std::result::Result<(), ClientCertsError> { - let Some(client_ca_keyids_set) = client_ca_keyids_set_for_sni else { - // No client cert settings for given server name - return Ok(()); - }; - - let Some(client_certs) = client_certs else { - error!("Client certificate is needed for given server name"); - return Err(ClientCertsError::ClientCertRequired( - "Client certificate is needed for given server name".to_string(), - )); - }; - debug!("Incoming TLS client is (temporarily) authenticated via client cert"); - - // Check client certificate key ids - let mut client_certs_parsed_iter = client_certs.iter().filter_map(|d| parse_x509_certificate(&d.0).ok()); - let match_server_crypto_and_client_cert = client_certs_parsed_iter.any(|c| { - let mut filtered = c.1.iter_extensions().filter_map(|e| { - if let ParsedExtension::AuthorityKeyIdentifier(key_id) = e.parsed_extension() { - key_id.key_identifier.as_ref() - } else { - None - } - }); - filtered.any(|id| client_ca_keyids_set.contains(id.0)) - }); - - if !match_server_crypto_and_client_cert { - error!("Inconsistent client certificate was provided for SNI"); - return Err(ClientCertsError::InconsistentClientCert( - "Inconsistent client certificate was provided for SNI".to_string(), - )); - } - - Ok(()) -} diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index fd07521..1e0f24f 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -1,25 +1,33 @@ -use super::Proxy; -use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; +use super::proxy_main::Proxy; +use crate::{ + crypto::CryptoSource, + error::*, + hyper_ext::body::{IncomingLike, RequestBody}, + log::*, + name_exp::ServerName, +}; use bytes::{Buf, Bytes}; +use http::{Request, Response}; +use http_body_util::BodyExt; +use hyper_util::client::legacy::connect::Connect; +use std::net::SocketAddr; + #[cfg(feature = "http3-quinn")] use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -use hyper::{client::connect::Connect, Body, Request, Response}; -#[cfg(feature = "http3-s2n")] +#[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; -use std::net::SocketAddr; -use tokio::time::{timeout, Duration}; -impl Proxy +impl Proxy where T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { - pub(super) async fn connection_serve_h3( + pub(super) async fn h3_serve_connection( &self, quic_connection: C, - tls_server_name: ServerNameBytesExp, + tls_server_name: ServerName, client_addr: SocketAddr, - ) -> Result<()> + ) -> RpxyResult<()> where C: ConnectionQuic, >::BidiStream: BidiStream + Send + 'static, @@ -28,9 +36,11 @@ where { let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; info!( - "QUIC/HTTP3 connection established from {:?} {:?}", - client_addr, tls_server_name + "QUIC/HTTP3 connection established from {:?} {}", + client_addr, + <&ServerName as TryInto>::try_into(&tls_server_name).unwrap_or_default() ); + // TODO: Is here enough to fetch server_name from NewConnection? // to avoid deep nested call from listener_service_h3 loop { @@ -60,13 +70,13 @@ where let self_inner = self.clone(); let tls_server_name_inner = tls_server_name.clone(); self.globals.runtime_handle.spawn(async move { - if let Err(e) = timeout( - self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2 - self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), - ) - .await - { - error!("HTTP/3 failed to process stream: {}", e); + let fut = self_inner.h3_serve_stream(req, stream, client_addr, tls_server_name_inner); + if let Some(connection_handling_timeout) = self_inner.globals.proxy_config.connection_handling_timeout { + if let Err(e) = tokio::time::timeout(connection_handling_timeout, fut).await { + warn!("HTTP/3 error on serve stream: {}", e); + }; + } else if let Err(e) = fut.await { + warn!("HTTP/3 error on serve stream: {}", e); } request_count.decrement(); debug!("Request processed: current # {}", request_count.current()); @@ -78,13 +88,17 @@ where Ok(()) } - async fn stream_serve_h3( + /// Serves a request stream from a client + /// Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside. + /// Thus, we needed to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of + /// Either as body. + async fn h3_serve_stream( &self, req: Request<()>, stream: RequestStream, client_addr: SocketAddr, - tls_server_name: ServerNameBytesExp, - ) -> Result<()> + tls_server_name: ServerName, + ) -> RpxyResult<()> where S: BidiStream + Send + 'static, >::RecvStream: Send, @@ -94,7 +108,7 @@ where let (mut send_stream, mut recv_stream) = stream.split(); // generate streamed body with trailers using channel - let (body_sender, req_body) = Body::channel(); + let (body_sender, req_body) = IncomingLike::channel(); // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1 // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn. @@ -107,10 +121,10 @@ where size += body.remaining(); if size > max_body_size { error!( - "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", + "Exceeds max request body size for HTTP/3: received {}, maximum_allowed {}", size, max_body_size ); - return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); + return Err(RpxyError::H3TooLargeBody); } // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes sender.send_data(body.copy_to_bytes(body.remaining())).await?; @@ -122,13 +136,12 @@ where debug!("HTTP/3 incoming request trailers"); sender.send_trailers(trailers.unwrap()).await?; } - Ok(()) + Ok(()) as RpxyResult<()> }); - let new_req: Request = Request::from_parts(req_parts, req_body); + let new_req: Request = Request::from_parts(req_parts, RequestBody::IncomingLike(req_body)); let res = self - .msg_handler - .clone() + .message_handler .handle_request( new_req, client_addr, @@ -138,21 +151,33 @@ where ) .await?; - let (new_res_parts, new_body) = res.into_parts(); + let (new_res_parts, mut new_body) = res.into_parts(); let new_res = Response::from_parts(new_res_parts, ()); match send_stream.send_response(new_res).await { Ok(_) => { debug!("HTTP/3 response to connection successful"); - // aggregate body without copying - let mut body_data = hyper::body::aggregate(new_body).await?; + // on-demand body streaming to downstream without expanding the object onto memory. + loop { + let frame = match new_body.frame().await { + Some(frame) => frame, + None => { + debug!("Response body finished"); + break; + } + } + .map_err(|e| RpxyError::HyperBodyManipulationError(e.to_string()))?; - // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes - send_stream - .send_data(body_data.copy_to_bytes(body_data.remaining())) - .await?; - - // TODO: needs handling trailer? should be included in body from handler. + if frame.is_data() { + let data = frame.into_data().unwrap_or_default(); + // debug!("Write data to HTTP/3 stream"); + send_stream.send_data(data).await?; + } else if frame.is_trailers() { + let trailers = frame.into_trailers().unwrap_or_default(); + // debug!("Write trailer to HTTP/3 stream"); + send_stream.send_trailers(trailers).await?; + } + } } Err(err) => { error!("Unable to send response to connection peer: {:?}", err); diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index bd52ea9..67eeb30 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -1,78 +1,81 @@ use super::socket::bind_tcp_socket; use crate::{ - certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp, + constants::TLS_HANDSHAKE_TIMEOUT_SEC, + crypto::{CryptoSource, ServerCrypto, SniServerCryptoMap}, + error::*, + globals::Globals, + hyper_ext::{ + body::{RequestBody, ResponseBody}, + rt::LocalExecutor, + }, + log::*, + message_handler::HttpMessageHandler, + name_exp::ServerName, }; -use derive_builder::{self, Builder}; -use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; -use std::{net::SocketAddr, sync::Arc}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - runtime::Handle, - sync::Notify, - time::{timeout, Duration}, +use futures::{select, FutureExt}; +use http::{Request, Response}; +use hyper::{ + body::Incoming, + rt::{Read, Write}, + service::service_fn, }; +use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; +use std::{net::SocketAddr, sync::Arc, time::Duration}; +use tokio::time::timeout; + +/// Wrapper function to handle request for HTTP/1.1 and HTTP/2 +/// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler +async fn serve_request( + req: Request, + handler: Arc>, + client_addr: SocketAddr, + listen_addr: SocketAddr, + tls_enabled: bool, + tls_server_name: Option, +) -> RpxyResult> +where + T: Send + Sync + Connect + Clone, + U: CryptoSource + Clone, +{ + handler + .handle_request( + req.map(RequestBody::Incoming), + client_addr, + listen_addr, + tls_enabled, + tls_server_name, + ) + .await +} #[derive(Clone)] -pub struct LocalExecutor { - runtime_handle: Handle, -} - -impl LocalExecutor { - fn new(runtime_handle: Handle) -> Self { - LocalExecutor { runtime_handle } - } -} - -impl hyper::rt::Executor for LocalExecutor +/// Proxy main object responsible to serve requests received from clients at the given socket address. +pub(crate) struct Proxy where - F: std::future::Future + Send + 'static, - F::Output: Send, -{ - fn execute(&self, fut: F) { - self.runtime_handle.spawn(fut); - } -} - -#[derive(Clone, Builder)] -pub struct Proxy -where - T: Connect + Clone + Sync + Send + 'static, + T: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { + /// global context shared among async tasks + pub globals: Arc, + /// listen socket address pub listening_on: SocketAddr, - pub tls_enabled: bool, // TCP待受がTLSかどうか - pub msg_handler: Arc>, - pub globals: Arc>, + /// whether TLS is enabled or not + pub tls_enabled: bool, + /// hyper connection builder serving http request + pub connection_builder: Arc>, + /// message handler serving incoming http request + pub message_handler: Arc>, } -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, - U: CryptoSource + Clone + Sync + Send, + T: Send + Sync + Connect + Clone + 'static, + U: CryptoSource + Clone + Sync + Send + 'static, { - /// Wrapper function to handle request - async fn serve( - handler: Arc>, - req: Request, - client_addr: SocketAddr, - listen_addr: SocketAddr, - tls_enabled: bool, - tls_server_name: Option, - ) -> Result> { - handler - .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) - .await - } - /// Serves requests from clients - pub(super) fn client_serve( - self, - stream: I, - server: Http, - peer_addr: SocketAddr, - tls_server_name: Option, - ) where - I: AsyncRead + AsyncWrite + Send + Unpin + 'static, + fn serve_connection(&self, stream: I, peer_addr: SocketAddr, tls_server_name: Option) + where + I: Read + Write + Send + Unpin + 'static, { let request_count = self.globals.request_count.clone(); if request_count.increment() > self.globals.proxy_config.max_clients { @@ -81,27 +84,32 @@ where } debug!("Request incoming: current # {}", request_count.current()); + let server_clone = self.connection_builder.clone(); + let message_handler_clone = self.message_handler.clone(); + let tls_enabled = self.tls_enabled; + let listening_on = self.listening_on; + let handling_timeout = self.globals.proxy_config.connection_handling_timeout; + self.globals.runtime_handle.clone().spawn(async move { - timeout( - self.globals.proxy_config.proxy_timeout + Duration::from_secs(1), - server - .serve_connection( - stream, - service_fn(move |req: Request| { - Self::serve( - self.msg_handler.clone(), - req, - peer_addr, - self.listening_on, - self.tls_enabled, - tls_server_name.clone(), - ) - }), + let fut = server_clone.serve_connection_with_upgrades( + stream, + service_fn(move |req: Request| { + serve_request( + req, + message_handler_clone.clone(), + peer_addr, + listening_on, + tls_enabled, + tls_server_name.clone(), ) - .with_upgrades(), - ) - .await - .ok(); + }), + ); + + if let Some(handling_timeout) = handling_timeout { + timeout(handling_timeout, fut).await.ok(); + } else { + fut.await.ok(); + } request_count.decrement(); debug!("Request processed: current # {}", request_count.current()); @@ -109,47 +117,149 @@ where } /// Start without TLS (HTTP cleartext) - async fn start_without_tls(self, server: Http) -> Result<()> { + async fn start_without_tls(&self) -> RpxyResult<()> { let listener_service = async { let tcp_socket = bind_tcp_socket(&self.listening_on)?; let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; info!("Start TCP proxy serving with HTTP request for configured host names"); - while let Ok((stream, _client_addr)) = tcp_listener.accept().await { - self.clone().client_serve(stream, server.clone(), _client_addr, None); + while let Ok((stream, client_addr)) = tcp_listener.accept().await { + self.serve_connection(TokioIo::new(stream), client_addr, None); } - Ok(()) as Result<()> + Ok(()) as RpxyResult<()> }; listener_service.await?; Ok(()) } - /// Entrypoint for HTTP/1.1 and HTTP/2 servers - pub async fn start(self, term_notify: Option>) -> Result<()> { - let mut server = Http::new(); - server.http1_keep_alive(self.globals.proxy_config.keepalive); - server.http2_max_concurrent_streams(self.globals.proxy_config.max_concurrent_streams); - server.pipeline_flush(true); - let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); - let server = server.with_executor(executor); + /// Start with TLS (HTTPS) + pub(super) async fn start_with_tls(&self) -> RpxyResult<()> { + #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] + { + self.tls_listener_service().await?; + error!("TCP proxy service for TLS exited"); + Ok(()) + } + #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] + { + if self.globals.proxy_config.http3 { + select! { + _ = self.tls_listener_service().fuse() => { + error!("TCP proxy service for TLS exited"); + }, + _ = self.h3_listener_service().fuse() => { + error!("UDP proxy service for QUIC exited"); + } + }; + Ok(()) + } else { + self.tls_listener_service().await?; + error!("TCP proxy service for TLS exited"); + Ok(()) + } + } + } - let listening_on = self.listening_on; + // TCP Listener Service, i.e., http/2 and http/1.1 + async fn tls_listener_service(&self) -> RpxyResult<()> { + let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { + return Err(RpxyError::NoCertificateReloader); + }; + let tcp_socket = bind_tcp_socket(&self.listening_on)?; + let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; + info!("Start TCP proxy serving with HTTPS request for configured host names"); + let mut server_crypto_map: Option> = None; + loop { + select! { + tcp_cnx = tcp_listener.accept().fuse() => { + if tcp_cnx.is_err() || server_crypto_map.is_none() { + continue; + } + let (raw_stream, client_addr) = tcp_cnx.unwrap(); + let sc_map_inner = server_crypto_map.clone(); + let self_inner = self.clone(); + + // spawns async handshake to avoid blocking thread by sequential handshake. + let handshake_fut = async move { + let acceptor = tokio_rustls::LazyConfigAcceptor::new(tokio_rustls::rustls::server::Acceptor::default(), raw_stream).await; + if let Err(e) = acceptor { + return Err(RpxyError::FailedToTlsHandshake(e.to_string())); + } + let start = acceptor.unwrap(); + let client_hello = start.client_hello(); + let sni = client_hello.server_name(); + debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", sni.unwrap_or("None")); + let server_name = sni.map(ServerName::from); + if server_name.is_none(){ + return Err(RpxyError::NoServerNameInClientHello); + } + let server_crypto = sc_map_inner.as_ref().unwrap().get(server_name.as_ref().unwrap()); + if server_crypto.is_none() { + return Err(RpxyError::NoTlsServingApp(server_name.as_ref().unwrap().try_into().unwrap_or_default())); + } + let stream = match start.into_stream(server_crypto.unwrap().clone()).await { + Ok(s) => TokioIo::new(s), + Err(e) => { + return Err(RpxyError::FailedToTlsHandshake(e.to_string())); + } + }; + Ok((stream, client_addr, server_name)) + }; + + self.globals.runtime_handle.spawn( async move { + // timeout is introduced to avoid get stuck here. + let Ok(v) = timeout( + Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), + handshake_fut + ).await else { + error!("Timeout to handshake TLS"); + return; + }; + match v { + Ok((stream, client_addr, server_name)) => { + self_inner.serve_connection(stream, client_addr, server_name); + } + Err(e) => { + error!("{}", e); + } + } + }); + } + _ = server_crypto_rx.changed().fuse() => { + if server_crypto_rx.borrow().is_none() { + error!("Reloader is broken"); + break; + } + let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); + let Some(server_crypto): Option> = (&cert_keys_map).try_into().ok() else { + error!("Failed to update server crypto"); + break; + }; + server_crypto_map = Some(server_crypto.inner_local_map.clone()); + } + } + } + Ok(()) + } + + /// Entrypoint for HTTP/1.1, 2 and 3 servers + pub async fn start(&self) -> RpxyResult<()> { let proxy_service = async { if self.tls_enabled { - self.start_with_tls(server).await + self.start_with_tls().await } else { - self.start_without_tls(server).await + self.start_without_tls().await } }; - match term_notify { + match &self.globals.term_notify { Some(term) => { - tokio::select! { - _ = proxy_service => { + select! { + _ = proxy_service.fuse() => { warn!("Proxy service got down"); } - _ = term.notified() => { - info!("Proxy service listening on {} receives term signal", listening_on); + _ = term.notified().fuse() => { + info!("Proxy service listening on {} receives term signal", self.listening_on); } } } @@ -159,8 +269,6 @@ where } } - // proxy_service.await?; - Ok(()) } } diff --git a/rpxy-lib/src/proxy/proxy_quic_quinn.rs b/rpxy-lib/src/proxy/proxy_quic_quinn.rs index fb08420..9c4bf4e 100644 --- a/rpxy-lib/src/proxy/proxy_quic_quinn.rs +++ b/rpxy-lib/src/proxy/proxy_quic_quinn.rs @@ -1,30 +1,32 @@ +use super::proxy_main::Proxy; use super::socket::bind_udp_socket; -use super::{ - crypto_service::{ServerCrypto, ServerCryptoBase}, - proxy_main::Proxy, +use crate::{ + crypto::{CryptoSource, ServerCrypto}, + error::*, + log::*, + name_exp::ByteName, }; -use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; -use hot_reload::ReloaderReceiver; -use hyper::client::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; use rustls::ServerConfig; use std::sync::Arc; -impl Proxy +impl Proxy where - T: Connect + Clone + Sync + Send + 'static, + T: Send + Sync + Connect + Clone + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { - pub(super) async fn listener_service_h3( - &self, - mut server_crypto_rx: ReloaderReceiver, - ) -> Result<()> { + pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { + let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { + return Err(RpxyError::NoCertificateReloader); + }; info!("Start UDP proxy serving with HTTP/3 request for configured host names [quinn]"); // first set as null config server let rustls_server_config = ServerConfig::builder() .with_safe_default_cipher_suites() .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13])? + .with_protocol_versions(&[&rustls::version::TLS13]) + .map_err(|e| RpxyError::QuinnInvalidTlsProtocolVersion(e.to_string()))? .with_no_client_auth() .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); @@ -90,11 +92,11 @@ where }, Err(e) => { warn!("QUIC accepting connection failed: {:?}", e); - return Err(RpxyError::QuicConn(e)); + return Err(RpxyError::QuinnConnectionFailed(e)); } }; // Timeout is based on underlying quic - if let Err(e) = self_clone.connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr).await { + if let Err(e) = self_clone.h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr).await { warn!("QUIC or HTTP/3 connection failed: {}", e); }; Ok(()) @@ -119,6 +121,6 @@ where } } endpoint.wait_idle().await; - Ok(()) as Result<()> + Ok(()) as RpxyResult<()> } } diff --git a/rpxy-lib/src/proxy/proxy_quic_s2n.rs b/rpxy-lib/src/proxy/proxy_quic_s2n.rs index e0c41a5..13a8802 100644 --- a/rpxy-lib/src/proxy/proxy_quic_s2n.rs +++ b/rpxy-lib/src/proxy/proxy_quic_s2n.rs @@ -1,22 +1,27 @@ -use super::{ - crypto_service::{ServerCrypto, ServerCryptoBase}, - proxy_main::Proxy, +use super::proxy_main::Proxy; +use crate::{ + crypto::CryptoSource, + crypto::{ServerCrypto, ServerCryptoBase}, + error::*, + log::*, + name_exp::ByteName, }; -use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; +use anyhow::anyhow; use hot_reload::ReloaderReceiver; -use hyper::client::connect::Connect; +use hyper_util::client::legacy::connect::Connect; use s2n_quic::provider; use std::sync::Arc; -impl Proxy +impl Proxy where T: Connect + Clone + Sync + Send + 'static, U: CryptoSource + Clone + Sync + Send + 'static, { - pub(super) async fn listener_service_h3( - &self, - mut server_crypto_rx: ReloaderReceiver, - ) -> Result<()> { + /// Start UDP proxy serving with HTTP/3 request for configured host names + pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { + let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { + return Err(RpxyError::NoCertificateReloader); + }; info!("Start UDP proxy serving with HTTP/3 request for configured host names [s2n-quic]"); // initially wait for receipt @@ -29,7 +34,7 @@ where // event loop loop { tokio::select! { - v = self.serve_connection(&server_crypto) => { + v = self.h3_listener_service_inner(&server_crypto) => { if let Err(e) = v { error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); break; @@ -51,20 +56,25 @@ where Ok(()) } - fn receive_server_crypto(&self, server_crypto_rx: ReloaderReceiver) -> Result> { + /// Receive server crypto from reloader + fn receive_server_crypto( + &self, + server_crypto_rx: ReloaderReceiver, + ) -> RpxyResult> { let cert_keys_map = server_crypto_rx.borrow().clone().ok_or_else(|| { error!("Reloader is broken"); - RpxyError::Other(anyhow!("Reloader is broken")) + RpxyError::CertificateReloadError(anyhow!("Reloader is broken").into()) })?; let server_crypto: Option> = (&cert_keys_map).try_into().ok(); server_crypto.ok_or_else(|| { error!("Failed to update server crypto for h3 [s2n-quic]"); - RpxyError::Other(anyhow!("Failed to update server crypto for h3 [s2n-quic]")) + RpxyError::FailedToUpdateServerCrypto("Failed to update server crypto for h3 [s2n-quic]".to_string()) }) } - async fn serve_connection(&self, server_crypto: &Option>) -> Result<()> { + /// Event loop for UDP proxy serving with HTTP/3 request for configured host names + async fn h3_listener_service_inner(&self, server_crypto: &Option>) -> RpxyResult<()> { // setup UDP socket let io = provider::io::tokio::Builder::default() .with_receive_address(self.listening_on)? @@ -73,18 +83,13 @@ where // setup limits let mut limits = provider::limits::Limits::default() - .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) - .map_err(|e| anyhow!(e))? - .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) - .map_err(|e| anyhow!(e))? - .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) - .map_err(|e| anyhow!(e))? - .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) - .map_err(|e| anyhow!(e))? - .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64) - .map_err(|e| anyhow!(e))?; + .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? + .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? + .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? + .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? + .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64)?; limits = if let Some(v) = self.globals.proxy_config.h3_max_idle_timeout { - limits.with_max_idle_timeout(v).map_err(|e| anyhow!(e))? + limits.with_max_idle_timeout(v)? } else { limits }; @@ -92,27 +97,25 @@ where // setup tls let Some(server_crypto) = server_crypto else { warn!("No server crypto is given [s2n-quic]"); - return Err(RpxyError::Other(anyhow!("No server crypto is given [s2n-quic]"))); + return Err(RpxyError::NoServerCrypto( + "No server crypto is given [s2n-quic]".to_string(), + )); }; let tls = server_crypto.inner_global_no_client_auth.clone(); let mut server = s2n_quic::Server::builder() - .with_tls(tls) - .map_err(|e| anyhow::anyhow!(e))? - .with_io(io) - .map_err(|e| anyhow!(e))? - .with_limits(limits) - .map_err(|e| anyhow!(e))? - .start() - .map_err(|e| anyhow!(e))?; + .with_tls(tls)? + .with_io(io)? + .with_limits(limits)? + .start()?; // quic event loop. this immediately cancels when crypto is updated by tokio::select! while let Some(new_conn) = server.accept().await { debug!("New QUIC connection established"); let Ok(Some(new_server_name)) = new_conn.server_name() else { - warn!("HTTP/3 no SNI is given"); - continue; - }; + warn!("HTTP/3 no SNI is given"); + continue; + }; debug!("HTTP/3 connection incoming (SNI {:?})", new_server_name); let self_clone = self.clone(); @@ -121,12 +124,12 @@ where let quic_connection = s2n_quic_h3::Connection::new(new_conn); // Timeout is based on underlying quic if let Err(e) = self_clone - .connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr) + .h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr) .await { warn!("QUIC or HTTP/3 connection failed: {}", e); }; - Ok(()) as Result<()> + Ok(()) as RpxyResult<()> }); } diff --git a/rpxy-lib/src/proxy/proxy_tls.rs b/rpxy-lib/src/proxy/proxy_tls.rs deleted file mode 100644 index 7c5d601..0000000 --- a/rpxy-lib/src/proxy/proxy_tls.rs +++ /dev/null @@ -1,163 +0,0 @@ -use super::{ - crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, - proxy_main::{LocalExecutor, Proxy}, - socket::bind_tcp_socket, -}; -use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName}; -use hot_reload::{ReloaderReceiver, ReloaderService}; -use hyper::{client::connect::Connect, server::conn::Http}; -use std::sync::Arc; -use tokio::time::{timeout, Duration}; - -impl Proxy -where - T: Connect + Clone + Sync + Send + 'static, - U: CryptoSource + Clone + Sync + Send + 'static, -{ - // TCP Listener Service, i.e., http/2 and http/1.1 - async fn listener_service( - &self, - server: Http, - mut server_crypto_rx: ReloaderReceiver, - ) -> Result<()> { - let tcp_socket = bind_tcp_socket(&self.listening_on)?; - let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; - info!("Start TCP proxy serving with HTTPS request for configured host names"); - - let mut server_crypto_map: Option> = None; - loop { - tokio::select! { - tcp_cnx = tcp_listener.accept() => { - if tcp_cnx.is_err() || server_crypto_map.is_none() { - continue; - } - let (raw_stream, client_addr) = tcp_cnx.unwrap(); - let sc_map_inner = server_crypto_map.clone(); - let server_clone = server.clone(); - let self_inner = self.clone(); - - // spawns async handshake to avoid blocking thread by sequential handshake. - let handshake_fut = async move { - let acceptor = tokio_rustls::LazyConfigAcceptor::new(tokio_rustls::rustls::server::Acceptor::default(), raw_stream).await; - if let Err(e) = acceptor { - return Err(RpxyError::Proxy(format!("Failed to handshake TLS: {e}"))); - } - let start = acceptor.unwrap(); - let client_hello = start.client_hello(); - let server_name = client_hello.server_name(); - debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name); - let server_name_in_bytes = server_name.map_or_else(|| None, |v| Some(v.to_server_name_vec())); - if server_name_in_bytes.is_none(){ - return Err(RpxyError::Proxy("No SNI is given".to_string())); - } - let server_crypto = sc_map_inner.as_ref().unwrap().get(server_name_in_bytes.as_ref().unwrap()); - if server_crypto.is_none() { - return Err(RpxyError::Proxy(format!("No TLS serving app for {:?}", server_name.unwrap()))); - } - let stream = match start.into_stream(server_crypto.unwrap().clone()).await { - Ok(s) => s, - Err(e) => { - return Err(RpxyError::Proxy(format!("Failed to handshake TLS: {e}"))); - } - }; - self_inner.client_serve(stream, server_clone, client_addr, server_name_in_bytes); - Ok(()) - }; - - self.globals.runtime_handle.spawn( async move { - // timeout is introduced to avoid get stuck here. - match timeout( - Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), - handshake_fut - ).await { - Ok(a) => { - if let Err(e) = a { - error!("{}", e); - } - }, - Err(e) => { - error!("Timeout to handshake TLS: {}", e); - } - }; - }); - } - _ = server_crypto_rx.changed() => { - if server_crypto_rx.borrow().is_none() { - error!("Reloader is broken"); - break; - } - let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); - let Some(server_crypto): Option> = (&cert_keys_map).try_into().ok() else { - error!("Failed to update server crypto"); - break; - }; - server_crypto_map = Some(server_crypto.inner_local_map.clone()); - } - else => break - } - } - Ok(()) as Result<()> - } - - pub async fn start_with_tls(self, server: Http) -> Result<()> { - let (cert_reloader_service, cert_reloader_rx) = ReloaderService::, ServerCryptoBase>::new( - &self.globals.clone(), - CERTS_WATCH_DELAY_SECS, - !LOAD_CERTS_ONLY_WHEN_UPDATED, - ) - .await - .map_err(|e| anyhow::anyhow!(e))?; - - #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] - { - tokio::select! { - _= cert_reloader_service.start() => { - error!("Cert service for TLS exited"); - }, - _ = self.listener_service(server, cert_reloader_rx) => { - error!("TCP proxy service for TLS exited"); - }, - else => { - error!("Something went wrong"); - return Ok(()) - } - }; - Ok(()) - } - #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] - { - if self.globals.proxy_config.http3 { - tokio::select! { - _= cert_reloader_service.start() => { - error!("Cert service for TLS exited"); - }, - _ = self.listener_service(server, cert_reloader_rx.clone()) => { - error!("TCP proxy service for TLS exited"); - }, - _= self.listener_service_h3(cert_reloader_rx) => { - error!("UDP proxy service for QUIC exited"); - }, - else => { - error!("Something went wrong"); - return Ok(()) - } - }; - Ok(()) - } else { - tokio::select! { - _= cert_reloader_service.start() => { - error!("Cert service for TLS exited"); - }, - _ = self.listener_service(server, cert_reloader_rx) => { - error!("TCP proxy service for TLS exited"); - }, - else => { - error!("Something went wrong"); - return Ok(()) - } - }; - Ok(()) - } - } - } -} diff --git a/rpxy-lib/src/proxy/socket.rs b/rpxy-lib/src/proxy/socket.rs index 9e4c8f9..322b42b 100644 --- a/rpxy-lib/src/proxy/socket.rs +++ b/rpxy-lib/src/proxy/socket.rs @@ -8,7 +8,7 @@ use tokio::net::TcpSocket; /// Bind TCP socket to the given `SocketAddr`, and returns the TCP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options. /// This option is required to re-bind the socket address when the proxy instance is reconstructed. -pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> Result { +pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> RpxyResult { let tcp_socket = if listening_on.is_ipv6() { TcpSocket::new_v6() } else { @@ -26,7 +26,7 @@ pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> Result { #[cfg(feature = "http3-quinn")] /// Bind UDP socket to the given `SocketAddr`, and returns the UDP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options. /// This option is required to re-bind the socket address when the proxy instance is reconstructed. -pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> Result { +pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> RpxyResult { let socket = if listening_on.is_ipv6() { Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) } else { @@ -34,6 +34,7 @@ pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> Result { }?; socket.set_reuse_address(true)?; // This isn't necessary? socket.set_reuse_port(true)?; + socket.set_nonblocking(true)?; // This was made true inside quinn. so this line isn't necessary here. but just in case. if let Err(e) = socket.bind(&(*listening_on).into()) { error!("Failed to bind UDP socket: {}", e); diff --git a/rpxy-lib/src/utils/bytes_name.rs b/rpxy-lib/src/utils/bytes_name.rs deleted file mode 100644 index 5d2fef5..0000000 --- a/rpxy-lib/src/utils/bytes_name.rs +++ /dev/null @@ -1,123 +0,0 @@ -/// Server name (hostname or ip address) representation in bytes-based struct -/// for searching hashmap or key list by exact or longest-prefix matching -#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] -pub struct ServerNameBytesExp(pub Vec); // lowercase ascii bytes -impl From<&[u8]> for ServerNameBytesExp { - fn from(b: &[u8]) -> Self { - Self(b.to_ascii_lowercase()) - } -} -impl TryInto for &ServerNameBytesExp { - type Error = anyhow::Error; - fn try_into(self) -> Result { - let s = std::str::from_utf8(&self.0)?; - Ok(s.to_string()) - } -} - -/// Path name, like "/path/ok", represented in bytes-based struct -/// for searching hashmap or key list by exact or longest-prefix matching -#[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] -pub struct PathNameBytesExp(pub Vec); // lowercase ascii bytes -impl PathNameBytesExp { - pub fn len(&self) -> usize { - self.0.len() - } - pub fn is_empty(&self) -> bool { - self.0.len() == 0 - } - pub fn get(&self, index: I) -> Option<&I::Output> - where - I: std::slice::SliceIndex<[u8]>, - { - self.0.get(index) - } - pub fn starts_with(&self, needle: &Self) -> bool { - self.0.starts_with(&needle.0) - } -} -impl AsRef<[u8]> for PathNameBytesExp { - fn as_ref(&self) -> &[u8] { - self.0.as_ref() - } -} - -/// Trait to express names in ascii-lowercased bytes -pub trait BytesName { - type OutputSv: Send + Sync + 'static; - type OutputPath; - fn to_server_name_vec(self) -> Self::OutputSv; - fn to_path_name_vec(self) -> Self::OutputPath; -} - -impl<'a, T: Into>> BytesName for T { - type OutputSv = ServerNameBytesExp; - type OutputPath = PathNameBytesExp; - - fn to_server_name_vec(self) -> Self::OutputSv { - let name = self.into().bytes().collect::>().to_ascii_lowercase(); - ServerNameBytesExp(name) - } - - fn to_path_name_vec(self) -> Self::OutputPath { - let name = self.into().bytes().collect::>().to_ascii_lowercase(); - PathNameBytesExp(name) - } -} - -#[cfg(test)] -mod tests { - use super::*; - #[test] - fn bytes_name_str_works() { - let s = "OK_string"; - let bn = s.to_path_name_vec(); - let bn_lc = s.to_server_name_vec(); - - assert_eq!(Vec::from("ok_string".as_bytes()), bn.0); - assert_eq!(Vec::from("ok_string".as_bytes()), bn_lc.0); - } - - #[test] - fn from_works() { - let s = "OK_string".to_server_name_vec(); - let m = ServerNameBytesExp::from("OK_strinG".as_bytes()); - assert_eq!(s, m); - assert_eq!(s.0, "ok_string".as_bytes().to_vec()); - assert_eq!(m.0, "ok_string".as_bytes().to_vec()); - } - - #[test] - fn get_works() { - let s = "OK_str".to_path_name_vec(); - let i = s.get(0); - assert_eq!(Some(&"o".as_bytes()[0]), i); - let i = s.get(1); - assert_eq!(Some(&"k".as_bytes()[0]), i); - let i = s.get(2); - assert_eq!(Some(&"_".as_bytes()[0]), i); - let i = s.get(3); - assert_eq!(Some(&"s".as_bytes()[0]), i); - let i = s.get(4); - assert_eq!(Some(&"t".as_bytes()[0]), i); - let i = s.get(5); - assert_eq!(Some(&"r".as_bytes()[0]), i); - let i = s.get(6); - assert_eq!(None, i); - } - - #[test] - fn start_with_works() { - let s = "OK_str".to_path_name_vec(); - let correct = "OK".to_path_name_vec(); - let incorrect = "KO".to_path_name_vec(); - assert!(s.starts_with(&correct)); - assert!(!s.starts_with(&incorrect)); - } - - #[test] - fn as_ref_works() { - let s = "OK_str".to_path_name_vec(); - assert_eq!(s.as_ref(), "ok_str".as_bytes()); - } -} diff --git a/rpxy-lib/src/utils/mod.rs b/rpxy-lib/src/utils/mod.rs deleted file mode 100644 index ed8d4ff..0000000 --- a/rpxy-lib/src/utils/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -mod bytes_name; -mod socket_addr; - -pub use bytes_name::{BytesName, PathNameBytesExp, ServerNameBytesExp}; -pub use socket_addr::ToCanonical; diff --git a/submodules/h3 b/submodules/h3 index a57ed22..c11410c 160000 --- a/submodules/h3 +++ b/submodules/h3 @@ -1 +1 @@ -Subproject commit a57ed224ac5d17a635eb71eb6f83c1196f581a51 +Subproject commit c11410c76e738a62e62e7766b82f814547621f6f diff --git a/submodules/h3-quinn/Cargo.toml b/submodules/h3-quinn/Cargo.toml deleted file mode 100644 index df34822..0000000 --- a/submodules/h3-quinn/Cargo.toml +++ /dev/null @@ -1,24 +0,0 @@ -[package] -name = "h3-quinn" -version = "0.0.1" -rust-version = "1.59" -authors = ["Jean-Christophe BEGUE "] -edition = "2018" -documentation = "https://docs.rs/h3-quinn" -repository = "https://github.com/hyperium/h3" -readme = "../README.md" -description = "QUIC transport implementation based on Quinn." -keywords = ["http3", "quic", "h3"] -categories = ["network-programming", "web-programming"] -license = "MIT" - -[dependencies] -h3 = { version = "0.0.2", path = "../h3/h3" } -bytes = "1" -quinn = { path = "../quinn/quinn/", default-features = false, features = [ - "futures-io", -] } -quinn-proto = { path = "../quinn/quinn-proto/", default-features = false } -tokio-util = { version = "0.7.8" } -futures = { version = "0.3.27" } -tokio = { version = "1.28", features = ["io-util"], default-features = false } diff --git a/submodules/h3-quinn/src/lib.rs b/submodules/h3-quinn/src/lib.rs deleted file mode 100644 index 78696de..0000000 --- a/submodules/h3-quinn/src/lib.rs +++ /dev/null @@ -1,740 +0,0 @@ -//! QUIC Transport implementation with Quinn -//! -//! This module implements QUIC traits with Quinn. -#![deny(missing_docs)] - -use std::{ - convert::TryInto, - fmt::{self, Display}, - future::Future, - pin::Pin, - sync::Arc, - task::{self, Poll}, -}; - -use bytes::{Buf, Bytes, BytesMut}; - -use futures::{ - ready, - stream::{self, BoxStream}, - StreamExt, -}; -use quinn::ReadDatagram; -pub use quinn::{ - self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError, -}; - -use h3::{ - ext::Datagram, - quic::{self, Error, StreamId, WriteBuf}, -}; -use tokio_util::sync::ReusableBoxFuture; - -/// A QUIC connection backed by Quinn -/// -/// Implements a [`quic::Connection`] backed by a [`quinn::Connection`]. -pub struct Connection { - conn: quinn::Connection, - incoming_bi: BoxStream<'static, as Future>::Output>, - opening_bi: Option as Future>::Output>>, - incoming_uni: BoxStream<'static, as Future>::Output>, - opening_uni: Option as Future>::Output>>, - datagrams: BoxStream<'static, as Future>::Output>, -} - -impl Connection { - /// Create a [`Connection`] from a [`quinn::NewConnection`] - pub fn new(conn: quinn::Connection) -> Self { - Self { - conn: conn.clone(), - incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async { - Some((conn.accept_bi().await, conn)) - })), - opening_bi: None, - incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { - Some((conn.accept_uni().await, conn)) - })), - opening_uni: None, - datagrams: Box::pin(stream::unfold(conn, |conn| async { - Some((conn.read_datagram().await, conn)) - })), - } - } -} - -/// The error type for [`Connection`] -/// -/// Wraps reasons a Quinn connection might be lost. -#[derive(Debug)] -pub struct ConnectionError(quinn::ConnectionError); - -impl std::error::Error for ConnectionError {} - -impl fmt::Display for ConnectionError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl Error for ConnectionError { - fn is_timeout(&self) -> bool { - matches!(self.0, quinn::ConnectionError::TimedOut) - } - - fn err_code(&self) -> Option { - match self.0 { - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }) => Some(error_code.into_inner()), - _ => None, - } - } -} - -impl From for ConnectionError { - fn from(e: quinn::ConnectionError) -> Self { - Self(e) - } -} - -/// Types of errors when sending a datagram. -#[derive(Debug)] -pub enum SendDatagramError { - /// Datagrams are not supported by the peer - UnsupportedByPeer, - /// Datagrams are locally disabled - Disabled, - /// The datagram was too large to be sent. - TooLarge, - /// Network error - ConnectionLost(Box), -} - -impl fmt::Display for SendDatagramError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), - SendDatagramError::Disabled => write!(f, "datagram support disabled"), - SendDatagramError::TooLarge => write!(f, "datagram too large"), - SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), - } - } -} - -impl std::error::Error for SendDatagramError {} - -impl Error for SendDatagramError { - fn is_timeout(&self) -> bool { - false - } - - fn err_code(&self) -> Option { - match self { - Self::ConnectionLost(err) => err.err_code(), - _ => None, - } - } -} - -impl From for SendDatagramError { - fn from(value: quinn::SendDatagramError) -> Self { - match value { - quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, - quinn::SendDatagramError::Disabled => Self::Disabled, - quinn::SendDatagramError::TooLarge => Self::TooLarge, - quinn::SendDatagramError::ConnectionLost(err) => { - Self::ConnectionLost(ConnectionError::from(err).into()) - } - } - } -} - -impl quic::Connection for Connection -where - B: Buf, -{ - type SendStream = SendStream; - type RecvStream = RecvStream; - type BidiStream = BidiStream; - type OpenStreams = OpenStreams; - type Error = ConnectionError; - - fn poll_accept_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) { - Some(x) => x?, - None => return Poll::Ready(Ok(None)), - }; - Poll::Ready(Ok(Some(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - }))) - } - - fn poll_accept_recv( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) { - Some(x) => x?, - None => return Poll::Ready(Ok(None)), - }; - Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) - } - - fn poll_open_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_bi.is_none() { - self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.clone().open_bi().await, conn)) - }))); - } - - let (send, recv) = - ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - })) - } - - fn poll_open_send( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_uni.is_none() { - self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_uni().await, conn)) - }))); - } - - let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::SendStream::new(send))) - } - - fn opener(&self) -> Self::OpenStreams { - OpenStreams { - conn: self.conn.clone(), - opening_bi: None, - opening_uni: None, - } - } - - fn close(&mut self, code: h3::error::Code, reason: &[u8]) { - self.conn.close( - VarInt::from_u64(code.value()).expect("error code VarInt"), - reason, - ); - } -} - -impl quic::SendDatagramExt for Connection -where - B: Buf, -{ - type Error = SendDatagramError; - - fn send_datagram(&mut self, data: Datagram) -> Result<(), SendDatagramError> { - // TODO investigate static buffer from known max datagram size - let mut buf = BytesMut::new(); - data.encode(&mut buf); - self.conn.send_datagram(buf.freeze())?; - - Ok(()) - } -} - -impl quic::RecvDatagramExt for Connection { - type Buf = Bytes; - - type Error = ConnectionError; - - #[inline] - fn poll_accept_datagram( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - match ready!(self.datagrams.poll_next_unpin(cx)) { - Some(Ok(x)) => Poll::Ready(Ok(Some(x))), - Some(Err(e)) => Poll::Ready(Err(e.into())), - None => Poll::Ready(Ok(None)), - } - } -} - -/// Stream opener backed by a Quinn connection -/// -/// Implements [`quic::OpenStreams`] using [`quinn::Connection`], -/// [`quinn::OpenBi`], [`quinn::OpenUni`]. -pub struct OpenStreams { - conn: quinn::Connection, - opening_bi: Option as Future>::Output>>, - opening_uni: Option as Future>::Output>>, -} - -impl quic::OpenStreams for OpenStreams -where - B: Buf, -{ - type RecvStream = RecvStream; - type SendStream = SendStream; - type BidiStream = BidiStream; - type Error = ConnectionError; - - fn poll_open_bidi( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_bi.is_none() { - self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_bi().await, conn)) - }))); - } - - let (send, recv) = - ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::BidiStream { - send: Self::SendStream::new(send), - recv: Self::RecvStream::new(recv), - })) - } - - fn poll_open_send( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll> { - if self.opening_uni.is_none() { - self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { - Some((conn.open_uni().await, conn)) - }))); - } - - let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; - Poll::Ready(Ok(Self::SendStream::new(send))) - } - - fn close(&mut self, code: h3::error::Code, reason: &[u8]) { - self.conn.close( - VarInt::from_u64(code.value()).expect("error code VarInt"), - reason, - ); - } -} - -impl Clone for OpenStreams { - fn clone(&self) -> Self { - Self { - conn: self.conn.clone(), - opening_bi: None, - opening_uni: None, - } - } -} - -/// Quinn-backed bidirectional stream -/// -/// Implements [`quic::BidiStream`] which allows the stream to be split -/// into two structs each implementing one direction. -pub struct BidiStream -where - B: Buf, -{ - send: SendStream, - recv: RecvStream, -} - -impl quic::BidiStream for BidiStream -where - B: Buf, -{ - type SendStream = SendStream; - type RecvStream = RecvStream; - - fn split(self) -> (Self::SendStream, Self::RecvStream) { - (self.send, self.recv) - } -} - -impl quic::RecvStream for BidiStream { - type Buf = Bytes; - type Error = ReadError; - - fn poll_data( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - self.recv.poll_data(cx) - } - - fn stop_sending(&mut self, error_code: u64) { - self.recv.stop_sending(error_code) - } - - fn recv_id(&self) -> StreamId { - self.recv.recv_id() - } -} - -impl quic::SendStream for BidiStream -where - B: Buf, -{ - type Error = SendStreamError; - - fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.send.poll_ready(cx) - } - - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.send.poll_finish(cx) - } - - fn reset(&mut self, reset_code: u64) { - self.send.reset(reset_code) - } - - fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { - self.send.send_data(data) - } - - fn send_id(&self) -> StreamId { - self.send.send_id() - } -} -impl quic::SendStreamUnframed for BidiStream -where - B: Buf, -{ - fn poll_send( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut D, - ) -> Poll> { - self.send.poll_send(cx, buf) - } -} - -/// Quinn-backed receive stream -/// -/// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`]. -pub struct RecvStream { - stream: Option, - read_chunk_fut: ReadChunkFuture, -} - -type ReadChunkFuture = ReusableBoxFuture< - 'static, - ( - quinn::RecvStream, - Result, quinn::ReadError>, - ), ->; - -impl RecvStream { - fn new(stream: quinn::RecvStream) -> Self { - Self { - stream: Some(stream), - // Should only allocate once the first time it's used - read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }), - } - } -} - -impl quic::RecvStream for RecvStream { - type Buf = Bytes; - type Error = ReadError; - - fn poll_data( - &mut self, - cx: &mut task::Context<'_>, - ) -> Poll, Self::Error>> { - if let Some(mut stream) = self.stream.take() { - self.read_chunk_fut.set(async move { - let chunk = stream.read_chunk(usize::MAX, true).await; - (stream, chunk) - }) - }; - - let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx)); - self.stream = Some(stream); - Poll::Ready(Ok(chunk?.map(|c| c.bytes))) - } - - fn stop_sending(&mut self, error_code: u64) { - self.stream - .as_mut() - .unwrap() - .stop(VarInt::from_u64(error_code).expect("invalid error_code")) - .ok(); - } - - fn recv_id(&self) -> StreamId { - self.stream - .as_ref() - .unwrap() - .id() - .0 - .try_into() - .expect("invalid stream id") - } -} - -/// The error type for [`RecvStream`] -/// -/// Wraps errors that occur when reading from a receive stream. -#[derive(Debug)] -pub struct ReadError(quinn::ReadError); - -impl From for std::io::Error { - fn from(value: ReadError) -> Self { - value.0.into() - } -} - -impl std::error::Error for ReadError { - fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { - self.0.source() - } -} - -impl fmt::Display for ReadError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -impl From for Arc { - fn from(e: ReadError) -> Self { - Arc::new(e) - } -} - -impl From for ReadError { - fn from(e: quinn::ReadError) -> Self { - Self(e) - } -} - -impl Error for ReadError { - fn is_timeout(&self) -> bool { - matches!( - self.0, - quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut) - ) - } - - fn err_code(&self) -> Option { - match self.0 { - quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( - quinn_proto::ApplicationClose { error_code, .. }, - )) => Some(error_code.into_inner()), - quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), - _ => None, - } - } -} - -/// Quinn-backed send stream -/// -/// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`]. -pub struct SendStream { - stream: Option, - writing: Option>, - write_fut: WriteFuture, -} - -type WriteFuture = - ReusableBoxFuture<'static, (quinn::SendStream, Result)>; - -impl SendStream -where - B: Buf, -{ - fn new(stream: quinn::SendStream) -> SendStream { - Self { - stream: Some(stream), - writing: None, - write_fut: ReusableBoxFuture::new(async { unreachable!() }), - } - } -} - -impl quic::SendStream for SendStream -where - B: Buf, -{ - type Error = SendStreamError; - - fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { - if let Some(ref mut data) = self.writing { - while data.has_remaining() { - if let Some(mut stream) = self.stream.take() { - let chunk = data.chunk().to_owned(); // FIXME - avoid copy - self.write_fut.set(async move { - let ret = stream.write(&chunk).await; - (stream, ret) - }); - } - - let (stream, res) = ready!(self.write_fut.poll(cx)); - self.stream = Some(stream); - match res { - Ok(cnt) => data.advance(cnt), - Err(err) => { - return Poll::Ready(Err(SendStreamError::Write(err))); - } - } - } - } - self.writing = None; - Poll::Ready(Ok(())) - } - - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.stream - .as_mut() - .unwrap() - .poll_finish(cx) - .map_err(Into::into) - } - - fn reset(&mut self, reset_code: u64) { - let _ = self - .stream - .as_mut() - .unwrap() - .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); - } - - fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { - if self.writing.is_some() { - return Err(Self::Error::NotReady); - } - self.writing = Some(data.into()); - Ok(()) - } - - fn send_id(&self) -> StreamId { - self.stream - .as_ref() - .unwrap() - .id() - .0 - .try_into() - .expect("invalid stream id") - } -} - -impl quic::SendStreamUnframed for SendStream -where - B: Buf, -{ - fn poll_send( - &mut self, - cx: &mut task::Context<'_>, - buf: &mut D, - ) -> Poll> { - if self.writing.is_some() { - // This signifies a bug in implementation - panic!("poll_send called while send stream is not ready") - } - - let s = Pin::new(self.stream.as_mut().unwrap()); - - let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); - match res { - Ok(written) => { - buf.advance(written); - Poll::Ready(Ok(written)) - } - Err(err) => { - // We are forced to use AsyncWrite for now because we cannot store - // the result of a call to: - // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result. - // - // This is why we have to unpack the error from io::Error instead of having it - // returned directly. This should not panic as long as quinn's AsyncWrite impl - // doesn't change. - let err = err - .into_inner() - .expect("write stream returned an empty error") - .downcast::() - .expect("write stream returned an error which type is not WriteError"); - - Poll::Ready(Err(SendStreamError::Write(*err))) - } - } - } -} - -/// The error type for [`SendStream`] -/// -/// Wraps errors that can happen writing to or polling a send stream. -#[derive(Debug)] -pub enum SendStreamError { - /// Errors when writing, wrapping a [`quinn::WriteError`] - Write(WriteError), - /// Error when the stream is not ready, because it is still sending - /// data from a previous call - NotReady, -} - -impl From for std::io::Error { - fn from(value: SendStreamError) -> Self { - match value { - SendStreamError::Write(err) => err.into(), - SendStreamError::NotReady => { - std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") - } - } - } -} - -impl std::error::Error for SendStreamError {} - -impl Display for SendStreamError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) - } -} - -impl From for SendStreamError { - fn from(e: WriteError) -> Self { - Self::Write(e) - } -} - -impl Error for SendStreamError { - fn is_timeout(&self) -> bool { - matches!( - self, - Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::TimedOut - )) - ) - } - - fn err_code(&self) -> Option { - match self { - Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), - Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }), - )) => Some(error_code.into_inner()), - _ => None, - } - } -} - -impl From for Arc { - fn from(e: SendStreamError) -> Self { - Arc::new(e) - } -} diff --git a/submodules/quinn b/submodules/quinn deleted file mode 160000 index e1e1e6e..0000000 --- a/submodules/quinn +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e1e1e6e392a382fbded42ca010505fecb8fe3655 diff --git a/submodules/rusty-http-cache-semantics b/submodules/rusty-http-cache-semantics index 3cd0917..88d23c2 160000 --- a/submodules/rusty-http-cache-semantics +++ b/submodules/rusty-http-cache-semantics @@ -1 +1 @@ -Subproject commit 3cd09170305753309d86e88b9427827cca0de0dd +Subproject commit 88d23c2f5a3ac36295dff4a804968c43932ba46b diff --git a/submodules/s2n-quic b/submodules/s2n-quic deleted file mode 160000 index c88e64b..0000000 --- a/submodules/s2n-quic +++ /dev/null @@ -1 +0,0 @@ -Subproject commit c88e64b6c58891651954834207d974de80e9bba8 diff --git a/submodules/s2n-quic-h3/Cargo.toml b/submodules/s2n-quic-h3/Cargo.toml new file mode 100644 index 0000000..59a5403 --- /dev/null +++ b/submodules/s2n-quic-h3/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "s2n-quic-h3" +# this in an unpublished internal crate so the version should not be changed +version = "0.1.0" +authors = ["AWS s2n"] +edition = "2021" +rust-version = "1.63" +license = "Apache-2.0" +# this contains an http3 implementation for testing purposes and should not be published +publish = false + +[dependencies] +bytes = { version = "1", default-features = false } +futures = { version = "0.3", default-features = false } +h3 = { path = "../h3/h3/" } +s2n-quic = "1.33.0" +s2n-quic-core = "0.33.0" diff --git a/submodules/s2n-quic-h3/README.md b/submodules/s2n-quic-h3/README.md new file mode 100644 index 0000000..aed9475 --- /dev/null +++ b/submodules/s2n-quic-h3/README.md @@ -0,0 +1,10 @@ +# s2n-quic-h3 + +This is an internal crate used by [s2n-quic](https://github.com/aws/s2n-quic) written as a proof of concept for implementing HTTP3 on top of s2n-quic. The API is not currently stable and should not be used directly. + +## License + +This project is licensed under the [Apache-2.0 License][license-url]. + +[license-badge]: https://img.shields.io/badge/license-apache-blue.svg +[license-url]: https://aws.amazon.com/apache-2-0/ diff --git a/submodules/s2n-quic-h3/src/lib.rs b/submodules/s2n-quic-h3/src/lib.rs new file mode 100644 index 0000000..c85f197 --- /dev/null +++ b/submodules/s2n-quic-h3/src/lib.rs @@ -0,0 +1,7 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +mod s2n_quic; + +pub use self::s2n_quic::*; +pub use h3; diff --git a/submodules/s2n-quic-h3/src/s2n_quic.rs b/submodules/s2n-quic-h3/src/s2n_quic.rs new file mode 100644 index 0000000..dffa19b --- /dev/null +++ b/submodules/s2n-quic-h3/src/s2n_quic.rs @@ -0,0 +1,506 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 + +use bytes::{Buf, Bytes}; +use futures::ready; +use h3::quic::{self, Error, StreamId, WriteBuf}; +use s2n_quic::stream::{BidirectionalStream, ReceiveStream}; +use s2n_quic_core::varint::VarInt; +use std::{ + convert::TryInto, + fmt::{self, Display}, + sync::Arc, + task::{self, Poll}, +}; + +pub struct Connection { + conn: s2n_quic::connection::Handle, + bidi_acceptor: s2n_quic::connection::BidirectionalStreamAcceptor, + recv_acceptor: s2n_quic::connection::ReceiveStreamAcceptor, +} + +impl Connection { + pub fn new(new_conn: s2n_quic::Connection) -> Self { + let (handle, acceptor) = new_conn.split(); + let (bidi, recv) = acceptor.split(); + + Self { + conn: handle, + bidi_acceptor: bidi, + recv_acceptor: recv, + } + } +} + +#[derive(Debug)] +pub struct ConnectionError(s2n_quic::connection::Error); + +impl std::error::Error for ConnectionError {} + +impl fmt::Display for ConnectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Error for ConnectionError { + fn is_timeout(&self) -> bool { + matches!(self.0, s2n_quic::connection::Error::IdleTimerExpired { .. }) + } + + fn err_code(&self) -> Option { + match self.0 { + s2n_quic::connection::Error::Application { error, .. } => Some(error.into()), + _ => None, + } + } +} + +impl From for ConnectionError { + fn from(e: s2n_quic::connection::Error) -> Self { + Self(e) + } +} + +impl quic::Connection for Connection +where + B: Buf, +{ + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type OpenStreams = OpenStreams; + type Error = ConnectionError; + + fn poll_accept_recv( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let recv = match ready!(self.recv_acceptor.poll_accept_receive_stream(cx))? { + Some(x) => x, + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) + } + + fn poll_accept_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let (recv, send) = match ready!(self.bidi_acceptor.poll_accept_bidirectional_stream(cx))? { + Some(x) => x.split(), + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::BidiStream { + send: Self::SendStream::new(send), + recv: Self::RecvStream::new(recv), + }))) + } + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; + Ok(stream.into()).into() + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_send_stream(cx))?; + Ok(stream.into()).into() + } + + fn opener(&self) -> Self::OpenStreams { + OpenStreams { + conn: self.conn.clone(), + } + } + + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + self.conn.close( + code.value() + .try_into() + .expect("s2n-quic supports error codes up to 2^62-1"), + ); + } +} + +pub struct OpenStreams { + conn: s2n_quic::connection::Handle, +} + +impl quic::OpenStreams for OpenStreams +where + B: Buf, +{ + type BidiStream = BidiStream; + type SendStream = SendStream; + type RecvStream = RecvStream; + type Error = ConnectionError; + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; + Ok(stream.into()).into() + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + let stream = ready!(self.conn.poll_open_send_stream(cx))?; + Ok(stream.into()).into() + } + + fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { + self.conn.close( + code.value() + .try_into() + .unwrap_or_else(|_| VarInt::MAX.into()), + ); + } +} + +impl Clone for OpenStreams { + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + } + } +} + +pub struct BidiStream +where + B: Buf, +{ + send: SendStream, + recv: RecvStream, +} + +impl quic::BidiStream for BidiStream +where + B: Buf, +{ + type SendStream = SendStream; + type RecvStream = RecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + (self.send, self.recv) + } +} + +impl quic::RecvStream for BidiStream +where + B: Buf, +{ + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + self.recv.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.recv.stop_sending(error_code) + } + + fn recv_id(&self) -> StreamId { + self.recv.stream.id().try_into().expect("invalid stream id") + } +} + +impl quic::SendStream for BidiStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_ready(cx) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.send.reset(reset_code) + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn send_id(&self) -> StreamId { + self.send.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for BidiStream +where + B: Buf, +{ + fn from(bidi: BidirectionalStream) -> Self { + let (recv, send) = bidi.split(); + BidiStream { + send: send.into(), + recv: recv.into(), + } + } +} + +pub struct RecvStream { + stream: s2n_quic::stream::ReceiveStream, +} + +impl RecvStream { + fn new(stream: s2n_quic::stream::ReceiveStream) -> Self { + Self { stream } + } +} + +impl quic::RecvStream for RecvStream { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let buf = ready!(self.stream.poll_receive(cx))?; + Ok(buf).into() + } + + fn stop_sending(&mut self, error_code: u64) { + let _ = self.stream.stop_sending( + s2n_quic::application::Error::new(error_code) + .expect("s2n-quic supports error codes up to 2^62-1"), + ); + } + + fn recv_id(&self) -> StreamId { + self.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for RecvStream { + fn from(recv: ReceiveStream) -> Self { + RecvStream::new(recv) + } +} + +#[derive(Debug)] +pub struct ReadError(s2n_quic::stream::Error); + +impl std::error::Error for ReadError {} + +impl fmt::Display for ReadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl From for Arc { + fn from(e: ReadError) -> Self { + Arc::new(e) + } +} + +impl From for ReadError { + fn from(e: s2n_quic::stream::Error) -> Self { + Self(e) + } +} + +impl Error for ReadError { + fn is_timeout(&self) -> bool { + matches!( + self.0, + s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::IdleTimerExpired { .. }, + .. + } + ) + } + + fn err_code(&self) -> Option { + match self.0 { + s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::Application { error, .. }, + .. + } => Some(error.into()), + s2n_quic::stream::Error::StreamReset { error, .. } => Some(error.into()), + _ => None, + } + } +} + +pub struct SendStream { + stream: s2n_quic::stream::SendStream, + chunk: Option, + buf: Option>, // TODO: Replace with buf: PhantomData + // after https://github.com/hyperium/h3/issues/78 is resolved +} + +impl SendStream +where + B: Buf, +{ + fn new(stream: s2n_quic::stream::SendStream) -> SendStream { + Self { + stream, + chunk: None, + buf: Default::default(), + } + } +} + +impl quic::SendStream for SendStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + loop { + // try to flush the current chunk if we have one + if let Some(chunk) = self.chunk.as_mut() { + ready!(self.stream.poll_send(chunk, cx))?; + + // s2n-quic will take the whole chunk on send, even if it exceeds the limits + debug_assert!(chunk.is_empty()); + self.chunk = None; + } + + // try to take the next chunk from the WriteBuf + if let Some(ref mut data) = self.buf { + let len = data.chunk().len(); + + // if the write buf is empty, then clear it and break + if len == 0 { + self.buf = None; + break; + } + + // copy the first chunk from WriteBuf and prepare it to flush + let chunk = data.copy_to_bytes(len); + self.chunk = Some(chunk); + + // loop back around to flush the chunk + continue; + } + + // if we didn't have either a chunk or WriteBuf, then we're ready + break; + } + + Poll::Ready(Ok(())) + + // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved + // self.available_bytes = ready!(self.stream.poll_send_ready(cx))?; + // Poll::Ready(Ok(())) + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + if self.buf.is_some() { + return Err(Self::Error::NotReady); + } + self.buf = Some(data.into()); + Ok(()) + + // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved + // let mut data = data.into(); + // while self.available_bytes > 0 && data.has_remaining() { + // let len = data.chunk().len(); + // let chunk = data.copy_to_bytes(len); + // self.stream.send_data(chunk)?; + // self.available_bytes = self.available_bytes.saturating_sub(len); + // } + // Ok(()) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + // ensure all chunks are flushed to the QUIC stream before finishing + ready!(self.poll_ready(cx))?; + self.stream.finish()?; + Ok(()).into() + } + + fn reset(&mut self, reset_code: u64) { + let _ = self + .stream + .reset(reset_code.try_into().unwrap_or_else(|_| VarInt::MAX.into())); + } + + fn send_id(&self) -> StreamId { + self.stream.id().try_into().expect("invalid stream id") + } +} + +impl From for SendStream +where + B: Buf, +{ + fn from(send: s2n_quic::stream::SendStream) -> Self { + SendStream::new(send) + } +} + +#[derive(Debug)] +pub enum SendStreamError { + Write(s2n_quic::stream::Error), + NotReady, +} + +impl std::error::Error for SendStreamError {} + +impl Display for SendStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{self:?}") + } +} + +impl From for SendStreamError { + fn from(e: s2n_quic::stream::Error) -> Self { + Self::Write(e) + } +} + +impl Error for SendStreamError { + fn is_timeout(&self) -> bool { + matches!( + self, + Self::Write(s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::IdleTimerExpired { .. }, + .. + }) + ) + } + + fn err_code(&self) -> Option { + match self { + Self::Write(s2n_quic::stream::Error::StreamReset { error, .. }) => { + Some((*error).into()) + } + Self::Write(s2n_quic::stream::Error::ConnectionError { + error: s2n_quic::connection::Error::Application { error, .. }, + .. + }) => Some((*error).into()), + _ => None, + } + } +} + +impl From for Arc { + fn from(e: SendStreamError) -> Self { + Arc::new(e) + } +}