From 1d62dae7850b81a63826594cea1caa963face7a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pascal=20Eng=C3=A9libert?= Date: Fri, 16 Jan 2026 16:45:40 +0100 Subject: [PATCH] Handle chunked HTTP correctly --- src/client.rs | 192 +++++++++++++++--------- src/http.rs | 408 ++++++++++++++++++++++++++++++++++++++++++-------- src/main.rs | 18 ++- src/record.rs | 102 ++++++++++--- src/server.rs | 183 ++++++++++++++-------- src/util.rs | 130 ++++++++++++++++ 6 files changed, 817 insertions(+), 216 deletions(-) diff --git a/src/client.rs b/src/client.rs index 8d80ac6..f19dbf8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,13 +1,15 @@ use crate::{ TlsMode, record::{Direction, Records}, + util::{ResponseStreamer, print_bin}, }; -use futures_util::StreamExt; +use futures_util::{StreamExt, TryStreamExt}; use std::{ collections::HashSet, net::ToSocketAddrs, sync::{Arc, atomic::AtomicU32}, + time::Duration, }; use tokio::{ io::AsyncWriteExt, @@ -27,6 +29,8 @@ use tokio_rustls::{ }; use tokio_util::codec::Framed; +const TIMEOUT: Duration = Duration::from_secs(60); + #[derive(Debug)] struct DummyCertVerifier; @@ -99,6 +103,7 @@ pub async fn play( let total = records.len() * repeat as usize; let mut handles = Vec::new(); let connect_to = connect_to.to_socket_addrs().unwrap().next().unwrap(); + let debug_mutex = Arc::new(Mutex::new(())); match tls_mode { TlsMode::Both | TlsMode::Client => { let mut config = tokio_rustls::rustls::ClientConfig::builder() @@ -137,54 +142,71 @@ pub async fn play( .connect(server_name.clone(), stream) .await .unwrap(); - let mut stream = Framed::new(stream, crate::http::HttpCodec {}); - for (direction, data) in records { + let mut stream = Framed::new(stream, crate::http::HttpClientCodec::new()); + for (direction, data_list) in ResponseStreamer::new(records.iter()) { match direction { Direction::ClientToServer => { - //println!("[CLT] ({id}) >> {}", data.len()); - //stream.get_mut().write_all(data).await.unwrap(); - match tokio::time::timeout( - std::time::Duration::from_millis(1000), - stream.get_mut().write_all(data), - ) - .await - { - Ok(v) => v.unwrap(), - Err(_e) => continue 'repeat, + for data in data_list { + //println!("[CLT] ({id}) >> {}", data.len()); + //stream.get_mut().write_all(data).await.unwrap(); + match tokio::time::timeout( + TIMEOUT, + stream.get_mut().write_all(data), + ) + .await + { + Ok(v) => v.unwrap(), + Err(_e) => { + println!("client timeout {id} (sending)"); + continue 'repeat; + } + } } } Direction::ServerToClient => { + let total_len: usize = + data_list.iter().map(|data| data.len()).sum::(); + let reduced_len = + total_len.saturating_sub(160 * data_list.len()).max(1); + let mut total_recv = 0; //println!("[CLT] ({id}) << {}", data.len()); // let mut buf = Vec::new(); // stream.read_buf(&mut buf).await.ok(); //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; //let resp = stream.next().await.unwrap().unwrap(); - match tokio::time::timeout( - std::time::Duration::from_millis(1000), - stream.next(), - ) - .await - { - Ok(v) => v.unwrap().unwrap(), - Err(_e) => { - // TODO fix - println!("client timeout {id}"); - break 'repeat; - } - }; - //dbg!(resp.len()); - //crate::http::decode_http(&mut buf, &mut stream).await; + while total_recv < reduced_len { + let resp = match tokio::time::timeout( + TIMEOUT, + stream.next(), + ) + .await + { + Ok(v) => v.unwrap().unwrap(), + Err(_e) => { + // TODO fix + println!( + "client timeout {}: {} / {}", + id, total_recv, total_len + ); + //print_bin(data); + break 'repeat; + } + }; + total_recv += resp.len(); + //dbg!(resp.len()); + //crate::http::decode_http(&mut buf, &mut stream).await; + } + /*if total_recv > total_len { + println!("received too much {}: {} / {}", id, total_recv, total_len); + }*/ } } } //stream.get_mut().shutdown().await.unwrap(); - tokio::time::timeout( - std::time::Duration::from_millis(1000), - stream.get_mut().shutdown(), - ) - .await - .unwrap() - .unwrap(); + tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown()) + .await + .unwrap() + .unwrap(); let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); println!("Client: {} / {}", cnt + 1, total); } @@ -204,6 +226,7 @@ pub async fn play( let counter = counter.clone(); let limiter = limiter.clone(); let running = running.clone(); + let debug_mutex = debug_mutex.clone(); handles.push(tokio::spawn(async move { let mut running_guard = running.lock().await; running_guard.insert(*id); @@ -212,7 +235,7 @@ pub async fn play( //let mut buf = Vec::new(); 'repeat: for _i in 0..repeat { let stream = TcpStream::connect(connect_to).await.unwrap(); - let mut stream = Framed::new(stream, crate::http::HttpCodec {}); + let mut stream = Framed::new(stream, crate::http::HttpClientCodec::new()); /*let mut skip_recv = false; for (direction, data) in records { match direction { @@ -262,54 +285,79 @@ pub async fn play( } } }*/ - for (direction, data) in records { + for (direction, data_list) in ResponseStreamer::new(records.iter()) { match direction { Direction::ClientToServer => { - //println!("[CLT] ({id}) >> {}", str::from_utf8(&data[..data.len().min(255)]).unwrap()); - //stream.get_mut().write_all(data).await.unwrap(); - match tokio::time::timeout( - std::time::Duration::from_millis(1000), - stream.get_mut().write_all(data), - ) - .await - { - Ok(v) => v.unwrap(), - Err(_e) => continue 'repeat, + for data in data_list { + //println!("[CLT] ({id}) >> {}", str::from_utf8(&data[..data.len().min(255)]).unwrap()); + //stream.get_mut().write_all(data).await.unwrap(); + match tokio::time::timeout( + TIMEOUT, + stream.get_mut().write_all(data), + ) + .await + { + Ok(v) => v.unwrap(), + Err(_e) => { + println!("client timeout {id} (sending)"); + continue 'repeat; + } + } } } Direction::ServerToClient => { + let total_len: usize = + data_list.iter().map(|data| data.len()).sum::(); + let reduced_len = + total_len.saturating_sub(160 * data_list.len()).max(1); + let mut total_recv = 0; //println!("[CLT] ({id}) << {}", data.len()); //let mut buf = Vec::new(); //stream.read_buf(&mut buf).await.ok(); //let mut buf = vec![0; data.len().saturating_sub(50).max(1)]; - let resp = match tokio::time::timeout( - std::time::Duration::from_millis(1000), - stream.next(), - ) - .await - { - Ok(v) => v.unwrap().unwrap(), - Err(_e) => { - // TODO fix - println!("client timeout {id}"); - break 'repeat; - } - }; - //let resp = stream.next().await.unwrap().unwrap(); - //dbg!(resp.len()); - //crate::http::decode_http(&mut buf, &mut stream).await; - //buf.clear(); + while total_recv < reduced_len { + let resp = match tokio::time::timeout( + TIMEOUT, + stream.next(), + ) + .await + { + Ok(v) => v.unwrap().unwrap(), + Err(_e) => { + // TODO fix + println!( + "client timeout {}: {} / {}", + id, total_recv, total_len + ); + //print_bin(data); + break 'repeat; + } + }; + total_recv += resp.len(); + /*if resp.len() != data.len() { + let guard = debug_mutex.lock().await; + println!("RECV NOT ENOUGH {} / {}", resp.len(), data.len()); + if resp.len() < 1000 && data.len() < 1000 { + //print_bin(&resp); + //println!("WANTED"); + //print_bin(data); + } + std::mem::drop(guard); + }*/ + //print_bin(&resp); + //let resp = stream.next().await.unwrap().unwrap(); + //dbg!(resp.len()); + //crate::http::decode_http(&mut buf, &mut stream).await; + //buf.clear(); + } } } } //stream.get_mut().shutdown().await.unwrap(); - tokio::time::timeout( - std::time::Duration::from_millis(1000), - stream.get_mut().shutdown(), - ) - .await - .unwrap() - .unwrap(); + tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown()) + .await + .unwrap() + .unwrap(); let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed); println!("Client: {} / {}", cnt + 1, total); } @@ -327,7 +375,7 @@ pub async fn play( async move { let mut last_count = 0; loop { - tokio::time::sleep(std::time::Duration::from_secs(2)).await; + tokio::time::sleep(TIMEOUT).await; println!("Running: {:?}", running.lock().await); let new_count = counter.load(std::sync::atomic::Ordering::Relaxed); if new_count == last_count { diff --git a/src/http.rs b/src/http.rs index 416499b..00b20f9 100644 --- a/src/http.rs +++ b/src/http.rs @@ -1,80 +1,66 @@ use regex::bytes::Regex; -use std::sync::LazyLock; -use tokio::io::AsyncReadExt; -use tokio_util::codec::{Decoder, Encoder}; +use std::{pin::Pin, sync::LazyLock}; +use tokio::io::{AsyncRead, AsyncReadExt}; +use tokio_util::{ + bytes::BytesMut, + codec::{Decoder, Encoder}, +}; + +use crate::util::{is_hex, parse_hex}; static REGEX_CONTENT_LENGTH: LazyLock = LazyLock::new(|| Regex::new(r#"[cC]ontent-[lL]ength: *(\d+)\r\n"#).unwrap()); +static REGEX_CHUNKED: LazyLock = + LazyLock::new(|| Regex::new(r#"[tT]ransfer-[eE]ncoding: *[cC]hunked\r\n"#).unwrap()); -pub async fn _decode_http(buf: &mut Vec, stream: &mut R) { - loop { - if let Some(mut end_index) = memchr::memmem::find(buf, b"\r\n\r\n") { - end_index += 4; - if let Some(captures) = REGEX_CONTENT_LENGTH.captures(buf) { - if let Some(content_length) = captures.get(1) { - // Read body - let content_length: usize = str::from_utf8(content_length.as_bytes()) - .unwrap() - .parse() - .unwrap(); - while buf.len() < end_index + content_length { - match tokio::time::timeout( - std::time::Duration::from_millis(500), - stream.read_buf(buf), - ) - .await - { - Ok(Ok(_n)) => {} - Ok(Err(e)) => { - println!("[http] error reading: {e:?}"); - break; - } - Err(_e) => { - // timeout - break; - } - } - } - break; - } else { - // Erroneous Content-Type - break; - } - } else { - // Header ended without Content-Type => no body - break; - } - } - match tokio::time::timeout(std::time::Duration::from_millis(500), stream.read_buf(buf)) - .await - { - Ok(Ok(n)) => { - println!("[http] read {n}"); - } - Ok(Err(e)) => { - println!("[http] error reading: {e:?}"); - break; - } - Err(_e) => { - // timeout - break; - } - } +/*pin_project! { + pub struct Framer { + #[pin] + pub stream: S, + codec: C, + buf: BytesMut, } } -pub struct HttpCodec {} +impl Framer { + pub fn new(stream: S, codec: C) -> Self { + Self { + stream, + codec, + buf: BytesMut::new(), + } + } -impl Decoder for HttpCodec { + pub async fn next(&mut self) -> Option> { + self.stream.read_buf(&mut self.buf).await.unwrap(); + None + } +}*/ + +pub struct HttpClientCodec { + buf: Vec, +} + +impl HttpClientCodec { + pub fn new() -> Self { + Self { buf: Vec::new() } + } +} + +impl Decoder for HttpClientCodec { type Item = Vec; type Error = std::io::Error; fn decode( &mut self, src: &mut tokio_util::bytes::BytesMut, ) -> Result, Self::Error> { + self.buf.extend_from_slice(&src); + src.clear(); + let src = &mut self.buf; if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { end_index += 4; if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) { + // Content-Length: simple body if let Some(content_length) = captures.get(1) { // Read body let content_length: usize = str::from_utf8(content_length.as_bytes()) @@ -82,30 +68,324 @@ impl Decoder for HttpCodec { .parse() .unwrap(); if src.len() >= end_index + content_length { - //dbg!(content_length); + let remaining = src.split_off(end_index + content_length); let out = src.to_vec(); - src.clear(); + *src = remaining; Ok(Some(out)) } else { + //dbg!("Not enough data"); Ok(None) } } else { // Invalid Content-Length Err(std::io::ErrorKind::InvalidData.into()) } + } else if REGEX_CHUNKED.is_match(&src[0..end_index]) { + // Chunked body + let mut content = &src[end_index..]; + let mut total_len = end_index; + loop { + if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") { + let len_slice = &content[0..len_end_index]; + if len_end_index < 8 && is_hex(len_slice) { + let chunk_len = parse_hex(len_slice) as usize; + if content.len() >= len_end_index + chunk_len + 4 { + total_len += len_end_index + chunk_len + 4; + // Should we check the ending CRLF? + if chunk_len == 0 { + let remaining = src.split_off(total_len); + let out = src.to_vec(); + *src = remaining; + return Ok(Some(out)); + } + // else, wait for the next chunk + content = &content[len_end_index + chunk_len + 4..]; + } else { + // Not enough data + return Ok(None); + } + } else { + // Invalid chunk length + return Err(std::io::ErrorKind::InvalidData.into()); + } + } else { + // Not enough data + return Ok(None); + } + } } else { - // Header ended without Content-Type => no body + // Header ended without Content-Type nor chunks => no body + let remaining = src.split_off(end_index); let out = src.to_vec(); - src.clear(); + *src = remaining; Ok(Some(out)) } } else { + //dbg!("Unfinished header"); Ok(None) } + + /*self.buf.extend_from_slice(&src); + src.clear(); + let src = &mut self.buf; + if self.chunked { + if let Some(len_end_index) = memchr::memmem::find(src, b"\r\n") { + let len_slice = &src[0..len_end_index]; + if len_end_index < 8 && is_hex(len_slice) { + let chunk_len = parse_hex(len_slice) as usize; + if src.len() >= len_end_index + chunk_len + 4 { + // Should we check the ending CRLF? + if chunk_len == 0 { + self.chunked = false; + } + let remaining = src.split_off(len_end_index+chunk_len+4); + let out = src.to_vec(); + *src = remaining; + Ok(Some(out)) + } else { + // Not enough data + Ok(None) + } + } else { + // Invalid chunk length + Err(std::io::ErrorKind::InvalidData.into()) + } + } else { + // Not enough data + Ok(None) + } + } else { + if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { + end_index += 4; + if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) { + if let Some(content_length) = captures.get(1) { + // Read body + let content_length: usize = str::from_utf8(content_length.as_bytes()) + .unwrap() + .parse() + .unwrap(); + if src.len() >= end_index + content_length { + if REGEX_CHUNKED.is_match(&src[0..end_index]) { + self.chunked = true; + } + //dbg!(content_length); + let remaining = src.split_off(end_index + content_length); + let out = src.to_vec(); + *src = remaining; + Ok(Some(out)) + } else { + //dbg!("Not enough data"); + Ok(None) + } + } else { + // Invalid Content-Length + Err(std::io::ErrorKind::InvalidData.into()) + } + } else { + // Header ended without Content-Type => no body + let remaining = src.split_off(end_index); + let out = src.to_vec(); + *src = remaining; + Ok(Some(out)) + } + } else { + //dbg!("Unfinished header"); + Ok(None) + } + }*/ + + /*if let Some(start_index) = memchr::memmem::find(src, b"HTTP") { + if start_index != 0 { + dbg!(start_index); + if start_index == 529 { + println!("{src:?}"); + } + } + let src2 = &src[start_index..]; + if let Some(mut end_index) = memchr::memmem::find(src2, b"\r\n\r\n") { + end_index += 4; + if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src2) { + if let Some(content_length) = captures.get(1) { + // Read body + let content_length: usize = str::from_utf8(content_length.as_bytes()) + .unwrap() + .parse() + .unwrap(); + if src2.len() >= end_index + content_length { + if src2.len() > end_index + content_length { + dbg!(src2.len(), end_index + content_length); + println!("{src2:?}"); + std::process::exit(1); + } + //dbg!(content_length); + let out = src2.to_vec(); + src.clear(); + Ok(Some(out)) + } else { + //dbg!("Not enough data"); + Ok(None) + } + } else { + // Invalid Content-Length + Err(std::io::ErrorKind::InvalidData.into()) + } + } else { + // Header ended without Content-Type => no body + let out = src2.to_vec(); + src.clear(); + Ok(Some(out)) + } + } else { + //dbg!("Unfinished header"); + Ok(None) + } + } else { + //dbg!("Unstarted header"); + Ok(None) + }*/ } } -impl Encoder> for HttpCodec { +impl Encoder> for HttpClientCodec { + type Error = std::io::Error; + fn encode( + &mut self, + _item: Vec, + _dst: &mut tokio_util::bytes::BytesMut, + ) -> Result<(), Self::Error> { + Ok(()) + } +} + +pub struct HttpServerCodec { + buf: Vec, +} + +impl HttpServerCodec { + pub fn new() -> Self { + Self { buf: Vec::new() } + } +} + +impl Decoder for HttpServerCodec { + type Item = Vec; + type Error = std::io::Error; + fn decode( + &mut self, + src: &mut tokio_util::bytes::BytesMut, + ) -> Result, Self::Error> { + self.buf.extend_from_slice(&src); + src.clear(); + let src = &mut self.buf; + if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { + end_index += 4; + if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) { + // Content-Length: simple body + if let Some(content_length) = captures.get(1) { + // Read body + let content_length: usize = str::from_utf8(content_length.as_bytes()) + .unwrap() + .parse() + .unwrap(); + if src.len() >= end_index + content_length { + let remaining = src.split_off(end_index + content_length); + let out = src.to_vec(); + *src = remaining; + Ok(Some(out)) + } else { + //dbg!("Not enough data"); + Ok(None) + } + } else { + // Invalid Content-Length + Err(std::io::ErrorKind::InvalidData.into()) + } + } else if REGEX_CHUNKED.is_match(&src[0..end_index]) { + // Chunked body + let mut content = &src[end_index..]; + let mut total_len = end_index; + loop { + if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") { + let len_slice = &content[0..len_end_index]; + if len_end_index < 8 && is_hex(len_slice) { + let chunk_len = parse_hex(len_slice) as usize; + if content.len() >= len_end_index + chunk_len + 4 { + total_len += len_end_index + chunk_len + 4; + // Should we check the ending CRLF? + if chunk_len == 0 { + let remaining = src.split_off(total_len); + let out = src.to_vec(); + *src = remaining; + return Ok(Some(out)); + } + // else, wait for the next chunk + content = &content[len_end_index + chunk_len + 4..]; + } else { + // Not enough data + return Ok(None); + } + } else { + // Invalid chunk length + return Err(std::io::ErrorKind::InvalidData.into()); + } + } else { + // Not enough data + return Ok(None); + } + } + } else { + // Header ended without Content-Type nor chunks => no body + let remaining = src.split_off(end_index); + let out = src.to_vec(); + *src = remaining; + Ok(Some(out)) + } + } else { + //dbg!("Unfinished header"); + Ok(None) + } + + /*self.buf.extend_from_slice(&src); + src.clear(); + let src = &mut self.buf; + if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") { + end_index += 4; + if let Some(captures) = REGEX_CONTENT_LENGTH.captures(&src[0..end_index]) { + if let Some(content_length) = captures.get(1) { + // Read body + let content_length: usize = str::from_utf8(content_length.as_bytes()) + .unwrap() + .parse() + .unwrap(); + if src.len() >= end_index + content_length { + //dbg!(content_length); + let remaining = src.split_off(end_index + content_length); + let out = src.to_vec(); + *src = remaining; + Ok(Some(out)) + } else { + //dbg!("Not enough data"); + Ok(None) + } + } else { + // Invalid Content-Length + Err(std::io::ErrorKind::InvalidData.into()) + } + } else { + // Header ended without Content-Type => no body + let remaining = src.split_off(end_index); + let out = src.to_vec(); + *src = remaining; + Ok(Some(out)) + } + } else { + //dbg!("Unfinished header"); + Ok(None) + }*/ + } +} + +impl Encoder> for HttpServerCodec { type Error = std::io::Error; fn encode( &mut self, diff --git a/src/main.rs b/src/main.rs index 2c7a41e..55c1766 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,10 @@ +#![feature(ascii_char)] + mod client; mod http; mod record; mod server; +mod util; use record::Records; @@ -29,6 +32,8 @@ enum Subcommand { Print(OptPrint), /// Record traffic Record(OptRecord), + /// Write test record + Test(OptTest), } /// Replay from records @@ -68,6 +73,9 @@ struct OptPrint { /// Print packets #[argp(switch, short = 'p')] packets: bool, + /// Record number + #[argp(option, short = 'n')] + number: Option, } /// Record traffic @@ -75,6 +83,11 @@ struct OptPrint { #[argp(subcommand, name = "record")] struct OptRecord {} +/// Record traffic +#[derive(FromArgs)] +#[argp(subcommand, name = "test")] +struct OptTest {} + #[derive(Clone, Copy, Debug, Eq, PartialEq)] enum RunMode { Client, @@ -228,10 +241,13 @@ async fn main() { } Subcommand::Print(subopt) => { let records = record::read_record_file(&opt.record_file); - record::print_records(&records, subopt.packets); + record::print_records(&records, subopt.packets, subopt.number); } Subcommand::Record(_subopt) => { record::make_record(&opt.record_file); } + Subcommand::Test(_subopt) => { + record::make_test_record(&opt.record_file); + } } } diff --git a/src/record.rs b/src/record.rs index 2e73f22..38fc7bd 100644 --- a/src/record.rs +++ b/src/record.rs @@ -4,6 +4,8 @@ use std::{ sync::mpsc::{Receiver, Sender, channel}, }; +use crate::util::print_bin; + const CLIENT_TO_SERVER: u8 = b'C'; const SERVER_TO_CLIENT: u8 = b'S'; @@ -15,6 +17,39 @@ pub enum Direction { ServerToClient, } +static TEST_RECORD: &[(u64, &str, Direction, &[u8])] = + &[ + (0, "upload.wikimedia.org", Direction::ClientToServer, b"GET /aaaaaaaaa HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"), + (0, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nContent-Length: 5\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\nhello"), + (0, "upload.wikimedia.org", Direction::ClientToServer, b"GET /bbbbbbbbb HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"), + (0, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nContent-Length: 7\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\nbonjour"), + (1, "upload.wikimedia.org", Direction::ClientToServer, b"GET /ccccccccc HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"), + (1, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\n6\r\nbanane\r\n"), + (1, "upload.wikimedia.org", Direction::ServerToClient, b"5\r\npomme\r\n"), + (1, "upload.wikimedia.org", Direction::ServerToClient, b"0\r\n\r\n"), + ]; + +fn write_record( + file: &mut std::fs::File, + direction: Direction, + conn_id: u64, + server_name: &str, + data: &[u8], +) { + let server_name = server_name.as_bytes(); + file.write_all(&[match direction { + Direction::ClientToServer => CLIENT_TO_SERVER, + Direction::ServerToClient => SERVER_TO_CLIENT, + }]) + .unwrap(); + file.write_all(&conn_id.to_be_bytes()).unwrap(); + file.write_all(&[server_name.len() as u8]).unwrap(); + file.write_all(server_name).unwrap(); + file.write_all(&(data.len() as u64).to_be_bytes()).unwrap(); + file.write_all(&data).unwrap(); + file.flush().unwrap(); +} + pub struct Recorder { file: std::fs::File, receiver: Receiver<(u64, Option, Direction, Vec)>, @@ -43,21 +78,7 @@ impl Recorder { let Some(server_name) = server_name else { continue; }; - let server_name = server_name.as_bytes(); - self.file - .write_all(&[match direction { - Direction::ClientToServer => CLIENT_TO_SERVER, - Direction::ServerToClient => SERVER_TO_CLIENT, - }]) - .unwrap(); - self.file.write_all(&conn_id.to_be_bytes()).unwrap(); - self.file.write_all(&[server_name.len() as u8]).unwrap(); - self.file.write_all(server_name).unwrap(); - self.file - .write_all(&(data.len() as u64).to_be_bytes()) - .unwrap(); - self.file.write_all(&data).unwrap(); - self.file.flush().unwrap(); + write_record(&mut self.file, direction, conn_id, &server_name, &data); } } } @@ -186,11 +207,38 @@ pub fn read_record_file(path: &str) -> Records { println!("Error: incomplete data. stop."); break; } + + // Replace URL with unique id, to allow for better tracking by making each request unique. + // (proxy may modify some headers, but not the URL) + let mut insert_id = |req_id| { + if direction == Direction::ClientToServer { + let mut spaces = buf + .iter() + .enumerate() + .filter_map(|(i, c)| if *c == b' ' { Some(i) } else { None }); + let s1 = spaces.next().unwrap(); + let s2 = spaces.next().unwrap(); + let new_url = format!("/{conn_id}-{req_id}/"); + if s2 - s1 - 1 < new_url.len() { + // Not optimal but good enough + let mut new_buf = Vec::new(); + new_buf.extend_from_slice(&buf[0..s1 + 1]); + new_buf.extend_from_slice(new_url.as_bytes()); + new_buf.extend_from_slice(&buf[s2..]); + buf = new_buf; + } else { + buf[s1 + 1..s2][0..new_url.len()].copy_from_slice(new_url.as_bytes()); + } + } + }; + match records.entry(conn_id) { btree_map::Entry::Occupied(mut entry) => { + (insert_id)(entry.get().1.len()); entry.get_mut().1.push((direction, buf)); } btree_map::Entry::Vacant(entry) => { + (insert_id)(0); entry.insert((server_name, vec![(direction, buf)])); } } @@ -198,8 +246,13 @@ pub fn read_record_file(path: &str) -> Records { records } -pub fn print_records(records: &Records, print_packets: bool) { +pub fn print_records(records: &Records, print_packets: bool, number: Option) { for (id, (server_name, records)) in records { + if let Some(number) = number + && number != *id + { + continue; + } let server_name = str::from_utf8(server_name.as_slice()).unwrap(); println!("{id} {server_name}"); for (direction, data) in records { @@ -212,7 +265,7 @@ pub fn print_records(records: &Records, print_packets: bool) { } } if print_packets { - let data_tr = if data.len() >= 256 && *direction == Direction::ServerToClient { + /*let data_tr = if data.len() >= 256 && *direction == Direction::ServerToClient { &data[0..256] } else { data.as_slice() @@ -224,8 +277,21 @@ pub fn print_records(records: &Records, print_packets: bool) { } if let Some(header_end) = memchr::memmem::find(data, b"\r\n\r\n") { println!(" --> body len: {}", data.len() - header_end - 4); - } + }*/ + print_bin(&data[0..data.len().min(8192)]); } } } } + +pub fn make_test_record(path: &str) { + let mut file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .open(path) + .unwrap(); + for (conn_id, server_name, direction, data) in TEST_RECORD { + write_record(&mut file, *direction, *conn_id, *server_name, *data); + } +} diff --git a/src/server.rs b/src/server.rs index 2df46b1..2ba5979 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,6 +1,7 @@ use crate::{ TlsMode, record::{Direction, Records}, + util::print_bin, }; use futures_util::stream::StreamExt; @@ -37,10 +38,19 @@ pub async fn play( response_map.insert((server_name.to_vec(), hash), (id, responses)); responses = Vec::new(); } - hash = Some( + let mut slashes = data + .iter() + .enumerate() + .filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None }); + let s1 = slashes.next(); + let s2 = slashes.next(); + hash = Some(if let (Some(s1), Some(s2)) = (s1, s2) { + data[s1 + 1..s2].to_vec() + } else { + panic!("Did not find URL: {:?}", &data[0..256]); tlsh::hash_buf(data) - .map_or_else(|_| data.clone(), |h| h.to_string().into_bytes()), - ); + .map_or_else(|_| data.clone(), |h| h.to_string().into_bytes()) + }); } Direction::ServerToClient => { responses.push(data); @@ -113,7 +123,7 @@ pub async fn play( .iter() { if let GeneralName::DNSName(name) = name { - resolver.add(name, cert_key.clone()).unwrap(); + resolver.add(name, cert_key.clone()).ok(); } } } @@ -189,40 +199,69 @@ pub async fn play( .unwrap() .trim_end_matches(".localhost") .to_string(); - let stream = accepted.into_stream(config).await.unwrap(); - let mut stream = Framed::new(stream, crate::http::HttpCodec {}); - let req = stream.next().await.unwrap().unwrap(); - let req_hash = tlsh::hash_buf(&req) - .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()); - let mut best = None; - for (i_server_name, hash) in response_map.keys() { - if i_server_name != server_name.as_bytes() { - continue; + let stream = accepted + .into_stream(config) + .await + .map_err(|e| panic!("{e:?} with name `{server_name}`")) + .unwrap(); + let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new()); + //let mut previous = Vec::new(); + while let Some(req) = stream.next().await { + let req = req.unwrap(); + //println!("REQUEST"); + //print_bin(&req); + let req_hash = { + let mut slashes = req + .iter() + .enumerate() + .filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None }); + let s1 = slashes.next(); + let s2 = slashes.next(); + if let (Some(s1), Some(s2)) = (s1, s2) { + req[s1 + 1..s2].to_vec() + } else { + //println!("Previous: {:?}", &previous); + println!("Did not find URL: {:?}", &req[0..req.len().min(255)]); + tlsh::hash_buf(&req) + .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()) + } + }; + //previous = req.clone(); + let mut best = None; + for (i_server_name, hash) in response_map.keys() { + if i_server_name != server_name.as_bytes() { + continue; + } + let diff = if &req_hash == hash { + 0 + } else { + compare(&req_hash, hash) + }; + if let Some((best_hash, best_diff)) = &mut best { + if diff < *best_diff { + *best_hash = hash; + *best_diff = diff; + } + } else { + best = Some((hash, diff)); + } } - let diff = compare(&req_hash, hash); - if let Some((best_hash, best_diff)) = &mut best { - if diff < *best_diff { - *best_hash = hash; - *best_diff = diff; + let stream = stream.get_mut(); + if let Some((hash, _diff)) = best { + let (id, responses) = response_map + .get(&(server_name.as_bytes().to_vec(), hash.clone())) + .unwrap(); + //dbg!(id); + for &res in responses { + //println!("[SRV] response for ({}): {} bytes", id, res.len()); + stream.write_all(res).await.unwrap(); + stream.flush().await.unwrap(); } } else { - best = Some((hash, diff)); + println!("No response found for SNI=`{server_name}`"); } } - let stream = stream.get_mut(); - if let Some((hash, _diff)) = best { - let (id, responses) = response_map - .get(&(server_name.as_bytes().to_vec(), hash.clone())) - .unwrap(); - for &res in responses { - //println!("[SRV] response for ({}): {} bytes", id, res.len()); - stream.write_all(res).await.unwrap(); - stream.flush().await.unwrap(); - } - } else { - println!("No response found for SNI=`{server_name}`"); - } - stream.shutdown().await.unwrap(); + stream.get_mut().shutdown().await.unwrap(); }; tokio::spawn(async move { fut.await; @@ -274,39 +313,61 @@ pub async fn play( };*/ let fut = async move { //println!("[SRV] New task"); - let mut stream = Framed::new(stream, crate::http::HttpCodec {}); - let req = stream.next().await.unwrap().unwrap(); - //println!("[SRV] << {}", str::from_utf8(&req[..req.len().min(255)]).unwrap()); - let req_hash = tlsh::hash_buf(&req) - .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()); - let mut best = None; - for (i_server_name, hash) in response_map.keys() { - let diff = compare(&req_hash, hash); - if let Some((best_server_name, best_hash, best_diff)) = &mut best { - if diff < *best_diff { - *best_server_name = i_server_name; - *best_hash = hash; - *best_diff = diff; + //let mut stream = crate::http::Framer::new(stream, crate::http::HttpServerCodec::new()); + let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new()); + //let mut previous = Vec::new(); + while let Some(req) = stream.next().await { + let req = req.unwrap(); + //println!("REQUEST"); + //print_bin(&req); + //println!("[SRV] << {}", str::from_utf8(&req[..req.len().min(255)]).unwrap()); + let req_hash = { + let mut slashes = req + .iter() + .enumerate() + .filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None }); + let s1 = slashes.next(); + let s2 = slashes.next(); + if let (Some(s1), Some(s2)) = (s1, s2) { + req[s1 + 1..s2].to_vec() + } else { + //println!("Previous: {:?}", &previous); + println!("Did not find URL: {:?}", &req[0..req.len().min(255)]); + tlsh::hash_buf(&req) + .map_or_else(|_| req.clone(), |h| h.to_string().into_bytes()) + } + }; + //previous = req.clone(); + let mut best = None; + for (i_server_name, hash) in response_map.keys() { + let diff = compare(&req_hash, hash); + if let Some((best_server_name, best_hash, best_diff)) = &mut best { + if diff < *best_diff { + *best_server_name = i_server_name; + *best_hash = hash; + *best_diff = diff; + } + } else { + best = Some((i_server_name, hash, diff)); + } + } + let stream = stream.get_mut(); + if let Some((server_name, hash, _diff)) = best { + let (id, responses) = response_map + .get(&(server_name.clone(), hash.clone())) + .unwrap(); + //dbg!(id); + for &res in responses { + //println!("[SRV] response for ({}): {} bytes", id, res.len()); + stream.write_all(res).await.unwrap(); + stream.flush().await.unwrap(); } } else { - best = Some((i_server_name, hash, diff)); + println!("[SRV] No response found"); } } - let stream = stream.get_mut(); - if let Some((server_name, hash, _diff)) = best { - let (id, responses) = response_map - .get(&(server_name.clone(), hash.clone())) - .unwrap(); - for &res in responses { - //println!("[SRV] response for ({}): {} bytes", id, res.len()); - stream.write_all(res).await.unwrap(); - stream.flush().await.unwrap(); - } - } else { - println!("[SRV] No response found"); - } //println!("Server shutdown"); - stream.shutdown().await.unwrap(); + stream.get_mut().shutdown().await.unwrap(); }; // Using a variable for the future allows it to be detected by tokio-console tokio::spawn(async move { diff --git a/src/util.rs b/src/util.rs index e69de29..b401eae 100644 --- a/src/util.rs +++ b/src/util.rs @@ -0,0 +1,130 @@ +use std::iter::Peekable; + +use crate::record::Direction; + +fn hex_digit(c: u8) -> u32 { + ((c & !(16 | 32 | 64)) + ((c & 64) >> 6) * 9) as _ +} + +pub fn parse_hex(s: &[u8]) -> u32 { + let mut r = 0; + for i in s.iter() { + r <<= 4; + r |= hex_digit(*i); + } + r +} + +pub fn is_hex(s: &[u8]) -> bool { + s.iter().all(|c| { + let c = *c | 32; + (c >= b'a' && c <= b'f') || (c >= b'0' && c <= b'9') + }) +} + +/// Print ASCII if possible +pub fn print_bin(s: &[u8]) { + if let Ok(s) = str::from_utf8(s) { + println!("{s}"); + } else { + let mut buf = String::new(); + for c in s { + if c.is_ascii_control() && *c != b'\n' { + continue; + } + if let Some(c) = c.as_ascii() { + buf.push_str(c.as_str()); + } else { + for c in std::ascii::escape_default(*c) { + buf.push(c.as_ascii().unwrap().into()); + } + } + } + println!("{buf}"); + } +} + +pub struct ResponseStreamer(Peekable); + +impl<'a, I: Iterator> ResponseStreamer { + pub fn new(inner: I) -> Self { + Self(inner.peekable()) + } +} + +impl<'a, I: Iterator)>> Iterator for ResponseStreamer { + type Item = (&'a Direction, Vec<&'a Vec>); + fn next(&mut self) -> Option { + let (direction, first_item) = self.0.next()?; + let mut items = vec![first_item]; + while let Some((item_direction, _item)) = self.0.peek() + && item_direction == direction + { + items.push(&self.0.next().unwrap().1); + } + Some((direction, items)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_hex_digit() { + assert_eq!(hex_digit(b'0'), 0); + assert_eq!(hex_digit(b'1'), 1); + assert_eq!(hex_digit(b'2'), 2); + assert_eq!(hex_digit(b'3'), 3); + assert_eq!(hex_digit(b'4'), 4); + assert_eq!(hex_digit(b'5'), 5); + assert_eq!(hex_digit(b'6'), 6); + assert_eq!(hex_digit(b'7'), 7); + assert_eq!(hex_digit(b'8'), 8); + assert_eq!(hex_digit(b'9'), 9); + assert_eq!(hex_digit(b'a'), 10); + assert_eq!(hex_digit(b'b'), 11); + assert_eq!(hex_digit(b'c'), 12); + assert_eq!(hex_digit(b'd'), 13); + assert_eq!(hex_digit(b'e'), 14); + assert_eq!(hex_digit(b'f'), 15); + assert_eq!(hex_digit(b'A'), 10); + assert_eq!(hex_digit(b'B'), 11); + assert_eq!(hex_digit(b'C'), 12); + assert_eq!(hex_digit(b'D'), 13); + assert_eq!(hex_digit(b'E'), 14); + assert_eq!(hex_digit(b'F'), 15); + } + + #[test] + fn test_parse_hex() { + assert_eq!(parse_hex(b"abc123"), 0xabc123); + assert_eq!(parse_hex(b"1"), 1); + } + + #[test] + fn test_is_hex() { + assert!(is_hex(b"0")); + assert!(is_hex(b"1")); + assert!(is_hex(b"2")); + assert!(is_hex(b"3")); + assert!(is_hex(b"4")); + assert!(is_hex(b"5")); + assert!(is_hex(b"6")); + assert!(is_hex(b"7")); + assert!(is_hex(b"8")); + assert!(is_hex(b"9")); + assert!(is_hex(b"a")); + assert!(is_hex(b"b")); + assert!(is_hex(b"c")); + assert!(is_hex(b"d")); + assert!(is_hex(b"e")); + assert!(is_hex(b"f")); + assert!(is_hex(b"A")); + assert!(is_hex(b"B")); + assert!(is_hex(b"C")); + assert!(is_hex(b"D")); + assert!(is_hex(b"E")); + assert!(is_hex(b"F")); + } +}