Initial commit

This commit is contained in:
Pascal Engélibert 2025-10-30 13:57:15 +01:00
commit feb1ec51c8
11 changed files with 2559 additions and 0 deletions

149
src/client.rs Normal file
View file

@ -0,0 +1,149 @@
use crate::{
TlsMode,
record::{Direction, Records},
};
use std::{net::ToSocketAddrs, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpStream,
sync::oneshot,
};
use tokio_rustls::{
TlsConnector,
rustls::{
SignatureScheme,
client::danger::{HandshakeSignatureValid, ServerCertVerifier},
pki_types::ServerName,
},
};
#[derive(Debug)]
struct DummyCertVerifier;
impl ServerCertVerifier for DummyCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>],
_server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: tokio_rustls::rustls::pki_types::UnixTime,
) -> Result<tokio_rustls::rustls::client::danger::ServerCertVerified, tokio_rustls::rustls::Error>
{
Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
}
fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
SignatureScheme::ML_DSA_44,
SignatureScheme::ML_DSA_65,
SignatureScheme::ML_DSA_87,
]
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<
tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
tokio_rustls::rustls::Error,
> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
_dss: &tokio_rustls::rustls::DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, tokio_rustls::rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
}
pub async fn play(
records: &'static Records,
tls_mode: TlsMode,
connect_to: (String, u16),
sync_receiver: oneshot::Receiver<()>,
repeat: u32,
) {
sync_receiver.await.unwrap();
let mut handles = Vec::new();
let connect_to = connect_to.to_socket_addrs().unwrap().next().unwrap();
match tls_mode {
TlsMode::Both | TlsMode::Client => {
let config = Arc::new(
tokio_rustls::rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(Arc::new(DummyCertVerifier))
.with_no_client_auth(),
);
for (_id, (server_name, records)) in records.iter() {
let connector = TlsConnector::from(config.clone());
handles.push(tokio::spawn(async move {
let server_name =
ServerName::try_from(String::from_utf8(server_name.clone()).unwrap())
.unwrap();
for _i in 0..repeat {
let stream = TcpStream::connect(connect_to).await.unwrap();
let mut stream = connector
.connect(server_name.clone(), stream)
.await
.unwrap();
for (direction, data) in records {
match direction {
Direction::ClientToServer => {
stream.write_all(data).await.unwrap();
}
Direction::ServerToClient => {
let mut buf = Vec::new();
stream.read_buf(&mut buf).await.ok();
}
}
}
stream.shutdown().await.unwrap();
}
}));
}
}
TlsMode::None | TlsMode::Server => {
for (_id, (_server_name, records)) in records.iter() {
handles.push(tokio::spawn(async move {
for _i in 0..repeat {
let mut stream = TcpStream::connect(connect_to).await.unwrap();
for (direction, data) in records {
match direction {
Direction::ClientToServer => {
stream.write_all(data).await.unwrap();
}
Direction::ServerToClient => {
let mut buf = Vec::new();
stream.read_buf(&mut buf).await.ok();
}
}
}
stream.shutdown().await.unwrap();
}
}));
}
}
}
for handle in handles {
handle.await.unwrap();
}
//std::process::exit(0);
}

122
src/main.rs Normal file
View file

@ -0,0 +1,122 @@
#![feature(let_chains)]
mod client;
mod record;
mod server;
use record::Records;
use argp::FromArgs;
use static_cell::StaticCell;
use tokio::sync::oneshot;
/// Play recorded requests and responses
#[derive(FromArgs)]
struct Opt {
/// Path to record file
#[argp(positional)]
record_file: String,
#[argp(subcommand)]
subcommand: Subcommand,
}
#[derive(FromArgs)]
#[argp(subcommand)]
enum Subcommand {
/// Replay from records
Play(OptPlay),
/// Print records
Print(OptPrint),
/// Record traffic
Record(OptRecord),
}
/// Replay from records
#[derive(FromArgs)]
#[argp(subcommand, name = "play")]
struct OptPlay {
/// Connect to address
#[argp(positional)]
forward_addr: String,
/// Connect to port
#[argp(positional)]
forward_port: u16,
/// Listen to port
#[argp(positional)]
listen_port: u16,
/// Path to PEM certificates and keys
#[argp(positional)]
certs: String,
/// Where to use TLS
#[argp(positional)]
tls: String,
/// Repeat N times
#[argp(option, short = 'r', default = "1")]
repeat: u32,
}
/// Print records
#[derive(FromArgs)]
#[argp(subcommand, name = "print")]
struct OptPrint {
/// Print packets
#[argp(switch, short = 'p')]
packets: bool,
}
/// Record traffic
#[derive(FromArgs)]
#[argp(subcommand, name = "record")]
struct OptRecord {}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
enum TlsMode {
None,
Client,
Server,
Both,
}
static RECORDS: StaticCell<Records> = StaticCell::new();
#[tokio::main]
async fn main() {
let opt: Opt = argp::parse_args_or_exit(argp::DEFAULT);
match opt.subcommand {
Subcommand::Play(subopt) => {
let tls_mode = match subopt.tls.as_str() {
"none" => TlsMode::None,
"client" => TlsMode::Client,
"server" => TlsMode::Server,
"both" => TlsMode::Both,
_ => panic!("TLS mode must be one of none,client,server,both."),
};
let records = RECORDS.init(record::read_record_file(&opt.record_file));
let (sync_sender, sync_receiver) = oneshot::channel();
let client = tokio::spawn(client::play(
records,
tls_mode,
(subopt.forward_addr, subopt.forward_port),
sync_receiver,
subopt.repeat,
));
server::play(
records,
tls_mode,
&subopt.certs,
("0.0.0.0", subopt.listen_port),
sync_sender,
)
.await;
client.await.unwrap();
}
Subcommand::Print(subopt) => {
let records = record::read_record_file(&opt.record_file);
record::print_records(&records, subopt.packets);
}
Subcommand::Record(_subopt) => {
record::make_record(&opt.record_file);
}
}
}

