From 5656af0781658da20301d604b8aac7e5703ac4a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Eng=C3=A9libert?= Date: Thu, 6 Nov 2025 17:27:16 +0100 Subject: [PATCH] Fix more bugs --- Cargo.toml | 2 +- src/client.rs | 103 +++++++++++++++++++++++++++++++++++++++----------- src/http.rs | 3 -- src/main.rs | 2 - 4 files changed, 82 insertions(+), 28 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 52929e0..e3ca750 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,7 @@ console-subscriber = "0.5.0" futures-util = "0.3.31" memchr = "2.7.6" regex = "1.12.2" -sslrelay = { path = "../sslrelay-lib" } +sslrelay = { path = "../sslrelay" } static_cell = "2.1.1" tlsh = { package = "fast-tlsh", version = "0.1.10", features = ["easy-functions"] } tokio = { version = "1.48.0", features = ["io-util", "macros", "net", "rt", "rt-multi-thread", "sync", "time", "tracing"]} diff --git a/src/client.rs b/src/client.rs index c2aef79..8f8f6c0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -5,13 +5,14 @@ use crate::{ use futures_util::StreamExt; use std::{ + collections::HashSet, net::ToSocketAddrs, sync::{Arc, atomic::AtomicU32}, }; use tokio::{ io::AsyncWriteExt, net::TcpStream, - sync::{Semaphore, oneshot}, + sync::{Mutex, Semaphore, oneshot}, }; use tokio_rustls::{ TlsConnector, @@ -91,6 +92,7 @@ pub async fn play( // Its handle is released when the task panics. let limiter = Arc::new(Semaphore::new(32)); let counter = Arc::new(AtomicU32::new(0)); + let running = Arc::new(Mutex::new(HashSet::new())); let total = records.len() * repeat as usize; let mut handles = Vec::new(); let connect_to = connect_to.to_socket_addrs().unwrap().next().unwrap(); @@ -106,12 +108,16 @@ pub async fn play( let connector = TlsConnector::from(config.clone()); let counter = counter.clone(); let limiter = limiter.clone(); + let running = running.clone(); handles.push(tokio::spawn(async move { + let mut running_guard = running.lock().await; + running_guard.insert(*id); + drop(running_guard); let limiter = limiter.acquire().await.unwrap(); let server_name = ServerName::try_from(String::from_utf8(server_name.clone()).unwrap()) .unwrap(); - for _i in 0..repeat { + 'repeat: for _i in 0..repeat { let stream = TcpStream::connect(connect_to).await.unwrap(); let stream = connector .connect(server_name.clone(), stream) @@ -122,7 +128,16 @@ pub async fn play( match direction { Direction::ClientToServer => { println!("[CLT] ({id}) >> {}", data.len()); - stream.get_mut().write_all(data).await.unwrap(); + //stream.get_mut().write_all(data).await.unwrap(); + match tokio::time::timeout( + std::time::Duration::from_millis(1000), + stream.get_mut().write_all(data), + ) + .await + { + Ok(v) => v.unwrap(), + Err(_e) => continue 'repeat, + } } Direction::ServerToClient => { println!("[CLT] ({id}) << {}", data.len()); @@ -130,24 +145,38 @@ pub async fn play( // stream.read_buf(&mut buf).await.ok(); //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; //let resp = stream.next().await.unwrap().unwrap(); - let resp = tokio::time::timeout( - std::time::Duration::from_millis(500), + match tokio::time::timeout( + std::time::Duration::from_millis(1000), stream.next(), ) .await - .unwrap() - .unwrap() - .unwrap(); - dbg!(resp.len()); + { + Ok(v) => v.unwrap().unwrap(), + Err(_e) => { + // TODO fix + break 'repeat; + } + }; + //dbg!(resp.len()); //crate::http::decode_http(&mut buf, &mut stream).await; } } } - stream.get_mut().shutdown().await.unwrap(); + //stream.get_mut().shutdown().await.unwrap(); + tokio::time::timeout( + std::time::Duration::from_millis(1000), + stream.get_mut().shutdown(), + ) + .await + .unwrap() + .unwrap(); let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); println!("Client: {} / {}", cnt + 1, total); } drop(limiter); + let mut running_guard = running.lock().await; + running_guard.remove(id); + drop(running_guard); })); //tokio::time::sleep(std::time::Duration::from_millis(500)).await; } @@ -159,12 +188,14 @@ pub async fn play( }*/ let counter = counter.clone(); let limiter = limiter.clone(); + let running = running.clone(); handles.push(tokio::spawn(async move { - dbg!(limiter.available_permits()); + let mut running_guard = running.lock().await; + running_guard.insert(*id); + drop(running_guard); let limiter = limiter.acquire().await.unwrap(); //let mut buf = Vec::new(); - for _i in 0..repeat { - dbg!(); + 'repeat: for _i in 0..repeat { let stream = TcpStream::connect(connect_to).await.unwrap(); let mut stream = Framed::new(stream, crate::http::HttpCodec {}); /*let mut skip_recv = false; @@ -220,39 +251,67 @@ pub async fn play( match direction { Direction::ClientToServer => { println!("[CLT] ({id}) >> {}", data.len()); - stream.get_mut().write_all(data).await.unwrap(); + //stream.get_mut().write_all(data).await.unwrap(); + match tokio::time::timeout( + std::time::Duration::from_millis(1000), + stream.get_mut().write_all(data), + ) + .await + { + Ok(v) => v.unwrap(), + Err(_e) => continue 'repeat, + } } Direction::ServerToClient => { println!("[CLT] ({id}) << {}", data.len()); //let mut buf = Vec::new(); //stream.read_buf(&mut buf).await.ok(); //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; - let resp = tokio::time::timeout( - std::time::Duration::from_millis(500), + match tokio::time::timeout( + std::time::Duration::from_millis(1000), stream.next(), ) .await - .unwrap() - .unwrap() - .unwrap(); + { + Ok(v) => v.unwrap().unwrap(), + Err(_e) => { + // TODO fix + break 'repeat; + } + }; //let resp = stream.next().await.unwrap().unwrap(); - dbg!(resp.len()); + //dbg!(resp.len()); //crate::http::decode_http(&mut buf, &mut stream).await; //buf.clear(); } } } - dbg!(); - stream.get_mut().shutdown().await.unwrap(); + //stream.get_mut().shutdown().await.unwrap(); + tokio::time::timeout( + std::time::Duration::from_millis(1000), + stream.get_mut().shutdown(), + ) + .await + .unwrap() + .unwrap(); let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); println!("Client: {} / {}", cnt + 1, total); } drop(limiter); + let mut running_guard = running.lock().await; + running_guard.remove(id); + drop(running_guard); })); //tokio::time::sleep(std::time::Duration::from_millis(500)).await; } } } + tokio::spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + println!("Running: {:?}", running.lock().await); + } + }); for handle in handles { handle.await.unwrap(); } diff --git a/src/http.rs b/src/http.rs index b868c9e..416499b 100644 --- a/src/http.rs +++ b/src/http.rs @@ -7,9 +7,7 @@ static REGEX_CONTENT_LENGTH: LazyLock = LazyLock::new(|| Regex::new(r#"[cC]ontent-[lL]ength: *(\d+)\r\n"#).unwrap()); pub async fn _decode_http(buf: &mut Vec, stream: &mut R) { - dbg!(); loop { - dbg!(); if let Some(mut end_index) = memchr::memmem::find(buf, b"\r\n\r\n") { end_index += 4; if let Some(captures) = REGEX_CONTENT_LENGTH.captures(buf) { @@ -20,7 +18,6 @@ pub async fn _decode_http(buf: &mut Vec, stream: &m .parse() .unwrap(); while buf.len() < end_index + content_length { - dbg!(); match tokio::time::timeout( std::time::Duration::from_millis(500), stream.read_buf(buf), diff --git a/src/main.rs b/src/main.rs index c70603a..ba21e47 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,3 @@ -#![feature(let_chains)] - mod client; mod http; mod record;