mesozoa/src/main.rs

252 lines
6.6 KiB
Rust

mod challenge;
mod cli;
mod config;
mod http;
mod policy;
use http::HeaderLineIterator;
use policy::CompiledPolicies;
use rand::Rng;
use realm_syscall::socket2::TcpKeepalive;
use regex::bytes::Regex;
use std::{net::SocketAddr, time::Duration};
use tokio::{
io::{AsyncWriteExt, ReadBuf},
net::{TcpSocket, TcpStream},
time::timeout,
};
const SALT_LEN: usize = 16;
const SECRET_LEN: usize = 32;
const MAC_LEN: usize = 32;
const TARGET_ZEROS: u32 = 15;
static CHALLENGE_BODY: &str = include_str!("challenge.html");
macro_rules! mk_static {
($t:ty, $val:expr) => {{
static STATIC_CELL: static_cell::StaticCell<$t> = static_cell::StaticCell::new();
#[deny(unused_attributes)]
let x = STATIC_CELL.uninit().write(($val));
x
}};
}
#[tokio::main]
async fn main() {
let cli: cli::Cli = argp::parse_args_or_exit(argp::DEFAULT);
let config = &*mk_static!(config::Config, config::Config::from_file(&cli.config));
let mut rng = rand::thread_rng();
let secret: [u8; SECRET_LEN] = rng.r#gen();
let policy_groups = &*mk_static!(
Vec<CompiledPolicies>,
config
.policy_groups
.iter()
.map(|policy| CompiledPolicies::new(policy))
.collect()
);
let socket = realm_syscall::new_tcp_socket(&config.listen_addr).unwrap();
socket.set_reuse_address(true).ok();
socket.bind(&config.listen_addr.into()).unwrap();
socket.listen(1024).unwrap();
let listener = tokio::net::TcpListener::from_std(socket.into()).unwrap();
let proof_regex =
Regex::new(r"^Cookie: *(?:[^;=]+=[^;=]* *; *)*mesozoa-proof *= *([0-9a-zA-Z_-]{8})")
.unwrap();
let challenge_regex =
Regex::new(r"^Cookie: *(?:[^;=]+=[^;=]* *; *)*mesozoa-challenge *= *([0-9a-zA-Z_-]{75})")
.unwrap();
let ip_regex = Regex::new(r"^X-Forwarded-For: *([a-fA-F0-9.:]+)$").unwrap();
let user_agent_regex = Regex::new(r"^User-Agent: *([a-zA-Z0-9.,:;/ _()-]+)$").unwrap();
let response_begin = &*mk_static!(
String,
format!(
"HTTP/1.1 200\r\n\
content-type: text/html\r\n\
content-length: {}\r\n",
CHALLENGE_BODY.len(),
)
);
loop {
let Ok((mut client_stream, _client_addr)) = listener.accept().await else {
continue;
};
//client_stream.set_nodelay(true).ok();
let proof_regex = proof_regex.clone();
let challenge_regex = challenge_regex.clone();
let ip_regex = ip_regex.clone();
let user_agent_regex = user_agent_regex.clone();
tokio::spawn(async move {
let mut buf = [0u8; 1024];
let mut buf_reader = ReadBuf::new(&mut buf);
if timeout(
Duration::from_millis(100),
std::future::poll_fn(|cx| client_stream.poll_peek(cx, &mut buf_reader)),
)
.await
.is_err()
{
// Peek timeout
return;
}
let mut header_line_iter = HeaderLineIterator::new(&buf);
let Some(first_line) = header_line_iter.next() else {
// Not HTTP, or too long line
return;
};
let mut action = config.default_action;
for policy_group in policy_groups.iter() {
if let Some(policy) = policy_group.evaluate(first_line) {
action = policy.action;
break;
}
}
match action {
policy::Action::Drop => {}
policy::Action::Allow => {
do_proxy(config.pass_addr, client_stream).await;
}
policy::Action::Challenge => {
let mut req_challenge = None;
let mut req_proof = None;
let mut req_user_agent: &[u8] = &[];
let mut req_ip: &[u8] = &[];
for line in header_line_iter {
if let Some(Some(m)) = challenge_regex.captures(line).map(|c| c.get(1)) {
req_challenge = Some(m.as_bytes());
}
if let Some(Some(m)) = proof_regex.captures(line).map(|c| c.get(1)) {
req_proof = Some(m.as_bytes());
}
if let Some(Some(m)) = user_agent_regex.captures(line).map(|c| c.get(1)) {
req_user_agent = m.as_bytes();
}
if let Some(Some(m)) = ip_regex.captures(line).map(|c| c.get(1)) {
req_ip = m.as_bytes();
}
}
let mut valid_challenge = false;
let mut allow = false;
if let Some(req_challenge) = req_challenge {
valid_challenge = challenge::verify_challenge_cookie(
req_challenge,
&secret,
req_user_agent,
req_ip,
config.challenge_timeout,
);
if let Some(req_proof) = req_proof {
allow = valid_challenge
&& challenge::check_challenge(
req_challenge,
req_proof,
TARGET_ZEROS,
);
}
}
if allow {
do_proxy(config.pass_addr, client_stream).await;
} else {
let salt: [u8; SALT_LEN] = rand::thread_rng().r#gen();
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let timestamp_bytes = timestamp.to_be_bytes();
let challenge_mac = challenge::compute_challenge_mac(
&secret,
&salt,
timestamp_bytes,
req_ip,
req_user_agent,
);
let challenge_cookie = challenge::format_challenge_cookie(
&salt,
timestamp_bytes,
&challenge_mac,
);
client_stream.writable().await.unwrap();
client_stream
.write_all(response_begin.as_bytes())
.await
.unwrap();
if !valid_challenge {
client_stream
.write_all(b"set-cookie: mesozoa-challenge=")
.await
.unwrap();
client_stream
.write_all(challenge_cookie.as_bytes())
.await
.unwrap();
client_stream.write_all(b"; domain=").await.unwrap();
client_stream
.write_all(config.domain.as_bytes())
.await
.unwrap();
client_stream
.write_all(b"; path=/; max-age=3600; samesite=strict\r\n")
.await
.unwrap();
}
client_stream.write_all(b"\r\n").await.unwrap();
client_stream
.write_all(CHALLENGE_BODY.as_bytes())
.await
.unwrap();
}
}
}
});
}
}
async fn do_proxy(pass_addr: SocketAddr, mut client_stream: TcpStream) {
let keepalive_dur = Duration::from_secs(15);
let mut keepalive = TcpKeepalive::new().with_time(keepalive_dur);
keepalive = TcpKeepalive::with_interval(keepalive, keepalive_dur);
keepalive = TcpKeepalive::with_retries(keepalive, 3);
let pass_socket = realm_syscall::new_tcp_socket(&pass_addr).unwrap();
pass_socket.set_reuse_address(true).ok();
pass_socket.set_tcp_keepalive(&keepalive).ok();
let pass_socket = TcpSocket::from_std_stream(pass_socket.into());
let mut pass_stream = pass_socket.connect(pass_addr).await.unwrap();
match realm_io::bidi_zero_copy(&mut client_stream, &mut pass_stream).await {
Ok(_) => {}
Err(ref e) if e.kind() == tokio::io::ErrorKind::InvalidInput => {
realm_io::bidi_copy(&mut client_stream, &mut pass_stream)
.await
.unwrap();
}
Err(e) => panic!("err {}", e),
}
}