228
src/record.rs Normal file
View file

@ -0,0 +1,228 @@
use std::{
collections::{BTreeMap, btree_map},
io::{Read, Write},
sync::mpsc::{Receiver, Sender, channel},
};
const CLIENT_TO_SERVER: u8 = b'C';
const SERVER_TO_CLIENT: u8 = b'S';
pub type Records = BTreeMap<u64, (Vec<u8>, Vec<(Direction, Vec<u8>)>)>;
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
pub enum Direction {
ClientToServer,
ServerToClient,
}
pub struct Recorder {
file: std::fs::File,
receiver: Receiver<(u64, Option<String>, Direction, Vec<u8>)>,
}
impl Recorder {
#[allow(clippy::type_complexity)]
pub fn new(path: &str) -> (Self, Sender<(u64, Option<String>, Direction, Vec<u8>)>) {
let (sender, receiver) = channel();
(
Self {
file: std::fs::OpenOptions::new()
.create(true)
.write(true)
.truncate(true)
.open(path)
.unwrap(),
receiver,
},
sender,
)
}
pub fn run(&mut self) {
while let Ok((conn_id, server_name, direction, data)) = self.receiver.recv() {
let Some(server_name) = server_name else {
continue;
};
let server_name = server_name.as_bytes();
self.file
.write_all(&[match direction {
Direction::ClientToServer => CLIENT_TO_SERVER,
Direction::ServerToClient => SERVER_TO_CLIENT,
}])
.unwrap();
self.file.write_all(&conn_id.to_be_bytes()).unwrap();
self.file.write_all(&[server_name.len() as u8]).unwrap();
self.file.write_all(server_name).unwrap();
self.file
.write_all(&(data.len() as u64).to_be_bytes())
.unwrap();
self.file.write_all(&data).unwrap();
self.file.flush().unwrap();
}
}
}
#[derive(Clone)]
struct Handler {
sender: Sender<(u64, Option<String>, Direction, Vec<u8>)>,
server_name: Option<String>,
}
impl sslrelay::HandlerCallbacks for Handler {
// DownStream non blocking callback
fn ds_nb_callback(&self, in_data: Vec<u8>, conn_id: u64) {
self.sender
.send((
conn_id,
self.server_name.clone(),
Direction::ClientToServer,
in_data,
))
.unwrap();
}
// DownStream blocking callback
fn ds_b_callback(&mut self, in_data: Vec<u8>, _conn_id: u64) -> sslrelay::CallbackRet {
sslrelay::CallbackRet::Relay(in_data)
}
// UpStream non blocking callback
fn us_nb_callback(&self, in_data: Vec<u8>, conn_id: u64) {
self.sender
.send((
conn_id,
self.server_name.clone(),
Direction::ServerToClient,
in_data,
))
.unwrap();
}
// UpStream blocking callback
fn us_b_callback(&mut self, in_data: Vec<u8>, _conn_id: u64) -> sslrelay::CallbackRet {
sslrelay::CallbackRet::Relay(in_data)
}
fn set_server_name(&mut self, server_name: Option<&str>) {
self.server_name = server_name.map(str::to_string);
}
}
pub fn make_record(path: &str) {
let (mut recorder, sender) = Recorder::new(path);
let mut relay = sslrelay::SSLRelay::new(
Handler {
sender,
server_name: None,
},
sslrelay::RelayConfig {
downstream_data_type: sslrelay::TCPDataType::TLS,
upstream_data_type: sslrelay::TCPDataType::TLS,
bind_host: "127.0.0.1".to_string(),
bind_port: "443".to_string(),
remote_host: |server_name| {
server_name
.map(str::to_string)
.unwrap_or_else(|| String::from("www.apple.com"))
},
remote_port: "443".to_string(),
tls_config: sslrelay::TLSConfig::FILE {
certificate_path: "/dev/shm/exp/certs/prime256v1/all.crt".to_string(),
private_key_path: "/dev/shm/exp/certs/prime256v1/all.key".to_string(),
},
},
);
std::thread::spawn(move || recorder.run());
relay.start();
}
pub fn read_record_file(path: &str) -> Records {
let mut file = std::fs::OpenOptions::new().read(true).open(path).unwrap();
let mut records = BTreeMap::<u64, (Vec<u8>, Vec<(Direction, Vec<u8>)>)>::new();
loop {
let mut direction = [0; 1];
if file.read(&mut direction).unwrap() != 1 {
break;
}
let direction = match direction[0] {
CLIENT_TO_SERVER => Direction::ClientToServer,
SERVER_TO_CLIENT => Direction::ServerToClient,
_ => {
println!("Error: invalid direction. stop.");
break;
}
};
let mut conn_id = [0; 8];
if file.read(&mut conn_id).unwrap() != 8 {
println!("Error: incomplete conn id. stop.");
break;
}
let conn_id = u64::from_be_bytes(conn_id);
let mut server_name_len = [0];
if file.read(&mut server_name_len).unwrap() != 1 {
println!("Error: incomplete server name len. stop.");
break;
}
let server_name_len = server_name_len[0] as usize;
let mut server_name = vec![0; server_name_len];
if file.read(&mut server_name).unwrap() != server_name_len {
println!("Error: incomplete data. stop.");
break;
}
let mut len = [0; 8];
if file.read(&mut len).unwrap() != 8 {
println!("Error: incomplete len. stop.");
break;
}
let len = u64::from_be_bytes(len);
if len > 0xfff_ffff {
println!("Error: len too large {len}. stop.");
break;
}
let mut buf = vec![0; len as usize];
if file.read(&mut buf).unwrap() != len as usize {
println!("Error: incomplete data. stop.");
break;
}
match records.entry(conn_id) {
btree_map::Entry::Occupied(mut entry) => {
entry.get_mut().1.push((direction, buf));
}
btree_map::Entry::Vacant(entry) => {
entry.insert((server_name, vec![(direction, buf)]));
}
}
}
records
}
pub fn print_records(records: &Records, print_packets: bool) {
for (id, (server_name, records)) in records {
let server_name = str::from_utf8(server_name.as_slice()).unwrap();
println!("{id} {server_name}");
for (direction, data) in records {
match direction {
Direction::ClientToServer => {
println!(" >> {}", data.len());
}
Direction::ServerToClient => {
println!(" << {}", data.len());
}
}
if print_packets {
let data = if data.len() >= 256 {
&data[0..256]
} else {
data.as_slice()
};
if let Ok(data) = str::from_utf8(data) {
println!(" {data:?}")
} else {
println!(" {data:?}")
}
}
}
}
}

