diff --git a/Cargo.lock b/Cargo.lock index f5d0a0b..1fb2f1d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,16 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + [[package]] name = "aho-corasick" version = "1.1.3" @@ -151,6 +161,18 @@ dependencies = [ "fs_extra", ] +[[package]] +name = "base16ct" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" + +[[package]] +name = "base64ct" +version = "1.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" + [[package]] name = "bindgen" version = "0.72.1" @@ -171,23 +193,77 @@ dependencies = [ "syn", ] -[[package]] -name = "bitfield-struct" -version = "0.9.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2869c63ccf4f8bf0d485070b880e60e097fb7aeea80ee82a0a94a957e372a0b" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - [[package]] name = "bitflags" version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3" +[[package]] +name = "boring" +version = "4.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9078ed42084a408d2a1c33bb9dde775721227f1673e4daa81e7f74017b412ce3" +dependencies = [ + "bitflags", + "boring-sys", + "foreign-types 0.5.0", + "libc", + "openssl-macros", +] + +[[package]] +name = "boring-additions" +version = "0.0.1" +source = "git+https://github.com/janrueth/boring-rustls-provider.git?rev=490340afa77e2c08fc45853124f99d49f4f9f8a0#490340afa77e2c08fc45853124f99d49f4f9f8a0" +dependencies = [ + "boring", + "boring-sys", + "foreign-types 0.5.0", +] + +[[package]] +name = "boring-rustls-provider" +version = "0.0.1" +source = "git+https://github.com/janrueth/boring-rustls-provider.git?rev=490340afa77e2c08fc45853124f99d49f4f9f8a0#490340afa77e2c08fc45853124f99d49f4f9f8a0" +dependencies = [ + "aead", + "boring", + "boring-additions", + "boring-sys", + "boring-sys-additions", + "foreign-types 0.5.0", + "rustls", + "rustls-pki-types", + "spki", +] + +[[package]] +name = "boring-sys" +version = "4.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f033d6ba6dee8b5322f5504acf758c29317790883bac078e7949dd3d670269" +dependencies = [ + "bindgen", + "cmake", + "fs_extra", + "fslock", +] + +[[package]] +name = "boring-sys-additions" +version = "0.0.1" +source = "git+https://github.com/janrueth/boring-rustls-provider.git?rev=490340afa77e2c08fc45853124f99d49f4f9f8a0#490340afa77e2c08fc45853124f99d49f4f9f8a0" +dependencies = [ + "boring-sys", +] + +[[package]] +name = "bumpalo" +version = "3.19.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dd9dc738b7a8311c7ade152424974d8115f2cdad61e8dab8dac9f2362298510" + [[package]] name = "bytes" version = "1.10.1" @@ -247,12 +323,49 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "const-oid" +version = "0.9.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" + +[[package]] +name = "core-models" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0940496e5c83c54f3b753d5317daec82e8edac71c33aaa1f666d76f518de2444" +dependencies = [ + "hax-lib", + "pastey", + "rand", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "data-encoding" version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" +[[package]] +name = "der" +version = "0.7.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c1832837b905bbfb5101e07cc24c8deddf52f93225eee6ead5f4d63d53ddcb" +dependencies = [ + "const-oid", + "zeroize", +] + [[package]] name = "der-parser" version = "10.0.0" @@ -322,21 +435,6 @@ dependencies = [ "log", ] -[[package]] -name = "fast-tlsh" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d24a9974c3be467166fe631c0a9e804e028ee7525ed27300dfa7915cc6e36311" -dependencies = [ - "bitfield-struct", - "bitflags", - "cfg-if", - "hex-simd", - "serde", - "static_assertions", - "version_check", -] - [[package]] name = "find-msvc-tools" version = "0.1.4" @@ -349,7 +447,28 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ - "foreign-types-shared", + "foreign-types-shared 0.1.1", +] + +[[package]] +name = "foreign-types" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d737d9aa519fb7b749cbc3b962edcf310a8dd1f4b67c91c4f83975dbdd17d965" +dependencies = [ + "foreign-types-macros", + "foreign-types-shared 0.3.1", +] + +[[package]] +name = "foreign-types-macros" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] @@ -358,12 +477,28 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" +[[package]] +name = "foreign-types-shared" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa9a19cbb55df58761df49b23516a86d432839add4af60fc256da840f66ed35b" + [[package]] name = "fs_extra" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "fslock" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "04412b8935272e3a9bae6f48c7bfff74c2911f60525404edfdd28e49884c3bfb" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "futures-core" version = "0.3.31" @@ -407,6 +542,16 @@ dependencies = [ "slab", ] +[[package]] +name = "generic-array" +version = "0.14.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb6743198531e02858aeaea5398fcc883e71851fcbcb5a2f773e2fb6cb1edf2" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getopts" version = "0.2.24" @@ -446,13 +591,50 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" [[package]] -name = "hex-simd" -version = "0.8.0" +name = "graviola" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1f7685beb53fc20efc2605f32f5d51e9ba18b8ef237961d1760169d2290d3bee" +checksum = "b1662fcff7237fbe8c91ff2800fcce9435af25b7f0cb580f5679b31c3a1f1e7a" dependencies = [ - "outref", - "vsimd", + "cfg-if", + "getrandom 0.3.4", +] + +[[package]] +name = "hax-lib" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74d9ba66d1739c68e0219b2b2238b5c4145f491ebf181b9c6ab561a19352ae86" +dependencies = [ + "hax-lib-macros", + "num-bigint", + "num-traits", +] + +[[package]] +name = "hax-lib-macros" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24ba777a231a58d1bce1d68313fa6b6afcc7966adef23d60f45b8a2b9b688bf1" +dependencies = [ + "hax-lib-macros-types", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "hax-lib-macros-types" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "867e19177d7425140b417cd27c2e05320e727ee682e98368f88b7194e80ad515" +dependencies = [ + "proc-macro2", + "quote", + "serde", + "serde_json", + "uuid", ] [[package]] @@ -510,6 +692,16 @@ dependencies = [ "libc", ] +[[package]] +name = "js-sys" +version = "0.3.85" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c942ebf8e95485ca0d52d97da7c5a2c387d0e7f0ba4c35e93bfcaee045955b3" +dependencies = [ + "once_cell", + "wasm-bindgen", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -522,6 +714,70 @@ version = "0.2.177" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +[[package]] +name = "libcrux-intrinsics" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc9ee7ef66569dd7516454fe26de4e401c0c62073929803486b96744594b9632" +dependencies = [ + "core-models", + "hax-lib", +] + +[[package]] +name = "libcrux-ml-kem" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4bb6a88086bf11bd2ec90926c749c4a427f2e59841437dbdede8cde8a96334ab" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-secrets", + "libcrux-sha3", + "libcrux-traits", +] + +[[package]] +name = "libcrux-platform" +version = "0.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db82d058aa76ea315a3b2092f69dfbd67ddb0e462038a206e1dcd73f058c0778" +dependencies = [ + "libc", +] + +[[package]] +name = "libcrux-secrets" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4dbbf6bc9f2bc0f20dc3bea3e5c99adff3bdccf6d2a40488963da69e2ec307" +dependencies = [ + "hax-lib", +] + +[[package]] +name = "libcrux-sha3" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2400bec764d1c75b8a496d5747cffe32f1fb864a12577f0aca2f55a92021c962" +dependencies = [ + "hax-lib", + "libcrux-intrinsics", + "libcrux-platform", + "libcrux-traits", +] + +[[package]] +name = "libcrux-traits" +version = "0.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9adfd58e79d860f6b9e40e35127bfae9e5bd3ade33201d1347459011a2add034" +dependencies = [ + "libcrux-secrets", + "rand", +] + [[package]] name = "libloading" version = "0.8.9" @@ -567,12 +823,14 @@ version = "0.1.0" dependencies = [ "argp", "aws-lc-rs", + "boring-rustls-provider", "env_logger", - "fast-tlsh", "futures-util", - "memchr", - "regex", + "log", + "rustls-graviola", + "rustls-openssl", "rustls-post-quantum", + "rustls-symcrypt", "sslrelay", "static_cell", "tokio", @@ -654,7 +912,7 @@ checksum = "24ad14dd45412269e1a30f52ad8f0664f0f4f4a89ee8fe28c3b3527021ebb654" dependencies = [ "bitflags", "cfg-if", - "foreign-types", + "foreign-types 0.3.2", "libc", "once_cell", "openssl-macros", @@ -685,10 +943,10 @@ dependencies = [ ] [[package]] -name = "outref" -version = "0.5.2" +name = "pastey" +version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" +checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" [[package]] name = "pin-project-lite" @@ -702,6 +960,27 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkcs1" +version = "0.7.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8ffb9f10fa047879315e6625af03c164b16962a5368d724ed16323b68ace47f" +dependencies = [ + "der", + "pkcs8", + "spki", +] + +[[package]] +name = "pkcs8" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" +dependencies = [ + "der", + "spki", +] + [[package]] name = "pkg-config" version = "0.3.32" @@ -729,6 +1008,15 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.37" @@ -739,6 +1027,28 @@ dependencies = [ "syn", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -775,6 +1085,35 @@ version = "5.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom 0.3.4", +] + [[package]] name = "regex" version = "1.12.2" @@ -842,12 +1181,38 @@ dependencies = [ "aws-lc-rs", "log", "once_cell", + "ring", "rustls-pki-types", - "rustls-webpki", + "rustls-webpki 0.103.7", "subtle", "zeroize", ] +[[package]] +name = "rustls-graviola" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e81f0f2005bfba00e8088f9cb75f4b3ce3f2a31aebfaeed0b2cc05e13d01ce06" +dependencies = [ + "graviola", + "libcrux-ml-kem", + "rustls", +] + +[[package]] +name = "rustls-openssl" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "346f161084062dd5a443adfb0de9bbfab62699333e2d8424ca91da5f22ec88d1" +dependencies = [ + "foreign-types 0.3.2", + "once_cell", + "openssl", + "openssl-sys", + "rustls", + "zeroize", +] + [[package]] name = "rustls-pki-types" version = "1.12.0" @@ -865,7 +1230,33 @@ checksum = "0da3cd9229bac4fae1f589c8f875b3c891a058ddaa26eb3bde16b5e43dc174ce" dependencies = [ "aws-lc-rs", "rustls", - "rustls-webpki", + "rustls-webpki 0.103.7", +] + +[[package]] +name = "rustls-symcrypt" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1aa1fbe108897452967f5de83014edf2e42606f09283c4ba1f671904098fcbc" +dependencies = [ + "der", + "pkcs1", + "pkcs8", + "rustls", + "rustls-webpki 0.102.8", + "sec1", + "symcrypt", +] + +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted 0.9.0", ] [[package]] @@ -880,6 +1271,24 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "rustversion" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39cdef0fa800fc44525c84ccb54a029961a8215f9619753635a9c0d2538d46d" + +[[package]] +name = "sec1" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" +dependencies = [ + "base16ct", + "der", + "generic-array", + "zeroize", +] + [[package]] name = "serde" version = "1.0.228" @@ -887,6 +1296,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ "serde_core", + "serde_derive", ] [[package]] @@ -909,6 +1319,19 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + [[package]] name = "shlex" version = "1.3.0" @@ -931,6 +1354,16 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "spki" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" +dependencies = [ + "base64ct", + "der", +] + [[package]] name = "sslrelay" version = "0.6.2" @@ -938,12 +1371,6 @@ dependencies = [ "openssl", ] -[[package]] -name = "static_assertions" -version = "1.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" - [[package]] name = "static_cell" version = "2.1.1" @@ -959,6 +1386,26 @@ version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" +[[package]] +name = "symcrypt" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45f6e93e13900bae533f1f347c82d2fcf9d7ad7fffd7e58bc9fbd1b68575ca1c" +dependencies = [ + "lazy_static", + "libc", + "symcrypt-sys", +] + +[[package]] +name = "symcrypt-sys" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0878df9feb068709a68ccdb5cc3201649201e79e823566863f6ca4db4b9201aa" +dependencies = [ + "libc", +] + [[package]] name = "syn" version = "2.0.107" @@ -1081,6 +1528,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "typenum" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" + [[package]] name = "unicase" version = "2.8.1" @@ -1117,6 +1570,17 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee48d38b119b0cd71fe4141b30f5ba9c7c5d9f4e7a3a8b4a674e4b6ef789976f" +dependencies = [ + "getrandom 0.3.4", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -1129,12 +1593,6 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" -[[package]] -name = "vsimd" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64" - [[package]] name = "wasi" version = "0.11.1+wasi-snapshot-preview1" @@ -1150,6 +1608,73 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasm-bindgen" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64024a30ec1e37399cf85a7ffefebdb72205ca1c972291c51512360d90bd8566" +dependencies = [ + "cfg-if", + "once_cell", + "rustversion", + "wasm-bindgen-macro", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "008b239d9c740232e71bd39e8ef6429d27097518b6b30bdf9086833bd5b6d608" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5256bae2d58f54820e6490f9839c49780dff84c65aeab9e772f15d5f0e913a55" +dependencies = [ + "bumpalo", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f01b580c9ac74c8d8f0c0e4afb04eeef2acf145458e52c03845ee9cd23e3d12" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-link" version = "0.2.1" @@ -1335,8 +1860,34 @@ dependencies = [ "time", ] +[[package]] +name = "zerocopy" +version = "0.8.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7456cf00f0685ad319c5b1693f291a650eaf345e941d082fc4e03df8a03996ac" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1328722bbf2115db7e19d69ebcc15e795719e2d66b60827c6a69a117365e37a0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "zeroize" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + +[[package]] +name = "zmij" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445" diff --git a/Cargo.toml b/Cargo.toml index 1cdac68..43e4a31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,17 +5,39 @@ edition = "2024" [dependencies] argp = "0.4.0" -aws-lc-rs = "1.14.1" #console-subscriber = "0.5.0" env_logger = "0.11.8" futures-util = "0.3.31" -memchr = "2.7.6" -regex = "1.12.2" -rustls-post-quantum = { version = "0.2.4", features = ["aws-lc-rs-unstable"] } -sslrelay = { path = "../sslrelay" } +log = "0.4.28" +sslrelay = { path = "../sslrelay", optional = true } static_cell = "2.1.1" -tlsh = { package = "fast-tlsh", version = "0.1.10", features = ["easy-functions"] } tokio = { version = "1.48.0", features = ["io-util", "macros", "net", "rt", "rt-multi-thread", "sync", "time"]} tokio-rustls = "0.26.4" tokio-util = { version = "0.7.6", features = ["codec"] } x509-parser = "0.18.0" + +# TLS impls +aws-lc-rs = { version = "1.14.1", optional = true } +boring-rustls-provider = { git = "https://github.com/janrueth/boring-rustls-provider.git", rev = "490340afa77e2c08fc45853124f99d49f4f9f8a0", optional = true } +rustls-graviola = { version = "0.3.2", optional = true } +rustls-openssl = { version = "0.3.0", default-features = false, features = ["tls12"], optional = true } +rustls-post-quantum = { version = "0.2.4", optional = true } +rustls-symcrypt = { version = "0.2.1", optional = true, features = ["chacha", "x25519"] } + +[features] +default = [ + "aws-lc", + "record", +] + +record = ["sslrelay"] + +aws-lc = ["tokio-rustls/aws-lc-rs", "rustls-post-quantum", "rustls-post-quantum/aws-lc-rs-unstable", "aws-lc-rs"] +boring = ["boring-rustls-provider"] +graviola = ["rustls-graviola"] +openssl = ["rustls-openssl"] +ring = ["tokio-rustls/ring"] +symcrypt = ["rustls-symcrypt"] + +[profile.release] +#lto = "fat" diff --git a/rustfmt.toml b/rustfmt.toml index ade1f81..931c5cd 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,5 +1,4 @@ hard_tabs = true -newline_style = "unix" unstable_features = true format_code_in_doc_comments = true diff --git a/src/client.rs b/src/client.rs index 404c523..06471fd 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,8 @@ use crate::{ - TlsMode, record::{Direction, Records}, - util::{ResponseStreamer, print_bin}, + util::ResponseStreamer, }; -use futures_util::{StreamExt, TryStreamExt}; use std::{ collections::HashSet, net::ToSocketAddrs, @@ -14,7 +12,7 @@ use std::{ use tokio::{ io::AsyncWriteExt, net::TcpStream, - sync::{Mutex, Semaphore, oneshot}, + sync::{Mutex, Semaphore}, }; use tokio_rustls::{ TlsConnector, @@ -27,7 +25,6 @@ use tokio_rustls::{ pki_types::ServerName, }, }; -use tokio_util::codec::Framed; const TIMEOUT: Duration = Duration::from_secs(30); @@ -89,13 +86,11 @@ impl ServerCertVerifier for DummyCertVerifier { pub async fn play( records: &'static Records, - tls_mode: TlsMode, + use_tls: bool, connect_to: (String, u16), - sync_receiver: oneshot::Receiver<()>, repeat: u32, debug: bool, ) { - sync_receiver.await.unwrap(); // Semaphore used to limit the number of concurrent clients. // Its handle is released when the task panics. let limiter = Arc::new(Semaphore::new(16)); @@ -105,6 +100,10 @@ pub async fn play( let connect_to = connect_to.to_socket_addrs().unwrap().next().unwrap(); let debug_mutex = Arc::new(Mutex::new(())); + let dummy_bytes = Arc::new(vec![0x42u8; 16 * 1024 * 1024]); + + tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; + tokio::spawn({ let running = running.clone(); let counter = counter.clone(); @@ -123,303 +122,200 @@ pub async fn play( } }); - match tls_mode { - TlsMode::Both | TlsMode::Client => { - let mut config = tokio_rustls::rustls::ClientConfig::builder() - .dangerous() - .with_custom_certificate_verifier(Arc::new(DummyCertVerifier)) - .with_no_client_auth(); - let mut enable_early_data = false; - for (var, val) in std::env::vars() { - match var.as_str() { - "EARLYDATA" => enable_early_data = val == "1", - _ => {} - } - } - if enable_early_data { - config.enable_early_data = true; - } else { - config.resumption = Resumption::disabled(); - } - config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); - let config = Arc::new(config); - for _i in 0..repeat { - let mut handles = Vec::new(); - for (id, (server_name, records)) in records.iter() { - let connector = TlsConnector::from(config.clone()); - let counter = counter.clone(); - let limiter = limiter.clone(); - let running = running.clone(); - handles.push(tokio::spawn(async move { - let mut running_guard = running.lock().await; - running_guard.insert(*id); - drop(running_guard); - let limiter = limiter.acquire().await.unwrap(); - let server_name = - ServerName::try_from(String::from_utf8(server_name.clone()).unwrap()) - .unwrap(); - 'repeat: for _i in 0..1 { - let stream = TcpStream::connect(connect_to).await.unwrap(); - let stream = connector - .connect(server_name.clone(), stream) - .await - .unwrap(); - let mut stream = - Framed::new(stream, crate::http::HttpClientCodec::new()); - for (direction, data_list) in ResponseStreamer::new(records.iter()) { - match direction { - Direction::ClientToServer => { - for data in data_list { - //println!("[CLT] ({id}) >> {}", data.len()); - //stream.get_mut().write_all(data).await.unwrap(); - match tokio::time::timeout( - TIMEOUT, - stream.get_mut().write_all(data), - ) - .await - { - Ok(v) => v.unwrap(), - Err(_e) => { - println!("client timeout {id} (sending)"); - continue 'repeat; - } - } - } - } - Direction::ServerToClient => { - let total_len: usize = - data_list.iter().map(|data| data.len()).sum::(); - let reduced_len = - total_len.saturating_sub(160 * data_list.len()).max(1); - let mut total_recv = 0; - //println!("[CLT] ({id}) << {}", data.len()); - // let mut buf = Vec::new(); - // stream.read_buf(&mut buf).await.ok(); - //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; - //let resp = stream.next().await.unwrap().unwrap(); - while total_recv < reduced_len { - let resp = - match tokio::time::timeout(TIMEOUT, stream.next()) - .await - { - Ok(v) => v.unwrap().unwrap(), - Err(_e) => { - // TODO fix - println!( - "client timeout {}: {} / {}", - id, total_recv, total_len - ); - //print_bin(data); - break 'repeat; - } - }; - total_recv += resp.len(); - //dbg!(resp.len()); - //crate::http::decode_http(&mut buf, &mut stream).await; - } - /*if total_recv > total_len { - println!("received too much {}: {} / {}", id, total_recv, total_len); - }*/ - } - } - } - //stream.get_mut().shutdown().await.unwrap(); - tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown()) - .await - .unwrap() - .unwrap(); - let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - println!("Client: {} / {}", cnt + 1, total); - } - drop(limiter); - let mut running_guard = running.lock().await; - running_guard.remove(id); - drop(running_guard); - })); - //tokio::time::sleep(std::time::Duration::from_millis(500)).await; - } - - for handle in handles { - handle.await.unwrap(); - } + if use_tls { + let mut config = tokio_rustls::rustls::ClientConfig::builder() + .dangerous() + .with_custom_certificate_verifier(Arc::new(DummyCertVerifier)) + .with_no_client_auth(); + let mut enable_early_data = false; + for (var, val) in std::env::vars() { + match var.as_str() { + "EARLYDATA" => enable_early_data = val == "1", + _ => {} } } - TlsMode::None | TlsMode::Server => { - for _i in 0..repeat { - let mut handles = Vec::new(); - for (id, (_server_name, records)) in records.iter() { - /*if *id != 33 { - continue - }*/ - let counter = counter.clone(); - let limiter = limiter.clone(); - let running = running.clone(); - let debug_mutex = debug_mutex.clone(); - handles.push(tokio::spawn(async move { - let mut running_guard = running.lock().await; - running_guard.insert(*id); - drop(running_guard); - let limiter = limiter.acquire().await.unwrap(); - //let mut buf = Vec::new(); - 'repeat: for _i in 0..1 { - let stream = TcpStream::connect(connect_to).await.unwrap(); - let mut stream = - Framed::new(stream, crate::http::HttpClientCodec::new()); - /*let mut skip_recv = false; - for (direction, data) in records { - match direction { - Direction::ClientToServer => { - skip_recv = false; - println!("[CLT] ({id}) >> {}", data.len()); - stream.write_all(data).await.unwrap(); - } - Direction::ServerToClient => { - if skip_recv { + if enable_early_data { + config.enable_early_data = true; + } else { + config.resumption = Resumption::disabled(); + } + config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); + let config = Arc::new(config); + for _i in 0..repeat { + let mut handles = Vec::new(); + for (conn_id, (server_name, records)) in records.iter() { + let connector = TlsConnector::from(config.clone()); + let counter = counter.clone(); + let limiter = limiter.clone(); + let running = running.clone(); + let dummy_bytes = dummy_bytes.clone(); + handles.push(tokio::spawn(async move { + let mut running_guard = running.lock().await; + running_guard.insert(*conn_id); + drop(running_guard); + let limiter = limiter.acquire().await.unwrap(); + let server_name = + ServerName::try_from(String::from_utf8(server_name.clone()).unwrap()) + .unwrap(); + let stream = TcpStream::connect(connect_to).await.unwrap(); + let stream = connector + .connect(server_name.clone(), stream) + .await + .unwrap(); + let mut stream = crate::codec::StreamCodec::new(stream); + for (direction, reqs) in ResponseStreamer::new(records.iter()) { + match direction { + Direction::ClientToServer => { + for (req_id, len) in reqs { + //println!("[CLT] ({conn_id}) >> {}", len); + let mut data = dummy_bytes[0..len as usize].to_vec(); + data[0..4].copy_from_slice(&(len as u32).to_be_bytes()); + data[4..6].copy_from_slice(&(*conn_id as u16).to_be_bytes()); + data[6..8].copy_from_slice(&(req_id as u16).to_be_bytes()); + match tokio::time::timeout(TIMEOUT, async { + stream.get_mut().write_all(&data).await.unwrap(); + }) + .await + { + Ok(_v) => {} + Err(_e) => { + println!("client timeout {conn_id} (sending)"); continue; } - println!("[CLT] ({id}) << {}", data.len()); - //let mut buf = Vec::new(); - //stream.read_buf(&mut buf).await.ok(); - //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; - let mut buf = vec![0; data.len()]; - match tokio::time::timeout( - std::time::Duration::from_millis(500), - stream.readable(), - ) - .await - { - Ok(r) => { - r.unwrap(); - } - Err(_) => { - println!("[CLT] timeout recv ({id})"); - break; - } - } - // TODO utiliser crate::http ici - match tokio::time::timeout( - std::time::Duration::from_millis(500), - stream.read_exact(&mut buf), - ) - .await - { - Ok(r) => { - r.unwrap(); - } - Err(_) => { - println!("[CLT] skip recv ({id})"); - skip_recv = true; - } - } - } - } - }*/ - for (direction, data_list) in ResponseStreamer::new(records.iter()) { - match direction { - Direction::ClientToServer => { - for data in data_list.into_iter() { - if debug { - //println!("[CLT] ({id}) >> {}", str::from_utf8(&data[..data.len().min(255)]).unwrap()); - println!("[CLT] ({id}) >> {}", data.len()); - } - //stream.get_mut().write_all(data).await.unwrap(); - match tokio::time::timeout( - TIMEOUT, - stream.get_mut().write_all(data), - ) - .await - { - Ok(v) => v.unwrap(), - Err(_e) => { - println!("client timeout {id} (sending)"); - continue 'repeat; - } - } - } - } - Direction::ServerToClient => { - let total_len: usize = - data_list.iter().map(|data| data.len()).sum::(); - let reduced_len = - total_len.saturating_sub(160 * data_list.len()).max(1); - let mut total_recv = 0; - if debug { - println!("[CLT] ({id}) << {total_len}"); - } - //let mut buf = Vec::new(); - //stream.read_buf(&mut buf).await.ok(); - //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; - let mut resp = Vec::new(); - while total_recv < reduced_len { - resp = - match tokio::time::timeout(TIMEOUT, stream.next()) - .await - { - Ok(None) => break, - Ok(Some(v)) => v.unwrap(), - Err(_e) => { - // TODO fix - println!( - "client timeout {}: {} / {}", - id, total_recv, total_len - ); - //print_bin(data); - break 'repeat; - } - }; - total_recv += resp.len(); - /*if resp.len() != data.len() { - let guard = debug_mutex.lock().await; - println!("RECV NOT ENOUGH {} / {}", resp.len(), data.len()); - if resp.len() < 1000 && data.len() < 1000 { - //print_bin(&resp); - //println!("WANTED"); - //print_bin(data); - } - std::mem::drop(guard); - }*/ - //print_bin(&resp); - //let resp = stream.next().await.unwrap().unwrap(); - //dbg!(resp.len()); - //crate::http::decode_http(&mut buf, &mut stream).await; - //buf.clear(); - } - if total_recv < reduced_len { - println!( - "({}) RECV NOT ENOUGH {} / {}", - id, total_recv, total_len - ); - if resp.len() < 1024 { - print_bin(&resp); - } - } else if debug { - println!("[CLT] ({id}) << {total_len} OK"); - } } } } - //stream.get_mut().shutdown().await.unwrap(); - tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown()) - .await - .unwrap() - .unwrap(); - let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); - println!("Client: {} / {}", cnt + 1, total); + Direction::ServerToClient => { + let expected_total_len = + reqs.iter().map(|(_req_id, len)| *len).sum::(); + let mut total_recv = 0; + //println!("[CLT] ({conn_id}) << {}", expected_total_len); + while total_recv < expected_total_len { + let resp = + match tokio::time::timeout(TIMEOUT, stream.next()).await { + Ok(v) => v.unwrap(), + Err(_e) => { + // TODO fix + println!( + "client timeout {}: {} / {}", + conn_id, total_recv, expected_total_len + ); + //print_bin(data); + break; + } + }; + total_recv += resp.len() as u64; + } + /*if total_recv > total_len { + println!("received too much {}: {} / {}", id, total_recv, total_len); + }*/ + } } - drop(limiter); - let mut running_guard = running.lock().await; - running_guard.remove(id); - drop(running_guard); - })); - //tokio::time::sleep(std::time::Duration::from_millis(500)).await; - } + } + //stream.get_mut().shutdown().await.unwrap(); + //println!("Client shutdown"); + tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown()) + .await + .unwrap() + .unwrap(); + let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + println!("Client: {} / {}", cnt + 1, total); + drop(limiter); + let mut running_guard = running.lock().await; + running_guard.remove(conn_id); + drop(running_guard); + })); + //tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } - for handle in handles { - handle.await.unwrap(); - } + for handle in handles { + handle.await.unwrap(); + } + } + } else { + for _i in 0..repeat { + let mut handles = Vec::new(); + for (conn_id, (_server_name, records)) in records.iter() { + let counter = counter.clone(); + let limiter = limiter.clone(); + let running = running.clone(); + let dummy_bytes = dummy_bytes.clone(); + handles.push(tokio::spawn(async move { + let mut running_guard = running.lock().await; + running_guard.insert(*conn_id); + drop(running_guard); + let limiter = limiter.acquire().await.unwrap(); + let stream = TcpStream::connect(connect_to).await.unwrap(); + let mut stream = crate::codec::StreamCodec::new(stream); + for (direction, reqs) in ResponseStreamer::new(records.iter()) { + match direction { + Direction::ClientToServer => { + for (req_id, len) in reqs { + let mut data = dummy_bytes[0..len as usize].to_vec(); + data[0..4].copy_from_slice(&(len as u32).to_be_bytes()); + data[4..6].copy_from_slice(&(*conn_id as u16).to_be_bytes()); + data[6..8].copy_from_slice(&(req_id as u16).to_be_bytes()); + //println!("[CLT] ({conn_id}) >> {}", len); + match tokio::time::timeout(TIMEOUT, async { + stream.get_mut().write_all(&data).await.unwrap(); + }) + .await + { + Ok(_v) => {} + Err(_e) => { + println!("client timeout {conn_id} (sending)"); + continue; + } + } + } + } + Direction::ServerToClient => { + let expected_total_len = + reqs.iter().map(|(_req_id, len)| *len).sum::(); + let mut total_recv = 0; + //println!("[CLT] ({conn_id}) << {}", expected_total_len); + while total_recv < expected_total_len { + let resp = + match tokio::time::timeout(TIMEOUT, stream.next()).await { + Ok(v) => v.unwrap(), + Err(_e) => { + // TODO fix + println!( + "client timeout {}: {} / {}", + conn_id, total_recv, expected_total_len + ); + //print_bin(data); + break; + } + }; + total_recv += resp.len() as u64; + } + /*if total_recv > total_len { + println!("received too much {}: {} / {}", id, total_recv, total_len); + }*/ + } + } + } + //stream.get_mut().shutdown().await.unwrap(); + //println!("Client shutdown"); + tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown()) + .await + .unwrap() + .unwrap(); + let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); + println!("Client: {} / {}", cnt + 1, total); + drop(limiter); + let mut running_guard = running.lock().await; + running_guard.remove(conn_id); + drop(running_guard); + })); + //tokio::time::sleep(std::time::Duration::from_millis(500)).await; + } + + for handle in handles { + handle.await.unwrap(); } } } println!("Unfinished: {:?}", running.lock().await); - std::process::exit(0); } diff --git a/src/codec.rs b/src/codec.rs new file mode 100644 index 0000000..e4d37f4 --- /dev/null +++ b/src/codec.rs @@ -0,0 +1,40 @@ +use tokio::io::{AsyncRead, AsyncReadExt}; + +pub struct StreamCodec { + stream: S, +} + +impl StreamCodec { + pub fn new(stream: S) -> Self { + Self { stream } + } + + pub async fn next(&mut self) -> Result, std::io::Error> { + let mut buf = vec![0; 8]; + self.stream.read_exact(&mut buf).await?; + let expected_len = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize; + if expected_len < 8 || expected_len > 8 * 1024 * 1024 { + return Err(std::io::ErrorKind::InvalidData.into()); + } + buf.resize(expected_len, 0); + self.stream.read_exact(&mut buf[8..expected_len]).await?; + Ok(buf) + } + + pub fn get_mut(&mut self) -> &mut S { + &mut self.stream + } +} + +/*#[cfg(test)] +mod test { + use super::*; + + use tokio_util::{bytes::BytesMut, codec::Framed}; + + #[test] + fn test_decode() { + let stream = futures_util::stream::iter([BytesMut::fr&[0u8]]); + let stream = Framed::new(stream, CustomCodec::new()); + } +}*/ diff --git a/src/http.rs b/src/http.rs deleted file mode 100644 index 00b20f9..0000000 --- a/src/http.rs +++ /dev/null @@ -1,397 +0,0 @@ -use regex::bytes::Regex; -use std::{pin::Pin, sync::LazyLock}; -use tokio::io::{AsyncRead, AsyncReadExt}; -use tokio_util::{ - bytes::BytesMut, - codec::{Decoder, Encoder}, -}; - -use crate::util::{is_hex, parse_hex}; - -static REGEX_CONTENT_LENGTH: LazyLock = - LazyLock::new(|| Regex::new(r#"[cC]ontent-[lL]ength: *(\d+)\r\n"#).unwrap()); -static REGEX_CHUNKED: LazyLock = - LazyLock::new(|| Regex::new(r#"[tT]ransfer-[eE]ncoding: *[cC]hunked\r\n"#).unwrap()); - -/*pin_project! { - pub struct Framer { - #[pin] - pub stream: S, - codec: C, - buf: BytesMut, - } -} - -impl Framer { - pub fn new(stream: S, codec: C) -> Self { - Self { - stream, - codec, - buf: BytesMut::new(), - } - } - - pub async fn next(&mut self) -> Option> { - self.stream.read_buf(&mut self.buf).await.unwrap(); - None - } -}*/ - -pub struct HttpClientCodec { - buf: Vec, -} - -impl HttpClientCodec { - pub fn new() -> Self { - Self { buf: Vec::new() } - } -} - -impl Decoder for HttpClientCodec { - type Item = Vec; - type Error = std::io::Error; - fn decode( - &mut self, - src: &mut tokio_util::bytes::BytesMut, - ) -> Result, Self::Error> { - self.buf.extend_from_slice(&src); - src.clear(); - let src = &mut self.buf; - if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { - end_index += 4; - if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) { - // Content-Length: simple body - if let Some(content_length) = captures.get(1) { - // Read body - let content_length: usize = str::from_utf8(content_length.as_bytes()) - .unwrap() - .parse() - .unwrap(); - if src.len() >= end_index + content_length { - let remaining = src.split_off(end_index + content_length); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } else { - //dbg!("Not enough data"); - Ok(None) - } - } else { - // Invalid Content-Length - Err(std::io::ErrorKind::InvalidData.into()) - } - } else if REGEX_CHUNKED.is_match(&src[0..end_index]) { - // Chunked body - let mut content = &src[end_index..]; - let mut total_len = end_index; - loop { - if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") { - let len_slice = &content[0..len_end_index]; - if len_end_index < 8 && is_hex(len_slice) { - let chunk_len = parse_hex(len_slice) as usize; - if content.len() >= len_end_index + chunk_len + 4 { - total_len += len_end_index + chunk_len + 4; - // Should we check the ending CRLF? - if chunk_len == 0 { - let remaining = src.split_off(total_len); - let out = src.to_vec(); - *src = remaining; - return Ok(Some(out)); - } - // else, wait for the next chunk - content = &content[len_end_index + chunk_len + 4..]; - } else { - // Not enough data - return Ok(None); - } - } else { - // Invalid chunk length - return Err(std::io::ErrorKind::InvalidData.into()); - } - } else { - // Not enough data - return Ok(None); - } - } - } else { - // Header ended without Content-Type nor chunks => no body - let remaining = src.split_off(end_index); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } - } else { - //dbg!("Unfinished header"); - Ok(None) - } - - /*self.buf.extend_from_slice(&src); - src.clear(); - let src = &mut self.buf; - if self.chunked { - if let Some(len_end_index) = memchr::memmem::find(src, b"\r\n") { - let len_slice = &src[0..len_end_index]; - if len_end_index < 8 && is_hex(len_slice) { - let chunk_len = parse_hex(len_slice) as usize; - if src.len() >= len_end_index + chunk_len + 4 { - // Should we check the ending CRLF? - if chunk_len == 0 { - self.chunked = false; - } - let remaining = src.split_off(len_end_index+chunk_len+4); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } else { - // Not enough data - Ok(None) - } - } else { - // Invalid chunk length - Err(std::io::ErrorKind::InvalidData.into()) - } - } else { - // Not enough data - Ok(None) - } - } else { - if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { - end_index += 4; - if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) { - if let Some(content_length) = captures.get(1) { - // Read body - let content_length: usize = str::from_utf8(content_length.as_bytes()) - .unwrap() - .parse() - .unwrap(); - if src.len() >= end_index + content_length { - if REGEX_CHUNKED.is_match(&src[0..end_index]) { - self.chunked = true; - } - //dbg!(content_length); - let remaining = src.split_off(end_index + content_length); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } else { - //dbg!("Not enough data"); - Ok(None) - } - } else { - // Invalid Content-Length - Err(std::io::ErrorKind::InvalidData.into()) - } - } else { - // Header ended without Content-Type => no body - let remaining = src.split_off(end_index); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } - } else { - //dbg!("Unfinished header"); - Ok(None) - } - }*/ - - /*if let Some(start_index) = memchr::memmem::find(src, b"HTTP") { - if start_index != 0 { - dbg!(start_index); - if start_index == 529 { - println!("{src:?}"); - } - } - let src2 = &src[start_index..]; - if let Some(mut end_index) = memchr::memmem::find(src2, b"\r\n\r\n") { - end_index += 4; - if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src2) { - if let Some(content_length) = captures.get(1) { - // Read body - let content_length: usize = str::from_utf8(content_length.as_bytes()) - .unwrap() - .parse() - .unwrap(); - if src2.len() >= end_index + content_length { - if src2.len() > end_index + content_length { - dbg!(src2.len(), end_index + content_length); - println!("{src2:?}"); - std::process::exit(1); - } - //dbg!(content_length); - let out = src2.to_vec(); - src.clear(); - Ok(Some(out)) - } else { - //dbg!("Not enough data"); - Ok(None) - } - } else { - // Invalid Content-Length - Err(std::io::ErrorKind::InvalidData.into()) - } - } else { - // Header ended without Content-Type => no body - let out = src2.to_vec(); - src.clear(); - Ok(Some(out)) - } - } else { - //dbg!("Unfinished header"); - Ok(None) - } - } else { - //dbg!("Unstarted header"); - Ok(None) - }*/ - } -} - -impl Encoder> for HttpClientCodec { - type Error = std::io::Error; - fn encode( - &mut self, - _item: Vec, - _dst: &mut tokio_util::bytes::BytesMut, - ) -> Result<(), Self::Error> { - Ok(()) - } -} - -pub struct HttpServerCodec { - buf: Vec, -} - -impl HttpServerCodec { - pub fn new() -> Self { - Self { buf: Vec::new() } - } -} - -impl Decoder for HttpServerCodec { - type Item = Vec; - type Error = std::io::Error; - fn decode( - &mut self, - src: &mut tokio_util::bytes::BytesMut, - ) -> Result, Self::Error> { - self.buf.extend_from_slice(&src); - src.clear(); - let src = &mut self.buf; - if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { - end_index += 4; - if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) { - // Content-Length: simple body - if let Some(content_length) = captures.get(1) { - // Read body - let content_length: usize = str::from_utf8(content_length.as_bytes()) - .unwrap() - .parse() - .unwrap(); - if src.len() >= end_index + content_length { - let remaining = src.split_off(end_index + content_length); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } else { - //dbg!("Not enough data"); - Ok(None) - } - } else { - // Invalid Content-Length - Err(std::io::ErrorKind::InvalidData.into()) - } - } else if REGEX_CHUNKED.is_match(&src[0..end_index]) { - // Chunked body - let mut content = &src[end_index..]; - let mut total_len = end_index; - loop { - if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") { - let len_slice = &content[0..len_end_index]; - if len_end_index < 8 && is_hex(len_slice) { - let chunk_len = parse_hex(len_slice) as usize; - if content.len() >= len_end_index + chunk_len + 4 { - total_len += len_end_index + chunk_len + 4; - // Should we check the ending CRLF? - if chunk_len == 0 { - let remaining = src.split_off(total_len); - let out = src.to_vec(); - *src = remaining; - return Ok(Some(out)); - } - // else, wait for the next chunk - content = &content[len_end_index + chunk_len + 4..]; - } else { - // Not enough data - return Ok(None); - } - } else { - // Invalid chunk length - return Err(std::io::ErrorKind::InvalidData.into()); - } - } else { - // Not enough data - return Ok(None); - } - } - } else { - // Header ended without Content-Type nor chunks => no body - let remaining = src.split_off(end_index); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } - } else { - //dbg!("Unfinished header"); - Ok(None) - } - - /*self.buf.extend_from_slice(&src); - src.clear(); - let src = &mut self.buf; - if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { - end_index += 4; - if let Some(captures) = REGEX_CONTENT_LENGTH.captures(&src[0..end_index]) { - if let Some(content_length) = captures.get(1) { - // Read body - let content_length: usize = str::from_utf8(content_length.as_bytes()) - .unwrap() - .parse() - .unwrap(); - if src.len() >= end_index + content_length { - //dbg!(content_length); - let remaining = src.split_off(end_index + content_length); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } else { - //dbg!("Not enough data"); - Ok(None) - } - } else { - // Invalid Content-Length - Err(std::io::ErrorKind::InvalidData.into()) - } - } else { - // Header ended without Content-Type => no body - let remaining = src.split_off(end_index); - let out = src.to_vec(); - *src = remaining; - Ok(Some(out)) - } - } else { - //dbg!("Unfinished header"); - Ok(None) - }*/ - } -} - -impl Encoder> for HttpServerCodec { - type Error = std::io::Error; - fn encode( - &mut self, - _item: Vec, - _dst: &mut tokio_util::bytes::BytesMut, - ) -> Result<(), Self::Error> { - Ok(()) - } -} diff --git a/src/main.rs b/src/main.rs index c331794..08d39aa 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ #![feature(ascii_char)] mod client; -mod http; +mod codec; mod record; mod server; mod util; @@ -10,8 +10,6 @@ use record::Records; use argp::FromArgs; use static_cell::StaticCell; -use tokio::sync::oneshot; -use tokio_rustls::rustls::crypto::CryptoProvider; /// Play recorded requests and responses #[derive(FromArgs)] @@ -26,11 +24,14 @@ struct Opt { #[derive(FromArgs)] #[argp(subcommand)] enum Subcommand { - /// Replay from records - Play(OptPlay), + /// Replay from records (client) + Client(OptClient), + /// Replay from records (server) + Server(OptServer), /// Print records Print(OptPrint), /// Record traffic + #[cfg(feature = "record")] Record(OptRecord), /// Remove record Remove(OptRemove), @@ -40,32 +41,44 @@ enum Subcommand { /// Replay from records #[derive(FromArgs)] -#[argp(subcommand, name = "play")] -struct OptPlay { +#[argp(subcommand, name = "client")] +struct OptClient { /// Connect to address #[argp(positional)] - forward_addr: String, + connect_addr: String, /// Connect to port #[argp(positional)] - forward_port: u16, + connect_port: u16, + /// Whether to use TLS + #[argp(switch, long = "tls")] + tls: bool, + /// Repeat N times + #[argp(option, short = 'r', default = "1")] + repeat: u32, + /// UDP end notification will be sent to this address:port + #[argp(option, short = 'n')] + notify_addr: Option, + /// Only play this record + #[argp(option)] + record: Option, + /// Print debug info + #[argp(switch, short = 'd')] + debug: bool, +} + +/// Replay from records +#[derive(FromArgs)] +#[argp(subcommand, name = "server")] +struct OptServer { /// Listen to port #[argp(positional)] listen_port: u16, /// Path to PEM certificates and keys #[argp(positional)] certs: String, - /// Where to use TLS - #[argp(positional)] - tls: String, - /// Repeat N times - #[argp(option, short = 'r', default = "1")] - repeat: u32, - /// Only play this record - #[argp(option)] - record: Option, - /// Only run these parts - #[argp(option, default = "String::from(\"both\")")] - run: String, + /// Whether to use TLS + #[argp(switch, long = "tls")] + tls: bool, /// Print debug info #[argp(switch, short = 'd')] debug: bool, @@ -75,15 +88,13 @@ struct OptPlay { #[derive(FromArgs)] #[argp(subcommand, name = "print")] struct OptPrint { - /// Print packets - #[argp(switch, short = 'p')] - packets: bool, /// Record number #[argp(option, short = 'n')] number: Option, } /// Record traffic +#[cfg(feature = "record")] #[derive(FromArgs)] #[argp(subcommand, name = "record")] struct OptRecord {} @@ -108,21 +119,6 @@ struct OptRemove { packet_number: usize, } -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum RunMode { - Client, - Server, - Both, -} - -#[derive(Clone, Copy, Debug, Eq, PartialEq)] -enum TlsMode { - None, - Client, - Server, - Both, -} - static RECORDS: StaticCell = StaticCell::new(); #[tokio::main] @@ -131,140 +127,49 @@ async fn main() { let opt: Opt = argp::parse_args_or_exit(argp::DEFAULT); match opt.subcommand { - Subcommand::Play(subopt) => { - let tls_mode = match subopt.tls.as_str() { - "none" => TlsMode::None, - "client" => TlsMode::Client, - "server" => TlsMode::Server, - "both" => TlsMode::Both, - _ => panic!("TLS mode must be one of none,client,server,both."), - }; - let run_mode = match subopt.run.as_str() { - "client" => RunMode::Client, - "server" => RunMode::Server, - "both" => RunMode::Both, - _ => panic!("run mode must be one of client,server,both."), - }; + Subcommand::Client(subopt) => { let records = RECORDS.init(record::read_record_file(&opt.record_file)); if let Some(only_record) = subopt.record { records.retain(|id, _| *id == only_record); } - let mut ciphers: Option> = None; - let mut kexes: Option> = None; - for (var, val) in std::env::vars() { - match var.as_str() { - "CIPHERS" => ciphers = Some(val.split(',').map(str::to_string).collect()), - "KEXES" => kexes = Some(val.split(',').map(str::to_string).collect()), - _ => {} - } - } - let mut prov = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider(); - if let Some(ciphers) = ciphers { - prov.cipher_suites.clear(); - for cipher in ciphers { - match cipher.as_str() { - "AES_256_GCM_SHA384" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_256_GCM_SHA384), - "AES_128_GCM_SHA256" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_128_GCM_SHA256), - "CHACHA20_POLY1305_SHA256" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256), - "ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384), - "ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), - "ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), - "ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), - "ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), - "ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov - .cipher_suites - .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), - other => { - println!("Unknown cipher `{other}`") - } - } - } - } - if let Some(kexes) = kexes { - prov.kx_groups.clear(); - for kex in kexes { - match kex.as_str() { - "X25519" => prov - .kx_groups - .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519), - "SECP256R1" => prov - .kx_groups - .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1), - "SECP384R1" => prov - .kx_groups - .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP384R1), - "X25519MLKEM768" => prov.kx_groups.push( - tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519MLKEM768, - ), - "SECP256R1MLKEM768" => prov.kx_groups.push( - tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1MLKEM768, - ), - "MLKEM768" => prov - .kx_groups - .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::MLKEM768), - other => { - println!("Unknown kex `{other}`") - } - } - } - } - CryptoProvider::install_default(prov).unwrap(); + util::init_provider(); - let (sync_sender, sync_receiver) = oneshot::channel(); //console_subscriber::init(); - let client = tokio::spawn({ - let records = &*records; - async move { - if run_mode == RunMode::Both || run_mode == RunMode::Client { - client::play( - records, - tls_mode, - (subopt.forward_addr, subopt.forward_port), - sync_receiver, - subopt.repeat, - subopt.debug, - ) - .await; - } else { - std::future::pending().await - } - } - }); - if run_mode == RunMode::Both || run_mode == RunMode::Server { - server::play( - records, - tls_mode, - &subopt.certs, - ("0.0.0.0", subopt.listen_port), - sync_sender, - subopt.debug, - ) - .await; + client::play( + records, + subopt.tls, + (subopt.connect_addr, subopt.connect_port), + subopt.repeat, + subopt.debug, + ) + .await; + if let Some(notify_addr) = subopt.notify_addr { + let socket = std::net::UdpSocket::bind("0.0.0.0:48567").unwrap(); + socket.send_to(b"done", ¬ify_addr).unwrap(); } - client.await.unwrap(); + } + Subcommand::Server(subopt) => { + let records = RECORDS.init(record::read_record_file(&opt.record_file)); + + util::init_provider(); + + //console_subscriber::init(); + server::play( + records, + subopt.tls, + &subopt.certs, + ("0.0.0.0", subopt.listen_port), + subopt.debug, + ) + .await; } Subcommand::Print(subopt) => { let records = record::read_record_file(&opt.record_file); - record::print_records(&records, subopt.packets, subopt.number); + record::print_records(&records, subopt.number); } + #[cfg(feature = "record")] Subcommand::Record(_subopt) => { record::make_record(&opt.record_file); } diff --git a/src/record.rs b/src/record.rs index a5c4a3f..37c49bd 100644 --- a/src/record.rs +++ b/src/record.rs @@ -4,12 +4,10 @@ use std::{ sync::mpsc::{Receiver, Sender, channel}, }; -use crate::util::{ResponseStreamer, print_bin}; - const CLIENT_TO_SERVER: u8 = b'C'; const SERVER_TO_CLIENT: u8 = b'S'; -pub type Records = BTreeMap, Vec<(Direction, Vec)>)>; +pub type Records = BTreeMap, Vec<(u64, Direction, u64)>)>; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum Direction { @@ -34,7 +32,7 @@ fn write_record( direction: Direction, conn_id: u64, server_name: &str, - data: &[u8], + len: u64, ) { let server_name = server_name.as_bytes(); file.write_all(&[match direction { @@ -45,8 +43,7 @@ fn write_record( file.write_all(&conn_id.to_be_bytes()).unwrap(); file.write_all(&[server_name.len() as u8]).unwrap(); file.write_all(server_name).unwrap(); - file.write_all(&(data.len() as u64).to_be_bytes()).unwrap(); - file.write_all(&data).unwrap(); + file.write_all(&len.to_be_bytes()).unwrap(); file.flush().unwrap(); } @@ -78,17 +75,25 @@ impl Recorder { let Some(server_name) = server_name else { continue; }; - write_record(&mut self.file, direction, conn_id, &server_name, &data); + write_record( + &mut self.file, + direction, + conn_id, + &server_name, + data.len() as u64, + ); } } } +#[cfg(feature = "record")] #[derive(Clone)] struct Handler { sender: Sender<(u64, Option, Direction, Vec)>, server_name: Option, } +#[cfg(feature = "record")] impl sslrelay::HandlerCallbacks for Handler { // DownStream non blocking callback fn ds_nb_callback(&self, in_data: Vec, conn_id: u64) { @@ -129,6 +134,7 @@ impl sslrelay::HandlerCallbacks for Handler { } } +#[cfg(feature = "record")] pub fn make_record(path: &str) { let (mut recorder, sender) = Recorder::new(path); let mut relay = sslrelay::SSLRelay::new( @@ -161,7 +167,7 @@ pub fn make_record(path: &str) { pub fn read_record_file(path: &str) -> Records { let mut file = std::fs::OpenOptions::new().read(true).open(path).unwrap(); - let mut records = BTreeMap::, Vec<(Direction, Vec)>)>::new(); + let mut records = BTreeMap::, Vec<(u64, Direction, u64)>)>::new(); loop { let mut direction = [0; 1]; if file.read(&mut direction).unwrap() != 1 { @@ -202,84 +208,39 @@ pub fn read_record_file(path: &str) -> Records { println!("Error: len too large {len}. stop."); break; } - let mut buf = vec![0; len as usize]; - if file.read(&mut buf).unwrap() != len as usize { - println!("Error: incomplete data. stop."); - break; - } - - // Replace URL with unique id, to allow for better tracking by making each request unique. - // (proxy may modify some headers, but not the URL) - let mut insert_id = |req_id| { - if direction == Direction::ClientToServer { - let mut spaces = buf - .iter() - .enumerate() - .filter_map(|(i, c)| if *c == b' ' { Some(i) } else { None }); - let s1 = spaces.next().unwrap(); - let s2 = spaces.next().unwrap(); - let new_url = format!("/{conn_id}-{req_id}/"); - if s2 - s1 - 1 < new_url.len() { - // Not optimal but good enough - let mut new_buf = Vec::new(); - new_buf.extend_from_slice(&buf[0..s1 + 1]); - new_buf.extend_from_slice(new_url.as_bytes()); - new_buf.extend_from_slice(&buf[s2..]); - buf = new_buf; - } else { - buf[s1 + 1..s2][0..new_url.len()].copy_from_slice(new_url.as_bytes()); - } - } - }; match records.entry(conn_id) { btree_map::Entry::Occupied(mut entry) => { - (insert_id)(entry.get().1.len()); - entry.get_mut().1.push((direction, buf)); + let req_id = entry.get().1.len() as u64; + entry.get_mut().1.push((req_id, direction, len)); } btree_map::Entry::Vacant(entry) => { - (insert_id)(0); - entry.insert((server_name, vec![(direction, buf)])); + let req_id = 0; + entry.insert((server_name, vec![(req_id, direction, len)])); } } } records } -pub fn print_records(records: &Records, print_packets: bool, number: Option) { - for (id, (server_name, records)) in records { +pub fn print_records(records: &Records, number: Option) { + for (conn_id, (server_name, records)) in records { if let Some(number) = number - && number != *id + && number != *conn_id { continue; } let server_name = str::from_utf8(server_name.as_slice()).unwrap(); - println!("{id} {server_name}"); - for (direction, data) in records { + println!("{conn_id} {server_name}"); + for (req_id, direction, len) in records { match direction { Direction::ClientToServer => { - println!(" >> {}", data.len()); + println!(" ({req_id}) >> {len}"); } Direction::ServerToClient => { - println!(" << {}", data.len()); + println!(" ({req_id}) << {len}"); } } - if print_packets { - /*let data_tr = if data.len() >= 256 && *direction == Direction::ServerToClient { - &data[0..256] - } else { - data.as_slice() - }; - if let Ok(data_tr) = str::from_utf8(data_tr) { - println!(" {data_tr:?}") - } else { - println!(" {data_tr:?}") - } - if let Some(header_end) = memchr::memmem::find(data, b"\r\n\r\n") { - println!(" --> body len: {}", data.len() - header_end - 4); - }*/ - print_bin(&data[0..data.len().min(8192)]); - } } } } @@ -292,7 +253,13 @@ pub fn make_test_record(path: &str) { .open(path) .unwrap(); for (conn_id, server_name, direction, data) in TEST_RECORD { - write_record(&mut file, *direction, *conn_id, *server_name, *data); + write_record( + &mut file, + *direction, + *conn_id, + server_name, + data.len() as u64, + ); } } @@ -311,9 +278,9 @@ pub fn remove_record( .unwrap(); for (conn_id, (server_name, packets)) in records.into_iter() { let server_name = String::from_utf8(server_name).unwrap(); - for (packet_id, (direction, data)) in packets.into_iter().enumerate() { + for (packet_id, (_req_id, direction, len)) in packets.into_iter().enumerate() { if conn_id != record_to_remove || packet_id != packet_to_remove { - write_record(&mut output_file, direction, conn_id, &server_name, &data); + write_record(&mut output_file, direction, conn_id, &server_name, len); } } } diff --git a/src/server.rs b/src/server.rs index 9a7467f..2fd6c7f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,12 +1,7 @@ -use crate::{ - TlsMode, - record::{Direction, Records}, - util::print_bin, -}; +use crate::record::{Direction, Records}; -use futures_util::stream::StreamExt; use std::{collections::HashMap, sync::Arc}; -use tokio::{io::AsyncWriteExt, net::TcpListener, sync::oneshot}; +use tokio::{io::AsyncWriteExt, net::TcpListener}; use tokio_rustls::rustls::{ pki_types::{ CertificateDer, PrivateKeyDer, @@ -15,414 +10,262 @@ use tokio_rustls::rustls::{ server::ResolvesServerCertUsingSni, sign::CertifiedKey, }; -use tokio_util::codec::Framed; use x509_parser::prelude::GeneralName; pub async fn play( records: &'static Records, - tls_mode: TlsMode, + use_tls: bool, cert_path: &str, listen_addr: (&str, u16), - sync_sender: oneshot::Sender<()>, - debug: bool, + _debug: bool, ) { let mut response_map = HashMap::new(); - for (id, (server_name, records)) in records.iter() { - let mut hash = None; + for (conn_id, (_server_name, records)) in records.iter() { + let mut last_client_req_id = None; let mut responses = Vec::new(); - for (direction, data) in records { + for (req_id, direction, len) in records { match direction { Direction::ClientToServer => { - if let Some(hash) = hash + if let Some(last_client_req_id) = last_client_req_id && !responses.is_empty() { - response_map.insert((server_name.to_vec(), hash), (id, responses, false)); + response_map.insert((*conn_id, last_client_req_id), (responses, false)); responses = Vec::new(); } - let mut slashes = data - .iter() - .enumerate() - .filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None }); - let s1 = slashes.next(); - let s2 = slashes.next(); - hash = Some(if let (Some(s1), Some(s2)) = (s1, s2) { - data[s1 + 1..s2].to_vec() - } else { - panic!("Did not find URL: {:?}", &data[0..256]); - tlsh::hash_buf(data) - .map_or_else(|_| data.clone(), |h| h.to_string().into_bytes()) - }); + last_client_req_id = Some(*req_id); } Direction::ServerToClient => { - responses.push(data); + responses.push((*req_id, *len)); } } } - if let Some(hash) = hash { + if let Some(last_client_req_id) = last_client_req_id { if !responses.is_empty() { - response_map.insert((server_name.to_vec(), hash), (id, responses, true)); - } else { - response_map - .get_mut(&(server_name.to_vec(), hash)) - .unwrap() - .2 = true; + response_map.insert((*conn_id, last_client_req_id), (responses, true)); + } else if let Some(entry) = response_map.get_mut(&(*conn_id, last_client_req_id)) { + entry.1 = true; } } } let response_map = Arc::new(response_map); + let dummy_bytes = Arc::new(vec![0x42u8; 16 * 1024 * 1024]); - match tls_mode { - TlsMode::Both | TlsMode::Server => { - let mut resolver = ResolvesServerCertUsingSni::new(); - let mut config = tokio_rustls::rustls::ServerConfig::builder() - .with_no_client_auth() - .with_cert_resolver(Arc::new(ResolvesServerCertUsingSni::new())); - config.max_early_data_size = 8192; - for file in std::fs::read_dir(cert_path).unwrap_or_else(|e| { - panic!("Cannot read certificate directory `{cert_path}`: {e:?}") - }) { - match file { - Ok(file) => { - if file.file_name().as_encoded_bytes().ends_with(b".crt") { - for section in - <(pem::SectionKind, Vec) as PemObject>::pem_file_iter( - file.path(), - ) + if use_tls { + let mut resolver = ResolvesServerCertUsingSni::new(); + let mut config = tokio_rustls::rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(ResolvesServerCertUsingSni::new())); + config.max_early_data_size = 8192; + for file in std::fs::read_dir(cert_path) + .unwrap_or_else(|e| panic!("Cannot read certificate directory `{cert_path}`: {e:?}")) + { + match file { + Ok(file) => { + if file.file_name().as_encoded_bytes().ends_with(b".crt") { + for section in + <(pem::SectionKind, Vec) as PemObject>::pem_file_iter(file.path()) .unwrap() - { - let (kind, data) = section.unwrap(); - if kind == SectionKind::Certificate { - let (_rem, cert) = - x509_parser::parse_x509_certificate(&data).unwrap(); - if !cert.is_ca() { - //println!("File: {:?}", file.file_name()); - let mut key_path = file.path().to_path_buf(); - key_path.pop(); - let file_name = - file.file_name().to_str().unwrap().to_string(); - let mut key_file_name = - file_name[0..file_name.len() - 4].to_string(); - key_file_name.push_str(".key"); - let key = PrivateKeyDer::from_pem_file( - key_path.join(key_file_name), - ) - .unwrap(); - let key = config - .crypto_provider() - .key_provider - .load_private_key(key) + { + let (kind, data) = section.unwrap(); + if kind == SectionKind::Certificate { + let (_rem, cert) = + x509_parser::parse_x509_certificate(&data).unwrap(); + if !cert.is_ca() { + //println!("File: {:?}", file.file_name()); + let mut key_path = file.path().to_path_buf(); + key_path.pop(); + let file_name = file.file_name().to_str().unwrap().to_string(); + let mut key_file_name = + file_name[0..file_name.len() - 4].to_string(); + key_file_name.push_str(".key"); + let key = + PrivateKeyDer::from_pem_file(key_path.join(key_file_name)) .unwrap(); - // This wants static lifetime... - let cert_key = CertifiedKey::new( - vec![CertificateDer::from_slice(Box::leak( - data.to_vec().into_boxed_slice(), - ))], - key, - ); - for name in cert - .subject_alternative_name() - .unwrap() - .unwrap() - .value - .general_names - .iter() - { - if let GeneralName::DNSName(name) = name { - resolver.add(name, cert_key.clone()).ok(); - } + let key = config + .crypto_provider() + .key_provider + .load_private_key(key) + .unwrap(); + // This wants static lifetime... + let cert_key = CertifiedKey::new( + vec![CertificateDer::from_slice(Box::leak( + data.to_vec().into_boxed_slice(), + ))], + key, + ); + for name in cert + .subject_alternative_name() + .unwrap() + .unwrap() + .value + .general_names + .iter() + { + if let GeneralName::DNSName(name) = name { + resolver.add(name, cert_key.clone()).ok(); } } } } } } - Err(e) => eprintln!("Error listing cert directory: {e:?}"), } + Err(e) => eprintln!("Error listing cert directory: {e:?}"), } + } - // Config requires resolver, keys can be added to resolver, creating a key requires config. WTF!? - // So we have to re-create config. - let mut config = tokio_rustls::rustls::ServerConfig::builder() - .with_no_client_auth() - .with_cert_resolver(Arc::new(resolver)); - config.max_early_data_size = 8192; - config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); - let config = Arc::new(config); + // Config requires resolver, keys can be added to resolver, creating a key requires config. WTF!? + // So we have to re-create config. + let mut config = tokio_rustls::rustls::ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(Arc::new(resolver)); + config.max_early_data_size = 8192; + config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new()); + let config = Arc::new(config); - let listener = TcpListener::bind(listen_addr).await.unwrap(); - sync_sender.send(()).unwrap(); - loop { - let config = config.clone(); - let (stream, _peer_addr) = listener.accept().await.unwrap(); - let acceptor = tokio_rustls::LazyConfigAcceptor::new( - tokio_rustls::rustls::server::Acceptor::default(), - stream, - ); - //let acceptor = acceptor.clone(); - let response_map = response_map.clone(); - /*let fut = async move { - let accepted = acceptor.await.unwrap(); - let server_name = accepted.client_hello().server_name().unwrap().to_string(); - let mut stream = accepted.into_stream(config).await.unwrap(); - let mut req = Vec::new(); - http::decode_http(&mut req, &mut stream).await; - let req_hash = tlsh::hash_buf(&req) - .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()); - let mut best = None; - for (i_server_name, hash) in response_map.keys() { - if i_server_name != server_name.as_bytes() { + let listener = TcpListener::bind(listen_addr).await.unwrap(); + loop { + let config = config.clone(); + let (stream, _peer_addr) = listener.accept().await.unwrap(); + let acceptor = tokio_rustls::LazyConfigAcceptor::new( + tokio_rustls::rustls::server::Acceptor::default(), + stream, + ); + //let acceptor = acceptor.clone(); + let response_map = response_map.clone(); + let dummy_bytes = dummy_bytes.clone(); + let fut = async move { + let accepted = acceptor.await.unwrap(); + let server_name = accepted + .client_hello() + .server_name() + .unwrap() + .trim_end_matches(".localhost") + .to_string(); + let stream = accepted + .into_stream(config) + .await + .map_err(|e| panic!("{e:?} with name `{server_name}`")) + .unwrap(); + let mut stream = crate::codec::StreamCodec::new(stream); + let mut break_next = false; + //let mut previous = Vec::new(); + loop { + let Ok(req) = + tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next()) + .await + else { + if break_next { + break; + } else { continue; } - let diff = compare(&req_hash, hash); - if let Some((best_hash, best_diff)) = &mut best { - if diff < *best_diff { - *best_hash = hash; - *best_diff = diff; - } - } else { - best = Some((hash, diff)); - } + }; + let req = req.unwrap(); + if req.len() < 8 { + println!("Invalid request"); + break; } - if let Some((hash, _diff)) = best { - let (id, responses) = response_map - .get(&(server_name.as_bytes().to_vec(), hash.clone())) - .unwrap(); - for &res in responses { - println!("[SRV] response for ({}): {} bytes", id, res.len()); - stream.write_all(res).await.unwrap(); + let _expected_len = u32::from_be_bytes(req[0..4].try_into().unwrap()) as u64; + let conn_id = u16::from_be_bytes(req[4..6].try_into().unwrap()) as u64; + let req_id = u16::from_be_bytes(req[6..8].try_into().unwrap()) as u64; + //println!("REQUEST"); + //print_bin(&req); + //previous = req.clone(); + let stream = stream.get_mut(); + if let Some((responses, last)) = response_map.get(&(conn_id, req_id)) { + //dbg!(id); + for (req_id, len) in responses { + //println!("[SRV] response for ({}): {} bytes", id, res.len()); + let mut data = dummy_bytes[0..*len as usize].to_vec(); + data[0..4].copy_from_slice(&(*len as u32).to_be_bytes()); + data[4..6].copy_from_slice(&(conn_id as u16).to_be_bytes()); + data[6..8].copy_from_slice(&(*req_id as u16).to_be_bytes()); + stream.write_all(&data).await.unwrap(); stream.flush().await.unwrap(); } - } else { - println!("No response found for SNI=`{server_name}`"); - } - stream.shutdown().await.unwrap(); - };*/ - let fut = async move { - let accepted = acceptor.await.unwrap(); - let server_name = accepted - .client_hello() - .server_name() - .unwrap() - .trim_end_matches(".localhost") - .to_string(); - let stream = accepted - .into_stream(config) - .await - .map_err(|e| panic!("{e:?} with name `{server_name}`")) - .unwrap(); - let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new()); - let mut break_next = false; - //let mut previous = Vec::new(); - loop { - let Ok(req) = tokio::time::timeout(tokio::time::Duration::from_secs(1), stream.next()).await else { - if break_next { - break; - } else { - continue; - } - }; - let Some(req) = req else { + if *last { + break_next = true; break; - }; - let req = req.unwrap(); - //println!("REQUEST"); - //print_bin(&req); - let req_hash = { - let mut slashes = req - .iter() - .enumerate() - .filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None }); - let s1 = slashes.next(); - let s2 = slashes.next(); - if let (Some(s1), Some(s2)) = (s1, s2) { - req[s1 + 1..s2].to_vec() - } else { - //println!("Previous: {:?}", &previous); - println!("Did not find URL: {:?}", &req[0..req.len().min(255)]); - tlsh::hash_buf(&req) - .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()) - } - }; - //previous = req.clone(); - let mut best = None; - for (i_server_name, hash) in response_map.keys() { - if i_server_name != server_name.as_bytes() { - continue; - } - let diff = if &req_hash == hash { - 0 - } else { - compare(&req_hash, hash) - }; - if let Some((best_hash, best_diff)) = &mut best { - if diff < *best_diff { - *best_hash = hash; - *best_diff = diff; - } - } else { - best = Some((hash, diff)); - } - } - let stream = stream.get_mut(); - if let Some((hash, _diff)) = best { - let (id, responses, last) = response_map - .get(&(server_name.as_bytes().to_vec(), hash.clone())) - .unwrap(); - //dbg!(id); - for &res in responses { - //println!("[SRV] response for ({}): {} bytes", id, res.len()); - stream.write_all(res).await.unwrap(); - stream.flush().await.unwrap(); - } - if *last { - break_next = true; - } - } else { - println!("No response found for SNI=`{server_name}`"); - } - } - stream.get_mut().shutdown().await.unwrap(); - }; - tokio::spawn(async move { - fut.await; - }); - } - } - TlsMode::None | TlsMode::Client => { - let listener = TcpListener::bind(listen_addr).await.unwrap_or_else(|e| { - println!("Server: Cannot listen: {e:?}"); - std::process::exit(1) - }); - sync_sender.send(()).unwrap(); - loop { - let (stream, _peer_addr) = listener.accept().await.unwrap(); - let response_map = response_map.clone(); - /*let fut = async move { - println!("[SRV] New task"); - let mut req = Vec::new(); - http::decode_http(&mut req, &mut stream).await; - let req_hash = tlsh::hash_buf(&req) - .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()); - let mut best = None; - for (i_server_name, hash) in response_map.keys() { - let diff = compare(&req_hash, hash); - if let Some((best_server_name, best_hash, best_diff)) = &mut best { - if diff < *best_diff { - *best_server_name = i_server_name; - *best_hash = hash; - *best_diff = diff; - } - } else { - best = Some((i_server_name, hash, diff)); - } - } - if let Some((server_name, hash, _diff)) = best { - let (id, responses) = response_map - .get(&(server_name.clone(), hash.clone())) - .unwrap(); - for &res in responses { - println!("[SRV] response for ({}): {} bytes", id, res.len()); - stream.write_all(res).await.unwrap(); - stream.flush().await.unwrap(); } } else { - println!("[SRV] No response found"); + println!("No response found for {conn_id}-{req_id} SNI=`{server_name}`"); } - //println!("Server shutdown"); - stream.shutdown().await.unwrap(); - };*/ - let fut = async move { - //println!("[SRV] New task"); - //let mut stream = crate::http::Framer::new(stream, crate::http::HttpServerCodec::new()); - let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new()); - //let mut previous = Vec::new(); - while let Some(req) = stream.next().await { - let req = req.unwrap(); - //println!("REQUEST"); - //print_bin(&req); - //println!("[SRV] << {}", str::from_utf8(&req[..req.len().min(255)]).unwrap()); - let req_hash = { - let mut slashes = req - .iter() - .enumerate() - .filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None }); - let s1 = slashes.next(); - let s2 = slashes.next(); - if let (Some(s1), Some(s2)) = (s1, s2) { - let uniq_id = req[s1 + 1..s2].to_vec(); - if debug { - if let Ok(uniq_id) = str::from_utf8(&uniq_id) { - println!("[SRV] ({uniq_id}) << {}", req.len()); - } - } - uniq_id - } else { - //println!("Previous: {:?}", &previous); - println!("Did not find URL: {:?}", &req[0..req.len().min(255)]); - tlsh::hash_buf(&req) - .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()) - } - }; - //previous = req.clone(); - let mut best = None; - for (i_server_name, hash) in response_map.keys() { - let diff = compare(&req_hash, hash); - if let Some((best_server_name, best_hash, best_diff)) = &mut best { - if diff < *best_diff { - *best_server_name = i_server_name; - *best_hash = hash; - *best_diff = diff; - } - } else { - best = Some((i_server_name, hash, diff)); - } - } - let stream = stream.get_mut(); - if let Some((server_name, hash, _diff)) = best { - let (id, responses, last) = response_map - .get(&(server_name.clone(), hash.clone())) - .unwrap(); - //dbg!(id); - for &res in responses { - if debug { - println!("[SRV] ({id}) >> {}", res.len()); - //println!("[SRV] response for ({}): {} bytes", id, res.len()); - } - stream.write_all(res).await.unwrap(); - stream.flush().await.unwrap(); - if debug { - println!("[SRV] ({id}) >> {} OK", res.len()); - } - } - if *last { - //break; - } - } else { - println!("[SRV] No response found"); - } - } - //println!("Server shutdown"); - stream.get_mut().shutdown().await.unwrap(); - }; - // Using a variable for the future allows it to be detected by tokio-console - tokio::spawn(async move { - fut.await; - }); - } + } + stream.get_mut().shutdown().await.unwrap(); + }; + tokio::spawn(async move { + fut.await; + }); } - } -} - -fn compare(a: &[u8], b: &[u8]) -> u32 { - if let (Ok(a), Ok(b)) = (str::from_utf8(a), str::from_utf8(b)) { - if let Ok(diff) = tlsh::compare(a, b) { - return diff; - } - } - if a == b { - 0 } else { - a.len().max(b.len()) as u32 + let listener = TcpListener::bind(listen_addr).await.unwrap_or_else(|e| { + println!("Server: Cannot listen: {e:?}"); + std::process::exit(1) + }); + loop { + let (stream, _peer_addr) = listener.accept().await.unwrap(); + let response_map = response_map.clone(); + let dummy_bytes = dummy_bytes.clone(); + let fut = async move { + //println!("[SRV] New task"); + //let mut stream = Framed::new(stream, crate::codec::CustomCodec::new()); + let mut stream = crate::codec::StreamCodec::new(stream); + let mut break_next = false; + //let mut previous = Vec::new(); + loop { + let Ok(req) = + tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next()) + .await + else { + if break_next { + println!("break timeout"); + break; + } else { + println!("continue"); + continue; + } + }; + let req = req.unwrap(); + if req.len() < 8 { + println!("Invalid request"); + break; + } + let expected_len = u32::from_be_bytes(req[0..4].try_into().unwrap()) as u64; + let conn_id = u16::from_be_bytes(req[4..6].try_into().unwrap()) as u64; + let req_id = u16::from_be_bytes(req[6..8].try_into().unwrap()) as u64; + //println!("[SRV] ({conn_id}) << {expected_len}"); + //println!("REQUEST"); + //print_bin(&req); + //previous = req.clone(); + let stream = stream.get_mut(); + if let Some((responses, last)) = response_map.get(&(conn_id, req_id)) { + //dbg!(id); + for (req_id, len) in responses { + //println!("[SRV] ({conn_id}) >> {len}"); + let mut data = dummy_bytes[0..*len as usize].to_vec(); + data[0..4].copy_from_slice(&(*len as u32).to_be_bytes()); + data[4..6].copy_from_slice(&(conn_id as u16).to_be_bytes()); + data[6..8].copy_from_slice(&(*req_id as u16).to_be_bytes()); + stream.write_all(&data).await.unwrap(); + stream.flush().await.unwrap(); + } + if *last { + break_next = true; + break; + } + } else { + println!("No response found for {conn_id}-{req_id}"); + } + } + //println!("Server shutdown"); + stream.get_mut().shutdown().await.unwrap(); + }; + // Using a variable for the future allows it to be detected by tokio-console + tokio::spawn(async move { + fut.await; + }); + } } } diff --git a/src/util.rs b/src/util.rs index b401eae..1f0e718 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,26 +1,8 @@ -use std::iter::Peekable; - use crate::record::Direction; -fn hex_digit(c: u8) -> u32 { - ((c & !(16 | 32 | 64)) + ((c & 64) >> 6) * 9) as _ -} - -pub fn parse_hex(s: &[u8]) -> u32 { - let mut r = 0; - for i in s.iter() { - r <<= 4; - r |= hex_digit(*i); - } - r -} - -pub fn is_hex(s: &[u8]) -> bool { - s.iter().all(|c| { - let c = *c | 32; - (c >= b'a' && c <= b'f') || (c >= b'0' && c <= b'9') - }) -} +use log::info; +use std::iter::Peekable; +use tokio_rustls::rustls::crypto::CryptoProvider; /// Print ASCII if possible pub fn print_bin(s: &[u8]) { @@ -46,85 +28,407 @@ pub fn print_bin(s: &[u8]) { pub struct ResponseStreamer(Peekable); -impl<'a, I: Iterator> ResponseStreamer { +impl ResponseStreamer { pub fn new(inner: I) -> Self { Self(inner.peekable()) } } -impl<'a, I: Iterator)>> Iterator for ResponseStreamer { - type Item = (&'a Direction, Vec<&'a Vec>); +impl<'a, I: Iterator> Iterator for ResponseStreamer { + type Item = (Direction, Vec<(u64, u64)>); fn next(&mut self) -> Option { - let (direction, first_item) = self.0.next()?; - let mut items = vec![first_item]; - while let Some((item_direction, _item)) = self.0.peek() - && item_direction == direction + let (first_req_id, first_direction, first_len) = self.0.next()?; + let mut items = vec![(*first_req_id, *first_len)]; + while let Some((_req_id, direction, _len)) = self.0.peek() + && direction == first_direction { - items.push(&self.0.next().unwrap().1); + let (req_id, _direction, len) = self.0.next().unwrap(); + items.push((*req_id, *len)); } - Some((direction, items)) + Some((*first_direction, items)) } } -#[cfg(test)] -mod test { - use super::*; - - #[test] - fn test_hex_digit() { - assert_eq!(hex_digit(b'0'), 0); - assert_eq!(hex_digit(b'1'), 1); - assert_eq!(hex_digit(b'2'), 2); - assert_eq!(hex_digit(b'3'), 3); - assert_eq!(hex_digit(b'4'), 4); - assert_eq!(hex_digit(b'5'), 5); - assert_eq!(hex_digit(b'6'), 6); - assert_eq!(hex_digit(b'7'), 7); - assert_eq!(hex_digit(b'8'), 8); - assert_eq!(hex_digit(b'9'), 9); - assert_eq!(hex_digit(b'a'), 10); - assert_eq!(hex_digit(b'b'), 11); - assert_eq!(hex_digit(b'c'), 12); - assert_eq!(hex_digit(b'd'), 13); - assert_eq!(hex_digit(b'e'), 14); - assert_eq!(hex_digit(b'f'), 15); - assert_eq!(hex_digit(b'A'), 10); - assert_eq!(hex_digit(b'B'), 11); - assert_eq!(hex_digit(b'C'), 12); - assert_eq!(hex_digit(b'D'), 13); - assert_eq!(hex_digit(b'E'), 14); - assert_eq!(hex_digit(b'F'), 15); +pub fn init_provider() { + let mut ciphers: Option> = None; + let mut kexes: Option> = None; + for (var, val) in std::env::vars() { + match var.as_str() { + "CIPHERS" => ciphers = Some(val.split(',').map(str::to_string).collect()), + "KEXES" => kexes = Some(val.split(',').map(str::to_string).collect()), + _ => {} + } + } + // Ensure multiple provider cannot be enabled without compile error. + let _provider; + #[cfg(feature = "aws-lc")] + { + info!("Using RusTLS provider aws-lc"); + let mut prov = rustls_post_quantum::provider(); + if let Some(ciphers) = ciphers { + prov.cipher_suites.clear(); + for cipher in ciphers { + match cipher.as_str() { + "AES_256_GCM_SHA384" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_256_GCM_SHA384), + "AES_128_GCM_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_128_GCM_SHA256), + "CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256), + "ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384), + "ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), + "ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), + "ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), + "ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), + "ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), + other => { + log::error!("Unknown cipher `{other}`") + } + } + } + } + if let Some(kexes) = kexes { + prov.kx_groups.clear(); + for kex in kexes { + match kex.as_str() { + "X25519" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519), + "SECP256R1" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1), + "SECP384R1" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP384R1), + "X25519MLKEM768" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519MLKEM768), + "SECP256R1MLKEM768" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1MLKEM768), + "MLKEM768" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::MLKEM768), + other => { + log::error!("Unknown kex `{other}`") + } + } + } + } + _provider = CryptoProvider::install_default(prov); + } + #[cfg(feature = "boring")] + { + info!("Using RusTLS provider boring"); + let mut prov = boring_rustls_provider::provider(); + if let Some(ciphers) = ciphers { + prov.cipher_suites.clear(); + for cipher in ciphers { + match cipher.as_str() { + "AES_256_GCM_SHA384" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls13( + &boring_rustls_provider::tls13::AES_256_GCM_SHA384, + )), + "AES_128_GCM_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls13( + &boring_rustls_provider::tls13::AES_128_GCM_SHA256, + )), + "CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls13( + &boring_rustls_provider::tls13::CHACHA20_POLY1305_SHA256, + )), + "ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12( + &boring_rustls_provider::tls12::ECDHE_ECDSA_AES256_GCM_SHA384, + )), + "ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12( + &boring_rustls_provider::tls12::ECDHE_ECDSA_AES128_GCM_SHA256, + )), + "ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12( + &boring_rustls_provider::tls12::ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + )), + "ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12( + &boring_rustls_provider::tls12::ECDHE_RSA_AES256_GCM_SHA384, + )), + "ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12( + &boring_rustls_provider::tls12::ECDHE_RSA_AES128_GCM_SHA256, + )), + "ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12( + &boring_rustls_provider::tls12::ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + )), + other => { + log::error!("Unknown cipher `{other}`") + } + } + } + } + if let Some(kexes) = kexes { + prov.kx_groups.clear(); + for kex in kexes { + match kex.as_str() { + "X25519" => prov + .kx_groups + .push(boring_rustls_provider::ALL_KX_GROUPS[0]), + "SECP256R1" => prov + .kx_groups + .push(boring_rustls_provider::ALL_KX_GROUPS[2]), + "SECP384R1" => prov + .kx_groups + .push(boring_rustls_provider::ALL_KX_GROUPS[3]), + other => { + log::error!("Unknown kex `{other}`") + } + } + } + } + _provider = CryptoProvider::install_default(prov); } - #[test] - fn test_parse_hex() { - assert_eq!(parse_hex(b"abc123"), 0xabc123); - assert_eq!(parse_hex(b"1"), 1); + #[cfg(feature = "graviola")] + { + info!("Using RusTLS provider graviola"); + let mut prov = rustls_graviola::default_provider(); + if let Some(ciphers) = ciphers { + prov.cipher_suites.clear(); + for cipher in ciphers { + match cipher.as_str() { + "AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS13_AES_256_GCM_SHA384), + "AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS13_AES_128_GCM_SHA256), + "CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS13_CHACHA20_POLY1305_SHA256), + "ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384), + "ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), + "ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push( + rustls_graviola::suites::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + ), + "ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), + "ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), + "ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(rustls_graviola::suites::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), + other => { + log::error!("Unknown cipher `{other}`") + } + } + } + } + if let Some(kexes) = kexes { + prov.kx_groups.clear(); + for kex in kexes { + match kex.as_str() { + "X25519" => prov.kx_groups.push(&rustls_graviola::kx::X25519), + "SECP256R1" => prov.kx_groups.push(&rustls_graviola::kx::P256), + "SECP384R1" => prov.kx_groups.push(&rustls_graviola::kx::P384), + "X25519MLKEM768" => prov.kx_groups.push(rustls_graviola::kx::X25519MLKEM768), + other => { + log::error!("Unknown kex `{other}`") + } + } + } + } + _provider = CryptoProvider::install_default(prov); } - #[test] - fn test_is_hex() { - assert!(is_hex(b"0")); - assert!(is_hex(b"1")); - assert!(is_hex(b"2")); - assert!(is_hex(b"3")); - assert!(is_hex(b"4")); - assert!(is_hex(b"5")); - assert!(is_hex(b"6")); - assert!(is_hex(b"7")); - assert!(is_hex(b"8")); - assert!(is_hex(b"9")); - assert!(is_hex(b"a")); - assert!(is_hex(b"b")); - assert!(is_hex(b"c")); - assert!(is_hex(b"d")); - assert!(is_hex(b"e")); - assert!(is_hex(b"f")); - assert!(is_hex(b"A")); - assert!(is_hex(b"B")); - assert!(is_hex(b"C")); - assert!(is_hex(b"D")); - assert!(is_hex(b"E")); - assert!(is_hex(b"F")); + #[cfg(feature = "openssl")] + { + info!("Using RusTLS provider openssl"); + let mut prov = rustls_openssl::default_provider(); + if let Some(ciphers) = ciphers { + prov.cipher_suites.clear(); + for cipher in ciphers { + match cipher.as_str() { + "AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_openssl::cipher_suite::TLS13_AES_256_GCM_SHA384), + "AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_openssl::cipher_suite::TLS13_AES_128_GCM_SHA256), + "CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(rustls_openssl::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256), + "ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov.cipher_suites.push( + rustls_openssl::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, + ), + "ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov.cipher_suites.push( + rustls_openssl::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + ), + "ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push( + rustls_openssl::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256, + ), + "ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_openssl::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), + "ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_openssl::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), + "ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push( + rustls_openssl::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + ), + other => { + log::error!("Unknown cipher `{other}`") + } + } + } + } + if let Some(kexes) = kexes { + prov.kx_groups.clear(); + for kex in kexes { + match kex.as_str() { + "X25519" => prov.kx_groups.push(rustls_openssl::kx_group::X25519), + "SECP256R1" => prov.kx_groups.push(rustls_openssl::kx_group::SECP256R1), + "SECP384R1" => prov.kx_groups.push(rustls_openssl::kx_group::SECP384R1), + "X25519MLKEM768" => prov + .kx_groups + .push(rustls_openssl::kx_group::X25519MLKEM768), + "MLKEM768" => prov.kx_groups.push(rustls_openssl::kx_group::MLKEM768), + other => { + log::error!("Unknown kex `{other}`") + } + } + } + } + _provider = CryptoProvider::install_default(prov); + } + #[cfg(feature = "ring")] + { + info!("Using RusTLS provider ring"); + let mut prov = tokio_rustls::rustls::crypto::ring::default_provider(); + if let Some(ciphers) = ciphers { + prov.cipher_suites.clear(); + for cipher in ciphers { + match cipher.as_str() { + "AES_256_GCM_SHA384" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS13_AES_256_GCM_SHA384), + "AES_128_GCM_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS13_AES_128_GCM_SHA256), + "CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256), + "ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384), + "ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), + "ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), + "ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), + "ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), + "ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), + other => { + log::error!("Unknown cipher `{other}`") + } + } + } + } + if let Some(kexes) = kexes { + prov.kx_groups.clear(); + for kex in kexes { + match kex.as_str() { + "X25519" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::ring::kx_group::X25519), + "SECP256R1" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::ring::kx_group::SECP256R1), + "SECP384R1" => prov + .kx_groups + .push(tokio_rustls::rustls::crypto::ring::kx_group::SECP384R1), + other => { + log::error!("Unknown kex `{other}`") + } + } + } + } + _provider = CryptoProvider::install_default(prov); + } + #[cfg(feature = "symcrypt")] + { + info!("Using RusTLS provider symcrypt"); + let mut prov = rustls_symcrypt::default_symcrypt_provider(); + if let Some(ciphers) = ciphers { + prov.cipher_suites.clear(); + for cipher in ciphers { + match cipher.as_str() { + "AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_symcrypt::TLS13_AES_256_GCM_SHA384), + "AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_symcrypt::TLS13_AES_128_GCM_SHA256), + "CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(rustls_symcrypt::TLS13_CHACHA20_POLY1305_SHA256), + "ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_symcrypt::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384), + "ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_symcrypt::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256), + "ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(rustls_symcrypt::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256), + "ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov + .cipher_suites + .push(rustls_symcrypt::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384), + "ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov + .cipher_suites + .push(rustls_symcrypt::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256), + "ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov + .cipher_suites + .push(rustls_symcrypt::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256), + other => { + log::error!("Unknown cipher `{other}`") + } + } + } + } + if let Some(kexes) = kexes { + prov.kx_groups.clear(); + for kex in kexes { + match kex.as_str() { + "X25519" => prov.kx_groups.push(rustls_symcrypt::X25519), + "SECP256R1" => prov.kx_groups.push(rustls_symcrypt::SECP256R1), + "SECP384R1" => prov.kx_groups.push(rustls_symcrypt::SECP384R1), + other => { + log::error!("Unknown kex `{other}`") + } + } + } + } + _provider = CryptoProvider::install_default(prov); } }