Dummy data instead of HTTP

This commit is contained in:
Pascal Engélibert 2026-02-06 15:05:41 +01:00
commit dec39cf2e3
10 changed files with 1543 additions and 1413 deletions

View file

@ -1,10 +1,8 @@
use crate::{
TlsMode,
record::{Direction, Records},
util::{ResponseStreamer, print_bin},
util::ResponseStreamer,
};
use futures_util::{StreamExt, TryStreamExt};
use std::{
collections::HashSet,
net::ToSocketAddrs,
@ -14,7 +12,7 @@ use std::{
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
sync::{Mutex, Semaphore, oneshot},
sync::{Mutex, Semaphore},
};
use tokio_rustls::{
TlsConnector,
@ -27,7 +25,6 @@ use tokio_rustls::{
pki_types::ServerName,
},
};
use tokio_util::codec::Framed;
const TIMEOUT: Duration = Duration::from_secs(30);
@ -89,13 +86,11 @@ impl ServerCertVerifier for DummyCertVerifier {
pub async fn play(
records: &'static Records,
tls_mode: TlsMode,
use_tls: bool,
connect_to: (String, u16),
sync_receiver: oneshot::Receiver<()>,
repeat: u32,
debug: bool,
) {
sync_receiver.await.unwrap();
// Semaphore used to limit the number of concurrent clients.
// Its handle is released when the task panics.
let limiter = Arc::new(Semaphore::new(16));
@ -105,6 +100,10 @@ pub async fn play(
let connect_to = connect_to.to_socket_addrs().unwrap().next().unwrap();
let debug_mutex = Arc::new(Mutex::new(()));
let dummy_bytes = Arc::new(vec![0x42u8; 16 * 1024 * 1024]);
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
tokio::spawn({
let running = running.clone();
let counter = counter.clone();
@ -123,303 +122,200 @@ pub async fn play(
}
});
match tls_mode {
TlsMode::Both | TlsMode::Client => {
let mut config = tokio_rustls::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyCertVerifier))
.with_no_client_auth();
let mut enable_early_data = false;
for (var, val) in std::env::vars() {
match var.as_str() {
"EARLYDATA" => enable_early_data = val == "1",
_ => {}
}
}
if enable_early_data {
config.enable_early_data = true;
} else {
config.resumption = Resumption::disabled();
}
config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new());
let config = Arc::new(config);
for _i in 0..repeat {
let mut handles = Vec::new();
for (id, (server_name, records)) in records.iter() {
let connector = TlsConnector::from(config.clone());
let counter = counter.clone();
let limiter = limiter.clone();
let running = running.clone();
handles.push(tokio::spawn(async move {
let mut running_guard = running.lock().await;
running_guard.insert(*id);
drop(running_guard);
let limiter = limiter.acquire().await.unwrap();
let server_name =
ServerName::try_from(String::from_utf8(server_name.clone()).unwrap())
.unwrap();
'repeat: for _i in 0..1 {
let stream = TcpStream::connect(connect_to).await.unwrap();
let stream = connector
.connect(server_name.clone(), stream)
.await
.unwrap();
let mut stream =
Framed::new(stream, crate::http::HttpClientCodec::new());
for (direction, data_list) in ResponseStreamer::new(records.iter()) {
match direction {
Direction::ClientToServer => {
for data in data_list {
//println!("[CLT] ({id}) >> {}", data.len());
//stream.get_mut().write_all(data).await.unwrap();
match tokio::time::timeout(
TIMEOUT,
stream.get_mut().write_all(data),
)
.await
{
Ok(v) => v.unwrap(),
Err(_e) => {
println!("client timeout {id} (sending)");
continue 'repeat;
}
}
}
}
Direction::ServerToClient => {
let total_len: usize =
data_list.iter().map(|data| data.len()).sum::<usize>();
let reduced_len =
total_len.saturating_sub(160 * data_list.len()).max(1);
let mut total_recv = 0;
//println!("[CLT] ({id}) << {}", data.len());
// let mut buf = Vec::new();
// stream.read_buf(&mut buf).await.ok();
//let mut buf = vec![0; data.len().saturating_sub(50).max(1)];
//let resp = stream.next().await.unwrap().unwrap();
while total_recv < reduced_len {
let resp =
match tokio::time::timeout(TIMEOUT, stream.next())
.await
{
Ok(v) => v.unwrap().unwrap(),
Err(_e) => {
// TODO fix
println!(
"client timeout {}: {} / {}",
id, total_recv, total_len
);
//print_bin(data);
break 'repeat;
}
};
total_recv += resp.len();
//dbg!(resp.len());
//crate::http::decode_http(&mut buf, &mut stream).await;
}
/*if total_recv > total_len {
println!("received too much {}: {} / {}", id, total_recv, total_len);
}*/
}
}
}
//stream.get_mut().shutdown().await.unwrap();
tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown())
.await
.unwrap()
.unwrap();
let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
println!("Client: {} / {}", cnt + 1, total);
}
drop(limiter);
let mut running_guard = running.lock().await;
running_guard.remove(id);
drop(running_guard);
}));
//tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
for handle in handles {
handle.await.unwrap();
}
if use_tls {
let mut config = tokio_rustls::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyCertVerifier))
.with_no_client_auth();
let mut enable_early_data = false;
for (var, val) in std::env::vars() {
match var.as_str() {
"EARLYDATA" => enable_early_data = val == "1",
_ => {}
}
}
TlsMode::None | TlsMode::Server => {
for _i in 0..repeat {
let mut handles = Vec::new();
for (id, (_server_name, records)) in records.iter() {
/*if *id != 33 {
continue
}*/
let counter = counter.clone();
let limiter = limiter.clone();
let running = running.clone();
let debug_mutex = debug_mutex.clone();
handles.push(tokio::spawn(async move {
let mut running_guard = running.lock().await;
running_guard.insert(*id);
drop(running_guard);
let limiter = limiter.acquire().await.unwrap();
//let mut buf = Vec::new();
'repeat: for _i in 0..1 {
let stream = TcpStream::connect(connect_to).await.unwrap();
let mut stream =
Framed::new(stream, crate::http::HttpClientCodec::new());
/*let mut skip_recv = false;
for (direction, data) in records {
match direction {
Direction::ClientToServer => {
skip_recv = false;
println!("[CLT] ({id}) >> {}", data.len());
stream.write_all(data).await.unwrap();
}
Direction::ServerToClient => {
if skip_recv {
if enable_early_data {
config.enable_early_data = true;
} else {
config.resumption = Resumption::disabled();
}
config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new());
let config = Arc::new(config);
for _i in 0..repeat {
let mut handles = Vec::new();
for (conn_id, (server_name, records)) in records.iter() {
let connector = TlsConnector::from(config.clone());
let counter = counter.clone();
let limiter = limiter.clone();
let running = running.clone();
let dummy_bytes = dummy_bytes.clone();
handles.push(tokio::spawn(async move {
let mut running_guard = running.lock().await;
running_guard.insert(*conn_id);
drop(running_guard);
let limiter = limiter.acquire().await.unwrap();
let server_name =
ServerName::try_from(String::from_utf8(server_name.clone()).unwrap())
.unwrap();
let stream = TcpStream::connect(connect_to).await.unwrap();
let stream = connector
.connect(server_name.clone(), stream)
.await
.unwrap();
let mut stream = crate::codec::StreamCodec::new(stream);
for (direction, reqs) in ResponseStreamer::new(records.iter()) {
match direction {
Direction::ClientToServer => {
for (req_id, len) in reqs {
//println!("[CLT] ({conn_id}) >> {}", len);
let mut data = dummy_bytes[0..len as usize].to_vec();
data[0..4].copy_from_slice(&(len as u32).to_be_bytes());
data[4..6].copy_from_slice(&(*conn_id as u16).to_be_bytes());
data[6..8].copy_from_slice(&(req_id as u16).to_be_bytes());
match tokio::time::timeout(TIMEOUT, async {
stream.get_mut().write_all(&data).await.unwrap();
})
.await
{
Ok(_v) => {}
Err(_e) => {
println!("client timeout {conn_id} (sending)");
continue;
}
println!("[CLT] ({id}) << {}", data.len());
//let mut buf = Vec::new();
//stream.read_buf(&mut buf).await.ok();
//let mut buf = vec![0; data.len().saturating_sub(50).max(1)];
let mut buf = vec![0; data.len()];
match tokio::time::timeout(
std::time::Duration::from_millis(500),
stream.readable(),
)
.await
{
Ok(r) => {
r.unwrap();
}
Err(_) => {
println!("[CLT] timeout recv ({id})");
break;
}
}
// TODO utiliser crate::http ici
match tokio::time::timeout(
std::time::Duration::from_millis(500),
stream.read_exact(&mut buf),
)
.await
{
Ok(r) => {
r.unwrap();
}
Err(_) => {
println!("[CLT] skip recv ({id})");
skip_recv = true;
}
}
}
}
}*/
for (direction, data_list) in ResponseStreamer::new(records.iter()) {
match direction {
Direction::ClientToServer => {
for data in data_list.into_iter() {
if debug {
//println!("[CLT] ({id}) >> {}", str::from_utf8(&data[..data.len().min(255)]).unwrap());
println!("[CLT] ({id}) >> {}", data.len());
}
//stream.get_mut().write_all(data).await.unwrap();
match tokio::time::timeout(
TIMEOUT,
stream.get_mut().write_all(data),
)
.await
{
Ok(v) => v.unwrap(),
Err(_e) => {
println!("client timeout {id} (sending)");
continue 'repeat;
}
}
}
}
Direction::ServerToClient => {
let total_len: usize =
data_list.iter().map(|data| data.len()).sum::<usize>();
let reduced_len =
total_len.saturating_sub(160 * data_list.len()).max(1);
let mut total_recv = 0;
if debug {
println!("[CLT] ({id}) << {total_len}");
}
//let mut buf = Vec::new();
//stream.read_buf(&mut buf).await.ok();
//let mut buf = vec![0; data.len().saturating_sub(50).max(1)];
let mut resp = Vec::new();
while total_recv < reduced_len {
resp =
match tokio::time::timeout(TIMEOUT, stream.next())
.await
{
Ok(None) => break,
Ok(Some(v)) => v.unwrap(),
Err(_e) => {
// TODO fix
println!(
"client timeout {}: {} / {}",
id, total_recv, total_len
);
//print_bin(data);
break 'repeat;
}
};
total_recv += resp.len();
/*if resp.len() != data.len() {
let guard = debug_mutex.lock().await;
println!("RECV NOT ENOUGH {} / {}", resp.len(), data.len());
if resp.len() < 1000 && data.len() < 1000 {
//print_bin(&resp);
//println!("WANTED");
//print_bin(data);
}
std::mem::drop(guard);
}*/
//print_bin(&resp);
//let resp = stream.next().await.unwrap().unwrap();
//dbg!(resp.len());
//crate::http::decode_http(&mut buf, &mut stream).await;
//buf.clear();
}
if total_recv < reduced_len {
println!(
"({}) RECV NOT ENOUGH {} / {}",
id, total_recv, total_len
);
if resp.len() < 1024 {
print_bin(&resp);
}
} else if debug {
println!("[CLT] ({id}) << {total_len} OK");
}
}
}
}
//stream.get_mut().shutdown().await.unwrap();
tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown())
.await
.unwrap()
.unwrap();
let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
println!("Client: {} / {}", cnt + 1, total);
Direction::ServerToClient => {
let expected_total_len =
reqs.iter().map(|(_req_id, len)| *len).sum::<u64>();
let mut total_recv = 0;
//println!("[CLT] ({conn_id}) << {}", expected_total_len);
while total_recv < expected_total_len {
let resp =
match tokio::time::timeout(TIMEOUT, stream.next()).await {
Ok(v) => v.unwrap(),
Err(_e) => {
// TODO fix
println!(
"client timeout {}: {} / {}",
conn_id, total_recv, expected_total_len
);
//print_bin(data);
break;
}
};
total_recv += resp.len() as u64;
}
/*if total_recv > total_len {
println!("received too much {}: {} / {}", id, total_recv, total_len);
}*/
}
}
drop(limiter);
let mut running_guard = running.lock().await;
running_guard.remove(id);
drop(running_guard);
}));
//tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
}
//stream.get_mut().shutdown().await.unwrap();
//println!("Client shutdown");
tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown())
.await
.unwrap()
.unwrap();
let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
println!("Client: {} / {}", cnt + 1, total);
drop(limiter);
let mut running_guard = running.lock().await;
running_guard.remove(conn_id);
drop(running_guard);
}));
//tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
for handle in handles {
handle.await.unwrap();
}
for handle in handles {
handle.await.unwrap();
}
}
} else {
for _i in 0..repeat {
let mut handles = Vec::new();
for (conn_id, (_server_name, records)) in records.iter() {
let counter = counter.clone();
let limiter = limiter.clone();
let running = running.clone();
let dummy_bytes = dummy_bytes.clone();
handles.push(tokio::spawn(async move {
let mut running_guard = running.lock().await;
running_guard.insert(*conn_id);
drop(running_guard);
let limiter = limiter.acquire().await.unwrap();
let stream = TcpStream::connect(connect_to).await.unwrap();
let mut stream = crate::codec::StreamCodec::new(stream);
for (direction, reqs) in ResponseStreamer::new(records.iter()) {
match direction {
Direction::ClientToServer => {
for (req_id, len) in reqs {
let mut data = dummy_bytes[0..len as usize].to_vec();
data[0..4].copy_from_slice(&(len as u32).to_be_bytes());
data[4..6].copy_from_slice(&(*conn_id as u16).to_be_bytes());
data[6..8].copy_from_slice(&(req_id as u16).to_be_bytes());
//println!("[CLT] ({conn_id}) >> {}", len);
match tokio::time::timeout(TIMEOUT, async {
stream.get_mut().write_all(&data).await.unwrap();
})
.await
{
Ok(_v) => {}
Err(_e) => {
println!("client timeout {conn_id} (sending)");
continue;
}
}
}
}
Direction::ServerToClient => {
let expected_total_len =
reqs.iter().map(|(_req_id, len)| *len).sum::<u64>();
let mut total_recv = 0;
//println!("[CLT] ({conn_id}) << {}", expected_total_len);
while total_recv < expected_total_len {
let resp =
match tokio::time::timeout(TIMEOUT, stream.next()).await {
Ok(v) => v.unwrap(),
Err(_e) => {
// TODO fix
println!(
"client timeout {}: {} / {}",
conn_id, total_recv, expected_total_len
);
//print_bin(data);
break;
}
};
total_recv += resp.len() as u64;
}
/*if total_recv > total_len {
println!("received too much {}: {} / {}", id, total_recv, total_len);
}*/
}
}
}
//stream.get_mut().shutdown().await.unwrap();
//println!("Client shutdown");
tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown())
.await
.unwrap()
.unwrap();
let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
println!("Client: {} / {}", cnt + 1, total);
drop(limiter);
let mut running_guard = running.lock().await;
running_guard.remove(conn_id);
drop(running_guard);
}));
//tokio::time::sleep(std::time::Duration::from_millis(500)).await;
}
for handle in handles {
handle.await.unwrap();
}
}
}
println!("Unfinished: {:?}", running.lock().await);
std::process::exit(0);
}

