mod challenge; mod cli; mod config; mod http; mod policy; use http::HeaderLineIterator; use policy::CompiledPolicies; use rand::Rng; use realm_syscall::socket2::TcpKeepalive; use regex::bytes::Regex; use std::{net::SocketAddr, time::Duration}; use tokio::{ io::{AsyncWriteExt, ReadBuf}, net::{TcpSocket, TcpStream}, time::timeout, }; const SALT_LEN: usize = 16; const SECRET_LEN: usize = 32; const MAC_LEN: usize = 32; const TARGET_ZEROS: u32 = 15; static CHALLENGE_BODY: &str = include_str!("challenge.html"); macro_rules! mk_static { ($t:ty, $val:expr) => {{ static STATIC_CELL: static_cell::StaticCell<$t> = static_cell::StaticCell::new(); #[deny(unused_attributes)] let x = STATIC_CELL.uninit().write(($val)); x }}; } #[tokio::main] async fn main() { let cli: cli::Cli = argp::parse_args_or_exit(argp::DEFAULT); let config = &*mk_static!(config::Config, config::Config::from_file(&cli.config)); let mut rng = rand::thread_rng(); let secret: [u8; SECRET_LEN] = rng.r#gen(); let policy_groups = &*mk_static!( Vec, config .policy_groups .iter() .map(|policy| CompiledPolicies::new(policy)) .collect() ); let socket = realm_syscall::new_tcp_socket(&config.listen_addr).unwrap(); socket.set_reuse_address(true).ok(); socket.bind(&config.listen_addr.into()).unwrap(); socket.listen(1024).unwrap(); let listener = tokio::net::TcpListener::from_std(socket.into()).unwrap(); let proof_regex = Regex::new(r"^Cookie: *(?:[^;=]+=[^;=]* *; *)*mesozoa-proof *= *([0-9a-zA-Z_-]{8})") .unwrap(); let challenge_regex = Regex::new(r"^Cookie: *(?:[^;=]+=[^;=]* *; *)*mesozoa-challenge *= *([0-9a-zA-Z_-]{75})") .unwrap(); let ip_regex = Regex::new(r"^X-Forwarded-For: *([a-fA-F0-9.:]+)$").unwrap(); let user_agent_regex = Regex::new(r"^User-Agent: *([a-zA-Z0-9.,:;/ _()-]+)$").unwrap(); let response_begin = &*mk_static!( String, format!( "HTTP/1.1 200\r\n\ content-type: text/html\r\n\ content-length: {}\r\n", CHALLENGE_BODY.len(), ) ); loop { let Ok((mut client_stream, _client_addr)) = listener.accept().await else { continue; }; //client_stream.set_nodelay(true).ok(); let proof_regex = proof_regex.clone(); let challenge_regex = challenge_regex.clone(); let ip_regex = ip_regex.clone(); let user_agent_regex = user_agent_regex.clone(); tokio::spawn(async move { let mut buf = [0u8; 1024]; let mut buf_reader = ReadBuf::new(&mut buf); if timeout( Duration::from_millis(100), std::future::poll_fn(|cx| client_stream.poll_peek(cx, &mut buf_reader)), ) .await .is_err() { // Peek timeout return; } let mut header_line_iter = HeaderLineIterator::new(&buf); let Some(first_line) = header_line_iter.next() else { // Not HTTP, or too long line return; }; let mut action = config.default_action; for policy_group in policy_groups.iter() { if let Some(policy) = policy_group.evaluate(first_line) { action = policy.action; break; } } match action { policy::Action::Drop => {} policy::Action::Allow => { do_proxy(config.pass_addr, client_stream).await; } policy::Action::Challenge => { let mut req_challenge = None; let mut req_proof = None; let mut req_user_agent: &[u8] = &[]; let mut req_ip: &[u8] = &[]; for line in header_line_iter { if let Some(Some(m)) = challenge_regex.captures(line).map(|c| c.get(1)) { req_challenge = Some(m.as_bytes()); } if let Some(Some(m)) = proof_regex.captures(line).map(|c| c.get(1)) { req_proof = Some(m.as_bytes()); } if let Some(Some(m)) = user_agent_regex.captures(line).map(|c| c.get(1)) { req_user_agent = m.as_bytes(); } if let Some(Some(m)) = ip_regex.captures(line).map(|c| c.get(1)) { req_ip = m.as_bytes(); } } let mut valid_challenge = false; let mut allow = false; if let Some(req_challenge) = req_challenge { valid_challenge = challenge::verify_challenge_cookie( req_challenge, &secret, req_user_agent, req_ip, config.challenge_timeout, ); if let Some(req_proof) = req_proof { allow = valid_challenge && challenge::check_challenge( req_challenge, req_proof, TARGET_ZEROS, ); } } if allow { do_proxy(config.pass_addr, client_stream).await; } else { let salt: [u8; SALT_LEN] = rand::thread_rng().r#gen(); let timestamp = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap() .as_secs(); let timestamp_bytes = timestamp.to_be_bytes(); let challenge_mac = challenge::compute_challenge_mac( &secret, &salt, timestamp_bytes, req_ip, req_user_agent, ); let challenge_cookie = challenge::format_challenge_cookie( &salt, timestamp_bytes, &challenge_mac, ); client_stream.writable().await.unwrap(); client_stream .write_all(response_begin.as_bytes()) .await .unwrap(); if !valid_challenge { client_stream .write_all(b"set-cookie: mesozoa-challenge=") .await .unwrap(); client_stream .write_all(challenge_cookie.as_bytes()) .await .unwrap(); client_stream.write_all(b"; domain=").await.unwrap(); client_stream .write_all(config.domain.as_bytes()) .await .unwrap(); client_stream .write_all(b"; path=/; max-age=3600; samesite=strict\r\n") .await .unwrap(); } client_stream.write_all(b"\r\n").await.unwrap(); client_stream .write_all(CHALLENGE_BODY.as_bytes()) .await .unwrap(); } } } }); } } async fn do_proxy(pass_addr: SocketAddr, mut client_stream: TcpStream) { let keepalive_dur = Duration::from_secs(15); let mut keepalive = TcpKeepalive::new().with_time(keepalive_dur); keepalive = TcpKeepalive::with_interval(keepalive, keepalive_dur); keepalive = TcpKeepalive::with_retries(keepalive, 3); let pass_socket = realm_syscall::new_tcp_socket(&pass_addr).unwrap(); pass_socket.set_reuse_address(true).ok(); pass_socket.set_tcp_keepalive(&keepalive).ok(); let pass_socket = TcpSocket::from_std_stream(pass_socket.into()); let mut pass_stream = pass_socket.connect(pass_addr).await.unwrap(); match realm_io::bidi_zero_copy(&mut client_stream, &mut pass_stream).await { Ok(_) => {} Err(ref e) if e.kind() == tokio::io::ErrorKind::InvalidInput => { realm_io::bidi_copy(&mut client_stream, &mut pass_stream) .await .unwrap(); } Err(e) => panic!("err {}", e), } }