diff --git a/.github/workflows/docker_build_push.yml b/.github/workflows/docker_build_push.yml index 2e5b38e..1dfd260 100644 --- a/.github/workflows/docker_build_push.yml +++ b/.github/workflows/docker_build_push.yml @@ -38,7 +38,7 @@ jobs: push: true tags: | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:latest - file: ./docker/amd64/Dockerfile + file: ./docker/Dockerfile.amd64 - name: Release build and push x86_64-slim if: ${{ env.BRANCH == 'main' }} @@ -48,7 +48,7 @@ jobs: push: true tags: | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:slim, ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:latest-slim - file: ./docker/amd64-slim/Dockerfile + file: ./docker/Dockerfile.amd64-slim - name: Nightly build and push x86_64 if: ${{ env.BRANCH == 'develop' }} @@ -58,4 +58,4 @@ jobs: push: true tags: | ${{ secrets.DOCKERHUB_USERNAME }}/${{ env.IMAGE_NAME }}:nightly - file: ./docker/amd64/Dockerfile + file: ./docker/Dockerfile.amd64 diff --git a/.gitignore b/.gitignore index 02474f4..6797716 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ .vscode .private +docker/log # Generated by Cargo diff --git a/.gitmodules b/.gitmodules index 0a7fc93..b9069a0 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "h3"] path = h3 url = git@github.com:junkurihara/h3.git +[submodule "quinn"] + path = quinn + url = git@github.com:junkurihara/quinn.git diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cff25b..8dea263 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,21 @@ # CHANGELOG -## 0.x.x (unreleased) +## 0.4.0 (unreleased) + + +## 0.3.0 ### Improvement +- HTTP/3 Deps: Update `h3` with `quinn-0.10` or higher. But changed their crates from `crates.io` to `git submodule` as a part of work around. I think this will be back to `crates.io` in a near-future update. +- Load Balancing: Implement the session persistance function for load balancing using sticky cookie (initial implementation). Enabled in `default-features`. +- Docker UID:GID: Update `Dockerfile`s to allow arbitrary UID and GID (non-root users) for rpxy. Now they can be set as you like by specifying through env vars. +- Refactor: Various minor improvements + +## 0.2.0 + +### Improvement + +- Update docker of `nightly` built from `develop` branch along with `amd64-slim` and `amd64` images with `latest` and `latest:slim` tags built from `main` branch. `nightly` image is based on `amd64`. +- Update `h3` with `quinn-0.10` or higher. - Implement path replacing option for each reverse proxy backend group. diff --git a/Cargo.toml b/Cargo.toml index 166b06c..8d955ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "rpxy" -version = "0.2.0" +version = "0.3.0" authors = ["Jun Kurihara"] homepage = "https://github.com/junkurihara/rust-rpxy" repository = "https://github.com/junkurihara/rust-rpxy" @@ -12,22 +12,23 @@ publish = false # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [features] -default = ["http3"] +default = ["http3", "sticky-cookie"] http3 = ["quinn", "h3", "h3-quinn"] +sticky-cookie = ["base64", "sha2", "chrono"] [dependencies] -anyhow = "1.0.70" -clap = { version = "4.2.1", features = ["std", "cargo", "wrap_help"] } +anyhow = "1.0.71" +clap = { version = "4.3.4", features = ["std", "cargo", "wrap_help"] } rand = "0.8.5" -toml = { version = "0.7.3", default-features = false, features = ["parse"] } +toml = { version = "0.7.4", default-features = false, features = ["parse"] } rustc-hash = "1.1.0" -serde = { version = "1.0.159", default-features = false, features = ["derive"] } +serde = { version = "1.0.164", default-features = false, features = ["derive"] } bytes = "1.4.0" thiserror = "1.0.40" x509-parser = "0.15.0" derive_builder = "0.12.0" futures = { version = "0.3.28", features = ["alloc", "async-await"] } -tokio = { version = "1.27.0", default-features = false, features = [ +tokio = { version = "1.28.2", default-features = false, features = [ "net", "rt-multi-thread", "parking_lot", @@ -37,30 +38,42 @@ tokio = { version = "1.27.0", default-features = false, features = [ ] } # http and tls -hyper = { version = "0.14.25", default-features = false, features = [ +hyper = { version = "0.14.26", default-features = false, features = [ "server", "http1", "http2", "stream", ] } -hyper-rustls = { version = "0.23.2", default-features = false, features = [ +hyper-rustls = { version = "0.24.0", default-features = false, features = [ "tokio-runtime", "webpki-tokio", "http1", "http2", ] } -tokio-rustls = { version = "0.23.4", features = ["early-data"] } +tokio-rustls = { version = "0.24.1", features = ["early-data"] } rustls-pemfile = "1.0.2" -rustls = { version = "0.20.8", default-features = false } +rustls = { version = "0.21.2", default-features = false } +webpki = "0.22.0" # logging tracing = { version = "0.1.37" } -tracing-subscriber = { version = "0.3.16", features = ["env-filter"] } +tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } # http/3 -quinn = { version = "0.9.3", optional = true } +# quinn = { version = "0.9.3", optional = true } +quinn = { path = "./quinn/quinn", optional = true } # Tentative to support rustls-0.21 h3 = { path = "./h3/h3/", optional = true } -h3-quinn = { path = "./h3/h3-quinn/", optional = true } +# h3-quinn = { path = "./h3/h3-quinn/", optional = true } +h3-quinn = { path = "./h3-quinn/", optional = true } # Tentative to support rustls-0.21 + +# cookie handling for sticky cookie +chrono = { version = "0.4.26", default-features = false, features = [ + "unstable-locales", + "alloc", + "clock", +], optional = true } +base64 = { version = "0.21.2", optional = true } +sha2 = { version = "0.10.7", default-features = false, optional = true } [target.'cfg(not(target_env = "msvc"))'.dependencies] diff --git a/README.md b/README.md index 3f255df..a6e7b7a 100644 --- a/README.md +++ b/README.md @@ -10,11 +10,11 @@ ## Introduction -`rpxy` [ahr-pik-see] is an implementation of simple and lightweight reverse-proxy with some additional features. The implementation is based on [`hyper`](https://github.com/hyperium/hyper), [`rustls`](https://github.com/rustls/rustls) and [`tokio`](https://github.com/tokio-rs/tokio), i.e., written in pure Rust. Our `rpxy` allows to route multiple host names to appropriate backend application servers while serving TLS connections. +`rpxy` [ahr-pik-see] is an implementation of simple and lightweight reverse-proxy with some additional features. The implementation is based on [`hyper`](https://github.com/hyperium/hyper), [`rustls`](https://github.com/rustls/rustls) and [`tokio`](https://github.com/tokio-rs/tokio), i.e., written in pure Rust. Our `rpxy` routes multiple host names to appropriate backend application servers while serving TLS connections. - As default, `rpxy` provides the *TLS connection sanitization* by correctly binding a certificate used to establish secure channel with backend application. Specifically, it always keeps the consistency between the given SNI (server name indication) in `ClientHello` of the underlying TLS and the domain name given by the overlaid HTTP HOST header (or URL in Request line) [^1]. Additionally, as a somewhat unstable feature, our `rpxy` can handle the brand-new HTTP/3 connection thanks to [`quinn`](https://github.com/quinn-rs/quinn) and [`hyperium/h3`](https://github.com/hyperium/h3). + As default, `rpxy` provides the *TLS connection sanitization* by correctly binding a certificate used to establish a secure channel with the backend application. Specifically, it always keeps the consistency between the given SNI (server name indication) in `ClientHello` of the underlying TLS and the domain name given by the overlaid HTTP HOST header (or URL in Request line) [^1]. Additionally, as a somewhat unstable feature, our `rpxy` can handle the brand-new HTTP/3 connection thanks to [`quinn`](https://github.com/quinn-rs/quinn) and [`hyperium/h3`](https://github.com/hyperium/h3). - This project is still *work-in-progress*. But it is already working in some production environments and serves numbers of domain names. Furthermore it *significantly outperforms* NGINX and Caddy, e.g., *1.5x faster than NGINX*, in the setting of very simple HTTP reverse-proxy scenario (See [`bench`](./bench/) directory). + This project is still *work-in-progress*. But it is already working in some production environments and serves a number of domain names. Furthermore it *significantly outperforms* NGINX and Caddy, e.g., *1.5x faster than NGINX*, in the setting of a very simple HTTP reverse-proxy scenario (See [`bench`](./bench/) directory). [^1]: We should note that NGINX doesn't guarantee such a consistency by default. To this end, you have to add `if` statement in the configuration file in NGINX. @@ -108,7 +108,7 @@ revese_proxy = [ #### Load Balancing -You can specify multiple backend locations in the `reverse_proxy` array for *load-balancing*. Currently it works in the manner of round-robin. +You can specify multiple backend locations in the `reverse_proxy` array for *load-balancing* with an appropriate `load_balance` option. Currently it works in the manner of round-robin, in the random fashion, or round-robin with *session-persistance* using cookie. if `load_balance` is not specified, the first backend location is always chosen. ```toml [apps."app_name"] @@ -117,6 +117,7 @@ reverse_proxy = [ { location = 'app1.local:8080' }, { location = 'app2.local:8000' } ] +load_balance = 'round_robin' # or 'random' or 'sticky' ``` ### Second Step: Terminating TLS diff --git a/TODO.md b/TODO.md index 3c44f82..90fb79d 100644 --- a/TODO.md +++ b/TODO.md @@ -3,8 +3,14 @@ - Improvement of path matcher - More flexible option for rewriting path - Refactoring + + Split `backend` module into three parts + + - backend(s): struct containing info, defined for each served domain with multiple paths + - upstream/upstream group: information on targeted destinations for each set of (a domain + a path) + - load-balance: load balancing mod for a domain + path + - Unit tests -- Implementing load-balancing of backend apps (currently it doesn't consider to maintain session but simply rotate in a certain fashion) - Options to serve custom http_error page. - Prometheus metrics - Documentation @@ -13,4 +19,7 @@ - Currently, we took the following approach (caveats) - For Http2 and 1.1, prepare `rustls::ServerConfig` for each domain name and hence client CA cert is set for each one. - For Http3, use aggregated `rustls::ServerConfig` for multiple domain names except for ones requiring client-auth. So, if a domain name is set with client authentication, http3 doesn't work for the domain. +- Make the session-persistance option for load-balancing sophisticated. (mostly done in v0.3.0) + - add option for sticky cookie name + - add option for sticky cookie duration - etc. diff --git a/config-example.toml b/config-example.toml index 9b2b463..0382393 100644 --- a/config-example.toml +++ b/config-example.toml @@ -52,6 +52,7 @@ upstream = [ { location = 'www.yahoo.com', tls = true }, { location = 'www.yahoo.co.jp', tls = true }, ] +load_balance = "round_robin" # or "random" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = ["override_host", "convert_https_to_2"] # Non-default destination in "localhost" app, which is routed by "path" @@ -60,13 +61,14 @@ path = '/maps' # For request path starting with "/maps", # this configuration results that any path like "/maps/org/any.ext" is mapped to "/replacing/path1/org/any.ext" # by replacing "/maps" with "/replacing/path1" for routing to the locations given in upstream array -# Note that unless "path_replaced_with" is specified, the "path" is always preserved. -# "path_replaced_with" must be start from "/" (root path) +# Note that unless "replace_path" is specified, the "path" is always preserved. +# "replace_path" must be start from "/" (root path) replace_path = "/replacing/path1" upstream = [ { location = 'www.bing.com', tls = true }, { location = 'www.bing.co.jp', tls = true }, ] +load_balance = "random" # or "round_robin" or "sticky" (sticky session) or "none" (fix to the first one, default) upstream_options = [ "override_host", "upgrade_insecure_requests", diff --git a/docker/amd64/Dockerfile b/docker/Dockerfile.amd64 similarity index 59% rename from docker/amd64/Dockerfile rename to docker/Dockerfile.amd64 index 8f7ecf7..da27439 100644 --- a/docker/amd64/Dockerfile +++ b/docker/Dockerfile.amd64 @@ -30,26 +30,27 @@ RUN apt-get update && apt-get install -qy --no-install-recommends $BUILD_DEPS && FROM base AS runner ENV TAG_NAME=amd64 -ENV RUNTIME_DEPS logrotate ca-certificates +ENV RUNTIME_DEPS logrotate ca-certificates gosu RUN apt-get update && \ apt-get install -qy --no-install-recommends $RUNTIME_DEPS && \ apt-get -qy clean && \ - apt-get -qy autoremove &&\ - rm -fr /tmp/* /var/tmp/* /var/cache/apt/* /var/lib/apt/lists/* /var/log/apt/* /var/log/*.log &&\ - mkdir -p /opt/rpxy/sbin &&\ - mkdir -p /var/log/rpxy && \ - touch /var/log/rpxy/rpxy.log + apt-get -qy autoremove && \ + rm -fr /tmp/* /var/tmp/* /var/cache/apt/* /var/lib/apt/lists/* /var/log/apt/* /var/log/*.log && \ + find / -type d -path /proc -prune -o -type f -perm /u+s -ignore_readdir_race -exec chmod u-s {} \; && \ + find / -type d -path /proc -prune -o -type f -perm /g+s -ignore_readdir_race -exec chmod g-s {} \; && \ + mkdir -p /rpxy/bin &&\ + mkdir -p /rpxy/log -COPY --from=builder /tmp/target/release/rpxy /opt/rpxy/sbin/rpxy -COPY ./docker/${TAG_NAME}/run.sh / -COPY ./docker/entrypoint.sh / +COPY --from=builder /tmp/target/release/rpxy /rpxy/bin/rpxy +COPY ./docker/run.sh /rpxy +COPY ./docker/entrypoint.sh /rpxy -RUN chmod 755 /run.sh && \ - chmod 755 /entrypoint.sh +RUN chmod +x /rpxy/run.sh && \ + chmod +x /rpxy/entrypoint.sh EXPOSE 80 443 -CMD ["/entrypoint.sh"] +CMD ["/usr/bin/bash" "/rpxy/entrypoint.sh"] -ENTRYPOINT ["/entrypoint.sh"] +ENTRYPOINT ["/usr/bin/bash", "/rpxy/entrypoint.sh"] diff --git a/docker/amd64-slim/Dockerfile b/docker/Dockerfile.amd64-slim similarity index 57% rename from docker/amd64-slim/Dockerfile rename to docker/Dockerfile.amd64-slim index 9e5b9d4..fb0246e 100644 --- a/docker/amd64-slim/Dockerfile +++ b/docker/Dockerfile.amd64-slim @@ -20,26 +20,27 @@ LABEL maintainer="Jun Kurihara" ENV TAG_NAME=amd64-slim ENV TARGET_DIR=x86_64-unknown-linux-musl -ENV RUNTIME_DEPS logrotate ca-certificates +ENV RUNTIME_DEPS logrotate ca-certificates su-exec RUN apk add --no-cache ${RUNTIME_DEPS} && \ update-ca-certificates && \ - mkdir -p /opt/rpxy/sbin &&\ - mkdir -p /var/log/rpxy && \ - touch /var/log/rpxy/rpxy.log + find / -type d -path /proc -prune -o -type f -perm /u+s -exec chmod u-s {} \; && \ + find / -type d -path /proc -prune -o -type f -perm /g+s -exec chmod g-s {} \; && \ + mkdir -p /rpxy/bin &&\ + mkdir -p /rpxy/log -COPY --from=builder /tmp/target/${TARGET_DIR}/release/rpxy /opt/rpxy/sbin/rpxy -COPY ./docker/${TAG_NAME}/run.sh / -COPY ./docker/entrypoint.sh / +COPY --from=builder /tmp/target/${TARGET_DIR}/release/rpxy /rpxy/bin/rpxy +COPY ./docker/run.sh /rpxy +COPY ./docker/entrypoint.sh /rpxy -RUN chmod 755 /run.sh && \ - chmod 755 /entrypoint.sh +RUN chmod +x /rpxy/run.sh && \ + chmod +x /rpxy/entrypoint.sh ENV SSL_CERT_FILE=/etc/ssl/certs/ca-certificates.crt ENV SSL_CERT_DIR=/etc/ssl/certs EXPOSE 80 443 -CMD ["/entrypoint.sh"] +CMD ["/rpxy/entrypoint.sh"] -ENTRYPOINT ["/entrypoint.sh"] +ENTRYPOINT ["/rpxy/entrypoint.sh"] diff --git a/docker/amd64-slim/run.sh b/docker/amd64-slim/run.sh deleted file mode 100644 index 1d99125..0000000 --- a/docker/amd64-slim/run.sh +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env sh - -LOG_FILE=/var/log/rpxy/rpxy.log -CONFIG_FILE=/etc/rpxy.toml -LOG_SIZE=10M -LOG_NUM=10 - -# logrotate -if [ $LOGROTATE_NUM ]; then - LOG_NUM=${LOGROTATE_NUM} -fi -if [ $LOGROTATE_SIZE ]; then - LOG_SIZE=${LOGROTATE_SIZE} -fi - -cat > /etc/logrotate.conf << EOF -# see "man logrotate" for details -# rotate log files weekly -weekly -# use the adm group by default, since this is the owning group -# of /var/log/syslog. -su root adm -# keep 4 weeks worth of backlogs -rotate 4 -# create new (empty) log files after rotating old ones -create -# use date as a suffix of the rotated file -#dateext -# uncomment this if you want your log files compressed -#compress -# packages drop log rotation information into this directory -include /etc/logrotate.d -# system-specific logs may be also be configured here. -EOF - -cat > /etc/logrotate.d/rpxy.conf << EOF -${LOG_FILE} { - dateext - daily - missingok - rotate ${LOG_NUM} - notifempty - compress - delaycompress - dateformat -%Y-%m-%d-%s - size ${LOG_SIZE} - copytruncate -} -EOF - -cp -f /etc/periodic/daily/logrotate /etc/periodic/15min -crond restart - -# debug level logging -if [ -z $LOG_LEVEL ]; then - LOG_LEVEL=info -fi -echo "rpxy: Logging with level ${LOG_LEVEL}" - -RUST_LOG=${LOG_LEVEL} /opt/rpxy/sbin/rpxy --config ${CONFIG_FILE} diff --git a/docker/amd64/run.sh b/docker/amd64/run.sh deleted file mode 100644 index bace2c9..0000000 --- a/docker/amd64/run.sh +++ /dev/null @@ -1,61 +0,0 @@ - -#!/usr/bin/env sh - -LOG_FILE=/var/log/rpxy/rpxy.log -CONFIG_FILE=/etc/rpxy.toml -LOG_SIZE=10M -LOG_NUM=10 - -# logrotate -if [ $LOGROTATE_NUM ]; then - LOG_NUM=${LOGROTATE_NUM} -fi -if [ $LOGROTATE_SIZE ]; then - LOG_SIZE=${LOGROTATE_SIZE} -fi - -cat > /etc/logrotate.conf << EOF -# see "man logrotate" for details -# rotate log files weekly -weekly -# use the adm group by default, since this is the owning group -# of /var/log/syslog. -su root adm -# keep 4 weeks worth of backlogs -rotate 4 -# create new (empty) log files after rotating old ones -create -# use date as a suffix of the rotated file -#dateext -# uncomment this if you want your log files compressed -#compress -# packages drop log rotation information into this directory -include /etc/logrotate.d -# system-specific logs may be also be configured here. -EOF - -cat > /etc/logrotate.d/rpxy << EOF -${LOG_FILE} { - dateext - daily - missingok - rotate ${LOG_NUM} - notifempty - compress - delaycompress - dateformat -%Y-%m-%d-%s - size ${LOG_SIZE} - copytruncate -} -EOF - -cp -p /etc/cron.daily/logrotate /etc/cron.hourly/ -service cron start - -# debug level logging -if [ -z $LOG_LEVEL ]; then - LOG_LEVEL=info -fi -echo "rpxy: Logging with level ${LOG_LEVEL}" - -RUST_LOG=${LOG_LEVEL} /opt/rpxy/sbin/rpxy --config ${CONFIG_FILE} diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 9a64db2..716d0de 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -3,19 +3,24 @@ services: rpxy-rp: image: jqtype/rpxy container_name: rpxy + init: true restart: unless-stopped ports: - 127.0.0.1:8080:8080 - 127.0.0.1:8443:8443 build: context: ../ - dockerfile: ./docker/amd64/Dockerfile + dockerfile: ./docker/Dockerfile.amd64 environment: - LOG_LEVEL=debug - - LOG_TO_FILE=false + - LOG_TO_FILE=true + - HOST_USER=jun + - HOST_UID=501 + - HOST_GID=501 tty: false privileged: true volumes: + - ./log:/rpxy/log - ../example-certs/server.crt:/certs/server.crt:ro - ../example-certs/server.key:/certs/server.key:ro - ../config-example.toml:/etc/rpxy.toml:ro diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 055f745..180ab93 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -1,14 +1,143 @@ #!/usr/bin/env sh -LOG_FILE=/var/log/rpxy/rpxy.log +LOG_DIR=/rpxy/log +LOG_FILE=${LOG_DIR}/rpxy.log +LOG_SIZE=10M +LOG_NUM=10 -if [ -z ${LOG_TO_FILE} ]; then - LOG_TO_FILE=false +LOGGING=${LOG_TO_FILE:-false} +USER=${HOST_USER:-rpxy} +USER_ID=${HOST_UID:-900} +GROUP_ID=${HOST_GID:-900} + +####################################### +# Setup logrotate +function setup_logrotate () { + if [ $LOGROTATE_NUM ]; then + LOG_NUM=${LOGROTATE_NUM} + fi + if [ $LOGROTATE_SIZE ]; then + LOG_SIZE=${LOGROTATE_SIZE} + fi + + cat > /etc/logrotate.conf << EOF +# see "man logrotate" for details +# rotate log files weekly +weekly +# use the adm group by default, since this is the owning group +# of /var/log/syslog. +# su root adm +# keep 4 weeks worth of backlogs +rotate 4 +# create new (empty) log files after rotating old ones +create +# use date as a suffix of the rotated file +#dateext +# uncomment this if you want your log files compressed +#compress +# packages drop log rotation information into this directory +include /etc/logrotate.d +# system-specific logs may be also be configured here. +EOF + + cat > /etc/logrotate.d/rpxy.conf << EOF +${LOG_FILE} { + dateext + daily + missingok + rotate ${LOG_NUM} + notifempty + compress + delaycompress + dateformat -%Y-%m-%d-%s + size ${LOG_SIZE} + copytruncate + su ${USER} ${USER} +} +EOF +} + +####################################### +function setup_ubuntu () { + # Check the existence of the user, if not exist, create it. + if [ ! $(id ${USER}) ]; then + echo "rpxy: Create user ${USER} with ${USER_ID}:${GROUP_ID}" + groupadd -g ${GROUP_ID} ${USER} + useradd -u ${USER_ID} -g ${GROUP_ID} ${USER} + fi + + # for crontab when logging + if "${LOGGING}"; then + # Set up logrotate + setup_logrotate + + # Setup cron + mkdir -p /etc/cron.15min/ + cp -p /etc/cron.daily/logrotate /etc/cron.15min/ + echo "*/15 * * * * root cd / && run-parts --report /etc/cron.15min" >> /etc/crontab + # cp -p /etc/cron.daily/logrotate /etc/cron.hourly/ + service cron start + fi +} + +####################################### +function setup_alpine () { + # Check the existence of the user, if not exist, create it. + if [ ! $(id ${USER}) ]; then + echo "rpxy: Create user ${USER} with ${USER_ID}:${GROUP_ID}" + addgroup -g ${GROUP_ID} ${USER} + adduser -H -D -u ${USER_ID} -G ${USER} ${USER} + fi + + # for crontab when logging + if "${LOGGING}"; then + # Set up logrotate + setup_logrotate + + # Setup cron + cp -f /etc/periodic/daily/logrotate /etc/periodic/15min + crond -b -l 8 + fi +} + +####################################### + +if [ $(whoami) != "root" -o $(id -u) -ne 0 -a $(id -g) -ne 0 ]; then + echo "Do not execute 'docker run' or 'docker-compose up' with a specific user through '-u'." + echo "If you want to run 'rpxy' with a specific user, use HOST_USER, HOST_UID and HOST_GID environment variables." + exit 1 fi -if "${LOG_TO_FILE}"; then +# Check gosu or su-exec, determine linux distribution, and set up user +if [ $(command -v gosu) ]; then + # Ubuntu Linux + alias gosu='gosu' + setup_ubuntu + LINUX="Ubuntu" +elif [ $(command -v su-exec) ]; then + # Alpine Linux + alias gosu='su-exec' + setup_alpine + LINUX="Alpine" +else + echo "Unknown distribution!" + exit 1 +fi + +# Check the given user and its uid:gid +if [ $(id -u ${USER}) -ne ${USER_ID} -a $(id -g ${USER}) -ne ${GROUP_ID} ]; then + echo "${USER} exists or was previously created. However, its uid and gid are inconsistent. Please recreate your container." + exit 1 +fi + +# Change permission according to the given user +chown -R ${USER_ID}:${USER_ID} /rpxy + +# Run rpxy +echo "rpxy: Start with user: ${USER} (${USER_ID}:${GROUP_ID})" +if "${LOGGING}"; then echo "rpxy: Start with writing log file" - /run.sh 2>&1 | tee $LOG_FILE + gosu ${USER} sh -c "/rpxy/run.sh 2>&1 | tee ${LOG_FILE}" else echo "rpxy: Start without writing log file" - /run.sh 2>&1 + gosu ${USER} sh -c "/rpxy/run.sh 2>&1" fi diff --git a/docker/run.sh b/docker/run.sh new file mode 100644 index 0000000..6f83ff8 --- /dev/null +++ b/docker/run.sh @@ -0,0 +1,10 @@ +#!/usr/bin/env sh +CONFIG_FILE=/etc/rpxy.toml + +# debug level logging +if [ -z $LOG_LEVEL ]; then + LOG_LEVEL=info +fi +echo "rpxy: Logging with level ${LOG_LEVEL}" + +RUST_LOG=${LOG_LEVEL} /rpxy/bin/rpxy --config ${CONFIG_FILE} diff --git a/h3 b/h3 index 49301f1..3ef7c1a 160000 --- a/h3 +++ b/h3 @@ -1 +1 @@ -Subproject commit 49301f18e15d3acffc2a8d8bea1a8038c5f3fe6d +Subproject commit 3ef7c1a37b635e8446322d8f8d3a68580a208ad8 diff --git a/h3-quinn/Cargo.toml b/h3-quinn/Cargo.toml new file mode 100644 index 0000000..df34822 --- /dev/null +++ b/h3-quinn/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "h3-quinn" +version = "0.0.1" +rust-version = "1.59" +authors = ["Jean-Christophe BEGUE "] +edition = "2018" +documentation = "https://docs.rs/h3-quinn" +repository = "https://github.com/hyperium/h3" +readme = "../README.md" +description = "QUIC transport implementation based on Quinn." +keywords = ["http3", "quic", "h3"] +categories = ["network-programming", "web-programming"] +license = "MIT" + +[dependencies] +h3 = { version = "0.0.2", path = "../h3/h3" } +bytes = "1" +quinn = { path = "../quinn/quinn/", default-features = false, features = [ + "futures-io", +] } +quinn-proto = { path = "../quinn/quinn-proto/", default-features = false } +tokio-util = { version = "0.7.8" } +futures = { version = "0.3.27" } +tokio = { version = "1.28", features = ["io-util"], default-features = false } diff --git a/h3-quinn/src/lib.rs b/h3-quinn/src/lib.rs new file mode 100644 index 0000000..78696de --- /dev/null +++ b/h3-quinn/src/lib.rs @@ -0,0 +1,740 @@ +//! QUIC Transport implementation with Quinn +//! +//! This module implements QUIC traits with Quinn. +#![deny(missing_docs)] + +use std::{ + convert::TryInto, + fmt::{self, Display}, + future::Future, + pin::Pin, + sync::Arc, + task::{self, Poll}, +}; + +use bytes::{Buf, Bytes, BytesMut}; + +use futures::{ + ready, + stream::{self, BoxStream}, + StreamExt, +}; +use quinn::ReadDatagram; +pub use quinn::{ + self, crypto::Session, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError, +}; + +use h3::{ + ext::Datagram, + quic::{self, Error, StreamId, WriteBuf}, +}; +use tokio_util::sync::ReusableBoxFuture; + +/// A QUIC connection backed by Quinn +/// +/// Implements a [`quic::Connection`] backed by a [`quinn::Connection`]. +pub struct Connection { + conn: quinn::Connection, + incoming_bi: BoxStream<'static, as Future>::Output>, + opening_bi: Option as Future>::Output>>, + incoming_uni: BoxStream<'static, as Future>::Output>, + opening_uni: Option as Future>::Output>>, + datagrams: BoxStream<'static, as Future>::Output>, +} + +impl Connection { + /// Create a [`Connection`] from a [`quinn::NewConnection`] + pub fn new(conn: quinn::Connection) -> Self { + Self { + conn: conn.clone(), + incoming_bi: Box::pin(stream::unfold(conn.clone(), |conn| async { + Some((conn.accept_bi().await, conn)) + })), + opening_bi: None, + incoming_uni: Box::pin(stream::unfold(conn.clone(), |conn| async { + Some((conn.accept_uni().await, conn)) + })), + opening_uni: None, + datagrams: Box::pin(stream::unfold(conn, |conn| async { + Some((conn.read_datagram().await, conn)) + })), + } + } +} + +/// The error type for [`Connection`] +/// +/// Wraps reasons a Quinn connection might be lost. +#[derive(Debug)] +pub struct ConnectionError(quinn::ConnectionError); + +impl std::error::Error for ConnectionError {} + +impl fmt::Display for ConnectionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl Error for ConnectionError { + fn is_timeout(&self) -> bool { + matches!(self.0, quinn::ConnectionError::TimedOut) + } + + fn err_code(&self) -> Option { + match self.0 { + quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { + error_code, + .. + }) => Some(error_code.into_inner()), + _ => None, + } + } +} + +impl From for ConnectionError { + fn from(e: quinn::ConnectionError) -> Self { + Self(e) + } +} + +/// Types of errors when sending a datagram. +#[derive(Debug)] +pub enum SendDatagramError { + /// Datagrams are not supported by the peer + UnsupportedByPeer, + /// Datagrams are locally disabled + Disabled, + /// The datagram was too large to be sent. + TooLarge, + /// Network error + ConnectionLost(Box), +} + +impl fmt::Display for SendDatagramError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SendDatagramError::UnsupportedByPeer => write!(f, "datagrams not supported by peer"), + SendDatagramError::Disabled => write!(f, "datagram support disabled"), + SendDatagramError::TooLarge => write!(f, "datagram too large"), + SendDatagramError::ConnectionLost(_) => write!(f, "connection lost"), + } + } +} + +impl std::error::Error for SendDatagramError {} + +impl Error for SendDatagramError { + fn is_timeout(&self) -> bool { + false + } + + fn err_code(&self) -> Option { + match self { + Self::ConnectionLost(err) => err.err_code(), + _ => None, + } + } +} + +impl From for SendDatagramError { + fn from(value: quinn::SendDatagramError) -> Self { + match value { + quinn::SendDatagramError::UnsupportedByPeer => Self::UnsupportedByPeer, + quinn::SendDatagramError::Disabled => Self::Disabled, + quinn::SendDatagramError::TooLarge => Self::TooLarge, + quinn::SendDatagramError::ConnectionLost(err) => { + Self::ConnectionLost(ConnectionError::from(err).into()) + } + } + } +} + +impl quic::Connection for Connection +where + B: Buf, +{ + type SendStream = SendStream; + type RecvStream = RecvStream; + type BidiStream = BidiStream; + type OpenStreams = OpenStreams; + type Error = ConnectionError; + + fn poll_accept_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let (send, recv) = match ready!(self.incoming_bi.poll_next_unpin(cx)) { + Some(x) => x?, + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::BidiStream { + send: Self::SendStream::new(send), + recv: Self::RecvStream::new(recv), + }))) + } + + fn poll_accept_recv( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + let recv = match ready!(self.incoming_uni.poll_next_unpin(cx)) { + Some(x) => x?, + None => return Poll::Ready(Ok(None)), + }; + Poll::Ready(Ok(Some(Self::RecvStream::new(recv)))) + } + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + if self.opening_bi.is_none() { + self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.clone().open_bi().await, conn)) + }))); + } + + let (send, recv) = + ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; + Poll::Ready(Ok(Self::BidiStream { + send: Self::SendStream::new(send), + recv: Self::RecvStream::new(recv), + })) + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + if self.opening_uni.is_none() { + self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.open_uni().await, conn)) + }))); + } + + let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; + Poll::Ready(Ok(Self::SendStream::new(send))) + } + + fn opener(&self) -> Self::OpenStreams { + OpenStreams { + conn: self.conn.clone(), + opening_bi: None, + opening_uni: None, + } + } + + fn close(&mut self, code: h3::error::Code, reason: &[u8]) { + self.conn.close( + VarInt::from_u64(code.value()).expect("error code VarInt"), + reason, + ); + } +} + +impl quic::SendDatagramExt for Connection +where + B: Buf, +{ + type Error = SendDatagramError; + + fn send_datagram(&mut self, data: Datagram) -> Result<(), SendDatagramError> { + // TODO investigate static buffer from known max datagram size + let mut buf = BytesMut::new(); + data.encode(&mut buf); + self.conn.send_datagram(buf.freeze())?; + + Ok(()) + } +} + +impl quic::RecvDatagramExt for Connection { + type Buf = Bytes; + + type Error = ConnectionError; + + #[inline] + fn poll_accept_datagram( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + match ready!(self.datagrams.poll_next_unpin(cx)) { + Some(Ok(x)) => Poll::Ready(Ok(Some(x))), + Some(Err(e)) => Poll::Ready(Err(e.into())), + None => Poll::Ready(Ok(None)), + } + } +} + +/// Stream opener backed by a Quinn connection +/// +/// Implements [`quic::OpenStreams`] using [`quinn::Connection`], +/// [`quinn::OpenBi`], [`quinn::OpenUni`]. +pub struct OpenStreams { + conn: quinn::Connection, + opening_bi: Option as Future>::Output>>, + opening_uni: Option as Future>::Output>>, +} + +impl quic::OpenStreams for OpenStreams +where + B: Buf, +{ + type RecvStream = RecvStream; + type SendStream = SendStream; + type BidiStream = BidiStream; + type Error = ConnectionError; + + fn poll_open_bidi( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + if self.opening_bi.is_none() { + self.opening_bi = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.open_bi().await, conn)) + }))); + } + + let (send, recv) = + ready!(self.opening_bi.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; + Poll::Ready(Ok(Self::BidiStream { + send: Self::SendStream::new(send), + recv: Self::RecvStream::new(recv), + })) + } + + fn poll_open_send( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll> { + if self.opening_uni.is_none() { + self.opening_uni = Some(Box::pin(stream::unfold(self.conn.clone(), |conn| async { + Some((conn.open_uni().await, conn)) + }))); + } + + let send = ready!(self.opening_uni.as_mut().unwrap().poll_next_unpin(cx)).unwrap()?; + Poll::Ready(Ok(Self::SendStream::new(send))) + } + + fn close(&mut self, code: h3::error::Code, reason: &[u8]) { + self.conn.close( + VarInt::from_u64(code.value()).expect("error code VarInt"), + reason, + ); + } +} + +impl Clone for OpenStreams { + fn clone(&self) -> Self { + Self { + conn: self.conn.clone(), + opening_bi: None, + opening_uni: None, + } + } +} + +/// Quinn-backed bidirectional stream +/// +/// Implements [`quic::BidiStream`] which allows the stream to be split +/// into two structs each implementing one direction. +pub struct BidiStream +where + B: Buf, +{ + send: SendStream, + recv: RecvStream, +} + +impl quic::BidiStream for BidiStream +where + B: Buf, +{ + type SendStream = SendStream; + type RecvStream = RecvStream; + + fn split(self) -> (Self::SendStream, Self::RecvStream) { + (self.send, self.recv) + } +} + +impl quic::RecvStream for BidiStream { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + self.recv.poll_data(cx) + } + + fn stop_sending(&mut self, error_code: u64) { + self.recv.stop_sending(error_code) + } + + fn recv_id(&self) -> StreamId { + self.recv.recv_id() + } +} + +impl quic::SendStream for BidiStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_ready(cx) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.send.poll_finish(cx) + } + + fn reset(&mut self, reset_code: u64) { + self.send.reset(reset_code) + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + self.send.send_data(data) + } + + fn send_id(&self) -> StreamId { + self.send.send_id() + } +} +impl quic::SendStreamUnframed for BidiStream +where + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + self.send.poll_send(cx, buf) + } +} + +/// Quinn-backed receive stream +/// +/// Implements a [`quic::RecvStream`] backed by a [`quinn::RecvStream`]. +pub struct RecvStream { + stream: Option, + read_chunk_fut: ReadChunkFuture, +} + +type ReadChunkFuture = ReusableBoxFuture< + 'static, + ( + quinn::RecvStream, + Result, quinn::ReadError>, + ), +>; + +impl RecvStream { + fn new(stream: quinn::RecvStream) -> Self { + Self { + stream: Some(stream), + // Should only allocate once the first time it's used + read_chunk_fut: ReusableBoxFuture::new(async { unreachable!() }), + } + } +} + +impl quic::RecvStream for RecvStream { + type Buf = Bytes; + type Error = ReadError; + + fn poll_data( + &mut self, + cx: &mut task::Context<'_>, + ) -> Poll, Self::Error>> { + if let Some(mut stream) = self.stream.take() { + self.read_chunk_fut.set(async move { + let chunk = stream.read_chunk(usize::MAX, true).await; + (stream, chunk) + }) + }; + + let (stream, chunk) = ready!(self.read_chunk_fut.poll(cx)); + self.stream = Some(stream); + Poll::Ready(Ok(chunk?.map(|c| c.bytes))) + } + + fn stop_sending(&mut self, error_code: u64) { + self.stream + .as_mut() + .unwrap() + .stop(VarInt::from_u64(error_code).expect("invalid error_code")) + .ok(); + } + + fn recv_id(&self) -> StreamId { + self.stream + .as_ref() + .unwrap() + .id() + .0 + .try_into() + .expect("invalid stream id") + } +} + +/// The error type for [`RecvStream`] +/// +/// Wraps errors that occur when reading from a receive stream. +#[derive(Debug)] +pub struct ReadError(quinn::ReadError); + +impl From for std::io::Error { + fn from(value: ReadError) -> Self { + value.0.into() + } +} + +impl std::error::Error for ReadError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0.source() + } +} + +impl fmt::Display for ReadError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl From for Arc { + fn from(e: ReadError) -> Self { + Arc::new(e) + } +} + +impl From for ReadError { + fn from(e: quinn::ReadError) -> Self { + Self(e) + } +} + +impl Error for ReadError { + fn is_timeout(&self) -> bool { + matches!( + self.0, + quinn::ReadError::ConnectionLost(quinn::ConnectionError::TimedOut) + ) + } + + fn err_code(&self) -> Option { + match self.0 { + quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( + quinn_proto::ApplicationClose { error_code, .. }, + )) => Some(error_code.into_inner()), + quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), + _ => None, + } + } +} + +/// Quinn-backed send stream +/// +/// Implements a [`quic::SendStream`] backed by a [`quinn::SendStream`]. +pub struct SendStream { + stream: Option, + writing: Option>, + write_fut: WriteFuture, +} + +type WriteFuture = + ReusableBoxFuture<'static, (quinn::SendStream, Result)>; + +impl SendStream +where + B: Buf, +{ + fn new(stream: quinn::SendStream) -> SendStream { + Self { + stream: Some(stream), + writing: None, + write_fut: ReusableBoxFuture::new(async { unreachable!() }), + } + } +} + +impl quic::SendStream for SendStream +where + B: Buf, +{ + type Error = SendStreamError; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + if let Some(ref mut data) = self.writing { + while data.has_remaining() { + if let Some(mut stream) = self.stream.take() { + let chunk = data.chunk().to_owned(); // FIXME - avoid copy + self.write_fut.set(async move { + let ret = stream.write(&chunk).await; + (stream, ret) + }); + } + + let (stream, res) = ready!(self.write_fut.poll(cx)); + self.stream = Some(stream); + match res { + Ok(cnt) => data.advance(cnt), + Err(err) => { + return Poll::Ready(Err(SendStreamError::Write(err))); + } + } + } + } + self.writing = None; + Poll::Ready(Ok(())) + } + + fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.stream + .as_mut() + .unwrap() + .poll_finish(cx) + .map_err(Into::into) + } + + fn reset(&mut self, reset_code: u64) { + let _ = self + .stream + .as_mut() + .unwrap() + .reset(VarInt::from_u64(reset_code).unwrap_or(VarInt::MAX)); + } + + fn send_data>>(&mut self, data: D) -> Result<(), Self::Error> { + if self.writing.is_some() { + return Err(Self::Error::NotReady); + } + self.writing = Some(data.into()); + Ok(()) + } + + fn send_id(&self) -> StreamId { + self.stream + .as_ref() + .unwrap() + .id() + .0 + .try_into() + .expect("invalid stream id") + } +} + +impl quic::SendStreamUnframed for SendStream +where + B: Buf, +{ + fn poll_send( + &mut self, + cx: &mut task::Context<'_>, + buf: &mut D, + ) -> Poll> { + if self.writing.is_some() { + // This signifies a bug in implementation + panic!("poll_send called while send stream is not ready") + } + + let s = Pin::new(self.stream.as_mut().unwrap()); + + let res = ready!(futures::io::AsyncWrite::poll_write(s, cx, buf.chunk())); + match res { + Ok(written) => { + buf.advance(written); + Poll::Ready(Ok(written)) + } + Err(err) => { + // We are forced to use AsyncWrite for now because we cannot store + // the result of a call to: + // quinn::send_stream::write<'a>(&'a mut self, buf: &'a [u8]) -> Result. + // + // This is why we have to unpack the error from io::Error instead of having it + // returned directly. This should not panic as long as quinn's AsyncWrite impl + // doesn't change. + let err = err + .into_inner() + .expect("write stream returned an empty error") + .downcast::() + .expect("write stream returned an error which type is not WriteError"); + + Poll::Ready(Err(SendStreamError::Write(*err))) + } + } + } +} + +/// The error type for [`SendStream`] +/// +/// Wraps errors that can happen writing to or polling a send stream. +#[derive(Debug)] +pub enum SendStreamError { + /// Errors when writing, wrapping a [`quinn::WriteError`] + Write(WriteError), + /// Error when the stream is not ready, because it is still sending + /// data from a previous call + NotReady, +} + +impl From for std::io::Error { + fn from(value: SendStreamError) -> Self { + match value { + SendStreamError::Write(err) => err.into(), + SendStreamError::NotReady => { + std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") + } + } + } +} + +impl std::error::Error for SendStreamError {} + +impl Display for SendStreamError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self) + } +} + +impl From for SendStreamError { + fn from(e: WriteError) -> Self { + Self::Write(e) + } +} + +impl Error for SendStreamError { + fn is_timeout(&self) -> bool { + matches!( + self, + Self::Write(quinn::WriteError::ConnectionLost( + quinn::ConnectionError::TimedOut + )) + ) + } + + fn err_code(&self) -> Option { + match self { + Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), + Self::Write(quinn::WriteError::ConnectionLost( + quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { + error_code, + .. + }), + )) => Some(error_code.into_inner()), + _ => None, + } + } +} + +impl From for Arc { + fn from(e: SendStreamError) -> Self { + Arc::new(e) + } +} diff --git a/quinn b/quinn new file mode 160000 index 0000000..7914468 --- /dev/null +++ b/quinn @@ -0,0 +1 @@ +Subproject commit 7914468e27621633a8399c8d02fbf3f557d54df2 diff --git a/src/backend/load_balance.rs b/src/backend/load_balance.rs new file mode 100644 index 0000000..5d93f0a --- /dev/null +++ b/src/backend/load_balance.rs @@ -0,0 +1,135 @@ +#[cfg(feature = "sticky-cookie")] +pub use super::{ + load_balance_sticky::{LbStickyRoundRobin, LbStickyRoundRobinBuilder}, + sticky_cookie::StickyCookie, +}; +use derive_builder::Builder; +use rand::Rng; +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, +}; + +/// Constants to specify a load balance option +pub(super) mod load_balance_options { + pub const FIX_TO_FIRST: &str = "none"; + pub const ROUND_ROBIN: &str = "round_robin"; + pub const RANDOM: &str = "random"; + #[cfg(feature = "sticky-cookie")] + pub const STICKY_ROUND_ROBIN: &str = "sticky"; +} + +#[derive(Debug, Clone)] +/// Pointer to upstream serving the incoming request. +/// If 'sticky cookie'-based LB is enabled and cookie must be updated/created, the new cookie is also given. +pub(super) struct PointerToUpstream { + pub ptr: usize, + pub context_lb: Option, +} +/// Trait for LB +pub(super) trait LbWithPointer { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream; +} + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object as a pointer to the current serving upstream destination +pub struct LbRoundRobin { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LbRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LbWithPointer for LbRoundRobin { + /// Increment the count of upstream served up to the max value + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + let ptr = if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + }; + PointerToUpstream { ptr, context_lb: None } + } +} + +#[derive(Debug, Clone, Builder)] +/// Random LB object to keep the object of random pools +pub struct LbRandom { + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, +} +impl LbRandomBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } +} +impl LbWithPointer for LbRandom { + /// Returns the random index within the range + fn get_ptr(&self, _info: Option<&LbContext>) -> PointerToUpstream { + let mut rng = rand::thread_rng(); + let ptr = rng.gen_range(0..self.num_upstreams); + PointerToUpstream { ptr, context_lb: None } + } +} + +#[derive(Debug, Clone)] +/// Load Balancing Option +pub enum LoadBalance { + /// Fix to the first upstream. Use if only one upstream destination is specified + FixToFirst, + /// Randomly chose one upstream server + Random(LbRandom), + /// Simple round robin without session persistance + RoundRobin(LbRoundRobin), + #[cfg(feature = "sticky-cookie")] + /// Round robin with session persistance using cookie + StickyRoundRobin(LbStickyRoundRobin), +} +impl Default for LoadBalance { + fn default() -> Self { + Self::FixToFirst + } +} + +impl LoadBalance { + /// Get the index of the upstream serving the incoming request + pub(super) fn get_context(&self, _context_to_lb: &Option) -> PointerToUpstream { + match self { + LoadBalance::FixToFirst => PointerToUpstream { + ptr: 0usize, + context_lb: None, + }, + LoadBalance::RoundRobin(ptr) => ptr.get_ptr(None), + LoadBalance::Random(ptr) => ptr.get_ptr(None), + #[cfg(feature = "sticky-cookie")] + LoadBalance::StickyRoundRobin(ptr) => { + // Generate new context if sticky round robin is enabled. + ptr.get_ptr(_context_to_lb.as_ref()) + } + } + } +} + +#[derive(Debug, Clone)] +/// Struct to handle the sticky cookie string, +/// - passed from Rp module (http handler) to LB module, manipulated from req, only StickyCookieValue exists. +/// - passed from LB module to Rp module (http handler), will be inserted into res, StickyCookieValue and Info exist. +pub struct LbContext { + #[cfg(feature = "sticky-cookie")] + pub sticky_cookie: StickyCookie, + #[cfg(not(feature = "sticky-cookie"))] + pub sticky_cookie: (), +} diff --git a/src/backend/load_balance_sticky.rs b/src/backend/load_balance_sticky.rs new file mode 100644 index 0000000..32f4fe5 --- /dev/null +++ b/src/backend/load_balance_sticky.rs @@ -0,0 +1,132 @@ +use super::{ + load_balance::{LbContext, LbWithPointer, PointerToUpstream}, + sticky_cookie::StickyCookieConfig, + Upstream, +}; +use crate::{constants::STICKY_COOKIE_NAME, log::*}; +use derive_builder::Builder; +use rustc_hash::FxHashMap as HashMap; +use std::{ + borrow::Cow, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, +}; + +#[derive(Debug, Clone, Builder)] +/// Round Robin LB object in the sticky cookie manner +pub struct LbStickyRoundRobin { + #[builder(default)] + /// Pointer to the index of the last served upstream destination + ptr: Arc, + #[builder(setter(custom), default)] + /// Number of upstream destinations + num_upstreams: usize, + #[builder(setter(custom))] + /// Information to build the cookie to stick clients to specific backends + pub sticky_config: StickyCookieConfig, + #[builder(setter(custom))] + /// Hashmaps: + /// - Hashmap that maps server indices to server id (string) + /// - Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_maps: UpstreamMap, +} +#[derive(Debug, Clone)] +pub struct UpstreamMap { + /// Hashmap that maps server indices to server id (string) + upstream_index_map: Vec, + /// Hashmap that maps server ids (string) to server indices, for fast reverse lookup + upstream_id_map: HashMap, +} +impl LbStickyRoundRobinBuilder { + pub fn num_upstreams(&mut self, v: &usize) -> &mut Self { + self.num_upstreams = Some(*v); + self + } + pub fn sticky_config(&mut self, server_name: &str, path_opt: &Option) -> &mut Self { + self.sticky_config = Some(StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), // TODO: config等で変更できるように + domain: server_name.to_ascii_lowercase(), + path: if let Some(v) = path_opt { + v.to_ascii_lowercase() + } else { + "/".to_string() + }, + duration: 300, // TODO: config等で変更できるように + }); + self + } + pub fn upstream_maps(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + let upstream_index_map: Vec = upstream_vec + .iter() + .enumerate() + .map(|(i, v)| v.calculate_id_with_index(i)) + .collect(); + let mut upstream_id_map = HashMap::default(); + for (i, v) in upstream_index_map.iter().enumerate() { + upstream_id_map.insert(v.to_string(), i); + } + self.upstream_maps = Some(UpstreamMap { + upstream_index_map, + upstream_id_map, + }); + self + } +} +impl<'a> LbStickyRoundRobin { + fn simple_increment_ptr(&self) -> usize { + // Get a current count of upstream served + let current_ptr = self.ptr.load(Ordering::Relaxed); + + if current_ptr < self.num_upstreams - 1 { + self.ptr.fetch_add(1, Ordering::Relaxed) + } else { + // Clear the counter + self.ptr.fetch_and(0, Ordering::Relaxed) + } + } + /// This is always called only internally. So 'unwrap()' is executed. + fn get_server_id_from_index(&self, index: usize) -> String { + self.upstream_maps.upstream_index_map.get(index).unwrap().to_owned() + } + /// This function takes value passed from outside. So 'result' is used. + fn get_server_index_from_id(&self, id: impl Into>) -> Option { + let id_str = id.into().to_string(); + self.upstream_maps.upstream_id_map.get(&id_str).map(|v| v.to_owned()) + } +} +impl LbWithPointer for LbStickyRoundRobin { + fn get_ptr(&self, req_info: Option<&LbContext>) -> PointerToUpstream { + // If given context is None or invalid (not contained), get_ptr() is invoked to increment the pointer. + // Otherwise, get the server index indicated by the server_id inside the cookie + let ptr = match req_info { + None => { + debug!("No sticky cookie"); + self.simple_increment_ptr() + } + Some(context) => { + let server_id = &context.sticky_cookie.value.value; + if let Some(server_index) = self.get_server_index_from_id(server_id) { + debug!("Valid sticky cookie: id={}, index={}", server_id, server_index); + server_index + } else { + debug!("Invalid sticky cookie: id={}", server_id); + self.simple_increment_ptr() + } + } + }; + + // Get the server id from the ptr. + // TODO: This should be simplified and optimized if ptr is not changed (id value exists in cookie). + let upstream_id = self.get_server_id_from_index(ptr); + let new_cookie = self.sticky_config.build_sticky_cookie(upstream_id).unwrap(); + let new_context = Some(LbContext { + sticky_cookie: new_cookie, + }); + PointerToUpstream { + ptr, + context_lb: new_context, + } + } +} diff --git a/src/backend/mod.rs b/src/backend/mod.rs index 6a92ec8..b7923c5 100644 --- a/src/backend/mod.rs +++ b/src/backend/mod.rs @@ -1,6 +1,18 @@ +mod load_balance; +#[cfg(feature = "sticky-cookie")] +mod load_balance_sticky; +#[cfg(feature = "sticky-cookie")] +mod sticky_cookie; mod upstream; mod upstream_opts; +#[cfg(feature = "sticky-cookie")] +pub use self::sticky_cookie::{StickyCookie, StickyCookieValue}; +pub use self::{ + load_balance::{LbContext, LoadBalance}, + upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}, + upstream_opts::UpstreamOption, +}; use crate::{ log::*, utils::{BytesName, PathNameBytesExp, ServerNameBytesExp}, @@ -20,20 +32,21 @@ use tokio_rustls::rustls::{ sign::{any_supported_type, CertifiedKey}, Certificate, PrivateKey, ServerConfig, }; -pub use upstream::{ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder}; -pub use upstream_opts::UpstreamOption; use x509_parser::prelude::*; /// Struct serving information to route incoming connections, like server name to be handled and tls certs/keys settings. #[derive(Builder)] pub struct Backend { #[builder(setter(into))] + /// backend application name, e.g., app1 pub app_name: String, #[builder(setter(custom))] + /// server name, e.g., example.com, in String ascii lower case pub server_name: String, + /// struct of reverse proxy serving incoming request pub reverse_proxy: ReverseProxy, - // tls settings + /// tls settings #[builder(setter(custom), default)] pub tls_cert_path: Option, #[builder(setter(custom), default)] @@ -69,12 +82,9 @@ fn opt_string_to_opt_pathbuf(input: &Option) -> Option { impl Backend { pub fn read_certs_and_key(&self) -> io::Result { debug!("Read TLS server certificates and private key"); - let (certs_path, certs_keys_path) = - if let (Some(c), Some(k)) = (self.tls_cert_path.as_ref(), self.tls_cert_key_path.as_ref()) { - (c, k) - } else { - return Err(io::Error::new(io::ErrorKind::Other, "Invalid certs and keys paths")); - }; + let (Some(certs_path), Some(certs_keys_path)) = (self.tls_cert_path.as_ref(), self.tls_cert_key_path.as_ref()) else { + return Err(io::Error::new(io::ErrorKind::Other, "Invalid certs and keys paths")); + }; let certs: Vec<_> = { let certs_path_str = certs_path.display().to_string(); let mut reader = BufReader::new(File::open(certs_path).map_err(|e| { @@ -144,11 +154,10 @@ impl Backend { debug!("Read CA certificates for client authentication"); // Reads client certificate and returns client let client_ca_cert_path = { - if let Some(c) = self.client_ca_cert_path.as_ref() { - c - } else { + let Some(c) = self.client_ca_cert_path.as_ref() else { return Err(io::Error::new(io::ErrorKind::Other, "Invalid certs and keys paths")); - } + }; + c }; let certs: Vec<_> = { let certs_path_str = client_ca_cert_path.display().to_string(); @@ -168,7 +177,8 @@ impl Backend { let owned_trust_anchors: Vec<_> = certs .iter() .map(|v| { - let trust_anchor = tokio_rustls::webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); + // let trust_anchor = tokio_rustls::webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); + let trust_anchor = webpki::TrustAnchor::try_from_cert_der(&v.0).unwrap(); rustls::OwnedTrustAnchor::from_subject_spki_name_constraints( trust_anchor.subject, trust_anchor.spki, @@ -264,22 +274,29 @@ impl Backends { let mut server_config_local = if client_ca_roots_local.is_empty() { // with no client auth, enable http1.1 -- 3 - let mut sc = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_cert_resolver(Arc::new(resolver_local)); + #[cfg(not(feature = "http3"))] + { + ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver_local)) + } #[cfg(feature = "http3")] { + let mut sc = ServerConfig::builder() + .with_safe_defaults() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver_local)); sc.alpn_protocols = vec![b"h3".to_vec(), b"hq-29".to_vec()]; // TODO: remove hq-29 later? + sc } - sc } else { // with client auth, enable only http1.1 and 2 // let client_certs_verifier = rustls::server::AllowAnyAnonymousOrAuthenticatedClient::new(client_ca_roots); let client_certs_verifier = rustls::server::AllowAnyAuthenticatedClient::new(client_ca_roots_local); ServerConfig::builder() .with_safe_defaults() - .with_client_cert_verifier(client_certs_verifier) + .with_client_cert_verifier(Arc::new(client_certs_verifier)) .with_cert_resolver(Arc::new(resolver_local)) }; server_config_local.alpn_protocols.push(b"h2".to_vec()); @@ -314,7 +331,7 @@ impl Backends { } #[cfg(not(feature = "http3"))] { - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; + server_crypto_global.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec()]; } Ok(ServerCrypto { diff --git a/src/backend/sticky_cookie.rs b/src/backend/sticky_cookie.rs new file mode 100644 index 0000000..998426b --- /dev/null +++ b/src/backend/sticky_cookie.rs @@ -0,0 +1,208 @@ +use std::borrow::Cow; + +use crate::error::*; +use chrono::{TimeZone, Utc}; +use derive_builder::Builder; + +#[derive(Debug, Clone, Builder)] +/// Cookie value only, used for COOKIE in req +pub struct StickyCookieValue { + #[builder(setter(custom))] + /// Field name indicating sticky cookie + pub name: String, + #[builder(setter(custom))] + /// Upstream server_id + pub value: String, +} +impl<'a> StickyCookieValueBuilder { + pub fn name(&mut self, v: impl Into>) -> &mut Self { + self.name = Some(v.into().to_ascii_lowercase()); + self + } + pub fn value(&mut self, v: impl Into>) -> &mut Self { + self.value = Some(v.into().to_string()); + self + } +} +impl StickyCookieValue { + pub fn try_from(value: &str, expected_name: &str) -> Result { + if !value.starts_with(expected_name) { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion from string".to_string(), + )); + }; + let kv = value.split('=').map(|v| v.trim()).collect::>(); + if kv.len() != 2 { + return Err(RpxyError::LoadBalance("Invalid cookie structure".to_string())); + }; + if kv[1].is_empty() { + return Err(RpxyError::LoadBalance("No sticky cookie value".to_string())); + } + Ok(StickyCookieValue { + name: expected_name.to_string(), + value: kv[1].to_string(), + }) + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie meta information used for SET-COOKIE in res +pub struct StickyCookieInfo { + #[builder(setter(custom))] + /// Unix time + pub expires: i64, + + #[builder(setter(custom))] + /// Domain + pub domain: String, + + #[builder(setter(custom))] + /// Path + pub path: String, +} +impl<'a> StickyCookieInfoBuilder { + pub fn domain(&mut self, v: impl Into>) -> &mut Self { + self.domain = Some(v.into().to_ascii_lowercase()); + self + } + pub fn path(&mut self, v: impl Into>) -> &mut Self { + self.path = Some(v.into().to_ascii_lowercase()); + self + } + pub fn expires(&mut self, duration_secs: i64) -> &mut Self { + let current = Utc::now().timestamp(); + self.expires = Some(current + duration_secs); + self + } +} + +#[derive(Debug, Clone, Builder)] +/// Struct describing sticky cookie +pub struct StickyCookie { + #[builder(setter(custom))] + /// Upstream server_id + pub value: StickyCookieValue, + #[builder(setter(custom), default)] + /// Upstream server_id + pub info: Option, +} + +impl<'a> StickyCookieBuilder { + pub fn value(&mut self, n: impl Into>, v: impl Into>) -> &mut Self { + self.value = Some(StickyCookieValueBuilder::default().name(n).value(v).build().unwrap()); + self + } + pub fn info( + &mut self, + domain: impl Into>, + path: impl Into>, + duration_secs: i64, + ) -> &mut Self { + let info = StickyCookieInfoBuilder::default() + .domain(domain) + .path(path) + .expires(duration_secs) + .build() + .unwrap(); + self.info = Some(Some(info)); + self + } +} + +impl TryInto for StickyCookie { + type Error = RpxyError; + + fn try_into(self) -> Result { + if self.info.is_none() { + return Err(RpxyError::LoadBalance( + "Failed to cookie conversion into string: no meta information".to_string(), + )); + } + let info = self.info.unwrap(); + let chrono::LocalResult::Single(expires_timestamp) = Utc.timestamp_opt(info.expires, 0) else { + return Err(RpxyError::LoadBalance("Failed to cookie conversion into string".to_string())); + }; + let exp_str = expires_timestamp.format("%a, %d-%b-%Y %T GMT").to_string(); + let max_age = info.expires - Utc::now().timestamp(); + + Ok(format!( + "{}={}; expires={}; Max-Age={}; path={}; domain={}", + self.value.name, self.value.value, exp_str, max_age, info.path, info.domain + )) + } +} + +#[derive(Debug, Clone)] +/// Configuration to serve incoming requests in the manner of "sticky cookie". +/// Including a dictionary to map Ids included in cookie and upstream destinations, +/// and expiration of cookie. +/// "domain" and "path" in the cookie will be the same as the reverse proxy options. +pub struct StickyCookieConfig { + pub name: String, + pub domain: String, + pub path: String, + pub duration: i64, +} +impl<'a> StickyCookieConfig { + pub fn build_sticky_cookie(&self, v: impl Into>) -> Result { + StickyCookieBuilder::default() + .value(self.name.clone(), v) + .info(&self.domain, &self.path, self.duration) + .build() + .map_err(|_| RpxyError::LoadBalance("Failed to build sticky cookie from config".to_string())) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::constants::STICKY_COOKIE_NAME; + + #[test] + fn config_works() { + let config = StickyCookieConfig { + name: STICKY_COOKIE_NAME.to_string(), + domain: "example.com".to_string(), + path: "/path".to_string(), + duration: 100, + }; + let expires_unix = Utc::now().timestamp() + 100; + let sc_string: Result = config.build_sticky_cookie("test_value").unwrap().try_into(); + let expires_date_string = Utc + .timestamp_opt(expires_unix, 0) + .unwrap() + .format("%a, %d-%b-%Y %T GMT") + .to_string(); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires={}; Max-Age={}; path=/path; domain=example.com", + STICKY_COOKIE_NAME, expires_date_string, 100 + ) + ); + } + #[test] + fn to_string_works() { + let sc = StickyCookie { + value: StickyCookieValue { + name: STICKY_COOKIE_NAME.to_string(), + value: "test_value".to_string(), + }, + info: Some(StickyCookieInfo { + expires: 1686221173i64, + domain: "example.com".to_string(), + path: "/path".to_string(), + }), + }; + let sc_string: Result = sc.try_into(); + let max_age = 1686221173i64 - Utc::now().timestamp(); + assert!(sc_string.is_ok()); + assert_eq!( + sc_string.unwrap(), + format!( + "{}=test_value; expires=Thu, 08-Jun-2023 10:46:13 GMT; Max-Age={}; path=/path; domain=example.com", + STICKY_COOKIE_NAME, max_age + ) + ); + } +} diff --git a/src/backend/upstream.rs b/src/backend/upstream.rs index c2fdb34..2bfd2d6 100644 --- a/src/backend/upstream.rs +++ b/src/backend/upstream.rs @@ -1,22 +1,22 @@ -use super::{BytesName, PathNameBytesExp, UpstreamOption}; +#[cfg(feature = "sticky-cookie")] +use super::load_balance::LbStickyRoundRobinBuilder; +use super::load_balance::{load_balance_options as lb_opts, LbRandomBuilder, LbRoundRobinBuilder, LoadBalance}; +use super::{BytesName, LbContext, PathNameBytesExp, UpstreamOption}; use crate::log::*; +#[cfg(feature = "sticky-cookie")] +use base64::{engine::general_purpose, Engine as _}; use derive_builder::Builder; -use rand::Rng; use rustc_hash::{FxHashMap as HashMap, FxHashSet as HashSet}; -use std::{ - borrow::Cow, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, -}; - +#[cfg(feature = "sticky-cookie")] +use sha2::{Digest, Sha256}; +use std::borrow::Cow; #[derive(Debug, Clone)] pub struct ReverseProxy { pub upstream: HashMap, // TODO: HashMapでいいのかは疑問。max_by_keyでlongest prefix matchしてるのも無駄っぽいが。。。 } impl ReverseProxy { + /// Get an appropriate upstream destination for given path string. pub fn get<'a>(&self, path_str: impl Into>) -> Option<&UpstreamGroup> { // trie使ってlongest prefix match させてもいいけどルート記述は少ないと思われるので、 // コスト的にこの程度で十分 @@ -50,38 +50,49 @@ impl ReverseProxy { } } -#[allow(dead_code)] #[derive(Debug, Clone)] -pub enum LoadBalance { - RoundRobin, - Random, +/// Upstream struct just containing uri without path +pub struct Upstream { + /// Base uri without specific path + pub uri: hyper::Uri, } -impl Default for LoadBalance { - fn default() -> Self { - Self::RoundRobin +impl Upstream { + #[cfg(feature = "sticky-cookie")] + /// Hashing uri with index to avoid collision + pub fn calculate_id_with_index(&self, index: usize) -> String { + let mut hasher = Sha256::new(); + let uri_string = format!("{}&index={}", self.uri.clone(), index); + hasher.update(uri_string.as_bytes()); + let digest = hasher.finalize(); + general_purpose::URL_SAFE_NO_PAD.encode(digest) } } - -#[derive(Debug, Clone)] -pub struct Upstream { - pub uri: hyper::Uri, // base uri without specific path -} - #[derive(Debug, Clone, Builder)] +/// Struct serving multiple upstream servers for, e.g., load balancing. pub struct UpstreamGroup { + #[builder(setter(custom))] + /// Upstream server(s) pub upstream: Vec, #[builder(setter(custom), default)] + /// Path like "/path" in [[PathNameBytesExp]] associated with the upstream server(s) pub path: PathNameBytesExp, #[builder(setter(custom), default)] + /// Path in [[PathNameBytesExp]] that will be used to replace the "path" part of incoming url pub replace_path: Option, - #[builder(default)] - pub lb: LoadBalance, - #[builder(default)] - pub cnt: UpstreamCount, // counter for load balancing + #[builder(setter(custom), default)] + /// Load balancing option + pub lb: LoadBalance, + #[builder(setter(custom), default)] + /// Activated upstream options defined in [[UpstreamOption]] pub opts: HashSet, } + impl UpstreamGroupBuilder { + pub fn upstream(&mut self, upstream_vec: &[Upstream]) -> &mut Self { + self.upstream = Some(upstream_vec.to_vec()); + self + } pub fn path(&mut self, v: &Option) -> &mut Self { let path = match v { Some(p) => p.to_path_name_vec(), @@ -98,6 +109,45 @@ impl UpstreamGroupBuilder { ); self } + pub fn lb( + &mut self, + v: &Option, + // upstream_num: &usize, + upstream_vec: &Vec, + _server_name: &str, + _path_opt: &Option, + ) -> &mut Self { + let upstream_num = &upstream_vec.len(); + let lb = if let Some(x) = v { + match x.as_str() { + lb_opts::FIX_TO_FIRST => LoadBalance::FixToFirst, + lb_opts::RANDOM => LoadBalance::Random(LbRandomBuilder::default().num_upstreams(upstream_num).build().unwrap()), + lb_opts::ROUND_ROBIN => LoadBalance::RoundRobin( + LbRoundRobinBuilder::default() + .num_upstreams(upstream_num) + .build() + .unwrap(), + ), + #[cfg(feature = "sticky-cookie")] + lb_opts::STICKY_ROUND_ROBIN => LoadBalance::StickyRoundRobin( + LbStickyRoundRobinBuilder::default() + .num_upstreams(upstream_num) + .sticky_config(_server_name, _path_opt) + .upstream_maps(upstream_vec) // TODO: + .build() + .unwrap(), + ), + _ => { + error!("Specified load balancing option is invalid."); + LoadBalance::default() + } + } + } else { + LoadBalance::default() + }; + self.lb = Some(lb); + self + } pub fn opts(&mut self, v: &Option>) -> &mut Self { let opts = if let Some(opts) = v { opts @@ -112,33 +162,40 @@ impl UpstreamGroupBuilder { } } -#[derive(Debug, Clone, Default)] -pub struct UpstreamCount(Arc); - impl UpstreamGroup { - pub fn get(&self) -> Option<&Upstream> { - match self.lb { - LoadBalance::RoundRobin => { - let idx = self.increment_cnt(); - self.upstream.get(idx) - } - LoadBalance::Random => { - let mut rng = rand::thread_rng(); - let max = self.upstream.len() - 1; - self.upstream.get(rng.gen_range(0..max)) - } - } - } - - fn current_cnt(&self) -> usize { - self.cnt.0.load(Ordering::Relaxed) - } - - fn increment_cnt(&self) -> usize { - if self.current_cnt() < self.upstream.len() - 1 { - self.cnt.0.fetch_add(1, Ordering::Relaxed) - } else { - self.cnt.0.fetch_and(0, Ordering::Relaxed) - } + /// Get an enabled option of load balancing [[LoadBalance]] + pub fn get(&self, context_to_lb: &Option) -> (Option<&Upstream>, Option) { + let pointer_to_upstream = self.lb.get_context(context_to_lb); + debug!("Upstream of index {} is chosen.", pointer_to_upstream.ptr); + debug!("Context to LB (Cookie in Req): {:?}", context_to_lb); + debug!( + "Context from LB (Set-Cookie in Res): {:?}", + pointer_to_upstream.context_lb + ); + ( + self.upstream.get(pointer_to_upstream.ptr), + pointer_to_upstream.context_lb, + ) + } +} + +#[cfg(test)] +mod test { + #[allow(unused)] + use super::*; + + #[cfg(feature = "sticky-cookie")] + #[test] + fn calc_id_works() { + let uri = "https://www.rust-lang.org".parse::().unwrap(); + let upstream = Upstream { uri }; + assert_eq!( + "eGsjoPbactQ1eUJjafYjPT3ekYZQkaqJnHdA_FMSkgM", + upstream.calculate_id_with_index(0) + ); + assert_eq!( + "tNVXFJ9eNCT2mFgKbYq35XgH5q93QZtfU8piUiiDxVA", + upstream.calculate_id_with_index(1) + ); } } diff --git a/src/config/parse.rs b/src/config/parse.rs index 935f86c..8e4ddf7 100644 --- a/src/config/parse.rs +++ b/src/config/parse.rs @@ -1,6 +1,6 @@ use super::toml::{ConfigToml, ReverseProxyOption}; use crate::{ - backend::{BackendBuilder, ReverseProxy, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, + backend::{BackendBuilder, ReverseProxy, Upstream, UpstreamGroup, UpstreamGroupBuilder, UpstreamOption}, constants::*, error::*, globals::*, @@ -99,7 +99,7 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro let mut backend_builder = BackendBuilder::default(); // reverse proxy settings ensure!(app.reverse_proxy.is_some(), "Missing reverse_proxy"); - let reverse_proxy = get_reverse_proxy(app.reverse_proxy.as_ref().unwrap())?; + let reverse_proxy = get_reverse_proxy(server_name_string, app.reverse_proxy.as_ref().unwrap())?; backend_builder .app_name(server_name_string) @@ -198,13 +198,21 @@ pub fn parse_opts(globals: &mut Globals) -> std::result::Result<(), anyhow::Erro Ok(()) } -fn get_reverse_proxy(rp_settings: &[ReverseProxyOption]) -> std::result::Result { +fn get_reverse_proxy( + server_name_string: &str, + rp_settings: &[ReverseProxyOption], +) -> std::result::Result { let mut upstream: HashMap = HashMap::default(); + rp_settings.iter().for_each(|rpo| { + let upstream_vec: Vec = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect(); + // let upstream_iter = rpo.upstream.iter().map(|x| x.to_upstream().unwrap()); + // let lb_upstream_num = vec_upstream.len(); let elem = UpstreamGroupBuilder::default() - .upstream(rpo.upstream.iter().map(|x| x.to_upstream().unwrap()).collect()) + .upstream(&upstream_vec) .path(&rpo.path) .replace_path(&rpo.replace_path) + .lb(&rpo.load_balance, &upstream_vec, server_name_string, &rpo.path) .opts(&rpo.upstream_options) .build() .unwrap(); diff --git a/src/config/toml.rs b/src/config/toml.rs index cefacb2..6ce48b2 100644 --- a/src/config/toml.rs +++ b/src/config/toml.rs @@ -57,6 +57,7 @@ pub struct ReverseProxyOption { pub replace_path: Option, pub upstream: Vec, pub upstream_options: Option>, + pub load_balance: Option, } #[derive(Deserialize, Debug, Default)] diff --git a/src/constants.rs b/src/constants.rs index d2fc25f..a29be29 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -24,3 +24,7 @@ pub mod H3 { pub const MAX_CONCURRENT_UNISTREAM: u32 = 64; pub const MAX_IDLE_TIMEOUT: u64 = 10; // secs } + +#[cfg(feature = "sticky-cookie")] +/// For load-balancing with sticky cookie +pub const STICKY_COOKIE_NAME: &str = "rpxy_srv_id"; diff --git a/src/error.rs b/src/error.rs index 6da3b02..3fb3474 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,12 +22,18 @@ pub enum RpxyError { #[error("TCP/UDP Proxy Layer Error: {0}")] Proxy(String), + #[allow(unused)] + #[error("LoadBalance Layer Error: {0}")] + LoadBalance(String), + #[error("I/O Error")] Io(#[from] io::Error), + #[cfg(feature = "http3")] #[error("Quic Connection Error")] QuicConn(#[from] quinn::ConnectionError), + #[cfg(feature = "http3")] #[error("H3 Error")] H3(#[from] h3::Error), diff --git a/src/handler/handler_main.rs b/src/handler/handler_main.rs index 4f60ee5..a73dcbc 100644 --- a/src/handler/handler_main.rs +++ b/src/handler/handler_main.rs @@ -1,5 +1,5 @@ // Highly motivated by https://github.com/felipenoris/hyper-reverse-proxy -use super::{utils_headers::*, utils_request::*, utils_synth_response::*}; +use super::{utils_headers::*, utils_request::*, utils_synth_response::*, HandlerContext}; use crate::{ backend::{Backend, UpstreamGroup}, error::*, @@ -91,7 +91,7 @@ where let request_upgraded = req.extensions_mut().remove::(); // Build request from destination information - if let Err(e) = self.generate_request_forwarded( + let _context = match self.generate_request_forwarded( &client_addr, &listen_addr, &mut req, @@ -99,8 +99,11 @@ where upstream_group, tls_enabled, ) { - error!("Failed to generate destination uri for reverse proxy: {}", e); - return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + Err(e) => { + error!("Failed to generate destination uri for reverse proxy: {}", e); + return self.return_with_error_log(StatusCode::SERVICE_UNAVAILABLE, &mut log_data); + } + Ok(v) => v, }; debug!("Request to be forwarded: {:?}", req); log_data.xff(&req.headers().get("x-forwarded-for")); @@ -123,6 +126,16 @@ where } }; + // Process reverse proxy context generated during the forwarding request generation. + #[cfg(feature = "sticky-cookie")] + if let Some(context_from_lb) = _context.context_lb { + let res_headers = res_backend.headers_mut(); + if let Err(e) = set_sticky_cookie_lb_context(res_headers, &context_from_lb) { + error!("Failed to append context to the response given from backend: {}", e); + return self.return_with_error_log(StatusCode::BAD_GATEWAY, &mut log_data); + } + } + if res_backend.status() != StatusCode::SWITCHING_PROTOCOLS { // Generate response to client if self.generate_response_forwarded(&mut res_backend, backend).is_ok() { @@ -141,9 +154,7 @@ where false } { if let Some(request_upgraded) = request_upgraded { - let onupgrade = if let Some(onupgrade) = res_backend.extensions_mut().remove::() { - onupgrade - } else { + let Some(onupgrade) = res_backend.extensions_mut().remove::() else { error!("Response does not have an upgrade extension"); return self.return_with_error_log(StatusCode::INTERNAL_SERVER_ERROR, &mut log_data); }; @@ -231,7 +242,7 @@ where upgrade: &Option, upstream_group: &UpstreamGroup, tls_enabled: bool, - ) -> Result<()> { + ) -> Result { debug!("Generate request to be forwarded"); // Add te: trailer if contained in original request @@ -267,8 +278,28 @@ where .insert(header::HOST, HeaderValue::from_str(&org_host)?); }; + ///////////////////////////////////////////// // Fix unique upstream destination since there could be multiple ones. - let upstream_chosen = upstream_group.get().ok_or_else(|| anyhow!("Failed to get upstream"))?; + #[cfg(feature = "sticky-cookie")] + let (upstream_chosen_opt, context_from_lb) = { + let context_to_lb = if let crate::backend::LoadBalance::StickyRoundRobin(lb) = &upstream_group.lb { + takeout_sticky_cookie_lb_context(req.headers_mut(), &lb.sticky_config.name)? + } else { + None + }; + upstream_group.get(&context_to_lb) + }; + #[cfg(not(feature = "sticky-cookie"))] + let (upstream_chosen_opt, _) = upstream_group.get(&None); + + let upstream_chosen = upstream_chosen_opt.ok_or_else(|| anyhow!("Failed to get upstream"))?; + let context = HandlerContext { + #[cfg(feature = "sticky-cookie")] + context_lb: context_from_lb, + #[cfg(not(feature = "sticky-cookie"))] + context_lb: None, + }; + ///////////////////////////////////////////// // apply upstream-specific headers given in upstream_option let headers = req.headers_mut(); @@ -321,6 +352,6 @@ where *req.version_mut() = Version::HTTP_2; } - Ok(()) + Ok(context) } } diff --git a/src/handler/mod.rs b/src/handler/mod.rs index c2225ce..8bec011 100644 --- a/src/handler/mod.rs +++ b/src/handler/mod.rs @@ -3,4 +3,15 @@ mod utils_headers; mod utils_request; mod utils_synth_response; +#[cfg(feature = "sticky-cookie")] +use crate::backend::LbContext; pub use handler_main::{HttpMessageHandler, HttpMessageHandlerBuilder, HttpMessageHandlerBuilderError}; + +#[allow(dead_code)] +#[derive(Debug)] +struct HandlerContext { + #[cfg(feature = "sticky-cookie")] + context_lb: Option, + #[cfg(not(feature = "sticky-cookie"))] + context_lb: Option<()>, +} diff --git a/src/handler/utils_headers.rs b/src/handler/utils_headers.rs index 7fc4a5f..944d4d9 100644 --- a/src/handler/utils_headers.rs +++ b/src/handler/utils_headers.rs @@ -1,9 +1,8 @@ -use crate::{ - backend::{UpstreamGroup, UpstreamOption}, - error::*, - log::*, - utils::*, -}; +#[cfg(feature = "sticky-cookie")] +use crate::backend::{LbContext, StickyCookie, StickyCookieValue}; +use crate::backend::{UpstreamGroup, UpstreamOption}; + +use crate::{error::*, log::*, utils::*}; use bytes::BufMut; use hyper::{ header::{self, HeaderMap, HeaderName, HeaderValue}, @@ -14,6 +13,76 @@ use std::net::SocketAddr; //////////////////////////////////////////////////// // Functions to manipulate headers +#[cfg(feature = "sticky-cookie")] +/// Take sticky cookie header value from request header, +/// and returns LbContext to be forwarded to LB if exist and if needed. +/// Removing sticky cookie is needed and it must not be passed to the upstream. +pub(super) fn takeout_sticky_cookie_lb_context( + headers: &mut HeaderMap, + expected_cookie_name: &str, +) -> Result> { + let mut headers_clone = headers.clone(); + + match headers_clone.entry(hyper::header::COOKIE) { + header::Entry::Vacant(_) => Ok(None), + header::Entry::Occupied(entry) => { + let cookies_iter = entry + .iter() + .flat_map(|v| v.to_str().unwrap_or("").split(';').map(|v| v.trim())); + let (sticky_cookies, without_sticky_cookies): (Vec<_>, Vec<_>) = cookies_iter + .into_iter() + .partition(|v| v.starts_with(expected_cookie_name)); + if sticky_cookies.is_empty() { + return Ok(None); + } + if sticky_cookies.len() > 1 { + error!("Multiple sticky cookie values in request"); + return Err(RpxyError::Other(anyhow!( + "Invalid cookie: Multiple sticky cookie values" + ))); + } + let cookies_passed_to_upstream = without_sticky_cookies.join("; "); + let cookie_passed_to_lb = sticky_cookies.first().unwrap(); + headers.remove(hyper::header::COOKIE); + headers.insert(hyper::header::COOKIE, cookies_passed_to_upstream.parse()?); + + let sticky_cookie = StickyCookie { + value: StickyCookieValue::try_from(cookie_passed_to_lb, expected_cookie_name)?, + info: None, + }; + Ok(Some(LbContext { sticky_cookie })) + } + } +} + +#[cfg(feature = "sticky-cookie")] +/// Set-Cookie if LB Sticky is enabled and if cookie is newly created/updated. +/// Set-Cookie response header could be in multiple lines. +/// https://developer.mozilla.org/ja/docs/Web/HTTP/Headers/Set-Cookie +pub(super) fn set_sticky_cookie_lb_context(headers: &mut HeaderMap, context_from_lb: &LbContext) -> Result<()> { + let sticky_cookie_string: String = context_from_lb.sticky_cookie.clone().try_into()?; + let new_header_val: HeaderValue = sticky_cookie_string.parse()?; + let expected_cookie_name = &context_from_lb.sticky_cookie.value.name; + match headers.entry(hyper::header::SET_COOKIE) { + header::Entry::Vacant(entry) => { + entry.insert(new_header_val); + } + header::Entry::Occupied(mut entry) => { + let mut flag = false; + for e in entry.iter_mut() { + if e.to_str().unwrap_or("").starts_with(expected_cookie_name) { + *e = new_header_val.clone(); + flag = true; + } + } + if !flag { + entry.append(new_header_val); + } + } + }; + Ok(()) +} + pub(super) fn apply_upstream_options_to_header( headers: &mut HeaderMap, _client_addr: &SocketAddr, diff --git a/src/proxy/proxy_tls.rs b/src/proxy/proxy_tls.rs index a3ed081..de18e0c 100644 --- a/src/proxy/proxy_tls.rs +++ b/src/proxy/proxy_tls.rs @@ -6,10 +6,10 @@ use crate::{ log::*, utils::BytesName, }; -#[cfg(feature = "http3")] use hyper::{client::connect::Connect, server::conn::Http}; #[cfg(feature = "http3")] use quinn::{crypto::rustls::HandshakeData, Endpoint, ServerConfig as QuicServerConfig, TransportConfig}; +#[cfg(feature = "http3")] use rustls::ServerConfig; use std::sync::Arc; use tokio::{ @@ -196,14 +196,14 @@ where let (tx, rx) = watch::channel::>>(None); #[cfg(not(feature = "http3"))] { - select! { - _= self.cert_service(tx).fuse() => { + tokio::select! { + _= self.cert_service(tx) => { error!("Cert service for TLS exited"); }, - _ = self.listener_service(server, rx).fuse() => { + _ = self.listener_service(server, rx) => { error!("TCP proxy service for TLS exited"); }, - complete => { + else => { error!("Something went wrong"); return Ok(()) } diff --git a/src/utils/bytes_name.rs b/src/utils/bytes_name.rs index 80bc0f0..16ec7ab 100644 --- a/src/utils/bytes_name.rs +++ b/src/utils/bytes_name.rs @@ -1,5 +1,5 @@ /// Server name (hostname or ip address) representation in bytes-based struct -/// For searching hashmap or key list by exact or longest-prefix matching +/// for searching hashmap or key list by exact or longest-prefix matching #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] pub struct ServerNameBytesExp(pub Vec); // lowercase ascii bytes impl From<&[u8]> for ServerNameBytesExp { @@ -8,8 +8,8 @@ impl From<&[u8]> for ServerNameBytesExp { } } -/// Server name (hostname or ip address) representation in bytes-based struct -/// For searching hashmap or key list by exact or longest-prefix matching +/// Path name, like "/path/ok", represented in bytes-based struct +/// for searching hashmap or key list by exact or longest-prefix matching #[derive(Clone, Debug, PartialEq, Eq, Hash, Default)] pub struct PathNameBytesExp(pub Vec); // lowercase ascii bytes impl PathNameBytesExp {