From 6b0976381c602166bcf3cfb3d940a78f2bb1640b Mon Sep 17 00:00:00 2001 From: "sijie.sun" Date: Thu, 14 Nov 2024 23:07:12 +0800 Subject: [PATCH] support compress for rpc and tun data --- Cargo.lock | 37 ++++- easytier-gui/src/composables/mobile_vpn.ts | 3 +- easytier-gui/src/composables/network.ts | 2 +- easytier/Cargo.toml | 2 + easytier/src/common/compressor.rs | 174 +++++++++++++++++++++ easytier/src/common/mod.rs | 1 + easytier/src/peers/peer_manager.rs | 58 +++---- easytier/src/peers/peer_map.rs | 13 ++ easytier/src/proto/common.proto | 19 ++- easytier/src/proto/common.rs | 25 +++ easytier/src/proto/rpc_impl/client.rs | 71 ++++++++- easytier/src/proto/rpc_impl/packet.rs | 60 +++++-- easytier/src/proto/rpc_impl/server.rs | 36 ++++- easytier/src/proto/tests.rs | 3 +- easytier/src/tunnel/common.rs | 2 +- easytier/src/tunnel/packet_def.rs | 59 ++++++- 16 files changed, 491 insertions(+), 74 deletions(-) create mode 100644 easytier/src/common/compressor.rs diff --git a/Cargo.lock b/Cargo.lock index 4c43aa2..e57b7ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -262,6 +262,20 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "async-compression" +version = "0.4.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857" +dependencies = [ + "futures-core", + "memchr", + "pin-project-lite", + "tokio", + "zstd 0.13.2", + "zstd-safe 7.2.1", +] + [[package]] name = "async-event" version = "0.2.1" @@ -1809,6 +1823,7 @@ version = "2.0.3" dependencies = [ "aes-gcm", "anyhow", + "async-compression", "async-recursion", "async-ringbuf", "async-stream", @@ -9474,7 +9489,7 @@ dependencies = [ "pbkdf2", "sha1", "time", - "zstd", + "zstd 0.11.2+zstd.1.5.2", ] [[package]] @@ -9483,7 +9498,16 @@ version = "0.11.2+zstd.1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" dependencies = [ - "zstd-safe", + "zstd-safe 5.0.2+zstd.1.5.2", +] + +[[package]] +name = "zstd" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9" +dependencies = [ + "zstd-safe 7.2.1", ] [[package]] @@ -9496,6 +9520,15 @@ dependencies = [ "zstd-sys", ] +[[package]] +name = "zstd-safe" +version = "7.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059" +dependencies = [ + "zstd-sys", +] + [[package]] name = "zstd-sys" version = "2.0.13+zstd.1.5.6" diff --git a/easytier-gui/src/composables/mobile_vpn.ts b/easytier-gui/src/composables/mobile_vpn.ts index 70a09a5..72a7478 100644 --- a/easytier-gui/src/composables/mobile_vpn.ts +++ b/easytier-gui/src/composables/mobile_vpn.ts @@ -1,6 +1,7 @@ +import type { NetworkTypes } from 'easytier-frontend-lib' import { addPluginListener } from '@tauri-apps/api/core' +import { Utils } from 'easytier-frontend-lib' import { prepare_vpn, start_vpn, stop_vpn } from 'tauri-plugin-vpnservice-api' -import { NetworkTypes, Utils } from 'easytier-frontend-lib' type Route = NetworkTypes.Route diff --git a/easytier-gui/src/composables/network.ts b/easytier-gui/src/composables/network.ts index ade1b1a..d603695 100644 --- a/easytier-gui/src/composables/network.ts +++ b/easytier-gui/src/composables/network.ts @@ -1,5 +1,5 @@ +import type { NetworkTypes } from 'easytier-frontend-lib' import { invoke } from '@tauri-apps/api/core' -import { NetworkTypes } from 'easytier-frontend-lib' type NetworkConfig = NetworkTypes.NetworkConfig type NetworkInstanceRunningInfo = NetworkTypes.NetworkInstanceRunningInfo diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 57e6146..26b682d 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -181,6 +181,8 @@ sys-locale = "0.3" ringbuf = "0.4.5" async-ringbuf = "0.3.1" +async-compression = { version = "0.4.17", default-features = false, features = ["zstd", "tokio"] } + [target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies] machine-uid = "0.5.3" diff --git a/easytier/src/common/compressor.rs b/easytier/src/common/compressor.rs new file mode 100644 index 0000000..8c82e6c --- /dev/null +++ b/easytier/src/common/compressor.rs @@ -0,0 +1,174 @@ +use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder}; +use tokio::io::AsyncWriteExt; + +use zerocopy::{AsBytes as _, FromBytes as _}; + +use crate::tunnel::packet_def::{ + CompressorAlgo, CompressorTail, PacketType, ZCPacket, COMPRESSOR_TAIL_SIZE, +}; + +type Error = anyhow::Error; + +#[async_trait::async_trait] +pub trait Compressor { + async fn compress( + &self, + packet: &mut ZCPacket, + compress_algo: CompressorAlgo, + ) -> Result<(), Error>; + async fn decompress(&self, packet: &mut ZCPacket) -> Result<(), Error>; +} + +pub struct DefaultCompressor {} + +impl DefaultCompressor { + pub fn new() -> Self { + DefaultCompressor {} + } + + pub async fn compress_raw( + &self, + data: &[u8], + compress_algo: CompressorAlgo, + ) -> Result, Error> { + let buf = match compress_algo { + CompressorAlgo::ZstdDefault => { + let mut o = ZstdEncoder::new(Vec::new()); + o.write_all(data).await?; + o.shutdown().await?; + o.into_inner() + } + CompressorAlgo::None => data.to_vec(), + }; + + Ok(buf) + } + + pub async fn decompress_raw( + &self, + data: &[u8], + compress_algo: CompressorAlgo, + ) -> Result, Error> { + let buf = match compress_algo { + CompressorAlgo::ZstdDefault => { + let mut o = ZstdDecoder::new(Vec::new()); + o.write_all(data).await?; + o.shutdown().await?; + o.into_inner() + } + CompressorAlgo::None => data.to_vec(), + }; + + Ok(buf) + } +} + +#[async_trait::async_trait] +impl Compressor for DefaultCompressor { + async fn compress( + &self, + zc_packet: &mut ZCPacket, + compress_algo: CompressorAlgo, + ) -> Result<(), Error> { + if matches!(compress_algo, CompressorAlgo::None) { + return Ok(()); + } + + let pm_header = zc_packet.peer_manager_header().unwrap(); + if pm_header.is_compressed() || pm_header.packet_type != PacketType::Data as u8 { + // only compress data packets + return Ok(()); + } + + let tail = CompressorTail::new(compress_algo); + let buf = self + .compress_raw(zc_packet.payload(), compress_algo) + .await?; + + if buf.len() + COMPRESSOR_TAIL_SIZE > pm_header.len.get() as usize { + // Compressed data is larger than original data, don't compress + return Ok(()); + } + + zc_packet + .mut_peer_manager_header() + .unwrap() + .set_compressed(true); + + let payload_offset = zc_packet.payload_offset(); + zc_packet.mut_inner().truncate(payload_offset); + zc_packet.mut_inner().extend_from_slice(&buf); + zc_packet.mut_inner().extend_from_slice(tail.as_bytes()); + + Ok(()) + } + + async fn decompress(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { + let pm_header = zc_packet.peer_manager_header().unwrap(); + if !pm_header.is_compressed() { + return Ok(()); + } + + let payload_len = zc_packet.payload().len(); + if payload_len < COMPRESSOR_TAIL_SIZE { + return Err(anyhow::anyhow!("Packet too short: {}", payload_len)); + } + + let text_len = payload_len - COMPRESSOR_TAIL_SIZE; + + let tail = CompressorTail::ref_from_suffix(zc_packet.payload()) + .unwrap() + .clone(); + + let algo = tail + .get_algo() + .ok_or(anyhow::anyhow!("Unknown algo: {:?}", tail))?; + + let buf = self + .decompress_raw(&zc_packet.payload()[..text_len], algo) + .await?; + + if buf.len() != pm_header.len.get() as usize { + anyhow::bail!( + "Decompressed length mismatch: decompressed len {} != pm header len {}", + buf.len(), + pm_header.len.get() + ); + } + + zc_packet + .mut_peer_manager_header() + .unwrap() + .set_compressed(false); + + let payload_offset = zc_packet.payload_offset(); + zc_packet.mut_inner().truncate(payload_offset); + zc_packet.mut_inner().extend_from_slice(&buf); + + Ok(()) + } +} + +#[cfg(test)] +pub mod tests { + use super::*; + + #[tokio::test] + async fn test_compress() { + let text = b"12345670000000000000000000"; + let mut packet = ZCPacket::new_with_payload(text); + packet.fill_peer_manager_hdr(0, 0, 0); + + let compressor = DefaultCompressor {}; + + compressor + .compress(&mut packet, CompressorAlgo::ZstdDefault) + .await + .unwrap(); + assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true); + + compressor.decompress(&mut packet).await.unwrap(); + assert_eq!(packet.payload(), text); + assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false); + } +} diff --git a/easytier/src/common/mod.rs b/easytier/src/common/mod.rs index dd9a72e..7c09493 100644 --- a/easytier/src/common/mod.rs +++ b/easytier/src/common/mod.rs @@ -6,6 +6,7 @@ use std::{ use tokio::task::JoinSet; use tracing::Instrument; +pub mod compressor; pub mod config; pub mod constants; pub mod defer; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index 7472e04..b9053a8 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -20,6 +20,7 @@ use tokio::{ use crate::{ common::{ + compressor::{Compressor as _, DefaultCompressor}, constants::EASYTIER_VERSION, error::Error, global_ctx::{ArcGlobalCtx, NetworkIdentity}, @@ -41,7 +42,7 @@ use crate::{ }, tunnel::{ self, - packet_def::{PacketType, ZCPacket}, + packet_def::{CompressorAlgo, PacketType, ZCPacket}, SinkItem, Tunnel, TunnelConnector, }, }; @@ -61,6 +62,7 @@ use super::{ struct RpcTransport { my_peer_id: PeerId, peers: Weak, + // TODO: this seems can be removed foreign_peers: Mutex>>, packet_recv: Mutex>, @@ -76,48 +78,14 @@ impl PeerRpcManagerTransport for RpcTransport { } async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { - let foreign_peers = self - .foreign_peers - .lock() - .await - .as_ref() - .ok_or(Error::Unknown)? - .upgrade() - .ok_or(Error::Unknown)?; let peers = self.peers.upgrade().ok_or(Error::Unknown)?; - - if foreign_peers.has_next_hop(dst_peer_id) { - // do not encrypt for data sending to public server - tracing::debug!( - ?dst_peer_id, - ?self.my_peer_id, - "failed to send msg to peer, try foreign network", - ); - foreign_peers.send_msg(msg, dst_peer_id).await - } else if let Some(gateway_id) = peers - .get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop) - .await - { - tracing::trace!( - ?dst_peer_id, - ?gateway_id, - ?self.my_peer_id, - "send msg to peer via gateway", - ); + if peers.need_relay_by_foreign_network(dst_peer_id).await? { self.encryptor .encrypt(&mut msg) .with_context(|| "encrypt failed")?; - if peers.has_peer(gateway_id) { - peers.send_msg_directly(msg, gateway_id).await - } else { - foreign_peers.send_msg(msg, gateway_id).await - } - } else { - Err(Error::RouteError(Some(format!( - "peermgr RpcTransport no route for dst_peer_id: {}", - dst_peer_id - )))) } + // send to self and this packet will be forwarded in peer_recv loop + peers.send_msg_directly(msg, self.my_peer_id).await } async fn recv(&self) -> Result { @@ -163,6 +131,7 @@ pub struct PeerManager { foreign_network_client: Arc, encryptor: Arc>, + data_compress_algo: CompressorAlgo, exit_nodes: Vec, } @@ -272,6 +241,8 @@ impl PeerManager { foreign_network_client, encryptor, + data_compress_algo: CompressorAlgo::None, + exit_nodes, } } @@ -465,6 +436,12 @@ impl PeerManager { continue; } + let compressor = DefaultCompressor {}; + if let Err(e) = compressor.decompress(&mut ret).await { + tracing::error!(?e, "decompress failed"); + continue; + } + let mut processed = false; let mut zc_packet = Some(ret); let mut idx = 0; @@ -768,6 +745,11 @@ impl PeerManager { tunnel::packet_def::PacketType::Data as u8, ); self.run_nic_packet_process_pipeline(&mut msg).await; + let compressor = DefaultCompressor {}; + compressor + .compress(&mut msg, self.data_compress_algo) + .await + .with_context(|| "compress failed")?; self.encryptor .encrypt(&mut msg) .with_context(|| "encrypt failed")?; diff --git a/easytier/src/peers/peer_map.rs b/easytier/src/peers/peer_map.rs index c104931..b5ec5d4 100644 --- a/easytier/src/peers/peer_map.rs +++ b/easytier/src/peers/peer_map.rs @@ -250,6 +250,19 @@ impl PeerMap { } route_map } + + pub async fn need_relay_by_foreign_network(&self, dst_peer_id: PeerId) -> Result { + // if gateway_peer_id is not connected to me, means need relay by foreign network + let gateway_id = self + .get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop) + .await + .ok_or(Error::RouteError(Some(format!( + "peer map need_relay_by_foreign_network no gateway for dst_peer_id: {}", + dst_peer_id + ))))?; + + Ok(!self.has_peer(gateway_id)) + } } impl Drop for PeerMap { diff --git a/easytier/src/proto/common.proto b/easytier/src/proto/common.proto index 4af791a..629fa36 100644 --- a/easytier/src/proto/common.proto +++ b/easytier/src/proto/common.proto @@ -32,7 +32,7 @@ message RpcDescriptor { } message RpcRequest { - RpcDescriptor descriptor = 1; + RpcDescriptor descriptor = 1 [ deprecated = true ]; bytes request = 2; int32 timeout_ms = 3; @@ -45,6 +45,21 @@ message RpcResponse { uint64 runtime_us = 3; } +enum CompressionAlgoPb { + Invalid = 0; + None = 1; + Zstd = 2; +} + +message RpcCompressionInfo { + // use this to compress the content + CompressionAlgoPb algo = 1; + + // tell the peer which compression algo is used to compress the next + // response/request + CompressionAlgoPb accepted_algo = 2; +} + message RpcPacket { uint32 from_peer = 1; uint32 to_peer = 2; @@ -58,6 +73,8 @@ message RpcPacket { uint32 piece_idx = 8; int32 trace_id = 9; + + RpcCompressionInfo compression_info = 10; } message Void {} diff --git a/easytier/src/proto/common.rs b/easytier/src/proto/common.rs index 482a5d9..d1b9d38 100644 --- a/easytier/src/proto/common.rs +++ b/easytier/src/proto/common.rs @@ -2,6 +2,8 @@ use std::{fmt::Display, str::FromStr}; use anyhow::Context; +use crate::tunnel::packet_def::CompressorAlgo; + include!(concat!(env!("OUT_DIR"), "/common.rs")); impl From for Uuid { @@ -180,3 +182,26 @@ impl From for std::net::SocketAddr { } } } + +impl TryFrom for CompressorAlgo { + type Error = anyhow::Error; + + fn try_from(value: CompressionAlgoPb) -> Result { + match value { + CompressionAlgoPb::Zstd => Ok(CompressorAlgo::ZstdDefault), + CompressionAlgoPb::None => Ok(CompressorAlgo::None), + _ => Err(anyhow::anyhow!("Invalid CompressionAlgoPb")), + } + } +} + +impl TryFrom for CompressionAlgoPb { + type Error = anyhow::Error; + + fn try_from(value: CompressorAlgo) -> Result { + match value { + CompressorAlgo::ZstdDefault => Ok(CompressionAlgoPb::Zstd), + CompressorAlgo::None => Ok(CompressionAlgoPb::None), + } + } +} diff --git a/easytier/src/proto/rpc_impl/client.rs b/easytier/src/proto/rpc_impl/client.rs index e8e161c..632e77c 100644 --- a/easytier/src/proto/rpc_impl/client.rs +++ b/easytier/src/proto/rpc_impl/client.rs @@ -12,8 +12,10 @@ use tokio_stream::StreamExt; use crate::common::PeerId; use crate::defer; -use crate::proto::common::{RpcDescriptor, RpcPacket, RpcRequest, RpcResponse}; -use crate::proto::rpc_impl::packet::build_rpc_packet; +use crate::proto::common::{ + CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse, +}; +use crate::proto::rpc_impl::packet::{build_rpc_packet, compress_packet, decompress_packet}; use crate::proto::rpc_types::controller::Controller; use crate::proto::rpc_types::descriptor::MethodDescriptor; use crate::proto::rpc_types::{ @@ -48,12 +50,21 @@ struct InflightRequest { start_time: std::time::Instant, } +#[derive(Debug, Clone, Default)] +struct PeerInfo { + peer_id: PeerId, + compression_info: RpcCompressionInfo, + last_active: Option, +} + type InflightRequestTable = Arc>; +type PeerInfoTable = Arc>; pub struct Client { mpsc: Mutex>>, transport: Mutex, inflight_requests: InflightRequestTable, + peer_info: PeerInfoTable, tasks: Arc>>, } @@ -64,6 +75,7 @@ impl Client { mpsc: Mutex::new(MpscTunnel::new(ring_a, None)), transport: Mutex::new(MpscTunnel::new(ring_b, None)), inflight_requests: Arc::new(DashMap::new()), + peer_info: Arc::new(DashMap::new()), tasks: Arc::new(Mutex::new(JoinSet::new())), } } @@ -79,6 +91,21 @@ impl Client { pub fn run(&self) { let mut tasks = self.tasks.lock().unwrap(); + let peer_infos = self.peer_info.clone(); + tasks.spawn(async move { + loop { + tokio::time::sleep(std::time::Duration::from_secs(30)).await; + let now = std::time::Instant::now(); + peer_infos.retain(|_, v| { + if let Some(last_active) = v.last_active { + return now.duration_since(last_active) + < std::time::Duration::from_secs(120); + } + true + }); + } + }); + let mut rx = self.mpsc.lock().unwrap().get_stream(); let inflight_requests = self.inflight_requests.clone(); tasks.spawn(async move { @@ -111,6 +138,8 @@ impl Client { continue; }; + tracing::trace!(?packet, "Received response packet"); + let ret = inflight_request.merger.feed(packet); match ret { Ok(Some(rpc_packet)) => { @@ -138,6 +167,7 @@ impl Client { to_peer_id: PeerId, zc_packet_sender: MpscTunnelSender, inflight_requests: InflightRequestTable, + peer_info: PeerInfoTable, _phan: PhantomData, } @@ -194,23 +224,53 @@ impl Client { }; let rpc_req = RpcRequest { - descriptor: Some(rpc_desc.clone()), request: input.into(), timeout_ms: ctrl.timeout_ms(), + ..Default::default() }; + let peer_info = self + .peer_info + .get(&self.to_peer_id) + .map(|v| v.clone()) + .unwrap_or_default(); + let (buf, c_algo) = compress_packet( + peer_info.compression_info.accepted_algo(), + &rpc_req.encode_to_vec(), + ) + .await + .unwrap(); + let packets = build_rpc_packet( self.from_peer_id, self.to_peer_id, rpc_desc, transaction_id, true, - &rpc_req.encode_to_vec(), + &buf, ctrl.trace_id(), + RpcCompressionInfo { + algo: c_algo.into(), + accepted_algo: CompressionAlgoPb::Zstd.into(), + }, ); let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64); - let rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??; + let mut rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??; + + if let Some(compression_info) = rpc_packet.compression_info { + self.peer_info.insert( + self.to_peer_id, + PeerInfo { + peer_id: self.to_peer_id, + compression_info: compression_info.clone(), + last_active: Some(std::time::Instant::now()), + }, + ); + + rpc_packet.body = + decompress_packet(compression_info.algo(), &rpc_packet.body).await?; + } assert_eq!(rpc_packet.transaction_id, transaction_id); @@ -230,6 +290,7 @@ impl Client { to_peer_id, zc_packet_sender: self.mpsc.lock().unwrap().get_sink(), inflight_requests: self.inflight_requests.clone(), + peer_info: self.peer_info.clone(), _phan: PhantomData, }) } diff --git a/easytier/src/proto/rpc_impl/packet.rs b/easytier/src/proto/rpc_impl/packet.rs index 527627e..fa97535 100644 --- a/easytier/src/proto/rpc_impl/packet.rs +++ b/easytier/src/proto/rpc_impl/packet.rs @@ -1,18 +1,44 @@ use prost::Message as _; use crate::{ - common::PeerId, + common::{compressor::DefaultCompressor, PeerId}, proto::{ - common::{RpcDescriptor, RpcPacket}, + common::{CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket}, rpc_types::error::Error, }, - tunnel::packet_def::{PacketType, ZCPacket}, + tunnel::packet_def::{CompressorAlgo, PacketType, ZCPacket}, }; use super::RpcTransactId; const RPC_PACKET_CONTENT_MTU: usize = 1300; +pub async fn compress_packet( + accepted_compression_algo: CompressionAlgoPb, + content: &[u8], +) -> Result<(Vec, CompressionAlgoPb), Error> { + let compressor = DefaultCompressor::new(); + let algo = accepted_compression_algo + .try_into() + .unwrap_or(CompressorAlgo::None); + let compressed = compressor.compress_raw(&content, algo).await?; + if compressed.len() >= content.len() { + Ok((content.to_vec(), CompressionAlgoPb::None)) + } else { + Ok((compressed, algo.try_into().unwrap())) + } +} + +pub async fn decompress_packet( + compression_algo: CompressionAlgoPb, + content: &[u8], +) -> Result, Error> { + let compressor = DefaultCompressor::new(); + let algo = compression_algo.try_into()?; + let decompressed = compressor.decompress_raw(&content, algo).await?; + Ok(decompressed) +} + pub struct PacketMerger { first_piece: Option, pieces: Vec, @@ -46,7 +72,8 @@ impl PacketMerger { body.extend_from_slice(&p.body); } - let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone(); + // only the first packet contains the complete info + let mut tmpl_packet = self.pieces[0].clone(); tmpl_packet.total_pieces = 1; tmpl_packet.piece_idx = 0; tmpl_packet.body = body; @@ -58,17 +85,17 @@ impl PacketMerger { let total_pieces = rpc_packet.total_pieces; let piece_idx = rpc_packet.piece_idx; - if rpc_packet.descriptor.is_none() { - return Err(Error::MalformatRpcPacket( - "descriptor is missing".to_owned(), - )); - } - // for compatibility with old version if total_pieces == 0 && piece_idx == 0 { return Ok(Some(rpc_packet)); } + if rpc_packet.piece_idx == 0 && rpc_packet.descriptor.is_none() { + return Err(Error::MalformatRpcPacket( + "descriptor is missing".to_owned(), + )); + } + // about 32MB max size if total_pieces > 32 * 1024 || total_pieces == 0 { return Err(Error::MalformatRpcPacket(format!( @@ -89,6 +116,7 @@ impl PacketMerger { { self.first_piece = Some(rpc_packet.clone()); self.pieces.clear(); + tracing::trace!(?rpc_packet, "got first piece"); } self.pieces @@ -113,6 +141,7 @@ pub fn build_rpc_packet( is_req: bool, content: &Vec, trace_id: i32, + compression_info: RpcCompressionInfo, ) -> Vec { let mut ret = Vec::new(); let content_mtu = RPC_PACKET_CONTENT_MTU; @@ -130,13 +159,22 @@ pub fn build_rpc_packet( let cur_packet = RpcPacket { from_peer, to_peer, - descriptor: Some(rpc_desc.clone()), + descriptor: if cur_offset == 0 { + Some(rpc_desc.clone()) + } else { + None + }, is_request: is_req, total_pieces: total_pieces as u32, piece_idx: (cur_offset / content_mtu) as u32, transaction_id, body: cur_content, trace_id, + compression_info: if cur_offset == 0 { + Some(compression_info.clone()) + } else { + None + }, }; cur_offset += cur_len; diff --git a/easytier/src/proto/rpc_impl/server.rs b/easytier/src/proto/rpc_impl/server.rs index bf14a7b..384954a 100644 --- a/easytier/src/proto/rpc_impl/server.rs +++ b/easytier/src/proto/rpc_impl/server.rs @@ -12,7 +12,10 @@ use tokio_stream::StreamExt; use crate::{ common::{join_joinset_background, PeerId}, proto::{ - common::{self, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse}, + common::{ + self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest, + RpcResponse, + }, rpc_types::error::Result, }, tunnel::{ @@ -23,7 +26,7 @@ use crate::{ }; use super::{ - packet::{build_rpc_packet, PacketMerger}, + packet::{build_rpc_packet, compress_packet, decompress_packet, PacketMerger}, service_registry::ServiceRegistry, RpcController, Transport, }; @@ -31,7 +34,6 @@ use super::{ #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct PacketMergerKey { from_peer_id: PeerId, - rpc_desc: RpcDescriptor, transaction_id: i64, } @@ -108,10 +110,11 @@ impl Server { let key = PacketMergerKey { from_peer_id: packet.from_peer, - rpc_desc: packet.descriptor.clone().unwrap_or_default(), transaction_id: packet.transaction_id, }; + tracing::trace!(?key, ?packet, "Received request packet"); + let ret = packet_merges .entry(key.clone()) .or_insert_with(PacketMerger::new) @@ -144,7 +147,16 @@ impl Server { } async fn handle_rpc_request(packet: RpcPacket, reg: Arc) -> Result { - let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?; + let body = if let Some(compression_info) = packet.compression_info { + decompress_packet( + compression_info.algo.try_into().unwrap_or_default(), + &packet.body, + ) + .await? + } else { + packet.body + }; + let rpc_request = RpcRequest::decode(Bytes::from(body))?; let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64); let ctrl = RpcController::default(); Ok(timeout( @@ -168,6 +180,7 @@ impl Server { let mut resp_msg = RpcResponse::default(); let now = std::time::Instant::now(); + let compression_info = packet.compression_info.clone(); let resp_bytes = Self::handle_rpc_request(packet, reg).await; match &resp_bytes { @@ -180,14 +193,25 @@ impl Server { }; resp_msg.runtime_us = now.elapsed().as_micros() as u64; + let (compressed_resp, algo) = compress_packet( + compression_info.unwrap_or_default().accepted_algo(), + &resp_msg.encode_to_vec(), + ) + .await + .unwrap(); + let packets = build_rpc_packet( to_peer, from_peer, desc, transaction_id, false, - &resp_msg.encode_to_vec(), + &compressed_resp, trace_id, + RpcCompressionInfo { + algo: algo.into(), + accepted_algo: CompressionAlgoPb::Zstd.into(), + }, ); for packet in packets { diff --git a/easytier/src/proto/tests.rs b/easytier/src/proto/tests.rs index 93c872f..1a2d45f 100644 --- a/easytier/src/proto/tests.rs +++ b/easytier/src/proto/tests.rs @@ -107,6 +107,7 @@ fn random_string(len: usize) -> String { #[tokio::test] async fn rpc_basic_test() { + // enable_log(); let ctx = TestContext::new(); let server = GreetingServer::new(GreetingService { @@ -119,7 +120,7 @@ async fn rpc_basic_test() { .client .scoped_client::>(1, 1, "".to_string()); - // small size req and resp + // // small size req and resp let ctrl = RpcController::default(); let input = SayHelloRequest { diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index da0bcbc..7d66b2e 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -603,7 +603,7 @@ pub mod tests { pub fn enable_log() { let filter = tracing_subscriber::EnvFilter::builder() - .with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into()) + .with_default_directive(tracing::level_filters::LevelFilter::TRACE.into()) .from_env() .unwrap() .add_directive("tarpc=error".parse().unwrap()); diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index 7e93c03..7f6c6a9 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -7,6 +7,10 @@ use zerocopy::FromZeroes; type DefaultEndian = LittleEndian; +const fn max(a: usize, b: usize) -> usize { + [a, b][(a < b) as usize] +} + // TCP TunnelHeader #[repr(C, packed)] #[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] @@ -49,11 +53,11 @@ pub enum PacketType { Invalid = 0, Data = 1, HandShake = 2, - RoutePacket = 3, + RoutePacket = 3, // deprecated Ping = 4, Pong = 5, - TaRpc = 6, - Route = 7, + TaRpc = 6, // deprecated + Route = 7, // deprecated RpcReq = 8, RpcResp = 9, ForeignNetworkPacket = 10, @@ -65,6 +69,7 @@ bitflags::bitflags! { const LATENCY_FIRST = 0b0000_0010; const EXIT_NODE = 0b0000_0100; const NO_PROXY = 0b0000_1000; + const COMPRESSED = 0b0001_0000; const _ = !0; } @@ -118,6 +123,12 @@ impl PeerManagerHeader { .contains(PeerManagerHeaderFlags::NO_PROXY) } + pub fn is_compressed(&self) -> bool { + PeerManagerHeaderFlags::from_bits(self.flags) + .unwrap() + .contains(PeerManagerHeaderFlags::COMPRESSED) + } + pub fn set_latency_first(&mut self, latency_first: bool) -> &mut Self { let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap(); if latency_first { @@ -150,6 +161,17 @@ impl PeerManagerHeader { self.flags = flags.bits(); self } + + pub fn set_compressed(&mut self, compressed: bool) -> &mut Self { + let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap(); + if compressed { + flags.insert(PeerManagerHeaderFlags::COMPRESSED); + } else { + flags.remove(PeerManagerHeaderFlags::COMPRESSED); + } + self.flags = flags.bits(); + self + } } #[repr(C, packed)] @@ -201,12 +223,35 @@ pub struct AesGcmTail { } pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::(); -pub const TAIL_RESERVED_SIZE: usize = AES_GCM_ENCRYPTION_RESERVED; - -const fn max(a: usize, b: usize) -> usize { - [a, b][(a < b) as usize] +#[derive(AsBytes, FromZeroes, Clone, Debug, Copy)] +#[repr(u8)] +pub enum CompressorAlgo { + None = 0, + ZstdDefault = 1, } +#[repr(C, packed)] +#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] +pub struct CompressorTail { + pub algo: u8, +} +pub const COMPRESSOR_TAIL_SIZE: usize = std::mem::size_of::(); + +impl CompressorTail { + pub fn get_algo(&self) -> Option { + match self.algo { + 1 => Some(CompressorAlgo::ZstdDefault), + _ => None, + } + } + + pub fn new(algo: CompressorAlgo) -> Self { + Self { algo: algo as u8 } + } +} + +pub const TAIL_RESERVED_SIZE: usize = max(AES_GCM_ENCRYPTION_RESERVED, COMPRESSOR_TAIL_SIZE); + #[derive(Default, Debug)] pub struct ZCPacketOffsets { pub payload_offset: usize,