220
src/server.rs Normal file
View file

@ -0,0 +1,220 @@
use crate::{
TlsMode,
record::{Direction, Records},
};
use std::{collections::HashMap, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
sync::oneshot,
};
use tokio_rustls::rustls::{
pki_types::{
CertificateDer, PrivateKeyDer,
pem::{self, PemObject, SectionKind},
},
server::ResolvesServerCertUsingSni,
sign::CertifiedKey,
};
use x509_parser::prelude::GeneralName;
pub async fn play(
records: &'static Records,
tls_mode: TlsMode,
cert_path: &str,
listen_addr: (&str, u16),
sync_sender: oneshot::Sender<()>,
) {
let mut response_map = HashMap::new();
for (_id, (server_name, records)) in records.iter() {
let mut hash = None;
let mut responses = Vec::new();
for (direction, data) in records {
match direction {
Direction::ClientToServer => {
if let Some(hash) = hash
&& !responses.is_empty()
{
response_map.insert((server_name.to_vec(), hash), responses);
responses = Vec::new();
}
hash = Some(
tlsh::hash_buf(data)
.map_or_else(|_| data.clone(), |h| h.to_string().into_bytes()),
);
}
Direction::ServerToClient => {
responses.push(data);
}
}
}
if let Some(hash) = hash
&& !responses.is_empty()
{
response_map.insert((server_name.to_vec(), hash), responses);
}
}
match tls_mode {
TlsMode::Both | TlsMode::Server => {
let mut resolver = ResolvesServerCertUsingSni::new();
let config = tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(ResolvesServerCertUsingSni::new()));
for file in std::fs::read_dir(cert_path).unwrap() {
match file {
Ok(file) => {
if file.file_name().as_encoded_bytes().ends_with(b".crt") {
for section in
<(pem::SectionKind, Vec<u8>) as PemObject>::pem_file_iter(
file.path(),
)
.unwrap()
{
let (kind, data) = section.unwrap();
if kind == SectionKind::Certificate {
let (_rem, cert) =
x509_parser::parse_x509_certificate(&data).unwrap();
if !cert.is_ca() {
println!("File: {:?}", file.file_name());
let mut key_path = file.path().to_path_buf();
key_path.pop();
let file_name =
file.file_name().to_str().unwrap().to_string();
let mut key_file_name =
file_name[0..file_name.len() - 4].to_string();
key_file_name.push_str(".key");
let key = PrivateKeyDer::from_pem_file(
key_path.join(key_file_name),
)
.unwrap();
let key = config
.crypto_provider()
.key_provider
.load_private_key(key)
.unwrap();
// This wants static lifetime...
let cert_key = CertifiedKey::new(
vec![CertificateDer::from_slice(Box::leak(
data.to_vec().into_boxed_slice(),
))],
key,
);
for name in cert
.subject_alternative_name()
.unwrap()
.unwrap()
.value
.general_names
.iter()
{
if let GeneralName::DNSName(name) = name {
resolver.add(dbg!(name), cert_key.clone()).unwrap();
}
}
}
}
}
}
}
Err(e) => eprintln!("Error listing cert directory: {e:?}"),
}
}
// Config requires resolver, keys can be added to resolver, creating a key requires config. WTF!?
// So we have to re-create config.
let config = Arc::new(
tokio_rustls::rustls::ServerConfig::builder()
.with_no_client_auth()
.with_cert_resolver(Arc::new(resolver)),
);
//let acceptor = tokio_rustls::TlsAcceptor::from(Arc::new(config));
let listener = TcpListener::bind(listen_addr).await.unwrap();
let response_map = Arc::new(response_map);
sync_sender.send(()).unwrap();
loop {
let config = config.clone();
let (stream, _peer_addr) = listener.accept().await.unwrap();
let acceptor = tokio_rustls::LazyConfigAcceptor::new(
tokio_rustls::rustls::server::Acceptor::default(),
stream,
);
//let acceptor = acceptor.clone();
let response_map = response_map.clone();
let fut = async move {
/*let mut server_name = None;
let mut stream = acceptor
.accept_with(stream, |conn| {
server_name = conn.server_name().map(String::from)
})
.await
.unwrap();
let server_name = server_name.unwrap();*/
let accepted = acceptor.await.unwrap();
let server_name = accepted.client_hello().server_name().unwrap().to_string();
let mut stream = accepted.into_stream(config).await.unwrap();
let mut req = Vec::new();
// TODO if there is a body
while !req.ends_with(b"\r\n\r\n") {
if stream.read_buf(&mut req).await.unwrap() == 0 {
break;
}
}
if let Ok(req) = str::from_utf8(&req) {
println!("{req}");
}
let req_hash = tlsh::hash_buf(&req)
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes());
let mut best = None;
for (i_server_name, hash) in response_map.keys() {
if i_server_name != server_name.as_bytes() {
continue;
}
let diff = compare(&req_hash, hash);
if let Some((best_hash, best_diff)) = &mut best {
if diff < *best_diff {
*best_hash = hash;
*best_diff = diff;
}
} else {
best = Some((hash, diff));
}
}
if let Some((hash, _diff)) = best {
let responses = response_map
.get(&(server_name.as_bytes().to_vec(), hash.clone()))
.unwrap();
for &res in responses {
stream.write_all(res).await.unwrap();
stream.flush().await.unwrap();
}
} else {
eprintln!("No response found for SNI=`{server_name}`");
}
stream.shutdown().await.unwrap();
};
tokio::spawn(async move {
fut.await;
});
}
}
TlsMode::None | TlsMode::Client => {
// TODO
}
}
}
fn compare(a: &[u8], b: &[u8]) -> u32 {
if let (Ok(a), Ok(b)) = (str::from_utf8(a), str::from_utf8(b)) {
if let Ok(diff) = tlsh::compare(a, b) {
return diff;
}
}
if a == b {
0
} else {
a.len().max(b.len()) as u32
}
}

0
src/util.rs Normal file
View file