40
src/codec.rs Normal file
View file

@ -0,0 +1,40 @@
use tokio::io::{AsyncRead, AsyncReadExt};
pub struct StreamCodec<S> {
stream: S,
}
impl<S: AsyncRead + Unpin> StreamCodec<S> {
pub fn new(stream: S) -> Self {
Self { stream }
}
pub async fn next(&mut self) -> Result<Vec<u8>, std::io::Error> {
let mut buf = vec![0; 8];
self.stream.read_exact(&mut buf).await?;
let expected_len = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize;
if expected_len < 8 || expected_len > 8 * 1024 * 1024 {
return Err(std::io::ErrorKind::InvalidData.into());
}
buf.resize(expected_len, 0);
self.stream.read_exact(&mut buf[8..expected_len]).await?;
Ok(buf)
}
pub fn get_mut(&mut self) -> &mut S {
&mut self.stream
}
}
/*#[cfg(test)]
mod test {
use super::*;
use tokio_util::{bytes::BytesMut, codec::Framed};
#[test]
fn test_decode() {
let stream = futures_util::stream::iter([BytesMut::fr&[0u8]]);
let stream = Framed::new(stream, CustomCodec::new());
}
}*/

View file

@ -1,397 +0,0 @@
use regex::bytes::Regex;
use std::{pin::Pin, sync::LazyLock};
use tokio::io::{AsyncRead, AsyncReadExt};
use tokio_util::{
bytes::BytesMut,
codec::{Decoder, Encoder},
};
use crate::util::{is_hex, parse_hex};
static REGEX_CONTENT_LENGTH: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r#"[cC]ontent-[lL]ength: *(\d+)\r\n"#).unwrap());
static REGEX_CHUNKED: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r#"[tT]ransfer-[eE]ncoding: *[cC]hunked\r\n"#).unwrap());
/*pin_project! {
pub struct Framer<S, C> {
#[pin]
pub stream: S,
codec: C,
buf: BytesMut,
}
}
impl<S: AsyncRead + Unpin, C: Decoder> Framer<S, C> {
pub fn new(stream: S, codec: C) -> Self {
Self {
stream,
codec,
buf: BytesMut::new(),
}
}
pub async fn next(&mut self) -> Option<Vec<u8>> {
self.stream.read_buf(&mut self.buf).await.unwrap();
None
}
}*/
pub struct HttpClientCodec {
buf: Vec<u8>,
}
impl HttpClientCodec {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
}
impl Decoder for HttpClientCodec {
type Item = Vec<u8>;
type Error = std::io::Error;
fn decode(
&mut self,
src: &mut tokio_util::bytes::BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
self.buf.extend_from_slice(&src);
src.clear();
let src = &mut self.buf;
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
end_index += 4;
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) {
// Content-Length: simple body
if let Some(content_length) = captures.get(1) {
// Read body
let content_length: usize = str::from_utf8(content_length.as_bytes())
.unwrap()
.parse()
.unwrap();
if src.len() >= end_index + content_length {
let remaining = src.split_off(end_index + content_length);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
} else {
//dbg!("Not enough data");
Ok(None)
}
} else {
// Invalid Content-Length
Err(std::io::ErrorKind::InvalidData.into())
}
} else if REGEX_CHUNKED.is_match(&src[0..end_index]) {
// Chunked body
let mut content = &src[end_index..];
let mut total_len = end_index;
loop {
if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") {
let len_slice = &content[0..len_end_index];
if len_end_index < 8 && is_hex(len_slice) {
let chunk_len = parse_hex(len_slice) as usize;
if content.len() >= len_end_index + chunk_len + 4 {
total_len += len_end_index + chunk_len + 4;
// Should we check the ending CRLF?
if chunk_len == 0 {
let remaining = src.split_off(total_len);
let out = src.to_vec();
*src = remaining;
return Ok(Some(out));
}
// else, wait for the next chunk
content = &content[len_end_index + chunk_len + 4..];
} else {
// Not enough data
return Ok(None);
}
} else {
// Invalid chunk length
return Err(std::io::ErrorKind::InvalidData.into());
}
} else {
// Not enough data
return Ok(None);
}
}
} else {
// Header ended without Content-Type nor chunks => no body
let remaining = src.split_off(end_index);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
}
} else {
//dbg!("Unfinished header");
Ok(None)
}
/*self.buf.extend_from_slice(&src);
src.clear();
let src = &mut self.buf;
if self.chunked {
if let Some(len_end_index) = memchr::memmem::find(src, b"\r\n") {
let len_slice = &src[0..len_end_index];
if len_end_index < 8 && is_hex(len_slice) {
let chunk_len = parse_hex(len_slice) as usize;
if src.len() >= len_end_index + chunk_len + 4 {
// Should we check the ending CRLF?
if chunk_len == 0 {
self.chunked = false;
}
let remaining = src.split_off(len_end_index+chunk_len+4);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
} else {
// Not enough data
Ok(None)
}
} else {
// Invalid chunk length
Err(std::io::ErrorKind::InvalidData.into())
}
} else {
// Not enough data
Ok(None)
}
} else {
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
end_index += 4;
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) {
if let Some(content_length) = captures.get(1) {
// Read body
let content_length: usize = str::from_utf8(content_length.as_bytes())
.unwrap()
.parse()
.unwrap();
if src.len() >= end_index + content_length {
if REGEX_CHUNKED.is_match(&src[0..end_index]) {
self.chunked = true;
}
//dbg!(content_length);
let remaining = src.split_off(end_index + content_length);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
} else {
//dbg!("Not enough data");
Ok(None)
}
} else {
// Invalid Content-Length
Err(std::io::ErrorKind::InvalidData.into())
}
} else {
// Header ended without Content-Type => no body
let remaining = src.split_off(end_index);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
}
} else {
//dbg!("Unfinished header");
Ok(None)
}
}*/
/*if let Some(start_index) = memchr::memmem::find(src, b"HTTP") {
if start_index != 0 {
dbg!(start_index);
if start_index == 529 {
println!("{src:?}");
}
}
let src2 = &src[start_index..];
if let Some(mut end_index) = memchr::memmem::find(src2, b"\r\n\r\n") {
end_index += 4;
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src2) {
if let Some(content_length) = captures.get(1) {
// Read body
let content_length: usize = str::from_utf8(content_length.as_bytes())
.unwrap()
.parse()
.unwrap();
if src2.len() >= end_index + content_length {
if src2.len() > end_index + content_length {
dbg!(src2.len(), end_index + content_length);
println!("{src2:?}");
std::process::exit(1);
}
//dbg!(content_length);
let out = src2.to_vec();
src.clear();
Ok(Some(out))
} else {
//dbg!("Not enough data");
Ok(None)
}
} else {
// Invalid Content-Length
Err(std::io::ErrorKind::InvalidData.into())
}
} else {
// Header ended without Content-Type => no body
let out = src2.to_vec();
src.clear();
Ok(Some(out))
}
} else {
//dbg!("Unfinished header");
Ok(None)
}
} else {
//dbg!("Unstarted header");
Ok(None)
}*/
}
}
impl Encoder<Vec<u8>> for HttpClientCodec {
type Error = std::io::Error;
fn encode(
&mut self,
_item: Vec<u8>,
_dst: &mut tokio_util::bytes::BytesMut,
) -> Result<(), Self::Error> {
Ok(())
}
}
pub struct HttpServerCodec {
buf: Vec<u8>,
}
impl HttpServerCodec {
pub fn new() -> Self {
Self { buf: Vec::new() }
}
}
impl Decoder for HttpServerCodec {
type Item = Vec<u8>;
type Error = std::io::Error;
fn decode(
&mut self,
src: &mut tokio_util::bytes::BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
self.buf.extend_from_slice(&src);
src.clear();
let src = &mut self.buf;
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
end_index += 4;
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) {
// Content-Length: simple body
if let Some(content_length) = captures.get(1) {
// Read body
let content_length: usize = str::from_utf8(content_length.as_bytes())
.unwrap()
.parse()
.unwrap();
if src.len() >= end_index + content_length {
let remaining = src.split_off(end_index + content_length);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
} else {
//dbg!("Not enough data");
Ok(None)
}
} else {
// Invalid Content-Length
Err(std::io::ErrorKind::InvalidData.into())
}
} else if REGEX_CHUNKED.is_match(&src[0..end_index]) {
// Chunked body
let mut content = &src[end_index..];
let mut total_len = end_index;
loop {
if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") {
let len_slice = &content[0..len_end_index];
if len_end_index < 8 && is_hex(len_slice) {
let chunk_len = parse_hex(len_slice) as usize;
if content.len() >= len_end_index + chunk_len + 4 {
total_len += len_end_index + chunk_len + 4;
// Should we check the ending CRLF?
if chunk_len == 0 {
let remaining = src.split_off(total_len);
let out = src.to_vec();
*src = remaining;
return Ok(Some(out));
}
// else, wait for the next chunk
content = &content[len_end_index + chunk_len + 4..];
} else {
// Not enough data
return Ok(None);
}
} else {
// Invalid chunk length
return Err(std::io::ErrorKind::InvalidData.into());
}
} else {
// Not enough data
return Ok(None);
}
}
} else {
// Header ended without Content-Type nor chunks => no body
let remaining = src.split_off(end_index);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
}
} else {
//dbg!("Unfinished header");
Ok(None)
}
/*self.buf.extend_from_slice(&src);
src.clear();
let src = &mut self.buf;
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
end_index += 4;
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(&src[0..end_index]) {
if let Some(content_length) = captures.get(1) {
// Read body
let content_length: usize = str::from_utf8(content_length.as_bytes())
.unwrap()
.parse()
.unwrap();
if src.len() >= end_index + content_length {
//dbg!(content_length);
let remaining = src.split_off(end_index + content_length);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
} else {
//dbg!("Not enough data");
Ok(None)
}
} else {
// Invalid Content-Length
Err(std::io::ErrorKind::InvalidData.into())
}
} else {
// Header ended without Content-Type => no body
let remaining = src.split_off(end_index);
let out = src.to_vec();
*src = remaining;
Ok(Some(out))
}
} else {
//dbg!("Unfinished header");
Ok(None)
}*/
}
}
impl Encoder<Vec<u8>> for HttpServerCodec {
type Error = std::io::Error;
fn encode(
&mut self,
_item: Vec<u8>,
_dst: &mut tokio_util::bytes::BytesMut,
) -> Result<(), Self::Error> {
Ok(())
}
}

