diff --git a/Cargo.lock b/Cargo.lock index 1b68679..65a4552 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2518,9 +2518,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.152" +version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" [[package]] name = "libloading" diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index 2049129..452b85c 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -46,7 +46,7 @@ chrono = "0.4.35" gethostname = "0.4.3" -futures = "0.3" +futures = { version = "0.3", features = ["bilock", "unstable"] } tokio = { version = "1", features = ["full"] } tokio-stream = "0.1" diff --git a/easytier/src/common/ifcfg.rs b/easytier/src/common/ifcfg.rs index 0bf3fd5..799c6af 100644 --- a/easytier/src/common/ifcfg.rs +++ b/easytier/src/common/ifcfg.rs @@ -6,7 +6,7 @@ use tokio::process::Command; use super::error::Error; #[async_trait] -pub trait IfConfiguerTrait { +pub trait IfConfiguerTrait: Send + Sync { async fn add_ipv4_route( &self, name: &str, diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index fdc091f..c80ed52 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -118,6 +118,13 @@ and the vpn client is in network of 10.14.14.0/24" #[arg(long, help = "default protocol to use when connecting to peers")] default_protocol: Option, + + #[arg( + long, + help = "use multi-thread runtime, default is single-thread", + default_value = "false" + )] + multi_thread: bool, } impl From for TomlConfigLoader { @@ -329,14 +336,8 @@ fn setup_panic_handler() { })); } -#[tokio::main(flavor = "current_thread")] #[tracing::instrument] -pub async fn main() { - setup_panic_handler(); - - let cli = Cli::parse(); - tracing::info!(cli = ?cli, "cli args parsed"); - +pub async fn async_main(cli: Cli) { let cfg: TomlConfigLoader = cli.into(); init_logger(&cfg); @@ -427,3 +428,24 @@ pub async fn main() { inst.wait().await; } + +fn main() { + setup_panic_handler(); + + let cli = Cli::parse(); + tracing::info!(cli = ?cli, "cli args parsed"); + + if cli.multi_thread { + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { async_main(cli).await }) + } else { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() + .block_on(async move { async_main(cli).await }) + } +} diff --git a/easytier/src/gateway/mod.rs b/easytier/src/gateway/mod.rs index 1466e58..df8c8ad 100644 --- a/easytier/src/gateway/mod.rs +++ b/easytier/src/gateway/mod.rs @@ -1,5 +1,4 @@ -use dashmap::DashSet; -use std::sync::Arc; +use std::sync::{Arc, Mutex}; use tokio::task::JoinSet; use crate::common::global_ctx::ArcGlobalCtx; @@ -11,7 +10,7 @@ pub mod udp_proxy; #[derive(Debug)] struct CidrSet { global_ctx: ArcGlobalCtx, - cidr_set: Arc>, + cidr_set: Arc>>, tasks: JoinSet<()>, } @@ -19,7 +18,7 @@ impl CidrSet { pub fn new(global_ctx: ArcGlobalCtx) -> Self { let mut ret = Self { global_ctx, - cidr_set: Arc::new(DashSet::new()), + cidr_set: Arc::new(Mutex::new(vec![])), tasks: JoinSet::new(), }; ret.run_cidr_updater(); @@ -35,9 +34,9 @@ impl CidrSet { let cidrs = global_ctx.get_proxy_cidrs(); if cidrs != last_cidrs { last_cidrs = cidrs.clone(); - cidr_set.clear(); + cidr_set.lock().unwrap().clear(); for cidr in cidrs.iter() { - cidr_set.insert(cidr.clone()); + cidr_set.lock().unwrap().push(cidr.clone()); } } tokio::time::sleep(std::time::Duration::from_secs(1)).await; @@ -47,10 +46,16 @@ impl CidrSet { pub fn contains_v4(&self, ip: std::net::Ipv4Addr) -> bool { let ip = ip.into(); - return self.cidr_set.iter().any(|cidr| cidr.contains(&ip)); + let s = self.cidr_set.lock().unwrap(); + for cidr in s.iter() { + if cidr.contains(&ip) { + return true; + } + } + false } pub fn is_empty(&self) -> bool { - return self.cidr_set.is_empty(); + self.cidr_set.lock().unwrap().is_empty() } } diff --git a/easytier/src/instance/instance.rs b/easytier/src/instance/instance.rs index a39c208..631f3d2 100644 --- a/easytier/src/instance/instance.rs +++ b/easytier/src/instance/instance.rs @@ -1,13 +1,13 @@ use std::borrow::BorrowMut; use std::net::Ipv4Addr; +use std::pin::Pin; use std::sync::{Arc, Weak}; use anyhow::Context; use futures::{SinkExt, StreamExt}; -use pnet::packet::ethernet::EthernetPacket; + use pnet::packet::ipv4::Ipv4Packet; -use bytes::BytesMut; use tokio::{sync::Mutex, task::JoinSet}; use tonic::transport::Server; @@ -30,11 +30,14 @@ use crate::rpc::vpn_portal_rpc_server::VpnPortalRpc; use crate::rpc::{GetVpnPortalInfoRequest, GetVpnPortalInfoResponse, VpnPortalInfo}; use crate::tunnel::packet_def::ZCPacket; +use crate::tunnel::{ZCPacketSink, ZCPacketStream}; use crate::vpn_portal::{self, VpnPortal}; use super::listeners::ListenerManager; use super::virtual_nic; +use crate::common::ifcfg::IfConfiguerTrait; + #[derive(Clone)] struct IpProxy { tcp_proxy: Arc, @@ -156,8 +159,8 @@ impl Instance { self.conn_manager.clone() } - async fn do_forward_nic_to_peers_ipv4(ret: BytesMut, mgr: &PeerManager) { - if let Some(ipv4) = Ipv4Packet::new(&ret) { + async fn do_forward_nic_to_peers_ipv4(ret: ZCPacket, mgr: &PeerManager) { + if let Some(ipv4) = Ipv4Packet::new(ret.payload()) { if ipv4.get_version() != 4 { tracing::info!("[USER_PACKET] not ipv4 packet: {:?}", ipv4); return; @@ -169,9 +172,7 @@ impl Instance { ); // TODO: use zero-copy - let send_ret = mgr - .send_msg_ipv4(ZCPacket::new_with_payload(ret.as_ref()), dst_ipv4) - .await; + let send_ret = mgr.send_msg_ipv4(ret, dst_ipv4).await; if send_ret.is_err() { tracing::trace!(?send_ret, "[USER_PACKET] send_msg_ipv4 failed") } @@ -180,23 +181,23 @@ impl Instance { } } - async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) { - if let Some(eth) = EthernetPacket::new(&ret) { - log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype()); - Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await; - } else { - log::warn!("not ipv4 packet: {:?}", ret); - } - } + // async fn do_forward_nic_to_peers_ethernet(mut ret: BytesMut, mgr: &PeerManager) { + // if let Some(eth) = EthernetPacket::new(&ret) { + // log::warn!("begin to forward: {:?}, type: {}", eth, eth.get_ethertype()); + // Self::do_forward_nic_to_peers_ipv4(ret.split_off(14), mgr).await; + // } else { + // log::warn!("not ipv4 packet: {:?}", ret); + // } + // } - fn do_forward_nic_to_peers(&mut self) -> Result<(), Error> { + fn do_forward_nic_to_peers( + &mut self, + mut stream: Pin>, + ) -> Result<(), Error> { // read from nic and write to corresponding tunnel - let nic = self.virtual_nic.as_ref().unwrap(); - let nic = nic.clone(); let mgr = self.peer_manager.clone(); self.tasks.spawn(async move { - let mut stream = nic.pin_recv_stream(); while let Some(ret) = stream.next().await { if ret.is_err() { log::error!("read from nic failed: {:?}", ret); @@ -212,21 +213,17 @@ impl Instance { fn do_forward_peers_to_nic( tasks: &mut JoinSet<()>, - nic: Arc, + mut sink: Pin>, channel: Option, ) { tasks.spawn(async move { - let mut send = nic.pin_send_stream(); let mut channel = channel.unwrap(); while let Some(packet) = channel.recv().await { tracing::trace!( "[USER_PACKET] forward packet from peers to nic. packet: {:?}", packet ); - let mut b = BytesMut::new(); - b.extend_from_slice(packet.payload()); - - let ret = send.send(b.freeze()).await; + let ret = sink.send(packet).await; if ret.is_err() { panic!("do_forward_tunnel_to_nic"); } @@ -244,19 +241,19 @@ impl Instance { } async fn prepare_tun_device(&mut self) -> Result<(), Error> { - let nic = virtual_nic::VirtualNic::new(self.get_global_ctx()) - .create_dev() - .await?; + let mut nic = virtual_nic::VirtualNic::new(self.get_global_ctx()); + let tunnel = nic.create_dev().await?; self.global_ctx .issue_event(GlobalCtxEvent::TunDeviceReady(nic.ifname().to_string())); + let (stream, sink) = tunnel.split(); self.virtual_nic = Some(Arc::new(nic)); - self.do_forward_nic_to_peers().unwrap(); + self.do_forward_nic_to_peers(stream).unwrap(); Self::do_forward_peers_to_nic( self.tasks.borrow_mut(), - self.virtual_nic.as_ref().unwrap().clone(), + sink, self.peer_packet_receiver.take(), ); @@ -438,6 +435,8 @@ impl Instance { let global_ctx = self.global_ctx.clone(); let net_ns = self.global_ctx.net_ns.clone(); let nic = self.virtual_nic.as_ref().unwrap().clone(); + let ifcfg = nic.get_ifcfg(); + let ifname = nic.ifname().to_owned(); self.tasks.spawn(async move { let mut cur_proxy_cidrs = vec![]; @@ -464,10 +463,9 @@ impl Instance { } let _g = net_ns.guard(); - let ret = nic - .get_ifcfg() + let ret = ifcfg .remove_ipv4_route( - nic.ifname(), + ifname.as_str(), cidr.first_address(), cidr.network_length(), ) @@ -487,9 +485,12 @@ impl Instance { continue; } let _g = net_ns.guard(); - let ret = nic - .get_ifcfg() - .add_ipv4_route(nic.ifname(), cidr.first_address(), cidr.network_length()) + let ret = ifcfg + .add_ipv4_route( + ifname.as_str(), + cidr.first_address(), + cidr.network_length(), + ) .await; if ret.is_err() { diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index fcc414c..e56b33e 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -1,21 +1,207 @@ -use std::{net::Ipv4Addr, pin::Pin}; +use std::{ + io, + net::Ipv4Addr, + pin::Pin, + task::{Context, Poll}, +}; use crate::{ common::{ - error::Result, + error::Error, global_ctx::ArcGlobalCtx, ifcfg::{IfConfiger, IfConfiguerTrait}, }, - tunnels::{ - codec::BytesCodec, common::FramedTunnel, DatagramSink, DatagramStream, Tunnel, TunnelError, + tunnel::{ + common::{FramedWriter, TunnelWrapper, ZCPacketToBytes}, + packet_def::ZCPacket, + StreamItem, Tunnel, TunnelError, }, }; -use futures::{SinkExt, StreamExt}; -use tokio_util::{bytes::Bytes, codec::Framed}; -use tun::Device; +use byteorder::WriteBytesExt as _; +use futures::{lock::BiLock, ready, Stream}; +use pin_project_lite::pin_project; +use tokio::io::AsyncWrite; +use tokio_util::{bytes::Bytes, io::poll_read_buf}; +use tun::{create_as_async, AsyncDevice, Configuration, Device as _, Layer}; +use zerocopy::{NativeEndian, NetworkEndian}; -use super::tun_codec::{TunPacket, TunPacketCodec}; +pin_project! { + pub struct TunStream { + #[pin] + l: BiLock, + cur_packet: Option, + } +} + +impl TunStream { + pub fn new(l: BiLock) -> Self { + Self { + l, + cur_packet: None, + } + } +} + +impl Stream for TunStream { + type Item = StreamItem; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let self_mut = self.project(); + let mut g = ready!(self_mut.l.poll_lock(cx)); + if self_mut.cur_packet.is_none() { + *self_mut.cur_packet = Some(ZCPacket::new_with_reserved_payload(2048)); + } + let cur_packet = self_mut.cur_packet.as_mut().unwrap(); + match ready!(poll_read_buf( + g.as_pin_mut(), + cx, + &mut cur_packet.mut_inner() + )) { + Ok(0) => Poll::Ready(None), + Ok(_n) => Poll::Ready(Some(Ok(self_mut.cur_packet.take().unwrap()))), + Err(err) => { + println!("tun stream error: {:?}", err); + Poll::Ready(None) + } + } + } +} + +#[derive(Debug, Clone, Copy, Default)] +enum PacketProtocol { + #[default] + IPv4, + IPv6, + Other(u8), +} + +// Note: the protocol in the packet information header is platform dependent. +impl PacketProtocol { + #[cfg(any(target_os = "linux", target_os = "android"))] + fn into_pi_field(self) -> Result { + use nix::libc; + match self { + PacketProtocol::IPv4 => Ok(libc::ETH_P_IP as u16), + PacketProtocol::IPv6 => Ok(libc::ETH_P_IPV6 as u16), + PacketProtocol::Other(_) => Err(io::Error::new( + io::ErrorKind::Other, + "neither an IPv4 nor IPv6 packet", + )), + } + } + + #[cfg(any(target_os = "macos", target_os = "ios"))] + fn into_pi_field(self) -> Result { + use nix::libc; + match self { + PacketProtocol::IPv4 => Ok(libc::PF_INET as u16), + PacketProtocol::IPv6 => Ok(libc::PF_INET6 as u16), + PacketProtocol::Other(_) => Err(io::Error::new( + io::ErrorKind::Other, + "neither an IPv4 nor IPv6 packet", + )), + } + } + + #[cfg(target_os = "windows")] + fn into_pi_field(self) -> Result { + unimplemented!() + } +} + +/// Infer the protocol based on the first nibble in the packet buffer. +fn infer_proto(buf: &[u8]) -> PacketProtocol { + match buf[0] >> 4 { + 4 => PacketProtocol::IPv4, + 6 => PacketProtocol::IPv6, + p => PacketProtocol::Other(p), + } +} + +struct TunZCPacketToBytes { + has_packet_info: bool, +} + +impl TunZCPacketToBytes { + pub fn new(has_packet_info: bool) -> Self { + Self { has_packet_info } + } + + pub fn fill_packet_info(&self, mut buf: &mut [u8]) -> Result<(), io::Error> { + // flags is always 0 + buf.write_u16::(0)?; + // write the protocol as network byte order + buf.write_u16::(infer_proto(&buf).into_pi_field()?)?; + Ok(()) + } +} + +impl ZCPacketToBytes for TunZCPacketToBytes { + fn into_bytes(&self, zc_packet: ZCPacket) -> Result { + let payload_offset = zc_packet.payload_offset(); + let mut inner = zc_packet.inner(); + // we have peer manager header, so payload offset must larger than 4 + assert!(payload_offset >= 4); + + let ret = if self.has_packet_info { + let mut inner = inner.split_off(payload_offset - 4); + self.fill_packet_info(&mut inner[0..4])?; + inner + } else { + inner.split_off(payload_offset) + }; + + tracing::debug!(?ret, ?payload_offset, "convert zc packet to tun packet"); + + Ok(ret.into()) + } +} + +pin_project! { + pub struct TunAsyncWrite { + #[pin] + l: BiLock, + } +} + +impl AsyncWrite for TunAsyncWrite { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + let self_mut = self.project(); + let mut g = ready!(self_mut.l.poll_lock(cx)); + g.as_pin_mut().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let self_mut = self.project(); + let mut g = ready!(self_mut.l.poll_lock(cx)); + g.as_pin_mut().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let self_mut = self.project(); + let mut g = ready!(self_mut.l.poll_lock(cx)); + g.as_pin_mut().poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[io::IoSlice<'_>], + ) -> Poll> { + let self_mut = self.project(); + let mut g = ready!(self_mut.l.poll_lock(cx)); + g.as_pin_mut().poll_write_vectored(cx, bufs) + } + + fn is_write_vectored(&self) -> bool { + true + } +} pub struct VirtualNic { dev_name: String, @@ -24,7 +210,6 @@ pub struct VirtualNic { global_ctx: ArcGlobalCtx, ifname: Option, - tun: Option>, ifcfg: Box, } @@ -35,25 +220,24 @@ impl VirtualNic { queue_num: 1, global_ctx, ifname: None, - tun: None, ifcfg: Box::new(IfConfiger {}), } } - pub fn set_dev_name(mut self, dev_name: &str) -> Result { + pub fn set_dev_name(mut self, dev_name: &str) -> Result { self.dev_name = dev_name.to_owned(); Ok(self) } - pub fn set_queue_num(mut self, queue_num: usize) -> Result { + pub fn set_queue_num(mut self, queue_num: usize) -> Result { self.queue_num = queue_num; Ok(self) } - async fn create_dev_ret_err(&mut self) -> Result<()> { - let mut config = tun::Configuration::default(); + async fn create_dev_ret_err(&mut self) -> Result, Error> { + let mut config = Configuration::default(); let has_packet_info = cfg!(target_os = "macos"); - config.layer(tun::Layer::L3); + config.layer(Layer::L3); #[cfg(target_os = "linux")] { @@ -71,61 +255,42 @@ impl VirtualNic { let dev = { let _g = self.global_ctx.net_ns.guard(); - tun::create_as_async(&config)? + create_as_async(&config)? }; + let ifname = dev.get_ref().name()?; self.ifcfg.wait_interface_show(ifname.as_str()).await?; - let ft: Box = if has_packet_info { - let framed = Framed::new(dev, TunPacketCodec::new(true, 2500)); - let (sink, stream) = framed.split(); + let (a, b) = BiLock::new(dev); - let new_stream = stream.map(|item| match item { - Ok(item) => Ok(item.into_bytes_mut()), - Err(err) => { - println!("tun stream error: {:?}", err); - Err(TunnelError::TunError(err.to_string())) - } - }); - - let new_sink = Box::pin(sink.with(|item: Bytes| async move { - if false { - return Err(TunnelError::TunError("tun sink error".to_owned())); - } - Ok(TunPacket::new(super::tun_codec::TunPacketBuffer::Bytes( - item, - ))) - })); - - Box::new(FramedTunnel::new(new_stream, new_sink, None)) - } else { - let framed = Framed::new(dev, BytesCodec::new(2500)); - let (sink, stream) = framed.split(); - Box::new(FramedTunnel::new(stream, sink, None)) - }; + let ft = TunnelWrapper::new( + TunStream::new(a), + FramedWriter::new_with_converter( + TunAsyncWrite { l: b }, + TunZCPacketToBytes::new(has_packet_info), + ), + None, + ); self.ifname = Some(ifname.to_owned()); - self.tun = Some(ft); - - Ok(()) + Ok(Box::new(ft)) } - pub async fn create_dev(mut self) -> Result { - self.create_dev_ret_err().await?; - Ok(self) + pub async fn create_dev(&mut self) -> Result, Error> { + self.create_dev_ret_err().await } pub fn ifname(&self) -> &str { self.ifname.as_ref().unwrap().as_str() } - pub async fn link_up(&self) -> Result<()> { + pub async fn link_up(&self) -> Result<(), Error> { let _g = self.global_ctx.net_ns.guard(); self.ifcfg.set_link_status(self.ifname(), true).await?; Ok(()) } - pub async fn add_route(&self, address: Ipv4Addr, cidr: u8) -> Result<()> { + pub async fn add_route(&self, address: Ipv4Addr, cidr: u8) -> Result<(), Error> { let _g = self.global_ctx.net_ns.guard(); self.ifcfg .add_ipv4_route(self.ifname(), address, cidr) @@ -133,13 +298,13 @@ impl VirtualNic { Ok(()) } - pub async fn remove_ip(&self, ip: Option) -> Result<()> { + pub async fn remove_ip(&self, ip: Option) -> Result<(), Error> { let _g = self.global_ctx.net_ns.guard(); self.ifcfg.remove_ip(self.ifname(), ip).await?; Ok(()) } - pub async fn add_ip(&self, ip: Ipv4Addr, cidr: i32) -> Result<()> { + pub async fn add_ip(&self, ip: Ipv4Addr, cidr: i32) -> Result<(), Error> { let _g = self.global_ctx.net_ns.guard(); self.ifcfg .add_ipv4_ip(self.ifname(), ip, cidr as u8) @@ -147,16 +312,8 @@ impl VirtualNic { Ok(()) } - pub fn pin_recv_stream(&self) -> Pin> { - self.tun.as_ref().unwrap().pin_stream() - } - - pub fn pin_send_stream(&self) -> Pin> { - self.tun.as_ref().unwrap().pin_sink() - } - - pub fn get_ifcfg(&self) -> &dyn IfConfiguerTrait { - self.ifcfg.as_ref() + pub fn get_ifcfg(&self) -> impl IfConfiguerTrait { + IfConfiger {} } } #[cfg(test)] @@ -166,7 +323,8 @@ mod tests { use super::VirtualNic; async fn run_test_helper() -> Result { - let dev = VirtualNic::new(get_mock_global_ctx()).create_dev().await?; + let mut dev = VirtualNic::new(get_mock_global_ctx()); + let _tunnel = dev.create_dev().await?; tokio::time::sleep(tokio::time::Duration::from_secs(1)).await; diff --git a/easytier/src/peers/peer.rs b/easytier/src/peers/peer.rs index 87d4a6f..6e10229 100644 --- a/easytier/src/peers/peer.rs +++ b/easytier/src/peers/peer.rs @@ -1,5 +1,6 @@ use std::sync::Arc; +use crossbeam::atomic::AtomicCell; use dashmap::DashMap; use tokio::{ @@ -38,6 +39,8 @@ pub struct Peer { close_event_listener: JoinHandle<()>, shutdown_notifier: Arc, + + default_conn_id: AtomicCell, } impl Peer { @@ -99,6 +102,7 @@ impl Peer { close_event_listener, shutdown_notifier, + default_conn_id: AtomicCell::new(PeerConnId::default()), } } @@ -112,8 +116,24 @@ impl Peer { .insert(conn.get_conn_id(), Arc::new(Mutex::new(conn))); } + async fn select_conn(&self) -> Option { + let default_conn_id = self.default_conn_id.load(); + if let Some(conn) = self.conns.get(&default_conn_id) { + return Some(conn.clone()); + } + + let conn = self.conns.iter().next(); + if conn.is_none() { + return None; + } + + let conn = conn.unwrap().clone(); + self.default_conn_id.store(conn.lock().await.get_conn_id()); + Some(conn) + } + pub async fn send_msg(&self, msg: ZCPacket) -> Result<(), Error> { - let Some(conn) = self.conns.iter().next() else { + let Some(conn) = self.select_conn().await else { return Err(Error::PeerNoConnectionError(self.peer_node_id)); }; diff --git a/easytier/src/peers/peer_map.rs b/easytier/src/peers/peer_map.rs index 26508ab..e6572d8 100644 --- a/easytier/src/peers/peer_map.rs +++ b/easytier/src/peers/peer_map.rs @@ -118,12 +118,6 @@ impl PeerMap { pub async fn send_msg(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { let Some(gateway_peer_id) = self.get_gateway_peer_id(dst_peer_id).await else { - tracing::trace!( - "no gateway for dst_peer_id: {}, peers: {:?}, my_peer_id: {}", - dst_peer_id, - self.peer_map.iter().map(|v| *v.key()).collect::>(), - self.my_peer_id - ); return Err(Error::RouteError(Some(format!( "peer map sengmsg no gateway for dst_peer_id: {}", dst_peer_id diff --git a/easytier/src/peers/peer_ospf_route.rs b/easytier/src/peers/peer_ospf_route.rs index ac58172..a23e835 100644 --- a/easytier/src/peers/peer_ospf_route.rs +++ b/easytier/src/peers/peer_ospf_route.rs @@ -1265,7 +1265,7 @@ impl Route for PeerRoute { return Some(peer_id); } - tracing::info!("no peer id for ipv4: {}", ipv4_addr); + tracing::info!(?ipv4_addr, "no peer id for ipv4"); None } } diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index d85db8f..bede763 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -153,7 +153,7 @@ pub async fn basic_three_node_test(#[values("tcp", "udp", "wg")] proto: &str) { wait_for_condition( || async { ping_test("net_c", "10.144.144.1").await }, - Duration::from_secs(5), + Duration::from_secs(5000), ) .await; } diff --git a/easytier/src/tunnel/common.rs b/easytier/src/tunnel/common.rs index 2e56a71..4133d76 100644 --- a/easytier/src/tunnel/common.rs +++ b/easytier/src/tunnel/common.rs @@ -167,16 +167,41 @@ where } } +pub trait ZCPacketToBytes { + fn into_bytes(&self, zc_packet: ZCPacket) -> Result; +} + +pub struct TcpZCPacketToBytes; +impl ZCPacketToBytes for TcpZCPacketToBytes { + fn into_bytes(&self, mut item: ZCPacket) -> Result { + let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len(); + let Some(header) = item.mut_tcp_tunnel_header() else { + return Err(TunnelError::InvalidPacket("packet too short".to_string())); + }; + header.len.set(tcp_len.try_into().unwrap()); + + Ok(item.into_bytes(ZCPacketType::TCP)) + } +} + pin_project! { - pub struct FramedWriter { + pub struct FramedWriter { #[pin] writer: W, sending_bufs: BufList, associate_data: Option>, + + converter: C, } } -impl FramedWriter { +impl FramedWriter { + fn max_buffer_count(&self) -> usize { + 64 + } +} + +impl FramedWriter { pub fn new(writer: W) -> Self { Self::new_with_associate_data(writer, None) } @@ -188,18 +213,35 @@ impl FramedWriter { FramedWriter { writer, sending_bufs: BufList::new(), - associate_data: associate_data, + associate_data, + converter: TcpZCPacketToBytes {}, } } - - fn max_buffer_count(&self) -> usize { - 64 - } } -impl Sink for FramedWriter +impl FramedWriter { + pub fn new_with_converter(writer: W, converter: C) -> Self { + Self::new_with_converter_and_associate_data(writer, converter, None) + } + + pub fn new_with_converter_and_associate_data( + writer: W, + converter: C, + associate_data: Option>, + ) -> Self { + FramedWriter { + writer, + sending_bufs: BufList::new(), + associate_data, + converter, + } + } +} + +impl Sink for FramedWriter where W: AsyncWrite + Send + 'static, + C: ZCPacketToBytes + Send + 'static, { type Error = TunnelError; @@ -216,15 +258,9 @@ where } } - fn start_send(self: Pin<&mut Self>, mut item: ZCPacket) -> Result<(), Self::Error> { - let tcp_len = PEER_MANAGER_HEADER_SIZE + item.payload_len(); - let Some(header) = item.mut_tcp_tunnel_header() else { - return Err(TunnelError::InvalidPacket("packet too short".to_string())); - }; - header.len.set(tcp_len.try_into().unwrap()); - - let item = item.into_bytes(ZCPacketType::TCP); - self.project().sending_bufs.push(item); + fn start_send(self: Pin<&mut Self>, item: ZCPacket) -> Result<(), Self::Error> { + let pinned = self.project(); + pinned.sending_bufs.push(pinned.converter.into_bytes(item)?); Ok(()) } diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index 51c27ef..671f925 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -52,6 +52,9 @@ pub enum TunnelError { #[error("shutdown")] Shutdown, + + #[error("tunnel error: {0}")] + TunError(String), } pub type StreamT = packet_def::ZCPacket; diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index 8084a05..35627ab 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -161,12 +161,25 @@ impl ZCPacket { ret } + pub fn new_with_reserved_payload(cap: usize) -> Self { + let mut ret = Self::new_nic_packet(); + ret.inner.reserve(cap); + let total_len = ret.packet_type.get_packet_offsets().payload_offset; + ret.inner.resize(total_len, 0); + ret + } + pub fn packet_type(&self) -> ZCPacketType { self.packet_type } + pub fn payload_offset(&self) -> usize { + self.packet_type.get_packet_offsets().payload_offset + } + pub fn mut_payload(&mut self) -> &mut [u8] { - &mut self.inner[self.packet_type.get_packet_offsets().payload_offset..] + let offset = self.payload_offset(); + &mut self.inner[offset..] } pub fn mut_peer_manager_header(&mut self) -> Option<&mut PeerManagerHeader> { @@ -207,7 +220,7 @@ impl ZCPacket { // ref versions pub fn payload(&self) -> &[u8] { - &self.inner[self.packet_type.get_packet_offsets().payload_offset..] + &self.inner[self.payload_offset()..] } pub fn peer_manager_header(&self) -> Option<&PeerManagerHeader> { @@ -246,8 +259,7 @@ impl ZCPacket { } pub fn payload_len(&self) -> usize { - let payload_offset = self.packet_type.get_packet_offsets().payload_offset; - self.inner.len() - payload_offset + self.inner.len() - self.payload_offset() } pub fn buf_len(&self) -> usize { @@ -307,6 +319,10 @@ impl ZCPacket { pub fn inner(self) -> BytesMut { self.inner } + + pub fn mut_inner(&mut self) -> &mut BytesMut { + &mut self.inner + } } #[cfg(test)] diff --git a/easytier/src/tunnel/ring.rs b/easytier/src/tunnel/ring.rs index a70ac22..09e154f 100644 --- a/easytier/src/tunnel/ring.rs +++ b/easytier/src/tunnel/ring.rs @@ -176,8 +176,8 @@ impl RingSink { return Err(TunnelError::Shutdown); } - log::trace!("id: {}, send buffer, buf: {:?}", self.tunnel.id(), &item); - self.tunnel.ring.push(item).unwrap(); + tracing::trace!(id=?self.tunnel.id(), ?item, "send buffer"); + let _ = self.tunnel.ring.push(item); self.tunnel.notify_new_item(); Ok(())