diff --git a/rpxy-lib/src/error.rs b/rpxy-lib/src/error.rs index 343cf04..f63a06c 100644 --- a/rpxy-lib/src/error.rs +++ b/rpxy-lib/src/error.rs @@ -84,35 +84,15 @@ pub enum RpxyError { #[error("Failed to fetch from upstream: {0}")] FailedToFetchFromUpstream(String), - // Cache errors, - #[cfg(feature = "cache")] - #[error("Invalid null request and/or response")] - NullRequestOrResponse, - - #[cfg(feature = "cache")] - #[error("Failed to write byte buffer")] - FailedToWriteByteBufferForCache, - - #[cfg(feature = "cache")] - #[error("Failed to acquire mutex lock for cache")] - FailedToAcquiredMutexLockForCache, - - #[cfg(feature = "cache")] - #[error("Failed to create file cache")] - FailedToCreateFileCache, - - #[cfg(feature = "cache")] - #[error("Failed to write file cache")] - FailedToWriteFileCache, - - #[cfg(feature = "cache")] - #[error("Failed to open cache file")] - FailedToOpenCacheFile, - // Upstream connection setting errors #[error("Unsupported upstream option")] UnsupportedUpstreamOption, + // Cache error map + #[cfg(feature = "cache")] + #[error("Cache error: {0}")] + CacheError(#[from] crate::forwarder::CacheError), + // Others #[error("Infallible")] Infallible(#[from] std::convert::Infallible), diff --git a/rpxy-lib/src/forwarder/cache/cache_error.rs b/rpxy-lib/src/forwarder/cache/cache_error.rs new file mode 100644 index 0000000..bb2ffa6 --- /dev/null +++ b/rpxy-lib/src/forwarder/cache/cache_error.rs @@ -0,0 +1,35 @@ +use thiserror::Error; + +pub type CacheResult = std::result::Result; + +/// Describes things that can go wrong in the Rpxy +#[derive(Debug, Error)] +pub enum CacheError { + // Cache errors, + #[error("Invalid null request and/or response")] + NullRequestOrResponse, + + #[error("Failed to write byte buffer")] + FailedToWriteByteBufferForCache, + + #[error("Failed to acquire mutex lock for cache")] + FailedToAcquiredMutexLockForCache, + + #[error("Failed to create file cache")] + FailedToCreateFileCache, + + #[error("Failed to write file cache")] + FailedToWriteFileCache, + + #[error("Failed to open cache file")] + FailedToOpenCacheFile, + + #[error("Too large to cache")] + TooLargeToCache, + + #[error("Failed to cache bytes: {0}")] + FailedToCacheBytes(String), + + #[error("Failed to send frame to cache {0}")] + FailedToSendFrameToCache(String), +} diff --git a/rpxy-lib/src/forwarder/cache.rs b/rpxy-lib/src/forwarder/cache/cache_main.rs similarity index 75% rename from rpxy-lib/src/forwarder/cache.rs rename to rpxy-lib/src/forwarder/cache/cache_main.rs index 03755e6..2bc4548 100644 --- a/rpxy-lib/src/forwarder/cache.rs +++ b/rpxy-lib/src/forwarder/cache/cache_main.rs @@ -1,8 +1,11 @@ -use crate::{error::*, globals::Globals, log::*}; +use super::cache_error::*; +use crate::{globals::Globals, hyper_ext::body::UnboundedStreamBody, log::*}; use bytes::{Buf, Bytes, BytesMut}; +use futures::channel::mpsc; use http::{Request, Response}; -use http_body_util::StreamBody; +use http_body_util::{BodyExt, StreamBody}; use http_cache_semantics::CachePolicy; +use hyper::body::{Body, Frame, Incoming}; use lru::LruCache; use std::{ convert::Infallible, @@ -69,6 +72,73 @@ impl RpxyCache { let on_memory = total - file; (total, on_memory, file) } + + /// Put response into the cache + pub async fn put( + &self, + uri: &hyper::Uri, + mut body: Incoming, + policy: &CachePolicy, + ) -> CacheResult { + let my_cache = self.inner.clone(); + let mut file_store = self.file_store.clone(); + let uri = uri.clone(); + let policy_clone = policy.clone(); + let max_each_size = self.max_each_size; + let max_each_size_on_memory = self.max_each_size_on_memory; + + let (body_tx, body_rx) = mpsc::unbounded::, hyper::Error>>(); + + self.runtime_handle.spawn(async move { + let mut size = 0usize; + loop { + let frame = match body.frame().await { + Some(frame) => frame, + None => { + debug!("Response body finished"); + break; + } + }; + let frame_size = frame.as_ref().map(|f| { + if f.is_data() { + f.data_ref().map(|bytes| bytes.remaining()).unwrap_or_default() + } else { + 0 + } + }); + size += frame_size.unwrap_or_default(); + + // check size + if size > max_each_size { + warn!("Too large to cache"); + return Err(CacheError::TooLargeToCache); + } + frame + .as_ref() + .map(|f| { + if f.is_data() { + let data_bytes = f.data_ref().unwrap().clone(); + println!("ddddde"); + // TODO: cache data bytes as file or on memory + // fileにするかmemoryにするかの判断はある程度までバッファしてやってという手を使うことになる。途中までキャッシュしたやつはどうするかとかいう判断も必要。 + // ファイルとObjectのbindをどうやってするか + } + }) + .map_err(|e| CacheError::FailedToCacheBytes(e.to_string()))?; + + // send data to use response downstream + body_tx + .unbounded_send(frame) + .map_err(|e| CacheError::FailedToSendFrameToCache(e.to_string()))?; + } + + Ok(()) as CacheResult<()> + }); + + let stream_body = StreamBody::new(body_rx); + + Ok(stream_body) + } } /* ---------------------------------------------- */ @@ -93,7 +163,7 @@ impl FileStore { inner.cnt } /// Create a temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> RpxyResult { + async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> CacheResult { let mut inner = self.inner.write().await; inner.create(cache_filename, body_bytes).await } @@ -106,7 +176,7 @@ impl FileStore { // }; // } // /// Read a temporary file cache - // async fn read(&self, path: impl AsRef) -> RpxyResult { + // async fn read(&self, path: impl AsRef) -> CacheResult { // let inner = self.inner.read().await; // inner.read(&path).await // } @@ -141,16 +211,16 @@ impl FileStoreInner { } /// Create a new temporary file cache - async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> RpxyResult { + async fn create(&mut self, cache_filename: &str, body_bytes: &Bytes) -> CacheResult { let cache_filepath = self.cache_dir.join(cache_filename); let Ok(mut file) = File::create(&cache_filepath).await else { - return Err(RpxyError::FailedToCreateFileCache); + return Err(CacheError::FailedToCreateFileCache); }; let mut bytes_clone = body_bytes.clone(); while bytes_clone.has_remaining() { if let Err(e) = file.write_buf(&mut bytes_clone).await { error!("Failed to write file cache: {e}"); - return Err(RpxyError::FailedToWriteFileCache); + return Err(CacheError::FailedToWriteFileCache); }; } self.cnt += 1; @@ -158,15 +228,14 @@ impl FileStoreInner { } /// Retrieve a stored temporary file cache - async fn read(&self, path: impl AsRef) -> RpxyResult<()> { + async fn read(&self, path: impl AsRef) -> CacheResult<()> { let Ok(mut file) = File::open(&path).await else { warn!("Cache file object cannot be opened"); - return Err(RpxyError::FailedToOpenCacheFile); + return Err(CacheError::FailedToOpenCacheFile); }; /* ----------------------------- */ // PoC for streaming body - use futures::channel::mpsc; let (tx, rx) = mpsc::unbounded::, Infallible>>(); // let (body_sender, res_body) = Body::channel(); @@ -263,10 +332,10 @@ impl LruCacheManager { } /// Push an entry - fn push(&self, cache_key: &str, cache_object: CacheObject) -> RpxyResult> { + fn push(&self, cache_key: &str, cache_object: CacheObject) -> CacheResult> { let Ok(mut lock) = self.inner.lock() else { error!("Failed to acquire mutex lock for writing cache entry"); - return Err(RpxyError::FailedToAcquiredMutexLockForCache); + return Err(CacheError::FailedToAcquiredMutexLockForCache); }; let res = Ok(lock.push(cache_key.to_string(), cache_object)); // This may be inconsistent with the actual number of entries @@ -280,13 +349,13 @@ impl LruCacheManager { pub fn get_policy_if_cacheable( req: Option<&Request>, res: Option<&Response>, -) -> RpxyResult> +) -> CacheResult> // where // B1: core::fmt::Debug, { // deduce cache policy from req and res let (Some(req), Some(res)) = (req, res) else { - return Err(RpxyError::NullRequestOrResponse); + return Err(CacheError::NullRequestOrResponse); }; let new_policy = CachePolicy::new(req, res); diff --git a/rpxy-lib/src/forwarder/cache/mod.rs b/rpxy-lib/src/forwarder/cache/mod.rs new file mode 100644 index 0000000..cfe5a1b --- /dev/null +++ b/rpxy-lib/src/forwarder/cache/mod.rs @@ -0,0 +1,5 @@ +mod cache_error; +mod cache_main; + +pub use cache_error::CacheError; +pub use cache_main::{get_policy_if_cacheable, CacheFileOrOnMemory, RpxyCache}; diff --git a/rpxy-lib/src/forwarder/client.rs b/rpxy-lib/src/forwarder/client.rs index 8b86f9f..8d2e307 100644 --- a/rpxy-lib/src/forwarder/client.rs +++ b/rpxy-lib/src/forwarder/client.rs @@ -1,14 +1,10 @@ use crate::{ error::{RpxyError, RpxyResult}, globals::Globals, - hyper_ext::{ - body::{wrap_incoming_body_response, BoxBody, IncomingOr}, - rt::LocalExecutor, - }, + hyper_ext::{body::ResponseBody, rt::LocalExecutor}, log::*, }; use async_trait::async_trait; -use chrono::Duration; use http::{Request, Response, Version}; use hyper::body::{Body, Incoming}; use hyper_util::client::legacy::{ @@ -19,10 +15,6 @@ use std::sync::Arc; #[cfg(feature = "cache")] use super::cache::{get_policy_if_cacheable, RpxyCache}; -#[cfg(feature = "cache")] -use crate::hyper_ext::body::{full, wrap_synthetic_body_response}; -#[cfg(feature = "cache")] -use http_body_util::BodyExt; #[async_trait] /// Definition of the forwarder that simply forward requests from downstream client to upstream app servers. @@ -40,7 +32,7 @@ pub struct Forwarder { } #[async_trait] -impl ForwardRequest> for Forwarder +impl ForwardRequest for Forwarder where C: Send + Sync + Connect + Clone + 'static, B1: Body + Send + Sync + Unpin + 'static, @@ -49,7 +41,7 @@ where { type Error = RpxyError; - async fn request(&self, req: Request) -> Result>, Self::Error> { + async fn request(&self, req: Request) -> Result, Self::Error> { // TODO: cache handling #[cfg(feature = "cache")] { @@ -67,38 +59,27 @@ where let res = self.request_directly(req).await; if self.cache.is_none() { - return res.map(wrap_incoming_body_response::); + return res.map(|inner| inner.map(ResponseBody::Incoming)); } // check cacheability and store it if cacheable let Ok(Some(cache_policy)) = get_policy_if_cacheable(synth_req.as_ref(), res.as_ref().ok()) else { - return res.map(wrap_incoming_body_response::); + return res.map(|inner| inner.map(ResponseBody::Incoming)); }; let (parts, body) = res.unwrap().into_parts(); - // TODO: This is inefficient since current strategy needs to copy the whole body onto memory to cache it. - // This should be handled by copying buffer simultaneously while forwarding response to downstream. - let Ok(bytes) = body.collect().await.map(|v| v.to_bytes()) else { - return Err(RpxyError::FailedToWriteByteBufferForCache); - }; - let bytes_clone = bytes.clone(); + // Get streamed body without waiting for the arrival of the body, + // which is done simultaneously with caching. + let stream_body = self + .cache + .as_ref() + .unwrap() + .put(synth_req.unwrap().uri(), body, &cache_policy) + .await?; - // TODO: this is inefficient. needs to be reconsidered to avoid unnecessary copy and should spawn async task to store cache. - // We may need to use the same logic as h3. - // Is bytes.clone() enough? - - // if let Err(cache_err) = self - // .cache - // .as_ref() - // .unwrap() - // .put(synth_req.unwrap().uri(), &bytes, &cache_policy) - // .await - // { - // error!("{:?}", cache_err); - // }; - - // response with cached body - Ok(wrap_synthetic_body_response(Response::from_parts(parts, full(bytes)))) + // response with body being cached in background + let new_res = Response::from_parts(parts, ResponseBody::Streamed(stream_body)); + Ok(new_res) } // No cache handling @@ -107,7 +88,7 @@ where self .request_directly(req) .await - .map(wrap_incoming_body_response::) + .map(|inner| inner.map(ResponseBody::Incoming)) } } } diff --git a/rpxy-lib/src/forwarder/mod.rs b/rpxy-lib/src/forwarder/mod.rs index 286cb40..d53cd73 100644 --- a/rpxy-lib/src/forwarder/mod.rs +++ b/rpxy-lib/src/forwarder/mod.rs @@ -3,6 +3,9 @@ mod cache; mod client; use crate::hyper_ext::body::{IncomingLike, IncomingOr}; -pub type Forwarder = client::Forwarder>; -pub use client::ForwardRequest; +pub(crate) type Forwarder = client::Forwarder>; +pub(crate) use client::ForwardRequest; + +#[cfg(feature = "cache")] +pub(crate) use cache::CacheError; diff --git a/rpxy-lib/src/hyper_ext/body_type.rs b/rpxy-lib/src/hyper_ext/body_type.rs index 9616306..c1eb54b 100644 --- a/rpxy-lib/src/hyper_ext/body_type.rs +++ b/rpxy-lib/src/hyper_ext/body_type.rs @@ -1,24 +1,25 @@ -use http::Response; +// use http::Response; use http_body_util::{combinators, BodyExt, Either, Empty, Full}; -use hyper::body::{Bytes, Incoming}; +use hyper::body::{Body, Bytes, Incoming}; +use std::pin::Pin; /// Type for synthetic boxed body pub(crate) type BoxBody = combinators::BoxBody; /// Type for either passthrough body or given body type, specifically synthetic boxed body pub(crate) type IncomingOr = Either; -/// helper function to build http response with passthrough body -pub(crate) fn wrap_incoming_body_response(response: Response) -> Response> -where - B: hyper::body::Body, -{ - response.map(IncomingOr::Left) -} +// /// helper function to build http response with passthrough body +// pub(crate) fn wrap_incoming_body_response(response: Response) -> Response> +// where +// B: hyper::body::Body, +// { +// response.map(IncomingOr::Left) +// } -/// helper function to build http response with synthetic body -pub(crate) fn wrap_synthetic_body_response(response: Response) -> Response> { - response.map(IncomingOr::Right) -} +// /// helper function to build http response with synthetic body +// pub(crate) fn wrap_synthetic_body_response(response: Response) -> Response> { +// response.map(IncomingOr::Right) +// } /// helper function to build a empty body pub(crate) fn empty() -> BoxBody { @@ -29,3 +30,43 @@ pub(crate) fn empty() -> BoxBody { pub(crate) fn full(body: Bytes) -> BoxBody { Full::new(body).map_err(|never| match never {}).boxed() } + +/* ------------------------------------ */ +#[cfg(feature = "cache")] +use futures::channel::mpsc::UnboundedReceiver; +#[cfg(feature = "cache")] +use http_body_util::StreamBody; +#[cfg(feature = "cache")] +use hyper::body::Frame; + +#[cfg(feature = "cache")] +pub(crate) type UnboundedStreamBody = StreamBody, hyper::Error>>>; + +/// Response body use in this project +/// - Incoming: just a type that only forwards the upstream response body to downstream. +/// - BoxedCache: a type that is generated from cache, e.g.,, small byte object. +/// - StreamedCache: another type that is generated from cache as stream, e.g., large byte object. +pub(crate) enum ResponseBody { + Incoming(Incoming), + Boxed(BoxBody), + #[cfg(feature = "cache")] + Streamed(UnboundedStreamBody), +} + +impl Body for ResponseBody { + type Data = bytes::Bytes; + type Error = hyper::Error; + + fn poll_frame( + self: Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll, Self::Error>>> { + match self.get_mut() { + ResponseBody::Incoming(incoming) => Pin::new(incoming).poll_frame(cx), + #[cfg(feature = "cache")] + ResponseBody::Boxed(boxed) => Pin::new(boxed).poll_frame(cx), + #[cfg(feature = "cache")] + ResponseBody::Streamed(streamed) => Pin::new(streamed).poll_frame(cx), + } + } +} diff --git a/rpxy-lib/src/hyper_ext/mod.rs b/rpxy-lib/src/hyper_ext/mod.rs index 922776c..8b3776c 100644 --- a/rpxy-lib/src/hyper_ext/mod.rs +++ b/rpxy-lib/src/hyper_ext/mod.rs @@ -12,7 +12,5 @@ pub(crate) mod rt { #[allow(unused)] pub(crate) mod body { pub(crate) use super::body_incoming_like::IncomingLike; - pub(crate) use super::body_type::{ - empty, full, wrap_incoming_body_response, wrap_synthetic_body_response, BoxBody, IncomingOr, - }; + pub(crate) use super::body_type::{empty, full, BoxBody, IncomingOr, ResponseBody, UnboundedStreamBody}; } diff --git a/rpxy-lib/src/message_handler/handler_main.rs b/rpxy-lib/src/message_handler/handler_main.rs index 251411b..b5ae87d 100644 --- a/rpxy-lib/src/message_handler/handler_main.rs +++ b/rpxy-lib/src/message_handler/handler_main.rs @@ -11,7 +11,7 @@ use crate::{ error::*, forwarder::{ForwardRequest, Forwarder}, globals::Globals, - hyper_ext::body::{BoxBody, IncomingLike, IncomingOr}, + hyper_ext::body::{IncomingLike, IncomingOr, ResponseBody}, log::*, name_exp::ServerName, }; @@ -58,7 +58,7 @@ where listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, - ) -> RpxyResult>> { + ) -> RpxyResult> { // preparing log data let mut log_data = HttpMessageLog::from(&req); log_data.client_addr(&client_addr); @@ -99,7 +99,7 @@ where listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, - ) -> HttpResult>> { + ) -> HttpResult> { // Here we start to inspect and parse with server_name let server_name = req .inspect_parse_host() diff --git a/rpxy-lib/src/message_handler/synthetic_response.rs b/rpxy-lib/src/message_handler/synthetic_response.rs index 60aeeec..a955a2d 100644 --- a/rpxy-lib/src/message_handler/synthetic_response.rs +++ b/rpxy-lib/src/message_handler/synthetic_response.rs @@ -1,16 +1,16 @@ use super::http_result::{HttpError, HttpResult}; use crate::{ error::*, - hyper_ext::body::{empty, BoxBody, IncomingOr}, + hyper_ext::body::{empty, ResponseBody}, name_exp::ServerName, }; use http::{Request, Response, StatusCode, Uri}; /// build http response with status code of 4xx and 5xx -pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult>> { +pub(crate) fn synthetic_error_response(status_code: StatusCode) -> RpxyResult> { let res = Response::builder() .status(status_code) - .body(IncomingOr::Right(empty())) + .body(ResponseBody::Boxed(empty())) .unwrap(); Ok(res) } @@ -20,7 +20,7 @@ pub(super) fn secure_redirection_response( server_name: &ServerName, tls_port: Option, req: &Request, -) -> HttpResult>> { +) -> HttpResult> { let server_name: String = server_name.try_into().unwrap_or_default(); let pq = match req.uri().path_and_query() { Some(x) => x.as_str(), @@ -36,7 +36,7 @@ pub(super) fn secure_redirection_response( let response = Response::builder() .status(StatusCode::MOVED_PERMANENTLY) .header("Location", dest_uri.to_string()) - .body(IncomingOr::Right(empty())) + .body(ResponseBody::Boxed(empty())) .map_err(|e| HttpError::FailedToRedirect(e.to_string()))?; Ok(response) } diff --git a/rpxy-lib/src/proxy/proxy_h3.rs b/rpxy-lib/src/proxy/proxy_h3.rs index 342c995..61328b2 100644 --- a/rpxy-lib/src/proxy/proxy_h3.rs +++ b/rpxy-lib/src/proxy/proxy_h3.rs @@ -138,7 +138,6 @@ where }); let new_req: Request> = Request::from_parts(req_parts, IncomingOr::Right(req_body)); - // Response> wrapped by RpxyResult let res = self .message_handler .handle_request( diff --git a/rpxy-lib/src/proxy/proxy_main.rs b/rpxy-lib/src/proxy/proxy_main.rs index 96ec0be..2d7a649 100644 --- a/rpxy-lib/src/proxy/proxy_main.rs +++ b/rpxy-lib/src/proxy/proxy_main.rs @@ -5,7 +5,7 @@ use crate::{ error::*, globals::Globals, hyper_ext::{ - body::{BoxBody, IncomingOr}, + body::{IncomingOr, ResponseBody}, rt::LocalExecutor, }, log::*, @@ -32,7 +32,7 @@ async fn serve_request( listen_addr: SocketAddr, tls_enabled: bool, tls_server_name: Option, -) -> RpxyResult>> +) -> RpxyResult> where T: Send + Sync + Connect + Clone, U: CryptoSource + Clone,