use std::{ collections::{BTreeMap, btree_map}, io::{Read, Write}, sync::mpsc::{Receiver, Sender, channel}, }; const CLIENT_TO_SERVER: u8 = b'C'; const SERVER_TO_CLIENT: u8 = b'S'; pub type Records = BTreeMap, Vec<(u64, Direction, u64)>)>; #[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] pub enum Direction { ClientToServer, ServerToClient, } static TEST_RECORD: &[(u64, &str, Direction, &[u8])] = &[ (0, "upload.wikimedia.org", Direction::ClientToServer, b"GET /aaaaaaaaa HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"), (0, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nContent-Length: 5\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\nhello"), (0, "upload.wikimedia.org", Direction::ClientToServer, b"GET /bbbbbbbbb HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"), (0, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nContent-Length: 7\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\nbonjour"), (1, "upload.wikimedia.org", Direction::ClientToServer, b"GET /ccccccccc HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"), (1, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\n6\r\nbanane\r\n"), (1, "upload.wikimedia.org", Direction::ServerToClient, b"5\r\npomme\r\n"), (1, "upload.wikimedia.org", Direction::ServerToClient, b"0\r\n\r\n"), ]; fn write_record( file: &mut std::fs::File, direction: Direction, conn_id: u64, server_name: &str, len: u64, ) { let server_name = server_name.as_bytes(); file.write_all(&[match direction { Direction::ClientToServer => CLIENT_TO_SERVER, Direction::ServerToClient => SERVER_TO_CLIENT, }]) .unwrap(); file.write_all(&conn_id.to_be_bytes()).unwrap(); file.write_all(&[server_name.len() as u8]).unwrap(); file.write_all(server_name).unwrap(); file.write_all(&len.to_be_bytes()).unwrap(); file.flush().unwrap(); } pub struct Recorder { file: std::fs::File, receiver: Receiver<(u64, Option, Direction, Vec)>, } impl Recorder { #[allow(clippy::type_complexity)] pub fn new(path: &str) -> (Self, Sender<(u64, Option, Direction, Vec)>) { let (sender, receiver) = channel(); ( Self { file: std::fs::OpenOptions::new() .create(true) .write(true) .truncate(true) .open(path) .unwrap(), receiver, }, sender, ) } pub fn run(&mut self) { while let Ok((conn_id, server_name, direction, data)) = self.receiver.recv() { let Some(server_name) = server_name else { continue; }; write_record( &mut self.file, direction, conn_id, &server_name, data.len() as u64, ); } } } #[cfg(feature = "record")] #[derive(Clone)] struct Handler { sender: Sender<(u64, Option, Direction, Vec)>, server_name: Option, } #[cfg(feature = "record")] impl sslrelay::HandlerCallbacks for Handler { // DownStream non blocking callback fn ds_nb_callback(&self, in_data: Vec, conn_id: u64) { self.sender .send(( conn_id, self.server_name.clone(), Direction::ClientToServer, in_data, )) .unwrap(); } // DownStream blocking callback fn ds_b_callback(&mut self, in_data: Vec, _conn_id: u64) -> sslrelay::CallbackRet { sslrelay::CallbackRet::Relay(in_data) } // UpStream non blocking callback fn us_nb_callback(&self, in_data: Vec, conn_id: u64) { self.sender .send(( conn_id, self.server_name.clone(), Direction::ServerToClient, in_data, )) .unwrap(); } // UpStream blocking callback fn us_b_callback(&mut self, in_data: Vec, _conn_id: u64) -> sslrelay::CallbackRet { sslrelay::CallbackRet::Relay(in_data) } fn set_server_name(&mut self, server_name: Option<&str>) { self.server_name = server_name.map(str::to_string); } } #[cfg(feature = "record")] pub fn make_record(path: &str) { let (mut recorder, sender) = Recorder::new(path); let mut relay = sslrelay::SSLRelay::new( Handler { sender, server_name: None, }, sslrelay::RelayConfig { downstream_data_type: sslrelay::TCPDataType::TLS, upstream_data_type: sslrelay::TCPDataType::TLS, bind_host: "127.0.0.1".to_string(), bind_port: "443".to_string(), remote_host: |server_name| { server_name .map(str::to_string) .unwrap_or_else(|| String::from("www.apple.com")) }, remote_port: "443".to_string(), tls_config: sslrelay::TLSConfig::FILE { certificate_path: "/dev/shm/exp/certs/prime256v1/all.crt".to_string(), private_key_path: "/dev/shm/exp/certs/prime256v1/all.key".to_string(), }, }, ); std::thread::spawn(move || recorder.run()); relay.start(); } pub fn read_record_file(path: &str) -> Records { let mut file = std::fs::OpenOptions::new().read(true).open(path).unwrap(); let mut records = BTreeMap::, Vec<(u64, Direction, u64)>)>::new(); loop { let mut direction = [0; 1]; if file.read(&mut direction).unwrap() != 1 { break; } let direction = match direction[0] { CLIENT_TO_SERVER => Direction::ClientToServer, SERVER_TO_CLIENT => Direction::ServerToClient, _ => { println!("Error: invalid direction. stop."); break; } }; let mut conn_id = [0; 8]; if file.read(&mut conn_id).unwrap() != 8 { println!("Error: incomplete conn id. stop."); break; } let conn_id = u64::from_be_bytes(conn_id); let mut server_name_len = [0]; if file.read(&mut server_name_len).unwrap() != 1 { println!("Error: incomplete server name len. stop."); break; } let server_name_len = server_name_len[0] as usize; let mut server_name = vec![0; server_name_len]; if file.read(&mut server_name).unwrap() != server_name_len { println!("Error: incomplete data. stop."); break; } let mut len = [0; 8]; if file.read(&mut len).unwrap() != 8 { println!("Error: incomplete len. stop."); break; } let len = u64::from_be_bytes(len); if len > 0xfff_ffff { println!("Error: len too large {len}. stop."); break; } match records.entry(conn_id) { btree_map::Entry::Occupied(mut entry) => { let req_id = entry.get().1.len() as u64; entry.get_mut().1.push((req_id, direction, len)); } btree_map::Entry::Vacant(entry) => { let req_id = 0; entry.insert((server_name, vec![(req_id, direction, len)])); } } } records } fn round(n: f32, prec: f32) -> f32 { (n * prec).round() / prec } fn format_size(bytes: u64) -> String { match bytes { 0..1024 => format!("{bytes}"), 1024..0x100000 => format!("{}k", ((bytes * 100 / 1024) as f32).round() / 100.), 0x100000..0x40000000 => format!("{}M", ((bytes * 100 / 0x100000) as f32).round() / 100.), 0x40000000.. => format!("{}G", ((bytes * 100 / 0x40000000) as f32).round() / 100.), } } pub fn print_records(records: &Records, number: Option) { let mut total_c2s = 0; let mut total_s2c = 0; for (conn_id, (server_name, records)) in records { if let Some(number) = number && number != *conn_id { continue; } let server_name = str::from_utf8(server_name.as_slice()).unwrap(); println!("{conn_id} {server_name}"); for (req_id, direction, len) in records { match direction { Direction::ClientToServer => { println!(" ({req_id}) >> {len}"); total_c2s += len; } Direction::ServerToClient => { println!(" ({req_id}) << {len}"); total_s2c += len; } } } } println!("Total:"); println!(" >> {}", format_size(total_c2s)); println!(" << {}", format_size(total_s2c)); } pub fn make_test_record(path: &str) { let mut file = std::fs::OpenOptions::new() .write(true) .create(true) .truncate(true) .open(path) .unwrap(); for (conn_id, server_name, direction, data) in TEST_RECORD { write_record( &mut file, *direction, *conn_id, server_name, data.len() as u64, ); } } pub fn remove_record( input_path: &str, output_path: &str, record_to_remove: u64, packet_to_remove: usize, ) { let records = read_record_file(input_path); let mut output_file = std::fs::OpenOptions::new() .write(true) .create(true) .truncate(true) .open(output_path) .unwrap(); for (conn_id, (server_name, packets)) in records.into_iter() { let server_name = String::from_utf8(server_name).unwrap(); for (packet_id, (_req_id, direction, len)) in packets.into_iter().enumerate() { if conn_id != record_to_remove || packet_id != packet_to_remove { write_record(&mut output_file, direction, conn_id, &server_name, len); } } } }