mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-15 19:22:30 +08:00
support compress for rpc and tun data
This commit is contained in:
parent
4fc3ff8ce8
commit
6b0976381c
37
Cargo.lock
generated
37
Cargo.lock
generated
|
@ -262,6 +262,20 @@ dependencies = [
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "async-compression"
|
||||||
|
version = "0.4.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0cb8f1d480b0ea3783ab015936d2a55c87e219676f0c0b7dec61494043f21857"
|
||||||
|
dependencies = [
|
||||||
|
"futures-core",
|
||||||
|
"memchr",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tokio",
|
||||||
|
"zstd 0.13.2",
|
||||||
|
"zstd-safe 7.2.1",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "async-event"
|
name = "async-event"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
|
@ -1809,6 +1823,7 @@ version = "2.0.3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aes-gcm",
|
"aes-gcm",
|
||||||
"anyhow",
|
"anyhow",
|
||||||
|
"async-compression",
|
||||||
"async-recursion",
|
"async-recursion",
|
||||||
"async-ringbuf",
|
"async-ringbuf",
|
||||||
"async-stream",
|
"async-stream",
|
||||||
|
@ -9474,7 +9489,7 @@ dependencies = [
|
||||||
"pbkdf2",
|
"pbkdf2",
|
||||||
"sha1",
|
"sha1",
|
||||||
"time",
|
"time",
|
||||||
"zstd",
|
"zstd 0.11.2+zstd.1.5.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -9483,7 +9498,16 @@ version = "0.11.2+zstd.1.5.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
|
checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"zstd-safe",
|
"zstd-safe 5.0.2+zstd.1.5.2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd"
|
||||||
|
version = "0.13.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fcf2b778a664581e31e389454a7072dab1647606d44f7feea22cd5abb9c9f3f9"
|
||||||
|
dependencies = [
|
||||||
|
"zstd-safe 7.2.1",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -9496,6 +9520,15 @@ dependencies = [
|
||||||
"zstd-sys",
|
"zstd-sys",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "zstd-safe"
|
||||||
|
version = "7.2.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "54a3ab4db68cea366acc5c897c7b4d4d1b8994a9cd6e6f841f8964566a419059"
|
||||||
|
dependencies = [
|
||||||
|
"zstd-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zstd-sys"
|
name = "zstd-sys"
|
||||||
version = "2.0.13+zstd.1.5.6"
|
version = "2.0.13+zstd.1.5.6"
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
|
import type { NetworkTypes } from 'easytier-frontend-lib'
|
||||||
import { addPluginListener } from '@tauri-apps/api/core'
|
import { addPluginListener } from '@tauri-apps/api/core'
|
||||||
|
import { Utils } from 'easytier-frontend-lib'
|
||||||
import { prepare_vpn, start_vpn, stop_vpn } from 'tauri-plugin-vpnservice-api'
|
import { prepare_vpn, start_vpn, stop_vpn } from 'tauri-plugin-vpnservice-api'
|
||||||
import { NetworkTypes, Utils } from 'easytier-frontend-lib'
|
|
||||||
|
|
||||||
type Route = NetworkTypes.Route
|
type Route = NetworkTypes.Route
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
|
import type { NetworkTypes } from 'easytier-frontend-lib'
|
||||||
import { invoke } from '@tauri-apps/api/core'
|
import { invoke } from '@tauri-apps/api/core'
|
||||||
import { NetworkTypes } from 'easytier-frontend-lib'
|
|
||||||
|
|
||||||
type NetworkConfig = NetworkTypes.NetworkConfig
|
type NetworkConfig = NetworkTypes.NetworkConfig
|
||||||
type NetworkInstanceRunningInfo = NetworkTypes.NetworkInstanceRunningInfo
|
type NetworkInstanceRunningInfo = NetworkTypes.NetworkInstanceRunningInfo
|
||||||
|
|
|
@ -181,6 +181,8 @@ sys-locale = "0.3"
|
||||||
ringbuf = "0.4.5"
|
ringbuf = "0.4.5"
|
||||||
async-ringbuf = "0.3.1"
|
async-ringbuf = "0.3.1"
|
||||||
|
|
||||||
|
async-compression = { version = "0.4.17", default-features = false, features = ["zstd", "tokio"] }
|
||||||
|
|
||||||
[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies]
|
[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows", target_os = "freebsd"))'.dependencies]
|
||||||
machine-uid = "0.5.3"
|
machine-uid = "0.5.3"
|
||||||
|
|
||||||
|
|
174
easytier/src/common/compressor.rs
Normal file
174
easytier/src/common/compressor.rs
Normal file
|
@ -0,0 +1,174 @@
|
||||||
|
use async_compression::tokio::write::{ZstdDecoder, ZstdEncoder};
|
||||||
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
|
use zerocopy::{AsBytes as _, FromBytes as _};
|
||||||
|
|
||||||
|
use crate::tunnel::packet_def::{
|
||||||
|
CompressorAlgo, CompressorTail, PacketType, ZCPacket, COMPRESSOR_TAIL_SIZE,
|
||||||
|
};
|
||||||
|
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
pub trait Compressor {
|
||||||
|
async fn compress(
|
||||||
|
&self,
|
||||||
|
packet: &mut ZCPacket,
|
||||||
|
compress_algo: CompressorAlgo,
|
||||||
|
) -> Result<(), Error>;
|
||||||
|
async fn decompress(&self, packet: &mut ZCPacket) -> Result<(), Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DefaultCompressor {}
|
||||||
|
|
||||||
|
impl DefaultCompressor {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
DefaultCompressor {}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn compress_raw(
|
||||||
|
&self,
|
||||||
|
data: &[u8],
|
||||||
|
compress_algo: CompressorAlgo,
|
||||||
|
) -> Result<Vec<u8>, Error> {
|
||||||
|
let buf = match compress_algo {
|
||||||
|
CompressorAlgo::ZstdDefault => {
|
||||||
|
let mut o = ZstdEncoder::new(Vec::new());
|
||||||
|
o.write_all(data).await?;
|
||||||
|
o.shutdown().await?;
|
||||||
|
o.into_inner()
|
||||||
|
}
|
||||||
|
CompressorAlgo::None => data.to_vec(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn decompress_raw(
|
||||||
|
&self,
|
||||||
|
data: &[u8],
|
||||||
|
compress_algo: CompressorAlgo,
|
||||||
|
) -> Result<Vec<u8>, Error> {
|
||||||
|
let buf = match compress_algo {
|
||||||
|
CompressorAlgo::ZstdDefault => {
|
||||||
|
let mut o = ZstdDecoder::new(Vec::new());
|
||||||
|
o.write_all(data).await?;
|
||||||
|
o.shutdown().await?;
|
||||||
|
o.into_inner()
|
||||||
|
}
|
||||||
|
CompressorAlgo::None => data.to_vec(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(buf)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Compressor for DefaultCompressor {
|
||||||
|
async fn compress(
|
||||||
|
&self,
|
||||||
|
zc_packet: &mut ZCPacket,
|
||||||
|
compress_algo: CompressorAlgo,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
if matches!(compress_algo, CompressorAlgo::None) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let pm_header = zc_packet.peer_manager_header().unwrap();
|
||||||
|
if pm_header.is_compressed() || pm_header.packet_type != PacketType::Data as u8 {
|
||||||
|
// only compress data packets
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let tail = CompressorTail::new(compress_algo);
|
||||||
|
let buf = self
|
||||||
|
.compress_raw(zc_packet.payload(), compress_algo)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if buf.len() + COMPRESSOR_TAIL_SIZE > pm_header.len.get() as usize {
|
||||||
|
// Compressed data is larger than original data, don't compress
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
zc_packet
|
||||||
|
.mut_peer_manager_header()
|
||||||
|
.unwrap()
|
||||||
|
.set_compressed(true);
|
||||||
|
|
||||||
|
let payload_offset = zc_packet.payload_offset();
|
||||||
|
zc_packet.mut_inner().truncate(payload_offset);
|
||||||
|
zc_packet.mut_inner().extend_from_slice(&buf);
|
||||||
|
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn decompress(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
|
||||||
|
let pm_header = zc_packet.peer_manager_header().unwrap();
|
||||||
|
if !pm_header.is_compressed() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let payload_len = zc_packet.payload().len();
|
||||||
|
if payload_len < COMPRESSOR_TAIL_SIZE {
|
||||||
|
return Err(anyhow::anyhow!("Packet too short: {}", payload_len));
|
||||||
|
}
|
||||||
|
|
||||||
|
let text_len = payload_len - COMPRESSOR_TAIL_SIZE;
|
||||||
|
|
||||||
|
let tail = CompressorTail::ref_from_suffix(zc_packet.payload())
|
||||||
|
.unwrap()
|
||||||
|
.clone();
|
||||||
|
|
||||||
|
let algo = tail
|
||||||
|
.get_algo()
|
||||||
|
.ok_or(anyhow::anyhow!("Unknown algo: {:?}", tail))?;
|
||||||
|
|
||||||
|
let buf = self
|
||||||
|
.decompress_raw(&zc_packet.payload()[..text_len], algo)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if buf.len() != pm_header.len.get() as usize {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Decompressed length mismatch: decompressed len {} != pm header len {}",
|
||||||
|
buf.len(),
|
||||||
|
pm_header.len.get()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
zc_packet
|
||||||
|
.mut_peer_manager_header()
|
||||||
|
.unwrap()
|
||||||
|
.set_compressed(false);
|
||||||
|
|
||||||
|
let payload_offset = zc_packet.payload_offset();
|
||||||
|
zc_packet.mut_inner().truncate(payload_offset);
|
||||||
|
zc_packet.mut_inner().extend_from_slice(&buf);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
pub mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_compress() {
|
||||||
|
let text = b"12345670000000000000000000";
|
||||||
|
let mut packet = ZCPacket::new_with_payload(text);
|
||||||
|
packet.fill_peer_manager_hdr(0, 0, 0);
|
||||||
|
|
||||||
|
let compressor = DefaultCompressor {};
|
||||||
|
|
||||||
|
compressor
|
||||||
|
.compress(&mut packet, CompressorAlgo::ZstdDefault)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), true);
|
||||||
|
|
||||||
|
compressor.decompress(&mut packet).await.unwrap();
|
||||||
|
assert_eq!(packet.payload(), text);
|
||||||
|
assert_eq!(packet.peer_manager_header().unwrap().is_compressed(), false);
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ use std::{
|
||||||
use tokio::task::JoinSet;
|
use tokio::task::JoinSet;
|
||||||
use tracing::Instrument;
|
use tracing::Instrument;
|
||||||
|
|
||||||
|
pub mod compressor;
|
||||||
pub mod config;
|
pub mod config;
|
||||||
pub mod constants;
|
pub mod constants;
|
||||||
pub mod defer;
|
pub mod defer;
|
||||||
|
|
|
@ -20,6 +20,7 @@ use tokio::{
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{
|
common::{
|
||||||
|
compressor::{Compressor as _, DefaultCompressor},
|
||||||
constants::EASYTIER_VERSION,
|
constants::EASYTIER_VERSION,
|
||||||
error::Error,
|
error::Error,
|
||||||
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
global_ctx::{ArcGlobalCtx, NetworkIdentity},
|
||||||
|
@ -41,7 +42,7 @@ use crate::{
|
||||||
},
|
},
|
||||||
tunnel::{
|
tunnel::{
|
||||||
self,
|
self,
|
||||||
packet_def::{PacketType, ZCPacket},
|
packet_def::{CompressorAlgo, PacketType, ZCPacket},
|
||||||
SinkItem, Tunnel, TunnelConnector,
|
SinkItem, Tunnel, TunnelConnector,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -61,6 +62,7 @@ use super::{
|
||||||
struct RpcTransport {
|
struct RpcTransport {
|
||||||
my_peer_id: PeerId,
|
my_peer_id: PeerId,
|
||||||
peers: Weak<PeerMap>,
|
peers: Weak<PeerMap>,
|
||||||
|
// TODO: this seems can be removed
|
||||||
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
|
foreign_peers: Mutex<Option<Weak<ForeignNetworkClient>>>,
|
||||||
|
|
||||||
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
|
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
|
||||||
|
@ -76,48 +78,14 @@ impl PeerRpcManagerTransport for RpcTransport {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
|
||||||
let foreign_peers = self
|
|
||||||
.foreign_peers
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.as_ref()
|
|
||||||
.ok_or(Error::Unknown)?
|
|
||||||
.upgrade()
|
|
||||||
.ok_or(Error::Unknown)?;
|
|
||||||
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
|
let peers = self.peers.upgrade().ok_or(Error::Unknown)?;
|
||||||
|
if peers.need_relay_by_foreign_network(dst_peer_id).await? {
|
||||||
if foreign_peers.has_next_hop(dst_peer_id) {
|
|
||||||
// do not encrypt for data sending to public server
|
|
||||||
tracing::debug!(
|
|
||||||
?dst_peer_id,
|
|
||||||
?self.my_peer_id,
|
|
||||||
"failed to send msg to peer, try foreign network",
|
|
||||||
);
|
|
||||||
foreign_peers.send_msg(msg, dst_peer_id).await
|
|
||||||
} else if let Some(gateway_id) = peers
|
|
||||||
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
tracing::trace!(
|
|
||||||
?dst_peer_id,
|
|
||||||
?gateway_id,
|
|
||||||
?self.my_peer_id,
|
|
||||||
"send msg to peer via gateway",
|
|
||||||
);
|
|
||||||
self.encryptor
|
self.encryptor
|
||||||
.encrypt(&mut msg)
|
.encrypt(&mut msg)
|
||||||
.with_context(|| "encrypt failed")?;
|
.with_context(|| "encrypt failed")?;
|
||||||
if peers.has_peer(gateway_id) {
|
|
||||||
peers.send_msg_directly(msg, gateway_id).await
|
|
||||||
} else {
|
|
||||||
foreign_peers.send_msg(msg, gateway_id).await
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
Err(Error::RouteError(Some(format!(
|
|
||||||
"peermgr RpcTransport no route for dst_peer_id: {}",
|
|
||||||
dst_peer_id
|
|
||||||
))))
|
|
||||||
}
|
}
|
||||||
|
// send to self and this packet will be forwarded in peer_recv loop
|
||||||
|
peers.send_msg_directly(msg, self.my_peer_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recv(&self) -> Result<ZCPacket, Error> {
|
async fn recv(&self) -> Result<ZCPacket, Error> {
|
||||||
|
@ -163,6 +131,7 @@ pub struct PeerManager {
|
||||||
foreign_network_client: Arc<ForeignNetworkClient>,
|
foreign_network_client: Arc<ForeignNetworkClient>,
|
||||||
|
|
||||||
encryptor: Arc<Box<dyn Encryptor>>,
|
encryptor: Arc<Box<dyn Encryptor>>,
|
||||||
|
data_compress_algo: CompressorAlgo,
|
||||||
|
|
||||||
exit_nodes: Vec<Ipv4Addr>,
|
exit_nodes: Vec<Ipv4Addr>,
|
||||||
}
|
}
|
||||||
|
@ -272,6 +241,8 @@ impl PeerManager {
|
||||||
foreign_network_client,
|
foreign_network_client,
|
||||||
|
|
||||||
encryptor,
|
encryptor,
|
||||||
|
data_compress_algo: CompressorAlgo::None,
|
||||||
|
|
||||||
exit_nodes,
|
exit_nodes,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -465,6 +436,12 @@ impl PeerManager {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let compressor = DefaultCompressor {};
|
||||||
|
if let Err(e) = compressor.decompress(&mut ret).await {
|
||||||
|
tracing::error!(?e, "decompress failed");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
let mut processed = false;
|
let mut processed = false;
|
||||||
let mut zc_packet = Some(ret);
|
let mut zc_packet = Some(ret);
|
||||||
let mut idx = 0;
|
let mut idx = 0;
|
||||||
|
@ -768,6 +745,11 @@ impl PeerManager {
|
||||||
tunnel::packet_def::PacketType::Data as u8,
|
tunnel::packet_def::PacketType::Data as u8,
|
||||||
);
|
);
|
||||||
self.run_nic_packet_process_pipeline(&mut msg).await;
|
self.run_nic_packet_process_pipeline(&mut msg).await;
|
||||||
|
let compressor = DefaultCompressor {};
|
||||||
|
compressor
|
||||||
|
.compress(&mut msg, self.data_compress_algo)
|
||||||
|
.await
|
||||||
|
.with_context(|| "compress failed")?;
|
||||||
self.encryptor
|
self.encryptor
|
||||||
.encrypt(&mut msg)
|
.encrypt(&mut msg)
|
||||||
.with_context(|| "encrypt failed")?;
|
.with_context(|| "encrypt failed")?;
|
||||||
|
|
|
@ -250,6 +250,19 @@ impl PeerMap {
|
||||||
}
|
}
|
||||||
route_map
|
route_map
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn need_relay_by_foreign_network(&self, dst_peer_id: PeerId) -> Result<bool, Error> {
|
||||||
|
// if gateway_peer_id is not connected to me, means need relay by foreign network
|
||||||
|
let gateway_id = self
|
||||||
|
.get_gateway_peer_id(dst_peer_id, NextHopPolicy::LeastHop)
|
||||||
|
.await
|
||||||
|
.ok_or(Error::RouteError(Some(format!(
|
||||||
|
"peer map need_relay_by_foreign_network no gateway for dst_peer_id: {}",
|
||||||
|
dst_peer_id
|
||||||
|
))))?;
|
||||||
|
|
||||||
|
Ok(!self.has_peer(gateway_id))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for PeerMap {
|
impl Drop for PeerMap {
|
||||||
|
|
|
@ -32,7 +32,7 @@ message RpcDescriptor {
|
||||||
}
|
}
|
||||||
|
|
||||||
message RpcRequest {
|
message RpcRequest {
|
||||||
RpcDescriptor descriptor = 1;
|
RpcDescriptor descriptor = 1 [ deprecated = true ];
|
||||||
|
|
||||||
bytes request = 2;
|
bytes request = 2;
|
||||||
int32 timeout_ms = 3;
|
int32 timeout_ms = 3;
|
||||||
|
@ -45,6 +45,21 @@ message RpcResponse {
|
||||||
uint64 runtime_us = 3;
|
uint64 runtime_us = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum CompressionAlgoPb {
|
||||||
|
Invalid = 0;
|
||||||
|
None = 1;
|
||||||
|
Zstd = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RpcCompressionInfo {
|
||||||
|
// use this to compress the content
|
||||||
|
CompressionAlgoPb algo = 1;
|
||||||
|
|
||||||
|
// tell the peer which compression algo is used to compress the next
|
||||||
|
// response/request
|
||||||
|
CompressionAlgoPb accepted_algo = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message RpcPacket {
|
message RpcPacket {
|
||||||
uint32 from_peer = 1;
|
uint32 from_peer = 1;
|
||||||
uint32 to_peer = 2;
|
uint32 to_peer = 2;
|
||||||
|
@ -58,6 +73,8 @@ message RpcPacket {
|
||||||
uint32 piece_idx = 8;
|
uint32 piece_idx = 8;
|
||||||
|
|
||||||
int32 trace_id = 9;
|
int32 trace_id = 9;
|
||||||
|
|
||||||
|
RpcCompressionInfo compression_info = 10;
|
||||||
}
|
}
|
||||||
|
|
||||||
message Void {}
|
message Void {}
|
||||||
|
|
|
@ -2,6 +2,8 @@ use std::{fmt::Display, str::FromStr};
|
||||||
|
|
||||||
use anyhow::Context;
|
use anyhow::Context;
|
||||||
|
|
||||||
|
use crate::tunnel::packet_def::CompressorAlgo;
|
||||||
|
|
||||||
include!(concat!(env!("OUT_DIR"), "/common.rs"));
|
include!(concat!(env!("OUT_DIR"), "/common.rs"));
|
||||||
|
|
||||||
impl From<uuid::Uuid> for Uuid {
|
impl From<uuid::Uuid> for Uuid {
|
||||||
|
@ -180,3 +182,26 @@ impl From<SocketAddr> for std::net::SocketAddr {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl TryFrom<CompressionAlgoPb> for CompressorAlgo {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: CompressionAlgoPb) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
CompressionAlgoPb::Zstd => Ok(CompressorAlgo::ZstdDefault),
|
||||||
|
CompressionAlgoPb::None => Ok(CompressorAlgo::None),
|
||||||
|
_ => Err(anyhow::anyhow!("Invalid CompressionAlgoPb")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TryFrom<CompressorAlgo> for CompressionAlgoPb {
|
||||||
|
type Error = anyhow::Error;
|
||||||
|
|
||||||
|
fn try_from(value: CompressorAlgo) -> Result<Self, Self::Error> {
|
||||||
|
match value {
|
||||||
|
CompressorAlgo::ZstdDefault => Ok(CompressionAlgoPb::Zstd),
|
||||||
|
CompressorAlgo::None => Ok(CompressionAlgoPb::None),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -12,8 +12,10 @@ use tokio_stream::StreamExt;
|
||||||
|
|
||||||
use crate::common::PeerId;
|
use crate::common::PeerId;
|
||||||
use crate::defer;
|
use crate::defer;
|
||||||
use crate::proto::common::{RpcDescriptor, RpcPacket, RpcRequest, RpcResponse};
|
use crate::proto::common::{
|
||||||
use crate::proto::rpc_impl::packet::build_rpc_packet;
|
CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse,
|
||||||
|
};
|
||||||
|
use crate::proto::rpc_impl::packet::{build_rpc_packet, compress_packet, decompress_packet};
|
||||||
use crate::proto::rpc_types::controller::Controller;
|
use crate::proto::rpc_types::controller::Controller;
|
||||||
use crate::proto::rpc_types::descriptor::MethodDescriptor;
|
use crate::proto::rpc_types::descriptor::MethodDescriptor;
|
||||||
use crate::proto::rpc_types::{
|
use crate::proto::rpc_types::{
|
||||||
|
@ -48,12 +50,21 @@ struct InflightRequest {
|
||||||
start_time: std::time::Instant,
|
start_time: std::time::Instant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default)]
|
||||||
|
struct PeerInfo {
|
||||||
|
peer_id: PeerId,
|
||||||
|
compression_info: RpcCompressionInfo,
|
||||||
|
last_active: Option<std::time::Instant>,
|
||||||
|
}
|
||||||
|
|
||||||
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
|
type InflightRequestTable = Arc<DashMap<InflightRequestKey, InflightRequest>>;
|
||||||
|
type PeerInfoTable = Arc<DashMap<PeerId, PeerInfo>>;
|
||||||
|
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
|
mpsc: Mutex<MpscTunnel<Box<dyn Tunnel>>>,
|
||||||
transport: Mutex<Transport>,
|
transport: Mutex<Transport>,
|
||||||
inflight_requests: InflightRequestTable,
|
inflight_requests: InflightRequestTable,
|
||||||
|
peer_info: PeerInfoTable,
|
||||||
tasks: Arc<Mutex<JoinSet<()>>>,
|
tasks: Arc<Mutex<JoinSet<()>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,6 +75,7 @@ impl Client {
|
||||||
mpsc: Mutex::new(MpscTunnel::new(ring_a, None)),
|
mpsc: Mutex::new(MpscTunnel::new(ring_a, None)),
|
||||||
transport: Mutex::new(MpscTunnel::new(ring_b, None)),
|
transport: Mutex::new(MpscTunnel::new(ring_b, None)),
|
||||||
inflight_requests: Arc::new(DashMap::new()),
|
inflight_requests: Arc::new(DashMap::new()),
|
||||||
|
peer_info: Arc::new(DashMap::new()),
|
||||||
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
tasks: Arc::new(Mutex::new(JoinSet::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -79,6 +91,21 @@ impl Client {
|
||||||
pub fn run(&self) {
|
pub fn run(&self) {
|
||||||
let mut tasks = self.tasks.lock().unwrap();
|
let mut tasks = self.tasks.lock().unwrap();
|
||||||
|
|
||||||
|
let peer_infos = self.peer_info.clone();
|
||||||
|
tasks.spawn(async move {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
||||||
|
let now = std::time::Instant::now();
|
||||||
|
peer_infos.retain(|_, v| {
|
||||||
|
if let Some(last_active) = v.last_active {
|
||||||
|
return now.duration_since(last_active)
|
||||||
|
< std::time::Duration::from_secs(120);
|
||||||
|
}
|
||||||
|
true
|
||||||
|
});
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
let mut rx = self.mpsc.lock().unwrap().get_stream();
|
let mut rx = self.mpsc.lock().unwrap().get_stream();
|
||||||
let inflight_requests = self.inflight_requests.clone();
|
let inflight_requests = self.inflight_requests.clone();
|
||||||
tasks.spawn(async move {
|
tasks.spawn(async move {
|
||||||
|
@ -111,6 +138,8 @@ impl Client {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::trace!(?packet, "Received response packet");
|
||||||
|
|
||||||
let ret = inflight_request.merger.feed(packet);
|
let ret = inflight_request.merger.feed(packet);
|
||||||
match ret {
|
match ret {
|
||||||
Ok(Some(rpc_packet)) => {
|
Ok(Some(rpc_packet)) => {
|
||||||
|
@ -138,6 +167,7 @@ impl Client {
|
||||||
to_peer_id: PeerId,
|
to_peer_id: PeerId,
|
||||||
zc_packet_sender: MpscTunnelSender,
|
zc_packet_sender: MpscTunnelSender,
|
||||||
inflight_requests: InflightRequestTable,
|
inflight_requests: InflightRequestTable,
|
||||||
|
peer_info: PeerInfoTable,
|
||||||
_phan: PhantomData<F>,
|
_phan: PhantomData<F>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -194,23 +224,53 @@ impl Client {
|
||||||
};
|
};
|
||||||
|
|
||||||
let rpc_req = RpcRequest {
|
let rpc_req = RpcRequest {
|
||||||
descriptor: Some(rpc_desc.clone()),
|
|
||||||
request: input.into(),
|
request: input.into(),
|
||||||
timeout_ms: ctrl.timeout_ms(),
|
timeout_ms: ctrl.timeout_ms(),
|
||||||
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let peer_info = self
|
||||||
|
.peer_info
|
||||||
|
.get(&self.to_peer_id)
|
||||||
|
.map(|v| v.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
let (buf, c_algo) = compress_packet(
|
||||||
|
peer_info.compression_info.accepted_algo(),
|
||||||
|
&rpc_req.encode_to_vec(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let packets = build_rpc_packet(
|
let packets = build_rpc_packet(
|
||||||
self.from_peer_id,
|
self.from_peer_id,
|
||||||
self.to_peer_id,
|
self.to_peer_id,
|
||||||
rpc_desc,
|
rpc_desc,
|
||||||
transaction_id,
|
transaction_id,
|
||||||
true,
|
true,
|
||||||
&rpc_req.encode_to_vec(),
|
&buf,
|
||||||
ctrl.trace_id(),
|
ctrl.trace_id(),
|
||||||
|
RpcCompressionInfo {
|
||||||
|
algo: c_algo.into(),
|
||||||
|
accepted_algo: CompressionAlgoPb::Zstd.into(),
|
||||||
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
|
let timeout_dur = std::time::Duration::from_millis(ctrl.timeout_ms() as u64);
|
||||||
let rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
|
let mut rpc_packet = timeout(timeout_dur, self.do_rpc(packets, &mut rx)).await??;
|
||||||
|
|
||||||
|
if let Some(compression_info) = rpc_packet.compression_info {
|
||||||
|
self.peer_info.insert(
|
||||||
|
self.to_peer_id,
|
||||||
|
PeerInfo {
|
||||||
|
peer_id: self.to_peer_id,
|
||||||
|
compression_info: compression_info.clone(),
|
||||||
|
last_active: Some(std::time::Instant::now()),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
rpc_packet.body =
|
||||||
|
decompress_packet(compression_info.algo(), &rpc_packet.body).await?;
|
||||||
|
}
|
||||||
|
|
||||||
assert_eq!(rpc_packet.transaction_id, transaction_id);
|
assert_eq!(rpc_packet.transaction_id, transaction_id);
|
||||||
|
|
||||||
|
@ -230,6 +290,7 @@ impl Client {
|
||||||
to_peer_id,
|
to_peer_id,
|
||||||
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
|
zc_packet_sender: self.mpsc.lock().unwrap().get_sink(),
|
||||||
inflight_requests: self.inflight_requests.clone(),
|
inflight_requests: self.inflight_requests.clone(),
|
||||||
|
peer_info: self.peer_info.clone(),
|
||||||
_phan: PhantomData,
|
_phan: PhantomData,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,18 +1,44 @@
|
||||||
use prost::Message as _;
|
use prost::Message as _;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
common::PeerId,
|
common::{compressor::DefaultCompressor, PeerId},
|
||||||
proto::{
|
proto::{
|
||||||
common::{RpcDescriptor, RpcPacket},
|
common::{CompressionAlgoPb, RpcCompressionInfo, RpcDescriptor, RpcPacket},
|
||||||
rpc_types::error::Error,
|
rpc_types::error::Error,
|
||||||
},
|
},
|
||||||
tunnel::packet_def::{PacketType, ZCPacket},
|
tunnel::packet_def::{CompressorAlgo, PacketType, ZCPacket},
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::RpcTransactId;
|
use super::RpcTransactId;
|
||||||
|
|
||||||
const RPC_PACKET_CONTENT_MTU: usize = 1300;
|
const RPC_PACKET_CONTENT_MTU: usize = 1300;
|
||||||
|
|
||||||
|
pub async fn compress_packet(
|
||||||
|
accepted_compression_algo: CompressionAlgoPb,
|
||||||
|
content: &[u8],
|
||||||
|
) -> Result<(Vec<u8>, CompressionAlgoPb), Error> {
|
||||||
|
let compressor = DefaultCompressor::new();
|
||||||
|
let algo = accepted_compression_algo
|
||||||
|
.try_into()
|
||||||
|
.unwrap_or(CompressorAlgo::None);
|
||||||
|
let compressed = compressor.compress_raw(&content, algo).await?;
|
||||||
|
if compressed.len() >= content.len() {
|
||||||
|
Ok((content.to_vec(), CompressionAlgoPb::None))
|
||||||
|
} else {
|
||||||
|
Ok((compressed, algo.try_into().unwrap()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn decompress_packet(
|
||||||
|
compression_algo: CompressionAlgoPb,
|
||||||
|
content: &[u8],
|
||||||
|
) -> Result<Vec<u8>, Error> {
|
||||||
|
let compressor = DefaultCompressor::new();
|
||||||
|
let algo = compression_algo.try_into()?;
|
||||||
|
let decompressed = compressor.decompress_raw(&content, algo).await?;
|
||||||
|
Ok(decompressed)
|
||||||
|
}
|
||||||
|
|
||||||
pub struct PacketMerger {
|
pub struct PacketMerger {
|
||||||
first_piece: Option<RpcPacket>,
|
first_piece: Option<RpcPacket>,
|
||||||
pieces: Vec<RpcPacket>,
|
pieces: Vec<RpcPacket>,
|
||||||
|
@ -46,7 +72,8 @@ impl PacketMerger {
|
||||||
body.extend_from_slice(&p.body);
|
body.extend_from_slice(&p.body);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut tmpl_packet = self.first_piece.as_ref().unwrap().clone();
|
// only the first packet contains the complete info
|
||||||
|
let mut tmpl_packet = self.pieces[0].clone();
|
||||||
tmpl_packet.total_pieces = 1;
|
tmpl_packet.total_pieces = 1;
|
||||||
tmpl_packet.piece_idx = 0;
|
tmpl_packet.piece_idx = 0;
|
||||||
tmpl_packet.body = body;
|
tmpl_packet.body = body;
|
||||||
|
@ -58,17 +85,17 @@ impl PacketMerger {
|
||||||
let total_pieces = rpc_packet.total_pieces;
|
let total_pieces = rpc_packet.total_pieces;
|
||||||
let piece_idx = rpc_packet.piece_idx;
|
let piece_idx = rpc_packet.piece_idx;
|
||||||
|
|
||||||
if rpc_packet.descriptor.is_none() {
|
|
||||||
return Err(Error::MalformatRpcPacket(
|
|
||||||
"descriptor is missing".to_owned(),
|
|
||||||
));
|
|
||||||
}
|
|
||||||
|
|
||||||
// for compatibility with old version
|
// for compatibility with old version
|
||||||
if total_pieces == 0 && piece_idx == 0 {
|
if total_pieces == 0 && piece_idx == 0 {
|
||||||
return Ok(Some(rpc_packet));
|
return Ok(Some(rpc_packet));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rpc_packet.piece_idx == 0 && rpc_packet.descriptor.is_none() {
|
||||||
|
return Err(Error::MalformatRpcPacket(
|
||||||
|
"descriptor is missing".to_owned(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
// about 32MB max size
|
// about 32MB max size
|
||||||
if total_pieces > 32 * 1024 || total_pieces == 0 {
|
if total_pieces > 32 * 1024 || total_pieces == 0 {
|
||||||
return Err(Error::MalformatRpcPacket(format!(
|
return Err(Error::MalformatRpcPacket(format!(
|
||||||
|
@ -89,6 +116,7 @@ impl PacketMerger {
|
||||||
{
|
{
|
||||||
self.first_piece = Some(rpc_packet.clone());
|
self.first_piece = Some(rpc_packet.clone());
|
||||||
self.pieces.clear();
|
self.pieces.clear();
|
||||||
|
tracing::trace!(?rpc_packet, "got first piece");
|
||||||
}
|
}
|
||||||
|
|
||||||
self.pieces
|
self.pieces
|
||||||
|
@ -113,6 +141,7 @@ pub fn build_rpc_packet(
|
||||||
is_req: bool,
|
is_req: bool,
|
||||||
content: &Vec<u8>,
|
content: &Vec<u8>,
|
||||||
trace_id: i32,
|
trace_id: i32,
|
||||||
|
compression_info: RpcCompressionInfo,
|
||||||
) -> Vec<ZCPacket> {
|
) -> Vec<ZCPacket> {
|
||||||
let mut ret = Vec::new();
|
let mut ret = Vec::new();
|
||||||
let content_mtu = RPC_PACKET_CONTENT_MTU;
|
let content_mtu = RPC_PACKET_CONTENT_MTU;
|
||||||
|
@ -130,13 +159,22 @@ pub fn build_rpc_packet(
|
||||||
let cur_packet = RpcPacket {
|
let cur_packet = RpcPacket {
|
||||||
from_peer,
|
from_peer,
|
||||||
to_peer,
|
to_peer,
|
||||||
descriptor: Some(rpc_desc.clone()),
|
descriptor: if cur_offset == 0 {
|
||||||
|
Some(rpc_desc.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
is_request: is_req,
|
is_request: is_req,
|
||||||
total_pieces: total_pieces as u32,
|
total_pieces: total_pieces as u32,
|
||||||
piece_idx: (cur_offset / content_mtu) as u32,
|
piece_idx: (cur_offset / content_mtu) as u32,
|
||||||
transaction_id,
|
transaction_id,
|
||||||
body: cur_content,
|
body: cur_content,
|
||||||
trace_id,
|
trace_id,
|
||||||
|
compression_info: if cur_offset == 0 {
|
||||||
|
Some(compression_info.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
};
|
};
|
||||||
cur_offset += cur_len;
|
cur_offset += cur_len;
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,10 @@ use tokio_stream::StreamExt;
|
||||||
use crate::{
|
use crate::{
|
||||||
common::{join_joinset_background, PeerId},
|
common::{join_joinset_background, PeerId},
|
||||||
proto::{
|
proto::{
|
||||||
common::{self, RpcDescriptor, RpcPacket, RpcRequest, RpcResponse},
|
common::{
|
||||||
|
self, CompressionAlgoPb, RpcCompressionInfo, RpcPacket, RpcRequest,
|
||||||
|
RpcResponse,
|
||||||
|
},
|
||||||
rpc_types::error::Result,
|
rpc_types::error::Result,
|
||||||
},
|
},
|
||||||
tunnel::{
|
tunnel::{
|
||||||
|
@ -23,7 +26,7 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
packet::{build_rpc_packet, PacketMerger},
|
packet::{build_rpc_packet, compress_packet, decompress_packet, PacketMerger},
|
||||||
service_registry::ServiceRegistry,
|
service_registry::ServiceRegistry,
|
||||||
RpcController, Transport,
|
RpcController, Transport,
|
||||||
};
|
};
|
||||||
|
@ -31,7 +34,6 @@ use super::{
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||||
struct PacketMergerKey {
|
struct PacketMergerKey {
|
||||||
from_peer_id: PeerId,
|
from_peer_id: PeerId,
|
||||||
rpc_desc: RpcDescriptor,
|
|
||||||
transaction_id: i64,
|
transaction_id: i64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -108,10 +110,11 @@ impl Server {
|
||||||
|
|
||||||
let key = PacketMergerKey {
|
let key = PacketMergerKey {
|
||||||
from_peer_id: packet.from_peer,
|
from_peer_id: packet.from_peer,
|
||||||
rpc_desc: packet.descriptor.clone().unwrap_or_default(),
|
|
||||||
transaction_id: packet.transaction_id,
|
transaction_id: packet.transaction_id,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
tracing::trace!(?key, ?packet, "Received request packet");
|
||||||
|
|
||||||
let ret = packet_merges
|
let ret = packet_merges
|
||||||
.entry(key.clone())
|
.entry(key.clone())
|
||||||
.or_insert_with(PacketMerger::new)
|
.or_insert_with(PacketMerger::new)
|
||||||
|
@ -144,7 +147,16 @@ impl Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
|
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
|
||||||
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?;
|
let body = if let Some(compression_info) = packet.compression_info {
|
||||||
|
decompress_packet(
|
||||||
|
compression_info.algo.try_into().unwrap_or_default(),
|
||||||
|
&packet.body,
|
||||||
|
)
|
||||||
|
.await?
|
||||||
|
} else {
|
||||||
|
packet.body
|
||||||
|
};
|
||||||
|
let rpc_request = RpcRequest::decode(Bytes::from(body))?;
|
||||||
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
|
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
|
||||||
let ctrl = RpcController::default();
|
let ctrl = RpcController::default();
|
||||||
Ok(timeout(
|
Ok(timeout(
|
||||||
|
@ -168,6 +180,7 @@ impl Server {
|
||||||
let mut resp_msg = RpcResponse::default();
|
let mut resp_msg = RpcResponse::default();
|
||||||
let now = std::time::Instant::now();
|
let now = std::time::Instant::now();
|
||||||
|
|
||||||
|
let compression_info = packet.compression_info.clone();
|
||||||
let resp_bytes = Self::handle_rpc_request(packet, reg).await;
|
let resp_bytes = Self::handle_rpc_request(packet, reg).await;
|
||||||
|
|
||||||
match &resp_bytes {
|
match &resp_bytes {
|
||||||
|
@ -180,14 +193,25 @@ impl Server {
|
||||||
};
|
};
|
||||||
resp_msg.runtime_us = now.elapsed().as_micros() as u64;
|
resp_msg.runtime_us = now.elapsed().as_micros() as u64;
|
||||||
|
|
||||||
|
let (compressed_resp, algo) = compress_packet(
|
||||||
|
compression_info.unwrap_or_default().accepted_algo(),
|
||||||
|
&resp_msg.encode_to_vec(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let packets = build_rpc_packet(
|
let packets = build_rpc_packet(
|
||||||
to_peer,
|
to_peer,
|
||||||
from_peer,
|
from_peer,
|
||||||
desc,
|
desc,
|
||||||
transaction_id,
|
transaction_id,
|
||||||
false,
|
false,
|
||||||
&resp_msg.encode_to_vec(),
|
&compressed_resp,
|
||||||
trace_id,
|
trace_id,
|
||||||
|
RpcCompressionInfo {
|
||||||
|
algo: algo.into(),
|
||||||
|
accepted_algo: CompressionAlgoPb::Zstd.into(),
|
||||||
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
for packet in packets {
|
for packet in packets {
|
||||||
|
|
|
@ -107,6 +107,7 @@ fn random_string(len: usize) -> String {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn rpc_basic_test() {
|
async fn rpc_basic_test() {
|
||||||
|
// enable_log();
|
||||||
let ctx = TestContext::new();
|
let ctx = TestContext::new();
|
||||||
|
|
||||||
let server = GreetingServer::new(GreetingService {
|
let server = GreetingServer::new(GreetingService {
|
||||||
|
@ -119,7 +120,7 @@ async fn rpc_basic_test() {
|
||||||
.client
|
.client
|
||||||
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "".to_string());
|
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "".to_string());
|
||||||
|
|
||||||
// small size req and resp
|
// // small size req and resp
|
||||||
|
|
||||||
let ctrl = RpcController::default();
|
let ctrl = RpcController::default();
|
||||||
let input = SayHelloRequest {
|
let input = SayHelloRequest {
|
||||||
|
|
|
@ -603,7 +603,7 @@ pub mod tests {
|
||||||
|
|
||||||
pub fn enable_log() {
|
pub fn enable_log() {
|
||||||
let filter = tracing_subscriber::EnvFilter::builder()
|
let filter = tracing_subscriber::EnvFilter::builder()
|
||||||
.with_default_directive(tracing::level_filters::LevelFilter::DEBUG.into())
|
.with_default_directive(tracing::level_filters::LevelFilter::TRACE.into())
|
||||||
.from_env()
|
.from_env()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.add_directive("tarpc=error".parse().unwrap());
|
.add_directive("tarpc=error".parse().unwrap());
|
||||||
|
|
|
@ -7,6 +7,10 @@ use zerocopy::FromZeroes;
|
||||||
|
|
||||||
type DefaultEndian = LittleEndian;
|
type DefaultEndian = LittleEndian;
|
||||||
|
|
||||||
|
const fn max(a: usize, b: usize) -> usize {
|
||||||
|
[a, b][(a < b) as usize]
|
||||||
|
}
|
||||||
|
|
||||||
// TCP TunnelHeader
|
// TCP TunnelHeader
|
||||||
#[repr(C, packed)]
|
#[repr(C, packed)]
|
||||||
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||||
|
@ -49,11 +53,11 @@ pub enum PacketType {
|
||||||
Invalid = 0,
|
Invalid = 0,
|
||||||
Data = 1,
|
Data = 1,
|
||||||
HandShake = 2,
|
HandShake = 2,
|
||||||
RoutePacket = 3,
|
RoutePacket = 3, // deprecated
|
||||||
Ping = 4,
|
Ping = 4,
|
||||||
Pong = 5,
|
Pong = 5,
|
||||||
TaRpc = 6,
|
TaRpc = 6, // deprecated
|
||||||
Route = 7,
|
Route = 7, // deprecated
|
||||||
RpcReq = 8,
|
RpcReq = 8,
|
||||||
RpcResp = 9,
|
RpcResp = 9,
|
||||||
ForeignNetworkPacket = 10,
|
ForeignNetworkPacket = 10,
|
||||||
|
@ -65,6 +69,7 @@ bitflags::bitflags! {
|
||||||
const LATENCY_FIRST = 0b0000_0010;
|
const LATENCY_FIRST = 0b0000_0010;
|
||||||
const EXIT_NODE = 0b0000_0100;
|
const EXIT_NODE = 0b0000_0100;
|
||||||
const NO_PROXY = 0b0000_1000;
|
const NO_PROXY = 0b0000_1000;
|
||||||
|
const COMPRESSED = 0b0001_0000;
|
||||||
|
|
||||||
const _ = !0;
|
const _ = !0;
|
||||||
}
|
}
|
||||||
|
@ -118,6 +123,12 @@ impl PeerManagerHeader {
|
||||||
.contains(PeerManagerHeaderFlags::NO_PROXY)
|
.contains(PeerManagerHeaderFlags::NO_PROXY)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn is_compressed(&self) -> bool {
|
||||||
|
PeerManagerHeaderFlags::from_bits(self.flags)
|
||||||
|
.unwrap()
|
||||||
|
.contains(PeerManagerHeaderFlags::COMPRESSED)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn set_latency_first(&mut self, latency_first: bool) -> &mut Self {
|
pub fn set_latency_first(&mut self, latency_first: bool) -> &mut Self {
|
||||||
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
|
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
|
||||||
if latency_first {
|
if latency_first {
|
||||||
|
@ -150,6 +161,17 @@ impl PeerManagerHeader {
|
||||||
self.flags = flags.bits();
|
self.flags = flags.bits();
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn set_compressed(&mut self, compressed: bool) -> &mut Self {
|
||||||
|
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
|
||||||
|
if compressed {
|
||||||
|
flags.insert(PeerManagerHeaderFlags::COMPRESSED);
|
||||||
|
} else {
|
||||||
|
flags.remove(PeerManagerHeaderFlags::COMPRESSED);
|
||||||
|
}
|
||||||
|
self.flags = flags.bits();
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[repr(C, packed)]
|
#[repr(C, packed)]
|
||||||
|
@ -201,12 +223,35 @@ pub struct AesGcmTail {
|
||||||
}
|
}
|
||||||
pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::<AesGcmTail>();
|
pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::<AesGcmTail>();
|
||||||
|
|
||||||
pub const TAIL_RESERVED_SIZE: usize = AES_GCM_ENCRYPTION_RESERVED;
|
#[derive(AsBytes, FromZeroes, Clone, Debug, Copy)]
|
||||||
|
#[repr(u8)]
|
||||||
const fn max(a: usize, b: usize) -> usize {
|
pub enum CompressorAlgo {
|
||||||
[a, b][(a < b) as usize]
|
None = 0,
|
||||||
|
ZstdDefault = 1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[repr(C, packed)]
|
||||||
|
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
|
||||||
|
pub struct CompressorTail {
|
||||||
|
pub algo: u8,
|
||||||
|
}
|
||||||
|
pub const COMPRESSOR_TAIL_SIZE: usize = std::mem::size_of::<CompressorTail>();
|
||||||
|
|
||||||
|
impl CompressorTail {
|
||||||
|
pub fn get_algo(&self) -> Option<CompressorAlgo> {
|
||||||
|
match self.algo {
|
||||||
|
1 => Some(CompressorAlgo::ZstdDefault),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(algo: CompressorAlgo) -> Self {
|
||||||
|
Self { algo: algo as u8 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const TAIL_RESERVED_SIZE: usize = max(AES_GCM_ENCRYPTION_RESERVED, COMPRESSOR_TAIL_SIZE);
|
||||||
|
|
||||||
#[derive(Default, Debug)]
|
#[derive(Default, Debug)]
|
||||||
pub struct ZCPacketOffsets {
|
pub struct ZCPacketOffsets {
|
||||||
pub payload_offset: usize,
|
pub payload_offset: usize,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user