Handle chunked HTTP correctly
This commit is contained in:
parent
dc2819e028
commit
1d62dae785
6 changed files with 817 additions and 216 deletions
192
src/client.rs
192
src/client.rs
|
|
@ -1,13 +1,15 @@
|
|||
use crate::{
|
||||
TlsMode,
|
||||
record::{Direction, Records},
|
||||
util::{ResponseStreamer, print_bin},
|
||||
};
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use futures_util::{StreamExt, TryStreamExt};
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
net::ToSocketAddrs,
|
||||
sync::{Arc, atomic::AtomicU32},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{
|
||||
io::AsyncWriteExt,
|
||||
|
|
@ -27,6 +29,8 @@ use tokio_rustls::{
|
|||
};
|
||||
use tokio_util::codec::Framed;
|
||||
|
||||
const TIMEOUT: Duration = Duration::from_secs(60);
|
||||
|
||||
#[derive(Debug)]
|
||||
struct DummyCertVerifier;
|
||||
|
||||
|
|
@ -99,6 +103,7 @@ pub async fn play(
|
|||
let total = records.len() * repeat as usize;
|
||||
let mut handles = Vec::new();
|
||||
let connect_to = connect_to.to_socket_addrs().unwrap().next().unwrap();
|
||||
let debug_mutex = Arc::new(Mutex::new(()));
|
||||
match tls_mode {
|
||||
TlsMode::Both | TlsMode::Client => {
|
||||
let mut config = tokio_rustls::rustls::ClientConfig::builder()
|
||||
|
|
@ -137,54 +142,71 @@ pub async fn play(
|
|||
.connect(server_name.clone(), stream)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut stream = Framed::new(stream, crate::http::HttpCodec {});
|
||||
for (direction, data) in records {
|
||||
let mut stream = Framed::new(stream, crate::http::HttpClientCodec::new());
|
||||
for (direction, data_list) in ResponseStreamer::new(records.iter()) {
|
||||
match direction {
|
||||
Direction::ClientToServer => {
|
||||
//println!("[CLT] ({id}) >> {}", data.len());
|
||||
//stream.get_mut().write_all(data).await.unwrap();
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
stream.get_mut().write_all(data),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap(),
|
||||
Err(_e) => continue 'repeat,
|
||||
for data in data_list {
|
||||
//println!("[CLT] ({id}) >> {}", data.len());
|
||||
//stream.get_mut().write_all(data).await.unwrap();
|
||||
match tokio::time::timeout(
|
||||
TIMEOUT,
|
||||
stream.get_mut().write_all(data),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap(),
|
||||
Err(_e) => {
|
||||
println!("client timeout {id} (sending)");
|
||||
continue 'repeat;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Direction::ServerToClient => {
|
||||
let total_len: usize =
|
||||
data_list.iter().map(|data| data.len()).sum::<usize>();
|
||||
let reduced_len =
|
||||
total_len.saturating_sub(160 * data_list.len()).max(1);
|
||||
let mut total_recv = 0;
|
||||
//println!("[CLT] ({id}) << {}", data.len());
|
||||
// let mut buf = Vec::new();
|
||||
// stream.read_buf(&mut buf).await.ok();
|
||||
//let mut buf = vec![0; data.len().saturating_sub(50).max(1)];
|
||||
//let resp = stream.next().await.unwrap().unwrap();
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
stream.next(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap().unwrap(),
|
||||
Err(_e) => {
|
||||
// TODO fix
|
||||
println!("client timeout {id}");
|
||||
break 'repeat;
|
||||
}
|
||||
};
|
||||
//dbg!(resp.len());
|
||||
//crate::http::decode_http(&mut buf, &mut stream).await;
|
||||
while total_recv < reduced_len {
|
||||
let resp = match tokio::time::timeout(
|
||||
TIMEOUT,
|
||||
stream.next(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap().unwrap(),
|
||||
Err(_e) => {
|
||||
// TODO fix
|
||||
println!(
|
||||
"client timeout {}: {} / {}",
|
||||
id, total_recv, total_len
|
||||
);
|
||||
//print_bin(data);
|
||||
break 'repeat;
|
||||
}
|
||||
};
|
||||
total_recv += resp.len();
|
||||
//dbg!(resp.len());
|
||||
//crate::http::decode_http(&mut buf, &mut stream).await;
|
||||
}
|
||||
/*if total_recv > total_len {
|
||||
println!("received too much {}: {} / {}", id, total_recv, total_len);
|
||||
}*/
|
||||
}
|
||||
}
|
||||
}
|
||||
//stream.get_mut().shutdown().await.unwrap();
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
stream.get_mut().shutdown(),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
println!("Client: {} / {}", cnt + 1, total);
|
||||
}
|
||||
|
|
@ -204,6 +226,7 @@ pub async fn play(
|
|||
let counter = counter.clone();
|
||||
let limiter = limiter.clone();
|
||||
let running = running.clone();
|
||||
let debug_mutex = debug_mutex.clone();
|
||||
handles.push(tokio::spawn(async move {
|
||||
let mut running_guard = running.lock().await;
|
||||
running_guard.insert(*id);
|
||||
|
|
@ -212,7 +235,7 @@ pub async fn play(
|
|||
//let mut buf = Vec::new();
|
||||
'repeat: for _i in 0..repeat {
|
||||
let stream = TcpStream::connect(connect_to).await.unwrap();
|
||||
let mut stream = Framed::new(stream, crate::http::HttpCodec {});
|
||||
let mut stream = Framed::new(stream, crate::http::HttpClientCodec::new());
|
||||
/*let mut skip_recv = false;
|
||||
for (direction, data) in records {
|
||||
match direction {
|
||||
|
|
@ -262,54 +285,79 @@ pub async fn play(
|
|||
}
|
||||
}
|
||||
}*/
|
||||
for (direction, data) in records {
|
||||
for (direction, data_list) in ResponseStreamer::new(records.iter()) {
|
||||
match direction {
|
||||
Direction::ClientToServer => {
|
||||
//println!("[CLT] ({id}) >> {}", str::from_utf8(&data[..data.len().min(255)]).unwrap());
|
||||
//stream.get_mut().write_all(data).await.unwrap();
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
stream.get_mut().write_all(data),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap(),
|
||||
Err(_e) => continue 'repeat,
|
||||
for data in data_list {
|
||||
//println!("[CLT] ({id}) >> {}", str::from_utf8(&data[..data.len().min(255)]).unwrap());
|
||||
//stream.get_mut().write_all(data).await.unwrap();
|
||||
match tokio::time::timeout(
|
||||
TIMEOUT,
|
||||
stream.get_mut().write_all(data),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap(),
|
||||
Err(_e) => {
|
||||
println!("client timeout {id} (sending)");
|
||||
continue 'repeat;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Direction::ServerToClient => {
|
||||
let total_len: usize =
|
||||
data_list.iter().map(|data| data.len()).sum::<usize>();
|
||||
let reduced_len =
|
||||
total_len.saturating_sub(160 * data_list.len()).max(1);
|
||||
let mut total_recv = 0;
|
||||
//println!("[CLT] ({id}) << {}", data.len());
|
||||
//let mut buf = Vec::new();
|
||||
//stream.read_buf(&mut buf).await.ok();
|
||||
//let mut buf = vec![0; data.len().saturating_sub(50).max(1)];
|
||||
let resp = match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
stream.next(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap().unwrap(),
|
||||
Err(_e) => {
|
||||
// TODO fix
|
||||
println!("client timeout {id}");
|
||||
break 'repeat;
|
||||
}
|
||||
};
|
||||
//let resp = stream.next().await.unwrap().unwrap();
|
||||
//dbg!(resp.len());
|
||||
//crate::http::decode_http(&mut buf, &mut stream).await;
|
||||
//buf.clear();
|
||||
while total_recv < reduced_len {
|
||||
let resp = match tokio::time::timeout(
|
||||
TIMEOUT,
|
||||
stream.next(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(v) => v.unwrap().unwrap(),
|
||||
Err(_e) => {
|
||||
// TODO fix
|
||||
println!(
|
||||
"client timeout {}: {} / {}",
|
||||
id, total_recv, total_len
|
||||
);
|
||||
//print_bin(data);
|
||||
break 'repeat;
|
||||
}
|
||||
};
|
||||
total_recv += resp.len();
|
||||
/*if resp.len() != data.len() {
|
||||
let guard = debug_mutex.lock().await;
|
||||
println!("RECV NOT ENOUGH {} / {}", resp.len(), data.len());
|
||||
if resp.len() < 1000 && data.len() < 1000 {
|
||||
//print_bin(&resp);
|
||||
//println!("WANTED");
|
||||
//print_bin(data);
|
||||
}
|
||||
std::mem::drop(guard);
|
||||
}*/
|
||||
//print_bin(&resp);
|
||||
//let resp = stream.next().await.unwrap().unwrap();
|
||||
//dbg!(resp.len());
|
||||
//crate::http::decode_http(&mut buf, &mut stream).await;
|
||||
//buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//stream.get_mut().shutdown().await.unwrap();
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_millis(1000),
|
||||
stream.get_mut().shutdown(),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
tokio::time::timeout(TIMEOUT, stream.get_mut().shutdown())
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let cnt = counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
println!("Client: {} / {}", cnt + 1, total);
|
||||
}
|
||||
|
|
@ -327,7 +375,7 @@ pub async fn play(
|
|||
async move {
|
||||
let mut last_count = 0;
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
tokio::time::sleep(TIMEOUT).await;
|
||||
println!("Running: {:?}", running.lock().await);
|
||||
let new_count = counter.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if new_count == last_count {
|
||||
|
|
|
|||
408
src/http.rs
408
src/http.rs
|
|
@ -1,80 +1,66 @@
|
|||
use regex::bytes::Regex;
|
||||
use std::sync::LazyLock;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio_util::codec::{Decoder, Encoder};
|
||||
use std::{pin::Pin, sync::LazyLock};
|
||||
use tokio::io::{AsyncRead, AsyncReadExt};
|
||||
use tokio_util::{
|
||||
bytes::BytesMut,
|
||||
codec::{Decoder, Encoder},
|
||||
};
|
||||
|
||||
use crate::util::{is_hex, parse_hex};
|
||||
|
||||
static REGEX_CONTENT_LENGTH: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r#"[cC]ontent-[lL]ength: *(\d+)\r\n"#).unwrap());
|
||||
static REGEX_CHUNKED: LazyLock<Regex> =
|
||||
LazyLock::new(|| Regex::new(r#"[tT]ransfer-[eE]ncoding: *[cC]hunked\r\n"#).unwrap());
|
||||
|
||||
pub async fn _decode_http<R: AsyncReadExt + Unpin>(buf: &mut Vec<u8>, stream: &mut R) {
|
||||
loop {
|
||||
if let Some(mut end_index) = memchr::memmem::find(buf, b"\r\n\r\n") {
|
||||
end_index += 4;
|
||||
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(buf) {
|
||||
if let Some(content_length) = captures.get(1) {
|
||||
// Read body
|
||||
let content_length: usize = str::from_utf8(content_length.as_bytes())
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap();
|
||||
while buf.len() < end_index + content_length {
|
||||
match tokio::time::timeout(
|
||||
std::time::Duration::from_millis(500),
|
||||
stream.read_buf(buf),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(_n)) => {}
|
||||
Ok(Err(e)) => {
|
||||
println!("[http] error reading: {e:?}");
|
||||
break;
|
||||
}
|
||||
Err(_e) => {
|
||||
// timeout
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
} else {
|
||||
// Erroneous Content-Type
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
// Header ended without Content-Type => no body
|
||||
break;
|
||||
}
|
||||
}
|
||||
match tokio::time::timeout(std::time::Duration::from_millis(500), stream.read_buf(buf))
|
||||
.await
|
||||
{
|
||||
Ok(Ok(n)) => {
|
||||
println!("[http] read {n}");
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
println!("[http] error reading: {e:?}");
|
||||
break;
|
||||
}
|
||||
Err(_e) => {
|
||||
// timeout
|
||||
break;
|
||||
}
|
||||
}
|
||||
/*pin_project! {
|
||||
pub struct Framer<S, C> {
|
||||
#[pin]
|
||||
pub stream: S,
|
||||
codec: C,
|
||||
buf: BytesMut,
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HttpCodec {}
|
||||
impl<S: AsyncRead + Unpin, C: Decoder> Framer<S, C> {
|
||||
pub fn new(stream: S, codec: C) -> Self {
|
||||
Self {
|
||||
stream,
|
||||
codec,
|
||||
buf: BytesMut::new(),
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for HttpCodec {
|
||||
pub async fn next(&mut self) -> Option<Vec<u8>> {
|
||||
self.stream.read_buf(&mut self.buf).await.unwrap();
|
||||
None
|
||||
}
|
||||
}*/
|
||||
|
||||
pub struct HttpClientCodec {
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl HttpClientCodec {
|
||||
pub fn new() -> Self {
|
||||
Self { buf: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for HttpClientCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Error = std::io::Error;
|
||||
fn decode(
|
||||
&mut self,
|
||||
src: &mut tokio_util::bytes::BytesMut,
|
||||
) -> Result<Option<Self::Item>, Self::Error> {
|
||||
self.buf.extend_from_slice(&src);
|
||||
src.clear();
|
||||
let src = &mut self.buf;
|
||||
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
|
||||
end_index += 4;
|
||||
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) {
|
||||
// Content-Length: simple body
|
||||
if let Some(content_length) = captures.get(1) {
|
||||
// Read body
|
||||
let content_length: usize = str::from_utf8(content_length.as_bytes())
|
||||
|
|
@ -82,30 +68,324 @@ impl Decoder for HttpCodec {
|
|||
.parse()
|
||||
.unwrap();
|
||||
if src.len() >= end_index + content_length {
|
||||
//dbg!(content_length);
|
||||
let remaining = src.split_off(end_index + content_length);
|
||||
let out = src.to_vec();
|
||||
src.clear();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
} else {
|
||||
//dbg!("Not enough data");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Invalid Content-Length
|
||||
Err(std::io::ErrorKind::InvalidData.into())
|
||||
}
|
||||
} else if REGEX_CHUNKED.is_match(&src[0..end_index]) {
|
||||
// Chunked body
|
||||
let mut content = &src[end_index..];
|
||||
let mut total_len = end_index;
|
||||
loop {
|
||||
if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") {
|
||||
let len_slice = &content[0..len_end_index];
|
||||
if len_end_index < 8 && is_hex(len_slice) {
|
||||
let chunk_len = parse_hex(len_slice) as usize;
|
||||
if content.len() >= len_end_index + chunk_len + 4 {
|
||||
total_len += len_end_index + chunk_len + 4;
|
||||
// Should we check the ending CRLF?
|
||||
if chunk_len == 0 {
|
||||
let remaining = src.split_off(total_len);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
return Ok(Some(out));
|
||||
}
|
||||
// else, wait for the next chunk
|
||||
content = &content[len_end_index + chunk_len + 4..];
|
||||
} else {
|
||||
// Not enough data
|
||||
return Ok(None);
|
||||
}
|
||||
} else {
|
||||
// Invalid chunk length
|
||||
return Err(std::io::ErrorKind::InvalidData.into());
|
||||
}
|
||||
} else {
|
||||
// Not enough data
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Header ended without Content-Type => no body
|
||||
// Header ended without Content-Type nor chunks => no body
|
||||
let remaining = src.split_off(end_index);
|
||||
let out = src.to_vec();
|
||||
src.clear();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
}
|
||||
} else {
|
||||
//dbg!("Unfinished header");
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/*self.buf.extend_from_slice(&src);
|
||||
src.clear();
|
||||
let src = &mut self.buf;
|
||||
if self.chunked {
|
||||
if let Some(len_end_index) = memchr::memmem::find(src, b"\r\n") {
|
||||
let len_slice = &src[0..len_end_index];
|
||||
if len_end_index < 8 && is_hex(len_slice) {
|
||||
let chunk_len = parse_hex(len_slice) as usize;
|
||||
if src.len() >= len_end_index + chunk_len + 4 {
|
||||
// Should we check the ending CRLF?
|
||||
if chunk_len == 0 {
|
||||
self.chunked = false;
|
||||
}
|
||||
let remaining = src.split_off(len_end_index+chunk_len+4);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
} else {
|
||||
// Not enough data
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Invalid chunk length
|
||||
Err(std::io::ErrorKind::InvalidData.into())
|
||||
}
|
||||
} else {
|
||||
// Not enough data
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
|
||||
end_index += 4;
|
||||
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) {
|
||||
if let Some(content_length) = captures.get(1) {
|
||||
// Read body
|
||||
let content_length: usize = str::from_utf8(content_length.as_bytes())
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap();
|
||||
if src.len() >= end_index + content_length {
|
||||
if REGEX_CHUNKED.is_match(&src[0..end_index]) {
|
||||
self.chunked = true;
|
||||
}
|
||||
//dbg!(content_length);
|
||||
let remaining = src.split_off(end_index + content_length);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
} else {
|
||||
//dbg!("Not enough data");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Invalid Content-Length
|
||||
Err(std::io::ErrorKind::InvalidData.into())
|
||||
}
|
||||
} else {
|
||||
// Header ended without Content-Type => no body
|
||||
let remaining = src.split_off(end_index);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
}
|
||||
} else {
|
||||
//dbg!("Unfinished header");
|
||||
Ok(None)
|
||||
}
|
||||
}*/
|
||||
|
||||
/*if let Some(start_index) = memchr::memmem::find(src, b"HTTP") {
|
||||
if start_index != 0 {
|
||||
dbg!(start_index);
|
||||
if start_index == 529 {
|
||||
println!("{src:?}");
|
||||
}
|
||||
}
|
||||
let src2 = &src[start_index..];
|
||||
if let Some(mut end_index) = memchr::memmem::find(src2, b"\r\n\r\n") {
|
||||
end_index += 4;
|
||||
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src2) {
|
||||
if let Some(content_length) = captures.get(1) {
|
||||
// Read body
|
||||
let content_length: usize = str::from_utf8(content_length.as_bytes())
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap();
|
||||
if src2.len() >= end_index + content_length {
|
||||
if src2.len() > end_index + content_length {
|
||||
dbg!(src2.len(), end_index + content_length);
|
||||
println!("{src2:?}");
|
||||
std::process::exit(1);
|
||||
}
|
||||
//dbg!(content_length);
|
||||
let out = src2.to_vec();
|
||||
src.clear();
|
||||
Ok(Some(out))
|
||||
} else {
|
||||
//dbg!("Not enough data");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Invalid Content-Length
|
||||
Err(std::io::ErrorKind::InvalidData.into())
|
||||
}
|
||||
} else {
|
||||
// Header ended without Content-Type => no body
|
||||
let out = src2.to_vec();
|
||||
src.clear();
|
||||
Ok(Some(out))
|
||||
}
|
||||
} else {
|
||||
//dbg!("Unfinished header");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
//dbg!("Unstarted header");
|
||||
Ok(None)
|
||||
}*/
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Vec<u8>> for HttpCodec {
|
||||
impl Encoder<Vec<u8>> for HttpClientCodec {
|
||||
type Error = std::io::Error;
|
||||
fn encode(
|
||||
&mut self,
|
||||
_item: Vec<u8>,
|
||||
_dst: &mut tokio_util::bytes::BytesMut,
|
||||
) -> Result<(), Self::Error> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub struct HttpServerCodec {
|
||||
buf: Vec<u8>,
|
||||
}
|
||||
|
||||
impl HttpServerCodec {
|
||||
pub fn new() -> Self {
|
||||
Self { buf: Vec::new() }
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for HttpServerCodec {
|
||||
type Item = Vec<u8>;
|
||||
type Error = std::io::Error;
|
||||
fn decode(
|
||||
&mut self,
|
||||
src: &mut tokio_util::bytes::BytesMut,
|
||||
) -> Result<Option<Self::Item>, Self::Error> {
|
||||
self.buf.extend_from_slice(&src);
|
||||
src.clear();
|
||||
let src = &mut self.buf;
|
||||
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
|
||||
end_index += 4;
|
||||
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(src) {
|
||||
// Content-Length: simple body
|
||||
if let Some(content_length) = captures.get(1) {
|
||||
// Read body
|
||||
let content_length: usize = str::from_utf8(content_length.as_bytes())
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap();
|
||||
if src.len() >= end_index + content_length {
|
||||
let remaining = src.split_off(end_index + content_length);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
} else {
|
||||
//dbg!("Not enough data");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Invalid Content-Length
|
||||
Err(std::io::ErrorKind::InvalidData.into())
|
||||
}
|
||||
} else if REGEX_CHUNKED.is_match(&src[0..end_index]) {
|
||||
// Chunked body
|
||||
let mut content = &src[end_index..];
|
||||
let mut total_len = end_index;
|
||||
loop {
|
||||
if let Some(len_end_index) = memchr::memmem::find(content, b"\r\n") {
|
||||
let len_slice = &content[0..len_end_index];
|
||||
if len_end_index < 8 && is_hex(len_slice) {
|
||||
let chunk_len = parse_hex(len_slice) as usize;
|
||||
if content.len() >= len_end_index + chunk_len + 4 {
|
||||
total_len += len_end_index + chunk_len + 4;
|
||||
// Should we check the ending CRLF?
|
||||
if chunk_len == 0 {
|
||||
let remaining = src.split_off(total_len);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
return Ok(Some(out));
|
||||
}
|
||||
// else, wait for the next chunk
|
||||
content = &content[len_end_index + chunk_len + 4..];
|
||||
} else {
|
||||
// Not enough data
|
||||
return Ok(None);
|
||||
}
|
||||
} else {
|
||||
// Invalid chunk length
|
||||
return Err(std::io::ErrorKind::InvalidData.into());
|
||||
}
|
||||
} else {
|
||||
// Not enough data
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Header ended without Content-Type nor chunks => no body
|
||||
let remaining = src.split_off(end_index);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
}
|
||||
} else {
|
||||
//dbg!("Unfinished header");
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/*self.buf.extend_from_slice(&src);
|
||||
src.clear();
|
||||
let src = &mut self.buf;
|
||||
if let Some(mut end_index) = memchr::memmem::find(src, b"\r\n\r\n") {
|
||||
end_index += 4;
|
||||
if let Some(captures) = REGEX_CONTENT_LENGTH.captures(&src[0..end_index]) {
|
||||
if let Some(content_length) = captures.get(1) {
|
||||
// Read body
|
||||
let content_length: usize = str::from_utf8(content_length.as_bytes())
|
||||
.unwrap()
|
||||
.parse()
|
||||
.unwrap();
|
||||
if src.len() >= end_index + content_length {
|
||||
//dbg!(content_length);
|
||||
let remaining = src.split_off(end_index + content_length);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
} else {
|
||||
//dbg!("Not enough data");
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
// Invalid Content-Length
|
||||
Err(std::io::ErrorKind::InvalidData.into())
|
||||
}
|
||||
} else {
|
||||
// Header ended without Content-Type => no body
|
||||
let remaining = src.split_off(end_index);
|
||||
let out = src.to_vec();
|
||||
*src = remaining;
|
||||
Ok(Some(out))
|
||||
}
|
||||
} else {
|
||||
//dbg!("Unfinished header");
|
||||
Ok(None)
|
||||
}*/
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder<Vec<u8>> for HttpServerCodec {
|
||||
type Error = std::io::Error;
|
||||
fn encode(
|
||||
&mut self,
|
||||
|
|
|
|||
18
src/main.rs
18
src/main.rs
|
|
@ -1,7 +1,10 @@
|
|||
#![feature(ascii_char)]
|
||||
|
||||
mod client;
|
||||
mod http;
|
||||
mod record;
|
||||
mod server;
|
||||
mod util;
|
||||
|
||||
use record::Records;
|
||||
|
||||
|
|
@ -29,6 +32,8 @@ enum Subcommand {
|
|||
Print(OptPrint),
|
||||
/// Record traffic
|
||||
Record(OptRecord),
|
||||
/// Write test record
|
||||
Test(OptTest),
|
||||
}
|
||||
|
||||
/// Replay from records
|
||||
|
|
@ -68,6 +73,9 @@ struct OptPrint {
|
|||
/// Print packets
|
||||
#[argp(switch, short = 'p')]
|
||||
packets: bool,
|
||||
/// Record number
|
||||
#[argp(option, short = 'n')]
|
||||
number: Option<u64>,
|
||||
}
|
||||
|
||||
/// Record traffic
|
||||
|
|
@ -75,6 +83,11 @@ struct OptPrint {
|
|||
#[argp(subcommand, name = "record")]
|
||||
struct OptRecord {}
|
||||
|
||||
/// Record traffic
|
||||
#[derive(FromArgs)]
|
||||
#[argp(subcommand, name = "test")]
|
||||
struct OptTest {}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
enum RunMode {
|
||||
Client,
|
||||
|
|
@ -228,10 +241,13 @@ async fn main() {
|
|||
}
|
||||
Subcommand::Print(subopt) => {
|
||||
let records = record::read_record_file(&opt.record_file);
|
||||
record::print_records(&records, subopt.packets);
|
||||
record::print_records(&records, subopt.packets, subopt.number);
|
||||
}
|
||||
Subcommand::Record(_subopt) => {
|
||||
record::make_record(&opt.record_file);
|
||||
}
|
||||
Subcommand::Test(_subopt) => {
|
||||
record::make_test_record(&opt.record_file);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
102
src/record.rs
102
src/record.rs
|
|
@ -4,6 +4,8 @@ use std::{
|
|||
sync::mpsc::{Receiver, Sender, channel},
|
||||
};
|
||||
|
||||
use crate::util::print_bin;
|
||||
|
||||
const CLIENT_TO_SERVER: u8 = b'C';
|
||||
const SERVER_TO_CLIENT: u8 = b'S';
|
||||
|
||||
|
|
@ -15,6 +17,39 @@ pub enum Direction {
|
|||
ServerToClient,
|
||||
}
|
||||
|
||||
static TEST_RECORD: &[(u64, &str, Direction, &[u8])] =
|
||||
&[
|
||||
(0, "upload.wikimedia.org", Direction::ClientToServer, b"GET /aaaaaaaaa HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"),
|
||||
(0, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nContent-Length: 5\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\nhello"),
|
||||
(0, "upload.wikimedia.org", Direction::ClientToServer, b"GET /bbbbbbbbb HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"),
|
||||
(0, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nContent-Length: 7\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\nbonjour"),
|
||||
(1, "upload.wikimedia.org", Direction::ClientToServer, b"GET /ccccccccc HTTP/1.1\r\nHost: upload.wikimedia.org\r\n\r\n"),
|
||||
(1, "upload.wikimedia.org", Direction::ServerToClient, b"HTTP/1.1 200\r\nTransfer-Encoding: chunked\r\nDate: Wed, 12 Nov 2025 13:52:58 GMT\r\n\r\n6\r\nbanane\r\n"),
|
||||
(1, "upload.wikimedia.org", Direction::ServerToClient, b"5\r\npomme\r\n"),
|
||||
(1, "upload.wikimedia.org", Direction::ServerToClient, b"0\r\n\r\n"),
|
||||
];
|
||||
|
||||
fn write_record(
|
||||
file: &mut std::fs::File,
|
||||
direction: Direction,
|
||||
conn_id: u64,
|
||||
server_name: &str,
|
||||
data: &[u8],
|
||||
) {
|
||||
let server_name = server_name.as_bytes();
|
||||
file.write_all(&[match direction {
|
||||
Direction::ClientToServer => CLIENT_TO_SERVER,
|
||||
Direction::ServerToClient => SERVER_TO_CLIENT,
|
||||
}])
|
||||
.unwrap();
|
||||
file.write_all(&conn_id.to_be_bytes()).unwrap();
|
||||
file.write_all(&[server_name.len() as u8]).unwrap();
|
||||
file.write_all(server_name).unwrap();
|
||||
file.write_all(&(data.len() as u64).to_be_bytes()).unwrap();
|
||||
file.write_all(&data).unwrap();
|
||||
file.flush().unwrap();
|
||||
}
|
||||
|
||||
pub struct Recorder {
|
||||
file: std::fs::File,
|
||||
receiver: Receiver<(u64, Option<String>, Direction, Vec<u8>)>,
|
||||
|
|
@ -43,21 +78,7 @@ impl Recorder {
|
|||
let Some(server_name) = server_name else {
|
||||
continue;
|
||||
};
|
||||
let server_name = server_name.as_bytes();
|
||||
self.file
|
||||
.write_all(&[match direction {
|
||||
Direction::ClientToServer => CLIENT_TO_SERVER,
|
||||
Direction::ServerToClient => SERVER_TO_CLIENT,
|
||||
}])
|
||||
.unwrap();
|
||||
self.file.write_all(&conn_id.to_be_bytes()).unwrap();
|
||||
self.file.write_all(&[server_name.len() as u8]).unwrap();
|
||||
self.file.write_all(server_name).unwrap();
|
||||
self.file
|
||||
.write_all(&(data.len() as u64).to_be_bytes())
|
||||
.unwrap();
|
||||
self.file.write_all(&data).unwrap();
|
||||
self.file.flush().unwrap();
|
||||
write_record(&mut self.file, direction, conn_id, &server_name, &data);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -186,11 +207,38 @@ pub fn read_record_file(path: &str) -> Records {
|
|||
println!("Error: incomplete data. stop.");
|
||||
break;
|
||||
}
|
||||
|
||||
// Replace URL with unique id, to allow for better tracking by making each request unique.
|
||||
// (proxy may modify some headers, but not the URL)
|
||||
let mut insert_id = |req_id| {
|
||||
if direction == Direction::ClientToServer {
|
||||
let mut spaces = buf
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, c)| if *c == b' ' { Some(i) } else { None });
|
||||
let s1 = spaces.next().unwrap();
|
||||
let s2 = spaces.next().unwrap();
|
||||
let new_url = format!("/{conn_id}-{req_id}/");
|
||||
if s2 - s1 - 1 < new_url.len() {
|
||||
// Not optimal but good enough
|
||||
let mut new_buf = Vec::new();
|
||||
new_buf.extend_from_slice(&buf[0..s1 + 1]);
|
||||
new_buf.extend_from_slice(new_url.as_bytes());
|
||||
new_buf.extend_from_slice(&buf[s2..]);
|
||||
buf = new_buf;
|
||||
} else {
|
||||
buf[s1 + 1..s2][0..new_url.len()].copy_from_slice(new_url.as_bytes());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match records.entry(conn_id) {
|
||||
btree_map::Entry::Occupied(mut entry) => {
|
||||
(insert_id)(entry.get().1.len());
|
||||
entry.get_mut().1.push((direction, buf));
|
||||
}
|
||||
btree_map::Entry::Vacant(entry) => {
|
||||
(insert_id)(0);
|
||||
entry.insert((server_name, vec![(direction, buf)]));
|
||||
}
|
||||
}
|
||||
|
|
@ -198,8 +246,13 @@ pub fn read_record_file(path: &str) -> Records {
|
|||
records
|
||||
}
|
||||
|
||||
pub fn print_records(records: &Records, print_packets: bool) {
|
||||
pub fn print_records(records: &Records, print_packets: bool, number: Option<u64>) {
|
||||
for (id, (server_name, records)) in records {
|
||||
if let Some(number) = number
|
||||
&& number != *id
|
||||
{
|
||||
continue;
|
||||
}
|
||||
let server_name = str::from_utf8(server_name.as_slice()).unwrap();
|
||||
println!("{id} {server_name}");
|
||||
for (direction, data) in records {
|
||||
|
|
@ -212,7 +265,7 @@ pub fn print_records(records: &Records, print_packets: bool) {
|
|||
}
|
||||
}
|
||||
if print_packets {
|
||||
let data_tr = if data.len() >= 256 && *direction == Direction::ServerToClient {
|
||||
/*let data_tr = if data.len() >= 256 && *direction == Direction::ServerToClient {
|
||||
&data[0..256]
|
||||
} else {
|
||||
data.as_slice()
|
||||
|
|
@ -224,8 +277,21 @@ pub fn print_records(records: &Records, print_packets: bool) {
|
|||
}
|
||||
if let Some(header_end) = memchr::memmem::find(data, b"\r\n\r\n") {
|
||||
println!(" --> body len: {}", data.len() - header_end - 4);
|
||||
}
|
||||
}*/
|
||||
print_bin(&data[0..data.len().min(8192)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn make_test_record(path: &str) {
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.open(path)
|
||||
.unwrap();
|
||||
for (conn_id, server_name, direction, data) in TEST_RECORD {
|
||||
write_record(&mut file, *direction, *conn_id, *server_name, *data);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
183
src/server.rs
183
src/server.rs
|
|
@ -1,6 +1,7 @@
|
|||
use crate::{
|
||||
TlsMode,
|
||||
record::{Direction, Records},
|
||||
util::print_bin,
|
||||
};
|
||||
|
||||
use futures_util::stream::StreamExt;
|
||||
|
|
@ -37,10 +38,19 @@ pub async fn play(
|
|||
response_map.insert((server_name.to_vec(), hash), (id, responses));
|
||||
responses = Vec::new();
|
||||
}
|
||||
hash = Some(
|
||||
let mut slashes = data
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None });
|
||||
let s1 = slashes.next();
|
||||
let s2 = slashes.next();
|
||||
hash = Some(if let (Some(s1), Some(s2)) = (s1, s2) {
|
||||
data[s1 + 1..s2].to_vec()
|
||||
} else {
|
||||
panic!("Did not find URL: {:?}", &data[0..256]);
|
||||
tlsh::hash_buf(data)
|
||||
.map_or_else(|_| data.clone(), |h| h.to_string().into_bytes()),
|
||||
);
|
||||
.map_or_else(|_| data.clone(), |h| h.to_string().into_bytes())
|
||||
});
|
||||
}
|
||||
Direction::ServerToClient => {
|
||||
responses.push(data);
|
||||
|
|
@ -113,7 +123,7 @@ pub async fn play(
|
|||
.iter()
|
||||
{
|
||||
if let GeneralName::DNSName(name) = name {
|
||||
resolver.add(name, cert_key.clone()).unwrap();
|
||||
resolver.add(name, cert_key.clone()).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -189,40 +199,69 @@ pub async fn play(
|
|||
.unwrap()
|
||||
.trim_end_matches(".localhost")
|
||||
.to_string();
|
||||
let stream = accepted.into_stream(config).await.unwrap();
|
||||
let mut stream = Framed::new(stream, crate::http::HttpCodec {});
|
||||
let req = stream.next().await.unwrap().unwrap();
|
||||
let req_hash = tlsh::hash_buf(&req)
|
||||
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes());
|
||||
let mut best = None;
|
||||
for (i_server_name, hash) in response_map.keys() {
|
||||
if i_server_name != server_name.as_bytes() {
|
||||
continue;
|
||||
let stream = accepted
|
||||
.into_stream(config)
|
||||
.await
|
||||
.map_err(|e| panic!("{e:?} with name `{server_name}`"))
|
||||
.unwrap();
|
||||
let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new());
|
||||
//let mut previous = Vec::new();
|
||||
while let Some(req) = stream.next().await {
|
||||
let req = req.unwrap();
|
||||
//println!("REQUEST");
|
||||
//print_bin(&req);
|
||||
let req_hash = {
|
||||
let mut slashes = req
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None });
|
||||
let s1 = slashes.next();
|
||||
let s2 = slashes.next();
|
||||
if let (Some(s1), Some(s2)) = (s1, s2) {
|
||||
req[s1 + 1..s2].to_vec()
|
||||
} else {
|
||||
//println!("Previous: {:?}", &previous);
|
||||
println!("Did not find URL: {:?}", &req[0..req.len().min(255)]);
|
||||
tlsh::hash_buf(&req)
|
||||
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes())
|
||||
}
|
||||
};
|
||||
//previous = req.clone();
|
||||
let mut best = None;
|
||||
for (i_server_name, hash) in response_map.keys() {
|
||||
if i_server_name != server_name.as_bytes() {
|
||||
continue;
|
||||
}
|
||||
let diff = if &req_hash == hash {
|
||||
0
|
||||
} else {
|
||||
compare(&req_hash, hash)
|
||||
};
|
||||
if let Some((best_hash, best_diff)) = &mut best {
|
||||
if diff < *best_diff {
|
||||
*best_hash = hash;
|
||||
*best_diff = diff;
|
||||
}
|
||||
} else {
|
||||
best = Some((hash, diff));
|
||||
}
|
||||
}
|
||||
let diff = compare(&req_hash, hash);
|
||||
if let Some((best_hash, best_diff)) = &mut best {
|
||||
if diff < *best_diff {
|
||||
*best_hash = hash;
|
||||
*best_diff = diff;
|
||||
let stream = stream.get_mut();
|
||||
if let Some((hash, _diff)) = best {
|
||||
let (id, responses) = response_map
|
||||
.get(&(server_name.as_bytes().to_vec(), hash.clone()))
|
||||
.unwrap();
|
||||
//dbg!(id);
|
||||
for &res in responses {
|
||||
//println!("[SRV] response for ({}): {} bytes", id, res.len());
|
||||
stream.write_all(res).await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
}
|
||||
} else {
|
||||
best = Some((hash, diff));
|
||||
println!("No response found for SNI=`{server_name}`");
|
||||
}
|
||||
}
|
||||
let stream = stream.get_mut();
|
||||
if let Some((hash, _diff)) = best {
|
||||
let (id, responses) = response_map
|
||||
.get(&(server_name.as_bytes().to_vec(), hash.clone()))
|
||||
.unwrap();
|
||||
for &res in responses {
|
||||
//println!("[SRV] response for ({}): {} bytes", id, res.len());
|
||||
stream.write_all(res).await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
}
|
||||
} else {
|
||||
println!("No response found for SNI=`{server_name}`");
|
||||
}
|
||||
stream.shutdown().await.unwrap();
|
||||
stream.get_mut().shutdown().await.unwrap();
|
||||
};
|
||||
tokio::spawn(async move {
|
||||
fut.await;
|
||||
|
|
@ -274,39 +313,61 @@ pub async fn play(
|
|||
};*/
|
||||
let fut = async move {
|
||||
//println!("[SRV] New task");
|
||||
let mut stream = Framed::new(stream, crate::http::HttpCodec {});
|
||||
let req = stream.next().await.unwrap().unwrap();
|
||||
//println!("[SRV] << {}", str::from_utf8(&req[..req.len().min(255)]).unwrap());
|
||||
let req_hash = tlsh::hash_buf(&req)
|
||||
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes());
|
||||
let mut best = None;
|
||||
for (i_server_name, hash) in response_map.keys() {
|
||||
let diff = compare(&req_hash, hash);
|
||||
if let Some((best_server_name, best_hash, best_diff)) = &mut best {
|
||||
if diff < *best_diff {
|
||||
*best_server_name = i_server_name;
|
||||
*best_hash = hash;
|
||||
*best_diff = diff;
|
||||
//let mut stream = crate::http::Framer::new(stream, crate::http::HttpServerCodec::new());
|
||||
let mut stream = Framed::new(stream, crate::http::HttpServerCodec::new());
|
||||
//let mut previous = Vec::new();
|
||||
while let Some(req) = stream.next().await {
|
||||
let req = req.unwrap();
|
||||
//println!("REQUEST");
|
||||
//print_bin(&req);
|
||||
//println!("[SRV] << {}", str::from_utf8(&req[..req.len().min(255)]).unwrap());
|
||||
let req_hash = {
|
||||
let mut slashes = req
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, c)| if *c == b'/' { Some(i) } else { None });
|
||||
let s1 = slashes.next();
|
||||
let s2 = slashes.next();
|
||||
if let (Some(s1), Some(s2)) = (s1, s2) {
|
||||
req[s1 + 1..s2].to_vec()
|
||||
} else {
|
||||
//println!("Previous: {:?}", &previous);
|
||||
println!("Did not find URL: {:?}", &req[0..req.len().min(255)]);
|
||||
tlsh::hash_buf(&req)
|
||||
.map_or_else(|_| req.clone(), |h| h.to_string().into_bytes())
|
||||
}
|
||||
};
|
||||
//previous = req.clone();
|
||||
let mut best = None;
|
||||
for (i_server_name, hash) in response_map.keys() {
|
||||
let diff = compare(&req_hash, hash);
|
||||
if let Some((best_server_name, best_hash, best_diff)) = &mut best {
|
||||
if diff < *best_diff {
|
||||
*best_server_name = i_server_name;
|
||||
*best_hash = hash;
|
||||
*best_diff = diff;
|
||||
}
|
||||
} else {
|
||||
best = Some((i_server_name, hash, diff));
|
||||
}
|
||||
}
|
||||
let stream = stream.get_mut();
|
||||
if let Some((server_name, hash, _diff)) = best {
|
||||
let (id, responses) = response_map
|
||||
.get(&(server_name.clone(), hash.clone()))
|
||||
.unwrap();
|
||||
//dbg!(id);
|
||||
for &res in responses {
|
||||
//println!("[SRV] response for ({}): {} bytes", id, res.len());
|
||||
stream.write_all(res).await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
}
|
||||
} else {
|
||||
best = Some((i_server_name, hash, diff));
|
||||
println!("[SRV] No response found");
|
||||
}
|
||||
}
|
||||
let stream = stream.get_mut();
|
||||
if let Some((server_name, hash, _diff)) = best {
|
||||
let (id, responses) = response_map
|
||||
.get(&(server_name.clone(), hash.clone()))
|
||||
.unwrap();
|
||||
for &res in responses {
|
||||
//println!("[SRV] response for ({}): {} bytes", id, res.len());
|
||||
stream.write_all(res).await.unwrap();
|
||||
stream.flush().await.unwrap();
|
||||
}
|
||||
} else {
|
||||
println!("[SRV] No response found");
|
||||
}
|
||||
//println!("Server shutdown");
|
||||
stream.shutdown().await.unwrap();
|
||||
stream.get_mut().shutdown().await.unwrap();
|
||||
};
|
||||
// Using a variable for the future allows it to be detected by tokio-console
|
||||
tokio::spawn(async move {
|
||||
|
|
|
|||
130
src/util.rs
130
src/util.rs
|
|
@ -0,0 +1,130 @@
|
|||
use std::iter::Peekable;
|
||||
|
||||
use crate::record::Direction;
|
||||
|
||||
fn hex_digit(c: u8) -> u32 {
|
||||
((c & !(16 | 32 | 64)) + ((c & 64) >> 6) * 9) as _
|
||||
}
|
||||
|
||||
pub fn parse_hex(s: &[u8]) -> u32 {
|
||||
let mut r = 0;
|
||||
for i in s.iter() {
|
||||
r <<= 4;
|
||||
r |= hex_digit(*i);
|
||||
}
|
||||
r
|
||||
}
|
||||
|
||||
pub fn is_hex(s: &[u8]) -> bool {
|
||||
s.iter().all(|c| {
|
||||
let c = *c | 32;
|
||||
(c >= b'a' && c <= b'f') || (c >= b'0' && c <= b'9')
|
||||
})
|
||||
}
|
||||
|
||||
/// Print ASCII if possible
|
||||
pub fn print_bin(s: &[u8]) {
|
||||
if let Ok(s) = str::from_utf8(s) {
|
||||
println!("{s}");
|
||||
} else {
|
||||
let mut buf = String::new();
|
||||
for c in s {
|
||||
if c.is_ascii_control() && *c != b'\n' {
|
||||
continue;
|
||||
}
|
||||
if let Some(c) = c.as_ascii() {
|
||||
buf.push_str(c.as_str());
|
||||
} else {
|
||||
for c in std::ascii::escape_default(*c) {
|
||||
buf.push(c.as_ascii().unwrap().into());
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("{buf}");
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResponseStreamer<I: Iterator>(Peekable<I>);
|
||||
|
||||
impl<'a, I: Iterator> ResponseStreamer<I> {
|
||||
pub fn new(inner: I) -> Self {
|
||||
Self(inner.peekable())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, I: Iterator<Item = &'a (Direction, Vec<u8>)>> Iterator for ResponseStreamer<I> {
|
||||
type Item = (&'a Direction, Vec<&'a Vec<u8>>);
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
let (direction, first_item) = self.0.next()?;
|
||||
let mut items = vec![first_item];
|
||||
while let Some((item_direction, _item)) = self.0.peek()
|
||||
&& item_direction == direction
|
||||
{
|
||||
items.push(&self.0.next().unwrap().1);
|
||||
}
|
||||
Some((direction, items))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_hex_digit() {
|
||||
assert_eq!(hex_digit(b'0'), 0);
|
||||
assert_eq!(hex_digit(b'1'), 1);
|
||||
assert_eq!(hex_digit(b'2'), 2);
|
||||
assert_eq!(hex_digit(b'3'), 3);
|
||||
assert_eq!(hex_digit(b'4'), 4);
|
||||
assert_eq!(hex_digit(b'5'), 5);
|
||||
assert_eq!(hex_digit(b'6'), 6);
|
||||
assert_eq!(hex_digit(b'7'), 7);
|
||||
assert_eq!(hex_digit(b'8'), 8);
|
||||
assert_eq!(hex_digit(b'9'), 9);
|
||||
assert_eq!(hex_digit(b'a'), 10);
|
||||
assert_eq!(hex_digit(b'b'), 11);
|
||||
assert_eq!(hex_digit(b'c'), 12);
|
||||
assert_eq!(hex_digit(b'd'), 13);
|
||||
assert_eq!(hex_digit(b'e'), 14);
|
||||
assert_eq!(hex_digit(b'f'), 15);
|
||||
assert_eq!(hex_digit(b'A'), 10);
|
||||
assert_eq!(hex_digit(b'B'), 11);
|
||||
assert_eq!(hex_digit(b'C'), 12);
|
||||
assert_eq!(hex_digit(b'D'), 13);
|
||||
assert_eq!(hex_digit(b'E'), 14);
|
||||
assert_eq!(hex_digit(b'F'), 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hex() {
|
||||
assert_eq!(parse_hex(b"abc123"), 0xabc123);
|
||||
assert_eq!(parse_hex(b"1"), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_hex() {
|
||||
assert!(is_hex(b"0"));
|
||||
assert!(is_hex(b"1"));
|
||||
assert!(is_hex(b"2"));
|
||||
assert!(is_hex(b"3"));
|
||||
assert!(is_hex(b"4"));
|
||||
assert!(is_hex(b"5"));
|
||||
assert!(is_hex(b"6"));
|
||||
assert!(is_hex(b"7"));
|
||||
assert!(is_hex(b"8"));
|
||||
assert!(is_hex(b"9"));
|
||||
assert!(is_hex(b"a"));
|
||||
assert!(is_hex(b"b"));
|
||||
assert!(is_hex(b"c"));
|
||||
assert!(is_hex(b"d"));
|
||||
assert!(is_hex(b"e"));
|
||||
assert!(is_hex(b"f"));
|
||||
assert!(is_hex(b"A"));
|
||||
assert!(is_hex(b"B"));
|
||||
assert!(is_hex(b"C"));
|
||||
assert!(is_hex(b"D"));
|
||||
assert!(is_hex(b"E"));
|
||||
assert!(is_hex(b"F"));
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue