commit
				
					
						dd0d88d7c0
					
				
			
		
					 83 changed files with 4167 additions and 3122 deletions
				
			
		|  | @ -4,3 +4,4 @@ bench/ | ||||||
| .private/ | .private/ | ||||||
| .github/ | .github/ | ||||||
| example-certs/ | example-certs/ | ||||||
|  | legacy-lib/ | ||||||
|  |  | ||||||
							
								
								
									
										35
									
								
								.github/workflows/release.yml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										35
									
								
								.github/workflows/release.yml
									
										
									
									
										vendored
									
									
								
							|  | @ -45,34 +45,34 @@ jobs: | ||||||
|             tags-suffix: "-s2n" |             tags-suffix: "-s2n" | ||||||
| 
 | 
 | ||||||
|           - target: "gnu" |           - target: "gnu" | ||||||
|             build-feature: "-native-roots" |             build-feature: "-webpki-roots" | ||||||
|             platform: linux/amd64 |             platform: linux/amd64 | ||||||
|             tags-suffix: "-native-roots" |             tags-suffix: "-webpki-roots" | ||||||
| 
 | 
 | ||||||
|           - target: "gnu" |           - target: "gnu" | ||||||
|             build-feature: "-native-roots" |             build-feature: "-webpki-roots" | ||||||
|             platform: linux/arm64 |             platform: linux/arm64 | ||||||
|             tags-suffix: "-native-roots" |             tags-suffix: "-webpki-roots" | ||||||
| 
 | 
 | ||||||
|           - target: "musl" |           - target: "musl" | ||||||
|             build-feature: "-native-roots" |             build-feature: "-webpki-roots" | ||||||
|             platform: linux/amd64 |             platform: linux/amd64 | ||||||
|             tags-suffix: "-slim-native-roots" |             tags-suffix: "-slim-webpki-roots" | ||||||
| 
 | 
 | ||||||
|           - target: "musl" |           - target: "musl" | ||||||
|             build-feature: "-native-roots" |             build-feature: "-webpki-roots" | ||||||
|             platform: linux/arm64 |             platform: linux/arm64 | ||||||
|             tags-suffix: "-slim-native-roots" |             tags-suffix: "-slim-webpki-roots" | ||||||
| 
 | 
 | ||||||
|           - target: "gnu" |           - target: "gnu" | ||||||
|             build-feature: "-s2n-native-roots" |             build-feature: "-s2n-webpki-roots" | ||||||
|             platform: linux/amd64 |             platform: linux/amd64 | ||||||
|             tags-suffix: "-s2n-native-roots" |             tags-suffix: "-s2n-webpki-roots" | ||||||
| 
 | 
 | ||||||
|           - target: "gnu" |           - target: "gnu" | ||||||
|             build-feature: "-s2n-native-roots" |             build-feature: "-s2n-webpki-roots" | ||||||
|             platform: linux/arm64 |             platform: linux/arm64 | ||||||
|             tags-suffix: "-s2n-native-roots" |             tags-suffix: "-s2n-webpki-roots" | ||||||
| 
 | 
 | ||||||
|     steps: |     steps: | ||||||
|       - run: "echo 'The relese triggering workflows passed'" |       - run: "echo 'The relese triggering workflows passed'" | ||||||
|  | @ -81,10 +81,9 @@ jobs: | ||||||
|         id: "set-env" |         id: "set-env" | ||||||
|         run: | |         run: | | ||||||
|           if [ ${{ matrix.platform }} == 'linux/amd64' ]; then PLATFORM_MAP="x86_64"; else PLATFORM_MAP="aarch64"; fi |           if [ ${{ matrix.platform }} == 'linux/amd64' ]; then PLATFORM_MAP="x86_64"; else PLATFORM_MAP="aarch64"; fi | ||||||
|           if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_NAME="-nightly"; else BUILD_NAME=""; fi |           if [ ${{ github.ref_name }} == 'main' ]; then BUILD_IMG="latest"; else BUILD_IMG="nightly"; fi | ||||||
|           if [ ${{ github.ref_name }} == 'develop' ]; then BUILD_IMG="nightly"; else BUILD_IMG="latest"; fi |  | ||||||
|           echo "build_img=${BUILD_IMG}" >> $GITHUB_OUTPUT |           echo "build_img=${BUILD_IMG}" >> $GITHUB_OUTPUT | ||||||
|           echo "target_name=rpxy${BUILD_NAME}-${PLATFORM_MAP}-unknown-linux-${{ matrix.target }}${{ matrix.build-feature }}" >> $GITHUB_OUTPUT |           echo "target_name=rpxy-${PLATFORM_MAP}-unknown-linux-${{ matrix.target }}${{ matrix.build-feature }}" >> $GITHUB_OUTPUT | ||||||
| 
 | 
 | ||||||
|       - name: "docker pull and extract binary from docker image" |       - name: "docker pull and extract binary from docker image" | ||||||
|         id: "extract-binary" |         id: "extract-binary" | ||||||
|  | @ -93,7 +92,7 @@ jobs: | ||||||
|           docker cp ${CONTAINER_ID}:/rpxy/bin/rpxy /tmp/${{ steps.set-env.outputs.target_name }} |           docker cp ${CONTAINER_ID}:/rpxy/bin/rpxy /tmp/${{ steps.set-env.outputs.target_name }} | ||||||
| 
 | 
 | ||||||
|       - name: "upload artifacts" |       - name: "upload artifacts" | ||||||
|         uses: actions/upload-artifact@v3 |         uses: actions/upload-artifact@v4 | ||||||
|         with: |         with: | ||||||
|           name: ${{ steps.set-env.outputs.target_name }} |           name: ${{ steps.set-env.outputs.target_name }} | ||||||
|           path: "/tmp/${{ steps.set-env.outputs.target_name }}" |           path: "/tmp/${{ steps.set-env.outputs.target_name }}" | ||||||
|  | @ -110,7 +109,7 @@ jobs: | ||||||
|     needs: on-success |     needs: on-success | ||||||
|     steps: |     steps: | ||||||
|       - name: check pull_request title |       - name: check pull_request title | ||||||
|         uses: kaisugi/action-regex-match@v1.0.0 |         uses: kaisugi/action-regex-match@v1.0.1 | ||||||
|         id: regex-match |         id: regex-match | ||||||
|         with: |         with: | ||||||
|           text: ${{ github.event.client_payload.pull_request.title }} |           text: ${{ github.event.client_payload.pull_request.title }} | ||||||
|  | @ -122,7 +121,7 @@ jobs: | ||||||
| 
 | 
 | ||||||
|       - name: download artifacts |       - name: download artifacts | ||||||
|         if: ${{ steps.regex-match.outputs.match != ''}} |         if: ${{ steps.regex-match.outputs.match != ''}} | ||||||
|         uses: actions/download-artifact@v3 |         uses: actions/download-artifact@v4 | ||||||
|         with: |         with: | ||||||
|           path: /tmp/rpxy |           path: /tmp/rpxy | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										52
									
								
								.github/workflows/release_docker.yml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										52
									
								
								.github/workflows/release_docker.yml
									
										
									
									
										vendored
									
									
								
							|  | @ -2,6 +2,7 @@ name: Release - Build and publish docker, and trigger package release | ||||||
| on: | on: | ||||||
|   push: |   push: | ||||||
|     branches: |     branches: | ||||||
|  |       - "feat/*" | ||||||
|       - "develop" |       - "develop" | ||||||
|   pull_request: |   pull_request: | ||||||
|     types: [closed] |     types: [closed] | ||||||
|  | @ -44,7 +45,7 @@ jobs: | ||||||
|           - target: "s2n" |           - target: "s2n" | ||||||
|             dockerfile: ./docker/Dockerfile |             dockerfile: ./docker/Dockerfile | ||||||
|             build-args: | |             build-args: | | ||||||
|               "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache" |               "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,rustls-backend" | ||||||
|               "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" |               "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" | ||||||
|             platforms: linux/amd64,linux/arm64 |             platforms: linux/amd64,linux/arm64 | ||||||
|             tags-suffix: "-s2n" |             tags-suffix: "-s2n" | ||||||
|  | @ -53,42 +54,42 @@ jobs: | ||||||
|               jqtype/rpxy:s2n |               jqtype/rpxy:s2n | ||||||
|               ghcr.io/junkurihara/rust-rpxy:s2n |               ghcr.io/junkurihara/rust-rpxy:s2n | ||||||
| 
 | 
 | ||||||
|           - target: "native-roots" |           - target: "webpki-roots" | ||||||
|             dockerfile: ./docker/Dockerfile |             dockerfile: ./docker/Dockerfile | ||||||
|             platforms: linux/amd64,linux/arm64 |             platforms: linux/amd64,linux/arm64 | ||||||
|             build-args: | |             build-args: | | ||||||
|               "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" |               "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,webpki-roots" | ||||||
|             tags-suffix: "-native-roots" |             tags-suffix: "-webpki-roots" | ||||||
|             # Aliases must be used only for release builds |             # Aliases must be used only for release builds | ||||||
|             aliases: | |             aliases: | | ||||||
|               jqtype/rpxy:native-roots |               jqtype/rpxy:webpki-roots | ||||||
|               ghcr.io/junkurihara/rust-rpxy:native-roots |               ghcr.io/junkurihara/rust-rpxy:webpki-roots | ||||||
| 
 | 
 | ||||||
|           - target: "slim-native-roots" |           - target: "slim-webpki-roots" | ||||||
|             dockerfile: ./docker/Dockerfile-slim |             dockerfile: ./docker/Dockerfile-slim | ||||||
|             build-args: | |             build-args: | | ||||||
|               "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,native-roots" |               "CARGO_FEATURES=--no-default-features --features=http3-quinn,cache,webpki-roots" | ||||||
|             build-contexts: | |             build-contexts: | | ||||||
|               messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl |               messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl | ||||||
|               messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl |               messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl | ||||||
|             platforms: linux/amd64,linux/arm64 |             platforms: linux/amd64,linux/arm64 | ||||||
|             tags-suffix: "-slim-native-roots" |             tags-suffix: "-slim-webpki-roots" | ||||||
|             # Aliases must be used only for release builds |             # Aliases must be used only for release builds | ||||||
|             aliases: | |             aliases: | | ||||||
|               jqtype/rpxy:slim-native-roots |               jqtype/rpxy:slim-webpki-roots | ||||||
|               ghcr.io/junkurihara/rust-rpxy:slim-native-roots |               ghcr.io/junkurihara/rust-rpxy:slim-webpki-roots | ||||||
| 
 | 
 | ||||||
|           - target: "s2n-native-roots" |           - target: "s2n-webpki-roots" | ||||||
|             dockerfile: ./docker/Dockerfile |             dockerfile: ./docker/Dockerfile | ||||||
|             build-args: | |             build-args: | | ||||||
|               "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,native-roots" |               "CARGO_FEATURES=--no-default-features --features=http3-s2n,cache,webpki-roots" | ||||||
|               "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" |               "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" | ||||||
|             platforms: linux/amd64,linux/arm64 |             platforms: linux/amd64,linux/arm64 | ||||||
|             tags-suffix: "-s2n-native-roots" |             tags-suffix: "-s2n-webpki-roots" | ||||||
|             # Aliases must be used only for release builds |             # Aliases must be used only for release builds | ||||||
|             aliases: | |             aliases: | | ||||||
|               jqtype/rpxy:s2n-native-roots |               jqtype/rpxy:s2n-webpki-roots | ||||||
|               ghcr.io/junkurihara/rust-rpxy:s2n-native-roots |               ghcr.io/junkurihara/rust-rpxy:s2n-webpki-roots | ||||||
| 
 | 
 | ||||||
|     steps: |     steps: | ||||||
|       - name: Checkout |       - name: Checkout | ||||||
|  | @ -135,6 +136,23 @@ jobs: | ||||||
|       #     platforms: linux/amd64 |       #     platforms: linux/amd64 | ||||||
|       #     labels: ${{ steps.meta.outputs.labels }} |       #     labels: ${{ steps.meta.outputs.labels }} | ||||||
| 
 | 
 | ||||||
|  |       - name: Unstable build and push from develop branch | ||||||
|  |         if: ${{ startsWith(github.ref_name, 'feat/') && (github.event_name == 'push') }} | ||||||
|  |         uses: docker/build-push-action@v5 | ||||||
|  |         with: | ||||||
|  |           context: . | ||||||
|  |           build-args: ${{ matrix.build-args }} | ||||||
|  |           push: true | ||||||
|  |           tags: | | ||||||
|  |             ${{ env.GHCR }}/${{ env.GHCR_IMAGE_NAME }}:unstable${{ matrix.tags-suffix }} | ||||||
|  |             ${{ env.DH_REGISTRY_NAME }}:unstable${{ matrix.tags-suffix }} | ||||||
|  |           build-contexts: ${{ matrix.build-contexts }} | ||||||
|  |           file: ${{ matrix.dockerfile }} | ||||||
|  |           cache-from: type=gha,scope=rpxy-unstable-${{ matrix.target }} | ||||||
|  |           cache-to: type=gha,mode=max,scope=rpxy-unstable-${{ matrix.target }} | ||||||
|  |           platforms: linux/amd64 | ||||||
|  |           labels: ${{ steps.meta.outputs.labels }} | ||||||
|  | 
 | ||||||
|       - name: Nightly build and push from develop branch |       - name: Nightly build and push from develop branch | ||||||
|         if: ${{ (github.ref_name == 'develop') && (github.event_name == 'push') }} |         if: ${{ (github.ref_name == 'develop') && (github.event_name == 'push') }} | ||||||
|         uses: docker/build-push-action@v5 |         uses: docker/build-push-action@v5 | ||||||
|  | @ -176,7 +194,7 @@ jobs: | ||||||
|     needs: build_and_push |     needs: build_and_push | ||||||
|     steps: |     steps: | ||||||
|       - name: Repository dispatch for release |       - name: Repository dispatch for release | ||||||
|         uses: peter-evans/repository-dispatch@v2 |         uses: peter-evans/repository-dispatch@v3 | ||||||
|         with: |         with: | ||||||
|           event-type: release-event |           event-type: release-event | ||||||
|           client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "pull_request": { "title": "${{ github.event.pull_request.title }}", "body": ${{ toJson(github.event.pull_request.body) }}, "number": "${{ github.event.pull_request.number }}", "head": "${{ github.event.pull_request.head.ref }}", "base": "${{ github.event.pull_request.base.ref}}"}}' |           client-payload: '{"ref": "${{ github.ref }}", "sha": "${{ github.sha }}", "pull_request": { "title": "${{ github.event.pull_request.title }}", "body": ${{ toJson(github.event.pull_request.body) }}, "number": "${{ github.event.pull_request.number }}", "head": "${{ github.event.pull_request.head.ref }}", "base": "${{ github.event.pull_request.base.ref}}"}}' | ||||||
|  |  | ||||||
							
								
								
									
										6
									
								
								.gitmodules
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.gitmodules
									
										
									
									
										vendored
									
									
								
							|  | @ -1,12 +1,6 @@ | ||||||
| [submodule "submodules/h3"] | [submodule "submodules/h3"] | ||||||
| 	path = submodules/h3 | 	path = submodules/h3 | ||||||
| 	url = git@github.com:junkurihara/h3.git | 	url = git@github.com:junkurihara/h3.git | ||||||
| [submodule "submodules/quinn"] |  | ||||||
| 	path = submodules/quinn |  | ||||||
| 	url = git@github.com:junkurihara/quinn.git |  | ||||||
| [submodule "submodules/s2n-quic"] |  | ||||||
| 	path = submodules/s2n-quic |  | ||||||
| 	url = git@github.com:junkurihara/s2n-quic.git |  | ||||||
| [submodule "submodules/rusty-http-cache-semantics"] | [submodule "submodules/rusty-http-cache-semantics"] | ||||||
| 	path = submodules/rusty-http-cache-semantics | 	path = submodules/rusty-http-cache-semantics | ||||||
| 	url = git@github.com:junkurihara/rusty-http-cache-semantics.git | 	url = git@github.com:junkurihara/rusty-http-cache-semantics.git | ||||||
|  |  | ||||||
							
								
								
									
										18
									
								
								CHANGELOG.md
									
										
									
									
									
								
							
							
						
						
									
										18
									
								
								CHANGELOG.md
									
										
									
									
									
								
							|  | @ -1,6 +1,22 @@ | ||||||
| # CHANGELOG | # CHANGELOG | ||||||
| 
 | 
 | ||||||
| ## 0.7.0  (unreleased) | ## 0.8.0 (Unreleased) | ||||||
|  | 
 | ||||||
|  | ## 0.7.0 | ||||||
|  | 
 | ||||||
|  | ### Important Changes | ||||||
|  | 
 | ||||||
|  | - Breaking: `hyper`-1.0 for both server and client modules. | ||||||
|  | - Breaking: Remove `override_host` option in upstream options. Add a reverse option, i.e., `keep_original_host`, and the similar option `set_upstream_host`. While `keep_original_host` can be explicitly specified, `rpxy` keeps the original `host` given by the incoming request by default. Then, the original `host` header is maintained or added from the value of url request line. If `host` header needs to be overridden with the upstream host name (backend uri's host name), `set_upstream_host` has to be set. If both of `set_upstream_host` and `keep_original_host` are set, `keep_original_host` is prioritized since it is explicitly specified. | ||||||
|  | - Breaking: Introduced `native-tls-backend` feature to use the native TLS engine to access backend applications. | ||||||
|  | - Breaking: Changed the policy of the default cert store from `webpki` to the system-native store. Thus we terminated the feature `native-roots` and introduced `webpki-roots` feature to use `webpki` root cert store. | ||||||
|  | 
 | ||||||
|  | ### Improvement | ||||||
|  | 
 | ||||||
|  | - Redesigned: Cache structure is totally redesigned with more memory-efficient way to read from cache file, and more secure way to strongly bind memory-objects with files with hash values. | ||||||
|  | - Redesigned: HTTP body handling flow is also redesigned with more memory-and-time efficient techniques without putting the whole objects on memory by using `futures::stream::Stream` and `futures::channel::mpsc` | ||||||
|  | - Feat: Allow to disable/enable forced-connection-timeout regardless of connection status (idle or not). [default: disabled] | ||||||
|  | - Refactor: lots of minor improvements | ||||||
| 
 | 
 | ||||||
| ## 0.6.2 | ## 0.6.2 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
							
								
								
									
										11
									
								
								Cargo.toml
									
										
									
									
									
								
							
							
						
						
									
										11
									
								
								Cargo.toml
									
										
									
									
									
								
							|  | @ -1,5 +1,14 @@ | ||||||
| [workspace] | [workspace.package] | ||||||
|  | version = "0.7.0" | ||||||
|  | authors = ["Jun Kurihara"] | ||||||
|  | homepage = "https://github.com/junkurihara/rust-rpxy" | ||||||
|  | repository = "https://github.com/junkurihara/rust-rpxy" | ||||||
|  | license = "MIT" | ||||||
|  | readme = "./README.md" | ||||||
|  | edition = "2021" | ||||||
|  | publish = false | ||||||
| 
 | 
 | ||||||
|  | [workspace] | ||||||
| members = ["rpxy-bin", "rpxy-lib"] | members = ["rpxy-bin", "rpxy-lib"] | ||||||
| exclude = ["submodules"] | exclude = ["submodules"] | ||||||
| resolver = "2" | resolver = "2" | ||||||
|  |  | ||||||
							
								
								
									
										2
									
								
								LICENSE
									
										
									
									
									
								
							
							
						
						
									
										2
									
								
								LICENSE
									
										
									
									
									
								
							|  | @ -1,6 +1,6 @@ | ||||||
| MIT License | MIT License | ||||||
| 
 | 
 | ||||||
| Copyright (c) 2023 Jun Kurihara | Copyright (c) 2024 Jun Kurihara | ||||||
| 
 | 
 | ||||||
| Permission is hereby granted, free of charge, to any person obtaining a copy | Permission is hereby granted, free of charge, to any person obtaining a copy | ||||||
| of this software and associated documentation files (the "Software"), to deal | of this software and associated documentation files (the "Software"), to deal | ||||||
|  |  | ||||||
|  | @ -2,7 +2,7 @@ | ||||||
| 
 | 
 | ||||||
| [](LICENSE) | [](LICENSE) | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
| [](https://hub.docker.com/r/jqtype/rpxy) | [](https://hub.docker.com/r/jqtype/rpxy) | ||||||
| 
 | 
 | ||||||
|  | @ -104,11 +104,11 @@ If you want to host multiple and distinct domain names in a single IP address/po | ||||||
| ```toml | ```toml | ||||||
| default_application = "app1" | default_application = "app1" | ||||||
| 
 | 
 | ||||||
| [app.app1] | [apps.app1] | ||||||
| server_name = "app1.example.com" | server_name = "app1.example.com" | ||||||
| #... | #... | ||||||
| 
 | 
 | ||||||
| [app.app2] | [apps.app2] | ||||||
| server_name = "app2.example.org" | server_name = "app2.example.org" | ||||||
| #... | #... | ||||||
| ``` | ``` | ||||||
|  |  | ||||||
							
								
								
									
										19
									
								
								TODO.md
									
										
									
									
									
								
							
							
						
						
									
										19
									
								
								TODO.md
									
										
									
									
									
								
							|  | @ -1,9 +1,11 @@ | ||||||
| # TODO List | # TODO List | ||||||
| 
 | 
 | ||||||
| - [Done in 0.6.0] But we need more sophistication on `Forwarder` struct. ~~Fix strategy for `h2c` requests on forwarded requests upstream. This needs to update forwarder definition. Also, maybe forwarder would have a cache corresponding to the following task.~~ | - Support of `rustls-0.22`. | ||||||
| - [Initial implementation in v0.6.0] ~~**Cache option for the response with `Cache-Control: public` header directive ([#55](https://github.com/junkurihara/rust-rpxy/issues/55))**~~ Using `lru` crate might be inefficient in terms of the speed. | - We need more sophistication on `Forwarder` struct to handle `h2c`. | ||||||
|  | - Cache using `lru` crate might be inefficient in terms of the speed. | ||||||
|   - Consider more sophisticated architecture for cache |   - Consider more sophisticated architecture for cache | ||||||
|   - Persistent cache (if possible). |   - Persistent cache (if possible). | ||||||
|  |   - More secure cache file object naming | ||||||
|   - etc etc |   - etc etc | ||||||
| - Improvement of path matcher | - Improvement of path matcher | ||||||
| - More flexible option for rewriting path | - More flexible option for rewriting path | ||||||
|  | @ -17,7 +19,7 @@ | ||||||
| 
 | 
 | ||||||
| - Unit tests | - Unit tests | ||||||
| - Options to serve custom http_error page. | - Options to serve custom http_error page. | ||||||
| - Prometheus metrics | - Traces and metrics using opentelemetry (`tracing-opentelemetry` crate) | ||||||
| - Documentation | - Documentation | ||||||
| - Client certificate | - Client certificate | ||||||
|   - support intermediate certificate. Currently, only supports client certificates directly signed by root CA. |   - support intermediate certificate. Currently, only supports client certificates directly signed by root CA. | ||||||
|  | @ -27,15 +29,4 @@ | ||||||
| - Make the session-persistance option for load-balancing sophisticated. (mostly done in v0.3.0) | - Make the session-persistance option for load-balancing sophisticated. (mostly done in v0.3.0) | ||||||
|   - add option for sticky cookie name |   - add option for sticky cookie name | ||||||
|   - add option for sticky cookie duration |   - add option for sticky cookie duration | ||||||
| 
 |  | ||||||
| - Done in v0.5.0 ~~Use `gchr.io`~~ |  | ||||||
| - Done in v0.5.0: |  | ||||||
|   ~~Consideration on migrating from `quinn` and `h3-quinn` to other QUIC implementations ([#57](https://github.com/junkurihara/rust-rpxy/issues/57))~~ |  | ||||||
| - Done in v0.4.0: |  | ||||||
|   ~~Benchmark with other reverse proxy implementations like Sozu ([#58](https://github.com/junkurihara/rust-rpxy/issues/58)) Currently, Sozu can work only on `amd64` format due to its HTTP message parser limitation... Since the main developer have only `arm64` (Apple M1) laptops, so we should do that on VPS?~~ |  | ||||||
| - Done in v0.4.0: |  | ||||||
|   ~~Split `rpxy` source codes into `rpxy-lib` and `rpxy-bin` to make the core part (reverse proxy) isolated from the misc part like toml file loader. This is in order to make the configuration-related part more flexible (related to [#33](https://github.com/junkurihara/rust-rpxy/issues/33))~~ |  | ||||||
| - Done in 0.6.0: |  | ||||||
|   ~~Fix dynamic reloading of configuration file~~ |  | ||||||
| 
 |  | ||||||
| - etc. | - etc. | ||||||
|  |  | ||||||
|  | @ -57,7 +57,7 @@ upstream = [ | ||||||
| ] | ] | ||||||
| load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) | load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) | ||||||
| upstream_options = [ | upstream_options = [ | ||||||
|   "override_host", |   "keep_original_host",   # [default] do not overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy), which is prior to "set_upstream_host" if both are specified. | ||||||
|   "force_http2_upstream", # mutually exclusive with "force_http11_upstream" |   "force_http2_upstream", # mutually exclusive with "force_http11_upstream" | ||||||
| ] | ] | ||||||
| 
 | 
 | ||||||
|  | @ -76,9 +76,9 @@ upstream = [ | ||||||
| ] | ] | ||||||
| load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) | load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) | ||||||
| upstream_options = [ | upstream_options = [ | ||||||
|   "override_host", |  | ||||||
|   "upgrade_insecure_requests", |   "upgrade_insecure_requests", | ||||||
|   "force_http11_upstream", |   "force_http11_upstream", | ||||||
|  |   "set_upstream_host",         # overwrite HOST value with upstream hostname (like www.yahoo.com) | ||||||
| ] | ] | ||||||
| ###################################################################### | ###################################################################### | ||||||
| 
 | 
 | ||||||
|  | @ -98,6 +98,11 @@ reverse_proxy = [{ upstream = [{ location = 'www.google.com', tls = true }] }] | ||||||
| # We should note that this strongly depends on the client implementation. | # We should note that this strongly depends on the client implementation. | ||||||
| ignore_sni_consistency = false | ignore_sni_consistency = false | ||||||
| 
 | 
 | ||||||
|  | # Force connection handling timeout regardless of the connection status, i.e., idle or not. | ||||||
|  | # 0 represents an infinite timeout. [default: 0] | ||||||
|  | # Note that idel and header read timeouts are always specified independently of this. | ||||||
|  | connection_handling_timeout = 0 # sec | ||||||
|  | 
 | ||||||
| # If this specified, h3 is enabled | # If this specified, h3 is enabled | ||||||
| [experimental.h3] | [experimental.h3] | ||||||
| alt_svc_max_age = 3600             # sec | alt_svc_max_age = 3600             # sec | ||||||
|  |  | ||||||
|  | @ -14,8 +14,8 @@ services: | ||||||
|       additional_contexts: |       additional_contexts: | ||||||
|         - messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl |         - messense/rust-musl-cross:amd64-musl=docker-image://messense/rust-musl-cross:x86_64-musl | ||||||
|         - messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl |         - messense/rust-musl-cross:arm64-musl=docker-image://messense/rust-musl-cross:aarch64-musl | ||||||
|       # args: # Uncomment when build with native cert store |       # args: # Uncomment when build with webpki cert store | ||||||
|       #   - "CARGO_FEATURES=--no-default-features --features=http3-quinn,native-roots" |       #   - "CARGO_FEATURES=--no-default-features --features=http3-quinn,webpki-roots" | ||||||
|       dockerfile: ./docker/Dockerfile-slim # based on alpine and build x86_64-unknown-linux-musl |       dockerfile: ./docker/Dockerfile-slim # based on alpine and build x86_64-unknown-linux-musl | ||||||
|       platforms: # Choose your platforms |       platforms: # Choose your platforms | ||||||
|         # - "linux/amd64" |         # - "linux/amd64" | ||||||
|  |  | ||||||
|  | @ -14,8 +14,8 @@ services: | ||||||
|       # args: # Uncomment when build quic-s2n version |       # args: # Uncomment when build quic-s2n version | ||||||
|       #   - "CARGO_FEATURES=--no-default-features --features=http3-s2n" |       #   - "CARGO_FEATURES=--no-default-features --features=http3-s2n" | ||||||
|       #   - "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" |       #   - "ADDITIONAL_DEPS=pkg-config libssl-dev cmake libclang1 gcc g++" | ||||||
|       # args: # Uncomment when build with native cert store |       # args: # Uncomment when build with webpki root store | ||||||
|       #   - "CARGO_FEATURES=--no-default-features --features=http3-quinn,native-roots" |       #   - "CARGO_FEATURES=--no-default-features --features=http3-quinn,webpki-roots" | ||||||
|       dockerfile: ./docker/Dockerfile # based on ubuntu 22.04 and build x86_64-unknown-linux-gnu |       dockerfile: ./docker/Dockerfile # based on ubuntu 22.04 and build x86_64-unknown-linux-gnu | ||||||
|       platforms: # Choose your platforms |       platforms: # Choose your platforms | ||||||
|         # - "linux/amd64" |         # - "linux/amd64" | ||||||
|  |  | ||||||
|  | @ -1,51 +1,54 @@ | ||||||
| [package] | [package] | ||||||
| name = "rpxy" | name = "rpxy" | ||||||
| version = "0.6.2" | description = "`rpxy`: a simple and ultrafast http reverse proxy" | ||||||
| authors = ["Jun Kurihara"] | version.workspace = true | ||||||
| homepage = "https://github.com/junkurihara/rust-rpxy" | authors.workspace = true | ||||||
| repository = "https://github.com/junkurihara/rust-rpxy" | homepage.workspace = true | ||||||
| license = "MIT" | repository.workspace = true | ||||||
| readme = "../README.md" | license.workspace = true | ||||||
| edition = "2021" | readme.workspace = true | ||||||
| publish = false | edition.workspace = true | ||||||
|  | publish.workspace = true | ||||||
| 
 | 
 | ||||||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||||
| 
 | 
 | ||||||
| [features] | [features] | ||||||
| default = ["http3-quinn", "cache"] | default = ["http3-quinn", "cache", "rustls-backend"] | ||||||
| http3-quinn = ["rpxy-lib/http3-quinn"] | http3-quinn = ["rpxy-lib/http3-quinn"] | ||||||
| http3-s2n = ["rpxy-lib/http3-s2n"] | http3-s2n = ["rpxy-lib/http3-s2n"] | ||||||
|  | native-tls-backend = ["rpxy-lib/native-tls-backend"] | ||||||
|  | rustls-backend = ["rpxy-lib/rustls-backend"] | ||||||
|  | webpki-roots = ["rpxy-lib/webpki-roots"] | ||||||
| cache = ["rpxy-lib/cache"] | cache = ["rpxy-lib/cache"] | ||||||
| native-roots = ["rpxy-lib/native-roots"] |  | ||||||
| 
 | 
 | ||||||
| [dependencies] | [dependencies] | ||||||
| rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ | rpxy-lib = { path = "../rpxy-lib/", default-features = false, features = [ | ||||||
|   "sticky-cookie", |   "sticky-cookie", | ||||||
| ] } | ] } | ||||||
| 
 | 
 | ||||||
| anyhow = "1.0.75" | anyhow = "1.0.79" | ||||||
| rustc-hash = "1.1.0" | rustc-hash = "1.1.0" | ||||||
| serde = { version = "1.0.188", default-features = false, features = ["derive"] } | serde = { version = "1.0.196", default-features = false, features = ["derive"] } | ||||||
| derive_builder = "0.12.0" | derive_builder = "0.20.0" | ||||||
| tokio = { version = "1.33.0", default-features = false, features = [ | tokio = { version = "1.36.0", default-features = false, features = [ | ||||||
|   "net", |   "net", | ||||||
|   "rt-multi-thread", |   "rt-multi-thread", | ||||||
|   "time", |   "time", | ||||||
|   "sync", |   "sync", | ||||||
|   "macros", |   "macros", | ||||||
| ] } | ] } | ||||||
| async-trait = "0.1.73" | async-trait = "0.1.77" | ||||||
| rustls-pemfile = "1.0.3" | rustls-pemfile = "1.0.4" | ||||||
| mimalloc = { version = "*", default-features = false } | mimalloc = { version = "*", default-features = false } | ||||||
| 
 | 
 | ||||||
| # config | # config | ||||||
| clap = { version = "4.4.6", features = ["std", "cargo", "wrap_help"] } | clap = { version = "4.5.0", features = ["std", "cargo", "wrap_help"] } | ||||||
| toml = { version = "0.8", default-features = false, features = ["parse"] } | toml = { version = "0.8.10", default-features = false, features = ["parse"] } | ||||||
| hot_reload = "0.1.4" | hot_reload = "0.1.5" | ||||||
| 
 | 
 | ||||||
| # logging | # logging | ||||||
| tracing = { version = "0.1.37" } | tracing = { version = "0.1.40" } | ||||||
| tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } | tracing-subscriber = { version = "0.3.18", features = ["env-filter"] } | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| [dev-dependencies] | [dev-dependencies] | ||||||
|  |  | ||||||
|  | @ -8,7 +8,7 @@ use rpxy_lib::{ | ||||||
| use std::{ | use std::{ | ||||||
|   fs::File, |   fs::File, | ||||||
|   io::{self, BufReader, Cursor, Read}, |   io::{self, BufReader, Cursor, Read}, | ||||||
|   path::PathBuf, |   path::{Path, PathBuf}, | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| #[derive(Builder, Debug, Clone)] | #[derive(Builder, Debug, Clone)] | ||||||
|  | @ -28,16 +28,16 @@ pub struct CryptoFileSource { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl CryptoFileSourceBuilder { | impl CryptoFileSourceBuilder { | ||||||
|   pub fn tls_cert_path(&mut self, v: &str) -> &mut Self { |   pub fn tls_cert_path<T: AsRef<Path>>(&mut self, v: T) -> &mut Self { | ||||||
|     self.tls_cert_path = Some(PathBuf::from(v)); |     self.tls_cert_path = Some(v.as_ref().to_path_buf()); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|   pub fn tls_cert_key_path(&mut self, v: &str) -> &mut Self { |   pub fn tls_cert_key_path<T: AsRef<Path>>(&mut self, v: T) -> &mut Self { | ||||||
|     self.tls_cert_key_path = Some(PathBuf::from(v)); |     self.tls_cert_key_path = Some(v.as_ref().to_path_buf()); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|   pub fn client_ca_cert_path(&mut self, v: &Option<String>) -> &mut Self { |   pub fn client_ca_cert_path<T: AsRef<Path>>(&mut self, v: Option<T>) -> &mut Self { | ||||||
|     self.client_ca_cert_path = Some(v.to_owned().as_ref().map(PathBuf::from)); |     self.client_ca_cert_path = Some(v.map(|p| p.as_ref().to_path_buf())); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  | @ -167,11 +167,11 @@ mod tests { | ||||||
|   async fn read_server_crt_key_files_with_client_ca_crt() { |   async fn read_server_crt_key_files_with_client_ca_crt() { | ||||||
|     let tls_cert_path = "../example-certs/server.crt"; |     let tls_cert_path = "../example-certs/server.crt"; | ||||||
|     let tls_cert_key_path = "../example-certs/server.key"; |     let tls_cert_key_path = "../example-certs/server.key"; | ||||||
|     let client_ca_cert_path = Some("../example-certs/client.ca.crt".to_string()); |     let client_ca_cert_path = Some("../example-certs/client.ca.crt"); | ||||||
|     let crypto_file_source = CryptoFileSourceBuilder::default() |     let crypto_file_source = CryptoFileSourceBuilder::default() | ||||||
|       .tls_cert_key_path(tls_cert_key_path) |       .tls_cert_key_path(tls_cert_key_path) | ||||||
|       .tls_cert_path(tls_cert_path) |       .tls_cert_path(tls_cert_path) | ||||||
|       .client_ca_cert_path(&client_ca_cert_path) |       .client_ca_cert_path(client_ca_cert_path) | ||||||
|       .build(); |       .build(); | ||||||
|     assert!(crypto_file_source.is_ok()); |     assert!(crypto_file_source.is_ok()); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -7,6 +7,7 @@ use rpxy_lib::{reexports::Uri, AppConfig, ProxyConfig, ReverseProxyConfig, TlsCo | ||||||
| use rustc_hash::FxHashMap as HashMap; | use rustc_hash::FxHashMap as HashMap; | ||||||
| use serde::Deserialize; | use serde::Deserialize; | ||||||
| use std::{fs, net::SocketAddr}; | use std::{fs, net::SocketAddr}; | ||||||
|  | use tokio::time::Duration; | ||||||
| 
 | 
 | ||||||
| #[derive(Deserialize, Debug, Default, PartialEq, Eq, Clone)] | #[derive(Deserialize, Debug, Default, PartialEq, Eq, Clone)] | ||||||
| pub struct ConfigToml { | pub struct ConfigToml { | ||||||
|  | @ -48,6 +49,7 @@ pub struct Experimental { | ||||||
|   #[cfg(feature = "cache")] |   #[cfg(feature = "cache")] | ||||||
|   pub cache: Option<CacheOption>, |   pub cache: Option<CacheOption>, | ||||||
|   pub ignore_sni_consistency: Option<bool>, |   pub ignore_sni_consistency: Option<bool>, | ||||||
|  |   pub connection_handling_timeout: Option<u64>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[derive(Deserialize, Debug, Default, PartialEq, Eq, Clone)] | #[derive(Deserialize, Debug, Default, PartialEq, Eq, Clone)] | ||||||
|  | @ -162,7 +164,7 @@ impl TryInto<ProxyConfig> for &ConfigToml { | ||||||
|             if x == 0u64 { |             if x == 0u64 { | ||||||
|               proxy_config.h3_max_idle_timeout = None; |               proxy_config.h3_max_idle_timeout = None; | ||||||
|             } else { |             } else { | ||||||
|               proxy_config.h3_max_idle_timeout = Some(tokio::time::Duration::from_secs(x)) |               proxy_config.h3_max_idle_timeout = Some(Duration::from_secs(x)) | ||||||
|             } |             } | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|  | @ -172,6 +174,14 @@ impl TryInto<ProxyConfig> for &ConfigToml { | ||||||
|         proxy_config.sni_consistency = !ignore; |         proxy_config.sni_consistency = !ignore; | ||||||
|       } |       } | ||||||
| 
 | 
 | ||||||
|  |       if let Some(timeout) = exp.connection_handling_timeout { | ||||||
|  |         if timeout == 0u64 { | ||||||
|  |           proxy_config.connection_handling_timeout = None; | ||||||
|  |         } else { | ||||||
|  |           proxy_config.connection_handling_timeout = Some(Duration::from_secs(timeout)); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|       #[cfg(feature = "cache")] |       #[cfg(feature = "cache")] | ||||||
|       if let Some(cache_option) = &exp.cache { |       if let Some(cache_option) = &exp.cache { | ||||||
|         proxy_config.cache_enabled = true; |         proxy_config.cache_enabled = true; | ||||||
|  | @ -217,7 +227,7 @@ impl Application { | ||||||
|       let inner = CryptoFileSourceBuilder::default() |       let inner = CryptoFileSourceBuilder::default() | ||||||
|         .tls_cert_path(tls.tls_cert_path.as_ref().unwrap()) |         .tls_cert_path(tls.tls_cert_path.as_ref().unwrap()) | ||||||
|         .tls_cert_key_path(tls.tls_cert_key_path.as_ref().unwrap()) |         .tls_cert_key_path(tls.tls_cert_key_path.as_ref().unwrap()) | ||||||
|         .client_ca_cert_path(&tls.client_ca_cert_path) |         .client_ca_cert_path(tls.client_ca_cert_path.as_deref()) | ||||||
|         .build()?; |         .build()?; | ||||||
| 
 | 
 | ||||||
|       let https_redirection = if tls.https_redirection.is_none() { |       let https_redirection = if tls.https_redirection.is_none() { | ||||||
|  |  | ||||||
|  | @ -1 +1,2 @@ | ||||||
|  | #[allow(unused)] | ||||||
| pub use anyhow::{anyhow, bail, ensure, Context}; | pub use anyhow::{anyhow, bail, ensure, Context}; | ||||||
|  |  | ||||||
|  | @ -1,3 +1,4 @@ | ||||||
|  | #[allow(unused)] | ||||||
| pub use tracing::{debug, error, info, warn}; | pub use tracing::{debug, error, info, warn}; | ||||||
| 
 | 
 | ||||||
| pub fn init_logger() { | pub fn init_logger() { | ||||||
|  | @ -12,10 +13,13 @@ pub fn init_logger() { | ||||||
|     .with_level(true) |     .with_level(true) | ||||||
|     .compact(); |     .compact(); | ||||||
| 
 | 
 | ||||||
|   // This limits the logger to emits only rpxy crate
 |   // This limits the logger to emits only proxy crate
 | ||||||
|  |   let pkg_name = env!("CARGO_PKG_NAME").replace('-', "_"); | ||||||
|   let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); |   let level_string = std::env::var(EnvFilter::DEFAULT_ENV).unwrap_or_else(|_| "info".to_string()); | ||||||
|   let filter_layer = EnvFilter::new(format!("{}={}", env!("CARGO_PKG_NAME"), level_string)); |   let filter_layer = EnvFilter::new(format!("{}={}", pkg_name, level_string)); | ||||||
|   // let filter_layer = EnvFilter::from_default_env();
 |   // let filter_layer = EnvFilter::try_from_default_env()
 | ||||||
|  |   //   .unwrap_or_else(|_| EnvFilter::new("info"))
 | ||||||
|  |   //   .add_directive(format!("{}=trace", pkg_name).parse().unwrap());
 | ||||||
| 
 | 
 | ||||||
|   tracing_subscriber::registry() |   tracing_subscriber::registry() | ||||||
|     .with(format_layer) |     .with(format_layer) | ||||||
|  |  | ||||||
|  | @ -15,9 +15,6 @@ use crate::{ | ||||||
| use hot_reload::{ReloaderReceiver, ReloaderService}; | use hot_reload::{ReloaderReceiver, ReloaderService}; | ||||||
| use rpxy_lib::entrypoint; | use rpxy_lib::entrypoint; | ||||||
| 
 | 
 | ||||||
| #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] |  | ||||||
| compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); |  | ||||||
| 
 |  | ||||||
| fn main() { | fn main() { | ||||||
|   init_logger(); |   init_logger(); | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,31 +1,40 @@ | ||||||
| [package] | [package] | ||||||
| name = "rpxy-lib" | name = "rpxy-lib" | ||||||
| version = "0.6.2" | description = "Library of `rpxy`: a simple and ultrafast http reverse proxy" | ||||||
| authors = ["Jun Kurihara"] | version.workspace = true | ||||||
| homepage = "https://github.com/junkurihara/rust-rpxy" | authors.workspace = true | ||||||
| repository = "https://github.com/junkurihara/rust-rpxy" | homepage.workspace = true | ||||||
| license = "MIT" | repository.workspace = true | ||||||
| readme = "../README.md" | license.workspace = true | ||||||
| edition = "2021" | readme.workspace = true | ||||||
| publish = false | edition.workspace = true | ||||||
|  | publish.workspace = true | ||||||
| 
 | 
 | ||||||
| # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||||||
| 
 | 
 | ||||||
| [features] | [features] | ||||||
| default = ["http3-quinn", "sticky-cookie", "cache"] | default = ["http3-quinn", "sticky-cookie", "cache", "rustls-backend"] | ||||||
| http3-quinn = ["quinn", "h3", "h3-quinn", "socket2"] | http3-quinn = ["socket2", "quinn", "h3", "h3-quinn"] | ||||||
| http3-s2n = ["h3", "s2n-quic", "s2n-quic-rustls", "s2n-quic-h3"] | http3-s2n = [ | ||||||
|  |   "h3", | ||||||
|  |   "s2n-quic", | ||||||
|  |   "s2n-quic-core", | ||||||
|  |   "s2n-quic-rustls", | ||||||
|  |   "s2n-quic-h3", | ||||||
|  | ] | ||||||
|  | cache = ["http-cache-semantics", "lru", "sha2", "base64"] | ||||||
| sticky-cookie = ["base64", "sha2", "chrono"] | sticky-cookie = ["base64", "sha2", "chrono"] | ||||||
| cache = ["http-cache-semantics", "lru"] | native-tls-backend = ["hyper-tls"] | ||||||
| native-roots = ["hyper-rustls/native-tokio"] | rustls-backend = ["hyper-rustls"] | ||||||
|  | webpki-roots = ["rustls-backend", "hyper-rustls/webpki-tokio"] | ||||||
| 
 | 
 | ||||||
| [dependencies] | [dependencies] | ||||||
| rand = "0.8.5" | rand = "0.8.5" | ||||||
| rustc-hash = "1.1.0" | rustc-hash = "1.1.0" | ||||||
| bytes = "1.5.0" | bytes = "1.5.0" | ||||||
| derive_builder = "0.12.0" | derive_builder = "0.20.0" | ||||||
| futures = { version = "0.3.28", features = ["alloc", "async-await"] } | futures = { version = "0.3.30", features = ["alloc", "async-await"] } | ||||||
| tokio = { version = "1.33.0", default-features = false, features = [ | tokio = { version = "1.36.0", default-features = false, features = [ | ||||||
|   "net", |   "net", | ||||||
|   "rt-multi-thread", |   "rt-multi-thread", | ||||||
|   "time", |   "time", | ||||||
|  | @ -33,60 +42,69 @@ tokio = { version = "1.33.0", default-features = false, features = [ | ||||||
|   "macros", |   "macros", | ||||||
|   "fs", |   "fs", | ||||||
| ] } | ] } | ||||||
| async-trait = "0.1.73" | pin-project-lite = "0.2.13" | ||||||
| hot_reload = "0.1.4" # reloading certs | async-trait = "0.1.77" | ||||||
| 
 | 
 | ||||||
| # Error handling | # Error handling | ||||||
| anyhow = "1.0.75" | anyhow = "1.0.79" | ||||||
| thiserror = "1.0.49" | thiserror = "1.0.57" | ||||||
| 
 | 
 | ||||||
| # http and tls | # http for both server and client | ||||||
| hyper = { version = "0.14.27", default-features = false, features = [ | http = "1.0.0" | ||||||
|   "server", | http-body-util = "0.1.0" | ||||||
|  | hyper = { version = "1.1.0", default-features = false } | ||||||
|  | hyper-util = { version = "0.1.3", features = ["full"] } | ||||||
|  | futures-util = { version = "0.3.30", default-features = false } | ||||||
|  | futures-channel = { version = "0.3.30", default-features = false } | ||||||
|  | 
 | ||||||
|  | # http client for upstream | ||||||
|  | hyper-tls = { version = "0.6.0", features = [ | ||||||
|  |   "alpn", | ||||||
|  |   "vendored", | ||||||
|  | ], optional = true } | ||||||
|  | hyper-rustls = { version = "0.26.0", default-features = false, features = [ | ||||||
|  |   "ring", | ||||||
|  |   "native-tokio", | ||||||
|   "http1", |   "http1", | ||||||
|   "http2", |   "http2", | ||||||
|   "stream", | ], optional = true } | ||||||
| ] } | 
 | ||||||
| hyper-rustls = { version = "0.24.1", default-features = false, features = [ | # tls and cert management for server | ||||||
|   "tokio-runtime", | hot_reload = "0.1.5" | ||||||
|   "webpki-tokio", | rustls = { version = "0.21.10", default-features = false } | ||||||
|   "http1", |  | ||||||
|   "http2", |  | ||||||
| ] } |  | ||||||
| tokio-rustls = { version = "0.24.1", features = ["early-data"] } | tokio-rustls = { version = "0.24.1", features = ["early-data"] } | ||||||
| rustls = { version = "0.21.7", default-features = false } |  | ||||||
| webpki = "0.22.4" | webpki = "0.22.4" | ||||||
| x509-parser = "0.15.1" | x509-parser = "0.15.1" | ||||||
| 
 | 
 | ||||||
| # logging | # logging | ||||||
| tracing = { version = "0.1.37" } | tracing = { version = "0.1.40" } | ||||||
| 
 | 
 | ||||||
| # http/3 | # http/3 | ||||||
| # quinn = { version = "0.9.3", optional = true } | quinn = { version = "0.10.2", optional = true } | ||||||
| quinn = { path = "../submodules/quinn/quinn", optional = true } # Tentative to support rustls-0.21 |  | ||||||
| h3 = { path = "../submodules/h3/h3/", optional = true } | h3 = { path = "../submodules/h3/h3/", optional = true } | ||||||
| # h3-quinn = { path = "./h3/h3-quinn/", optional = true } | h3-quinn = { path = "../submodules/h3/h3-quinn/", optional = true } | ||||||
| h3-quinn = { path = "../submodules/h3-quinn/", optional = true } # Tentative to support rustls-0.21 | s2n-quic = { version = "1.33.0", default-features = false, features = [ | ||||||
| # for UDP socket wit SO_REUSEADDR when h3 with quinn |  | ||||||
| socket2 = { version = "0.5.4", features = ["all"], optional = true } |  | ||||||
| s2n-quic = { path = "../submodules/s2n-quic/quic/s2n-quic/", default-features = false, features = [ |  | ||||||
|   "provider-tls-rustls", |   "provider-tls-rustls", | ||||||
| ], optional = true } | ], optional = true } | ||||||
| s2n-quic-h3 = { path = "../submodules/s2n-quic/quic/s2n-quic-h3/", optional = true } | s2n-quic-core = { version = "0.33.0", default-features = false, optional = true } | ||||||
| s2n-quic-rustls = { path = "../submodules/s2n-quic/quic/s2n-quic-rustls/", optional = true } | s2n-quic-h3 = { path = "../submodules/s2n-quic-h3/", optional = true } | ||||||
|  | s2n-quic-rustls = { version = "0.33.0", optional = true } | ||||||
|  | # for UDP socket wit SO_REUSEADDR when h3 with quinn | ||||||
|  | socket2 = { version = "0.5.5", features = ["all"], optional = true } | ||||||
| 
 | 
 | ||||||
| # cache | # cache | ||||||
| http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } | http-cache-semantics = { path = "../submodules/rusty-http-cache-semantics/", optional = true } | ||||||
| lru = { version = "0.12.0", optional = true } | lru = { version = "0.12.2", optional = true } | ||||||
|  | sha2 = { version = "0.10.8", default-features = false, optional = true } | ||||||
| 
 | 
 | ||||||
| # cookie handling for sticky cookie | # cookie handling for sticky cookie | ||||||
| chrono = { version = "0.4.31", default-features = false, features = [ | chrono = { version = "0.4.34", default-features = false, features = [ | ||||||
|   "unstable-locales", |   "unstable-locales", | ||||||
|   "alloc", |   "alloc", | ||||||
|   "clock", |   "clock", | ||||||
| ], optional = true } | ], optional = true } | ||||||
| base64 = { version = "0.21.4", optional = true } | base64 = { version = "0.21.7", optional = true } | ||||||
| sha2 = { version = "0.10.8", default-features = false, optional = true } |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| [dev-dependencies] | [dev-dependencies] | ||||||
|  | tokio-test = "0.4.3" | ||||||
|  |  | ||||||
							
								
								
									
										136
									
								
								rpxy-lib/src/backend/backend_main.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										136
									
								
								rpxy-lib/src/backend/backend_main.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,136 @@ | ||||||
|  | use crate::{ | ||||||
|  |   crypto::CryptoSource, | ||||||
|  |   error::*, | ||||||
|  |   log::*, | ||||||
|  |   name_exp::{ByteName, ServerName}, | ||||||
|  |   AppConfig, AppConfigList, | ||||||
|  | }; | ||||||
|  | use derive_builder::Builder; | ||||||
|  | use rustc_hash::FxHashMap as HashMap; | ||||||
|  | use std::borrow::Cow; | ||||||
|  | 
 | ||||||
|  | use super::upstream::PathManager; | ||||||
|  | 
 | ||||||
|  | /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 | ||||||
|  | #[derive(Builder)] | ||||||
|  | pub struct BackendApp<T> | ||||||
|  | where | ||||||
|  |   T: CryptoSource, | ||||||
|  | { | ||||||
|  |   #[builder(setter(into))] | ||||||
|  |   /// backend application name, e.g., app1
 | ||||||
|  |   pub app_name: String, | ||||||
|  |   #[builder(setter(custom))] | ||||||
|  |   /// server name, e.g., example.com, in [[ServerName]] object
 | ||||||
|  |   pub server_name: ServerName, | ||||||
|  |   /// struct of reverse proxy serving incoming request
 | ||||||
|  |   pub path_manager: PathManager, | ||||||
|  |   /// tls settings: https redirection with 30x
 | ||||||
|  |   #[builder(default)] | ||||||
|  |   pub https_redirection: Option<bool>, | ||||||
|  |   /// TLS settings: source meta for server cert, key, client ca cert
 | ||||||
|  |   #[builder(default)] | ||||||
|  |   pub crypto_source: Option<T>, | ||||||
|  | } | ||||||
|  | impl<'a, T> BackendAppBuilder<T> | ||||||
|  | where | ||||||
|  |   T: CryptoSource, | ||||||
|  | { | ||||||
|  |   pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self { | ||||||
|  |     self.server_name = Some(server_name.to_server_name()); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// HashMap and some meta information for multiple Backend structs.
 | ||||||
|  | pub struct BackendAppManager<T> | ||||||
|  | where | ||||||
|  |   T: CryptoSource, | ||||||
|  | { | ||||||
|  |   /// HashMap of Backend structs, key is server name
 | ||||||
|  |   pub apps: HashMap<ServerName, BackendApp<T>>, | ||||||
|  |   /// for plaintext http
 | ||||||
|  |   pub default_server_name: Option<ServerName>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<T> Default for BackendAppManager<T> | ||||||
|  | where | ||||||
|  |   T: CryptoSource, | ||||||
|  | { | ||||||
|  |   fn default() -> Self { | ||||||
|  |     Self { | ||||||
|  |       apps: HashMap::<ServerName, BackendApp<T>>::default(), | ||||||
|  |       default_server_name: None, | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<T> TryFrom<&AppConfig<T>> for BackendApp<T> | ||||||
|  | where | ||||||
|  |   T: CryptoSource + Clone, | ||||||
|  | { | ||||||
|  |   type Error = RpxyError; | ||||||
|  | 
 | ||||||
|  |   fn try_from(app_config: &AppConfig<T>) -> Result<Self, Self::Error> { | ||||||
|  |     let mut backend_builder = BackendAppBuilder::default(); | ||||||
|  |     let path_manager = PathManager::try_from(app_config)?; | ||||||
|  |     backend_builder | ||||||
|  |       .app_name(app_config.app_name.clone()) | ||||||
|  |       .server_name(app_config.server_name.clone()) | ||||||
|  |       .path_manager(path_manager); | ||||||
|  |     // TLS settings and build backend instance
 | ||||||
|  |     let backend = if app_config.tls.is_none() { | ||||||
|  |       backend_builder.build()? | ||||||
|  |     } else { | ||||||
|  |       let tls = app_config.tls.as_ref().unwrap(); | ||||||
|  |       backend_builder | ||||||
|  |         .https_redirection(Some(tls.https_redirection)) | ||||||
|  |         .crypto_source(Some(tls.inner.clone())) | ||||||
|  |         .build()? | ||||||
|  |     }; | ||||||
|  |     Ok(backend) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<T> TryFrom<&AppConfigList<T>> for BackendAppManager<T> | ||||||
|  | where | ||||||
|  |   T: CryptoSource + Clone, | ||||||
|  | { | ||||||
|  |   type Error = RpxyError; | ||||||
|  | 
 | ||||||
|  |   fn try_from(config_list: &AppConfigList<T>) -> Result<Self, Self::Error> { | ||||||
|  |     let mut manager = Self::default(); | ||||||
|  |     for app_config in config_list.inner.iter() { | ||||||
|  |       let backend: BackendApp<T> = BackendApp::try_from(app_config)?; | ||||||
|  |       manager | ||||||
|  |         .apps | ||||||
|  |         .insert(app_config.server_name.clone().to_server_name(), backend); | ||||||
|  | 
 | ||||||
|  |       info!( | ||||||
|  |         "Registering application {} ({})", | ||||||
|  |         &app_config.server_name, &app_config.app_name | ||||||
|  |       ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // default backend application for plaintext http requests
 | ||||||
|  |     if let Some(default_app_name) = &config_list.default_app { | ||||||
|  |       let default_server_name = manager | ||||||
|  |         .apps | ||||||
|  |         .iter() | ||||||
|  |         .filter(|(_k, v)| &v.app_name == default_app_name) | ||||||
|  |         .map(|(_, v)| v.server_name.clone()) | ||||||
|  |         .collect::<Vec<_>>(); | ||||||
|  | 
 | ||||||
|  |       if !default_server_name.is_empty() { | ||||||
|  |         info!( | ||||||
|  |           "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", | ||||||
|  |           &default_app_name, | ||||||
|  |           (&default_server_name[0]).try_into().unwrap_or_else(|_| "".to_string()) | ||||||
|  |         ); | ||||||
|  | 
 | ||||||
|  |         manager.default_server_name = Some(default_server_name[0].clone()); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     Ok(manager) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -1,6 +1,7 @@ | ||||||
|  | #[allow(unused)] | ||||||
| #[cfg(feature = "sticky-cookie")] | #[cfg(feature = "sticky-cookie")] | ||||||
| pub use super::{ | pub use super::{ | ||||||
|   load_balance_sticky::{LbStickyRoundRobin, LbStickyRoundRobinBuilder}, |   load_balance_sticky::{LoadBalanceSticky, LoadBalanceStickyBuilder}, | ||||||
|   sticky_cookie::StickyCookie, |   sticky_cookie::StickyCookie, | ||||||
| }; | }; | ||||||
| use derive_builder::Builder; | use derive_builder::Builder; | ||||||
|  | @ -11,7 +12,7 @@ use std::sync::{ | ||||||
| }; | }; | ||||||
| 
 | 
 | ||||||
| /// Constants to specify a load balance option
 | /// Constants to specify a load balance option
 | ||||||
| pub(super) mod load_balance_options { | pub mod load_balance_options { | ||||||
|   pub const FIX_TO_FIRST: &str = "none"; |   pub const FIX_TO_FIRST: &str = "none"; | ||||||
|   pub const ROUND_ROBIN: &str = "round_robin"; |   pub const ROUND_ROBIN: &str = "round_robin"; | ||||||
|   pub const RANDOM: &str = "random"; |   pub const RANDOM: &str = "random"; | ||||||
|  | @ -22,18 +23,18 @@ pub(super) mod load_balance_options { | ||||||
| #[derive(Debug, Clone)] | #[derive(Debug, Clone)] | ||||||
| /// Pointer to upstream serving the incoming request.
 | /// Pointer to upstream serving the incoming request.
 | ||||||
| /// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given.
 | /// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given.
 | ||||||
| pub(super) struct PointerToUpstream { | pub struct PointerToUpstream { | ||||||
|   pub ptr: usize, |   pub ptr: usize, | ||||||
|   pub context_lb: Option<LbContext>, |   pub context: Option<LoadBalanceContext>, | ||||||
| } | } | ||||||
| /// Trait for LB
 | /// Trait for LB
 | ||||||
| pub(super) trait LbWithPointer { | pub(super) trait LoadBalanceWithPointer { | ||||||
|   fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; |   fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[derive(Debug, Clone, Builder)] | #[derive(Debug, Clone, Builder)] | ||||||
| /// Round Robin LB object as a pointer to the current serving upstream destination
 | /// Round Robin LB object as a pointer to the current serving upstream destination
 | ||||||
| pub struct LbRoundRobin { | pub struct LoadBalanceRoundRobin { | ||||||
|   #[builder(default)] |   #[builder(default)] | ||||||
|   /// Pointer to the index of the last served upstream destination
 |   /// Pointer to the index of the last served upstream destination
 | ||||||
|   ptr: Arc<AtomicUsize>, |   ptr: Arc<AtomicUsize>, | ||||||
|  | @ -41,15 +42,15 @@ pub struct LbRoundRobin { | ||||||
|   /// Number of upstream destinations
 |   /// Number of upstream destinations
 | ||||||
|   num_upstreams: usize, |   num_upstreams: usize, | ||||||
| } | } | ||||||
| impl LbRoundRobinBuilder { | impl LoadBalanceRoundRobinBuilder { | ||||||
|   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { |   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||||
|     self.num_upstreams = Some(*v); |     self.num_upstreams = Some(*v); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
| } | } | ||||||
| impl LbWithPointer for LbRoundRobin { | impl LoadBalanceWithPointer for LoadBalanceRoundRobin { | ||||||
|   /// Increment the count of upstream served up to the max value
 |   /// Increment the count of upstream served up to the max value
 | ||||||
|   fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { |   fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { | ||||||
|     // Get a current count of upstream served
 |     // Get a current count of upstream served
 | ||||||
|     let current_ptr = self.ptr.load(Ordering::Relaxed); |     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||||
| 
 | 
 | ||||||
|  | @ -59,29 +60,29 @@ impl LbWithPointer for LbRoundRobin { | ||||||
|       // Clear the counter
 |       // Clear the counter
 | ||||||
|       self.ptr.fetch_and(0, Ordering::Relaxed) |       self.ptr.fetch_and(0, Ordering::Relaxed) | ||||||
|     }; |     }; | ||||||
|     PointerToUpstream { ptr, context_lb: None } |     PointerToUpstream { ptr, context: None } | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[derive(Debug, Clone, Builder)] | #[derive(Debug, Clone, Builder)] | ||||||
| /// Random LB object to keep the object of random pools
 | /// Random LB object to keep the object of random pools
 | ||||||
| pub struct LbRandom { | pub struct LoadBalanceRandom { | ||||||
|   #[builder(setter(custom), default)] |   #[builder(setter(custom), default)] | ||||||
|   /// Number of upstream destinations
 |   /// Number of upstream destinations
 | ||||||
|   num_upstreams: usize, |   num_upstreams: usize, | ||||||
| } | } | ||||||
| impl LbRandomBuilder { | impl LoadBalanceRandomBuilder { | ||||||
|   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { |   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||||
|     self.num_upstreams = Some(*v); |     self.num_upstreams = Some(*v); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
| } | } | ||||||
| impl LbWithPointer for LbRandom { | impl LoadBalanceWithPointer for LoadBalanceRandom { | ||||||
|   /// Returns the random index within the range
 |   /// Returns the random index within the range
 | ||||||
|   fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { |   fn get_ptr(&self, _info: Option<&LoadBalanceContext>) -> PointerToUpstream { | ||||||
|     let mut rng = rand::thread_rng(); |     let mut rng = rand::thread_rng(); | ||||||
|     let ptr = rng.gen_range(0..self.num_upstreams); |     let ptr = rng.gen_range(0..self.num_upstreams); | ||||||
|     PointerToUpstream { ptr, context_lb: None } |     PointerToUpstream { ptr, context: None } | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -91,12 +92,12 @@ pub enum LoadBalance { | ||||||
|   /// Fix to the first upstream. Use if only one upstream destination is specified
 |   /// Fix to the first upstream. Use if only one upstream destination is specified
 | ||||||
|   FixToFirst, |   FixToFirst, | ||||||
|   /// Randomly chose one upstream server
 |   /// Randomly chose one upstream server
 | ||||||
|   Random(LbRandom), |   Random(LoadBalanceRandom), | ||||||
|   /// Simple round robin without session persistance
 |   /// Simple round robin without session persistance
 | ||||||
|   RoundRobin(LbRoundRobin), |   RoundRobin(LoadBalanceRoundRobin), | ||||||
|   #[cfg(feature = "sticky-cookie")] |   #[cfg(feature = "sticky-cookie")] | ||||||
|   /// Round robin with session persistance using cookie
 |   /// Round robin with session persistance using cookie
 | ||||||
|   StickyRoundRobin(LbStickyRoundRobin), |   StickyRoundRobin(LoadBalanceSticky), | ||||||
| } | } | ||||||
| impl Default for LoadBalance { | impl Default for LoadBalance { | ||||||
|   fn default() -> Self { |   fn default() -> Self { | ||||||
|  | @ -106,11 +107,11 @@ impl Default for LoadBalance { | ||||||
| 
 | 
 | ||||||
| impl LoadBalance { | impl LoadBalance { | ||||||
|   /// Get the index of the upstream serving the incoming request
 |   /// Get the index of the upstream serving the incoming request
 | ||||||
|   pub(super) fn get_context(&self, _context_to_lb: &Option<LbContext>) -> PointerToUpstream { |   pub fn get_context(&self, _context_to_lb: &Option<LoadBalanceContext>) -> PointerToUpstream { | ||||||
|     match self { |     match self { | ||||||
|       LoadBalance::FixToFirst => PointerToUpstream { |       LoadBalance::FixToFirst => PointerToUpstream { | ||||||
|         ptr: 0usize, |         ptr: 0usize, | ||||||
|         context_lb: None, |         context: None, | ||||||
|       }, |       }, | ||||||
|       LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), |       LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), | ||||||
|       LoadBalance::Random(ptr) => ptr.get_ptr(None), |       LoadBalance::Random(ptr) => ptr.get_ptr(None), | ||||||
|  | @ -127,7 +128,7 @@ impl LoadBalance { | ||||||
| /// Struct to handle the sticky cookie string,
 | /// Struct to handle the sticky cookie string,
 | ||||||
| /// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists.
 | /// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists.
 | ||||||
| /// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist.
 | /// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist.
 | ||||||
| pub struct LbContext { | pub struct LoadBalanceContext { | ||||||
|   #[cfg(feature = "sticky-cookie")] |   #[cfg(feature = "sticky-cookie")] | ||||||
|   pub sticky_cookie: StickyCookie, |   pub sticky_cookie: StickyCookie, | ||||||
|   #[cfg(not(feature = "sticky-cookie"))] |   #[cfg(not(feature = "sticky-cookie"))] | ||||||
|  | @ -1,5 +1,5 @@ | ||||||
| use super::{ | use super::{ | ||||||
|   load_balance::{LbContext, LbWithPointer, PointerToUpstream}, |   load_balance_main::{LoadBalanceContext, LoadBalanceWithPointer, PointerToUpstream}, | ||||||
|   sticky_cookie::StickyCookieConfig, |   sticky_cookie::StickyCookieConfig, | ||||||
|   Upstream, |   Upstream, | ||||||
| }; | }; | ||||||
|  | @ -16,7 +16,7 @@ use std::{ | ||||||
| 
 | 
 | ||||||
| #[derive(Debug, Clone, Builder)] | #[derive(Debug, Clone, Builder)] | ||||||
| /// Round Robin LB object in the sticky cookie manner
 | /// Round Robin LB object in the sticky cookie manner
 | ||||||
| pub struct LbStickyRoundRobin { | pub struct LoadBalanceSticky { | ||||||
|   #[builder(default)] |   #[builder(default)] | ||||||
|   /// Pointer to the index of the last served upstream destination
 |   /// Pointer to the index of the last served upstream destination
 | ||||||
|   ptr: Arc<AtomicUsize>, |   ptr: Arc<AtomicUsize>, | ||||||
|  | @ -39,11 +39,13 @@ pub struct UpstreamMap { | ||||||
|   /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 |   /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup
 | ||||||
|   upstream_id_map: HashMap<String, usize>, |   upstream_id_map: HashMap<String, usize>, | ||||||
| } | } | ||||||
| impl LbStickyRoundRobinBuilder { | impl LoadBalanceStickyBuilder { | ||||||
|  |   /// Set the number of upstream destinations
 | ||||||
|   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { |   pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { | ||||||
|     self.num_upstreams = Some(*v); |     self.num_upstreams = Some(*v); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|  |   /// Set the information to build the cookie to stick clients to specific backends
 | ||||||
|   pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option<String>) -> &mut Self { |   pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option<String>) -> &mut Self { | ||||||
|     self.sticky_config = Some(StickyCookieConfig { |     self.sticky_config = Some(StickyCookieConfig { | ||||||
|       name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように
 |       name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように
 | ||||||
|  | @ -57,6 +59,7 @@ impl LbStickyRoundRobinBuilder { | ||||||
|     }); |     }); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|  |   /// Set the hashmaps: upstream_index_map and upstream_id_map
 | ||||||
|   pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { |   pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||||
|     let upstream_index_map: Vec<String> = upstream_vec |     let upstream_index_map: Vec<String> = upstream_vec | ||||||
|       .iter() |       .iter() | ||||||
|  | @ -74,7 +77,8 @@ impl LbStickyRoundRobinBuilder { | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
| } | } | ||||||
| impl<'a> LbStickyRoundRobin { | impl<'a> LoadBalanceSticky { | ||||||
|  |   /// Increment the count of upstream served up to the max value
 | ||||||
|   fn simple_increment_ptr(&self) -> usize { |   fn simple_increment_ptr(&self) -> usize { | ||||||
|     // Get a current count of upstream served
 |     // Get a current count of upstream served
 | ||||||
|     let current_ptr = self.ptr.load(Ordering::Relaxed); |     let current_ptr = self.ptr.load(Ordering::Relaxed); | ||||||
|  | @ -96,8 +100,9 @@ impl<'a> LbStickyRoundRobin { | ||||||
|     self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) |     self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) | ||||||
|   } |   } | ||||||
| } | } | ||||||
| impl LbWithPointer for LbStickyRoundRobin { | impl LoadBalanceWithPointer for LoadBalanceSticky { | ||||||
|   fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { |   /// Get the pointer to the upstream server to serve the incoming request.
 | ||||||
|  |   fn get_ptr(&self, req_info: Option<&LoadBalanceContext>) -> PointerToUpstream { | ||||||
|     // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer.
 |     // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer.
 | ||||||
|     // Otherwise, get the server index indicated by the server_id inside the cookie
 |     // Otherwise, get the server index indicated by the server_id inside the cookie
 | ||||||
|     let ptr = match req_info { |     let ptr = match req_info { | ||||||
|  | @ -121,12 +126,12 @@ impl LbWithPointer for LbStickyRoundRobin { | ||||||
|     // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie).
 |     // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie).
 | ||||||
|     let upstream_id = self.get_server_id_from_index(ptr); |     let upstream_id = self.get_server_id_from_index(ptr); | ||||||
|     let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); |     let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); | ||||||
|     let new_context = Some(LbContext { |     let new_context = Some(LoadBalanceContext { | ||||||
|       sticky_cookie: new_cookie, |       sticky_cookie: new_cookie, | ||||||
|     }); |     }); | ||||||
|     PointerToUpstream { |     PointerToUpstream { | ||||||
|       ptr, |       ptr, | ||||||
|       context_lb: new_context, |       context: new_context, | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
							
								
								
									
										43
									
								
								rpxy-lib/src/backend/load_balance/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								rpxy-lib/src/backend/load_balance/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,43 @@ | ||||||
|  | mod load_balance_main; | ||||||
|  | #[cfg(feature = "sticky-cookie")] | ||||||
|  | mod load_balance_sticky; | ||||||
|  | #[cfg(feature = "sticky-cookie")] | ||||||
|  | mod sticky_cookie; | ||||||
|  | 
 | ||||||
|  | use super::upstream::Upstream; | ||||||
|  | use thiserror::Error; | ||||||
|  | 
 | ||||||
|  | pub use load_balance_main::{ | ||||||
|  |   load_balance_options, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, | ||||||
|  | }; | ||||||
|  | #[cfg(feature = "sticky-cookie")] | ||||||
|  | pub use load_balance_sticky::LoadBalanceStickyBuilder; | ||||||
|  | #[cfg(feature = "sticky-cookie")] | ||||||
|  | pub use sticky_cookie::{StickyCookie, StickyCookieValue}; | ||||||
|  | 
 | ||||||
|  | /// Result type for load balancing
 | ||||||
|  | type LoadBalanceResult<T> = std::result::Result<T, LoadBalanceError>; | ||||||
|  | /// Describes things that can go wrong in the Load Balance
 | ||||||
|  | #[derive(Debug, Error)] | ||||||
|  | pub enum LoadBalanceError { | ||||||
|  |   // backend load balance errors
 | ||||||
|  |   #[cfg(feature = "sticky-cookie")] | ||||||
|  |   #[error("Failed to cookie conversion to/from string")] | ||||||
|  |   FailedToConversionStickyCookie, | ||||||
|  | 
 | ||||||
|  |   #[cfg(feature = "sticky-cookie")] | ||||||
|  |   #[error("Invalid cookie structure")] | ||||||
|  |   InvalidStickyCookieStructure, | ||||||
|  | 
 | ||||||
|  |   #[cfg(feature = "sticky-cookie")] | ||||||
|  |   #[error("No sticky cookie value")] | ||||||
|  |   NoStickyCookieValue, | ||||||
|  | 
 | ||||||
|  |   #[cfg(feature = "sticky-cookie")] | ||||||
|  |   #[error("Failed to cookie conversion into string: no meta information")] | ||||||
|  |   NoStickyCookieNoMetaInfo, | ||||||
|  | 
 | ||||||
|  |   #[cfg(feature = "sticky-cookie")] | ||||||
|  |   #[error("Failed to build sticky cookie from config")] | ||||||
|  |   FailedToBuildStickyCookie, | ||||||
|  | } | ||||||
|  | @ -1,8 +1,7 @@ | ||||||
| use std::borrow::Cow; | use super::{LoadBalanceError, LoadBalanceResult}; | ||||||
| 
 |  | ||||||
| use crate::error::*; |  | ||||||
| use chrono::{TimeZone, Utc}; | use chrono::{TimeZone, Utc}; | ||||||
| use derive_builder::Builder; | use derive_builder::Builder; | ||||||
|  | use std::borrow::Cow; | ||||||
| 
 | 
 | ||||||
| #[derive(Debug, Clone, Builder)] | #[derive(Debug, Clone, Builder)] | ||||||
| /// Cookie value only, used for COOKIE in req
 | /// Cookie value only, used for COOKIE in req
 | ||||||
|  | @ -25,18 +24,16 @@ impl<'a> StickyCookieValueBuilder { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| impl StickyCookieValue { | impl StickyCookieValue { | ||||||
|   pub fn try_from(value: &str, expected_name: &str) -> Result<Self> { |   pub fn try_from(value: &str, expected_name: &str) -> LoadBalanceResult<Self> { | ||||||
|     if !value.starts_with(expected_name) { |     if !value.starts_with(expected_name) { | ||||||
|       return Err(RpxyError::LoadBalance( |       return Err(LoadBalanceError::FailedToConversionStickyCookie); | ||||||
|         "Failed to cookie conversion from string".to_string(), |  | ||||||
|       )); |  | ||||||
|     }; |     }; | ||||||
|     let kv = value.split('=').map(|v| v.trim()).collect::<Vec<&str>>(); |     let kv = value.split('=').map(|v| v.trim()).collect::<Vec<&str>>(); | ||||||
|     if kv.len() != 2 { |     if kv.len() != 2 { | ||||||
|       return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); |       return Err(LoadBalanceError::InvalidStickyCookieStructure); | ||||||
|     }; |     }; | ||||||
|     if kv[1].is_empty() { |     if kv[1].is_empty() { | ||||||
|       return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); |       return Err(LoadBalanceError::NoStickyCookieValue); | ||||||
|     } |     } | ||||||
|     Ok(StickyCookieValue { |     Ok(StickyCookieValue { | ||||||
|       name: expected_name.to_string(), |       name: expected_name.to_string(), | ||||||
|  | @ -88,10 +85,12 @@ pub struct StickyCookie { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl<'a> StickyCookieBuilder { | impl<'a> StickyCookieBuilder { | ||||||
|  |   /// Set the value of sticky cookie
 | ||||||
|   pub fn value(&mut self, n: impl Into<Cow<'a, str>>, v: impl Into<Cow<'a, str>>) -> &mut Self { |   pub fn value(&mut self, n: impl Into<Cow<'a, str>>, v: impl Into<Cow<'a, str>>) -> &mut Self { | ||||||
|     self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); |     self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|  |   /// Set the meta information of sticky cookie
 | ||||||
|   pub fn info( |   pub fn info( | ||||||
|     &mut self, |     &mut self, | ||||||
|     domain: impl Into<Cow<'a, str>>, |     domain: impl Into<Cow<'a, str>>, | ||||||
|  | @ -110,17 +109,15 @@ impl<'a> StickyCookieBuilder { | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl TryInto<String> for StickyCookie { | impl TryInto<String> for StickyCookie { | ||||||
|   type Error = RpxyError; |   type Error = LoadBalanceError; | ||||||
| 
 | 
 | ||||||
|   fn try_into(self) -> Result<String> { |   fn try_into(self) -> LoadBalanceResult<String> { | ||||||
|     if self.info.is_none() { |     if self.info.is_none() { | ||||||
|       return Err(RpxyError::LoadBalance( |       return Err(LoadBalanceError::NoStickyCookieNoMetaInfo); | ||||||
|         "Failed to cookie conversion into string: no meta information".to_string(), |  | ||||||
|       )); |  | ||||||
|     } |     } | ||||||
|     let info = self.info.unwrap(); |     let info = self.info.unwrap(); | ||||||
|     let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { |     let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { | ||||||
|       return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); |       return Err(LoadBalanceError::FailedToConversionStickyCookie); | ||||||
|     }; |     }; | ||||||
|     let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); |     let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); | ||||||
|     let max_age = info.expires - Utc::now().timestamp(); |     let max_age = info.expires - Utc::now().timestamp(); | ||||||
|  | @ -144,12 +141,12 @@ pub struct StickyCookieConfig { | ||||||
|   pub duration: i64, |   pub duration: i64, | ||||||
| } | } | ||||||
| impl<'a> StickyCookieConfig { | impl<'a> StickyCookieConfig { | ||||||
|   pub fn build_sticky_cookie(&self, v: impl Into<Cow<'a, str>>) -> Result<StickyCookie> { |   pub fn build_sticky_cookie(&self, v: impl Into<Cow<'a, str>>) -> LoadBalanceResult<StickyCookie> { | ||||||
|     StickyCookieBuilder::default() |     StickyCookieBuilder::default() | ||||||
|       .value(self.name.clone(), v) |       .value(self.name.clone(), v) | ||||||
|       .info(&self.domain, &self.path, self.duration) |       .info(&self.domain, &self.path, self.duration) | ||||||
|       .build() |       .build() | ||||||
|       .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) |       .map_err(|_| LoadBalanceError::FailedToBuildStickyCookie) | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | @ -167,7 +164,7 @@ mod tests { | ||||||
|       duration: 100, |       duration: 100, | ||||||
|     }; |     }; | ||||||
|     let expires_unix = Utc::now().timestamp() + 100; |     let expires_unix = Utc::now().timestamp() + 100; | ||||||
|     let sc_string: Result<String> = config.build_sticky_cookie("test_value").unwrap().try_into(); |     let sc_string: LoadBalanceResult<String> = config.build_sticky_cookie("test_value").unwrap().try_into(); | ||||||
|     let expires_date_string = Utc |     let expires_date_string = Utc | ||||||
|       .timestamp_opt(expires_unix, 0) |       .timestamp_opt(expires_unix, 0) | ||||||
|       .unwrap() |       .unwrap() | ||||||
|  | @ -194,7 +191,7 @@ mod tests { | ||||||
|         path: "/path".to_string(), |         path: "/path".to_string(), | ||||||
|       }), |       }), | ||||||
|     }; |     }; | ||||||
|     let sc_string: Result<String> = sc.try_into(); |     let sc_string: LoadBalanceResult<String> = sc.try_into(); | ||||||
|     let max_age = 1686221173i64 - Utc::now().timestamp(); |     let max_age = 1686221173i64 - Utc::now().timestamp(); | ||||||
|     assert!(sc_string.is_ok()); |     assert!(sc_string.is_ok()); | ||||||
|     assert_eq!( |     assert_eq!( | ||||||
|  | @ -1,77 +1,14 @@ | ||||||
|  | mod backend_main; | ||||||
| mod load_balance; | mod load_balance; | ||||||
| #[cfg(feature = "sticky-cookie")] |  | ||||||
| mod load_balance_sticky; |  | ||||||
| #[cfg(feature = "sticky-cookie")] |  | ||||||
| mod sticky_cookie; |  | ||||||
| mod upstream; | mod upstream; | ||||||
| mod upstream_opts; | mod upstream_opts; | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "sticky-cookie")] | #[cfg(feature = "sticky-cookie")] | ||||||
| pub use self::sticky_cookie::{StickyCookie, StickyCookieValue}; | pub(crate) use self::load_balance::{StickyCookie, StickyCookieValue}; | ||||||
| pub use self::{ | #[allow(unused)] | ||||||
|   load_balance::{LbContext, LoadBalance}, | pub(crate) use self::{ | ||||||
|   upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, |   load_balance::{LoadBalance, LoadBalanceContext}, | ||||||
|  |   upstream::{PathManager, Upstream, UpstreamCandidates}, | ||||||
|   upstream_opts::UpstreamOption, |   upstream_opts::UpstreamOption, | ||||||
| }; | }; | ||||||
| use crate::{ | pub(crate) use backend_main::{BackendApp, BackendAppBuilderError, BackendAppManager}; | ||||||
|   certs::CryptoSource, |  | ||||||
|   utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, |  | ||||||
| }; |  | ||||||
| use derive_builder::Builder; |  | ||||||
| use rustc_hash::FxHashMap as HashMap; |  | ||||||
| use std::borrow::Cow; |  | ||||||
| 
 |  | ||||||
| /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings.
 |  | ||||||
| #[derive(Builder)] |  | ||||||
| pub struct Backend<T> |  | ||||||
| where |  | ||||||
|   T: CryptoSource, |  | ||||||
| { |  | ||||||
|   #[builder(setter(into))] |  | ||||||
|   /// backend application name, e.g., app1
 |  | ||||||
|   pub app_name: String, |  | ||||||
|   #[builder(setter(custom))] |  | ||||||
|   /// server name, e.g., example.com, in String ascii lower case
 |  | ||||||
|   pub server_name: String, |  | ||||||
|   /// struct of reverse proxy serving incoming request
 |  | ||||||
|   pub reverse_proxy: ReverseProxy, |  | ||||||
| 
 |  | ||||||
|   /// tls settings: https redirection with 30x
 |  | ||||||
|   #[builder(default)] |  | ||||||
|   pub https_redirection: Option<bool>, |  | ||||||
| 
 |  | ||||||
|   /// TLS settings: source meta for server cert, key, client ca cert
 |  | ||||||
|   #[builder(default)] |  | ||||||
|   pub crypto_source: Option<T>, |  | ||||||
| } |  | ||||||
| impl<'a, T> BackendBuilder<T> |  | ||||||
| where |  | ||||||
|   T: CryptoSource, |  | ||||||
| { |  | ||||||
|   pub fn server_name(&mut self, server_name: impl Into<Cow<'a, str>>) -> &mut Self { |  | ||||||
|     self.server_name = Some(server_name.into().to_ascii_lowercase()); |  | ||||||
|     self |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// HashMap and some meta information for multiple Backend structs.
 |  | ||||||
| pub struct Backends<T> |  | ||||||
| where |  | ||||||
|   T: CryptoSource, |  | ||||||
| { |  | ||||||
|   pub apps: HashMap<ServerNameBytesExp, Backend<T>>, // hyper::uriで抜いたhostで引っ掛ける
 |  | ||||||
|   pub default_server_name_bytes: Option<ServerNameBytesExp>, // for plaintext http
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<T> Backends<T> |  | ||||||
| where |  | ||||||
|   T: CryptoSource, |  | ||||||
| { |  | ||||||
|   #[allow(clippy::new_without_default)] |  | ||||||
|   pub fn new() -> Self { |  | ||||||
|     Backends { |  | ||||||
|       apps: HashMap::<ServerNameBytesExp, Backend<T>>::default(), |  | ||||||
|       default_server_name_bytes: None, |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -1,8 +1,17 @@ | ||||||
| #[cfg(feature = "sticky-cookie")] | #[cfg(feature = "sticky-cookie")] | ||||||
| use super::load_balance::LbStickyRoundRobinBuilder; | use super::load_balance::LoadBalanceStickyBuilder; | ||||||
| use super::load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}; | use super::load_balance::{ | ||||||
| use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; |   load_balance_options as lb_opts, LoadBalance, LoadBalanceContext, LoadBalanceRandomBuilder, LoadBalanceRoundRobinBuilder, | ||||||
| use crate::log::*; | }; | ||||||
|  | // use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption};
 | ||||||
|  | use super::upstream_opts::UpstreamOption; | ||||||
|  | use crate::{ | ||||||
|  |   crypto::CryptoSource, | ||||||
|  |   error::RpxyError, | ||||||
|  |   globals::{AppConfig, UpstreamUri}, | ||||||
|  |   log::*, | ||||||
|  |   name_exp::{ByteName, PathName}, | ||||||
|  | }; | ||||||
| #[cfg(feature = "sticky-cookie")] | #[cfg(feature = "sticky-cookie")] | ||||||
| use base64::{engine::general_purpose, Engine as _}; | use base64::{engine::general_purpose, Engine as _}; | ||||||
| use derive_builder::Builder; | use derive_builder::Builder; | ||||||
|  | @ -10,26 +19,67 @@ use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; | ||||||
| #[cfg(feature = "sticky-cookie")] | #[cfg(feature = "sticky-cookie")] | ||||||
| use sha2::{Digest, Sha256}; | use sha2::{Digest, Sha256}; | ||||||
| use std::borrow::Cow; | use std::borrow::Cow; | ||||||
|  | 
 | ||||||
| #[derive(Debug, Clone)] | #[derive(Debug, Clone)] | ||||||
| pub struct ReverseProxy { | /// Handler for given path to route incoming request to path's corresponding upstream server(s).
 | ||||||
|   pub upstream: HashMap<PathNameBytesExp, UpstreamGroup>, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 | pub struct PathManager { | ||||||
|  |   /// HashMap of upstream candidate server info, key is path name
 | ||||||
|  |   /// TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。
 | ||||||
|  |   inner: HashMap<PathName, UpstreamCandidates>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl ReverseProxy { | impl<T> TryFrom<&AppConfig<T>> for PathManager | ||||||
|   /// Get an appropriate upstream destination for given path string.
 | where | ||||||
|   pub fn get<'a>(&self, path_str: impl Into<Cow<'a, str>>) -> Option<&UpstreamGroup> { |   T: CryptoSource, | ||||||
|     // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、
 | { | ||||||
|     // コスト的にこの程度で十分
 |   type Error = RpxyError; | ||||||
|     let path_bytes = &path_str.to_path_name_vec(); |   fn try_from(app_config: &AppConfig<T>) -> Result<Self, Self::Error> { | ||||||
|  |     let mut inner: HashMap<PathName, UpstreamCandidates> = HashMap::default(); | ||||||
|  | 
 | ||||||
|  |     app_config.reverse_proxy.iter().for_each(|rpc| { | ||||||
|  |       let upstream_vec: Vec<Upstream> = rpc.upstream.iter().map(Upstream::from).collect(); | ||||||
|  |       let elem = UpstreamCandidatesBuilder::default() | ||||||
|  |         .upstream(&upstream_vec) | ||||||
|  |         .path(&rpc.path) | ||||||
|  |         .replace_path(&rpc.replace_path) | ||||||
|  |         .load_balance(&rpc.load_balance, &upstream_vec, &app_config.server_name, &rpc.path) | ||||||
|  |         .options(&rpc.upstream_options) | ||||||
|  |         .build() | ||||||
|  |         .unwrap(); | ||||||
|  |       inner.insert(elem.path.clone(), elem); | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     if app_config.reverse_proxy.iter().filter(|rpc| rpc.path.is_none()).count() >= 2 { | ||||||
|  |       error!("Multiple default reverse proxy setting"); | ||||||
|  |       return Err(RpxyError::InvalidReverseProxyConfig); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if !(inner.iter().all(|(_, elem)| { | ||||||
|  |       !(elem.options.contains(&UpstreamOption::ForceHttp11Upstream) && elem.options.contains(&UpstreamOption::ForceHttp2Upstream)) | ||||||
|  |     })) { | ||||||
|  |       error!("Either one of force_http11 or force_http2 can be enabled"); | ||||||
|  |       return Err(RpxyError::InvalidUpstreamOptionSetting); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ok(PathManager { inner }) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl PathManager { | ||||||
|  |   /// Get an appropriate upstream destinations for given path string.
 | ||||||
|  |   /// trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、
 | ||||||
|  |   /// コスト的にこの程度で十分では。
 | ||||||
|  |   pub fn get<'a>(&self, path_str: impl Into<Cow<'a, str>>) -> Option<&UpstreamCandidates> { | ||||||
|  |     let path_name = &path_str.to_path_name(); | ||||||
| 
 | 
 | ||||||
|     let matched_upstream = self |     let matched_upstream = self | ||||||
|       .upstream |       .inner | ||||||
|       .iter() |       .iter() | ||||||
|       .filter(|(route_bytes, _)| { |       .filter(|(route_bytes, _)| { | ||||||
|         match path_bytes.starts_with(route_bytes) { |         match path_name.starts_with(route_bytes) { | ||||||
|           true => { |           true => { | ||||||
|             route_bytes.len() == 1 // route = '/', i.e., default
 |             route_bytes.len() == 1 // route = '/', i.e., default
 | ||||||
|             || match path_bytes.get(route_bytes.len()) { |               || match path_name.get(route_bytes.len()) { | ||||||
|                 None => true, // exact case
 |                 None => true, // exact case
 | ||||||
|                 Some(p) => p == &b'/', // sub-path case
 |                 Some(p) => p == &b'/', // sub-path case
 | ||||||
|               } |               } | ||||||
|  | @ -38,10 +88,10 @@ impl ReverseProxy { | ||||||
|         } |         } | ||||||
|       }) |       }) | ||||||
|       .max_by_key(|(route_bytes, _)| route_bytes.len()); |       .max_by_key(|(route_bytes, _)| route_bytes.len()); | ||||||
|     if let Some((_path, u)) = matched_upstream { |     if let Some((path, u)) = matched_upstream { | ||||||
|       debug!( |       debug!( | ||||||
|         "Found upstream: {:?}", |         "Found upstream: {:?}", | ||||||
|         String::from_utf8(_path.0.clone()).unwrap_or_else(|_| "<none>".to_string()) |         path.try_into().unwrap_or_else(|_| "<none>".to_string()) | ||||||
|       ); |       ); | ||||||
|       Some(u) |       Some(u) | ||||||
|     } else { |     } else { | ||||||
|  | @ -56,6 +106,13 @@ pub struct Upstream { | ||||||
|   /// Base uri without specific path
 |   /// Base uri without specific path
 | ||||||
|   pub uri: hyper::Uri, |   pub uri: hyper::Uri, | ||||||
| } | } | ||||||
|  | impl From<&UpstreamUri> for Upstream { | ||||||
|  |   fn from(value: &UpstreamUri) -> Self { | ||||||
|  |     Self { | ||||||
|  |       uri: value.inner.clone(), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
| impl Upstream { | impl Upstream { | ||||||
|   #[cfg(feature = "sticky-cookie")] |   #[cfg(feature = "sticky-cookie")] | ||||||
|   /// Hashing uri with index to avoid collision
 |   /// Hashing uri with index to avoid collision
 | ||||||
|  | @ -69,51 +126,54 @@ impl Upstream { | ||||||
| } | } | ||||||
| #[derive(Debug, Clone, Builder)] | #[derive(Debug, Clone, Builder)] | ||||||
| /// Struct serving multiple upstream servers for, e.g., load balancing.
 | /// Struct serving multiple upstream servers for, e.g., load balancing.
 | ||||||
| pub struct UpstreamGroup { | pub struct UpstreamCandidates { | ||||||
|   #[builder(setter(custom))] |   #[builder(setter(custom))] | ||||||
|   /// Upstream server(s)
 |   /// Upstream server(s)
 | ||||||
|   pub upstream: Vec<Upstream>, |   pub inner: Vec<Upstream>, | ||||||
|  | 
 | ||||||
|   #[builder(setter(custom), default)] |   #[builder(setter(custom), default)] | ||||||
|   /// Path like "/path" in [[PathNameBytesExp]] associated with the upstream server(s)
 |   /// Path like "/path" in [[PathName]] associated with the upstream server(s)
 | ||||||
|   pub path: PathNameBytesExp, |   pub path: PathName, | ||||||
|  | 
 | ||||||
|   #[builder(setter(custom), default)] |   #[builder(setter(custom), default)] | ||||||
|   /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url
 |   /// Path in [[PathName]] that will be used to replace the "path" part of incoming url
 | ||||||
|   pub replace_path: Option<PathNameBytesExp>, |   pub replace_path: Option<PathName>, | ||||||
| 
 | 
 | ||||||
|   #[builder(setter(custom), default)] |   #[builder(setter(custom), default)] | ||||||
|   /// Load balancing option
 |   /// Load balancing option
 | ||||||
|   pub lb: LoadBalance, |   pub load_balance: LoadBalance, | ||||||
|  | 
 | ||||||
|   #[builder(setter(custom), default)] |   #[builder(setter(custom), default)] | ||||||
|   /// Activated upstream options defined in [[UpstreamOption]]
 |   /// Activated upstream options defined in [[UpstreamOption]]
 | ||||||
|   pub opts: HashSet<UpstreamOption>, |   pub options: HashSet<UpstreamOption>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl UpstreamGroupBuilder { | impl UpstreamCandidatesBuilder { | ||||||
|  |   /// Set the upstream server(s)
 | ||||||
|   pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { |   pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { | ||||||
|     self.upstream = Some(upstream_vec.to_vec()); |     self.inner = Some(upstream_vec.to_vec()); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|  |   /// Set the path like "/path" in [[PathName]] associated with the upstream server(s), default is "/"
 | ||||||
|   pub fn path(&mut self, v: &Option<String>) -> &mut Self { |   pub fn path(&mut self, v: &Option<String>) -> &mut Self { | ||||||
|     let path = match v { |     let path = match v { | ||||||
|       Some(p) => p.to_path_name_vec(), |       Some(p) => p.to_path_name(), | ||||||
|       None => "/".to_path_name_vec(), |       None => "/".to_path_name(), | ||||||
|     }; |     }; | ||||||
|     self.path = Some(path); |     self.path = Some(path); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|  |   /// Set the path in [[PathName]] that will be used to replace the "path" part of incoming url
 | ||||||
|   pub fn replace_path(&mut self, v: &Option<String>) -> &mut Self { |   pub fn replace_path(&mut self, v: &Option<String>) -> &mut Self { | ||||||
|     self.replace_path = Some( |     self.replace_path = Some(v.to_owned().as_ref().map_or_else(|| None, |v| Some(v.to_path_name()))); | ||||||
|       v.to_owned() |  | ||||||
|         .as_ref() |  | ||||||
|         .map_or_else(|| None, |v| Some(v.to_path_name_vec())), |  | ||||||
|     ); |  | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|   pub fn lb( |   /// Set the load balancing option
 | ||||||
|  |   pub fn load_balance( | ||||||
|     &mut self, |     &mut self, | ||||||
|     v: &Option<String>, |     v: &Option<String>, | ||||||
|     // upstream_num: &usize,
 |     // upstream_num: &usize,
 | ||||||
|     upstream_vec: &Vec<Upstream>, |     upstream_vec: &[Upstream], | ||||||
|     _server_name: &str, |     _server_name: &str, | ||||||
|     _path_opt: &Option<String>, |     _path_opt: &Option<String>, | ||||||
|   ) -> &mut Self { |   ) -> &mut Self { | ||||||
|  | @ -121,16 +181,21 @@ impl UpstreamGroupBuilder { | ||||||
|     let lb = if let Some(x) = v { |     let lb = if let Some(x) = v { | ||||||
|       match x.as_str() { |       match x.as_str() { | ||||||
|         lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, |         lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, | ||||||
|         lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), |         lb_opts::RANDOM => LoadBalance::Random( | ||||||
|  |           LoadBalanceRandomBuilder::default() | ||||||
|  |             .num_upstreams(upstream_num) | ||||||
|  |             .build() | ||||||
|  |             .unwrap(), | ||||||
|  |         ), | ||||||
|         lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( |         lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( | ||||||
|           LbRoundRobinBuilder::default() |           LoadBalanceRoundRobinBuilder::default() | ||||||
|             .num_upstreams(upstream_num) |             .num_upstreams(upstream_num) | ||||||
|             .build() |             .build() | ||||||
|             .unwrap(), |             .unwrap(), | ||||||
|         ), |         ), | ||||||
|         #[cfg(feature = "sticky-cookie")] |         #[cfg(feature = "sticky-cookie")] | ||||||
|         lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( |         lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( | ||||||
|           LbStickyRoundRobinBuilder::default() |           LoadBalanceStickyBuilder::default() | ||||||
|             .num_upstreams(upstream_num) |             .num_upstreams(upstream_num) | ||||||
|             .sticky_config(_server_name, _path_opt) |             .sticky_config(_server_name, _path_opt) | ||||||
|             .upstream_maps(upstream_vec) // TODO:
 |             .upstream_maps(upstream_vec) // TODO:
 | ||||||
|  | @ -145,10 +210,11 @@ impl UpstreamGroupBuilder { | ||||||
|     } else { |     } else { | ||||||
|       LoadBalance::default() |       LoadBalance::default() | ||||||
|     }; |     }; | ||||||
|     self.lb = Some(lb); |     self.load_balance = Some(lb); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
|   pub fn opts(&mut self, v: &Option<Vec<String>>) -> &mut Self { |   /// Set the activated upstream options defined in [[UpstreamOption]]
 | ||||||
|  |   pub fn options(&mut self, v: &Option<Vec<String>>) -> &mut Self { | ||||||
|     let opts = if let Some(opts) = v { |     let opts = if let Some(opts) = v { | ||||||
|       opts |       opts | ||||||
|         .iter() |         .iter() | ||||||
|  | @ -157,25 +223,19 @@ impl UpstreamGroupBuilder { | ||||||
|     } else { |     } else { | ||||||
|       Default::default() |       Default::default() | ||||||
|     }; |     }; | ||||||
|     self.opts = Some(opts); |     self.options = Some(opts); | ||||||
|     self |     self | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| impl UpstreamGroup { | impl UpstreamCandidates { | ||||||
|   /// Get an enabled option of load balancing [[LoadBalance]]
 |   /// Get an enabled option of load balancing [[LoadBalance]]
 | ||||||
|   pub fn get(&self, context_to_lb: &Option<LbContext>) -> (Option<&Upstream>, Option<LbContext>) { |   pub fn get(&self, context_to_lb: &Option<LoadBalanceContext>) -> (Option<&Upstream>, Option<LoadBalanceContext>) { | ||||||
|     let pointer_to_upstream = self.lb.get_context(context_to_lb); |     let pointer_to_upstream = self.load_balance.get_context(context_to_lb); | ||||||
|     debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); |     debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); | ||||||
|     debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); |     debug!("Context to LB (Cookie in Request): {:?}", context_to_lb); | ||||||
|     debug!( |     debug!("Context from LB (Set-Cookie in Response): {:?}", pointer_to_upstream.context); | ||||||
|       "Context from LB (Set-Cookie in Res): {:?}", |     (self.inner.get(pointer_to_upstream.ptr), pointer_to_upstream.context) | ||||||
|       pointer_to_upstream.context_lb |  | ||||||
|     ); |  | ||||||
|     ( |  | ||||||
|       self.upstream.get(pointer_to_upstream.ptr), |  | ||||||
|       pointer_to_upstream.context_lb, |  | ||||||
|     ) |  | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,22 +1,30 @@ | ||||||
| use crate::error::*; | use crate::error::*; | ||||||
| 
 | 
 | ||||||
|  | /// Options for request message to be sent to upstream.
 | ||||||
| #[derive(Debug, Clone, Hash, Eq, PartialEq)] | #[derive(Debug, Clone, Hash, Eq, PartialEq)] | ||||||
| pub enum UpstreamOption { | pub enum UpstreamOption { | ||||||
|   OverrideHost, |   /// Keep original host header, which is prioritized over SetUpstreamHost
 | ||||||
|  |   KeepOriginalHost, | ||||||
|  |   /// Overwrite host header with upstream hostname
 | ||||||
|  |   SetUpstreamHost, | ||||||
|  |   /// Add upgrade-insecure-requests header
 | ||||||
|   UpgradeInsecureRequests, |   UpgradeInsecureRequests, | ||||||
|  |   /// Force HTTP/1.1 upstream
 | ||||||
|   ForceHttp11Upstream, |   ForceHttp11Upstream, | ||||||
|  |   /// Force HTTP/2 upstream
 | ||||||
|   ForceHttp2Upstream, |   ForceHttp2Upstream, | ||||||
|   // TODO: Adds more options for heder override
 |   // TODO: Adds more options for heder override
 | ||||||
| } | } | ||||||
| impl TryFrom<&str> for UpstreamOption { | impl TryFrom<&str> for UpstreamOption { | ||||||
|   type Error = RpxyError; |   type Error = RpxyError; | ||||||
|   fn try_from(val: &str) -> Result<Self> { |   fn try_from(val: &str) -> RpxyResult<Self> { | ||||||
|     match val { |     match val { | ||||||
|       "override_host" => Ok(Self::OverrideHost), |       "keep_original_host" => Ok(Self::KeepOriginalHost), | ||||||
|  |       "set_upstream_host" => Ok(Self::SetUpstreamHost), | ||||||
|       "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), |       "upgrade_insecure_requests" => Ok(Self::UpgradeInsecureRequests), | ||||||
|       "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), |       "force_http11_upstream" => Ok(Self::ForceHttp11Upstream), | ||||||
|       "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), |       "force_http2_upstream" => Ok(Self::ForceHttp2Upstream), | ||||||
|       _ => Err(RpxyError::Other(anyhow!("Unsupported header option"))), |       _ => Err(RpxyError::UnsupportedUpstreamOption), | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -4,8 +4,8 @@ pub const RESPONSE_HEADER_SERVER: &str = "rpxy"; | ||||||
| pub const TCP_LISTEN_BACKLOG: u32 = 1024; | pub const TCP_LISTEN_BACKLOG: u32 = 1024; | ||||||
| // pub const HTTP_LISTEN_PORT: u16 = 8080;
 | // pub const HTTP_LISTEN_PORT: u16 = 8080;
 | ||||||
| // pub const HTTPS_LISTEN_PORT: u16 = 8443;
 | // pub const HTTPS_LISTEN_PORT: u16 = 8443;
 | ||||||
| pub const PROXY_TIMEOUT_SEC: u64 = 60; | pub const PROXY_IDLE_TIMEOUT_SEC: u64 = 20; | ||||||
| pub const UPSTREAM_TIMEOUT_SEC: u64 = 60; | pub const UPSTREAM_IDLE_TIMEOUT_SEC: u64 = 20; | ||||||
| pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser
 | pub const TLS_HANDSHAKE_TIMEOUT_SEC: u64 = 15; // default as with firefox browser
 | ||||||
| pub const MAX_CLIENTS: usize = 512; | pub const MAX_CLIENTS: usize = 512; | ||||||
| pub const MAX_CONCURRENT_STREAMS: u32 = 64; | pub const MAX_CONCURRENT_STREAMS: u32 = 64; | ||||||
|  |  | ||||||
							
								
								
									
										31
									
								
								rpxy-lib/src/count.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								rpxy-lib/src/count.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,31 @@ | ||||||
|  | use std::sync::{ | ||||||
|  |   atomic::{AtomicUsize, Ordering}, | ||||||
|  |   Arc, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone, Default)] | ||||||
|  | /// Counter for serving requests
 | ||||||
|  | pub struct RequestCount(Arc<AtomicUsize>); | ||||||
|  | 
 | ||||||
|  | impl RequestCount { | ||||||
|  |   pub fn current(&self) -> usize { | ||||||
|  |     self.0.load(Ordering::Relaxed) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   pub fn increment(&self) -> usize { | ||||||
|  |     self.0.fetch_add(1, Ordering::Relaxed) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   pub fn decrement(&self) -> usize { | ||||||
|  |     let mut count; | ||||||
|  |     while { | ||||||
|  |       count = self.0.load(Ordering::Relaxed); | ||||||
|  |       count > 0 | ||||||
|  |         && self | ||||||
|  |           .0 | ||||||
|  |           .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) | ||||||
|  |           != Ok(count) | ||||||
|  |     } {} | ||||||
|  |     count | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										36
									
								
								rpxy-lib/src/crypto/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								rpxy-lib/src/crypto/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,36 @@ | ||||||
|  | mod certs; | ||||||
|  | mod service; | ||||||
|  | 
 | ||||||
|  | use crate::{ | ||||||
|  |   backend::BackendAppManager, | ||||||
|  |   constants::{CERTS_WATCH_DELAY_SECS, LOAD_CERTS_ONLY_WHEN_UPDATED}, | ||||||
|  |   error::RpxyResult, | ||||||
|  | }; | ||||||
|  | use hot_reload::{ReloaderReceiver, ReloaderService}; | ||||||
|  | use service::CryptoReloader; | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
|  | pub use certs::{CertsAndKeys, CryptoSource}; | ||||||
|  | pub use service::{ServerCrypto, ServerCryptoBase, SniServerCryptoMap}; | ||||||
|  | 
 | ||||||
|  | /// Result type inner of certificate reloader service
 | ||||||
|  | type ReloaderServiceResultInner<T> = ( | ||||||
|  |   ReloaderService<CryptoReloader<T>, ServerCryptoBase>, | ||||||
|  |   ReloaderReceiver<ServerCryptoBase>, | ||||||
|  | ); | ||||||
|  | /// Build certificate reloader service
 | ||||||
|  | pub(crate) async fn build_cert_reloader<T>( | ||||||
|  |   app_manager: &Arc<BackendAppManager<T>>, | ||||||
|  | ) -> RpxyResult<ReloaderServiceResultInner<T>> | ||||||
|  | where | ||||||
|  |   T: CryptoSource + Clone + Send + Sync + 'static, | ||||||
|  | { | ||||||
|  |   let (cert_reloader_service, cert_reloader_rx) = ReloaderService::< | ||||||
|  |     service::CryptoReloader<T>, | ||||||
|  |     service::ServerCryptoBase, | ||||||
|  |   >::new( | ||||||
|  |     app_manager, CERTS_WATCH_DELAY_SECS, !LOAD_CERTS_ONLY_WHEN_UPDATED | ||||||
|  |   ) | ||||||
|  |   .await?; | ||||||
|  |   Ok((cert_reloader_service, cert_reloader_rx)) | ||||||
|  | } | ||||||
|  | @ -1,9 +1,5 @@ | ||||||
| use crate::{ | use super::certs::{CertsAndKeys, CryptoSource}; | ||||||
|   certs::{CertsAndKeys, CryptoSource}, | use crate::{backend::BackendAppManager, log::*, name_exp::ServerName}; | ||||||
|   globals::Globals, |  | ||||||
|   log::*, |  | ||||||
|   utils::ServerNameBytesExp, |  | ||||||
| }; |  | ||||||
| use async_trait::async_trait; | use async_trait::async_trait; | ||||||
| use hot_reload::*; | use hot_reload::*; | ||||||
| use rustc_hash::FxHashMap as HashMap; | use rustc_hash::FxHashMap as HashMap; | ||||||
|  | @ -16,15 +12,17 @@ pub struct CryptoReloader<T> | ||||||
| where | where | ||||||
|   T: CryptoSource, |   T: CryptoSource, | ||||||
| { | { | ||||||
|   globals: Arc<Globals<T>>, |   inner: Arc<BackendAppManager<T>>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| pub type SniServerCryptoMap = HashMap<ServerNameBytesExp, Arc<ServerConfig>>; | /// SNI to ServerConfig map type
 | ||||||
|  | pub type SniServerCryptoMap = HashMap<ServerName, Arc<ServerConfig>>; | ||||||
|  | /// SNI to ServerConfig map
 | ||||||
| pub struct ServerCrypto { | pub struct ServerCrypto { | ||||||
|   // For Quic/HTTP3, only servers with no client authentication
 |   // For Quic/HTTP3, only servers with no client authentication
 | ||||||
|   #[cfg(feature = "http3-quinn")] |   #[cfg(feature = "http3-quinn")] | ||||||
|   pub inner_global_no_client_auth: Arc<ServerConfig>, |   pub inner_global_no_client_auth: Arc<ServerConfig>, | ||||||
|   #[cfg(feature = "http3-s2n")] |   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
|   pub inner_global_no_client_auth: s2n_quic_rustls::Server, |   pub inner_global_no_client_auth: s2n_quic_rustls::Server, | ||||||
|   // For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers
 |   // For TLS over TCP/HTTP2 and 1.1, map of SNI to server_crypto for all given servers
 | ||||||
|   pub inner_local_map: Arc<SniServerCryptoMap>, |   pub inner_local_map: Arc<SniServerCryptoMap>, | ||||||
|  | @ -33,7 +31,7 @@ pub struct ServerCrypto { | ||||||
| /// Reloader target for the certificate reloader service
 | /// Reloader target for the certificate reloader service
 | ||||||
| #[derive(Debug, PartialEq, Eq, Clone, Default)] | #[derive(Debug, PartialEq, Eq, Clone, Default)] | ||||||
| pub struct ServerCryptoBase { | pub struct ServerCryptoBase { | ||||||
|   inner: HashMap<ServerNameBytesExp, CertsAndKeys>, |   inner: HashMap<ServerName, CertsAndKeys>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[async_trait] | #[async_trait] | ||||||
|  | @ -41,17 +39,15 @@ impl<T> Reload<ServerCryptoBase> for CryptoReloader<T> | ||||||
| where | where | ||||||
|   T: CryptoSource + Sync + Send, |   T: CryptoSource + Sync + Send, | ||||||
| { | { | ||||||
|   type Source = Arc<Globals<T>>; |   type Source = Arc<BackendAppManager<T>>; | ||||||
|   async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> { |   async fn new(source: &Self::Source) -> Result<Self, ReloaderError<ServerCryptoBase>> { | ||||||
|     Ok(Self { |     Ok(Self { inner: source.clone() }) | ||||||
|       globals: source.clone(), |  | ||||||
|     }) |  | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   async fn reload(&self) -> Result<Option<ServerCryptoBase>, ReloaderError<ServerCryptoBase>> { |   async fn reload(&self) -> Result<Option<ServerCryptoBase>, ReloaderError<ServerCryptoBase>> { | ||||||
|     let mut certs_and_keys_map = ServerCryptoBase::default(); |     let mut certs_and_keys_map = ServerCryptoBase::default(); | ||||||
| 
 | 
 | ||||||
|     for (server_name_bytes_exp, backend) in self.globals.backends.apps.iter() { |     for (server_name_bytes_exp, backend) in self.inner.apps.iter() { | ||||||
|       if let Some(crypto_source) = &backend.crypto_source { |       if let Some(crypto_source) = &backend.crypto_source { | ||||||
|         let certs_and_keys = crypto_source |         let certs_and_keys = crypto_source | ||||||
|           .read() |           .read() | ||||||
|  | @ -78,7 +74,7 @@ impl TryInto<Arc<ServerCrypto>> for &ServerCryptoBase { | ||||||
|     Ok(Arc::new(ServerCrypto { |     Ok(Arc::new(ServerCrypto { | ||||||
|       #[cfg(feature = "http3-quinn")] |       #[cfg(feature = "http3-quinn")] | ||||||
|       inner_global_no_client_auth: Arc::new(server_crypto_global), |       inner_global_no_client_auth: Arc::new(server_crypto_global), | ||||||
|       #[cfg(feature = "http3-s2n")] |       #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
|       inner_global_no_client_auth: server_crypto_global, |       inner_global_no_client_auth: server_crypto_global, | ||||||
|       inner_local_map: Arc::new(server_crypto_local_map), |       inner_local_map: Arc::new(server_crypto_local_map), | ||||||
|     })) |     })) | ||||||
|  | @ -204,7 +200,7 @@ impl ServerCryptoBase { | ||||||
|     Ok(server_crypto_global) |     Ok(server_crypto_global) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   #[cfg(feature = "http3-s2n")] |   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
|   fn build_server_crypto_global(&self) -> Result<s2n_quic_rustls::Server, ReloaderError<ServerCryptoBase>> { |   fn build_server_crypto_global(&self) -> Result<s2n_quic_rustls::Server, ReloaderError<ServerCryptoBase>> { | ||||||
|     let mut resolver_global = s2n_quic_rustls::rustls::server::ResolvesServerCertUsingSni::new(); |     let mut resolver_global = s2n_quic_rustls::rustls::server::ResolvesServerCertUsingSni::new(); | ||||||
| 
 | 
 | ||||||
|  | @ -245,7 +241,7 @@ impl ServerCryptoBase { | ||||||
|   } |   } | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[cfg(feature = "http3-s2n")] | #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
| /// This is workaround for the version difference between rustls and s2n-quic-rustls
 | /// This is workaround for the version difference between rustls and s2n-quic-rustls
 | ||||||
| fn parse_server_certs_and_keys_s2n( | fn parse_server_certs_and_keys_s2n( | ||||||
|   certs_and_keys: &CertsAndKeys, |   certs_and_keys: &CertsAndKeys, | ||||||
|  | @ -1,86 +1,101 @@ | ||||||
| pub use anyhow::{anyhow, bail, ensure, Context}; |  | ||||||
| use std::io; |  | ||||||
| use thiserror::Error; | use thiserror::Error; | ||||||
| 
 | 
 | ||||||
| pub type Result<T> = std::result::Result<T, RpxyError>; | pub type RpxyResult<T> = std::result::Result<T, RpxyError>; | ||||||
| 
 | 
 | ||||||
| /// Describes things that can go wrong in the Rpxy
 | /// Describes things that can go wrong in the Rpxy
 | ||||||
| #[derive(Debug, Error)] | #[derive(Debug, Error)] | ||||||
| pub enum RpxyError { | pub enum RpxyError { | ||||||
|   #[error("Proxy build error: {0}")] |   // general errors
 | ||||||
|   ProxyBuild(#[from] crate::proxy::ProxyBuilderError), |   #[error("IO error: {0}")] | ||||||
|  |   Io(#[from] std::io::Error), | ||||||
| 
 | 
 | ||||||
|   #[error("Backend build error: {0}")] |   // TLS errors
 | ||||||
|   BackendBuild(#[from] crate::backend::BackendBuilderError), |   #[error("Failed to build TLS acceptor: {0}")] | ||||||
|  |   FailedToTlsHandshake(String), | ||||||
|  |   #[error("No server name in ClientHello")] | ||||||
|  |   NoServerNameInClientHello, | ||||||
|  |   #[error("No TLS serving app: {0}")] | ||||||
|  |   NoTlsServingApp(String), | ||||||
|  |   #[error("Failed to update server crypto: {0}")] | ||||||
|  |   FailedToUpdateServerCrypto(String), | ||||||
|  |   #[error("No server crypto: {0}")] | ||||||
|  |   NoServerCrypto(String), | ||||||
| 
 | 
 | ||||||
|   #[error("MessageHandler build error: {0}")] |   // hyper errors
 | ||||||
|   HandlerBuild(#[from] crate::handler::HttpMessageHandlerBuilderError), |   #[error("hyper body manipulation error: {0}")] | ||||||
|  |   HyperBodyManipulationError(String), | ||||||
|  |   #[error("New closed in incoming-like")] | ||||||
|  |   HyperIncomingLikeNewClosed, | ||||||
|  |   #[error("New body write aborted")] | ||||||
|  |   HyperNewBodyWriteAborted, | ||||||
|  |   #[error("Hyper error in serving request or response body type: {0}")] | ||||||
|  |   HyperBodyError(#[from] hyper::Error), | ||||||
| 
 | 
 | ||||||
|   #[error("Config builder error: {0}")] |   // http/3 errors
 | ||||||
|   ConfigBuild(&'static str), |   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||||
| 
 |   #[error("H3 error: {0}")] | ||||||
|   #[error("Http Message Handler Error: {0}")] |   H3Error(#[from] h3::Error), | ||||||
|   Handler(&'static str), |   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||||
| 
 |   #[error("Exceeds max request body size for HTTP/3")] | ||||||
|   #[error("Cache Error: {0}")] |   H3TooLargeBody, | ||||||
|   Cache(&'static str), |  | ||||||
| 
 |  | ||||||
|   #[error("Http Request Message Error: {0}")] |  | ||||||
|   Request(&'static str), |  | ||||||
| 
 |  | ||||||
|   #[error("TCP/UDP Proxy Layer Error: {0}")] |  | ||||||
|   Proxy(String), |  | ||||||
| 
 |  | ||||||
|   #[allow(unused)] |  | ||||||
|   #[error("LoadBalance Layer Error: {0}")] |  | ||||||
|   LoadBalance(String), |  | ||||||
| 
 |  | ||||||
|   #[error("I/O Error: {0}")] |  | ||||||
|   Io(#[from] io::Error), |  | ||||||
| 
 |  | ||||||
|   // #[error("Toml Deserialization Error")]
 |  | ||||||
|   // TomlDe(#[from] toml::de::Error),
 |  | ||||||
|   #[cfg(feature = "http3-quinn")] |  | ||||||
|   #[error("Quic Connection Error [quinn]: {0}")] |  | ||||||
|   QuicConn(#[from] quinn::ConnectionError), |  | ||||||
| 
 |  | ||||||
|   #[cfg(feature = "http3-s2n")] |  | ||||||
|   #[error("Quic Connection Error [s2n-quic]: {0}")] |  | ||||||
|   QUicConn(#[from] s2n_quic::connection::Error), |  | ||||||
| 
 | 
 | ||||||
|   #[cfg(feature = "http3-quinn")] |   #[cfg(feature = "http3-quinn")] | ||||||
|   #[error("H3 Error [quinn]: {0}")] |   #[error("Invalid rustls TLS version: {0}")] | ||||||
|   H3(#[from] h3::Error), |   QuinnInvalidTlsProtocolVersion(String), | ||||||
|  |   #[cfg(feature = "http3-quinn")] | ||||||
|  |   #[error("Quinn connection error: {0}")] | ||||||
|  |   QuinnConnectionFailed(#[from] quinn::ConnectionError), | ||||||
| 
 | 
 | ||||||
|   #[cfg(feature = "http3-s2n")] |   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
|   #[error("H3 Error [s2n-quic]: {0}")] |   #[error("s2n-quic validation error: {0}")] | ||||||
|   H3(#[from] s2n_quic_h3::h3::Error), |   S2nQuicValidationError(#[from] s2n_quic_core::transport::parameters::ValidationError), | ||||||
|  |   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
|  |   #[error("s2n-quic connection error: {0}")] | ||||||
|  |   S2nQuicConnectionError(#[from] s2n_quic_core::connection::Error), | ||||||
|  |   #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
|  |   #[error("s2n-quic start error: {0}")] | ||||||
|  |   S2nQuicStartError(#[from] s2n_quic::provider::StartError), | ||||||
| 
 | 
 | ||||||
|   #[error("rustls Connection Error: {0}")] |   // certificate reloader errors
 | ||||||
|   Rustls(#[from] rustls::Error), |   #[error("No certificate reloader when building a proxy for TLS")] | ||||||
|  |   NoCertificateReloader, | ||||||
|  |   #[error("Certificate reload error: {0}")] | ||||||
|  |   CertificateReloadError(#[from] hot_reload::ReloaderError<crate::crypto::ServerCryptoBase>), | ||||||
| 
 | 
 | ||||||
|   #[error("Hyper Error: {0}")] |   // backend errors
 | ||||||
|   Hyper(#[from] hyper::Error), |   #[error("Invalid reverse proxy setting")] | ||||||
|  |   InvalidReverseProxyConfig, | ||||||
|  |   #[error("Invalid upstream option setting")] | ||||||
|  |   InvalidUpstreamOptionSetting, | ||||||
|  |   #[error("Failed to build backend app: {0}")] | ||||||
|  |   FailedToBuildBackendApp(#[from] crate::backend::BackendAppBuilderError), | ||||||
| 
 | 
 | ||||||
|   #[error("Hyper Http Error: {0}")] |   // Handler errors
 | ||||||
|   HyperHttp(#[from] hyper::http::Error), |   #[error("Failed to build message handler: {0}")] | ||||||
|  |   FailedToBuildMessageHandler(#[from] crate::message_handler::HttpMessageHandlerBuilderError), | ||||||
|  |   #[error("Failed to upgrade request: {0}")] | ||||||
|  |   FailedToUpgradeRequest(String), | ||||||
|  |   #[error("Failed to upgrade response: {0}")] | ||||||
|  |   FailedToUpgradeResponse(String), | ||||||
|  |   #[error("Failed to copy bidirectional for upgraded connections: {0}")] | ||||||
|  |   FailedToCopyBidirectional(String), | ||||||
| 
 | 
 | ||||||
|   #[error("Hyper Http HeaderValue Error: {0}")] |   // Forwarder errors
 | ||||||
|   HyperHeaderValue(#[from] hyper::header::InvalidHeaderValue), |   #[error("Failed to build forwarder: {0}")] | ||||||
|  |   FailedToBuildForwarder(String), | ||||||
|  |   #[error("Failed to fetch from upstream: {0}")] | ||||||
|  |   FailedToFetchFromUpstream(String), | ||||||
| 
 | 
 | ||||||
|   #[error("Hyper Http HeaderName Error: {0}")] |   // Upstream connection setting errors
 | ||||||
|   HyperHeaderName(#[from] hyper::header::InvalidHeaderName), |   #[error("Unsupported upstream option")] | ||||||
|  |   UnsupportedUpstreamOption, | ||||||
| 
 | 
 | ||||||
|   #[error(transparent)] |   // Cache error map
 | ||||||
|   Other(#[from] anyhow::Error), |   #[cfg(feature = "cache")] | ||||||
| } |   #[error("Cache error: {0}")] | ||||||
| 
 |   CacheError(#[from] crate::forwarder::CacheError), | ||||||
| #[allow(dead_code)] | 
 | ||||||
| #[derive(Debug, Error, Clone)] |   // Others
 | ||||||
| pub enum ClientCertsError { |   #[error("Infallible")] | ||||||
|   #[error("TLS Client Certificate is Required for Given SNI: {0}")] |   Infallible(#[from] std::convert::Infallible), | ||||||
|   ClientCertRequired(String), |  | ||||||
| 
 |  | ||||||
|   #[error("Inconsistent TLS Client Certificate for Given SNI: {0}")] |  | ||||||
|   InconsistentClientCert(String), |  | ||||||
| } | } | ||||||
|  |  | ||||||
							
								
								
									
										47
									
								
								rpxy-lib/src/forwarder/cache/cache_error.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								rpxy-lib/src/forwarder/cache/cache_error.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,47 @@ | ||||||
|  | use thiserror::Error; | ||||||
|  | 
 | ||||||
|  | pub(crate) type CacheResult<T> = std::result::Result<T, CacheError>; | ||||||
|  | 
 | ||||||
|  | /// Describes things that can go wrong in the Rpxy
 | ||||||
|  | #[derive(Debug, Error)] | ||||||
|  | pub enum CacheError { | ||||||
|  |   // Cache errors,
 | ||||||
|  |   #[error("Invalid null request and/or response")] | ||||||
|  |   NullRequestOrResponse, | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to acquire mutex lock for cache")] | ||||||
|  |   FailedToAcquiredMutexLockForCache, | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to acquire mutex lock for check")] | ||||||
|  |   FailedToAcquiredMutexLockForCheck, | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to create file cache")] | ||||||
|  |   FailedToCreateFileCache, | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to write file cache")] | ||||||
|  |   FailedToWriteFileCache, | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to open cache file")] | ||||||
|  |   FailedToOpenCacheFile, | ||||||
|  | 
 | ||||||
|  |   #[error("Too large to cache")] | ||||||
|  |   TooLargeToCache, | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to cache bytes: {0}")] | ||||||
|  |   FailedToCacheBytes(String), | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to send frame to cache {0}")] | ||||||
|  |   FailedToSendFrameToCache(String), | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to send frame from file cache {0}")] | ||||||
|  |   FailedToSendFrameFromCache(String), | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to remove cache file: {0}")] | ||||||
|  |   FailedToRemoveCacheFile(String), | ||||||
|  | 
 | ||||||
|  |   #[error("Invalid cache target")] | ||||||
|  |   InvalidCacheTarget, | ||||||
|  | 
 | ||||||
|  |   #[error("Hash mismatched in cache file")] | ||||||
|  |   HashMismatchedInCacheFile, | ||||||
|  | } | ||||||
							
								
								
									
										527
									
								
								rpxy-lib/src/forwarder/cache/cache_main.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										527
									
								
								rpxy-lib/src/forwarder/cache/cache_main.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,527 @@ | ||||||
|  | use super::cache_error::*; | ||||||
|  | use crate::{ | ||||||
|  |   globals::Globals, | ||||||
|  |   hyper_ext::body::{full, BoxBody, ResponseBody, UnboundedStreamBody}, | ||||||
|  |   log::*, | ||||||
|  | }; | ||||||
|  | use base64::{engine::general_purpose, Engine as _}; | ||||||
|  | use bytes::{Buf, Bytes, BytesMut}; | ||||||
|  | use futures::channel::mpsc; | ||||||
|  | use http::{Request, Response, Uri}; | ||||||
|  | use http_body_util::{BodyExt, StreamBody}; | ||||||
|  | use http_cache_semantics::CachePolicy; | ||||||
|  | use hyper::body::{Frame, Incoming}; | ||||||
|  | use lru::LruCache; | ||||||
|  | use sha2::{Digest, Sha256}; | ||||||
|  | use std::{ | ||||||
|  |   path::{Path, PathBuf}, | ||||||
|  |   sync::{ | ||||||
|  |     atomic::{AtomicUsize, Ordering}, | ||||||
|  |     Arc, Mutex, | ||||||
|  |   }, | ||||||
|  |   time::SystemTime, | ||||||
|  | }; | ||||||
|  | use tokio::{ | ||||||
|  |   fs::{self, File}, | ||||||
|  |   io::{AsyncReadExt, AsyncWriteExt}, | ||||||
|  |   sync::RwLock, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | /* ---------------------------------------------- */ | ||||||
|  | #[derive(Clone, Debug)] | ||||||
|  | /// Cache main manager
 | ||||||
|  | pub(crate) struct RpxyCache { | ||||||
|  |   /// Inner lru cache manager storing http message caching policy
 | ||||||
|  |   inner: LruCacheManager, | ||||||
|  |   /// Managing cache file objects through RwLock's lock mechanism for file lock
 | ||||||
|  |   file_store: FileStore, | ||||||
|  |   /// Async runtime
 | ||||||
|  |   runtime_handle: tokio::runtime::Handle, | ||||||
|  |   /// Maximum size of each cache file object
 | ||||||
|  |   max_each_size: usize, | ||||||
|  |   /// Maximum size of cache object on memory
 | ||||||
|  |   max_each_size_on_memory: usize, | ||||||
|  |   /// Cache directory path
 | ||||||
|  |   cache_dir: PathBuf, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl RpxyCache { | ||||||
|  |   #[allow(unused)] | ||||||
|  |   /// Generate cache storage
 | ||||||
|  |   pub(crate) async fn new(globals: &Globals) -> Option<Self> { | ||||||
|  |     if !globals.proxy_config.cache_enabled { | ||||||
|  |       return None; | ||||||
|  |     } | ||||||
|  |     let cache_dir = globals.proxy_config.cache_dir.as_ref().unwrap(); | ||||||
|  |     let file_store = FileStore::new(&globals.runtime_handle).await; | ||||||
|  |     let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry); | ||||||
|  | 
 | ||||||
|  |     let max_each_size = globals.proxy_config.cache_max_each_size; | ||||||
|  |     let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory; | ||||||
|  |     if max_each_size < max_each_size_on_memory { | ||||||
|  |       warn!( | ||||||
|  |         "Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache" | ||||||
|  |       ); | ||||||
|  |       max_each_size_on_memory = max_each_size; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if let Err(e) = fs::remove_dir_all(cache_dir).await { | ||||||
|  |       warn!("Failed to clean up the cache dir: {e}"); | ||||||
|  |     }; | ||||||
|  |     fs::create_dir_all(&cache_dir).await.unwrap(); | ||||||
|  | 
 | ||||||
|  |     Some(Self { | ||||||
|  |       file_store, | ||||||
|  |       inner, | ||||||
|  |       runtime_handle: globals.runtime_handle.clone(), | ||||||
|  |       max_each_size, | ||||||
|  |       max_each_size_on_memory, | ||||||
|  |       cache_dir: cache_dir.clone(), | ||||||
|  |     }) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Count cache entries
 | ||||||
|  |   pub(crate) async fn count(&self) -> (usize, usize, usize) { | ||||||
|  |     let total = self.inner.count(); | ||||||
|  |     let file = self.file_store.count().await; | ||||||
|  |     let on_memory = total - file; | ||||||
|  |     (total, on_memory, file) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Put response into the cache
 | ||||||
|  |   pub(crate) async fn put( | ||||||
|  |     &self, | ||||||
|  |     uri: &hyper::Uri, | ||||||
|  |     mut body: Incoming, | ||||||
|  |     policy: &CachePolicy, | ||||||
|  |   ) -> CacheResult<UnboundedStreamBody> { | ||||||
|  |     let cache_manager = self.inner.clone(); | ||||||
|  |     let mut file_store = self.file_store.clone(); | ||||||
|  |     let uri = uri.clone(); | ||||||
|  |     let policy_clone = policy.clone(); | ||||||
|  |     let max_each_size = self.max_each_size; | ||||||
|  |     let max_each_size_on_memory = self.max_each_size_on_memory; | ||||||
|  |     let cache_dir = self.cache_dir.clone(); | ||||||
|  | 
 | ||||||
|  |     let (body_tx, body_rx) = mpsc::unbounded::<Result<Frame<Bytes>, hyper::Error>>(); | ||||||
|  | 
 | ||||||
|  |     self.runtime_handle.spawn(async move { | ||||||
|  |       let mut size = 0usize; | ||||||
|  |       let mut buf = BytesMut::new(); | ||||||
|  | 
 | ||||||
|  |       loop { | ||||||
|  |         let frame = match body.frame().await { | ||||||
|  |           Some(frame) => frame, | ||||||
|  |           None => { | ||||||
|  |             debug!("Response body finished"); | ||||||
|  |             break; | ||||||
|  |           } | ||||||
|  |         }; | ||||||
|  |         let frame_size = frame.as_ref().map(|f| { | ||||||
|  |           if f.is_data() { | ||||||
|  |             f.data_ref().map(|bytes| bytes.remaining()).unwrap_or_default() | ||||||
|  |           } else { | ||||||
|  |             0 | ||||||
|  |           } | ||||||
|  |         }); | ||||||
|  |         size += frame_size.unwrap_or_default(); | ||||||
|  | 
 | ||||||
|  |         // check size
 | ||||||
|  |         if size > max_each_size { | ||||||
|  |           warn!("Too large to cache"); | ||||||
|  |           return Err(CacheError::TooLargeToCache); | ||||||
|  |         } | ||||||
|  |         frame | ||||||
|  |           .as_ref() | ||||||
|  |           .map(|f| { | ||||||
|  |             if f.is_data() { | ||||||
|  |               let data_bytes = f.data_ref().unwrap().clone(); | ||||||
|  |               // debug!("cache data bytes of {} bytes", data_bytes.len());
 | ||||||
|  |               // We do not use stream-type buffering since it needs to lock file during operation.
 | ||||||
|  |               buf.extend(data_bytes.as_ref()); | ||||||
|  |             } | ||||||
|  |           }) | ||||||
|  |           .map_err(|e| CacheError::FailedToCacheBytes(e.to_string()))?; | ||||||
|  | 
 | ||||||
|  |         // send data to use response downstream
 | ||||||
|  |         body_tx | ||||||
|  |           .unbounded_send(frame) | ||||||
|  |           .map_err(|e| CacheError::FailedToSendFrameToCache(e.to_string()))?; | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       let buf = buf.freeze(); | ||||||
|  |       // Calculate hash of the cached data, after all data is received.
 | ||||||
|  |       // In-operation calculation is possible but it blocks sending data.
 | ||||||
|  |       let mut hasher = Sha256::new(); | ||||||
|  |       hasher.update(buf.as_ref()); | ||||||
|  |       let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); | ||||||
|  |       debug!("Cached data: {} bytes, hash = {:?}", size, hash_bytes); | ||||||
|  | 
 | ||||||
|  |       // Create cache object
 | ||||||
|  |       let cache_key = derive_cache_key_from_uri(&uri); | ||||||
|  |       let cache_object = CacheObject { | ||||||
|  |         policy: policy_clone, | ||||||
|  |         target: CacheFileOrOnMemory::build(&cache_dir, &uri, &buf, max_each_size_on_memory), | ||||||
|  |         hash: hash_bytes, | ||||||
|  |       }; | ||||||
|  | 
 | ||||||
|  |       if let Some((k, v)) = cache_manager.push(&cache_key, &cache_object)? { | ||||||
|  |         if k != cache_key { | ||||||
|  |           info!("Over the cache capacity. Evict least recent used entry"); | ||||||
|  |           if let CacheFileOrOnMemory::File(path) = v.target { | ||||||
|  |             file_store.evict(&path).await; | ||||||
|  |           } | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |       // store cache object to file
 | ||||||
|  |       if let CacheFileOrOnMemory::File(_) = cache_object.target { | ||||||
|  |         file_store.create(&cache_object, &buf).await?; | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       Ok(()) as CacheResult<()> | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     let stream_body = StreamBody::new(body_rx); | ||||||
|  | 
 | ||||||
|  |     Ok(stream_body) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Get cached response
 | ||||||
|  |   pub(crate) async fn get<R>(&self, req: &Request<R>) -> Option<Response<ResponseBody>> { | ||||||
|  |     debug!( | ||||||
|  |       "Current cache status: (total, on-memory, file) = {:?}", | ||||||
|  |       self.count().await | ||||||
|  |     ); | ||||||
|  |     let cache_key = derive_cache_key_from_uri(req.uri()); | ||||||
|  | 
 | ||||||
|  |     // First check cache chance
 | ||||||
|  |     let Ok(Some(cached_object)) = self.inner.get(&cache_key) else { | ||||||
|  |       return None; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     // Secondly check the cache freshness as an HTTP message
 | ||||||
|  |     let now = SystemTime::now(); | ||||||
|  |     let http_cache_semantics::BeforeRequest::Fresh(res_parts) = cached_object.policy.before_request(req, now) else { | ||||||
|  |       // Evict stale cache entry.
 | ||||||
|  |       // This might be okay to keep as is since it would be updated later.
 | ||||||
|  |       // However, there is no guarantee that newly got objects will be still cacheable.
 | ||||||
|  |       // So, we have to evict stale cache entries and cache file objects if found.
 | ||||||
|  |       debug!("Stale cache entry: {cache_key}"); | ||||||
|  |       let _evicted_entry = self.inner.evict(&cache_key); | ||||||
|  |       // For cache file
 | ||||||
|  |       if let CacheFileOrOnMemory::File(path) = &cached_object.target { | ||||||
|  |         self.file_store.evict(&path).await; | ||||||
|  |       } | ||||||
|  |       return None; | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     // Finally retrieve the file/on-memory object
 | ||||||
|  |     let response_body = match cached_object.target { | ||||||
|  |       CacheFileOrOnMemory::File(path) => { | ||||||
|  |         let stream_body = match self.file_store.read(path.clone(), &cached_object.hash).await { | ||||||
|  |           Ok(s) => s, | ||||||
|  |           Err(e) => { | ||||||
|  |             warn!("Failed to read from file cache: {e}"); | ||||||
|  |             let _evicted_entry = self.inner.evict(&cache_key); | ||||||
|  |             self.file_store.evict(path).await; | ||||||
|  |             return None; | ||||||
|  |           } | ||||||
|  |         }; | ||||||
|  |         debug!("Cache hit from file: {cache_key}"); | ||||||
|  |         ResponseBody::Streamed(stream_body) | ||||||
|  |       } | ||||||
|  |       CacheFileOrOnMemory::OnMemory(object) => { | ||||||
|  |         debug!("Cache hit from on memory: {cache_key}"); | ||||||
|  |         let mut hasher = Sha256::new(); | ||||||
|  |         hasher.update(object.as_ref()); | ||||||
|  |         let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); | ||||||
|  |         if hash_bytes != cached_object.hash { | ||||||
|  |           warn!("Hash mismatched. Cache object is corrupted"); | ||||||
|  |           let _evicted_entry = self.inner.evict(&cache_key); | ||||||
|  |           return None; | ||||||
|  |         } | ||||||
|  |         ResponseBody::Boxed(BoxBody::new(full(object))) | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     Some(Response::from_parts(res_parts, response_body)) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ---------------------------------------------- */ | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | /// Cache file manager outer that is responsible to handle `RwLock`
 | ||||||
|  | struct FileStore { | ||||||
|  |   /// Inner file store main object
 | ||||||
|  |   inner: Arc<RwLock<FileStoreInner>>, | ||||||
|  | } | ||||||
|  | impl FileStore { | ||||||
|  |   #[allow(unused)] | ||||||
|  |   /// Build manager
 | ||||||
|  |   async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { | ||||||
|  |     Self { | ||||||
|  |       inner: Arc::new(RwLock::new(FileStoreInner::new(runtime_handle).await)), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Count file cache entries
 | ||||||
|  |   async fn count(&self) -> usize { | ||||||
|  |     let inner = self.inner.read().await; | ||||||
|  |     inner.cnt | ||||||
|  |   } | ||||||
|  |   /// Create a temporary file cache
 | ||||||
|  |   async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { | ||||||
|  |     let mut inner = self.inner.write().await; | ||||||
|  |     inner.create(cache_object, body_bytes).await | ||||||
|  |   } | ||||||
|  |   /// Evict a temporary file cache
 | ||||||
|  |   async fn evict(&self, path: impl AsRef<Path>) { | ||||||
|  |     // Acquire the write lock
 | ||||||
|  |     let mut inner = self.inner.write().await; | ||||||
|  |     if let Err(e) = inner.remove(path).await { | ||||||
|  |       warn!("Eviction failed during file object removal: {:?}", e); | ||||||
|  |     }; | ||||||
|  |   } | ||||||
|  |   /// Read a temporary file cache
 | ||||||
|  |   async fn read( | ||||||
|  |     &self, | ||||||
|  |     path: impl AsRef<Path> + Send + Sync + 'static, | ||||||
|  |     hash: &Bytes, | ||||||
|  |   ) -> CacheResult<UnboundedStreamBody> { | ||||||
|  |     let inner = self.inner.read().await; | ||||||
|  |     inner.read(path, hash).await | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | /// Manager inner for cache on file system
 | ||||||
|  | struct FileStoreInner { | ||||||
|  |   /// Counter of current cached files
 | ||||||
|  |   cnt: usize, | ||||||
|  |   /// Async runtime
 | ||||||
|  |   runtime_handle: tokio::runtime::Handle, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl FileStoreInner { | ||||||
|  |   #[allow(unused)] | ||||||
|  |   /// Build new cache file manager.
 | ||||||
|  |   /// This first creates cache file dir if not exists, and cleans up the file inside the directory.
 | ||||||
|  |   /// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed.
 | ||||||
|  |   async fn new(runtime_handle: &tokio::runtime::Handle) -> Self { | ||||||
|  |     Self { | ||||||
|  |       cnt: 0, | ||||||
|  |       runtime_handle: runtime_handle.clone(), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Create a new temporary file cache
 | ||||||
|  |   async fn create(&mut self, cache_object: &CacheObject, body_bytes: &Bytes) -> CacheResult<()> { | ||||||
|  |     let cache_filepath = match cache_object.target { | ||||||
|  |       CacheFileOrOnMemory::File(ref path) => path.clone(), | ||||||
|  |       CacheFileOrOnMemory::OnMemory(_) => { | ||||||
|  |         return Err(CacheError::InvalidCacheTarget); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     let Ok(mut file) = File::create(&cache_filepath).await else { | ||||||
|  |       return Err(CacheError::FailedToCreateFileCache); | ||||||
|  |     }; | ||||||
|  |     let mut bytes_clone = body_bytes.clone(); | ||||||
|  |     while bytes_clone.has_remaining() { | ||||||
|  |       if let Err(e) = file.write_buf(&mut bytes_clone).await { | ||||||
|  |         error!("Failed to write file cache: {e}"); | ||||||
|  |         return Err(CacheError::FailedToWriteFileCache); | ||||||
|  |       }; | ||||||
|  |     } | ||||||
|  |     self.cnt += 1; | ||||||
|  |     Ok(()) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Retrieve a stored temporary file cache
 | ||||||
|  |   async fn read( | ||||||
|  |     &self, | ||||||
|  |     path: impl AsRef<Path> + Send + Sync + 'static, | ||||||
|  |     hash: &Bytes, | ||||||
|  |   ) -> CacheResult<UnboundedStreamBody> { | ||||||
|  |     let Ok(mut file) = File::open(&path).await else { | ||||||
|  |       warn!("Cache file object cannot be opened"); | ||||||
|  |       return Err(CacheError::FailedToOpenCacheFile); | ||||||
|  |     }; | ||||||
|  |     let hash_clone = hash.clone(); | ||||||
|  |     let mut self_clone = self.clone(); | ||||||
|  | 
 | ||||||
|  |     let (body_tx, body_rx) = mpsc::unbounded::<Result<Frame<Bytes>, hyper::Error>>(); | ||||||
|  | 
 | ||||||
|  |     self.runtime_handle.spawn(async move { | ||||||
|  |       let mut hasher = Sha256::new(); | ||||||
|  |       let mut buf = BytesMut::new(); | ||||||
|  |       loop { | ||||||
|  |         match file.read_buf(&mut buf).await { | ||||||
|  |           Ok(0) => break, | ||||||
|  |           Ok(_) => { | ||||||
|  |             let bytes = buf.copy_to_bytes(buf.remaining()); | ||||||
|  |             hasher.update(bytes.as_ref()); | ||||||
|  |             body_tx | ||||||
|  |               .unbounded_send(Ok(Frame::data(bytes))) | ||||||
|  |               .map_err(|e| CacheError::FailedToSendFrameFromCache(e.to_string()))? | ||||||
|  |           } | ||||||
|  |           Err(_) => break, | ||||||
|  |         }; | ||||||
|  |       } | ||||||
|  |       let hash_bytes = Bytes::copy_from_slice(hasher.finalize().as_ref()); | ||||||
|  |       if hash_bytes != hash_clone { | ||||||
|  |         warn!("Hash mismatched. Cache object is corrupted. Force to remove the cache file."); | ||||||
|  |         // only file can be evicted
 | ||||||
|  |         let _evicted_entry = self_clone.remove(&path).await; | ||||||
|  |         return Err(CacheError::HashMismatchedInCacheFile); | ||||||
|  |       } | ||||||
|  |       Ok(()) as CacheResult<()> | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     let stream_body = StreamBody::new(body_rx); | ||||||
|  | 
 | ||||||
|  |     Ok(stream_body) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Remove file
 | ||||||
|  |   async fn remove(&mut self, path: impl AsRef<Path>) -> CacheResult<()> { | ||||||
|  |     fs::remove_file(path.as_ref()) | ||||||
|  |       .await | ||||||
|  |       .map_err(|e| CacheError::FailedToRemoveCacheFile(e.to_string()))?; | ||||||
|  |     self.cnt -= 1; | ||||||
|  |     debug!("Removed a cache file at {:?} (file count: {})", path.as_ref(), self.cnt); | ||||||
|  | 
 | ||||||
|  |     Ok(()) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ---------------------------------------------- */ | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Debug)] | ||||||
|  | /// Cache target in hybrid manner of on-memory and file system
 | ||||||
|  | pub(crate) enum CacheFileOrOnMemory { | ||||||
|  |   /// Pointer to the temporary cache file
 | ||||||
|  |   File(PathBuf), | ||||||
|  |   /// Cached body itself
 | ||||||
|  |   OnMemory(Bytes), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl CacheFileOrOnMemory { | ||||||
|  |   /// Get cache object target
 | ||||||
|  |   fn build(cache_dir: &Path, uri: &Uri, object: &Bytes, max_each_size_on_memory: usize) -> Self { | ||||||
|  |     if object.len() > max_each_size_on_memory { | ||||||
|  |       let cache_filename = derive_filename_from_uri(uri); | ||||||
|  |       let cache_filepath = cache_dir.join(cache_filename); | ||||||
|  |       CacheFileOrOnMemory::File(cache_filepath) | ||||||
|  |     } else { | ||||||
|  |       CacheFileOrOnMemory::OnMemory(object.clone()) | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Debug)] | ||||||
|  | /// Cache object definition
 | ||||||
|  | struct CacheObject { | ||||||
|  |   /// Cache policy to determine if the stored cache can be used as a response to a new incoming request
 | ||||||
|  |   policy: CachePolicy, | ||||||
|  |   /// Cache target: on-memory object or temporary file
 | ||||||
|  |   target: CacheFileOrOnMemory, | ||||||
|  |   /// SHA256 hash of target to strongly bind the cache metadata (this object) and file target
 | ||||||
|  |   hash: Bytes, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ---------------------------------------------- */ | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | /// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache`
 | ||||||
|  | struct LruCacheManager { | ||||||
|  |   /// Inner lru cache manager main object
 | ||||||
|  |   inner: Arc<Mutex<LruCache<String, CacheObject>>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう
 | ||||||
|  |   /// Counter of current cached object (total)
 | ||||||
|  |   cnt: Arc<AtomicUsize>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl LruCacheManager { | ||||||
|  |   #[allow(unused)] | ||||||
|  |   /// Build LruCache
 | ||||||
|  |   fn new(cache_max_entry: usize) -> Self { | ||||||
|  |     Self { | ||||||
|  |       inner: Arc::new(Mutex::new(LruCache::new( | ||||||
|  |         std::num::NonZeroUsize::new(cache_max_entry).unwrap(), | ||||||
|  |       ))), | ||||||
|  |       cnt: Default::default(), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Count entries
 | ||||||
|  |   fn count(&self) -> usize { | ||||||
|  |     self.cnt.load(Ordering::Relaxed) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Evict an entry
 | ||||||
|  |   fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> { | ||||||
|  |     let Ok(mut lock) = self.inner.lock() else { | ||||||
|  |       error!("Mutex can't be locked to evict a cache entry"); | ||||||
|  |       return None; | ||||||
|  |     }; | ||||||
|  |     let res = lock.pop_entry(cache_key); | ||||||
|  |     // This may be inconsistent with the actual number of entries
 | ||||||
|  |     self.cnt.store(lock.len(), Ordering::Relaxed); | ||||||
|  |     res | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Push an entry
 | ||||||
|  |   fn push(&self, cache_key: &str, cache_object: &CacheObject) -> CacheResult<Option<(String, CacheObject)>> { | ||||||
|  |     let Ok(mut lock) = self.inner.lock() else { | ||||||
|  |       error!("Failed to acquire mutex lock for writing cache entry"); | ||||||
|  |       return Err(CacheError::FailedToAcquiredMutexLockForCache); | ||||||
|  |     }; | ||||||
|  |     let res = Ok(lock.push(cache_key.to_string(), cache_object.clone())); | ||||||
|  |     // This may be inconsistent with the actual number of entries
 | ||||||
|  |     self.cnt.store(lock.len(), Ordering::Relaxed); | ||||||
|  |     res | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Get an entry
 | ||||||
|  |   fn get(&self, cache_key: &str) -> CacheResult<Option<CacheObject>> { | ||||||
|  |     let Ok(mut lock) = self.inner.lock() else { | ||||||
|  |       error!("Mutex can't be locked for checking cache entry"); | ||||||
|  |       return Err(CacheError::FailedToAcquiredMutexLockForCheck); | ||||||
|  |     }; | ||||||
|  |     let Some(cached_object) = lock.get(cache_key) else { | ||||||
|  |       return Ok(None); | ||||||
|  |     }; | ||||||
|  |     Ok(Some(cached_object.clone())) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ---------------------------------------------- */ | ||||||
|  | /// Generate cache policy if the response is cacheable
 | ||||||
|  | pub(crate) fn get_policy_if_cacheable<B1, B2>( | ||||||
|  |   req: Option<&Request<B1>>, | ||||||
|  |   res: Option<&Response<B2>>, | ||||||
|  | ) -> CacheResult<Option<CachePolicy>> | ||||||
|  | // where
 | ||||||
|  | //   B1: core::fmt::Debug,
 | ||||||
|  | { | ||||||
|  |   // deduce cache policy from req and res
 | ||||||
|  |   let (Some(req), Some(res)) = (req, res) else { | ||||||
|  |     return Err(CacheError::NullRequestOrResponse); | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|  |   let new_policy = CachePolicy::new(req, res); | ||||||
|  |   if new_policy.is_storable() { | ||||||
|  |     // debug!("Response is cacheable: {:?}\n{:?}", req, res.headers());
 | ||||||
|  |     Ok(Some(new_policy)) | ||||||
|  |   } else { | ||||||
|  |     Ok(None) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn derive_filename_from_uri(uri: &hyper::Uri) -> String { | ||||||
|  |   let mut hasher = Sha256::new(); | ||||||
|  |   hasher.update(uri.to_string()); | ||||||
|  |   let digest = hasher.finalize(); | ||||||
|  |   general_purpose::URL_SAFE_NO_PAD.encode(digest) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | fn derive_cache_key_from_uri(uri: &hyper::Uri) -> String { | ||||||
|  |   uri.to_string() | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								rpxy-lib/src/forwarder/cache/mod.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								rpxy-lib/src/forwarder/cache/mod.rs
									
										
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,5 @@ | ||||||
|  | mod cache_error; | ||||||
|  | mod cache_main; | ||||||
|  | 
 | ||||||
|  | pub use cache_error::CacheError; | ||||||
|  | pub(crate) use cache_main::{get_policy_if_cacheable, RpxyCache}; | ||||||
							
								
								
									
										255
									
								
								rpxy-lib/src/forwarder/client.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										255
									
								
								rpxy-lib/src/forwarder/client.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,255 @@ | ||||||
|  | #[allow(unused)] | ||||||
|  | use crate::{ | ||||||
|  |   error::{RpxyError, RpxyResult}, | ||||||
|  |   globals::Globals, | ||||||
|  |   hyper_ext::{body::ResponseBody, rt::LocalExecutor}, | ||||||
|  |   log::*, | ||||||
|  | }; | ||||||
|  | use async_trait::async_trait; | ||||||
|  | use http::{Request, Response, Version}; | ||||||
|  | use hyper::body::{Body, Incoming}; | ||||||
|  | use hyper_util::client::legacy::{ | ||||||
|  |   connect::{Connect, HttpConnector}, | ||||||
|  |   Client, | ||||||
|  | }; | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "cache")] | ||||||
|  | use super::cache::{get_policy_if_cacheable, RpxyCache}; | ||||||
|  | 
 | ||||||
|  | #[async_trait] | ||||||
|  | /// Definition of the forwarder that simply forward requests from downstream client to upstream app servers.
 | ||||||
|  | pub trait ForwardRequest<B1, B2> { | ||||||
|  |   type Error; | ||||||
|  |   async fn request(&self, req: Request<B1>) -> Result<Response<B2>, Self::Error>; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// Forwarder http client struct responsible to cache handling
 | ||||||
|  | pub struct Forwarder<C, B> { | ||||||
|  |   #[cfg(feature = "cache")] | ||||||
|  |   cache: Option<RpxyCache>, | ||||||
|  |   inner: Client<C, B>, | ||||||
|  |   inner_h2: Client<C, B>, // `h2c` or http/2-only client is defined separately
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[async_trait] | ||||||
|  | impl<C, B1> ForwardRequest<B1, ResponseBody> for Forwarder<C, B1> | ||||||
|  | where | ||||||
|  |   C: Send + Sync + Connect + Clone + 'static, | ||||||
|  |   B1: Body + Send + Sync + Unpin + 'static, | ||||||
|  |   <B1 as Body>::Data: Send, | ||||||
|  |   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||||
|  | { | ||||||
|  |   type Error = RpxyError; | ||||||
|  | 
 | ||||||
|  |   async fn request(&self, req: Request<B1>) -> Result<Response<ResponseBody>, Self::Error> { | ||||||
|  |     // TODO: cache handling
 | ||||||
|  |     #[cfg(feature = "cache")] | ||||||
|  |     { | ||||||
|  |       let mut synth_req = None; | ||||||
|  |       if self.cache.is_some() { | ||||||
|  |         // try reading from cache
 | ||||||
|  |         if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { | ||||||
|  |           // if found, return it as response.
 | ||||||
|  |           info!("Cache hit - Return from cache"); | ||||||
|  |           return Ok(cached_response); | ||||||
|  |         }; | ||||||
|  | 
 | ||||||
|  |         // Synthetic request copy used just for caching (cannot clone request object...)
 | ||||||
|  |         synth_req = Some(build_synth_req_for_cache(&req)); | ||||||
|  |       } | ||||||
|  |       let res = self.request_directly(req).await; | ||||||
|  | 
 | ||||||
|  |       if self.cache.is_none() { | ||||||
|  |         return res.map(|inner| inner.map(ResponseBody::Incoming)); | ||||||
|  |       } | ||||||
|  | 
 | ||||||
|  |       // check cacheability and store it if cacheable
 | ||||||
|  |       let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { | ||||||
|  |         return res.map(|inner| inner.map(ResponseBody::Incoming)); | ||||||
|  |       }; | ||||||
|  |       let (parts, body) = res.unwrap().into_parts(); | ||||||
|  | 
 | ||||||
|  |       // Get streamed body without waiting for the arrival of the body,
 | ||||||
|  |       // which is done simultaneously with caching.
 | ||||||
|  |       let stream_body = self | ||||||
|  |         .cache | ||||||
|  |         .as_ref() | ||||||
|  |         .unwrap() | ||||||
|  |         .put(synth_req.unwrap().uri(), body, &cache_policy) | ||||||
|  |         .await?; | ||||||
|  | 
 | ||||||
|  |       // response with body being cached in background
 | ||||||
|  |       let new_res = Response::from_parts(parts, ResponseBody::Streamed(stream_body)); | ||||||
|  |       Ok(new_res) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // No cache handling
 | ||||||
|  |     #[cfg(not(feature = "cache"))] | ||||||
|  |     { | ||||||
|  |       self | ||||||
|  |         .request_directly(req) | ||||||
|  |         .await | ||||||
|  |         .map(|inner| inner.map(ResponseBody::Incoming)) | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<C, B1> Forwarder<C, B1> | ||||||
|  | where | ||||||
|  |   C: Send + Sync + Connect + Clone + 'static, | ||||||
|  |   B1: Body + Send + Unpin + 'static, | ||||||
|  |   <B1 as Body>::Data: Send, | ||||||
|  |   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||||
|  | { | ||||||
|  |   async fn request_directly(&self, req: Request<B1>) -> RpxyResult<Response<Incoming>> { | ||||||
|  |     // TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient.
 | ||||||
|  |     // Needs to be reconsidered. Currently, this is a kind of work around.
 | ||||||
|  |     // This possibly relates to https://github.com/hyperium/hyper/issues/2417.
 | ||||||
|  |     match req.version() { | ||||||
|  |       Version::HTTP_2 => self.inner_h2.request(req).await, // handles `h2c` requests
 | ||||||
|  |       _ => self.inner.request(req).await, | ||||||
|  |     } | ||||||
|  |     .map_err(|e| RpxyError::FailedToFetchFromUpstream(e.to_string())) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(not(any(feature = "native-tls-backend", feature = "rustls-backend")))] | ||||||
|  | impl<B> Forwarder<HttpConnector, B> | ||||||
|  | where | ||||||
|  |   B: Body + Send + Unpin + 'static, | ||||||
|  |   <B as Body>::Data: Send, | ||||||
|  |   <B as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||||
|  | { | ||||||
|  |   /// Build inner client with http
 | ||||||
|  |   pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> { | ||||||
|  |     warn!( | ||||||
|  |       " | ||||||
|  | -------------------------------------------------------------------------------------------------- | ||||||
|  | Request forwarder is working without TLS support!!! | ||||||
|  | We recommend to use this just for testing. | ||||||
|  | Please enable native-tls-backend or rustls-backend feature to enable TLS support. | ||||||
|  | --------------------------------------------------------------------------------------------------" | ||||||
|  |     ); | ||||||
|  |     let executor = LocalExecutor::new(_globals.runtime_handle.clone()); | ||||||
|  |     let mut http = HttpConnector::new(); | ||||||
|  |     http.enforce_http(true); | ||||||
|  |     http.set_reuse_address(true); | ||||||
|  |     http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); | ||||||
|  |     let inner = Client::builder(executor).build::<_, B>(http); | ||||||
|  |     let inner_h2 = inner.clone(); | ||||||
|  | 
 | ||||||
|  |     Ok(Self { | ||||||
|  |       inner, | ||||||
|  |       inner_h2, | ||||||
|  |       #[cfg(feature = "cache")] | ||||||
|  |       cache: RpxyCache::new(_globals).await, | ||||||
|  |     }) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(all(feature = "native-tls-backend", not(feature = "rustls-backend")))] | ||||||
|  | /// Build forwarder with hyper-tls (native-tls)
 | ||||||
|  | impl<B1> Forwarder<hyper_tls::HttpsConnector<HttpConnector>, B1> | ||||||
|  | where | ||||||
|  |   B1: Body + Send + Unpin + 'static, | ||||||
|  |   <B1 as Body>::Data: Send, | ||||||
|  |   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||||
|  | { | ||||||
|  |   /// Build forwarder
 | ||||||
|  |   pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> { | ||||||
|  |     // build hyper client with hyper-tls
 | ||||||
|  |     info!("Native TLS support is enabled for the connection to backend applications"); | ||||||
|  |     let executor = LocalExecutor::new(_globals.runtime_handle.clone()); | ||||||
|  | 
 | ||||||
|  |     let try_build_connector = |alpns: &[&str]| { | ||||||
|  |       hyper_tls::native_tls::TlsConnector::builder() | ||||||
|  |         .request_alpns(alpns) | ||||||
|  |         .build() | ||||||
|  |         .map_err(|e| RpxyError::FailedToBuildForwarder(e.to_string())) | ||||||
|  |         .map(|tls| { | ||||||
|  |           let mut http = HttpConnector::new(); | ||||||
|  |           http.enforce_http(false); | ||||||
|  |           http.set_reuse_address(true); | ||||||
|  |           http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); | ||||||
|  |           hyper_tls::HttpsConnector::from((http, tls.into())) | ||||||
|  |         }) | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let connector = try_build_connector(&["h2", "http/1.1"])?; | ||||||
|  |     let inner = Client::builder(executor.clone()).build::<_, B1>(connector); | ||||||
|  | 
 | ||||||
|  |     let connector_h2 = try_build_connector(&["h2"])?; | ||||||
|  |     let inner_h2 = Client::builder(executor.clone()) | ||||||
|  |       .http2_only(true) | ||||||
|  |       .build::<_, B1>(connector_h2); | ||||||
|  | 
 | ||||||
|  |     Ok(Self { | ||||||
|  |       inner, | ||||||
|  |       inner_h2, | ||||||
|  |       #[cfg(feature = "cache")] | ||||||
|  |       cache: RpxyCache::new(_globals).await, | ||||||
|  |     }) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "rustls-backend")] | ||||||
|  | /// Build forwarder with hyper-rustls (rustls)
 | ||||||
|  | impl<B1> Forwarder<hyper_rustls::HttpsConnector<HttpConnector>, B1> | ||||||
|  | where | ||||||
|  |   B1: Body + Send + Unpin + 'static, | ||||||
|  |   <B1 as Body>::Data: Send, | ||||||
|  |   <B1 as Body>::Error: Into<Box<(dyn std::error::Error + Send + Sync + 'static)>>, | ||||||
|  | { | ||||||
|  |   /// Build forwarder
 | ||||||
|  |   pub async fn try_new(_globals: &Arc<Globals>) -> RpxyResult<Self> { | ||||||
|  |     // build hyper client with rustls and webpki, only https is allowed
 | ||||||
|  |     #[cfg(feature = "rustls-backend-webpki")] | ||||||
|  |     let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); | ||||||
|  |     #[cfg(feature = "rustls-backend-webpki")] | ||||||
|  |     let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); | ||||||
|  |     #[cfg(feature = "rustls-backend-webpki")] | ||||||
|  |     info!("Mozilla WebPKI root certs with rustls is used for the connection to backend applications"); | ||||||
|  | 
 | ||||||
|  |     #[cfg(not(feature = "rustls-backend-webpki"))] | ||||||
|  |     let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots()?; | ||||||
|  |     #[cfg(not(feature = "rustls-backend-webpki"))] | ||||||
|  |     let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots()?; | ||||||
|  |     #[cfg(not(feature = "rustls-backend-webpki"))] | ||||||
|  |     info!("Native cert store with rustls is used for the connection to backend applications"); | ||||||
|  | 
 | ||||||
|  |     let mut http = HttpConnector::new(); | ||||||
|  |     http.enforce_http(false); | ||||||
|  |     http.set_reuse_address(true); | ||||||
|  |     http.set_keepalive(Some(_globals.proxy_config.upstream_idle_timeout)); | ||||||
|  | 
 | ||||||
|  |     let connector = builder | ||||||
|  |       .https_or_http() | ||||||
|  |       .enable_all_versions() | ||||||
|  |       .wrap_connector(http.clone()); | ||||||
|  |     let connector_h2 = builder_h2.https_or_http().enable_http2().wrap_connector(http); | ||||||
|  |     let inner = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector); | ||||||
|  |     let inner_h2 = Client::builder(LocalExecutor::new(_globals.runtime_handle.clone())).build::<_, B1>(connector_h2); | ||||||
|  | 
 | ||||||
|  |     Ok(Self { | ||||||
|  |       inner, | ||||||
|  |       inner_h2, | ||||||
|  |       #[cfg(feature = "cache")] | ||||||
|  |       cache: RpxyCache::new(_globals).await, | ||||||
|  |     }) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "cache")] | ||||||
|  | /// Build synthetic request to cache
 | ||||||
|  | fn build_synth_req_for_cache<T>(req: &Request<T>) -> Request<()> { | ||||||
|  |   let mut builder = Request::builder() | ||||||
|  |     .method(req.method()) | ||||||
|  |     .uri(req.uri()) | ||||||
|  |     .version(req.version()); | ||||||
|  |   // TODO: omit extensions. is this approach correct?
 | ||||||
|  |   for (header_key, header_value) in req.headers() { | ||||||
|  |     builder = builder.header(header_key, header_value); | ||||||
|  |   } | ||||||
|  |   builder.body(()).unwrap() | ||||||
|  | } | ||||||
							
								
								
									
										11
									
								
								rpxy-lib/src/forwarder/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								rpxy-lib/src/forwarder/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,11 @@ | ||||||
|  | #[cfg(feature = "cache")] | ||||||
|  | mod cache; | ||||||
|  | mod client; | ||||||
|  | 
 | ||||||
|  | use crate::hyper_ext::body::RequestBody; | ||||||
|  | 
 | ||||||
|  | pub(crate) type Forwarder<C> = client::Forwarder<C, RequestBody>; | ||||||
|  | pub(crate) use client::ForwardRequest; | ||||||
|  | 
 | ||||||
|  | #[cfg(feature = "cache")] | ||||||
|  | pub(crate) use cache::CacheError; | ||||||
|  | @ -1,57 +1,53 @@ | ||||||
| use crate::{ | use crate::{ | ||||||
|   backend::{ |  | ||||||
|     Backend, BackendBuilder, Backends, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption, |  | ||||||
|   }, |  | ||||||
|   certs::CryptoSource, |  | ||||||
|   constants::*, |   constants::*, | ||||||
|   error::RpxyError, |   count::RequestCount, | ||||||
|   log::*, |   crypto::{CryptoSource, ServerCryptoBase}, | ||||||
|   utils::{BytesName, PathNameBytesExp}, |  | ||||||
| }; | }; | ||||||
| use rustc_hash::FxHashMap as HashMap; | use hot_reload::ReloaderReceiver; | ||||||
| use std::net::SocketAddr; | use std::{net::SocketAddr, sync::Arc, time::Duration}; | ||||||
| use std::sync::{ |  | ||||||
|   atomic::{AtomicUsize, Ordering}, |  | ||||||
|   Arc, |  | ||||||
| }; |  | ||||||
| use tokio::time::Duration; |  | ||||||
| 
 | 
 | ||||||
| /// Global object containing proxy configurations and shared object like counters.
 | /// Global object containing proxy configurations and shared object like counters.
 | ||||||
| /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
 | /// But note that in Globals, we do not have Mutex and RwLock. It is indeed, the context shared among async tasks.
 | ||||||
| pub struct Globals<T> | pub struct Globals { | ||||||
| where |  | ||||||
|   T: CryptoSource, |  | ||||||
| { |  | ||||||
|   /// Configuration parameters for proxy transport and request handlers
 |   /// Configuration parameters for proxy transport and request handlers
 | ||||||
|   pub proxy_config: ProxyConfig, // TODO: proxy configはarcに包んでこいつだけ使いまわせばいいように変えていく。backendsも?
 |   pub proxy_config: ProxyConfig, | ||||||
| 
 |  | ||||||
|   /// Backend application objects to which http request handler forward incoming requests
 |  | ||||||
|   pub backends: Backends<T>, |  | ||||||
| 
 |  | ||||||
|   /// Shared context - Counter for serving requests
 |   /// Shared context - Counter for serving requests
 | ||||||
|   pub request_count: RequestCount, |   pub request_count: RequestCount, | ||||||
| 
 |  | ||||||
|   /// Shared context - Async task runtime handler
 |   /// Shared context - Async task runtime handler
 | ||||||
|   pub runtime_handle: tokio::runtime::Handle, |   pub runtime_handle: tokio::runtime::Handle, | ||||||
|  |   /// Shared context - Notify object to stop async tasks
 | ||||||
|  |   pub term_notify: Option<Arc<tokio::sync::Notify>>, | ||||||
|  |   /// Shared context - Certificate reloader service receiver
 | ||||||
|  |   pub cert_reloader_rx: Option<ReloaderReceiver<ServerCryptoBase>>, | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Configuration parameters for proxy transport and request handlers
 | /// Configuration parameters for proxy transport and request handlers
 | ||||||
| #[derive(PartialEq, Eq, Clone)] | #[derive(PartialEq, Eq, Clone)] | ||||||
| pub struct ProxyConfig { | pub struct ProxyConfig { | ||||||
|   pub listen_sockets: Vec<SocketAddr>, // when instantiate server
 |   /// listen socket addresses
 | ||||||
|   pub http_port: Option<u16>,          // when instantiate server
 |   pub listen_sockets: Vec<SocketAddr>, | ||||||
|   pub https_port: Option<u16>,         // when instantiate server
 |   /// http port
 | ||||||
|   pub tcp_listen_backlog: u32,         // when instantiate server
 |   pub http_port: Option<u16>, | ||||||
|  |   /// https port
 | ||||||
|  |   pub https_port: Option<u16>, | ||||||
|  |   /// tcp listen backlog
 | ||||||
|  |   pub tcp_listen_backlog: u32, | ||||||
| 
 | 
 | ||||||
|   pub proxy_timeout: Duration,    // when serving requests at Proxy
 |   /// Idle timeout as an HTTP server, used as the keep alive interval and timeout for reading request header
 | ||||||
|   pub upstream_timeout: Duration, // when serving requests at Handler
 |   pub proxy_idle_timeout: Duration, | ||||||
|  |   /// Idle timeout as an HTTP client, used as the keep alive interval for upstream connections
 | ||||||
|  |   pub upstream_idle_timeout: Duration, | ||||||
| 
 | 
 | ||||||
|   pub max_clients: usize,          // when serving requests
 |   pub max_clients: usize,          // when serving requests
 | ||||||
|   pub max_concurrent_streams: u32, // when instantiate server
 |   pub max_concurrent_streams: u32, // when instantiate server
 | ||||||
|   pub keepalive: bool,             // when instantiate server
 |   pub keepalive: bool,             // when instantiate server
 | ||||||
| 
 | 
 | ||||||
|   // experimentals
 |   // experimentals
 | ||||||
|  |   /// SNI consistency check
 | ||||||
|   pub sni_consistency: bool, // Handler
 |   pub sni_consistency: bool, // Handler
 | ||||||
|  |   /// Connection handling timeout
 | ||||||
|  |   /// timeout to handle a connection, total time of receive request, serve, and send response. this might limits the max length of response.
 | ||||||
|  |   pub connection_handling_timeout: Option<Duration>, | ||||||
| 
 | 
 | ||||||
|   #[cfg(feature = "cache")] |   #[cfg(feature = "cache")] | ||||||
|   pub cache_enabled: bool, |   pub cache_enabled: bool, | ||||||
|  | @ -90,14 +86,15 @@ impl Default for ProxyConfig { | ||||||
|       tcp_listen_backlog: TCP_LISTEN_BACKLOG, |       tcp_listen_backlog: TCP_LISTEN_BACKLOG, | ||||||
| 
 | 
 | ||||||
|       // TODO: Reconsider each timeout values
 |       // TODO: Reconsider each timeout values
 | ||||||
|       proxy_timeout: Duration::from_secs(PROXY_TIMEOUT_SEC), |       proxy_idle_timeout: Duration::from_secs(PROXY_IDLE_TIMEOUT_SEC), | ||||||
|       upstream_timeout: Duration::from_secs(UPSTREAM_TIMEOUT_SEC), |       upstream_idle_timeout: Duration::from_secs(UPSTREAM_IDLE_TIMEOUT_SEC), | ||||||
| 
 | 
 | ||||||
|       max_clients: MAX_CLIENTS, |       max_clients: MAX_CLIENTS, | ||||||
|       max_concurrent_streams: MAX_CONCURRENT_STREAMS, |       max_concurrent_streams: MAX_CONCURRENT_STREAMS, | ||||||
|       keepalive: true, |       keepalive: true, | ||||||
| 
 | 
 | ||||||
|       sni_consistency: true, |       sni_consistency: true, | ||||||
|  |       connection_handling_timeout: None, | ||||||
| 
 | 
 | ||||||
|       #[cfg(feature = "cache")] |       #[cfg(feature = "cache")] | ||||||
|       cache_enabled: false, |       cache_enabled: false, | ||||||
|  | @ -137,44 +134,6 @@ where | ||||||
|   pub inner: Vec<AppConfig<T>>, |   pub inner: Vec<AppConfig<T>>, | ||||||
|   pub default_app: Option<String>, |   pub default_app: Option<String>, | ||||||
| } | } | ||||||
| impl<T> TryInto<Backends<T>> for AppConfigList<T> |  | ||||||
| where |  | ||||||
|   T: CryptoSource + Clone, |  | ||||||
| { |  | ||||||
|   type Error = RpxyError; |  | ||||||
| 
 |  | ||||||
|   fn try_into(self) -> Result<Backends<T>, Self::Error> { |  | ||||||
|     let mut backends = Backends::new(); |  | ||||||
|     for app_config in self.inner.iter() { |  | ||||||
|       let backend = app_config.try_into()?; |  | ||||||
|       backends |  | ||||||
|         .apps |  | ||||||
|         .insert(app_config.server_name.clone().to_server_name_vec(), backend); |  | ||||||
|       info!( |  | ||||||
|         "Registering application {} ({})", |  | ||||||
|         &app_config.server_name, &app_config.app_name |  | ||||||
|       ); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // default backend application for plaintext http requests
 |  | ||||||
|     if let Some(d) = self.default_app { |  | ||||||
|       let d_sn: Vec<&str> = backends |  | ||||||
|         .apps |  | ||||||
|         .iter() |  | ||||||
|         .filter(|(_k, v)| v.app_name == d) |  | ||||||
|         .map(|(_, v)| v.server_name.as_ref()) |  | ||||||
|         .collect(); |  | ||||||
|       if !d_sn.is_empty() { |  | ||||||
|         info!( |  | ||||||
|           "Serving plaintext http for requests to unconfigured server_name by app {} (server_name: {}).", |  | ||||||
|           d, d_sn[0] |  | ||||||
|         ); |  | ||||||
|         backends.default_server_name_bytes = Some(d_sn[0].to_server_name_vec()); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|     Ok(backends) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| /// Configuration parameters for single backend application
 | /// Configuration parameters for single backend application
 | ||||||
| #[derive(PartialEq, Eq, Clone)] | #[derive(PartialEq, Eq, Clone)] | ||||||
|  | @ -187,77 +146,6 @@ where | ||||||
|   pub reverse_proxy: Vec<ReverseProxyConfig>, |   pub reverse_proxy: Vec<ReverseProxyConfig>, | ||||||
|   pub tls: Option<TlsConfig<T>>, |   pub tls: Option<TlsConfig<T>>, | ||||||
| } | } | ||||||
| impl<T> TryInto<Backend<T>> for &AppConfig<T> |  | ||||||
| where |  | ||||||
|   T: CryptoSource + Clone, |  | ||||||
| { |  | ||||||
|   type Error = RpxyError; |  | ||||||
| 
 |  | ||||||
|   fn try_into(self) -> Result<Backend<T>, Self::Error> { |  | ||||||
|     // backend builder
 |  | ||||||
|     let mut backend_builder = BackendBuilder::default(); |  | ||||||
|     // reverse proxy settings
 |  | ||||||
|     let reverse_proxy = self.try_into()?; |  | ||||||
| 
 |  | ||||||
|     backend_builder |  | ||||||
|       .app_name(self.app_name.clone()) |  | ||||||
|       .server_name(self.server_name.clone()) |  | ||||||
|       .reverse_proxy(reverse_proxy); |  | ||||||
| 
 |  | ||||||
|     // TLS settings and build backend instance
 |  | ||||||
|     let backend = if self.tls.is_none() { |  | ||||||
|       backend_builder.build().map_err(RpxyError::BackendBuild)? |  | ||||||
|     } else { |  | ||||||
|       let tls = self.tls.as_ref().unwrap(); |  | ||||||
| 
 |  | ||||||
|       backend_builder |  | ||||||
|         .https_redirection(Some(tls.https_redirection)) |  | ||||||
|         .crypto_source(Some(tls.inner.clone())) |  | ||||||
|         .build()? |  | ||||||
|     }; |  | ||||||
|     Ok(backend) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| impl<T> TryInto<ReverseProxy> for &AppConfig<T> |  | ||||||
| where |  | ||||||
|   T: CryptoSource + Clone, |  | ||||||
| { |  | ||||||
|   type Error = RpxyError; |  | ||||||
| 
 |  | ||||||
|   fn try_into(self) -> Result<ReverseProxy, Self::Error> { |  | ||||||
|     let mut upstream: HashMap<PathNameBytesExp, UpstreamGroup> = HashMap::default(); |  | ||||||
| 
 |  | ||||||
|     self.reverse_proxy.iter().for_each(|rpo| { |  | ||||||
|       let upstream_vec: Vec<Upstream> = rpo.upstream.iter().map(|x| x.try_into().unwrap()).collect(); |  | ||||||
|       // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap());
 |  | ||||||
|       // let lb_upstream_num = vec_upstream.len();
 |  | ||||||
|       let elem = UpstreamGroupBuilder::default() |  | ||||||
|         .upstream(&upstream_vec) |  | ||||||
|         .path(&rpo.path) |  | ||||||
|         .replace_path(&rpo.replace_path) |  | ||||||
|         .lb(&rpo.load_balance, &upstream_vec, &self.server_name, &rpo.path) |  | ||||||
|         .opts(&rpo.upstream_options) |  | ||||||
|         .build() |  | ||||||
|         .unwrap(); |  | ||||||
| 
 |  | ||||||
|       upstream.insert(elem.path.clone(), elem); |  | ||||||
|     }); |  | ||||||
|     if self.reverse_proxy.iter().filter(|rpo| rpo.path.is_none()).count() >= 2 { |  | ||||||
|       error!("Multiple default reverse proxy setting"); |  | ||||||
|       return Err(RpxyError::ConfigBuild("Invalid reverse proxy setting")); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if !(upstream.iter().all(|(_, elem)| { |  | ||||||
|       !(elem.opts.contains(&UpstreamOption::ForceHttp11Upstream) |  | ||||||
|         && elem.opts.contains(&UpstreamOption::ForceHttp2Upstream)) |  | ||||||
|     })) { |  | ||||||
|       error!("Either one of force_http11 or force_http2 can be enabled"); |  | ||||||
|       return Err(RpxyError::ConfigBuild("Invalid upstream option setting")); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     Ok(ReverseProxy { upstream }) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 | 
 | ||||||
| /// Configuration parameters for single reverse proxy corresponding to the path
 | /// Configuration parameters for single reverse proxy corresponding to the path
 | ||||||
| #[derive(PartialEq, Eq, Clone)] | #[derive(PartialEq, Eq, Clone)] | ||||||
|  | @ -272,16 +160,7 @@ pub struct ReverseProxyConfig { | ||||||
| /// Configuration parameters for single upstream destination from a reverse proxy
 | /// Configuration parameters for single upstream destination from a reverse proxy
 | ||||||
| #[derive(PartialEq, Eq, Clone)] | #[derive(PartialEq, Eq, Clone)] | ||||||
| pub struct UpstreamUri { | pub struct UpstreamUri { | ||||||
|   pub inner: hyper::Uri, |   pub inner: http::Uri, | ||||||
| } |  | ||||||
| impl TryInto<Upstream> for &UpstreamUri { |  | ||||||
|   type Error = anyhow::Error; |  | ||||||
| 
 |  | ||||||
|   fn try_into(self) -> std::result::Result<Upstream, Self::Error> { |  | ||||||
|     Ok(Upstream { |  | ||||||
|       uri: self.inner.clone(), |  | ||||||
|     }) |  | ||||||
|   } |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| /// Configuration parameters on TLS for a single backend application
 | /// Configuration parameters on TLS for a single backend application
 | ||||||
|  | @ -293,30 +172,3 @@ where | ||||||
|   pub inner: T, |   pub inner: T, | ||||||
|   pub https_redirection: bool, |   pub https_redirection: bool, | ||||||
| } | } | ||||||
| 
 |  | ||||||
| #[derive(Debug, Clone, Default)] |  | ||||||
| /// Counter for serving requests
 |  | ||||||
| pub struct RequestCount(Arc<AtomicUsize>); |  | ||||||
| 
 |  | ||||||
| impl RequestCount { |  | ||||||
|   pub fn current(&self) -> usize { |  | ||||||
|     self.0.load(Ordering::Relaxed) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   pub fn increment(&self) -> usize { |  | ||||||
|     self.0.fetch_add(1, Ordering::Relaxed) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   pub fn decrement(&self) -> usize { |  | ||||||
|     let mut count; |  | ||||||
|     while { |  | ||||||
|       count = self.0.load(Ordering::Relaxed); |  | ||||||
|       count > 0 |  | ||||||
|         && self |  | ||||||
|           .0 |  | ||||||
|           .compare_exchange(count, count - 1, Ordering::Relaxed, Ordering::Relaxed) |  | ||||||
|           != Ok(count) |  | ||||||
|     } {} |  | ||||||
|     count |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -1,393 +0,0 @@ | ||||||
| use crate::{error::*, globals::Globals, log::*, CryptoSource}; |  | ||||||
| use base64::{engine::general_purpose, Engine as _}; |  | ||||||
| use bytes::{Buf, Bytes, BytesMut}; |  | ||||||
| use http_cache_semantics::CachePolicy; |  | ||||||
| use hyper::{ |  | ||||||
|   http::{Request, Response}, |  | ||||||
|   Body, |  | ||||||
| }; |  | ||||||
| use lru::LruCache; |  | ||||||
| use sha2::{Digest, Sha256}; |  | ||||||
| use std::{ |  | ||||||
|   fmt::Debug, |  | ||||||
|   path::{Path, PathBuf}, |  | ||||||
|   sync::{ |  | ||||||
|     atomic::{AtomicUsize, Ordering}, |  | ||||||
|     Arc, Mutex, |  | ||||||
|   }, |  | ||||||
|   time::SystemTime, |  | ||||||
| }; |  | ||||||
| use tokio::{ |  | ||||||
|   fs::{self, File}, |  | ||||||
|   io::{AsyncReadExt, AsyncWriteExt}, |  | ||||||
|   sync::RwLock, |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| /// Cache target in hybrid manner of on-memory and file system
 |  | ||||||
| pub enum CacheFileOrOnMemory { |  | ||||||
|   /// Pointer to the temporary cache file
 |  | ||||||
|   File(PathBuf), |  | ||||||
|   /// Cached body itself
 |  | ||||||
|   OnMemory(Vec<u8>), |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| /// Cache object definition
 |  | ||||||
| struct CacheObject { |  | ||||||
|   /// Cache policy to determine if the stored cache can be used as a response to a new incoming request
 |  | ||||||
|   pub policy: CachePolicy, |  | ||||||
|   /// Cache target: on-memory object or temporary file
 |  | ||||||
|   pub target: CacheFileOrOnMemory, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Debug)] |  | ||||||
| /// Manager inner for cache on file system
 |  | ||||||
| struct CacheFileManagerInner { |  | ||||||
|   /// Directory of temporary files
 |  | ||||||
|   cache_dir: PathBuf, |  | ||||||
|   /// Counter of current cached files
 |  | ||||||
|   cnt: usize, |  | ||||||
|   /// Async runtime
 |  | ||||||
|   runtime_handle: tokio::runtime::Handle, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl CacheFileManagerInner { |  | ||||||
|   /// Build new cache file manager.
 |  | ||||||
|   /// This first creates cache file dir if not exists, and cleans up the file inside the directory.
 |  | ||||||
|   /// TODO: Persistent cache is really difficult. `sqlite` or something like that is needed.
 |  | ||||||
|   async fn new(path: impl AsRef<Path>, runtime_handle: &tokio::runtime::Handle) -> Self { |  | ||||||
|     let path_buf = path.as_ref().to_path_buf(); |  | ||||||
|     if let Err(e) = fs::remove_dir_all(path).await { |  | ||||||
|       warn!("Failed to clean up the cache dir: {e}"); |  | ||||||
|     }; |  | ||||||
|     fs::create_dir_all(&path_buf).await.unwrap(); |  | ||||||
|     Self { |  | ||||||
|       cache_dir: path_buf.clone(), |  | ||||||
|       cnt: 0, |  | ||||||
|       runtime_handle: runtime_handle.clone(), |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Create a new temporary file cache
 |  | ||||||
|   async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> Result<CacheFileOrOnMemory> { |  | ||||||
|     let cache_filepath = self.cache_dir.join(cache_filename); |  | ||||||
|     let Ok(mut file) = File::create(&cache_filepath).await else { |  | ||||||
|       return Err(RpxyError::Cache("Failed to create file")); |  | ||||||
|     }; |  | ||||||
|     let mut bytes_clone = body_bytes.clone(); |  | ||||||
|     while bytes_clone.has_remaining() { |  | ||||||
|       if let Err(e) = file.write_buf(&mut bytes_clone).await { |  | ||||||
|         error!("Failed to write file cache: {e}"); |  | ||||||
|         return Err(RpxyError::Cache("Failed to write file cache: {e}")); |  | ||||||
|       }; |  | ||||||
|     } |  | ||||||
|     self.cnt += 1; |  | ||||||
|     Ok(CacheFileOrOnMemory::File(cache_filepath)) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Retrieve a stored temporary file cache
 |  | ||||||
|   async fn read(&self, path: impl AsRef<Path>) -> Result<Body> { |  | ||||||
|     let Ok(mut file) = File::open(&path).await else { |  | ||||||
|       warn!("Cache file object cannot be opened"); |  | ||||||
|       return Err(RpxyError::Cache("Cache file object cannot be opened")); |  | ||||||
|     }; |  | ||||||
|     let (body_sender, res_body) = Body::channel(); |  | ||||||
|     self.runtime_handle.spawn(async move { |  | ||||||
|       let mut sender = body_sender; |  | ||||||
|       let mut buf = BytesMut::new(); |  | ||||||
|       loop { |  | ||||||
|         match file.read_buf(&mut buf).await { |  | ||||||
|           Ok(0) => break, |  | ||||||
|           Ok(_) => sender.send_data(buf.copy_to_bytes(buf.remaining())).await?, |  | ||||||
|           Err(_) => break, |  | ||||||
|         }; |  | ||||||
|       } |  | ||||||
|       Ok(()) as Result<()> |  | ||||||
|     }); |  | ||||||
| 
 |  | ||||||
|     Ok(res_body) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Remove file
 |  | ||||||
|   async fn remove(&mut self, path: impl AsRef<Path>) -> Result<()> { |  | ||||||
|     fs::remove_file(path.as_ref()).await?; |  | ||||||
|     self.cnt -= 1; |  | ||||||
|     debug!("Removed a cache file at {:?} (file count: {})", path.as_ref(), self.cnt); |  | ||||||
| 
 |  | ||||||
|     Ok(()) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Debug, Clone)] |  | ||||||
| /// Cache file manager outer that is responsible to handle `RwLock`
 |  | ||||||
| struct CacheFileManager { |  | ||||||
|   inner: Arc<RwLock<CacheFileManagerInner>>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl CacheFileManager { |  | ||||||
|   /// Build manager
 |  | ||||||
|   async fn new(path: impl AsRef<Path>, runtime_handle: &tokio::runtime::Handle) -> Self { |  | ||||||
|     Self { |  | ||||||
|       inner: Arc::new(RwLock::new(CacheFileManagerInner::new(path, runtime_handle).await)), |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   /// Evict a temporary file cache
 |  | ||||||
|   async fn evict(&self, path: impl AsRef<Path>) { |  | ||||||
|     // Acquire the write lock
 |  | ||||||
|     let mut inner = self.inner.write().await; |  | ||||||
|     if let Err(e) = inner.remove(path).await { |  | ||||||
|       warn!("Eviction failed during file object removal: {:?}", e); |  | ||||||
|     }; |  | ||||||
|   } |  | ||||||
|   /// Read a temporary file cache
 |  | ||||||
|   async fn read(&self, path: impl AsRef<Path>) -> Result<Body> { |  | ||||||
|     let mgr = self.inner.read().await; |  | ||||||
|     mgr.read(&path).await |  | ||||||
|   } |  | ||||||
|   /// Create a temporary file cache
 |  | ||||||
|   async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> Result<CacheFileOrOnMemory> { |  | ||||||
|     let mut mgr = self.inner.write().await; |  | ||||||
|     mgr.create(cache_filename, body_bytes).await |  | ||||||
|   } |  | ||||||
|   async fn count(&self) -> usize { |  | ||||||
|     let mgr = self.inner.read().await; |  | ||||||
|     mgr.cnt |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Debug, Clone)] |  | ||||||
| /// Lru cache manager that is responsible to handle `Mutex` as an outer of `LruCache`
 |  | ||||||
| struct LruCacheManager { |  | ||||||
|   inner: Arc<Mutex<LruCache<String, CacheObject>>>, // TODO: keyはstring urlでいいのか疑問。全requestに対してcheckすることになりそう
 |  | ||||||
|   cnt: Arc<AtomicUsize>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl LruCacheManager { |  | ||||||
|   /// Build LruCache
 |  | ||||||
|   fn new(cache_max_entry: usize) -> Self { |  | ||||||
|     Self { |  | ||||||
|       inner: Arc::new(Mutex::new(LruCache::new( |  | ||||||
|         std::num::NonZeroUsize::new(cache_max_entry).unwrap(), |  | ||||||
|       ))), |  | ||||||
|       cnt: Arc::new(AtomicUsize::default()), |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
|   /// Count entries
 |  | ||||||
|   fn count(&self) -> usize { |  | ||||||
|     self.cnt.load(Ordering::Relaxed) |  | ||||||
|   } |  | ||||||
|   /// Evict an entry
 |  | ||||||
|   fn evict(&self, cache_key: &str) -> Option<(String, CacheObject)> { |  | ||||||
|     let Ok(mut lock) = self.inner.lock() else { |  | ||||||
|         error!("Mutex can't be locked to evict a cache entry"); |  | ||||||
|         return None; |  | ||||||
|       }; |  | ||||||
|     let res = lock.pop_entry(cache_key); |  | ||||||
|     self.cnt.store(lock.len(), Ordering::Relaxed); |  | ||||||
|     res |  | ||||||
|   } |  | ||||||
|   /// Get an entry
 |  | ||||||
|   fn get(&self, cache_key: &str) -> Result<Option<CacheObject>> { |  | ||||||
|     let Ok(mut lock) = self.inner.lock() else { |  | ||||||
|       error!("Mutex can't be locked for checking cache entry"); |  | ||||||
|       return Err(RpxyError::Cache("Mutex can't be locked for checking cache entry")); |  | ||||||
|     }; |  | ||||||
|     let Some(cached_object) = lock.get(cache_key) else { |  | ||||||
|       return Ok(None); |  | ||||||
|     }; |  | ||||||
|     Ok(Some(cached_object.clone())) |  | ||||||
|   } |  | ||||||
|   /// Push an entry
 |  | ||||||
|   fn push(&self, cache_key: &str, cache_object: CacheObject) -> Result<Option<(String, CacheObject)>> { |  | ||||||
|     let Ok(mut lock) = self.inner.lock() else { |  | ||||||
|       error!("Failed to acquire mutex lock for writing cache entry"); |  | ||||||
|       return Err(RpxyError::Cache("Failed to acquire mutex lock for writing cache entry")); |  | ||||||
|     }; |  | ||||||
|     let res = Ok(lock.push(cache_key.to_string(), cache_object)); |  | ||||||
|     self.cnt.store(lock.len(), Ordering::Relaxed); |  | ||||||
|     res |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Clone, Debug)] |  | ||||||
| pub struct RpxyCache { |  | ||||||
|   /// Managing cache file objects through RwLock's lock mechanism for file lock
 |  | ||||||
|   cache_file_manager: CacheFileManager, |  | ||||||
|   /// Lru cache storing http message caching policy
 |  | ||||||
|   inner: LruCacheManager, |  | ||||||
|   /// Async runtime
 |  | ||||||
|   runtime_handle: tokio::runtime::Handle, |  | ||||||
|   /// Maximum size of each cache file object
 |  | ||||||
|   max_each_size: usize, |  | ||||||
|   /// Maximum size of cache object on memory
 |  | ||||||
|   max_each_size_on_memory: usize, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl RpxyCache { |  | ||||||
|   /// Generate cache storage
 |  | ||||||
|   pub async fn new<T: CryptoSource>(globals: &Globals<T>) -> Option<Self> { |  | ||||||
|     if !globals.proxy_config.cache_enabled { |  | ||||||
|       return None; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     let path = globals.proxy_config.cache_dir.as_ref().unwrap(); |  | ||||||
|     let cache_file_manager = CacheFileManager::new(path, &globals.runtime_handle).await; |  | ||||||
|     let inner = LruCacheManager::new(globals.proxy_config.cache_max_entry); |  | ||||||
| 
 |  | ||||||
|     let max_each_size = globals.proxy_config.cache_max_each_size; |  | ||||||
|     let mut max_each_size_on_memory = globals.proxy_config.cache_max_each_size_on_memory; |  | ||||||
|     if max_each_size < max_each_size_on_memory { |  | ||||||
|       warn!( |  | ||||||
|         "Maximum size of on memory cache per entry must be smaller than or equal to the maximum of each file cache" |  | ||||||
|       ); |  | ||||||
|       max_each_size_on_memory = max_each_size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     Some(Self { |  | ||||||
|       cache_file_manager, |  | ||||||
|       inner, |  | ||||||
|       runtime_handle: globals.runtime_handle.clone(), |  | ||||||
|       max_each_size, |  | ||||||
|       max_each_size_on_memory, |  | ||||||
|     }) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Count cache entries
 |  | ||||||
|   pub async fn count(&self) -> (usize, usize, usize) { |  | ||||||
|     let total = self.inner.count(); |  | ||||||
|     let file = self.cache_file_manager.count().await; |  | ||||||
|     let on_memory = total - file; |  | ||||||
|     (total, on_memory, file) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Get cached response
 |  | ||||||
|   pub async fn get<R>(&self, req: &Request<R>) -> Option<Response<Body>> { |  | ||||||
|     debug!( |  | ||||||
|       "Current cache status: (total, on-memory, file) = {:?}", |  | ||||||
|       self.count().await |  | ||||||
|     ); |  | ||||||
|     let cache_key = req.uri().to_string(); |  | ||||||
| 
 |  | ||||||
|     // First check cache chance
 |  | ||||||
|     let Ok(Some(cached_object)) = self.inner.get(&cache_key) else { |  | ||||||
|       return None; |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     // Secondly check the cache freshness as an HTTP message
 |  | ||||||
|     let now = SystemTime::now(); |  | ||||||
|     let http_cache_semantics::BeforeRequest::Fresh(res_parts) = cached_object.policy.before_request(req, now) else { |  | ||||||
|       // Evict stale cache entry.
 |  | ||||||
|       // This might be okay to keep as is since it would be updated later.
 |  | ||||||
|       // However, there is no guarantee that newly got objects will be still cacheable.
 |  | ||||||
|       // So, we have to evict stale cache entries and cache file objects if found.
 |  | ||||||
|       debug!("Stale cache entry: {cache_key}"); |  | ||||||
|       let _evicted_entry = self.inner.evict(&cache_key); |  | ||||||
|       // For cache file
 |  | ||||||
|       if let CacheFileOrOnMemory::File(path) = &cached_object.target { |  | ||||||
|         self.cache_file_manager.evict(&path).await; |  | ||||||
|       } |  | ||||||
|       return None; |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     // Finally retrieve the file/on-memory object
 |  | ||||||
|     match cached_object.target { |  | ||||||
|       CacheFileOrOnMemory::File(path) => { |  | ||||||
|         let res_body = match self.cache_file_manager.read(&path).await { |  | ||||||
|           Ok(res_body) => res_body, |  | ||||||
|           Err(e) => { |  | ||||||
|             warn!("Failed to read from file cache: {e}"); |  | ||||||
|             let _evicted_entry = self.inner.evict(&cache_key); |  | ||||||
|             self.cache_file_manager.evict(&path).await; |  | ||||||
|             return None; |  | ||||||
|           } |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         debug!("Cache hit from file: {cache_key}"); |  | ||||||
|         Some(Response::from_parts(res_parts, res_body)) |  | ||||||
|       } |  | ||||||
|       CacheFileOrOnMemory::OnMemory(object) => { |  | ||||||
|         debug!("Cache hit from on memory: {cache_key}"); |  | ||||||
|         Some(Response::from_parts(res_parts, Body::from(object))) |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Put response into the cache
 |  | ||||||
|   pub async fn put(&self, uri: &hyper::Uri, body_bytes: &Bytes, policy: &CachePolicy) -> Result<()> { |  | ||||||
|     let my_cache = self.inner.clone(); |  | ||||||
|     let mut mgr = self.cache_file_manager.clone(); |  | ||||||
|     let uri = uri.clone(); |  | ||||||
|     let bytes_clone = body_bytes.clone(); |  | ||||||
|     let policy_clone = policy.clone(); |  | ||||||
|     let max_each_size = self.max_each_size; |  | ||||||
|     let max_each_size_on_memory = self.max_each_size_on_memory; |  | ||||||
| 
 |  | ||||||
|     self.runtime_handle.spawn(async move { |  | ||||||
|       if bytes_clone.len() > max_each_size { |  | ||||||
|         warn!("Too large to cache"); |  | ||||||
|         return Err(RpxyError::Cache("Too large to cache")); |  | ||||||
|       } |  | ||||||
|       let cache_key = derive_cache_key_from_uri(&uri); |  | ||||||
| 
 |  | ||||||
|       debug!("Object of size {:?} bytes to be cached", bytes_clone.len()); |  | ||||||
| 
 |  | ||||||
|       let cache_object = if bytes_clone.len() > max_each_size_on_memory { |  | ||||||
|         let cache_filename = derive_filename_from_uri(&uri); |  | ||||||
|         let target = mgr.create(&cache_filename, &bytes_clone).await?; |  | ||||||
|         debug!("Cached a new cache file: {} - {}", cache_key, cache_filename); |  | ||||||
|         CacheObject { |  | ||||||
|           policy: policy_clone, |  | ||||||
|           target, |  | ||||||
|         } |  | ||||||
|       } else { |  | ||||||
|         debug!("Cached a new object on memory: {}", cache_key); |  | ||||||
|         CacheObject { |  | ||||||
|           policy: policy_clone, |  | ||||||
|           target: CacheFileOrOnMemory::OnMemory(bytes_clone.to_vec()), |  | ||||||
|         } |  | ||||||
|       }; |  | ||||||
| 
 |  | ||||||
|       if let Some((k, v)) = my_cache.push(&cache_key, cache_object)? { |  | ||||||
|         if k != cache_key { |  | ||||||
|           info!("Over the cache capacity. Evict least recent used entry"); |  | ||||||
|           if let CacheFileOrOnMemory::File(path) = v.target { |  | ||||||
|             mgr.evict(&path).await; |  | ||||||
|           } |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|       Ok(()) |  | ||||||
|     }); |  | ||||||
| 
 |  | ||||||
|     Ok(()) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| fn derive_filename_from_uri(uri: &hyper::Uri) -> String { |  | ||||||
|   let mut hasher = Sha256::new(); |  | ||||||
|   hasher.update(uri.to_string()); |  | ||||||
|   let digest = hasher.finalize(); |  | ||||||
|   general_purpose::URL_SAFE_NO_PAD.encode(digest) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| fn derive_cache_key_from_uri(uri: &hyper::Uri) -> String { |  | ||||||
|   uri.to_string() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| pub fn get_policy_if_cacheable<R>(req: Option<&Request<R>>, res: Option<&Response<Body>>) -> Result<Option<CachePolicy>> |  | ||||||
| where |  | ||||||
|   R: Debug, |  | ||||||
| { |  | ||||||
|   // deduce cache policy from req and res
 |  | ||||||
|   let (Some(req), Some(res)) = (req, res) else { |  | ||||||
|       return Err(RpxyError::Cache("Invalid null request and/or response")); |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|   let new_policy = CachePolicy::new(req, res); |  | ||||||
|   if new_policy.is_storable() { |  | ||||||
|     // debug!("Response is cacheable: {:?}\n{:?}", req, res.headers());
 |  | ||||||
|     Ok(Some(new_policy)) |  | ||||||
|   } else { |  | ||||||
|     Ok(None) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  | @ -1,147 +0,0 @@ | ||||||
| #[cfg(feature = "cache")] |  | ||||||
| use super::cache::{get_policy_if_cacheable, RpxyCache}; |  | ||||||
| use crate::{error::RpxyError, globals::Globals, log::*, CryptoSource}; |  | ||||||
| use async_trait::async_trait; |  | ||||||
| #[cfg(feature = "cache")] |  | ||||||
| use bytes::Buf; |  | ||||||
| use hyper::{ |  | ||||||
|   body::{Body, HttpBody}, |  | ||||||
|   client::{connect::Connect, HttpConnector}, |  | ||||||
|   http::Version, |  | ||||||
|   Client, Request, Response, |  | ||||||
| }; |  | ||||||
| use hyper_rustls::HttpsConnector; |  | ||||||
| 
 |  | ||||||
| #[cfg(feature = "cache")] |  | ||||||
| /// Build synthetic request to cache
 |  | ||||||
| fn build_synth_req_for_cache<T>(req: &Request<T>) -> Request<()> { |  | ||||||
|   let mut builder = Request::builder() |  | ||||||
|     .method(req.method()) |  | ||||||
|     .uri(req.uri()) |  | ||||||
|     .version(req.version()); |  | ||||||
|   // TODO: omit extensions. is this approach correct?
 |  | ||||||
|   for (header_key, header_value) in req.headers() { |  | ||||||
|     builder = builder.header(header_key, header_value); |  | ||||||
|   } |  | ||||||
|   builder.body(()).unwrap() |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[async_trait] |  | ||||||
| /// Definition of the forwarder that simply forward requests from downstream client to upstream app servers.
 |  | ||||||
| pub trait ForwardRequest<B> { |  | ||||||
|   type Error; |  | ||||||
|   async fn request(&self, req: Request<B>) -> Result<Response<Body>, Self::Error>; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Forwarder struct responsible to cache handling
 |  | ||||||
| pub struct Forwarder<C, B = Body> |  | ||||||
| where |  | ||||||
|   C: Connect + Clone + Sync + Send + 'static, |  | ||||||
| { |  | ||||||
|   #[cfg(feature = "cache")] |  | ||||||
|   cache: Option<RpxyCache>, |  | ||||||
|   inner: Client<C, B>, |  | ||||||
|   inner_h2: Client<C, B>, // `h2c` or http/2-only client is defined separately
 |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[async_trait] |  | ||||||
| impl<C, B> ForwardRequest<B> for Forwarder<C, B> |  | ||||||
| where |  | ||||||
|   B: HttpBody + Send + Sync + 'static, |  | ||||||
|   B::Data: Send, |  | ||||||
|   B::Error: Into<Box<dyn std::error::Error + Send + Sync>>, |  | ||||||
|   C: Connect + Clone + Sync + Send + 'static, |  | ||||||
| { |  | ||||||
|   type Error = RpxyError; |  | ||||||
| 
 |  | ||||||
|   #[cfg(feature = "cache")] |  | ||||||
|   async fn request(&self, req: Request<B>) -> Result<Response<Body>, Self::Error> { |  | ||||||
|     let mut synth_req = None; |  | ||||||
|     if self.cache.is_some() { |  | ||||||
|       if let Some(cached_response) = self.cache.as_ref().unwrap().get(&req).await { |  | ||||||
|         // if found, return it as response.
 |  | ||||||
|         info!("Cache hit - Return from cache"); |  | ||||||
|         return Ok(cached_response); |  | ||||||
|       }; |  | ||||||
| 
 |  | ||||||
|       // Synthetic request copy used just for caching (cannot clone request object...)
 |  | ||||||
|       synth_req = Some(build_synth_req_for_cache(&req)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // TODO: This 'match' condition is always evaluated at every 'request' invocation. So, it is inefficient.
 |  | ||||||
|     // Needs to be reconsidered. Currently, this is a kind of work around.
 |  | ||||||
|     // This possibly relates to https://github.com/hyperium/hyper/issues/2417.
 |  | ||||||
|     let res = match req.version() { |  | ||||||
|       Version::HTTP_2 => self.inner_h2.request(req).await.map_err(RpxyError::Hyper), // handles `h2c` requests
 |  | ||||||
|       _ => self.inner.request(req).await.map_err(RpxyError::Hyper), |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     if self.cache.is_none() { |  | ||||||
|       return res; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // check cacheability and store it if cacheable
 |  | ||||||
|     let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { |  | ||||||
|       return res; |  | ||||||
|     }; |  | ||||||
|     let (parts, body) = res.unwrap().into_parts(); |  | ||||||
|     let Ok(mut bytes) = hyper::body::aggregate(body).await else { |  | ||||||
|       return Err(RpxyError::Cache("Failed to write byte buffer")); |  | ||||||
|     }; |  | ||||||
|     let aggregated = bytes.copy_to_bytes(bytes.remaining()); |  | ||||||
| 
 |  | ||||||
|     if let Err(cache_err) = self |  | ||||||
|       .cache |  | ||||||
|       .as_ref() |  | ||||||
|       .unwrap() |  | ||||||
|       .put(synth_req.unwrap().uri(), &aggregated, &cache_policy) |  | ||||||
|       .await |  | ||||||
|     { |  | ||||||
|       error!("{:?}", cache_err); |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     // res
 |  | ||||||
|     Ok(Response::from_parts(parts, Body::from(aggregated))) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   #[cfg(not(feature = "cache"))] |  | ||||||
|   async fn request(&self, req: Request<B>) -> Result<Response<Body>, Self::Error> { |  | ||||||
|     match req.version() { |  | ||||||
|       Version::HTTP_2 => self.inner_h2.request(req).await.map_err(RpxyError::Hyper), // handles `h2c` requests
 |  | ||||||
|       _ => self.inner.request(req).await.map_err(RpxyError::Hyper), |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Forwarder<HttpsConnector<HttpConnector>, Body> { |  | ||||||
|   /// Build forwarder
 |  | ||||||
|   pub async fn new<T: CryptoSource>(_globals: &std::sync::Arc<Globals<T>>) -> Self { |  | ||||||
|     #[cfg(feature = "native-roots")] |  | ||||||
|     let builder = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); |  | ||||||
|     #[cfg(feature = "native-roots")] |  | ||||||
|     let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_native_roots(); |  | ||||||
|     #[cfg(feature = "native-roots")] |  | ||||||
|     info!("Native cert store is used for the connection to backend applications"); |  | ||||||
| 
 |  | ||||||
|     #[cfg(not(feature = "native-roots"))] |  | ||||||
|     let builder = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); |  | ||||||
|     #[cfg(not(feature = "native-roots"))] |  | ||||||
|     let builder_h2 = hyper_rustls::HttpsConnectorBuilder::new().with_webpki_roots(); |  | ||||||
|     #[cfg(not(feature = "native-roots"))] |  | ||||||
|     info!("Mozilla WebPKI root certs is used for the connection to backend applications"); |  | ||||||
| 
 |  | ||||||
|     let connector = builder.https_or_http().enable_http1().enable_http2().build(); |  | ||||||
|     let connector_h2 = builder_h2.https_or_http().enable_http2().build(); |  | ||||||
| 
 |  | ||||||
|     let inner = Client::builder().build::<_, Body>(connector); |  | ||||||
|     let inner_h2 = Client::builder().http2_only(true).build::<_, Body>(connector_h2); |  | ||||||
| 
 |  | ||||||
|     #[cfg(feature = "cache")] |  | ||||||
|     { |  | ||||||
|       let cache = RpxyCache::new(_globals).await; |  | ||||||
|       Self { inner, inner_h2, cache } |  | ||||||
|     } |  | ||||||
|     #[cfg(not(feature = "cache"))] |  | ||||||
|     Self { inner, inner_h2 } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  | @ -1,380 +0,0 @@ | ||||||
| // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 |  | ||||||
| use super::{ |  | ||||||
|   forwarder::{ForwardRequest, Forwarder}, |  | ||||||
|   utils_headers::*, |  | ||||||
|   utils_request::*, |  | ||||||
|   utils_synth_response::*, |  | ||||||
|   HandlerContext, |  | ||||||
| }; |  | ||||||
| use crate::{ |  | ||||||
|   backend::{Backend, UpstreamGroup}, |  | ||||||
|   certs::CryptoSource, |  | ||||||
|   constants::RESPONSE_HEADER_SERVER, |  | ||||||
|   error::*, |  | ||||||
|   globals::Globals, |  | ||||||
|   log::*, |  | ||||||
|   utils::ServerNameBytesExp, |  | ||||||
| }; |  | ||||||
| use derive_builder::Builder; |  | ||||||
| use hyper::{ |  | ||||||
|   client::connect::Connect, |  | ||||||
|   header::{self, HeaderValue}, |  | ||||||
|   http::uri::Scheme, |  | ||||||
|   Body, Request, Response, StatusCode, Uri, Version, |  | ||||||
| }; |  | ||||||
| use std::{net::SocketAddr, sync::Arc}; |  | ||||||
| use tokio::{io::copy_bidirectional, time::timeout}; |  | ||||||
| 
 |  | ||||||
| #[derive(Clone, Builder)] |  | ||||||
| /// HTTP message handler for requests from clients and responses from backend applications,
 |  | ||||||
| /// responsible to manipulate and forward messages to upstream backends and downstream clients.
 |  | ||||||
| pub struct HttpMessageHandler<T, U> |  | ||||||
| where |  | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |  | ||||||
|   U: CryptoSource + Clone, |  | ||||||
| { |  | ||||||
|   forwarder: Arc<Forwarder<T>>, |  | ||||||
|   globals: Arc<Globals<U>>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<T, U> HttpMessageHandler<T, U> |  | ||||||
| where |  | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |  | ||||||
|   U: CryptoSource + Clone, |  | ||||||
| { |  | ||||||
|   /// Return with an arbitrary status code of error and log message
 |  | ||||||
|   fn return_with_error_log(&self, status_code: StatusCode, log_data: &mut MessageLog) -> Result<Response<Body>> { |  | ||||||
|     log_data.status_code(&status_code).output(); |  | ||||||
|     http_error(status_code) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   /// Handle incoming request message from a client
 |  | ||||||
|   pub async fn handle_request( |  | ||||||
|     &self, |  | ||||||
|     mut req: Request<Body>, |  | ||||||
|     client_addr: SocketAddr, // アクセス制御用
 |  | ||||||
|     listen_addr: SocketAddr, |  | ||||||
|     tls_enabled: bool, |  | ||||||
|     tls_server_name: Option<ServerNameBytesExp>, |  | ||||||
|   ) -> Result<Response<Body>> { |  | ||||||
|     ////////
 |  | ||||||
|     let mut log_data = MessageLog::from(&req); |  | ||||||
|     log_data.client_addr(&client_addr); |  | ||||||
|     //////
 |  | ||||||
| 
 |  | ||||||
|     // Here we start to handle with server_name
 |  | ||||||
|     let server_name = if let Ok(v) = req.parse_host() { |  | ||||||
|       ServerNameBytesExp::from(v) |  | ||||||
|     } else { |  | ||||||
|       return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); |  | ||||||
|     }; |  | ||||||
|     // check consistency of between TLS SNI and HOST/Request URI Line.
 |  | ||||||
|     #[allow(clippy::collapsible_if)] |  | ||||||
|     if tls_enabled && self.globals.proxy_config.sni_consistency { |  | ||||||
|       if server_name != tls_server_name.unwrap_or_default() { |  | ||||||
|         return self.return_with_error_log(StatusCode::MISDIRECTED_REQUEST, &mut log_data); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|     // Find backend application for given server_name, and drop if incoming request is invalid as request.
 |  | ||||||
|     let backend = match self.globals.backends.apps.get(&server_name) { |  | ||||||
|       Some(be) => be, |  | ||||||
|       None => { |  | ||||||
|         let Some(default_server_name) = &self.globals.backends.default_server_name_bytes else { |  | ||||||
|           return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); |  | ||||||
|         }; |  | ||||||
|         debug!("Serving by default app"); |  | ||||||
|         self.globals.backends.apps.get(default_server_name).unwrap() |  | ||||||
|       } |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     // Redirect to https if !tls_enabled and redirect_to_https is true
 |  | ||||||
|     if !tls_enabled && backend.https_redirection.unwrap_or(false) { |  | ||||||
|       debug!("Redirect to secure connection: {}", &backend.server_name); |  | ||||||
|       log_data.status_code(&StatusCode::PERMANENT_REDIRECT).output(); |  | ||||||
|       return secure_redirection(&backend.server_name, self.globals.proxy_config.https_port, &req); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Find reverse proxy for given path and choose one of upstream host
 |  | ||||||
|     // Longest prefix match
 |  | ||||||
|     let path = req.uri().path(); |  | ||||||
|     let Some(upstream_group) = backend.reverse_proxy.get(path) else { |  | ||||||
|       return self.return_with_error_log(StatusCode::NOT_FOUND, &mut log_data) |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     // Upgrade in request header
 |  | ||||||
|     let upgrade_in_request = extract_upgrade(req.headers()); |  | ||||||
|     let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>(); |  | ||||||
| 
 |  | ||||||
|     // Build request from destination information
 |  | ||||||
|     let _context = match self.generate_request_forwarded( |  | ||||||
|       &client_addr, |  | ||||||
|       &listen_addr, |  | ||||||
|       &mut req, |  | ||||||
|       &upgrade_in_request, |  | ||||||
|       upstream_group, |  | ||||||
|       tls_enabled, |  | ||||||
|     ) { |  | ||||||
|       Err(e) => { |  | ||||||
|         error!("Failed to generate destination uri for reverse proxy: {}", e); |  | ||||||
|         return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); |  | ||||||
|       } |  | ||||||
|       Ok(v) => v, |  | ||||||
|     }; |  | ||||||
|     debug!("Request to be forwarded: {:?}", req); |  | ||||||
|     log_data.xff(&req.headers().get("x-forwarded-for")); |  | ||||||
|     log_data.upstream(req.uri()); |  | ||||||
|     //////
 |  | ||||||
| 
 |  | ||||||
|     // Forward request to a chosen backend
 |  | ||||||
|     let mut res_backend = { |  | ||||||
|       let Ok(result) = timeout(self.globals.proxy_config.upstream_timeout, self.forwarder.request(req)).await else { |  | ||||||
|         return self.return_with_error_log(StatusCode::GATEWAY_TIMEOUT, &mut log_data); |  | ||||||
|       }; |  | ||||||
|       match result { |  | ||||||
|         Ok(res) => res, |  | ||||||
|         Err(e) => { |  | ||||||
|           error!("Failed to get response from backend: {}", e); |  | ||||||
|           return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); |  | ||||||
|         } |  | ||||||
|       } |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     // Process reverse proxy context generated during the forwarding request generation.
 |  | ||||||
|     #[cfg(feature = "sticky-cookie")] |  | ||||||
|     if let Some(context_from_lb) = _context.context_lb { |  | ||||||
|       let res_headers = res_backend.headers_mut(); |  | ||||||
|       if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { |  | ||||||
|         error!("Failed to append context to the response given from backend: {}", e); |  | ||||||
|         return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { |  | ||||||
|       // Generate response to client
 |  | ||||||
|       if self.generate_response_forwarded(&mut res_backend, backend).is_err() { |  | ||||||
|         return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); |  | ||||||
|       } |  | ||||||
|       log_data.status_code(&res_backend.status()).output(); |  | ||||||
|       return Ok(res_backend); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Handle StatusCode::SWITCHING_PROTOCOLS in response
 |  | ||||||
|     let upgrade_in_response = extract_upgrade(res_backend.headers()); |  | ||||||
|     let should_upgrade = if let (Some(u_req), Some(u_res)) = (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) |  | ||||||
|     { |  | ||||||
|       u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase() |  | ||||||
|     } else { |  | ||||||
|       false |  | ||||||
|     }; |  | ||||||
|     if !should_upgrade { |  | ||||||
|       error!( |  | ||||||
|         "Backend tried to switch to protocol {:?} when {:?} was requested", |  | ||||||
|         upgrade_in_response, upgrade_in_request |  | ||||||
|       ); |  | ||||||
|       return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); |  | ||||||
|     } |  | ||||||
|     let Some(request_upgraded) = request_upgraded else { |  | ||||||
|       error!("Request does not have an upgrade extension"); |  | ||||||
|       return self.return_with_error_log(StatusCode::BAD_REQUEST, &mut log_data); |  | ||||||
|     }; |  | ||||||
|     let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else { |  | ||||||
|       error!("Response does not have an upgrade extension"); |  | ||||||
|       return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     self.globals.runtime_handle.spawn(async move { |  | ||||||
|       let mut response_upgraded = onupgrade.await.map_err(|e| { |  | ||||||
|         error!("Failed to upgrade response: {}", e); |  | ||||||
|         RpxyError::Hyper(e) |  | ||||||
|       })?; |  | ||||||
|       let mut request_upgraded = request_upgraded.await.map_err(|e| { |  | ||||||
|         error!("Failed to upgrade request: {}", e); |  | ||||||
|         RpxyError::Hyper(e) |  | ||||||
|       })?; |  | ||||||
|       copy_bidirectional(&mut response_upgraded, &mut request_upgraded) |  | ||||||
|         .await |  | ||||||
|         .map_err(|e| { |  | ||||||
|           error!("Coping between upgraded connections failed: {}", e); |  | ||||||
|           RpxyError::Io(e) |  | ||||||
|         })?; |  | ||||||
|       Ok(()) as Result<()> |  | ||||||
|     }); |  | ||||||
|     log_data.status_code(&res_backend.status()).output(); |  | ||||||
|     Ok(res_backend) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   ////////////////////////////////////////////////////
 |  | ||||||
|   // Functions to generate messages
 |  | ||||||
|   ////////////////////////////////////////////////////
 |  | ||||||
| 
 |  | ||||||
|   /// Manipulate a response message sent from a backend application to forward downstream to a client.
 |  | ||||||
|   fn generate_response_forwarded<B>(&self, response: &mut Response<B>, chosen_backend: &Backend<U>) -> Result<()> |  | ||||||
|   where |  | ||||||
|     B: core::fmt::Debug, |  | ||||||
|   { |  | ||||||
|     let headers = response.headers_mut(); |  | ||||||
|     remove_connection_header(headers); |  | ||||||
|     remove_hop_header(headers); |  | ||||||
|     add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; |  | ||||||
| 
 |  | ||||||
|     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] |  | ||||||
|     { |  | ||||||
|       // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled
 |  | ||||||
|       // TODO: This is a workaround for avoiding a client authentication in HTTP/3
 |  | ||||||
|       if self.globals.proxy_config.http3 |  | ||||||
|         && chosen_backend |  | ||||||
|           .crypto_source |  | ||||||
|           .as_ref() |  | ||||||
|           .is_some_and(|v| !v.is_mutual_tls()) |  | ||||||
|       { |  | ||||||
|         if let Some(port) = self.globals.proxy_config.https_port { |  | ||||||
|           add_header_entry_overwrite_if_exist( |  | ||||||
|             headers, |  | ||||||
|             header::ALT_SVC.as_str(), |  | ||||||
|             format!( |  | ||||||
|               "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", |  | ||||||
|               port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age |  | ||||||
|             ), |  | ||||||
|           )?; |  | ||||||
|         } |  | ||||||
|       } else { |  | ||||||
|         // remove alt-svc to disallow requests via http3
 |  | ||||||
|         headers.remove(header::ALT_SVC.as_str()); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|     #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] |  | ||||||
|     { |  | ||||||
|       if let Some(port) = self.globals.proxy_config.https_port { |  | ||||||
|         headers.remove(header::ALT_SVC.as_str()); |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     Ok(()) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   #[allow(clippy::too_many_arguments)] |  | ||||||
|   /// Manipulate a request message sent from a client to forward upstream to a backend application
 |  | ||||||
|   fn generate_request_forwarded<B>( |  | ||||||
|     &self, |  | ||||||
|     client_addr: &SocketAddr, |  | ||||||
|     listen_addr: &SocketAddr, |  | ||||||
|     req: &mut Request<B>, |  | ||||||
|     upgrade: &Option<String>, |  | ||||||
|     upstream_group: &UpstreamGroup, |  | ||||||
|     tls_enabled: bool, |  | ||||||
|   ) -> Result<HandlerContext> { |  | ||||||
|     debug!("Generate request to be forwarded"); |  | ||||||
| 
 |  | ||||||
|     // Add te: trailer if contained in original request
 |  | ||||||
|     let contains_te_trailers = { |  | ||||||
|       if let Some(te) = req.headers().get(header::TE) { |  | ||||||
|         te.as_bytes() |  | ||||||
|           .split(|v| v == &b',' || v == &b' ') |  | ||||||
|           .any(|x| x == "trailers".as_bytes()) |  | ||||||
|       } else { |  | ||||||
|         false |  | ||||||
|       } |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     let uri = req.uri().to_string(); |  | ||||||
|     let headers = req.headers_mut(); |  | ||||||
|     // delete headers specified in header.connection
 |  | ||||||
|     remove_connection_header(headers); |  | ||||||
|     // delete hop headers including header.connection
 |  | ||||||
|     remove_hop_header(headers); |  | ||||||
|     // X-Forwarded-For
 |  | ||||||
|     add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &uri)?; |  | ||||||
| 
 |  | ||||||
|     // Add te: trailer if te_trailer
 |  | ||||||
|     if contains_te_trailers { |  | ||||||
|       headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // add "host" header of original server_name if not exist (default)
 |  | ||||||
|     if req.headers().get(header::HOST).is_none() { |  | ||||||
|       let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); |  | ||||||
|       req |  | ||||||
|         .headers_mut() |  | ||||||
|         .insert(header::HOST, HeaderValue::from_str(&org_host)?); |  | ||||||
|     }; |  | ||||||
| 
 |  | ||||||
|     /////////////////////////////////////////////
 |  | ||||||
|     // Fix unique upstream destination since there could be multiple ones.
 |  | ||||||
|     #[cfg(feature = "sticky-cookie")] |  | ||||||
|     let (upstream_chosen_opt, context_from_lb) = { |  | ||||||
|       let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { |  | ||||||
|         takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? |  | ||||||
|       } else { |  | ||||||
|         None |  | ||||||
|       }; |  | ||||||
|       upstream_group.get(&context_to_lb) |  | ||||||
|     }; |  | ||||||
|     #[cfg(not(feature = "sticky-cookie"))] |  | ||||||
|     let (upstream_chosen_opt, _) = upstream_group.get(&None); |  | ||||||
| 
 |  | ||||||
|     let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; |  | ||||||
|     let context = HandlerContext { |  | ||||||
|       #[cfg(feature = "sticky-cookie")] |  | ||||||
|       context_lb: context_from_lb, |  | ||||||
|       #[cfg(not(feature = "sticky-cookie"))] |  | ||||||
|       context_lb: None, |  | ||||||
|     }; |  | ||||||
|     /////////////////////////////////////////////
 |  | ||||||
| 
 |  | ||||||
|     // apply upstream-specific headers given in upstream_option
 |  | ||||||
|     let headers = req.headers_mut(); |  | ||||||
|     apply_upstream_options_to_header(headers, client_addr, upstream_group, &upstream_chosen.uri)?; |  | ||||||
| 
 |  | ||||||
|     // update uri in request
 |  | ||||||
|     if !(upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some()) { |  | ||||||
|       return Err(RpxyError::Handler("Upstream uri `scheme` and `authority` is broken")); |  | ||||||
|     }; |  | ||||||
|     let new_uri = Uri::builder() |  | ||||||
|       .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) |  | ||||||
|       .authority(upstream_chosen.uri.authority().unwrap().as_str()); |  | ||||||
|     let org_pq = match req.uri().path_and_query() { |  | ||||||
|       Some(pq) => pq.to_string(), |  | ||||||
|       None => "/".to_string(), |  | ||||||
|     } |  | ||||||
|     .into_bytes(); |  | ||||||
| 
 |  | ||||||
|     // replace some parts of path if opt_replace_path is enabled for chosen upstream
 |  | ||||||
|     let new_pq = match &upstream_group.replace_path { |  | ||||||
|       Some(new_path) => { |  | ||||||
|         let matched_path: &[u8] = upstream_group.path.as_ref(); |  | ||||||
|         if matched_path.is_empty() || org_pq.len() < matched_path.len() { |  | ||||||
|           return Err(RpxyError::Handler("Upstream uri `path and query` is broken")); |  | ||||||
|         }; |  | ||||||
|         let mut new_pq = Vec::<u8>::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); |  | ||||||
|         new_pq.extend_from_slice(new_path.as_ref()); |  | ||||||
|         new_pq.extend_from_slice(&org_pq[matched_path.len()..]); |  | ||||||
|         new_pq |  | ||||||
|       } |  | ||||||
|       None => org_pq, |  | ||||||
|     }; |  | ||||||
|     *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; |  | ||||||
| 
 |  | ||||||
|     // upgrade
 |  | ||||||
|     if let Some(v) = upgrade { |  | ||||||
|       req.headers_mut().insert(header::UPGRADE, v.parse()?); |  | ||||||
|       req |  | ||||||
|         .headers_mut() |  | ||||||
|         .insert(header::CONNECTION, HeaderValue::from_str("upgrade")?); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
 |  | ||||||
|     if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { |  | ||||||
|       // Change version to http/1.1 when destination scheme is http
 |  | ||||||
|       debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); |  | ||||||
|       *req.version_mut() = Version::HTTP_11; |  | ||||||
|     } else if req.version() == Version::HTTP_3 { |  | ||||||
|       // HTTP/3 is always https
 |  | ||||||
|       debug!("HTTP/3 is currently unsupported for request to upstream."); |  | ||||||
|       *req.version_mut() = Version::HTTP_2; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     apply_upstream_options_to_request_line(req, upstream_group)?; |  | ||||||
| 
 |  | ||||||
|     Ok(context) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  | @ -1,24 +0,0 @@ | ||||||
| #[cfg(feature = "cache")] |  | ||||||
| mod cache; |  | ||||||
| mod forwarder; |  | ||||||
| mod handler_main; |  | ||||||
| mod utils_headers; |  | ||||||
| mod utils_request; |  | ||||||
| mod utils_synth_response; |  | ||||||
| 
 |  | ||||||
| #[cfg(feature = "sticky-cookie")] |  | ||||||
| use crate::backend::LbContext; |  | ||||||
| pub use { |  | ||||||
|   forwarder::Forwarder, |  | ||||||
|   handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}, |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| #[allow(dead_code)] |  | ||||||
| #[derive(Debug)] |  | ||||||
| /// Context object to handle sticky cookies at HTTP message handler
 |  | ||||||
| struct HandlerContext { |  | ||||||
|   #[cfg(feature = "sticky-cookie")] |  | ||||||
|   context_lb: Option<LbContext>, |  | ||||||
|   #[cfg(not(feature = "sticky-cookie"))] |  | ||||||
|   context_lb: Option<()>, |  | ||||||
| } |  | ||||||
|  | @ -1,64 +0,0 @@ | ||||||
| use crate::{ |  | ||||||
|   backend::{UpstreamGroup, UpstreamOption}, |  | ||||||
|   error::*, |  | ||||||
| }; |  | ||||||
| use hyper::{header, Request}; |  | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////
 |  | ||||||
| // Functions to manipulate request line
 |  | ||||||
| 
 |  | ||||||
| /// Apply upstream options in request line, specified in the configuration
 |  | ||||||
| pub(super) fn apply_upstream_options_to_request_line<B>(req: &mut Request<B>, upstream: &UpstreamGroup) -> Result<()> { |  | ||||||
|   for opt in upstream.opts.iter() { |  | ||||||
|     match opt { |  | ||||||
|       UpstreamOption::ForceHttp11Upstream => *req.version_mut() = hyper::Version::HTTP_11, |  | ||||||
|       UpstreamOption::ForceHttp2Upstream => { |  | ||||||
|         // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt
 |  | ||||||
|         // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required.
 |  | ||||||
|         *req.version_mut() = hyper::Version::HTTP_2; |  | ||||||
|       } |  | ||||||
|       _ => (), |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   Ok(()) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Trait defining parser of hostname
 |  | ||||||
| pub trait ParseHost { |  | ||||||
|   fn parse_host(&self) -> Result<&[u8]>; |  | ||||||
| } |  | ||||||
| impl<B> ParseHost for Request<B> { |  | ||||||
|   /// Extract hostname from either the request HOST header or request line
 |  | ||||||
|   fn parse_host(&self) -> Result<&[u8]> { |  | ||||||
|     let headers_host = self.headers().get(header::HOST); |  | ||||||
|     let uri_host = self.uri().host(); |  | ||||||
|     // let uri_port = self.uri().port_u16();
 |  | ||||||
| 
 |  | ||||||
|     if !(!(headers_host.is_none() && uri_host.is_none())) { |  | ||||||
|       return Err(RpxyError::Request("No host in request header")); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // prioritize server_name in uri
 |  | ||||||
|     uri_host.map_or_else( |  | ||||||
|       || { |  | ||||||
|         let m = headers_host.unwrap().as_bytes(); |  | ||||||
|         if m.starts_with(&[b'[']) { |  | ||||||
|           // v6 address with bracket case. if port is specified, always it is in this case.
 |  | ||||||
|           let mut iter = m.split(|ptr| ptr == &b'[' || ptr == &b']'); |  | ||||||
|           iter.next().ok_or(RpxyError::Request("Invalid Host"))?; // first item is always blank
 |  | ||||||
|           iter.next().ok_or(RpxyError::Request("Invalid Host")) |  | ||||||
|         } else if m.len() - m.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { |  | ||||||
|           // v6 address case, if 2 or more ':' is contained
 |  | ||||||
|           Ok(m) |  | ||||||
|         } else { |  | ||||||
|           // v4 address or hostname
 |  | ||||||
|           m.split(|colon| colon == &b':') |  | ||||||
|             .next() |  | ||||||
|             .ok_or(RpxyError::Request("Invalid Host")) |  | ||||||
|         } |  | ||||||
|       }, |  | ||||||
|       |v| Ok(v.as_bytes()), |  | ||||||
|     ) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  | @ -1,35 +0,0 @@ | ||||||
| // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy
 |  | ||||||
| use crate::error::*; |  | ||||||
| use hyper::{Body, Request, Response, StatusCode, Uri}; |  | ||||||
| 
 |  | ||||||
| ////////////////////////////////////////////////////
 |  | ||||||
| // Functions to create response (error or redirect)
 |  | ||||||
| 
 |  | ||||||
| /// Generate a synthetic response  message of a certain error status code
 |  | ||||||
| pub(super) fn http_error(status_code: StatusCode) -> Result<Response<Body>> { |  | ||||||
|   let response = Response::builder().status(status_code).body(Body::empty())?; |  | ||||||
|   Ok(response) |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Generate synthetic response message of a redirection to https host with 301
 |  | ||||||
| pub(super) fn secure_redirection<B>( |  | ||||||
|   server_name: &str, |  | ||||||
|   tls_port: Option<u16>, |  | ||||||
|   req: &Request<B>, |  | ||||||
| ) -> Result<Response<Body>> { |  | ||||||
|   let pq = match req.uri().path_and_query() { |  | ||||||
|     Some(x) => x.as_str(), |  | ||||||
|     _ => "", |  | ||||||
|   }; |  | ||||||
|   let new_uri = Uri::builder().scheme("https").path_and_query(pq); |  | ||||||
|   let dest_uri = match tls_port { |  | ||||||
|     Some(443) | None => new_uri.authority(server_name), |  | ||||||
|     Some(p) => new_uri.authority(format!("{server_name}:{p}")), |  | ||||||
|   } |  | ||||||
|   .build()?; |  | ||||||
|   let response = Response::builder() |  | ||||||
|     .status(StatusCode::MOVED_PERMANENTLY) |  | ||||||
|     .header("Location", dest_uri.to_string()) |  | ||||||
|     .body(Body::empty())?; |  | ||||||
|   Ok(response) |  | ||||||
| } |  | ||||||
							
								
								
									
										370
									
								
								rpxy-lib/src/hyper_ext/body_incoming_like.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										370
									
								
								rpxy-lib/src/hyper_ext/body_incoming_like.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,370 @@ | ||||||
|  | use super::watch; | ||||||
|  | use crate::error::*; | ||||||
|  | use futures_channel::{mpsc, oneshot}; | ||||||
|  | use futures_util::{stream::FusedStream, Future, Stream}; | ||||||
|  | use http::HeaderMap; | ||||||
|  | use hyper::body::{Body, Bytes, Frame, SizeHint}; | ||||||
|  | use std::{ | ||||||
|  |   pin::Pin, | ||||||
|  |   task::{Context, Poll}, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////////////
 | ||||||
|  | /// Incoming like body to handle incoming request body
 | ||||||
|  | /// ported from https://github.com/hyperium/hyper/blob/master/src/body/incoming.rs
 | ||||||
|  | pub struct IncomingLike { | ||||||
|  |   content_length: DecodedLength, | ||||||
|  |   want_tx: watch::Sender, | ||||||
|  |   data_rx: mpsc::Receiver<Result<Bytes, RpxyError>>, | ||||||
|  |   trailers_rx: oneshot::Receiver<HeaderMap>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | macro_rules! ready { | ||||||
|  |   ($e:expr) => { | ||||||
|  |     match $e { | ||||||
|  |       Poll::Ready(v) => v, | ||||||
|  |       Poll::Pending => return Poll::Pending, | ||||||
|  |     } | ||||||
|  |   }; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type BodySender = mpsc::Sender<Result<Bytes, RpxyError>>; | ||||||
|  | type TrailersSender = oneshot::Sender<HeaderMap>; | ||||||
|  | 
 | ||||||
|  | const MAX_LEN: u64 = std::u64::MAX - 2; | ||||||
|  | #[derive(Clone, Copy, PartialEq, Eq)] | ||||||
|  | pub(crate) struct DecodedLength(u64); | ||||||
|  | impl DecodedLength { | ||||||
|  |   pub(crate) const CLOSE_DELIMITED: DecodedLength = DecodedLength(::std::u64::MAX); | ||||||
|  |   pub(crate) const CHUNKED: DecodedLength = DecodedLength(::std::u64::MAX - 1); | ||||||
|  |   pub(crate) const ZERO: DecodedLength = DecodedLength(0); | ||||||
|  | 
 | ||||||
|  |   #[allow(dead_code)] | ||||||
|  |   pub(crate) fn new(len: u64) -> Self { | ||||||
|  |     debug_assert!(len <= MAX_LEN); | ||||||
|  |     DecodedLength(len) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   pub(crate) fn sub_if(&mut self, amt: u64) { | ||||||
|  |     match *self { | ||||||
|  |       DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => (), | ||||||
|  |       DecodedLength(ref mut known) => { | ||||||
|  |         *known -= amt; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   /// Converts to an Option<u64> representing a Known or Unknown length.
 | ||||||
|  |   pub(crate) fn into_opt(self) -> Option<u64> { | ||||||
|  |     match self { | ||||||
|  |       DecodedLength::CHUNKED | DecodedLength::CLOSE_DELIMITED => None, | ||||||
|  |       DecodedLength(known) => Some(known), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | pub(crate) struct Sender { | ||||||
|  |   want_rx: watch::Receiver, | ||||||
|  |   data_tx: BodySender, | ||||||
|  |   trailers_tx: Option<TrailersSender>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | const WANT_PENDING: usize = 1; | ||||||
|  | const WANT_READY: usize = 2; | ||||||
|  | 
 | ||||||
|  | impl IncomingLike { | ||||||
|  |   /// Create a `Body` stream with an associated sender half.
 | ||||||
|  |   ///
 | ||||||
|  |   /// Useful when wanting to stream chunks from another thread.
 | ||||||
|  |   #[inline] | ||||||
|  |   #[allow(unused)] | ||||||
|  |   pub(crate) fn channel() -> (Sender, IncomingLike) { | ||||||
|  |     Self::new_channel(DecodedLength::CHUNKED, /*wanter =*/ false) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   pub(crate) fn new_channel(content_length: DecodedLength, wanter: bool) -> (Sender, IncomingLike) { | ||||||
|  |     let (data_tx, data_rx) = mpsc::channel(0); | ||||||
|  |     let (trailers_tx, trailers_rx) = oneshot::channel(); | ||||||
|  | 
 | ||||||
|  |     // If wanter is true, `Sender::poll_ready()` won't becoming ready
 | ||||||
|  |     // until the `Body` has been polled for data once.
 | ||||||
|  |     let want = if wanter { WANT_PENDING } else { WANT_READY }; | ||||||
|  | 
 | ||||||
|  |     let (want_tx, want_rx) = watch::channel(want); | ||||||
|  | 
 | ||||||
|  |     let tx = Sender { | ||||||
|  |       want_rx, | ||||||
|  |       data_tx, | ||||||
|  |       trailers_tx: Some(trailers_tx), | ||||||
|  |     }; | ||||||
|  |     let rx = IncomingLike { | ||||||
|  |       content_length, | ||||||
|  |       want_tx, | ||||||
|  |       data_rx, | ||||||
|  |       trailers_rx, | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     (tx, rx) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Body for IncomingLike { | ||||||
|  |   type Data = Bytes; | ||||||
|  |   type Error = RpxyError; | ||||||
|  | 
 | ||||||
|  |   fn poll_frame( | ||||||
|  |     mut self: Pin<&mut Self>, | ||||||
|  |     cx: &mut Context<'_>, | ||||||
|  |   ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { | ||||||
|  |     self.want_tx.send(WANT_READY); | ||||||
|  | 
 | ||||||
|  |     if !self.data_rx.is_terminated() { | ||||||
|  |       if let Some(chunk) = ready!(Pin::new(&mut self.data_rx).poll_next(cx)?) { | ||||||
|  |         self.content_length.sub_if(chunk.len() as u64); | ||||||
|  |         return Poll::Ready(Some(Ok(Frame::data(chunk)))); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // check trailers after data is terminated
 | ||||||
|  |     match ready!(Pin::new(&mut self.trailers_rx).poll(cx)) { | ||||||
|  |       Ok(t) => Poll::Ready(Some(Ok(Frame::trailers(t)))), | ||||||
|  |       Err(_) => Poll::Ready(None), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   fn is_end_stream(&self) -> bool { | ||||||
|  |     self.content_length == DecodedLength::ZERO | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   fn size_hint(&self) -> SizeHint { | ||||||
|  |     macro_rules! opt_len { | ||||||
|  |       ($content_length:expr) => {{ | ||||||
|  |         let mut hint = SizeHint::default(); | ||||||
|  | 
 | ||||||
|  |         if let Some(content_length) = $content_length.into_opt() { | ||||||
|  |           hint.set_exact(content_length); | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         hint | ||||||
|  |       }}; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     opt_len!(self.content_length) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Sender { | ||||||
|  |   /// Check to see if this `Sender` can send more data.
 | ||||||
|  |   pub(crate) fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<RpxyResult<()>> { | ||||||
|  |     // Check if the receiver end has tried polling for the body yet
 | ||||||
|  |     ready!(self.poll_want(cx)?); | ||||||
|  |     self | ||||||
|  |       .data_tx | ||||||
|  |       .poll_ready(cx) | ||||||
|  |       .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   fn poll_want(&mut self, cx: &mut Context<'_>) -> Poll<RpxyResult<()>> { | ||||||
|  |     match self.want_rx.load(cx) { | ||||||
|  |       WANT_READY => Poll::Ready(Ok(())), | ||||||
|  |       WANT_PENDING => Poll::Pending, | ||||||
|  |       watch::CLOSED => Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)), | ||||||
|  |       unexpected => unreachable!("want_rx value: {}", unexpected), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   async fn ready(&mut self) -> RpxyResult<()> { | ||||||
|  |     futures_util::future::poll_fn(|cx| self.poll_ready(cx)).await | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Send data on data channel when it is ready.
 | ||||||
|  |   #[allow(unused)] | ||||||
|  |   pub(crate) async fn send_data(&mut self, chunk: Bytes) -> RpxyResult<()> { | ||||||
|  |     self.ready().await?; | ||||||
|  |     self | ||||||
|  |       .data_tx | ||||||
|  |       .try_send(Ok(chunk)) | ||||||
|  |       .map_err(|_| RpxyError::HyperIncomingLikeNewClosed) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Send trailers on trailers channel.
 | ||||||
|  |   #[allow(unused)] | ||||||
|  |   pub(crate) async fn send_trailers(&mut self, trailers: HeaderMap) -> RpxyResult<()> { | ||||||
|  |     let tx = match self.trailers_tx.take() { | ||||||
|  |       Some(tx) => tx, | ||||||
|  |       None => return Err(RpxyError::HyperIncomingLikeNewClosed), | ||||||
|  |     }; | ||||||
|  |     tx.send(trailers).map_err(|_| RpxyError::HyperIncomingLikeNewClosed) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Try to send data on this channel.
 | ||||||
|  |   ///
 | ||||||
|  |   /// # Errors
 | ||||||
|  |   ///
 | ||||||
|  |   /// Returns `Err(Bytes)` if the channel could not (currently) accept
 | ||||||
|  |   /// another `Bytes`.
 | ||||||
|  |   ///
 | ||||||
|  |   /// # Note
 | ||||||
|  |   ///
 | ||||||
|  |   /// This is mostly useful for when trying to send from some other thread
 | ||||||
|  |   /// that doesn't have an async context. If in an async context, prefer
 | ||||||
|  |   /// `send_data()` instead.
 | ||||||
|  |   #[allow(unused)] | ||||||
|  |   pub(crate) fn try_send_data(&mut self, chunk: Bytes) -> Result<(), Bytes> { | ||||||
|  |     self | ||||||
|  |       .data_tx | ||||||
|  |       .try_send(Ok(chunk)) | ||||||
|  |       .map_err(|err| err.into_inner().expect("just sent Ok")) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[allow(unused)] | ||||||
|  |   pub(crate) fn abort(mut self) { | ||||||
|  |     self.send_error(RpxyError::HyperNewBodyWriteAborted); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   pub(crate) fn send_error(&mut self, err: RpxyError) { | ||||||
|  |     let _ = self | ||||||
|  |       .data_tx | ||||||
|  |       // clone so the send works even if buffer is full
 | ||||||
|  |       .clone() | ||||||
|  |       .try_send(Err(err)); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(test)] | ||||||
|  | mod tests { | ||||||
|  |   use std::mem; | ||||||
|  |   use std::task::Poll; | ||||||
|  | 
 | ||||||
|  |   use super::{Body, DecodedLength, IncomingLike, Sender, SizeHint}; | ||||||
|  |   use crate::error::RpxyError; | ||||||
|  |   use http_body_util::BodyExt; | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn test_size_of() { | ||||||
|  |     // These are mostly to help catch *accidentally* increasing
 | ||||||
|  |     // the size by too much.
 | ||||||
|  | 
 | ||||||
|  |     let body_size = mem::size_of::<IncomingLike>(); | ||||||
|  |     let body_expected_size = mem::size_of::<u64>() * 5; | ||||||
|  |     assert!( | ||||||
|  |       body_size <= body_expected_size, | ||||||
|  |       "Body size = {} <= {}", | ||||||
|  |       body_size, | ||||||
|  |       body_expected_size, | ||||||
|  |     ); | ||||||
|  | 
 | ||||||
|  |     //assert_eq!(body_size, mem::size_of::<Option<Incoming>>(), "Option<Incoming>");
 | ||||||
|  | 
 | ||||||
|  |     assert_eq!(mem::size_of::<Sender>(), mem::size_of::<usize>() * 5, "Sender"); | ||||||
|  | 
 | ||||||
|  |     assert_eq!( | ||||||
|  |       mem::size_of::<Sender>(), | ||||||
|  |       mem::size_of::<Option<Sender>>(), | ||||||
|  |       "Option<Sender>" | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  |   #[test] | ||||||
|  |   fn size_hint() { | ||||||
|  |     fn eq(body: IncomingLike, b: SizeHint, note: &str) { | ||||||
|  |       let a = body.size_hint(); | ||||||
|  |       assert_eq!(a.lower(), b.lower(), "lower for {:?}", note); | ||||||
|  |       assert_eq!(a.upper(), b.upper(), "upper for {:?}", note); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     eq(IncomingLike::channel().1, SizeHint::new(), "channel"); | ||||||
|  | 
 | ||||||
|  |     eq( | ||||||
|  |       IncomingLike::new_channel(DecodedLength::new(4), /*wanter =*/ false).1, | ||||||
|  |       SizeHint::with_exact(4), | ||||||
|  |       "channel with length", | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[tokio::test] | ||||||
|  |   async fn channel_abort() { | ||||||
|  |     let (tx, mut rx) = IncomingLike::channel(); | ||||||
|  | 
 | ||||||
|  |     tx.abort(); | ||||||
|  | 
 | ||||||
|  |     match rx.frame().await.unwrap() { | ||||||
|  |       Err(RpxyError::HyperNewBodyWriteAborted) => true, | ||||||
|  |       unexpected => panic!("unexpected: {:?}", unexpected), | ||||||
|  |     }; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[tokio::test] | ||||||
|  |   async fn channel_abort_when_buffer_is_full() { | ||||||
|  |     let (mut tx, mut rx) = IncomingLike::channel(); | ||||||
|  | 
 | ||||||
|  |     tx.try_send_data("chunk 1".into()).expect("send 1"); | ||||||
|  |     // buffer is full, but can still send abort
 | ||||||
|  |     tx.abort(); | ||||||
|  | 
 | ||||||
|  |     let chunk1 = rx.frame().await.expect("item 1").expect("chunk 1").into_data().unwrap(); | ||||||
|  |     assert_eq!(chunk1, "chunk 1"); | ||||||
|  | 
 | ||||||
|  |     match rx.frame().await.unwrap() { | ||||||
|  |       Err(RpxyError::HyperNewBodyWriteAborted) => true, | ||||||
|  |       unexpected => panic!("unexpected: {:?}", unexpected), | ||||||
|  |     }; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn channel_buffers_one() { | ||||||
|  |     let (mut tx, _rx) = IncomingLike::channel(); | ||||||
|  | 
 | ||||||
|  |     tx.try_send_data("chunk 1".into()).expect("send 1"); | ||||||
|  | 
 | ||||||
|  |     // buffer is now full
 | ||||||
|  |     let chunk2 = tx.try_send_data("chunk 2".into()).expect_err("send 2"); | ||||||
|  |     assert_eq!(chunk2, "chunk 2"); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[tokio::test] | ||||||
|  |   async fn channel_empty() { | ||||||
|  |     let (_, mut rx) = IncomingLike::channel(); | ||||||
|  | 
 | ||||||
|  |     assert!(rx.frame().await.is_none()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn channel_ready() { | ||||||
|  |     let (mut tx, _rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ false); | ||||||
|  | 
 | ||||||
|  |     let mut tx_ready = tokio_test::task::spawn(tx.ready()); | ||||||
|  | 
 | ||||||
|  |     assert!(tx_ready.poll().is_ready(), "tx is ready immediately"); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn channel_wanter() { | ||||||
|  |     let (mut tx, mut rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); | ||||||
|  | 
 | ||||||
|  |     let mut tx_ready = tokio_test::task::spawn(tx.ready()); | ||||||
|  |     let mut rx_data = tokio_test::task::spawn(rx.frame()); | ||||||
|  | 
 | ||||||
|  |     assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); | ||||||
|  | 
 | ||||||
|  |     assert!(rx_data.poll().is_pending(), "poll rx.data"); | ||||||
|  |     assert!(tx_ready.is_woken(), "rx poll wakes tx"); | ||||||
|  | 
 | ||||||
|  |     assert!(tx_ready.poll().is_ready(), "tx is ready after rx has been polled"); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  | 
 | ||||||
|  |   fn channel_notices_closure() { | ||||||
|  |     let (mut tx, rx) = IncomingLike::new_channel(DecodedLength::CHUNKED, /*wanter = */ true); | ||||||
|  | 
 | ||||||
|  |     let mut tx_ready = tokio_test::task::spawn(tx.ready()); | ||||||
|  | 
 | ||||||
|  |     assert!(tx_ready.poll().is_pending(), "tx isn't ready before rx has been polled"); | ||||||
|  | 
 | ||||||
|  |     drop(rx); | ||||||
|  |     assert!(tx_ready.is_woken(), "dropping rx wakes tx"); | ||||||
|  | 
 | ||||||
|  |     match tx_ready.poll() { | ||||||
|  |       Poll::Ready(Err(RpxyError::HyperIncomingLikeNewClosed)) => (), | ||||||
|  |       unexpected => panic!("tx poll ready unexpected: {:?}", unexpected), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										75
									
								
								rpxy-lib/src/hyper_ext/body_type.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										75
									
								
								rpxy-lib/src/hyper_ext/body_type.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,75 @@ | ||||||
|  | use super::body::IncomingLike; | ||||||
|  | use crate::error::RpxyError; | ||||||
|  | use futures::channel::mpsc::UnboundedReceiver; | ||||||
|  | use http_body_util::{combinators, BodyExt, Empty, Full, StreamBody}; | ||||||
|  | use hyper::body::{Body, Bytes, Frame, Incoming}; | ||||||
|  | use std::pin::Pin; | ||||||
|  | 
 | ||||||
|  | /// Type for synthetic boxed body
 | ||||||
|  | pub type BoxBody = combinators::BoxBody<Bytes, hyper::Error>; | ||||||
|  | 
 | ||||||
|  | /// helper function to build a empty body
 | ||||||
|  | pub(crate) fn empty() -> BoxBody { | ||||||
|  |   Empty::<Bytes>::new().map_err(|never| match never {}).boxed() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// helper function to build a full body
 | ||||||
|  | pub(crate) fn full(body: Bytes) -> BoxBody { | ||||||
|  |   Full::new(body).map_err(|never| match never {}).boxed() | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[allow(unused)] | ||||||
|  | /* ------------------------------------ */ | ||||||
|  | /// Request body used in this project
 | ||||||
|  | /// - Incoming: just a type that only forwards the downstream request body to upstream.
 | ||||||
|  | /// - IncomingLike: a Incoming-like type in which channel is used
 | ||||||
|  | pub enum RequestBody { | ||||||
|  |   Incoming(Incoming), | ||||||
|  |   IncomingLike(IncomingLike), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Body for RequestBody { | ||||||
|  |   type Data = bytes::Bytes; | ||||||
|  |   type Error = RpxyError; | ||||||
|  | 
 | ||||||
|  |   fn poll_frame( | ||||||
|  |     self: Pin<&mut Self>, | ||||||
|  |     cx: &mut std::task::Context<'_>, | ||||||
|  |   ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { | ||||||
|  |     match self.get_mut() { | ||||||
|  |       RequestBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx).map_err(RpxyError::HyperBodyError), | ||||||
|  |       RequestBody::IncomingLike(incoming_like) => Pin::new(incoming_like).poll_frame(cx), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /* ------------------------------------ */ | ||||||
|  | pub type UnboundedStreamBody = StreamBody<UnboundedReceiver<Result<Frame<bytes::Bytes>, hyper::Error>>>; | ||||||
|  | 
 | ||||||
|  | #[allow(unused)] | ||||||
|  | /// Response body use in this project
 | ||||||
|  | /// - Incoming: just a type that only forwards the upstream response body to downstream.
 | ||||||
|  | /// - Boxed: a type that is generated from cache or synthetic response body, e.g.,, small byte object.
 | ||||||
|  | /// - Streamed: another type that is generated from stream, e.g., large byte object.
 | ||||||
|  | pub enum ResponseBody { | ||||||
|  |   Incoming(Incoming), | ||||||
|  |   Boxed(BoxBody), | ||||||
|  |   Streamed(UnboundedStreamBody), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Body for ResponseBody { | ||||||
|  |   type Data = bytes::Bytes; | ||||||
|  |   type Error = RpxyError; | ||||||
|  | 
 | ||||||
|  |   fn poll_frame( | ||||||
|  |     self: Pin<&mut Self>, | ||||||
|  |     cx: &mut std::task::Context<'_>, | ||||||
|  |   ) -> std::task::Poll<Option<Result<Frame<Self::Data>, Self::Error>>> { | ||||||
|  |     match self.get_mut() { | ||||||
|  |       ResponseBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx), | ||||||
|  |       ResponseBody::Boxed(boxed) => Pin::new(boxed).poll_frame(cx), | ||||||
|  |       ResponseBody::Streamed(streamed) => Pin::new(streamed).poll_frame(cx), | ||||||
|  |     } | ||||||
|  |     .map_err(RpxyError::HyperBodyError) | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										23
									
								
								rpxy-lib/src/hyper_ext/executor.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										23
									
								
								rpxy-lib/src/hyper_ext/executor.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,23 @@ | ||||||
|  | use tokio::runtime::Handle; | ||||||
|  | 
 | ||||||
|  | #[derive(Clone)] | ||||||
|  | /// Executor for hyper
 | ||||||
|  | pub struct LocalExecutor { | ||||||
|  |   runtime_handle: Handle, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl LocalExecutor { | ||||||
|  |   pub fn new(runtime_handle: Handle) -> Self { | ||||||
|  |     LocalExecutor { runtime_handle } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<F> hyper::rt::Executor<F> for LocalExecutor | ||||||
|  | where | ||||||
|  |   F: std::future::Future + Send + 'static, | ||||||
|  |   F::Output: Send, | ||||||
|  | { | ||||||
|  |   fn execute(&self, fut: F) { | ||||||
|  |     self.runtime_handle.spawn(fut); | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										16
									
								
								rpxy-lib/src/hyper_ext/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								rpxy-lib/src/hyper_ext/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,16 @@ | ||||||
|  | mod body_incoming_like; | ||||||
|  | mod body_type; | ||||||
|  | mod executor; | ||||||
|  | mod tokio_timer; | ||||||
|  | mod watch; | ||||||
|  | 
 | ||||||
|  | #[allow(unused)] | ||||||
|  | pub(crate) mod rt { | ||||||
|  |   pub(crate) use super::executor::LocalExecutor; | ||||||
|  |   pub(crate) use super::tokio_timer::{TokioSleep, TokioTimer}; | ||||||
|  | } | ||||||
|  | #[allow(unused)] | ||||||
|  | pub(crate) mod body { | ||||||
|  |   pub(crate) use super::body_incoming_like::IncomingLike; | ||||||
|  |   pub(crate) use super::body_type::{empty, full, BoxBody, RequestBody, ResponseBody, UnboundedStreamBody}; | ||||||
|  | } | ||||||
							
								
								
									
										55
									
								
								rpxy-lib/src/hyper_ext/tokio_timer.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								rpxy-lib/src/hyper_ext/tokio_timer.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,55 @@ | ||||||
|  | use std::{ | ||||||
|  |   future::Future, | ||||||
|  |   pin::Pin, | ||||||
|  |   task::{Context, Poll}, | ||||||
|  |   time::{Duration, Instant}, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | use hyper::rt::{Sleep, Timer}; | ||||||
|  | use pin_project_lite::pin_project; | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Debug)] | ||||||
|  | pub struct TokioTimer; | ||||||
|  | 
 | ||||||
|  | impl Timer for TokioTimer { | ||||||
|  |   fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> { | ||||||
|  |     Box::pin(TokioSleep { | ||||||
|  |       inner: tokio::time::sleep(duration), | ||||||
|  |     }) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> { | ||||||
|  |     Box::pin(TokioSleep { | ||||||
|  |       inner: tokio::time::sleep_until(deadline.into()), | ||||||
|  |     }) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) { | ||||||
|  |     if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() { | ||||||
|  |       sleep.reset(new_deadline) | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pin_project! { | ||||||
|  |     pub(crate) struct TokioSleep { | ||||||
|  |         #[pin] | ||||||
|  |         pub(crate) inner: tokio::time::Sleep, | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Future for TokioSleep { | ||||||
|  |   type Output = (); | ||||||
|  | 
 | ||||||
|  |   fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { | ||||||
|  |     self.project().inner.poll(cx) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Sleep for TokioSleep {} | ||||||
|  | 
 | ||||||
|  | impl TokioSleep { | ||||||
|  |   pub fn reset(self: Pin<&mut Self>, deadline: Instant) { | ||||||
|  |     self.project().inner.as_mut().reset(deadline.into()); | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										67
									
								
								rpxy-lib/src/hyper_ext/watch.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								rpxy-lib/src/hyper_ext/watch.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,67 @@ | ||||||
|  | //! An SPSC broadcast channel.
 | ||||||
|  | //!
 | ||||||
|  | //! - The value can only be a `usize`.
 | ||||||
|  | //! - The consumer is only notified if the value is different.
 | ||||||
|  | //! - The value `0` is reserved for closed.
 | ||||||
|  | // from https://github.com/hyperium/hyper/blob/master/src/common/watch.rs
 | ||||||
|  | 
 | ||||||
|  | use futures_util::task::AtomicWaker; | ||||||
|  | use std::sync::{ | ||||||
|  |   atomic::{AtomicUsize, Ordering}, | ||||||
|  |   Arc, | ||||||
|  | }; | ||||||
|  | use std::task; | ||||||
|  | 
 | ||||||
|  | type Value = usize; | ||||||
|  | 
 | ||||||
|  | pub(super) const CLOSED: usize = 0; | ||||||
|  | 
 | ||||||
|  | pub(super) fn channel(initial: Value) -> (Sender, Receiver) { | ||||||
|  |   debug_assert!(initial != CLOSED, "watch::channel initial state of 0 is reserved"); | ||||||
|  | 
 | ||||||
|  |   let shared = Arc::new(Shared { | ||||||
|  |     value: AtomicUsize::new(initial), | ||||||
|  |     waker: AtomicWaker::new(), | ||||||
|  |   }); | ||||||
|  | 
 | ||||||
|  |   (Sender { shared: shared.clone() }, Receiver { shared }) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub(super) struct Sender { | ||||||
|  |   shared: Arc<Shared>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub(super) struct Receiver { | ||||||
|  |   shared: Arc<Shared>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | struct Shared { | ||||||
|  |   value: AtomicUsize, | ||||||
|  |   waker: AtomicWaker, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Sender { | ||||||
|  |   pub(super) fn send(&mut self, value: Value) { | ||||||
|  |     if self.shared.value.swap(value, Ordering::SeqCst) != value { | ||||||
|  |       self.shared.waker.wake(); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Drop for Sender { | ||||||
|  |   fn drop(&mut self) { | ||||||
|  |     self.send(CLOSED); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Receiver { | ||||||
|  |   pub(crate) fn load(&mut self, cx: &mut task::Context<'_>) -> Value { | ||||||
|  |     self.shared.waker.register(cx.waker()); | ||||||
|  |     self.shared.value.load(Ordering::SeqCst) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[allow(dead_code)] | ||||||
|  |   pub(crate) fn peek(&self) -> Value { | ||||||
|  |     self.shared.value.load(Ordering::Relaxed) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -1,26 +1,25 @@ | ||||||
| mod backend; | mod backend; | ||||||
| mod certs; |  | ||||||
| mod constants; | mod constants; | ||||||
|  | mod count; | ||||||
|  | mod crypto; | ||||||
| mod error; | mod error; | ||||||
|  | mod forwarder; | ||||||
| mod globals; | mod globals; | ||||||
| mod handler; | mod hyper_ext; | ||||||
| mod log; | mod log; | ||||||
|  | mod message_handler; | ||||||
|  | mod name_exp; | ||||||
| mod proxy; | mod proxy; | ||||||
| mod utils; |  | ||||||
| 
 | 
 | ||||||
| use crate::{ | use crate::{ | ||||||
|   error::*, |   crypto::build_cert_reloader, error::*, forwarder::Forwarder, globals::Globals, log::*, | ||||||
|   globals::Globals, |   message_handler::HttpMessageHandlerBuilder, proxy::Proxy, | ||||||
|   handler::{Forwarder, HttpMessageHandlerBuilder}, |  | ||||||
|   log::*, |  | ||||||
|   proxy::ProxyBuilder, |  | ||||||
| }; | }; | ||||||
| use futures::future::select_all; | use futures::future::select_all; | ||||||
| // use hyper_trust_dns::TrustDnsResolver;
 |  | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
| 
 | 
 | ||||||
| pub use crate::{ | pub use crate::{ | ||||||
|   certs::{CertsAndKeys, CryptoSource}, |   crypto::{CertsAndKeys, CryptoSource}, | ||||||
|   globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}, |   globals::{AppConfig, AppConfigList, ProxyConfig, ReverseProxyConfig, TlsConfig, UpstreamUri}, | ||||||
| }; | }; | ||||||
| pub mod reexports { | pub mod reexports { | ||||||
|  | @ -28,19 +27,22 @@ pub mod reexports { | ||||||
|   pub use rustls::{Certificate, PrivateKey}; |   pub use rustls::{Certificate, PrivateKey}; | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] |  | ||||||
| compile_error!("feature \"http3-quinn\" and feature \"http3-s2n\" cannot be enabled at the same time"); |  | ||||||
| 
 |  | ||||||
| /// Entrypoint that creates and spawns tasks of reverse proxy services
 | /// Entrypoint that creates and spawns tasks of reverse proxy services
 | ||||||
| pub async fn entrypoint<T>( | pub async fn entrypoint<T>( | ||||||
|   proxy_config: &ProxyConfig, |   proxy_config: &ProxyConfig, | ||||||
|   app_config_list: &AppConfigList<T>, |   app_config_list: &AppConfigList<T>, | ||||||
|   runtime_handle: &tokio::runtime::Handle, |   runtime_handle: &tokio::runtime::Handle, | ||||||
|   term_notify: Option<Arc<tokio::sync::Notify>>, |   term_notify: Option<Arc<tokio::sync::Notify>>, | ||||||
| ) -> Result<()> | ) -> RpxyResult<()> | ||||||
| where | where | ||||||
|   T: CryptoSource + Clone + Send + Sync + 'static, |   T: CryptoSource + Clone + Send + Sync + 'static, | ||||||
| { | { | ||||||
|  |   #[cfg(all(feature = "http3-quinn", feature = "http3-s2n"))] | ||||||
|  |   warn!("Both \"http3-quinn\" and \"http3-s2n\" features are enabled. \"http3-quinn\" will be used"); | ||||||
|  | 
 | ||||||
|  |   #[cfg(all(feature = "native-tls-backend", feature = "rustls-backend"))] | ||||||
|  |   warn!("Both \"native-tls-backend\" and \"rustls-backend\" features are enabled. \"rustls-backend\" will be used"); | ||||||
|  | 
 | ||||||
|   // For initial message logging
 |   // For initial message logging
 | ||||||
|   if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { |   if proxy_config.listen_sockets.iter().any(|addr| addr.is_ipv6()) { | ||||||
|     info!("Listen both IPv4 and IPv6") |     info!("Listen both IPv4 and IPv6") | ||||||
|  | @ -53,6 +55,12 @@ where | ||||||
|   if proxy_config.https_port.is_some() { |   if proxy_config.https_port.is_some() { | ||||||
|     info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); |     info!("Listen port: {} (for TLS)", proxy_config.https_port.unwrap()); | ||||||
|   } |   } | ||||||
|  |   if proxy_config.connection_handling_timeout.is_some() { | ||||||
|  |     info!( | ||||||
|  |       "Force connection handling timeout: {:?} sec", | ||||||
|  |       proxy_config.connection_handling_timeout.unwrap_or_default().as_secs() | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] |   #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||||
|   if proxy_config.http3 { |   if proxy_config.http3 { | ||||||
|     info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); |     info!("Experimental HTTP/3.0 is enabled. Note it is still very unstable."); | ||||||
|  | @ -62,52 +70,81 @@ where | ||||||
|   } |   } | ||||||
|   #[cfg(feature = "cache")] |   #[cfg(feature = "cache")] | ||||||
|   if proxy_config.cache_enabled { |   if proxy_config.cache_enabled { | ||||||
|     info!( |     info!("Cache is enabled: cache dir = {:?}", proxy_config.cache_dir.as_ref().unwrap()); | ||||||
|       "Cache is enabled: cache dir = {:?}", |  | ||||||
|       proxy_config.cache_dir.as_ref().unwrap() |  | ||||||
|     ); |  | ||||||
|   } else { |   } else { | ||||||
|     info!("Cache is disabled") |     info!("Cache is disabled") | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   // build global
 |   // 1. build backends, and make it contained in Arc
 | ||||||
|  |   let app_manager = Arc::new(backend::BackendAppManager::try_from(app_config_list)?); | ||||||
|  | 
 | ||||||
|  |   // 2. build crypto reloader service
 | ||||||
|  |   let (cert_reloader_service, cert_reloader_rx) = match proxy_config.https_port { | ||||||
|  |     Some(_) => { | ||||||
|  |       let (s, r) = build_cert_reloader(&app_manager).await?; | ||||||
|  |       (Some(s), Some(r)) | ||||||
|  |     } | ||||||
|  |     None => (None, None), | ||||||
|  |   }; | ||||||
|  | 
 | ||||||
|  |   // 3. build global shared context
 | ||||||
|   let globals = Arc::new(Globals { |   let globals = Arc::new(Globals { | ||||||
|     proxy_config: proxy_config.clone(), |     proxy_config: proxy_config.clone(), | ||||||
|     backends: app_config_list.clone().try_into()?, |  | ||||||
|     request_count: Default::default(), |     request_count: Default::default(), | ||||||
|     runtime_handle: runtime_handle.clone(), |     runtime_handle: runtime_handle.clone(), | ||||||
|  |     term_notify: term_notify.clone(), | ||||||
|  |     cert_reloader_rx: cert_reloader_rx.clone(), | ||||||
|   }); |   }); | ||||||
| 
 | 
 | ||||||
|   // build message handler including a request forwarder
 |   // 4. build message handler containing Arc-ed http_client and backends, and make it contained in Arc as well
 | ||||||
|   let msg_handler = Arc::new( |   let forwarder = Arc::new(Forwarder::try_new(&globals).await?); | ||||||
|  |   let message_handler = Arc::new( | ||||||
|     HttpMessageHandlerBuilder::default() |     HttpMessageHandlerBuilder::default() | ||||||
|       .forwarder(Arc::new(Forwarder::new(&globals).await)) |  | ||||||
|       .globals(globals.clone()) |       .globals(globals.clone()) | ||||||
|  |       .app_manager(app_manager.clone()) | ||||||
|  |       .forwarder(forwarder) | ||||||
|       .build()?, |       .build()?, | ||||||
|   ); |   ); | ||||||
| 
 | 
 | ||||||
|  |   // 5. spawn each proxy for a given socket with copied Arc-ed message_handler.
 | ||||||
|  |   // build hyper connection builder shared with proxy instances
 | ||||||
|  |   let connection_builder = proxy::connection_builder(&globals); | ||||||
|  | 
 | ||||||
|  |   // spawn each proxy for a given socket with copied Arc-ed backend, message_handler and connection builder.
 | ||||||
|   let addresses = globals.proxy_config.listen_sockets.clone(); |   let addresses = globals.proxy_config.listen_sockets.clone(); | ||||||
|   let futures = select_all(addresses.into_iter().map(|addr| { |   let futures_iter = addresses.into_iter().map(|listening_on| { | ||||||
|     let mut tls_enabled = false; |     let mut tls_enabled = false; | ||||||
|     if let Some(https_port) = globals.proxy_config.https_port { |     if let Some(https_port) = globals.proxy_config.https_port { | ||||||
|       tls_enabled = https_port == addr.port() |       tls_enabled = https_port == listening_on.port() | ||||||
|     } |     } | ||||||
| 
 |     let proxy = Proxy { | ||||||
|     let proxy = ProxyBuilder::default() |       globals: globals.clone(), | ||||||
|       .globals(globals.clone()) |       listening_on, | ||||||
|       .listening_on(addr) |       tls_enabled, | ||||||
|       .tls_enabled(tls_enabled) |       connection_builder: connection_builder.clone(), | ||||||
|       .msg_handler(msg_handler.clone()) |       message_handler: message_handler.clone(), | ||||||
|       .build() |     }; | ||||||
|       .unwrap(); |     globals.runtime_handle.spawn(async move { proxy.start().await }) | ||||||
| 
 |   }); | ||||||
|     globals.runtime_handle.spawn(proxy.start(term_notify.clone())) |  | ||||||
|   })); |  | ||||||
| 
 | 
 | ||||||
|   // wait for all future
 |   // wait for all future
 | ||||||
|   if let (Ok(Err(e)), _, _) = futures.await { |   match cert_reloader_service { | ||||||
|     error!("Some proxy services are down: {:?}", e); |     Some(cert_service) => { | ||||||
|   }; |       tokio::select! { | ||||||
|  |         _ = cert_service.start() => { | ||||||
|  |           error!("Certificate reloader service got down"); | ||||||
|  |         } | ||||||
|  |         _ = select_all(futures_iter) => { | ||||||
|  |           error!("Some proxy services are down"); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     None => { | ||||||
|  |       if let (Ok(Err(e)), _, _) = select_all(futures_iter).await { | ||||||
|  |         error!("Some proxy services are down: {}", e); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|   Ok(()) |   Ok(()) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,98 +1 @@ | ||||||
| use crate::utils::ToCanonical; |  | ||||||
| use hyper::header; |  | ||||||
| use std::net::SocketAddr; |  | ||||||
| pub use tracing::{debug, error, info, warn}; | pub use tracing::{debug, error, info, warn}; | ||||||
| 
 |  | ||||||
| #[derive(Debug, Clone)] |  | ||||||
| pub struct MessageLog { |  | ||||||
|   // pub tls_server_name: String,
 |  | ||||||
|   pub client_addr: String, |  | ||||||
|   pub method: String, |  | ||||||
|   pub host: String, |  | ||||||
|   pub p_and_q: String, |  | ||||||
|   pub version: hyper::Version, |  | ||||||
|   pub uri_scheme: String, |  | ||||||
|   pub uri_host: String, |  | ||||||
|   pub ua: String, |  | ||||||
|   pub xff: String, |  | ||||||
|   pub status: String, |  | ||||||
|   pub upstream: String, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<T> From<&hyper::Request<T>> for MessageLog { |  | ||||||
|   fn from(req: &hyper::Request<T>) -> Self { |  | ||||||
|     let header_mapper = |v: header::HeaderName| { |  | ||||||
|       req |  | ||||||
|         .headers() |  | ||||||
|         .get(v) |  | ||||||
|         .map_or_else(|| "", |s| s.to_str().unwrap_or("")) |  | ||||||
|         .to_string() |  | ||||||
|     }; |  | ||||||
|     Self { |  | ||||||
|       // tls_server_name: "".to_string(),
 |  | ||||||
|       client_addr: "".to_string(), |  | ||||||
|       method: req.method().to_string(), |  | ||||||
|       host: header_mapper(header::HOST), |  | ||||||
|       p_and_q: req |  | ||||||
|         .uri() |  | ||||||
|         .path_and_query() |  | ||||||
|         .map_or_else(|| "", |v| v.as_str()) |  | ||||||
|         .to_string(), |  | ||||||
|       version: req.version(), |  | ||||||
|       uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), |  | ||||||
|       uri_host: req.uri().host().unwrap_or("").to_string(), |  | ||||||
|       ua: header_mapper(header::USER_AGENT), |  | ||||||
|       xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), |  | ||||||
|       status: "".to_string(), |  | ||||||
|       upstream: "".to_string(), |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl MessageLog { |  | ||||||
|   pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { |  | ||||||
|     self.client_addr = client_addr.to_canonical().to_string(); |  | ||||||
|     self |  | ||||||
|   } |  | ||||||
|   // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self {
 |  | ||||||
|   //   self.tls_server_name = tls_server_name.to_string();
 |  | ||||||
|   //   self
 |  | ||||||
|   // }
 |  | ||||||
|   pub fn status_code(&mut self, status_code: &hyper::StatusCode) -> &mut Self { |  | ||||||
|     self.status = status_code.to_string(); |  | ||||||
|     self |  | ||||||
|   } |  | ||||||
|   pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { |  | ||||||
|     self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); |  | ||||||
|     self |  | ||||||
|   } |  | ||||||
|   pub fn upstream(&mut self, upstream: &hyper::Uri) -> &mut Self { |  | ||||||
|     self.upstream = upstream.to_string(); |  | ||||||
|     self |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   pub fn output(&self) { |  | ||||||
|     info!( |  | ||||||
|       "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", |  | ||||||
|       if !self.host.is_empty() { |  | ||||||
|         self.host.as_str() |  | ||||||
|       } else { |  | ||||||
|         self.uri_host.as_str() |  | ||||||
|       }, |  | ||||||
|       self.client_addr, |  | ||||||
|       self.method, |  | ||||||
|       self.p_and_q, |  | ||||||
|       self.version, |  | ||||||
|       self.status, |  | ||||||
|       if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { |  | ||||||
|         format!("{}://{}", self.uri_scheme, self.uri_host) |  | ||||||
|       } else { |  | ||||||
|         "".to_string() |  | ||||||
|       }, |  | ||||||
|       self.ua, |  | ||||||
|       self.xff, |  | ||||||
|       self.upstream, |  | ||||||
|       // self.tls_server_name
 |  | ||||||
|     ); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  |  | ||||||
|  | @ -1,5 +1,6 @@ | ||||||
| use std::net::{IpAddr, Ipv4Addr, SocketAddr}; | use std::net::{IpAddr, Ipv4Addr, SocketAddr}; | ||||||
| 
 | 
 | ||||||
|  | /// Trait to convert an IP address to its canonical form
 | ||||||
| pub trait ToCanonical { | pub trait ToCanonical { | ||||||
|   fn to_canonical(&self) -> Self; |   fn to_canonical(&self) -> Self; | ||||||
| } | } | ||||||
							
								
								
									
										248
									
								
								rpxy-lib/src/message_handler/handler_main.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										248
									
								
								rpxy-lib/src/message_handler/handler_main.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,248 @@ | ||||||
|  | use super::{ | ||||||
|  |   http_log::HttpMessageLog, | ||||||
|  |   http_result::{HttpError, HttpResult}, | ||||||
|  |   synthetic_response::{secure_redirection_response, synthetic_error_response}, | ||||||
|  |   utils_headers::*, | ||||||
|  |   utils_request::InspectParseHost, | ||||||
|  | }; | ||||||
|  | use crate::{ | ||||||
|  |   backend::{BackendAppManager, LoadBalanceContext}, | ||||||
|  |   crypto::CryptoSource, | ||||||
|  |   error::*, | ||||||
|  |   forwarder::{ForwardRequest, Forwarder}, | ||||||
|  |   globals::Globals, | ||||||
|  |   hyper_ext::body::{RequestBody, ResponseBody}, | ||||||
|  |   log::*, | ||||||
|  |   name_exp::ServerName, | ||||||
|  | }; | ||||||
|  | use derive_builder::Builder; | ||||||
|  | use http::{Request, Response, StatusCode}; | ||||||
|  | use hyper_util::{client::legacy::connect::Connect, rt::TokioIo}; | ||||||
|  | use std::{net::SocketAddr, sync::Arc}; | ||||||
|  | use tokio::io::copy_bidirectional; | ||||||
|  | 
 | ||||||
|  | #[allow(dead_code)] | ||||||
|  | #[derive(Debug)] | ||||||
|  | /// Context object to handle sticky cookies at HTTP message handler
 | ||||||
|  | pub(super) struct HandlerContext { | ||||||
|  |   #[cfg(feature = "sticky-cookie")] | ||||||
|  |   pub(super) context_lb: Option<LoadBalanceContext>, | ||||||
|  |   #[cfg(not(feature = "sticky-cookie"))] | ||||||
|  |   pub(super) context_lb: Option<()>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Clone, Builder)] | ||||||
|  | /// HTTP message handler for requests from clients and responses from backend applications,
 | ||||||
|  | /// responsible to manipulate and forward messages to upstream backends and downstream clients.
 | ||||||
|  | pub struct HttpMessageHandler<U, C> | ||||||
|  | where | ||||||
|  |   C: Send + Sync + Connect + Clone + 'static, | ||||||
|  |   U: CryptoSource + Clone, | ||||||
|  | { | ||||||
|  |   forwarder: Arc<Forwarder<C>>, | ||||||
|  |   pub(super) globals: Arc<Globals>, | ||||||
|  |   app_manager: Arc<BackendAppManager<U>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<U, C> HttpMessageHandler<U, C> | ||||||
|  | where | ||||||
|  |   C: Send + Sync + Connect + Clone + 'static, | ||||||
|  |   U: CryptoSource + Clone, | ||||||
|  | { | ||||||
|  |   /// Handle incoming request message from a client.
 | ||||||
|  |   /// Responsible to passthrough responses from backend applications or generate synthetic error responses.
 | ||||||
|  |   pub async fn handle_request( | ||||||
|  |     &self, | ||||||
|  |     req: Request<RequestBody>, | ||||||
|  |     client_addr: SocketAddr, // For access control
 | ||||||
|  |     listen_addr: SocketAddr, | ||||||
|  |     tls_enabled: bool, | ||||||
|  |     tls_server_name: Option<ServerName>, | ||||||
|  |   ) -> RpxyResult<Response<ResponseBody>> { | ||||||
|  |     // preparing log data
 | ||||||
|  |     let mut log_data = HttpMessageLog::from(&req); | ||||||
|  |     log_data.client_addr(&client_addr); | ||||||
|  | 
 | ||||||
|  |     let http_result = self | ||||||
|  |       .handle_request_inner( | ||||||
|  |         &mut log_data, | ||||||
|  |         req, | ||||||
|  |         client_addr, | ||||||
|  |         listen_addr, | ||||||
|  |         tls_enabled, | ||||||
|  |         tls_server_name, | ||||||
|  |       ) | ||||||
|  |       .await; | ||||||
|  | 
 | ||||||
|  |     // passthrough or synthetic response
 | ||||||
|  |     match http_result { | ||||||
|  |       Ok(v) => { | ||||||
|  |         log_data.status_code(&v.status()).output(); | ||||||
|  |         Ok(v) | ||||||
|  |       } | ||||||
|  |       Err(e) => { | ||||||
|  |         error!("{e}"); | ||||||
|  |         let code = StatusCode::from(e); | ||||||
|  |         log_data.status_code(&code).output(); | ||||||
|  |         synthetic_error_response(code) | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Handle inner with no synthetic error response.
 | ||||||
|  |   /// Synthetic response is generated by caller.
 | ||||||
|  |   async fn handle_request_inner( | ||||||
|  |     &self, | ||||||
|  |     log_data: &mut HttpMessageLog, | ||||||
|  |     mut req: Request<RequestBody>, | ||||||
|  |     client_addr: SocketAddr, // For access control
 | ||||||
|  |     listen_addr: SocketAddr, | ||||||
|  |     tls_enabled: bool, | ||||||
|  |     tls_server_name: Option<ServerName>, | ||||||
|  |   ) -> HttpResult<Response<ResponseBody>> { | ||||||
|  |     // Here we start to inspect and parse with server_name
 | ||||||
|  |     let server_name = req | ||||||
|  |       .inspect_parse_host() | ||||||
|  |       .map(|v| ServerName::from(v.as_slice())) | ||||||
|  |       .map_err(|_e| HttpError::InvalidHostInRequestHeader)?; | ||||||
|  | 
 | ||||||
|  |     // check consistency of between TLS SNI and HOST/Request URI Line.
 | ||||||
|  |     #[allow(clippy::collapsible_if)] | ||||||
|  |     if tls_enabled && self.globals.proxy_config.sni_consistency { | ||||||
|  |       if server_name != tls_server_name.unwrap_or_default() { | ||||||
|  |         return Err(HttpError::SniHostInconsistency); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     // Find backend application for given server_name, and drop if incoming request is invalid as request.
 | ||||||
|  |     let backend_app = match self.app_manager.apps.get(&server_name) { | ||||||
|  |       Some(backend_app) => backend_app, | ||||||
|  |       None => { | ||||||
|  |         let Some(default_server_name) = &self.app_manager.default_server_name else { | ||||||
|  |           return Err(HttpError::NoMatchingBackendApp); | ||||||
|  |         }; | ||||||
|  |         debug!("Serving by default app"); | ||||||
|  |         self.app_manager.apps.get(default_server_name).unwrap() | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     // Redirect to https if !tls_enabled and redirect_to_https is true
 | ||||||
|  |     if !tls_enabled && backend_app.https_redirection.unwrap_or(false) { | ||||||
|  |       debug!( | ||||||
|  |         "Redirect to secure connection: {}", | ||||||
|  |         <&ServerName as TryInto<String>>::try_into(&backend_app.server_name).unwrap_or_default() | ||||||
|  |       ); | ||||||
|  |       return secure_redirection_response(&backend_app.server_name, self.globals.proxy_config.https_port, &req); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Find reverse proxy for given path and choose one of upstream host
 | ||||||
|  |     // Longest prefix match
 | ||||||
|  |     let path = req.uri().path(); | ||||||
|  |     let Some(upstream_candidates) = backend_app.path_manager.get(path) else { | ||||||
|  |       return Err(HttpError::NoUpstreamCandidates); | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     // Upgrade in request header
 | ||||||
|  |     let upgrade_in_request = extract_upgrade(req.headers()); | ||||||
|  |     if upgrade_in_request.is_some() && req.version() != http::Version::HTTP_11 { | ||||||
|  |       return Err(HttpError::FailedToUpgrade(format!( | ||||||
|  |         "Unsupported HTTP version: {:?}", | ||||||
|  |         req.version() | ||||||
|  |       ))); | ||||||
|  |     } | ||||||
|  |     // let request_upgraded = req.extensions_mut().remove::<hyper::upgrade::OnUpgrade>();
 | ||||||
|  |     let req_on_upgrade = hyper::upgrade::on(&mut req); | ||||||
|  | 
 | ||||||
|  |     // Build request from destination information
 | ||||||
|  |     let _context = match self.generate_request_forwarded( | ||||||
|  |       &client_addr, | ||||||
|  |       &listen_addr, | ||||||
|  |       &mut req, | ||||||
|  |       &upgrade_in_request, | ||||||
|  |       upstream_candidates, | ||||||
|  |       tls_enabled, | ||||||
|  |     ) { | ||||||
|  |       Err(e) => { | ||||||
|  |         return Err(HttpError::FailedToGenerateUpstreamRequest(e.to_string())); | ||||||
|  |       } | ||||||
|  |       Ok(v) => v, | ||||||
|  |     }; | ||||||
|  |     debug!( | ||||||
|  |       "Request to be forwarded: [uri {}, method: {}, version {:?}, headers {:?}]", | ||||||
|  |       req.uri(), | ||||||
|  |       req.method(), | ||||||
|  |       req.version(), | ||||||
|  |       req.headers() | ||||||
|  |     ); | ||||||
|  |     log_data.xff(&req.headers().get("x-forwarded-for")); | ||||||
|  |     log_data.upstream(req.uri()); | ||||||
|  |     //////
 | ||||||
|  | 
 | ||||||
|  |     //////////////
 | ||||||
|  |     // Forward request to a chosen backend
 | ||||||
|  |     let mut res_backend = match self.forwarder.request(req).await { | ||||||
|  |       Ok(v) => v, | ||||||
|  |       Err(e) => { | ||||||
|  |         return Err(HttpError::FailedToGetResponseFromBackend(e.to_string())); | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  |     //////////////
 | ||||||
|  |     // Process reverse proxy context generated during the forwarding request generation.
 | ||||||
|  |     #[cfg(feature = "sticky-cookie")] | ||||||
|  |     if let Some(context_from_lb) = _context.context_lb { | ||||||
|  |       let res_headers = res_backend.headers_mut(); | ||||||
|  |       if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { | ||||||
|  |         return Err(HttpError::FailedToAddSetCookeInResponse(e.to_string())); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { | ||||||
|  |       // Generate response to client
 | ||||||
|  |       if let Err(e) = self.generate_response_forwarded(&mut res_backend, backend_app) { | ||||||
|  |         return Err(HttpError::FailedToGenerateDownstreamResponse(e.to_string())); | ||||||
|  |       } | ||||||
|  |       return Ok(res_backend); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // Handle StatusCode::SWITCHING_PROTOCOLS in response
 | ||||||
|  |     let upgrade_in_response = extract_upgrade(res_backend.headers()); | ||||||
|  |     let should_upgrade = match (upgrade_in_request.as_ref(), upgrade_in_response.as_ref()) { | ||||||
|  |       (Some(u_req), Some(u_res)) => u_req.to_ascii_lowercase() == u_res.to_ascii_lowercase(), | ||||||
|  |       _ => false, | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     if !should_upgrade { | ||||||
|  |       return Err(HttpError::FailedToUpgrade(format!( | ||||||
|  |         "Backend tried to switch to protocol {:?} when {:?} was requested", | ||||||
|  |         upgrade_in_response, upgrade_in_request | ||||||
|  |       ))); | ||||||
|  |     } | ||||||
|  |     // let Some(request_upgraded) = request_upgraded else {
 | ||||||
|  |     //   return Err(HttpError::NoUpgradeExtensionInRequest);
 | ||||||
|  |     // };
 | ||||||
|  | 
 | ||||||
|  |     // let Some(onupgrade) = res_backend.extensions_mut().remove::<hyper::upgrade::OnUpgrade>() else {
 | ||||||
|  |     //   return Err(HttpError::NoUpgradeExtensionInResponse);
 | ||||||
|  |     // };
 | ||||||
|  |     let res_on_upgrade = hyper::upgrade::on(&mut res_backend); | ||||||
|  | 
 | ||||||
|  |     self.globals.runtime_handle.spawn(async move { | ||||||
|  |       let mut response_upgraded = TokioIo::new(res_on_upgrade.await.map_err(|e| { | ||||||
|  |         error!("Failed to upgrade response: {}", e); | ||||||
|  |         RpxyError::FailedToUpgradeResponse(e.to_string()) | ||||||
|  |       })?); | ||||||
|  |       let mut request_upgraded = TokioIo::new(req_on_upgrade.await.map_err(|e| { | ||||||
|  |         error!("Failed to upgrade request: {}", e); | ||||||
|  |         RpxyError::FailedToUpgradeRequest(e.to_string()) | ||||||
|  |       })?); | ||||||
|  |       copy_bidirectional(&mut response_upgraded, &mut request_upgraded) | ||||||
|  |         .await | ||||||
|  |         .map_err(|e| { | ||||||
|  |           error!("Coping between upgraded connections failed: {}", e); | ||||||
|  |           RpxyError::FailedToCopyBidirectional(e.to_string()) | ||||||
|  |         })?; | ||||||
|  |       Ok(()) as RpxyResult<()> | ||||||
|  |     }); | ||||||
|  | 
 | ||||||
|  |     Ok(res_backend) | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										185
									
								
								rpxy-lib/src/message_handler/handler_manipulate_messages.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								rpxy-lib/src/message_handler/handler_manipulate_messages.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,185 @@ | ||||||
|  | use super::{handler_main::HandlerContext, utils_headers::*, utils_request::update_request_line, HttpMessageHandler}; | ||||||
|  | use crate::{ | ||||||
|  |   backend::{BackendApp, UpstreamCandidates}, | ||||||
|  |   constants::RESPONSE_HEADER_SERVER, | ||||||
|  |   log::*, | ||||||
|  |   CryptoSource, | ||||||
|  | }; | ||||||
|  | use anyhow::{anyhow, ensure, Result}; | ||||||
|  | use http::{header, HeaderValue, Request, Response, Uri}; | ||||||
|  | use hyper_util::client::legacy::connect::Connect; | ||||||
|  | use std::net::SocketAddr; | ||||||
|  | 
 | ||||||
|  | impl<U, C> HttpMessageHandler<U, C> | ||||||
|  | where | ||||||
|  |   C: Send + Sync + Connect + Clone + 'static, | ||||||
|  |   U: CryptoSource + Clone, | ||||||
|  | { | ||||||
|  |   ////////////////////////////////////////////////////
 | ||||||
|  |   // Functions to generate messages
 | ||||||
|  |   ////////////////////////////////////////////////////
 | ||||||
|  | 
 | ||||||
|  |   #[allow(unused_variables)] | ||||||
|  |   /// Manipulate a response message sent from a backend application to forward downstream to a client.
 | ||||||
|  |   pub(super) fn generate_response_forwarded<B>( | ||||||
|  |     &self, | ||||||
|  |     response: &mut Response<B>, | ||||||
|  |     backend_app: &BackendApp<U>, | ||||||
|  |   ) -> Result<()> { | ||||||
|  |     let headers = response.headers_mut(); | ||||||
|  |     remove_connection_header(headers); | ||||||
|  |     remove_hop_header(headers); | ||||||
|  |     add_header_entry_overwrite_if_exist(headers, "server", RESPONSE_HEADER_SERVER)?; | ||||||
|  | 
 | ||||||
|  |     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||||
|  |     { | ||||||
|  |       // Manipulate ALT_SVC allowing h3 in response message only when mutual TLS is not enabled
 | ||||||
|  |       // TODO: This is a workaround for avoiding a client authentication in HTTP/3
 | ||||||
|  |       if self.globals.proxy_config.http3 && backend_app.crypto_source.as_ref().is_some_and(|v| !v.is_mutual_tls()) { | ||||||
|  |         if let Some(port) = self.globals.proxy_config.https_port { | ||||||
|  |           add_header_entry_overwrite_if_exist( | ||||||
|  |             headers, | ||||||
|  |             header::ALT_SVC.as_str(), | ||||||
|  |             format!( | ||||||
|  |               "h3=\":{}\"; ma={}, h3-29=\":{}\"; ma={}", | ||||||
|  |               port, self.globals.proxy_config.h3_alt_svc_max_age, port, self.globals.proxy_config.h3_alt_svc_max_age | ||||||
|  |             ), | ||||||
|  |           )?; | ||||||
|  |         } | ||||||
|  |       } else { | ||||||
|  |         // remove alt-svc to disallow requests via http3
 | ||||||
|  |         headers.remove(header::ALT_SVC.as_str()); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] | ||||||
|  |     { | ||||||
|  |       if self.globals.proxy_config.https_port.is_some() { | ||||||
|  |         headers.remove(header::ALT_SVC.as_str()); | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ok(()) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[allow(clippy::too_many_arguments)] | ||||||
|  |   /// Manipulate a request message sent from a client to forward upstream to a backend application
 | ||||||
|  |   pub(super) fn generate_request_forwarded<B>( | ||||||
|  |     &self, | ||||||
|  |     client_addr: &SocketAddr, | ||||||
|  |     listen_addr: &SocketAddr, | ||||||
|  |     req: &mut Request<B>, | ||||||
|  |     upgrade: &Option<String>, | ||||||
|  |     upstream_candidates: &UpstreamCandidates, | ||||||
|  |     tls_enabled: bool, | ||||||
|  |   ) -> Result<HandlerContext> { | ||||||
|  |     debug!("Generate request to be forwarded"); | ||||||
|  | 
 | ||||||
|  |     // Add te: trailer if contained in original request
 | ||||||
|  |     let contains_te_trailers = { | ||||||
|  |       if let Some(te) = req.headers().get(header::TE) { | ||||||
|  |         te.as_bytes() | ||||||
|  |           .split(|v| v == &b',' || v == &b' ') | ||||||
|  |           .any(|x| x == "trailers".as_bytes()) | ||||||
|  |       } else { | ||||||
|  |         false | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let original_uri = req.uri().to_string(); | ||||||
|  |     let headers = req.headers_mut(); | ||||||
|  |     // delete headers specified in header.connection
 | ||||||
|  |     remove_connection_header(headers); | ||||||
|  |     // delete hop headers including header.connection
 | ||||||
|  |     remove_hop_header(headers); | ||||||
|  |     // X-Forwarded-For
 | ||||||
|  |     add_forwarding_header(headers, client_addr, listen_addr, tls_enabled, &original_uri)?; | ||||||
|  | 
 | ||||||
|  |     // Add te: trailer if te_trailer
 | ||||||
|  |     if contains_te_trailers { | ||||||
|  |       headers.insert(header::TE, HeaderValue::from_bytes("trailers".as_bytes()).unwrap()); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     // by default, add "host" header of original server_name if not exist
 | ||||||
|  |     if req.headers().get(header::HOST).is_none() { | ||||||
|  |       let org_host = req.uri().host().ok_or_else(|| anyhow!("Invalid request"))?.to_owned(); | ||||||
|  |       req | ||||||
|  |         .headers_mut() | ||||||
|  |         .insert(header::HOST, HeaderValue::from_str(&org_host)?); | ||||||
|  |     }; | ||||||
|  |     println!("{:?}", req.headers().get(header::HOST)); | ||||||
|  | 
 | ||||||
|  |     /////////////////////////////////////////////
 | ||||||
|  |     // Fix unique upstream destination since there could be multiple ones.
 | ||||||
|  |     #[cfg(feature = "sticky-cookie")] | ||||||
|  |     let (upstream_chosen_opt, context_from_lb) = { | ||||||
|  |       let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_candidates.load_balance { | ||||||
|  |         takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? | ||||||
|  |       } else { | ||||||
|  |         None | ||||||
|  |       }; | ||||||
|  |       upstream_candidates.get(&context_to_lb) | ||||||
|  |     }; | ||||||
|  |     #[cfg(not(feature = "sticky-cookie"))] | ||||||
|  |     let (upstream_chosen_opt, _) = upstream_candidates.get(&None); | ||||||
|  | 
 | ||||||
|  |     let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; | ||||||
|  |     let context = HandlerContext { | ||||||
|  |       #[cfg(feature = "sticky-cookie")] | ||||||
|  |       context_lb: context_from_lb, | ||||||
|  |       #[cfg(not(feature = "sticky-cookie"))] | ||||||
|  |       context_lb: None, | ||||||
|  |     }; | ||||||
|  |     /////////////////////////////////////////////
 | ||||||
|  | 
 | ||||||
|  |     // apply upstream-specific headers given in upstream_option
 | ||||||
|  |     let headers = req.headers_mut(); | ||||||
|  |     // apply upstream options to header
 | ||||||
|  |     apply_upstream_options_to_header(headers, &upstream_chosen.uri, upstream_candidates)?; | ||||||
|  | 
 | ||||||
|  |     // update uri in request
 | ||||||
|  |     ensure!( | ||||||
|  |       upstream_chosen.uri.authority().is_some() && upstream_chosen.uri.scheme().is_some(), | ||||||
|  |       "Upstream uri `scheme` and `authority` is broken" | ||||||
|  |     ); | ||||||
|  | 
 | ||||||
|  |     let new_uri = Uri::builder() | ||||||
|  |       .scheme(upstream_chosen.uri.scheme().unwrap().as_str()) | ||||||
|  |       .authority(upstream_chosen.uri.authority().unwrap().as_str()); | ||||||
|  |     let org_pq = match req.uri().path_and_query() { | ||||||
|  |       Some(pq) => pq.to_string(), | ||||||
|  |       None => "/".to_string(), | ||||||
|  |     } | ||||||
|  |     .into_bytes(); | ||||||
|  | 
 | ||||||
|  |     // replace some parts of path if opt_replace_path is enabled for chosen upstream
 | ||||||
|  |     let new_pq = match &upstream_candidates.replace_path { | ||||||
|  |       Some(new_path) => { | ||||||
|  |         let matched_path: &[u8] = upstream_candidates.path.as_ref(); | ||||||
|  |         ensure!( | ||||||
|  |           !matched_path.is_empty() && org_pq.len() >= matched_path.len(), | ||||||
|  |           "Upstream uri `path and query` is broken" | ||||||
|  |         ); | ||||||
|  |         let mut new_pq = Vec::<u8>::with_capacity(org_pq.len() - matched_path.len() + new_path.len()); | ||||||
|  |         new_pq.extend_from_slice(new_path.as_ref()); | ||||||
|  |         new_pq.extend_from_slice(&org_pq[matched_path.len()..]); | ||||||
|  |         new_pq | ||||||
|  |       } | ||||||
|  |       None => org_pq, | ||||||
|  |     }; | ||||||
|  |     *req.uri_mut() = new_uri.path_and_query(new_pq).build()?; | ||||||
|  | 
 | ||||||
|  |     // upgrade
 | ||||||
|  |     if let Some(v) = upgrade { | ||||||
|  |       req.headers_mut().insert(header::UPGRADE, v.parse()?); | ||||||
|  |       req | ||||||
|  |         .headers_mut() | ||||||
|  |         .insert(header::CONNECTION, HeaderValue::from_static("upgrade")); | ||||||
|  |     } | ||||||
|  |     if upgrade.is_none() { | ||||||
|  |       // can update request line i.e., http version, only if not upgrade (http 1.1)
 | ||||||
|  |       update_request_line(req, upstream_chosen, upstream_candidates)?; | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     Ok(context) | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										99
									
								
								rpxy-lib/src/message_handler/http_log.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										99
									
								
								rpxy-lib/src/message_handler/http_log.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,99 @@ | ||||||
|  | use super::canonical_address::ToCanonical; | ||||||
|  | use crate::log::*; | ||||||
|  | use http::header; | ||||||
|  | use std::net::SocketAddr; | ||||||
|  | 
 | ||||||
|  | /// Struct to log HTTP messages
 | ||||||
|  | #[derive(Debug, Clone)] | ||||||
|  | pub struct HttpMessageLog { | ||||||
|  |   // pub tls_server_name: String,
 | ||||||
|  |   pub client_addr: String, | ||||||
|  |   pub method: String, | ||||||
|  |   pub host: String, | ||||||
|  |   pub p_and_q: String, | ||||||
|  |   pub version: http::Version, | ||||||
|  |   pub uri_scheme: String, | ||||||
|  |   pub uri_host: String, | ||||||
|  |   pub ua: String, | ||||||
|  |   pub xff: String, | ||||||
|  |   pub status: String, | ||||||
|  |   pub upstream: String, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<T> From<&http::Request<T>> for HttpMessageLog { | ||||||
|  |   fn from(req: &http::Request<T>) -> Self { | ||||||
|  |     let header_mapper = |v: header::HeaderName| { | ||||||
|  |       req | ||||||
|  |         .headers() | ||||||
|  |         .get(v) | ||||||
|  |         .map_or_else(|| "", |s| s.to_str().unwrap_or("")) | ||||||
|  |         .to_string() | ||||||
|  |     }; | ||||||
|  |     Self { | ||||||
|  |       // tls_server_name: "".to_string(),
 | ||||||
|  |       client_addr: "".to_string(), | ||||||
|  |       method: req.method().to_string(), | ||||||
|  |       host: header_mapper(header::HOST), | ||||||
|  |       p_and_q: req | ||||||
|  |         .uri() | ||||||
|  |         .path_and_query() | ||||||
|  |         .map_or_else(|| "", |v| v.as_str()) | ||||||
|  |         .to_string(), | ||||||
|  |       version: req.version(), | ||||||
|  |       uri_scheme: req.uri().scheme_str().unwrap_or("").to_string(), | ||||||
|  |       uri_host: req.uri().host().unwrap_or("").to_string(), | ||||||
|  |       ua: header_mapper(header::USER_AGENT), | ||||||
|  |       xff: header_mapper(header::HeaderName::from_static("x-forwarded-for")), | ||||||
|  |       status: "".to_string(), | ||||||
|  |       upstream: "".to_string(), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl HttpMessageLog { | ||||||
|  |   pub fn client_addr(&mut self, client_addr: &SocketAddr) -> &mut Self { | ||||||
|  |     self.client_addr = client_addr.to_canonical().to_string(); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   // pub fn tls_server_name(&mut self, tls_server_name: &str) -> &mut Self {
 | ||||||
|  |   //   self.tls_server_name = tls_server_name.to_string();
 | ||||||
|  |   //   self
 | ||||||
|  |   // }
 | ||||||
|  |   pub fn status_code(&mut self, status_code: &http::StatusCode) -> &mut Self { | ||||||
|  |     self.status = status_code.to_string(); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn xff(&mut self, xff: &Option<&header::HeaderValue>) -> &mut Self { | ||||||
|  |     self.xff = xff.map_or_else(|| "", |v| v.to_str().unwrap_or("")).to_string(); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  |   pub fn upstream(&mut self, upstream: &http::Uri) -> &mut Self { | ||||||
|  |     self.upstream = upstream.to_string(); | ||||||
|  |     self | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   pub fn output(&self) { | ||||||
|  |     info!( | ||||||
|  |       "{} <- {} -- {} {} {:?} -- {} -- {} \"{}\", \"{}\" \"{}\"", | ||||||
|  |       if !self.host.is_empty() { | ||||||
|  |         self.host.as_str() | ||||||
|  |       } else { | ||||||
|  |         self.uri_host.as_str() | ||||||
|  |       }, | ||||||
|  |       self.client_addr, | ||||||
|  |       self.method, | ||||||
|  |       self.p_and_q, | ||||||
|  |       self.version, | ||||||
|  |       self.status, | ||||||
|  |       if !self.uri_scheme.is_empty() && !self.uri_host.is_empty() { | ||||||
|  |         format!("{}://{}", self.uri_scheme, self.uri_host) | ||||||
|  |       } else { | ||||||
|  |         "".to_string() | ||||||
|  |       }, | ||||||
|  |       self.ua, | ||||||
|  |       self.xff, | ||||||
|  |       self.upstream, | ||||||
|  |       // self.tls_server_name
 | ||||||
|  |     ); | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										61
									
								
								rpxy-lib/src/message_handler/http_result.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								rpxy-lib/src/message_handler/http_result.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,61 @@ | ||||||
|  | use http::StatusCode; | ||||||
|  | use thiserror::Error; | ||||||
|  | 
 | ||||||
|  | /// HTTP result type, T is typically a hyper::Response
 | ||||||
|  | /// HttpError is used to generate a synthetic error response
 | ||||||
|  | pub(crate) type HttpResult<T> = std::result::Result<T, HttpError>; | ||||||
|  | 
 | ||||||
|  | /// Describes things that can go wrong in the forwarder
 | ||||||
|  | #[derive(Debug, Error)] | ||||||
|  | pub enum HttpError { | ||||||
|  |   // #[error("No host is give in request header")]
 | ||||||
|  |   // NoHostInRequestHeader,
 | ||||||
|  |   #[error("Invalid host in request header")] | ||||||
|  |   InvalidHostInRequestHeader, | ||||||
|  |   #[error("SNI and Host header mismatch")] | ||||||
|  |   SniHostInconsistency, | ||||||
|  |   #[error("No matching backend app")] | ||||||
|  |   NoMatchingBackendApp, | ||||||
|  |   #[error("Failed to redirect: {0}")] | ||||||
|  |   FailedToRedirect(String), | ||||||
|  |   #[error("No upstream candidates")] | ||||||
|  |   NoUpstreamCandidates, | ||||||
|  |   #[error("Failed to generate upstream request for backend application: {0}")] | ||||||
|  |   FailedToGenerateUpstreamRequest(String), | ||||||
|  |   #[error("Failed to get response from backend: {0}")] | ||||||
|  |   FailedToGetResponseFromBackend(String), | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to add set-cookie header in response {0}")] | ||||||
|  |   FailedToAddSetCookeInResponse(String), | ||||||
|  |   #[error("Failed to generated downstream response for clients: {0}")] | ||||||
|  |   FailedToGenerateDownstreamResponse(String), | ||||||
|  | 
 | ||||||
|  |   #[error("Failed to upgrade connection: {0}")] | ||||||
|  |   FailedToUpgrade(String), | ||||||
|  |   // #[error("Request does not have an upgrade extension")]
 | ||||||
|  |   // NoUpgradeExtensionInRequest,
 | ||||||
|  |   // #[error("Response does not have an upgrade extension")]
 | ||||||
|  |   // NoUpgradeExtensionInResponse,
 | ||||||
|  |   #[error(transparent)] | ||||||
|  |   Other(#[from] anyhow::Error), | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<HttpError> for StatusCode { | ||||||
|  |   fn from(e: HttpError) -> StatusCode { | ||||||
|  |     match e { | ||||||
|  |       // HttpError::NoHostInRequestHeader => StatusCode::BAD_REQUEST,
 | ||||||
|  |       HttpError::InvalidHostInRequestHeader => StatusCode::BAD_REQUEST, | ||||||
|  |       HttpError::SniHostInconsistency => StatusCode::MISDIRECTED_REQUEST, | ||||||
|  |       HttpError::NoMatchingBackendApp => StatusCode::SERVICE_UNAVAILABLE, | ||||||
|  |       HttpError::FailedToRedirect(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||||
|  |       HttpError::NoUpstreamCandidates => StatusCode::NOT_FOUND, | ||||||
|  |       HttpError::FailedToGenerateUpstreamRequest(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||||
|  |       HttpError::FailedToAddSetCookeInResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||||
|  |       HttpError::FailedToGenerateDownstreamResponse(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||||
|  |       HttpError::FailedToUpgrade(_) => StatusCode::INTERNAL_SERVER_ERROR, | ||||||
|  |       // HttpError::NoUpgradeExtensionInRequest => StatusCode::BAD_REQUEST,
 | ||||||
|  |       // HttpError::NoUpgradeExtensionInResponse => StatusCode::BAD_GATEWAY,
 | ||||||
|  |       _ => StatusCode::INTERNAL_SERVER_ERROR, | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
							
								
								
									
										11
									
								
								rpxy-lib/src/message_handler/mod.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										11
									
								
								rpxy-lib/src/message_handler/mod.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,11 @@ | ||||||
|  | mod canonical_address; | ||||||
|  | mod handler_main; | ||||||
|  | mod handler_manipulate_messages; | ||||||
|  | mod http_log; | ||||||
|  | mod http_result; | ||||||
|  | mod synthetic_response; | ||||||
|  | mod utils_headers; | ||||||
|  | mod utils_request; | ||||||
|  | 
 | ||||||
|  | pub use handler_main::HttpMessageHandlerBuilderError; | ||||||
|  | pub(crate) use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder}; | ||||||
							
								
								
									
										42
									
								
								rpxy-lib/src/message_handler/synthetic_response.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								rpxy-lib/src/message_handler/synthetic_response.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,42 @@ | ||||||
|  | use super::http_result::{HttpError, HttpResult}; | ||||||
|  | use crate::{ | ||||||
|  |   error::*, | ||||||
|  |   hyper_ext::body::{empty, ResponseBody}, | ||||||
|  |   name_exp::ServerName, | ||||||
|  | }; | ||||||
|  | use http::{Request, Response, StatusCode, Uri}; | ||||||
|  | 
 | ||||||
|  | /// build http response with status code of 4xx and 5xx
 | ||||||
|  | pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult<Response<ResponseBody>> { | ||||||
|  |   let res = Response::builder() | ||||||
|  |     .status(status_code) | ||||||
|  |     .body(ResponseBody::Boxed(empty())) | ||||||
|  |     .unwrap(); | ||||||
|  |   Ok(res) | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// Generate synthetic response message of a redirection to https host with 301
 | ||||||
|  | pub(super) fn secure_redirection_response<B>( | ||||||
|  |   server_name: &ServerName, | ||||||
|  |   tls_port: Option<u16>, | ||||||
|  |   req: &Request<B>, | ||||||
|  | ) -> HttpResult<Response<ResponseBody>> { | ||||||
|  |   let server_name: String = server_name.try_into().unwrap_or_default(); | ||||||
|  |   let pq = match req.uri().path_and_query() { | ||||||
|  |     Some(x) => x.as_str(), | ||||||
|  |     _ => "", | ||||||
|  |   }; | ||||||
|  |   let new_uri = Uri::builder().scheme("https").path_and_query(pq); | ||||||
|  |   let dest_uri = match tls_port { | ||||||
|  |     Some(443) | None => new_uri.authority(server_name), | ||||||
|  |     Some(p) => new_uri.authority(format!("{server_name}:{p}")), | ||||||
|  |   } | ||||||
|  |   .build() | ||||||
|  |   .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; | ||||||
|  |   let response = Response::builder() | ||||||
|  |     .status(StatusCode::MOVED_PERMANENTLY) | ||||||
|  |     .header("Location", dest_uri.to_string()) | ||||||
|  |     .body(ResponseBody::Boxed(empty())) | ||||||
|  |     .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; | ||||||
|  |   Ok(response) | ||||||
|  | } | ||||||
|  | @ -1,26 +1,27 @@ | ||||||
| #[cfg(feature = "sticky-cookie")] | use super::canonical_address::ToCanonical; | ||||||
| use crate::backend::{LbContext, StickyCookie, StickyCookieValue}; | use crate::{ | ||||||
| use crate::backend::{UpstreamGroup, UpstreamOption}; |   backend::{UpstreamCandidates, UpstreamOption}, | ||||||
| 
 |   log::*, | ||||||
| use crate::{error::*, log::*, utils::*}; |  | ||||||
| use bytes::BufMut; |  | ||||||
| use hyper::{ |  | ||||||
|   header::{self, HeaderMap, HeaderName, HeaderValue}, |  | ||||||
|   Uri, |  | ||||||
| }; | }; | ||||||
|  | use anyhow::{anyhow, ensure, Result}; | ||||||
|  | use bytes::BufMut; | ||||||
|  | use http::{header, HeaderMap, HeaderName, HeaderValue, Uri}; | ||||||
| use std::{borrow::Cow, net::SocketAddr}; | use std::{borrow::Cow, net::SocketAddr}; | ||||||
| 
 | 
 | ||||||
| ////////////////////////////////////////////////////
 | #[cfg(feature = "sticky-cookie")] | ||||||
| // Functions to manipulate headers
 | use crate::backend::{LoadBalanceContext, StickyCookie, StickyCookieValue}; | ||||||
|  | // use crate::backend::{UpstreamGroup, UpstreamOption};
 | ||||||
| 
 | 
 | ||||||
|  | // ////////////////////////////////////////////////////
 | ||||||
|  | // // Functions to manipulate headers
 | ||||||
| #[cfg(feature = "sticky-cookie")] | #[cfg(feature = "sticky-cookie")] | ||||||
| /// Take sticky cookie header value from request header,
 | /// Take sticky cookie header value from request header,
 | ||||||
| /// and returns LbContext to be forwarded to LB if exist and if needed.
 | /// and returns LoadBalanceContext to be forwarded to LB if exist and if needed.
 | ||||||
| /// Removing sticky cookie is needed and it must not be passed to the upstream.
 | /// Removing sticky cookie is needed and it must not be passed to the upstream.
 | ||||||
| pub(super) fn takeout_sticky_cookie_lb_context( | pub(super) fn takeout_sticky_cookie_lb_context( | ||||||
|   headers: &mut HeaderMap, |   headers: &mut HeaderMap, | ||||||
|   expected_cookie_name: &str, |   expected_cookie_name: &str, | ||||||
| ) -> Result<Option<LbContext>> { | ) -> Result<Option<LoadBalanceContext>> { | ||||||
|   let mut headers_clone = headers.clone(); |   let mut headers_clone = headers.clone(); | ||||||
| 
 | 
 | ||||||
|   match headers_clone.entry(header::COOKIE) { |   match headers_clone.entry(header::COOKIE) { | ||||||
|  | @ -35,12 +36,11 @@ pub(super) fn takeout_sticky_cookie_lb_context( | ||||||
|       if sticky_cookies.is_empty() { |       if sticky_cookies.is_empty() { | ||||||
|         return Ok(None); |         return Ok(None); | ||||||
|       } |       } | ||||||
|       if sticky_cookies.len() > 1 { |       ensure!( | ||||||
|         error!("Multiple sticky cookie values in request"); |         sticky_cookies.len() == 1, | ||||||
|         return Err(RpxyError::Other(anyhow!( |  | ||||||
|         "Invalid cookie: Multiple sticky cookie values" |         "Invalid cookie: Multiple sticky cookie values" | ||||||
|         ))); |       ); | ||||||
|       } | 
 | ||||||
|       let cookies_passed_to_upstream = without_sticky_cookies.join("; "); |       let cookies_passed_to_upstream = without_sticky_cookies.join("; "); | ||||||
|       let cookie_passed_to_lb = sticky_cookies.first().unwrap(); |       let cookie_passed_to_lb = sticky_cookies.first().unwrap(); | ||||||
|       headers.remove(header::COOKIE); |       headers.remove(header::COOKIE); | ||||||
|  | @ -50,7 +50,7 @@ pub(super) fn takeout_sticky_cookie_lb_context( | ||||||
|         value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, |         value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, | ||||||
|         info: None, |         info: None, | ||||||
|       }; |       }; | ||||||
|       Ok(Some(LbContext { sticky_cookie })) |       Ok(Some(LoadBalanceContext { sticky_cookie })) | ||||||
|     } |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  | @ -59,7 +59,10 @@ pub(super) fn takeout_sticky_cookie_lb_context( | ||||||
| /// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated.
 | /// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated.
 | ||||||
| /// Set-Cookie response header could be in multiple lines.
 | /// Set-Cookie response header could be in multiple lines.
 | ||||||
| /// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie
 | /// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie
 | ||||||
| pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { | pub(super) fn set_sticky_cookie_lb_context( | ||||||
|  |   headers: &mut HeaderMap, | ||||||
|  |   context_from_lb: &LoadBalanceContext, | ||||||
|  | ) -> Result<()> { | ||||||
|   let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; |   let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; | ||||||
|   let new_header_val: HeaderValue = sticky_cookie_string.parse()?; |   let new_header_val: HeaderValue = sticky_cookie_string.parse()?; | ||||||
|   let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; |   let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; | ||||||
|  | @ -83,23 +86,37 @@ pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from | ||||||
|   Ok(()) |   Ok(()) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | /// overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy)
 | ||||||
|  | fn override_host_header(headers: &mut HeaderMap, upstream_base_uri: &Uri) -> Result<()> { | ||||||
|  |   let mut upstream_host = upstream_base_uri | ||||||
|  |     .host() | ||||||
|  |     .ok_or_else(|| anyhow!("No hostname is given"))? | ||||||
|  |     .to_string(); | ||||||
|  |   // add port if it is not default
 | ||||||
|  |   if let Some(port) = upstream_base_uri.port_u16() { | ||||||
|  |     upstream_host = format!("{}:{}", upstream_host, port); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   // overwrite host header, this removes all the HOST header values
 | ||||||
|  |   headers.insert(header::HOST, HeaderValue::from_str(&upstream_host)?); | ||||||
|  |   Ok(()) | ||||||
|  | } | ||||||
|  | 
 | ||||||
| /// Apply options to request header, which are specified in the configuration
 | /// Apply options to request header, which are specified in the configuration
 | ||||||
| pub(super) fn apply_upstream_options_to_header( | pub(super) fn apply_upstream_options_to_header( | ||||||
|   headers: &mut HeaderMap, |   headers: &mut HeaderMap, | ||||||
|   _client_addr: &SocketAddr, |  | ||||||
|   upstream: &UpstreamGroup, |  | ||||||
|   upstream_base_uri: &Uri, |   upstream_base_uri: &Uri, | ||||||
|  |   // _client_addr: &SocketAddr,
 | ||||||
|  |   upstream: &UpstreamCandidates, | ||||||
| ) -> Result<()> { | ) -> Result<()> { | ||||||
|   for opt in upstream.opts.iter() { |   for opt in upstream.options.iter() { | ||||||
|     match opt { |     match opt { | ||||||
|       UpstreamOption::OverrideHost => { |       UpstreamOption::SetUpstreamHost => { | ||||||
|         // overwrite HOST value with upstream hostname (like 192.168.xx.x seen from rpxy)
 |         // prioritize KeepOriginalHost
 | ||||||
|         let upstream_host = upstream_base_uri |         if !upstream.options.contains(&UpstreamOption::KeepOriginalHost) { | ||||||
|           .host() |           // overwrite host header, this removes all the HOST header values
 | ||||||
|           .ok_or_else(|| anyhow!("No hostname is given in override_host option"))?; |           override_host_header(headers, upstream_base_uri)?; | ||||||
|         headers |         } | ||||||
|           .insert(header::HOST, HeaderValue::from_str(upstream_host)?) |  | ||||||
|           .ok_or_else(|| anyhow!("Failed to insert host header in override_host option"))?; |  | ||||||
|       } |       } | ||||||
|       UpstreamOption::UpgradeInsecureRequests => { |       UpstreamOption::UpgradeInsecureRequests => { | ||||||
|         // add upgrade-insecure-requests in request header if not exist
 |         // add upgrade-insecure-requests in request header if not exist
 | ||||||
							
								
								
									
										86
									
								
								rpxy-lib/src/message_handler/utils_request.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								rpxy-lib/src/message_handler/utils_request.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,86 @@ | ||||||
|  | use crate::{ | ||||||
|  |   backend::{Upstream, UpstreamCandidates, UpstreamOption}, | ||||||
|  |   log::*, | ||||||
|  | }; | ||||||
|  | use anyhow::{anyhow, ensure, Result}; | ||||||
|  | use http::{header, uri::Scheme, Request, Version}; | ||||||
|  | 
 | ||||||
|  | /// Trait defining parser of hostname
 | ||||||
|  | /// Inspect and extract hostname from either the request HOST header or request line
 | ||||||
|  | pub trait InspectParseHost { | ||||||
|  |   type Error; | ||||||
|  |   fn inspect_parse_host(&self) -> Result<Vec<u8>, Self::Error>; | ||||||
|  | } | ||||||
|  | impl<B> InspectParseHost for Request<B> { | ||||||
|  |   type Error = anyhow::Error; | ||||||
|  |   /// Inspect and extract hostname from either the request HOST header or request line
 | ||||||
|  |   fn inspect_parse_host(&self) -> Result<Vec<u8>> { | ||||||
|  |     let drop_port = |v: &[u8]| { | ||||||
|  |       if v.starts_with(&[b'[']) { | ||||||
|  |         // v6 address with bracket case. if port is specified, always it is in this case.
 | ||||||
|  |         let mut iter = v.split(|ptr| ptr == &b'[' || ptr == &b']'); | ||||||
|  |         iter.next().ok_or(anyhow!("Invalid Host header"))?; // first item is always blank
 | ||||||
|  |         iter.next().ok_or(anyhow!("Invalid Host header")).map(|b| b.to_owned()) | ||||||
|  |       } else if v.len() - v.split(|v| v == &b':').fold(0, |acc, s| acc + s.len()) >= 2 { | ||||||
|  |         // v6 address case, if 2 or more ':' is contained
 | ||||||
|  |         Ok(v.to_owned()) | ||||||
|  |       } else { | ||||||
|  |         // v4 address or hostname
 | ||||||
|  |         v.split(|colon| colon == &b':') | ||||||
|  |           .next() | ||||||
|  |           .ok_or(anyhow!("Invalid Host header")) | ||||||
|  |           .map(|v| v.to_ascii_lowercase()) | ||||||
|  |       } | ||||||
|  |     }; | ||||||
|  | 
 | ||||||
|  |     let headers_host = self.headers().get(header::HOST).map(|v| drop_port(v.as_bytes())); | ||||||
|  |     let uri_host = self.uri().host().map(|v| drop_port(v.as_bytes())); | ||||||
|  |     // let uri_port = self.uri().port_u16();
 | ||||||
|  | 
 | ||||||
|  |     // prioritize server_name in uri
 | ||||||
|  |     match (headers_host, uri_host) { | ||||||
|  |       (Some(Ok(hh)), Some(Ok(hu))) => { | ||||||
|  |         ensure!(hh == hu, "Host header and uri host mismatch"); | ||||||
|  |         Ok(hh) | ||||||
|  |       } | ||||||
|  |       (Some(Ok(hh)), None) => Ok(hh), | ||||||
|  |       (None, Some(Ok(hu))) => Ok(hu), | ||||||
|  |       _ => Err(anyhow!("Neither Host header nor uri host is valid")), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | ////////////////////////////////////////////////////
 | ||||||
|  | // Functions to manipulate request line
 | ||||||
|  | 
 | ||||||
|  | /// Update request line, e.g., version, and apply upstream options to request line, specified in the configuration
 | ||||||
|  | pub(super) fn update_request_line<B>( | ||||||
|  |   req: &mut Request<B>, | ||||||
|  |   upstream_chosen: &Upstream, | ||||||
|  |   upstream_candidates: &UpstreamCandidates, | ||||||
|  | ) -> anyhow::Result<()> { | ||||||
|  |   // If not specified (force_httpXX_upstream) and https, version is preserved except for http/3
 | ||||||
|  |   if upstream_chosen.uri.scheme() == Some(&Scheme::HTTP) { | ||||||
|  |     // Change version to http/1.1 when destination scheme is http
 | ||||||
|  |     debug!("Change version to http/1.1 when destination scheme is http unless upstream option enabled."); | ||||||
|  |     *req.version_mut() = Version::HTTP_11; | ||||||
|  |   } else if req.version() == Version::HTTP_3 { | ||||||
|  |     // HTTP/3 is always https
 | ||||||
|  |     debug!("HTTP/3 is currently unsupported for request to upstream."); | ||||||
|  |     *req.version_mut() = Version::HTTP_2; | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   for opt in upstream_candidates.options.iter() { | ||||||
|  |     match opt { | ||||||
|  |       UpstreamOption::ForceHttp11Upstream => *req.version_mut() = Version::HTTP_11, | ||||||
|  |       UpstreamOption::ForceHttp2Upstream => { | ||||||
|  |         // case: h2c -> https://www.rfc-editor.org/rfc/rfc9113.txt
 | ||||||
|  |         // Upgrade from HTTP/1.1 to HTTP/2 is deprecated. So, http-2 prior knowledge is required.
 | ||||||
|  |         *req.version_mut() = Version::HTTP_2; | ||||||
|  |       } | ||||||
|  |       _ => (), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   Ok(()) | ||||||
|  | } | ||||||
							
								
								
									
										160
									
								
								rpxy-lib/src/name_exp.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										160
									
								
								rpxy-lib/src/name_exp.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,160 @@ | ||||||
|  | use std::borrow::Cow; | ||||||
|  | 
 | ||||||
|  | /// Server name (hostname or ip address) representation in bytes-based struct
 | ||||||
|  | /// for searching hashmap or key list by exact or longest-prefix matching
 | ||||||
|  | #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] | ||||||
|  | pub struct ServerName { | ||||||
|  |   inner: Vec<u8>, // lowercase ascii bytes
 | ||||||
|  | } | ||||||
|  | impl From<&str> for ServerName { | ||||||
|  |   fn from(s: &str) -> Self { | ||||||
|  |     let name = s.bytes().collect::<Vec<u8>>().to_ascii_lowercase(); | ||||||
|  |     Self { inner: name } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl From<&[u8]> for ServerName { | ||||||
|  |   fn from(b: &[u8]) -> Self { | ||||||
|  |     Self { | ||||||
|  |       inner: b.to_ascii_lowercase(), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl TryInto<String> for &ServerName { | ||||||
|  |   type Error = anyhow::Error; | ||||||
|  |   fn try_into(self) -> Result<String, Self::Error> { | ||||||
|  |     let s = std::str::from_utf8(&self.inner)?; | ||||||
|  |     Ok(s.to_string()) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl AsRef<[u8]> for ServerName { | ||||||
|  |   fn as_ref(&self) -> &[u8] { | ||||||
|  |     self.inner.as_ref() | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// Path name, like "/path/ok", represented in bytes-based struct
 | ||||||
|  | /// for searching hashmap or key list by exact or longest-prefix matching
 | ||||||
|  | #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] | ||||||
|  | pub struct PathName { | ||||||
|  |   inner: Vec<u8>, // lowercase ascii bytes
 | ||||||
|  | } | ||||||
|  | impl From<&str> for PathName { | ||||||
|  |   fn from(s: &str) -> Self { | ||||||
|  |     let name = s.bytes().collect::<Vec<u8>>().to_ascii_lowercase(); | ||||||
|  |     Self { inner: name } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl From<&[u8]> for PathName { | ||||||
|  |   fn from(b: &[u8]) -> Self { | ||||||
|  |     Self { | ||||||
|  |       inner: b.to_ascii_lowercase(), | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl TryInto<String> for &PathName { | ||||||
|  |   type Error = anyhow::Error; | ||||||
|  |   fn try_into(self) -> Result<String, Self::Error> { | ||||||
|  |     let s = std::str::from_utf8(&self.inner)?; | ||||||
|  |     Ok(s.to_string()) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl AsRef<[u8]> for PathName { | ||||||
|  |   fn as_ref(&self) -> &[u8] { | ||||||
|  |     self.inner.as_ref() | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | impl PathName { | ||||||
|  |   pub fn len(&self) -> usize { | ||||||
|  |     self.inner.len() | ||||||
|  |   } | ||||||
|  |   pub fn is_empty(&self) -> bool { | ||||||
|  |     self.inner.len() == 0 | ||||||
|  |   } | ||||||
|  |   pub fn get<I>(&self, index: I) -> Option<&I::Output> | ||||||
|  |   where | ||||||
|  |     I: std::slice::SliceIndex<[u8]>, | ||||||
|  |   { | ||||||
|  |     self.inner.get(index) | ||||||
|  |   } | ||||||
|  |   pub fn starts_with(&self, needle: &Self) -> bool { | ||||||
|  |     self.inner.starts_with(&needle.inner) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | /// Trait to express names in ascii-lowercased bytes
 | ||||||
|  | pub trait ByteName { | ||||||
|  |   type OutputServer: Send + Sync + 'static; | ||||||
|  |   type OutputPath; | ||||||
|  |   fn to_server_name(self) -> Self::OutputServer; | ||||||
|  |   fn to_path_name(self) -> Self::OutputPath; | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<'a, T: Into<Cow<'a, str>>> ByteName for T { | ||||||
|  |   type OutputServer = ServerName; | ||||||
|  |   type OutputPath = PathName; | ||||||
|  | 
 | ||||||
|  |   fn to_server_name(self) -> Self::OutputServer { | ||||||
|  |     ServerName::from(self.into().as_ref()) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   fn to_path_name(self) -> Self::OutputPath { | ||||||
|  |     PathName::from(self.into().as_ref()) | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[cfg(test)] | ||||||
|  | mod tests { | ||||||
|  |   use super::*; | ||||||
|  |   #[test] | ||||||
|  |   fn bytes_name_str_works() { | ||||||
|  |     let s = "OK_string"; | ||||||
|  |     let bn = s.to_path_name(); | ||||||
|  |     let bn_lc = s.to_server_name(); | ||||||
|  | 
 | ||||||
|  |     assert_eq!("ok_string".as_bytes(), bn.as_ref()); | ||||||
|  |     assert_eq!("ok_string".as_bytes(), bn_lc.as_ref()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn from_works() { | ||||||
|  |     let s = "OK_string".to_server_name(); | ||||||
|  |     let m = ServerName::from("OK_strinG".as_bytes()); | ||||||
|  |     assert_eq!(s, m); | ||||||
|  |     assert_eq!(s.as_ref(), "ok_string".as_bytes()); | ||||||
|  |     assert_eq!(m.as_ref(), "ok_string".as_bytes()); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn get_works() { | ||||||
|  |     let s = "OK_str".to_path_name(); | ||||||
|  |     let i = s.get(0); | ||||||
|  |     assert_eq!(Some(&"o".as_bytes()[0]), i); | ||||||
|  |     let i = s.get(1); | ||||||
|  |     assert_eq!(Some(&"k".as_bytes()[0]), i); | ||||||
|  |     let i = s.get(2); | ||||||
|  |     assert_eq!(Some(&"_".as_bytes()[0]), i); | ||||||
|  |     let i = s.get(3); | ||||||
|  |     assert_eq!(Some(&"s".as_bytes()[0]), i); | ||||||
|  |     let i = s.get(4); | ||||||
|  |     assert_eq!(Some(&"t".as_bytes()[0]), i); | ||||||
|  |     let i = s.get(5); | ||||||
|  |     assert_eq!(Some(&"r".as_bytes()[0]), i); | ||||||
|  |     let i = s.get(6); | ||||||
|  |     assert_eq!(None, i); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn start_with_works() { | ||||||
|  |     let s = "OK_str".to_path_name(); | ||||||
|  |     let correct = "OK".to_path_name(); | ||||||
|  |     let incorrect = "KO".to_path_name(); | ||||||
|  |     assert!(s.starts_with(&correct)); | ||||||
|  |     assert!(!s.starts_with(&incorrect)); | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   #[test] | ||||||
|  |   fn as_ref_works() { | ||||||
|  |     let s = "OK_str".to_path_name(); | ||||||
|  |     assert_eq!(s.as_ref(), "ok_str".as_bytes()); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | @ -1,13 +1,42 @@ | ||||||
| mod crypto_service; |  | ||||||
| mod proxy_client_cert; |  | ||||||
| #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] |  | ||||||
| mod proxy_h3; |  | ||||||
| mod proxy_main; | mod proxy_main; | ||||||
| #[cfg(feature = "http3-quinn")] |  | ||||||
| mod proxy_quic_quinn; |  | ||||||
| #[cfg(feature = "http3-s2n")] |  | ||||||
| mod proxy_quic_s2n; |  | ||||||
| mod proxy_tls; |  | ||||||
| mod socket; | mod socket; | ||||||
| 
 | 
 | ||||||
| pub use proxy_main::{Proxy, ProxyBuilder, ProxyBuilderError}; | #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||||
|  | mod proxy_h3; | ||||||
|  | #[cfg(feature = "http3-quinn")] | ||||||
|  | mod proxy_quic_quinn; | ||||||
|  | #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
|  | mod proxy_quic_s2n; | ||||||
|  | 
 | ||||||
|  | use crate::{ | ||||||
|  |   globals::Globals, | ||||||
|  |   hyper_ext::rt::{LocalExecutor, TokioTimer}, | ||||||
|  | }; | ||||||
|  | use hyper_util::server::{self, conn::auto::Builder as ConnectionBuilder}; | ||||||
|  | use std::sync::Arc; | ||||||
|  | 
 | ||||||
|  | pub(crate) use proxy_main::Proxy; | ||||||
|  | 
 | ||||||
|  | /// build connection builder shared with proxy instances
 | ||||||
|  | pub(crate) fn connection_builder(globals: &Arc<Globals>) -> Arc<ConnectionBuilder<LocalExecutor>> { | ||||||
|  |   let executor = LocalExecutor::new(globals.runtime_handle.clone()); | ||||||
|  |   let mut http_server = server::conn::auto::Builder::new(executor); | ||||||
|  |   http_server | ||||||
|  |     .http1() | ||||||
|  |     .keep_alive(globals.proxy_config.keepalive) | ||||||
|  |     .header_read_timeout(globals.proxy_config.proxy_idle_timeout) | ||||||
|  |     .timer(TokioTimer) | ||||||
|  |     .pipeline_flush(true); | ||||||
|  |   http_server | ||||||
|  |     .http2() | ||||||
|  |     .max_concurrent_streams(globals.proxy_config.max_concurrent_streams); | ||||||
|  | 
 | ||||||
|  |   if globals.proxy_config.keepalive { | ||||||
|  |     http_server | ||||||
|  |       .http2() | ||||||
|  |       .keep_alive_interval(Some(globals.proxy_config.proxy_idle_timeout)) | ||||||
|  |       .keep_alive_timeout(globals.proxy_config.proxy_idle_timeout + std::time::Duration::from_secs(1)) | ||||||
|  |       .timer(TokioTimer); | ||||||
|  |   } | ||||||
|  |   Arc::new(http_server) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | @ -1,47 +0,0 @@ | ||||||
| use crate::{error::*, log::*}; |  | ||||||
| use rustc_hash::FxHashSet as HashSet; |  | ||||||
| use rustls::Certificate; |  | ||||||
| use x509_parser::extensions::ParsedExtension; |  | ||||||
| use x509_parser::prelude::*; |  | ||||||
| 
 |  | ||||||
| #[allow(dead_code)] |  | ||||||
| // TODO: consider move this function to the layer of handle_request (L7) to return 403
 |  | ||||||
| pub(super) fn check_client_authentication( |  | ||||||
|   client_certs: Option<&[Certificate]>, |  | ||||||
|   client_ca_keyids_set_for_sni: Option<&HashSet<Vec<u8>>>, |  | ||||||
| ) -> std::result::Result<(), ClientCertsError> { |  | ||||||
|   let Some(client_ca_keyids_set) = client_ca_keyids_set_for_sni else { |  | ||||||
|     // No client cert settings for given server name
 |  | ||||||
|     return Ok(()); |  | ||||||
|   }; |  | ||||||
| 
 |  | ||||||
|   let Some(client_certs) = client_certs else { |  | ||||||
|     error!("Client certificate is needed for given server name"); |  | ||||||
|     return Err(ClientCertsError::ClientCertRequired( |  | ||||||
|       "Client certificate is needed for given server name".to_string(), |  | ||||||
|     )); |  | ||||||
|   }; |  | ||||||
|   debug!("Incoming TLS client is (temporarily) authenticated via client cert"); |  | ||||||
| 
 |  | ||||||
|   // Check client certificate key ids
 |  | ||||||
|   let mut client_certs_parsed_iter = client_certs.iter().filter_map(|d| parse_x509_certificate(&d.0).ok()); |  | ||||||
|   let match_server_crypto_and_client_cert = client_certs_parsed_iter.any(|c| { |  | ||||||
|     let mut filtered = c.1.iter_extensions().filter_map(|e| { |  | ||||||
|       if let ParsedExtension::AuthorityKeyIdentifier(key_id) = e.parsed_extension() { |  | ||||||
|         key_id.key_identifier.as_ref() |  | ||||||
|       } else { |  | ||||||
|         None |  | ||||||
|       } |  | ||||||
|     }); |  | ||||||
|     filtered.any(|id| client_ca_keyids_set.contains(id.0)) |  | ||||||
|   }); |  | ||||||
| 
 |  | ||||||
|   if !match_server_crypto_and_client_cert { |  | ||||||
|     error!("Inconsistent client certificate was provided for SNI"); |  | ||||||
|     return Err(ClientCertsError::InconsistentClientCert( |  | ||||||
|       "Inconsistent client certificate was provided for SNI".to_string(), |  | ||||||
|     )); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   Ok(()) |  | ||||||
| } |  | ||||||
|  | @ -1,25 +1,33 @@ | ||||||
| use super::Proxy; | use super::proxy_main::Proxy; | ||||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::ServerNameBytesExp}; | use crate::{ | ||||||
|  |   crypto::CryptoSource, | ||||||
|  |   error::*, | ||||||
|  |   hyper_ext::body::{IncomingLike, RequestBody}, | ||||||
|  |   log::*, | ||||||
|  |   name_exp::ServerName, | ||||||
|  | }; | ||||||
| use bytes::{Buf, Bytes}; | use bytes::{Buf, Bytes}; | ||||||
|  | use http::{Request, Response}; | ||||||
|  | use http_body_util::BodyExt; | ||||||
|  | use hyper_util::client::legacy::connect::Connect; | ||||||
|  | use std::net::SocketAddr; | ||||||
|  | 
 | ||||||
| #[cfg(feature = "http3-quinn")] | #[cfg(feature = "http3-quinn")] | ||||||
| use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | use h3::{quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | ||||||
| use hyper::{client::connect::Connect, Body, Request, Response}; | #[cfg(all(feature = "http3-s2n", not(feature = "http3-quinn")))] | ||||||
| #[cfg(feature = "http3-s2n")] |  | ||||||
| use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | use s2n_quic_h3::h3::{self, quic::BidiStream, quic::Connection as ConnectionQuic, server::RequestStream}; | ||||||
| use std::net::SocketAddr; |  | ||||||
| use tokio::time::{timeout, Duration}; |  | ||||||
| 
 | 
 | ||||||
| impl<T, U> Proxy<T, U> | impl<U, T> Proxy<U, T> | ||||||
| where | where | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |   T: Connect + Clone + Sync + Send + 'static, | ||||||
|   U: CryptoSource + Clone + Sync + Send + 'static, |   U: CryptoSource + Clone + Sync + Send + 'static, | ||||||
| { | { | ||||||
|   pub(super) async fn connection_serve_h3<C>( |   pub(super) async fn h3_serve_connection<C>( | ||||||
|     &self, |     &self, | ||||||
|     quic_connection: C, |     quic_connection: C, | ||||||
|     tls_server_name: ServerNameBytesExp, |     tls_server_name: ServerName, | ||||||
|     client_addr: SocketAddr, |     client_addr: SocketAddr, | ||||||
|   ) -> Result<()> |   ) -> RpxyResult<()> | ||||||
|   where |   where | ||||||
|     C: ConnectionQuic<Bytes>, |     C: ConnectionQuic<Bytes>, | ||||||
|     <C as ConnectionQuic<Bytes>>::BidiStream: BidiStream<Bytes> + Send + 'static, |     <C as ConnectionQuic<Bytes>>::BidiStream: BidiStream<Bytes> + Send + 'static, | ||||||
|  | @ -28,9 +36,11 @@ where | ||||||
|   { |   { | ||||||
|     let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; |     let mut h3_conn = h3::server::Connection::<_, Bytes>::new(quic_connection).await?; | ||||||
|     info!( |     info!( | ||||||
|       "QUIC/HTTP3 connection established from {:?} {:?}", |       "QUIC/HTTP3 connection established from {:?} {}", | ||||||
|       client_addr, tls_server_name |       client_addr, | ||||||
|  |       <&ServerName as TryInto<String>>::try_into(&tls_server_name).unwrap_or_default() | ||||||
|     ); |     ); | ||||||
|  | 
 | ||||||
|     // TODO: Is here enough to fetch server_name from NewConnection?
 |     // TODO: Is here enough to fetch server_name from NewConnection?
 | ||||||
|     // to avoid deep nested call from listener_service_h3
 |     // to avoid deep nested call from listener_service_h3
 | ||||||
|     loop { |     loop { | ||||||
|  | @ -60,13 +70,13 @@ where | ||||||
|           let self_inner = self.clone(); |           let self_inner = self.clone(); | ||||||
|           let tls_server_name_inner = tls_server_name.clone(); |           let tls_server_name_inner = tls_server_name.clone(); | ||||||
|           self.globals.runtime_handle.spawn(async move { |           self.globals.runtime_handle.spawn(async move { | ||||||
|             if let Err(e) = timeout( |             let fut = self_inner.h3_serve_stream(req, stream, client_addr, tls_server_name_inner); | ||||||
|               self_inner.globals.proxy_config.proxy_timeout + Duration::from_secs(1), // timeout per stream are considered as same as one in http2
 |             if let Some(connection_handling_timeout) = self_inner.globals.proxy_config.connection_handling_timeout { | ||||||
|               self_inner.stream_serve_h3(req, stream, client_addr, tls_server_name_inner), |               if let Err(e) = tokio::time::timeout(connection_handling_timeout, fut).await { | ||||||
|             ) |                 warn!("HTTP/3 error on serve stream: {}", e); | ||||||
|             .await |               }; | ||||||
|             { |             } else if let Err(e) = fut.await { | ||||||
|               error!("HTTP/3 failed to process stream: {}", e); |               warn!("HTTP/3 error on serve stream: {}", e); | ||||||
|             } |             } | ||||||
|             request_count.decrement(); |             request_count.decrement(); | ||||||
|             debug!("Request processed: current # {}", request_count.current()); |             debug!("Request processed: current # {}", request_count.current()); | ||||||
|  | @ -78,13 +88,17 @@ where | ||||||
|     Ok(()) |     Ok(()) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   async fn stream_serve_h3<S>( |   /// Serves a request stream from a client
 | ||||||
|  |   /// Body in hyper-0.14 was changed to Incoming in hyper-1.0, and it is not accessible from outside.
 | ||||||
|  |   /// Thus, we needed to implement IncomingLike trait using channel. Also, the backend handler must feed the body in the form of
 | ||||||
|  |   /// Either<Incoming, IncomingLike> as body.
 | ||||||
|  |   async fn h3_serve_stream<S>( | ||||||
|     &self, |     &self, | ||||||
|     req: Request<()>, |     req: Request<()>, | ||||||
|     stream: RequestStream<S, Bytes>, |     stream: RequestStream<S, Bytes>, | ||||||
|     client_addr: SocketAddr, |     client_addr: SocketAddr, | ||||||
|     tls_server_name: ServerNameBytesExp, |     tls_server_name: ServerName, | ||||||
|   ) -> Result<()> |   ) -> RpxyResult<()> | ||||||
|   where |   where | ||||||
|     S: BidiStream<Bytes> + Send + 'static, |     S: BidiStream<Bytes> + Send + 'static, | ||||||
|     <S as BidiStream<Bytes>>::RecvStream: Send, |     <S as BidiStream<Bytes>>::RecvStream: Send, | ||||||
|  | @ -94,7 +108,7 @@ where | ||||||
|     let (mut send_stream, mut recv_stream) = stream.split(); |     let (mut send_stream, mut recv_stream) = stream.split(); | ||||||
| 
 | 
 | ||||||
|     // generate streamed body with trailers using channel
 |     // generate streamed body with trailers using channel
 | ||||||
|     let (body_sender, req_body) = Body::channel(); |     let (body_sender, req_body) = IncomingLike::channel(); | ||||||
| 
 | 
 | ||||||
|     // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1
 |     // Buffering and sending body through channel for protocol conversion like h3 -> h2/http1.1
 | ||||||
|     // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn.
 |     // The underling buffering, i.e., buffer given by the API recv_data.await?, is handled by quinn.
 | ||||||
|  | @ -107,10 +121,10 @@ where | ||||||
|         size += body.remaining(); |         size += body.remaining(); | ||||||
|         if size > max_body_size { |         if size > max_body_size { | ||||||
|           error!( |           error!( | ||||||
|             "Exceeds max request body size for HTTP/3: received {}, maximum_allowd {}", |             "Exceeds max request body size for HTTP/3: received {}, maximum_allowed {}", | ||||||
|             size, max_body_size |             size, max_body_size | ||||||
|           ); |           ); | ||||||
|           return Err(RpxyError::Proxy("Exceeds max request body size for HTTP/3".to_string())); |           return Err(RpxyError::H3TooLargeBody); | ||||||
|         } |         } | ||||||
|         // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes
 |         // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes
 | ||||||
|         sender.send_data(body.copy_to_bytes(body.remaining())).await?; |         sender.send_data(body.copy_to_bytes(body.remaining())).await?; | ||||||
|  | @ -122,13 +136,12 @@ where | ||||||
|         debug!("HTTP/3 incoming request trailers"); |         debug!("HTTP/3 incoming request trailers"); | ||||||
|         sender.send_trailers(trailers.unwrap()).await?; |         sender.send_trailers(trailers.unwrap()).await?; | ||||||
|       } |       } | ||||||
|       Ok(()) |       Ok(()) as RpxyResult<()> | ||||||
|     }); |     }); | ||||||
| 
 | 
 | ||||||
|     let new_req: Request<Body> = Request::from_parts(req_parts, req_body); |     let new_req: Request<RequestBody> = Request::from_parts(req_parts, RequestBody::IncomingLike(req_body)); | ||||||
|     let res = self |     let res = self | ||||||
|       .msg_handler |       .message_handler | ||||||
|       .clone() |  | ||||||
|       .handle_request( |       .handle_request( | ||||||
|         new_req, |         new_req, | ||||||
|         client_addr, |         client_addr, | ||||||
|  | @ -138,21 +151,33 @@ where | ||||||
|       ) |       ) | ||||||
|       .await?; |       .await?; | ||||||
| 
 | 
 | ||||||
|     let (new_res_parts, new_body) = res.into_parts(); |     let (new_res_parts, mut new_body) = res.into_parts(); | ||||||
|     let new_res = Response::from_parts(new_res_parts, ()); |     let new_res = Response::from_parts(new_res_parts, ()); | ||||||
| 
 | 
 | ||||||
|     match send_stream.send_response(new_res).await { |     match send_stream.send_response(new_res).await { | ||||||
|       Ok(_) => { |       Ok(_) => { | ||||||
|         debug!("HTTP/3 response to connection successful"); |         debug!("HTTP/3 response to connection successful"); | ||||||
|         // aggregate body without copying
 |         // on-demand body streaming to downstream without expanding the object onto memory.
 | ||||||
|         let mut body_data = hyper::body::aggregate(new_body).await?; |         loop { | ||||||
|  |           let frame = match new_body.frame().await { | ||||||
|  |             Some(frame) => frame, | ||||||
|  |             None => { | ||||||
|  |               debug!("Response body finished"); | ||||||
|  |               break; | ||||||
|  |             } | ||||||
|  |           } | ||||||
|  |           .map_err(|e| RpxyError::HyperBodyManipulationError(e.to_string()))?; | ||||||
| 
 | 
 | ||||||
|         // create stream body to save memory, shallow copy (increment of ref-count) to Bytes using copy_to_bytes
 |           if frame.is_data() { | ||||||
|         send_stream |             let data = frame.into_data().unwrap_or_default(); | ||||||
|           .send_data(body_data.copy_to_bytes(body_data.remaining())) |             // debug!("Write data to HTTP/3 stream");
 | ||||||
|           .await?; |             send_stream.send_data(data).await?; | ||||||
| 
 |           } else if frame.is_trailers() { | ||||||
|         // TODO: needs handling trailer? should be included in body from handler.
 |             let trailers = frame.into_trailers().unwrap_or_default(); | ||||||
|  |             // debug!("Write trailer to HTTP/3 stream");
 | ||||||
|  |             send_stream.send_trailers(trailers).await?; | ||||||
|  |           } | ||||||
|  |         } | ||||||
|       } |       } | ||||||
|       Err(err) => { |       Err(err) => { | ||||||
|         error!("Unable to send response to connection peer: {:?}", err); |         error!("Unable to send response to connection peer: {:?}", err); | ||||||
|  |  | ||||||
|  | @ -1,78 +1,81 @@ | ||||||
| use super::socket::bind_tcp_socket; | use super::socket::bind_tcp_socket; | ||||||
| use crate::{ | use crate::{ | ||||||
|   certs::CryptoSource, error::*, globals::Globals, handler::HttpMessageHandler, log::*, utils::ServerNameBytesExp, |   constants::TLS_HANDSHAKE_TIMEOUT_SEC, | ||||||
|  |   crypto::{CryptoSource, ServerCrypto, SniServerCryptoMap}, | ||||||
|  |   error::*, | ||||||
|  |   globals::Globals, | ||||||
|  |   hyper_ext::{ | ||||||
|  |     body::{RequestBody, ResponseBody}, | ||||||
|  |     rt::LocalExecutor, | ||||||
|  |   }, | ||||||
|  |   log::*, | ||||||
|  |   message_handler::HttpMessageHandler, | ||||||
|  |   name_exp::ServerName, | ||||||
| }; | }; | ||||||
| use derive_builder::{self, Builder}; | use futures::{select, FutureExt}; | ||||||
| use hyper::{client::connect::Connect, server::conn::Http, service::service_fn, Body, Request}; | use http::{Request, Response}; | ||||||
| use std::{net::SocketAddr, sync::Arc}; | use hyper::{ | ||||||
| use tokio::{ |   body::Incoming, | ||||||
|   io::{AsyncRead, AsyncWrite}, |   rt::{Read, Write}, | ||||||
|   runtime::Handle, |   service::service_fn, | ||||||
|   sync::Notify, |  | ||||||
|   time::{timeout, Duration}, |  | ||||||
| }; | }; | ||||||
|  | use hyper_util::{client::legacy::connect::Connect, rt::TokioIo, server::conn::auto::Builder as ConnectionBuilder}; | ||||||
|  | use std::{net::SocketAddr, sync::Arc, time::Duration}; | ||||||
|  | use tokio::time::timeout; | ||||||
| 
 | 
 | ||||||
| #[derive(Clone)] | /// Wrapper function to handle request for HTTP/1.1 and HTTP/2
 | ||||||
| pub struct LocalExecutor { | /// HTTP/3 is handled in proxy_h3.rs which directly calls the message handler
 | ||||||
|   runtime_handle: Handle, | async fn serve_request<U, T>( | ||||||
| } |   req: Request<Incoming>, | ||||||
| 
 |   handler: Arc<HttpMessageHandler<U, T>>, | ||||||
| impl LocalExecutor { |  | ||||||
|   fn new(runtime_handle: Handle) -> Self { |  | ||||||
|     LocalExecutor { runtime_handle } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<F> hyper::rt::Executor<F> for LocalExecutor |  | ||||||
| where |  | ||||||
|   F: std::future::Future + Send + 'static, |  | ||||||
|   F::Output: Send, |  | ||||||
| { |  | ||||||
|   fn execute(&self, fut: F) { |  | ||||||
|     self.runtime_handle.spawn(fut); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[derive(Clone, Builder)] |  | ||||||
| pub struct Proxy<T, U> |  | ||||||
| where |  | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |  | ||||||
|   U: CryptoSource + Clone + Sync + Send + 'static, |  | ||||||
| { |  | ||||||
|   pub listening_on: SocketAddr, |  | ||||||
|   pub tls_enabled: bool, // TCP待受がTLSかどうか
 |  | ||||||
|   pub msg_handler: Arc<HttpMessageHandler<T, U>>, |  | ||||||
|   pub globals: Arc<Globals<U>>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<T, U> Proxy<T, U> |  | ||||||
| where |  | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |  | ||||||
|   U: CryptoSource + Clone + Sync + Send, |  | ||||||
| { |  | ||||||
|   /// Wrapper function to handle request
 |  | ||||||
|   async fn serve( |  | ||||||
|     handler: Arc<HttpMessageHandler<T, U>>, |  | ||||||
|     req: Request<Body>, |  | ||||||
|   client_addr: SocketAddr, |   client_addr: SocketAddr, | ||||||
|   listen_addr: SocketAddr, |   listen_addr: SocketAddr, | ||||||
|   tls_enabled: bool, |   tls_enabled: bool, | ||||||
|     tls_server_name: Option<ServerNameBytesExp>, |   tls_server_name: Option<ServerName>, | ||||||
|   ) -> Result<hyper::Response<Body>> { | ) -> RpxyResult<Response<ResponseBody>> | ||||||
|  | where | ||||||
|  |   T: Send + Sync + Connect + Clone, | ||||||
|  |   U: CryptoSource + Clone, | ||||||
|  | { | ||||||
|   handler |   handler | ||||||
|       .handle_request(req, client_addr, listen_addr, tls_enabled, tls_server_name) |     .handle_request( | ||||||
|  |       req.map(RequestBody::Incoming), | ||||||
|  |       client_addr, | ||||||
|  |       listen_addr, | ||||||
|  |       tls_enabled, | ||||||
|  |       tls_server_name, | ||||||
|  |     ) | ||||||
|     .await |     .await | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
|  | #[derive(Clone)] | ||||||
|  | /// Proxy main object responsible to serve requests received from clients at the given socket address.
 | ||||||
|  | pub(crate) struct Proxy<U, T, E = LocalExecutor> | ||||||
|  | where | ||||||
|  |   T: Send + Sync + Connect + Clone + 'static, | ||||||
|  |   U: CryptoSource + Clone + Sync + Send + 'static, | ||||||
|  | { | ||||||
|  |   /// global context shared among async tasks
 | ||||||
|  |   pub globals: Arc<Globals>, | ||||||
|  |   /// listen socket address
 | ||||||
|  |   pub listening_on: SocketAddr, | ||||||
|  |   /// whether TLS is enabled or not
 | ||||||
|  |   pub tls_enabled: bool, | ||||||
|  |   /// hyper connection builder serving http request
 | ||||||
|  |   pub connection_builder: Arc<ConnectionBuilder<E>>, | ||||||
|  |   /// message handler serving incoming http request
 | ||||||
|  |   pub message_handler: Arc<HttpMessageHandler<U, T>>, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<U, T> Proxy<U, T> | ||||||
|  | where | ||||||
|  |   T: Send + Sync + Connect + Clone + 'static, | ||||||
|  |   U: CryptoSource + Clone + Sync + Send + 'static, | ||||||
|  | { | ||||||
|   /// Serves requests from clients
 |   /// Serves requests from clients
 | ||||||
|   pub(super) fn client_serve<I>( |   fn serve_connection<I>(&self, stream: I, peer_addr: SocketAddr, tls_server_name: Option<ServerName>) | ||||||
|     self, |   where | ||||||
|     stream: I, |     I: Read + Write + Send + Unpin + 'static, | ||||||
|     server: Http<LocalExecutor>, |  | ||||||
|     peer_addr: SocketAddr, |  | ||||||
|     tls_server_name: Option<ServerNameBytesExp>, |  | ||||||
|   ) where |  | ||||||
|     I: AsyncRead + AsyncWrite + Send + Unpin + 'static, |  | ||||||
|   { |   { | ||||||
|     let request_count = self.globals.request_count.clone(); |     let request_count = self.globals.request_count.clone(); | ||||||
|     if request_count.increment() > self.globals.proxy_config.max_clients { |     if request_count.increment() > self.globals.proxy_config.max_clients { | ||||||
|  | @ -81,27 +84,32 @@ where | ||||||
|     } |     } | ||||||
|     debug!("Request incoming: current # {}", request_count.current()); |     debug!("Request incoming: current # {}", request_count.current()); | ||||||
| 
 | 
 | ||||||
|  |     let server_clone = self.connection_builder.clone(); | ||||||
|  |     let message_handler_clone = self.message_handler.clone(); | ||||||
|  |     let tls_enabled = self.tls_enabled; | ||||||
|  |     let listening_on = self.listening_on; | ||||||
|  |     let handling_timeout = self.globals.proxy_config.connection_handling_timeout; | ||||||
|  | 
 | ||||||
|     self.globals.runtime_handle.clone().spawn(async move { |     self.globals.runtime_handle.clone().spawn(async move { | ||||||
|       timeout( |       let fut = server_clone.serve_connection_with_upgrades( | ||||||
|         self.globals.proxy_config.proxy_timeout + Duration::from_secs(1), |  | ||||||
|         server |  | ||||||
|           .serve_connection( |  | ||||||
|         stream, |         stream, | ||||||
|             service_fn(move |req: Request<Body>| { |         service_fn(move |req: Request<Incoming>| { | ||||||
|               Self::serve( |           serve_request( | ||||||
|                 self.msg_handler.clone(), |  | ||||||
|             req, |             req, | ||||||
|  |             message_handler_clone.clone(), | ||||||
|             peer_addr, |             peer_addr, | ||||||
|                 self.listening_on, |             listening_on, | ||||||
|                 self.tls_enabled, |             tls_enabled, | ||||||
|             tls_server_name.clone(), |             tls_server_name.clone(), | ||||||
|           ) |           ) | ||||||
|         }), |         }), | ||||||
|           ) |       ); | ||||||
|           .with_upgrades(), | 
 | ||||||
|       ) |       if let Some(handling_timeout) = handling_timeout { | ||||||
|       .await |         timeout(handling_timeout, fut).await.ok(); | ||||||
|       .ok(); |       } else { | ||||||
|  |         fut.await.ok(); | ||||||
|  |       } | ||||||
| 
 | 
 | ||||||
|       request_count.decrement(); |       request_count.decrement(); | ||||||
|       debug!("Request processed: current # {}", request_count.current()); |       debug!("Request processed: current # {}", request_count.current()); | ||||||
|  | @ -109,47 +117,149 @@ where | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /// Start without TLS (HTTP cleartext)
 |   /// Start without TLS (HTTP cleartext)
 | ||||||
|   async fn start_without_tls(self, server: Http<LocalExecutor>) -> Result<()> { |   async fn start_without_tls(&self) -> RpxyResult<()> { | ||||||
|     let listener_service = async { |     let listener_service = async { | ||||||
|       let tcp_socket = bind_tcp_socket(&self.listening_on)?; |       let tcp_socket = bind_tcp_socket(&self.listening_on)?; | ||||||
|       let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; |       let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; | ||||||
|       info!("Start TCP proxy serving with HTTP request for configured host names"); |       info!("Start TCP proxy serving with HTTP request for configured host names"); | ||||||
|       while let Ok((stream, _client_addr)) = tcp_listener.accept().await { |       while let Ok((stream, client_addr)) = tcp_listener.accept().await { | ||||||
|         self.clone().client_serve(stream, server.clone(), _client_addr, None); |         self.serve_connection(TokioIo::new(stream), client_addr, None); | ||||||
|       } |       } | ||||||
|       Ok(()) as Result<()> |       Ok(()) as RpxyResult<()> | ||||||
|     }; |     }; | ||||||
|     listener_service.await?; |     listener_service.await?; | ||||||
|     Ok(()) |     Ok(()) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   /// Entrypoint for HTTP/1.1 and HTTP/2 servers
 |   /// Start with TLS (HTTPS)
 | ||||||
|   pub async fn start(self, term_notify: Option<Arc<Notify>>) -> Result<()> { |   pub(super) async fn start_with_tls(&self) -> RpxyResult<()> { | ||||||
|     let mut server = Http::new(); |     #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] | ||||||
|     server.http1_keep_alive(self.globals.proxy_config.keepalive); |     { | ||||||
|     server.http2_max_concurrent_streams(self.globals.proxy_config.max_concurrent_streams); |       self.tls_listener_service().await?; | ||||||
|     server.pipeline_flush(true); |       error!("TCP proxy service for TLS exited"); | ||||||
|     let executor = LocalExecutor::new(self.globals.runtime_handle.clone()); |       Ok(()) | ||||||
|     let server = server.with_executor(executor); |     } | ||||||
|  |     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] | ||||||
|  |     { | ||||||
|  |       if self.globals.proxy_config.http3 { | ||||||
|  |         select! { | ||||||
|  |           _ = self.tls_listener_service().fuse() => { | ||||||
|  |             error!("TCP proxy service for TLS exited"); | ||||||
|  |           }, | ||||||
|  |           _ = self.h3_listener_service().fuse() => { | ||||||
|  |             error!("UDP proxy service for QUIC exited"); | ||||||
|  |           } | ||||||
|  |         }; | ||||||
|  |         Ok(()) | ||||||
|  |       } else { | ||||||
|  |         self.tls_listener_service().await?; | ||||||
|  |         error!("TCP proxy service for TLS exited"); | ||||||
|  |         Ok(()) | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
| 
 | 
 | ||||||
|     let listening_on = self.listening_on; |   // TCP Listener Service, i.e., http/2 and http/1.1
 | ||||||
|  |   async fn tls_listener_service(&self) -> RpxyResult<()> { | ||||||
|  |     let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { | ||||||
|  |       return Err(RpxyError::NoCertificateReloader); | ||||||
|  |     }; | ||||||
|  |     let tcp_socket = bind_tcp_socket(&self.listening_on)?; | ||||||
|  |     let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; | ||||||
|  |     info!("Start TCP proxy serving with HTTPS request for configured host names"); | ||||||
| 
 | 
 | ||||||
|  |     let mut server_crypto_map: Option<Arc<SniServerCryptoMap>> = None; | ||||||
|  |     loop { | ||||||
|  |       select! { | ||||||
|  |         tcp_cnx = tcp_listener.accept().fuse() => { | ||||||
|  |           if tcp_cnx.is_err() || server_crypto_map.is_none() { | ||||||
|  |             continue; | ||||||
|  |           } | ||||||
|  |           let (raw_stream, client_addr) = tcp_cnx.unwrap(); | ||||||
|  |           let sc_map_inner = server_crypto_map.clone(); | ||||||
|  |           let self_inner = self.clone(); | ||||||
|  | 
 | ||||||
|  |           // spawns async handshake to avoid blocking thread by sequential handshake.
 | ||||||
|  |           let handshake_fut = async move { | ||||||
|  |             let acceptor = tokio_rustls::LazyConfigAcceptor::new(tokio_rustls::rustls::server::Acceptor::default(), raw_stream).await; | ||||||
|  |             if let Err(e) = acceptor { | ||||||
|  |               return Err(RpxyError::FailedToTlsHandshake(e.to_string())); | ||||||
|  |             } | ||||||
|  |             let start = acceptor.unwrap(); | ||||||
|  |             let client_hello = start.client_hello(); | ||||||
|  |             let sni = client_hello.server_name(); | ||||||
|  |             debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", sni.unwrap_or("None")); | ||||||
|  |             let server_name = sni.map(ServerName::from); | ||||||
|  |             if server_name.is_none(){ | ||||||
|  |               return Err(RpxyError::NoServerNameInClientHello); | ||||||
|  |             } | ||||||
|  |             let server_crypto = sc_map_inner.as_ref().unwrap().get(server_name.as_ref().unwrap()); | ||||||
|  |             if server_crypto.is_none() { | ||||||
|  |               return Err(RpxyError::NoTlsServingApp(server_name.as_ref().unwrap().try_into().unwrap_or_default())); | ||||||
|  |             } | ||||||
|  |             let stream = match start.into_stream(server_crypto.unwrap().clone()).await { | ||||||
|  |               Ok(s) => TokioIo::new(s), | ||||||
|  |               Err(e) => { | ||||||
|  |                 return Err(RpxyError::FailedToTlsHandshake(e.to_string())); | ||||||
|  |               } | ||||||
|  |             }; | ||||||
|  |             Ok((stream, client_addr, server_name)) | ||||||
|  |           }; | ||||||
|  | 
 | ||||||
|  |           self.globals.runtime_handle.spawn( async move { | ||||||
|  |             // timeout is introduced to avoid get stuck here.
 | ||||||
|  |             let Ok(v) = timeout( | ||||||
|  |               Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), | ||||||
|  |               handshake_fut | ||||||
|  |             ).await else { | ||||||
|  |               error!("Timeout to handshake TLS"); | ||||||
|  |               return; | ||||||
|  |             }; | ||||||
|  |             match v { | ||||||
|  |               Ok((stream, client_addr, server_name)) => { | ||||||
|  |                 self_inner.serve_connection(stream, client_addr, server_name); | ||||||
|  |               } | ||||||
|  |               Err(e) => { | ||||||
|  |                 error!("{}", e); | ||||||
|  |               } | ||||||
|  |             } | ||||||
|  |           }); | ||||||
|  |         } | ||||||
|  |         _ = server_crypto_rx.changed().fuse() => { | ||||||
|  |           if server_crypto_rx.borrow().is_none() { | ||||||
|  |             error!("Reloader is broken"); | ||||||
|  |             break; | ||||||
|  |           } | ||||||
|  |           let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); | ||||||
|  |           let Some(server_crypto): Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok() else { | ||||||
|  |             error!("Failed to update server crypto"); | ||||||
|  |             break; | ||||||
|  |           }; | ||||||
|  |           server_crypto_map = Some(server_crypto.inner_local_map.clone()); | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     Ok(()) | ||||||
|  |   } | ||||||
|  | 
 | ||||||
|  |   /// Entrypoint for HTTP/1.1, 2 and 3 servers
 | ||||||
|  |   pub async fn start(&self) -> RpxyResult<()> { | ||||||
|     let proxy_service = async { |     let proxy_service = async { | ||||||
|       if self.tls_enabled { |       if self.tls_enabled { | ||||||
|         self.start_with_tls(server).await |         self.start_with_tls().await | ||||||
|       } else { |       } else { | ||||||
|         self.start_without_tls(server).await |         self.start_without_tls().await | ||||||
|       } |       } | ||||||
|     }; |     }; | ||||||
| 
 | 
 | ||||||
|     match term_notify { |     match &self.globals.term_notify { | ||||||
|       Some(term) => { |       Some(term) => { | ||||||
|         tokio::select! { |         select! { | ||||||
|           _ = proxy_service => { |           _ = proxy_service.fuse() => { | ||||||
|             warn!("Proxy service got down"); |             warn!("Proxy service got down"); | ||||||
|           } |           } | ||||||
|           _ = term.notified() => { |           _ = term.notified().fuse() => { | ||||||
|             info!("Proxy service listening on {} receives term signal", listening_on); |             info!("Proxy service listening on {} receives term signal", self.listening_on); | ||||||
|           } |           } | ||||||
|         } |         } | ||||||
|       } |       } | ||||||
|  | @ -159,8 +269,6 @@ where | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|     // proxy_service.await?;
 |  | ||||||
| 
 |  | ||||||
|     Ok(()) |     Ok(()) | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,30 +1,32 @@ | ||||||
|  | use super::proxy_main::Proxy; | ||||||
| use super::socket::bind_udp_socket; | use super::socket::bind_udp_socket; | ||||||
| use super::{ | use crate::{ | ||||||
|   crypto_service::{ServerCrypto, ServerCryptoBase}, |   crypto::{CryptoSource, ServerCrypto}, | ||||||
|   proxy_main::Proxy, |   error::*, | ||||||
|  |   log::*, | ||||||
|  |   name_exp::ByteName, | ||||||
| }; | }; | ||||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; | use hyper_util::client::legacy::connect::Connect; | ||||||
| use hot_reload::ReloaderReceiver; |  | ||||||
| use hyper::client::connect::Connect; |  | ||||||
| use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; | use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; | ||||||
| use rustls::ServerConfig; | use rustls::ServerConfig; | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
| 
 | 
 | ||||||
| impl<T, U> Proxy<T, U> | impl<U, T> Proxy<U, T> | ||||||
| where | where | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |   T: Send + Sync + Connect + Clone + 'static, | ||||||
|   U: CryptoSource + Clone + Sync + Send + 'static, |   U: CryptoSource + Clone + Sync + Send + 'static, | ||||||
| { | { | ||||||
|   pub(super) async fn listener_service_h3( |   pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { | ||||||
|     &self, |     let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { | ||||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, |       return Err(RpxyError::NoCertificateReloader); | ||||||
|   ) -> Result<()> { |     }; | ||||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names [quinn]"); |     info!("Start UDP proxy serving with HTTP/3 request for configured host names [quinn]"); | ||||||
|     // first set as null config server
 |     // first set as null config server
 | ||||||
|     let rustls_server_config = ServerConfig::builder() |     let rustls_server_config = ServerConfig::builder() | ||||||
|       .with_safe_default_cipher_suites() |       .with_safe_default_cipher_suites() | ||||||
|       .with_safe_default_kx_groups() |       .with_safe_default_kx_groups() | ||||||
|       .with_protocol_versions(&[&rustls::version::TLS13])? |       .with_protocol_versions(&[&rustls::version::TLS13]) | ||||||
|  |       .map_err(|e| RpxyError::QuinnInvalidTlsProtocolVersion(e.to_string()))? | ||||||
|       .with_no_client_auth() |       .with_no_client_auth() | ||||||
|       .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); |       .with_cert_resolver(Arc::new(rustls::server::ResolvesServerCertUsingSni::new())); | ||||||
| 
 | 
 | ||||||
|  | @ -90,11 +92,11 @@ where | ||||||
|               }, |               }, | ||||||
|               Err(e) => { |               Err(e) => { | ||||||
|                 warn!("QUIC accepting connection failed: {:?}", e); |                 warn!("QUIC accepting connection failed: {:?}", e); | ||||||
|                 return Err(RpxyError::QuicConn(e)); |                 return Err(RpxyError::QuinnConnectionFailed(e)); | ||||||
|               } |               } | ||||||
|             }; |             }; | ||||||
|             // Timeout is based on underlying quic
 |             // Timeout is based on underlying quic
 | ||||||
|             if let Err(e) = self_clone.connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr).await { |             if let Err(e) = self_clone.h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr).await { | ||||||
|               warn!("QUIC or HTTP/3 connection failed: {}", e); |               warn!("QUIC or HTTP/3 connection failed: {}", e); | ||||||
|             }; |             }; | ||||||
|             Ok(()) |             Ok(()) | ||||||
|  | @ -119,6 +121,6 @@ where | ||||||
|       } |       } | ||||||
|     } |     } | ||||||
|     endpoint.wait_idle().await; |     endpoint.wait_idle().await; | ||||||
|     Ok(()) as Result<()> |     Ok(()) as RpxyResult<()> | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | @ -1,22 +1,27 @@ | ||||||
| use super::{ | use super::proxy_main::Proxy; | ||||||
|   crypto_service::{ServerCrypto, ServerCryptoBase}, | use crate::{ | ||||||
|   proxy_main::Proxy, |   crypto::CryptoSource, | ||||||
|  |   crypto::{ServerCrypto, ServerCryptoBase}, | ||||||
|  |   error::*, | ||||||
|  |   log::*, | ||||||
|  |   name_exp::ByteName, | ||||||
| }; | }; | ||||||
| use crate::{certs::CryptoSource, error::*, log::*, utils::BytesName}; | use anyhow::anyhow; | ||||||
| use hot_reload::ReloaderReceiver; | use hot_reload::ReloaderReceiver; | ||||||
| use hyper::client::connect::Connect; | use hyper_util::client::legacy::connect::Connect; | ||||||
| use s2n_quic::provider; | use s2n_quic::provider; | ||||||
| use std::sync::Arc; | use std::sync::Arc; | ||||||
| 
 | 
 | ||||||
| impl<T, U> Proxy<T, U> | impl<U, T> Proxy<U, T> | ||||||
| where | where | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |   T: Connect + Clone + Sync + Send + 'static, | ||||||
|   U: CryptoSource + Clone + Sync + Send + 'static, |   U: CryptoSource + Clone + Sync + Send + 'static, | ||||||
| { | { | ||||||
|   pub(super) async fn listener_service_h3( |   /// Start UDP proxy serving with HTTP/3 request for configured host names
 | ||||||
|     &self, |   pub(super) async fn h3_listener_service(&self) -> RpxyResult<()> { | ||||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, |     let Some(mut server_crypto_rx) = self.globals.cert_reloader_rx.clone() else { | ||||||
|   ) -> Result<()> { |       return Err(RpxyError::NoCertificateReloader); | ||||||
|  |     }; | ||||||
|     info!("Start UDP proxy serving with HTTP/3 request for configured host names [s2n-quic]"); |     info!("Start UDP proxy serving with HTTP/3 request for configured host names [s2n-quic]"); | ||||||
| 
 | 
 | ||||||
|     // initially wait for receipt
 |     // initially wait for receipt
 | ||||||
|  | @ -29,7 +34,7 @@ where | ||||||
|     // event loop
 |     // event loop
 | ||||||
|     loop { |     loop { | ||||||
|       tokio::select! { |       tokio::select! { | ||||||
|         v = self.serve_connection(&server_crypto) => { |         v = self.h3_listener_service_inner(&server_crypto) => { | ||||||
|           if let Err(e) = v { |           if let Err(e) = v { | ||||||
|             error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); |             error!("Quic connection event loop illegally shutdown [s2n-quic] {e}"); | ||||||
|             break; |             break; | ||||||
|  | @ -51,20 +56,25 @@ where | ||||||
|     Ok(()) |     Ok(()) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   fn receive_server_crypto(&self, server_crypto_rx: ReloaderReceiver<ServerCryptoBase>) -> Result<Arc<ServerCrypto>> { |   /// Receive server crypto from reloader
 | ||||||
|  |   fn receive_server_crypto( | ||||||
|  |     &self, | ||||||
|  |     server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, | ||||||
|  |   ) -> RpxyResult<Arc<ServerCrypto>> { | ||||||
|     let cert_keys_map = server_crypto_rx.borrow().clone().ok_or_else(|| { |     let cert_keys_map = server_crypto_rx.borrow().clone().ok_or_else(|| { | ||||||
|       error!("Reloader is broken"); |       error!("Reloader is broken"); | ||||||
|       RpxyError::Other(anyhow!("Reloader is broken")) |       RpxyError::CertificateReloadError(anyhow!("Reloader is broken").into()) | ||||||
|     })?; |     })?; | ||||||
| 
 | 
 | ||||||
|     let server_crypto: Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok(); |     let server_crypto: Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok(); | ||||||
|     server_crypto.ok_or_else(|| { |     server_crypto.ok_or_else(|| { | ||||||
|       error!("Failed to update server crypto for h3 [s2n-quic]"); |       error!("Failed to update server crypto for h3 [s2n-quic]"); | ||||||
|       RpxyError::Other(anyhow!("Failed to update server crypto for h3 [s2n-quic]")) |       RpxyError::FailedToUpdateServerCrypto("Failed to update server crypto for h3 [s2n-quic]".to_string()) | ||||||
|     }) |     }) | ||||||
|   } |   } | ||||||
| 
 | 
 | ||||||
|   async fn serve_connection(&self, server_crypto: &Option<Arc<ServerCrypto>>) -> Result<()> { |   /// Event loop for UDP proxy serving with HTTP/3 request for configured host names
 | ||||||
|  |   async fn h3_listener_service_inner(&self, server_crypto: &Option<Arc<ServerCrypto>>) -> RpxyResult<()> { | ||||||
|     // setup UDP socket
 |     // setup UDP socket
 | ||||||
|     let io = provider::io::tokio::Builder::default() |     let io = provider::io::tokio::Builder::default() | ||||||
|       .with_receive_address(self.listening_on)? |       .with_receive_address(self.listening_on)? | ||||||
|  | @ -73,18 +83,13 @@ where | ||||||
| 
 | 
 | ||||||
|     // setup limits
 |     // setup limits
 | ||||||
|     let mut limits = provider::limits::Limits::default() |     let mut limits = provider::limits::Limits::default() | ||||||
|       .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) |       .with_max_open_local_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? | ||||||
|       .map_err(|e| anyhow!(e))? |       .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64)? | ||||||
|       .with_max_open_remote_bidirectional_streams(self.globals.proxy_config.h3_max_concurrent_bidistream as u64) |       .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? | ||||||
|       .map_err(|e| anyhow!(e))? |       .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64)? | ||||||
|       .with_max_open_local_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) |       .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64)?; | ||||||
|       .map_err(|e| anyhow!(e))? |  | ||||||
|       .with_max_open_remote_unidirectional_streams(self.globals.proxy_config.h3_max_concurrent_unistream as u64) |  | ||||||
|       .map_err(|e| anyhow!(e))? |  | ||||||
|       .with_max_active_connection_ids(self.globals.proxy_config.h3_max_concurrent_connections as u64) |  | ||||||
|       .map_err(|e| anyhow!(e))?; |  | ||||||
|     limits = if let Some(v) = self.globals.proxy_config.h3_max_idle_timeout { |     limits = if let Some(v) = self.globals.proxy_config.h3_max_idle_timeout { | ||||||
|       limits.with_max_idle_timeout(v).map_err(|e| anyhow!(e))? |       limits.with_max_idle_timeout(v)? | ||||||
|     } else { |     } else { | ||||||
|       limits |       limits | ||||||
|     }; |     }; | ||||||
|  | @ -92,19 +97,17 @@ where | ||||||
|     // setup tls
 |     // setup tls
 | ||||||
|     let Some(server_crypto) = server_crypto else { |     let Some(server_crypto) = server_crypto else { | ||||||
|       warn!("No server crypto is given [s2n-quic]"); |       warn!("No server crypto is given [s2n-quic]"); | ||||||
|       return Err(RpxyError::Other(anyhow!("No server crypto is given [s2n-quic]"))); |       return Err(RpxyError::NoServerCrypto( | ||||||
|  |         "No server crypto is given [s2n-quic]".to_string(), | ||||||
|  |       )); | ||||||
|     }; |     }; | ||||||
|     let tls = server_crypto.inner_global_no_client_auth.clone(); |     let tls = server_crypto.inner_global_no_client_auth.clone(); | ||||||
| 
 | 
 | ||||||
|     let mut server = s2n_quic::Server::builder() |     let mut server = s2n_quic::Server::builder() | ||||||
|       .with_tls(tls) |       .with_tls(tls)? | ||||||
|       .map_err(|e| anyhow::anyhow!(e))? |       .with_io(io)? | ||||||
|       .with_io(io) |       .with_limits(limits)? | ||||||
|       .map_err(|e| anyhow!(e))? |       .start()?; | ||||||
|       .with_limits(limits) |  | ||||||
|       .map_err(|e| anyhow!(e))? |  | ||||||
|       .start() |  | ||||||
|       .map_err(|e| anyhow!(e))?; |  | ||||||
| 
 | 
 | ||||||
|     // quic event loop. this immediately cancels when crypto is updated by tokio::select!
 |     // quic event loop. this immediately cancels when crypto is updated by tokio::select!
 | ||||||
|     while let Some(new_conn) = server.accept().await { |     while let Some(new_conn) = server.accept().await { | ||||||
|  | @ -121,12 +124,12 @@ where | ||||||
|         let quic_connection = s2n_quic_h3::Connection::new(new_conn); |         let quic_connection = s2n_quic_h3::Connection::new(new_conn); | ||||||
|         // Timeout is based on underlying quic
 |         // Timeout is based on underlying quic
 | ||||||
|         if let Err(e) = self_clone |         if let Err(e) = self_clone | ||||||
|           .connection_serve_h3(quic_connection, new_server_name.to_server_name_vec(), client_addr) |           .h3_serve_connection(quic_connection, new_server_name.to_server_name(), client_addr) | ||||||
|           .await |           .await | ||||||
|         { |         { | ||||||
|           warn!("QUIC or HTTP/3 connection failed: {}", e); |           warn!("QUIC or HTTP/3 connection failed: {}", e); | ||||||
|         }; |         }; | ||||||
|         Ok(()) as Result<()> |         Ok(()) as RpxyResult<()> | ||||||
|       }); |       }); | ||||||
|     } |     } | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,163 +0,0 @@ | ||||||
| use super::{ |  | ||||||
|   crypto_service::{CryptoReloader, ServerCrypto, ServerCryptoBase, SniServerCryptoMap}, |  | ||||||
|   proxy_main::{LocalExecutor, Proxy}, |  | ||||||
|   socket::bind_tcp_socket, |  | ||||||
| }; |  | ||||||
| use crate::{certs::CryptoSource, constants::*, error::*, log::*, utils::BytesName}; |  | ||||||
| use hot_reload::{ReloaderReceiver, ReloaderService}; |  | ||||||
| use hyper::{client::connect::Connect, server::conn::Http}; |  | ||||||
| use std::sync::Arc; |  | ||||||
| use tokio::time::{timeout, Duration}; |  | ||||||
| 
 |  | ||||||
| impl<T, U> Proxy<T, U> |  | ||||||
| where |  | ||||||
|   T: Connect + Clone + Sync + Send + 'static, |  | ||||||
|   U: CryptoSource + Clone + Sync + Send + 'static, |  | ||||||
| { |  | ||||||
|   // TCP Listener Service, i.e., http/2 and http/1.1
 |  | ||||||
|   async fn listener_service( |  | ||||||
|     &self, |  | ||||||
|     server: Http<LocalExecutor>, |  | ||||||
|     mut server_crypto_rx: ReloaderReceiver<ServerCryptoBase>, |  | ||||||
|   ) -> Result<()> { |  | ||||||
|     let tcp_socket = bind_tcp_socket(&self.listening_on)?; |  | ||||||
|     let tcp_listener = tcp_socket.listen(self.globals.proxy_config.tcp_listen_backlog)?; |  | ||||||
|     info!("Start TCP proxy serving with HTTPS request for configured host names"); |  | ||||||
| 
 |  | ||||||
|     let mut server_crypto_map: Option<Arc<SniServerCryptoMap>> = None; |  | ||||||
|     loop { |  | ||||||
|       tokio::select! { |  | ||||||
|         tcp_cnx = tcp_listener.accept() => { |  | ||||||
|           if tcp_cnx.is_err() || server_crypto_map.is_none() { |  | ||||||
|             continue; |  | ||||||
|           } |  | ||||||
|           let (raw_stream, client_addr) = tcp_cnx.unwrap(); |  | ||||||
|           let sc_map_inner = server_crypto_map.clone(); |  | ||||||
|           let server_clone = server.clone(); |  | ||||||
|           let self_inner = self.clone(); |  | ||||||
| 
 |  | ||||||
|           // spawns async handshake to avoid blocking thread by sequential handshake.
 |  | ||||||
|           let handshake_fut = async move { |  | ||||||
|             let acceptor = tokio_rustls::LazyConfigAcceptor::new(tokio_rustls::rustls::server::Acceptor::default(), raw_stream).await; |  | ||||||
|             if let Err(e) = acceptor { |  | ||||||
|               return Err(RpxyError::Proxy(format!("Failed to handshake TLS: {e}"))); |  | ||||||
|             } |  | ||||||
|             let start = acceptor.unwrap(); |  | ||||||
|             let client_hello = start.client_hello(); |  | ||||||
|             let server_name = client_hello.server_name(); |  | ||||||
|             debug!("HTTP/2 or 1.1: SNI in ClientHello: {:?}", server_name); |  | ||||||
|             let server_name_in_bytes = server_name.map_or_else(|| None, |v| Some(v.to_server_name_vec())); |  | ||||||
|             if server_name_in_bytes.is_none(){ |  | ||||||
|               return Err(RpxyError::Proxy("No SNI is given".to_string())); |  | ||||||
|             } |  | ||||||
|             let server_crypto = sc_map_inner.as_ref().unwrap().get(server_name_in_bytes.as_ref().unwrap()); |  | ||||||
|             if server_crypto.is_none() { |  | ||||||
|               return Err(RpxyError::Proxy(format!("No TLS serving app for {:?}", server_name.unwrap()))); |  | ||||||
|             } |  | ||||||
|             let stream = match start.into_stream(server_crypto.unwrap().clone()).await { |  | ||||||
|               Ok(s) => s, |  | ||||||
|               Err(e) => { |  | ||||||
|                 return Err(RpxyError::Proxy(format!("Failed to handshake TLS: {e}"))); |  | ||||||
|               } |  | ||||||
|             }; |  | ||||||
|             self_inner.client_serve(stream, server_clone, client_addr, server_name_in_bytes); |  | ||||||
|             Ok(()) |  | ||||||
|           }; |  | ||||||
| 
 |  | ||||||
|           self.globals.runtime_handle.spawn( async move { |  | ||||||
|             // timeout is introduced to avoid get stuck here.
 |  | ||||||
|             match timeout( |  | ||||||
|               Duration::from_secs(TLS_HANDSHAKE_TIMEOUT_SEC), |  | ||||||
|               handshake_fut |  | ||||||
|             ).await { |  | ||||||
|               Ok(a) => { |  | ||||||
|                 if let Err(e) = a { |  | ||||||
|                   error!("{}", e); |  | ||||||
|                 } |  | ||||||
|               }, |  | ||||||
|               Err(e) => { |  | ||||||
|                 error!("Timeout to handshake TLS: {}", e); |  | ||||||
|               } |  | ||||||
|             }; |  | ||||||
|           }); |  | ||||||
|         } |  | ||||||
|         _ = server_crypto_rx.changed() => { |  | ||||||
|           if server_crypto_rx.borrow().is_none() { |  | ||||||
|             error!("Reloader is broken"); |  | ||||||
|             break; |  | ||||||
|           } |  | ||||||
|           let cert_keys_map = server_crypto_rx.borrow().clone().unwrap(); |  | ||||||
|           let Some(server_crypto): Option<Arc<ServerCrypto>> = (&cert_keys_map).try_into().ok() else { |  | ||||||
|             error!("Failed to update server crypto"); |  | ||||||
|             break; |  | ||||||
|           }; |  | ||||||
|           server_crypto_map = Some(server_crypto.inner_local_map.clone()); |  | ||||||
|         } |  | ||||||
|         else => break
 |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|     Ok(()) as Result<()> |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   pub async fn start_with_tls(self, server: Http<LocalExecutor>) -> Result<()> { |  | ||||||
|     let (cert_reloader_service, cert_reloader_rx) = ReloaderService::<CryptoReloader<U>, ServerCryptoBase>::new( |  | ||||||
|       &self.globals.clone(), |  | ||||||
|       CERTS_WATCH_DELAY_SECS, |  | ||||||
|       !LOAD_CERTS_ONLY_WHEN_UPDATED, |  | ||||||
|     ) |  | ||||||
|     .await |  | ||||||
|     .map_err(|e| anyhow::anyhow!(e))?; |  | ||||||
| 
 |  | ||||||
|     #[cfg(not(any(feature = "http3-quinn", feature = "http3-s2n")))] |  | ||||||
|     { |  | ||||||
|       tokio::select! { |  | ||||||
|         _= cert_reloader_service.start() => { |  | ||||||
|           error!("Cert service for TLS exited"); |  | ||||||
|         }, |  | ||||||
|         _ = self.listener_service(server, cert_reloader_rx) => { |  | ||||||
|           error!("TCP proxy service for TLS exited"); |  | ||||||
|         }, |  | ||||||
|         else => { |  | ||||||
|           error!("Something went wrong"); |  | ||||||
|           return Ok(()) |  | ||||||
|         } |  | ||||||
|       }; |  | ||||||
|       Ok(()) |  | ||||||
|     } |  | ||||||
|     #[cfg(any(feature = "http3-quinn", feature = "http3-s2n"))] |  | ||||||
|     { |  | ||||||
|       if self.globals.proxy_config.http3 { |  | ||||||
|         tokio::select! { |  | ||||||
|           _= cert_reloader_service.start() => { |  | ||||||
|             error!("Cert service for TLS exited"); |  | ||||||
|           }, |  | ||||||
|           _ = self.listener_service(server, cert_reloader_rx.clone()) => { |  | ||||||
|             error!("TCP proxy service for TLS exited"); |  | ||||||
|           }, |  | ||||||
|           _= self.listener_service_h3(cert_reloader_rx) => { |  | ||||||
|             error!("UDP proxy service for QUIC exited"); |  | ||||||
|           }, |  | ||||||
|           else => { |  | ||||||
|             error!("Something went wrong"); |  | ||||||
|             return Ok(()) |  | ||||||
|           } |  | ||||||
|         }; |  | ||||||
|         Ok(()) |  | ||||||
|       } else { |  | ||||||
|         tokio::select! { |  | ||||||
|           _= cert_reloader_service.start() => { |  | ||||||
|             error!("Cert service for TLS exited"); |  | ||||||
|           }, |  | ||||||
|           _ = self.listener_service(server, cert_reloader_rx) => { |  | ||||||
|             error!("TCP proxy service for TLS exited"); |  | ||||||
|           }, |  | ||||||
|           else => { |  | ||||||
|             error!("Something went wrong"); |  | ||||||
|             return Ok(()) |  | ||||||
|           } |  | ||||||
|         }; |  | ||||||
|         Ok(()) |  | ||||||
|       } |  | ||||||
|     } |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  | @ -8,7 +8,7 @@ use tokio::net::TcpSocket; | ||||||
| 
 | 
 | ||||||
| /// Bind TCP socket to the given `SocketAddr`, and returns the TCP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | /// Bind TCP socket to the given `SocketAddr`, and returns the TCP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | ||||||
| /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | ||||||
| pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> Result<TcpSocket> { | pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> RpxyResult<TcpSocket> { | ||||||
|   let tcp_socket = if listening_on.is_ipv6() { |   let tcp_socket = if listening_on.is_ipv6() { | ||||||
|     TcpSocket::new_v6() |     TcpSocket::new_v6() | ||||||
|   } else { |   } else { | ||||||
|  | @ -26,7 +26,7 @@ pub(super) fn bind_tcp_socket(listening_on: &SocketAddr) -> Result<TcpSocket> { | ||||||
| #[cfg(feature = "http3-quinn")] | #[cfg(feature = "http3-quinn")] | ||||||
| /// Bind UDP socket to the given `SocketAddr`, and returns the UDP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | /// Bind UDP socket to the given `SocketAddr`, and returns the UDP socket with `SO_REUSEADDR` and `SO_REUSEPORT` options.
 | ||||||
| /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | /// This option is required to re-bind the socket address when the proxy instance is reconstructed.
 | ||||||
| pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> Result<UdpSocket> { | pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> RpxyResult<UdpSocket> { | ||||||
|   let socket = if listening_on.is_ipv6() { |   let socket = if listening_on.is_ipv6() { | ||||||
|     Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) |     Socket::new(Domain::IPV6, Type::DGRAM, Some(Protocol::UDP)) | ||||||
|   } else { |   } else { | ||||||
|  | @ -34,6 +34,7 @@ pub(super) fn bind_udp_socket(listening_on: &SocketAddr) -> Result<UdpSocket> { | ||||||
|   }?; |   }?; | ||||||
|   socket.set_reuse_address(true)?; // This isn't necessary?
 |   socket.set_reuse_address(true)?; // This isn't necessary?
 | ||||||
|   socket.set_reuse_port(true)?; |   socket.set_reuse_port(true)?; | ||||||
|  |   socket.set_nonblocking(true)?; // This was made true inside quinn. so this line isn't necessary here. but just in case.
 | ||||||
| 
 | 
 | ||||||
|   if let Err(e) = socket.bind(&(*listening_on).into()) { |   if let Err(e) = socket.bind(&(*listening_on).into()) { | ||||||
|     error!("Failed to bind UDP socket: {}", e); |     error!("Failed to bind UDP socket: {}", e); | ||||||
|  |  | ||||||
|  | @ -1,123 +0,0 @@ | ||||||
| /// Server name (hostname or ip address) representation in bytes-based struct
 |  | ||||||
| /// for searching hashmap or key list by exact or longest-prefix matching
 |  | ||||||
| #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] |  | ||||||
| pub struct ServerNameBytesExp(pub Vec<u8>); // lowercase ascii bytes
 |  | ||||||
| impl From<&[u8]> for ServerNameBytesExp { |  | ||||||
|   fn from(b: &[u8]) -> Self { |  | ||||||
|     Self(b.to_ascii_lowercase()) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| impl TryInto<String> for &ServerNameBytesExp { |  | ||||||
|   type Error = anyhow::Error; |  | ||||||
|   fn try_into(self) -> Result<String, Self::Error> { |  | ||||||
|     let s = std::str::from_utf8(&self.0)?; |  | ||||||
|     Ok(s.to_string()) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Path name, like "/path/ok", represented in bytes-based struct
 |  | ||||||
| /// for searching hashmap or key list by exact or longest-prefix matching
 |  | ||||||
| #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] |  | ||||||
| pub struct PathNameBytesExp(pub Vec<u8>); // lowercase ascii bytes
 |  | ||||||
| impl PathNameBytesExp { |  | ||||||
|   pub fn len(&self) -> usize { |  | ||||||
|     self.0.len() |  | ||||||
|   } |  | ||||||
|   pub fn is_empty(&self) -> bool { |  | ||||||
|     self.0.len() == 0 |  | ||||||
|   } |  | ||||||
|   pub fn get<I>(&self, index: I) -> Option<&I::Output> |  | ||||||
|   where |  | ||||||
|     I: std::slice::SliceIndex<[u8]>, |  | ||||||
|   { |  | ||||||
|     self.0.get(index) |  | ||||||
|   } |  | ||||||
|   pub fn starts_with(&self, needle: &Self) -> bool { |  | ||||||
|     self.0.starts_with(&needle.0) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| impl AsRef<[u8]> for PathNameBytesExp { |  | ||||||
|   fn as_ref(&self) -> &[u8] { |  | ||||||
|     self.0.as_ref() |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Trait to express names in ascii-lowercased bytes
 |  | ||||||
| pub trait BytesName { |  | ||||||
|   type OutputSv: Send + Sync + 'static; |  | ||||||
|   type OutputPath; |  | ||||||
|   fn to_server_name_vec(self) -> Self::OutputSv; |  | ||||||
|   fn to_path_name_vec(self) -> Self::OutputPath; |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<'a, T: Into<std::borrow::Cow<'a, str>>> BytesName for T { |  | ||||||
|   type OutputSv = ServerNameBytesExp; |  | ||||||
|   type OutputPath = PathNameBytesExp; |  | ||||||
| 
 |  | ||||||
|   fn to_server_name_vec(self) -> Self::OutputSv { |  | ||||||
|     let name = self.into().bytes().collect::<Vec<u8>>().to_ascii_lowercase(); |  | ||||||
|     ServerNameBytesExp(name) |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   fn to_path_name_vec(self) -> Self::OutputPath { |  | ||||||
|     let name = self.into().bytes().collect::<Vec<u8>>().to_ascii_lowercase(); |  | ||||||
|     PathNameBytesExp(name) |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| #[cfg(test)] |  | ||||||
| mod tests { |  | ||||||
|   use super::*; |  | ||||||
|   #[test] |  | ||||||
|   fn bytes_name_str_works() { |  | ||||||
|     let s = "OK_string"; |  | ||||||
|     let bn = s.to_path_name_vec(); |  | ||||||
|     let bn_lc = s.to_server_name_vec(); |  | ||||||
| 
 |  | ||||||
|     assert_eq!(Vec::from("ok_string".as_bytes()), bn.0); |  | ||||||
|     assert_eq!(Vec::from("ok_string".as_bytes()), bn_lc.0); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   #[test] |  | ||||||
|   fn from_works() { |  | ||||||
|     let s = "OK_string".to_server_name_vec(); |  | ||||||
|     let m = ServerNameBytesExp::from("OK_strinG".as_bytes()); |  | ||||||
|     assert_eq!(s, m); |  | ||||||
|     assert_eq!(s.0, "ok_string".as_bytes().to_vec()); |  | ||||||
|     assert_eq!(m.0, "ok_string".as_bytes().to_vec()); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   #[test] |  | ||||||
|   fn get_works() { |  | ||||||
|     let s = "OK_str".to_path_name_vec(); |  | ||||||
|     let i = s.get(0); |  | ||||||
|     assert_eq!(Some(&"o".as_bytes()[0]), i); |  | ||||||
|     let i = s.get(1); |  | ||||||
|     assert_eq!(Some(&"k".as_bytes()[0]), i); |  | ||||||
|     let i = s.get(2); |  | ||||||
|     assert_eq!(Some(&"_".as_bytes()[0]), i); |  | ||||||
|     let i = s.get(3); |  | ||||||
|     assert_eq!(Some(&"s".as_bytes()[0]), i); |  | ||||||
|     let i = s.get(4); |  | ||||||
|     assert_eq!(Some(&"t".as_bytes()[0]), i); |  | ||||||
|     let i = s.get(5); |  | ||||||
|     assert_eq!(Some(&"r".as_bytes()[0]), i); |  | ||||||
|     let i = s.get(6); |  | ||||||
|     assert_eq!(None, i); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   #[test] |  | ||||||
|   fn start_with_works() { |  | ||||||
|     let s = "OK_str".to_path_name_vec(); |  | ||||||
|     let correct = "OK".to_path_name_vec(); |  | ||||||
|     let incorrect = "KO".to_path_name_vec(); |  | ||||||
|     assert!(s.starts_with(&correct)); |  | ||||||
|     assert!(!s.starts_with(&incorrect)); |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   #[test] |  | ||||||
|   fn as_ref_works() { |  | ||||||
|     let s = "OK_str".to_path_name_vec(); |  | ||||||
|     assert_eq!(s.as_ref(), "ok_str".as_bytes()); |  | ||||||
|   } |  | ||||||
| } |  | ||||||
|  | @ -1,5 +0,0 @@ | ||||||
| mod bytes_name; |  | ||||||
| mod socket_addr; |  | ||||||
| 
 |  | ||||||
| pub use bytes_name::{BytesName, PathNameBytesExp, ServerNameBytesExp}; |  | ||||||
| pub use socket_addr::ToCanonical; |  | ||||||
|  | @ -1 +1 @@ | ||||||
| Subproject commit a57ed224ac5d17a635eb71eb6f83c1196f581a51 | Subproject commit c11410c76e738a62e62e7766b82f814547621f6f | ||||||
|  | @ -1,24 +0,0 @@ | ||||||
| [package] |  | ||||||
| name = "h3-quinn" |  | ||||||
| version = "0.0.1" |  | ||||||
| rust-version = "1.59" |  | ||||||
| authors = ["Jean-Christophe BEGUE <jc.begue@pm.me>"] |  | ||||||
| edition = "2018" |  | ||||||
| documentation = "https://docs.rs/h3-quinn" |  | ||||||
| repository = "https://github.com/hyperium/h3" |  | ||||||
| readme = "../README.md" |  | ||||||
| description = "QUIC transport implementation based on Quinn." |  | ||||||
| keywords = ["http3", "quic", "h3"] |  | ||||||
| categories = ["network-programming", "web-programming"] |  | ||||||
| license = "MIT" |  | ||||||
| 
 |  | ||||||
| [dependencies] |  | ||||||
| h3 = { version = "0.0.2", path = "../h3/h3" } |  | ||||||
| bytes = "1" |  | ||||||
| quinn = { path = "../quinn/quinn/", default-features = false, features = [ |  | ||||||
|   "futures-io", |  | ||||||
| ] } |  | ||||||
| quinn-proto = { path = "../quinn/quinn-proto/", default-features = false } |  | ||||||
| tokio-util = { version = "0.7.8" } |  | ||||||
| futures = { version = "0.3.27" } |  | ||||||
| tokio = { version = "1.28", features = ["io-util"], default-features = false } |  | ||||||
|  | @ -1,740 +0,0 @@ | ||||||
| //! QUIC Transport implementation with Quinn
 |  | ||||||
| //!
 |  | ||||||
| //! This module implements QUIC traits with Quinn.
 |  | ||||||
| #![deny(missing_docs)] |  | ||||||
| 
 |  | ||||||
| use std::{ |  | ||||||
|     convert::TryInto, |  | ||||||
|     fmt::{self, Display}, |  | ||||||
|     future::Future, |  | ||||||
|     pin::Pin, |  | ||||||
|     sync::Arc, |  | ||||||
|     task::{self, Poll}, |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| use bytes::{Buf, Bytes, BytesMut}; |  | ||||||
| 
 |  | ||||||
| use futures::{ |  | ||||||
|     ready, |  | ||||||
|     stream::{self, BoxStream}, |  | ||||||
|     StreamExt, |  | ||||||
| }; |  | ||||||
| use quinn::ReadDatagram; |  | ||||||
| pub use quinn::{ |  | ||||||
|     self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError, |  | ||||||
| }; |  | ||||||
| 
 |  | ||||||
| use h3::{ |  | ||||||
|     ext::Datagram, |  | ||||||
|     quic::{self, Error, StreamId, WriteBuf}, |  | ||||||
| }; |  | ||||||
| use tokio_util::sync::ReusableBoxFuture; |  | ||||||
| 
 |  | ||||||
| /// A QUIC connection backed by Quinn
 |  | ||||||
| ///
 |  | ||||||
| /// Implements a [`quic::Connection`] backed by a [`quinn::Connection`].
 |  | ||||||
| pub struct Connection { |  | ||||||
|     conn: quinn::Connection, |  | ||||||
|     incoming_bi: BoxStream<'static, <AcceptBi<'static> as Future>::Output>, |  | ||||||
|     opening_bi: Option<BoxStream<'static, <OpenBi<'static> as Future>::Output>>, |  | ||||||
|     incoming_uni: BoxStream<'static, <AcceptUni<'static> as Future>::Output>, |  | ||||||
|     opening_uni: Option<BoxStream<'static, <OpenUni<'static> as Future>::Output>>, |  | ||||||
|     datagrams: BoxStream<'static, <ReadDatagram<'static> as Future>::Output>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Connection { |  | ||||||
|     /// Create a [`Connection`] from a [`quinn::NewConnection`]
 |  | ||||||
|     pub fn new(conn: quinn::Connection) -> Self { |  | ||||||
|         Self { |  | ||||||
|             conn: conn.clone(), |  | ||||||
|             incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async { |  | ||||||
|                 Some((conn.accept_bi().await, conn)) |  | ||||||
|             })), |  | ||||||
|             opening_bi: None, |  | ||||||
|             incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { |  | ||||||
|                 Some((conn.accept_uni().await, conn)) |  | ||||||
|             })), |  | ||||||
|             opening_uni: None, |  | ||||||
|             datagrams: Box::pin(stream::unfold(conn, |conn| async { |  | ||||||
|                 Some((conn.read_datagram().await, conn)) |  | ||||||
|             })), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// The error type for [`Connection`]
 |  | ||||||
| ///
 |  | ||||||
| /// Wraps reasons a Quinn connection might be lost.
 |  | ||||||
| #[derive(Debug)] |  | ||||||
| pub struct ConnectionError(quinn::ConnectionError); |  | ||||||
| 
 |  | ||||||
| impl std::error::Error for ConnectionError {} |  | ||||||
| 
 |  | ||||||
| impl fmt::Display for ConnectionError { |  | ||||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |  | ||||||
|         self.0.fmt(f) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Error for ConnectionError { |  | ||||||
|     fn is_timeout(&self) -> bool { |  | ||||||
|         matches!(self.0, quinn::ConnectionError::TimedOut) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn err_code(&self) -> Option<u64> { |  | ||||||
|         match self.0 { |  | ||||||
|             quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { |  | ||||||
|                 error_code, |  | ||||||
|                 .. |  | ||||||
|             }) => Some(error_code.into_inner()), |  | ||||||
|             _ => None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl From<quinn::ConnectionError> for ConnectionError { |  | ||||||
|     fn from(e: quinn::ConnectionError) -> Self { |  | ||||||
|         Self(e) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Types of errors when sending a datagram.
 |  | ||||||
| #[derive(Debug)] |  | ||||||
| pub enum SendDatagramError { |  | ||||||
|     /// Datagrams are not supported by the peer
 |  | ||||||
|     UnsupportedByPeer, |  | ||||||
|     /// Datagrams are locally disabled
 |  | ||||||
|     Disabled, |  | ||||||
|     /// The datagram was too large to be sent.
 |  | ||||||
|     TooLarge, |  | ||||||
|     /// Network error
 |  | ||||||
|     ConnectionLost(Box<dyn Error>), |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl fmt::Display for SendDatagramError { |  | ||||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |  | ||||||
|         match self { |  | ||||||
|             SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), |  | ||||||
|             SendDatagramError::Disabled => write!(f, "datagram support disabled"), |  | ||||||
|             SendDatagramError::TooLarge => write!(f, "datagram too large"), |  | ||||||
|             SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl std::error::Error for SendDatagramError {} |  | ||||||
| 
 |  | ||||||
| impl Error for SendDatagramError { |  | ||||||
|     fn is_timeout(&self) -> bool { |  | ||||||
|         false |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn err_code(&self) -> Option<u64> { |  | ||||||
|         match self { |  | ||||||
|             Self::ConnectionLost(err) => err.err_code(), |  | ||||||
|             _ => None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl From<quinn::SendDatagramError> for SendDatagramError { |  | ||||||
|     fn from(value: quinn::SendDatagramError) -> Self { |  | ||||||
|         match value { |  | ||||||
|             quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, |  | ||||||
|             quinn::SendDatagramError::Disabled => Self::Disabled, |  | ||||||
|             quinn::SendDatagramError::TooLarge => Self::TooLarge, |  | ||||||
|             quinn::SendDatagramError::ConnectionLost(err) => { |  | ||||||
|                 Self::ConnectionLost(ConnectionError::from(err).into()) |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B> quic::Connection<B> for Connection |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     type SendStream = SendStream<B>; |  | ||||||
|     type RecvStream = RecvStream; |  | ||||||
|     type BidiStream = BidiStream<B>; |  | ||||||
|     type OpenStreams = OpenStreams; |  | ||||||
|     type Error = ConnectionError; |  | ||||||
| 
 |  | ||||||
|     fn poll_accept_bidi( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> { |  | ||||||
|         let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) { |  | ||||||
|             Some(x) => x?, |  | ||||||
|             None => return Poll::Ready(Ok(None)), |  | ||||||
|         }; |  | ||||||
|         Poll::Ready(Ok(Some(Self::BidiStream { |  | ||||||
|             send: Self::SendStream::new(send), |  | ||||||
|             recv: Self::RecvStream::new(recv), |  | ||||||
|         }))) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn poll_accept_recv( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> { |  | ||||||
|         let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) { |  | ||||||
|             Some(x) => x?, |  | ||||||
|             None => return Poll::Ready(Ok(None)), |  | ||||||
|         }; |  | ||||||
|         Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn poll_open_bidi( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Self::BidiStream, Self::Error>> { |  | ||||||
|         if self.opening_bi.is_none() { |  | ||||||
|             self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { |  | ||||||
|                 Some((conn.clone().open_bi().await, conn)) |  | ||||||
|             }))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let (send, recv) = |  | ||||||
|             ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; |  | ||||||
|         Poll::Ready(Ok(Self::BidiStream { |  | ||||||
|             send: Self::SendStream::new(send), |  | ||||||
|             recv: Self::RecvStream::new(recv), |  | ||||||
|         })) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn poll_open_send( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Self::SendStream, Self::Error>> { |  | ||||||
|         if self.opening_uni.is_none() { |  | ||||||
|             self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { |  | ||||||
|                 Some((conn.open_uni().await, conn)) |  | ||||||
|             }))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; |  | ||||||
|         Poll::Ready(Ok(Self::SendStream::new(send))) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn opener(&self) -> Self::OpenStreams { |  | ||||||
|         OpenStreams { |  | ||||||
|             conn: self.conn.clone(), |  | ||||||
|             opening_bi: None, |  | ||||||
|             opening_uni: None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn close(&mut self, code: h3::error::Code, reason: &[u8]) { |  | ||||||
|         self.conn.close( |  | ||||||
|             VarInt::from_u64(code.value()).expect("error code VarInt"), |  | ||||||
|             reason, |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B> quic::SendDatagramExt<B> for Connection |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     type Error = SendDatagramError; |  | ||||||
| 
 |  | ||||||
|     fn send_datagram(&mut self, data: Datagram<B>) -> Result<(), SendDatagramError> { |  | ||||||
|         // TODO investigate static buffer from known max datagram size
 |  | ||||||
|         let mut buf = BytesMut::new(); |  | ||||||
|         data.encode(&mut buf); |  | ||||||
|         self.conn.send_datagram(buf.freeze())?; |  | ||||||
| 
 |  | ||||||
|         Ok(()) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl quic::RecvDatagramExt for Connection { |  | ||||||
|     type Buf = Bytes; |  | ||||||
| 
 |  | ||||||
|     type Error = ConnectionError; |  | ||||||
| 
 |  | ||||||
|     #[inline] |  | ||||||
|     fn poll_accept_datagram( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { |  | ||||||
|         match ready!(self.datagrams.poll_next_unpin(cx)) { |  | ||||||
|             Some(Ok(x)) => Poll::Ready(Ok(Some(x))), |  | ||||||
|             Some(Err(e)) => Poll::Ready(Err(e.into())), |  | ||||||
|             None => Poll::Ready(Ok(None)), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Stream opener backed by a Quinn connection
 |  | ||||||
| ///
 |  | ||||||
| /// Implements [`quic::OpenStreams`] using [`quinn::Connection`],
 |  | ||||||
| /// [`quinn::OpenBi`], [`quinn::OpenUni`].
 |  | ||||||
| pub struct OpenStreams { |  | ||||||
|     conn: quinn::Connection, |  | ||||||
|     opening_bi: Option<BoxStream<'static, <OpenBi<'static> as Future>::Output>>, |  | ||||||
|     opening_uni: Option<BoxStream<'static, <OpenUni<'static> as Future>::Output>>, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B> quic::OpenStreams<B> for OpenStreams |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     type RecvStream = RecvStream; |  | ||||||
|     type SendStream = SendStream<B>; |  | ||||||
|     type BidiStream = BidiStream<B>; |  | ||||||
|     type Error = ConnectionError; |  | ||||||
| 
 |  | ||||||
|     fn poll_open_bidi( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Self::BidiStream, Self::Error>> { |  | ||||||
|         if self.opening_bi.is_none() { |  | ||||||
|             self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { |  | ||||||
|                 Some((conn.open_bi().await, conn)) |  | ||||||
|             }))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let (send, recv) = |  | ||||||
|             ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; |  | ||||||
|         Poll::Ready(Ok(Self::BidiStream { |  | ||||||
|             send: Self::SendStream::new(send), |  | ||||||
|             recv: Self::RecvStream::new(recv), |  | ||||||
|         })) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn poll_open_send( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Self::SendStream, Self::Error>> { |  | ||||||
|         if self.opening_uni.is_none() { |  | ||||||
|             self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { |  | ||||||
|                 Some((conn.open_uni().await, conn)) |  | ||||||
|             }))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; |  | ||||||
|         Poll::Ready(Ok(Self::SendStream::new(send))) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn close(&mut self, code: h3::error::Code, reason: &[u8]) { |  | ||||||
|         self.conn.close( |  | ||||||
|             VarInt::from_u64(code.value()).expect("error code VarInt"), |  | ||||||
|             reason, |  | ||||||
|         ); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Clone for OpenStreams { |  | ||||||
|     fn clone(&self) -> Self { |  | ||||||
|         Self { |  | ||||||
|             conn: self.conn.clone(), |  | ||||||
|             opening_bi: None, |  | ||||||
|             opening_uni: None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Quinn-backed bidirectional stream
 |  | ||||||
| ///
 |  | ||||||
| /// Implements [`quic::BidiStream`] which allows the stream to be split
 |  | ||||||
| /// into two structs each implementing one direction.
 |  | ||||||
| pub struct BidiStream<B> |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     send: SendStream<B>, |  | ||||||
|     recv: RecvStream, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B> quic::BidiStream<B> for BidiStream<B> |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     type SendStream = SendStream<B>; |  | ||||||
|     type RecvStream = RecvStream; |  | ||||||
| 
 |  | ||||||
|     fn split(self) -> (Self::SendStream, Self::RecvStream) { |  | ||||||
|         (self.send, self.recv) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B: Buf> quic::RecvStream for BidiStream<B> { |  | ||||||
|     type Buf = Bytes; |  | ||||||
|     type Error = ReadError; |  | ||||||
| 
 |  | ||||||
|     fn poll_data( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { |  | ||||||
|         self.recv.poll_data(cx) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn stop_sending(&mut self, error_code: u64) { |  | ||||||
|         self.recv.stop_sending(error_code) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn recv_id(&self) -> StreamId { |  | ||||||
|         self.recv.recv_id() |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B> quic::SendStream<B> for BidiStream<B> |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     type Error = SendStreamError; |  | ||||||
| 
 |  | ||||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { |  | ||||||
|         self.send.poll_ready(cx) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { |  | ||||||
|         self.send.poll_finish(cx) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn reset(&mut self, reset_code: u64) { |  | ||||||
|         self.send.reset(reset_code) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> { |  | ||||||
|         self.send.send_data(data) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn send_id(&self) -> StreamId { |  | ||||||
|         self.send.send_id() |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| impl<B> quic::SendStreamUnframed<B> for BidiStream<B> |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     fn poll_send<D: Buf>( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|         buf: &mut D, |  | ||||||
|     ) -> Poll<Result<usize, Self::Error>> { |  | ||||||
|         self.send.poll_send(cx, buf) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Quinn-backed receive stream
 |  | ||||||
| ///
 |  | ||||||
| /// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`].
 |  | ||||||
| pub struct RecvStream { |  | ||||||
|     stream: Option<quinn::RecvStream>, |  | ||||||
|     read_chunk_fut: ReadChunkFuture, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type ReadChunkFuture = ReusableBoxFuture< |  | ||||||
|     'static, |  | ||||||
|     ( |  | ||||||
|         quinn::RecvStream, |  | ||||||
|         Result<Option<quinn::Chunk>, quinn::ReadError>, |  | ||||||
|     ), |  | ||||||
| >; |  | ||||||
| 
 |  | ||||||
| impl RecvStream { |  | ||||||
|     fn new(stream: quinn::RecvStream) -> Self { |  | ||||||
|         Self { |  | ||||||
|             stream: Some(stream), |  | ||||||
|             // Should only allocate once the first time it's used
 |  | ||||||
|             read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl quic::RecvStream for RecvStream { |  | ||||||
|     type Buf = Bytes; |  | ||||||
|     type Error = ReadError; |  | ||||||
| 
 |  | ||||||
|     fn poll_data( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { |  | ||||||
|         if let Some(mut stream) = self.stream.take() { |  | ||||||
|             self.read_chunk_fut.set(async move { |  | ||||||
|                 let chunk = stream.read_chunk(usize::MAX, true).await; |  | ||||||
|                 (stream, chunk) |  | ||||||
|             }) |  | ||||||
|         }; |  | ||||||
| 
 |  | ||||||
|         let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx)); |  | ||||||
|         self.stream = Some(stream); |  | ||||||
|         Poll::Ready(Ok(chunk?.map(|c| c.bytes))) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn stop_sending(&mut self, error_code: u64) { |  | ||||||
|         self.stream |  | ||||||
|             .as_mut() |  | ||||||
|             .unwrap() |  | ||||||
|             .stop(VarInt::from_u64(error_code).expect("invalid error_code")) |  | ||||||
|             .ok(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn recv_id(&self) -> StreamId { |  | ||||||
|         self.stream |  | ||||||
|             .as_ref() |  | ||||||
|             .unwrap() |  | ||||||
|             .id() |  | ||||||
|             .0 |  | ||||||
|             .try_into() |  | ||||||
|             .expect("invalid stream id") |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// The error type for [`RecvStream`]
 |  | ||||||
| ///
 |  | ||||||
| /// Wraps errors that occur when reading from a receive stream.
 |  | ||||||
| #[derive(Debug)] |  | ||||||
| pub struct ReadError(quinn::ReadError); |  | ||||||
| 
 |  | ||||||
| impl From<ReadError> for std::io::Error { |  | ||||||
|     fn from(value: ReadError) -> Self { |  | ||||||
|         value.0.into() |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl std::error::Error for ReadError { |  | ||||||
|     fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { |  | ||||||
|         self.0.source() |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl fmt::Display for ReadError { |  | ||||||
|     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { |  | ||||||
|         self.0.fmt(f) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl From<ReadError> for Arc<dyn Error> { |  | ||||||
|     fn from(e: ReadError) -> Self { |  | ||||||
|         Arc::new(e) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl From<quinn::ReadError> for ReadError { |  | ||||||
|     fn from(e: quinn::ReadError) -> Self { |  | ||||||
|         Self(e) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Error for ReadError { |  | ||||||
|     fn is_timeout(&self) -> bool { |  | ||||||
|         matches!( |  | ||||||
|             self.0, |  | ||||||
|             quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut) |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn err_code(&self) -> Option<u64> { |  | ||||||
|         match self.0 { |  | ||||||
|             quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( |  | ||||||
|                 quinn_proto::ApplicationClose { error_code, .. }, |  | ||||||
|             )) => Some(error_code.into_inner()), |  | ||||||
|             quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), |  | ||||||
|             _ => None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// Quinn-backed send stream
 |  | ||||||
| ///
 |  | ||||||
| /// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`].
 |  | ||||||
| pub struct SendStream<B: Buf> { |  | ||||||
|     stream: Option<quinn::SendStream>, |  | ||||||
|     writing: Option<WriteBuf<B>>, |  | ||||||
|     write_fut: WriteFuture, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| type WriteFuture = |  | ||||||
|     ReusableBoxFuture<'static, (quinn::SendStream, Result<usize, quinn::WriteError>)>; |  | ||||||
| 
 |  | ||||||
| impl<B> SendStream<B> |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     fn new(stream: quinn::SendStream) -> SendStream<B> { |  | ||||||
|         Self { |  | ||||||
|             stream: Some(stream), |  | ||||||
|             writing: None, |  | ||||||
|             write_fut: ReusableBoxFuture::new(async { unreachable!() }), |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B> quic::SendStream<B> for SendStream<B> |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     type Error = SendStreamError; |  | ||||||
| 
 |  | ||||||
|     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { |  | ||||||
|         if let Some(ref mut data) = self.writing { |  | ||||||
|             while data.has_remaining() { |  | ||||||
|                 if let Some(mut stream) = self.stream.take() { |  | ||||||
|                     let chunk = data.chunk().to_owned(); // FIXME - avoid copy
 |  | ||||||
|                     self.write_fut.set(async move { |  | ||||||
|                         let ret = stream.write(&chunk).await; |  | ||||||
|                         (stream, ret) |  | ||||||
|                     }); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 let (stream, res) = ready!(self.write_fut.poll(cx)); |  | ||||||
|                 self.stream = Some(stream); |  | ||||||
|                 match res { |  | ||||||
|                     Ok(cnt) => data.advance(cnt), |  | ||||||
|                     Err(err) => { |  | ||||||
|                         return Poll::Ready(Err(SendStreamError::Write(err))); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         self.writing = None; |  | ||||||
|         Poll::Ready(Ok(())) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { |  | ||||||
|         self.stream |  | ||||||
|             .as_mut() |  | ||||||
|             .unwrap() |  | ||||||
|             .poll_finish(cx) |  | ||||||
|             .map_err(Into::into) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn reset(&mut self, reset_code: u64) { |  | ||||||
|         let _ = self |  | ||||||
|             .stream |  | ||||||
|             .as_mut() |  | ||||||
|             .unwrap() |  | ||||||
|             .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> { |  | ||||||
|         if self.writing.is_some() { |  | ||||||
|             return Err(Self::Error::NotReady); |  | ||||||
|         } |  | ||||||
|         self.writing = Some(data.into()); |  | ||||||
|         Ok(()) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn send_id(&self) -> StreamId { |  | ||||||
|         self.stream |  | ||||||
|             .as_ref() |  | ||||||
|             .unwrap() |  | ||||||
|             .id() |  | ||||||
|             .0 |  | ||||||
|             .try_into() |  | ||||||
|             .expect("invalid stream id") |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl<B> quic::SendStreamUnframed<B> for SendStream<B> |  | ||||||
| where |  | ||||||
|     B: Buf, |  | ||||||
| { |  | ||||||
|     fn poll_send<D: Buf>( |  | ||||||
|         &mut self, |  | ||||||
|         cx: &mut task::Context<'_>, |  | ||||||
|         buf: &mut D, |  | ||||||
|     ) -> Poll<Result<usize, Self::Error>> { |  | ||||||
|         if self.writing.is_some() { |  | ||||||
|             // This signifies a bug in implementation
 |  | ||||||
|             panic!("poll_send called while send stream is not ready") |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         let s = Pin::new(self.stream.as_mut().unwrap()); |  | ||||||
| 
 |  | ||||||
|         let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); |  | ||||||
|         match res { |  | ||||||
|             Ok(written) => { |  | ||||||
|                 buf.advance(written); |  | ||||||
|                 Poll::Ready(Ok(written)) |  | ||||||
|             } |  | ||||||
|             Err(err) => { |  | ||||||
|                 // We are forced to use AsyncWrite for now because we cannot store
 |  | ||||||
|                 // the result of a call to:
 |  | ||||||
|                 // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result<usize, WriteError>.
 |  | ||||||
|                 //
 |  | ||||||
|                 // This is why we have to unpack the error from io::Error instead of having it
 |  | ||||||
|                 // returned directly. This should not panic as long as quinn's AsyncWrite impl
 |  | ||||||
|                 // doesn't change.
 |  | ||||||
|                 let err = err |  | ||||||
|                     .into_inner() |  | ||||||
|                     .expect("write stream returned an empty error") |  | ||||||
|                     .downcast::<WriteError>() |  | ||||||
|                     .expect("write stream returned an error which type is not WriteError"); |  | ||||||
| 
 |  | ||||||
|                 Poll::Ready(Err(SendStreamError::Write(*err))) |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| /// The error type for [`SendStream`]
 |  | ||||||
| ///
 |  | ||||||
| /// Wraps errors that can happen writing to or polling a send stream.
 |  | ||||||
| #[derive(Debug)] |  | ||||||
| pub enum SendStreamError { |  | ||||||
|     /// Errors when writing, wrapping a [`quinn::WriteError`]
 |  | ||||||
|     Write(WriteError), |  | ||||||
|     /// Error when the stream is not ready, because it is still sending
 |  | ||||||
|     /// data from a previous call
 |  | ||||||
|     NotReady, |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl From<SendStreamError> for std::io::Error { |  | ||||||
|     fn from(value: SendStreamError) -> Self { |  | ||||||
|         match value { |  | ||||||
|             SendStreamError::Write(err) => err.into(), |  | ||||||
|             SendStreamError::NotReady => { |  | ||||||
|                 std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl std::error::Error for SendStreamError {} |  | ||||||
| 
 |  | ||||||
| impl Display for SendStreamError { |  | ||||||
|     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { |  | ||||||
|         write!(f, "{:?}", self) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl From<WriteError> for SendStreamError { |  | ||||||
|     fn from(e: WriteError) -> Self { |  | ||||||
|         Self::Write(e) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl Error for SendStreamError { |  | ||||||
|     fn is_timeout(&self) -> bool { |  | ||||||
|         matches!( |  | ||||||
|             self, |  | ||||||
|             Self::Write(quinn::WriteError::ConnectionLost( |  | ||||||
|                 quinn::ConnectionError::TimedOut |  | ||||||
|             )) |  | ||||||
|         ) |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     fn err_code(&self) -> Option<u64> { |  | ||||||
|         match self { |  | ||||||
|             Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), |  | ||||||
|             Self::Write(quinn::WriteError::ConnectionLost( |  | ||||||
|                 quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { |  | ||||||
|                     error_code, |  | ||||||
|                     .. |  | ||||||
|                 }), |  | ||||||
|             )) => Some(error_code.into_inner()), |  | ||||||
|             _ => None, |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| impl From<SendStreamError> for Arc<dyn Error> { |  | ||||||
|     fn from(e: SendStreamError) -> Self { |  | ||||||
|         Arc::new(e) |  | ||||||
|     } |  | ||||||
| } |  | ||||||
|  | @ -1 +0,0 @@ | ||||||
| Subproject commit e1e1e6e392a382fbded42ca010505fecb8fe3655 |  | ||||||
|  | @ -1 +1 @@ | ||||||
| Subproject commit 3cd09170305753309d86e88b9427827cca0de0dd | Subproject commit 88d23c2f5a3ac36295dff4a804968c43932ba46b | ||||||
|  | @ -1 +0,0 @@ | ||||||
| Subproject commit c88e64b6c58891651954834207d974de80e9bba8 |  | ||||||
							
								
								
									
										17
									
								
								submodules/s2n-quic-h3/Cargo.toml
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								submodules/s2n-quic-h3/Cargo.toml
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,17 @@ | ||||||
|  | [package] | ||||||
|  | name = "s2n-quic-h3" | ||||||
|  | # this in an unpublished internal crate so the version should not be changed | ||||||
|  | version = "0.1.0" | ||||||
|  | authors = ["AWS s2n"] | ||||||
|  | edition = "2021" | ||||||
|  | rust-version = "1.63" | ||||||
|  | license = "Apache-2.0" | ||||||
|  | # this contains an http3 implementation for testing purposes and should not be published | ||||||
|  | publish = false | ||||||
|  | 
 | ||||||
|  | [dependencies] | ||||||
|  | bytes = { version = "1", default-features = false } | ||||||
|  | futures = { version = "0.3", default-features = false } | ||||||
|  | h3 = { path = "../h3/h3/" } | ||||||
|  | s2n-quic = "1.33.0" | ||||||
|  | s2n-quic-core = "0.33.0" | ||||||
							
								
								
									
										10
									
								
								submodules/s2n-quic-h3/README.md
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								submodules/s2n-quic-h3/README.md
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,10 @@ | ||||||
|  | # s2n-quic-h3 | ||||||
|  | 
 | ||||||
|  | This is an internal crate used by [s2n-quic](https://github.com/aws/s2n-quic) written as a proof of concept for implementing HTTP3 on top of s2n-quic. The API is not currently stable and should not be used directly. | ||||||
|  | 
 | ||||||
|  | ## License | ||||||
|  | 
 | ||||||
|  | This project is licensed under the [Apache-2.0 License][license-url]. | ||||||
|  | 
 | ||||||
|  | [license-badge]: https://img.shields.io/badge/license-apache-blue.svg | ||||||
|  | [license-url]: https://aws.amazon.com/apache-2-0/ | ||||||
							
								
								
									
										7
									
								
								submodules/s2n-quic-h3/src/lib.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								submodules/s2n-quic-h3/src/lib.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,7 @@ | ||||||
|  | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 | ||||||
|  | // SPDX-License-Identifier: Apache-2.0
 | ||||||
|  | 
 | ||||||
|  | mod s2n_quic; | ||||||
|  | 
 | ||||||
|  | pub use self::s2n_quic::*; | ||||||
|  | pub use h3; | ||||||
							
								
								
									
										506
									
								
								submodules/s2n-quic-h3/src/s2n_quic.rs
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										506
									
								
								submodules/s2n-quic-h3/src/s2n_quic.rs
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,506 @@ | ||||||
|  | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
 | ||||||
|  | // SPDX-License-Identifier: Apache-2.0
 | ||||||
|  | 
 | ||||||
|  | use bytes::{Buf, Bytes}; | ||||||
|  | use futures::ready; | ||||||
|  | use h3::quic::{self, Error, StreamId, WriteBuf}; | ||||||
|  | use s2n_quic::stream::{BidirectionalStream, ReceiveStream}; | ||||||
|  | use s2n_quic_core::varint::VarInt; | ||||||
|  | use std::{ | ||||||
|  |     convert::TryInto, | ||||||
|  |     fmt::{self, Display}, | ||||||
|  |     sync::Arc, | ||||||
|  |     task::{self, Poll}, | ||||||
|  | }; | ||||||
|  | 
 | ||||||
|  | pub struct Connection { | ||||||
|  |     conn: s2n_quic::connection::Handle, | ||||||
|  |     bidi_acceptor: s2n_quic::connection::BidirectionalStreamAcceptor, | ||||||
|  |     recv_acceptor: s2n_quic::connection::ReceiveStreamAcceptor, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Connection { | ||||||
|  |     pub fn new(new_conn: s2n_quic::Connection) -> Self { | ||||||
|  |         let (handle, acceptor) = new_conn.split(); | ||||||
|  |         let (bidi, recv) = acceptor.split(); | ||||||
|  | 
 | ||||||
|  |         Self { | ||||||
|  |             conn: handle, | ||||||
|  |             bidi_acceptor: bidi, | ||||||
|  |             recv_acceptor: recv, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct ConnectionError(s2n_quic::connection::Error); | ||||||
|  | 
 | ||||||
|  | impl std::error::Error for ConnectionError {} | ||||||
|  | 
 | ||||||
|  | impl fmt::Display for ConnectionError { | ||||||
|  |     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||||
|  |         self.0.fmt(f) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Error for ConnectionError { | ||||||
|  |     fn is_timeout(&self) -> bool { | ||||||
|  |         matches!(self.0, s2n_quic::connection::Error::IdleTimerExpired { .. }) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn err_code(&self) -> Option<u64> { | ||||||
|  |         match self.0 { | ||||||
|  |             s2n_quic::connection::Error::Application { error, .. } => Some(error.into()), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<s2n_quic::connection::Error> for ConnectionError { | ||||||
|  |     fn from(e: s2n_quic::connection::Error) -> Self { | ||||||
|  |         Self(e) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> quic::Connection<B> for Connection | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     type BidiStream = BidiStream<B>; | ||||||
|  |     type SendStream = SendStream<B>; | ||||||
|  |     type RecvStream = RecvStream; | ||||||
|  |     type OpenStreams = OpenStreams; | ||||||
|  |     type Error = ConnectionError; | ||||||
|  | 
 | ||||||
|  |     fn poll_accept_recv( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Option<Self::RecvStream>, Self::Error>> { | ||||||
|  |         let recv = match ready!(self.recv_acceptor.poll_accept_receive_stream(cx))? { | ||||||
|  |             Some(x) => x, | ||||||
|  |             None => return Poll::Ready(Ok(None)), | ||||||
|  |         }; | ||||||
|  |         Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn poll_accept_bidi( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Option<Self::BidiStream>, Self::Error>> { | ||||||
|  |         let (recv, send) = match ready!(self.bidi_acceptor.poll_accept_bidirectional_stream(cx))? { | ||||||
|  |             Some(x) => x.split(), | ||||||
|  |             None => return Poll::Ready(Ok(None)), | ||||||
|  |         }; | ||||||
|  |         Poll::Ready(Ok(Some(Self::BidiStream { | ||||||
|  |             send: Self::SendStream::new(send), | ||||||
|  |             recv: Self::RecvStream::new(recv), | ||||||
|  |         }))) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn poll_open_bidi( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Self::BidiStream, Self::Error>> { | ||||||
|  |         let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; | ||||||
|  |         Ok(stream.into()).into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn poll_open_send( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Self::SendStream, Self::Error>> { | ||||||
|  |         let stream = ready!(self.conn.poll_open_send_stream(cx))?; | ||||||
|  |         Ok(stream.into()).into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn opener(&self) -> Self::OpenStreams { | ||||||
|  |         OpenStreams { | ||||||
|  |             conn: self.conn.clone(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { | ||||||
|  |         self.conn.close( | ||||||
|  |             code.value() | ||||||
|  |                 .try_into() | ||||||
|  |                 .expect("s2n-quic supports error codes up to 2^62-1"), | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct OpenStreams { | ||||||
|  |     conn: s2n_quic::connection::Handle, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> quic::OpenStreams<B> for OpenStreams | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     type BidiStream = BidiStream<B>; | ||||||
|  |     type SendStream = SendStream<B>; | ||||||
|  |     type RecvStream = RecvStream; | ||||||
|  |     type Error = ConnectionError; | ||||||
|  | 
 | ||||||
|  |     fn poll_open_bidi( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Self::BidiStream, Self::Error>> { | ||||||
|  |         let stream = ready!(self.conn.poll_open_bidirectional_stream(cx))?; | ||||||
|  |         Ok(stream.into()).into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn poll_open_send( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Self::SendStream, Self::Error>> { | ||||||
|  |         let stream = ready!(self.conn.poll_open_send_stream(cx))?; | ||||||
|  |         Ok(stream.into()).into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn close(&mut self, code: h3::error::Code, _reason: &[u8]) { | ||||||
|  |         self.conn.close( | ||||||
|  |             code.value() | ||||||
|  |                 .try_into() | ||||||
|  |                 .unwrap_or_else(|_| VarInt::MAX.into()), | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Clone for OpenStreams { | ||||||
|  |     fn clone(&self) -> Self { | ||||||
|  |         Self { | ||||||
|  |             conn: self.conn.clone(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct BidiStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     send: SendStream<B>, | ||||||
|  |     recv: RecvStream, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> quic::BidiStream<B> for BidiStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     type SendStream = SendStream<B>; | ||||||
|  |     type RecvStream = RecvStream; | ||||||
|  | 
 | ||||||
|  |     fn split(self) -> (Self::SendStream, Self::RecvStream) { | ||||||
|  |         (self.send, self.recv) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> quic::RecvStream for BidiStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     type Buf = Bytes; | ||||||
|  |     type Error = ReadError; | ||||||
|  | 
 | ||||||
|  |     fn poll_data( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { | ||||||
|  |         self.recv.poll_data(cx) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn stop_sending(&mut self, error_code: u64) { | ||||||
|  |         self.recv.stop_sending(error_code) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn recv_id(&self) -> StreamId { | ||||||
|  |         self.recv.stream.id().try_into().expect("invalid stream id") | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> quic::SendStream<B> for BidiStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     type Error = SendStreamError; | ||||||
|  | 
 | ||||||
|  |     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||||
|  |         self.send.poll_ready(cx) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||||
|  |         self.send.poll_finish(cx) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn reset(&mut self, reset_code: u64) { | ||||||
|  |         self.send.reset(reset_code) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> { | ||||||
|  |         self.send.send_data(data) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn send_id(&self) -> StreamId { | ||||||
|  |         self.send.stream.id().try_into().expect("invalid stream id") | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> From<BidirectionalStream> for BidiStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     fn from(bidi: BidirectionalStream) -> Self { | ||||||
|  |         let (recv, send) = bidi.split(); | ||||||
|  |         BidiStream { | ||||||
|  |             send: send.into(), | ||||||
|  |             recv: recv.into(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct RecvStream { | ||||||
|  |     stream: s2n_quic::stream::ReceiveStream, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl RecvStream { | ||||||
|  |     fn new(stream: s2n_quic::stream::ReceiveStream) -> Self { | ||||||
|  |         Self { stream } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl quic::RecvStream for RecvStream { | ||||||
|  |     type Buf = Bytes; | ||||||
|  |     type Error = ReadError; | ||||||
|  | 
 | ||||||
|  |     fn poll_data( | ||||||
|  |         &mut self, | ||||||
|  |         cx: &mut task::Context<'_>, | ||||||
|  |     ) -> Poll<Result<Option<Self::Buf>, Self::Error>> { | ||||||
|  |         let buf = ready!(self.stream.poll_receive(cx))?; | ||||||
|  |         Ok(buf).into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn stop_sending(&mut self, error_code: u64) { | ||||||
|  |         let _ = self.stream.stop_sending( | ||||||
|  |             s2n_quic::application::Error::new(error_code) | ||||||
|  |                 .expect("s2n-quic supports error codes up to 2^62-1"), | ||||||
|  |         ); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn recv_id(&self) -> StreamId { | ||||||
|  |         self.stream.id().try_into().expect("invalid stream id") | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<ReceiveStream> for RecvStream { | ||||||
|  |     fn from(recv: ReceiveStream) -> Self { | ||||||
|  |         RecvStream::new(recv) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub struct ReadError(s2n_quic::stream::Error); | ||||||
|  | 
 | ||||||
|  | impl std::error::Error for ReadError {} | ||||||
|  | 
 | ||||||
|  | impl fmt::Display for ReadError { | ||||||
|  |     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { | ||||||
|  |         self.0.fmt(f) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<ReadError> for Arc<dyn Error> { | ||||||
|  |     fn from(e: ReadError) -> Self { | ||||||
|  |         Arc::new(e) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<s2n_quic::stream::Error> for ReadError { | ||||||
|  |     fn from(e: s2n_quic::stream::Error) -> Self { | ||||||
|  |         Self(e) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Error for ReadError { | ||||||
|  |     fn is_timeout(&self) -> bool { | ||||||
|  |         matches!( | ||||||
|  |             self.0, | ||||||
|  |             s2n_quic::stream::Error::ConnectionError { | ||||||
|  |                 error: s2n_quic::connection::Error::IdleTimerExpired { .. }, | ||||||
|  |                 .. | ||||||
|  |             } | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn err_code(&self) -> Option<u64> { | ||||||
|  |         match self.0 { | ||||||
|  |             s2n_quic::stream::Error::ConnectionError { | ||||||
|  |                 error: s2n_quic::connection::Error::Application { error, .. }, | ||||||
|  |                 .. | ||||||
|  |             } => Some(error.into()), | ||||||
|  |             s2n_quic::stream::Error::StreamReset { error, .. } => Some(error.into()), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | pub struct SendStream<B: Buf> { | ||||||
|  |     stream: s2n_quic::stream::SendStream, | ||||||
|  |     chunk: Option<Bytes>, | ||||||
|  |     buf: Option<WriteBuf<B>>, // TODO: Replace with buf: PhantomData<B>
 | ||||||
|  |                               //       after https://github.com/hyperium/h3/issues/78 is resolved
 | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> SendStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     fn new(stream: s2n_quic::stream::SendStream) -> SendStream<B> { | ||||||
|  |         Self { | ||||||
|  |             stream, | ||||||
|  |             chunk: None, | ||||||
|  |             buf: Default::default(), | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> quic::SendStream<B> for SendStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     type Error = SendStreamError; | ||||||
|  | 
 | ||||||
|  |     fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||||
|  |         loop { | ||||||
|  |             // try to flush the current chunk if we have one
 | ||||||
|  |             if let Some(chunk) = self.chunk.as_mut() { | ||||||
|  |                 ready!(self.stream.poll_send(chunk, cx))?; | ||||||
|  | 
 | ||||||
|  |                 // s2n-quic will take the whole chunk on send, even if it exceeds the limits
 | ||||||
|  |                 debug_assert!(chunk.is_empty()); | ||||||
|  |                 self.chunk = None; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // try to take the next chunk from the WriteBuf
 | ||||||
|  |             if let Some(ref mut data) = self.buf { | ||||||
|  |                 let len = data.chunk().len(); | ||||||
|  | 
 | ||||||
|  |                 // if the write buf is empty, then clear it and break
 | ||||||
|  |                 if len == 0 { | ||||||
|  |                     self.buf = None; | ||||||
|  |                     break; | ||||||
|  |                 } | ||||||
|  | 
 | ||||||
|  |                 // copy the first chunk from WriteBuf and prepare it to flush
 | ||||||
|  |                 let chunk = data.copy_to_bytes(len); | ||||||
|  |                 self.chunk = Some(chunk); | ||||||
|  | 
 | ||||||
|  |                 // loop back around to flush the chunk
 | ||||||
|  |                 continue; | ||||||
|  |             } | ||||||
|  | 
 | ||||||
|  |             // if we didn't have either a chunk or WriteBuf, then we're ready
 | ||||||
|  |             break; | ||||||
|  |         } | ||||||
|  | 
 | ||||||
|  |         Poll::Ready(Ok(())) | ||||||
|  | 
 | ||||||
|  |         // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved
 | ||||||
|  |         // self.available_bytes = ready!(self.stream.poll_send_ready(cx))?;
 | ||||||
|  |         // Poll::Ready(Ok(()))
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn send_data<D: Into<WriteBuf<B>>>(&mut self, data: D) -> Result<(), Self::Error> { | ||||||
|  |         if self.buf.is_some() { | ||||||
|  |             return Err(Self::Error::NotReady); | ||||||
|  |         } | ||||||
|  |         self.buf = Some(data.into()); | ||||||
|  |         Ok(()) | ||||||
|  | 
 | ||||||
|  |         // TODO: Replace with following after https://github.com/hyperium/h3/issues/78 is resolved
 | ||||||
|  |         // let mut data = data.into();
 | ||||||
|  |         // while self.available_bytes > 0 && data.has_remaining() {
 | ||||||
|  |         //     let len = data.chunk().len();
 | ||||||
|  |         //     let chunk = data.copy_to_bytes(len);
 | ||||||
|  |         //     self.stream.send_data(chunk)?;
 | ||||||
|  |         //     self.available_bytes = self.available_bytes.saturating_sub(len);
 | ||||||
|  |         // }
 | ||||||
|  |         // Ok(())
 | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | ||||||
|  |         // ensure all chunks are flushed to the QUIC stream before finishing
 | ||||||
|  |         ready!(self.poll_ready(cx))?; | ||||||
|  |         self.stream.finish()?; | ||||||
|  |         Ok(()).into() | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn reset(&mut self, reset_code: u64) { | ||||||
|  |         let _ = self | ||||||
|  |             .stream | ||||||
|  |             .reset(reset_code.try_into().unwrap_or_else(|_| VarInt::MAX.into())); | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn send_id(&self) -> StreamId { | ||||||
|  |         self.stream.id().try_into().expect("invalid stream id") | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl<B> From<s2n_quic::stream::SendStream> for SendStream<B> | ||||||
|  | where | ||||||
|  |     B: Buf, | ||||||
|  | { | ||||||
|  |     fn from(send: s2n_quic::stream::SendStream) -> Self { | ||||||
|  |         SendStream::new(send) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | #[derive(Debug)] | ||||||
|  | pub enum SendStreamError { | ||||||
|  |     Write(s2n_quic::stream::Error), | ||||||
|  |     NotReady, | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl std::error::Error for SendStreamError {} | ||||||
|  | 
 | ||||||
|  | impl Display for SendStreamError { | ||||||
|  |     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { | ||||||
|  |         write!(f, "{self:?}") | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<s2n_quic::stream::Error> for SendStreamError { | ||||||
|  |     fn from(e: s2n_quic::stream::Error) -> Self { | ||||||
|  |         Self::Write(e) | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl Error for SendStreamError { | ||||||
|  |     fn is_timeout(&self) -> bool { | ||||||
|  |         matches!( | ||||||
|  |             self, | ||||||
|  |             Self::Write(s2n_quic::stream::Error::ConnectionError { | ||||||
|  |                 error: s2n_quic::connection::Error::IdleTimerExpired { .. }, | ||||||
|  |                 .. | ||||||
|  |             }) | ||||||
|  |         ) | ||||||
|  |     } | ||||||
|  | 
 | ||||||
|  |     fn err_code(&self) -> Option<u64> { | ||||||
|  |         match self { | ||||||
|  |             Self::Write(s2n_quic::stream::Error::StreamReset { error, .. }) => { | ||||||
|  |                 Some((*error).into()) | ||||||
|  |             } | ||||||
|  |             Self::Write(s2n_quic::stream::Error::ConnectionError { | ||||||
|  |                 error: s2n_quic::connection::Error::Application { error, .. }, | ||||||
|  |                 .. | ||||||
|  |             }) => Some((*error).into()), | ||||||
|  |             _ => None, | ||||||
|  |         } | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | impl From<SendStreamError> for Arc<dyn Error> { | ||||||
|  |     fn from(e: SendStreamError) -> Self { | ||||||
|  |         Arc::new(e) | ||||||
|  |     } | ||||||
|  | } | ||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue
	
	 Jun Kurihara
				Jun Kurihara