252 lines
6.6 KiB
Rust
252 lines
6.6 KiB
Rust
mod challenge;
|
|
mod cli;
|
|
mod config;
|
|
mod http;
|
|
mod policy;
|
|
|
|
use http::HeaderLineIterator;
|
|
use policy::CompiledPolicies;
|
|
|
|
use rand::Rng;
|
|
use realm_syscall::socket2::TcpKeepalive;
|
|
use regex::bytes::Regex;
|
|
use std::{net::SocketAddr, time::Duration};
|
|
use tokio::{
|
|
io::{AsyncWriteExt, ReadBuf},
|
|
net::{TcpSocket, TcpStream},
|
|
time::timeout,
|
|
};
|
|
|
|
const SALT_LEN: usize = 16;
|
|
const SECRET_LEN: usize = 32;
|
|
const MAC_LEN: usize = 32;
|
|
const TARGET_ZEROS: u32 = 15;
|
|
|
|
static CHALLENGE_BODY: &str = include_str!("challenge.html");
|
|
|
|
macro_rules! mk_static {
|
|
($t:ty, $val:expr) => {{
|
|
static STATIC_CELL: static_cell::StaticCell<$t> = static_cell::StaticCell::new();
|
|
#[deny(unused_attributes)]
|
|
let x = STATIC_CELL.uninit().write(($val));
|
|
x
|
|
}};
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
let cli: cli::Cli = argp::parse_args_or_exit(argp::DEFAULT);
|
|
|
|
let config = &*mk_static!(config::Config, config::Config::from_file(&cli.config));
|
|
|
|
let mut rng = rand::thread_rng();
|
|
|
|
let secret: [u8; SECRET_LEN] = rng.r#gen();
|
|
|
|
let policy_groups = &*mk_static!(
|
|
Vec<CompiledPolicies>,
|
|
config
|
|
.policy_groups
|
|
.iter()
|
|
.map(|policy| CompiledPolicies::new(policy))
|
|
.collect()
|
|
);
|
|
|
|
let socket = realm_syscall::new_tcp_socket(&config.listen_addr).unwrap();
|
|
|
|
socket.set_reuse_address(true).ok();
|
|
|
|
socket.bind(&config.listen_addr.into()).unwrap();
|
|
socket.listen(1024).unwrap();
|
|
|
|
let listener = tokio::net::TcpListener::from_std(socket.into()).unwrap();
|
|
|
|
let proof_regex =
|
|
Regex::new(r"^Cookie: *(?:[^;=]+=[^;=]* *; *)*mesozoa-proof *= *([0-9a-zA-Z_-]{8})")
|
|
.unwrap();
|
|
let challenge_regex =
|
|
Regex::new(r"^Cookie: *(?:[^;=]+=[^;=]* *; *)*mesozoa-challenge *= *([0-9a-zA-Z_-]{75})")
|
|
.unwrap();
|
|
let ip_regex = Regex::new(r"^X-Forwarded-For: *([a-fA-F0-9.:]+)$").unwrap();
|
|
let user_agent_regex = Regex::new(r"^User-Agent: *([a-zA-Z0-9.,:;/ _()-]+)$").unwrap();
|
|
|
|
let response_begin = &*mk_static!(
|
|
String,
|
|
format!(
|
|
"HTTP/1.1 200\r\n\
|
|
content-type: text/html\r\n\
|
|
content-length: {}\r\n",
|
|
CHALLENGE_BODY.len(),
|
|
)
|
|
);
|
|
|
|
loop {
|
|
let Ok((mut client_stream, _client_addr)) = listener.accept().await else {
|
|
continue;
|
|
};
|
|
//client_stream.set_nodelay(true).ok();
|
|
|
|
let proof_regex = proof_regex.clone();
|
|
let challenge_regex = challenge_regex.clone();
|
|
let ip_regex = ip_regex.clone();
|
|
let user_agent_regex = user_agent_regex.clone();
|
|
tokio::spawn(async move {
|
|
let mut buf = [0u8; 1024];
|
|
let mut buf_reader = ReadBuf::new(&mut buf);
|
|
if timeout(
|
|
Duration::from_millis(100),
|
|
std::future::poll_fn(|cx| client_stream.poll_peek(cx, &mut buf_reader)),
|
|
)
|
|
.await
|
|
.is_err()
|
|
{
|
|
// Peek timeout
|
|
return;
|
|
}
|
|
|
|
let mut header_line_iter = HeaderLineIterator::new(&buf);
|
|
let Some(first_line) = header_line_iter.next() else {
|
|
// Not HTTP, or too long line
|
|
return;
|
|
};
|
|
|
|
let mut action = config.default_action;
|
|
for policy_group in policy_groups.iter() {
|
|
if let Some(policy) = policy_group.evaluate(first_line) {
|
|
action = policy.action;
|
|
break;
|
|
}
|
|
}
|
|
|
|
match action {
|
|
policy::Action::Drop => {}
|
|
policy::Action::Allow => {
|
|
do_proxy(config.pass_addr, client_stream).await;
|
|
}
|
|
policy::Action::Challenge => {
|
|
let mut req_challenge = None;
|
|
let mut req_proof = None;
|
|
let mut req_user_agent: &[u8] = &[];
|
|
let mut req_ip: &[u8] = &[];
|
|
for line in header_line_iter {
|
|
if let Some(Some(m)) = challenge_regex.captures(line).map(|c| c.get(1)) {
|
|
req_challenge = Some(m.as_bytes());
|
|
}
|
|
if let Some(Some(m)) = proof_regex.captures(line).map(|c| c.get(1)) {
|
|
req_proof = Some(m.as_bytes());
|
|
}
|
|
if let Some(Some(m)) = user_agent_regex.captures(line).map(|c| c.get(1)) {
|
|
req_user_agent = m.as_bytes();
|
|
}
|
|
if let Some(Some(m)) = ip_regex.captures(line).map(|c| c.get(1)) {
|
|
req_ip = m.as_bytes();
|
|
}
|
|
}
|
|
|
|
let mut valid_challenge = false;
|
|
let mut allow = false;
|
|
if let Some(req_challenge) = req_challenge {
|
|
valid_challenge = challenge::verify_challenge_cookie(
|
|
req_challenge,
|
|
&secret,
|
|
req_user_agent,
|
|
req_ip,
|
|
config.challenge_timeout,
|
|
);
|
|
|
|
if let Some(req_proof) = req_proof {
|
|
allow = valid_challenge
|
|
&& challenge::check_challenge(
|
|
req_challenge,
|
|
req_proof,
|
|
TARGET_ZEROS,
|
|
);
|
|
}
|
|
}
|
|
|
|
if allow {
|
|
do_proxy(config.pass_addr, client_stream).await;
|
|
} else {
|
|
let salt: [u8; SALT_LEN] = rand::thread_rng().r#gen();
|
|
|
|
let timestamp = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap()
|
|
.as_secs();
|
|
let timestamp_bytes = timestamp.to_be_bytes();
|
|
|
|
let challenge_mac = challenge::compute_challenge_mac(
|
|
&secret,
|
|
&salt,
|
|
timestamp_bytes,
|
|
req_ip,
|
|
req_user_agent,
|
|
);
|
|
|
|
let challenge_cookie = challenge::format_challenge_cookie(
|
|
&salt,
|
|
timestamp_bytes,
|
|
&challenge_mac,
|
|
);
|
|
|
|
client_stream.writable().await.unwrap();
|
|
client_stream
|
|
.write_all(response_begin.as_bytes())
|
|
.await
|
|
.unwrap();
|
|
if !valid_challenge {
|
|
client_stream
|
|
.write_all(b"set-cookie: mesozoa-challenge=")
|
|
.await
|
|
.unwrap();
|
|
client_stream
|
|
.write_all(challenge_cookie.as_bytes())
|
|
.await
|
|
.unwrap();
|
|
client_stream.write_all(b"; domain=").await.unwrap();
|
|
client_stream
|
|
.write_all(config.domain.as_bytes())
|
|
.await
|
|
.unwrap();
|
|
client_stream
|
|
.write_all(b"; path=/; max-age=3600; samesite=strict\r\n")
|
|
.await
|
|
.unwrap();
|
|
}
|
|
client_stream.write_all(b"\r\n").await.unwrap();
|
|
client_stream
|
|
.write_all(CHALLENGE_BODY.as_bytes())
|
|
.await
|
|
.unwrap();
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
}
|
|
|
|
async fn do_proxy(pass_addr: SocketAddr, mut client_stream: TcpStream) {
|
|
let keepalive_dur = Duration::from_secs(15);
|
|
let mut keepalive = TcpKeepalive::new().with_time(keepalive_dur);
|
|
keepalive = TcpKeepalive::with_interval(keepalive, keepalive_dur);
|
|
keepalive = TcpKeepalive::with_retries(keepalive, 3);
|
|
|
|
let pass_socket = realm_syscall::new_tcp_socket(&pass_addr).unwrap();
|
|
|
|
pass_socket.set_reuse_address(true).ok();
|
|
pass_socket.set_tcp_keepalive(&keepalive).ok();
|
|
|
|
let pass_socket = TcpSocket::from_std_stream(pass_socket.into());
|
|
|
|
let mut pass_stream = pass_socket.connect(pass_addr).await.unwrap();
|
|
|
|
match realm_io::bidi_zero_copy(&mut client_stream, &mut pass_stream).await {
|
|
Ok(_) => {}
|
|
Err(ref e) if e.kind() == tokio::io::ErrorKind::InvalidInput => {
|
|
realm_io::bidi_copy(&mut client_stream, &mut pass_stream)
|
|
.await
|
|
.unwrap();
|
|
}
|
|
Err(e) => panic!("err {}", e),
|
|
}
|
|
}
|