Merge pull request #132 from junkurihara/feat/hyper-1.0
Feat hyper 1.0 again
This commit is contained in:
		
				commit
				
					
						6a298d7a9b
					
				
			
		
					 101 changed files with 7355 additions and 2011 deletions
				
			
		|  | @ -4,3 +4,4 @@ bench/ | |||
| .private/ | ||||
| .github/ | ||||
| example-certs/ | ||||
| legacy-lib/ | ||||
|  |  | |||
							
								
								
									
										56
									
								
								.github/workflows/release.yml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										56
									
								
								.github/workflows/release.yml
									
										
									
									
										vendored
									
									
								
							|  | @ -44,35 +44,35 @@ jobs: | |||
|             platform: linux/arm64 | ||||
|             tags-suffix: "-s2n" | ||||
| 
 | ||||
|           - target: "gnu" | ||||
|             build-feature: "-native-roots" | ||||
|             platform: linux/amd64 | ||||
|             tags-suffix: "-native-roots" | ||||
|           # - target: "gnu" | ||||
|           #   build-feature: "-native-roots" | ||||
|           #   platform: linux/amd64 | ||||
|           #   tags-suffix: "-native-roots" | ||||
| 
 | ||||
|           - target: "gnu" | ||||
|             build-feature: "-native-roots" | ||||
|             platform: linux/arm64 | ||||
|             tags-suffix: "-native-roots" | ||||
|           # - target: "gnu" | ||||
|           #   build-feature: "-native-roots" | ||||
|           #   platform: linux/arm64 | ||||
|           #   tags-suffix: "-native-roots" | ||||
| 
 | ||||
|           - target: "musl" | ||||
|             build-feature: "-native-roots" | ||||
|             platform: linux/amd64 | ||||
|             tags-suffix: "-slim-native-roots" | ||||
|           # - target: "musl" | ||||
|           #   build-feature: "-native-roots" | ||||
|           #   platform: linux/amd64 | ||||
|           #   tags-suffix: "-slim-native-roots" | ||||
| 
 | ||||
|           - target: "musl" | ||||
|             build-feature: "-native-roots" | ||||
|             platform: linux/arm64 | ||||
|             tags-suffix: "-slim-native-roots" | ||||
|           # - target: "musl" | ||||
|           #   build-feature: "-native-roots" | ||||
|           #   platform: linux/arm64 | ||||
|           #   tags-suffix: "-slim-native-roots" | ||||
| 
 | ||||
|           - target: "gnu" | ||||
|             build-feature: "-s2n-native-roots" | ||||
|             platform: linux/amd64 | ||||
|             tags-suffix: "-s2n-native-roots" | ||||
|           # - target: "gnu" | ||||
|           #   build-feature: "-s2n-native-roots" | ||||
|           #   platform: linux/amd64 | ||||
|           #   tags-suffix: "-s2n-native-roots" | ||||
| 
 | ||||
|           - target: "gnu" | ||||
|             build-feature: "-s2n-native-roots" | ||||
|             platform: linux/arm64 | ||||
|             tags-suffix: "-s2n-native-roots" | ||||
|           # - target: "gnu" | ||||
|           #   build-feature: "-s2n-native-roots" | ||||
|           #   platform: linux/arm64 | ||||
|           #   tags-suffix: "-s2n-native-roots" | ||||
| 
 | ||||
|     steps: | ||||
|       - run: "echo 'The relese triggering workflows passed'" | ||||
|  | @ -81,8 +81,8 @@ jobs: | |||
|         id: "set-env" | ||||
|         run: | | ||||
|           if [ ${{ matrix.platform }} == 'linux/amd64' ]; then PLATFORM_MAP="x86_64"; else PLATFORM_MAP="aarch64"; fi | ||||
|           if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_NAME="-nightly"; else BUILD_NAME=""; fi | ||||
|           if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_IMG="nightly"; else BUILD_IMG="latest"; fi | ||||
|           if [ ${{ github.ref_name == 'develop' && github.event.client_payload.pull_request.head == 'develop' && github.event.client_payload.pull_request.base == 'main' }} || ${{ github.ref_name == 'main' }}]; then BUILD_NAME=""; else BUILD_NAME="-nightly"; fi | ||||
|           if [ ${{ github.ref_name }} == 'main' ]; then BUILD_IMG="latest"; else BUILD_IMG="nightly"; fi | ||||
|           echo "build_img=${BUILD_IMG}" >> $GITHUB_OUTPUT | ||||
|           echo "target_name=rpxy${BUILD_NAME}-${PLATFORM_MAP}-unknown-linux-${{ matrix.target }}${{ matrix.build-feature }}" >> $GITHUB_OUTPUT | ||||
| 
 | ||||
|  | @ -93,7 +93,7 @@ jobs: | |||
|           docker cp ${CONTAINER_ID}:/rpxy/bin/rpxy /tmp/${{ steps.set-env.outputs.target_name }} | ||||
| 
 | ||||
|       - name: "upload artifacts" | ||||
|         uses: actions/upload-artifact@v3 | ||||
|         uses: actions/upload-artifact@v4 | ||||
|         with: | ||||
|           name: ${{ steps.set-env.outputs.target_name }} | ||||
|           path: "/tmp/${{ steps.set-env.outputs.target_name }}" | ||||
|  | @ -122,7 +122,7 @@ jobs: | |||
| 
 | ||||
|       - name: download artifacts | ||||
|         if: ${{ steps.regex-match.outputs.match != ''}} | ||||
|         uses: actions/download-artifact@v3 | ||||
|         uses: actions/download-artifact@v4 | ||||
|         with: | ||||
|           path: /tmp/rpxy | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										88
									
								
								.github/workflows/release_docker.yml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										88
									
								
								.github/workflows/release_docker.yml
									
										
									
									
										vendored
									
									
								
							|  | @ -2,6 +2,7 @@ name: Release - Build and publish docker, and trigger package release | |||
| on: | ||||
|   push: | ||||
|     branches: | ||||
|       - "feat/*" | ||||
|       - "develop" | ||||
|   pull_request: | ||||
|     types: [closed] | ||||
|  | @ -44,7 +45,7 @@ jobs: | |||
|           - target: "s2n" | ||||
|             dockerfile: ./docker/Dockerfile | ||||
|             build-args: | | ||||
|               "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache" | ||||
|               "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-tls-backend" | ||||
|               "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" | ||||
|             platforms: linux/amd64,linux/arm64 | ||||
|             tags-suffix: "-s2n" | ||||
|  | @ -53,42 +54,42 @@ jobs: | |||
|               jqtype/rpxy:s2n | ||||
|               ghcr.io/junkurihara/rust-rpxy:s2n | ||||
| 
 | ||||
|           - target: "native-roots" | ||||
|             dockerfile: ./docker/Dockerfile | ||||
|             platforms: linux/amd64,linux/arm64 | ||||
|             build-args: | | ||||
|               "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" | ||||
|             tags-suffix: "-native-roots" | ||||
|             # Aliases must be used only for release builds | ||||
|             aliases: | | ||||
|               jqtype/rpxy:native-roots | ||||
|               ghcr.io/junkurihara/rust-rpxy:native-roots | ||||
|           # - target: "native-roots" | ||||
|           #   dockerfile: ./docker/Dockerfile | ||||
|           #   platforms: linux/amd64,linux/arm64 | ||||
|           #   build-args: | | ||||
|           #     "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" | ||||
|           #   tags-suffix: "-native-roots" | ||||
|           #   # Aliases must be used only for release builds | ||||
|           #   aliases: | | ||||
|           #     jqtype/rpxy:native-roots | ||||
|           #     ghcr.io/junkurihara/rust-rpxy:native-roots | ||||
| 
 | ||||
|           - target: "slim-native-roots" | ||||
|             dockerfile: ./docker/Dockerfile-slim | ||||
|             build-args: | | ||||
|               "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" | ||||
|             build-contexts: | | ||||
|               messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl | ||||
|               messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl | ||||
|             platforms: linux/amd64,linux/arm64 | ||||
|             tags-suffix: "-slim-native-roots" | ||||
|             # Aliases must be used only for release builds | ||||
|             aliases: | | ||||
|               jqtype/rpxy:slim-native-roots | ||||
|               ghcr.io/junkurihara/rust-rpxy:slim-native-roots | ||||
|           # - target: "slim-native-roots" | ||||
|           #   dockerfile: ./docker/Dockerfile-slim | ||||
|           #   build-args: | | ||||
|           #     "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" | ||||
|           #   build-contexts: | | ||||
|           #     messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl | ||||
|           #     messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl | ||||
|           #   platforms: linux/amd64,linux/arm64 | ||||
|           #   tags-suffix: "-slim-native-roots" | ||||
|           #   # Aliases must be used only for release builds | ||||
|           #   aliases: | | ||||
|           #     jqtype/rpxy:slim-native-roots | ||||
|           #     ghcr.io/junkurihara/rust-rpxy:slim-native-roots | ||||
| 
 | ||||
|           - target: "s2n-native-roots" | ||||
|             dockerfile: ./docker/Dockerfile | ||||
|             build-args: | | ||||
|               "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-roots" | ||||
|               "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" | ||||
|             platforms: linux/amd64,linux/arm64 | ||||
|             tags-suffix: "-s2n-native-roots" | ||||
|             # Aliases must be used only for release builds | ||||
|             aliases: | | ||||
|               jqtype/rpxy:s2n-native-roots | ||||
|               ghcr.io/junkurihara/rust-rpxy:s2n-native-roots | ||||
|           # - target: "s2n-native-roots" | ||||
|           #   dockerfile: ./docker/Dockerfile | ||||
|           #   build-args: | | ||||
|           #     "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-roots" | ||||
|           #     "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" | ||||
|           #   platforms: linux/amd64,linux/arm64 | ||||
|           #   tags-suffix: "-s2n-native-roots" | ||||
|           #   # Aliases must be used only for release builds | ||||
|           #   aliases: | | ||||
|           #     jqtype/rpxy:s2n-native-roots | ||||
|           #     ghcr.io/junkurihara/rust-rpxy:s2n-native-roots | ||||
| 
 | ||||
|     steps: | ||||
|       - name: Checkout | ||||
|  | @ -135,6 +136,23 @@ jobs: | |||
|       #     platforms: linux/amd64 | ||||
|       #     labels: ${{ steps.meta.outputs.labels }} | ||||
| 
 | ||||
|       - name: Unstable build and push from develop branch | ||||
|         if: ${{ startsWith(github.ref_name, 'feat/') && (github.event_name == 'push') }} | ||||
|         uses: docker/build-push-action@v5 | ||||
|         with: | ||||
|           context: . | ||||
|           build-args: ${{ matrix.build-args }} | ||||
|           push: true | ||||
|           tags: | | ||||
|             ${{ env.GHCR }}/${{ env.GHCR_IMAGE_NAME }}:unstable${{ matrix.tags-suffix }} | ||||
|             ${{ env.DH_REGISTRY_NAME }}:unstable${{ matrix.tags-suffix }} | ||||
|           build-contexts: ${{ matrix.build-contexts }} | ||||
|           file: ${{ matrix.dockerfile }} | ||||
|           cache-from: type=gha,scope=rpxy-unstable-${{ matrix.target }} | ||||
|           cache-to: type=gha,mode=max,scope=rpxy-unstable-${{ matrix.target }} | ||||
|           platforms: linux/amd64 | ||||
|           labels: ${{ steps.meta.outputs.labels }} | ||||
| 
 | ||||
|       - name: Nightly build and push from develop branch | ||||
|         if: ${{ (github.ref_name == 'develop') && (github.event_name == 'push') }} | ||||
|         uses: docker/build-push-action@v5 | ||||
|  |  | |||
							
								
								
									
										6
									
								
								.gitmodules
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.gitmodules
									
										
									
									
										vendored
									
									
								
							|  | @ -1,12 +1,6 @@ | |||
| [submodule "submodules/h3"] | ||||
| 	path = submodules/h3 | ||||
| 	url = git@github.com:junkurihara/h3.git | ||||
| [submodule "submodules/quinn"] | ||||
| 	path = submodules/quinn | ||||
| 	url = git@github.com:junkurihara/quinn.git | ||||
| [submodule "submodules/s2n-quic"] | ||||
| 	path = submodules/s2n-quic | ||||
| 	url = git@github.com:junkurihara/s2n-quic.git | ||||
| [submodule "submodules/rusty-http-cache-semantics"] | ||||
| 	path = submodules/rusty-http-cache-semantics | ||||
| 	url = git@github.com:junkurihara/rusty-http-cache-semantics.git | ||||
|  |  | |||
|  | @ -2,6 +2,13 @@ | |||
| 
 | ||||
| ## 0.7.0  (unreleased) | ||||
| 
 | ||||
| - Breaking: `hyper`-1.0 for both server and client modules. | ||||
| - Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `keep_original_host`. That is, `rpxy` always override the host header by the upstream hostname (backend uri host name) by default. If this reverse option specified, original `host` header is maintained or added from the value of url request line. | ||||
| - Breaking: Introduced `native-tls-backend` feature to use the native TLS engine to access backend applications. | ||||
| - Redesigned: Cache structure is totally redesigned with more memory-efficient way to read from cache file, and more secure way to strongly bind memory-objects with files with hash values. | ||||
| - Redesigned: HTTP body handling flow is also redesigned with more memory-and-time efficient techniques without putting the whole objects on memory by using `futures::stream::Stream` and `futures::channel::mpsc` | ||||
| - Refactor: lots of minor improvements | ||||
| 
 | ||||
| ## 0.6.2 | ||||
| 
 | ||||
| ### Improvement | ||||
|  |  | |||
|  | @ -104,11 +104,11 @@ If you want to host multiple and distinct domain names in a single IP address/po | |||
| ```toml | ||||
| default_application = "app1" | ||||
| 
 | ||||
| [app.app1] | ||||
| [apps.app1] | ||||
| server_name = "app1.example.com" | ||||
| #... | ||||
| 
 | ||||
| [app.app2] | ||||
| [apps.app2] | ||||
| server_name = "app2.example.org" | ||||
| #... | ||||
| ``` | ||||
|  |  | |||
							
								
								
									
										2
									
								
								TODO.md
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								TODO.md
									
										
									
									
									
								
							|  | @ -1,9 +1,11 @@ | |||
| # TODO List | ||||
| 
 | ||||
| - Support of `rustls-0.22` along with `hyper-1.0`. Maybe `hyper-rustls` is the most difficult part. | ||||
| - [Done in 0.6.0] But we need more sophistication on `Forwarder` struct. ~~Fix strategy for `h2c` requests on forwarded requests upstream. This needs to update forwarder definition. Also, maybe forwarder would have a cache corresponding to the following task.~~ | ||||
| - [Initial implementation in v0.6.0] ~~**Cache option for the response with `Cache-Control: public` header directive ([#55](https://github.com/junkurihara/rust-rpxy/issues/55))**~~ Using `lru` crate might be inefficient in terms of the speed. | ||||
|   - Consider more sophisticated architecture for cache | ||||
|   - Persistent cache (if possible). | ||||
|   - More secure cache file object naming | ||||
|   - etc etc | ||||
| - Improvement of path matcher | ||||
| - More flexible option for rewriting path | ||||
|  |  | |||
|  | @ -57,7 +57,7 @@ upstream = [ | |||
| ] | ||||
| load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) | ||||
| upstream_options = [ | ||||
|   "override_host", | ||||
|   "keep_original_host",   # do not overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy) | ||||
|   "force_http2_upstream", # mutually exclusive with "force_http11_upstream" | ||||
| ] | ||||
| 
 | ||||
|  | @ -76,7 +76,7 @@ upstream = [ | |||
| ] | ||||
| load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) | ||||
| upstream_options = [ | ||||
|   "override_host", | ||||
|   "disable_override_host", | ||||
|   "upgrade_insecure_requests", | ||||
|   "force_http11_upstream", | ||||
| ] | ||||
|  |  | |||
							
								
								
									
										89
									
								
								legacy-lib/Cargo.toml
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										89
									
								
								legacy-lib/Cargo.toml
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,89 @@ | |||
| [package] | ||||
| name = "rpxy-lib-legacy" | ||||
| version = "0.6.2" | ||||
| authors = ["Jun Kurihara"] | ||||
| homepage = "https://github.com/junkurihara/rust-rpxy" | ||||
| repository = "https://github.com/junkurihara/rust-rpxy" | ||||
| license = "MIT" | ||||
| readme = "../README.md" | ||||
| edition = "2021" | ||||
| publish = false | ||||
| 
 | ||||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||
| 
 | ||||
| [features] | ||||
| default = ["http3-quinn", "sticky-cookie", "cache"] | ||||
| http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] | ||||
| http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] | ||||
| sticky-cookie = ["base64", "sha2", "chrono"] | ||||
| cache = ["http-cache-semantics", "lru"] | ||||
| native-roots = ["hyper-rustls/native-tokio"] | ||||
| 
 | ||||
| [dependencies] | ||||
| rand = "0.8.5" | ||||
| rustc-hash = "1.1.0" | ||||
| bytes = "1.5.0" | ||||
| derive_builder = "0.12.0" | ||||
| futures = { version = "0.3.29", features = ["alloc", "async-await"] } | ||||
| tokio = { version = "1.34.0", default-features = false, features = [ | ||||
|   "net", | ||||
|   "rt-multi-thread", | ||||
|   "time", | ||||
|   "sync", | ||||
|   "macros", | ||||
|   "fs", | ||||
| ] } | ||||
| async-trait = "0.1.74" | ||||
| hot_reload = "0.1.4" # reloading certs | ||||
| 
 | ||||
| # Error handling | ||||
| anyhow = "1.0.75" | ||||
| thiserror = "1.0.50" | ||||
| 
 | ||||
| # http and tls | ||||
| http = "1.0.0" | ||||
| http-body-util = "0.1.0" | ||||
| hyper = { version = "1.0.1", default-features = false } | ||||
| hyper-util = { version = "0.1.1", features = ["full"] } | ||||
| hyper-rustls = { version = "0.24.2", default-features = false, features = [ | ||||
|   "tokio-runtime", | ||||
|   "webpki-tokio", | ||||
|   "http1", | ||||
|   "http2", | ||||
| ] } | ||||
| tokio-rustls = { version = "0.24.1", features = ["early-data"] } | ||||
| rustls = { version = "0.21.9", default-features = false } | ||||
| webpki = "0.22.4" | ||||
| x509-parser = "0.15.1" | ||||
| 
 | ||||
| # logging | ||||
| tracing = { version = "0.1.40" } | ||||
| 
 | ||||
| # http/3 | ||||
| quinn = { version = "0.10.2", optional = true } | ||||
| h3 = { path = "../submodules/h3/h3/", optional = true } | ||||
| h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } | ||||
| s2n-quic = { version = "1.31.0", default-features = false, features = [ | ||||
|   "provider-tls-rustls", | ||||
| ], optional = true } | ||||
| s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } | ||||
| s2n-quic-rustls = { version = "0.31.0", optional = true } | ||||
| # for UDP socket wit SO_REUSEADDR when h3 with quinn | ||||
| socket2 = { version = "0.5.5", features = ["all"], optional = true } | ||||
| 
 | ||||
| # cache | ||||
| http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } | ||||
| lru = { version = "0.12.0", optional = true } | ||||
| 
 | ||||
| # cookie handling for sticky cookie | ||||
| chrono = { version = "0.4.31", default-features = false, features = [ | ||||
|   "unstable-locales", | ||||
|   "alloc", | ||||
|   "clock", | ||||
| ], optional = true } | ||||
| base64 = { version = "0.21.5", optional = true } | ||||
| sha2 = { version = "0.10.8", default-features = false, optional = true } | ||||
| 
 | ||||
| 
 | ||||
| [dev-dependencies] | ||||
| # http and tls | ||||
							
								
								
									
										77
									
								
								legacy-lib/src/backend/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										77
									
								
								legacy-lib/src/backend/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,77 @@ | |||
| mod load_balance; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| mod load_balance_sticky; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| mod sticky_cookie; | ||||
| mod upstream; | ||||
| mod upstream_opts; | ||||
| 
 | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| pub use self::sticky_cookie::{StickyCookie, StickyCookieValue}; | ||||
| pub use self::{ | ||||
|   load_balance::{LbContext, LoadBalance}, | ||||
|   upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, | ||||
|   upstream_opts::UpstreamOption, | ||||
| }; | ||||
| use crate::{ | ||||
|   certs::CryptoSource, | ||||
|   utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, | ||||
| }; | ||||
| use derive_builder::Builder; | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use std::borrow::Cow; | ||||
| 
 | ||||
| /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 | ||||
| #[derive(Builder)] | ||||
| pub struct Backend<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   #[builder(setter(into))] | ||||
|   /// backend application name, e.g., app1
 | ||||
|   pub app_name: String, | ||||
|   #[builder(setter(custom))] | ||||
|   /// server name, e.g., example.com, in String ascii lower case
 | ||||
|   pub server_name: String, | ||||
|   /// struct of reverse proxy serving incoming request
 | ||||
|   pub reverse_proxy: ReverseProxy, | ||||
| 
 | ||||
|   /// tls settings: https redirection with 30x
 | ||||
|   #[builder(default)] | ||||
|   pub https_redirection: Option<bool>, | ||||
| 
 | ||||
|   /// TLS settings: source meta for server cert, key, client ca cert
 | ||||
|   #[builder(default)] | ||||
|   pub crypto_source: Option<T>, | ||||
| } | ||||
| impl<'a, T> BackendBuilder<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.server_name = Some(server_name.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// HashMap and some meta information for multiple Backend structs.
 | ||||
| pub struct Backends<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub apps: HashMap<ServerNameBytesExp, Backend<T>>, // hyper::uriで抜いたhostで引っ掛ける
 | ||||
|   pub default_server_name_bytes: Option<ServerNameBytesExp>, // for plaintext http
 | ||||
| } | ||||
| 
 | ||||
| impl<T> Backends<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   #[allow(clippy::new_without_default)] | ||||
|   pub fn new() -> Self { | ||||
|     Backends { | ||||
|       apps: HashMap::<ServerNameBytesExp, Backend<T>>::default(), | ||||
|       default_server_name_bytes: None, | ||||
|     } | ||||
|   } | ||||
| } | ||||
							
								
								
									
										201
									
								
								legacy-lib/src/backend/upstream.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								legacy-lib/src/backend/upstream.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,201 @@ | |||
| #[cfg(feature = "sticky-cookie")] | ||||
| use super::load_balance::LbStickyRoundRobinBuilder; | ||||
| use super::load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}; | ||||
| use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; | ||||
| use crate::log::*; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| use base64::{engine::general_purpose, Engine as _}; | ||||
| use derive_builder::Builder; | ||||
| use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| use sha2::{Digest, Sha256}; | ||||
| use std::borrow::Cow; | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct ReverseProxy { | ||||
|   pub upstream: HashMap<PathNameBytesExp, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 | ||||
| } | ||||
| 
 | ||||
| impl ReverseProxy { | ||||
|   /// Get an appropriate upstream destination for given path string.
 | ||||
|   pub fn get<'a>(&self, path_str: impl Into<Cow<'a, str>>) -> Option<&UpstreamGroup> { | ||||
|     // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、
 | ||||
|     // コスト的にこの程度で十分
 | ||||
|     let path_bytes = &path_str.to_path_name_vec(); | ||||
| 
 | ||||
|     let matched_upstream = self | ||||
|       .upstream | ||||
|       .iter() | ||||
|       .filter(|(route_bytes, _)| { | ||||
|         match path_bytes.starts_with(route_bytes) { | ||||
|           true => { | ||||
|             route_bytes.len() == 1 // route = '/', i.e., default
 | ||||
|             || match path_bytes.get(route_bytes.len()) { | ||||
|               None => true, // exact case
 | ||||
|               Some(p) => p == &b'/', // sub-path case
 | ||||
|             } | ||||
|           } | ||||
|           _ => false, | ||||
|         } | ||||
|       }) | ||||
|       .max_by_key(|(route_bytes, _)| route_bytes.len()); | ||||
|     if let Some((_path, u)) = matched_upstream { | ||||
|       debug!( | ||||
|         "Found upstream: {:?}", | ||||
|         String::from_utf8(_path.0.clone()).unwrap_or_else(|_| "<none>".to_string()) | ||||
|       ); | ||||
|       Some(u) | ||||
|     } else { | ||||
|       None | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Upstream struct just containing uri without path
 | ||||
| pub struct Upstream { | ||||
|   /// Base uri without specific path
 | ||||
|   pub uri: hyper::Uri, | ||||
| } | ||||
| impl Upstream { | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   /// Hashing uri with index to avoid collision
 | ||||
|   pub fn calculate_id_with_index(&self, index: usize) -> String { | ||||
|     let mut hasher = Sha256::new(); | ||||
|     let uri_string = format!("{}&index={}", self.uri.clone(), index); | ||||
|     hasher.update(uri_string.as_bytes()); | ||||
|     let digest = hasher.finalize(); | ||||
|     general_purpose::URL_SAFE_NO_PAD.encode(digest) | ||||
|   } | ||||
| } | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Struct serving multiple upstream servers for, e.g., load balancing.
 | ||||
| pub struct UpstreamGroup { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Upstream server(s)
 | ||||
|   pub upstream: Vec<Upstream>, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Path like "/path" in [[PathNameBytesExp]] associated with the upstream server(s)
 | ||||
|   pub path: PathNameBytesExp, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url
 | ||||
|   pub replace_path: Option<PathNameBytesExp>, | ||||
| 
 | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Load balancing option
 | ||||
|   pub lb: LoadBalance, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Activated upstream options defined in [[UpstreamOption]]
 | ||||
|   pub opts: HashSet<UpstreamOption>, | ||||
| } | ||||
| 
 | ||||
| impl UpstreamGroupBuilder { | ||||
|   pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||
|     self.upstream = Some(upstream_vec.to_vec()); | ||||
|     self | ||||
|   } | ||||
|   pub fn path(&mut self, v: &Option<String>) -> &mut Self { | ||||
|     let path = match v { | ||||
|       Some(p) => p.to_path_name_vec(), | ||||
|       None => "/".to_path_name_vec(), | ||||
|     }; | ||||
|     self.path = Some(path); | ||||
|     self | ||||
|   } | ||||
|   pub fn replace_path(&mut self, v: &Option<String>) -> &mut Self { | ||||
|     self.replace_path = Some( | ||||
|       v.to_owned() | ||||
|         .as_ref() | ||||
|         .map_or_else(|| None, |v| Some(v.to_path_name_vec())), | ||||
|     ); | ||||
|     self | ||||
|   } | ||||
|   pub fn lb( | ||||
|     &mut self, | ||||
|     v: &Option<String>, | ||||
|     // upstream_num: &usize,
 | ||||
|     upstream_vec: &Vec<Upstream>, | ||||
|     _server_name: &str, | ||||
|     _path_opt: &Option<String>, | ||||
|   ) -> &mut Self { | ||||
|     let upstream_num = &upstream_vec.len(); | ||||
|     let lb = if let Some(x) = v { | ||||
|       match x.as_str() { | ||||
|         lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, | ||||
|         lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), | ||||
|         lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( | ||||
|           LbRoundRobinBuilder::default() | ||||
|             .num_upstreams(upstream_num) | ||||
|             .build() | ||||
|             .unwrap(), | ||||
|         ), | ||||
|         #[cfg(feature = "sticky-cookie")] | ||||
|         lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( | ||||
|           LbStickyRoundRobinBuilder::default() | ||||
|             .num_upstreams(upstream_num) | ||||
|             .sticky_config(_server_name, _path_opt) | ||||
|             .upstream_maps(upstream_vec) // TODO:
 | ||||
|             .build() | ||||
|             .unwrap(), | ||||
|         ), | ||||
|         _ => { | ||||
|           error!("Specified load balancing option is invalid."); | ||||
|           LoadBalance::default() | ||||
|         } | ||||
|       } | ||||
|     } else { | ||||
|       LoadBalance::default() | ||||
|     }; | ||||
|     self.lb = Some(lb); | ||||
|     self | ||||
|   } | ||||
|   pub fn opts(&mut self, v: &Option<Vec<String>>) -> &mut Self { | ||||
|     let opts = if let Some(opts) = v { | ||||
|       opts | ||||
|         .iter() | ||||
|         .filter_map(|str| UpstreamOption::try_from(str.as_str()).ok()) | ||||
|         .collect::<HashSet<UpstreamOption>>() | ||||
|     } else { | ||||
|       Default::default() | ||||
|     }; | ||||
|     self.opts = Some(opts); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl UpstreamGroup { | ||||
|   /// Get an enabled option of load balancing [[LoadBalance]]
 | ||||
|   pub fn get(&self, context_to_lb: &Option<LbContext>) -> (Option<&Upstream>, Option<LbContext>) { | ||||
|     let pointer_to_upstream = self.lb.get_context(context_to_lb); | ||||
|     debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); | ||||
|     debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); | ||||
|     debug!( | ||||
|       "Context from LB (Set-Cookie in Res): {:?}", | ||||
|       pointer_to_upstream.context_lb | ||||
|     ); | ||||
|     ( | ||||
|       self.upstream.get(pointer_to_upstream.ptr), | ||||
|       pointer_to_upstream.context_lb, | ||||
|     ) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod test { | ||||
|   #[allow(unused)] | ||||
|   use super::*; | ||||
| 
 | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   #[test] | ||||
|   fn calc_id_works() { | ||||
|     let uri = "https://www.rust-lang.org".parse::<hyper::Uri>().unwrap(); | ||||
|     let upstream = Upstream { uri }; | ||||
|     assert_eq!( | ||||
|       "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", | ||||
|       upstream.calculate_id_with_index(0) | ||||
|     ); | ||||
|     assert_eq!( | ||||
|       "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", | ||||
|       upstream.calculate_id_with_index(1) | ||||
|     ); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										22
									
								
								legacy-lib/src/backend/upstream_opts.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								legacy-lib/src/backend/upstream_opts.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,22 @@ | |||
| use crate::error::*; | ||||
| 
 | ||||
| #[derive(Debug, Clone, Hash, Eq, PartialEq)] | ||||
| pub enum UpstreamOption { | ||||
|   OverrideHost, | ||||
|   UpgradeInsecureRequests, | ||||
|   ForceHttp11Upstream, | ||||
|   ForceHttp2Upstream, | ||||
|   // TODO: Adds more options for heder override
 | ||||
| } | ||||
| impl TryFrom<&str> for UpstreamOption { | ||||
|   type Error = RpxyError; | ||||
|   fn try_from(val: &str) -> Result<Self> { | ||||
|     match val { | ||||
|       "override_host" => Ok(Self::OverrideHost), | ||||
|       "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), | ||||
|       "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), | ||||
|       "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), | ||||
|       _ => Err(RpxyError::Other(anyhow!("Unsupported header option"))), | ||||
|     } | ||||
|   } | ||||
| } | ||||
							
								
								
									
										45
									
								
								legacy-lib/src/constants.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								legacy-lib/src/constants.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,45 @@ | |||
| pub const RESPONSE_HEADER_SERVER: &str = "rpxy"; | ||||
| // pub const LISTEN_ADDRESSES_V4: &[&str] = &["0.0.0.0"];
 | ||||
| // pub const LISTEN_ADDRESSES_V6: &[&str] = &["[::]"];
 | ||||
| pub const TCP_LISTEN_BACKLOG: u32 = 1024; | ||||
| // pub const HTTP_LISTEN_PORT: u16 = 8080;
 | ||||
| // pub const HTTPS_LISTEN_PORT: u16 = 8443;
 | ||||
| pub const PROXY_TIMEOUT_SEC: u64 = 60; | ||||
| pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; | ||||
| pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser
 | ||||
| pub const MAX_CLIENTS: usize = 512; | ||||
| pub const MAX_CONCURRENT_STREAMS: u32 = 64; | ||||
| pub const CERTS_WATCH_DELAY_SECS: u32 = 60; | ||||
| pub const LOAD_CERTS_ONLY_WHEN_UPDATED: bool = true; | ||||
| 
 | ||||
| // #[cfg(feature = "http3")]
 | ||||
| // pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB
 | ||||
| // #[cfg(feature = "http3")]
 | ||||
| // pub const H3_REQUEST_BUF_SIZE: usize = 65_536; // 64KB // handled by quinn
 | ||||
| 
 | ||||
| #[allow(non_snake_case)] | ||||
| #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
| pub mod H3 { | ||||
|   pub const ALT_SVC_MAX_AGE: u32 = 3600; | ||||
|   pub const REQUEST_MAX_BODY_SIZE: usize = 268_435_456; // 256MB
 | ||||
|   pub const MAX_CONCURRENT_CONNECTIONS: u32 = 4096; | ||||
|   pub const MAX_CONCURRENT_BIDISTREAM: u32 = 64; | ||||
|   pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; | ||||
|   pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs
 | ||||
| } | ||||
| 
 | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| /// For load-balancing with sticky cookie
 | ||||
| pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; | ||||
| 
 | ||||
| #[cfg(feature = "cache")] | ||||
| // # of entries in cache
 | ||||
| pub const MAX_CACHE_ENTRY: usize = 1_000; | ||||
| #[cfg(feature = "cache")] | ||||
| // max size for each file in bytes
 | ||||
| pub const MAX_CACHE_EACH_SIZE: usize = 65_535; | ||||
| #[cfg(feature = "cache")] | ||||
| // on memory cache if less than or equel to
 | ||||
| pub const MAX_CACHE_EACH_SIZE_ON_MEMORY: usize = 4_096; | ||||
| 
 | ||||
| // TODO: max cache size in total
 | ||||
							
								
								
									
										86
									
								
								legacy-lib/src/error.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								legacy-lib/src/error.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,86 @@ | |||
| pub use anyhow::{anyhow, bail, ensure, Context}; | ||||
| use std::io; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
| pub type Result<T> = std::result::Result<T, RpxyError>; | ||||
| 
 | ||||
| /// Describes things that can go wrong in the Rpxy
 | ||||
| #[derive(Debug, Error)] | ||||
| pub enum RpxyError { | ||||
|   #[error("Proxy build error: {0}")] | ||||
|   ProxyBuild(#[from] crate::proxy::ProxyBuilderError), | ||||
| 
 | ||||
|   #[error("Backend build error: {0}")] | ||||
|   BackendBuild(#[from] crate::backend::BackendBuilderError), | ||||
| 
 | ||||
|   #[error("MessageHandler build error: {0}")] | ||||
|   HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError), | ||||
| 
 | ||||
|   #[error("Config builder error: {0}")] | ||||
|   ConfigBuild(&'static str), | ||||
| 
 | ||||
|   #[error("Http Message Handler Error: {0}")] | ||||
|   Handler(&'static str), | ||||
| 
 | ||||
|   #[error("Cache Error: {0}")] | ||||
|   Cache(&'static str), | ||||
| 
 | ||||
|   #[error("Http Request Message Error: {0}")] | ||||
|   Request(&'static str), | ||||
| 
 | ||||
|   #[error("TCP/UDP Proxy Layer Error: {0}")] | ||||
|   Proxy(String), | ||||
| 
 | ||||
|   #[allow(unused)] | ||||
|   #[error("LoadBalance Layer Error: {0}")] | ||||
|   LoadBalance(String), | ||||
| 
 | ||||
|   #[error("I/O Error: {0}")] | ||||
|   Io(#[from] io::Error), | ||||
| 
 | ||||
|   // #[error("Toml Deserialization Error")]
 | ||||
|   // TomlDe(#[from] toml::de::Error),
 | ||||
|   #[cfg(feature = "http3-quinn")] | ||||
|   #[error("Quic Connection Error [quinn]: {0}")] | ||||
|   QuicConn(#[from] quinn::ConnectionError), | ||||
| 
 | ||||
|   #[cfg(feature = "http3-s2n")] | ||||
|   #[error("Quic Connection Error [s2n-quic]: {0}")] | ||||
|   QUicConn(#[from] s2n_quic::connection::Error), | ||||
| 
 | ||||
|   #[cfg(feature = "http3-quinn")] | ||||
|   #[error("H3 Error [quinn]: {0}")] | ||||
|   H3(#[from] h3::Error), | ||||
| 
 | ||||
|   #[cfg(feature = "http3-s2n")] | ||||
|   #[error("H3 Error [s2n-quic]: {0}")] | ||||
|   H3(#[from] s2n_quic_h3::h3::Error), | ||||
| 
 | ||||
|   #[error("rustls Connection Error: {0}")] | ||||
|   Rustls(#[from] rustls::Error), | ||||
| 
 | ||||
|   #[error("Hyper Error: {0}")] | ||||
|   Hyper(#[from] hyper::Error), | ||||
| 
 | ||||
|   #[error("Hyper Http Error: {0}")] | ||||
|   HyperHttp(#[from] hyper::http::Error), | ||||
| 
 | ||||
|   #[error("Hyper Http HeaderValue Error: {0}")] | ||||
|   HyperHeaderValue(#[from] hyper::header::InvalidHeaderValue), | ||||
| 
 | ||||
|   #[error("Hyper Http HeaderName Error: {0}")] | ||||
|   HyperHeaderName(#[from] hyper::header::InvalidHeaderName), | ||||
| 
 | ||||
|   #[error(transparent)] | ||||
|   Other(#[from] anyhow::Error), | ||||
| } | ||||
| 
 | ||||
| #[allow(dead_code)] | ||||
| #[derive(Debug, Error, Clone)] | ||||
| pub enum ClientCertsError { | ||||
|   #[error("TLS Client Certificate is Required for Given SNI: {0}")] | ||||
|   ClientCertRequired(String), | ||||
| 
 | ||||
|   #[error("Inconsistent TLS Client Certificate for Given SNI: {0}")] | ||||
|   InconsistentClientCert(String), | ||||
| } | ||||
							
								
								
									
										325
									
								
								legacy-lib/src/globals.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										325
									
								
								legacy-lib/src/globals.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,325 @@ | |||
| use crate::{ | ||||
|   backend::{ | ||||
|     Backend, BackendBuilder, Backends, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption, | ||||
|   }, | ||||
|   certs::CryptoSource, | ||||
|   constants::*, | ||||
|   error::RpxyError, | ||||
|   log::*, | ||||
|   utils::{BytesName, PathNameBytesExp}, | ||||
| }; | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use std::net::SocketAddr; | ||||
| use std::sync::{ | ||||
|   atomic::{AtomicUsize, Ordering}, | ||||
|   Arc, | ||||
| }; | ||||
| use tokio::time::Duration; | ||||
| 
 | ||||
| /// Global object containing proxy configurations and shared object like counters.
 | ||||
| /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
 | ||||
| pub struct Globals<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   /// Configuration parameters for proxy transport and request handlers
 | ||||
|   pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも?
 | ||||
| 
 | ||||
|   /// Backend application objects to which http request handler forward incoming requests
 | ||||
|   pub backends: Backends<T>, | ||||
| 
 | ||||
|   /// Shared context - Counter for serving requests
 | ||||
|   pub request_count: RequestCount, | ||||
| 
 | ||||
|   /// Shared context - Async task runtime handler
 | ||||
|   pub runtime_handle: tokio::runtime::Handle, | ||||
| 
 | ||||
|   /// Shared context - Notify object to stop async tasks
 | ||||
|   pub term_notify: Option<Arc<tokio::sync::Notify>>, | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for proxy transport and request handlers
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct ProxyConfig { | ||||
|   pub listen_sockets: Vec<SocketAddr>, // when instantiate server
 | ||||
|   pub http_port: Option<u16>,          // when instantiate server
 | ||||
|   pub https_port: Option<u16>,         // when instantiate server
 | ||||
|   pub tcp_listen_backlog: u32,         // when instantiate server
 | ||||
| 
 | ||||
|   pub proxy_timeout: Duration,    // when serving requests at Proxy
 | ||||
|   pub upstream_timeout: Duration, // when serving requests at Handler
 | ||||
| 
 | ||||
|   pub max_clients: usize,          // when serving requests
 | ||||
|   pub max_concurrent_streams: u32, // when instantiate server
 | ||||
|   pub keepalive: bool,             // when instantiate server
 | ||||
| 
 | ||||
|   // experimentals
 | ||||
|   pub sni_consistency: bool, // Handler
 | ||||
| 
 | ||||
|   #[cfg(feature = "cache")] | ||||
|   pub cache_enabled: bool, | ||||
|   #[cfg(feature = "cache")] | ||||
|   pub cache_dir: Option<std::path::PathBuf>, | ||||
|   #[cfg(feature = "cache")] | ||||
|   pub cache_max_entry: usize, | ||||
|   #[cfg(feature = "cache")] | ||||
|   pub cache_max_each_size: usize, | ||||
|   #[cfg(feature = "cache")] | ||||
|   pub cache_max_each_size_on_memory: usize, | ||||
| 
 | ||||
|   // All need to make packet acceptor
 | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   pub http3: bool, | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   pub h3_alt_svc_max_age: u32, | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   pub h3_request_max_body_size: usize, | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   pub h3_max_concurrent_bidistream: u32, | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   pub h3_max_concurrent_unistream: u32, | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   pub h3_max_concurrent_connections: u32, | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   pub h3_max_idle_timeout: Option<Duration>, | ||||
| } | ||||
| 
 | ||||
| impl Default for ProxyConfig { | ||||
|   fn default() -> Self { | ||||
|     Self { | ||||
|       listen_sockets: Vec::new(), | ||||
|       http_port: None, | ||||
|       https_port: None, | ||||
|       tcp_listen_backlog: TCP_LISTEN_BACKLOG, | ||||
| 
 | ||||
|       // TODO: Reconsider each timeout values
 | ||||
|       proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), | ||||
|       upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), | ||||
| 
 | ||||
|       max_clients: MAX_CLIENTS, | ||||
|       max_concurrent_streams: MAX_CONCURRENT_STREAMS, | ||||
|       keepalive: true, | ||||
| 
 | ||||
|       sni_consistency: true, | ||||
| 
 | ||||
|       #[cfg(feature = "cache")] | ||||
|       cache_enabled: false, | ||||
|       #[cfg(feature = "cache")] | ||||
|       cache_dir: None, | ||||
|       #[cfg(feature = "cache")] | ||||
|       cache_max_entry: MAX_CACHE_ENTRY, | ||||
|       #[cfg(feature = "cache")] | ||||
|       cache_max_each_size: MAX_CACHE_EACH_SIZE, | ||||
|       #[cfg(feature = "cache")] | ||||
|       cache_max_each_size_on_memory: MAX_CACHE_EACH_SIZE_ON_MEMORY, | ||||
| 
 | ||||
|       #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|       http3: false, | ||||
|       #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|       h3_alt_svc_max_age: H3::ALT_SVC_MAX_AGE, | ||||
|       #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|       h3_request_max_body_size: H3::REQUEST_MAX_BODY_SIZE, | ||||
|       #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|       h3_max_concurrent_connections: H3::MAX_CONCURRENT_CONNECTIONS, | ||||
|       #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|       h3_max_concurrent_bidistream: H3::MAX_CONCURRENT_BIDISTREAM, | ||||
|       #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|       h3_max_concurrent_unistream: H3::MAX_CONCURRENT_UNISTREAM, | ||||
|       #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|       h3_max_idle_timeout: Some(Duration::from_secs(H3::MAX_IDLE_TIMEOUT)), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for backend applications
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct AppConfigList<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub inner: Vec<AppConfig<T>>, | ||||
|   pub default_app: Option<String>, | ||||
| } | ||||
| impl<T> TryInto<Backends<T>> for AppConfigList<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<Backends<T>, Self::Error> { | ||||
|     let mut backends = Backends::new(); | ||||
|     for app_config in self.inner.iter() { | ||||
|       let backend = app_config.try_into()?; | ||||
|       backends | ||||
|         .apps | ||||
|         .insert(app_config.server_name.clone().to_server_name_vec(), backend); | ||||
|       info!( | ||||
|         "Registering application {} ({})", | ||||
|         &app_config.server_name, &app_config.app_name | ||||
|       ); | ||||
|     } | ||||
| 
 | ||||
|     // default backend application for plaintext http requests
 | ||||
|     if let Some(d) = self.default_app { | ||||
|       let d_sn: Vec<&str> = backends | ||||
|         .apps | ||||
|         .iter() | ||||
|         .filter(|(_k, v)| v.app_name == d) | ||||
|         .map(|(_, v)| v.server_name.as_ref()) | ||||
|         .collect(); | ||||
|       if !d_sn.is_empty() { | ||||
|         info!( | ||||
|           "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", | ||||
|           d, d_sn[0] | ||||
|         ); | ||||
|         backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); | ||||
|       } | ||||
|     } | ||||
|     Ok(backends) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for single backend application
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct AppConfig<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub app_name: String, | ||||
|   pub server_name: String, | ||||
|   pub reverse_proxy: Vec<ReverseProxyConfig>, | ||||
|   pub tls: Option<TlsConfig<T>>, | ||||
| } | ||||
| impl<T> TryInto<Backend<T>> for &AppConfig<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<Backend<T>, Self::Error> { | ||||
|     // backend builder
 | ||||
|     let mut backend_builder = BackendBuilder::default(); | ||||
|     // reverse proxy settings
 | ||||
|     let reverse_proxy = self.try_into()?; | ||||
| 
 | ||||
|     backend_builder | ||||
|       .app_name(self.app_name.clone()) | ||||
|       .server_name(self.server_name.clone()) | ||||
|       .reverse_proxy(reverse_proxy); | ||||
| 
 | ||||
|     // TLS settings and build backend instance
 | ||||
|     let backend = if self.tls.is_none() { | ||||
|       backend_builder.build().map_err(RpxyError::BackendBuild)? | ||||
|     } else { | ||||
|       let tls = self.tls.as_ref().unwrap(); | ||||
| 
 | ||||
|       backend_builder | ||||
|         .https_redirection(Some(tls.https_redirection)) | ||||
|         .crypto_source(Some(tls.inner.clone())) | ||||
|         .build()? | ||||
|     }; | ||||
|     Ok(backend) | ||||
|   } | ||||
| } | ||||
| impl<T> TryInto<ReverseProxy> for &AppConfig<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<ReverseProxy, Self::Error> { | ||||
|     let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default(); | ||||
| 
 | ||||
|     self.reverse_proxy.iter().for_each(|rpo| { | ||||
|       let upstream_vec: Vec<Upstream> = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); | ||||
|       // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap());
 | ||||
|       // let lb_upstream_num = vec_upstream.len();
 | ||||
|       let elem = UpstreamGroupBuilder::default() | ||||
|         .upstream(&upstream_vec) | ||||
|         .path(&rpo.path) | ||||
|         .replace_path(&rpo.replace_path) | ||||
|         .lb(&rpo.load_balance, &upstream_vec, &self.server_name, &rpo.path) | ||||
|         .opts(&rpo.upstream_options) | ||||
|         .build() | ||||
|         .unwrap(); | ||||
| 
 | ||||
|       upstream.insert(elem.path.clone(), elem); | ||||
|     }); | ||||
|     if self.reverse_proxy.iter().filter(|rpo| rpo.path.is_none()).count() >= 2 { | ||||
|       error!("Multiple default reverse proxy setting"); | ||||
|       return Err(RpxyError::ConfigBuild("Invalid reverse proxy setting")); | ||||
|     } | ||||
| 
 | ||||
|     if !(upstream.iter().all(|(_, elem)| { | ||||
|       !(elem.opts.contains(&UpstreamOption::ForceHttp11Upstream) | ||||
|         && elem.opts.contains(&UpstreamOption::ForceHttp2Upstream)) | ||||
|     })) { | ||||
|       error!("Either one of force_http11 or force_http2 can be enabled"); | ||||
|       return Err(RpxyError::ConfigBuild("Invalid upstream option setting")); | ||||
|     } | ||||
| 
 | ||||
|     Ok(ReverseProxy { upstream }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for single reverse proxy corresponding to the path
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct ReverseProxyConfig { | ||||
|   pub path: Option<String>, | ||||
|   pub replace_path: Option<String>, | ||||
|   pub upstream: Vec<UpstreamUri>, | ||||
|   pub upstream_options: Option<Vec<String>>, | ||||
|   pub load_balance: Option<String>, | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for single upstream destination from a reverse proxy
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct UpstreamUri { | ||||
|   pub inner: hyper::Uri, | ||||
| } | ||||
| impl TryInto<Upstream> for &UpstreamUri { | ||||
|   type Error = anyhow::Error; | ||||
| 
 | ||||
|   fn try_into(self) -> std::result::Result<Upstream, Self::Error> { | ||||
|     Ok(Upstream { | ||||
|       uri: self.inner.clone(), | ||||
|     }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters on TLS for a single backend application
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct TlsConfig<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub inner: T, | ||||
|   pub https_redirection: bool, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Default)] | ||||
| /// Counter for serving requests
 | ||||
| pub struct RequestCount(Arc<AtomicUsize>); | ||||
| 
 | ||||
| impl RequestCount { | ||||
|   pub fn current(&self) -> usize { | ||||
|     self.0.load(Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
|   pub fn increment(&self) -> usize { | ||||
|     self.0.fetch_add(1, Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
|   pub fn decrement(&self) -> usize { | ||||
|     let mut count; | ||||
|     while { | ||||
|       count = self.0.load(Ordering::Relaxed); | ||||
|       count > 0 | ||||
|         && self | ||||
|           .0 | ||||
|           .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) | ||||
|           != Ok(count) | ||||
|     } {} | ||||
|     count | ||||
|   } | ||||
| } | ||||
							
								
								
									
										16
									
								
								legacy-lib/src/handler/error.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								legacy-lib/src/handler/error.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,16 @@ | |||
| use http::StatusCode; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
| pub type HttpResult<T> = std::result::Result<T, HttpError>; | ||||
| 
 | ||||
| /// Describes things that can go wrong in the handler
 | ||||
| #[derive(Debug, Error)] | ||||
| pub enum HttpError {} | ||||
| 
 | ||||
| impl From<HttpError> for StatusCode { | ||||
|   fn from(e: HttpError) -> StatusCode { | ||||
|     match e { | ||||
|       _ => StatusCode::INTERNAL_SERVER_ERROR, | ||||
|     } | ||||
|   } | ||||
| } | ||||
							
								
								
									
										384
									
								
								legacy-lib/src/handler/handler_main.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										384
									
								
								legacy-lib/src/handler/handler_main.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,384 @@ | |||
| // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 | ||||
| use super::{ | ||||
|   error::*, | ||||
|   // forwarder::{ForwardRequest, Forwarder},
 | ||||
|   utils_headers::*, | ||||
|   utils_request::*, | ||||
|   // utils_synth_response::*,
 | ||||
|   HandlerContext, | ||||
| }; | ||||
| use crate::{ | ||||
|   backend::{Backend, UpstreamGroup}, | ||||
|   certs::CryptoSource, | ||||
|   constants::RESPONSE_HEADER_SERVER, | ||||
|   error::*, | ||||
|   globals::Globals, | ||||
|   log::*, | ||||
|   utils::ServerNameBytesExp, | ||||
| }; | ||||
| use derive_builder::Builder; | ||||
| use http::{ | ||||
|   header::{self, HeaderValue}, | ||||
|   uri::Scheme, | ||||
|   Request, Response, StatusCode, Uri, Version, | ||||
| }; | ||||
| use hyper::body::Incoming; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| use std::{net::SocketAddr, sync::Arc}; | ||||
| use tokio::{io::copy_bidirectional, time::timeout}; | ||||
| 
 | ||||
| #[derive(Clone, Builder)] | ||||
| /// HTTP message handler for requests from clients and responses from backend applications,
 | ||||
| /// responsible to manipulate and forward messages to upstream backends and downstream clients.
 | ||||
| // pub struct HttpMessageHandler<T, U>
 | ||||
| pub struct HttpMessageHandler<U> | ||||
| where | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   // forwarder: Arc<Forwarder<T>>,
 | ||||
|   globals: Arc<Globals<U>>, | ||||
| } | ||||
| 
 | ||||
| impl<U> HttpMessageHandler<U> | ||||
| where | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   // /// Return with an arbitrary status code of error and log message
 | ||||
|   // fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result<Response<Body>> {
 | ||||
|   //   log_data.status_code(&status_code).output();
 | ||||
|   //   http_error(status_code)
 | ||||
|   // }
 | ||||
| 
 | ||||
|   /// Handle incoming request message from a client
 | ||||
|   pub async fn handle_request( | ||||
|     &self, | ||||
|     mut req: Request<Incoming>, | ||||
|     client_addr: SocketAddr, // アクセス制御用
 | ||||
|     listen_addr: SocketAddr, | ||||
|     tls_enabled: bool, | ||||
|     tls_server_name: Option<ServerNameBytesExp>, | ||||
|   ) -> Result<HttpResult<Response<Incoming>>> { | ||||
|     ////////
 | ||||
|     let mut log_data = MessageLog::from(&req); | ||||
|     log_data.client_addr(&client_addr); | ||||
|     //////
 | ||||
| 
 | ||||
|     // // Here we start to handle with server_name
 | ||||
|     // let server_name = if let Ok(v) = req.parse_host() {
 | ||||
|     //   ServerNameBytesExp::from(v)
 | ||||
|     // } else {
 | ||||
|     //   return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data);
 | ||||
|     // };
 | ||||
|     // // check consistency of between TLS SNI and HOST/Request URI Line.
 | ||||
|     // #[allow(clippy::collapsible_if)]
 | ||||
|     // if tls_enabled && self.globals.proxy_config.sni_consistency {
 | ||||
|     //   if server_name != tls_server_name.unwrap_or_default() {
 | ||||
|     //     return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data);
 | ||||
|     //   }
 | ||||
|     // }
 | ||||
|     // // Find backend application for given server_name, and drop if incoming request is invalid as request.
 | ||||
|     // let backend = match self.globals.backends.apps.get(&server_name) {
 | ||||
|     //   Some(be) => be,
 | ||||
|     //   None => {
 | ||||
|     //     let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else {
 | ||||
|     //       return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data);
 | ||||
|     //     };
 | ||||
|     //     debug!("Serving by default app");
 | ||||
|     //     self.globals.backends.apps.get(default_server_name).unwrap()
 | ||||
|     //   }
 | ||||
|     // };
 | ||||
| 
 | ||||
|     // // Redirect to https if !tls_enabled and redirect_to_https is true
 | ||||
|     // if !tls_enabled && backend.https_redirection.unwrap_or(false) {
 | ||||
|     //   debug!("Redirect to secure connection: {}", &backend.server_name);
 | ||||
|     //   log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output();
 | ||||
|     //   return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req);
 | ||||
|     // }
 | ||||
| 
 | ||||
|     // // Find reverse proxy for given path and choose one of upstream host
 | ||||
|     // // Longest prefix match
 | ||||
|     // let path = req.uri().path();
 | ||||
|     // let Some(upstream_group) = backend.reverse_proxy.get(path) else {
 | ||||
|     //   return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data);
 | ||||
|     // };
 | ||||
| 
 | ||||
|     // // Upgrade in request header
 | ||||
|     // let upgrade_in_request = extract_upgrade(req.headers());
 | ||||
|     // let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
 | ||||
| 
 | ||||
|     // // Build request from destination information
 | ||||
|     // let _context = match self.generate_request_forwarded(
 | ||||
|     //   &client_addr,
 | ||||
|     //   &listen_addr,
 | ||||
|     //   &mut req,
 | ||||
|     //   &upgrade_in_request,
 | ||||
|     //   upstream_group,
 | ||||
|     //   tls_enabled,
 | ||||
|     // ) {
 | ||||
|     //   Err(e) => {
 | ||||
|     //     error!("Failed to generate destination uri for reverse proxy: {}", e);
 | ||||
|     //     return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data);
 | ||||
|     //   }
 | ||||
|     //   Ok(v) => v,
 | ||||
|     // };
 | ||||
|     // debug!("Request to be forwarded: {:?}", req);
 | ||||
|     // log_data.xff(&req.headers().get("x-forwarded-for"));
 | ||||
|     // log_data.upstream(req.uri());
 | ||||
|     // //////
 | ||||
| 
 | ||||
|     // // Forward request to a chosen backend
 | ||||
|     // let mut res_backend = {
 | ||||
|     //   let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else {
 | ||||
|     //     return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data);
 | ||||
|     //   };
 | ||||
|     //   match result {
 | ||||
|     //     Ok(res) => res,
 | ||||
|     //     Err(e) => {
 | ||||
|     //       error!("Failed to get response from backend: {}", e);
 | ||||
|     //       return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data);
 | ||||
|     //     }
 | ||||
|     //   }
 | ||||
|     // };
 | ||||
| 
 | ||||
|     // // Process reverse proxy context generated during the forwarding request generation.
 | ||||
|     // #[cfg(feature = "sticky-cookie")]
 | ||||
|     // if let Some(context_from_lb) = _context.context_lb {
 | ||||
|     //   let res_headers = res_backend.headers_mut();
 | ||||
|     //   if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) {
 | ||||
|     //     error!("Failed to append context to the response given from backend: {}", e);
 | ||||
|     //     return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data);
 | ||||
|     //   }
 | ||||
|     // }
 | ||||
| 
 | ||||
|     // if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS {
 | ||||
|     //   // Generate response to client
 | ||||
|     //   if self.generate_response_forwarded(&mut res_backend, backend).is_err() {
 | ||||
|     //     return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
 | ||||
|     //   }
 | ||||
|     //   log_data.status_code(&res_backend.status()).output();
 | ||||
|     //   return Ok(res_backend);
 | ||||
|     // }
 | ||||
| 
 | ||||
|     // // Handle StatusCode::SWITCHING_PROTOCOLS in response
 | ||||
|     // let upgrade_in_response = extract_upgrade(res_backend.headers());
 | ||||
|     // let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref())
 | ||||
|     // {
 | ||||
|     //   u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase()
 | ||||
|     // } else {
 | ||||
|     //   false
 | ||||
|     // };
 | ||||
|     // if !should_upgrade {
 | ||||
|     //   error!(
 | ||||
|     //     "Backend tried to switch to protocol {:?} when {:?} was requested",
 | ||||
|     //     upgrade_in_response, upgrade_in_request
 | ||||
|     //   );
 | ||||
|     //   return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
 | ||||
|     // }
 | ||||
|     // let Some(request_upgraded) = request_upgraded else {
 | ||||
|     //   error!("Request does not have an upgrade extension");
 | ||||
|     //   return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data);
 | ||||
|     // };
 | ||||
|     // let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
 | ||||
|     //   error!("Response does not have an upgrade extension");
 | ||||
|     //   return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data);
 | ||||
|     // };
 | ||||
| 
 | ||||
|     // self.globals.runtime_handle.spawn(async move {
 | ||||
|     //   let mut response_upgraded = onupgrade.await.map_err(|e| {
 | ||||
|     //     error!("Failed to upgrade response: {}", e);
 | ||||
|     //     RpxyError::Hyper(e)
 | ||||
|     //   })?;
 | ||||
|     //   let mut request_upgraded = request_upgraded.await.map_err(|e| {
 | ||||
|     //     error!("Failed to upgrade request: {}", e);
 | ||||
|     //     RpxyError::Hyper(e)
 | ||||
|     //   })?;
 | ||||
|     //   copy_bidirectional(&mut response_upgraded, &mut request_upgraded)
 | ||||
|     //     .await
 | ||||
|     //     .map_err(|e| {
 | ||||
|     //       error!("Coping between upgraded connections failed: {}", e);
 | ||||
|     //       RpxyError::Io(e)
 | ||||
|     //     })?;
 | ||||
|     //   Ok(()) as Result<()>
 | ||||
|     // });
 | ||||
|     // log_data.status_code(&res_backend.status()).output();
 | ||||
|     // Ok(res_backend)
 | ||||
|     todo!() | ||||
|   } | ||||
| 
 | ||||
|   ////////////////////////////////////////////////////
 | ||||
|   // Functions to generate messages
 | ||||
|   ////////////////////////////////////////////////////
 | ||||
| 
 | ||||
|   // /// Manipulate a response message sent from a backend application to forward downstream to a client.
 | ||||
|   // fn generate_response_forwarded<B>(&self, response: &mut Response<B>, chosen_backend: &Backend<U>) -> Result<()>
 | ||||
|   // where
 | ||||
|   //   B: core::fmt::Debug,
 | ||||
|   // {
 | ||||
|   //   let headers = response.headers_mut();
 | ||||
|   //   remove_connection_header(headers);
 | ||||
|   //   remove_hop_header(headers);
 | ||||
|   //   add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?;
 | ||||
| 
 | ||||
|   //   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))]
 | ||||
|   //   {
 | ||||
|   //     // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled
 | ||||
|   //     // TODO: This is a workaround for avoiding a client authentication in HTTP/3
 | ||||
|   //     if self.globals.proxy_config.http3
 | ||||
|   //       && chosen_backend
 | ||||
|   //         .crypto_source
 | ||||
|   //         .as_ref()
 | ||||
|   //         .is_some_and(|v| !v.is_mutual_tls())
 | ||||
|   //     {
 | ||||
|   //       if let Some(port) = self.globals.proxy_config.https_port {
 | ||||
|   //         add_header_entry_overwrite_if_exist(
 | ||||
|   //           headers,
 | ||||
|   //           header::ALT_SVC.as_str(),
 | ||||
|   //           format!(
 | ||||
|   //             "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}",
 | ||||
|   //             port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age
 | ||||
|   //           ),
 | ||||
|   //         )?;
 | ||||
|   //       }
 | ||||
|   //     } else {
 | ||||
|   //       // remove alt-svc to disallow requests via http3
 | ||||
|   //       headers.remove(header::ALT_SVC.as_str());
 | ||||
|   //     }
 | ||||
|   //   }
 | ||||
|   //   #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))]
 | ||||
|   //   {
 | ||||
|   //     if let Some(port) = self.globals.proxy_config.https_port {
 | ||||
|   //       headers.remove(header::ALT_SVC.as_str());
 | ||||
|   //     }
 | ||||
|   //   }
 | ||||
| 
 | ||||
|   //   Ok(())
 | ||||
|   // }
 | ||||
| 
 | ||||
|   // #[allow(clippy::too_many_arguments)]
 | ||||
|   // /// Manipulate a request message sent from a client to forward upstream to a backend application
 | ||||
|   // fn generate_request_forwarded<B>(
 | ||||
|   //   &self,
 | ||||
|   //   client_addr: &SocketAddr,
 | ||||
|   //   listen_addr: &SocketAddr,
 | ||||
|   //   req: &mut Request<B>,
 | ||||
|   //   upgrade: &Option<String>,
 | ||||
|   //   upstream_group: &UpstreamGroup,
 | ||||
|   //   tls_enabled: bool,
 | ||||
|   // ) -> Result<HandlerContext> {
 | ||||
|   //   debug!("Generate request to be forwarded");
 | ||||
| 
 | ||||
|   //   // Add te: trailer if contained in original request
 | ||||
|   //   let contains_te_trailers = {
 | ||||
|   //     if let Some(te) = req.headers().get(header::TE) {
 | ||||
|   //       te.as_bytes()
 | ||||
|   //         .split(|v| v == &b',' || v == &b' ')
 | ||||
|   //         .any(|x| x == "trailers".as_bytes())
 | ||||
|   //     } else {
 | ||||
|   //       false
 | ||||
|   //     }
 | ||||
|   //   };
 | ||||
| 
 | ||||
|   //   let uri = req.uri().to_string();
 | ||||
|   //   let headers = req.headers_mut();
 | ||||
|   //   // delete headers specified in header.connection
 | ||||
|   //   remove_connection_header(headers);
 | ||||
|   //   // delete hop headers including header.connection
 | ||||
|   //   remove_hop_header(headers);
 | ||||
|   //   // X-Forwarded-For
 | ||||
|   //   add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?;
 | ||||
| 
 | ||||
|   //   // Add te: trailer if te_trailer
 | ||||
|   //   if contains_te_trailers {
 | ||||
|   //     headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap());
 | ||||
|   //   }
 | ||||
| 
 | ||||
|   //   // add "host" header of original server_name if not exist (default)
 | ||||
|   //   if req.headers().get(header::HOST).is_none() {
 | ||||
|   //     let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned();
 | ||||
|   //     req
 | ||||
|   //       .headers_mut()
 | ||||
|   //       .insert(header::HOST, HeaderValue::from_str(&org_host)?);
 | ||||
|   //   };
 | ||||
| 
 | ||||
|   //   /////////////////////////////////////////////
 | ||||
|   //   // Fix unique upstream destination since there could be multiple ones.
 | ||||
|   //   #[cfg(feature = "sticky-cookie")]
 | ||||
|   //   let (upstream_chosen_opt, context_from_lb) = {
 | ||||
|   //     let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb {
 | ||||
|   //       takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)?
 | ||||
|   //     } else {
 | ||||
|   //       None
 | ||||
|   //     };
 | ||||
|   //     upstream_group.get(&context_to_lb)
 | ||||
|   //   };
 | ||||
|   //   #[cfg(not(feature = "sticky-cookie"))]
 | ||||
|   //   let (upstream_chosen_opt, _) = upstream_group.get(&None);
 | ||||
| 
 | ||||
|   //   let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?;
 | ||||
|   //   let context = HandlerContext {
 | ||||
|   //     #[cfg(feature = "sticky-cookie")]
 | ||||
|   //     context_lb: context_from_lb,
 | ||||
|   //     #[cfg(not(feature = "sticky-cookie"))]
 | ||||
|   //     context_lb: None,
 | ||||
|   //   };
 | ||||
|   //   /////////////////////////////////////////////
 | ||||
| 
 | ||||
|   //   // apply upstream-specific headers given in upstream_option
 | ||||
|   //   let headers = req.headers_mut();
 | ||||
|   //   apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?;
 | ||||
| 
 | ||||
|   //   // update uri in request
 | ||||
|   //   if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) {
 | ||||
|   //     return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken"));
 | ||||
|   //   };
 | ||||
|   //   let new_uri = Uri::builder()
 | ||||
|   //     .scheme(upstream_chosen.uri.scheme().unwrap().as_str())
 | ||||
|   //     .authority(upstream_chosen.uri.authority().unwrap().as_str());
 | ||||
|   //   let org_pq = match req.uri().path_and_query() {
 | ||||
|   //     Some(pq) => pq.to_string(),
 | ||||
|   //     None => "/".to_string(),
 | ||||
|   //   }
 | ||||
|   //   .into_bytes();
 | ||||
| 
 | ||||
|   //   // replace some parts of path if opt_replace_path is enabled for chosen upstream
 | ||||
|   //   let new_pq = match &upstream_group.replace_path {
 | ||||
|   //     Some(new_path) => {
 | ||||
|   //       let matched_path: &[u8] = upstream_group.path.as_ref();
 | ||||
|   //       if matched_path.is_empty() || org_pq.len() < matched_path.len() {
 | ||||
|   //         return Err(RpxyError::Handler("Upstream uri `path and query` is broken"));
 | ||||
|   //       };
 | ||||
|   //       let mut new_pq = Vec::<u8>::with_capacity(org_pq.len() - matched_path.len() + new_path.len());
 | ||||
|   //       new_pq.extend_from_slice(new_path.as_ref());
 | ||||
|   //       new_pq.extend_from_slice(&org_pq[matched_path.len()..]);
 | ||||
|   //       new_pq
 | ||||
|   //     }
 | ||||
|   //     None => org_pq,
 | ||||
|   //   };
 | ||||
|   //   *req.uri_mut() = new_uri.path_and_query(new_pq).build()?;
 | ||||
| 
 | ||||
|   //   // upgrade
 | ||||
|   //   if let Some(v) = upgrade {
 | ||||
|   //     req.headers_mut().insert(header::UPGRADE, v.parse()?);
 | ||||
|   //     req
 | ||||
|   //       .headers_mut()
 | ||||
|   //       .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?);
 | ||||
|   //   }
 | ||||
| 
 | ||||
|   //   // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
 | ||||
|   //   if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) {
 | ||||
|   //     // Change version to http/1.1 when destination scheme is http
 | ||||
|   //     debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled.");
 | ||||
|   //     *req.version_mut() = Version::HTTP_11;
 | ||||
|   //   } else if req.version() == Version::HTTP_3 {
 | ||||
|   //     // HTTP/3 is always https
 | ||||
|   //     debug!("HTTP/3 is currently unsupported for request to upstream.");
 | ||||
|   //     *req.version_mut() = Version::HTTP_2;
 | ||||
|   //   }
 | ||||
| 
 | ||||
|   //   apply_upstream_options_to_request_line(req, upstream_group)?;
 | ||||
| 
 | ||||
|   //   Ok(context)
 | ||||
|   // }
 | ||||
| } | ||||
|  | @ -1,17 +1,15 @@ | |||
| #[cfg(feature = "cache")] | ||||
| mod cache; | ||||
| mod forwarder; | ||||
| // mod cache;
 | ||||
| mod error; | ||||
| // mod forwarder;
 | ||||
| mod handler_main; | ||||
| mod utils_headers; | ||||
| mod utils_request; | ||||
| mod utils_synth_response; | ||||
| // mod utils_synth_response;
 | ||||
| 
 | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| use crate::backend::LbContext; | ||||
| pub use { | ||||
|   forwarder::Forwarder, | ||||
|   handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}, | ||||
| }; | ||||
| pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; | ||||
| 
 | ||||
| #[allow(dead_code)] | ||||
| #[derive(Debug)] | ||||
							
								
								
									
										45
									
								
								legacy-lib/src/hyper_executor.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								legacy-lib/src/hyper_executor.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,45 @@ | |||
| use std::sync::Arc; | ||||
| 
 | ||||
| use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; | ||||
| use tokio::runtime::Handle; | ||||
| 
 | ||||
| use crate::{globals::Globals, CryptoSource}; | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| /// Executor for hyper
 | ||||
| pub struct LocalExecutor { | ||||
|   runtime_handle: Handle, | ||||
| } | ||||
| 
 | ||||
| impl LocalExecutor { | ||||
|   pub fn new(runtime_handle: Handle) -> Self { | ||||
|     LocalExecutor { runtime_handle } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<F> hyper::rt::Executor<F> for LocalExecutor | ||||
| where | ||||
|   F: std::future::Future + Send + 'static, | ||||
|   F::Output: Send, | ||||
| { | ||||
|   fn execute(&self, fut: F) { | ||||
|     self.runtime_handle.spawn(fut); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// build connection builder shared with proxy instances
 | ||||
| pub(crate) fn build_http_server<T>(globals: &Arc<Globals<T>>) -> ConnectionBuilder<LocalExecutor> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   let executor = LocalExecutor::new(globals.runtime_handle.clone()); | ||||
|   let mut http_server = server::conn::auto::Builder::new(executor); | ||||
|   http_server | ||||
|     .http1() | ||||
|     .keep_alive(globals.proxy_config.keepalive) | ||||
|     .pipeline_flush(true); | ||||
|   http_server | ||||
|     .http2() | ||||
|     .max_concurrent_streams(globals.proxy_config.max_concurrent_streams); | ||||
|   http_server | ||||
| } | ||||
							
								
								
									
										112
									
								
								legacy-lib/src/lib.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										112
									
								
								legacy-lib/src/lib.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,112 @@ | |||
| mod backend; | ||||
| mod certs; | ||||
| mod constants; | ||||
| mod error; | ||||
| mod globals; | ||||
| mod handler; | ||||
| mod hyper_executor; | ||||
| mod log; | ||||
| mod proxy; | ||||
| mod utils; | ||||
| 
 | ||||
| use crate::{error::*, globals::Globals, handler::HttpMessageHandlerBuilder, log::*, proxy::ProxyBuilder}; | ||||
| use futures::future::select_all; | ||||
| use hyper_executor::build_http_server; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| pub use crate::{ | ||||
|   certs::{CertsAndKeys, CryptoSource}, | ||||
|   globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}, | ||||
| }; | ||||
| pub mod reexports { | ||||
|   pub use hyper::Uri; | ||||
|   pub use rustls::{Certificate, PrivateKey}; | ||||
| } | ||||
| 
 | ||||
| #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
| compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); | ||||
| 
 | ||||
| /// Entrypoint that creates and spawns tasks of reverse proxy services
 | ||||
| pub async fn entrypoint<T>( | ||||
|   proxy_config: &ProxyConfig, | ||||
|   app_config_list: &AppConfigList<T>, | ||||
|   runtime_handle: &tokio::runtime::Handle, | ||||
|   term_notify: Option<Arc<tokio::sync::Notify>>, | ||||
| ) -> Result<()> | ||||
| where | ||||
|   T: CryptoSource + Clone + Send + Sync + 'static, | ||||
| { | ||||
|   // For initial message logging
 | ||||
|   if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { | ||||
|     info!("Listen both IPv4 and IPv6") | ||||
|   } else { | ||||
|     info!("Listen IPv4") | ||||
|   } | ||||
|   if proxy_config.http_port.is_some() { | ||||
|     info!("Listen port: {}", proxy_config.http_port.unwrap()); | ||||
|   } | ||||
|   if proxy_config.https_port.is_some() { | ||||
|     info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); | ||||
|   } | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   if proxy_config.http3 { | ||||
|     info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); | ||||
|   } | ||||
|   if !proxy_config.sni_consistency { | ||||
|     info!("Ignore consistency between TLS SNI and Host header (or Request line). Note it violates RFC."); | ||||
|   } | ||||
|   #[cfg(feature = "cache")] | ||||
|   if proxy_config.cache_enabled { | ||||
|     info!( | ||||
|       "Cache is enabled: cache dir = {:?}", | ||||
|       proxy_config.cache_dir.as_ref().unwrap() | ||||
|     ); | ||||
|   } else { | ||||
|     info!("Cache is disabled") | ||||
|   } | ||||
| 
 | ||||
|   // build global
 | ||||
|   let globals = Arc::new(Globals { | ||||
|     proxy_config: proxy_config.clone(), | ||||
|     backends: app_config_list.clone().try_into()?, | ||||
|     request_count: Default::default(), | ||||
|     runtime_handle: runtime_handle.clone(), | ||||
|     term_notify: term_notify.clone(), | ||||
|   }); | ||||
| 
 | ||||
|   // build message handler including a request forwarder
 | ||||
|   let msg_handler = Arc::new( | ||||
|     HttpMessageHandlerBuilder::default() | ||||
|       // .forwarder(Arc::new(Forwarder::new(&globals).await))
 | ||||
|       .globals(globals.clone()) | ||||
|       .build()?, | ||||
|   ); | ||||
| 
 | ||||
|   let http_server = Arc::new(build_http_server(&globals)); | ||||
| 
 | ||||
|   let addresses = globals.proxy_config.listen_sockets.clone(); | ||||
|   let futures = select_all(addresses.into_iter().map(|addr| { | ||||
|     let mut tls_enabled = false; | ||||
|     if let Some(https_port) = globals.proxy_config.https_port { | ||||
|       tls_enabled = https_port == addr.port() | ||||
|     } | ||||
| 
 | ||||
|     let proxy = ProxyBuilder::default() | ||||
|       .globals(globals.clone()) | ||||
|       .listening_on(addr) | ||||
|       .tls_enabled(tls_enabled) | ||||
|       .http_server(http_server.clone()) | ||||
|       .msg_handler(msg_handler.clone()) | ||||
|       .build() | ||||
|       .unwrap(); | ||||
| 
 | ||||
|     globals.runtime_handle.spawn(async move { proxy.start().await }) | ||||
|   })); | ||||
| 
 | ||||
|   // wait for all future
 | ||||
|   if let (Ok(Err(e)), _, _) = futures.await { | ||||
|     error!("Some proxy services are down: {}", e); | ||||
|   }; | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
							
								
								
									
										98
									
								
								legacy-lib/src/log.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										98
									
								
								legacy-lib/src/log.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,98 @@ | |||
| use crate::utils::ToCanonical; | ||||
| use hyper::header; | ||||
| use std::net::SocketAddr; | ||||
| pub use tracing::{debug, error, info, warn}; | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct MessageLog { | ||||
|   // pub tls_server_name: String,
 | ||||
|   pub client_addr: String, | ||||
|   pub method: String, | ||||
|   pub host: String, | ||||
|   pub p_and_q: String, | ||||
|   pub version: hyper::Version, | ||||
|   pub uri_scheme: String, | ||||
|   pub uri_host: String, | ||||
|   pub ua: String, | ||||
|   pub xff: String, | ||||
|   pub status: String, | ||||
|   pub upstream: String, | ||||
| } | ||||
| 
 | ||||
| impl<T> From<&hyper::Request<T>> for MessageLog { | ||||
|   fn from(req: &hyper::Request<T>) -> Self { | ||||
|     let header_mapper = |v: header::HeaderName| { | ||||
|       req | ||||
|         .headers() | ||||
|         .get(v) | ||||
|         .map_or_else(|| "", |s| s.to_str().unwrap_or("")) | ||||
|         .to_string() | ||||
|     }; | ||||
|     Self { | ||||
|       // tls_server_name: "".to_string(),
 | ||||
|       client_addr: "".to_string(), | ||||
|       method: req.method().to_string(), | ||||
|       host: header_mapper(header::HOST), | ||||
|       p_and_q: req | ||||
|         .uri() | ||||
|         .path_and_query() | ||||
|         .map_or_else(|| "", |v| v.as_str()) | ||||
|         .to_string(), | ||||
|       version: req.version(), | ||||
|       uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), | ||||
|       uri_host: req.uri().host().unwrap_or("").to_string(), | ||||
|       ua: header_mapper(header::USER_AGENT), | ||||
|       xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), | ||||
|       status: "".to_string(), | ||||
|       upstream: "".to_string(), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl MessageLog { | ||||
|   pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { | ||||
|     self.client_addr = client_addr.to_canonical().to_string(); | ||||
|     self | ||||
|   } | ||||
|   // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self {
 | ||||
|   //   self.tls_server_name = tls_server_name.to_string();
 | ||||
|   //   self
 | ||||
|   // }
 | ||||
|   pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { | ||||
|     self.status = status_code.to_string(); | ||||
|     self | ||||
|   } | ||||
|   pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { | ||||
|     self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); | ||||
|     self | ||||
|   } | ||||
|   pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { | ||||
|     self.upstream = upstream.to_string(); | ||||
|     self | ||||
|   } | ||||
| 
 | ||||
|   pub fn output(&self) { | ||||
|     info!( | ||||
|       "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", | ||||
|       if !self.host.is_empty() { | ||||
|         self.host.as_str() | ||||
|       } else { | ||||
|         self.uri_host.as_str() | ||||
|       }, | ||||
|       self.client_addr, | ||||
|       self.method, | ||||
|       self.p_and_q, | ||||
|       self.version, | ||||
|       self.status, | ||||
|       if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { | ||||
|         format!("{}://{}", self.uri_scheme, self.uri_host) | ||||
|       } else { | ||||
|         "".to_string() | ||||
|       }, | ||||
|       self.ua, | ||||
|       self.xff, | ||||
|       self.upstream, | ||||
|       // self.tls_server_name
 | ||||
|     ); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										42
									
								
								legacy-lib/src/proxy/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								legacy-lib/src/proxy/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,42 @@ | |||
| mod crypto_service; | ||||
| mod proxy_client_cert; | ||||
| #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
| mod proxy_h3; | ||||
| mod proxy_main; | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| mod proxy_quic_quinn; | ||||
| #[cfg(feature = "http3-s2n")] | ||||
| mod proxy_quic_s2n; | ||||
| mod proxy_tls; | ||||
| mod socket; | ||||
| 
 | ||||
| use crate::error::*; | ||||
| use http::{Response, StatusCode}; | ||||
| use http_body_util::{combinators, BodyExt, Either, Empty}; | ||||
| use hyper::body::{Bytes, Incoming}; | ||||
| 
 | ||||
| pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError}; | ||||
| 
 | ||||
| /// Type for synthetic boxed body
 | ||||
| type BoxBody = combinators::BoxBody<Bytes, hyper::Error>; | ||||
| /// Type for either passthrough body or synthetic body
 | ||||
| type EitherBody = Either<Incoming, BoxBody>; | ||||
| 
 | ||||
| /// helper function to build http response with passthrough body
 | ||||
| fn passthrough_response(response: Response<Incoming>) -> Result<Response<EitherBody>> { | ||||
|   Ok(response.map(EitherBody::Left)) | ||||
| } | ||||
| 
 | ||||
| /// build http response with status code of 4xx and 5xx
 | ||||
| fn synthetic_error_response(status_code: StatusCode) -> Result<Response<EitherBody>> { | ||||
|   let res = Response::builder() | ||||
|     .status(status_code) | ||||
|     .body(EitherBody::Right(BoxBody::new(empty()))) | ||||
|     .unwrap(); | ||||
|   Ok(res) | ||||
| } | ||||
| 
 | ||||
| /// helper function to build a empty body
 | ||||
| fn empty() -> BoxBody { | ||||
|   Empty::<Bytes>::new().map_err(|never| match never {}).boxed() | ||||
| } | ||||
							
								
								
									
										186
									
								
								legacy-lib/src/proxy/proxy_h3.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										186
									
								
								legacy-lib/src/proxy/proxy_h3.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,186 @@ | |||
| use super::Proxy; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; | ||||
| use bytes::{Buf, Bytes}; | ||||
| use futures::Stream; | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | ||||
| use http::{Request, Response}; | ||||
| use http_body_util::{BodyExt, BodyStream, StreamBody}; | ||||
| use hyper::body::{Body, Incoming}; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| #[cfg(feature = "http3-s2n")] | ||||
| use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | ||||
| use std::net::SocketAddr; | ||||
| use tokio::time::{timeout, Duration}; | ||||
| 
 | ||||
| impl<U> Proxy<U> | ||||
| where | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn connection_serve_h3<C>( | ||||
|     &self, | ||||
|     quic_connection: C, | ||||
|     tls_server_name: ServerNameBytesExp, | ||||
|     client_addr: SocketAddr, | ||||
|   ) -> Result<()> | ||||
|   where | ||||
|     C: ConnectionQuic<Bytes>, | ||||
|     <C as ConnectionQuic<Bytes>>::BidiStream: BidiStream<Bytes> + Send + 'static, | ||||
|     <<C as ConnectionQuic<Bytes>>::BidiStream as BidiStream<Bytes>>::RecvStream: Send, | ||||
|     <<C as ConnectionQuic<Bytes>>::BidiStream as BidiStream<Bytes>>::SendStream: Send, | ||||
|   { | ||||
|     let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; | ||||
|     info!( | ||||
|       "QUIC/HTTP3 connection established from {:?} {:?}", | ||||
|       client_addr, tls_server_name | ||||
|     ); | ||||
|     // TODO: Is here enough to fetch server_name from NewConnection?
 | ||||
|     // to avoid deep nested call from listener_service_h3
 | ||||
|     loop { | ||||
|       // this routine follows hyperium/h3 examples https://github.com/hyperium/h3/blob/master/examples/server.rs
 | ||||
|       match h3_conn.accept().await { | ||||
|         Ok(None) => { | ||||
|           break; | ||||
|         } | ||||
|         Err(e) => { | ||||
|           warn!("HTTP/3 error on accept incoming connection: {}", e); | ||||
|           match e.get_error_level() { | ||||
|             h3::error::ErrorLevel::ConnectionError => break, | ||||
|             h3::error::ErrorLevel::StreamError => continue, | ||||
|           } | ||||
|         } | ||||
|         Ok(Some((req, stream))) => { | ||||
|           // We consider the connection count separately from the stream count.
 | ||||
|           // Max clients for h1/h2 = max 'stream' for h3.
 | ||||
|           let request_count = self.globals.request_count.clone(); | ||||
|           if request_count.increment() > self.globals.proxy_config.max_clients { | ||||
|             request_count.decrement(); | ||||
|             h3_conn.shutdown(0).await?; | ||||
|             break; | ||||
|           } | ||||
|           debug!("Request incoming: current # {}", request_count.current()); | ||||
| 
 | ||||
|           let self_inner = self.clone(); | ||||
|           let tls_server_name_inner = tls_server_name.clone(); | ||||
|           self.globals.runtime_handle.spawn(async move { | ||||
|             if let Err(e) = timeout( | ||||
|               self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
 | ||||
|               self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), | ||||
|             ) | ||||
|             .await | ||||
|             { | ||||
|               error!("HTTP/3 failed to process stream: {}", e); | ||||
|             } | ||||
|             request_count.decrement(); | ||||
|             debug!("Request processed: current # {}", request_count.current()); | ||||
|           }); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   async fn stream_serve_h3<S>( | ||||
|     &self, | ||||
|     req: Request<()>, | ||||
|     stream: RequestStream<S, Bytes>, | ||||
|     client_addr: SocketAddr, | ||||
|     tls_server_name: ServerNameBytesExp, | ||||
|   ) -> Result<()> | ||||
|   where | ||||
|     S: BidiStream<Bytes> + Send + 'static, | ||||
|     <S as BidiStream<Bytes>>::RecvStream: Send, | ||||
|   { | ||||
|     println!("stream_serve_h3"); | ||||
|     let (req_parts, _) = req.into_parts(); | ||||
|     // split stream and async body handling
 | ||||
|     let (mut send_stream, mut recv_stream) = stream.split(); | ||||
| 
 | ||||
|     // let max_body_size = self.globals.proxy_config.h3_request_max_body_size;
 | ||||
|     // // let max = body_stream.size_hint().upper().unwrap_or(u64::MAX);
 | ||||
|     // // if max > max_body_size as u64 {
 | ||||
|     // //   return Err(HttpError::TooLargeRequestBody);
 | ||||
|     // // }
 | ||||
|     // let new_req = Request::from_parts(req_parts, body_stream);
 | ||||
| 
 | ||||
|     ////////////////////
 | ||||
|     // TODO: TODO: TODO: TODO:
 | ||||
|     // TODO: Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside.
 | ||||
|     // Thus, we need to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of
 | ||||
|     // Either<Incoming, IncomingLike> as body.
 | ||||
|     // Also, the downstream from the backend handler could be Incoming, but will be wrapped as Either<Incoming, ()/Empty> as well due to H3.
 | ||||
|     // Result<Either<_,_>, E> type includes E as HttpError to generate the status code and related Response<BoxBody>.
 | ||||
|     // Thus to handle synthetic error messages in BoxBody, the serve() function outputs Response<Either<Either<Incoming, ()/Empty>, BoxBody>>>.
 | ||||
|     ////////////////////
 | ||||
| 
 | ||||
|     // // generate streamed body with trailers using channel
 | ||||
|     // let (body_sender, req_body) = Incoming::channel();
 | ||||
| 
 | ||||
|     // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1
 | ||||
|     // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn.
 | ||||
|     let max_body_size = self.globals.proxy_config.h3_request_max_body_size; | ||||
|     self.globals.runtime_handle.spawn(async move { | ||||
|       // let mut sender = body_sender;
 | ||||
|       let mut size = 0usize; | ||||
|       while let Some(mut body) = recv_stream.recv_data().await? { | ||||
|         debug!("HTTP/3 incoming request body: remaining {}", body.remaining()); | ||||
|         size += body.remaining(); | ||||
|         if size > max_body_size { | ||||
|           error!( | ||||
|             "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", | ||||
|             size, max_body_size | ||||
|           ); | ||||
|           return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); | ||||
|         } | ||||
|         // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes
 | ||||
|         // sender.send_data(body.copy_to_bytes(body.remaining())).await?;
 | ||||
|       } | ||||
| 
 | ||||
|       // trailers: use inner for work around. (directly get trailer)
 | ||||
|       let trailers = recv_stream.as_mut().recv_trailers().await?; | ||||
|       if trailers.is_some() { | ||||
|         debug!("HTTP/3 incoming request trailers"); | ||||
|         // sender.send_trailers(trailers.unwrap()).await?;
 | ||||
|       } | ||||
|       Ok(()) | ||||
|     }); | ||||
| 
 | ||||
|     // let new_req: Request<Incoming> = Request::from_parts(req_parts, req_body);
 | ||||
|     // let res = self
 | ||||
|     //   .msg_handler
 | ||||
|     //   .clone()
 | ||||
|     //   .handle_request(
 | ||||
|     //     new_req,
 | ||||
|     //     client_addr,
 | ||||
|     //     self.listening_on,
 | ||||
|     //     self.tls_enabled,
 | ||||
|     //     Some(tls_server_name),
 | ||||
|     //   )
 | ||||
|     //   .await?;
 | ||||
| 
 | ||||
|     // let (new_res_parts, new_body) = res.into_parts();
 | ||||
|     // let new_res = Response::from_parts(new_res_parts, ());
 | ||||
| 
 | ||||
|     // match send_stream.send_response(new_res).await {
 | ||||
|     //   Ok(_) => {
 | ||||
|     //     debug!("HTTP/3 response to connection successful");
 | ||||
|     //     // aggregate body without copying
 | ||||
|     //     let body_data = new_body.collect().await?.aggregate();
 | ||||
| 
 | ||||
|     //     // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes
 | ||||
|     //     send_stream
 | ||||
|     //       .send_data(body_data.copy_to_bytes(body_data.remaining()))
 | ||||
|     //       .await?;
 | ||||
| 
 | ||||
|     //     // TODO: needs handling trailer? should be included in body from handler.
 | ||||
|     //   }
 | ||||
|     //   Err(err) => {
 | ||||
|     //     error!("Unable to send response to connection peer: {:?}", err);
 | ||||
|     //   }
 | ||||
|     // }
 | ||||
|     // Ok(send_stream.finish().await?)
 | ||||
|     todo!() | ||||
|   } | ||||
| } | ||||
							
								
								
									
										150
									
								
								legacy-lib/src/proxy/proxy_main.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										150
									
								
								legacy-lib/src/proxy/proxy_main.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,150 @@ | |||
| use super::{passthrough_response, socket::bind_tcp_socket, synthetic_error_response, EitherBody}; | ||||
| use crate::{ | ||||
|   certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, hyper_executor::LocalExecutor, log::*, | ||||
|   utils::ServerNameBytesExp, | ||||
| }; | ||||
| use derive_builder::{self, Builder}; | ||||
| use http::{Request, StatusCode}; | ||||
| use hyper::{ | ||||
|   body::Incoming, | ||||
|   rt::{Read, Write}, | ||||
|   service::service_fn, | ||||
| }; | ||||
| use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; | ||||
| use std::{net::SocketAddr, sync::Arc}; | ||||
| use tokio::time::{timeout, Duration}; | ||||
| 
 | ||||
| #[derive(Clone, Builder)] | ||||
| /// Proxy main object
 | ||||
| pub struct Proxy<U> | ||||
| where | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub listening_on: SocketAddr, | ||||
|   pub tls_enabled: bool, // TCP待受がTLSかどうか
 | ||||
|   /// hyper server receiving http request
 | ||||
|   pub http_server: Arc<ConnectionBuilder<LocalExecutor>>, | ||||
|   // pub msg_handler: Arc<HttpMessageHandler<U>>,
 | ||||
|   pub msg_handler: Arc<HttpMessageHandler<U>>, | ||||
|   pub globals: Arc<Globals<U>>, | ||||
| } | ||||
| 
 | ||||
| /// Wrapper function to handle request
 | ||||
| async fn serve_request<U>( | ||||
|   req: Request<Incoming>, | ||||
|   // handler: Arc<HttpMessageHandler<T, U>>,
 | ||||
|   handler: Arc<HttpMessageHandler<U>>, | ||||
|   client_addr: SocketAddr, | ||||
|   listen_addr: SocketAddr, | ||||
|   tls_enabled: bool, | ||||
|   tls_server_name: Option<ServerNameBytesExp>, | ||||
| ) -> Result<hyper::Response<EitherBody>> | ||||
| where | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   match handler | ||||
|     .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) | ||||
|     .await? | ||||
|   { | ||||
|     Ok(res) => passthrough_response(res), | ||||
|     Err(e) => synthetic_error_response(StatusCode::from(e)), | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<U> Proxy<U> | ||||
| where | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone + Sync + Send, | ||||
| { | ||||
|   /// Serves requests from clients
 | ||||
|   pub(super) fn serve_connection<I>( | ||||
|     &self, | ||||
|     stream: I, | ||||
|     peer_addr: SocketAddr, | ||||
|     tls_server_name: Option<ServerNameBytesExp>, | ||||
|   ) where | ||||
|     I: Read + Write + Send + Unpin + 'static, | ||||
|   { | ||||
|     let request_count = self.globals.request_count.clone(); | ||||
|     if request_count.increment() > self.globals.proxy_config.max_clients { | ||||
|       request_count.decrement(); | ||||
|       return; | ||||
|     } | ||||
|     debug!("Request incoming: current # {}", request_count.current()); | ||||
| 
 | ||||
|     let server_clone = self.http_server.clone(); | ||||
|     let msg_handler_clone = self.msg_handler.clone(); | ||||
|     let timeout_sec = self.globals.proxy_config.proxy_timeout; | ||||
|     let tls_enabled = self.tls_enabled; | ||||
|     let listening_on = self.listening_on; | ||||
|     self.globals.runtime_handle.clone().spawn(async move { | ||||
|       timeout( | ||||
|         timeout_sec + Duration::from_secs(1), | ||||
|         server_clone.serve_connection_with_upgrades( | ||||
|           stream, | ||||
|           service_fn(move |req: Request<Incoming>| { | ||||
|             serve_request( | ||||
|               req, | ||||
|               msg_handler_clone.clone(), | ||||
|               peer_addr, | ||||
|               listening_on, | ||||
|               tls_enabled, | ||||
|               tls_server_name.clone(), | ||||
|             ) | ||||
|           }), | ||||
|         ), | ||||
|       ) | ||||
|       .await | ||||
|       .ok(); | ||||
| 
 | ||||
|       request_count.decrement(); | ||||
|       debug!("Request processed: current # {}", request_count.current()); | ||||
|     }); | ||||
|   } | ||||
| 
 | ||||
|   /// Start without TLS (HTTP cleartext)
 | ||||
|   async fn start_without_tls(&self) -> Result<()> { | ||||
|     let listener_service = async { | ||||
|       let tcp_socket = bind_tcp_socket(&self.listening_on)?; | ||||
|       let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; | ||||
|       info!("Start TCP proxy serving with HTTP request for configured host names"); | ||||
|       while let Ok((stream, client_addr)) = tcp_listener.accept().await { | ||||
|         self.serve_connection(TokioIo::new(stream), client_addr, None); | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|     }; | ||||
|     listener_service.await?; | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   /// Entrypoint for HTTP/1.1 and HTTP/2 servers
 | ||||
|   pub async fn start(&self) -> Result<()> { | ||||
|     let proxy_service = async { | ||||
|       if self.tls_enabled { | ||||
|         self.start_with_tls().await | ||||
|       } else { | ||||
|         self.start_without_tls().await | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     match &self.globals.term_notify { | ||||
|       Some(term) => { | ||||
|         tokio::select! { | ||||
|           _ = proxy_service => { | ||||
|             warn!("Proxy service got down"); | ||||
|           } | ||||
|           _ = term.notified() => { | ||||
|             info!("Proxy service listening on {} receives term signal", self.listening_on); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       None => { | ||||
|         proxy_service.await?; | ||||
|         warn!("Proxy service got down"); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										124
									
								
								legacy-lib/src/proxy/proxy_quic_quinn.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										124
									
								
								legacy-lib/src/proxy/proxy_quic_quinn.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,124 @@ | |||
| use super::socket::bind_udp_socket; | ||||
| use super::{ | ||||
|   crypto_service::{ServerCrypto, ServerCryptoBase}, | ||||
|   proxy_main::Proxy, | ||||
| }; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; | ||||
| use hot_reload::ReloaderReceiver; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; | ||||
| use rustls::ServerConfig; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| impl<U> Proxy<U> | ||||
| where | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn listener_service_h3( | ||||
|     &self, | ||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||
|   ) -> Result<()> { | ||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names [quinn]"); | ||||
|     // first set as null config server
 | ||||
|     let rustls_server_config = ServerConfig::builder() | ||||
|       .with_safe_default_cipher_suites() | ||||
|       .with_safe_default_kx_groups() | ||||
|       .with_protocol_versions(&[&rustls::version::TLS13])? | ||||
|       .with_no_client_auth() | ||||
|       .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); | ||||
| 
 | ||||
|     let mut transport_config_quic = TransportConfig::default(); | ||||
|     transport_config_quic | ||||
|       .max_concurrent_bidi_streams(self.globals.proxy_config.h3_max_concurrent_bidistream.into()) | ||||
|       .max_concurrent_uni_streams(self.globals.proxy_config.h3_max_concurrent_unistream.into()) | ||||
|       .max_idle_timeout( | ||||
|         self | ||||
|           .globals | ||||
|           .proxy_config | ||||
|           .h3_max_idle_timeout | ||||
|           .map(|v| quinn::IdleTimeout::try_from(v).unwrap()), | ||||
|       ); | ||||
| 
 | ||||
|     let mut server_config_h3 = QuicServerConfig::with_crypto(Arc::new(rustls_server_config)); | ||||
|     server_config_h3.transport = Arc::new(transport_config_quic); | ||||
|     server_config_h3.concurrent_connections(self.globals.proxy_config.h3_max_concurrent_connections); | ||||
| 
 | ||||
|     // To reuse address
 | ||||
|     let udp_socket = bind_udp_socket(&self.listening_on)?; | ||||
|     let runtime = quinn::default_runtime() | ||||
|       .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::Other, "No async runtime found"))?; | ||||
|     let endpoint = Endpoint::new( | ||||
|       quinn::EndpointConfig::default(), | ||||
|       Some(server_config_h3), | ||||
|       udp_socket, | ||||
|       runtime, | ||||
|     )?; | ||||
| 
 | ||||
|     let mut server_crypto: Option<Arc<ServerCrypto>> = None; | ||||
|     loop { | ||||
|       tokio::select! { | ||||
|         new_conn = endpoint.accept() => { | ||||
|           if server_crypto.is_none() || new_conn.is_none() { | ||||
|             continue; | ||||
|           } | ||||
|           let mut conn: quinn::Connecting = new_conn.unwrap(); | ||||
|           let Ok(hsd) = conn.handshake_data().await else { | ||||
|             continue
 | ||||
|           }; | ||||
| 
 | ||||
|           let Ok(hsd_downcast) = hsd.downcast::<HandshakeData>() else { | ||||
|             continue
 | ||||
|           }; | ||||
|           let Some(new_server_name) = hsd_downcast.server_name else { | ||||
|             warn!("HTTP/3 no SNI is given"); | ||||
|             continue; | ||||
|           }; | ||||
|           debug!( | ||||
|             "HTTP/3 connection incoming (SNI {:?})", | ||||
|             new_server_name | ||||
|           ); | ||||
|           // TODO: server_nameをここで出してどんどん深く投げていくのは効率が悪い。connecting -> connectionsの後でいいのでは?
 | ||||
|           // TODO: 通常のTLSと同じenumか何かにまとめたい
 | ||||
|           let self_clone = self.clone(); | ||||
|           self.globals.runtime_handle.spawn(async move { | ||||
|             let client_addr = conn.remote_address(); | ||||
|             let quic_connection = match conn.await { | ||||
|               Ok(new_conn) => { | ||||
|                 info!("New connection established"); | ||||
|                 h3_quinn::Connection::new(new_conn) | ||||
|               }, | ||||
|               Err(e) => { | ||||
|                 warn!("QUIC accepting connection failed: {:?}", e); | ||||
|                 return Err(RpxyError::QuicConn(e)); | ||||
|               } | ||||
|             }; | ||||
|             // Timeout is based on underlying quic
 | ||||
|             if let Err(e) = self_clone.connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr).await { | ||||
|               warn!("QUIC or HTTP/3 connection failed: {}", e); | ||||
|             }; | ||||
|             Ok(()) | ||||
|           }); | ||||
|         } | ||||
|         _ = server_crypto_rx.changed() => { | ||||
|           if server_crypto_rx.borrow().is_none() { | ||||
|             error!("Reloader is broken"); | ||||
|             break; | ||||
|           } | ||||
|           let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); | ||||
| 
 | ||||
|           server_crypto = (&cert_keys_map).try_into().ok(); | ||||
|           let Some(inner) = server_crypto.clone() else { | ||||
|             error!("Failed to update server crypto for h3"); | ||||
|             break; | ||||
|           }; | ||||
|           endpoint.set_server_config(Some(QuicServerConfig::with_crypto(inner.clone().inner_global_no_client_auth.clone()))); | ||||
| 
 | ||||
|         } | ||||
|         else => break
 | ||||
|       } | ||||
|     } | ||||
|     endpoint.wait_idle().await; | ||||
|     Ok(()) as Result<()> | ||||
|   } | ||||
| } | ||||
							
								
								
									
										135
									
								
								legacy-lib/src/proxy/proxy_quic_s2n.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								legacy-lib/src/proxy/proxy_quic_s2n.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,135 @@ | |||
| use super::{ | ||||
|   crypto_service::{ServerCrypto, ServerCryptoBase}, | ||||
|   proxy_main::Proxy, | ||||
| }; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; | ||||
| use hot_reload::ReloaderReceiver; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| use s2n_quic::provider; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| impl<U> Proxy<U> | ||||
| where | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn listener_service_h3( | ||||
|     &self, | ||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||
|   ) -> Result<()> { | ||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names [s2n-quic]"); | ||||
| 
 | ||||
|     // initially wait for receipt
 | ||||
|     let mut server_crypto: Option<Arc<ServerCrypto>> = { | ||||
|       let _ = server_crypto_rx.changed().await; | ||||
|       let sc = self.receive_server_crypto(server_crypto_rx.clone())?; | ||||
|       Some(sc) | ||||
|     }; | ||||
| 
 | ||||
|     // event loop
 | ||||
|     loop { | ||||
|       tokio::select! { | ||||
|         v = self.listener_service_h3_inner(&server_crypto) => { | ||||
|           if let Err(e) = v { | ||||
|             error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); | ||||
|             break; | ||||
|           } | ||||
|         } | ||||
|         _ = server_crypto_rx.changed() => { | ||||
|           server_crypto = match self.receive_server_crypto(server_crypto_rx.clone()) { | ||||
|             Ok(sc) => Some(sc), | ||||
|             Err(e) => { | ||||
|               error!("{e}"); | ||||
|               break; | ||||
|             } | ||||
|           }; | ||||
|         } | ||||
|         else => break
 | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   fn receive_server_crypto(&self, server_crypto_rx: ReloaderReceiver<ServerCryptoBase>) -> Result<Arc<ServerCrypto>> { | ||||
|     let cert_keys_map = server_crypto_rx.borrow().clone().ok_or_else(|| { | ||||
|       error!("Reloader is broken"); | ||||
|       RpxyError::Other(anyhow!("Reloader is broken")) | ||||
|     })?; | ||||
| 
 | ||||
|     let server_crypto: Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok(); | ||||
|     server_crypto.ok_or_else(|| { | ||||
|       error!("Failed to update server crypto for h3 [s2n-quic]"); | ||||
|       RpxyError::Other(anyhow!("Failed to update server crypto for h3 [s2n-quic]")) | ||||
|     }) | ||||
|   } | ||||
| 
 | ||||
|   async fn listener_service_h3_inner(&self, server_crypto: &Option<Arc<ServerCrypto>>) -> Result<()> { | ||||
|     // setup UDP socket
 | ||||
|     let io = provider::io::tokio::Builder::default() | ||||
|       .with_receive_address(self.listening_on)? | ||||
|       .with_reuse_port()? | ||||
|       .build()?; | ||||
| 
 | ||||
|     // setup limits
 | ||||
|     let mut limits = provider::limits::Limits::default() | ||||
|       .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64) | ||||
|       .map_err(|e| anyhow!(e))?; | ||||
|     limits = if let Some(v) = self.globals.proxy_config.h3_max_idle_timeout { | ||||
|       limits.with_max_idle_timeout(v).map_err(|e| anyhow!(e))? | ||||
|     } else { | ||||
|       limits | ||||
|     }; | ||||
| 
 | ||||
|     // setup tls
 | ||||
|     let Some(server_crypto) = server_crypto else { | ||||
|       warn!("No server crypto is given [s2n-quic]"); | ||||
|       return Err(RpxyError::Other(anyhow!("No server crypto is given [s2n-quic]"))); | ||||
|     }; | ||||
|     let tls = server_crypto.inner_global_no_client_auth.clone(); | ||||
| 
 | ||||
|     let mut server = s2n_quic::Server::builder() | ||||
|       .with_tls(tls) | ||||
|       .map_err(|e| anyhow::anyhow!(e))? | ||||
|       .with_io(io) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_limits(limits) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .start() | ||||
|       .map_err(|e| anyhow!(e))?; | ||||
| 
 | ||||
|     // quic event loop. this immediately cancels when crypto is updated by tokio::select!
 | ||||
|     while let Some(new_conn) = server.accept().await { | ||||
|       debug!("New QUIC connection established"); | ||||
|       let Ok(Some(new_server_name)) = new_conn.server_name() else { | ||||
|         warn!("HTTP/3 no SNI is given"); | ||||
|         continue; | ||||
|       }; | ||||
|       debug!("HTTP/3 connection incoming (SNI {:?})", new_server_name); | ||||
|       let self_clone = self.clone(); | ||||
| 
 | ||||
|       self.globals.runtime_handle.spawn(async move { | ||||
|         let client_addr = new_conn.remote_addr()?; | ||||
|         let quic_connection = s2n_quic_h3::Connection::new(new_conn); | ||||
|         // Timeout is based on underlying quic
 | ||||
|         if let Err(e) = self_clone | ||||
|           .connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr) | ||||
|           .await | ||||
|         { | ||||
|           warn!("QUIC or HTTP/3 connection failed: {}", e); | ||||
|         }; | ||||
|         Ok(()) as Result<()> | ||||
|       }); | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| } | ||||
|  | @ -1,25 +1,21 @@ | |||
| use super::{ | ||||
|   crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, | ||||
|   proxy_main::{LocalExecutor, Proxy}, | ||||
|   proxy_main::Proxy, | ||||
|   socket::bind_tcp_socket, | ||||
| }; | ||||
| use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName}; | ||||
| use hot_reload::{ReloaderReceiver, ReloaderService}; | ||||
| use hyper::{client::connect::Connect, server::conn::Http}; | ||||
| use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; | ||||
| use std::sync::Arc; | ||||
| use tokio::time::{timeout, Duration}; | ||||
| 
 | ||||
| impl<T, U> Proxy<T, U> | ||||
| impl<U> Proxy<U> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   // T: Connect + Clone + Sync + Send + 'static,
 | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   // TCP Listener Service, i.e., http/2 and http/1.1
 | ||||
|   async fn listener_service( | ||||
|     &self, | ||||
|     server: Http<LocalExecutor>, | ||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||
|   ) -> Result<()> { | ||||
|   async fn listener_service(&self, mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>) -> Result<()> { | ||||
|     let tcp_socket = bind_tcp_socket(&self.listening_on)?; | ||||
|     let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; | ||||
|     info!("Start TCP proxy serving with HTTPS request for configured host names"); | ||||
|  | @ -33,7 +29,6 @@ where | |||
|           } | ||||
|           let (raw_stream, client_addr) = tcp_cnx.unwrap(); | ||||
|           let sc_map_inner = server_crypto_map.clone(); | ||||
|           let server_clone = server.clone(); | ||||
|           let self_inner = self.clone(); | ||||
| 
 | ||||
|           // spawns async handshake to avoid blocking thread by sequential handshake.
 | ||||
|  | @ -55,30 +50,27 @@ where | |||
|               return Err(RpxyError::Proxy(format!("No TLS serving app for {:?}", server_name.unwrap()))); | ||||
|             } | ||||
|             let stream = match start.into_stream(server_crypto.unwrap().clone()).await { | ||||
|               Ok(s) => s, | ||||
|               Ok(s) => TokioIo::new(s), | ||||
|               Err(e) => { | ||||
|                 return Err(RpxyError::Proxy(format!("Failed to handshake TLS: {e}"))); | ||||
|               } | ||||
|             }; | ||||
|             self_inner.client_serve(stream, server_clone, client_addr, server_name_in_bytes); | ||||
|             self_inner.serve_connection(stream, client_addr, server_name_in_bytes); | ||||
|             Ok(()) | ||||
|           }; | ||||
| 
 | ||||
|           self.globals.runtime_handle.spawn( async move { | ||||
|             // timeout is introduced to avoid get stuck here.
 | ||||
|             match timeout( | ||||
|             let Ok(v) = timeout( | ||||
|               Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), | ||||
|               handshake_fut | ||||
|             ).await { | ||||
|               Ok(a) => { | ||||
|                 if let Err(e) = a { | ||||
|             ).await else { | ||||
|               error!("Timeout to handshake TLS"); | ||||
|               return; | ||||
|             }; | ||||
|             if let Err(e) = v { | ||||
|               error!("{}", e); | ||||
|             } | ||||
|               }, | ||||
|               Err(e) => { | ||||
|                 error!("Timeout to handshake TLS: {}", e); | ||||
|               } | ||||
|             }; | ||||
|           }); | ||||
|         } | ||||
|         _ = server_crypto_rx.changed() => { | ||||
|  | @ -99,7 +91,7 @@ where | |||
|     Ok(()) as Result<()> | ||||
|   } | ||||
| 
 | ||||
|   pub async fn start_with_tls(self, server: Http<LocalExecutor>) -> Result<()> { | ||||
|   pub async fn start_with_tls(&self) -> Result<()> { | ||||
|     let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<CryptoReloader<U>, ServerCryptoBase>::new( | ||||
|       &self.globals.clone(), | ||||
|       CERTS_WATCH_DELAY_SECS, | ||||
|  | @ -114,7 +106,7 @@ where | |||
|         _= cert_reloader_service.start() => { | ||||
|           error!("Cert service for TLS exited"); | ||||
|         }, | ||||
|         _ = self.listener_service(server, cert_reloader_rx) => { | ||||
|         _ = self.listener_service(cert_reloader_rx) => { | ||||
|           error!("TCP proxy service for TLS exited"); | ||||
|         }, | ||||
|         else => { | ||||
|  | @ -131,7 +123,7 @@ where | |||
|           _= cert_reloader_service.start() => { | ||||
|             error!("Cert service for TLS exited"); | ||||
|           }, | ||||
|           _ = self.listener_service(server, cert_reloader_rx.clone()) => { | ||||
|           _ = self.listener_service(cert_reloader_rx.clone()) => { | ||||
|             error!("TCP proxy service for TLS exited"); | ||||
|           }, | ||||
|           _= self.listener_service_h3(cert_reloader_rx) => { | ||||
|  | @ -148,7 +140,7 @@ where | |||
|           _= cert_reloader_service.start() => { | ||||
|             error!("Cert service for TLS exited"); | ||||
|           }, | ||||
|           _ = self.listener_service(server, cert_reloader_rx) => { | ||||
|           _ = self.listener_service(cert_reloader_rx) => { | ||||
|             error!("TCP proxy service for TLS exited"); | ||||
|           }, | ||||
|           else => { | ||||
							
								
								
									
										46
									
								
								legacy-lib/src/proxy/socket.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								legacy-lib/src/proxy/socket.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,46 @@ | |||
| use crate::{error::*, log::*}; | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| use socket2::{Domain, Protocol, Socket, Type}; | ||||
| use std::net::SocketAddr; | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| use std::net::UdpSocket; | ||||
| use tokio::net::TcpSocket; | ||||
| 
 | ||||
| /// Bind TCP socket to the given `SocketAddr`, and returns the TCP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | ||||
| /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | ||||
| pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> Result<TcpSocket> { | ||||
|   let tcp_socket = if listening_on.is_ipv6() { | ||||
|     TcpSocket::new_v6() | ||||
|   } else { | ||||
|     TcpSocket::new_v4() | ||||
|   }?; | ||||
|   tcp_socket.set_reuseaddr(true)?; | ||||
|   tcp_socket.set_reuseport(true)?; | ||||
|   if let Err(e) = tcp_socket.bind(*listening_on) { | ||||
|     error!("Failed to bind TCP socket: {}", e); | ||||
|     return Err(RpxyError::Io(e)); | ||||
|   }; | ||||
|   Ok(tcp_socket) | ||||
| } | ||||
| 
 | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| /// Bind UDP socket to the given `SocketAddr`, and returns the UDP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | ||||
| /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | ||||
| pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> Result<UdpSocket> { | ||||
|   let socket = if listening_on.is_ipv6() { | ||||
|     Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) | ||||
|   } else { | ||||
|     Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)) | ||||
|   }?; | ||||
|   socket.set_reuse_address(true)?; // This isn't necessary?
 | ||||
|   socket.set_reuse_port(true)?; | ||||
|   socket.set_nonblocking(true)?; // This was made true inside quinn. so this line isn't necessary here. but just in case.
 | ||||
| 
 | ||||
|   if let Err(e) = socket.bind(&(*listening_on).into()) { | ||||
|     error!("Failed to bind UDP socket: {}", e); | ||||
|     return Err(RpxyError::Io(e)); | ||||
|   }; | ||||
|   let udp_socket: UdpSocket = socket.into(); | ||||
| 
 | ||||
|   Ok(udp_socket) | ||||
| } | ||||
|  | @ -1,6 +1,6 @@ | |||
| [package] | ||||
| name = "rpxy" | ||||
| version = "0.6.2" | ||||
| version = "0.7.0-alpha.0" | ||||
| authors = ["Jun Kurihara"] | ||||
| homepage = "https://github.com/junkurihara/rust-rpxy" | ||||
| repository = "https://github.com/junkurihara/rust-rpxy" | ||||
|  | @ -12,9 +12,12 @@ publish = false | |||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||
| 
 | ||||
| [features] | ||||
| default = ["http3-quinn", "cache"] | ||||
| default = ["http3-quinn", "cache", "native-tls-backend"] | ||||
| http3-quinn = ["rpxy-lib/http3-quinn"] | ||||
| http3-s2n = ["rpxy-lib/http3-s2n"] | ||||
| native-tls-backend = ["rpxy-lib/native-tls-backend"] | ||||
| # Not yet implemented | ||||
| rustls-backend = ["rpxy-lib/rustls-backend"] | ||||
| cache = ["rpxy-lib/cache"] | ||||
| native-roots = ["rpxy-lib/native-roots"] | ||||
| 
 | ||||
|  | @ -25,9 +28,9 @@ rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ | |||
| 
 | ||||
| anyhow = "1.0.75" | ||||
| rustc-hash = "1.1.0" | ||||
| serde = { version = "1.0.192", default-features = false, features = ["derive"] } | ||||
| serde = { version = "1.0.193", default-features = false, features = ["derive"] } | ||||
| derive_builder = "0.12.0" | ||||
| tokio = { version = "1.33.0", default-features = false, features = [ | ||||
| tokio = { version = "1.35.0", default-features = false, features = [ | ||||
|   "net", | ||||
|   "rt-multi-thread", | ||||
|   "time", | ||||
|  | @ -35,17 +38,17 @@ tokio = { version = "1.33.0", default-features = false, features = [ | |||
|   "macros", | ||||
| ] } | ||||
| async-trait = "0.1.74" | ||||
| rustls-pemfile = "1.0.3" | ||||
| rustls-pemfile = "1.0.4" | ||||
| mimalloc = { version = "*", default-features = false } | ||||
| 
 | ||||
| # config | ||||
| clap = { version = "4.4.7", features = ["std", "cargo", "wrap_help"] } | ||||
| toml = { version = "0.8", default-features = false, features = ["parse"] } | ||||
| clap = { version = "4.4.11", features = ["std", "cargo", "wrap_help"] } | ||||
| toml = { version = "0.8.8", default-features = false, features = ["parse"] } | ||||
| hot_reload = "0.1.4" | ||||
| 
 | ||||
| # logging | ||||
| tracing = { version = "0.1.40" } | ||||
| tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } | ||||
| tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } | ||||
| 
 | ||||
| 
 | ||||
| [dev-dependencies] | ||||
|  |  | |||
|  | @ -12,10 +12,13 @@ pub fn init_logger() { | |||
|     .with_level(true) | ||||
|     .compact(); | ||||
| 
 | ||||
|   // This limits the logger to emits only rpxy crate
 | ||||
|   // This limits the logger to emits only proxy crate
 | ||||
|   let pkg_name = env!("CARGO_PKG_NAME").replace('-', "_"); | ||||
|   let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); | ||||
|   let filter_layer = EnvFilter::new(format!("{}={}", env!("CARGO_PKG_NAME"), level_string)); | ||||
|   // let filter_layer = EnvFilter::from_default_env();
 | ||||
|   let filter_layer = EnvFilter::new(format!("{}={}", pkg_name, level_string)); | ||||
|   // let filter_layer = EnvFilter::try_from_default_env()
 | ||||
|   //   .unwrap_or_else(|_| EnvFilter::new("info"))
 | ||||
|   //   .add_directive(format!("{}=trace", pkg_name).parse().unwrap());
 | ||||
| 
 | ||||
|   tracing_subscriber::registry() | ||||
|     .with(format_layer) | ||||
|  |  | |||
|  | @ -15,9 +15,6 @@ use crate::{ | |||
| use hot_reload::{ReloaderReceiver, ReloaderService}; | ||||
| use rpxy_lib::entrypoint; | ||||
| 
 | ||||
| #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
| compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); | ||||
| 
 | ||||
| fn main() { | ||||
|   init_logger(); | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,6 +1,6 @@ | |||
| [package] | ||||
| name = "rpxy-lib" | ||||
| version = "0.6.2" | ||||
| version = "0.7.0-alpha.0" | ||||
| authors = ["Jun Kurihara"] | ||||
| homepage = "https://github.com/junkurihara/rust-rpxy" | ||||
| repository = "https://github.com/junkurihara/rust-rpxy" | ||||
|  | @ -12,12 +12,20 @@ publish = false | |||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||
| 
 | ||||
| [features] | ||||
| default = ["http3-quinn", "sticky-cookie", "cache"] | ||||
| http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] | ||||
| http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] | ||||
| default = ["http3-quinn", "sticky-cookie", "cache", "native-tls-backend"] | ||||
| http3-quinn = ["socket2", "quinn", "h3", "h3-quinn"] | ||||
| http3-s2n = [ | ||||
|   "h3", | ||||
|   "s2n-quic", | ||||
|   "s2n-quic-core", | ||||
|   "s2n-quic-rustls", | ||||
|   "s2n-quic-h3", | ||||
| ] | ||||
| cache = ["http-cache-semantics", "lru", "sha2", "base64"] | ||||
| sticky-cookie = ["base64", "sha2", "chrono"] | ||||
| cache = ["http-cache-semantics", "lru"] | ||||
| native-roots = ["hyper-rustls/native-tokio"] | ||||
| native-tls-backend = ["hyper-tls"] | ||||
| rustls-backend = [] # not implemented yet | ||||
| native-roots = [] #"hyper-rustls/native-tokio"] # not implemented yet | ||||
| 
 | ||||
| [dependencies] | ||||
| rand = "0.8.5" | ||||
|  | @ -25,7 +33,7 @@ rustc-hash = "1.1.0" | |||
| bytes = "1.5.0" | ||||
| derive_builder = "0.12.0" | ||||
| futures = { version = "0.3.29", features = ["alloc", "async-await"] } | ||||
| tokio = { version = "1.33.0", default-features = false, features = [ | ||||
| tokio = { version = "1.35.0", default-features = false, features = [ | ||||
|   "net", | ||||
|   "rt-multi-thread", | ||||
|   "time", | ||||
|  | @ -33,28 +41,44 @@ tokio = { version = "1.33.0", default-features = false, features = [ | |||
|   "macros", | ||||
|   "fs", | ||||
| ] } | ||||
| pin-project-lite = "0.2.13" | ||||
| async-trait = "0.1.74" | ||||
| hot_reload = "0.1.4" # reloading certs | ||||
| 
 | ||||
| # Error handling | ||||
| anyhow = "1.0.75" | ||||
| thiserror = "1.0.50" | ||||
| 
 | ||||
| # http and tls | ||||
| hyper = { version = "0.14.27", default-features = false, features = [ | ||||
|   "server", | ||||
|   "http1", | ||||
|   "http2", | ||||
|   "stream", | ||||
| ] } | ||||
| hyper-rustls = { version = "0.24.2", default-features = false, features = [ | ||||
|   "tokio-runtime", | ||||
|   "webpki-tokio", | ||||
|   "http1", | ||||
|   "http2", | ||||
| ] } | ||||
| # http for both server and client | ||||
| http = "1.0.0" | ||||
| http-body-util = "0.1.0" | ||||
| hyper = { version = "1.0.1", default-features = false } | ||||
| # hyper-util = { version = "0.1.1", features = ["full"] } | ||||
| hyper-util = { git = "https://github.com/junkurihara/hyper-util", features = [ | ||||
|   "full", | ||||
| ], rev = "99409f5c4059633b7e2fa8b9c2e6c110b0f2f64b" } | ||||
| futures-util = { version = "0.3.29", default-features = false } | ||||
| futures-channel = { version = "0.3.29", default-features = false } | ||||
| 
 | ||||
| # http client for upstream | ||||
| hyper-tls = { git = "https://github.com/junkurihara/hyper-tls", features = [ | ||||
|   "alpn", | ||||
|   "vendored", | ||||
| ], rev = "06fb462ee67ec349936ceb64849d64d05e58458a", optional = true } | ||||
| # hyper-tls = { version = "0.6.0", features = [ | ||||
| #   "alpn", | ||||
| #   "vendored", | ||||
| # ], optional = true } | ||||
| # hyper-rustls = { version = "0.24.2", default-features = false, features = [ | ||||
| #   "tokio-runtime", | ||||
| #   "webpki-tokio", | ||||
| #   "http1", | ||||
| #   "http2", | ||||
| # ] } | ||||
| 
 | ||||
| # tls and cert management for server | ||||
| hot_reload = "0.1.4" | ||||
| rustls = { version = "0.21.10", default-features = false } | ||||
| tokio-rustls = { version = "0.24.1", features = ["early-data"] } | ||||
| rustls = { version = "0.21.8", default-features = false } | ||||
| webpki = "0.22.4" | ||||
| x509-parser = "0.15.1" | ||||
| 
 | ||||
|  | @ -62,22 +86,22 @@ x509-parser = "0.15.1" | |||
| tracing = { version = "0.1.40" } | ||||
| 
 | ||||
| # http/3 | ||||
| # quinn = { version = "0.9.3", optional = true } | ||||
| quinn = { path = "../submodules/quinn/quinn", optional = true } # Tentative to support rustls-0.21 | ||||
| quinn = { version = "0.10.2", optional = true } | ||||
| h3 = { path = "../submodules/h3/h3/", optional = true } | ||||
| # h3-quinn = { path = "./h3/h3-quinn/", optional = true } | ||||
| h3-quinn = { path = "../submodules/h3-quinn/", optional = true } # Tentative to support rustls-0.21 | ||||
| # for UDP socket wit SO_REUSEADDR when h3 with quinn | ||||
| socket2 = { version = "0.5.5", features = ["all"], optional = true } | ||||
| s2n-quic = { path = "../submodules/s2n-quic/quic/s2n-quic/", default-features = false, features = [ | ||||
| h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } | ||||
| s2n-quic = { version = "1.32.0", default-features = false, features = [ | ||||
|   "provider-tls-rustls", | ||||
| ], optional = true } | ||||
| s2n-quic-h3 = { path = "../submodules/s2n-quic/quic/s2n-quic-h3/", optional = true } | ||||
| s2n-quic-rustls = { path = "../submodules/s2n-quic/quic/s2n-quic-rustls/", optional = true } | ||||
| s2n-quic-core = { version = "0.32.0", default-features = false, optional = true } | ||||
| s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } | ||||
| s2n-quic-rustls = { version = "0.32.0", optional = true } | ||||
| # for UDP socket wit SO_REUSEADDR when h3 with quinn | ||||
| socket2 = { version = "0.5.5", features = ["all"], optional = true } | ||||
| 
 | ||||
| # cache | ||||
| http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } | ||||
| lru = { version = "0.12.0", optional = true } | ||||
| lru = { version = "0.12.1", optional = true } | ||||
| sha2 = { version = "0.10.8", default-features = false, optional = true } | ||||
| 
 | ||||
| # cookie handling for sticky cookie | ||||
| chrono = { version = "0.4.31", default-features = false, features = [ | ||||
|  | @ -86,7 +110,7 @@ chrono = { version = "0.4.31", default-features = false, features = [ | |||
|   "clock", | ||||
| ], optional = true } | ||||
| base64 = { version = "0.21.5", optional = true } | ||||
| sha2 = { version = "0.10.8", default-features = false, optional = true } | ||||
| 
 | ||||
| 
 | ||||
| [dev-dependencies] | ||||
| tokio-test = "0.4.3" | ||||
|  |  | |||
							
								
								
									
										136
									
								
								rpxy-lib/src/backend/backend_main.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								rpxy-lib/src/backend/backend_main.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,136 @@ | |||
| use crate::{ | ||||
|   crypto::CryptoSource, | ||||
|   error::*, | ||||
|   log::*, | ||||
|   name_exp::{ByteName, ServerName}, | ||||
|   AppConfig, AppConfigList, | ||||
| }; | ||||
| use derive_builder::Builder; | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use std::borrow::Cow; | ||||
| 
 | ||||
| use super::upstream::PathManager; | ||||
| 
 | ||||
| /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 | ||||
| #[derive(Builder)] | ||||
| pub struct BackendApp<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   #[builder(setter(into))] | ||||
|   /// backend application name, e.g., app1
 | ||||
|   pub app_name: String, | ||||
|   #[builder(setter(custom))] | ||||
|   /// server name, e.g., example.com, in [[ServerName]] object
 | ||||
|   pub server_name: ServerName, | ||||
|   /// struct of reverse proxy serving incoming request
 | ||||
|   pub path_manager: PathManager, | ||||
|   /// tls settings: https redirection with 30x
 | ||||
|   #[builder(default)] | ||||
|   pub https_redirection: Option<bool>, | ||||
|   /// TLS settings: source meta for server cert, key, client ca cert
 | ||||
|   #[builder(default)] | ||||
|   pub crypto_source: Option<T>, | ||||
| } | ||||
| impl<'a, T> BackendAppBuilder<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.server_name = Some(server_name.to_server_name()); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// HashMap and some meta information for multiple Backend structs.
 | ||||
| pub struct BackendAppManager<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   /// HashMap of Backend structs, key is server name
 | ||||
|   pub apps: HashMap<ServerName, BackendApp<T>>, | ||||
|   /// for plaintext http
 | ||||
|   pub default_server_name: Option<ServerName>, | ||||
| } | ||||
| 
 | ||||
| impl<T> Default for BackendAppManager<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   fn default() -> Self { | ||||
|     Self { | ||||
|       apps: HashMap::<ServerName, BackendApp<T>>::default(), | ||||
|       default_server_name: None, | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<T> TryFrom<&AppConfig<T>> for BackendApp<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_from(app_config: &AppConfig<T>) -> Result<Self, Self::Error> { | ||||
|     let mut backend_builder = BackendAppBuilder::default(); | ||||
|     let path_manager = PathManager::try_from(app_config)?; | ||||
|     backend_builder | ||||
|       .app_name(app_config.app_name.clone()) | ||||
|       .server_name(app_config.server_name.clone()) | ||||
|       .path_manager(path_manager); | ||||
|     // TLS settings and build backend instance
 | ||||
|     let backend = if app_config.tls.is_none() { | ||||
|       backend_builder.build()? | ||||
|     } else { | ||||
|       let tls = app_config.tls.as_ref().unwrap(); | ||||
|       backend_builder | ||||
|         .https_redirection(Some(tls.https_redirection)) | ||||
|         .crypto_source(Some(tls.inner.clone())) | ||||
|         .build()? | ||||
|     }; | ||||
|     Ok(backend) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<T> TryFrom<&AppConfigList<T>> for BackendAppManager<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_from(config_list: &AppConfigList<T>) -> Result<Self, Self::Error> { | ||||
|     let mut manager = Self::default(); | ||||
|     for app_config in config_list.inner.iter() { | ||||
|       let backend: BackendApp<T> = BackendApp::try_from(app_config)?; | ||||
|       manager | ||||
|         .apps | ||||
|         .insert(app_config.server_name.clone().to_server_name(), backend); | ||||
| 
 | ||||
|       info!( | ||||
|         "Registering application {} ({})", | ||||
|         &app_config.server_name, &app_config.app_name | ||||
|       ); | ||||
|     } | ||||
| 
 | ||||
|     // default backend application for plaintext http requests
 | ||||
|     if let Some(default_app_name) = &config_list.default_app { | ||||
|       let default_server_name = manager | ||||
|         .apps | ||||
|         .iter() | ||||
|         .filter(|(_k, v)| &v.app_name == default_app_name) | ||||
|         .map(|(_, v)| v.server_name.clone()) | ||||
|         .collect::<Vec<_>>(); | ||||
| 
 | ||||
|       if !default_server_name.is_empty() { | ||||
|         info!( | ||||
|           "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", | ||||
|           &default_app_name, | ||||
|           (&default_server_name[0]).try_into().unwrap_or_else(|_| "".to_string()) | ||||
|         ); | ||||
| 
 | ||||
|         manager.default_server_name = Some(default_server_name[0].clone()); | ||||
|       } | ||||
|     } | ||||
|     Ok(manager) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										135
									
								
								rpxy-lib/src/backend/load_balance/load_balance_main.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								rpxy-lib/src/backend/load_balance/load_balance_main.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,135 @@ | |||
| #[cfg(feature = "sticky-cookie")] | ||||
| pub use super::{ | ||||
|   load_balance_sticky::{LoadBalanceSticky, LoadBalanceStickyBuilder}, | ||||
|   sticky_cookie::StickyCookie, | ||||
| }; | ||||
| use derive_builder::Builder; | ||||
| use rand::Rng; | ||||
| use std::sync::{ | ||||
|   atomic::{AtomicUsize, Ordering}, | ||||
|   Arc, | ||||
| }; | ||||
| 
 | ||||
| /// Constants to specify a load balance option
 | ||||
| pub mod load_balance_options { | ||||
|   pub const FIX_TO_FIRST: &str = "none"; | ||||
|   pub const ROUND_ROBIN: &str = "round_robin"; | ||||
|   pub const RANDOM: &str = "random"; | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   pub const STICKY_ROUND_ROBIN: &str = "sticky"; | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Pointer to upstream serving the incoming request.
 | ||||
| /// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given.
 | ||||
| pub struct PointerToUpstream { | ||||
|   pub ptr: usize, | ||||
|   pub context: Option<LoadBalanceContext>, | ||||
| } | ||||
| /// Trait for LB
 | ||||
| pub(super) trait LoadBalanceWithPointer { | ||||
|   fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream; | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Round Robin LB object as a pointer to the current serving upstream destination
 | ||||
| pub struct LoadBalanceRoundRobin { | ||||
|   #[builder(default)] | ||||
|   /// Pointer to the index of the last served upstream destination
 | ||||
|   ptr: Arc<AtomicUsize>, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Number of upstream destinations
 | ||||
|   num_upstreams: usize, | ||||
| } | ||||
| impl LoadBalanceRoundRobinBuilder { | ||||
|   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||
|     self.num_upstreams = Some(*v); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| impl LoadBalanceWithPointer for LoadBalanceRoundRobin { | ||||
|   /// Increment the count of upstream served up to the max value
 | ||||
|   fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { | ||||
|     // Get a current count of upstream served
 | ||||
|     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||
| 
 | ||||
|     let ptr = if current_ptr < self.num_upstreams - 1 { | ||||
|       self.ptr.fetch_add(1, Ordering::Relaxed) | ||||
|     } else { | ||||
|       // Clear the counter
 | ||||
|       self.ptr.fetch_and(0, Ordering::Relaxed) | ||||
|     }; | ||||
|     PointerToUpstream { ptr, context: None } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Random LB object to keep the object of random pools
 | ||||
| pub struct LoadBalanceRandom { | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Number of upstream destinations
 | ||||
|   num_upstreams: usize, | ||||
| } | ||||
| impl LoadBalanceRandomBuilder { | ||||
|   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||
|     self.num_upstreams = Some(*v); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| impl LoadBalanceWithPointer for LoadBalanceRandom { | ||||
|   /// Returns the random index within the range
 | ||||
|   fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { | ||||
|     let mut rng = rand::thread_rng(); | ||||
|     let ptr = rng.gen_range(0..self.num_upstreams); | ||||
|     PointerToUpstream { ptr, context: None } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Load Balancing Option
 | ||||
| pub enum LoadBalance { | ||||
|   /// Fix to the first upstream. Use if only one upstream destination is specified
 | ||||
|   FixToFirst, | ||||
|   /// Randomly chose one upstream server
 | ||||
|   Random(LoadBalanceRandom), | ||||
|   /// Simple round robin without session persistance
 | ||||
|   RoundRobin(LoadBalanceRoundRobin), | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   /// Round robin with session persistance using cookie
 | ||||
|   StickyRoundRobin(LoadBalanceSticky), | ||||
| } | ||||
| impl Default for LoadBalance { | ||||
|   fn default() -> Self { | ||||
|     Self::FixToFirst | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl LoadBalance { | ||||
|   /// Get the index of the upstream serving the incoming request
 | ||||
|   pub fn get_context(&self, _context_to_lb: &Option<LoadBalanceContext>) -> PointerToUpstream { | ||||
|     match self { | ||||
|       LoadBalance::FixToFirst => PointerToUpstream { | ||||
|         ptr: 0usize, | ||||
|         context: None, | ||||
|       }, | ||||
|       LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), | ||||
|       LoadBalance::Random(ptr) => ptr.get_ptr(None), | ||||
|       #[cfg(feature = "sticky-cookie")] | ||||
|       LoadBalance::StickyRoundRobin(ptr) => { | ||||
|         // Generate new context if sticky round robin is enabled.
 | ||||
|         ptr.get_ptr(_context_to_lb.as_ref()) | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Struct to handle the sticky cookie string,
 | ||||
| /// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists.
 | ||||
| /// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist.
 | ||||
| pub struct LoadBalanceContext { | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   pub sticky_cookie: StickyCookie, | ||||
|   #[cfg(not(feature = "sticky-cookie"))] | ||||
|   pub sticky_cookie: (), | ||||
| } | ||||
							
								
								
									
										137
									
								
								rpxy-lib/src/backend/load_balance/load_balance_sticky.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										137
									
								
								rpxy-lib/src/backend/load_balance/load_balance_sticky.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,137 @@ | |||
| use super::{ | ||||
|   load_balance_main::{LoadBalanceContext, LoadBalanceWithPointer, PointerToUpstream}, | ||||
|   sticky_cookie::StickyCookieConfig, | ||||
|   Upstream, | ||||
| }; | ||||
| use crate::{constants::STICKY_COOKIE_NAME, log::*}; | ||||
| use derive_builder::Builder; | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use std::{ | ||||
|   borrow::Cow, | ||||
|   sync::{ | ||||
|     atomic::{AtomicUsize, Ordering}, | ||||
|     Arc, | ||||
|   }, | ||||
| }; | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Round Robin LB object in the sticky cookie manner
 | ||||
| pub struct LoadBalanceSticky { | ||||
|   #[builder(default)] | ||||
|   /// Pointer to the index of the last served upstream destination
 | ||||
|   ptr: Arc<AtomicUsize>, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Number of upstream destinations
 | ||||
|   num_upstreams: usize, | ||||
|   #[builder(setter(custom))] | ||||
|   /// Information to build the cookie to stick clients to specific backends
 | ||||
|   pub sticky_config: StickyCookieConfig, | ||||
|   #[builder(setter(custom))] | ||||
|   /// Hashmaps:
 | ||||
|   /// - Hashmap that maps server indices to server id (string)
 | ||||
|   /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 | ||||
|   upstream_maps: UpstreamMap, | ||||
| } | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct UpstreamMap { | ||||
|   /// Hashmap that maps server indices to server id (string)
 | ||||
|   upstream_index_map: Vec<String>, | ||||
|   /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 | ||||
|   upstream_id_map: HashMap<String, usize>, | ||||
| } | ||||
| impl LoadBalanceStickyBuilder { | ||||
|   /// Set the number of upstream destinations
 | ||||
|   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||
|     self.num_upstreams = Some(*v); | ||||
|     self | ||||
|   } | ||||
|   /// Set the information to build the cookie to stick clients to specific backends
 | ||||
|   pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option<String>) -> &mut Self { | ||||
|     self.sticky_config = Some(StickyCookieConfig { | ||||
|       name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように
 | ||||
|       domain: server_name.to_ascii_lowercase(), | ||||
|       path: if let Some(v) = path_opt { | ||||
|         v.to_ascii_lowercase() | ||||
|       } else { | ||||
|         "/".to_string() | ||||
|       }, | ||||
|       duration: 300, // TODO: config等で変更できるように
 | ||||
|     }); | ||||
|     self | ||||
|   } | ||||
|   /// Set the hashmaps: upstream_index_map and upstream_id_map
 | ||||
|   pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||
|     let upstream_index_map: Vec<String> = upstream_vec | ||||
|       .iter() | ||||
|       .enumerate() | ||||
|       .map(|(i, v)| v.calculate_id_with_index(i)) | ||||
|       .collect(); | ||||
|     let mut upstream_id_map = HashMap::default(); | ||||
|     for (i, v) in upstream_index_map.iter().enumerate() { | ||||
|       upstream_id_map.insert(v.to_string(), i); | ||||
|     } | ||||
|     self.upstream_maps = Some(UpstreamMap { | ||||
|       upstream_index_map, | ||||
|       upstream_id_map, | ||||
|     }); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| impl<'a> LoadBalanceSticky { | ||||
|   /// Increment the count of upstream served up to the max value
 | ||||
|   fn simple_increment_ptr(&self) -> usize { | ||||
|     // Get a current count of upstream served
 | ||||
|     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||
| 
 | ||||
|     if current_ptr < self.num_upstreams - 1 { | ||||
|       self.ptr.fetch_add(1, Ordering::Relaxed) | ||||
|     } else { | ||||
|       // Clear the counter
 | ||||
|       self.ptr.fetch_and(0, Ordering::Relaxed) | ||||
|     } | ||||
|   } | ||||
|   /// This is always called only internally. So 'unwrap()' is executed.
 | ||||
|   fn get_server_id_from_index(&self, index: usize) -> String { | ||||
|     self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() | ||||
|   } | ||||
|   /// This function takes value passed from outside. So 'result' is used.
 | ||||
|   fn get_server_index_from_id(&self, id: impl Into<Cow<'a, str>>) -> Option<usize> { | ||||
|     let id_str = id.into().to_string(); | ||||
|     self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) | ||||
|   } | ||||
| } | ||||
| impl LoadBalanceWithPointer for LoadBalanceSticky { | ||||
|   /// Get the pointer to the upstream server to serve the incoming request.
 | ||||
|   fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream { | ||||
|     // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer.
 | ||||
|     // Otherwise, get the server index indicated by the server_id inside the cookie
 | ||||
|     let ptr = match req_info { | ||||
|       None => { | ||||
|         debug!("No sticky cookie"); | ||||
|         self.simple_increment_ptr() | ||||
|       } | ||||
|       Some(context) => { | ||||
|         let server_id = &context.sticky_cookie.value.value; | ||||
|         if let Some(server_index) = self.get_server_index_from_id(server_id) { | ||||
|           debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); | ||||
|           server_index | ||||
|         } else { | ||||
|           debug!("Invalid sticky cookie: id={}", server_id); | ||||
|           self.simple_increment_ptr() | ||||
|         } | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     // Get the server id from the ptr.
 | ||||
|     // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie).
 | ||||
|     let upstream_id = self.get_server_id_from_index(ptr); | ||||
|     let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); | ||||
|     let new_context = Some(LoadBalanceContext { | ||||
|       sticky_cookie: new_cookie, | ||||
|     }); | ||||
|     PointerToUpstream { | ||||
|       ptr, | ||||
|       context: new_context, | ||||
|     } | ||||
|   } | ||||
| } | ||||
							
								
								
									
										43
									
								
								rpxy-lib/src/backend/load_balance/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								rpxy-lib/src/backend/load_balance/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,43 @@ | |||
| mod load_balance_main; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| mod load_balance_sticky; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| mod sticky_cookie; | ||||
| 
 | ||||
| use super::upstream::Upstream; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
| pub use load_balance_main::{ | ||||
|   load_balance_options, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, | ||||
| }; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| pub use load_balance_sticky::LoadBalanceStickyBuilder; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| pub use sticky_cookie::{StickyCookie, StickyCookieValue}; | ||||
| 
 | ||||
| /// Result type for load balancing
 | ||||
| type LoadBalanceResult<T> = std::result::Result<T, LoadBalanceError>; | ||||
| /// Describes things that can go wrong in the Load Balance
 | ||||
| #[derive(Debug, Error)] | ||||
| pub enum LoadBalanceError { | ||||
|   // backend load balance errors
 | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   #[error("Failed to cookie conversion to/from string")] | ||||
|   FailedToConversionStickyCookie, | ||||
| 
 | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   #[error("Invalid cookie structure")] | ||||
|   InvalidStickyCookieStructure, | ||||
| 
 | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   #[error("No sticky cookie value")] | ||||
|   NoStickyCookieValue, | ||||
| 
 | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   #[error("Failed to cookie conversion into string: no meta information")] | ||||
|   NoStickyCookieNoMetaInfo, | ||||
| 
 | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   #[error("Failed to build sticky cookie from config")] | ||||
|   FailedToBuildStickyCookie, | ||||
| } | ||||
							
								
								
									
										205
									
								
								rpxy-lib/src/backend/load_balance/sticky_cookie.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										205
									
								
								rpxy-lib/src/backend/load_balance/sticky_cookie.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,205 @@ | |||
| use super::{LoadBalanceError, LoadBalanceResult}; | ||||
| use chrono::{TimeZone, Utc}; | ||||
| use derive_builder::Builder; | ||||
| use std::borrow::Cow; | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Cookie value only, used for COOKIE in req
 | ||||
| pub struct StickyCookieValue { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Field name indicating sticky cookie
 | ||||
|   pub name: String, | ||||
|   #[builder(setter(custom))] | ||||
|   /// Upstream server_id
 | ||||
|   pub value: String, | ||||
| } | ||||
| impl<'a> StickyCookieValueBuilder { | ||||
|   pub fn name(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.name = Some(v.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
|   pub fn value(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.value = Some(v.into().to_string()); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| impl StickyCookieValue { | ||||
|   pub fn try_from(value: &str, expected_name: &str) -> LoadBalanceResult<Self> { | ||||
|     if !value.starts_with(expected_name) { | ||||
|       return Err(LoadBalanceError::FailedToConversionStickyCookie); | ||||
|     }; | ||||
|     let kv = value.split('=').map(|v| v.trim()).collect::<Vec<&str>>(); | ||||
|     if kv.len() != 2 { | ||||
|       return Err(LoadBalanceError::InvalidStickyCookieStructure); | ||||
|     }; | ||||
|     if kv[1].is_empty() { | ||||
|       return Err(LoadBalanceError::NoStickyCookieValue); | ||||
|     } | ||||
|     Ok(StickyCookieValue { | ||||
|       name: expected_name.to_string(), | ||||
|       value: kv[1].to_string(), | ||||
|     }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Struct describing sticky cookie meta information used for SET-COOKIE in res
 | ||||
| pub struct StickyCookieInfo { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Unix time
 | ||||
|   pub expires: i64, | ||||
| 
 | ||||
|   #[builder(setter(custom))] | ||||
|   /// Domain
 | ||||
|   pub domain: String, | ||||
| 
 | ||||
|   #[builder(setter(custom))] | ||||
|   /// Path
 | ||||
|   pub path: String, | ||||
| } | ||||
| impl<'a> StickyCookieInfoBuilder { | ||||
|   pub fn domain(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.domain = Some(v.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
|   pub fn path(&mut self, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.path = Some(v.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
|   pub fn expires(&mut self, duration_secs: i64) -> &mut Self { | ||||
|     let current = Utc::now().timestamp(); | ||||
|     self.expires = Some(current + duration_secs); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Struct describing sticky cookie
 | ||||
| pub struct StickyCookie { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Upstream server_id
 | ||||
|   pub value: StickyCookieValue, | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Upstream server_id
 | ||||
|   pub info: Option<StickyCookieInfo>, | ||||
| } | ||||
| 
 | ||||
| impl<'a> StickyCookieBuilder { | ||||
|   /// Set the value of sticky cookie
 | ||||
|   pub fn value(&mut self, n: impl Into<Cow<'a, str>>, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); | ||||
|     self | ||||
|   } | ||||
|   /// Set the meta information of sticky cookie
 | ||||
|   pub fn info( | ||||
|     &mut self, | ||||
|     domain: impl Into<Cow<'a, str>>, | ||||
|     path: impl Into<Cow<'a, str>>, | ||||
|     duration_secs: i64, | ||||
|   ) -> &mut Self { | ||||
|     let info = StickyCookieInfoBuilder::default() | ||||
|       .domain(domain) | ||||
|       .path(path) | ||||
|       .expires(duration_secs) | ||||
|       .build() | ||||
|       .unwrap(); | ||||
|     self.info = Some(Some(info)); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl TryInto<String> for StickyCookie { | ||||
|   type Error = LoadBalanceError; | ||||
| 
 | ||||
|   fn try_into(self) -> LoadBalanceResult<String> { | ||||
|     if self.info.is_none() { | ||||
|       return Err(LoadBalanceError::NoStickyCookieNoMetaInfo); | ||||
|     } | ||||
|     let info = self.info.unwrap(); | ||||
|     let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { | ||||
|       return Err(LoadBalanceError::FailedToConversionStickyCookie); | ||||
|     }; | ||||
|     let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); | ||||
|     let max_age = info.expires - Utc::now().timestamp(); | ||||
| 
 | ||||
|     Ok(format!( | ||||
|       "{}={}; expires={}; Max-Age={}; path={}; domain={}", | ||||
|       self.value.name, self.value.value, exp_str, max_age, info.path, info.domain | ||||
|     )) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Configuration to serve incoming requests in the manner of "sticky cookie".
 | ||||
| /// Including a dictionary to map Ids included in cookie and upstream destinations,
 | ||||
| /// and expiration of cookie.
 | ||||
| /// "domain" and "path" in the cookie will be the same as the reverse proxy options.
 | ||||
| pub struct StickyCookieConfig { | ||||
|   pub name: String, | ||||
|   pub domain: String, | ||||
|   pub path: String, | ||||
|   pub duration: i64, | ||||
| } | ||||
| impl<'a> StickyCookieConfig { | ||||
|   pub fn build_sticky_cookie(&self, v: impl Into<Cow<'a, str>>) -> LoadBalanceResult<StickyCookie> { | ||||
|     StickyCookieBuilder::default() | ||||
|       .value(self.name.clone(), v) | ||||
|       .info(&self.domain, &self.path, self.duration) | ||||
|       .build() | ||||
|       .map_err(|_| LoadBalanceError::FailedToBuildStickyCookie) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|   use super::*; | ||||
|   use crate::constants::STICKY_COOKIE_NAME; | ||||
| 
 | ||||
|   #[test] | ||||
|   fn config_works() { | ||||
|     let config = StickyCookieConfig { | ||||
|       name: STICKY_COOKIE_NAME.to_string(), | ||||
|       domain: "example.com".to_string(), | ||||
|       path: "/path".to_string(), | ||||
|       duration: 100, | ||||
|     }; | ||||
|     let expires_unix = Utc::now().timestamp() + 100; | ||||
|     let sc_string: LoadBalanceResult<String> = config.build_sticky_cookie("test_value").unwrap().try_into(); | ||||
|     let expires_date_string = Utc | ||||
|       .timestamp_opt(expires_unix, 0) | ||||
|       .unwrap() | ||||
|       .format("%a, %d-%b-%Y %T GMT") | ||||
|       .to_string(); | ||||
|     assert_eq!( | ||||
|       sc_string.unwrap(), | ||||
|       format!( | ||||
|         "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com", | ||||
|         STICKY_COOKIE_NAME, expires_date_string, 100 | ||||
|       ) | ||||
|     ); | ||||
|   } | ||||
|   #[test] | ||||
|   fn to_string_works() { | ||||
|     let sc = StickyCookie { | ||||
|       value: StickyCookieValue { | ||||
|         name: STICKY_COOKIE_NAME.to_string(), | ||||
|         value: "test_value".to_string(), | ||||
|       }, | ||||
|       info: Some(StickyCookieInfo { | ||||
|         expires: 1686221173i64, | ||||
|         domain: "example.com".to_string(), | ||||
|         path: "/path".to_string(), | ||||
|       }), | ||||
|     }; | ||||
|     let sc_string: LoadBalanceResult<String> = sc.try_into(); | ||||
|     let max_age = 1686221173i64 - Utc::now().timestamp(); | ||||
|     assert!(sc_string.is_ok()); | ||||
|     assert_eq!( | ||||
|       sc_string.unwrap(), | ||||
|       format!( | ||||
|         "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com", | ||||
|         STICKY_COOKIE_NAME, max_age | ||||
|       ) | ||||
|     ); | ||||
|   } | ||||
| } | ||||
|  | @ -1,77 +1,14 @@ | |||
| mod backend_main; | ||||
| mod load_balance; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| mod load_balance_sticky; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| mod sticky_cookie; | ||||
| mod upstream; | ||||
| mod upstream_opts; | ||||
| 
 | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| pub use self::sticky_cookie::{StickyCookie, StickyCookieValue}; | ||||
| pub use self::{ | ||||
|   load_balance::{LbContext, LoadBalance}, | ||||
|   upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, | ||||
| pub(crate) use self::load_balance::{StickyCookie, StickyCookieValue}; | ||||
| #[allow(unused)] | ||||
| pub(crate) use self::{ | ||||
|   load_balance::{LoadBalance, LoadBalanceContext}, | ||||
|   upstream::{PathManager, Upstream, UpstreamCandidates}, | ||||
|   upstream_opts::UpstreamOption, | ||||
| }; | ||||
| use crate::{ | ||||
|   certs::CryptoSource, | ||||
|   utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, | ||||
| }; | ||||
| use derive_builder::Builder; | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use std::borrow::Cow; | ||||
| 
 | ||||
| /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 | ||||
| #[derive(Builder)] | ||||
| pub struct Backend<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   #[builder(setter(into))] | ||||
|   /// backend application name, e.g., app1
 | ||||
|   pub app_name: String, | ||||
|   #[builder(setter(custom))] | ||||
|   /// server name, e.g., example.com, in String ascii lower case
 | ||||
|   pub server_name: String, | ||||
|   /// struct of reverse proxy serving incoming request
 | ||||
|   pub reverse_proxy: ReverseProxy, | ||||
| 
 | ||||
|   /// tls settings: https redirection with 30x
 | ||||
|   #[builder(default)] | ||||
|   pub https_redirection: Option<bool>, | ||||
| 
 | ||||
|   /// TLS settings: source meta for server cert, key, client ca cert
 | ||||
|   #[builder(default)] | ||||
|   pub crypto_source: Option<T>, | ||||
| } | ||||
| impl<'a, T> BackendBuilder<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self { | ||||
|     self.server_name = Some(server_name.into().to_ascii_lowercase()); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// HashMap and some meta information for multiple Backend structs.
 | ||||
| pub struct Backends<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   pub apps: HashMap<ServerNameBytesExp, Backend<T>>, // hyper::uriで抜いたhostで引っ掛ける
 | ||||
|   pub default_server_name_bytes: Option<ServerNameBytesExp>, // for plaintext http
 | ||||
| } | ||||
| 
 | ||||
| impl<T> Backends<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   #[allow(clippy::new_without_default)] | ||||
|   pub fn new() -> Self { | ||||
|     Backends { | ||||
|       apps: HashMap::<ServerNameBytesExp, Backend<T>>::default(), | ||||
|       default_server_name_bytes: None, | ||||
|     } | ||||
|   } | ||||
| } | ||||
| pub(crate) use backend_main::{BackendApp, BackendAppBuilderError, BackendAppManager}; | ||||
|  |  | |||
|  | @ -1,8 +1,18 @@ | |||
| #[cfg(feature = "sticky-cookie")] | ||||
| use super::load_balance::LbStickyRoundRobinBuilder; | ||||
| use super::load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}; | ||||
| use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; | ||||
| use crate::log::*; | ||||
| use super::load_balance::LoadBalanceStickyBuilder; | ||||
| use super::load_balance::{ | ||||
|   load_balance_options as lb_opts, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, | ||||
|   LoadBalanceRoundRobinBuilder, | ||||
| }; | ||||
| // use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption};
 | ||||
| use super::upstream_opts::UpstreamOption; | ||||
| use crate::{ | ||||
|   crypto::CryptoSource, | ||||
|   error::RpxyError, | ||||
|   globals::{AppConfig, UpstreamUri}, | ||||
|   log::*, | ||||
|   name_exp::{ByteName, PathName}, | ||||
| }; | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| use base64::{engine::general_purpose, Engine as _}; | ||||
| use derive_builder::Builder; | ||||
|  | @ -10,26 +20,68 @@ use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; | |||
| #[cfg(feature = "sticky-cookie")] | ||||
| use sha2::{Digest, Sha256}; | ||||
| use std::borrow::Cow; | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct ReverseProxy { | ||||
|   pub upstream: HashMap<PathNameBytesExp, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 | ||||
| /// Handler for given path to route incoming request to path's corresponding upstream server(s).
 | ||||
| pub struct PathManager { | ||||
|   /// HashMap of upstream candidate server info, key is path name
 | ||||
|   /// TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 | ||||
|   inner: HashMap<PathName, UpstreamCandidates>, | ||||
| } | ||||
| 
 | ||||
| impl ReverseProxy { | ||||
|   /// Get an appropriate upstream destination for given path string.
 | ||||
|   pub fn get<'a>(&self, path_str: impl Into<Cow<'a, str>>) -> Option<&UpstreamGroup> { | ||||
|     // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、
 | ||||
|     // コスト的にこの程度で十分
 | ||||
|     let path_bytes = &path_str.to_path_name_vec(); | ||||
| impl<T> TryFrom<&AppConfig<T>> for PathManager | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
|   fn try_from(app_config: &AppConfig<T>) -> Result<Self, Self::Error> { | ||||
|     let mut inner: HashMap<PathName, UpstreamCandidates> = HashMap::default(); | ||||
| 
 | ||||
|     app_config.reverse_proxy.iter().for_each(|rpc| { | ||||
|       let upstream_vec: Vec<Upstream> = rpc.upstream.iter().map(Upstream::from).collect(); | ||||
|       let elem = UpstreamCandidatesBuilder::default() | ||||
|         .upstream(&upstream_vec) | ||||
|         .path(&rpc.path) | ||||
|         .replace_path(&rpc.replace_path) | ||||
|         .load_balance(&rpc.load_balance, &upstream_vec, &app_config.server_name, &rpc.path) | ||||
|         .options(&rpc.upstream_options) | ||||
|         .build() | ||||
|         .unwrap(); | ||||
|       inner.insert(elem.path.clone(), elem); | ||||
|     }); | ||||
| 
 | ||||
|     if app_config.reverse_proxy.iter().filter(|rpc| rpc.path.is_none()).count() >= 2 { | ||||
|       error!("Multiple default reverse proxy setting"); | ||||
|       return Err(RpxyError::InvalidReverseProxyConfig); | ||||
|     } | ||||
| 
 | ||||
|     if !(inner.iter().all(|(_, elem)| { | ||||
|       !(elem.options.contains(&UpstreamOption::ForceHttp11Upstream) | ||||
|         && elem.options.contains(&UpstreamOption::ForceHttp2Upstream)) | ||||
|     })) { | ||||
|       error!("Either one of force_http11 or force_http2 can be enabled"); | ||||
|       return Err(RpxyError::InvalidUpstreamOptionSetting); | ||||
|     } | ||||
| 
 | ||||
|     Ok(PathManager { inner }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl PathManager { | ||||
|   /// Get an appropriate upstream destinations for given path string.
 | ||||
|   /// trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、
 | ||||
|   /// コスト的にこの程度で十分では。
 | ||||
|   pub fn get<'a>(&self, path_str: impl Into<Cow<'a, str>>) -> Option<&UpstreamCandidates> { | ||||
|     let path_name = &path_str.to_path_name(); | ||||
| 
 | ||||
|     let matched_upstream = self | ||||
|       .upstream | ||||
|       .inner | ||||
|       .iter() | ||||
|       .filter(|(route_bytes, _)| { | ||||
|         match path_bytes.starts_with(route_bytes) { | ||||
|         match path_name.starts_with(route_bytes) { | ||||
|           true => { | ||||
|             route_bytes.len() == 1 // route = '/', i.e., default
 | ||||
|             || match path_bytes.get(route_bytes.len()) { | ||||
|               || match path_name.get(route_bytes.len()) { | ||||
|                 None => true, // exact case
 | ||||
|                 Some(p) => p == &b'/', // sub-path case
 | ||||
|               } | ||||
|  | @ -38,10 +90,10 @@ impl ReverseProxy { | |||
|         } | ||||
|       }) | ||||
|       .max_by_key(|(route_bytes, _)| route_bytes.len()); | ||||
|     if let Some((_path, u)) = matched_upstream { | ||||
|     if let Some((path, u)) = matched_upstream { | ||||
|       debug!( | ||||
|         "Found upstream: {:?}", | ||||
|         String::from_utf8(_path.0.clone()).unwrap_or_else(|_| "<none>".to_string()) | ||||
|         path.try_into().unwrap_or_else(|_| "<none>".to_string()) | ||||
|       ); | ||||
|       Some(u) | ||||
|     } else { | ||||
|  | @ -56,6 +108,13 @@ pub struct Upstream { | |||
|   /// Base uri without specific path
 | ||||
|   pub uri: hyper::Uri, | ||||
| } | ||||
| impl From<&UpstreamUri> for Upstream { | ||||
|   fn from(value: &UpstreamUri) -> Self { | ||||
|     Self { | ||||
|       uri: value.inner.clone(), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| impl Upstream { | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   /// Hashing uri with index to avoid collision
 | ||||
|  | @ -69,47 +128,50 @@ impl Upstream { | |||
| } | ||||
| #[derive(Debug, Clone, Builder)] | ||||
| /// Struct serving multiple upstream servers for, e.g., load balancing.
 | ||||
| pub struct UpstreamGroup { | ||||
| pub struct UpstreamCandidates { | ||||
|   #[builder(setter(custom))] | ||||
|   /// Upstream server(s)
 | ||||
|   pub upstream: Vec<Upstream>, | ||||
|   pub inner: Vec<Upstream>, | ||||
| 
 | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Path like "/path" in [[PathNameBytesExp]] associated with the upstream server(s)
 | ||||
|   pub path: PathNameBytesExp, | ||||
|   /// Path like "/path" in [[PathName]] associated with the upstream server(s)
 | ||||
|   pub path: PathName, | ||||
| 
 | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url
 | ||||
|   pub replace_path: Option<PathNameBytesExp>, | ||||
|   /// Path in [[PathName]] that will be used to replace the "path" part of incoming url
 | ||||
|   pub replace_path: Option<PathName>, | ||||
| 
 | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Load balancing option
 | ||||
|   pub lb: LoadBalance, | ||||
|   pub load_balance: LoadBalance, | ||||
| 
 | ||||
|   #[builder(setter(custom), default)] | ||||
|   /// Activated upstream options defined in [[UpstreamOption]]
 | ||||
|   pub opts: HashSet<UpstreamOption>, | ||||
|   pub options: HashSet<UpstreamOption>, | ||||
| } | ||||
| 
 | ||||
| impl UpstreamGroupBuilder { | ||||
| impl UpstreamCandidatesBuilder { | ||||
|   /// Set the upstream server(s)
 | ||||
|   pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||
|     self.upstream = Some(upstream_vec.to_vec()); | ||||
|     self.inner = Some(upstream_vec.to_vec()); | ||||
|     self | ||||
|   } | ||||
|   /// Set the path like "/path" in [[PathName]] associated with the upstream server(s), default is "/"
 | ||||
|   pub fn path(&mut self, v: &Option<String>) -> &mut Self { | ||||
|     let path = match v { | ||||
|       Some(p) => p.to_path_name_vec(), | ||||
|       None => "/".to_path_name_vec(), | ||||
|       Some(p) => p.to_path_name(), | ||||
|       None => "/".to_path_name(), | ||||
|     }; | ||||
|     self.path = Some(path); | ||||
|     self | ||||
|   } | ||||
|   /// Set the path in [[PathName]] that will be used to replace the "path" part of incoming url
 | ||||
|   pub fn replace_path(&mut self, v: &Option<String>) -> &mut Self { | ||||
|     self.replace_path = Some( | ||||
|       v.to_owned() | ||||
|         .as_ref() | ||||
|         .map_or_else(|| None, |v| Some(v.to_path_name_vec())), | ||||
|     ); | ||||
|     self.replace_path = Some(v.to_owned().as_ref().map_or_else(|| None, |v| Some(v.to_path_name()))); | ||||
|     self | ||||
|   } | ||||
|   pub fn lb( | ||||
|   /// Set the load balancing option
 | ||||
|   pub fn load_balance( | ||||
|     &mut self, | ||||
|     v: &Option<String>, | ||||
|     // upstream_num: &usize,
 | ||||
|  | @ -121,16 +183,21 @@ impl UpstreamGroupBuilder { | |||
|     let lb = if let Some(x) = v { | ||||
|       match x.as_str() { | ||||
|         lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, | ||||
|         lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), | ||||
|         lb_opts::RANDOM => LoadBalance::Random( | ||||
|           LoadBalanceRandomBuilder::default() | ||||
|             .num_upstreams(upstream_num) | ||||
|             .build() | ||||
|             .unwrap(), | ||||
|         ), | ||||
|         lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( | ||||
|           LbRoundRobinBuilder::default() | ||||
|           LoadBalanceRoundRobinBuilder::default() | ||||
|             .num_upstreams(upstream_num) | ||||
|             .build() | ||||
|             .unwrap(), | ||||
|         ), | ||||
|         #[cfg(feature = "sticky-cookie")] | ||||
|         lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( | ||||
|           LbStickyRoundRobinBuilder::default() | ||||
|           LoadBalanceStickyBuilder::default() | ||||
|             .num_upstreams(upstream_num) | ||||
|             .sticky_config(_server_name, _path_opt) | ||||
|             .upstream_maps(upstream_vec) // TODO:
 | ||||
|  | @ -145,10 +212,11 @@ impl UpstreamGroupBuilder { | |||
|     } else { | ||||
|       LoadBalance::default() | ||||
|     }; | ||||
|     self.lb = Some(lb); | ||||
|     self.load_balance = Some(lb); | ||||
|     self | ||||
|   } | ||||
|   pub fn opts(&mut self, v: &Option<Vec<String>>) -> &mut Self { | ||||
|   /// Set the activated upstream options defined in [[UpstreamOption]]
 | ||||
|   pub fn options(&mut self, v: &Option<Vec<String>>) -> &mut Self { | ||||
|     let opts = if let Some(opts) = v { | ||||
|       opts | ||||
|         .iter() | ||||
|  | @ -157,25 +225,22 @@ impl UpstreamGroupBuilder { | |||
|     } else { | ||||
|       Default::default() | ||||
|     }; | ||||
|     self.opts = Some(opts); | ||||
|     self.options = Some(opts); | ||||
|     self | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl UpstreamGroup { | ||||
| impl UpstreamCandidates { | ||||
|   /// Get an enabled option of load balancing [[LoadBalance]]
 | ||||
|   pub fn get(&self, context_to_lb: &Option<LbContext>) -> (Option<&Upstream>, Option<LbContext>) { | ||||
|     let pointer_to_upstream = self.lb.get_context(context_to_lb); | ||||
|   pub fn get(&self, context_to_lb: &Option<LoadBalanceContext>) -> (Option<&Upstream>, Option<LoadBalanceContext>) { | ||||
|     let pointer_to_upstream = self.load_balance.get_context(context_to_lb); | ||||
|     debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); | ||||
|     debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); | ||||
|     debug!("Context to LB (Cookie in Request): {:?}", context_to_lb); | ||||
|     debug!( | ||||
|       "Context from LB (Set-Cookie in Res): {:?}", | ||||
|       pointer_to_upstream.context_lb | ||||
|       "Context from LB (Set-Cookie in Response): {:?}", | ||||
|       pointer_to_upstream.context | ||||
|     ); | ||||
|     ( | ||||
|       self.upstream.get(pointer_to_upstream.ptr), | ||||
|       pointer_to_upstream.context_lb, | ||||
|     ) | ||||
|     (self.inner.get(pointer_to_upstream.ptr), pointer_to_upstream.context) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
|  |  | |||
|  | @ -2,7 +2,7 @@ use crate::error::*; | |||
| 
 | ||||
| #[derive(Debug, Clone, Hash, Eq, PartialEq)] | ||||
| pub enum UpstreamOption { | ||||
|   OverrideHost, | ||||
|   KeepOriginalHost, | ||||
|   UpgradeInsecureRequests, | ||||
|   ForceHttp11Upstream, | ||||
|   ForceHttp2Upstream, | ||||
|  | @ -10,13 +10,13 @@ pub enum UpstreamOption { | |||
| } | ||||
| impl TryFrom<&str> for UpstreamOption { | ||||
|   type Error = RpxyError; | ||||
|   fn try_from(val: &str) -> Result<Self> { | ||||
|   fn try_from(val: &str) -> RpxyResult<Self> { | ||||
|     match val { | ||||
|       "override_host" => Ok(Self::OverrideHost), | ||||
|       "keep_original_host" => Ok(Self::KeepOriginalHost), | ||||
|       "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), | ||||
|       "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), | ||||
|       "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), | ||||
|       _ => Err(RpxyError::Other(anyhow!("Unsupported header option"))), | ||||
|       _ => Err(RpxyError::UnsupportedUpstreamOption), | ||||
|     } | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -4,14 +4,16 @@ pub const RESPONSE_HEADER_SERVER: &str = "rpxy"; | |||
| pub const TCP_LISTEN_BACKLOG: u32 = 1024; | ||||
| // pub const HTTP_LISTEN_PORT: u16 = 8080;
 | ||||
| // pub const HTTPS_LISTEN_PORT: u16 = 8443;
 | ||||
| pub const PROXY_TIMEOUT_SEC: u64 = 60; | ||||
| pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; | ||||
| pub const PROXY_IDLE_TIMEOUT_SEC: u64 = 20; | ||||
| pub const UPSTREAM_IDLE_TIMEOUT_SEC: u64 = 20; | ||||
| pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser
 | ||||
| pub const MAX_CLIENTS: usize = 512; | ||||
| pub const MAX_CONCURRENT_STREAMS: u32 = 64; | ||||
| pub const CERTS_WATCH_DELAY_SECS: u32 = 60; | ||||
| pub const LOAD_CERTS_ONLY_WHEN_UPDATED: bool = true; | ||||
| 
 | ||||
| pub const CONNECTION_TIMEOUT_SEC: u64 = 30; // timeout to serve a connection. this might limits the max length of response.
 | ||||
| 
 | ||||
| // #[cfg(feature = "http3")]
 | ||||
| // pub const H3_RESPONSE_BUF_SIZE: usize = 65_536; // 64KB
 | ||||
| // #[cfg(feature = "http3")]
 | ||||
|  |  | |||
							
								
								
									
										31
									
								
								rpxy-lib/src/count.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								rpxy-lib/src/count.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,31 @@ | |||
| use std::sync::{ | ||||
|   atomic::{AtomicUsize, Ordering}, | ||||
|   Arc, | ||||
| }; | ||||
| 
 | ||||
| #[derive(Debug, Clone, Default)] | ||||
| /// Counter for serving requests
 | ||||
| pub struct RequestCount(Arc<AtomicUsize>); | ||||
| 
 | ||||
| impl RequestCount { | ||||
|   pub fn current(&self) -> usize { | ||||
|     self.0.load(Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
|   pub fn increment(&self) -> usize { | ||||
|     self.0.fetch_add(1, Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
|   pub fn decrement(&self) -> usize { | ||||
|     let mut count; | ||||
|     while { | ||||
|       count = self.0.load(Ordering::Relaxed); | ||||
|       count > 0 | ||||
|         && self | ||||
|           .0 | ||||
|           .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) | ||||
|           != Ok(count) | ||||
|     } {} | ||||
|     count | ||||
|   } | ||||
| } | ||||
							
								
								
									
										91
									
								
								rpxy-lib/src/crypto/certs.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								rpxy-lib/src/crypto/certs.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,91 @@ | |||
| use async_trait::async_trait; | ||||
| use rustc_hash::FxHashSet as HashSet; | ||||
| use rustls::{ | ||||
|   sign::{any_supported_type, CertifiedKey}, | ||||
|   Certificate, OwnedTrustAnchor, PrivateKey, | ||||
| }; | ||||
| use std::io; | ||||
| use x509_parser::prelude::*; | ||||
| 
 | ||||
| #[async_trait] | ||||
| // Trait to read certs and keys anywhere from KVS, file, sqlite, etc.
 | ||||
| pub trait CryptoSource { | ||||
|   type Error; | ||||
| 
 | ||||
|   /// read crypto materials from source
 | ||||
|   async fn read(&self) -> Result<CertsAndKeys, Self::Error>; | ||||
| 
 | ||||
|   /// Returns true when mutual tls is enabled
 | ||||
|   fn is_mutual_tls(&self) -> bool; | ||||
| } | ||||
| 
 | ||||
| /// Certificates and private keys in rustls loaded from files
 | ||||
| #[derive(Debug, PartialEq, Eq, Clone)] | ||||
| pub struct CertsAndKeys { | ||||
|   pub certs: Vec<Certificate>, | ||||
|   pub cert_keys: Vec<PrivateKey>, | ||||
|   pub client_ca_certs: Option<Vec<Certificate>>, | ||||
| } | ||||
| 
 | ||||
| impl CertsAndKeys { | ||||
|   pub fn parse_server_certs_and_keys(&self) -> Result<CertifiedKey, anyhow::Error> { | ||||
|     // for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() {
 | ||||
|     let signing_key = self | ||||
|       .cert_keys | ||||
|       .iter() | ||||
|       .find_map(|k| { | ||||
|         if let Ok(sk) = any_supported_type(k) { | ||||
|           Some(sk) | ||||
|         } else { | ||||
|           None | ||||
|         } | ||||
|       }) | ||||
|       .ok_or_else(|| { | ||||
|         io::Error::new( | ||||
|           io::ErrorKind::InvalidInput, | ||||
|           "Unable to find a valid certificate and key", | ||||
|         ) | ||||
|       })?; | ||||
|     Ok(CertifiedKey::new(self.certs.clone(), signing_key)) | ||||
|   } | ||||
| 
 | ||||
|   pub fn parse_client_ca_certs(&self) -> Result<(Vec<OwnedTrustAnchor>, HashSet<Vec<u8>>), anyhow::Error> { | ||||
|     let certs = self.client_ca_certs.as_ref().ok_or(anyhow::anyhow!("No client cert"))?; | ||||
| 
 | ||||
|     let owned_trust_anchors: Vec<_> = certs | ||||
|       .iter() | ||||
|       .map(|v| { | ||||
|         // let trust_anchor = tokio_rustls::webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap();
 | ||||
|         let trust_anchor = webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); | ||||
|         rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( | ||||
|           trust_anchor.subject, | ||||
|           trust_anchor.spki, | ||||
|           trust_anchor.name_constraints, | ||||
|         ) | ||||
|       }) | ||||
|       .collect(); | ||||
| 
 | ||||
|     // TODO: SKID is not used currently
 | ||||
|     let subject_key_identifiers: HashSet<_> = certs | ||||
|       .iter() | ||||
|       .filter_map(|v| { | ||||
|         // retrieve ca key id (subject key id)
 | ||||
|         let cert = parse_x509_certificate(&v.0).unwrap().1; | ||||
|         let subject_key_ids = cert | ||||
|           .iter_extensions() | ||||
|           .filter_map(|ext| match ext.parsed_extension() { | ||||
|             ParsedExtension::SubjectKeyIdentifier(skid) => Some(skid), | ||||
|             _ => None, | ||||
|           }) | ||||
|           .collect::<Vec<_>>(); | ||||
|         if !subject_key_ids.is_empty() { | ||||
|           Some(subject_key_ids[0].0.to_owned()) | ||||
|         } else { | ||||
|           None | ||||
|         } | ||||
|       }) | ||||
|       .collect(); | ||||
| 
 | ||||
|     Ok((owned_trust_anchors, subject_key_identifiers)) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										36
									
								
								rpxy-lib/src/crypto/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								rpxy-lib/src/crypto/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,36 @@ | |||
| mod certs; | ||||
| mod service; | ||||
| 
 | ||||
| use crate::{ | ||||
|   backend::BackendAppManager, | ||||
|   constants::{CERTS_WATCH_DELAY_SECS, LOAD_CERTS_ONLY_WHEN_UPDATED}, | ||||
|   error::RpxyResult, | ||||
| }; | ||||
| use hot_reload::{ReloaderReceiver, ReloaderService}; | ||||
| use service::CryptoReloader; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| pub use certs::{CertsAndKeys, CryptoSource}; | ||||
| pub use service::{ServerCrypto, ServerCryptoBase, SniServerCryptoMap}; | ||||
| 
 | ||||
| /// Result type inner of certificate reloader service
 | ||||
| type ReloaderServiceResultInner<T> = ( | ||||
|   ReloaderService<CryptoReloader<T>, ServerCryptoBase>, | ||||
|   ReloaderReceiver<ServerCryptoBase>, | ||||
| ); | ||||
| /// Build certificate reloader service
 | ||||
| pub(crate) async fn build_cert_reloader<T>( | ||||
|   app_manager: &Arc<BackendAppManager<T>>, | ||||
| ) -> RpxyResult<ReloaderServiceResultInner<T>> | ||||
| where | ||||
|   T: CryptoSource + Clone + Send + Sync + 'static, | ||||
| { | ||||
|   let (cert_reloader_service, cert_reloader_rx) = ReloaderService::< | ||||
|     service::CryptoReloader<T>, | ||||
|     service::ServerCryptoBase, | ||||
|   >::new( | ||||
|     app_manager, CERTS_WATCH_DELAY_SECS, !LOAD_CERTS_ONLY_WHEN_UPDATED | ||||
|   ) | ||||
|   .await?; | ||||
|   Ok((cert_reloader_service, cert_reloader_rx)) | ||||
| } | ||||
							
								
								
									
										272
									
								
								rpxy-lib/src/crypto/service.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										272
									
								
								rpxy-lib/src/crypto/service.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,272 @@ | |||
| use super::certs::{CertsAndKeys, CryptoSource}; | ||||
| use crate::{backend::BackendAppManager, log::*, name_exp::ServerName}; | ||||
| use async_trait::async_trait; | ||||
| use hot_reload::*; | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use rustls::{server::ResolvesServerCertUsingSni, sign::CertifiedKey, RootCertStore, ServerConfig}; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| /// Reloader service for certificates and keys for TLS
 | ||||
| pub struct CryptoReloader<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
|   inner: Arc<BackendAppManager<T>>, | ||||
| } | ||||
| 
 | ||||
| /// SNI to ServerConfig map type
 | ||||
| pub type SniServerCryptoMap = HashMap<ServerName, Arc<ServerConfig>>; | ||||
| /// SNI to ServerConfig map
 | ||||
| pub struct ServerCrypto { | ||||
|   // For Quic/HTTP3, only servers with no client authentication
 | ||||
|   #[cfg(feature = "http3-quinn")] | ||||
|   pub inner_global_no_client_auth: Arc<ServerConfig>, | ||||
|   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
|   pub inner_global_no_client_auth: s2n_quic_rustls::Server, | ||||
|   // For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers
 | ||||
|   pub inner_local_map: Arc<SniServerCryptoMap>, | ||||
| } | ||||
| 
 | ||||
| /// Reloader target for the certificate reloader service
 | ||||
| #[derive(Debug, PartialEq, Eq, Clone, Default)] | ||||
| pub struct ServerCryptoBase { | ||||
|   inner: HashMap<ServerName, CertsAndKeys>, | ||||
| } | ||||
| 
 | ||||
| #[async_trait] | ||||
| impl<T> Reload<ServerCryptoBase> for CryptoReloader<T> | ||||
| where | ||||
|   T: CryptoSource + Sync + Send, | ||||
| { | ||||
|   type Source = Arc<BackendAppManager<T>>; | ||||
|   async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> { | ||||
|     Ok(Self { inner: source.clone() }) | ||||
|   } | ||||
| 
 | ||||
|   async fn reload(&self) -> Result<Option<ServerCryptoBase>, ReloaderError<ServerCryptoBase>> { | ||||
|     let mut certs_and_keys_map = ServerCryptoBase::default(); | ||||
| 
 | ||||
|     for (server_name_bytes_exp, backend) in self.inner.apps.iter() { | ||||
|       if let Some(crypto_source) = &backend.crypto_source { | ||||
|         let certs_and_keys = crypto_source | ||||
|           .read() | ||||
|           .await | ||||
|           .map_err(|_e| ReloaderError::<ServerCryptoBase>::Reload("Failed to reload cert, key or ca cert"))?; | ||||
|         certs_and_keys_map | ||||
|           .inner | ||||
|           .insert(server_name_bytes_exp.to_owned(), certs_and_keys); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     Ok(Some(certs_and_keys_map)) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl TryInto<Arc<ServerCrypto>> for &ServerCryptoBase { | ||||
|   type Error = anyhow::Error; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<Arc<ServerCrypto>, Self::Error> { | ||||
|     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|     let server_crypto_global = self.build_server_crypto_global()?; | ||||
|     let server_crypto_local_map: SniServerCryptoMap = self.build_server_crypto_local_map()?; | ||||
| 
 | ||||
|     Ok(Arc::new(ServerCrypto { | ||||
|       #[cfg(feature = "http3-quinn")] | ||||
|       inner_global_no_client_auth: Arc::new(server_crypto_global), | ||||
|       #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
|       inner_global_no_client_auth: server_crypto_global, | ||||
|       inner_local_map: Arc::new(server_crypto_local_map), | ||||
|     })) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl ServerCryptoBase { | ||||
|   fn build_server_crypto_local_map(&self) -> Result<SniServerCryptoMap, ReloaderError<ServerCryptoBase>> { | ||||
|     let mut server_crypto_local_map: SniServerCryptoMap = HashMap::default(); | ||||
| 
 | ||||
|     for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { | ||||
|       let server_name: String = server_name_bytes_exp.try_into()?; | ||||
| 
 | ||||
|       // Parse server certificates and private keys
 | ||||
|       let Ok(certified_key): Result<CertifiedKey, _> = certs_and_keys.parse_server_certs_and_keys() else { | ||||
|         warn!("Failed to add certificate for {}", server_name); | ||||
|         continue; | ||||
|       }; | ||||
| 
 | ||||
|       let mut resolver_local = ResolvesServerCertUsingSni::new(); | ||||
|       let mut client_ca_roots_local = RootCertStore::empty(); | ||||
| 
 | ||||
|       // add server certificate and key
 | ||||
|       if let Err(e) = resolver_local.add(server_name.as_str(), certified_key.to_owned()) { | ||||
|         error!( | ||||
|           "{}: Failed to read some certificates and keys {}", | ||||
|           server_name.as_str(), | ||||
|           e | ||||
|         ) | ||||
|       } | ||||
| 
 | ||||
|       // add client certificate if specified
 | ||||
|       if certs_and_keys.client_ca_certs.is_some() { | ||||
|         // add client certificate if specified
 | ||||
|         match certs_and_keys.parse_client_ca_certs() { | ||||
|           Ok((owned_trust_anchors, _subject_key_ids)) => { | ||||
|             client_ca_roots_local.add_trust_anchors(owned_trust_anchors.into_iter()); | ||||
|           } | ||||
|           Err(e) => { | ||||
|             warn!( | ||||
|               "Failed to add client CA certificate for {}: {}", | ||||
|               server_name.as_str(), | ||||
|               e | ||||
|             ); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
| 
 | ||||
|       let mut server_config_local = if client_ca_roots_local.is_empty() { | ||||
|         // with no client auth, enable http1.1 -- 3
 | ||||
|         #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] | ||||
|         { | ||||
|           ServerConfig::builder() | ||||
|             .with_safe_defaults() | ||||
|             .with_no_client_auth() | ||||
|             .with_cert_resolver(Arc::new(resolver_local)) | ||||
|         } | ||||
|         #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|         { | ||||
|           let mut sc = ServerConfig::builder() | ||||
|             .with_safe_defaults() | ||||
|             .with_no_client_auth() | ||||
|             .with_cert_resolver(Arc::new(resolver_local)); | ||||
|           sc.alpn_protocols = vec![b"h3".to_vec(), b"hq-29".to_vec()]; // TODO: remove hq-29 later?
 | ||||
|           sc | ||||
|         } | ||||
|       } else { | ||||
|         // with client auth, enable only http1.1 and 2
 | ||||
|         // let client_certs_verifier = rustls::server::AllowAnyAnonymousOrAuthenticatedClient::new(client_ca_roots);
 | ||||
|         let client_certs_verifier = rustls::server::AllowAnyAuthenticatedClient::new(client_ca_roots_local); | ||||
|         ServerConfig::builder() | ||||
|           .with_safe_defaults() | ||||
|           .with_client_cert_verifier(Arc::new(client_certs_verifier)) | ||||
|           .with_cert_resolver(Arc::new(resolver_local)) | ||||
|       }; | ||||
|       server_config_local.alpn_protocols.push(b"h2".to_vec()); | ||||
|       server_config_local.alpn_protocols.push(b"http/1.1".to_vec()); | ||||
| 
 | ||||
|       server_crypto_local_map.insert(server_name_bytes_exp.to_owned(), Arc::new(server_config_local)); | ||||
|     } | ||||
|     Ok(server_crypto_local_map) | ||||
|   } | ||||
| 
 | ||||
|   #[cfg(feature = "http3-quinn")] | ||||
|   fn build_server_crypto_global(&self) -> Result<ServerConfig, ReloaderError<ServerCryptoBase>> { | ||||
|     let mut resolver_global = ResolvesServerCertUsingSni::new(); | ||||
| 
 | ||||
|     for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { | ||||
|       let server_name: String = server_name_bytes_exp.try_into()?; | ||||
| 
 | ||||
|       // Parse server certificates and private keys
 | ||||
|       let Ok(certified_key): Result<CertifiedKey, _> = certs_and_keys.parse_server_certs_and_keys() else { | ||||
|         warn!("Failed to add certificate for {}", server_name); | ||||
|         continue; | ||||
|       }; | ||||
| 
 | ||||
|       if certs_and_keys.client_ca_certs.is_none() { | ||||
|         // aggregated server config for no client auth server for http3
 | ||||
|         if let Err(e) = resolver_global.add(server_name.as_str(), certified_key) { | ||||
|           error!( | ||||
|             "{}: Failed to read some certificates and keys {}", | ||||
|             server_name.as_str(), | ||||
|             e | ||||
|           ) | ||||
|         } | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     //////////////
 | ||||
|     let mut server_crypto_global = ServerConfig::builder() | ||||
|       .with_safe_defaults() | ||||
|       .with_no_client_auth() | ||||
|       .with_cert_resolver(Arc::new(resolver_global)); | ||||
| 
 | ||||
|     //////////////////////////////
 | ||||
| 
 | ||||
|     server_crypto_global.alpn_protocols = vec![ | ||||
|       b"h3".to_vec(), | ||||
|       b"hq-29".to_vec(), // TODO: remove later?
 | ||||
|       b"h2".to_vec(), | ||||
|       b"http/1.1".to_vec(), | ||||
|     ]; | ||||
|     Ok(server_crypto_global) | ||||
|   } | ||||
| 
 | ||||
|   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
|   fn build_server_crypto_global(&self) -> Result<s2n_quic_rustls::Server, ReloaderError<ServerCryptoBase>> { | ||||
|     let mut resolver_global = s2n_quic_rustls::rustls::server::ResolvesServerCertUsingSni::new(); | ||||
| 
 | ||||
|     for (server_name_bytes_exp, certs_and_keys) in self.inner.iter() { | ||||
|       let server_name: String = server_name_bytes_exp.try_into()?; | ||||
| 
 | ||||
|       // Parse server certificates and private keys
 | ||||
|       let Ok(certified_key) = parse_server_certs_and_keys_s2n(certs_and_keys) else { | ||||
|         warn!("Failed to add certificate for {}", server_name); | ||||
|         continue; | ||||
|       }; | ||||
| 
 | ||||
|       if certs_and_keys.client_ca_certs.is_none() { | ||||
|         // aggregated server config for no client auth server for http3
 | ||||
|         if let Err(e) = resolver_global.add(server_name.as_str(), certified_key) { | ||||
|           error!( | ||||
|             "{}: Failed to read some certificates and keys {}", | ||||
|             server_name.as_str(), | ||||
|             e | ||||
|           ) | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     let alpn = vec![ | ||||
|       b"h3".to_vec(), | ||||
|       b"hq-29".to_vec(), // TODO: remove later?
 | ||||
|       b"h2".to_vec(), | ||||
|       b"http/1.1".to_vec(), | ||||
|     ]; | ||||
|     let server_crypto_global = s2n_quic::provider::tls::rustls::Server::builder() | ||||
|       .with_cert_resolver(Arc::new(resolver_global)) | ||||
|       .map_err(|e| anyhow::anyhow!(e))? | ||||
|       .with_application_protocols(alpn.iter()) | ||||
|       .map_err(|e| anyhow::anyhow!(e))? | ||||
|       .build() | ||||
|       .map_err(|e| anyhow::anyhow!(e))?; | ||||
|     Ok(server_crypto_global) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
| /// This is workaround for the version difference between rustls and s2n-quic-rustls
 | ||||
| fn parse_server_certs_and_keys_s2n( | ||||
|   certs_and_keys: &CertsAndKeys, | ||||
| ) -> Result<s2n_quic_rustls::rustls::sign::CertifiedKey, anyhow::Error> { | ||||
|   let signing_key = certs_and_keys | ||||
|     .cert_keys | ||||
|     .iter() | ||||
|     .find_map(|k| { | ||||
|       let s2n_private_key = s2n_quic_rustls::PrivateKey(k.0.clone()); | ||||
|       if let Ok(sk) = s2n_quic_rustls::rustls::sign::any_supported_type(&s2n_private_key) { | ||||
|         Some(sk) | ||||
|       } else { | ||||
|         None | ||||
|       } | ||||
|     }) | ||||
|     .ok_or_else(|| { | ||||
|       std::io::Error::new( | ||||
|         std::io::ErrorKind::InvalidInput, | ||||
|         "Unable to find a valid certificate and key", | ||||
|       ) | ||||
|     })?; | ||||
|   let certs: Vec<_> = certs_and_keys | ||||
|     .certs | ||||
|     .iter() | ||||
|     .map(|c| s2n_quic_rustls::rustls::Certificate(c.0.clone())) | ||||
|     .collect(); | ||||
|   Ok(s2n_quic_rustls::rustls::sign::CertifiedKey::new(certs, signing_key)) | ||||
| } | ||||
|  | @ -1,86 +1,101 @@ | |||
| pub use anyhow::{anyhow, bail, ensure, Context}; | ||||
| use std::io; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
| pub type Result<T> = std::result::Result<T, RpxyError>; | ||||
| pub type RpxyResult<T> = std::result::Result<T, RpxyError>; | ||||
| 
 | ||||
| /// Describes things that can go wrong in the Rpxy
 | ||||
| #[derive(Debug, Error)] | ||||
| pub enum RpxyError { | ||||
|   #[error("Proxy build error: {0}")] | ||||
|   ProxyBuild(#[from] crate::proxy::ProxyBuilderError), | ||||
|   // general errors
 | ||||
|   #[error("IO error: {0}")] | ||||
|   Io(#[from] std::io::Error), | ||||
| 
 | ||||
|   #[error("Backend build error: {0}")] | ||||
|   BackendBuild(#[from] crate::backend::BackendBuilderError), | ||||
|   // TLS errors
 | ||||
|   #[error("Failed to build TLS acceptor: {0}")] | ||||
|   FailedToTlsHandshake(String), | ||||
|   #[error("No server name in ClientHello")] | ||||
|   NoServerNameInClientHello, | ||||
|   #[error("No TLS serving app: {0}")] | ||||
|   NoTlsServingApp(String), | ||||
|   #[error("Failed to update server crypto: {0}")] | ||||
|   FailedToUpdateServerCrypto(String), | ||||
|   #[error("No server crypto: {0}")] | ||||
|   NoServerCrypto(String), | ||||
| 
 | ||||
|   #[error("MessageHandler build error: {0}")] | ||||
|   HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError), | ||||
|   // hyper errors
 | ||||
|   #[error("hyper body manipulation error: {0}")] | ||||
|   HyperBodyManipulationError(String), | ||||
|   #[error("New closed in incoming-like")] | ||||
|   HyperIncomingLikeNewClosed, | ||||
|   #[error("New body write aborted")] | ||||
|   HyperNewBodyWriteAborted, | ||||
|   #[error("Hyper error in serving request or response body type: {0}")] | ||||
|   HyperBodyError(#[from] hyper::Error), | ||||
| 
 | ||||
|   #[error("Config builder error: {0}")] | ||||
|   ConfigBuild(&'static str), | ||||
| 
 | ||||
|   #[error("Http Message Handler Error: {0}")] | ||||
|   Handler(&'static str), | ||||
| 
 | ||||
|   #[error("Cache Error: {0}")] | ||||
|   Cache(&'static str), | ||||
| 
 | ||||
|   #[error("Http Request Message Error: {0}")] | ||||
|   Request(&'static str), | ||||
| 
 | ||||
|   #[error("TCP/UDP Proxy Layer Error: {0}")] | ||||
|   Proxy(String), | ||||
| 
 | ||||
|   #[allow(unused)] | ||||
|   #[error("LoadBalance Layer Error: {0}")] | ||||
|   LoadBalance(String), | ||||
| 
 | ||||
|   #[error("I/O Error: {0}")] | ||||
|   Io(#[from] io::Error), | ||||
| 
 | ||||
|   // #[error("Toml Deserialization Error")]
 | ||||
|   // TomlDe(#[from] toml::de::Error),
 | ||||
|   #[cfg(feature = "http3-quinn")] | ||||
|   #[error("Quic Connection Error [quinn]: {0}")] | ||||
|   QuicConn(#[from] quinn::ConnectionError), | ||||
| 
 | ||||
|   #[cfg(feature = "http3-s2n")] | ||||
|   #[error("Quic Connection Error [s2n-quic]: {0}")] | ||||
|   QUicConn(#[from] s2n_quic::connection::Error), | ||||
|   // http/3 errors
 | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   #[error("H3 error: {0}")] | ||||
|   H3Error(#[from] h3::Error), | ||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   #[error("Exceeds max request body size for HTTP/3")] | ||||
|   H3TooLargeBody, | ||||
| 
 | ||||
|   #[cfg(feature = "http3-quinn")] | ||||
|   #[error("H3 Error [quinn]: {0}")] | ||||
|   H3(#[from] h3::Error), | ||||
|   #[error("Invalid rustls TLS version: {0}")] | ||||
|   QuinnInvalidTlsProtocolVersion(String), | ||||
|   #[cfg(feature = "http3-quinn")] | ||||
|   #[error("Quinn connection error: {0}")] | ||||
|   QuinnConnectionFailed(#[from] quinn::ConnectionError), | ||||
| 
 | ||||
|   #[cfg(feature = "http3-s2n")] | ||||
|   #[error("H3 Error [s2n-quic]: {0}")] | ||||
|   H3(#[from] s2n_quic_h3::h3::Error), | ||||
|   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
|   #[error("s2n-quic validation error: {0}")] | ||||
|   S2nQuicValidationError(#[from] s2n_quic_core::transport::parameters::ValidationError), | ||||
|   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
|   #[error("s2n-quic connection error: {0}")] | ||||
|   S2nQuicConnectionError(#[from] s2n_quic_core::connection::Error), | ||||
|   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
|   #[error("s2n-quic start error: {0}")] | ||||
|   S2nQuicStartError(#[from] s2n_quic::provider::StartError), | ||||
| 
 | ||||
|   #[error("rustls Connection Error: {0}")] | ||||
|   Rustls(#[from] rustls::Error), | ||||
|   // certificate reloader errors
 | ||||
|   #[error("No certificate reloader when building a proxy for TLS")] | ||||
|   NoCertificateReloader, | ||||
|   #[error("Certificate reload error: {0}")] | ||||
|   CertificateReloadError(#[from] hot_reload::ReloaderError<crate::crypto::ServerCryptoBase>), | ||||
| 
 | ||||
|   #[error("Hyper Error: {0}")] | ||||
|   Hyper(#[from] hyper::Error), | ||||
|   // backend errors
 | ||||
|   #[error("Invalid reverse proxy setting")] | ||||
|   InvalidReverseProxyConfig, | ||||
|   #[error("Invalid upstream option setting")] | ||||
|   InvalidUpstreamOptionSetting, | ||||
|   #[error("Failed to build backend app: {0}")] | ||||
|   FailedToBuildBackendApp(#[from] crate::backend::BackendAppBuilderError), | ||||
| 
 | ||||
|   #[error("Hyper Http Error: {0}")] | ||||
|   HyperHttp(#[from] hyper::http::Error), | ||||
|   // Handler errors
 | ||||
|   #[error("Failed to build message handler: {0}")] | ||||
|   FailedToBuildMessageHandler(#[from] crate::message_handler::HttpMessageHandlerBuilderError), | ||||
|   #[error("Failed to upgrade request: {0}")] | ||||
|   FailedToUpgradeRequest(String), | ||||
|   #[error("Failed to upgrade response: {0}")] | ||||
|   FailedToUpgradeResponse(String), | ||||
|   #[error("Failed to copy bidirectional for upgraded connections: {0}")] | ||||
|   FailedToCopyBidirectional(String), | ||||
| 
 | ||||
|   #[error("Hyper Http HeaderValue Error: {0}")] | ||||
|   HyperHeaderValue(#[from] hyper::header::InvalidHeaderValue), | ||||
|   // Forwarder errors
 | ||||
|   #[error("Failed to build forwarder: {0}")] | ||||
|   FailedToBuildForwarder(String), | ||||
|   #[error("Failed to fetch from upstream: {0}")] | ||||
|   FailedToFetchFromUpstream(String), | ||||
| 
 | ||||
|   #[error("Hyper Http HeaderName Error: {0}")] | ||||
|   HyperHeaderName(#[from] hyper::header::InvalidHeaderName), | ||||
|   // Upstream connection setting errors
 | ||||
|   #[error("Unsupported upstream option")] | ||||
|   UnsupportedUpstreamOption, | ||||
| 
 | ||||
|   #[error(transparent)] | ||||
|   Other(#[from] anyhow::Error), | ||||
| } | ||||
| 
 | ||||
| #[allow(dead_code)] | ||||
| #[derive(Debug, Error, Clone)] | ||||
| pub enum ClientCertsError { | ||||
|   #[error("TLS Client Certificate is Required for Given SNI: {0}")] | ||||
|   ClientCertRequired(String), | ||||
| 
 | ||||
|   #[error("Inconsistent TLS Client Certificate for Given SNI: {0}")] | ||||
|   InconsistentClientCert(String), | ||||
|   // Cache error map
 | ||||
|   #[cfg(feature = "cache")] | ||||
|   #[error("Cache error: {0}")] | ||||
|   CacheError(#[from] crate::forwarder::CacheError), | ||||
| 
 | ||||
|   // Others
 | ||||
|   #[error("Infallible")] | ||||
|   Infallible(#[from] std::convert::Infallible), | ||||
| } | ||||
|  |  | |||
							
								
								
									
										47
									
								
								rpxy-lib/src/forwarder/cache/cache_error.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								rpxy-lib/src/forwarder/cache/cache_error.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,47 @@ | |||
| use thiserror::Error; | ||||
| 
 | ||||
| pub(crate) type CacheResult<T> = std::result::Result<T, CacheError>; | ||||
| 
 | ||||
| /// Describes things that can go wrong in the Rpxy
 | ||||
| #[derive(Debug, Error)] | ||||
| pub enum CacheError { | ||||
|   // Cache errors,
 | ||||
|   #[error("Invalid null request and/or response")] | ||||
|   NullRequestOrResponse, | ||||
| 
 | ||||
|   #[error("Failed to acquire mutex lock for cache")] | ||||
|   FailedToAcquiredMutexLockForCache, | ||||
| 
 | ||||
|   #[error("Failed to acquire mutex lock for check")] | ||||
|   FailedToAcquiredMutexLockForCheck, | ||||
| 
 | ||||
|   #[error("Failed to create file cache")] | ||||
|   FailedToCreateFileCache, | ||||
| 
 | ||||
|   #[error("Failed to write file cache")] | ||||
|   FailedToWriteFileCache, | ||||
| 
 | ||||
|   #[error("Failed to open cache file")] | ||||
|   FailedToOpenCacheFile, | ||||
| 
 | ||||
|   #[error("Too large to cache")] | ||||
|   TooLargeToCache, | ||||
| 
 | ||||
|   #[error("Failed to cache bytes: {0}")] | ||||
|   FailedToCacheBytes(String), | ||||
| 
 | ||||
|   #[error("Failed to send frame to cache {0}")] | ||||
|   FailedToSendFrameToCache(String), | ||||
| 
 | ||||
|   #[error("Failed to send frame from file cache {0}")] | ||||
|   FailedToSendFrameFromCache(String), | ||||
| 
 | ||||
|   #[error("Failed to remove cache file: {0}")] | ||||
|   FailedToRemoveCacheFile(String), | ||||
| 
 | ||||
|   #[error("Invalid cache target")] | ||||
|   InvalidCacheTarget, | ||||
| 
 | ||||
|   #[error("Hash mismatched in cache file")] | ||||
|   HashMismatchedInCacheFile, | ||||
| } | ||||
							
								
								
									
										523
									
								
								rpxy-lib/src/forwarder/cache/cache_main.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										523
									
								
								rpxy-lib/src/forwarder/cache/cache_main.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,523 @@ | |||
| use super::cache_error::*; | ||||
| use crate::{ | ||||
|   globals::Globals, | ||||
|   hyper_ext::body::{full, BoxBody, ResponseBody, UnboundedStreamBody}, | ||||
|   log::*, | ||||
| }; | ||||
| use base64::{engine::general_purpose, Engine as _}; | ||||
| use bytes::{Buf, Bytes, BytesMut}; | ||||
| use futures::channel::mpsc; | ||||
| use http::{Request, Response, Uri}; | ||||
| use http_body_util::{BodyExt, StreamBody}; | ||||
| use http_cache_semantics::CachePolicy; | ||||
| use hyper::body::{Frame, Incoming}; | ||||
| use lru::LruCache; | ||||
| use sha2::{Digest, Sha256}; | ||||
| use std::{ | ||||
|   path::{Path, PathBuf}, | ||||
|   sync::{ | ||||
|     atomic::{AtomicUsize, Ordering}, | ||||
|     Arc, Mutex, | ||||
|   }, | ||||
|   time::SystemTime, | ||||
| }; | ||||
| use tokio::{ | ||||
|   fs::{self, File}, | ||||
|   io::{AsyncReadExt, AsyncWriteExt}, | ||||
|   sync::RwLock, | ||||
| }; | ||||
| 
 | ||||
| /* ---------------------------------------------- */ | ||||
| #[derive(Clone, Debug)] | ||||
| /// Cache main manager
 | ||||
| pub(crate) struct RpxyCache { | ||||
|   /// Inner lru cache manager storing http message caching policy
 | ||||
|   inner: LruCacheManager, | ||||
|   /// Managing cache file objects through RwLock's lock mechanism for file lock
 | ||||
|   file_store: FileStore, | ||||
|   /// Async runtime
 | ||||
|   runtime_handle: tokio::runtime::Handle, | ||||
|   /// Maximum size of each cache file object
 | ||||
|   max_each_size: usize, | ||||
|   /// Maximum size of cache object on memory
 | ||||
|   max_each_size_on_memory: usize, | ||||
|   /// Cache directory path
 | ||||
|   cache_dir: PathBuf, | ||||
| } | ||||
| 
 | ||||
| impl RpxyCache { | ||||
|   /// Generate cache storage
 | ||||
|   pub(crate) async fn new(globals: &Globals) -> Option<Self> { | ||||
|     if !globals.proxy_config.cache_enabled { | ||||
|       return None; | ||||
|     } | ||||
|     let cache_dir = globals.proxy_config.cache_dir.as_ref().unwrap(); | ||||
|     let file_store = FileStore::new(&globals.runtime_handle).await; | ||||
|     let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry); | ||||
| 
 | ||||
|     let max_each_size = globals.proxy_config.cache_max_each_size; | ||||
|     let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory; | ||||
|     if max_each_size < max_each_size_on_memory { | ||||
|       warn!( | ||||
|         "Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache" | ||||
|       ); | ||||
|       max_each_size_on_memory = max_each_size; | ||||
|     } | ||||
| 
 | ||||
|     if let Err(e) = fs::remove_dir_all(cache_dir).await { | ||||
|       warn!("Failed to clean up the cache dir: {e}"); | ||||
|     }; | ||||
|     fs::create_dir_all(&cache_dir).await.unwrap(); | ||||
| 
 | ||||
|     Some(Self { | ||||
|       file_store, | ||||
|       inner, | ||||
|       runtime_handle: globals.runtime_handle.clone(), | ||||
|       max_each_size, | ||||
|       max_each_size_on_memory, | ||||
|       cache_dir: cache_dir.clone(), | ||||
|     }) | ||||
|   } | ||||
| 
 | ||||
|   /// Count cache entries
 | ||||
|   pub(crate) async fn count(&self) -> (usize, usize, usize) { | ||||
|     let total = self.inner.count(); | ||||
|     let file = self.file_store.count().await; | ||||
|     let on_memory = total - file; | ||||
|     (total, on_memory, file) | ||||
|   } | ||||
| 
 | ||||
|   /// Put response into the cache
 | ||||
|   pub(crate) async fn put( | ||||
|     &self, | ||||
|     uri: &hyper::Uri, | ||||
|     mut body: Incoming, | ||||
|     policy: &CachePolicy, | ||||
|   ) -> CacheResult<UnboundedStreamBody> { | ||||
|     let cache_manager = self.inner.clone(); | ||||
|     let mut file_store = self.file_store.clone(); | ||||
|     let uri = uri.clone(); | ||||
|     let policy_clone = policy.clone(); | ||||
|     let max_each_size = self.max_each_size; | ||||
|     let max_each_size_on_memory = self.max_each_size_on_memory; | ||||
|     let cache_dir = self.cache_dir.clone(); | ||||
| 
 | ||||
|     let (body_tx, body_rx) = mpsc::unbounded::<Result<Frame<Bytes>, hyper::Error>>(); | ||||
| 
 | ||||
|     self.runtime_handle.spawn(async move { | ||||
|       let mut size = 0usize; | ||||
|       let mut buf = BytesMut::new(); | ||||
| 
 | ||||
|       loop { | ||||
|         let frame = match body.frame().await { | ||||
|           Some(frame) => frame, | ||||
|           None => { | ||||
|             debug!("Response body finished"); | ||||
|             break; | ||||
|           } | ||||
|         }; | ||||
|         let frame_size = frame.as_ref().map(|f| { | ||||
|           if f.is_data() { | ||||
|             f.data_ref().map(|bytes| bytes.remaining()).unwrap_or_default() | ||||
|           } else { | ||||
|             0 | ||||
|           } | ||||
|         }); | ||||
|         size += frame_size.unwrap_or_default(); | ||||
| 
 | ||||
|         // check size
 | ||||
|         if size > max_each_size { | ||||
|           warn!("Too large to cache"); | ||||
|           return Err(CacheError::TooLargeToCache); | ||||
|         } | ||||
|         frame | ||||
|           .as_ref() | ||||
|           .map(|f| { | ||||
|             if f.is_data() { | ||||
|               let data_bytes = f.data_ref().unwrap().clone(); | ||||
|               // debug!("cache data bytes of {} bytes", data_bytes.len());
 | ||||
|               // We do not use stream-type buffering since it needs to lock file during operation.
 | ||||
|               buf.extend(data_bytes.as_ref()); | ||||
|             } | ||||
|           }) | ||||
|           .map_err(|e| CacheError::FailedToCacheBytes(e.to_string()))?; | ||||
| 
 | ||||
|         // send data to use response downstream
 | ||||
|         body_tx | ||||
|           .unbounded_send(frame) | ||||
|           .map_err(|e| CacheError::FailedToSendFrameToCache(e.to_string()))?; | ||||
|       } | ||||
| 
 | ||||
|       let buf = buf.freeze(); | ||||
|       // Calculate hash of the cached data, after all data is received.
 | ||||
|       // In-operation calculation is possible but it blocks sending data.
 | ||||
|       let mut hasher = Sha256::new(); | ||||
|       hasher.update(buf.as_ref()); | ||||
|       let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); | ||||
|       debug!("Cached data: {} bytes, hash = {:?}", size, hash_bytes); | ||||
| 
 | ||||
|       // Create cache object
 | ||||
|       let cache_key = derive_cache_key_from_uri(&uri); | ||||
|       let cache_object = CacheObject { | ||||
|         policy: policy_clone, | ||||
|         target: CacheFileOrOnMemory::build(&cache_dir, &uri, &buf, max_each_size_on_memory), | ||||
|         hash: hash_bytes, | ||||
|       }; | ||||
| 
 | ||||
|       if let Some((k, v)) = cache_manager.push(&cache_key, &cache_object)? { | ||||
|         if k != cache_key { | ||||
|           info!("Over the cache capacity. Evict least recent used entry"); | ||||
|           if let CacheFileOrOnMemory::File(path) = v.target { | ||||
|             file_store.evict(&path).await; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       // store cache object to file
 | ||||
|       if let CacheFileOrOnMemory::File(_) = cache_object.target { | ||||
|         file_store.create(&cache_object, &buf).await?; | ||||
|       } | ||||
| 
 | ||||
|       Ok(()) as CacheResult<()> | ||||
|     }); | ||||
| 
 | ||||
|     let stream_body = StreamBody::new(body_rx); | ||||
| 
 | ||||
|     Ok(stream_body) | ||||
|   } | ||||
| 
 | ||||
|   /// Get cached response
 | ||||
|   pub(crate) async fn get<R>(&self, req: &Request<R>) -> Option<Response<ResponseBody>> { | ||||
|     debug!( | ||||
|       "Current cache status: (total, on-memory, file) = {:?}", | ||||
|       self.count().await | ||||
|     ); | ||||
|     let cache_key = derive_cache_key_from_uri(req.uri()); | ||||
| 
 | ||||
|     // First check cache chance
 | ||||
|     let Ok(Some(cached_object)) = self.inner.get(&cache_key) else { | ||||
|       return None; | ||||
|     }; | ||||
| 
 | ||||
|     // Secondly check the cache freshness as an HTTP message
 | ||||
|     let now = SystemTime::now(); | ||||
|     let http_cache_semantics::BeforeRequest::Fresh(res_parts) = cached_object.policy.before_request(req, now) else { | ||||
|       // Evict stale cache entry.
 | ||||
|       // This might be okay to keep as is since it would be updated later.
 | ||||
|       // However, there is no guarantee that newly got objects will be still cacheable.
 | ||||
|       // So, we have to evict stale cache entries and cache file objects if found.
 | ||||
|       debug!("Stale cache entry: {cache_key}"); | ||||
|       let _evicted_entry = self.inner.evict(&cache_key); | ||||
|       // For cache file
 | ||||
|       if let CacheFileOrOnMemory::File(path) = &cached_object.target { | ||||
|         self.file_store.evict(&path).await; | ||||
|       } | ||||
|       return None; | ||||
|     }; | ||||
| 
 | ||||
|     // Finally retrieve the file/on-memory object
 | ||||
|     let response_body = match cached_object.target { | ||||
|       CacheFileOrOnMemory::File(path) => { | ||||
|         let stream_body = match self.file_store.read(path.clone(), &cached_object.hash).await { | ||||
|           Ok(s) => s, | ||||
|           Err(e) => { | ||||
|             warn!("Failed to read from file cache: {e}"); | ||||
|             let _evicted_entry = self.inner.evict(&cache_key); | ||||
|             self.file_store.evict(path).await; | ||||
|             return None; | ||||
|           } | ||||
|         }; | ||||
|         debug!("Cache hit from file: {cache_key}"); | ||||
|         ResponseBody::Streamed(stream_body) | ||||
|       } | ||||
|       CacheFileOrOnMemory::OnMemory(object) => { | ||||
|         debug!("Cache hit from on memory: {cache_key}"); | ||||
|         let mut hasher = Sha256::new(); | ||||
|         hasher.update(object.as_ref()); | ||||
|         let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); | ||||
|         if hash_bytes != cached_object.hash { | ||||
|           warn!("Hash mismatched. Cache object is corrupted"); | ||||
|           let _evicted_entry = self.inner.evict(&cache_key); | ||||
|           return None; | ||||
|         } | ||||
|         ResponseBody::Boxed(BoxBody::new(full(object))) | ||||
|       } | ||||
|     }; | ||||
|     Some(Response::from_parts(res_parts, response_body)) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ---------------------------------------------- */ | ||||
| #[derive(Debug, Clone)] | ||||
| /// Cache file manager outer that is responsible to handle `RwLock`
 | ||||
| struct FileStore { | ||||
|   /// Inner file store main object
 | ||||
|   inner: Arc<RwLock<FileStoreInner>>, | ||||
| } | ||||
| impl FileStore { | ||||
|   /// Build manager
 | ||||
|   async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { | ||||
|     Self { | ||||
|       inner: Arc::new(RwLock::new(FileStoreInner::new(runtime_handle).await)), | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /// Count file cache entries
 | ||||
|   async fn count(&self) -> usize { | ||||
|     let inner = self.inner.read().await; | ||||
|     inner.cnt | ||||
|   } | ||||
|   /// Create a temporary file cache
 | ||||
|   async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { | ||||
|     let mut inner = self.inner.write().await; | ||||
|     inner.create(cache_object, body_bytes).await | ||||
|   } | ||||
|   /// Evict a temporary file cache
 | ||||
|   async fn evict(&self, path: impl AsRef<Path>) { | ||||
|     // Acquire the write lock
 | ||||
|     let mut inner = self.inner.write().await; | ||||
|     if let Err(e) = inner.remove(path).await { | ||||
|       warn!("Eviction failed during file object removal: {:?}", e); | ||||
|     }; | ||||
|   } | ||||
|   /// Read a temporary file cache
 | ||||
|   async fn read( | ||||
|     &self, | ||||
|     path: impl AsRef<Path> + Send + Sync + 'static, | ||||
|     hash: &Bytes, | ||||
|   ) -> CacheResult<UnboundedStreamBody> { | ||||
|     let inner = self.inner.read().await; | ||||
|     inner.read(path, hash).await | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| /// Manager inner for cache on file system
 | ||||
| struct FileStoreInner { | ||||
|   /// Counter of current cached files
 | ||||
|   cnt: usize, | ||||
|   /// Async runtime
 | ||||
|   runtime_handle: tokio::runtime::Handle, | ||||
| } | ||||
| 
 | ||||
| impl FileStoreInner { | ||||
|   /// Build new cache file manager.
 | ||||
|   /// This first creates cache file dir if not exists, and cleans up the file inside the directory.
 | ||||
|   /// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed.
 | ||||
|   async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { | ||||
|     Self { | ||||
|       cnt: 0, | ||||
|       runtime_handle: runtime_handle.clone(), | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /// Create a new temporary file cache
 | ||||
|   async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { | ||||
|     let cache_filepath = match cache_object.target { | ||||
|       CacheFileOrOnMemory::File(ref path) => path.clone(), | ||||
|       CacheFileOrOnMemory::OnMemory(_) => { | ||||
|         return Err(CacheError::InvalidCacheTarget); | ||||
|       } | ||||
|     }; | ||||
|     let Ok(mut file) = File::create(&cache_filepath).await else { | ||||
|       return Err(CacheError::FailedToCreateFileCache); | ||||
|     }; | ||||
|     let mut bytes_clone = body_bytes.clone(); | ||||
|     while bytes_clone.has_remaining() { | ||||
|       if let Err(e) = file.write_buf(&mut bytes_clone).await { | ||||
|         error!("Failed to write file cache: {e}"); | ||||
|         return Err(CacheError::FailedToWriteFileCache); | ||||
|       }; | ||||
|     } | ||||
|     self.cnt += 1; | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   /// Retrieve a stored temporary file cache
 | ||||
|   async fn read( | ||||
|     &self, | ||||
|     path: impl AsRef<Path> + Send + Sync + 'static, | ||||
|     hash: &Bytes, | ||||
|   ) -> CacheResult<UnboundedStreamBody> { | ||||
|     let Ok(mut file) = File::open(&path).await else { | ||||
|       warn!("Cache file object cannot be opened"); | ||||
|       return Err(CacheError::FailedToOpenCacheFile); | ||||
|     }; | ||||
|     let hash_clone = hash.clone(); | ||||
|     let mut self_clone = self.clone(); | ||||
| 
 | ||||
|     let (body_tx, body_rx) = mpsc::unbounded::<Result<Frame<Bytes>, hyper::Error>>(); | ||||
| 
 | ||||
|     self.runtime_handle.spawn(async move { | ||||
|       let mut hasher = Sha256::new(); | ||||
|       let mut buf = BytesMut::new(); | ||||
|       loop { | ||||
|         match file.read_buf(&mut buf).await { | ||||
|           Ok(0) => break, | ||||
|           Ok(_) => { | ||||
|             let bytes = buf.copy_to_bytes(buf.remaining()); | ||||
|             hasher.update(bytes.as_ref()); | ||||
|             body_tx | ||||
|               .unbounded_send(Ok(Frame::data(bytes))) | ||||
|               .map_err(|e| CacheError::FailedToSendFrameFromCache(e.to_string()))? | ||||
|           } | ||||
|           Err(_) => break, | ||||
|         }; | ||||
|       } | ||||
|       let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); | ||||
|       if hash_bytes != hash_clone { | ||||
|         warn!("Hash mismatched. Cache object is corrupted. Force to remove the cache file."); | ||||
|         // only file can be evicted
 | ||||
|         let _evicted_entry = self_clone.remove(&path).await; | ||||
|         return Err(CacheError::HashMismatchedInCacheFile); | ||||
|       } | ||||
|       Ok(()) as CacheResult<()> | ||||
|     }); | ||||
| 
 | ||||
|     let stream_body = StreamBody::new(body_rx); | ||||
| 
 | ||||
|     Ok(stream_body) | ||||
|   } | ||||
| 
 | ||||
|   /// Remove file
 | ||||
|   async fn remove(&mut self, path: impl AsRef<Path>) -> CacheResult<()> { | ||||
|     fs::remove_file(path.as_ref()) | ||||
|       .await | ||||
|       .map_err(|e| CacheError::FailedToRemoveCacheFile(e.to_string()))?; | ||||
|     self.cnt -= 1; | ||||
|     debug!("Removed a cache file at {:?} (file count: {})", path.as_ref(), self.cnt); | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ---------------------------------------------- */ | ||||
| 
 | ||||
| #[derive(Clone, Debug)] | ||||
| /// Cache target in hybrid manner of on-memory and file system
 | ||||
| pub(crate) enum CacheFileOrOnMemory { | ||||
|   /// Pointer to the temporary cache file
 | ||||
|   File(PathBuf), | ||||
|   /// Cached body itself
 | ||||
|   OnMemory(Bytes), | ||||
| } | ||||
| 
 | ||||
| impl CacheFileOrOnMemory { | ||||
|   /// Get cache object target
 | ||||
|   fn build(cache_dir: &Path, uri: &Uri, object: &Bytes, max_each_size_on_memory: usize) -> Self { | ||||
|     if object.len() > max_each_size_on_memory { | ||||
|       let cache_filename = derive_filename_from_uri(uri); | ||||
|       let cache_filepath = cache_dir.join(cache_filename); | ||||
|       CacheFileOrOnMemory::File(cache_filepath) | ||||
|     } else { | ||||
|       CacheFileOrOnMemory::OnMemory(object.clone()) | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone, Debug)] | ||||
| /// Cache object definition
 | ||||
| struct CacheObject { | ||||
|   /// Cache policy to determine if the stored cache can be used as a response to a new incoming request
 | ||||
|   policy: CachePolicy, | ||||
|   /// Cache target: on-memory object or temporary file
 | ||||
|   target: CacheFileOrOnMemory, | ||||
|   /// SHA256 hash of target to strongly bind the cache metadata (this object) and file target
 | ||||
|   hash: Bytes, | ||||
| } | ||||
| 
 | ||||
| /* ---------------------------------------------- */ | ||||
| #[derive(Debug, Clone)] | ||||
| /// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache`
 | ||||
| struct LruCacheManager { | ||||
|   /// Inner lru cache manager main object
 | ||||
|   inner: Arc<Mutex<LruCache<String, CacheObject>>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう
 | ||||
|   /// Counter of current cached object (total)
 | ||||
|   cnt: Arc<AtomicUsize>, | ||||
| } | ||||
| 
 | ||||
| impl LruCacheManager { | ||||
|   /// Build LruCache
 | ||||
|   fn new(cache_max_entry: usize) -> Self { | ||||
|     Self { | ||||
|       inner: Arc::new(Mutex::new(LruCache::new( | ||||
|         std::num::NonZeroUsize::new(cache_max_entry).unwrap(), | ||||
|       ))), | ||||
|       cnt: Default::default(), | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /// Count entries
 | ||||
|   fn count(&self) -> usize { | ||||
|     self.cnt.load(Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
|   /// Evict an entry
 | ||||
|   fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> { | ||||
|     let Ok(mut lock) = self.inner.lock() else { | ||||
|       error!("Mutex can't be locked to evict a cache entry"); | ||||
|       return None; | ||||
|     }; | ||||
|     let res = lock.pop_entry(cache_key); | ||||
|     // This may be inconsistent with the actual number of entries
 | ||||
|     self.cnt.store(lock.len(), Ordering::Relaxed); | ||||
|     res | ||||
|   } | ||||
| 
 | ||||
|   /// Push an entry
 | ||||
|   fn push(&self, cache_key: &str, cache_object: &CacheObject) -> CacheResult<Option<(String, CacheObject)>> { | ||||
|     let Ok(mut lock) = self.inner.lock() else { | ||||
|       error!("Failed to acquire mutex lock for writing cache entry"); | ||||
|       return Err(CacheError::FailedToAcquiredMutexLockForCache); | ||||
|     }; | ||||
|     let res = Ok(lock.push(cache_key.to_string(), cache_object.clone())); | ||||
|     // This may be inconsistent with the actual number of entries
 | ||||
|     self.cnt.store(lock.len(), Ordering::Relaxed); | ||||
|     res | ||||
|   } | ||||
| 
 | ||||
|   /// Get an entry
 | ||||
|   fn get(&self, cache_key: &str) -> CacheResult<Option<CacheObject>> { | ||||
|     let Ok(mut lock) = self.inner.lock() else { | ||||
|       error!("Mutex can't be locked for checking cache entry"); | ||||
|       return Err(CacheError::FailedToAcquiredMutexLockForCheck); | ||||
|     }; | ||||
|     let Some(cached_object) = lock.get(cache_key) else { | ||||
|       return Ok(None); | ||||
|     }; | ||||
|     Ok(Some(cached_object.clone())) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ---------------------------------------------- */ | ||||
| /// Generate cache policy if the response is cacheable
 | ||||
| pub(crate) fn get_policy_if_cacheable<B1, B2>( | ||||
|   req: Option<&Request<B1>>, | ||||
|   res: Option<&Response<B2>>, | ||||
| ) -> CacheResult<Option<CachePolicy>> | ||||
| // where
 | ||||
| //   B1: core::fmt::Debug,
 | ||||
| { | ||||
|   // deduce cache policy from req and res
 | ||||
|   let (Some(req), Some(res)) = (req, res) else { | ||||
|     return Err(CacheError::NullRequestOrResponse); | ||||
|   }; | ||||
| 
 | ||||
|   let new_policy = CachePolicy::new(req, res); | ||||
|   if new_policy.is_storable() { | ||||
|     // debug!("Response is cacheable: {:?}\n{:?}", req, res.headers());
 | ||||
|     Ok(Some(new_policy)) | ||||
|   } else { | ||||
|     Ok(None) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| fn derive_filename_from_uri(uri: &hyper::Uri) -> String { | ||||
|   let mut hasher = Sha256::new(); | ||||
|   hasher.update(uri.to_string()); | ||||
|   let digest = hasher.finalize(); | ||||
|   general_purpose::URL_SAFE_NO_PAD.encode(digest) | ||||
| } | ||||
| 
 | ||||
| fn derive_cache_key_from_uri(uri: &hyper::Uri) -> String { | ||||
|   uri.to_string() | ||||
| } | ||||
							
								
								
									
										5
									
								
								rpxy-lib/src/forwarder/cache/mod.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								rpxy-lib/src/forwarder/cache/mod.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,5 @@ | |||
| mod cache_error; | ||||
| mod cache_main; | ||||
| 
 | ||||
| pub use cache_error::CacheError; | ||||
| pub(crate) use cache_main::{get_policy_if_cacheable, RpxyCache}; | ||||
							
								
								
									
										238
									
								
								rpxy-lib/src/forwarder/client.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										238
									
								
								rpxy-lib/src/forwarder/client.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,238 @@ | |||
| use crate::{ | ||||
|   error::{RpxyError, RpxyResult}, | ||||
|   globals::Globals, | ||||
|   hyper_ext::{body::ResponseBody, rt::LocalExecutor}, | ||||
|   log::*, | ||||
| }; | ||||
| use async_trait::async_trait; | ||||
| use http::{Request, Response, Version}; | ||||
| use hyper::body::{Body, Incoming}; | ||||
| use hyper_util::client::legacy::{ | ||||
|   connect::{Connect, HttpConnector}, | ||||
|   Client, | ||||
| }; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| #[cfg(feature = "cache")] | ||||
| use super::cache::{get_policy_if_cacheable, RpxyCache}; | ||||
| 
 | ||||
| #[async_trait] | ||||
| /// Definition of the forwarder that simply forward requests from downstream client to upstream app servers.
 | ||||
| pub trait ForwardRequest<B1, B2> { | ||||
|   type Error; | ||||
|   async fn request(&self, req: Request<B1>) -> Result<Response<B2>, Self::Error>; | ||||
| } | ||||
| 
 | ||||
| /// Forwarder http client struct responsible to cache handling
 | ||||
| pub struct Forwarder<C, B> { | ||||
|   #[cfg(feature = "cache")] | ||||
|   cache: Option<RpxyCache>, | ||||
|   inner: Client<C, B>, | ||||
|   inner_h2: Client<C, B>, // `h2c` or http/2-only client is defined separately
 | ||||
| } | ||||
| 
 | ||||
| #[async_trait] | ||||
| impl<C, B1> ForwardRequest<B1, ResponseBody> for Forwarder<C, B1> | ||||
| where | ||||
|   C: Send + Sync + Connect + Clone + 'static, | ||||
|   B1: Body + Send + Sync + Unpin + 'static, | ||||
|   <B1 as Body>::Data: Send, | ||||
|   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   async fn request(&self, req: Request<B1>) -> Result<Response<ResponseBody>, Self::Error> { | ||||
|     // TODO: cache handling
 | ||||
|     #[cfg(feature = "cache")] | ||||
|     { | ||||
|       let mut synth_req = None; | ||||
|       if self.cache.is_some() { | ||||
|         // try reading from cache
 | ||||
|         if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { | ||||
|           // if found, return it as response.
 | ||||
|           info!("Cache hit - Return from cache"); | ||||
|           return Ok(cached_response); | ||||
|         }; | ||||
| 
 | ||||
|         // Synthetic request copy used just for caching (cannot clone request object...)
 | ||||
|         synth_req = Some(build_synth_req_for_cache(&req)); | ||||
|       } | ||||
|       let res = self.request_directly(req).await; | ||||
| 
 | ||||
|       if self.cache.is_none() { | ||||
|         return res.map(|inner| inner.map(ResponseBody::Incoming)); | ||||
|       } | ||||
| 
 | ||||
|       // check cacheability and store it if cacheable
 | ||||
|       let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { | ||||
|         return res.map(|inner| inner.map(ResponseBody::Incoming)); | ||||
|       }; | ||||
|       let (parts, body) = res.unwrap().into_parts(); | ||||
| 
 | ||||
|       // Get streamed body without waiting for the arrival of the body,
 | ||||
|       // which is done simultaneously with caching.
 | ||||
|       let stream_body = self | ||||
|         .cache | ||||
|         .as_ref() | ||||
|         .unwrap() | ||||
|         .put(synth_req.unwrap().uri(), body, &cache_policy) | ||||
|         .await?; | ||||
| 
 | ||||
|       // response with body being cached in background
 | ||||
|       let new_res = Response::from_parts(parts, ResponseBody::Streamed(stream_body)); | ||||
|       Ok(new_res) | ||||
|     } | ||||
| 
 | ||||
|     // No cache handling
 | ||||
|     #[cfg(not(feature = "cache"))] | ||||
|     { | ||||
|       self | ||||
|         .request_directly(req) | ||||
|         .await | ||||
|         .map(|inner| inner.map(ResponseBody::Incoming)) | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<C, B1> Forwarder<C, B1> | ||||
| where | ||||
|   C: Send + Sync + Connect + Clone + 'static, | ||||
|   B1: Body + Send + Unpin + 'static, | ||||
|   <B1 as Body>::Data: Send, | ||||
|   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||
| { | ||||
|   async fn request_directly(&self, req: Request<B1>) -> RpxyResult<Response<Incoming>> { | ||||
|     // TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient.
 | ||||
|     // Needs to be reconsidered. Currently, this is a kind of work around.
 | ||||
|     // This possibly relates to https://github.com/hyperium/hyper/issues/2417.
 | ||||
|     match req.version() { | ||||
|       Version::HTTP_2 => self.inner_h2.request(req).await, // handles `h2c` requests
 | ||||
|       _ => self.inner.request(req).await, | ||||
|     } | ||||
|     .map_err(|e| RpxyError::FailedToFetchFromUpstream(e.to_string())) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(not(any(feature = "native-tls-backend", feature = "rustls-backend")))] | ||||
| impl<B> Forwarder<HttpConnector, B> | ||||
| where | ||||
|   B: Body + Send + Unpin + 'static, | ||||
|   <B as Body>::Data: Send, | ||||
|   <B as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||
| { | ||||
|   /// Build inner client with http
 | ||||
|   pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> { | ||||
|     warn!( | ||||
|       " | ||||
| -------------------------------------------------------------------------------------------------- | ||||
| Request forwarder is working without TLS support!!! | ||||
| We recommend to use this just for testing. | ||||
| Please enable native-tls-backend or rustls-backend feature to enable TLS support. | ||||
| --------------------------------------------------------------------------------------------------" | ||||
|     ); | ||||
|     let executor = LocalExecutor::new(_globals.runtime_handle.clone()); | ||||
|     let mut http = HttpConnector::new(); | ||||
|     http.set_reuse_address(true); | ||||
|     let inner = Client::builder(executor).build::<_, B>(http); | ||||
|     let inner_h2 = inner.clone(); | ||||
| 
 | ||||
|     Ok(Self { | ||||
|       inner, | ||||
|       inner_h2, | ||||
|       #[cfg(feature = "cache")] | ||||
|       cache: RpxyCache::new(_globals).await, | ||||
|     }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(all(feature = "native-tls-backend", not(feature = "rustls-backend")))] | ||||
| /// Build forwarder with hyper-tls (native-tls)
 | ||||
| impl<B1> Forwarder<hyper_tls::HttpsConnector<HttpConnector>, B1> | ||||
| where | ||||
|   B1: Body + Send + Unpin + 'static, | ||||
|   <B1 as Body>::Data: Send, | ||||
|   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||
| { | ||||
|   /// Build forwarder
 | ||||
|   pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> { | ||||
|     // build hyper client with hyper-tls
 | ||||
|     info!("Native TLS support is enabled for the connection to backend applications"); | ||||
|     let executor = LocalExecutor::new(_globals.runtime_handle.clone()); | ||||
| 
 | ||||
|     let try_build_connector = |alpns: &[&str]| { | ||||
|       hyper_tls::native_tls::TlsConnector::builder() | ||||
|         .request_alpns(alpns) | ||||
|         .build() | ||||
|         .map_err(|e| RpxyError::FailedToBuildForwarder(e.to_string())) | ||||
|         .map(|tls| { | ||||
|           let mut http = HttpConnector::new(); | ||||
|           http.enforce_http(false); | ||||
|           http.set_reuse_address(true); | ||||
|           http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); | ||||
|           hyper_tls::HttpsConnector::from((http, tls.into())) | ||||
|         }) | ||||
|     }; | ||||
| 
 | ||||
|     let connector = try_build_connector(&["h2", "http/1.1"])?; | ||||
|     let inner = Client::builder(executor.clone()).build::<_, B1>(connector); | ||||
| 
 | ||||
|     let connector_h2 = try_build_connector(&["h2"])?; | ||||
|     let inner_h2 = Client::builder(executor.clone()) | ||||
|       .http2_only(true) | ||||
|       .build::<_, B1>(connector_h2); | ||||
| 
 | ||||
|     Ok(Self { | ||||
|       inner, | ||||
|       inner_h2, | ||||
|       #[cfg(feature = "cache")] | ||||
|       cache: RpxyCache::new(_globals).await, | ||||
|     }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(feature = "rustls-backend")] | ||||
| /// Build forwarder with hyper-rustls (rustls)
 | ||||
| impl<B1> Forwarder<HttpConnector, B1> | ||||
| where | ||||
|   B1: Body + Send + Unpin + 'static, | ||||
|   <B1 as Body>::Data: Send, | ||||
|   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||
| { | ||||
|   /// Build forwarder
 | ||||
|   pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> { | ||||
|     todo!("Not implemented yet. Please use native-tls-backend feature for now."); | ||||
|     // #[cfg(feature = "native-roots")]
 | ||||
|     // let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots();
 | ||||
|     // #[cfg(feature = "native-roots")]
 | ||||
|     // let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots();
 | ||||
|     // #[cfg(feature = "native-roots")]
 | ||||
|     // info!("Native cert store is used for the connection to backend applications");
 | ||||
| 
 | ||||
|     // #[cfg(not(feature = "native-roots"))]
 | ||||
|     // let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
 | ||||
|     // #[cfg(not(feature = "native-roots"))]
 | ||||
|     // let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots();
 | ||||
|     // #[cfg(not(feature = "native-roots"))]
 | ||||
|     // info!("Mozilla WebPKI root certs is used for the connection to backend applications");
 | ||||
| 
 | ||||
|     // let connector = builder.https_or_http().enable_http1().enable_http2().build();
 | ||||
|     // let connector_h2 = builder_h2.https_or_http().enable_http2().build();
 | ||||
| 
 | ||||
|     // let inner = Client::builder().build::<_, Body>(connector);
 | ||||
|     // let inner_h2 = Client::builder().http2_only(true).build::<_, Body>(connector_h2);
 | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(feature = "cache")] | ||||
| /// Build synthetic request to cache
 | ||||
| fn build_synth_req_for_cache<T>(req: &Request<T>) -> Request<()> { | ||||
|   let mut builder = Request::builder() | ||||
|     .method(req.method()) | ||||
|     .uri(req.uri()) | ||||
|     .version(req.version()); | ||||
|   // TODO: omit extensions. is this approach correct?
 | ||||
|   for (header_key, header_value) in req.headers() { | ||||
|     builder = builder.header(header_key, header_value); | ||||
|   } | ||||
|   builder.body(()).unwrap() | ||||
| } | ||||
							
								
								
									
										11
									
								
								rpxy-lib/src/forwarder/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								rpxy-lib/src/forwarder/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,11 @@ | |||
| #[cfg(feature = "cache")] | ||||
| mod cache; | ||||
| mod client; | ||||
| 
 | ||||
| use crate::hyper_ext::body::RequestBody; | ||||
| 
 | ||||
| pub(crate) type Forwarder<C> = client::Forwarder<C, RequestBody>; | ||||
| pub(crate) use client::ForwardRequest; | ||||
| 
 | ||||
| #[cfg(feature = "cache")] | ||||
| pub(crate) use cache::CacheError; | ||||
|  | @ -1,50 +1,42 @@ | |||
| use crate::{ | ||||
|   backend::{ | ||||
|     Backend, BackendBuilder, Backends, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption, | ||||
|   }, | ||||
|   certs::CryptoSource, | ||||
|   constants::*, | ||||
|   error::RpxyError, | ||||
|   log::*, | ||||
|   utils::{BytesName, PathNameBytesExp}, | ||||
|   count::RequestCount, | ||||
|   crypto::{CryptoSource, ServerCryptoBase}, | ||||
| }; | ||||
| use rustc_hash::FxHashMap as HashMap; | ||||
| use std::net::SocketAddr; | ||||
| use std::sync::{ | ||||
|   atomic::{AtomicUsize, Ordering}, | ||||
|   Arc, | ||||
| }; | ||||
| use tokio::time::Duration; | ||||
| use hot_reload::ReloaderReceiver; | ||||
| use std::{net::SocketAddr, sync::Arc, time::Duration}; | ||||
| 
 | ||||
| /// Global object containing proxy configurations and shared object like counters.
 | ||||
| /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
 | ||||
| pub struct Globals<T> | ||||
| where | ||||
|   T: CryptoSource, | ||||
| { | ||||
| pub struct Globals { | ||||
|   /// Configuration parameters for proxy transport and request handlers
 | ||||
|   pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも?
 | ||||
| 
 | ||||
|   /// Backend application objects to which http request handler forward incoming requests
 | ||||
|   pub backends: Backends<T>, | ||||
| 
 | ||||
|   pub proxy_config: ProxyConfig, | ||||
|   /// Shared context - Counter for serving requests
 | ||||
|   pub request_count: RequestCount, | ||||
| 
 | ||||
|   /// Shared context - Async task runtime handler
 | ||||
|   pub runtime_handle: tokio::runtime::Handle, | ||||
|   /// Shared context - Notify object to stop async tasks
 | ||||
|   pub term_notify: Option<Arc<tokio::sync::Notify>>, | ||||
|   /// Shared context - Certificate reloader service receiver
 | ||||
|   pub cert_reloader_rx: Option<ReloaderReceiver<ServerCryptoBase>>, | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for proxy transport and request handlers
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct ProxyConfig { | ||||
|   pub listen_sockets: Vec<SocketAddr>, // when instantiate server
 | ||||
|   pub http_port: Option<u16>,          // when instantiate server
 | ||||
|   pub https_port: Option<u16>,         // when instantiate server
 | ||||
|   pub tcp_listen_backlog: u32,         // when instantiate server
 | ||||
|   /// listen socket addresses
 | ||||
|   pub listen_sockets: Vec<SocketAddr>, | ||||
|   /// http port
 | ||||
|   pub http_port: Option<u16>, | ||||
|   /// https port
 | ||||
|   pub https_port: Option<u16>, | ||||
|   /// tcp listen backlog
 | ||||
|   pub tcp_listen_backlog: u32, | ||||
| 
 | ||||
|   pub proxy_timeout: Duration,    // when serving requests at Proxy
 | ||||
|   pub upstream_timeout: Duration, // when serving requests at Handler
 | ||||
|   /// Idle timeout as an HTTP server, used as the keep alive interval and timeout for reading request header
 | ||||
|   pub proxy_idle_timeout: Duration, | ||||
|   /// Idle timeout as an HTTP client, used as the keep alive interval for upstream connections
 | ||||
|   pub upstream_idle_timeout: Duration, | ||||
| 
 | ||||
|   pub max_clients: usize,          // when serving requests
 | ||||
|   pub max_concurrent_streams: u32, // when instantiate server
 | ||||
|  | @ -90,8 +82,8 @@ impl Default for ProxyConfig { | |||
|       tcp_listen_backlog: TCP_LISTEN_BACKLOG, | ||||
| 
 | ||||
|       // TODO: Reconsider each timeout values
 | ||||
|       proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), | ||||
|       upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), | ||||
|       proxy_idle_timeout: Duration::from_secs(PROXY_IDLE_TIMEOUT_SEC), | ||||
|       upstream_idle_timeout: Duration::from_secs(UPSTREAM_IDLE_TIMEOUT_SEC), | ||||
| 
 | ||||
|       max_clients: MAX_CLIENTS, | ||||
|       max_concurrent_streams: MAX_CONCURRENT_STREAMS, | ||||
|  | @ -137,44 +129,6 @@ where | |||
|   pub inner: Vec<AppConfig<T>>, | ||||
|   pub default_app: Option<String>, | ||||
| } | ||||
| impl<T> TryInto<Backends<T>> for AppConfigList<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<Backends<T>, Self::Error> { | ||||
|     let mut backends = Backends::new(); | ||||
|     for app_config in self.inner.iter() { | ||||
|       let backend = app_config.try_into()?; | ||||
|       backends | ||||
|         .apps | ||||
|         .insert(app_config.server_name.clone().to_server_name_vec(), backend); | ||||
|       info!( | ||||
|         "Registering application {} ({})", | ||||
|         &app_config.server_name, &app_config.app_name | ||||
|       ); | ||||
|     } | ||||
| 
 | ||||
|     // default backend application for plaintext http requests
 | ||||
|     if let Some(d) = self.default_app { | ||||
|       let d_sn: Vec<&str> = backends | ||||
|         .apps | ||||
|         .iter() | ||||
|         .filter(|(_k, v)| v.app_name == d) | ||||
|         .map(|(_, v)| v.server_name.as_ref()) | ||||
|         .collect(); | ||||
|       if !d_sn.is_empty() { | ||||
|         info!( | ||||
|           "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", | ||||
|           d, d_sn[0] | ||||
|         ); | ||||
|         backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); | ||||
|       } | ||||
|     } | ||||
|     Ok(backends) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for single backend application
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
|  | @ -187,77 +141,6 @@ where | |||
|   pub reverse_proxy: Vec<ReverseProxyConfig>, | ||||
|   pub tls: Option<TlsConfig<T>>, | ||||
| } | ||||
| impl<T> TryInto<Backend<T>> for &AppConfig<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<Backend<T>, Self::Error> { | ||||
|     // backend builder
 | ||||
|     let mut backend_builder = BackendBuilder::default(); | ||||
|     // reverse proxy settings
 | ||||
|     let reverse_proxy = self.try_into()?; | ||||
| 
 | ||||
|     backend_builder | ||||
|       .app_name(self.app_name.clone()) | ||||
|       .server_name(self.server_name.clone()) | ||||
|       .reverse_proxy(reverse_proxy); | ||||
| 
 | ||||
|     // TLS settings and build backend instance
 | ||||
|     let backend = if self.tls.is_none() { | ||||
|       backend_builder.build().map_err(RpxyError::BackendBuild)? | ||||
|     } else { | ||||
|       let tls = self.tls.as_ref().unwrap(); | ||||
| 
 | ||||
|       backend_builder | ||||
|         .https_redirection(Some(tls.https_redirection)) | ||||
|         .crypto_source(Some(tls.inner.clone())) | ||||
|         .build()? | ||||
|     }; | ||||
|     Ok(backend) | ||||
|   } | ||||
| } | ||||
| impl<T> TryInto<ReverseProxy> for &AppConfig<T> | ||||
| where | ||||
|   T: CryptoSource + Clone, | ||||
| { | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn try_into(self) -> Result<ReverseProxy, Self::Error> { | ||||
|     let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default(); | ||||
| 
 | ||||
|     self.reverse_proxy.iter().for_each(|rpo| { | ||||
|       let upstream_vec: Vec<Upstream> = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); | ||||
|       // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap());
 | ||||
|       // let lb_upstream_num = vec_upstream.len();
 | ||||
|       let elem = UpstreamGroupBuilder::default() | ||||
|         .upstream(&upstream_vec) | ||||
|         .path(&rpo.path) | ||||
|         .replace_path(&rpo.replace_path) | ||||
|         .lb(&rpo.load_balance, &upstream_vec, &self.server_name, &rpo.path) | ||||
|         .opts(&rpo.upstream_options) | ||||
|         .build() | ||||
|         .unwrap(); | ||||
| 
 | ||||
|       upstream.insert(elem.path.clone(), elem); | ||||
|     }); | ||||
|     if self.reverse_proxy.iter().filter(|rpo| rpo.path.is_none()).count() >= 2 { | ||||
|       error!("Multiple default reverse proxy setting"); | ||||
|       return Err(RpxyError::ConfigBuild("Invalid reverse proxy setting")); | ||||
|     } | ||||
| 
 | ||||
|     if !(upstream.iter().all(|(_, elem)| { | ||||
|       !(elem.opts.contains(&UpstreamOption::ForceHttp11Upstream) | ||||
|         && elem.opts.contains(&UpstreamOption::ForceHttp2Upstream)) | ||||
|     })) { | ||||
|       error!("Either one of force_http11 or force_http2 can be enabled"); | ||||
|       return Err(RpxyError::ConfigBuild("Invalid upstream option setting")); | ||||
|     } | ||||
| 
 | ||||
|     Ok(ReverseProxy { upstream }) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters for single reverse proxy corresponding to the path
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
|  | @ -272,16 +155,7 @@ pub struct ReverseProxyConfig { | |||
| /// Configuration parameters for single upstream destination from a reverse proxy
 | ||||
| #[derive(PartialEq, Eq, Clone)] | ||||
| pub struct UpstreamUri { | ||||
|   pub inner: hyper::Uri, | ||||
| } | ||||
| impl TryInto<Upstream> for &UpstreamUri { | ||||
|   type Error = anyhow::Error; | ||||
| 
 | ||||
|   fn try_into(self) -> std::result::Result<Upstream, Self::Error> { | ||||
|     Ok(Upstream { | ||||
|       uri: self.inner.clone(), | ||||
|     }) | ||||
|   } | ||||
|   pub inner: http::Uri, | ||||
| } | ||||
| 
 | ||||
| /// Configuration parameters on TLS for a single backend application
 | ||||
|  | @ -293,30 +167,3 @@ where | |||
|   pub inner: T, | ||||
|   pub https_redirection: bool, | ||||
| } | ||||
| 
 | ||||
| #[derive(Debug, Clone, Default)] | ||||
| /// Counter for serving requests
 | ||||
| pub struct RequestCount(Arc<AtomicUsize>); | ||||
| 
 | ||||
| impl RequestCount { | ||||
|   pub fn current(&self) -> usize { | ||||
|     self.0.load(Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
|   pub fn increment(&self) -> usize { | ||||
|     self.0.fetch_add(1, Ordering::Relaxed) | ||||
|   } | ||||
| 
 | ||||
|   pub fn decrement(&self) -> usize { | ||||
|     let mut count; | ||||
|     while { | ||||
|       count = self.0.load(Ordering::Relaxed); | ||||
|       count > 0 | ||||
|         && self | ||||
|           .0 | ||||
|           .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) | ||||
|           != Ok(count) | ||||
|     } {} | ||||
|     count | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -1,380 +0,0 @@ | |||
| // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 | ||||
| use super::{ | ||||
|   forwarder::{ForwardRequest, Forwarder}, | ||||
|   utils_headers::*, | ||||
|   utils_request::*, | ||||
|   utils_synth_response::*, | ||||
|   HandlerContext, | ||||
| }; | ||||
| use crate::{ | ||||
|   backend::{Backend, UpstreamGroup}, | ||||
|   certs::CryptoSource, | ||||
|   constants::RESPONSE_HEADER_SERVER, | ||||
|   error::*, | ||||
|   globals::Globals, | ||||
|   log::*, | ||||
|   utils::ServerNameBytesExp, | ||||
| }; | ||||
| use derive_builder::Builder; | ||||
| use hyper::{ | ||||
|   client::connect::Connect, | ||||
|   header::{self, HeaderValue}, | ||||
|   http::uri::Scheme, | ||||
|   Body, Request, Response, StatusCode, Uri, Version, | ||||
| }; | ||||
| use std::{net::SocketAddr, sync::Arc}; | ||||
| use tokio::{io::copy_bidirectional, time::timeout}; | ||||
| 
 | ||||
| #[derive(Clone, Builder)] | ||||
| /// HTTP message handler for requests from clients and responses from backend applications,
 | ||||
| /// responsible to manipulate and forward messages to upstream backends and downstream clients.
 | ||||
| pub struct HttpMessageHandler<T, U> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   forwarder: Arc<Forwarder<T>>, | ||||
|   globals: Arc<Globals<U>>, | ||||
| } | ||||
| 
 | ||||
| impl<T, U> HttpMessageHandler<T, U> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   /// Return with an arbitrary status code of error and log message
 | ||||
|   fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result<Response<Body>> { | ||||
|     log_data.status_code(&status_code).output(); | ||||
|     http_error(status_code) | ||||
|   } | ||||
| 
 | ||||
|   /// Handle incoming request message from a client
 | ||||
|   pub async fn handle_request( | ||||
|     &self, | ||||
|     mut req: Request<Body>, | ||||
|     client_addr: SocketAddr, // アクセス制御用
 | ||||
|     listen_addr: SocketAddr, | ||||
|     tls_enabled: bool, | ||||
|     tls_server_name: Option<ServerNameBytesExp>, | ||||
|   ) -> Result<Response<Body>> { | ||||
|     ////////
 | ||||
|     let mut log_data = MessageLog::from(&req); | ||||
|     log_data.client_addr(&client_addr); | ||||
|     //////
 | ||||
| 
 | ||||
|     // Here we start to handle with server_name
 | ||||
|     let server_name = if let Ok(v) = req.parse_host() { | ||||
|       ServerNameBytesExp::from(v) | ||||
|     } else { | ||||
|       return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); | ||||
|     }; | ||||
|     // check consistency of between TLS SNI and HOST/Request URI Line.
 | ||||
|     #[allow(clippy::collapsible_if)] | ||||
|     if tls_enabled && self.globals.proxy_config.sni_consistency { | ||||
|       if server_name != tls_server_name.unwrap_or_default() { | ||||
|         return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); | ||||
|       } | ||||
|     } | ||||
|     // Find backend application for given server_name, and drop if incoming request is invalid as request.
 | ||||
|     let backend = match self.globals.backends.apps.get(&server_name) { | ||||
|       Some(be) => be, | ||||
|       None => { | ||||
|         let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else { | ||||
|           return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); | ||||
|         }; | ||||
|         debug!("Serving by default app"); | ||||
|         self.globals.backends.apps.get(default_server_name).unwrap() | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     // Redirect to https if !tls_enabled and redirect_to_https is true
 | ||||
|     if !tls_enabled && backend.https_redirection.unwrap_or(false) { | ||||
|       debug!("Redirect to secure connection: {}", &backend.server_name); | ||||
|       log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output(); | ||||
|       return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req); | ||||
|     } | ||||
| 
 | ||||
|     // Find reverse proxy for given path and choose one of upstream host
 | ||||
|     // Longest prefix match
 | ||||
|     let path = req.uri().path(); | ||||
|     let Some(upstream_group) = backend.reverse_proxy.get(path) else { | ||||
|       return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data) | ||||
|     }; | ||||
| 
 | ||||
|     // Upgrade in request header
 | ||||
|     let upgrade_in_request = extract_upgrade(req.headers()); | ||||
|     let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>(); | ||||
| 
 | ||||
|     // Build request from destination information
 | ||||
|     let _context = match self.generate_request_forwarded( | ||||
|       &client_addr, | ||||
|       &listen_addr, | ||||
|       &mut req, | ||||
|       &upgrade_in_request, | ||||
|       upstream_group, | ||||
|       tls_enabled, | ||||
|     ) { | ||||
|       Err(e) => { | ||||
|         error!("Failed to generate destination uri for reverse proxy: {}", e); | ||||
|         return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); | ||||
|       } | ||||
|       Ok(v) => v, | ||||
|     }; | ||||
|     debug!("Request to be forwarded: {:?}", req); | ||||
|     log_data.xff(&req.headers().get("x-forwarded-for")); | ||||
|     log_data.upstream(req.uri()); | ||||
|     //////
 | ||||
| 
 | ||||
|     // Forward request to a chosen backend
 | ||||
|     let mut res_backend = { | ||||
|       let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { | ||||
|         return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); | ||||
|       }; | ||||
|       match result { | ||||
|         Ok(res) => res, | ||||
|         Err(e) => { | ||||
|           error!("Failed to get response from backend: {}", e); | ||||
|           return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); | ||||
|         } | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     // Process reverse proxy context generated during the forwarding request generation.
 | ||||
|     #[cfg(feature = "sticky-cookie")] | ||||
|     if let Some(context_from_lb) = _context.context_lb { | ||||
|       let res_headers = res_backend.headers_mut(); | ||||
|       if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { | ||||
|         error!("Failed to append context to the response given from backend: {}", e); | ||||
|         return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { | ||||
|       // Generate response to client
 | ||||
|       if self.generate_response_forwarded(&mut res_backend, backend).is_err() { | ||||
|         return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); | ||||
|       } | ||||
|       log_data.status_code(&res_backend.status()).output(); | ||||
|       return Ok(res_backend); | ||||
|     } | ||||
| 
 | ||||
|     // Handle StatusCode::SWITCHING_PROTOCOLS in response
 | ||||
|     let upgrade_in_response = extract_upgrade(res_backend.headers()); | ||||
|     let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) | ||||
|     { | ||||
|       u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() | ||||
|     } else { | ||||
|       false | ||||
|     }; | ||||
|     if !should_upgrade { | ||||
|       error!( | ||||
|         "Backend tried to switch to protocol {:?} when {:?} was requested", | ||||
|         upgrade_in_response, upgrade_in_request | ||||
|       ); | ||||
|       return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); | ||||
|     } | ||||
|     let Some(request_upgraded) = request_upgraded else { | ||||
|       error!("Request does not have an upgrade extension"); | ||||
|       return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); | ||||
|     }; | ||||
|     let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else { | ||||
|       error!("Response does not have an upgrade extension"); | ||||
|       return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); | ||||
|     }; | ||||
| 
 | ||||
|     self.globals.runtime_handle.spawn(async move { | ||||
|       let mut response_upgraded = onupgrade.await.map_err(|e| { | ||||
|         error!("Failed to upgrade response: {}", e); | ||||
|         RpxyError::Hyper(e) | ||||
|       })?; | ||||
|       let mut request_upgraded = request_upgraded.await.map_err(|e| { | ||||
|         error!("Failed to upgrade request: {}", e); | ||||
|         RpxyError::Hyper(e) | ||||
|       })?; | ||||
|       copy_bidirectional(&mut response_upgraded, &mut request_upgraded) | ||||
|         .await | ||||
|         .map_err(|e| { | ||||
|           error!("Coping between upgraded connections failed: {}", e); | ||||
|           RpxyError::Io(e) | ||||
|         })?; | ||||
|       Ok(()) as Result<()> | ||||
|     }); | ||||
|     log_data.status_code(&res_backend.status()).output(); | ||||
|     Ok(res_backend) | ||||
|   } | ||||
| 
 | ||||
|   ////////////////////////////////////////////////////
 | ||||
|   // Functions to generate messages
 | ||||
|   ////////////////////////////////////////////////////
 | ||||
| 
 | ||||
|   /// Manipulate a response message sent from a backend application to forward downstream to a client.
 | ||||
|   fn generate_response_forwarded<B>(&self, response: &mut Response<B>, chosen_backend: &Backend<U>) -> Result<()> | ||||
|   where | ||||
|     B: core::fmt::Debug, | ||||
|   { | ||||
|     let headers = response.headers_mut(); | ||||
|     remove_connection_header(headers); | ||||
|     remove_hop_header(headers); | ||||
|     add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; | ||||
| 
 | ||||
|     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|     { | ||||
|       // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled
 | ||||
|       // TODO: This is a workaround for avoiding a client authentication in HTTP/3
 | ||||
|       if self.globals.proxy_config.http3 | ||||
|         && chosen_backend | ||||
|           .crypto_source | ||||
|           .as_ref() | ||||
|           .is_some_and(|v| !v.is_mutual_tls()) | ||||
|       { | ||||
|         if let Some(port) = self.globals.proxy_config.https_port { | ||||
|           add_header_entry_overwrite_if_exist( | ||||
|             headers, | ||||
|             header::ALT_SVC.as_str(), | ||||
|             format!( | ||||
|               "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", | ||||
|               port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age | ||||
|             ), | ||||
|           )?; | ||||
|         } | ||||
|       } else { | ||||
|         // remove alt-svc to disallow requests via http3
 | ||||
|         headers.remove(header::ALT_SVC.as_str()); | ||||
|       } | ||||
|     } | ||||
|     #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] | ||||
|     { | ||||
|       if let Some(port) = self.globals.proxy_config.https_port { | ||||
|         headers.remove(header::ALT_SVC.as_str()); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   #[allow(clippy::too_many_arguments)] | ||||
|   /// Manipulate a request message sent from a client to forward upstream to a backend application
 | ||||
|   fn generate_request_forwarded<B>( | ||||
|     &self, | ||||
|     client_addr: &SocketAddr, | ||||
|     listen_addr: &SocketAddr, | ||||
|     req: &mut Request<B>, | ||||
|     upgrade: &Option<String>, | ||||
|     upstream_group: &UpstreamGroup, | ||||
|     tls_enabled: bool, | ||||
|   ) -> Result<HandlerContext> { | ||||
|     debug!("Generate request to be forwarded"); | ||||
| 
 | ||||
|     // Add te: trailer if contained in original request
 | ||||
|     let contains_te_trailers = { | ||||
|       if let Some(te) = req.headers().get(header::TE) { | ||||
|         te.as_bytes() | ||||
|           .split(|v| v == &b',' || v == &b' ') | ||||
|           .any(|x| x == "trailers".as_bytes()) | ||||
|       } else { | ||||
|         false | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     let uri = req.uri().to_string(); | ||||
|     let headers = req.headers_mut(); | ||||
|     // delete headers specified in header.connection
 | ||||
|     remove_connection_header(headers); | ||||
|     // delete hop headers including header.connection
 | ||||
|     remove_hop_header(headers); | ||||
|     // X-Forwarded-For
 | ||||
|     add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; | ||||
| 
 | ||||
|     // Add te: trailer if te_trailer
 | ||||
|     if contains_te_trailers { | ||||
|       headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); | ||||
|     } | ||||
| 
 | ||||
|     // add "host" header of original server_name if not exist (default)
 | ||||
|     if req.headers().get(header::HOST).is_none() { | ||||
|       let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); | ||||
|       req | ||||
|         .headers_mut() | ||||
|         .insert(header::HOST, HeaderValue::from_str(&org_host)?); | ||||
|     }; | ||||
| 
 | ||||
|     /////////////////////////////////////////////
 | ||||
|     // Fix unique upstream destination since there could be multiple ones.
 | ||||
|     #[cfg(feature = "sticky-cookie")] | ||||
|     let (upstream_chosen_opt, context_from_lb) = { | ||||
|       let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { | ||||
|         takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? | ||||
|       } else { | ||||
|         None | ||||
|       }; | ||||
|       upstream_group.get(&context_to_lb) | ||||
|     }; | ||||
|     #[cfg(not(feature = "sticky-cookie"))] | ||||
|     let (upstream_chosen_opt, _) = upstream_group.get(&None); | ||||
| 
 | ||||
|     let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; | ||||
|     let context = HandlerContext { | ||||
|       #[cfg(feature = "sticky-cookie")] | ||||
|       context_lb: context_from_lb, | ||||
|       #[cfg(not(feature = "sticky-cookie"))] | ||||
|       context_lb: None, | ||||
|     }; | ||||
|     /////////////////////////////////////////////
 | ||||
| 
 | ||||
|     // apply upstream-specific headers given in upstream_option
 | ||||
|     let headers = req.headers_mut(); | ||||
|     apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?; | ||||
| 
 | ||||
|     // update uri in request
 | ||||
|     if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) { | ||||
|       return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken")); | ||||
|     }; | ||||
|     let new_uri = Uri::builder() | ||||
|       .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) | ||||
|       .authority(upstream_chosen.uri.authority().unwrap().as_str()); | ||||
|     let org_pq = match req.uri().path_and_query() { | ||||
|       Some(pq) => pq.to_string(), | ||||
|       None => "/".to_string(), | ||||
|     } | ||||
|     .into_bytes(); | ||||
| 
 | ||||
|     // replace some parts of path if opt_replace_path is enabled for chosen upstream
 | ||||
|     let new_pq = match &upstream_group.replace_path { | ||||
|       Some(new_path) => { | ||||
|         let matched_path: &[u8] = upstream_group.path.as_ref(); | ||||
|         if matched_path.is_empty() || org_pq.len() < matched_path.len() { | ||||
|           return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); | ||||
|         }; | ||||
|         let mut new_pq = Vec::<u8>::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); | ||||
|         new_pq.extend_from_slice(new_path.as_ref()); | ||||
|         new_pq.extend_from_slice(&org_pq[matched_path.len()..]); | ||||
|         new_pq | ||||
|       } | ||||
|       None => org_pq, | ||||
|     }; | ||||
|     *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; | ||||
| 
 | ||||
|     // upgrade
 | ||||
|     if let Some(v) = upgrade { | ||||
|       req.headers_mut().insert(header::UPGRADE, v.parse()?); | ||||
|       req | ||||
|         .headers_mut() | ||||
|         .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?); | ||||
|     } | ||||
| 
 | ||||
|     // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
 | ||||
|     if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { | ||||
|       // Change version to http/1.1 when destination scheme is http
 | ||||
|       debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); | ||||
|       *req.version_mut() = Version::HTTP_11; | ||||
|     } else if req.version() == Version::HTTP_3 { | ||||
|       // HTTP/3 is always https
 | ||||
|       debug!("HTTP/3 is currently unsupported for request to upstream."); | ||||
|       *req.version_mut() = Version::HTTP_2; | ||||
|     } | ||||
| 
 | ||||
|     apply_upstream_options_to_request_line(req, upstream_group)?; | ||||
| 
 | ||||
|     Ok(context) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										370
									
								
								rpxy-lib/src/hyper_ext/body_incoming_like.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										370
									
								
								rpxy-lib/src/hyper_ext/body_incoming_like.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,370 @@ | |||
| use super::watch; | ||||
| use crate::error::*; | ||||
| use futures_channel::{mpsc, oneshot}; | ||||
| use futures_util::{stream::FusedStream, Future, Stream}; | ||||
| use http::HeaderMap; | ||||
| use hyper::body::{Body, Bytes, Frame, SizeHint}; | ||||
| use std::{ | ||||
|   pin::Pin, | ||||
|   task::{Context, Poll}, | ||||
| }; | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////
 | ||||
| /// Incoming like body to handle incoming request body
 | ||||
| /// ported from https://github.com/hyperium/hyper/blob/master/src/body/incoming.rs
 | ||||
| pub struct IncomingLike { | ||||
|   content_length: DecodedLength, | ||||
|   want_tx: watch::Sender, | ||||
|   data_rx: mpsc::Receiver<Result<Bytes, RpxyError>>, | ||||
|   trailers_rx: oneshot::Receiver<HeaderMap>, | ||||
| } | ||||
| 
 | ||||
| macro_rules! ready { | ||||
|   ($e:expr) => { | ||||
|     match $e { | ||||
|       Poll::Ready(v) => v, | ||||
|       Poll::Pending => return Poll::Pending, | ||||
|     } | ||||
|   }; | ||||
| } | ||||
| 
 | ||||
| type BodySender = mpsc::Sender<Result<Bytes, RpxyError>>; | ||||
| type TrailersSender = oneshot::Sender<HeaderMap>; | ||||
| 
 | ||||
| const MAX_LEN: u64 = std::u64::MAX - 2; | ||||
| #[derive(Clone, Copy, PartialEq, Eq)] | ||||
| pub(crate) struct DecodedLength(u64); | ||||
| impl DecodedLength { | ||||
|   pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); | ||||
|   pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); | ||||
|   pub(crate) const ZERO: DecodedLength = DecodedLength(0); | ||||
| 
 | ||||
|   #[allow(dead_code)] | ||||
|   pub(crate) fn new(len: u64) -> Self { | ||||
|     debug_assert!(len <= MAX_LEN); | ||||
|     DecodedLength(len) | ||||
|   } | ||||
| 
 | ||||
|   pub(crate) fn sub_if(&mut self, amt: u64) { | ||||
|     match *self { | ||||
|       DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), | ||||
|       DecodedLength(ref mut known) => { | ||||
|         *known -= amt; | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   /// Converts to an Option<u64> representing a Known or Unknown length.
 | ||||
|   pub(crate) fn into_opt(self) -> Option<u64> { | ||||
|     match self { | ||||
|       DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, | ||||
|       DecodedLength(known) => Some(known), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| pub(crate) struct Sender { | ||||
|   want_rx: watch::Receiver, | ||||
|   data_tx: BodySender, | ||||
|   trailers_tx: Option<TrailersSender>, | ||||
| } | ||||
| 
 | ||||
| const WANT_PENDING: usize = 1; | ||||
| const WANT_READY: usize = 2; | ||||
| 
 | ||||
| impl IncomingLike { | ||||
|   /// Create a `Body` stream with an associated sender half.
 | ||||
|   ///
 | ||||
|   /// Useful when wanting to stream chunks from another thread.
 | ||||
|   #[inline] | ||||
|   #[allow(unused)] | ||||
|   pub(crate) fn channel() -> (Sender, IncomingLike) { | ||||
|     Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) | ||||
|   } | ||||
| 
 | ||||
|   pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, IncomingLike) { | ||||
|     let (data_tx, data_rx) = mpsc::channel(0); | ||||
|     let (trailers_tx, trailers_rx) = oneshot::channel(); | ||||
| 
 | ||||
|     // If wanter is true, `Sender::poll_ready()` won't becoming ready
 | ||||
|     // until the `Body` has been polled for data once.
 | ||||
|     let want = if wanter { WANT_PENDING } else { WANT_READY }; | ||||
| 
 | ||||
|     let (want_tx, want_rx) = watch::channel(want); | ||||
| 
 | ||||
|     let tx = Sender { | ||||
|       want_rx, | ||||
|       data_tx, | ||||
|       trailers_tx: Some(trailers_tx), | ||||
|     }; | ||||
|     let rx = IncomingLike { | ||||
|       content_length, | ||||
|       want_tx, | ||||
|       data_rx, | ||||
|       trailers_rx, | ||||
|     }; | ||||
| 
 | ||||
|     (tx, rx) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl Body for IncomingLike { | ||||
|   type Data = Bytes; | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn poll_frame( | ||||
|     mut self: Pin<&mut Self>, | ||||
|     cx: &mut Context<'_>, | ||||
|   ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { | ||||
|     self.want_tx.send(WANT_READY); | ||||
| 
 | ||||
|     if !self.data_rx.is_terminated() { | ||||
|       if let Some(chunk) = ready!(Pin::new(&mut self.data_rx).poll_next(cx)?) { | ||||
|         self.content_length.sub_if(chunk.len() as u64); | ||||
|         return Poll::Ready(Some(Ok(Frame::data(chunk)))); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // check trailers after data is terminated
 | ||||
|     match ready!(Pin::new(&mut self.trailers_rx).poll(cx)) { | ||||
|       Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))), | ||||
|       Err(_) => Poll::Ready(None), | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   fn is_end_stream(&self) -> bool { | ||||
|     self.content_length == DecodedLength::ZERO | ||||
|   } | ||||
| 
 | ||||
|   fn size_hint(&self) -> SizeHint { | ||||
|     macro_rules! opt_len { | ||||
|       ($content_length:expr) => {{ | ||||
|         let mut hint = SizeHint::default(); | ||||
| 
 | ||||
|         if let Some(content_length) = $content_length.into_opt() { | ||||
|           hint.set_exact(content_length); | ||||
|         } | ||||
| 
 | ||||
|         hint | ||||
|       }}; | ||||
|     } | ||||
| 
 | ||||
|     opt_len!(self.content_length) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl Sender { | ||||
|   /// Check to see if this `Sender` can send more data.
 | ||||
|   pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<RpxyResult<()>> { | ||||
|     // Check if the receiver end has tried polling for the body yet
 | ||||
|     ready!(self.poll_want(cx)?); | ||||
|     self | ||||
|       .data_tx | ||||
|       .poll_ready(cx) | ||||
|       .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) | ||||
|   } | ||||
| 
 | ||||
|   fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll<RpxyResult<()>> { | ||||
|     match self.want_rx.load(cx) { | ||||
|       WANT_READY => Poll::Ready(Ok(())), | ||||
|       WANT_PENDING => Poll::Pending, | ||||
|       watch::CLOSED => Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)), | ||||
|       unexpected => unreachable!("want_rx value: {}", unexpected), | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   async fn ready(&mut self) -> RpxyResult<()> { | ||||
|     futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await | ||||
|   } | ||||
| 
 | ||||
|   /// Send data on data channel when it is ready.
 | ||||
|   #[allow(unused)] | ||||
|   pub(crate) async fn send_data(&mut self, chunk: Bytes) -> RpxyResult<()> { | ||||
|     self.ready().await?; | ||||
|     self | ||||
|       .data_tx | ||||
|       .try_send(Ok(chunk)) | ||||
|       .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) | ||||
|   } | ||||
| 
 | ||||
|   /// Send trailers on trailers channel.
 | ||||
|   #[allow(unused)] | ||||
|   pub(crate) async fn send_trailers(&mut self, trailers: HeaderMap) -> RpxyResult<()> { | ||||
|     let tx = match self.trailers_tx.take() { | ||||
|       Some(tx) => tx, | ||||
|       None => return Err(RpxyError::HyperIncomingLikeNewClosed), | ||||
|     }; | ||||
|     tx.send(trailers).map_err(|_| RpxyError::HyperIncomingLikeNewClosed) | ||||
|   } | ||||
| 
 | ||||
|   /// Try to send data on this channel.
 | ||||
|   ///
 | ||||
|   /// # Errors
 | ||||
|   ///
 | ||||
|   /// Returns `Err(Bytes)` if the channel could not (currently) accept
 | ||||
|   /// another `Bytes`.
 | ||||
|   ///
 | ||||
|   /// # Note
 | ||||
|   ///
 | ||||
|   /// This is mostly useful for when trying to send from some other thread
 | ||||
|   /// that doesn't have an async context. If in an async context, prefer
 | ||||
|   /// `send_data()` instead.
 | ||||
|   #[allow(unused)] | ||||
|   pub(crate) fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { | ||||
|     self | ||||
|       .data_tx | ||||
|       .try_send(Ok(chunk)) | ||||
|       .map_err(|err| err.into_inner().expect("just sent Ok")) | ||||
|   } | ||||
| 
 | ||||
|   #[allow(unused)] | ||||
|   pub(crate) fn abort(mut self) { | ||||
|     self.send_error(RpxyError::HyperNewBodyWriteAborted); | ||||
|   } | ||||
| 
 | ||||
|   pub(crate) fn send_error(&mut self, err: RpxyError) { | ||||
|     let _ = self | ||||
|       .data_tx | ||||
|       // clone so the send works even if buffer is full
 | ||||
|       .clone() | ||||
|       .try_send(Err(err)); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|   use std::mem; | ||||
|   use std::task::Poll; | ||||
| 
 | ||||
|   use super::{Body, DecodedLength, IncomingLike, Sender, SizeHint}; | ||||
|   use crate::error::RpxyError; | ||||
|   use http_body_util::BodyExt; | ||||
| 
 | ||||
|   #[test] | ||||
|   fn test_size_of() { | ||||
|     // These are mostly to help catch *accidentally* increasing
 | ||||
|     // the size by too much.
 | ||||
| 
 | ||||
|     let body_size = mem::size_of::<IncomingLike>(); | ||||
|     let body_expected_size = mem::size_of::<u64>() * 5; | ||||
|     assert!( | ||||
|       body_size <= body_expected_size, | ||||
|       "Body size = {} <= {}", | ||||
|       body_size, | ||||
|       body_expected_size, | ||||
|     ); | ||||
| 
 | ||||
|     //assert_eq!(body_size, mem::size_of::<Option<Incoming>>(), "Option<Incoming>");
 | ||||
| 
 | ||||
|     assert_eq!(mem::size_of::<Sender>(), mem::size_of::<usize>() * 5, "Sender"); | ||||
| 
 | ||||
|     assert_eq!( | ||||
|       mem::size_of::<Sender>(), | ||||
|       mem::size_of::<Option<Sender>>(), | ||||
|       "Option<Sender>" | ||||
|     ); | ||||
|   } | ||||
|   #[test] | ||||
|   fn size_hint() { | ||||
|     fn eq(body: IncomingLike, b: SizeHint, note: &str) { | ||||
|       let a = body.size_hint(); | ||||
|       assert_eq!(a.lower(), b.lower(), "lower for {:?}", note); | ||||
|       assert_eq!(a.upper(), b.upper(), "upper for {:?}", note); | ||||
|     } | ||||
| 
 | ||||
|     eq(IncomingLike::channel().1, SizeHint::new(), "channel"); | ||||
| 
 | ||||
|     eq( | ||||
|       IncomingLike::new_channel(DecodedLength::new(4), /*wanter =*/ false).1, | ||||
|       SizeHint::with_exact(4), | ||||
|       "channel with length", | ||||
|     ); | ||||
|   } | ||||
| 
 | ||||
|   #[tokio::test] | ||||
|   async fn channel_abort() { | ||||
|     let (tx, mut rx) = IncomingLike::channel(); | ||||
| 
 | ||||
|     tx.abort(); | ||||
| 
 | ||||
|     match rx.frame().await.unwrap() { | ||||
|       Err(RpxyError::HyperNewBodyWriteAborted) => true, | ||||
|       unexpected => panic!("unexpected: {:?}", unexpected), | ||||
|     }; | ||||
|   } | ||||
| 
 | ||||
|   #[tokio::test] | ||||
|   async fn channel_abort_when_buffer_is_full() { | ||||
|     let (mut tx, mut rx) = IncomingLike::channel(); | ||||
| 
 | ||||
|     tx.try_send_data("chunk 1".into()).expect("send 1"); | ||||
|     // buffer is full, but can still send abort
 | ||||
|     tx.abort(); | ||||
| 
 | ||||
|     let chunk1 = rx.frame().await.expect("item 1").expect("chunk 1").into_data().unwrap(); | ||||
|     assert_eq!(chunk1, "chunk 1"); | ||||
| 
 | ||||
|     match rx.frame().await.unwrap() { | ||||
|       Err(RpxyError::HyperNewBodyWriteAborted) => true, | ||||
|       unexpected => panic!("unexpected: {:?}", unexpected), | ||||
|     }; | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
|   fn channel_buffers_one() { | ||||
|     let (mut tx, _rx) = IncomingLike::channel(); | ||||
| 
 | ||||
|     tx.try_send_data("chunk 1".into()).expect("send 1"); | ||||
| 
 | ||||
|     // buffer is now full
 | ||||
|     let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); | ||||
|     assert_eq!(chunk2, "chunk 2"); | ||||
|   } | ||||
| 
 | ||||
|   #[tokio::test] | ||||
|   async fn channel_empty() { | ||||
|     let (_, mut rx) = IncomingLike::channel(); | ||||
| 
 | ||||
|     assert!(rx.frame().await.is_none()); | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
|   fn channel_ready() { | ||||
|     let (mut tx, _rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); | ||||
| 
 | ||||
|     let mut tx_ready = tokio_test::task::spawn(tx.ready()); | ||||
| 
 | ||||
|     assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
|   fn channel_wanter() { | ||||
|     let (mut tx, mut rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); | ||||
| 
 | ||||
|     let mut tx_ready = tokio_test::task::spawn(tx.ready()); | ||||
|     let mut rx_data = tokio_test::task::spawn(rx.frame()); | ||||
| 
 | ||||
|     assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); | ||||
| 
 | ||||
|     assert!(rx_data.poll().is_pending(), "poll rx.data"); | ||||
|     assert!(tx_ready.is_woken(), "rx poll wakes tx"); | ||||
| 
 | ||||
|     assert!(tx_ready.poll().is_ready(), "tx is ready after rx has been polled"); | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
| 
 | ||||
|   fn channel_notices_closure() { | ||||
|     let (mut tx, rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); | ||||
| 
 | ||||
|     let mut tx_ready = tokio_test::task::spawn(tx.ready()); | ||||
| 
 | ||||
|     assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); | ||||
| 
 | ||||
|     drop(rx); | ||||
|     assert!(tx_ready.is_woken(), "dropping rx wakes tx"); | ||||
| 
 | ||||
|     match tx_ready.poll() { | ||||
|       Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)) => (), | ||||
|       unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), | ||||
|     } | ||||
|   } | ||||
| } | ||||
							
								
								
									
										75
									
								
								rpxy-lib/src/hyper_ext/body_type.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								rpxy-lib/src/hyper_ext/body_type.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,75 @@ | |||
| use super::body::IncomingLike; | ||||
| use crate::error::RpxyError; | ||||
| use futures::channel::mpsc::UnboundedReceiver; | ||||
| use http_body_util::{combinators, BodyExt, Empty, Full, StreamBody}; | ||||
| use hyper::body::{Body, Bytes, Frame, Incoming}; | ||||
| use std::pin::Pin; | ||||
| 
 | ||||
| /// Type for synthetic boxed body
 | ||||
| pub type BoxBody = combinators::BoxBody<Bytes, hyper::Error>; | ||||
| 
 | ||||
| /// helper function to build a empty body
 | ||||
| pub(crate) fn empty() -> BoxBody { | ||||
|   Empty::<Bytes>::new().map_err(|never| match never {}).boxed() | ||||
| } | ||||
| 
 | ||||
| /// helper function to build a full body
 | ||||
| pub(crate) fn full(body: Bytes) -> BoxBody { | ||||
|   Full::new(body).map_err(|never| match never {}).boxed() | ||||
| } | ||||
| 
 | ||||
| #[allow(unused)] | ||||
| /* ------------------------------------ */ | ||||
| /// Request body used in this project
 | ||||
| /// - Incoming: just a type that only forwards the downstream request body to upstream.
 | ||||
| /// - IncomingLike: a Incoming-like type in which channel is used
 | ||||
| pub enum RequestBody { | ||||
|   Incoming(Incoming), | ||||
|   IncomingLike(IncomingLike), | ||||
| } | ||||
| 
 | ||||
| impl Body for RequestBody { | ||||
|   type Data = bytes::Bytes; | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn poll_frame( | ||||
|     self: Pin<&mut Self>, | ||||
|     cx: &mut std::task::Context<'_>, | ||||
|   ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { | ||||
|     match self.get_mut() { | ||||
|       RequestBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx).map_err(RpxyError::HyperBodyError), | ||||
|       RequestBody::IncomingLike(incoming_like) => Pin::new(incoming_like).poll_frame(cx), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /* ------------------------------------ */ | ||||
| pub type UnboundedStreamBody = StreamBody<UnboundedReceiver<Result<Frame<bytes::Bytes>, hyper::Error>>>; | ||||
| 
 | ||||
| #[allow(unused)] | ||||
| /// Response body use in this project
 | ||||
| /// - Incoming: just a type that only forwards the upstream response body to downstream.
 | ||||
| /// - Boxed: a type that is generated from cache or synthetic response body, e.g.,, small byte object.
 | ||||
| /// - Streamed: another type that is generated from stream, e.g., large byte object.
 | ||||
| pub enum ResponseBody { | ||||
|   Incoming(Incoming), | ||||
|   Boxed(BoxBody), | ||||
|   Streamed(UnboundedStreamBody), | ||||
| } | ||||
| 
 | ||||
| impl Body for ResponseBody { | ||||
|   type Data = bytes::Bytes; | ||||
|   type Error = RpxyError; | ||||
| 
 | ||||
|   fn poll_frame( | ||||
|     self: Pin<&mut Self>, | ||||
|     cx: &mut std::task::Context<'_>, | ||||
|   ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { | ||||
|     match self.get_mut() { | ||||
|       ResponseBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx), | ||||
|       ResponseBody::Boxed(boxed) => Pin::new(boxed).poll_frame(cx), | ||||
|       ResponseBody::Streamed(streamed) => Pin::new(streamed).poll_frame(cx), | ||||
|     } | ||||
|     .map_err(RpxyError::HyperBodyError) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										23
									
								
								rpxy-lib/src/hyper_ext/executor.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								rpxy-lib/src/hyper_ext/executor.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,23 @@ | |||
| use tokio::runtime::Handle; | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| /// Executor for hyper
 | ||||
| pub struct LocalExecutor { | ||||
|   runtime_handle: Handle, | ||||
| } | ||||
| 
 | ||||
| impl LocalExecutor { | ||||
|   pub fn new(runtime_handle: Handle) -> Self { | ||||
|     LocalExecutor { runtime_handle } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<F> hyper::rt::Executor<F> for LocalExecutor | ||||
| where | ||||
|   F: std::future::Future + Send + 'static, | ||||
|   F::Output: Send, | ||||
| { | ||||
|   fn execute(&self, fut: F) { | ||||
|     self.runtime_handle.spawn(fut); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										16
									
								
								rpxy-lib/src/hyper_ext/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								rpxy-lib/src/hyper_ext/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,16 @@ | |||
| mod body_incoming_like; | ||||
| mod body_type; | ||||
| mod executor; | ||||
| mod tokio_timer; | ||||
| mod watch; | ||||
| 
 | ||||
| #[allow(unused)] | ||||
| pub(crate) mod rt { | ||||
|   pub(crate) use super::executor::LocalExecutor; | ||||
|   pub(crate) use super::tokio_timer::{TokioSleep, TokioTimer}; | ||||
| } | ||||
| #[allow(unused)] | ||||
| pub(crate) mod body { | ||||
|   pub(crate) use super::body_incoming_like::IncomingLike; | ||||
|   pub(crate) use super::body_type::{empty, full, BoxBody, RequestBody, ResponseBody, UnboundedStreamBody}; | ||||
| } | ||||
							
								
								
									
										55
									
								
								rpxy-lib/src/hyper_ext/tokio_timer.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								rpxy-lib/src/hyper_ext/tokio_timer.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,55 @@ | |||
| use std::{ | ||||
|   future::Future, | ||||
|   pin::Pin, | ||||
|   task::{Context, Poll}, | ||||
|   time::{Duration, Instant}, | ||||
| }; | ||||
| 
 | ||||
| use hyper::rt::{Sleep, Timer}; | ||||
| use pin_project_lite::pin_project; | ||||
| 
 | ||||
| #[derive(Clone, Debug)] | ||||
| pub struct TokioTimer; | ||||
| 
 | ||||
| impl Timer for TokioTimer { | ||||
|   fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> { | ||||
|     Box::pin(TokioSleep { | ||||
|       inner: tokio::time::sleep(duration), | ||||
|     }) | ||||
|   } | ||||
| 
 | ||||
|   fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> { | ||||
|     Box::pin(TokioSleep { | ||||
|       inner: tokio::time::sleep_until(deadline.into()), | ||||
|     }) | ||||
|   } | ||||
| 
 | ||||
|   fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) { | ||||
|     if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() { | ||||
|       sleep.reset(new_deadline) | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| pin_project! { | ||||
|     pub(crate) struct TokioSleep { | ||||
|         #[pin] | ||||
|         pub(crate) inner: tokio::time::Sleep, | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Future for TokioSleep { | ||||
|   type Output = (); | ||||
| 
 | ||||
|   fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||
|     self.project().inner.poll(cx) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl Sleep for TokioSleep {} | ||||
| 
 | ||||
| impl TokioSleep { | ||||
|   pub fn reset(self: Pin<&mut Self>, deadline: Instant) { | ||||
|     self.project().inner.as_mut().reset(deadline.into()); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										67
									
								
								rpxy-lib/src/hyper_ext/watch.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								rpxy-lib/src/hyper_ext/watch.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,67 @@ | |||
| //! An SPSC broadcast channel.
 | ||||
| //!
 | ||||
| //! - The value can only be a `usize`.
 | ||||
| //! - The consumer is only notified if the value is different.
 | ||||
| //! - The value `0` is reserved for closed.
 | ||||
| // from https://github.com/hyperium/hyper/blob/master/src/common/watch.rs
 | ||||
| 
 | ||||
| use futures_util::task::AtomicWaker; | ||||
| use std::sync::{ | ||||
|   atomic::{AtomicUsize, Ordering}, | ||||
|   Arc, | ||||
| }; | ||||
| use std::task; | ||||
| 
 | ||||
| type Value = usize; | ||||
| 
 | ||||
| pub(super) const CLOSED: usize = 0; | ||||
| 
 | ||||
| pub(super) fn channel(initial: Value) -> (Sender, Receiver) { | ||||
|   debug_assert!(initial != CLOSED, "watch::channel initial state of 0 is reserved"); | ||||
| 
 | ||||
|   let shared = Arc::new(Shared { | ||||
|     value: AtomicUsize::new(initial), | ||||
|     waker: AtomicWaker::new(), | ||||
|   }); | ||||
| 
 | ||||
|   (Sender { shared: shared.clone() }, Receiver { shared }) | ||||
| } | ||||
| 
 | ||||
| pub(super) struct Sender { | ||||
|   shared: Arc<Shared>, | ||||
| } | ||||
| 
 | ||||
| pub(super) struct Receiver { | ||||
|   shared: Arc<Shared>, | ||||
| } | ||||
| 
 | ||||
| struct Shared { | ||||
|   value: AtomicUsize, | ||||
|   waker: AtomicWaker, | ||||
| } | ||||
| 
 | ||||
| impl Sender { | ||||
|   pub(super) fn send(&mut self, value: Value) { | ||||
|     if self.shared.value.swap(value, Ordering::SeqCst) != value { | ||||
|       self.shared.waker.wake(); | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl Drop for Sender { | ||||
|   fn drop(&mut self) { | ||||
|     self.send(CLOSED); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl Receiver { | ||||
|   pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { | ||||
|     self.shared.waker.register(cx.waker()); | ||||
|     self.shared.value.load(Ordering::SeqCst) | ||||
|   } | ||||
| 
 | ||||
|   #[allow(dead_code)] | ||||
|   pub(crate) fn peek(&self) -> Value { | ||||
|     self.shared.value.load(Ordering::Relaxed) | ||||
|   } | ||||
| } | ||||
|  | @ -1,26 +1,25 @@ | |||
| mod backend; | ||||
| mod certs; | ||||
| mod constants; | ||||
| mod count; | ||||
| mod crypto; | ||||
| mod error; | ||||
| mod forwarder; | ||||
| mod globals; | ||||
| mod handler; | ||||
| mod hyper_ext; | ||||
| mod log; | ||||
| mod message_handler; | ||||
| mod name_exp; | ||||
| mod proxy; | ||||
| mod utils; | ||||
| 
 | ||||
| use crate::{ | ||||
|   error::*, | ||||
|   globals::Globals, | ||||
|   handler::{Forwarder, HttpMessageHandlerBuilder}, | ||||
|   log::*, | ||||
|   proxy::ProxyBuilder, | ||||
|   crypto::build_cert_reloader, error::*, forwarder::Forwarder, globals::Globals, log::*, | ||||
|   message_handler::HttpMessageHandlerBuilder, proxy::Proxy, | ||||
| }; | ||||
| use futures::future::select_all; | ||||
| // use hyper_trust_dns::TrustDnsResolver;
 | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| pub use crate::{ | ||||
|   certs::{CertsAndKeys, CryptoSource}, | ||||
|   crypto::{CertsAndKeys, CryptoSource}, | ||||
|   globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}, | ||||
| }; | ||||
| pub mod reexports { | ||||
|  | @ -28,19 +27,22 @@ pub mod reexports { | |||
|   pub use rustls::{Certificate, PrivateKey}; | ||||
| } | ||||
| 
 | ||||
| #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
| compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); | ||||
| 
 | ||||
| /// Entrypoint that creates and spawns tasks of reverse proxy services
 | ||||
| pub async fn entrypoint<T>( | ||||
|   proxy_config: &ProxyConfig, | ||||
|   app_config_list: &AppConfigList<T>, | ||||
|   runtime_handle: &tokio::runtime::Handle, | ||||
|   term_notify: Option<Arc<tokio::sync::Notify>>, | ||||
| ) -> Result<()> | ||||
| ) -> RpxyResult<()> | ||||
| where | ||||
|   T: CryptoSource + Clone + Send + Sync + 'static, | ||||
| { | ||||
|   #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|   warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used"); | ||||
| 
 | ||||
|   #[cfg(all(feature = "native-tls-backend", feature = "rustls-backend"))] | ||||
|   warn!("Both \"native-tls-backend\" and \"rustls-backend\" features are enabled. \"rustls-backend\" will be used"); | ||||
| 
 | ||||
|   // For initial message logging
 | ||||
|   if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { | ||||
|     info!("Listen both IPv4 and IPv6") | ||||
|  | @ -70,44 +72,76 @@ where | |||
|     info!("Cache is disabled") | ||||
|   } | ||||
| 
 | ||||
|   // build global
 | ||||
|   // 1. build backends, and make it contained in Arc
 | ||||
|   let app_manager = Arc::new(backend::BackendAppManager::try_from(app_config_list)?); | ||||
| 
 | ||||
|   // 2. build crypto reloader service
 | ||||
|   let (cert_reloader_service, cert_reloader_rx) = match proxy_config.https_port { | ||||
|     Some(_) => { | ||||
|       let (s, r) = build_cert_reloader(&app_manager).await?; | ||||
|       (Some(s), Some(r)) | ||||
|     } | ||||
|     None => (None, None), | ||||
|   }; | ||||
| 
 | ||||
|   // 3. build global shared context
 | ||||
|   let globals = Arc::new(Globals { | ||||
|     proxy_config: proxy_config.clone(), | ||||
|     backends: app_config_list.clone().try_into()?, | ||||
|     request_count: Default::default(), | ||||
|     runtime_handle: runtime_handle.clone(), | ||||
|     term_notify: term_notify.clone(), | ||||
|     cert_reloader_rx: cert_reloader_rx.clone(), | ||||
|   }); | ||||
| 
 | ||||
|   // build message handler including a request forwarder
 | ||||
|   let msg_handler = Arc::new( | ||||
|   // 4. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well
 | ||||
|   let forwarder = Arc::new(Forwarder::try_new(&globals).await?); | ||||
|   let message_handler = Arc::new( | ||||
|     HttpMessageHandlerBuilder::default() | ||||
|       .forwarder(Arc::new(Forwarder::new(&globals).await)) | ||||
|       .globals(globals.clone()) | ||||
|       .app_manager(app_manager.clone()) | ||||
|       .forwarder(forwarder) | ||||
|       .build()?, | ||||
|   ); | ||||
| 
 | ||||
|   // 5. spawn each proxy for a given socket with copied Arc-ed message_handler.
 | ||||
|   // build hyper connection builder shared with proxy instances
 | ||||
|   let connection_builder = proxy::connection_builder(&globals); | ||||
| 
 | ||||
|   // spawn each proxy for a given socket with copied Arc-ed backend, message_handler and connection builder.
 | ||||
|   let addresses = globals.proxy_config.listen_sockets.clone(); | ||||
|   let futures = select_all(addresses.into_iter().map(|addr| { | ||||
|   let futures_iter = addresses.into_iter().map(|listening_on| { | ||||
|     let mut tls_enabled = false; | ||||
|     if let Some(https_port) = globals.proxy_config.https_port { | ||||
|       tls_enabled = https_port == addr.port() | ||||
|       tls_enabled = https_port == listening_on.port() | ||||
|     } | ||||
| 
 | ||||
|     let proxy = ProxyBuilder::default() | ||||
|       .globals(globals.clone()) | ||||
|       .listening_on(addr) | ||||
|       .tls_enabled(tls_enabled) | ||||
|       .msg_handler(msg_handler.clone()) | ||||
|       .build() | ||||
|       .unwrap(); | ||||
| 
 | ||||
|     globals.runtime_handle.spawn(proxy.start(term_notify.clone())) | ||||
|   })); | ||||
|     let proxy = Proxy { | ||||
|       globals: globals.clone(), | ||||
|       listening_on, | ||||
|       tls_enabled, | ||||
|       connection_builder: connection_builder.clone(), | ||||
|       message_handler: message_handler.clone(), | ||||
|     }; | ||||
|     globals.runtime_handle.spawn(async move { proxy.start().await }) | ||||
|   }); | ||||
| 
 | ||||
|   // wait for all future
 | ||||
|   if let (Ok(Err(e)), _, _) = futures.await { | ||||
|     error!("Some proxy services are down: {:?}", e); | ||||
|   }; | ||||
|   match cert_reloader_service { | ||||
|     Some(cert_service) => { | ||||
|       tokio::select! { | ||||
|         _ = cert_service.start() => { | ||||
|           error!("Certificate reloader service got down"); | ||||
|         } | ||||
|         _ = select_all(futures_iter) => { | ||||
|           error!("Some proxy services are down"); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     None => { | ||||
|       if let (Ok(Err(e)), _, _) = select_all(futures_iter).await { | ||||
|         error!("Some proxy services are down: {}", e); | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
|  |  | |||
|  | @ -1,98 +1 @@ | |||
| use crate::utils::ToCanonical; | ||||
| use hyper::header; | ||||
| use std::net::SocketAddr; | ||||
| pub use tracing::{debug, error, info, warn}; | ||||
| 
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct MessageLog { | ||||
|   // pub tls_server_name: String,
 | ||||
|   pub client_addr: String, | ||||
|   pub method: String, | ||||
|   pub host: String, | ||||
|   pub p_and_q: String, | ||||
|   pub version: hyper::Version, | ||||
|   pub uri_scheme: String, | ||||
|   pub uri_host: String, | ||||
|   pub ua: String, | ||||
|   pub xff: String, | ||||
|   pub status: String, | ||||
|   pub upstream: String, | ||||
| } | ||||
| 
 | ||||
| impl<T> From<&hyper::Request<T>> for MessageLog { | ||||
|   fn from(req: &hyper::Request<T>) -> Self { | ||||
|     let header_mapper = |v: header::HeaderName| { | ||||
|       req | ||||
|         .headers() | ||||
|         .get(v) | ||||
|         .map_or_else(|| "", |s| s.to_str().unwrap_or("")) | ||||
|         .to_string() | ||||
|     }; | ||||
|     Self { | ||||
|       // tls_server_name: "".to_string(),
 | ||||
|       client_addr: "".to_string(), | ||||
|       method: req.method().to_string(), | ||||
|       host: header_mapper(header::HOST), | ||||
|       p_and_q: req | ||||
|         .uri() | ||||
|         .path_and_query() | ||||
|         .map_or_else(|| "", |v| v.as_str()) | ||||
|         .to_string(), | ||||
|       version: req.version(), | ||||
|       uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), | ||||
|       uri_host: req.uri().host().unwrap_or("").to_string(), | ||||
|       ua: header_mapper(header::USER_AGENT), | ||||
|       xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), | ||||
|       status: "".to_string(), | ||||
|       upstream: "".to_string(), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl MessageLog { | ||||
|   pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { | ||||
|     self.client_addr = client_addr.to_canonical().to_string(); | ||||
|     self | ||||
|   } | ||||
|   // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self {
 | ||||
|   //   self.tls_server_name = tls_server_name.to_string();
 | ||||
|   //   self
 | ||||
|   // }
 | ||||
|   pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { | ||||
|     self.status = status_code.to_string(); | ||||
|     self | ||||
|   } | ||||
|   pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { | ||||
|     self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); | ||||
|     self | ||||
|   } | ||||
|   pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { | ||||
|     self.upstream = upstream.to_string(); | ||||
|     self | ||||
|   } | ||||
| 
 | ||||
|   pub fn output(&self) { | ||||
|     info!( | ||||
|       "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", | ||||
|       if !self.host.is_empty() { | ||||
|         self.host.as_str() | ||||
|       } else { | ||||
|         self.uri_host.as_str() | ||||
|       }, | ||||
|       self.client_addr, | ||||
|       self.method, | ||||
|       self.p_and_q, | ||||
|       self.version, | ||||
|       self.status, | ||||
|       if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { | ||||
|         format!("{}://{}", self.uri_scheme, self.uri_host) | ||||
|       } else { | ||||
|         "".to_string() | ||||
|       }, | ||||
|       self.ua, | ||||
|       self.xff, | ||||
|       self.upstream, | ||||
|       // self.tls_server_name
 | ||||
|     ); | ||||
|   } | ||||
| } | ||||
|  |  | |||
							
								
								
									
										61
									
								
								rpxy-lib/src/message_handler/canonical_address.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								rpxy-lib/src/message_handler/canonical_address.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,61 @@ | |||
| use std::net::{IpAddr, Ipv4Addr, SocketAddr}; | ||||
| 
 | ||||
| /// Trait to convert an IP address to its canonical form
 | ||||
| pub trait ToCanonical { | ||||
|   fn to_canonical(&self) -> Self; | ||||
| } | ||||
| 
 | ||||
| impl ToCanonical for SocketAddr { | ||||
|   fn to_canonical(&self) -> Self { | ||||
|     match self { | ||||
|       SocketAddr::V4(_) => *self, | ||||
|       SocketAddr::V6(v6) => match v6.ip().to_ipv4() { | ||||
|         Some(mapped) => { | ||||
|           if mapped == Ipv4Addr::new(0, 0, 0, 1) { | ||||
|             *self | ||||
|           } else { | ||||
|             SocketAddr::new(IpAddr::V4(mapped), self.port()) | ||||
|           } | ||||
|         } | ||||
|         None => *self, | ||||
|       }, | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|   use super::*; | ||||
|   use std::net::Ipv6Addr; | ||||
|   #[test] | ||||
|   fn ipv4_loopback_to_canonical() { | ||||
|     let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080); | ||||
|     assert_eq!(socket.to_canonical(), socket); | ||||
|   } | ||||
|   #[test] | ||||
|   fn ipv6_loopback_to_canonical() { | ||||
|     let socket = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 8080); | ||||
|     assert_eq!(socket.to_canonical(), socket); | ||||
|   } | ||||
|   #[test] | ||||
|   fn ipv4_to_canonical() { | ||||
|     let socket = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1)), 8080); | ||||
|     assert_eq!(socket.to_canonical(), socket); | ||||
|   } | ||||
|   #[test] | ||||
|   fn ipv6_to_canonical() { | ||||
|     let socket = SocketAddr::new( | ||||
|       IpAddr::V6(Ipv6Addr::new(0x2001, 0x0db8, 0, 0, 0, 0, 0xdead, 0xbeef)), | ||||
|       8080, | ||||
|     ); | ||||
|     assert_eq!(socket.to_canonical(), socket); | ||||
|   } | ||||
|   #[test] | ||||
|   fn ipv4_mapped_to_ipv6_to_canonical() { | ||||
|     let socket = SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0xffff, 0xc00a, 0x2ff)), 8080); | ||||
|     assert_eq!( | ||||
|       socket.to_canonical(), | ||||
|       SocketAddr::new(IpAddr::V4(Ipv4Addr::new(192, 10, 2, 255)), 8080) | ||||
|     ); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										248
									
								
								rpxy-lib/src/message_handler/handler_main.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								rpxy-lib/src/message_handler/handler_main.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,248 @@ | |||
| use super::{ | ||||
|   http_log::HttpMessageLog, | ||||
|   http_result::{HttpError, HttpResult}, | ||||
|   synthetic_response::{secure_redirection_response, synthetic_error_response}, | ||||
|   utils_headers::*, | ||||
|   utils_request::InspectParseHost, | ||||
| }; | ||||
| use crate::{ | ||||
|   backend::{BackendAppManager, LoadBalanceContext}, | ||||
|   crypto::CryptoSource, | ||||
|   error::*, | ||||
|   forwarder::{ForwardRequest, Forwarder}, | ||||
|   globals::Globals, | ||||
|   hyper_ext::body::{RequestBody, ResponseBody}, | ||||
|   log::*, | ||||
|   name_exp::ServerName, | ||||
| }; | ||||
| use derive_builder::Builder; | ||||
| use http::{Request, Response, StatusCode}; | ||||
| use hyper_util::{client::legacy::connect::Connect, rt::TokioIo}; | ||||
| use std::{net::SocketAddr, sync::Arc}; | ||||
| use tokio::io::copy_bidirectional; | ||||
| 
 | ||||
| #[allow(dead_code)] | ||||
| #[derive(Debug)] | ||||
| /// Context object to handle sticky cookies at HTTP message handler
 | ||||
| pub(super) struct HandlerContext { | ||||
|   #[cfg(feature = "sticky-cookie")] | ||||
|   pub(super) context_lb: Option<LoadBalanceContext>, | ||||
|   #[cfg(not(feature = "sticky-cookie"))] | ||||
|   pub(super) context_lb: Option<()>, | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone, Builder)] | ||||
| /// HTTP message handler for requests from clients and responses from backend applications,
 | ||||
| /// responsible to manipulate and forward messages to upstream backends and downstream clients.
 | ||||
| pub struct HttpMessageHandler<U, C> | ||||
| where | ||||
|   C: Send + Sync + Connect + Clone + 'static, | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   forwarder: Arc<Forwarder<C>>, | ||||
|   pub(super) globals: Arc<Globals>, | ||||
|   app_manager: Arc<BackendAppManager<U>>, | ||||
| } | ||||
| 
 | ||||
| impl<U, C> HttpMessageHandler<U, C> | ||||
| where | ||||
|   C: Send + Sync + Connect + Clone + 'static, | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   /// Handle incoming request message from a client.
 | ||||
|   /// Responsible to passthrough responses from backend applications or generate synthetic error responses.
 | ||||
|   pub async fn handle_request( | ||||
|     &self, | ||||
|     req: Request<RequestBody>, | ||||
|     client_addr: SocketAddr, // For access control
 | ||||
|     listen_addr: SocketAddr, | ||||
|     tls_enabled: bool, | ||||
|     tls_server_name: Option<ServerName>, | ||||
|   ) -> RpxyResult<Response<ResponseBody>> { | ||||
|     // preparing log data
 | ||||
|     let mut log_data = HttpMessageLog::from(&req); | ||||
|     log_data.client_addr(&client_addr); | ||||
| 
 | ||||
|     let http_result = self | ||||
|       .handle_request_inner( | ||||
|         &mut log_data, | ||||
|         req, | ||||
|         client_addr, | ||||
|         listen_addr, | ||||
|         tls_enabled, | ||||
|         tls_server_name, | ||||
|       ) | ||||
|       .await; | ||||
| 
 | ||||
|     // passthrough or synthetic response
 | ||||
|     match http_result { | ||||
|       Ok(v) => { | ||||
|         log_data.status_code(&v.status()).output(); | ||||
|         Ok(v) | ||||
|       } | ||||
|       Err(e) => { | ||||
|         error!("{e}"); | ||||
|         let code = StatusCode::from(e); | ||||
|         log_data.status_code(&code).output(); | ||||
|         synthetic_error_response(code) | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   /// Handle inner with no synthetic error response.
 | ||||
|   /// Synthetic response is generated by caller.
 | ||||
|   async fn handle_request_inner( | ||||
|     &self, | ||||
|     log_data: &mut HttpMessageLog, | ||||
|     mut req: Request<RequestBody>, | ||||
|     client_addr: SocketAddr, // For access control
 | ||||
|     listen_addr: SocketAddr, | ||||
|     tls_enabled: bool, | ||||
|     tls_server_name: Option<ServerName>, | ||||
|   ) -> HttpResult<Response<ResponseBody>> { | ||||
|     // Here we start to inspect and parse with server_name
 | ||||
|     let server_name = req | ||||
|       .inspect_parse_host() | ||||
|       .map(|v| ServerName::from(v.as_slice())) | ||||
|       .map_err(|_e| HttpError::InvalidHostInRequestHeader)?; | ||||
| 
 | ||||
|     // check consistency of between TLS SNI and HOST/Request URI Line.
 | ||||
|     #[allow(clippy::collapsible_if)] | ||||
|     if tls_enabled && self.globals.proxy_config.sni_consistency { | ||||
|       if server_name != tls_server_name.unwrap_or_default() { | ||||
|         return Err(HttpError::SniHostInconsistency); | ||||
|       } | ||||
|     } | ||||
|     // Find backend application for given server_name, and drop if incoming request is invalid as request.
 | ||||
|     let backend_app = match self.app_manager.apps.get(&server_name) { | ||||
|       Some(backend_app) => backend_app, | ||||
|       None => { | ||||
|         let Some(default_server_name) = &self.app_manager.default_server_name else { | ||||
|           return Err(HttpError::NoMatchingBackendApp); | ||||
|         }; | ||||
|         debug!("Serving by default app"); | ||||
|         self.app_manager.apps.get(default_server_name).unwrap() | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     // Redirect to https if !tls_enabled and redirect_to_https is true
 | ||||
|     if !tls_enabled && backend_app.https_redirection.unwrap_or(false) { | ||||
|       debug!( | ||||
|         "Redirect to secure connection: {}", | ||||
|         <&ServerName as TryInto<String>>::try_into(&backend_app.server_name).unwrap_or_default() | ||||
|       ); | ||||
|       return secure_redirection_response(&backend_app.server_name, self.globals.proxy_config.https_port, &req); | ||||
|     } | ||||
| 
 | ||||
|     // Find reverse proxy for given path and choose one of upstream host
 | ||||
|     // Longest prefix match
 | ||||
|     let path = req.uri().path(); | ||||
|     let Some(upstream_candidates) = backend_app.path_manager.get(path) else { | ||||
|       return Err(HttpError::NoUpstreamCandidates); | ||||
|     }; | ||||
| 
 | ||||
|     // Upgrade in request header
 | ||||
|     let upgrade_in_request = extract_upgrade(req.headers()); | ||||
|     if upgrade_in_request.is_some() && req.version() != http::Version::HTTP_11 { | ||||
|       return Err(HttpError::FailedToUpgrade(format!( | ||||
|         "Unsupported HTTP version: {:?}", | ||||
|         req.version() | ||||
|       ))); | ||||
|     } | ||||
|     // let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
 | ||||
|     let req_on_upgrade = hyper::upgrade::on(&mut req); | ||||
| 
 | ||||
|     // Build request from destination information
 | ||||
|     let _context = match self.generate_request_forwarded( | ||||
|       &client_addr, | ||||
|       &listen_addr, | ||||
|       &mut req, | ||||
|       &upgrade_in_request, | ||||
|       upstream_candidates, | ||||
|       tls_enabled, | ||||
|     ) { | ||||
|       Err(e) => { | ||||
|         return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string())); | ||||
|       } | ||||
|       Ok(v) => v, | ||||
|     }; | ||||
|     debug!( | ||||
|       "Request to be forwarded: [uri {}, method: {}, version {:?}, headers {:?}]", | ||||
|       req.uri(), | ||||
|       req.method(), | ||||
|       req.version(), | ||||
|       req.headers() | ||||
|     ); | ||||
|     log_data.xff(&req.headers().get("x-forwarded-for")); | ||||
|     log_data.upstream(req.uri()); | ||||
|     //////
 | ||||
| 
 | ||||
|     //////////////
 | ||||
|     // Forward request to a chosen backend
 | ||||
|     let mut res_backend = match self.forwarder.request(req).await { | ||||
|       Ok(v) => v, | ||||
|       Err(e) => { | ||||
|         return Err(HttpError::FailedToGetResponseFromBackend(e.to_string())); | ||||
|       } | ||||
|     }; | ||||
|     //////////////
 | ||||
|     // Process reverse proxy context generated during the forwarding request generation.
 | ||||
|     #[cfg(feature = "sticky-cookie")] | ||||
|     if let Some(context_from_lb) = _context.context_lb { | ||||
|       let res_headers = res_backend.headers_mut(); | ||||
|       if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { | ||||
|         return Err(HttpError::FailedToAddSetCookeInResponse(e.to_string())); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { | ||||
|       // Generate response to client
 | ||||
|       if let Err(e) = self.generate_response_forwarded(&mut res_backend, backend_app) { | ||||
|         return Err(HttpError::FailedToGenerateDownstreamResponse(e.to_string())); | ||||
|       } | ||||
|       return Ok(res_backend); | ||||
|     } | ||||
| 
 | ||||
|     // Handle StatusCode::SWITCHING_PROTOCOLS in response
 | ||||
|     let upgrade_in_response = extract_upgrade(res_backend.headers()); | ||||
|     let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) { | ||||
|       (Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(), | ||||
|       _ => false, | ||||
|     }; | ||||
| 
 | ||||
|     if !should_upgrade { | ||||
|       return Err(HttpError::FailedToUpgrade(format!( | ||||
|         "Backend tried to switch to protocol {:?} when {:?} was requested", | ||||
|         upgrade_in_response, upgrade_in_request | ||||
|       ))); | ||||
|     } | ||||
|     // let Some(request_upgraded) = request_upgraded else {
 | ||||
|     //   return Err(HttpError::NoUpgradeExtensionInRequest);
 | ||||
|     // };
 | ||||
| 
 | ||||
|     // let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
 | ||||
|     //   return Err(HttpError::NoUpgradeExtensionInResponse);
 | ||||
|     // };
 | ||||
|     let res_on_upgrade = hyper::upgrade::on(&mut res_backend); | ||||
| 
 | ||||
|     self.globals.runtime_handle.spawn(async move { | ||||
|       let mut response_upgraded = TokioIo::new(res_on_upgrade.await.map_err(|e| { | ||||
|         error!("Failed to upgrade response: {}", e); | ||||
|         RpxyError::FailedToUpgradeResponse(e.to_string()) | ||||
|       })?); | ||||
|       let mut request_upgraded = TokioIo::new(req_on_upgrade.await.map_err(|e| { | ||||
|         error!("Failed to upgrade request: {}", e); | ||||
|         RpxyError::FailedToUpgradeRequest(e.to_string()) | ||||
|       })?); | ||||
|       copy_bidirectional(&mut response_upgraded, &mut request_upgraded) | ||||
|         .await | ||||
|         .map_err(|e| { | ||||
|           error!("Coping between upgraded connections failed: {}", e); | ||||
|           RpxyError::FailedToCopyBidirectional(e.to_string()) | ||||
|         })?; | ||||
|       Ok(()) as RpxyResult<()> | ||||
|     }); | ||||
| 
 | ||||
|     Ok(res_backend) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										188
									
								
								rpxy-lib/src/message_handler/handler_manipulate_messages.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										188
									
								
								rpxy-lib/src/message_handler/handler_manipulate_messages.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,188 @@ | |||
| use super::{handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line, HttpMessageHandler}; | ||||
| use crate::{ | ||||
|   backend::{BackendApp, UpstreamCandidates}, | ||||
|   constants::RESPONSE_HEADER_SERVER, | ||||
|   log::*, | ||||
|   CryptoSource, | ||||
| }; | ||||
| use anyhow::{anyhow, ensure, Result}; | ||||
| use http::{header, HeaderValue, Request, Response, Uri}; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| use std::net::SocketAddr; | ||||
| 
 | ||||
| impl<U, C> HttpMessageHandler<U, C> | ||||
| where | ||||
|   C: Send + Sync + Connect + Clone + 'static, | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   ////////////////////////////////////////////////////
 | ||||
|   // Functions to generate messages
 | ||||
|   ////////////////////////////////////////////////////
 | ||||
| 
 | ||||
|   #[allow(unused_variables)] | ||||
|   /// Manipulate a response message sent from a backend application to forward downstream to a client.
 | ||||
|   pub(super) fn generate_response_forwarded<B>( | ||||
|     &self, | ||||
|     response: &mut Response<B>, | ||||
|     backend_app: &BackendApp<U>, | ||||
|   ) -> Result<()> { | ||||
|     let headers = response.headers_mut(); | ||||
|     remove_connection_header(headers); | ||||
|     remove_hop_header(headers); | ||||
|     add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; | ||||
| 
 | ||||
|     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|     { | ||||
|       // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled
 | ||||
|       // TODO: This is a workaround for avoiding a client authentication in HTTP/3
 | ||||
|       if self.globals.proxy_config.http3 && backend_app.crypto_source.as_ref().is_some_and(|v| !v.is_mutual_tls()) { | ||||
|         if let Some(port) = self.globals.proxy_config.https_port { | ||||
|           add_header_entry_overwrite_if_exist( | ||||
|             headers, | ||||
|             header::ALT_SVC.as_str(), | ||||
|             format!( | ||||
|               "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", | ||||
|               port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age | ||||
|             ), | ||||
|           )?; | ||||
|         } | ||||
|       } else { | ||||
|         // remove alt-svc to disallow requests via http3
 | ||||
|         headers.remove(header::ALT_SVC.as_str()); | ||||
|       } | ||||
|     } | ||||
|     #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] | ||||
|     { | ||||
|       if self.globals.proxy_config.https_port.is_some() { | ||||
|         headers.remove(header::ALT_SVC.as_str()); | ||||
|       } | ||||
|     } | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   #[allow(clippy::too_many_arguments)] | ||||
|   /// Manipulate a request message sent from a client to forward upstream to a backend application
 | ||||
|   pub(super) fn generate_request_forwarded<B>( | ||||
|     &self, | ||||
|     client_addr: &SocketAddr, | ||||
|     listen_addr: &SocketAddr, | ||||
|     req: &mut Request<B>, | ||||
|     upgrade: &Option<String>, | ||||
|     upstream_candidates: &UpstreamCandidates, | ||||
|     tls_enabled: bool, | ||||
|   ) -> Result<HandlerContext> { | ||||
|     debug!("Generate request to be forwarded"); | ||||
| 
 | ||||
|     // Add te: trailer if contained in original request
 | ||||
|     let contains_te_trailers = { | ||||
|       if let Some(te) = req.headers().get(header::TE) { | ||||
|         te.as_bytes() | ||||
|           .split(|v| v == &b',' || v == &b' ') | ||||
|           .any(|x| x == "trailers".as_bytes()) | ||||
|       } else { | ||||
|         false | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     let original_uri = req.uri().to_string(); | ||||
|     let headers = req.headers_mut(); | ||||
|     // delete headers specified in header.connection
 | ||||
|     remove_connection_header(headers); | ||||
|     // delete hop headers including header.connection
 | ||||
|     remove_hop_header(headers); | ||||
|     // X-Forwarded-For
 | ||||
|     add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &original_uri)?; | ||||
| 
 | ||||
|     // Add te: trailer if te_trailer
 | ||||
|     if contains_te_trailers { | ||||
|       headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); | ||||
|     } | ||||
| 
 | ||||
|     // add "host" header of original server_name if not exist (default)
 | ||||
|     if req.headers().get(header::HOST).is_none() { | ||||
|       let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); | ||||
|       req | ||||
|         .headers_mut() | ||||
|         .insert(header::HOST, HeaderValue::from_str(&org_host)?); | ||||
|     }; | ||||
|     let original_host_header = req.headers().get(header::HOST).unwrap().clone(); | ||||
| 
 | ||||
|     /////////////////////////////////////////////
 | ||||
|     // Fix unique upstream destination since there could be multiple ones.
 | ||||
|     #[cfg(feature = "sticky-cookie")] | ||||
|     let (upstream_chosen_opt, context_from_lb) = { | ||||
|       let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_candidates.load_balance { | ||||
|         takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? | ||||
|       } else { | ||||
|         None | ||||
|       }; | ||||
|       upstream_candidates.get(&context_to_lb) | ||||
|     }; | ||||
|     #[cfg(not(feature = "sticky-cookie"))] | ||||
|     let (upstream_chosen_opt, _) = upstream_candidates.get(&None); | ||||
| 
 | ||||
|     let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; | ||||
|     let context = HandlerContext { | ||||
|       #[cfg(feature = "sticky-cookie")] | ||||
|       context_lb: context_from_lb, | ||||
|       #[cfg(not(feature = "sticky-cookie"))] | ||||
|       context_lb: None, | ||||
|     }; | ||||
|     /////////////////////////////////////////////
 | ||||
| 
 | ||||
|     // apply upstream-specific headers given in upstream_option
 | ||||
|     let headers = req.headers_mut(); | ||||
|     // by default, host header is overwritten with upstream hostname
 | ||||
|     override_host_header(headers, &upstream_chosen.uri)?; | ||||
|     // apply upstream options to header
 | ||||
|     apply_upstream_options_to_header(headers, &original_host_header, upstream_candidates)?; | ||||
| 
 | ||||
|     // update uri in request
 | ||||
|     ensure!( | ||||
|       upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some(), | ||||
|       "Upstream uri `scheme` and `authority` is broken" | ||||
|     ); | ||||
| 
 | ||||
|     let new_uri = Uri::builder() | ||||
|       .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) | ||||
|       .authority(upstream_chosen.uri.authority().unwrap().as_str()); | ||||
|     let org_pq = match req.uri().path_and_query() { | ||||
|       Some(pq) => pq.to_string(), | ||||
|       None => "/".to_string(), | ||||
|     } | ||||
|     .into_bytes(); | ||||
| 
 | ||||
|     // replace some parts of path if opt_replace_path is enabled for chosen upstream
 | ||||
|     let new_pq = match &upstream_candidates.replace_path { | ||||
|       Some(new_path) => { | ||||
|         let matched_path: &[u8] = upstream_candidates.path.as_ref(); | ||||
|         ensure!( | ||||
|           !matched_path.is_empty() && org_pq.len() >= matched_path.len(), | ||||
|           "Upstream uri `path and query` is broken" | ||||
|         ); | ||||
|         let mut new_pq = Vec::<u8>::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); | ||||
|         new_pq.extend_from_slice(new_path.as_ref()); | ||||
|         new_pq.extend_from_slice(&org_pq[matched_path.len()..]); | ||||
|         new_pq | ||||
|       } | ||||
|       None => org_pq, | ||||
|     }; | ||||
|     *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; | ||||
| 
 | ||||
|     // upgrade
 | ||||
|     if let Some(v) = upgrade { | ||||
|       req.headers_mut().insert(header::UPGRADE, v.parse()?); | ||||
|       req | ||||
|         .headers_mut() | ||||
|         .insert(header::CONNECTION, HeaderValue::from_static("upgrade")); | ||||
|     } | ||||
|     if upgrade.is_none() { | ||||
|       // can update request line i.e., http version, only if not upgrade (http 1.1)
 | ||||
|       update_request_line(req, upstream_chosen, upstream_candidates)?; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     Ok(context) | ||||
|   } | ||||
| } | ||||
							
								
								
									
										99
									
								
								rpxy-lib/src/message_handler/http_log.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								rpxy-lib/src/message_handler/http_log.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,99 @@ | |||
| use super::canonical_address::ToCanonical; | ||||
| use crate::log::*; | ||||
| use http::header; | ||||
| use std::net::SocketAddr; | ||||
| 
 | ||||
| /// Struct to log HTTP messages
 | ||||
| #[derive(Debug, Clone)] | ||||
| pub struct HttpMessageLog { | ||||
|   // pub tls_server_name: String,
 | ||||
|   pub client_addr: String, | ||||
|   pub method: String, | ||||
|   pub host: String, | ||||
|   pub p_and_q: String, | ||||
|   pub version: http::Version, | ||||
|   pub uri_scheme: String, | ||||
|   pub uri_host: String, | ||||
|   pub ua: String, | ||||
|   pub xff: String, | ||||
|   pub status: String, | ||||
|   pub upstream: String, | ||||
| } | ||||
| 
 | ||||
| impl<T> From<&http::Request<T>> for HttpMessageLog { | ||||
|   fn from(req: &http::Request<T>) -> Self { | ||||
|     let header_mapper = |v: header::HeaderName| { | ||||
|       req | ||||
|         .headers() | ||||
|         .get(v) | ||||
|         .map_or_else(|| "", |s| s.to_str().unwrap_or("")) | ||||
|         .to_string() | ||||
|     }; | ||||
|     Self { | ||||
|       // tls_server_name: "".to_string(),
 | ||||
|       client_addr: "".to_string(), | ||||
|       method: req.method().to_string(), | ||||
|       host: header_mapper(header::HOST), | ||||
|       p_and_q: req | ||||
|         .uri() | ||||
|         .path_and_query() | ||||
|         .map_or_else(|| "", |v| v.as_str()) | ||||
|         .to_string(), | ||||
|       version: req.version(), | ||||
|       uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), | ||||
|       uri_host: req.uri().host().unwrap_or("").to_string(), | ||||
|       ua: header_mapper(header::USER_AGENT), | ||||
|       xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), | ||||
|       status: "".to_string(), | ||||
|       upstream: "".to_string(), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl HttpMessageLog { | ||||
|   pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { | ||||
|     self.client_addr = client_addr.to_canonical().to_string(); | ||||
|     self | ||||
|   } | ||||
|   // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self {
 | ||||
|   //   self.tls_server_name = tls_server_name.to_string();
 | ||||
|   //   self
 | ||||
|   // }
 | ||||
|   pub fn status_code(&mut self, status_code: &http::StatusCode) -> &mut Self { | ||||
|     self.status = status_code.to_string(); | ||||
|     self | ||||
|   } | ||||
|   pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { | ||||
|     self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); | ||||
|     self | ||||
|   } | ||||
|   pub fn upstream(&mut self, upstream: &http::Uri) -> &mut Self { | ||||
|     self.upstream = upstream.to_string(); | ||||
|     self | ||||
|   } | ||||
| 
 | ||||
|   pub fn output(&self) { | ||||
|     info!( | ||||
|       "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", | ||||
|       if !self.host.is_empty() { | ||||
|         self.host.as_str() | ||||
|       } else { | ||||
|         self.uri_host.as_str() | ||||
|       }, | ||||
|       self.client_addr, | ||||
|       self.method, | ||||
|       self.p_and_q, | ||||
|       self.version, | ||||
|       self.status, | ||||
|       if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { | ||||
|         format!("{}://{}", self.uri_scheme, self.uri_host) | ||||
|       } else { | ||||
|         "".to_string() | ||||
|       }, | ||||
|       self.ua, | ||||
|       self.xff, | ||||
|       self.upstream, | ||||
|       // self.tls_server_name
 | ||||
|     ); | ||||
|   } | ||||
| } | ||||
							
								
								
									
										61
									
								
								rpxy-lib/src/message_handler/http_result.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								rpxy-lib/src/message_handler/http_result.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,61 @@ | |||
| use http::StatusCode; | ||||
| use thiserror::Error; | ||||
| 
 | ||||
| /// HTTP result type, T is typically a hyper::Response
 | ||||
| /// HttpError is used to generate a synthetic error response
 | ||||
| pub(crate) type HttpResult<T> = std::result::Result<T, HttpError>; | ||||
| 
 | ||||
| /// Describes things that can go wrong in the forwarder
 | ||||
| #[derive(Debug, Error)] | ||||
| pub enum HttpError { | ||||
|   // #[error("No host is give in request header")]
 | ||||
|   // NoHostInRequestHeader,
 | ||||
|   #[error("Invalid host in request header")] | ||||
|   InvalidHostInRequestHeader, | ||||
|   #[error("SNI and Host header mismatch")] | ||||
|   SniHostInconsistency, | ||||
|   #[error("No matching backend app")] | ||||
|   NoMatchingBackendApp, | ||||
|   #[error("Failed to redirect: {0}")] | ||||
|   FailedToRedirect(String), | ||||
|   #[error("No upstream candidates")] | ||||
|   NoUpstreamCandidates, | ||||
|   #[error("Failed to generate upstream request for backend application: {0}")] | ||||
|   FailedToGenerateUpstreamRequest(String), | ||||
|   #[error("Failed to get response from backend: {0}")] | ||||
|   FailedToGetResponseFromBackend(String), | ||||
| 
 | ||||
|   #[error("Failed to add set-cookie header in response {0}")] | ||||
|   FailedToAddSetCookeInResponse(String), | ||||
|   #[error("Failed to generated downstream response for clients: {0}")] | ||||
|   FailedToGenerateDownstreamResponse(String), | ||||
| 
 | ||||
|   #[error("Failed to upgrade connection: {0}")] | ||||
|   FailedToUpgrade(String), | ||||
|   // #[error("Request does not have an upgrade extension")]
 | ||||
|   // NoUpgradeExtensionInRequest,
 | ||||
|   // #[error("Response does not have an upgrade extension")]
 | ||||
|   // NoUpgradeExtensionInResponse,
 | ||||
|   #[error(transparent)] | ||||
|   Other(#[from] anyhow::Error), | ||||
| } | ||||
| 
 | ||||
| impl From<HttpError> for StatusCode { | ||||
|   fn from(e: HttpError) -> StatusCode { | ||||
|     match e { | ||||
|       // HttpError::NoHostInRequestHeader => StatusCode::BAD_REQUEST,
 | ||||
|       HttpError::InvalidHostInRequestHeader => StatusCode::BAD_REQUEST, | ||||
|       HttpError::SniHostInconsistency => StatusCode::MISDIRECTED_REQUEST, | ||||
|       HttpError::NoMatchingBackendApp => StatusCode::SERVICE_UNAVAILABLE, | ||||
|       HttpError::FailedToRedirect(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||
|       HttpError::NoUpstreamCandidates => StatusCode::NOT_FOUND, | ||||
|       HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||
|       HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||
|       HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||
|       HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||
|       // HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST,
 | ||||
|       // HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY,
 | ||||
|       _ => StatusCode::INTERNAL_SERVER_ERROR, | ||||
|     } | ||||
|   } | ||||
| } | ||||
							
								
								
									
										11
									
								
								rpxy-lib/src/message_handler/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								rpxy-lib/src/message_handler/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,11 @@ | |||
| mod canonical_address; | ||||
| mod handler_main; | ||||
| mod handler_manipulate_messages; | ||||
| mod http_log; | ||||
| mod http_result; | ||||
| mod synthetic_response; | ||||
| mod utils_headers; | ||||
| mod utils_request; | ||||
| 
 | ||||
| pub use handler_main::HttpMessageHandlerBuilderError; | ||||
| pub(crate) use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder}; | ||||
							
								
								
									
										42
									
								
								rpxy-lib/src/message_handler/synthetic_response.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								rpxy-lib/src/message_handler/synthetic_response.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,42 @@ | |||
| use super::http_result::{HttpError, HttpResult}; | ||||
| use crate::{ | ||||
|   error::*, | ||||
|   hyper_ext::body::{empty, ResponseBody}, | ||||
|   name_exp::ServerName, | ||||
| }; | ||||
| use http::{Request, Response, StatusCode, Uri}; | ||||
| 
 | ||||
| /// build http response with status code of 4xx and 5xx
 | ||||
| pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult<Response<ResponseBody>> { | ||||
|   let res = Response::builder() | ||||
|     .status(status_code) | ||||
|     .body(ResponseBody::Boxed(empty())) | ||||
|     .unwrap(); | ||||
|   Ok(res) | ||||
| } | ||||
| 
 | ||||
| /// Generate synthetic response message of a redirection to https host with 301
 | ||||
| pub(super) fn secure_redirection_response<B>( | ||||
|   server_name: &ServerName, | ||||
|   tls_port: Option<u16>, | ||||
|   req: &Request<B>, | ||||
| ) -> HttpResult<Response<ResponseBody>> { | ||||
|   let server_name: String = server_name.try_into().unwrap_or_default(); | ||||
|   let pq = match req.uri().path_and_query() { | ||||
|     Some(x) => x.as_str(), | ||||
|     _ => "", | ||||
|   }; | ||||
|   let new_uri = Uri::builder().scheme("https").path_and_query(pq); | ||||
|   let dest_uri = match tls_port { | ||||
|     Some(443) | None => new_uri.authority(server_name), | ||||
|     Some(p) => new_uri.authority(format!("{server_name}:{p}")), | ||||
|   } | ||||
|   .build() | ||||
|   .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; | ||||
|   let response = Response::builder() | ||||
|     .status(StatusCode::MOVED_PERMANENTLY) | ||||
|     .header("Location", dest_uri.to_string()) | ||||
|     .body(ResponseBody::Boxed(empty())) | ||||
|     .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; | ||||
|   Ok(response) | ||||
| } | ||||
							
								
								
									
										293
									
								
								rpxy-lib/src/message_handler/utils_headers.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										293
									
								
								rpxy-lib/src/message_handler/utils_headers.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,293 @@ | |||
| use super::canonical_address::ToCanonical; | ||||
| use crate::{ | ||||
|   backend::{UpstreamCandidates, UpstreamOption}, | ||||
|   log::*, | ||||
| }; | ||||
| use anyhow::{anyhow, ensure, Result}; | ||||
| use bytes::BufMut; | ||||
| use http::{header, HeaderMap, HeaderName, HeaderValue, Uri}; | ||||
| use std::{borrow::Cow, net::SocketAddr}; | ||||
| 
 | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| use crate::backend::{LoadBalanceContext, StickyCookie, StickyCookieValue}; | ||||
| // use crate::backend::{UpstreamGroup, UpstreamOption};
 | ||||
| 
 | ||||
| // ////////////////////////////////////////////////////
 | ||||
| // // Functions to manipulate headers
 | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| /// Take sticky cookie header value from request header,
 | ||||
| /// and returns LoadBalanceContext to be forwarded to LB if exist and if needed.
 | ||||
| /// Removing sticky cookie is needed and it must not be passed to the upstream.
 | ||||
| pub(super) fn takeout_sticky_cookie_lb_context( | ||||
|   headers: &mut HeaderMap, | ||||
|   expected_cookie_name: &str, | ||||
| ) -> Result<Option<LoadBalanceContext>> { | ||||
|   let mut headers_clone = headers.clone(); | ||||
| 
 | ||||
|   match headers_clone.entry(header::COOKIE) { | ||||
|     header::Entry::Vacant(_) => Ok(None), | ||||
|     header::Entry::Occupied(entry) => { | ||||
|       let cookies_iter = entry | ||||
|         .iter() | ||||
|         .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); | ||||
|       let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter | ||||
|         .into_iter() | ||||
|         .partition(|v| v.starts_with(expected_cookie_name)); | ||||
|       if sticky_cookies.is_empty() { | ||||
|         return Ok(None); | ||||
|       } | ||||
|       ensure!( | ||||
|         sticky_cookies.len() == 1, | ||||
|         "Invalid cookie: Multiple sticky cookie values" | ||||
|       ); | ||||
| 
 | ||||
|       let cookies_passed_to_upstream = without_sticky_cookies.join("; "); | ||||
|       let cookie_passed_to_lb = sticky_cookies.first().unwrap(); | ||||
|       headers.remove(header::COOKIE); | ||||
|       headers.insert(header::COOKIE, cookies_passed_to_upstream.parse()?); | ||||
| 
 | ||||
|       let sticky_cookie = StickyCookie { | ||||
|         value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, | ||||
|         info: None, | ||||
|       }; | ||||
|       Ok(Some(LoadBalanceContext { sticky_cookie })) | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(feature = "sticky-cookie")] | ||||
| /// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated.
 | ||||
| /// Set-Cookie response header could be in multiple lines.
 | ||||
| /// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie
 | ||||
| pub(super) fn set_sticky_cookie_lb_context( | ||||
|   headers: &mut HeaderMap, | ||||
|   context_from_lb: &LoadBalanceContext, | ||||
| ) -> Result<()> { | ||||
|   let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; | ||||
|   let new_header_val: HeaderValue = sticky_cookie_string.parse()?; | ||||
|   let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; | ||||
|   match headers.entry(header::SET_COOKIE) { | ||||
|     header::Entry::Vacant(entry) => { | ||||
|       entry.insert(new_header_val); | ||||
|     } | ||||
|     header::Entry::Occupied(mut entry) => { | ||||
|       let mut flag = false; | ||||
|       for e in entry.iter_mut() { | ||||
|         if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { | ||||
|           *e = new_header_val.clone(); | ||||
|           flag = true; | ||||
|         } | ||||
|       } | ||||
|       if !flag { | ||||
|         entry.append(new_header_val); | ||||
|       } | ||||
|     } | ||||
|   }; | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// default: overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy)
 | ||||
| pub(super) fn override_host_header(headers: &mut HeaderMap, upstream_base_uri: &Uri) -> Result<()> { | ||||
|   let mut upstream_host = upstream_base_uri | ||||
|     .host() | ||||
|     .ok_or_else(|| anyhow!("No hostname is given"))? | ||||
|     .to_string(); | ||||
|   // add port if it is not default
 | ||||
|   if let Some(port) = upstream_base_uri.port_u16() { | ||||
|     upstream_host = format!("{}:{}", upstream_host, port); | ||||
|   } | ||||
| 
 | ||||
|   // overwrite host header, this removes all the HOST header values
 | ||||
|   headers.insert(header::HOST, HeaderValue::from_str(&upstream_host)?); | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// Apply options to request header, which are specified in the configuration
 | ||||
| pub(super) fn apply_upstream_options_to_header( | ||||
|   headers: &mut HeaderMap, | ||||
|   original_host_header: &HeaderValue, | ||||
|   // _client_addr: &SocketAddr,
 | ||||
|   upstream: &UpstreamCandidates, | ||||
|   // _upstream_base_uri: &Uri,
 | ||||
| ) -> Result<()> { | ||||
|   for opt in upstream.options.iter() { | ||||
|     match opt { | ||||
|       UpstreamOption::KeepOriginalHost => { | ||||
|         // revert hostname
 | ||||
|         headers | ||||
|           .insert(header::HOST, original_host_header.to_owned()) | ||||
|           .ok_or_else(|| anyhow!("Failed to revert host header in keep_original_host option"))?; | ||||
|       } | ||||
|       UpstreamOption::UpgradeInsecureRequests => { | ||||
|         // add upgrade-insecure-requests in request header if not exist
 | ||||
|         headers | ||||
|           .entry(header::UPGRADE_INSECURE_REQUESTS) | ||||
|           .or_insert(HeaderValue::from_bytes(&[b'1']).unwrap()); | ||||
|       } | ||||
|       _ => (), | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// Append header entry with comma according to [RFC9110](https://datatracker.ietf.org/doc/html/rfc9110)
 | ||||
| pub(super) fn append_header_entry_with_comma(headers: &mut HeaderMap, key: &str, value: &str) -> Result<()> { | ||||
|   match headers.entry(HeaderName::from_bytes(key.as_bytes())?) { | ||||
|     header::Entry::Vacant(entry) => { | ||||
|       entry.insert(value.parse::<HeaderValue>()?); | ||||
|     } | ||||
|     header::Entry::Occupied(mut entry) => { | ||||
|       // entry.append(value.parse::<HeaderValue>()?);
 | ||||
|       let mut new_value = Vec::<u8>::with_capacity(entry.get().as_bytes().len() + 2 + value.len()); | ||||
|       new_value.put_slice(entry.get().as_bytes()); | ||||
|       new_value.put_slice(&[b',', b' ']); | ||||
|       new_value.put_slice(value.as_bytes()); | ||||
|       entry.insert(HeaderValue::from_bytes(&new_value)?); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// Add header entry if not exist
 | ||||
| pub(super) fn add_header_entry_if_not_exist( | ||||
|   headers: &mut HeaderMap, | ||||
|   key: impl Into<Cow<'static, str>>, | ||||
|   value: impl Into<Cow<'static, str>>, | ||||
| ) -> Result<()> { | ||||
|   match headers.entry(HeaderName::from_bytes(key.into().as_bytes())?) { | ||||
|     header::Entry::Vacant(entry) => { | ||||
|       entry.insert(value.into().parse::<HeaderValue>()?); | ||||
|     } | ||||
|     header::Entry::Occupied(_) => (), | ||||
|   }; | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// Overwrite header entry if exist
 | ||||
| pub(super) fn add_header_entry_overwrite_if_exist( | ||||
|   headers: &mut HeaderMap, | ||||
|   key: impl Into<Cow<'static, str>>, | ||||
|   value: impl Into<Cow<'static, str>>, | ||||
| ) -> Result<()> { | ||||
|   match headers.entry(HeaderName::from_bytes(key.into().as_bytes())?) { | ||||
|     header::Entry::Vacant(entry) => { | ||||
|       entry.insert(value.into().parse::<HeaderValue>()?); | ||||
|     } | ||||
|     header::Entry::Occupied(mut entry) => { | ||||
|       entry.insert(HeaderValue::from_bytes(value.into().as_bytes())?); | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// Align cookie values in single line
 | ||||
| /// Sometimes violates [RFC6265](https://www.rfc-editor.org/rfc/rfc6265#section-5.4) (for http/1.1).
 | ||||
| /// This is allowed in RFC7540 (for http/2) as mentioned [here](https://stackoverflow.com/questions/4843556/in-http-specification-what-is-the-string-that-separates-cookies).
 | ||||
| pub(super) fn make_cookie_single_line(headers: &mut HeaderMap) -> Result<()> { | ||||
|   let cookies = headers | ||||
|     .iter() | ||||
|     .filter(|(k, _)| **k == header::COOKIE) | ||||
|     .map(|(_, v)| v.to_str().unwrap_or("")) | ||||
|     .collect::<Vec<_>>() | ||||
|     .join("; "); | ||||
|   if !cookies.is_empty() { | ||||
|     headers.remove(header::COOKIE); | ||||
|     headers.insert(header::COOKIE, HeaderValue::from_bytes(cookies.as_bytes())?); | ||||
|   } | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// Add forwarding headers like `x-forwarded-for`.
 | ||||
| pub(super) fn add_forwarding_header( | ||||
|   headers: &mut HeaderMap, | ||||
|   client_addr: &SocketAddr, | ||||
|   listen_addr: &SocketAddr, | ||||
|   tls: bool, | ||||
|   uri_str: &str, | ||||
| ) -> Result<()> { | ||||
|   // default process
 | ||||
|   // optional process defined by upstream_option is applied in fn apply_upstream_options
 | ||||
|   let canonical_client_addr = client_addr.to_canonical().ip().to_string(); | ||||
|   append_header_entry_with_comma(headers, "x-forwarded-for", &canonical_client_addr)?; | ||||
| 
 | ||||
|   // Single line cookie header
 | ||||
|   // TODO: This should be only for HTTP/1.1. For 2+, this can be multi-lined.
 | ||||
|   make_cookie_single_line(headers)?; | ||||
| 
 | ||||
|   /////////// As Nginx
 | ||||
|   // If we receive X-Forwarded-Proto, pass it through; otherwise, pass along the
 | ||||
|   // scheme used to connect to this server
 | ||||
|   add_header_entry_if_not_exist(headers, "x-forwarded-proto", if tls { "https" } else { "http" })?; | ||||
|   // If we receive X-Forwarded-Port, pass it through; otherwise, pass along the
 | ||||
|   // server port the client connected to
 | ||||
|   add_header_entry_if_not_exist(headers, "x-forwarded-port", listen_addr.port().to_string())?; | ||||
| 
 | ||||
|   /////////// As Nginx-Proxy
 | ||||
|   // x-real-ip
 | ||||
|   add_header_entry_overwrite_if_exist(headers, "x-real-ip", canonical_client_addr)?; | ||||
|   // x-forwarded-ssl
 | ||||
|   add_header_entry_overwrite_if_exist(headers, "x-forwarded-ssl", if tls { "on" } else { "off" })?; | ||||
|   // x-original-uri
 | ||||
|   add_header_entry_overwrite_if_exist(headers, "x-original-uri", uri_str.to_string())?; | ||||
|   // proxy
 | ||||
|   add_header_entry_overwrite_if_exist(headers, "proxy", "")?; | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
| 
 | ||||
| /// Remove connection header
 | ||||
| pub(super) fn remove_connection_header(headers: &mut HeaderMap) { | ||||
|   if let Some(values) = headers.get(header::CONNECTION) { | ||||
|     if let Ok(v) = values.clone().to_str() { | ||||
|       for m in v.split(',') { | ||||
|         if !m.is_empty() { | ||||
|           headers.remove(m.trim()); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Hop header values which are removed at proxy
 | ||||
| const HOP_HEADERS: &[&str] = &[ | ||||
|   "connection", | ||||
|   "te", | ||||
|   "trailer", | ||||
|   "keep-alive", | ||||
|   "proxy-connection", | ||||
|   "proxy-authenticate", | ||||
|   "proxy-authorization", | ||||
|   "transfer-encoding", | ||||
|   "upgrade", | ||||
| ]; | ||||
| 
 | ||||
| /// Remove hop headers
 | ||||
| pub(super) fn remove_hop_header(headers: &mut HeaderMap) { | ||||
|   HOP_HEADERS.iter().for_each(|key| { | ||||
|     headers.remove(*key); | ||||
|   }); | ||||
| } | ||||
| 
 | ||||
| /// Extract upgrade header value if exist
 | ||||
| pub(super) fn extract_upgrade(headers: &HeaderMap) -> Option<String> { | ||||
|   if let Some(c) = headers.get(header::CONNECTION) { | ||||
|     if c | ||||
|       .to_str() | ||||
|       .unwrap_or("") | ||||
|       .split(',') | ||||
|       .any(|w| w.trim().to_ascii_lowercase() == header::UPGRADE.as_str().to_ascii_lowercase()) | ||||
|     { | ||||
|       if let Some(u) = headers.get(header::UPGRADE) { | ||||
|         if let Ok(m) = u.to_str() { | ||||
|           debug!("Upgrade in request header: {}", m); | ||||
|           return Some(m.to_owned()); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|   } | ||||
|   None | ||||
| } | ||||
							
								
								
									
										86
									
								
								rpxy-lib/src/message_handler/utils_request.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								rpxy-lib/src/message_handler/utils_request.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,86 @@ | |||
| use crate::{ | ||||
|   backend::{Upstream, UpstreamCandidates, UpstreamOption}, | ||||
|   log::*, | ||||
| }; | ||||
| use anyhow::{anyhow, ensure, Result}; | ||||
| use http::{header, uri::Scheme, Request, Version}; | ||||
| 
 | ||||
| /// Trait defining parser of hostname
 | ||||
| /// Inspect and extract hostname from either the request HOST header or request line
 | ||||
| pub trait InspectParseHost { | ||||
|   type Error; | ||||
|   fn inspect_parse_host(&self) -> Result<Vec<u8>, Self::Error>; | ||||
| } | ||||
| impl<B> InspectParseHost for Request<B> { | ||||
|   type Error = anyhow::Error; | ||||
|   /// Inspect and extract hostname from either the request HOST header or request line
 | ||||
|   fn inspect_parse_host(&self) -> Result<Vec<u8>> { | ||||
|     let drop_port = |v: &[u8]| { | ||||
|       if v.starts_with(&[b'[']) { | ||||
|         // v6 address with bracket case. if port is specified, always it is in this case.
 | ||||
|         let mut iter = v.split(|ptr| ptr == &b'[' || ptr == &b']'); | ||||
|         iter.next().ok_or(anyhow!("Invalid Host header"))?; // first item is always blank
 | ||||
|         iter.next().ok_or(anyhow!("Invalid Host header")).map(|b| b.to_owned()) | ||||
|       } else if v.len() - v.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { | ||||
|         // v6 address case, if 2 or more ':' is contained
 | ||||
|         Ok(v.to_owned()) | ||||
|       } else { | ||||
|         // v4 address or hostname
 | ||||
|         v.split(|colon| colon == &b':') | ||||
|           .next() | ||||
|           .ok_or(anyhow!("Invalid Host header")) | ||||
|           .map(|v| v.to_ascii_lowercase()) | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     let headers_host = self.headers().get(header::HOST).map(|v| drop_port(v.as_bytes())); | ||||
|     let uri_host = self.uri().host().map(|v| drop_port(v.as_bytes())); | ||||
|     // let uri_port = self.uri().port_u16();
 | ||||
| 
 | ||||
|     // prioritize server_name in uri
 | ||||
|     match (headers_host, uri_host) { | ||||
|       (Some(Ok(hh)), Some(Ok(hu))) => { | ||||
|         ensure!(hh == hu, "Host header and uri host mismatch"); | ||||
|         Ok(hh) | ||||
|       } | ||||
|       (Some(Ok(hh)), None) => Ok(hh), | ||||
|       (None, Some(Ok(hu))) => Ok(hu), | ||||
|       _ => Err(anyhow!("Neither Host header nor uri host is valid")), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////
 | ||||
| // Functions to manipulate request line
 | ||||
| 
 | ||||
| /// Update request line, e.g., version, and apply upstream options to request line, specified in the configuration
 | ||||
| pub(super) fn update_request_line<B>( | ||||
|   req: &mut Request<B>, | ||||
|   upstream_chosen: &Upstream, | ||||
|   upstream_candidates: &UpstreamCandidates, | ||||
| ) -> anyhow::Result<()> { | ||||
|   // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
 | ||||
|   if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { | ||||
|     // Change version to http/1.1 when destination scheme is http
 | ||||
|     debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); | ||||
|     *req.version_mut() = Version::HTTP_11; | ||||
|   } else if req.version() == Version::HTTP_3 { | ||||
|     // HTTP/3 is always https
 | ||||
|     debug!("HTTP/3 is currently unsupported for request to upstream."); | ||||
|     *req.version_mut() = Version::HTTP_2; | ||||
|   } | ||||
| 
 | ||||
|   for opt in upstream_candidates.options.iter() { | ||||
|     match opt { | ||||
|       UpstreamOption::ForceHttp11Upstream => *req.version_mut() = Version::HTTP_11, | ||||
|       UpstreamOption::ForceHttp2Upstream => { | ||||
|         // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt
 | ||||
|         // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required.
 | ||||
|         *req.version_mut() = Version::HTTP_2; | ||||
|       } | ||||
|       _ => (), | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|   Ok(()) | ||||
| } | ||||
							
								
								
									
										160
									
								
								rpxy-lib/src/name_exp.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								rpxy-lib/src/name_exp.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,160 @@ | |||
| use std::borrow::Cow; | ||||
| 
 | ||||
| /// Server name (hostname or ip address) representation in bytes-based struct
 | ||||
| /// for searching hashmap or key list by exact or longest-prefix matching
 | ||||
| #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] | ||||
| pub struct ServerName { | ||||
|   inner: Vec<u8>, // lowercase ascii bytes
 | ||||
| } | ||||
| impl From<&str> for ServerName { | ||||
|   fn from(s: &str) -> Self { | ||||
|     let name = s.bytes().collect::<Vec<u8>>().to_ascii_lowercase(); | ||||
|     Self { inner: name } | ||||
|   } | ||||
| } | ||||
| impl From<&[u8]> for ServerName { | ||||
|   fn from(b: &[u8]) -> Self { | ||||
|     Self { | ||||
|       inner: b.to_ascii_lowercase(), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| impl TryInto<String> for &ServerName { | ||||
|   type Error = anyhow::Error; | ||||
|   fn try_into(self) -> Result<String, Self::Error> { | ||||
|     let s = std::str::from_utf8(&self.inner)?; | ||||
|     Ok(s.to_string()) | ||||
|   } | ||||
| } | ||||
| impl AsRef<[u8]> for ServerName { | ||||
|   fn as_ref(&self) -> &[u8] { | ||||
|     self.inner.as_ref() | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Path name, like "/path/ok", represented in bytes-based struct
 | ||||
| /// for searching hashmap or key list by exact or longest-prefix matching
 | ||||
| #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] | ||||
| pub struct PathName { | ||||
|   inner: Vec<u8>, // lowercase ascii bytes
 | ||||
| } | ||||
| impl From<&str> for PathName { | ||||
|   fn from(s: &str) -> Self { | ||||
|     let name = s.bytes().collect::<Vec<u8>>().to_ascii_lowercase(); | ||||
|     Self { inner: name } | ||||
|   } | ||||
| } | ||||
| impl From<&[u8]> for PathName { | ||||
|   fn from(b: &[u8]) -> Self { | ||||
|     Self { | ||||
|       inner: b.to_ascii_lowercase(), | ||||
|     } | ||||
|   } | ||||
| } | ||||
| impl TryInto<String> for &PathName { | ||||
|   type Error = anyhow::Error; | ||||
|   fn try_into(self) -> Result<String, Self::Error> { | ||||
|     let s = std::str::from_utf8(&self.inner)?; | ||||
|     Ok(s.to_string()) | ||||
|   } | ||||
| } | ||||
| impl AsRef<[u8]> for PathName { | ||||
|   fn as_ref(&self) -> &[u8] { | ||||
|     self.inner.as_ref() | ||||
|   } | ||||
| } | ||||
| impl PathName { | ||||
|   pub fn len(&self) -> usize { | ||||
|     self.inner.len() | ||||
|   } | ||||
|   pub fn is_empty(&self) -> bool { | ||||
|     self.inner.len() == 0 | ||||
|   } | ||||
|   pub fn get<I>(&self, index: I) -> Option<&I::Output> | ||||
|   where | ||||
|     I: std::slice::SliceIndex<[u8]>, | ||||
|   { | ||||
|     self.inner.get(index) | ||||
|   } | ||||
|   pub fn starts_with(&self, needle: &Self) -> bool { | ||||
|     self.inner.starts_with(&needle.inner) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| /// Trait to express names in ascii-lowercased bytes
 | ||||
| pub trait ByteName { | ||||
|   type OutputServer: Send + Sync + 'static; | ||||
|   type OutputPath; | ||||
|   fn to_server_name(self) -> Self::OutputServer; | ||||
|   fn to_path_name(self) -> Self::OutputPath; | ||||
| } | ||||
| 
 | ||||
| impl<'a, T: Into<Cow<'a, str>>> ByteName for T { | ||||
|   type OutputServer = ServerName; | ||||
|   type OutputPath = PathName; | ||||
| 
 | ||||
|   fn to_server_name(self) -> Self::OutputServer { | ||||
|     ServerName::from(self.into().as_ref()) | ||||
|   } | ||||
| 
 | ||||
|   fn to_path_name(self) -> Self::OutputPath { | ||||
|     PathName::from(self.into().as_ref()) | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[cfg(test)] | ||||
| mod tests { | ||||
|   use super::*; | ||||
|   #[test] | ||||
|   fn bytes_name_str_works() { | ||||
|     let s = "OK_string"; | ||||
|     let bn = s.to_path_name(); | ||||
|     let bn_lc = s.to_server_name(); | ||||
| 
 | ||||
|     assert_eq!("ok_string".as_bytes(), bn.as_ref()); | ||||
|     assert_eq!("ok_string".as_bytes(), bn_lc.as_ref()); | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
|   fn from_works() { | ||||
|     let s = "OK_string".to_server_name(); | ||||
|     let m = ServerName::from("OK_strinG".as_bytes()); | ||||
|     assert_eq!(s, m); | ||||
|     assert_eq!(s.as_ref(), "ok_string".as_bytes()); | ||||
|     assert_eq!(m.as_ref(), "ok_string".as_bytes()); | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
|   fn get_works() { | ||||
|     let s = "OK_str".to_path_name(); | ||||
|     let i = s.get(0); | ||||
|     assert_eq!(Some(&"o".as_bytes()[0]), i); | ||||
|     let i = s.get(1); | ||||
|     assert_eq!(Some(&"k".as_bytes()[0]), i); | ||||
|     let i = s.get(2); | ||||
|     assert_eq!(Some(&"_".as_bytes()[0]), i); | ||||
|     let i = s.get(3); | ||||
|     assert_eq!(Some(&"s".as_bytes()[0]), i); | ||||
|     let i = s.get(4); | ||||
|     assert_eq!(Some(&"t".as_bytes()[0]), i); | ||||
|     let i = s.get(5); | ||||
|     assert_eq!(Some(&"r".as_bytes()[0]), i); | ||||
|     let i = s.get(6); | ||||
|     assert_eq!(None, i); | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
|   fn start_with_works() { | ||||
|     let s = "OK_str".to_path_name(); | ||||
|     let correct = "OK".to_path_name(); | ||||
|     let incorrect = "KO".to_path_name(); | ||||
|     assert!(s.starts_with(&correct)); | ||||
|     assert!(!s.starts_with(&incorrect)); | ||||
|   } | ||||
| 
 | ||||
|   #[test] | ||||
|   fn as_ref_works() { | ||||
|     let s = "OK_str".to_path_name(); | ||||
|     assert_eq!(s.as_ref(), "ok_str".as_bytes()); | ||||
|   } | ||||
| } | ||||
|  | @ -1,13 +1,36 @@ | |||
| mod crypto_service; | ||||
| mod proxy_client_cert; | ||||
| #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
| mod proxy_h3; | ||||
| mod proxy_main; | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| mod proxy_quic_quinn; | ||||
| #[cfg(feature = "http3-s2n")] | ||||
| mod proxy_quic_s2n; | ||||
| mod proxy_tls; | ||||
| mod socket; | ||||
| 
 | ||||
| pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError}; | ||||
| #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
| mod proxy_h3; | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| mod proxy_quic_quinn; | ||||
| #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
| mod proxy_quic_s2n; | ||||
| 
 | ||||
| use crate::{ | ||||
|   globals::Globals, | ||||
|   hyper_ext::rt::{LocalExecutor, TokioTimer}, | ||||
| }; | ||||
| use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| pub(crate) use proxy_main::Proxy; | ||||
| 
 | ||||
| /// build connection builder shared with proxy instances
 | ||||
| pub(crate) fn connection_builder(globals: &Arc<Globals>) -> Arc<ConnectionBuilder<LocalExecutor>> { | ||||
|   let executor = LocalExecutor::new(globals.runtime_handle.clone()); | ||||
|   let mut http_server = server::conn::auto::Builder::new(executor); | ||||
|   http_server | ||||
|     .http1() | ||||
|     .keep_alive(globals.proxy_config.keepalive) | ||||
|     .header_read_timeout(globals.proxy_config.proxy_idle_timeout) | ||||
|     .timer(TokioTimer) | ||||
|     .pipeline_flush(true); | ||||
|   http_server | ||||
|     .http2() | ||||
|     .keep_alive_interval(Some(globals.proxy_config.proxy_idle_timeout)) | ||||
|     .timer(TokioTimer) | ||||
|     .max_concurrent_streams(globals.proxy_config.max_concurrent_streams); | ||||
|   Arc::new(http_server) | ||||
| } | ||||
|  |  | |||
|  | @ -1,25 +1,34 @@ | |||
| use super::Proxy; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; | ||||
| use super::proxy_main::Proxy; | ||||
| use crate::{ | ||||
|   constants::CONNECTION_TIMEOUT_SEC, | ||||
|   crypto::CryptoSource, | ||||
|   error::*, | ||||
|   hyper_ext::body::{IncomingLike, RequestBody}, | ||||
|   log::*, | ||||
|   name_exp::ServerName, | ||||
| }; | ||||
| use bytes::{Buf, Bytes}; | ||||
| use http::{Request, Response}; | ||||
| use http_body_util::BodyExt; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| use std::{net::SocketAddr, time::Duration}; | ||||
| 
 | ||||
| #[cfg(feature = "http3-quinn")] | ||||
| use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | ||||
| use hyper::{client::connect::Connect, Body, Request, Response}; | ||||
| #[cfg(feature = "http3-s2n")] | ||||
| #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||
| use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | ||||
| use std::net::SocketAddr; | ||||
| use tokio::time::{timeout, Duration}; | ||||
| 
 | ||||
| impl<T, U> Proxy<T, U> | ||||
| impl<U, T> Proxy<U, T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn connection_serve_h3<C>( | ||||
|   pub(super) async fn h3_serve_connection<C>( | ||||
|     &self, | ||||
|     quic_connection: C, | ||||
|     tls_server_name: ServerNameBytesExp, | ||||
|     tls_server_name: ServerName, | ||||
|     client_addr: SocketAddr, | ||||
|   ) -> Result<()> | ||||
|   ) -> RpxyResult<()> | ||||
|   where | ||||
|     C: ConnectionQuic<Bytes>, | ||||
|     <C as ConnectionQuic<Bytes>>::BidiStream: BidiStream<Bytes> + Send + 'static, | ||||
|  | @ -28,9 +37,11 @@ where | |||
|   { | ||||
|     let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; | ||||
|     info!( | ||||
|       "QUIC/HTTP3 connection established from {:?} {:?}", | ||||
|       client_addr, tls_server_name | ||||
|       "QUIC/HTTP3 connection established from {:?} {}", | ||||
|       client_addr, | ||||
|       <&ServerName as TryInto<String>>::try_into(&tls_server_name).unwrap_or_default() | ||||
|     ); | ||||
| 
 | ||||
|     // TODO: Is here enough to fetch server_name from NewConnection?
 | ||||
|     // to avoid deep nested call from listener_service_h3
 | ||||
|     loop { | ||||
|  | @ -60,13 +71,13 @@ where | |||
|           let self_inner = self.clone(); | ||||
|           let tls_server_name_inner = tls_server_name.clone(); | ||||
|           self.globals.runtime_handle.spawn(async move { | ||||
|             if let Err(e) = timeout( | ||||
|               self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
 | ||||
|               self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), | ||||
|             if let Err(e) = tokio::time::timeout( | ||||
|               Duration::from_secs(CONNECTION_TIMEOUT_SEC + 1), // just in case...
 | ||||
|               self_inner.h3_serve_stream(req, stream, client_addr, tls_server_name_inner), | ||||
|             ) | ||||
|             .await | ||||
|             { | ||||
|               error!("HTTP/3 failed to process stream: {}", e); | ||||
|               warn!("HTTP/3 error on serve stream: {}", e); | ||||
|             } | ||||
|             request_count.decrement(); | ||||
|             debug!("Request processed: current # {}", request_count.current()); | ||||
|  | @ -78,13 +89,17 @@ where | |||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   async fn stream_serve_h3<S>( | ||||
|   /// Serves a request stream from a client
 | ||||
|   /// Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside.
 | ||||
|   /// Thus, we needed to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of
 | ||||
|   /// Either<Incoming, IncomingLike> as body.
 | ||||
|   async fn h3_serve_stream<S>( | ||||
|     &self, | ||||
|     req: Request<()>, | ||||
|     stream: RequestStream<S, Bytes>, | ||||
|     client_addr: SocketAddr, | ||||
|     tls_server_name: ServerNameBytesExp, | ||||
|   ) -> Result<()> | ||||
|     tls_server_name: ServerName, | ||||
|   ) -> RpxyResult<()> | ||||
|   where | ||||
|     S: BidiStream<Bytes> + Send + 'static, | ||||
|     <S as BidiStream<Bytes>>::RecvStream: Send, | ||||
|  | @ -94,7 +109,7 @@ where | |||
|     let (mut send_stream, mut recv_stream) = stream.split(); | ||||
| 
 | ||||
|     // generate streamed body with trailers using channel
 | ||||
|     let (body_sender, req_body) = Body::channel(); | ||||
|     let (body_sender, req_body) = IncomingLike::channel(); | ||||
| 
 | ||||
|     // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1
 | ||||
|     // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn.
 | ||||
|  | @ -107,10 +122,10 @@ where | |||
|         size += body.remaining(); | ||||
|         if size > max_body_size { | ||||
|           error!( | ||||
|             "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", | ||||
|             "Exceeds max request body size for HTTP/3: received {}, maximum_allowed {}", | ||||
|             size, max_body_size | ||||
|           ); | ||||
|           return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); | ||||
|           return Err(RpxyError::H3TooLargeBody); | ||||
|         } | ||||
|         // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes
 | ||||
|         sender.send_data(body.copy_to_bytes(body.remaining())).await?; | ||||
|  | @ -122,13 +137,12 @@ where | |||
|         debug!("HTTP/3 incoming request trailers"); | ||||
|         sender.send_trailers(trailers.unwrap()).await?; | ||||
|       } | ||||
|       Ok(()) | ||||
|       Ok(()) as RpxyResult<()> | ||||
|     }); | ||||
| 
 | ||||
|     let new_req: Request<Body> = Request::from_parts(req_parts, req_body); | ||||
|     let new_req: Request<RequestBody> = Request::from_parts(req_parts, RequestBody::IncomingLike(req_body)); | ||||
|     let res = self | ||||
|       .msg_handler | ||||
|       .clone() | ||||
|       .message_handler | ||||
|       .handle_request( | ||||
|         new_req, | ||||
|         client_addr, | ||||
|  | @ -138,21 +152,33 @@ where | |||
|       ) | ||||
|       .await?; | ||||
| 
 | ||||
|     let (new_res_parts, new_body) = res.into_parts(); | ||||
|     let (new_res_parts, mut new_body) = res.into_parts(); | ||||
|     let new_res = Response::from_parts(new_res_parts, ()); | ||||
| 
 | ||||
|     match send_stream.send_response(new_res).await { | ||||
|       Ok(_) => { | ||||
|         debug!("HTTP/3 response to connection successful"); | ||||
|         // aggregate body without copying
 | ||||
|         let mut body_data = hyper::body::aggregate(new_body).await?; | ||||
|         // on-demand body streaming to downstream without expanding the object onto memory.
 | ||||
|         loop { | ||||
|           let frame = match new_body.frame().await { | ||||
|             Some(frame) => frame, | ||||
|             None => { | ||||
|               debug!("Response body finished"); | ||||
|               break; | ||||
|             } | ||||
|           } | ||||
|           .map_err(|e| RpxyError::HyperBodyManipulationError(e.to_string()))?; | ||||
| 
 | ||||
|         // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes
 | ||||
|         send_stream | ||||
|           .send_data(body_data.copy_to_bytes(body_data.remaining())) | ||||
|           .await?; | ||||
| 
 | ||||
|         // TODO: needs handling trailer? should be included in body from handler.
 | ||||
|           if frame.is_data() { | ||||
|             let data = frame.into_data().unwrap_or_default(); | ||||
|             // debug!("Write data to HTTP/3 stream");
 | ||||
|             send_stream.send_data(data).await?; | ||||
|           } else if frame.is_trailers() { | ||||
|             let trailers = frame.into_trailers().unwrap_or_default(); | ||||
|             // debug!("Write trailer to HTTP/3 stream");
 | ||||
|             send_stream.send_trailers(trailers).await?; | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|       Err(err) => { | ||||
|         error!("Unable to send response to connection peer: {:?}", err); | ||||
|  |  | |||
|  | @ -1,78 +1,81 @@ | |||
| use super::socket::bind_tcp_socket; | ||||
| use crate::{ | ||||
|   certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp, | ||||
|   constants::{CONNECTION_TIMEOUT_SEC, TLS_HANDSHAKE_TIMEOUT_SEC}, | ||||
|   crypto::{CryptoSource, ServerCrypto, SniServerCryptoMap}, | ||||
|   error::*, | ||||
|   globals::Globals, | ||||
|   hyper_ext::{ | ||||
|     body::{RequestBody, ResponseBody}, | ||||
|     rt::LocalExecutor, | ||||
|   }, | ||||
|   log::*, | ||||
|   message_handler::HttpMessageHandler, | ||||
|   name_exp::ServerName, | ||||
| }; | ||||
| use derive_builder::{self, Builder}; | ||||
| use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; | ||||
| use std::{net::SocketAddr, sync::Arc}; | ||||
| use tokio::{ | ||||
|   io::{AsyncRead, AsyncWrite}, | ||||
|   runtime::Handle, | ||||
|   sync::Notify, | ||||
|   time::{timeout, Duration}, | ||||
| use futures::{select, FutureExt}; | ||||
| use http::{Request, Response}; | ||||
| use hyper::{ | ||||
|   body::Incoming, | ||||
|   rt::{Read, Write}, | ||||
|   service::service_fn, | ||||
| }; | ||||
| use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; | ||||
| use std::{net::SocketAddr, sync::Arc, time::Duration}; | ||||
| use tokio::time::timeout; | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| pub struct LocalExecutor { | ||||
|   runtime_handle: Handle, | ||||
| } | ||||
| 
 | ||||
| impl LocalExecutor { | ||||
|   fn new(runtime_handle: Handle) -> Self { | ||||
|     LocalExecutor { runtime_handle } | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| impl<F> hyper::rt::Executor<F> for LocalExecutor | ||||
| where | ||||
|   F: std::future::Future + Send + 'static, | ||||
|   F::Output: Send, | ||||
| { | ||||
|   fn execute(&self, fut: F) { | ||||
|     self.runtime_handle.spawn(fut); | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone, Builder)] | ||||
| pub struct Proxy<T, U> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub listening_on: SocketAddr, | ||||
|   pub tls_enabled: bool, // TCP待受がTLSかどうか
 | ||||
|   pub msg_handler: Arc<HttpMessageHandler<T, U>>, | ||||
|   pub globals: Arc<Globals<U>>, | ||||
| } | ||||
| 
 | ||||
| impl<T, U> Proxy<T, U> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send, | ||||
| { | ||||
|   /// Wrapper function to handle request
 | ||||
|   async fn serve( | ||||
|     handler: Arc<HttpMessageHandler<T, U>>, | ||||
|     req: Request<Body>, | ||||
| /// Wrapper function to handle request for HTTP/1.1 and HTTP/2
 | ||||
| /// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler
 | ||||
| async fn serve_request<U, T>( | ||||
|   req: Request<Incoming>, | ||||
|   handler: Arc<HttpMessageHandler<U, T>>, | ||||
|   client_addr: SocketAddr, | ||||
|   listen_addr: SocketAddr, | ||||
|   tls_enabled: bool, | ||||
|     tls_server_name: Option<ServerNameBytesExp>, | ||||
|   ) -> Result<hyper::Response<Body>> { | ||||
|   tls_server_name: Option<ServerName>, | ||||
| ) -> RpxyResult<Response<ResponseBody>> | ||||
| where | ||||
|   T: Send + Sync + Connect + Clone, | ||||
|   U: CryptoSource + Clone, | ||||
| { | ||||
|   handler | ||||
|       .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) | ||||
|     .handle_request( | ||||
|       req.map(RequestBody::Incoming), | ||||
|       client_addr, | ||||
|       listen_addr, | ||||
|       tls_enabled, | ||||
|       tls_server_name, | ||||
|     ) | ||||
|     .await | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| #[derive(Clone)] | ||||
| /// Proxy main object responsible to serve requests received from clients at the given socket address.
 | ||||
| pub(crate) struct Proxy<U, T, E = LocalExecutor> | ||||
| where | ||||
|   T: Send + Sync + Connect + Clone + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   /// global context shared among async tasks
 | ||||
|   pub globals: Arc<Globals>, | ||||
|   /// listen socket address
 | ||||
|   pub listening_on: SocketAddr, | ||||
|   /// whether TLS is enabled or not
 | ||||
|   pub tls_enabled: bool, | ||||
|   /// hyper connection builder serving http request
 | ||||
|   pub connection_builder: Arc<ConnectionBuilder<E>>, | ||||
|   /// message handler serving incoming http request
 | ||||
|   pub message_handler: Arc<HttpMessageHandler<U, T>>, | ||||
| } | ||||
| 
 | ||||
| impl<U, T> Proxy<U, T> | ||||
| where | ||||
|   T: Send + Sync + Connect + Clone + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   /// Serves requests from clients
 | ||||
|   pub(super) fn client_serve<I>( | ||||
|     self, | ||||
|     stream: I, | ||||
|     server: Http<LocalExecutor>, | ||||
|     peer_addr: SocketAddr, | ||||
|     tls_server_name: Option<ServerNameBytesExp>, | ||||
|   ) where | ||||
|     I: AsyncRead + AsyncWrite + Send + Unpin + 'static, | ||||
|   fn serve_connection<I>(&self, stream: I, peer_addr: SocketAddr, tls_server_name: Option<ServerName>) | ||||
|   where | ||||
|     I: Read + Write + Send + Unpin + 'static, | ||||
|   { | ||||
|     let request_count = self.globals.request_count.clone(); | ||||
|     if request_count.increment() > self.globals.proxy_config.max_clients { | ||||
|  | @ -81,24 +84,27 @@ where | |||
|     } | ||||
|     debug!("Request incoming: current # {}", request_count.current()); | ||||
| 
 | ||||
|     let server_clone = self.connection_builder.clone(); | ||||
|     let message_handler_clone = self.message_handler.clone(); | ||||
|     let tls_enabled = self.tls_enabled; | ||||
|     let listening_on = self.listening_on; | ||||
|     let timeout_sec = Duration::from_secs(CONNECTION_TIMEOUT_SEC + 1); // just in case...
 | ||||
|     self.globals.runtime_handle.clone().spawn(async move { | ||||
|       timeout( | ||||
|         self.globals.proxy_config.proxy_timeout + Duration::from_secs(1), | ||||
|         server | ||||
|           .serve_connection( | ||||
|         timeout_sec + Duration::from_secs(1), // just in case...
 | ||||
|         server_clone.serve_connection_with_upgrades( | ||||
|           stream, | ||||
|             service_fn(move |req: Request<Body>| { | ||||
|               Self::serve( | ||||
|                 self.msg_handler.clone(), | ||||
|           service_fn(move |req: Request<Incoming>| { | ||||
|             serve_request( | ||||
|               req, | ||||
|               message_handler_clone.clone(), | ||||
|               peer_addr, | ||||
|                 self.listening_on, | ||||
|                 self.tls_enabled, | ||||
|               listening_on, | ||||
|               tls_enabled, | ||||
|               tls_server_name.clone(), | ||||
|             ) | ||||
|           }), | ||||
|           ) | ||||
|           .with_upgrades(), | ||||
|         ), | ||||
|       ) | ||||
|       .await | ||||
|       .ok(); | ||||
|  | @ -109,47 +115,149 @@ where | |||
|   } | ||||
| 
 | ||||
|   /// Start without TLS (HTTP cleartext)
 | ||||
|   async fn start_without_tls(self, server: Http<LocalExecutor>) -> Result<()> { | ||||
|   async fn start_without_tls(&self) -> RpxyResult<()> { | ||||
|     let listener_service = async { | ||||
|       let tcp_socket = bind_tcp_socket(&self.listening_on)?; | ||||
|       let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; | ||||
|       info!("Start TCP proxy serving with HTTP request for configured host names"); | ||||
|       while let Ok((stream, _client_addr)) = tcp_listener.accept().await { | ||||
|         self.clone().client_serve(stream, server.clone(), _client_addr, None); | ||||
|       while let Ok((stream, client_addr)) = tcp_listener.accept().await { | ||||
|         self.serve_connection(TokioIo::new(stream), client_addr, None); | ||||
|       } | ||||
|       Ok(()) as Result<()> | ||||
|       Ok(()) as RpxyResult<()> | ||||
|     }; | ||||
|     listener_service.await?; | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   /// Entrypoint for HTTP/1.1 and HTTP/2 servers
 | ||||
|   pub async fn start(self, term_notify: Option<Arc<Notify>>) -> Result<()> { | ||||
|     let mut server = Http::new(); | ||||
|     server.http1_keep_alive(self.globals.proxy_config.keepalive); | ||||
|     server.http2_max_concurrent_streams(self.globals.proxy_config.max_concurrent_streams); | ||||
|     server.pipeline_flush(true); | ||||
|     let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); | ||||
|     let server = server.with_executor(executor); | ||||
|   /// Start with TLS (HTTPS)
 | ||||
|   pub(super) async fn start_with_tls(&self) -> RpxyResult<()> { | ||||
|     #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] | ||||
|     { | ||||
|       self.tls_listener_service().await?; | ||||
|       error!("TCP proxy service for TLS exited"); | ||||
|       Ok(()) | ||||
|     } | ||||
|     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||
|     { | ||||
|       if self.globals.proxy_config.http3 { | ||||
|         select! { | ||||
|           _ = self.tls_listener_service().fuse() => { | ||||
|             error!("TCP proxy service for TLS exited"); | ||||
|           }, | ||||
|           _ = self.h3_listener_service().fuse() => { | ||||
|             error!("UDP proxy service for QUIC exited"); | ||||
|           } | ||||
|         }; | ||||
|         Ok(()) | ||||
|       } else { | ||||
|         self.tls_listener_service().await?; | ||||
|         error!("TCP proxy service for TLS exited"); | ||||
|         Ok(()) | ||||
|       } | ||||
|     } | ||||
|   } | ||||
| 
 | ||||
|     let listening_on = self.listening_on; | ||||
|   // TCP Listener Service, i.e., http/2 and http/1.1
 | ||||
|   async fn tls_listener_service(&self) -> RpxyResult<()> { | ||||
|     let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { | ||||
|       return Err(RpxyError::NoCertificateReloader); | ||||
|     }; | ||||
|     let tcp_socket = bind_tcp_socket(&self.listening_on)?; | ||||
|     let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; | ||||
|     info!("Start TCP proxy serving with HTTPS request for configured host names"); | ||||
| 
 | ||||
|     let mut server_crypto_map: Option<Arc<SniServerCryptoMap>> = None; | ||||
|     loop { | ||||
|       select! { | ||||
|         tcp_cnx = tcp_listener.accept().fuse() => { | ||||
|           if tcp_cnx.is_err() || server_crypto_map.is_none() { | ||||
|             continue; | ||||
|           } | ||||
|           let (raw_stream, client_addr) = tcp_cnx.unwrap(); | ||||
|           let sc_map_inner = server_crypto_map.clone(); | ||||
|           let self_inner = self.clone(); | ||||
| 
 | ||||
|           // spawns async handshake to avoid blocking thread by sequential handshake.
 | ||||
|           let handshake_fut = async move { | ||||
|             let acceptor = tokio_rustls::LazyConfigAcceptor::new(tokio_rustls::rustls::server::Acceptor::default(), raw_stream).await; | ||||
|             if let Err(e) = acceptor { | ||||
|               return Err(RpxyError::FailedToTlsHandshake(e.to_string())); | ||||
|             } | ||||
|             let start = acceptor.unwrap(); | ||||
|             let client_hello = start.client_hello(); | ||||
|             let sni = client_hello.server_name(); | ||||
|             debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", sni.unwrap_or("None")); | ||||
|             let server_name = sni.map(ServerName::from); | ||||
|             if server_name.is_none(){ | ||||
|               return Err(RpxyError::NoServerNameInClientHello); | ||||
|             } | ||||
|             let server_crypto = sc_map_inner.as_ref().unwrap().get(server_name.as_ref().unwrap()); | ||||
|             if server_crypto.is_none() { | ||||
|               return Err(RpxyError::NoTlsServingApp(server_name.as_ref().unwrap().try_into().unwrap_or_default())); | ||||
|             } | ||||
|             let stream = match start.into_stream(server_crypto.unwrap().clone()).await { | ||||
|               Ok(s) => TokioIo::new(s), | ||||
|               Err(e) => { | ||||
|                 return Err(RpxyError::FailedToTlsHandshake(e.to_string())); | ||||
|               } | ||||
|             }; | ||||
|             Ok((stream, client_addr, server_name)) | ||||
|           }; | ||||
| 
 | ||||
|           self.globals.runtime_handle.spawn( async move { | ||||
|             // timeout is introduced to avoid get stuck here.
 | ||||
|             let Ok(v) = timeout( | ||||
|               Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), | ||||
|               handshake_fut | ||||
|             ).await else { | ||||
|               error!("Timeout to handshake TLS"); | ||||
|               return; | ||||
|             }; | ||||
|             match v { | ||||
|               Ok((stream, client_addr, server_name)) => { | ||||
|                 self_inner.serve_connection(stream, client_addr, server_name); | ||||
|               } | ||||
|               Err(e) => { | ||||
|                 error!("{}", e); | ||||
|               } | ||||
|             } | ||||
|           }); | ||||
|         } | ||||
|         _ = server_crypto_rx.changed().fuse() => { | ||||
|           if server_crypto_rx.borrow().is_none() { | ||||
|             error!("Reloader is broken"); | ||||
|             break; | ||||
|           } | ||||
|           let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); | ||||
|           let Some(server_crypto): Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok() else { | ||||
|             error!("Failed to update server crypto"); | ||||
|             break; | ||||
|           }; | ||||
|           server_crypto_map = Some(server_crypto.inner_local_map.clone()); | ||||
|         } | ||||
|       } | ||||
|     } | ||||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   /// Entrypoint for HTTP/1.1, 2 and 3 servers
 | ||||
|   pub async fn start(&self) -> RpxyResult<()> { | ||||
|     let proxy_service = async { | ||||
|       if self.tls_enabled { | ||||
|         self.start_with_tls(server).await | ||||
|         self.start_with_tls().await | ||||
|       } else { | ||||
|         self.start_without_tls(server).await | ||||
|         self.start_without_tls().await | ||||
|       } | ||||
|     }; | ||||
| 
 | ||||
|     match term_notify { | ||||
|     match &self.globals.term_notify { | ||||
|       Some(term) => { | ||||
|         tokio::select! { | ||||
|           _ = proxy_service => { | ||||
|         select! { | ||||
|           _ = proxy_service.fuse() => { | ||||
|             warn!("Proxy service got down"); | ||||
|           } | ||||
|           _ = term.notified() => { | ||||
|             info!("Proxy service listening on {} receives term signal", listening_on); | ||||
|           _ = term.notified().fuse() => { | ||||
|             info!("Proxy service listening on {} receives term signal", self.listening_on); | ||||
|           } | ||||
|         } | ||||
|       } | ||||
|  | @ -159,8 +267,6 @@ where | |||
|       } | ||||
|     } | ||||
| 
 | ||||
|     // proxy_service.await?;
 | ||||
| 
 | ||||
|     Ok(()) | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -1,30 +1,32 @@ | |||
| use super::proxy_main::Proxy; | ||||
| use super::socket::bind_udp_socket; | ||||
| use super::{ | ||||
|   crypto_service::{ServerCrypto, ServerCryptoBase}, | ||||
|   proxy_main::Proxy, | ||||
| use crate::{ | ||||
|   crypto::{CryptoSource, ServerCrypto}, | ||||
|   error::*, | ||||
|   log::*, | ||||
|   name_exp::ByteName, | ||||
| }; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; | ||||
| use hot_reload::ReloaderReceiver; | ||||
| use hyper::client::connect::Connect; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; | ||||
| use rustls::ServerConfig; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| impl<T, U> Proxy<T, U> | ||||
| impl<U, T> Proxy<U, T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   T: Send + Sync + Connect + Clone + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn listener_service_h3( | ||||
|     &self, | ||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||
|   ) -> Result<()> { | ||||
|   pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { | ||||
|     let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { | ||||
|       return Err(RpxyError::NoCertificateReloader); | ||||
|     }; | ||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names [quinn]"); | ||||
|     // first set as null config server
 | ||||
|     let rustls_server_config = ServerConfig::builder() | ||||
|       .with_safe_default_cipher_suites() | ||||
|       .with_safe_default_kx_groups() | ||||
|       .with_protocol_versions(&[&rustls::version::TLS13])? | ||||
|       .with_protocol_versions(&[&rustls::version::TLS13]) | ||||
|       .map_err(|e| RpxyError::QuinnInvalidTlsProtocolVersion(e.to_string()))? | ||||
|       .with_no_client_auth() | ||||
|       .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); | ||||
| 
 | ||||
|  | @ -90,11 +92,11 @@ where | |||
|               }, | ||||
|               Err(e) => { | ||||
|                 warn!("QUIC accepting connection failed: {:?}", e); | ||||
|                 return Err(RpxyError::QuicConn(e)); | ||||
|                 return Err(RpxyError::QuinnConnectionFailed(e)); | ||||
|               } | ||||
|             }; | ||||
|             // Timeout is based on underlying quic
 | ||||
|             if let Err(e) = self_clone.connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr).await { | ||||
|             if let Err(e) = self_clone.h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr).await { | ||||
|               warn!("QUIC or HTTP/3 connection failed: {}", e); | ||||
|             }; | ||||
|             Ok(()) | ||||
|  | @ -119,6 +121,6 @@ where | |||
|       } | ||||
|     } | ||||
|     endpoint.wait_idle().await; | ||||
|     Ok(()) as Result<()> | ||||
|     Ok(()) as RpxyResult<()> | ||||
|   } | ||||
| } | ||||
|  |  | |||
|  | @ -1,22 +1,27 @@ | |||
| use super::{ | ||||
|   crypto_service::{ServerCrypto, ServerCryptoBase}, | ||||
|   proxy_main::Proxy, | ||||
| use super::proxy_main::Proxy; | ||||
| use crate::{ | ||||
|   crypto::CryptoSource, | ||||
|   crypto::{ServerCrypto, ServerCryptoBase}, | ||||
|   error::*, | ||||
|   log::*, | ||||
|   name_exp::ByteName, | ||||
| }; | ||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; | ||||
| use anyhow::anyhow; | ||||
| use hot_reload::ReloaderReceiver; | ||||
| use hyper::client::connect::Connect; | ||||
| use hyper_util::client::legacy::connect::Connect; | ||||
| use s2n_quic::provider; | ||||
| use std::sync::Arc; | ||||
| 
 | ||||
| impl<T, U> Proxy<T, U> | ||||
| impl<U, T> Proxy<U, T> | ||||
| where | ||||
|   T: Connect + Clone + Sync + Send + 'static, | ||||
|   U: CryptoSource + Clone + Sync + Send + 'static, | ||||
| { | ||||
|   pub(super) async fn listener_service_h3( | ||||
|     &self, | ||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||
|   ) -> Result<()> { | ||||
|   /// Start UDP proxy serving with HTTP/3 request for configured host names
 | ||||
|   pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { | ||||
|     let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { | ||||
|       return Err(RpxyError::NoCertificateReloader); | ||||
|     }; | ||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names [s2n-quic]"); | ||||
| 
 | ||||
|     // initially wait for receipt
 | ||||
|  | @ -29,7 +34,7 @@ where | |||
|     // event loop
 | ||||
|     loop { | ||||
|       tokio::select! { | ||||
|         v = self.serve_connection(&server_crypto) => { | ||||
|         v = self.h3_listener_service_inner(&server_crypto) => { | ||||
|           if let Err(e) = v { | ||||
|             error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); | ||||
|             break; | ||||
|  | @ -51,20 +56,25 @@ where | |||
|     Ok(()) | ||||
|   } | ||||
| 
 | ||||
|   fn receive_server_crypto(&self, server_crypto_rx: ReloaderReceiver<ServerCryptoBase>) -> Result<Arc<ServerCrypto>> { | ||||
|   /// Receive server crypto from reloader
 | ||||
|   fn receive_server_crypto( | ||||
|     &self, | ||||
|     server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||
|   ) -> RpxyResult<Arc<ServerCrypto>> { | ||||
|     let cert_keys_map = server_crypto_rx.borrow().clone().ok_or_else(|| { | ||||
|       error!("Reloader is broken"); | ||||
|       RpxyError::Other(anyhow!("Reloader is broken")) | ||||
|       RpxyError::CertificateReloadError(anyhow!("Reloader is broken").into()) | ||||
|     })?; | ||||
| 
 | ||||
|     let server_crypto: Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok(); | ||||
|     server_crypto.ok_or_else(|| { | ||||
|       error!("Failed to update server crypto for h3 [s2n-quic]"); | ||||
|       RpxyError::Other(anyhow!("Failed to update server crypto for h3 [s2n-quic]")) | ||||
|       RpxyError::FailedToUpdateServerCrypto("Failed to update server crypto for h3 [s2n-quic]".to_string()) | ||||
|     }) | ||||
|   } | ||||
| 
 | ||||
|   async fn serve_connection(&self, server_crypto: &Option<Arc<ServerCrypto>>) -> Result<()> { | ||||
|   /// Event loop for UDP proxy serving with HTTP/3 request for configured host names
 | ||||
|   async fn h3_listener_service_inner(&self, server_crypto: &Option<Arc<ServerCrypto>>) -> RpxyResult<()> { | ||||
|     // setup UDP socket
 | ||||
|     let io = provider::io::tokio::Builder::default() | ||||
|       .with_receive_address(self.listening_on)? | ||||
|  | @ -73,18 +83,13 @@ where | |||
| 
 | ||||
|     // setup limits
 | ||||
|     let mut limits = provider::limits::Limits::default() | ||||
|       .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64) | ||||
|       .map_err(|e| anyhow!(e))?; | ||||
|       .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? | ||||
|       .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? | ||||
|       .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? | ||||
|       .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? | ||||
|       .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64)?; | ||||
|     limits = if let Some(v) = self.globals.proxy_config.h3_max_idle_timeout { | ||||
|       limits.with_max_idle_timeout(v).map_err(|e| anyhow!(e))? | ||||
|       limits.with_max_idle_timeout(v)? | ||||
|     } else { | ||||
|       limits | ||||
|     }; | ||||
|  | @ -92,19 +97,17 @@ where | |||
|     // setup tls
 | ||||
|     let Some(server_crypto) = server_crypto else { | ||||
|       warn!("No server crypto is given [s2n-quic]"); | ||||
|       return Err(RpxyError::Other(anyhow!("No server crypto is given [s2n-quic]"))); | ||||
|       return Err(RpxyError::NoServerCrypto( | ||||
|         "No server crypto is given [s2n-quic]".to_string(), | ||||
|       )); | ||||
|     }; | ||||
|     let tls = server_crypto.inner_global_no_client_auth.clone(); | ||||
| 
 | ||||
|     let mut server = s2n_quic::Server::builder() | ||||
|       .with_tls(tls) | ||||
|       .map_err(|e| anyhow::anyhow!(e))? | ||||
|       .with_io(io) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .with_limits(limits) | ||||
|       .map_err(|e| anyhow!(e))? | ||||
|       .start() | ||||
|       .map_err(|e| anyhow!(e))?; | ||||
|       .with_tls(tls)? | ||||
|       .with_io(io)? | ||||
|       .with_limits(limits)? | ||||
|       .start()?; | ||||
| 
 | ||||
|     // quic event loop. this immediately cancels when crypto is updated by tokio::select!
 | ||||
|     while let Some(new_conn) = server.accept().await { | ||||
|  | @ -121,12 +124,12 @@ where | |||
|         let quic_connection = s2n_quic_h3::Connection::new(new_conn); | ||||
|         // Timeout is based on underlying quic
 | ||||
|         if let Err(e) = self_clone | ||||
|           .connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr) | ||||
|           .h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr) | ||||
|           .await | ||||
|         { | ||||
|           warn!("QUIC or HTTP/3 connection failed: {}", e); | ||||
|         }; | ||||
|         Ok(()) as Result<()> | ||||
|         Ok(()) as RpxyResult<()> | ||||
|       }); | ||||
|     } | ||||
| 
 | ||||
|  |  | |||
|  | @ -8,7 +8,7 @@ use tokio::net::TcpSocket; | |||
| 
 | ||||
| /// Bind TCP socket to the given `SocketAddr`, and returns the TCP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | ||||
| /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | ||||
| pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> Result<TcpSocket> { | ||||
| pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> RpxyResult<TcpSocket> { | ||||
|   let tcp_socket = if listening_on.is_ipv6() { | ||||
|     TcpSocket::new_v6() | ||||
|   } else { | ||||
|  | @ -26,7 +26,7 @@ pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> Result<TcpSocket> { | |||
| #[cfg(feature = "http3-quinn")] | ||||
| /// Bind UDP socket to the given `SocketAddr`, and returns the UDP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | ||||
| /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | ||||
| pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> Result<UdpSocket> { | ||||
| pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> RpxyResult<UdpSocket> { | ||||
|   let socket = if listening_on.is_ipv6() { | ||||
|     Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) | ||||
|   } else { | ||||
|  |  | |||
|  | @ -1 +1 @@ | |||
| Subproject commit b86df1220775d13b89cead99e787944b55991b1e | ||||
| Subproject commit 5c161952b02e663f31f9b83829bafa7a047b6627 | ||||
|  | @ -1,24 +0,0 @@ | |||
| [package] | ||||
| name = "h3-quinn" | ||||
| version = "0.0.4" | ||||
| rust-version = "1.63" | ||||
| authors = ["Jean-Christophe BEGUE <jc.begue@pm.me>"] | ||||
| edition = "2018" | ||||
| documentation = "https://docs.rs/h3-quinn" | ||||
| repository = "https://github.com/hyperium/h3" | ||||
| readme = "../README.md" | ||||
| description = "QUIC transport implementation based on Quinn." | ||||
| keywords = ["http3", "quic", "h3"] | ||||
| categories = ["network-programming", "web-programming"] | ||||
| license = "MIT" | ||||
| 
 | ||||
| [dependencies] | ||||
| h3 = { version = "0.0.3", path = "../h3/h3" } | ||||
| bytes = "1" | ||||
| quinn = { path = "../quinn/quinn/", default-features = false, features = [ | ||||
|   "futures-io", | ||||
| ] } | ||||
| quinn-proto = { path = "../quinn/quinn-proto/", default-features = false } | ||||
| tokio-util = { version = "0.7.9" } | ||||
| futures = { version = "0.3.28" } | ||||
| tokio = { version = "1.33.0", features = ["io-util"], default-features = false } | ||||
|  | @ -1,740 +0,0 @@ | |||
| //! QUIC Transport implementation with Quinn
 | ||||
| //!
 | ||||
| //! This module implements QUIC traits with Quinn.
 | ||||
| #![deny(missing_docs)] | ||||
| 
 | ||||
| use std::{ | ||||
|     convert::TryInto, | ||||
|     fmt::{self, Display}, | ||||
|     future::Future, | ||||
|     pin::Pin, | ||||
|     sync::Arc, | ||||
|     task::{self, Poll}, | ||||
| }; | ||||
| 
 | ||||
| use bytes::{Buf, Bytes, BytesMut}; | ||||
| 
 | ||||
| use futures::{ | ||||
|     ready, | ||||
|     stream::{self, BoxStream}, | ||||
|     StreamExt, | ||||
| }; | ||||
| use quinn::ReadDatagram; | ||||
| pub use quinn::{ | ||||
|     self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError, | ||||
| }; | ||||
| 
 | ||||
| use h3::{ | ||||
|     ext::Datagram, | ||||
|     quic::{self, Error, StreamId, WriteBuf}, | ||||
| }; | ||||
| use tokio_util::sync::ReusableBoxFuture; | ||||
| 
 | ||||
| /// A QUIC connection backed by Quinn
 | ||||
| ///
 | ||||
| /// Implements a [`quic::Connection`] backed by a [`quinn::Connection`].
 | ||||
| pub struct Connection { | ||||
|     conn: quinn::Connection, | ||||
|     incoming_bi: BoxStream<'static, <AcceptBi<'static> as Future>::Output>, | ||||
|     opening_bi: Option<BoxStream<'static, <OpenBi<'static> as Future>::Output>>, | ||||
|     incoming_uni: BoxStream<'static, <AcceptUni<'static> as Future>::Output>, | ||||
|     opening_uni: Option<BoxStream<'static, <OpenUni<'static> as Future>::Output>>, | ||||
|     datagrams: BoxStream<'static, <ReadDatagram<'static> as Future>::Output>, | ||||
| } | ||||
| 
 | ||||
| impl Connection { | ||||
|     /// Create a [`Connection`] from a [`quinn::NewConnection`]
 | ||||
|     pub fn new(conn: quinn::Connection) -> Self { | ||||
|         Self { | ||||
|             conn: conn.clone(), | ||||
|             incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async { | ||||
|                 Some((conn.accept_bi().await, conn)) | ||||
|             })), | ||||
|             opening_bi: None, | ||||
|             incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { | ||||
|                 Some((conn.accept_uni().await, conn)) | ||||
|             })), | ||||
|             opening_uni: None, | ||||
|             datagrams: Box::pin(stream::unfold(conn, |conn| async { | ||||
|                 Some((conn.read_datagram().await, conn)) | ||||
|             })), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The error type for [`Connection`]
 | ||||
| ///
 | ||||
| /// Wraps reasons a Quinn connection might be lost.
 | ||||
| #[derive(Debug)] | ||||
| pub struct ConnectionError(quinn::ConnectionError); | ||||
| 
 | ||||
| impl std::error::Error for ConnectionError {} | ||||
| 
 | ||||
| impl fmt::Display for ConnectionError { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         self.0.fmt(f) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Error for ConnectionError { | ||||
|     fn is_timeout(&self) -> bool { | ||||
|         matches!(self.0, quinn::ConnectionError::TimedOut) | ||||
|     } | ||||
| 
 | ||||
|     fn err_code(&self) -> Option<u64> { | ||||
|         match self.0 { | ||||
|             quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { | ||||
|                 error_code, | ||||
|                 .. | ||||
|             }) => Some(error_code.into_inner()), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<quinn::ConnectionError> for ConnectionError { | ||||
|     fn from(e: quinn::ConnectionError) -> Self { | ||||
|         Self(e) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Types of errors when sending a datagram.
 | ||||
| #[derive(Debug)] | ||||
| pub enum SendDatagramError { | ||||
|     /// Datagrams are not supported by the peer
 | ||||
|     UnsupportedByPeer, | ||||
|     /// Datagrams are locally disabled
 | ||||
|     Disabled, | ||||
|     /// The datagram was too large to be sent.
 | ||||
|     TooLarge, | ||||
|     /// Network error
 | ||||
|     ConnectionLost(Box<dyn Error>), | ||||
| } | ||||
| 
 | ||||
| impl fmt::Display for SendDatagramError { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         match self { | ||||
|             SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), | ||||
|             SendDatagramError::Disabled => write!(f, "datagram support disabled"), | ||||
|             SendDatagramError::TooLarge => write!(f, "datagram too large"), | ||||
|             SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl std::error::Error for SendDatagramError {} | ||||
| 
 | ||||
| impl Error for SendDatagramError { | ||||
|     fn is_timeout(&self) -> bool { | ||||
|         false | ||||
|     } | ||||
| 
 | ||||
|     fn err_code(&self) -> Option<u64> { | ||||
|         match self { | ||||
|             Self::ConnectionLost(err) => err.err_code(), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<quinn::SendDatagramError> for SendDatagramError { | ||||
|     fn from(value: quinn::SendDatagramError) -> Self { | ||||
|         match value { | ||||
|             quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, | ||||
|             quinn::SendDatagramError::Disabled => Self::Disabled, | ||||
|             quinn::SendDatagramError::TooLarge => Self::TooLarge, | ||||
|             quinn::SendDatagramError::ConnectionLost(err) => { | ||||
|                 Self::ConnectionLost(ConnectionError::from(err).into()) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<B> quic::Connection<B> for Connection | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     type SendStream = SendStream<B>; | ||||
|     type RecvStream = RecvStream; | ||||
|     type BidiStream = BidiStream<B>; | ||||
|     type OpenStreams = OpenStreams; | ||||
|     type Error = ConnectionError; | ||||
| 
 | ||||
|     fn poll_accept_bidi( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> { | ||||
|         let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) { | ||||
|             Some(x) => x?, | ||||
|             None => return Poll::Ready(Ok(None)), | ||||
|         }; | ||||
|         Poll::Ready(Ok(Some(Self::BidiStream { | ||||
|             send: Self::SendStream::new(send), | ||||
|             recv: Self::RecvStream::new(recv), | ||||
|         }))) | ||||
|     } | ||||
| 
 | ||||
|     fn poll_accept_recv( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> { | ||||
|         let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) { | ||||
|             Some(x) => x?, | ||||
|             None => return Poll::Ready(Ok(None)), | ||||
|         }; | ||||
|         Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) | ||||
|     } | ||||
| 
 | ||||
|     fn poll_open_bidi( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Self::BidiStream, Self::Error>> { | ||||
|         if self.opening_bi.is_none() { | ||||
|             self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { | ||||
|                 Some((conn.clone().open_bi().await, conn)) | ||||
|             }))); | ||||
|         } | ||||
| 
 | ||||
|         let (send, recv) = | ||||
|             ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; | ||||
|         Poll::Ready(Ok(Self::BidiStream { | ||||
|             send: Self::SendStream::new(send), | ||||
|             recv: Self::RecvStream::new(recv), | ||||
|         })) | ||||
|     } | ||||
| 
 | ||||
|     fn poll_open_send( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Self::SendStream, Self::Error>> { | ||||
|         if self.opening_uni.is_none() { | ||||
|             self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { | ||||
|                 Some((conn.open_uni().await, conn)) | ||||
|             }))); | ||||
|         } | ||||
| 
 | ||||
|         let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; | ||||
|         Poll::Ready(Ok(Self::SendStream::new(send))) | ||||
|     } | ||||
| 
 | ||||
|     fn opener(&self) -> Self::OpenStreams { | ||||
|         OpenStreams { | ||||
|             conn: self.conn.clone(), | ||||
|             opening_bi: None, | ||||
|             opening_uni: None, | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     fn close(&mut self, code: h3::error::Code, reason: &[u8]) { | ||||
|         self.conn.close( | ||||
|             VarInt::from_u64(code.value()).expect("error code VarInt"), | ||||
|             reason, | ||||
|         ); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<B> quic::SendDatagramExt<B> for Connection | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     type Error = SendDatagramError; | ||||
| 
 | ||||
|     fn send_datagram(&mut self, data: Datagram<B>) -> Result<(), SendDatagramError> { | ||||
|         // TODO investigate static buffer from known max datagram size
 | ||||
|         let mut buf = BytesMut::new(); | ||||
|         data.encode(&mut buf); | ||||
|         self.conn.send_datagram(buf.freeze())?; | ||||
| 
 | ||||
|         Ok(()) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl quic::RecvDatagramExt for Connection { | ||||
|     type Buf = Bytes; | ||||
| 
 | ||||
|     type Error = ConnectionError; | ||||
| 
 | ||||
|     #[inline] | ||||
|     fn poll_accept_datagram( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { | ||||
|         match ready!(self.datagrams.poll_next_unpin(cx)) { | ||||
|             Some(Ok(x)) => Poll::Ready(Ok(Some(x))), | ||||
|             Some(Err(e)) => Poll::Ready(Err(e.into())), | ||||
|             None => Poll::Ready(Ok(None)), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Stream opener backed by a Quinn connection
 | ||||
| ///
 | ||||
| /// Implements [`quic::OpenStreams`] using [`quinn::Connection`],
 | ||||
| /// [`quinn::OpenBi`], [`quinn::OpenUni`].
 | ||||
| pub struct OpenStreams { | ||||
|     conn: quinn::Connection, | ||||
|     opening_bi: Option<BoxStream<'static, <OpenBi<'static> as Future>::Output>>, | ||||
|     opening_uni: Option<BoxStream<'static, <OpenUni<'static> as Future>::Output>>, | ||||
| } | ||||
| 
 | ||||
| impl<B> quic::OpenStreams<B> for OpenStreams | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     type RecvStream = RecvStream; | ||||
|     type SendStream = SendStream<B>; | ||||
|     type BidiStream = BidiStream<B>; | ||||
|     type Error = ConnectionError; | ||||
| 
 | ||||
|     fn poll_open_bidi( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Self::BidiStream, Self::Error>> { | ||||
|         if self.opening_bi.is_none() { | ||||
|             self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { | ||||
|                 Some((conn.open_bi().await, conn)) | ||||
|             }))); | ||||
|         } | ||||
| 
 | ||||
|         let (send, recv) = | ||||
|             ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; | ||||
|         Poll::Ready(Ok(Self::BidiStream { | ||||
|             send: Self::SendStream::new(send), | ||||
|             recv: Self::RecvStream::new(recv), | ||||
|         })) | ||||
|     } | ||||
| 
 | ||||
|     fn poll_open_send( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Self::SendStream, Self::Error>> { | ||||
|         if self.opening_uni.is_none() { | ||||
|             self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { | ||||
|                 Some((conn.open_uni().await, conn)) | ||||
|             }))); | ||||
|         } | ||||
| 
 | ||||
|         let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; | ||||
|         Poll::Ready(Ok(Self::SendStream::new(send))) | ||||
|     } | ||||
| 
 | ||||
|     fn close(&mut self, code: h3::error::Code, reason: &[u8]) { | ||||
|         self.conn.close( | ||||
|             VarInt::from_u64(code.value()).expect("error code VarInt"), | ||||
|             reason, | ||||
|         ); | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Clone for OpenStreams { | ||||
|     fn clone(&self) -> Self { | ||||
|         Self { | ||||
|             conn: self.conn.clone(), | ||||
|             opening_bi: None, | ||||
|             opening_uni: None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Quinn-backed bidirectional stream
 | ||||
| ///
 | ||||
| /// Implements [`quic::BidiStream`] which allows the stream to be split
 | ||||
| /// into two structs each implementing one direction.
 | ||||
| pub struct BidiStream<B> | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     send: SendStream<B>, | ||||
|     recv: RecvStream, | ||||
| } | ||||
| 
 | ||||
| impl<B> quic::BidiStream<B> for BidiStream<B> | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     type SendStream = SendStream<B>; | ||||
|     type RecvStream = RecvStream; | ||||
| 
 | ||||
|     fn split(self) -> (Self::SendStream, Self::RecvStream) { | ||||
|         (self.send, self.recv) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<B: Buf> quic::RecvStream for BidiStream<B> { | ||||
|     type Buf = Bytes; | ||||
|     type Error = ReadError; | ||||
| 
 | ||||
|     fn poll_data( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { | ||||
|         self.recv.poll_data(cx) | ||||
|     } | ||||
| 
 | ||||
|     fn stop_sending(&mut self, error_code: u64) { | ||||
|         self.recv.stop_sending(error_code) | ||||
|     } | ||||
| 
 | ||||
|     fn recv_id(&self) -> StreamId { | ||||
|         self.recv.recv_id() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<B> quic::SendStream<B> for BidiStream<B> | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     type Error = SendStreamError; | ||||
| 
 | ||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
|         self.send.poll_ready(cx) | ||||
|     } | ||||
| 
 | ||||
|     fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
|         self.send.poll_finish(cx) | ||||
|     } | ||||
| 
 | ||||
|     fn reset(&mut self, reset_code: u64) { | ||||
|         self.send.reset(reset_code) | ||||
|     } | ||||
| 
 | ||||
|     fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> { | ||||
|         self.send.send_data(data) | ||||
|     } | ||||
| 
 | ||||
|     fn send_id(&self) -> StreamId { | ||||
|         self.send.send_id() | ||||
|     } | ||||
| } | ||||
| impl<B> quic::SendStreamUnframed<B> for BidiStream<B> | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     fn poll_send<D: Buf>( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|         buf: &mut D, | ||||
|     ) -> Poll<Result<usize, Self::Error>> { | ||||
|         self.send.poll_send(cx, buf) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Quinn-backed receive stream
 | ||||
| ///
 | ||||
| /// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`].
 | ||||
| pub struct RecvStream { | ||||
|     stream: Option<quinn::RecvStream>, | ||||
|     read_chunk_fut: ReadChunkFuture, | ||||
| } | ||||
| 
 | ||||
| type ReadChunkFuture = ReusableBoxFuture< | ||||
|     'static, | ||||
|     ( | ||||
|         quinn::RecvStream, | ||||
|         Result<Option<quinn::Chunk>, quinn::ReadError>, | ||||
|     ), | ||||
| >; | ||||
| 
 | ||||
| impl RecvStream { | ||||
|     fn new(stream: quinn::RecvStream) -> Self { | ||||
|         Self { | ||||
|             stream: Some(stream), | ||||
|             // Should only allocate once the first time it's used
 | ||||
|             read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl quic::RecvStream for RecvStream { | ||||
|     type Buf = Bytes; | ||||
|     type Error = ReadError; | ||||
| 
 | ||||
|     fn poll_data( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { | ||||
|         if let Some(mut stream) = self.stream.take() { | ||||
|             self.read_chunk_fut.set(async move { | ||||
|                 let chunk = stream.read_chunk(usize::MAX, true).await; | ||||
|                 (stream, chunk) | ||||
|             }) | ||||
|         }; | ||||
| 
 | ||||
|         let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx)); | ||||
|         self.stream = Some(stream); | ||||
|         Poll::Ready(Ok(chunk?.map(|c| c.bytes))) | ||||
|     } | ||||
| 
 | ||||
|     fn stop_sending(&mut self, error_code: u64) { | ||||
|         self.stream | ||||
|             .as_mut() | ||||
|             .unwrap() | ||||
|             .stop(VarInt::from_u64(error_code).expect("invalid error_code")) | ||||
|             .ok(); | ||||
|     } | ||||
| 
 | ||||
|     fn recv_id(&self) -> StreamId { | ||||
|         self.stream | ||||
|             .as_ref() | ||||
|             .unwrap() | ||||
|             .id() | ||||
|             .0 | ||||
|             .try_into() | ||||
|             .expect("invalid stream id") | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The error type for [`RecvStream`]
 | ||||
| ///
 | ||||
| /// Wraps errors that occur when reading from a receive stream.
 | ||||
| #[derive(Debug)] | ||||
| pub struct ReadError(quinn::ReadError); | ||||
| 
 | ||||
| impl From<ReadError> for std::io::Error { | ||||
|     fn from(value: ReadError) -> Self { | ||||
|         value.0.into() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl std::error::Error for ReadError { | ||||
|     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { | ||||
|         self.0.source() | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl fmt::Display for ReadError { | ||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||
|         self.0.fmt(f) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<ReadError> for Arc<dyn Error> { | ||||
|     fn from(e: ReadError) -> Self { | ||||
|         Arc::new(e) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<quinn::ReadError> for ReadError { | ||||
|     fn from(e: quinn::ReadError) -> Self { | ||||
|         Self(e) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Error for ReadError { | ||||
|     fn is_timeout(&self) -> bool { | ||||
|         matches!( | ||||
|             self.0, | ||||
|             quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut) | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|     fn err_code(&self) -> Option<u64> { | ||||
|         match self.0 { | ||||
|             quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( | ||||
|                 quinn_proto::ApplicationClose { error_code, .. }, | ||||
|             )) => Some(error_code.into_inner()), | ||||
|             quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// Quinn-backed send stream
 | ||||
| ///
 | ||||
| /// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`].
 | ||||
| pub struct SendStream<B: Buf> { | ||||
|     stream: Option<quinn::SendStream>, | ||||
|     writing: Option<WriteBuf<B>>, | ||||
|     write_fut: WriteFuture, | ||||
| } | ||||
| 
 | ||||
| type WriteFuture = | ||||
|     ReusableBoxFuture<'static, (quinn::SendStream, Result<usize, quinn::WriteError>)>; | ||||
| 
 | ||||
| impl<B> SendStream<B> | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     fn new(stream: quinn::SendStream) -> SendStream<B> { | ||||
|         Self { | ||||
|             stream: Some(stream), | ||||
|             writing: None, | ||||
|             write_fut: ReusableBoxFuture::new(async { unreachable!() }), | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<B> quic::SendStream<B> for SendStream<B> | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     type Error = SendStreamError; | ||||
| 
 | ||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
|         if let Some(ref mut data) = self.writing { | ||||
|             while data.has_remaining() { | ||||
|                 if let Some(mut stream) = self.stream.take() { | ||||
|                     let chunk = data.chunk().to_owned(); // FIXME - avoid copy
 | ||||
|                     self.write_fut.set(async move { | ||||
|                         let ret = stream.write(&chunk).await; | ||||
|                         (stream, ret) | ||||
|                     }); | ||||
|                 } | ||||
| 
 | ||||
|                 let (stream, res) = ready!(self.write_fut.poll(cx)); | ||||
|                 self.stream = Some(stream); | ||||
|                 match res { | ||||
|                     Ok(cnt) => data.advance(cnt), | ||||
|                     Err(err) => { | ||||
|                         return Poll::Ready(Err(SendStreamError::Write(err))); | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         self.writing = None; | ||||
|         Poll::Ready(Ok(())) | ||||
|     } | ||||
| 
 | ||||
|     fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||
|         self.stream | ||||
|             .as_mut() | ||||
|             .unwrap() | ||||
|             .poll_finish(cx) | ||||
|             .map_err(Into::into) | ||||
|     } | ||||
| 
 | ||||
|     fn reset(&mut self, reset_code: u64) { | ||||
|         let _ = self | ||||
|             .stream | ||||
|             .as_mut() | ||||
|             .unwrap() | ||||
|             .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); | ||||
|     } | ||||
| 
 | ||||
|     fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> { | ||||
|         if self.writing.is_some() { | ||||
|             return Err(Self::Error::NotReady); | ||||
|         } | ||||
|         self.writing = Some(data.into()); | ||||
|         Ok(()) | ||||
|     } | ||||
| 
 | ||||
|     fn send_id(&self) -> StreamId { | ||||
|         self.stream | ||||
|             .as_ref() | ||||
|             .unwrap() | ||||
|             .id() | ||||
|             .0 | ||||
|             .try_into() | ||||
|             .expect("invalid stream id") | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl<B> quic::SendStreamUnframed<B> for SendStream<B> | ||||
| where | ||||
|     B: Buf, | ||||
| { | ||||
|     fn poll_send<D: Buf>( | ||||
|         &mut self, | ||||
|         cx: &mut task::Context<'_>, | ||||
|         buf: &mut D, | ||||
|     ) -> Poll<Result<usize, Self::Error>> { | ||||
|         if self.writing.is_some() { | ||||
|             // This signifies a bug in implementation
 | ||||
|             panic!("poll_send called while send stream is not ready") | ||||
|         } | ||||
| 
 | ||||
|         let s = Pin::new(self.stream.as_mut().unwrap()); | ||||
| 
 | ||||
|         let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); | ||||
|         match res { | ||||
|             Ok(written) => { | ||||
|                 buf.advance(written); | ||||
|                 Poll::Ready(Ok(written)) | ||||
|             } | ||||
|             Err(err) => { | ||||
|                 // We are forced to use AsyncWrite for now because we cannot store
 | ||||
|                 // the result of a call to:
 | ||||
|                 // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result<usize, WriteError>.
 | ||||
|                 //
 | ||||
|                 // This is why we have to unpack the error from io::Error instead of having it
 | ||||
|                 // returned directly. This should not panic as long as quinn's AsyncWrite impl
 | ||||
|                 // doesn't change.
 | ||||
|                 let err = err | ||||
|                     .into_inner() | ||||
|                     .expect("write stream returned an empty error") | ||||
|                     .downcast::<WriteError>() | ||||
|                     .expect("write stream returned an error which type is not WriteError"); | ||||
| 
 | ||||
|                 Poll::Ready(Err(SendStreamError::Write(*err))) | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| /// The error type for [`SendStream`]
 | ||||
| ///
 | ||||
| /// Wraps errors that can happen writing to or polling a send stream.
 | ||||
| #[derive(Debug)] | ||||
| pub enum SendStreamError { | ||||
|     /// Errors when writing, wrapping a [`quinn::WriteError`]
 | ||||
|     Write(WriteError), | ||||
|     /// Error when the stream is not ready, because it is still sending
 | ||||
|     /// data from a previous call
 | ||||
|     NotReady, | ||||
| } | ||||
| 
 | ||||
| impl From<SendStreamError> for std::io::Error { | ||||
|     fn from(value: SendStreamError) -> Self { | ||||
|         match value { | ||||
|             SendStreamError::Write(err) => err.into(), | ||||
|             SendStreamError::NotReady => { | ||||
|                 std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl std::error::Error for SendStreamError {} | ||||
| 
 | ||||
| impl Display for SendStreamError { | ||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||
|         write!(f, "{:?}", self) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<WriteError> for SendStreamError { | ||||
|     fn from(e: WriteError) -> Self { | ||||
|         Self::Write(e) | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl Error for SendStreamError { | ||||
|     fn is_timeout(&self) -> bool { | ||||
|         matches!( | ||||
|             self, | ||||
|             Self::Write(quinn::WriteError::ConnectionLost( | ||||
|                 quinn::ConnectionError::TimedOut | ||||
|             )) | ||||
|         ) | ||||
|     } | ||||
| 
 | ||||
|     fn err_code(&self) -> Option<u64> { | ||||
|         match self { | ||||
|             Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), | ||||
|             Self::Write(quinn::WriteError::ConnectionLost( | ||||
|                 quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { | ||||
|                     error_code, | ||||
|                     .. | ||||
|                 }), | ||||
|             )) => Some(error_code.into_inner()), | ||||
|             _ => None, | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| impl From<SendStreamError> for Arc<dyn Error> { | ||||
|     fn from(e: SendStreamError) -> Self { | ||||
|         Arc::new(e) | ||||
|     } | ||||
| } | ||||
|  | @ -1 +0,0 @@ | |||
| Subproject commit 6d80efeeae60b96ff330ae6a70e8cc9291fcc615 | ||||
|  | @ -1 +1 @@ | |||
| Subproject commit 3cd09170305753309d86e88b9427827cca0de0dd | ||||
| Subproject commit 88d23c2f5a3ac36295dff4a804968c43932ba46b | ||||
|  | @ -1 +0,0 @@ | |||
| Subproject commit 30027eeacc7b620da62fc4825b94afd57ab0c7be | ||||
							
								
								
									
										17
									
								
								submodules/s2n-quic-h3/Cargo.toml
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								submodules/s2n-quic-h3/Cargo.toml
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,17 @@ | |||
| [package] | ||||
| name = "s2n-quic-h3" | ||||
| # this in an unpublished internal crate so the version should not be changed | ||||
| version = "0.1.0" | ||||
| authors = ["AWS s2n"] | ||||
| edition = "2021" | ||||
| rust-version = "1.63" | ||||
| license = "Apache-2.0" | ||||
| # this contains an http3 implementation for testing purposes and should not be published | ||||
| publish = false | ||||
| 
 | ||||
| [dependencies] | ||||
| bytes = { version = "1", default-features = false } | ||||
| futures = { version = "0.3", default-features = false } | ||||
| h3 = { path = "../h3/h3/" } | ||||
| s2n-quic = "1.32.0" | ||||
| s2n-quic-core = "0.32.0" | ||||
							
								
								
									
										10
									
								
								submodules/s2n-quic-h3/README.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								submodules/s2n-quic-h3/README.md
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,10 @@ | |||
| # s2n-quic-h3 | ||||
| 
 | ||||
| This is an internal crate used by [s2n-quic](https://github.com/aws/s2n-quic) written as a proof of concept for implementing HTTP3 on top of s2n-quic. The API is not currently stable and should not be used directly. | ||||
| 
 | ||||
| ## License | ||||
| 
 | ||||
| This project is licensed under the [Apache-2.0 License][license-url]. | ||||
| 
 | ||||
| [license-badge]: https://img.shields.io/badge/license-apache-blue.svg | ||||
| [license-url]: https://aws.amazon.com/apache-2-0/ | ||||
							
								
								
									
										7
									
								
								submodules/s2n-quic-h3/src/lib.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								submodules/s2n-quic-h3/src/lib.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,7 @@ | |||
| // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 | ||||
| // SPDX-License-Identifier: Apache-2.0
 | ||||
| 
 | ||||
| mod s2n_quic; | ||||
| 
 | ||||
| pub use self::s2n_quic::*; | ||||
| pub use h3; | ||||
Some files were not shown because too many files have changed in this diff Show more
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara