use crate::{ TlsMode, record::{Direction, Records}, }; use futures_util::StreamExt; use std::{ collections::HashSet, net::ToSocketAddrs, sync::{Arc, atomic::AtomicU32}, }; use tokio::{ io::AsyncWriteExt, net::TcpStream, sync::{Mutex, Semaphore, oneshot}, }; use tokio_rustls::{ TlsConnector, rustls::{ SignatureScheme, client::danger::{HandshakeSignatureValid, ServerCertVerifier}, pki_types::ServerName, }, }; use tokio_util::codec::Framed; #[derive(Debug)] struct DummyCertVerifier; impl ServerCertVerifier for DummyCertVerifier { fn verify_server_cert( &self, _end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, _intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>], _server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>, _ocsp_response: &[u8], _now: tokio_rustls::rustls::pki_types::UnixTime, ) -> Result { Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion()) } fn supported_verify_schemes(&self) -> Vec { vec![ SignatureScheme::RSA_PKCS1_SHA1, SignatureScheme::ECDSA_SHA1_Legacy, SignatureScheme::RSA_PKCS1_SHA256, SignatureScheme::ECDSA_NISTP256_SHA256, SignatureScheme::RSA_PKCS1_SHA384, SignatureScheme::ECDSA_NISTP384_SHA384, SignatureScheme::RSA_PKCS1_SHA512, SignatureScheme::ECDSA_NISTP521_SHA512, SignatureScheme::RSA_PSS_SHA256, SignatureScheme::RSA_PSS_SHA384, SignatureScheme::RSA_PSS_SHA512, SignatureScheme::ED25519, SignatureScheme::ED448, SignatureScheme::ML_DSA_44, SignatureScheme::ML_DSA_65, SignatureScheme::ML_DSA_87, ] } fn verify_tls12_signature( &self, _message: &[u8], _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, _dss: &tokio_rustls::rustls::DigitallySignedStruct, ) -> Result< tokio_rustls::rustls::client::danger::HandshakeSignatureValid, tokio_rustls::rustls::Error, > { Ok(HandshakeSignatureValid::assertion()) } fn verify_tls13_signature( &self, _message: &[u8], _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>, _dss: &tokio_rustls::rustls::DigitallySignedStruct, ) -> Result { Ok(HandshakeSignatureValid::assertion()) } } pub async fn play( records: &'static Records, tls_mode: TlsMode, connect_to: (String, u16), sync_receiver: oneshot::Receiver<()>, repeat: u32, ) { sync_receiver.await.unwrap(); // Semaphore used to limit the number of concurrent clients. // Its handle is released when the task panics. let limiter = Arc::new(Semaphore::new(32)); let counter = Arc::new(AtomicU32::new(0)); let running = Arc::new(Mutex::new(HashSet::new())); let total = records.len() * repeat as usize; let mut handles = Vec::new(); let connect_to = connect_to.to_socket_addrs().unwrap().next().unwrap(); match tls_mode { TlsMode::Both | TlsMode::Client => { let config = Arc::new( tokio_rustls::rustls::ClientConfig::builder() .dangerous() .with_custom_certificate_verifier(Arc::new(DummyCertVerifier)) .with_no_client_auth(), ); for (id, (server_name, records)) in records.iter() { let connector = TlsConnector::from(config.clone()); let counter = counter.clone(); let limiter = limiter.clone(); let running = running.clone(); handles.push(tokio::spawn(async move { let mut running_guard = running.lock().await; running_guard.insert(*id); drop(running_guard); let limiter = limiter.acquire().await.unwrap(); let server_name = ServerName::try_from(String::from_utf8(server_name.clone()).unwrap()) .unwrap(); 'repeat: for _i in 0..repeat { let stream = TcpStream::connect(connect_to).await.unwrap(); let stream = connector .connect(server_name.clone(), stream) .await .unwrap(); let mut stream = Framed::new(stream, crate::http::HttpCodec {}); for (direction, data) in records { match direction { Direction::ClientToServer => { println!("[CLT] ({id}) >> {}", data.len()); //stream.get_mut().write_all(data).await.unwrap(); match tokio::time::timeout( std::time::Duration::from_millis(1000), stream.get_mut().write_all(data), ) .await { Ok(v) => v.unwrap(), Err(_e) => continue 'repeat, } } Direction::ServerToClient => { println!("[CLT] ({id}) << {}", data.len()); // let mut buf = Vec::new(); // stream.read_buf(&mut buf).await.ok(); //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; //let resp = stream.next().await.unwrap().unwrap(); match tokio::time::timeout( std::time::Duration::from_millis(1000), stream.next(), ) .await { Ok(v) => v.unwrap().unwrap(), Err(_e) => { // TODO fix break 'repeat; } }; //dbg!(resp.len()); //crate::http::decode_http(&mut buf, &mut stream).await; } } } //stream.get_mut().shutdown().await.unwrap(); tokio::time::timeout( std::time::Duration::from_millis(1000), stream.get_mut().shutdown(), ) .await .unwrap() .unwrap(); let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); println!("Client: {} / {}", cnt + 1, total); } drop(limiter); let mut running_guard = running.lock().await; running_guard.remove(id); drop(running_guard); })); //tokio::time::sleep(std::time::Duration::from_millis(500)).await; } } TlsMode::None | TlsMode::Server => { for (id, (_server_name, records)) in records.iter() { /*if *id != 33 { continue }*/ let counter = counter.clone(); let limiter = limiter.clone(); let running = running.clone(); handles.push(tokio::spawn(async move { let mut running_guard = running.lock().await; running_guard.insert(*id); drop(running_guard); let limiter = limiter.acquire().await.unwrap(); //let mut buf = Vec::new(); 'repeat: for _i in 0..repeat { let stream = TcpStream::connect(connect_to).await.unwrap(); let mut stream = Framed::new(stream, crate::http::HttpCodec {}); /*let mut skip_recv = false; for (direction, data) in records { match direction { Direction::ClientToServer => { skip_recv = false; println!("[CLT] ({id}) >> {}", data.len()); stream.write_all(data).await.unwrap(); } Direction::ServerToClient => { if skip_recv { continue; } println!("[CLT] ({id}) << {}", data.len()); //let mut buf = Vec::new(); //stream.read_buf(&mut buf).await.ok(); //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; let mut buf = vec![0; data.len()]; match tokio::time::timeout( std::time::Duration::from_millis(500), stream.readable(), ) .await { Ok(r) => { r.unwrap(); } Err(_) => { println!("[CLT] timeout recv ({id})"); break; } } // TODO utiliser crate::http ici match tokio::time::timeout( std::time::Duration::from_millis(500), stream.read_exact(&mut buf), ) .await { Ok(r) => { r.unwrap(); } Err(_) => { println!("[CLT] skip recv ({id})"); skip_recv = true; } } } } }*/ for (direction, data) in records { match direction { Direction::ClientToServer => { println!("[CLT] ({id}) >> {}", data.len()); //stream.get_mut().write_all(data).await.unwrap(); match tokio::time::timeout( std::time::Duration::from_millis(1000), stream.get_mut().write_all(data), ) .await { Ok(v) => v.unwrap(), Err(_e) => continue 'repeat, } } Direction::ServerToClient => { println!("[CLT] ({id}) << {}", data.len()); //let mut buf = Vec::new(); //stream.read_buf(&mut buf).await.ok(); //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; match tokio::time::timeout( std::time::Duration::from_millis(1000), stream.next(), ) .await { Ok(v) => v.unwrap().unwrap(), Err(_e) => { // TODO fix break 'repeat; } }; //let resp = stream.next().await.unwrap().unwrap(); //dbg!(resp.len()); //crate::http::decode_http(&mut buf, &mut stream).await; //buf.clear(); } } } //stream.get_mut().shutdown().await.unwrap(); tokio::time::timeout( std::time::Duration::from_millis(1000), stream.get_mut().shutdown(), ) .await .unwrap() .unwrap(); let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); println!("Client: {} / {}", cnt + 1, total); } drop(limiter); let mut running_guard = running.lock().await; running_guard.remove(id); drop(running_guard); })); //tokio::time::sleep(std::time::Duration::from_millis(500)).await; } } } tokio::spawn(async move { loop { tokio::time::sleep(std::time::Duration::from_secs(1)).await; println!("Running: {:?}", running.lock().await); } }); for handle in handles { handle.await.unwrap(); } std::process::exit(0); }