mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-16 03:32:43 +08:00
support compress for rpc and tun data
This commit is contained in:
parent
4fc3ff8ce8
commit
6b0976381c
37
Cargo.lock
generated
37
Cargo.lock
generated
|
@ -262,6 +262,20 @@ dependencies = [
|
|||
"pin-project-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-compression"
|
||||
version = "0.4.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"tokio",
|
||||
"zstd 0.13.2",
|
||||
"zstd-safe 7.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "async-event"
|
||||
version = "0.2.1"
|
||||
|
@ -1809,6 +1823,7 @@ version = "2.0.3"
|
|||
dependencies = [
|
||||
"aes-gcm",
|
||||
"anyhow",
|
||||
"async-compression",
|
||||
"async-recursion",
|
||||
"async-ringbuf",
|
||||
"async-stream",
|
||||
|
@ -9474,7 +9489,7 @@ dependencies = [
|
|||
"pbkdf2",
|
||||
"sha1",
|
||||
"time",
|
||||
"zstd",
|
||||
"zstd 0.11.2+zstd.1.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -9483,7 +9498,16 @@ version = "0.11.2+zstd.1.5.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
|
||||
dependencies = [
|
||||
"zstd-safe",
|
||||
"zstd-safe 5.0.2+zstd.1.5.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
|
||||
dependencies = [
|
||||
"zstd-safe 7.2.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -9496,6 +9520,15 @@ dependencies = [
|
|||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-safe"
|
||||
version = "7.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
|
||||
dependencies = [
|
||||
"zstd-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zstd-sys"
|
||||
version = "2.0.13+zstd.1.5.6"
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import type { NetworkTypes } from 'easytier-frontend-lib'
|
||||
import { addPluginListener } from '@tauri-apps/api/core'
|
||||
import { Utils } from 'easytier-frontend-lib'
|
||||
import { prepare_vpn, start_vpn, stop_vpn } from 'tauri-plugin-vpnservice-api'
|
||||
import { NetworkTypes, Utils } from 'easytier-frontend-lib'
|
||||
|
||||
type Route = NetworkTypes.Route
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import type { NetworkTypes } from 'easytier-frontend-lib'
|
||||
import { invoke } from '@tauri-apps/api/core'
|
||||
import { NetworkTypes } from 'easytier-frontend-lib'
|
||||
|
||||
type NetworkConfig = NetworkTypes.NetworkConfig
|
||||
type NetworkInstanceRunningInfo = NetworkTypes.NetworkInstanceRunningInfo
|
||||
|
|
|
@ -181,6 +181,8 @@ sys-locale = "0.3"
|
|||
ringbuf = "0.4.5"
|
||||
async-ringbuf = "0.3.1"
|
||||
|
||||
async-compression = { version = "0.4.17", default-features = false, features = ["zstd", "tokio"] }
|
||||
|
||||
[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies]
|
||||
machine-uid = "0.5.3"
|
||||
|
||||
|
|
174
easytier/src/common/compressor.rs
Normal file
174
easytier/src/common/compressor.rs
Normal file
|
@ -0,0 +1,174 @@
|
|||
use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder};
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use zerocopy::{AsBytes as _, FromBytes as _};
|
||||
|
||||
use crate::tunnel::packet_def::{
|
||||
CompressorAlgo, CompressorTail, PacketType, ZCPacket, COMPRESSOR_TAIL_SIZE,
|
||||
};
|
||||
|
||||
type Error = anyhow::Error;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait Compressor {
|
||||
async fn compress(
|
||||
&self,
|
||||
packet: &mut ZCPacket,
|
||||
compress_algo: CompressorAlgo,
|
||||
) -> Result<(), Error>;
|
||||
async fn decompress(&self, packet: &mut ZCPacket) -> Result<(), Error>;
|
||||
}
|
||||
|
||||
pub struct DefaultCompressor {}
|
||||
|
||||
impl DefaultCompressor {
|
||||
pub fn new() -> Self {
|
||||
DefaultCompressor {}
|
||||
}
|
||||
|
||||
pub async fn compress_raw(
|
||||
&self,
|
||||
data: &[u8],
|
||||
compress_algo: CompressorAlgo,
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
let buf = match compress_algo {
|
||||
CompressorAlgo::ZstdDefault => {
|
||||
let mut o = ZstdEncoder::new(Vec::new());
|
||||
o.write_all(data).await?;
|
||||
o.shutdown().await?;
|
||||
o.into_inner()
|
||||
}
|
||||
CompressorAlgo::None => data.to_vec(),
|
||||
};
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
|
||||
pub async fn decompress_raw(
|
||||
&self,
|
||||
data: &[u8],
|
||||
compress_algo: CompressorAlgo,
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
let buf = match compress_algo {
|
||||
CompressorAlgo::ZstdDefault => {
|
||||
let mut o = ZstdDecoder::new(Vec::new());
|
||||
o.write_all(data).await?;
|
||||
o.shutdown().await?;
|
||||
o.into_inner()
|
||||
}
|
||||
CompressorAlgo::None => data.to_vec(),
|
||||
};
|
||||
|
||||
Ok(buf)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Compressor for DefaultCompressor {
|
||||
async fn compress(
|
||||
&self,
|
||||
zc_packet: &mut ZCPacket,
|
||||
compress_algo: CompressorAlgo,
|
||||
) -> Result<(), Error> {
|
||||
if matches!(compress_algo, CompressorAlgo::None) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let pm_header = zc_packet.peer_manager_header().unwrap();
|
||||
if pm_header.is_compressed() || pm_header.packet_type != PacketType::Data as u8 {
|
||||
// only compress data packets
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let tail = CompressorTail::new(compress_algo);
|
||||
let buf = self
|
||||
.compress_raw(zc_packet.payload(), compress_algo)
|
||||
.await?;
|
||||
|
||||
if buf.len() + COMPRESSOR_TAIL_SIZE > pm_header.len.get() as usize {
|
||||
// Compressed data is larger than original data, don't compress
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
zc_packet
|
||||
.mut_peer_manager_header()
|
||||
.unwrap()
|
||||
.set_compressed(true);
|
||||
|
||||
let payload_offset = zc_packet.payload_offset();
|
||||
zc_packet.mut_inner().truncate(payload_offset);
|
||||
zc_packet.mut_inner().extend_from_slice(&buf);
|
||||
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn decompress(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
|
||||
let pm_header = zc_packet.peer_manager_header().unwrap();
|
||||
if !pm_header.is_compressed() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let payload_len = zc_packet.payload().len();
|
||||
if payload_len < COMPRESSOR_TAIL_SIZE {
|
||||
return Err(anyhow::anyhow!("Packet too short: {}", payload_len));
|
||||
}
|
||||
|
||||
let text_len = payload_len - COMPRESSOR_TAIL_SIZE;
|
||||
|
||||
let tail = CompressorTail::ref_from_suffix(zc_packet.payload())
|
||||
.unwrap()
|
||||
.clone();
|
||||
|
||||
let algo = tail
|
||||
.get_algo()
|
||||
.ok_or(anyhow::anyhow!("Unknown algo: {:?}", tail))?;
|
||||
|
||||
let buf = self
|
||||
.decompress_raw(&zc_packet.payload()[..text_len], algo)
|
||||
.await?;
|
||||
|
||||
if buf.len() != pm_header.len.get() as usize {
|
||||
anyhow::bail!(
|
||||
"Decompressed length mismatch: decompressed len {} != pm header len {}",
|
||||
buf.len(),
|
||||
pm_header.len.get()
|
||||
);
|
||||
}
|
||||
|
||||
zc_packet
|
||||
.mut_peer_manager_header()
|
||||
.unwrap()
|
||||
.set_compressed(false);
|
||||
|
||||
let payload_offset = zc_packet.payload_offset();
|
||||
zc_packet.mut_inner().truncate(payload_offset);
|
||||
zc_packet.mut_inner().extend_from_slice(&buf);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_compress() {
|
||||
let text = b"12345670000000000000000000";
|
||||
let mut packet = ZCPacket::new_with_payload(text);
|
||||
packet.fill_peer_manager_hdr(0, 0, 0);
|
||||
|
||||
let compressor = DefaultCompressor {};
|
||||
|
||||
compressor
|
||||
.compress(&mut packet, CompressorAlgo::ZstdDefault)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true);
|
||||
|
||||
compressor.decompress(&mut packet).await.unwrap();
|
||||
assert_eq!(packet.payload(), text);
|
||||
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
|
||||
}
|
||||
}
|
|
@ -6,6 +6,7 @@ use std::{
|
|||
use tokio::task::JoinSet;
|
||||
use tracing::Instrument;
|
||||
|
||||
pub mod compressor;
|
||||
pub mod config;
|
||||
pub mod constants;
|
||||
pub mod defer;
|
||||
|
|
|
@ -20,6 +20,7 @@ use tokio::{
|
|||
|
||||
use crate::{
|
||||
common::{
|
||||
compressor::{Compressor as _, DefaultCompressor},
|
||||
constants::EASYTIER_VERSION,
|
||||
error::Error,
|
||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||
|
@ -41,7 +42,7 @@ use crate::{
|
|||
},
|
||||
tunnel::{
|
||||
self,
|
||||
packet_def::{PacketType, ZCPacket},
|
||||
packet_def::{CompressorAlgo, PacketType, ZCPacket},
|
||||
SinkItem, Tunnel, TunnelConnector,
|
||||
},
|
||||
};
|
||||
|
@ -61,6 +62,7 @@ use super::{
|
|||
struct RpcTransport {
|
||||
my_peer_id: PeerId,
|
||||
peers: Weak<PeerMap>,
|
||||
// TODO: this seems can be removed
|
||||
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
|
||||
|
||||
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
|
||||
|
@ -76,48 +78,14 @@ impl PeerRpcManagerTransport for RpcTransport {
|
|||
}
|
||||
|
||||
async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||
let foreign_peers = self
|
||||
.foreign_peers
|
||||
.lock()
|
||||
.await
|
||||
.as_ref()
|
||||
.ok_or(Error::Unknown)?
|
||||
.upgrade()
|
||||
.ok_or(Error::Unknown)?;
|
||||
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||
|
||||
if foreign_peers.has_next_hop(dst_peer_id) {
|
||||
// do not encrypt for data sending to public server
|
||||
tracing::debug!(
|
||||
?dst_peer_id,
|
||||
?self.my_peer_id,
|
||||
"failed to send msg to peer, try foreign network",
|
||||
);
|
||||
foreign_peers.send_msg(msg, dst_peer_id).await
|
||||
} else if let Some(gateway_id) = peers
|
||||
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
|
||||
.await
|
||||
{
|
||||
tracing::trace!(
|
||||
?dst_peer_id,
|
||||
?gateway_id,
|
||||
?self.my_peer_id,
|
||||
"send msg to peer via gateway",
|
||||
);
|
||||
if peers.need_relay_by_foreign_network(dst_peer_id).await? {
|
||||
self.encryptor
|
||||
.encrypt(&mut msg)
|
||||
.with_context(|| "encrypt failed")?;
|
||||
if peers.has_peer(gateway_id) {
|
||||
peers.send_msg_directly(msg, gateway_id).await
|
||||
} else {
|
||||
foreign_peers.send_msg(msg, gateway_id).await
|
||||
}
|
||||
} else {
|
||||
Err(Error::RouteError(Some(format!(
|
||||
"peermgr RpcTransport no route for dst_peer_id: {}",
|
||||
dst_peer_id
|
||||
))))
|
||||
}
|
||||
// send to self and this packet will be forwarded in peer_recv loop
|
||||
peers.send_msg_directly(msg, self.my_peer_id).await
|
||||
}
|
||||
|
||||
async fn recv(&self) -> Result<ZCPacket, Error> {
|
||||
|
@ -163,6 +131,7 @@ pub struct PeerManager {
|
|||
foreign_network_client: Arc<ForeignNetworkClient>,
|
||||
|
||||
encryptor: Arc<Box<dyn Encryptor>>,
|
||||
data_compress_algo: CompressorAlgo,
|
||||
|
||||
exit_nodes: Vec<Ipv4Addr>,
|
||||
}
|
||||
|
@ -272,6 +241,8 @@ impl PeerManager {
|
|||
foreign_network_client,
|
||||
|
||||
encryptor,
|
||||
data_compress_algo: CompressorAlgo::None,
|
||||
|
||||
exit_nodes,
|
||||
}
|
||||
}
|
||||
|
@ -465,6 +436,12 @@ impl PeerManager {
|
|||
continue;
|
||||
}
|
||||
|
||||
let compressor = DefaultCompressor {};
|
||||
if let Err(e) = compressor.decompress(&mut ret).await {
|
||||
tracing::error!(?e, "decompress failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut processed = false;
|
||||
let mut zc_packet = Some(ret);
|
||||
let mut idx = 0;
|
||||
|
@ -768,6 +745,11 @@ impl PeerManager {
|
|||
tunnel::packet_def::PacketType::Data as u8,
|
||||
);
|
||||
self.run_nic_packet_process_pipeline(&mut msg).await;
|
||||
let compressor = DefaultCompressor {};
|
||||
compressor
|
||||
.compress(&mut msg, self.data_compress_algo)
|
||||
.await
|
||||
.with_context(|| "compress failed")?;
|
||||
self.encryptor
|
||||
.encrypt(&mut msg)
|
||||
.with_context(|| "encrypt failed")?;
|
||||
|
|
|
@ -250,6 +250,19 @@ impl PeerMap {
|
|||
}
|
||||
route_map
|
||||
}
|
||||
|
||||
pub async fn need_relay_by_foreign_network(&self, dst_peer_id: PeerId) -> Result<bool, Error> {
|
||||
// if gateway_peer_id is not connected to me, means need relay by foreign network
|
||||
let gateway_id = self
|
||||
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
|
||||
.await
|
||||
.ok_or(Error::RouteError(Some(format!(
|
||||
"peer map need_relay_by_foreign_network no gateway for dst_peer_id: {}",
|
||||
dst_peer_id
|
||||
))))?;
|
||||
|
||||
Ok(!self.has_peer(gateway_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PeerMap {
|
||||
|
|
|
@ -32,7 +32,7 @@ message RpcDescriptor {
|
|||
}
|
||||
|
||||
message RpcRequest {
|
||||
RpcDescriptor descriptor = 1;
|
||||
RpcDescriptor descriptor = 1 [ deprecated = true ];
|
||||
|
||||
bytes request = 2;
|
||||
int32 timeout_ms = 3;
|
||||
|
@ -45,6 +45,21 @@ message RpcResponse {
|
|||
uint64 runtime_us = 3;
|
||||
}
|
||||
|
||||
enum CompressionAlgoPb {
|
||||
Invalid = 0;
|
||||
None = 1;
|
||||
Zstd = 2;
|
||||
}
|
||||
|
||||
message RpcCompressionInfo {
|
||||
// use this to compress the content
|
||||
CompressionAlgoPb algo = 1;
|
||||
|
||||
// tell the peer which compression algo is used to compress the next
|
||||
// response/request
|
||||
CompressionAlgoPb accepted_algo = 2;
|
||||
}
|
||||
|
||||
message RpcPacket {
|
||||
uint32 from_peer = 1;
|
||||
uint32 to_peer = 2;
|
||||
|
@ -58,6 +73,8 @@ message RpcPacket {
|
|||
uint32 piece_idx = 8;
|
||||
|
||||
int32 trace_id = 9;
|
||||
|
||||
RpcCompressionInfo compression_info = 10;
|
||||
}
|
||||
|
||||
message Void {}
|
||||
|
|
|
@ -2,6 +2,8 @@ use std::{fmt::Display, str::FromStr};
|
|||
|
||||
use anyhow::Context;
|
||||
|
||||
use crate::tunnel::packet_def::CompressorAlgo;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/common.rs"));
|
||||
|
||||
impl From<uuid::Uuid> for Uuid {
|
||||
|
@ -180,3 +182,26 @@ impl From<SocketAddr> for std::net::SocketAddr {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<CompressionAlgoPb> for CompressorAlgo {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: CompressionAlgoPb) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
CompressionAlgoPb::Zstd => Ok(CompressorAlgo::ZstdDefault),
|
||||
CompressionAlgoPb::None => Ok(CompressorAlgo::None),
|
||||
_ => Err(anyhow::anyhow!("Invalid CompressionAlgoPb")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<CompressorAlgo> for CompressionAlgoPb {
|
||||
type Error = anyhow::Error;
|
||||
|
||||
fn try_from(value: CompressorAlgo) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
CompressorAlgo::ZstdDefault => Ok(CompressionAlgoPb::Zstd),
|
||||
CompressorAlgo::None => Ok(CompressionAlgoPb::None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,8 +12,10 @@ use tokio_stream::StreamExt;
|
|||
|
||||
use crate::common::PeerId;
|
||||
use crate::defer;
|
||||
use crate::proto::common::{RpcDescriptor, RpcPacket, RpcRequest, RpcResponse};
|
||||
use crate::proto::rpc_impl::packet::build_rpc_packet;
|
||||
use crate::proto::common::{
|
||||
CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse,
|
||||
};
|
||||
use crate::proto::rpc_impl::packet::{build_rpc_packet, compress_packet, decompress_packet};
|
||||
use crate::proto::rpc_types::controller::Controller;
|
||||
use crate::proto::rpc_types::descriptor::MethodDescriptor;
|
||||
use crate::proto::rpc_types::{
|
||||
|
@ -48,12 +50,21 @@ struct InflightRequest {
|
|||
start_time: std::time::Instant,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct PeerInfo {
|
||||
peer_id: PeerId,
|
||||
compression_info: RpcCompressionInfo,
|
||||
last_active: Option<std::time::Instant>,
|
||||
}
|
||||
|
||||
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
|
||||
type PeerInfoTable = Arc<DashMap<PeerId, PeerInfo>>;
|
||||
|
||||
pub struct Client {
|
||||
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
|
||||
transport: Mutex<Transport>,
|
||||
inflight_requests: InflightRequestTable,
|
||||
peer_info: PeerInfoTable,
|
||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
|
@ -64,6 +75,7 @@ impl Client {
|
|||
mpsc: Mutex::new(MpscTunnel::new(ring_a, None)),
|
||||
transport: Mutex::new(MpscTunnel::new(ring_b, None)),
|
||||
inflight_requests: Arc::new(DashMap::new()),
|
||||
peer_info: Arc::new(DashMap::new()),
|
||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||
}
|
||||
}
|
||||
|
@ -79,6 +91,21 @@ impl Client {
|
|||
pub fn run(&self) {
|
||||
let mut tasks = self.tasks.lock().unwrap();
|
||||
|
||||
let peer_infos = self.peer_info.clone();
|
||||
tasks.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
||||
let now = std::time::Instant::now();
|
||||
peer_infos.retain(|_, v| {
|
||||
if let Some(last_active) = v.last_active {
|
||||
return now.duration_since(last_active)
|
||||
< std::time::Duration::from_secs(120);
|
||||
}
|
||||
true
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let mut rx = self.mpsc.lock().unwrap().get_stream();
|
||||
let inflight_requests = self.inflight_requests.clone();
|
||||
tasks.spawn(async move {
|
||||
|
@ -111,6 +138,8 @@ impl Client {
|
|||
continue;
|
||||
};
|
||||
|
||||
tracing::trace!(?packet, "Received response packet");
|
||||
|
||||
let ret = inflight_request.merger.feed(packet);
|
||||
match ret {
|
||||
Ok(Some(rpc_packet)) => {
|
||||
|
@ -138,6 +167,7 @@ impl Client {
|
|||
to_peer_id: PeerId,
|
||||
zc_packet_sender: MpscTunnelSender,
|
||||
inflight_requests: InflightRequestTable,
|
||||
peer_info: PeerInfoTable,
|
||||
_phan: PhantomData<F>,
|
||||
}
|
||||
|
||||
|
@ -194,23 +224,53 @@ impl Client {
|
|||
};
|
||||
|
||||
let rpc_req = RpcRequest {
|
||||
descriptor: Some(rpc_desc.clone()),
|
||||
request: input.into(),
|
||||
timeout_ms: ctrl.timeout_ms(),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let peer_info = self
|
||||
.peer_info
|
||||
.get(&self.to_peer_id)
|
||||
.map(|v| v.clone())
|
||||
.unwrap_or_default();
|
||||
let (buf, c_algo) = compress_packet(
|
||||
peer_info.compression_info.accepted_algo(),
|
||||
&rpc_req.encode_to_vec(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let packets = build_rpc_packet(
|
||||
self.from_peer_id,
|
||||
self.to_peer_id,
|
||||
rpc_desc,
|
||||
transaction_id,
|
||||
true,
|
||||
&rpc_req.encode_to_vec(),
|
||||
&buf,
|
||||
ctrl.trace_id(),
|
||||
RpcCompressionInfo {
|
||||
algo: c_algo.into(),
|
||||
accepted_algo: CompressionAlgoPb::Zstd.into(),
|
||||
},
|
||||
);
|
||||
|
||||
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
|
||||
let rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
|
||||
let mut rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
|
||||
|
||||
if let Some(compression_info) = rpc_packet.compression_info {
|
||||
self.peer_info.insert(
|
||||
self.to_peer_id,
|
||||
PeerInfo {
|
||||
peer_id: self.to_peer_id,
|
||||
compression_info: compression_info.clone(),
|
||||
last_active: Some(std::time::Instant::now()),
|
||||
},
|
||||
);
|
||||
|
||||
rpc_packet.body =
|
||||
decompress_packet(compression_info.algo(), &rpc_packet.body).await?;
|
||||
}
|
||||
|
||||
assert_eq!(rpc_packet.transaction_id, transaction_id);
|
||||
|
||||
|
@ -230,6 +290,7 @@ impl Client {
|
|||
to_peer_id,
|
||||
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
|
||||
inflight_requests: self.inflight_requests.clone(),
|
||||
peer_info: self.peer_info.clone(),
|
||||
_phan: PhantomData,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,18 +1,44 @@
|
|||
use prost::Message as _;
|
||||
|
||||
use crate::{
|
||||
common::PeerId,
|
||||
common::{compressor::DefaultCompressor, PeerId},
|
||||
proto::{
|
||||
common::{RpcDescriptor, RpcPacket},
|
||||
common::{CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket},
|
||||
rpc_types::error::Error,
|
||||
},
|
||||
tunnel::packet_def::{PacketType, ZCPacket},
|
||||
tunnel::packet_def::{CompressorAlgo, PacketType, ZCPacket},
|
||||
};
|
||||
|
||||
use super::RpcTransactId;
|
||||
|
||||
const RPC_PACKET_CONTENT_MTU: usize = 1300;
|
||||
|
||||
pub async fn compress_packet(
|
||||
accepted_compression_algo: CompressionAlgoPb,
|
||||
content: &[u8],
|
||||
) -> Result<(Vec<u8>, CompressionAlgoPb), Error> {
|
||||
let compressor = DefaultCompressor::new();
|
||||
let algo = accepted_compression_algo
|
||||
.try_into()
|
||||
.unwrap_or(CompressorAlgo::None);
|
||||
let compressed = compressor.compress_raw(&content, algo).await?;
|
||||
if compressed.len() >= content.len() {
|
||||
Ok((content.to_vec(), CompressionAlgoPb::None))
|
||||
} else {
|
||||
Ok((compressed, algo.try_into().unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn decompress_packet(
|
||||
compression_algo: CompressionAlgoPb,
|
||||
content: &[u8],
|
||||
) -> Result<Vec<u8>, Error> {
|
||||
let compressor = DefaultCompressor::new();
|
||||
let algo = compression_algo.try_into()?;
|
||||
let decompressed = compressor.decompress_raw(&content, algo).await?;
|
||||
Ok(decompressed)
|
||||
}
|
||||
|
||||
pub struct PacketMerger {
|
||||
first_piece: Option<RpcPacket>,
|
||||
pieces: Vec<RpcPacket>,
|
||||
|
@ -46,7 +72,8 @@ impl PacketMerger {
|
|||
body.extend_from_slice(&p.body);
|
||||
}
|
||||
|
||||
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone();
|
||||
// only the first packet contains the complete info
|
||||
let mut tmpl_packet = self.pieces[0].clone();
|
||||
tmpl_packet.total_pieces = 1;
|
||||
tmpl_packet.piece_idx = 0;
|
||||
tmpl_packet.body = body;
|
||||
|
@ -58,17 +85,17 @@ impl PacketMerger {
|
|||
let total_pieces = rpc_packet.total_pieces;
|
||||
let piece_idx = rpc_packet.piece_idx;
|
||||
|
||||
if rpc_packet.descriptor.is_none() {
|
||||
return Err(Error::MalformatRpcPacket(
|
||||
"descriptor is missing".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
// for compatibility with old version
|
||||
if total_pieces == 0 && piece_idx == 0 {
|
||||
return Ok(Some(rpc_packet));
|
||||
}
|
||||
|
||||
if rpc_packet.piece_idx == 0 && rpc_packet.descriptor.is_none() {
|
||||
return Err(Error::MalformatRpcPacket(
|
||||
"descriptor is missing".to_owned(),
|
||||
));
|
||||
}
|
||||
|
||||
// about 32MB max size
|
||||
if total_pieces > 32 * 1024 || total_pieces == 0 {
|
||||
return Err(Error::MalformatRpcPacket(format!(
|
||||
|
@ -89,6 +116,7 @@ impl PacketMerger {
|
|||
{
|
||||
self.first_piece = Some(rpc_packet.clone());
|
||||
self.pieces.clear();
|
||||
tracing::trace!(?rpc_packet, "got first piece");
|
||||
}
|
||||
|
||||
self.pieces
|
||||
|
@ -113,6 +141,7 @@ pub fn build_rpc_packet(
|
|||
is_req: bool,
|
||||
content: &Vec<u8>,
|
||||
trace_id: i32,
|
||||
compression_info: RpcCompressionInfo,
|
||||
) -> Vec<ZCPacket> {
|
||||
let mut ret = Vec::new();
|
||||
let content_mtu = RPC_PACKET_CONTENT_MTU;
|
||||
|
@ -130,13 +159,22 @@ pub fn build_rpc_packet(
|
|||
let cur_packet = RpcPacket {
|
||||
from_peer,
|
||||
to_peer,
|
||||
descriptor: Some(rpc_desc.clone()),
|
||||
descriptor: if cur_offset == 0 {
|
||||
Some(rpc_desc.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
is_request: is_req,
|
||||
total_pieces: total_pieces as u32,
|
||||
piece_idx: (cur_offset / content_mtu) as u32,
|
||||
transaction_id,
|
||||
body: cur_content,
|
||||
trace_id,
|
||||
compression_info: if cur_offset == 0 {
|
||||
Some(compression_info.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
cur_offset += cur_len;
|
||||
|
||||
|
|
|
@ -12,7 +12,10 @@ use tokio_stream::StreamExt;
|
|||
use crate::{
|
||||
common::{join_joinset_background, PeerId},
|
||||
proto::{
|
||||
common::{self, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse},
|
||||
common::{
|
||||
self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest,
|
||||
RpcResponse,
|
||||
},
|
||||
rpc_types::error::Result,
|
||||
},
|
||||
tunnel::{
|
||||
|
@ -23,7 +26,7 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::{
|
||||
packet::{build_rpc_packet, PacketMerger},
|
||||
packet::{build_rpc_packet, compress_packet, decompress_packet, PacketMerger},
|
||||
service_registry::ServiceRegistry,
|
||||
RpcController, Transport,
|
||||
};
|
||||
|
@ -31,7 +34,6 @@ use super::{
|
|||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct PacketMergerKey {
|
||||
from_peer_id: PeerId,
|
||||
rpc_desc: RpcDescriptor,
|
||||
transaction_id: i64,
|
||||
}
|
||||
|
||||
|
@ -108,10 +110,11 @@ impl Server {
|
|||
|
||||
let key = PacketMergerKey {
|
||||
from_peer_id: packet.from_peer,
|
||||
rpc_desc: packet.descriptor.clone().unwrap_or_default(),
|
||||
transaction_id: packet.transaction_id,
|
||||
};
|
||||
|
||||
tracing::trace!(?key, ?packet, "Received request packet");
|
||||
|
||||
let ret = packet_merges
|
||||
.entry(key.clone())
|
||||
.or_insert_with(PacketMerger::new)
|
||||
|
@ -144,7 +147,16 @@ impl Server {
|
|||
}
|
||||
|
||||
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
|
||||
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?;
|
||||
let body = if let Some(compression_info) = packet.compression_info {
|
||||
decompress_packet(
|
||||
compression_info.algo.try_into().unwrap_or_default(),
|
||||
&packet.body,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
packet.body
|
||||
};
|
||||
let rpc_request = RpcRequest::decode(Bytes::from(body))?;
|
||||
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
|
||||
let ctrl = RpcController::default();
|
||||
Ok(timeout(
|
||||
|
@ -168,6 +180,7 @@ impl Server {
|
|||
let mut resp_msg = RpcResponse::default();
|
||||
let now = std::time::Instant::now();
|
||||
|
||||
let compression_info = packet.compression_info.clone();
|
||||
let resp_bytes = Self::handle_rpc_request(packet, reg).await;
|
||||
|
||||
match &resp_bytes {
|
||||
|
@ -180,14 +193,25 @@ impl Server {
|
|||
};
|
||||
resp_msg.runtime_us = now.elapsed().as_micros() as u64;
|
||||
|
||||
let (compressed_resp, algo) = compress_packet(
|
||||
compression_info.unwrap_or_default().accepted_algo(),
|
||||
&resp_msg.encode_to_vec(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let packets = build_rpc_packet(
|
||||
to_peer,
|
||||
from_peer,
|
||||
desc,
|
||||
transaction_id,
|
||||
false,
|
||||
&resp_msg.encode_to_vec(),
|
||||
&compressed_resp,
|
||||
trace_id,
|
||||
RpcCompressionInfo {
|
||||
algo: algo.into(),
|
||||
accepted_algo: CompressionAlgoPb::Zstd.into(),
|
||||
},
|
||||
);
|
||||
|
||||
for packet in packets {
|
||||
|
|
|
@ -107,6 +107,7 @@ fn random_string(len: usize) -> String {
|
|||
|
||||
#[tokio::test]
|
||||
async fn rpc_basic_test() {
|
||||
// enable_log();
|
||||
let ctx = TestContext::new();
|
||||
|
||||
let server = GreetingServer::new(GreetingService {
|
||||
|
@ -119,7 +120,7 @@ async fn rpc_basic_test() {
|
|||
.client
|
||||
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "".to_string());
|
||||
|
||||
// small size req and resp
|
||||
// // small size req and resp
|
||||
|
||||
let ctrl = RpcController::default();
|
||||
let input = SayHelloRequest {
|
||||
|
|
|
@ -603,7 +603,7 @@ pub mod tests {
|
|||
|
||||
pub fn enable_log() {
|
||||
let filter = tracing_subscriber::EnvFilter::builder()
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into())
|
||||
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||
.from_env()
|
||||
.unwrap()
|
||||
.add_directive("tarpc=error".parse().unwrap());
|
||||
|
|
|
@ -7,6 +7,10 @@ use zerocopy::FromZeroes;
|
|||
|
||||
type DefaultEndian = LittleEndian;
|
||||
|
||||
const fn max(a: usize, b: usize) -> usize {
|
||||
[a, b][(a < b) as usize]
|
||||
}
|
||||
|
||||
// TCP TunnelHeader
|
||||
#[repr(C, packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||
|
@ -49,11 +53,11 @@ pub enum PacketType {
|
|||
Invalid = 0,
|
||||
Data = 1,
|
||||
HandShake = 2,
|
||||
RoutePacket = 3,
|
||||
RoutePacket = 3, // deprecated
|
||||
Ping = 4,
|
||||
Pong = 5,
|
||||
TaRpc = 6,
|
||||
Route = 7,
|
||||
TaRpc = 6, // deprecated
|
||||
Route = 7, // deprecated
|
||||
RpcReq = 8,
|
||||
RpcResp = 9,
|
||||
ForeignNetworkPacket = 10,
|
||||
|
@ -65,6 +69,7 @@ bitflags::bitflags! {
|
|||
const LATENCY_FIRST = 0b0000_0010;
|
||||
const EXIT_NODE = 0b0000_0100;
|
||||
const NO_PROXY = 0b0000_1000;
|
||||
const COMPRESSED = 0b0001_0000;
|
||||
|
||||
const _ = !0;
|
||||
}
|
||||
|
@ -118,6 +123,12 @@ impl PeerManagerHeader {
|
|||
.contains(PeerManagerHeaderFlags::NO_PROXY)
|
||||
}
|
||||
|
||||
pub fn is_compressed(&self) -> bool {
|
||||
PeerManagerHeaderFlags::from_bits(self.flags)
|
||||
.unwrap()
|
||||
.contains(PeerManagerHeaderFlags::COMPRESSED)
|
||||
}
|
||||
|
||||
pub fn set_latency_first(&mut self, latency_first: bool) -> &mut Self {
|
||||
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
|
||||
if latency_first {
|
||||
|
@ -150,6 +161,17 @@ impl PeerManagerHeader {
|
|||
self.flags = flags.bits();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn set_compressed(&mut self, compressed: bool) -> &mut Self {
|
||||
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
|
||||
if compressed {
|
||||
flags.insert(PeerManagerHeaderFlags::COMPRESSED);
|
||||
} else {
|
||||
flags.remove(PeerManagerHeaderFlags::COMPRESSED);
|
||||
}
|
||||
self.flags = flags.bits();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
|
@ -201,12 +223,35 @@ pub struct AesGcmTail {
|
|||
}
|
||||
pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::<AesGcmTail>();
|
||||
|
||||
pub const TAIL_RESERVED_SIZE: usize = AES_GCM_ENCRYPTION_RESERVED;
|
||||
|
||||
const fn max(a: usize, b: usize) -> usize {
|
||||
[a, b][(a < b) as usize]
|
||||
#[derive(AsBytes, FromZeroes, Clone, Debug, Copy)]
|
||||
#[repr(u8)]
|
||||
pub enum CompressorAlgo {
|
||||
None = 0,
|
||||
ZstdDefault = 1,
|
||||
}
|
||||
|
||||
#[repr(C, packed)]
|
||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||
pub struct CompressorTail {
|
||||
pub algo: u8,
|
||||
}
|
||||
pub const COMPRESSOR_TAIL_SIZE: usize = std::mem::size_of::<CompressorTail>();
|
||||
|
||||
impl CompressorTail {
|
||||
pub fn get_algo(&self) -> Option<CompressorAlgo> {
|
||||
match self.algo {
|
||||
1 => Some(CompressorAlgo::ZstdDefault),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(algo: CompressorAlgo) -> Self {
|
||||
Self { algo: algo as u8 }
|
||||
}
|
||||
}
|
||||
|
||||
pub const TAIL_RESERVED_SIZE: usize = max(AES_GCM_ENCRYPTION_RESERVED, COMPRESSOR_TAIL_SIZE);
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
pub struct ZCPacketOffsets {
|
||||
pub payload_offset: usize,
|
||||
|
|
Loading…
Reference in New Issue
Block a user