View file

@ -1,7 +1,7 @@
#![feature(ascii_char)]
mod client;
mod http;
mod codec;
mod record;
mod server;
mod util;
@ -10,8 +10,6 @@ use record::Records;
use argp::FromArgs;
use static_cell::StaticCell;
use tokio::sync::oneshot;
use tokio_rustls::rustls::crypto::CryptoProvider;
/// Play recorded requests and responses
#[derive(FromArgs)]
@ -26,11 +24,14 @@ struct Opt {
#[derive(FromArgs)]
#[argp(subcommand)]
enum Subcommand {
/// Replay from records
Play(OptPlay),
/// Replay from records (client)
Client(OptClient),
/// Replay from records (server)
Server(OptServer),
/// Print records
Print(OptPrint),
/// Record traffic
#[cfg(feature = "record")]
Record(OptRecord),
/// Remove record
Remove(OptRemove),
@ -40,32 +41,44 @@ enum Subcommand {
/// Replay from records
#[derive(FromArgs)]
#[argp(subcommand, name = "play")]
struct OptPlay {
#[argp(subcommand, name = "client")]
struct OptClient {
/// Connect to address
#[argp(positional)]
forward_addr: String,
connect_addr: String,
/// Connect to port
#[argp(positional)]
forward_port: u16,
connect_port: u16,
/// Whether to use TLS
#[argp(switch, long = "tls")]
tls: bool,
/// Repeat N times
#[argp(option, short = 'r', default = "1")]
repeat: u32,
/// UDP end notification will be sent to this address:port
#[argp(option, short = 'n')]
notify_addr: Option<String>,
/// Only play this record
#[argp(option)]
record: Option<u64>,
/// Print debug info
#[argp(switch, short = 'd')]
debug: bool,
}
/// Replay from records
#[derive(FromArgs)]
#[argp(subcommand, name = "server")]
struct OptServer {
/// Listen to port
#[argp(positional)]
listen_port: u16,
/// Path to PEM certificates and keys
#[argp(positional)]
certs: String,
/// Where to use TLS
#[argp(positional)]
tls: String,
/// Repeat N times
#[argp(option, short = 'r', default = "1")]
repeat: u32,
/// Only play this record
#[argp(option)]
record: Option<u64>,
/// Only run these parts
#[argp(option, default = "String::from(\"both\")")]
run: String,
/// Whether to use TLS
#[argp(switch, long = "tls")]
tls: bool,
/// Print debug info
#[argp(switch, short = 'd')]
debug: bool,
@ -75,15 +88,13 @@ struct OptPlay {
#[derive(FromArgs)]
#[argp(subcommand, name = "print")]
struct OptPrint {
/// Print packets
#[argp(switch, short = 'p')]
packets: bool,
/// Record number
#[argp(option, short = 'n')]
number: Option<u64>,
}
/// Record traffic
#[cfg(feature = "record")]
#[derive(FromArgs)]
#[argp(subcommand, name = "record")]
struct OptRecord {}
@ -108,21 +119,6 @@ struct OptRemove {
packet_number: usize,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum RunMode {
Client,
Server,
Both,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum TlsMode {
None,
Client,
Server,
Both,
}
static RECORDS: StaticCell<Records> = StaticCell::new();
#[tokio::main]
@ -131,140 +127,49 @@ async fn main() {
let opt: Opt = argp::parse_args_or_exit(argp::DEFAULT);
match opt.subcommand {
Subcommand::Play(subopt) => {
let tls_mode = match subopt.tls.as_str() {
"none" => TlsMode::None,
"client" => TlsMode::Client,
"server" => TlsMode::Server,
"both" => TlsMode::Both,
_ => panic!("TLS mode must be one of none,client,server,both."),
};
let run_mode = match subopt.run.as_str() {
"client" => RunMode::Client,
"server" => RunMode::Server,
"both" => RunMode::Both,
_ => panic!("run mode must be one of client,server,both."),
};
Subcommand::Client(subopt) => {
let records = RECORDS.init(record::read_record_file(&opt.record_file));
if let Some(only_record) = subopt.record {
records.retain(|id, _| *id == only_record);
}
let mut ciphers: Option<Vec<String>> = None;
let mut kexes: Option<Vec<String>> = None;
for (var, val) in std::env::vars() {
match var.as_str() {
"CIPHERS" => ciphers = Some(val.split(',').map(str::to_string).collect()),
"KEXES" => kexes = Some(val.split(',').map(str::to_string).collect()),
_ => {}
}
}
let mut prov = tokio_rustls::rustls::crypto::aws_lc_rs::default_provider();
if let Some(ciphers) = ciphers {
prov.cipher_suites.clear();
for cipher in ciphers {
match cipher.as_str() {
"AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_256_GCM_SHA384),
"AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_128_GCM_SHA256),
"CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256),
"ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384),
"ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256),
"ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256),
"ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
"ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
"ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256),
other => {
println!("Unknown cipher `{other}`")
}
}
}
}
if let Some(kexes) = kexes {
prov.kx_groups.clear();
for kex in kexes {
match kex.as_str() {
"X25519" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519),
"SECP256R1" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1),
"SECP384R1" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP384R1),
"X25519MLKEM768" => prov.kx_groups.push(
tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519MLKEM768,
),
"SECP256R1MLKEM768" => prov.kx_groups.push(
tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1MLKEM768,
),
"MLKEM768" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::MLKEM768),
other => {
println!("Unknown kex `{other}`")
}
}
}
}
CryptoProvider::install_default(prov).unwrap();
util::init_provider();
let (sync_sender, sync_receiver) = oneshot::channel();
//console_subscriber::init();
let client = tokio::spawn({
let records = &*records;
async move {
if run_mode == RunMode::Both || run_mode == RunMode::Client {
client::play(
records,
tls_mode,
(subopt.forward_addr, subopt.forward_port),
sync_receiver,
subopt.repeat,
subopt.debug,
)
.await;
} else {
std::future::pending().await
}
}
});
if run_mode == RunMode::Both || run_mode == RunMode::Server {
server::play(
records,
tls_mode,
&subopt.certs,
("0.0.0.0", subopt.listen_port),
sync_sender,
subopt.debug,
)
.await;
client::play(
records,
subopt.tls,
(subopt.connect_addr, subopt.connect_port),
subopt.repeat,
subopt.debug,
)
.await;
if let Some(notify_addr) = subopt.notify_addr {
let socket = std::net::UdpSocket::bind("0.0.0.0:48567").unwrap();
socket.send_to(b"done", &notify_addr).unwrap();
}
client.await.unwrap();
}
Subcommand::Server(subopt) => {
let records = RECORDS.init(record::read_record_file(&opt.record_file));
util::init_provider();
//console_subscriber::init();
server::play(
records,
subopt.tls,
&subopt.certs,
("0.0.0.0", subopt.listen_port),
subopt.debug,
)
.await;
}
Subcommand::Print(subopt) => {
let records = record::read_record_file(&opt.record_file);
record::print_records(&records, subopt.packets, subopt.number);
record::print_records(&records, subopt.number);
}
#[cfg(feature = "record")]
Subcommand::Record(_subopt) => {
record::make_record(&opt.record_file);
}

View file

@ -4,12 +4,10 @@ use std::{
sync::mpsc::{Receiver, Sender, channel},
};
use crate::util::{ResponseStreamer, print_bin};
const CLIENT_TO_SERVER: u8 = b'C';
const SERVER_TO_CLIENT: u8 = b'S';
pub type Records = BTreeMap<u64, (Vec<u8>, Vec<(Direction, Vec<u8>)>)>;
pub type Records = BTreeMap<u64, (Vec<u8>, Vec<(u64, Direction, u64)>)>;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Direction {
@ -34,7 +32,7 @@ fn write_record(
direction: Direction,
conn_id: u64,
server_name: &str,
data: &[u8],
len: u64,
) {
let server_name = server_name.as_bytes();
file.write_all(&[match direction {
@ -45,8 +43,7 @@ fn write_record(
file.write_all(&conn_id.to_be_bytes()).unwrap();
file.write_all(&[server_name.len() as u8]).unwrap();
file.write_all(server_name).unwrap();
file.write_all(&(data.len() as u64).to_be_bytes()).unwrap();
file.write_all(&data).unwrap();
file.write_all(&len.to_be_bytes()).unwrap();
file.flush().unwrap();
}
@ -78,17 +75,25 @@ impl Recorder {
let Some(server_name) = server_name else {
continue;
};
write_record(&mut self.file, direction, conn_id, &server_name, &data);
write_record(
&mut self.file,
direction,
conn_id,
&server_name,
data.len() as u64,
);
}
}
}
#[cfg(feature = "record")]
#[derive(Clone)]
struct Handler {
sender: Sender<(u64, Option<String>, Direction, Vec<u8>)>,
server_name: Option<String>,
}
#[cfg(feature = "record")]
impl sslrelay::HandlerCallbacks for Handler {
// DownStream non blocking callback
fn ds_nb_callback(&self, in_data: Vec<u8>, conn_id: u64) {
@ -129,6 +134,7 @@ impl sslrelay::HandlerCallbacks for Handler {
}
}
#[cfg(feature = "record")]
pub fn make_record(path: &str) {
let (mut recorder, sender) = Recorder::new(path);
let mut relay = sslrelay::SSLRelay::new(
@ -161,7 +167,7 @@ pub fn make_record(path: &str) {
pub fn read_record_file(path: &str) -> Records {
let mut file = std::fs::OpenOptions::new().read(true).open(path).unwrap();
let mut records = BTreeMap::<u64, (Vec<u8>, Vec<(Direction, Vec<u8>)>)>::new();
let mut records = BTreeMap::<u64, (Vec<u8>, Vec<(u64, Direction, u64)>)>::new();
loop {
let mut direction = [0; 1];
if file.read(&mut direction).unwrap() != 1 {
@ -202,84 +208,39 @@ pub fn read_record_file(path: &str) -> Records {
println!("Error: len too large {len}. stop.");
break;
}
let mut buf = vec![0; len as usize];
if file.read(&mut buf).unwrap() != len as usize {
println!("Error: incomplete data. stop.");
break;
}
// Replace URL with unique id, to allow for better tracking by making each request unique.
// (proxy may modify some headers, but not the URL)
let mut insert_id = |req_id| {
if direction == Direction::ClientToServer {
let mut spaces = buf
.iter()
.enumerate()
.filter_map(|(i, c)| if *c == b' ' { Some(i) } else { None });
let s1 = spaces.next().unwrap();
let s2 = spaces.next().unwrap();
let new_url = format!("/{conn_id}-{req_id}/");
if s2 - s1 - 1 < new_url.len() {
// Not optimal but good enough
let mut new_buf = Vec::new();
new_buf.extend_from_slice(&buf[0..s1 + 1]);
new_buf.extend_from_slice(new_url.as_bytes());
new_buf.extend_from_slice(&buf[s2..]);
buf = new_buf;
} else {
buf[s1 + 1..s2][0..new_url.len()].copy_from_slice(new_url.as_bytes());
}
}
};
match records.entry(conn_id) {
btree_map::Entry::Occupied(mut entry) => {
(insert_id)(entry.get().1.len());
entry.get_mut().1.push((direction, buf));
let req_id = entry.get().1.len() as u64;
entry.get_mut().1.push((req_id, direction, len));
}
btree_map::Entry::Vacant(entry) => {
(insert_id)(0);
entry.insert((server_name, vec![(direction, buf)]));
let req_id = 0;
entry.insert((server_name, vec![(req_id, direction, len)]));
}
}
}
records
}
pub fn print_records(records: &Records, print_packets: bool, number: Option<u64>) {
for (id, (server_name, records)) in records {
pub fn print_records(records: &Records, number: Option<u64>) {
for (conn_id, (server_name, records)) in records {
if let Some(number) = number
&& number != *id
&& number != *conn_id
{
continue;
}
let server_name = str::from_utf8(server_name.as_slice()).unwrap();
println!("{id} {server_name}");
for (direction, data) in records {
println!("{conn_id} {server_name}");
for (req_id, direction, len) in records {
match direction {
Direction::ClientToServer => {
println!(" >> {}", data.len());
println!(" ({req_id}) >> {len}");
}
Direction::ServerToClient => {
println!(" << {}", data.len());
println!(" ({req_id}) << {len}");
}
}
if print_packets {
/*let data_tr = if data.len() >= 256 && *direction == Direction::ServerToClient {
&data[0..256]
} else {
data.as_slice()
};
if let Ok(data_tr) = str::from_utf8(data_tr) {
println!(" {data_tr:?}")
} else {
println!(" {data_tr:?}")
}
if let Some(header_end) = memchr::memmem::find(data, b"\r\n\r\n") {
println!(" --> body len: {}", data.len() - header_end - 4);
}*/
print_bin(&data[0..data.len().min(8192)]);
}
}
}
}
@ -292,7 +253,13 @@ pub fn make_test_record(path: &str) {
.open(path)
.unwrap();
for (conn_id, server_name, direction, data) in TEST_RECORD {
write_record(&mut file, *direction, *conn_id, *server_name, *data);
write_record(
&mut file,
*direction,
*conn_id,
server_name,
data.len() as u64,
);
}
}
@ -311,9 +278,9 @@ pub fn remove_record(
.unwrap();
for (conn_id, (server_name, packets)) in records.into_iter() {
let server_name = String::from_utf8(server_name).unwrap();
for (packet_id, (direction, data)) in packets.into_iter().enumerate() {
for (packet_id, (_req_id, direction, len)) in packets.into_iter().enumerate() {
if conn_id != record_to_remove || packet_id != packet_to_remove {
write_record(&mut output_file, direction, conn_id, &server_name, &data);
write_record(&mut output_file, direction, conn_id, &server_name, len);
}
}
}

View file

@ -1,12 +1,7 @@
use crate::{
TlsMode,
record::{Direction, Records},
util::print_bin,
};
use crate::record::{Direction, Records};
use futures_util::stream::StreamExt;
use std::{collections::HashMap, sync::Arc};
use tokio::{io::AsyncWriteExt, net::TcpListener, sync::oneshot};
use tokio::{io::AsyncWriteExt, net::TcpListener};
use tokio_rustls::rustls::{
pki_types::{
CertificateDer, PrivateKeyDer,
@ -15,414 +10,262 @@ use tokio_rustls::rustls::{
server::ResolvesServerCertUsingSni,
sign::CertifiedKey,
};
use tokio_util::codec::Framed;
use x509_parser::prelude::GeneralName;
pub async fn play(
records: &'static Records,
tls_mode: TlsMode,
use_tls: bool,
cert_path: &str,
listen_addr: (&str, u16),
sync_sender: oneshot::Sender<()>,
debug: bool,
_debug: bool,
) {
let mut response_map = HashMap::new();
for (id, (server_name, records)) in records.iter() {
let mut hash = None;
for (conn_id, (_server_name, records)) in records.iter() {
let mut last_client_req_id = None;
let mut responses = Vec::new();
for (direction, data) in records {
for (req_id, direction, len) in records {
match direction {
Direction::ClientToServer => {
if let Some(hash) = hash
if let Some(last_client_req_id) = last_client_req_id
&& !responses.is_empty()
{
response_map.insert((server_name.to_vec(), hash), (id, responses, false));
response_map.insert((*conn_id, last_client_req_id), (responses, false));
responses = Vec::new();
}
let mut slashes = data
.iter()
.enumerate()
.filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None });
let s1 = slashes.next();
let s2 = slashes.next();
hash = Some(if let (Some(s1), Some(s2)) = (s1, s2) {
data[s1 + 1..s2].to_vec()
} else {
panic!("Did not find URL: {:?}", &data[0..256]);
tlsh::hash_buf(data)
.map_or_else(|_| data.clone(), |h| h.to_string().into_bytes())
});
last_client_req_id = Some(*req_id);
}
Direction::ServerToClient => {
responses.push(data);
responses.push((*req_id, *len));
}
}
}
if let Some(hash) = hash {
if let Some(last_client_req_id) = last_client_req_id {
if !responses.is_empty() {
response_map.insert((server_name.to_vec(), hash), (id, responses, true));
} else {
response_map
.get_mut(&(server_name.to_vec(), hash))
.unwrap()
.2 = true;
response_map.insert((*conn_id, last_client_req_id), (responses, true));
} else if let Some(entry) = response_map.get_mut(&(*conn_id, last_client_req_id)) {
entry.1 = true;
}
}
}
let response_map = Arc::new(response_map);
let dummy_bytes = Arc::new(vec![0x42u8; 16 * 1024 * 1024]);
match tls_mode {
TlsMode::Both | TlsMode::Server => {
let mut resolver = ResolvesServerCertUsingSni::new();
let mut config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(ResolvesServerCertUsingSni::new()));
config.max_early_data_size = 8192;
for file in std::fs::read_dir(cert_path).unwrap_or_else(|e| {
panic!("Cannot read certificate directory `{cert_path}`: {e:?}")
}) {
match file {
Ok(file) => {
if file.file_name().as_encoded_bytes().ends_with(b".crt") {
for section in
<(pem::SectionKind, Vec<u8>) as PemObject>::pem_file_iter(
file.path(),
)
if use_tls {
let mut resolver = ResolvesServerCertUsingSni::new();
let mut config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(ResolvesServerCertUsingSni::new()));
config.max_early_data_size = 8192;
for file in std::fs::read_dir(cert_path)
.unwrap_or_else(|e| panic!("Cannot read certificate directory `{cert_path}`: {e:?}"))
{
match file {
Ok(file) => {
if file.file_name().as_encoded_bytes().ends_with(b".crt") {
for section in
<(pem::SectionKind, Vec<u8>) as PemObject>::pem_file_iter(file.path())
.unwrap()
{
let (kind, data) = section.unwrap();
if kind == SectionKind::Certificate {
let (_rem, cert) =
x509_parser::parse_x509_certificate(&data).unwrap();
if !cert.is_ca() {
//println!("File: {:?}", file.file_name());
let mut key_path = file.path().to_path_buf();
key_path.pop();
let file_name =
file.file_name().to_str().unwrap().to_string();
let mut key_file_name =
file_name[0..file_name.len() - 4].to_string();
key_file_name.push_str(".key");
let key = PrivateKeyDer::from_pem_file(
key_path.join(key_file_name),
)
.unwrap();
let key = config
.crypto_provider()
.key_provider
.load_private_key(key)
{
let (kind, data) = section.unwrap();
if kind == SectionKind::Certificate {
let (_rem, cert) =
x509_parser::parse_x509_certificate(&data).unwrap();
if !cert.is_ca() {
//println!("File: {:?}", file.file_name());
let mut key_path = file.path().to_path_buf();
key_path.pop();
let file_name = file.file_name().to_str().unwrap().to_string();
let mut key_file_name =
file_name[0..file_name.len() - 4].to_string();
key_file_name.push_str(".key");
let key =
PrivateKeyDer::from_pem_file(key_path.join(key_file_name))
.unwrap();
// This wants static lifetime...
let cert_key = CertifiedKey::new(
vec![CertificateDer::from_slice(Box::leak(
data.to_vec().into_boxed_slice(),
))],
key,
);
for name in cert
.subject_alternative_name()
.unwrap()
.unwrap()
.value
.general_names
.iter()
{
if let GeneralName::DNSName(name) = name {
resolver.add(name, cert_key.clone()).ok();
}
let key = config
.crypto_provider()
.key_provider
.load_private_key(key)
.unwrap();
// This wants static lifetime...
let cert_key = CertifiedKey::new(
vec![CertificateDer::from_slice(Box::leak(
data.to_vec().into_boxed_slice(),
))],
key,
);
for name in cert
.subject_alternative_name()
.unwrap()
.unwrap()
.value
.general_names
.iter()
{
if let GeneralName::DNSName(name) = name {
resolver.add(name, cert_key.clone()).ok();
}
}
}
}
}
}
Err(e) => eprintln!("Error listing cert directory: {e:?}"),
}
Err(e) => eprintln!("Error listing cert directory: {e:?}"),
}
}
// Config requires resolver, keys can be added to resolver, creating a key requires config. WTF!?
// So we have to re-create config.
let mut config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
config.max_early_data_size = 8192;
config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new());
let config = Arc::new(config);
// Config requires resolver, keys can be added to resolver, creating a key requires config. WTF!?
// So we have to re-create config.
let mut config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver));
config.max_early_data_size = 8192;
config.key_log = Arc::new(tokio_rustls::rustls::KeyLogFile::new());
let config = Arc::new(config);
let listener = TcpListener::bind(listen_addr).await.unwrap();
sync_sender.send(()).unwrap();
loop {
let config = config.clone();
let (stream, _peer_addr) = listener.accept().await.unwrap();
let acceptor = tokio_rustls::LazyConfigAcceptor::new(
tokio_rustls::rustls::server::Acceptor::default(),
stream,
);
//let acceptor = acceptor.clone();
let response_map = response_map.clone();
/*let fut = async move {
let accepted = acceptor.await.unwrap();
let server_name = accepted.client_hello().server_name().unwrap().to_string();
let mut stream = accepted.into_stream(config).await.unwrap();
let mut req = Vec::new();
http::decode_http(&mut req, &mut stream).await;
let req_hash = tlsh::hash_buf(&req)
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes());
let mut best = None;
for (i_server_name, hash) in response_map.keys() {
if i_server_name != server_name.as_bytes() {
let listener = TcpListener::bind(listen_addr).await.unwrap();
loop {
let config = config.clone();
let (stream, _peer_addr) = listener.accept().await.unwrap();
let acceptor = tokio_rustls::LazyConfigAcceptor::new(
tokio_rustls::rustls::server::Acceptor::default(),
stream,
);
//let acceptor = acceptor.clone();
let response_map = response_map.clone();
let dummy_bytes = dummy_bytes.clone();
let fut = async move {
let accepted = acceptor.await.unwrap();
let server_name = accepted
.client_hello()
.server_name()
.unwrap()
.trim_end_matches(".localhost")
.to_string();
let stream = accepted
.into_stream(config)
.await
.map_err(|e| panic!("{e:?} with name `{server_name}`"))
.unwrap();
let mut stream = crate::codec::StreamCodec::new(stream);
let mut break_next = false;
//let mut previous = Vec::new();
loop {
let Ok(req) =
tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next())
.await
else {
if break_next {
break;
} else {
continue;
}
let diff = compare(&req_hash, hash);
if let Some((best_hash, best_diff)) = &mut best {
if diff < *best_diff {
*best_hash = hash;
*best_diff = diff;
}
} else {
best = Some((hash, diff));
}
};
let req = req.unwrap();
if req.len() < 8 {
println!("Invalid request");
break;
}
if let Some((hash, _diff)) = best {
let (id, responses) = response_map
.get(&(server_name.as_bytes().to_vec(), hash.clone()))
.unwrap();
for &res in responses {
println!("[SRV] response for ({}): {} bytes", id, res.len());
stream.write_all(res).await.unwrap();
let _expected_len = u32::from_be_bytes(req[0..4].try_into().unwrap()) as u64;
let conn_id = u16::from_be_bytes(req[4..6].try_into().unwrap()) as u64;
let req_id = u16::from_be_bytes(req[6..8].try_into().unwrap()) as u64;
//println!("REQUEST");
//print_bin(&req);
//previous = req.clone();
let stream = stream.get_mut();
if let Some((responses, last)) = response_map.get(&(conn_id, req_id)) {
//dbg!(id);
for (req_id, len) in responses {
//println!("[SRV] response for ({}): {} bytes", id, res.len());
let mut data = dummy_bytes[0..*len as usize].to_vec();
data[0..4].copy_from_slice(&(*len as u32).to_be_bytes());
data[4..6].copy_from_slice(&(conn_id as u16).to_be_bytes());
data[6..8].copy_from_slice(&(*req_id as u16).to_be_bytes());
stream.write_all(&data).await.unwrap();
stream.flush().await.unwrap();
}
} else {
println!("No response found for SNI=`{server_name}`");
}
stream.shutdown().await.unwrap();
};*/
let fut = async move {
let accepted = acceptor.await.unwrap();
let server_name = accepted
.client_hello()
.server_name()
.unwrap()
.trim_end_matches(".localhost")
.to_string();
let stream = accepted
.into_stream(config)
.await
.map_err(|e| panic!("{e:?} with name `{server_name}`"))
.unwrap();
let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new());
let mut break_next = false;
//let mut previous = Vec::new();
loop {
let Ok(req) = tokio::time::timeout(tokio::time::Duration::from_secs(1), stream.next()).await else {
if break_next {
break;
} else {
continue;
}
};
let Some(req) = req else {
if *last {
break_next = true;
break;
};
let req = req.unwrap();
//println!("REQUEST");
//print_bin(&req);
let req_hash = {
let mut slashes = req
.iter()
.enumerate()
.filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None });
let s1 = slashes.next();
let s2 = slashes.next();
if let (Some(s1), Some(s2)) = (s1, s2) {
req[s1 + 1..s2].to_vec()
} else {
//println!("Previous: {:?}", &previous);
println!("Did not find URL: {:?}", &req[0..req.len().min(255)]);
tlsh::hash_buf(&req)
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes())
}
};
//previous = req.clone();
let mut best = None;
for (i_server_name, hash) in response_map.keys() {
if i_server_name != server_name.as_bytes() {
continue;
}
let diff = if &req_hash == hash {
0
} else {
compare(&req_hash, hash)
};
if let Some((best_hash, best_diff)) = &mut best {
if diff < *best_diff {
*best_hash = hash;
*best_diff = diff;
}
} else {
best = Some((hash, diff));
}
}
let stream = stream.get_mut();
if let Some((hash, _diff)) = best {
let (id, responses, last) = response_map
.get(&(server_name.as_bytes().to_vec(), hash.clone()))
.unwrap();
//dbg!(id);
for &res in responses {
//println!("[SRV] response for ({}): {} bytes", id, res.len());
stream.write_all(res).await.unwrap();
stream.flush().await.unwrap();
}
if *last {
break_next = true;
}
} else {
println!("No response found for SNI=`{server_name}`");
}
}
stream.get_mut().shutdown().await.unwrap();
};
tokio::spawn(async move {
fut.await;
});
}
}
TlsMode::None | TlsMode::Client => {
let listener = TcpListener::bind(listen_addr).await.unwrap_or_else(|e| {
println!("Server: Cannot listen: {e:?}");
std::process::exit(1)
});
sync_sender.send(()).unwrap();
loop {
let (stream, _peer_addr) = listener.accept().await.unwrap();
let response_map = response_map.clone();
/*let fut = async move {
println!("[SRV] New task");
let mut req = Vec::new();
http::decode_http(&mut req, &mut stream).await;
let req_hash = tlsh::hash_buf(&req)
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes());
let mut best = None;
for (i_server_name, hash) in response_map.keys() {
let diff = compare(&req_hash, hash);
if let Some((best_server_name, best_hash, best_diff)) = &mut best {
if diff < *best_diff {
*best_server_name = i_server_name;
*best_hash = hash;
*best_diff = diff;
}
} else {
best = Some((i_server_name, hash, diff));
}
}
if let Some((server_name, hash, _diff)) = best {
let (id, responses) = response_map
.get(&(server_name.clone(), hash.clone()))
.unwrap();
for &res in responses {
println!("[SRV] response for ({}): {} bytes", id, res.len());
stream.write_all(res).await.unwrap();
stream.flush().await.unwrap();
}
} else {
println!("[SRV] No response found");
println!("No response found for {conn_id}-{req_id} SNI=`{server_name}`");
}
//println!("Server shutdown");
stream.shutdown().await.unwrap();
};*/
let fut = async move {
//println!("[SRV] New task");
//let mut stream = crate::http::Framer::new(stream, crate::http::HttpServerCodec::new());
let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new());
//let mut previous = Vec::new();
while let Some(req) = stream.next().await {
let req = req.unwrap();
//println!("REQUEST");
//print_bin(&req);
//println!("[SRV] << {}", str::from_utf8(&req[..req.len().min(255)]).unwrap());
let req_hash = {
let mut slashes = req
.iter()
.enumerate()
.filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None });
let s1 = slashes.next();
let s2 = slashes.next();
if let (Some(s1), Some(s2)) = (s1, s2) {
let uniq_id = req[s1 + 1..s2].to_vec();
if debug {
if let Ok(uniq_id) = str::from_utf8(&uniq_id) {
println!("[SRV] ({uniq_id}) << {}", req.len());
}
}
uniq_id
} else {
//println!("Previous: {:?}", &previous);
println!("Did not find URL: {:?}", &req[0..req.len().min(255)]);
tlsh::hash_buf(&req)
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes())
}
};
//previous = req.clone();
let mut best = None;
for (i_server_name, hash) in response_map.keys() {
let diff = compare(&req_hash, hash);
if let Some((best_server_name, best_hash, best_diff)) = &mut best {
if diff < *best_diff {
*best_server_name = i_server_name;
*best_hash = hash;
*best_diff = diff;
}
} else {
best = Some((i_server_name, hash, diff));
}
}
let stream = stream.get_mut();
if let Some((server_name, hash, _diff)) = best {
let (id, responses, last) = response_map
.get(&(server_name.clone(), hash.clone()))
.unwrap();
//dbg!(id);
for &res in responses {
if debug {
println!("[SRV] ({id}) >> {}", res.len());
//println!("[SRV] response for ({}): {} bytes", id, res.len());
}
stream.write_all(res).await.unwrap();
stream.flush().await.unwrap();
if debug {
println!("[SRV] ({id}) >> {} OK", res.len());
}
}
if *last {
//break;
}
} else {
println!("[SRV] No response found");
}
}
//println!("Server shutdown");
stream.get_mut().shutdown().await.unwrap();
};
// Using a variable for the future allows it to be detected by tokio-console
tokio::spawn(async move {
fut.await;
});
}
}
stream.get_mut().shutdown().await.unwrap();
};
tokio::spawn(async move {
fut.await;
});
}
}
}
fn compare(a: &[u8], b: &[u8]) -> u32 {
if let (Ok(a), Ok(b)) = (str::from_utf8(a), str::from_utf8(b)) {
if let Ok(diff) = tlsh::compare(a, b) {
return diff;
}
}
if a == b {
0
} else {
a.len().max(b.len()) as u32
let listener = TcpListener::bind(listen_addr).await.unwrap_or_else(|e| {
println!("Server: Cannot listen: {e:?}");
std::process::exit(1)
});
loop {
let (stream, _peer_addr) = listener.accept().await.unwrap();
let response_map = response_map.clone();
let dummy_bytes = dummy_bytes.clone();
let fut = async move {
//println!("[SRV] New task");
//let mut stream = Framed::new(stream, crate::codec::CustomCodec::new());
let mut stream = crate::codec::StreamCodec::new(stream);
let mut break_next = false;
//let mut previous = Vec::new();
loop {
let Ok(req) =
tokio::time::timeout(tokio::time::Duration::from_secs(5), stream.next())
.await
else {
if break_next {
println!("break timeout");
break;
} else {
println!("continue");
continue;
}
};
let req = req.unwrap();
if req.len() < 8 {
println!("Invalid request");
break;
}
let expected_len = u32::from_be_bytes(req[0..4].try_into().unwrap()) as u64;
let conn_id = u16::from_be_bytes(req[4..6].try_into().unwrap()) as u64;
let req_id = u16::from_be_bytes(req[6..8].try_into().unwrap()) as u64;
//println!("[SRV] ({conn_id}) << {expected_len}");
//println!("REQUEST");
//print_bin(&req);
//previous = req.clone();
let stream = stream.get_mut();
if let Some((responses, last)) = response_map.get(&(conn_id, req_id)) {
//dbg!(id);
for (req_id, len) in responses {
//println!("[SRV] ({conn_id}) >> {len}");
let mut data = dummy_bytes[0..*len as usize].to_vec();
data[0..4].copy_from_slice(&(*len as u32).to_be_bytes());
data[4..6].copy_from_slice(&(conn_id as u16).to_be_bytes());
data[6..8].copy_from_slice(&(*req_id as u16).to_be_bytes());
stream.write_all(&data).await.unwrap();
stream.flush().await.unwrap();
}
if *last {
break_next = true;
break;
}
} else {
println!("No response found for {conn_id}-{req_id}");
}
}
//println!("Server shutdown");
stream.get_mut().shutdown().await.unwrap();
};
// Using a variable for the future allows it to be detected by tokio-console
tokio::spawn(async move {
fut.await;
});
}
}
}

View file

@ -1,26 +1,8 @@
use std::iter::Peekable;
use crate::record::Direction;
fn hex_digit(c: u8) -> u32 {
((c & !(16 | 32 | 64)) + ((c & 64) >> 6) * 9) as _
}
pub fn parse_hex(s: &[u8]) -> u32 {
let mut r = 0;
for i in s.iter() {
r <<= 4;
r |= hex_digit(*i);
}
r
}
pub fn is_hex(s: &[u8]) -> bool {
s.iter().all(|c| {
let c = *c | 32;
(c >= b'a' && c <= b'f') || (c >= b'0' && c <= b'9')
})
}
use log::info;
use std::iter::Peekable;
use tokio_rustls::rustls::crypto::CryptoProvider;
/// Print ASCII if possible
pub fn print_bin(s: &[u8]) {
@ -46,85 +28,407 @@ pub fn print_bin(s: &[u8]) {
pub struct ResponseStreamer<I: Iterator>(Peekable<I>);
impl<'a, I: Iterator> ResponseStreamer<I> {
impl<I: Iterator> ResponseStreamer<I> {
pub fn new(inner: I) -> Self {
Self(inner.peekable())
}
}
impl<'a, I: Iterator<Item = &'a (Direction, Vec<u8>)>> Iterator for ResponseStreamer<I> {
type Item = (&'a Direction, Vec<&'a Vec<u8>>);
impl<'a, I: Iterator<Item = &'a (u64, Direction, u64)>> Iterator for ResponseStreamer<I> {
type Item = (Direction, Vec<(u64, u64)>);
fn next(&mut self) -> Option<Self::Item> {
let (direction, first_item) = self.0.next()?;
let mut items = vec![first_item];
while let Some((item_direction, _item)) = self.0.peek()
&& item_direction == direction
let (first_req_id, first_direction, first_len) = self.0.next()?;
let mut items = vec![(*first_req_id, *first_len)];
while let Some((_req_id, direction, _len)) = self.0.peek()
&& direction == first_direction
{
items.push(&self.0.next().unwrap().1);
let (req_id, _direction, len) = self.0.next().unwrap();
items.push((*req_id, *len));
}
Some((direction, items))
Some((*first_direction, items))
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_hex_digit() {
assert_eq!(hex_digit(b'0'), 0);
assert_eq!(hex_digit(b'1'), 1);
assert_eq!(hex_digit(b'2'), 2);
assert_eq!(hex_digit(b'3'), 3);
assert_eq!(hex_digit(b'4'), 4);
assert_eq!(hex_digit(b'5'), 5);
assert_eq!(hex_digit(b'6'), 6);
assert_eq!(hex_digit(b'7'), 7);
assert_eq!(hex_digit(b'8'), 8);
assert_eq!(hex_digit(b'9'), 9);
assert_eq!(hex_digit(b'a'), 10);
assert_eq!(hex_digit(b'b'), 11);
assert_eq!(hex_digit(b'c'), 12);
assert_eq!(hex_digit(b'd'), 13);
assert_eq!(hex_digit(b'e'), 14);
assert_eq!(hex_digit(b'f'), 15);
assert_eq!(hex_digit(b'A'), 10);
assert_eq!(hex_digit(b'B'), 11);
assert_eq!(hex_digit(b'C'), 12);
assert_eq!(hex_digit(b'D'), 13);
assert_eq!(hex_digit(b'E'), 14);
assert_eq!(hex_digit(b'F'), 15);
pub fn init_provider() {
let mut ciphers: Option<Vec<String>> = None;
let mut kexes: Option<Vec<String>> = None;
for (var, val) in std::env::vars() {
match var.as_str() {
"CIPHERS" => ciphers = Some(val.split(',').map(str::to_string).collect()),
"KEXES" => kexes = Some(val.split(',').map(str::to_string).collect()),
_ => {}
}
}
// Ensure multiple provider cannot be enabled without compile error.
let _provider;
#[cfg(feature = "aws-lc")]
{
info!("Using RusTLS provider aws-lc");
let mut prov = rustls_post_quantum::provider();
if let Some(ciphers) = ciphers {
prov.cipher_suites.clear();
for cipher in ciphers {
match cipher.as_str() {
"AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_256_GCM_SHA384),
"AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_AES_128_GCM_SHA256),
"CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256),
"ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384),
"ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256),
"ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256),
"ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
"ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
"ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::aws_lc_rs::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256),
other => {
log::error!("Unknown cipher `{other}`")
}
}
}
}
if let Some(kexes) = kexes {
prov.kx_groups.clear();
for kex in kexes {
match kex.as_str() {
"X25519" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519),
"SECP256R1" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1),
"SECP384R1" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP384R1),
"X25519MLKEM768" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::X25519MLKEM768),
"SECP256R1MLKEM768" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::SECP256R1MLKEM768),
"MLKEM768" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::aws_lc_rs::kx_group::MLKEM768),
other => {
log::error!("Unknown kex `{other}`")
}
}
}
}
_provider = CryptoProvider::install_default(prov);
}
#[cfg(feature = "boring")]
{
info!("Using RusTLS provider boring");
let mut prov = boring_rustls_provider::provider();
if let Some(ciphers) = ciphers {
prov.cipher_suites.clear();
for cipher in ciphers {
match cipher.as_str() {
"AES_256_GCM_SHA384" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls13(
&boring_rustls_provider::tls13::AES_256_GCM_SHA384,
)),
"AES_128_GCM_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls13(
&boring_rustls_provider::tls13::AES_128_GCM_SHA256,
)),
"CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls13(
&boring_rustls_provider::tls13::CHACHA20_POLY1305_SHA256,
)),
"ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12(
&boring_rustls_provider::tls12::ECDHE_ECDSA_AES256_GCM_SHA384,
)),
"ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12(
&boring_rustls_provider::tls12::ECDHE_ECDSA_AES128_GCM_SHA256,
)),
"ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12(
&boring_rustls_provider::tls12::ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
)),
"ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12(
&boring_rustls_provider::tls12::ECDHE_RSA_AES256_GCM_SHA384,
)),
"ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12(
&boring_rustls_provider::tls12::ECDHE_RSA_AES128_GCM_SHA256,
)),
"ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(tokio_rustls::rustls::SupportedCipherSuite::Tls12(
&boring_rustls_provider::tls12::ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
)),
other => {
log::error!("Unknown cipher `{other}`")
}
}
}
}
if let Some(kexes) = kexes {
prov.kx_groups.clear();
for kex in kexes {
match kex.as_str() {
"X25519" => prov
.kx_groups
.push(boring_rustls_provider::ALL_KX_GROUPS[0]),
"SECP256R1" => prov
.kx_groups
.push(boring_rustls_provider::ALL_KX_GROUPS[2]),
"SECP384R1" => prov
.kx_groups
.push(boring_rustls_provider::ALL_KX_GROUPS[3]),
other => {
log::error!("Unknown kex `{other}`")
}
}
}
}
_provider = CryptoProvider::install_default(prov);
}
#[test]
fn test_parse_hex() {
assert_eq!(parse_hex(b"abc123"), 0xabc123);
assert_eq!(parse_hex(b"1"), 1);
#[cfg(feature = "graviola")]
{
info!("Using RusTLS provider graviola");
let mut prov = rustls_graviola::default_provider();
if let Some(ciphers) = ciphers {
prov.cipher_suites.clear();
for cipher in ciphers {
match cipher.as_str() {
"AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS13_AES_256_GCM_SHA384),
"AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS13_AES_128_GCM_SHA256),
"CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS13_CHACHA20_POLY1305_SHA256),
"ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384),
"ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256),
"ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(
rustls_graviola::suites::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
),
"ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
"ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
"ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(rustls_graviola::suites::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256),
other => {
log::error!("Unknown cipher `{other}`")
}
}
}
}
if let Some(kexes) = kexes {
prov.kx_groups.clear();
for kex in kexes {
match kex.as_str() {
"X25519" => prov.kx_groups.push(&rustls_graviola::kx::X25519),
"SECP256R1" => prov.kx_groups.push(&rustls_graviola::kx::P256),
"SECP384R1" => prov.kx_groups.push(&rustls_graviola::kx::P384),
"X25519MLKEM768" => prov.kx_groups.push(rustls_graviola::kx::X25519MLKEM768),
other => {
log::error!("Unknown kex `{other}`")
}
}
}
}
_provider = CryptoProvider::install_default(prov);
}
#[test]
fn test_is_hex() {
assert!(is_hex(b"0"));
assert!(is_hex(b"1"));
assert!(is_hex(b"2"));
assert!(is_hex(b"3"));
assert!(is_hex(b"4"));
assert!(is_hex(b"5"));
assert!(is_hex(b"6"));
assert!(is_hex(b"7"));
assert!(is_hex(b"8"));
assert!(is_hex(b"9"));
assert!(is_hex(b"a"));
assert!(is_hex(b"b"));
assert!(is_hex(b"c"));
assert!(is_hex(b"d"));
assert!(is_hex(b"e"));
assert!(is_hex(b"f"));
assert!(is_hex(b"A"));
assert!(is_hex(b"B"));
assert!(is_hex(b"C"));
assert!(is_hex(b"D"));
assert!(is_hex(b"E"));
assert!(is_hex(b"F"));
#[cfg(feature = "openssl")]
{
info!("Using RusTLS provider openssl");
let mut prov = rustls_openssl::default_provider();
if let Some(ciphers) = ciphers {
prov.cipher_suites.clear();
for cipher in ciphers {
match cipher.as_str() {
"AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_openssl::cipher_suite::TLS13_AES_256_GCM_SHA384),
"AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_openssl::cipher_suite::TLS13_AES_128_GCM_SHA256),
"CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(rustls_openssl::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256),
"ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov.cipher_suites.push(
rustls_openssl::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
),
"ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov.cipher_suites.push(
rustls_openssl::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
),
"ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(
rustls_openssl::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
),
"ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_openssl::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
"ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_openssl::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
"ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov.cipher_suites.push(
rustls_openssl::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
),
other => {
log::error!("Unknown cipher `{other}`")
}
}
}
}
if let Some(kexes) = kexes {
prov.kx_groups.clear();
for kex in kexes {
match kex.as_str() {
"X25519" => prov.kx_groups.push(rustls_openssl::kx_group::X25519),
"SECP256R1" => prov.kx_groups.push(rustls_openssl::kx_group::SECP256R1),
"SECP384R1" => prov.kx_groups.push(rustls_openssl::kx_group::SECP384R1),
"X25519MLKEM768" => prov
.kx_groups
.push(rustls_openssl::kx_group::X25519MLKEM768),
"MLKEM768" => prov.kx_groups.push(rustls_openssl::kx_group::MLKEM768),
other => {
log::error!("Unknown kex `{other}`")
}
}
}
}
_provider = CryptoProvider::install_default(prov);
}
#[cfg(feature = "ring")]
{
info!("Using RusTLS provider ring");
let mut prov = tokio_rustls::rustls::crypto::ring::default_provider();
if let Some(ciphers) = ciphers {
prov.cipher_suites.clear();
for cipher in ciphers {
match cipher.as_str() {
"AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS13_AES_256_GCM_SHA384),
"AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS13_AES_128_GCM_SHA256),
"CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS13_CHACHA20_POLY1305_SHA256),
"ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384),
"ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256),
"ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256),
"ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
"ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
"ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(tokio_rustls::rustls::crypto::ring::cipher_suite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256),
other => {
log::error!("Unknown cipher `{other}`")
}
}
}
}
if let Some(kexes) = kexes {
prov.kx_groups.clear();
for kex in kexes {
match kex.as_str() {
"X25519" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::ring::kx_group::X25519),
"SECP256R1" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::ring::kx_group::SECP256R1),
"SECP384R1" => prov
.kx_groups
.push(tokio_rustls::rustls::crypto::ring::kx_group::SECP384R1),
other => {
log::error!("Unknown kex `{other}`")
}
}
}
}
_provider = CryptoProvider::install_default(prov);
}
#[cfg(feature = "symcrypt")]
{
info!("Using RusTLS provider symcrypt");
let mut prov = rustls_symcrypt::default_symcrypt_provider();
if let Some(ciphers) = ciphers {
prov.cipher_suites.clear();
for cipher in ciphers {
match cipher.as_str() {
"AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_symcrypt::TLS13_AES_256_GCM_SHA384),
"AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_symcrypt::TLS13_AES_128_GCM_SHA256),
"CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(rustls_symcrypt::TLS13_CHACHA20_POLY1305_SHA256),
"ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_symcrypt::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384),
"ECDHE_ECDSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_symcrypt::TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256),
"ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(rustls_symcrypt::TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256),
"ECDHE_RSA_WITH_AES_256_GCM_SHA384" => prov
.cipher_suites
.push(rustls_symcrypt::TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384),
"ECDHE_RSA_WITH_AES_128_GCM_SHA256" => prov
.cipher_suites
.push(rustls_symcrypt::TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256),
"ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256" => prov
.cipher_suites
.push(rustls_symcrypt::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256),
other => {
log::error!("Unknown cipher `{other}`")
}
}
}
}
if let Some(kexes) = kexes {
prov.kx_groups.clear();
for kex in kexes {
match kex.as_str() {
"X25519" => prov.kx_groups.push(rustls_symcrypt::X25519),
"SECP256R1" => prov.kx_groups.push(rustls_symcrypt::SECP256R1),
"SECP384R1" => prov.kx_groups.push(rustls_symcrypt::SECP384R1),
other => {
log::error!("Unknown kex `{other}`")
}
}
}
}
_provider = CryptoProvider::install_default(prov);
}
}