nat4-nat4 punch (#388)
Some checks are pending
EasyTier Core / pre_job (push) Waiting to run
EasyTier Core / build (freebsd-13.2-x86_64, 13.2, ubuntu-latest, x86_64-unknown-freebsd) (push) Blocked by required conditions
EasyTier Core / build (linux-aarch64, ubuntu-latest, aarch64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-arm, ubuntu-latest, arm-unknown-linux-musleabi) (push) Blocked by required conditions
EasyTier Core / build (linux-armhf, ubuntu-latest, arm-unknown-linux-musleabihf) (push) Blocked by required conditions
EasyTier Core / build (linux-armv7, ubuntu-latest, armv7-unknown-linux-musleabi) (push) Blocked by required conditions
EasyTier Core / build (linux-armv7hf, ubuntu-latest, armv7-unknown-linux-musleabihf) (push) Blocked by required conditions
EasyTier Core / build (linux-mips, ubuntu-latest, mips-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-mipsel, ubuntu-latest, mipsel-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-x86_64, ubuntu-latest, x86_64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (macos-aarch64, macos-latest, aarch64-apple-darwin) (push) Blocked by required conditions
EasyTier Core / build (macos-x86_64, macos-latest, x86_64-apple-darwin) (push) Blocked by required conditions
EasyTier Core / build (windows-x86_64, windows-latest, x86_64-pc-windows-msvc) (push) Blocked by required conditions
EasyTier Core / core-result (push) Blocked by required conditions
EasyTier GUI / pre_job (push) Waiting to run
EasyTier GUI / build-gui (linux-aarch64, aarch64-unknown-linux-gnu, ubuntu-latest, aarch64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier GUI / build-gui (linux-x86_64, x86_64-unknown-linux-gnu, ubuntu-latest, x86_64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier GUI / build-gui (macos-aarch64, aarch64-apple-darwin, macos-latest, aarch64-apple-darwin) (push) Blocked by required conditions
EasyTier GUI / build-gui (macos-x86_64, x86_64-apple-darwin, macos-latest, x86_64-apple-darwin) (push) Blocked by required conditions
EasyTier GUI / build-gui (windows-x86_64, x86_64-pc-windows-msvc, windows-latest, x86_64-pc-windows-msvc) (push) Blocked by required conditions
EasyTier GUI / gui-result (push) Blocked by required conditions
EasyTier Mobile / pre_job (push) Waiting to run
EasyTier Mobile / build-mobile (android, ubuntu-latest, android) (push) Blocked by required conditions
EasyTier Mobile / mobile-result (push) Blocked by required conditions
EasyTier Test / pre_job (push) Waiting to run
EasyTier Test / test (push) Blocked by required conditions

this patch optimize the udp hole punch logic:

1. allow start punch hole before stun test complete.
2. add lock to symmetric punch, avoid conflict between concurrent hole punching task.
3. support punching hole for predictable nat4-nat4.
4. make backoff of retry reasonable
This commit is contained in:
Sijie.Sun 2024-10-06 22:49:18 +08:00 committed by GitHub
parent ba3da97ad4
commit 37ceb77bf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 2748 additions and 1310 deletions

View File

@ -343,6 +343,8 @@ impl StunClientBuilder {
pub struct UdpNatTypeDetectResult {
source_addr: SocketAddr,
stun_resps: Vec<BindRequestResponse>,
// if we are easy symmetric nat, we need to test with another port to check inc or dec
extra_bind_test: Option<BindRequestResponse>,
}
impl UdpNatTypeDetectResult {
@ -350,6 +352,7 @@ impl UdpNatTypeDetectResult {
Self {
source_addr,
stun_resps,
extra_bind_test: None,
}
}
@ -406,7 +409,7 @@ impl UdpNatTypeDetectResult {
.filter_map(|x| x.mapped_socket_addr)
.collect::<BTreeSet<_>>()
.len();
mapped_addr_count < self.stun_server_count()
mapped_addr_count == 1
}
pub fn nat_type(&self) -> NatType {
@ -429,7 +432,32 @@ impl UdpNatTypeDetectResult {
return NatType::PortRestricted;
}
} else if !self.stun_resps.is_empty() {
if self.public_ips().len() != 1
|| self.usable_stun_resp_count() <= 1
|| self.max_port() - self.min_port() > 15
|| self.extra_bind_test.is_none()
|| self
.extra_bind_test
.as_ref()
.unwrap()
.mapped_socket_addr
.is_none()
{
return NatType::Symmetric;
} else {
let extra_bind_test = self.extra_bind_test.as_ref().unwrap();
let extra_port = extra_bind_test.mapped_socket_addr.unwrap().port();
let max_port_diff = extra_port.saturating_sub(self.max_port());
let min_port_diff = self.min_port().saturating_sub(extra_port);
if max_port_diff != 0 && max_port_diff < 100 {
return NatType::SymmetricEasyInc;
} else if min_port_diff != 0 && min_port_diff < 100 {
return NatType::SymmetricEasyDec;
} else {
return NatType::Symmetric;
}
}
} else {
return NatType::Unknown;
}
@ -477,6 +505,13 @@ impl UdpNatTypeDetectResult {
.max()
.unwrap_or(u16::MAX)
}
pub fn usable_stun_resp_count(&self) -> usize {
self.stun_resps
.iter()
.filter(|x| x.mapped_socket_addr.is_some())
.count()
}
}
pub struct UdpNatTypeDetector {
@ -492,6 +527,19 @@ impl UdpNatTypeDetector {
}
}
async fn get_extra_bind_result(
&self,
source_port: u16,
stun_server: SocketAddr,
) -> Result<BindRequestResponse, Error> {
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?);
let client_builder = StunClientBuilder::new(udp.clone());
client_builder
.new_stun_client(stun_server)
.bind_request(false, false)
.await
}
pub async fn detect_nat_type(&self, source_port: u16) -> Result<UdpNatTypeDetectResult, Error> {
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?);
self.detect_nat_type_with_socket(udp).await
@ -578,13 +626,28 @@ impl StunInfoCollectorTrait for StunInfoCollector {
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
self.start_stun_routine();
let stun_servers = self
let mut stun_servers = self
.udp_nat_test_result
.read()
.unwrap()
.clone()
.map(|x| x.collect_available_stun_server())
.ok_or(Error::NotFound)?;
.unwrap_or(vec![]);
if stun_servers.is_empty() {
let mut host_resolver =
HostResolverIter::new(self.stun_servers.read().unwrap().clone(), 2);
while let Some(addr) = host_resolver.next().await {
stun_servers.push(addr);
if stun_servers.len() >= 2 {
break;
}
}
}
if stun_servers.is_empty() {
return Err(Error::NotFound);
}
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", local_port)).await?);
let mut client_builder = StunClientBuilder::new(udp.clone());
@ -630,9 +693,9 @@ impl StunInfoCollector {
// stun server cross nation may return a external ip address with high latency and loss rate
vec![
"stun.miwifi.com",
"stun.cdnbye.com",
"stun.hitv.com",
"stun.chat.bilibili.com",
"stun.hitv.com",
"stun.cdnbye.com",
"stun.douyucdn.cn:18000",
"fwa.lifesizecloud.com",
"global.turn.twilio.com",
@ -673,38 +736,41 @@ impl StunInfoCollector {
.map(|x| x.to_string())
.collect();
let detector = UdpNatTypeDetector::new(servers, 1);
let ret = detector.detect_nat_type(0).await;
let mut ret = detector.detect_nat_type(0).await;
tracing::debug!(?ret, "finish udp nat type detect");
let mut nat_type = NatType::Unknown;
let sleep_sec = match &ret {
Ok(resp) => {
*udp_nat_test_result.write().unwrap() = Some(resp.clone());
udp_test_time.store(Local::now());
if let Ok(resp) = &ret {
tracing::debug!(?resp, "got udp nat type detect result");
nat_type = resp.nat_type();
if nat_type == NatType::Unknown {
15
} else {
600
}
}
_ => 15,
};
// if nat type is symmtric, detect with another port to gather more info
if nat_type == NatType::Symmetric {
let old_resp = ret.unwrap();
let old_local_port = old_resp.local_addr().port();
let new_port = if old_local_port >= 65535 {
old_local_port - 1
} else {
old_local_port + 1
};
let ret = detector.detect_nat_type(new_port).await;
let old_resp = ret.as_mut().unwrap();
tracing::debug!(?old_resp, "start get extra bind result");
let available_stun_servers = old_resp.collect_available_stun_server();
for server in available_stun_servers.iter() {
let ret = detector
.get_extra_bind_result(0, *server)
.await
.with_context(|| "get extra bind result failed");
tracing::debug!(?ret, "finish udp nat type detect with another port");
if let Ok(resp) = ret {
udp_nat_test_result.write().unwrap().as_mut().map(|x| {
x.extend_result(resp);
});
old_resp.extra_bind_test = Some(resp);
break;
}
}
}
let mut sleep_sec = 10;
if let Ok(resp) = &ret {
udp_test_time.store(Local::now());
*udp_nat_test_result.write().unwrap() = Some(resp.clone());
if nat_type != NatType::Unknown
&& (nat_type != NatType::Symmetric || resp.extra_bind_test.is_some())
{
sleep_sec = 600
}
}
@ -734,7 +800,7 @@ impl StunInfoCollectorTrait for MockStunInfoCollector {
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
min_port: 100,
max_port: 200,
..Default::default()
public_ip: vec!["127.0.0.1".to_string()],
}
}

View File

@ -425,7 +425,7 @@ impl DirectConnectorManager {
);
let ip_list = rpc_stub
.get_ip_list(BaseController {}, GetIpListRequest {})
.get_ip_list(BaseController::default(), GetIpListRequest {})
.await
.with_context(|| format!("get ip list from peer {}", dst_peer_id))?;

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,396 @@
use std::{
net::{IpAddr, SocketAddr, SocketAddrV4},
sync::Arc,
time::{Duration, Instant},
};
use anyhow::Context;
use tokio::sync::Mutex;
use crate::{
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId},
connector::udp_hole_punch::common::{
try_connect_with_socket, UdpHolePunchListener, HOLE_PUNCH_PACKET_BODY_LEN,
},
peers::peer_manager::PeerManager,
proto::{
peer_rpc::{
SendPunchPacketBothEasySymRequest, SendPunchPacketBothEasySymResponse,
UdpHolePunchRpcClientFactory,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{udp::new_hole_punch_packet, Tunnel},
};
use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray};
const UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM: usize = 25;
const DST_PORT_OFFSET: u16 = 20;
const REMOTE_WAIT_TIME_MS: u64 = 5000;
pub(crate) struct PunchBothEasySymHoleServer {
common: Arc<PunchHoleServerCommon>,
task: Mutex<Option<ScopedTask<()>>>,
}
impl PunchBothEasySymHoleServer {
pub(crate) fn new(common: Arc<PunchHoleServerCommon>) -> Self {
Self {
common,
task: Mutex::new(None),
}
}
// hard sym means public port is random and cannot be predicted
#[tracing::instrument(skip(self), ret, err)]
pub(crate) async fn send_punch_packet_both_easy_sym(
&self,
request: SendPunchPacketBothEasySymRequest,
) -> Result<SendPunchPacketBothEasySymResponse, rpc_types::error::Error> {
tracing::info!("send_punch_packet_both_easy_sym start");
let busy_resp = Ok(SendPunchPacketBothEasySymResponse {
is_busy: true,
..Default::default()
});
let Ok(mut locked_task) = self.task.try_lock() else {
return busy_resp;
};
if locked_task.is_some() && !locked_task.as_ref().unwrap().is_finished() {
return busy_resp;
}
let global_ctx = self.common.get_global_ctx();
let cur_mapped_addr = global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(0)
.await
.with_context(|| "failed to get udp port mapping")?;
tracing::info!("send_punch_packet_hard_sym start");
let socket_count = request.udp_socket_count as usize;
let public_ips = request
.public_ip
.ok_or(anyhow::anyhow!("public_ip is required"))?;
let transaction_id = request.transaction_id;
let udp_array =
UdpSocketArray::new(socket_count, self.common.get_global_ctx().net_ns.clone());
udp_array.start().await?;
udp_array.add_intreast_tid(transaction_id);
let peer_mgr = self.common.get_peer_mgr();
let punch_packet =
new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes();
let mut punched = vec![];
let common = self.common.clone();
let task = tokio::spawn(async move {
let mut listeners = Vec::new();
let start_time = Instant::now();
let wait_time_ms = request.wait_time_ms.min(8000);
while start_time.elapsed() < Duration::from_millis(wait_time_ms as u64) {
if let Err(e) = udp_array
.send_with_all(
&punch_packet,
SocketAddr::V4(SocketAddrV4::new(
public_ips.into(),
request.dst_port_num as u16,
)),
)
.await
{
tracing::error!(?e, "failed to send hole punch packet");
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
if let Some(s) = udp_array.try_fetch_punched_socket(transaction_id) {
tracing::info!(?s, ?transaction_id, "got punched socket in both easy sym");
assert!(Arc::strong_count(&s.socket) == 1);
let Some(port) = s.socket.local_addr().ok().map(|addr| addr.port()) else {
tracing::warn!("failed to get local addr from punched socket");
continue;
};
let remote_addr = s.remote_addr;
drop(s);
let listener =
match UdpHolePunchListener::new_ext(peer_mgr.clone(), false, Some(port))
.await
{
Ok(l) => l,
Err(e) => {
tracing::warn!(?e, "failed to create listener");
continue;
}
};
punched.push((listener.get_socket().await, remote_addr));
listeners.push(listener);
}
// if any listener is punched, we can break the loop
for l in &listeners {
if l.get_conn_count().await > 0 {
tracing::info!(?l, "got punched listener");
break;
}
}
if !punched.is_empty() {
tracing::debug!(?punched, "got punched socket and keep sending punch packet");
}
for p in &punched {
let (socket, remote_addr) = p;
let send_remote_ret = socket.send_to(&punch_packet, remote_addr).await;
tracing::debug!(
?send_remote_ret,
?socket,
"send hole punch packet to punched remote"
);
}
}
for l in listeners {
if l.get_conn_count().await > 0 {
common.add_listener(l).await;
}
}
});
*locked_task = Some(task.into());
return Ok(SendPunchPacketBothEasySymResponse {
is_busy: false,
base_mapped_addr: Some(cur_mapped_addr.into()),
});
}
}
#[derive(Debug)]
pub(crate) struct PunchBothEasySymHoleClient {
peer_mgr: Arc<PeerManager>,
}
impl PunchBothEasySymHoleClient {
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
Self { peer_mgr }
}
#[tracing::instrument(ret)]
pub(crate) async fn do_hole_punching(
&self,
dst_peer_id: PeerId,
my_nat_info: UdpNatType,
peer_nat_info: UdpNatType,
is_busy: &mut bool,
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
*is_busy = false;
let udp_array = UdpSocketArray::new(
UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM,
self.peer_mgr.get_global_ctx().net_ns.clone(),
);
udp_array.start().await?;
let global_ctx = self.peer_mgr.get_global_ctx();
let cur_mapped_addr = global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(0)
.await
.with_context(|| "failed to get udp port mapping")?;
let my_public_ip = match cur_mapped_addr.ip() {
IpAddr::V4(v4) => v4,
_ => {
anyhow::bail!("ipv6 is not supported");
}
};
let me_is_incremental = my_nat_info
.get_inc_of_easy_sym()
.ok_or(anyhow::anyhow!("me_is_incremental is required"))?;
let peer_is_incremental = peer_nat_info
.get_inc_of_easy_sym()
.ok_or(anyhow::anyhow!("peer_is_incremental is required"))?;
let rpc_stub = self
.peer_mgr
.get_peer_rpc_mgr()
.rpc_client()
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
self.peer_mgr.my_peer_id(),
dst_peer_id,
global_ctx.get_network_name(),
);
let tid = rand::random();
udp_array.add_intreast_tid(tid);
let remote_ret = rpc_stub
.send_punch_packet_both_easy_sym(
BaseController {
timeout_ms: 2000,
..Default::default()
},
SendPunchPacketBothEasySymRequest {
transaction_id: tid,
public_ip: Some(my_public_ip.into()),
dst_port_num: if me_is_incremental {
cur_mapped_addr.port().saturating_add(DST_PORT_OFFSET)
} else {
cur_mapped_addr.port().saturating_sub(DST_PORT_OFFSET)
} as u32,
udp_socket_count: UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM as u32,
wait_time_ms: REMOTE_WAIT_TIME_MS as u32,
},
)
.await?;
if remote_ret.is_busy {
*is_busy = true;
anyhow::bail!("remote is busy");
}
let mut remote_mapped_addr = remote_ret
.base_mapped_addr
.ok_or(anyhow::anyhow!("remote_mapped_addr is required"))?;
let now = Instant::now();
remote_mapped_addr.port = if peer_is_incremental {
remote_mapped_addr
.port
.saturating_add(DST_PORT_OFFSET as u32)
} else {
remote_mapped_addr
.port
.saturating_sub(DST_PORT_OFFSET as u32)
};
tracing::debug!(
?remote_mapped_addr,
?remote_ret,
"start send hole punch packet for both easy sym"
);
while now.elapsed().as_millis() < (REMOTE_WAIT_TIME_MS + 1000).into() {
udp_array
.send_with_all(
&new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(),
remote_mapped_addr.into(),
)
.await?;
tokio::time::sleep(Duration::from_millis(100)).await;
let Some(socket) = udp_array.try_fetch_punched_socket(tid) else {
tracing::trace!(
?remote_mapped_addr,
?tid,
"no punched socket found, send some more hole punch packets"
);
continue;
};
tracing::info!(
?socket,
?remote_mapped_addr,
?tid,
"got punched socket in both easy sym"
);
for _ in 0..2 {
match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into())
.await
{
Ok(tunnel) => {
return Ok(tunnel);
}
Err(e) => {
tracing::error!(?e, "failed to connect with socket");
continue;
}
}
}
udp_array.add_new_socket(socket.socket).await?;
}
anyhow::bail!("failed to punch hole for both easy sym");
}
}
#[cfg(test)]
pub mod tests {
use std::{
sync::{atomic::AtomicU32, Arc},
time::Duration,
};
use tokio::net::UdpSocket;
use crate::{
connector::udp_hole_punch::{
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector,
},
peers::tests::{connect_peer_manager, wait_route_appear},
proto::common::NatType,
tunnel::common::tests::wait_for_condition,
};
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial(hole_punch)]
async fn hole_punching_easy_sym(#[values("true", "false")] is_inc: bool) {
let p_a = create_mock_peer_manager_with_mock_stun(if is_inc {
NatType::SymmetricEasyInc
} else {
NatType::SymmetricEasyDec
})
.await;
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
let p_c = create_mock_peer_manager_with_mock_stun(if !is_inc {
NatType::SymmetricEasyInc
} else {
NatType::SymmetricEasyDec
})
.await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
// 144 + DST_PORT_OFFSET = 164
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40164").await.unwrap());
// 144 - DST_PORT_OFFSET = 124
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40124").await.unwrap());
let udps = vec![udp1, udp2];
let counter = Arc::new(AtomicU32::new(0));
// all these sockets should receive hole punching packet
for udp in udps.iter().map(Arc::clone) {
let counter = counter.clone();
tokio::spawn(async move {
let mut buf = [0u8; 1024];
let (len, addr) = udp.recv_from(&mut buf).await.unwrap();
println!(
"got predictable punch packet, {:?} {:?} {:?}",
len,
addr,
udp.local_addr()
);
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
});
}
hole_punching_a.client.run_immediately().await;
let udp_len = udps.len();
wait_for_condition(
|| async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 },
Duration::from_secs(30),
)
.await;
}
}

View File

@ -0,0 +1,573 @@
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
time::Duration,
};
use crossbeam::atomic::AtomicCell;
use dashmap::{DashMap, DashSet};
use rand::seq::SliceRandom as _;
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
use tracing::{instrument, Instrument, Level};
use zerocopy::FromBytes as _;
use crate::{
common::{
error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
stun::StunInfoCollectorTrait as _,
},
defer,
peers::peer_manager::PeerManager,
proto::common::NatType,
tunnel::{
packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE},
udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener},
Tunnel, TunnelConnCounter, TunnelListener as _,
},
};
pub(crate) const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 16;
fn generate_shuffled_port_vec() -> Vec<u16> {
let mut rng = rand::thread_rng();
let mut port_vec: Vec<u16> = (1..=65535).collect();
port_vec.shuffle(&mut rng);
port_vec
}
pub(crate) enum UdpPunchClientMethod {
None,
ConeToCone,
SymToCone,
EasySymToEasySym,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub(crate) enum UdpNatType {
Unknown,
Open(NatType),
Cone(NatType),
// bool means if it is incremental
EasySymmetric(NatType, bool),
HardSymmetric(NatType),
}
impl From<NatType> for UdpNatType {
fn from(nat_type: NatType) -> Self {
match nat_type {
NatType::Unknown => UdpNatType::Unknown,
NatType::NoPat | NatType::OpenInternet => UdpNatType::Open(nat_type),
NatType::FullCone | NatType::Restricted | NatType::PortRestricted => {
UdpNatType::Cone(nat_type)
}
NatType::Symmetric | NatType::SymUdpFirewall => UdpNatType::HardSymmetric(nat_type),
NatType::SymmetricEasyInc => UdpNatType::EasySymmetric(nat_type, true),
NatType::SymmetricEasyDec => UdpNatType::EasySymmetric(nat_type, false),
}
}
}
impl Into<NatType> for UdpNatType {
fn into(self) -> NatType {
match self {
UdpNatType::Unknown => NatType::Unknown,
UdpNatType::Open(nat_type) => nat_type,
UdpNatType::Cone(nat_type) => nat_type,
UdpNatType::EasySymmetric(nat_type, _) => nat_type,
UdpNatType::HardSymmetric(nat_type) => nat_type,
}
}
}
impl UdpNatType {
pub(crate) fn is_open(&self) -> bool {
matches!(self, UdpNatType::Open(_))
}
pub(crate) fn is_unknown(&self) -> bool {
matches!(self, UdpNatType::Unknown)
}
pub(crate) fn is_sym(&self) -> bool {
self.is_hard_sym() || self.is_easy_sym()
}
pub(crate) fn is_hard_sym(&self) -> bool {
matches!(self, UdpNatType::HardSymmetric(_))
}
pub(crate) fn is_easy_sym(&self) -> bool {
matches!(self, UdpNatType::EasySymmetric(_, _))
}
pub(crate) fn is_cone(&self) -> bool {
matches!(self, UdpNatType::Cone(_))
}
pub(crate) fn get_inc_of_easy_sym(&self) -> Option<bool> {
match self {
UdpNatType::EasySymmetric(_, inc) => Some(*inc),
_ => None,
}
}
pub(crate) fn get_punch_hole_method(&self, other: Self) -> UdpPunchClientMethod {
if other.is_unknown() {
if self.is_sym() {
return UdpPunchClientMethod::SymToCone;
} else {
return UdpPunchClientMethod::ConeToCone;
}
}
if self.is_unknown() {
if other.is_sym() {
return UdpPunchClientMethod::None;
} else {
return UdpPunchClientMethod::ConeToCone;
}
}
if self.is_open() || other.is_open() {
// open nat does not need to punch hole
return UdpPunchClientMethod::None;
}
if self.is_cone() {
if other.is_sym() {
return UdpPunchClientMethod::None;
} else {
return UdpPunchClientMethod::ConeToCone;
}
} else if self.is_easy_sym() {
if other.is_hard_sym() {
return UdpPunchClientMethod::None;
} else if other.is_easy_sym() {
return UdpPunchClientMethod::EasySymToEasySym;
} else {
return UdpPunchClientMethod::SymToCone;
}
} else if self.is_hard_sym() {
if other.is_sym() {
return UdpPunchClientMethod::None;
} else {
return UdpPunchClientMethod::SymToCone;
}
}
unreachable!("invalid nat type");
}
pub(crate) fn can_punch_hole_as_client(&self, other: Self) -> bool {
!matches!(
self.get_punch_hole_method(other),
UdpPunchClientMethod::None
)
}
}
#[derive(Debug)]
pub(crate) struct PunchedUdpSocket {
pub(crate) socket: Arc<UdpSocket>,
pub(crate) tid: u32,
pub(crate) remote_addr: SocketAddr,
}
// used for symmetric hole punching, binding to multiple ports to increase the chance of success
pub(crate) struct UdpSocketArray {
sockets: Arc<DashMap<SocketAddr, Arc<UdpSocket>>>,
max_socket_count: usize,
net_ns: NetNS,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
intreast_tids: Arc<DashSet<u32>>,
tid_to_socket: Arc<DashMap<u32, Vec<PunchedUdpSocket>>>,
}
impl UdpSocketArray {
pub fn new(max_socket_count: usize, net_ns: NetNS) -> Self {
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "UdpSocketArray".to_owned());
Self {
sockets: Arc::new(DashMap::new()),
max_socket_count,
net_ns,
tasks,
intreast_tids: Arc::new(DashSet::new()),
tid_to_socket: Arc::new(DashMap::new()),
}
}
pub fn started(&self) -> bool {
!self.sockets.is_empty()
}
pub async fn add_new_socket(&self, socket: Arc<UdpSocket>) -> Result<(), anyhow::Error> {
let socket_map = self.sockets.clone();
let local_addr = socket.local_addr()?;
let intreast_tids = self.intreast_tids.clone();
let tid_to_socket = self.tid_to_socket.clone();
socket_map.insert(local_addr, socket.clone());
self.tasks.lock().unwrap().spawn(
async move {
defer!(socket_map.remove(&local_addr););
let mut buf = [0u8; UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize];
tracing::trace!(?local_addr, "udp socket added");
loop {
let Ok((len, addr)) = socket.recv_from(&mut buf).await else {
break;
};
tracing::debug!(?len, ?addr, "got raw packet");
if len != UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize {
continue;
}
let Some(p) = UDPTunnelHeader::ref_from_prefix(&buf) else {
continue;
};
let tid = p.conn_id.get();
let valid = p.msg_type == UdpPacketType::HolePunch as u8
&& p.len.get() == HOLE_PUNCH_PACKET_BODY_LEN;
tracing::debug!(?p, ?addr, ?tid, ?valid, ?p, "got udp hole punch packet");
if !valid {
continue;
}
if intreast_tids.contains(&tid) {
tracing::info!(?addr, ?tid, "got hole punching packet with intreast tid");
tid_to_socket
.entry(tid)
.or_insert_with(Vec::new)
.push(PunchedUdpSocket {
socket: socket.clone(),
tid,
remote_addr: addr,
});
break;
}
}
tracing::debug!(?local_addr, "udp socket recv loop end");
}
.instrument(tracing::info_span!("udp array socket recv loop")),
);
Ok(())
}
#[instrument(err)]
pub async fn start(&self) -> Result<(), anyhow::Error> {
tracing::info!("starting udp socket array");
while self.sockets.len() < self.max_socket_count {
let socket = {
let _g = self.net_ns.guard();
Arc::new(UdpSocket::bind("0.0.0.0:0").await?)
};
self.add_new_socket(socket).await?;
}
Ok(())
}
#[instrument(err)]
pub async fn send_with_all(&self, data: &[u8], addr: SocketAddr) -> Result<(), anyhow::Error> {
tracing::info!(?addr, "sending hole punching packet");
for socket in self.sockets.iter() {
let socket = socket.value();
socket.send_to(data, addr).await?;
}
Ok(())
}
#[instrument(ret(level = Level::DEBUG))]
pub fn try_fetch_punched_socket(&self, tid: u32) -> Option<PunchedUdpSocket> {
tracing::debug!(?tid, "try fetch punched socket");
self.tid_to_socket.get_mut(&tid)?.value_mut().pop()
}
pub fn add_intreast_tid(&self, tid: u32) {
self.intreast_tids.insert(tid);
}
pub fn remove_intreast_tid(&self, tid: u32) {
self.intreast_tids.remove(&tid);
self.tid_to_socket.remove(&tid);
}
}
impl std::fmt::Debug for UdpSocketArray {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdpSocketArray")
.field("sockets", &self.sockets.len())
.field("max_socket_count", &self.max_socket_count)
.field("started", &self.started())
.field("intreast_tids", &self.intreast_tids.len())
.field("tid_to_socket", &self.tid_to_socket.len())
.finish()
}
}
#[derive(Debug)]
pub(crate) struct UdpHolePunchListener {
socket: Arc<UdpSocket>,
tasks: JoinSet<()>,
running: Arc<AtomicCell<bool>>,
mapped_addr: SocketAddr,
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
listen_time: std::time::Instant,
last_select_time: AtomicCell<std::time::Instant>,
last_active_time: Arc<AtomicCell<std::time::Instant>>,
}
impl UdpHolePunchListener {
async fn get_avail_port() -> Result<u16, Error> {
let socket = UdpSocket::bind("0.0.0.0:0").await?;
Ok(socket.local_addr()?.port())
}
#[instrument(err)]
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
Self::new_ext(peer_mgr, true, None).await
}
#[instrument(err)]
pub async fn new_ext(
peer_mgr: Arc<PeerManager>,
with_mapped_addr: bool,
port: Option<u16>,
) -> Result<Self, Error> {
let port = port.unwrap_or(Self::get_avail_port().await?);
let listen_url = format!("udp://0.0.0.0:{}", port);
let mapped_addr = if with_mapped_addr {
let gctx = peer_mgr.get_global_ctx();
let stun_info_collect = gctx.get_stun_info_collector();
stun_info_collect.get_udp_port_mapping(port).await?
} else {
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), port))
};
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
{
let _g = peer_mgr.get_global_ctx().net_ns.guard();
listener.listen().await?;
}
let socket = listener.get_socket().unwrap();
let running = Arc::new(AtomicCell::new(true));
let running_clone = running.clone();
let conn_counter = listener.get_conn_counter();
let mut tasks = JoinSet::new();
tasks.spawn(async move {
while let Ok(conn) = listener.accept().await {
tracing::warn!(?conn, "udp hole punching listener got peer connection");
let peer_mgr = peer_mgr.clone();
tokio::spawn(async move {
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
tracing::error!(
?e,
"failed to add tunnel as server in hole punch listener"
);
}
});
}
running_clone.store(false);
});
let last_active_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
let conn_counter_clone = conn_counter.clone();
let last_active_time_clone = last_active_time.clone();
tasks.spawn(async move {
loop {
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
if conn_counter_clone.get().unwrap_or(0) != 0 {
last_active_time_clone.store(std::time::Instant::now());
}
}
});
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
Ok(Self {
tasks,
socket,
running,
mapped_addr,
conn_counter,
listen_time: std::time::Instant::now(),
last_select_time: AtomicCell::new(std::time::Instant::now()),
last_active_time,
})
}
pub async fn get_socket(&self) -> Arc<UdpSocket> {
self.last_select_time.store(std::time::Instant::now());
self.socket.clone()
}
pub async fn get_conn_count(&self) -> usize {
self.conn_counter.get().unwrap_or(0) as usize
}
}
pub(crate) struct PunchHoleServerCommon {
peer_mgr: Arc<PeerManager>,
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
}
impl PunchHoleServerCommon {
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
join_joinset_background(tasks.clone(), "PunchHoleServerCommon".to_owned());
let listeners = Arc::new(Mutex::new(Vec::<UdpHolePunchListener>::new()));
let l = listeners.clone();
tasks.lock().unwrap().spawn(async move {
loop {
tokio::time::sleep(Duration::from_secs(5)).await;
{
// remove listener that is not active for 40 seconds but keep listeners that are selected less than 30 seconds
l.lock().await.retain(|listener| {
listener.last_active_time.load().elapsed().as_secs() < 40
|| listener.last_select_time.load().elapsed().as_secs() < 30
});
}
}
});
Self {
peer_mgr,
listeners,
tasks,
}
}
pub(crate) async fn add_listener(&self, listener: UdpHolePunchListener) {
self.listeners.lock().await.push(listener);
}
pub(crate) async fn find_listener(&self, addr: &SocketAddr) -> Option<Arc<UdpSocket>> {
let all_listener_sockets = self.listeners.lock().await;
let listener = all_listener_sockets
.iter()
.find(|listener| listener.mapped_addr == *addr && listener.running.load())?;
Some(listener.get_socket().await)
}
pub(crate) async fn my_udp_nat_type(&self) -> i32 {
self.peer_mgr
.get_global_ctx()
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type
}
pub(crate) async fn select_listener(
&self,
use_new_listener: bool,
) -> Option<(Arc<UdpSocket>, SocketAddr)> {
let all_listener_sockets = &self.listeners;
let mut use_last = false;
if all_listener_sockets.lock().await.len() < 16 || use_new_listener {
tracing::warn!("creating new udp hole punching listener");
all_listener_sockets.lock().await.push(
UdpHolePunchListener::new(self.peer_mgr.clone())
.await
.ok()?,
);
use_last = true;
}
let locked = all_listener_sockets.lock().await;
let listener = if use_last {
locked.last()?
} else {
// use the listener that is active most recently
locked
.iter()
.max_by_key(|listener| listener.last_active_time.load())?
};
Some((listener.get_socket().await, listener.mapped_addr))
}
pub(crate) fn get_joinset(&self) -> Arc<std::sync::Mutex<JoinSet<()>>> {
self.tasks.clone()
}
pub(crate) fn get_global_ctx(&self) -> ArcGlobalCtx {
self.peer_mgr.get_global_ctx()
}
pub(crate) fn get_peer_mgr(&self) -> Arc<PeerManager> {
self.peer_mgr.clone()
}
}
#[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))]
pub(crate) async fn send_symmetric_hole_punch_packet(
ports: &Vec<u16>,
udp: Arc<UdpSocket>,
transaction_id: u32,
public_ips: &Vec<Ipv4Addr>,
port_start_idx: usize,
max_packets: usize,
) -> Result<usize, Error> {
tracing::debug!("sending hard symmetric hole punching packet");
let mut sent_packets = 0;
let mut cur_port_idx = port_start_idx;
while sent_packets < max_packets {
let port = ports[cur_port_idx % ports.len()];
for pub_ip in public_ips {
let addr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, port));
let packet = new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN);
udp.send_to(&packet.into_bytes(), addr).await?;
sent_packets += 1;
}
cur_port_idx = cur_port_idx.wrapping_add(1);
tokio::time::sleep(Duration::from_millis(3)).await;
}
Ok(cur_port_idx % ports.len())
}
pub(crate) async fn try_connect_with_socket(
socket: Arc<UdpSocket>,
remote_mapped_addr: SocketAddr,
) -> Result<Box<dyn Tunnel>, Error> {
let connector = UdpTunnelConnector::new(
format!(
"udp://{}:{}",
remote_mapped_addr.ip(),
remote_mapped_addr.port()
)
.to_string()
.parse()
.unwrap(),
);
connector
.try_connect_with_socket(socket, remote_mapped_addr)
.await
.map_err(|e| Error::from(e))
}

View File

@ -0,0 +1,258 @@
use std::{
sync::Arc,
time::{Duration, Instant},
};
use anyhow::Context;
use tokio::net::UdpSocket;
use crate::{
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId},
connector::udp_hole_punch::common::{
try_connect_with_socket, UdpSocketArray, HOLE_PUNCH_PACKET_BODY_LEN,
},
peers::peer_manager::PeerManager,
proto::{
common::Void,
peer_rpc::{
SelectPunchListenerRequest, SendPunchPacketConeRequest, UdpHolePunchRpcClientFactory,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{udp::new_hole_punch_packet, Tunnel},
};
use super::common::PunchHoleServerCommon;
pub(crate) struct PunchConeHoleServer {
common: Arc<PunchHoleServerCommon>,
}
impl PunchConeHoleServer {
pub(crate) fn new(common: Arc<PunchHoleServerCommon>) -> Self {
Self { common }
}
#[tracing::instrument(skip(self), ret, err)]
pub(crate) async fn send_punch_packet_cone(
&self,
_: BaseController,
request: SendPunchPacketConeRequest,
) -> Result<Void, rpc_types::error::Error> {
let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!(
"send_punch_packet_for_cone request missing listener_mapped_addr"
))?;
let listener_addr = std::net::SocketAddr::from(listener_addr);
let listener = self
.common
.find_listener(&listener_addr)
.await
.ok_or(anyhow::anyhow!(
"send_punch_packet_for_cone failed to find listener"
))?;
let dest_addr = request.dest_addr.ok_or(anyhow::anyhow!(
"send_punch_packet_for_cone request missing dest_addr"
))?;
let dest_addr = std::net::SocketAddr::from(dest_addr);
let dest_ip = dest_addr.ip();
if dest_ip.is_unspecified() || dest_ip.is_multicast() {
return Err(anyhow::anyhow!(
"send_punch_packet_for_cone dest_ip is malformed, {:?}",
request
)
.into());
}
for _ in 0..request.packet_batch_count {
tracing::info!(?request, "sending hole punching packet");
for _ in 0..request.packet_count_per_batch {
let udp_packet =
new_hole_punch_packet(request.transaction_id, HOLE_PUNCH_PACKET_BODY_LEN);
if let Err(e) = listener.send_to(&udp_packet.into_bytes(), &dest_addr).await {
tracing::error!(?e, "failed to send hole punch packet to dest addr");
}
}
tokio::time::sleep(Duration::from_millis(request.packet_interval_ms as u64)).await;
}
Ok(Void::default())
}
}
pub(crate) struct PunchConeHoleClient {
peer_mgr: Arc<PeerManager>,
}
impl PunchConeHoleClient {
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
Self { peer_mgr }
}
#[tracing::instrument(skip(self))]
pub(crate) async fn do_hole_punching(
&self,
dst_peer_id: PeerId,
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
tracing::info!(?dst_peer_id, "start hole punching");
let tid = rand::random();
let global_ctx = self.peer_mgr.get_global_ctx();
let udp_array = UdpSocketArray::new(1, global_ctx.net_ns.clone());
let local_socket = {
let _g = self.peer_mgr.get_global_ctx().net_ns.guard();
Arc::new(UdpSocket::bind("0.0.0.0:0").await?)
};
let local_addr = local_socket
.local_addr()
.with_context(|| anyhow::anyhow!("failed to get local port from udp array"))?;
let local_port = local_addr.port();
let local_mapped_addr = global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(local_port)
.await
.with_context(|| "failed to get udp port mapping")?;
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
let rpc_stub = self
.peer_mgr
.get_peer_rpc_mgr()
.rpc_client()
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
self.peer_mgr.my_peer_id(),
dst_peer_id,
global_ctx.get_network_name(),
);
let resp = rpc_stub
.select_punch_listener(
BaseController::default(),
SelectPunchListenerRequest { force_new: false },
)
.await
.with_context(|| "failed to select punch listener")?;
let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!(
"select_punch_listener response missing listener_mapped_addr"
))?;
tracing::debug!(
?local_mapped_addr,
?remote_mapped_addr,
"hole punch got remote listener"
);
udp_array.add_new_socket(local_socket).await?;
udp_array.add_intreast_tid(tid);
let send_from_local = || async {
udp_array
.send_with_all(
&new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(),
remote_mapped_addr.clone().into(),
)
.await
.with_context(|| "failed to send hole punch packet from local")
};
send_from_local().await?;
let scoped_punch_task: ScopedTask<()> = tokio::spawn(async move {
if let Err(e) = rpc_stub
.send_punch_packet_cone(
BaseController {
timeout_ms: 4000,
..Default::default()
},
SendPunchPacketConeRequest {
listener_mapped_addr: Some(remote_mapped_addr.into()),
dest_addr: Some(local_mapped_addr.into()),
transaction_id: tid,
packet_count_per_batch: 2,
packet_batch_count: 5,
packet_interval_ms: 400,
},
)
.await
{
tracing::error!(?e, "failed to call remote send punch packet");
}
})
.into();
// server: will send some punching resps, total 10 packets.
// client: use the socket to create UdpTunnel with UdpTunnelConnector
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
let mut finish_time: Option<Instant> = None;
while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 {
tokio::time::sleep(Duration::from_millis(200)).await;
if finish_time.is_none() && (*scoped_punch_task).is_finished() {
finish_time = Some(Instant::now());
}
let Some(socket) = udp_array.try_fetch_punched_socket(tid) else {
tracing::debug!("no punched socket found, send some more hole punch packets");
send_from_local().await?;
continue;
};
tracing::debug!(?socket, ?tid, "punched socket found, try connect with it");
for _ in 0..2 {
match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into())
.await
{
Ok(tunnel) => {
tracing::info!(?tunnel, "hole punched");
return Ok(tunnel);
}
Err(e) => {
tracing::error!(?e, "failed to connect with socket");
}
}
}
}
return Err(anyhow::anyhow!("punch task finished but no hole punched"));
}
}
#[cfg(test)]
pub mod tests {
use crate::{
connector::udp_hole_punch::{
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector,
},
peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost},
proto::common::NatType,
};
#[tokio::test]
async fn hole_punching_cone() {
let p_a = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await;
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
let p_c = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
println!("{:?}", p_a.list_routes().await);
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
hole_punching_a.run_as_client().await.unwrap();
hole_punching_c.run_as_server().await.unwrap();
hole_punching_a.client.run_immediately().await;
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
.await
.unwrap();
println!("{:?}", p_a.list_routes().await);
}
}

View File

@ -0,0 +1,482 @@
use std::sync::Arc;
use anyhow::Error;
use both_easy_sym::{PunchBothEasySymHoleClient, PunchBothEasySymHoleServer};
use common::{PunchHoleServerCommon, UdpNatType, UdpPunchClientMethod};
use cone::{PunchConeHoleClient, PunchConeHoleServer};
use sym_to_cone::{PunchSymToConeHoleClient, PunchSymToConeHoleServer};
use tokio::{sync::Mutex, task::JoinHandle};
use crate::{
common::{stun::StunInfoCollectorTrait, PeerId},
connector::direct::PeerManagerForDirectConnector,
peers::{
peer_manager::PeerManager,
peer_task::{PeerTaskLauncher, PeerTaskManager},
},
proto::{
common::{NatType, Void},
peer_rpc::{
SelectPunchListenerRequest, SelectPunchListenerResponse,
SendPunchPacketBothEasySymRequest, SendPunchPacketBothEasySymResponse,
SendPunchPacketConeRequest, SendPunchPacketEasySymRequest,
SendPunchPacketHardSymRequest, SendPunchPacketHardSymResponse, UdpHolePunchRpc,
UdpHolePunchRpcServer,
},
rpc_types::{self, controller::BaseController},
},
};
pub(crate) mod both_easy_sym;
pub(crate) mod common;
pub(crate) mod cone;
pub(crate) mod sym_to_cone;
struct UdpHolePunchServer {
common: Arc<PunchHoleServerCommon>,
cone_server: PunchConeHoleServer,
sym_to_cone_server: PunchSymToConeHoleServer,
both_easy_sym_server: PunchBothEasySymHoleServer,
}
impl UdpHolePunchServer {
pub fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
let common = Arc::new(PunchHoleServerCommon::new(peer_mgr.clone()));
let cone_server = PunchConeHoleServer::new(common.clone());
let sym_to_cone_server = PunchSymToConeHoleServer::new(common.clone());
let both_easy_sym_server = PunchBothEasySymHoleServer::new(common.clone());
Arc::new(Self {
common,
cone_server,
sym_to_cone_server,
both_easy_sym_server,
})
}
}
#[async_trait::async_trait]
impl UdpHolePunchRpc for UdpHolePunchServer {
type Controller = BaseController;
async fn select_punch_listener(
&self,
_ctrl: Self::Controller,
input: SelectPunchListenerRequest,
) -> rpc_types::error::Result<SelectPunchListenerResponse> {
let (_, addr) = self
.common
.select_listener(input.force_new)
.await
.ok_or(anyhow::anyhow!("no listener available"))?;
Ok(SelectPunchListenerResponse {
listener_mapped_addr: Some(addr.into()),
})
}
/// send packet to one remote_addr, used by nat1-3 to nat1-3
async fn send_punch_packet_cone(
&self,
ctrl: Self::Controller,
input: SendPunchPacketConeRequest,
) -> rpc_types::error::Result<Void> {
self.cone_server.send_punch_packet_cone(ctrl, input).await
}
/// send packet to multiple remote_addr (birthday attack), used by nat4 to nat1-3
async fn send_punch_packet_hard_sym(
&self,
_ctrl: Self::Controller,
input: SendPunchPacketHardSymRequest,
) -> rpc_types::error::Result<SendPunchPacketHardSymResponse> {
self.sym_to_cone_server
.send_punch_packet_hard_sym(input)
.await
}
async fn send_punch_packet_easy_sym(
&self,
_ctrl: Self::Controller,
input: SendPunchPacketEasySymRequest,
) -> rpc_types::error::Result<Void> {
self.sym_to_cone_server
.send_punch_packet_easy_sym(input)
.await
.map(|_| Void {})
}
/// nat4 to nat4 (both predictably)
async fn send_punch_packet_both_easy_sym(
&self,
_ctrl: Self::Controller,
input: SendPunchPacketBothEasySymRequest,
) -> rpc_types::error::Result<SendPunchPacketBothEasySymResponse> {
self.both_easy_sym_server
.send_punch_packet_both_easy_sym(input)
.await
}
}
struct BackOff {
backoffs_ms: Vec<u64>,
current_idx: usize,
}
impl BackOff {
pub fn new(backoffs_ms: Vec<u64>) -> Self {
Self {
backoffs_ms,
current_idx: 0,
}
}
pub fn next_backoff(&mut self) -> u64 {
let backoff = self.backoffs_ms[self.current_idx];
self.current_idx = (self.current_idx + 1).min(self.backoffs_ms.len() - 1);
backoff
}
pub fn rollback(&mut self) {
self.current_idx = self.current_idx.saturating_sub(1);
}
pub async fn sleep_for_next_backoff(&mut self) {
let backoff = self.next_backoff();
if backoff > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(backoff)).await;
}
}
}
struct UdpHoePunchConnectorData {
cone_client: PunchConeHoleClient,
sym_to_cone_client: PunchSymToConeHoleClient,
both_easy_sym_client: PunchBothEasySymHoleClient,
peer_mgr: Arc<PeerManager>,
// sym punch should be serialized
sym_punch_lock: Mutex<()>,
}
impl UdpHoePunchConnectorData {
pub fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
let cone_client = PunchConeHoleClient::new(peer_mgr.clone());
let sym_to_cone_client = PunchSymToConeHoleClient::new(peer_mgr.clone());
let both_easy_sym_client = PunchBothEasySymHoleClient::new(peer_mgr.clone());
Arc::new(Self {
cone_client,
sym_to_cone_client,
both_easy_sym_client,
peer_mgr,
sym_punch_lock: Mutex::new(()),
})
}
#[tracing::instrument(skip(self))]
async fn cone_to_cone(self: Arc<Self>, task_info: PunchTaskInfo) -> Result<(), Error> {
let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000]);
loop {
backoff.sleep_for_next_backoff().await;
let ret = self
.cone_client
.do_hole_punching(task_info.dst_peer_id)
.await;
if let Err(e) = ret {
tracing::info!(?e, "cone_to_cone hole punching failed");
continue;
}
if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await {
tracing::warn!(?e, "cone_to_cone add client tunnel failed");
continue;
}
break;
}
tracing::info!("cone_to_cone hole punching success");
Ok(())
}
#[tracing::instrument(skip(self))]
async fn sym_to_cone(self: Arc<Self>, task_info: PunchTaskInfo) -> Result<(), Error> {
let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000, 64000]);
let mut round = 0;
let mut port_idx = rand::random();
loop {
backoff.sleep_for_next_backoff().await;
let ret = {
let _lock = self.sym_punch_lock.lock().await;
self.sym_to_cone_client
.do_hole_punching(
task_info.dst_peer_id,
round,
&mut port_idx,
task_info.my_nat_type,
)
.await
};
round += 1;
if let Err(e) = ret {
tracing::info!(?e, "sym_to_cone hole punching failed");
continue;
}
if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await {
tracing::warn!(?e, "sym_to_cone add client tunnel failed");
continue;
}
break;
}
Ok(())
}
#[tracing::instrument(skip(self))]
async fn both_easy_sym(self: Arc<Self>, task_info: PunchTaskInfo) -> Result<(), Error> {
let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000, 64000]);
loop {
backoff.sleep_for_next_backoff().await;
let mut is_busy = false;
let ret = {
let _lock = self.sym_punch_lock.lock().await;
self.both_easy_sym_client
.do_hole_punching(
task_info.dst_peer_id,
task_info.my_nat_type,
task_info.dst_nat_type,
&mut is_busy,
)
.await
};
if is_busy {
backoff.rollback();
}
if let Err(e) = ret {
tracing::info!(?e, "both_easy_sym hole punching failed");
continue;
}
if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await {
tracing::warn!(?e, "both_easy_sym add client tunnel failed");
continue;
}
break;
}
Ok(())
}
}
#[derive(Clone)]
struct UdpHolePunchPeerTaskLauncher {}
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
struct PunchTaskInfo {
dst_peer_id: PeerId,
dst_nat_type: UdpNatType,
my_nat_type: UdpNatType,
}
#[async_trait::async_trait]
impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher {
type Data = Arc<UdpHoePunchConnectorData>;
type CollectPeerItem = PunchTaskInfo;
type TaskRet = ();
fn new_data(&self, peer_mgr: Arc<PeerManager>) -> Self::Data {
UdpHoePunchConnectorData::new(peer_mgr)
}
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem> {
let my_nat_type = data
.peer_mgr
.get_global_ctx()
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type;
let my_nat_type: UdpNatType = NatType::try_from(my_nat_type)
.unwrap_or(NatType::Unknown)
.into();
if !my_nat_type.is_sym() {
data.sym_to_cone_client.clear_udp_array().await;
}
let mut peers_to_connect: Vec<Self::CollectPeerItem> = Vec::new();
// do not do anything if:
// 1. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
// notice that if we are unknown, we treat ourselves as cone
if my_nat_type.is_open() {
return peers_to_connect;
}
// collect peer list from peer manager and do some filter:
// 1. peers without direct conns;
// 2. peers is full cone (any restricted type);
for route in data.peer_mgr.list_routes().await.iter() {
if route
.feature_flag
.map(|x| x.is_public_server)
.unwrap_or(false)
{
continue;
}
let peer_nat_type = route
.stun_info
.as_ref()
.map(|x| x.udp_nat_type)
.unwrap_or(0);
let Ok(peer_nat_type) = NatType::try_from(peer_nat_type) else {
continue;
};
let peer_nat_type = peer_nat_type.into();
let peer_id: PeerId = route.peer_id;
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
if conns.is_some() && conns.unwrap().len() > 0 {
continue;
}
if !my_nat_type.can_punch_hole_as_client(peer_nat_type) {
continue;
}
tracing::info!(
?peer_id,
?peer_nat_type,
?my_nat_type,
"found peer to do hole punching"
);
peers_to_connect.push(PunchTaskInfo {
dst_peer_id: peer_id,
dst_nat_type: peer_nat_type,
my_nat_type,
});
}
peers_to_connect
}
async fn launch_task(
&self,
data: &Self::Data,
item: Self::CollectPeerItem,
) -> JoinHandle<Result<Self::TaskRet, Error>> {
let data = data.clone();
let punch_method = item.my_nat_type.get_punch_hole_method(item.dst_nat_type);
match punch_method {
UdpPunchClientMethod::ConeToCone => tokio::spawn(data.cone_to_cone(item)),
UdpPunchClientMethod::SymToCone => tokio::spawn(data.sym_to_cone(item)),
UdpPunchClientMethod::EasySymToEasySym => tokio::spawn(data.both_easy_sym(item)),
_ => unreachable!(),
}
}
async fn all_task_done(&self, data: &Self::Data) {
data.sym_to_cone_client.clear_udp_array().await;
}
fn loop_interval_ms(&self) -> u64 {
5000
}
}
pub struct UdpHolePunchConnector {
server: Arc<UdpHolePunchServer>,
client: PeerTaskManager<UdpHolePunchPeerTaskLauncher>,
peer_mgr: Arc<PeerManager>,
}
// Currently support:
// Symmetric -> Full Cone
// Any Type of Full Cone -> Any Type of Full Cone
// if same level of full cone, node with smaller peer_id will be the initiator
// if different level of full cone, node with more strict level will be the initiator
impl UdpHolePunchConnector {
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
Self {
server: UdpHolePunchServer::new(peer_mgr.clone()),
client: PeerTaskManager::new(UdpHolePunchPeerTaskLauncher {}, peer_mgr.clone()),
peer_mgr,
}
}
pub async fn run_as_client(&mut self) -> Result<(), Error> {
self.client.start();
Ok(())
}
pub async fn run_as_server(&mut self) -> Result<(), Error> {
self.peer_mgr
.get_peer_rpc_mgr()
.rpc_server()
.registry()
.register(
UdpHolePunchRpcServer::new(self.server.clone()),
&self.peer_mgr.get_global_ctx().get_network_name(),
);
Ok(())
}
pub async fn run(&mut self) -> Result<(), Error> {
let global_ctx = self.peer_mgr.get_global_ctx();
if global_ctx.get_flags().disable_p2p {
return Ok(());
}
if global_ctx.get_flags().disable_udp_hole_punching {
return Ok(());
}
self.run_as_client().await?;
self.run_as_server().await?;
Ok(())
}
}
#[cfg(test)]
pub mod tests {
use std::sync::Arc;
use crate::common::stun::MockStunInfoCollector;
use crate::proto::common::NatType;
use crate::peers::{peer_manager::PeerManager, tests::create_mock_peer_manager};
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
peer_mgr
.get_global_ctx()
.replace_stun_info_collector(collector);
}
pub async fn create_mock_peer_manager_with_mock_stun(
udp_nat_type: NatType,
) -> Arc<PeerManager> {
let p_a = create_mock_peer_manager().await;
replace_stun_info_collector(p_a.clone(), udp_nat_type);
p_a
}
}

View File

@ -0,0 +1,591 @@
use std::{
net::Ipv4Addr,
ops::{Div, Mul},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
time::{Duration, Instant},
};
use anyhow::Context;
use rand::{seq::SliceRandom, Rng};
use tokio::{net::UdpSocket, sync::RwLock};
use tracing::Level;
use crate::{
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId},
connector::udp_hole_punch::common::{
send_symmetric_hole_punch_packet, try_connect_with_socket, HOLE_PUNCH_PACKET_BODY_LEN,
},
defer,
peers::peer_manager::PeerManager,
proto::{
peer_rpc::{
SelectPunchListenerRequest, SendPunchPacketEasySymRequest,
SendPunchPacketHardSymRequest, SendPunchPacketHardSymResponse,
UdpHolePunchRpcClientFactory,
},
rpc_types::{self, controller::BaseController},
},
tunnel::{udp::new_hole_punch_packet, Tunnel},
};
use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray};
const UDP_ARRAY_SIZE_FOR_HARD_SYM: usize = 84;
pub(crate) struct PunchSymToConeHoleServer {
common: Arc<PunchHoleServerCommon>,
shuffled_port_vec: Arc<Vec<u16>>,
}
impl PunchSymToConeHoleServer {
pub(crate) fn new(common: Arc<PunchHoleServerCommon>) -> Self {
let mut shuffled_port_vec: Vec<u16> = (1..=65535).collect();
shuffled_port_vec.shuffle(&mut rand::thread_rng());
Self {
common,
shuffled_port_vec: Arc::new(shuffled_port_vec),
}
}
// hard sym means public port is random and cannot be predicted
#[tracing::instrument(skip(self), ret)]
pub(crate) async fn send_punch_packet_easy_sym(
&self,
request: SendPunchPacketEasySymRequest,
) -> Result<(), rpc_types::error::Error> {
tracing::info!("send_punch_packet_easy_sym start");
let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!(
"send_punch_packet_easy_sym request missing listener_addr"
))?;
let listener_addr = std::net::SocketAddr::from(listener_addr);
let listener = self
.common
.find_listener(&listener_addr)
.await
.ok_or(anyhow::anyhow!(
"send_punch_packet_easy_sym failed to find listener"
))?;
let public_ips = request
.public_ips
.into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip))
.collect::<Vec<_>>();
if public_ips.len() == 0 {
tracing::warn!("send_punch_packet_easy_sym got zero len public ip");
return Err(
anyhow::anyhow!("send_punch_packet_easy_sym got zero len public ip").into(),
);
}
let transaction_id = request.transaction_id;
let base_port_num = request.base_port_num;
let max_port_num = request.max_port_num.max(1);
let is_incremental = request.is_incremental;
let port_start = if is_incremental {
base_port_num.saturating_add(1)
} else {
base_port_num.saturating_sub(max_port_num)
};
let port_end = if is_incremental {
base_port_num.saturating_add(max_port_num)
} else {
base_port_num.saturating_sub(1)
};
if port_end <= port_start {
return Err(anyhow::anyhow!("send_punch_packet_easy_sym invalid port range").into());
}
let ports = (port_start..=port_end)
.map(|x| x as u16)
.collect::<Vec<_>>();
tracing::debug!(
?ports,
?public_ips,
"send_punch_packet_easy_sym send to ports"
);
send_symmetric_hole_punch_packet(
&ports,
listener,
transaction_id,
&public_ips,
0,
ports.len(),
)
.await
.with_context(|| "failed to send symmetric hole punch packet")?;
Ok(())
}
// hard sym means public port is random and cannot be predicted
#[tracing::instrument(skip(self))]
pub(crate) async fn send_punch_packet_hard_sym(
&self,
request: SendPunchPacketHardSymRequest,
) -> Result<SendPunchPacketHardSymResponse, rpc_types::error::Error> {
tracing::info!("try_punch_symmetric start");
let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!(
"try_punch_symmetric request missing listener_addr"
))?;
let listener_addr = std::net::SocketAddr::from(listener_addr);
let listener = self
.common
.find_listener(&listener_addr)
.await
.ok_or(anyhow::anyhow!(
"send_punch_packet_for_cone failed to find listener"
))?;
let public_ips = request
.public_ips
.into_iter()
.map(|ip| std::net::Ipv4Addr::from(ip))
.collect::<Vec<_>>();
if public_ips.len() == 0 {
tracing::warn!("try_punch_symmetric got zero len public ip");
return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into());
}
let transaction_id = request.transaction_id;
let last_port_index = request.port_index as usize;
let round = std::cmp::max(request.round, 1);
// send max k1 packets if we are predicting the dst port
let max_k1: u32 = 180;
// send max k2 packets if we are sending to random port
let mut max_k2: u32 = rand::thread_rng().gen_range(600..800);
if round > 2 {
max_k2 = max_k2.mul(2).div(round).max(max_k1);
}
let next_port_index = send_symmetric_hole_punch_packet(
&self.shuffled_port_vec,
listener.clone(),
transaction_id,
&public_ips,
last_port_index,
max_k2 as usize,
)
.await
.with_context(|| "failed to send symmetric hole punch packet randomly")?;
return Ok(SendPunchPacketHardSymResponse {
next_port_index: next_port_index as u32,
});
}
}
pub(crate) struct PunchSymToConeHoleClient {
peer_mgr: Arc<PeerManager>,
udp_array: RwLock<Option<Arc<UdpSocketArray>>>,
try_direct_connect: AtomicBool,
punch_predicablely: AtomicBool,
punch_randomly: AtomicBool,
}
impl PunchSymToConeHoleClient {
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
Self {
peer_mgr,
udp_array: RwLock::new(None),
try_direct_connect: AtomicBool::new(true),
punch_predicablely: AtomicBool::new(true),
punch_randomly: AtomicBool::new(true),
}
}
async fn prepare_udp_array(&self) -> Result<Arc<UdpSocketArray>, anyhow::Error> {
let rlocked = self.udp_array.read().await;
if let Some(udp_array) = rlocked.clone() {
return Ok(udp_array);
}
drop(rlocked);
let mut wlocked = self.udp_array.write().await;
if let Some(udp_array) = wlocked.clone() {
return Ok(udp_array);
}
let udp_array = Arc::new(UdpSocketArray::new(
UDP_ARRAY_SIZE_FOR_HARD_SYM,
self.peer_mgr.get_global_ctx().net_ns.clone(),
));
udp_array.start().await?;
wlocked.replace(udp_array.clone());
Ok(udp_array)
}
pub(crate) async fn clear_udp_array(&self) {
let mut wlocked = self.udp_array.write().await;
wlocked.take();
}
async fn get_base_port_for_easy_sym(&self, my_nat_info: UdpNatType) -> Option<u16> {
let global_ctx = self.peer_mgr.get_global_ctx();
if my_nat_info.is_easy_sym() {
match global_ctx
.get_stun_info_collector()
.get_udp_port_mapping(0)
.await
{
Ok(addr) => Some(addr.port()),
ret => {
tracing::warn!(?ret, "failed to get udp port mapping for easy sym");
None
}
}
} else {
None
}
}
#[tracing::instrument(err(level = Level::ERROR), skip(self))]
pub(crate) async fn do_hole_punching(
&self,
dst_peer_id: PeerId,
round: u32,
last_port_idx: &mut usize,
my_nat_info: UdpNatType,
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
let udp_array = self.prepare_udp_array().await?;
let global_ctx = self.peer_mgr.get_global_ctx();
let rpc_stub = self
.peer_mgr
.get_peer_rpc_mgr()
.rpc_client()
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
self.peer_mgr.my_peer_id(),
dst_peer_id,
global_ctx.get_network_name(),
);
let resp = rpc_stub
.select_punch_listener(
BaseController::default(),
SelectPunchListenerRequest { force_new: false },
)
.await
.with_context(|| "failed to select punch listener")?;
let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!(
"select_punch_listener response missing listener_mapped_addr"
))?;
// try direct connect first
if self.try_direct_connect.load(Ordering::Relaxed) {
if let Ok(tunnel) = try_connect_with_socket(
Arc::new(UdpSocket::bind("0.0.0.0:0").await?),
remote_mapped_addr.into(),
)
.await
{
return Ok(tunnel);
}
}
let stun_info = global_ctx.get_stun_info_collector().get_stun_info();
let public_ips: Vec<Ipv4Addr> = stun_info
.public_ip
.iter()
.map(|x| x.parse().unwrap())
.collect();
if public_ips.is_empty() {
return Err(anyhow::anyhow!("failed to get public ips"));
}
let tid = rand::thread_rng().gen();
let packet = new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes();
udp_array.add_intreast_tid(tid);
defer! { udp_array.remove_intreast_tid(tid);}
udp_array
.send_with_all(&packet, remote_mapped_addr.into())
.await?;
let port_index = *last_port_idx as u32;
let base_port_for_easy_sym = self.get_base_port_for_easy_sym(my_nat_info).await;
let punch_random = self.punch_randomly.load(Ordering::Relaxed);
let punch_predicable = self.punch_predicablely.load(Ordering::Relaxed);
let scoped_punch_task: ScopedTask<Option<u32>> = tokio::spawn(async move {
if punch_predicable {
if let Some(inc) = my_nat_info.get_inc_of_easy_sym() {
let req = SendPunchPacketEasySymRequest {
listener_mapped_addr: remote_mapped_addr.clone().into(),
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
transaction_id: tid,
base_port_num: base_port_for_easy_sym.unwrap() as u32,
max_port_num: 50,
is_incremental: inc,
};
tracing::debug!(?req, "send punch packet for easy sym start");
let ret = rpc_stub
.send_punch_packet_easy_sym(
BaseController {
timeout_ms: 4000,
trace_id: 0,
},
req,
)
.await;
tracing::debug!(?ret, "send punch packet for easy sym return");
}
}
if punch_random {
let req = SendPunchPacketHardSymRequest {
listener_mapped_addr: remote_mapped_addr.clone().into(),
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
transaction_id: tid,
round,
port_index,
};
tracing::debug!(?req, "send punch packet for hard sym start");
match rpc_stub
.send_punch_packet_hard_sym(
BaseController {
timeout_ms: 4000,
trace_id: 0,
},
req,
)
.await
{
Err(e) => {
tracing::error!(?e, "failed to send punch packet for hard sym");
return None;
}
Ok(resp) => return Some(resp.next_port_index),
}
}
None
})
.into();
// no matter what the result is, we should check if we received any hole punching packet
let mut ret_tunnel: Option<Box<dyn Tunnel>> = None;
let mut finish_time: Option<Instant> = None;
while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 {
tokio::time::sleep(Duration::from_millis(200)).await;
if finish_time.is_none() && (*scoped_punch_task).is_finished() {
finish_time = Some(Instant::now());
}
let Some(socket) = udp_array.try_fetch_punched_socket(tid) else {
tracing::debug!("no punched socket found, wait for more time");
continue;
};
// if hole punched but tunnel creation failed, need to retry entire process.
match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into()).await {
Ok(tunnel) => {
ret_tunnel.replace(tunnel);
break;
}
Err(e) => {
tracing::error!(?e, "failed to connect with socket");
udp_array.add_new_socket(socket.socket).await?;
continue;
}
}
}
let punch_task_result = scoped_punch_task.await;
tracing::debug!(?punch_task_result, ?ret_tunnel, "punch task got result");
if let Ok(Some(next_port_idx)) = punch_task_result {
*last_port_idx = next_port_idx as usize;
} else {
*last_port_idx = rand::random();
}
if let Some(tunnel) = ret_tunnel {
Ok(tunnel)
} else {
anyhow::bail!(
"failed to hole punch, punch task result: {:?}",
punch_task_result
)
}
}
}
#[cfg(test)]
pub mod tests {
use std::{
sync::{atomic::AtomicU32, Arc},
time::Duration,
};
use tokio::net::UdpSocket;
use crate::{
connector::udp_hole_punch::{
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector,
},
peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost},
proto::common::NatType,
tunnel::common::tests::wait_for_condition,
};
#[tokio::test]
async fn hole_punching_symmetric_only_random() {
let p_a = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
hole_punching_a
.client
.data()
.sym_to_cone_client
.try_direct_connect
.store(false, std::sync::atomic::Ordering::Relaxed);
hole_punching_a
.client
.data()
.sym_to_cone_client
.punch_predicablely
.store(false, std::sync::atomic::Ordering::Relaxed);
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
hole_punching_a.client.run_immediately().await;
wait_for_condition(
|| async {
hole_punching_a
.client
.data()
.sym_to_cone_client
.udp_array
.read()
.await
.is_some()
},
Duration::from_secs(5),
)
.await;
wait_for_condition(
|| async {
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
.await
.is_ok()
},
Duration::from_secs(5),
)
.await;
println!("{:?}", p_a.list_routes().await);
wait_for_condition(
|| async {
hole_punching_a
.client
.data()
.sym_to_cone_client
.udp_array
.read()
.await
.is_none()
},
Duration::from_secs(10),
)
.await;
}
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial(hole_punch)]
async fn hole_punching_symmetric_only_predict(#[values("true", "false")] is_inc: bool) {
let p_a = create_mock_peer_manager_with_mock_stun(if is_inc {
NatType::SymmetricEasyInc
} else {
NatType::SymmetricEasyDec
})
.await;
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
connect_peer_manager(p_a.clone(), p_b.clone()).await;
connect_peer_manager(p_b.clone(), p_c.clone()).await;
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
hole_punching_a
.client
.data()
.sym_to_cone_client
.try_direct_connect
.store(false, std::sync::atomic::Ordering::Relaxed);
hole_punching_a
.client
.data()
.sym_to_cone_client
.punch_randomly
.store(false, std::sync::atomic::Ordering::Relaxed);
hole_punching_a.run().await.unwrap();
hole_punching_c.run().await.unwrap();
let udps = if is_inc {
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40147").await.unwrap());
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40194").await.unwrap());
vec![udp1, udp2]
} else {
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40141").await.unwrap());
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40100").await.unwrap());
vec![udp1, udp2]
};
// let udp_dec = Arc::new(UdpSocket::bind("0.0.0.0:40140").await.unwrap());
// let udp_dec2 = Arc::new(UdpSocket::bind("0.0.0.0:40050").await.unwrap());
let counter = Arc::new(AtomicU32::new(0));
// all these sockets should receive hole punching packet
for udp in udps.iter().map(Arc::clone) {
let counter = counter.clone();
tokio::spawn(async move {
let mut buf = [0u8; 1024];
let (len, addr) = udp.recv_from(&mut buf).await.unwrap();
println!(
"got predictable punch packet, {:?} {:?} {:?}",
len,
addr,
udp.local_addr()
);
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
});
}
hole_punching_a.client.run_immediately().await;
let udp_len = udps.len();
wait_for_condition(
|| async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 },
Duration::from_secs(30),
)
.await;
}
}

View File

@ -179,14 +179,16 @@ impl CommandHandler {
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
let client = self.get_peer_manager_client().await?;
let request = ListPeerRequest::default();
let response = client.list_peer(BaseController {}, request).await?;
let response = client.list_peer(BaseController::default(), request).await?;
Ok(response)
}
async fn list_routes(&self) -> Result<ListRouteResponse, Error> {
let client = self.get_peer_manager_client().await?;
let request = ListRouteRequest::default();
let response = client.list_route(BaseController {}, request).await?;
let response = client
.list_route(BaseController::default(), request)
.await?;
Ok(response)
}
@ -275,7 +277,7 @@ impl CommandHandler {
let client = self.get_peer_manager_client().await?;
let node_info = client
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.show_node_info(BaseController::default(), ShowNodeInfoRequest::default())
.await?
.node_info
.ok_or(anyhow::anyhow!("node info not found"))?;
@ -296,7 +298,9 @@ impl CommandHandler {
async fn handle_route_dump(&self) -> Result<(), Error> {
let client = self.get_peer_manager_client().await?;
let request = DumpRouteRequest::default();
let response = client.dump_route(BaseController {}, request).await?;
let response = client
.dump_route(BaseController::default(), request)
.await?;
println!("response: {}", response.result);
Ok(())
}
@ -305,7 +309,7 @@ impl CommandHandler {
let client = self.get_peer_manager_client().await?;
let request = ListForeignNetworkRequest::default();
let response = client
.list_foreign_network(BaseController {}, request)
.list_foreign_network(BaseController::default(), request)
.await?;
let network_map = response;
if self.verbose {
@ -347,7 +351,7 @@ impl CommandHandler {
let client = self.get_peer_manager_client().await?;
let request = ListGlobalForeignNetworkRequest::default();
let response = client
.list_global_foreign_network(BaseController {}, request)
.list_global_foreign_network(BaseController::default(), request)
.await?;
if self.verbose {
println!("{:#?}", response);
@ -383,7 +387,7 @@ impl CommandHandler {
let mut items: Vec<RouteTableItem> = vec![];
let client = self.get_peer_manager_client().await?;
let node_info = client
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.show_node_info(BaseController::default(), ShowNodeInfoRequest::default())
.await?
.node_info
.ok_or(anyhow::anyhow!("node info not found"))?;
@ -451,7 +455,9 @@ impl CommandHandler {
async fn handle_connector_list(&self) -> Result<(), Error> {
let client = self.get_connector_manager_client().await?;
let request = ListConnectorRequest::default();
let response = client.list_connector(BaseController {}, request).await?;
let response = client
.list_connector(BaseController::default(), request)
.await?;
println!("response: {:#?}", response);
Ok(())
}
@ -515,7 +521,7 @@ async fn main() -> Result<(), Error> {
Some(RouteSubCommand::Dump) => handler.handle_route_dump().await?,
},
SubCommand::Stun => {
timeout(Duration::from_secs(5), async move {
timeout(Duration::from_secs(25), async move {
let collector = StunInfoCollector::new_with_default_servers();
loop {
let ret = collector.get_stun_info();
@ -532,7 +538,10 @@ async fn main() -> Result<(), Error> {
SubCommand::PeerCenter => {
let peer_center_client = handler.get_peer_center_client().await?;
let resp = peer_center_client
.get_global_peer_map(BaseController {}, GetGlobalPeerMapRequest::default())
.get_global_peer_map(
BaseController::default(),
GetGlobalPeerMapRequest::default(),
)
.await?;
#[derive(tabled::Tabled)]
@ -565,7 +574,10 @@ async fn main() -> Result<(), Error> {
SubCommand::VpnPortal => {
let vpn_portal_client = handler.get_vpn_portal_client().await?;
let resp = vpn_portal_client
.get_vpn_portal_info(BaseController {}, GetVpnPortalInfoRequest::default())
.get_vpn_portal_info(
BaseController::default(),
GetVpnPortalInfoRequest::default(),
)
.await?
.vpn_portal_info
.unwrap_or_default();
@ -583,7 +595,7 @@ async fn main() -> Result<(), Error> {
SubCommand::Node(sub_cmd) => {
let client = handler.get_peer_manager_client().await?;
let node_info = client
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
.show_node_info(BaseController::default(), ShowNodeInfoRequest::default())
.await?
.node_info
.ok_or(anyhow::anyhow!("node info not found"))?;

View File

@ -161,7 +161,7 @@ impl Instance {
DirectConnectorManager::new(global_ctx.clone(), peer_manager.clone());
direct_conn_manager.run();
let udp_hole_puncher = UdpHolePunchConnector::new(global_ctx.clone(), peer_manager.clone());
let udp_hole_puncher = UdpHolePunchConnector::new(peer_manager.clone());
let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone()));

View File

@ -230,7 +230,7 @@ impl PeerCenterInstance {
let ret = client
.get_global_peer_map(
BaseController {},
BaseController::default(),
GetGlobalPeerMapRequest {
digest: ctx.job_ctx.global_peer_map_digest.load(),
},
@ -307,7 +307,7 @@ impl PeerCenterInstance {
let ret = client
.report_peers(
BaseController {},
BaseController::default(),
ReportPeersRequest {
my_peer_id: my_node_id,
peer_infos: Some(peers),

View File

@ -15,6 +15,8 @@ pub mod foreign_network_manager;
pub mod encrypt;
pub mod peer_task;
#[cfg(test)]
pub mod tests;

View File

@ -1058,7 +1058,7 @@ mod tests {
let ret = stub
.say_hello(
RpcController {},
RpcController::default(),
SayHelloRequest {
name: "abc".to_string(),
},

View File

@ -539,7 +539,7 @@ impl RouteTable {
fn get_nat_type(&self, peer_id: PeerId) -> Option<NatType> {
self.peer_infos
.get(&peer_id)
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap())
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap_or_default())
}
fn build_peer_graph_from_synced_info<T: RouteCostCalculatorInterface>(
@ -1322,7 +1322,7 @@ impl PeerRouteServiceImpl {
self.global_ctx.get_network_name(),
);
let mut ctrl = BaseController {};
let mut ctrl = BaseController::default();
ctrl.set_timeout_ms(3000);
let ret = rpc_stub
.sync_route_info(

View File

@ -224,7 +224,10 @@ pub mod tests {
let msg = random_string(8192);
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.say_hello(
RpcController::default(),
SayHelloRequest { name: msg.clone() },
)
.await
.unwrap();
@ -233,7 +236,10 @@ pub mod tests {
let msg = random_string(10);
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.say_hello(
RpcController::default(),
SayHelloRequest { name: msg.clone() },
)
.await
.unwrap();
@ -281,7 +287,10 @@ pub mod tests {
);
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.say_hello(
RpcController::default(),
SayHelloRequest { name: msg.clone() },
)
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
@ -289,14 +298,20 @@ pub mod tests {
// call again
let msg = random_string(16 * 1024);
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.say_hello(
RpcController::default(),
SayHelloRequest { name: msg.clone() },
)
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
let msg = random_string(16 * 1024);
let ret = stub
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.say_hello(
RpcController::default(),
SayHelloRequest { name: msg.clone() },
)
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
@ -340,13 +355,19 @@ pub mod tests {
let msg = random_string(16 * 1024);
let ret = stub1
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.say_hello(
RpcController::default(),
SayHelloRequest { name: msg.clone() },
)
.await
.unwrap();
assert_eq!(ret.greeting, format!("Hello {}!", msg));
let ret = stub2
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
.say_hello(
RpcController::default(),
SayHelloRequest { name: msg.clone() },
)
.await;
assert!(ret.is_err() && ret.unwrap_err().to_string().contains("Timeout"));
}

View File

@ -0,0 +1,138 @@
use std::result::Result;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::select;
use tokio::sync::Notify;
use tokio::task::JoinHandle;
use crate::common::scoped_task::ScopedTask;
use anyhow::Error;
use super::peer_manager::PeerManager;
#[async_trait]
pub trait PeerTaskLauncher: Send + Sync + Clone + 'static {
type Data;
type CollectPeerItem;
type TaskRet;
fn new_data(&self, peer_mgr: Arc<PeerManager>) -> Self::Data;
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem>;
async fn launch_task(
&self,
data: &Self::Data,
item: Self::CollectPeerItem,
) -> JoinHandle<Result<Self::TaskRet, Error>>;
async fn all_task_done(&self, _data: &Self::Data) {}
fn loop_interval_ms(&self) -> u64 {
5000
}
}
pub struct PeerTaskManager<Launcher: PeerTaskLauncher> {
launcher: Launcher,
peer_mgr: Arc<PeerManager>,
main_loop_task: Mutex<Option<ScopedTask<()>>>,
run_signal: Arc<Notify>,
data: Launcher::Data,
}
impl<D, C, T, L> PeerTaskManager<L>
where
D: Send + Sync + Clone + 'static,
C: std::fmt::Debug + Send + Sync + Clone + core::hash::Hash + Eq + 'static,
T: Send + 'static,
L: PeerTaskLauncher<Data = D, CollectPeerItem = C, TaskRet = T> + 'static,
{
pub fn new(launcher: L, peer_mgr: Arc<PeerManager>) -> Self {
let data = launcher.new_data(peer_mgr.clone());
Self {
launcher,
peer_mgr,
main_loop_task: Mutex::new(None),
run_signal: Arc::new(Notify::new()),
data,
}
}
pub fn start(&self) {
let task = tokio::spawn(Self::main_loop(
self.launcher.clone(),
self.data.clone(),
self.run_signal.clone(),
))
.into();
self.main_loop_task.lock().unwrap().replace(task);
}
async fn main_loop(launcher: L, data: D, signal: Arc<Notify>) {
let peer_task_map = Arc::new(DashMap::<C, ScopedTask<Result<T, Error>>>::new());
loop {
let peers_to_connect = launcher.collect_peers_need_task(&data).await;
// remove task not in peers_to_connect
let mut to_remove = vec![];
for item in peer_task_map.iter() {
if !peers_to_connect.contains(item.key()) || item.value().is_finished() {
to_remove.push(item.key().clone());
}
}
tracing::debug!(
?peers_to_connect,
?to_remove,
"got peers to connect and remove"
);
for key in to_remove {
if let Some((_, task)) = peer_task_map.remove(&key) {
task.abort();
match task.await {
Ok(Ok(_)) => {}
Ok(Err(task_ret)) => {
tracing::error!(?task_ret, "hole punching task failed");
}
Err(e) => {
tracing::error!(?e, "hole punching task aborted");
}
}
}
}
if !peers_to_connect.is_empty() {
for item in peers_to_connect {
if peer_task_map.contains_key(&item) {
continue;
}
tracing::debug!(?item, "launch hole punching task");
peer_task_map
.insert(item.clone(), launcher.launch_task(&data, item).await.into());
}
} else if peer_task_map.is_empty() {
tracing::debug!("all task done");
launcher.all_task_done(&data).await;
}
select! {
_ = tokio::time::sleep(std::time::Duration::from_millis(
launcher.loop_interval_ms(),
)) => {},
_ = signal.notified() => {}
}
}
}
pub async fn run_immediately(&self) {
self.run_signal.notify_one();
}
pub fn data(&self) -> D {
self.data.clone()
}
}

View File

@ -42,6 +42,8 @@ message RpcPacket {
int32 trace_id = 9;
}
message Void {}
message UUID {
uint64 high = 1;
uint64 low = 2;
@ -57,6 +59,8 @@ enum NatType {
PortRestricted = 5;
Symmetric = 6;
SymUdpFirewall = 7;
SymmetricEasyInc = 8;
SymmetricEasyDec = 9;
}
message Ipv4Addr { uint32 addr = 1; }

View File

@ -93,27 +93,78 @@ service DirectConnectorRpc {
rpc GetIpList(GetIpListRequest) returns (GetIpListResponse);
}
message TryPunchHoleRequest { common.SocketAddr local_mapped_addr = 1; }
message TryPunchHoleResponse { common.SocketAddr remote_mapped_addr = 1; }
message TryPunchSymmetricRequest {
common.SocketAddr listener_addr = 1;
uint32 port = 2;
repeated common.Ipv4Addr public_ips = 3;
uint32 min_port = 4;
uint32 max_port = 5;
uint32 transaction_id = 6;
uint32 round = 7;
uint32 last_port_index = 8;
message SelectPunchListenerRequest {
bool force_new = 1;
}
message TryPunchSymmetricResponse { uint32 last_port_index = 1; }
message SelectPunchListenerResponse {
common.SocketAddr listener_mapped_addr = 1;
}
message SendPunchPacketConeRequest {
common.SocketAddr listener_mapped_addr = 1;
common.SocketAddr dest_addr = 2;
uint32 transaction_id = 3;
// send this many packets in a batch
uint32 packet_count_per_batch = 4;
// send total this batch count, total packet count = packet_batch_size * packet_batch_count
uint32 packet_batch_count = 5;
// interval between each batch
uint32 packet_interval_ms = 6;
}
message SendPunchPacketHardSymRequest {
common.SocketAddr listener_mapped_addr = 1;
repeated common.Ipv4Addr public_ips = 2;
uint32 transaction_id = 3;
uint32 port_index = 4;
uint32 round = 5;
}
message SendPunchPacketHardSymResponse { uint32 next_port_index = 1; }
message SendPunchPacketEasySymRequest {
common.SocketAddr listener_mapped_addr = 1;
repeated common.Ipv4Addr public_ips = 2;
uint32 transaction_id = 3;
uint32 base_port_num = 4;
uint32 max_port_num = 5;
bool is_incremental = 6;
}
message SendPunchPacketBothEasySymRequest {
uint32 udp_socket_count = 1;
common.Ipv4Addr public_ip = 2;
uint32 transaction_id = 3;
uint32 dst_port_num = 4;
uint32 wait_time_ms = 5;
}
message SendPunchPacketBothEasySymResponse {
// is doing punch with other peer
bool is_busy = 1;
common.SocketAddr base_mapped_addr = 2;
}
service UdpHolePunchRpc {
rpc TryPunchHole(TryPunchHoleRequest) returns (TryPunchHoleResponse);
rpc TryPunchSymmetric(TryPunchSymmetricRequest)
returns (TryPunchSymmetricResponse);
rpc SelectPunchListener(SelectPunchListenerRequest)
returns (SelectPunchListenerResponse);
// send packet to one remote_addr, used by nat1-3 to nat1-3
rpc SendPunchPacketCone(SendPunchPacketConeRequest) returns (common.Void);
// send packet to multiple remote_addr (birthday attack), used by nat4 to nat1-3
rpc SendPunchPacketHardSym(SendPunchPacketHardSymRequest)
returns (SendPunchPacketHardSymResponse);
rpc SendPunchPacketEasySym(SendPunchPacketEasySymRequest)
returns (common.Void);
// nat4 to nat4 (both predictably)
rpc SendPunchPacketBothEasySym(SendPunchPacketBothEasySymRequest)
returns (SendPunchPacketBothEasySymResponse);
}
message DirectConnectedPeerInfo { int32 latency_ms = 1; }

View File

@ -146,7 +146,7 @@ impl Server {
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?;
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
let ctrl = RpcController {};
let ctrl = RpcController::default();
Ok(timeout(
timeout_duration,
reg.call_method(

View File

@ -13,6 +13,34 @@ pub trait Controller: Send + Sync + 'static {
}
#[derive(Debug)]
pub struct BaseController {}
pub struct BaseController {
pub timeout_ms: i32,
pub trace_id: i32,
}
impl Controller for BaseController {}
impl Controller for BaseController {
fn timeout_ms(&self) -> i32 {
self.timeout_ms
}
fn set_timeout_ms(&mut self, timeout_ms: i32) {
self.timeout_ms = timeout_ms;
}
fn set_trace_id(&mut self, trace_id: i32) {
self.trace_id = trace_id;
}
fn trace_id(&self) -> i32 {
self.trace_id
}
}
impl Default for BaseController {
fn default() -> Self {
Self {
timeout_ms: 5000,
trace_id: 0,
}
}
}

View File

@ -121,14 +121,14 @@ async fn rpc_basic_test() {
// small size req and resp
let ctrl = RpcController {};
let ctrl = RpcController::default();
let input = SayHelloRequest {
name: "world".to_string(),
};
let ret = out.say_hello(ctrl, input).await;
assert_eq!(ret.unwrap().greeting, "Hello world!");
let ctrl = RpcController {};
let ctrl = RpcController::default();
let input = SayGoodbyeRequest {
name: "world".to_string(),
};
@ -136,7 +136,7 @@ async fn rpc_basic_test() {
assert_eq!(ret.unwrap().greeting, "Goodbye, world!");
// large size req and resp
let ctrl = RpcController {};
let ctrl = RpcController::default();
let name = random_string(20 * 1024 * 1024);
let input = SayGoodbyeRequest { name: name.clone() };
let ret = out.say_goodbye(ctrl, input).await;
@ -160,7 +160,7 @@ async fn rpc_timeout_test() {
.client
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
let ctrl = RpcController {};
let ctrl = RpcController::default();
let input = SayHelloRequest {
name: "world".to_string(),
};
@ -199,7 +199,7 @@ async fn standalone_rpc_test() {
.await
.unwrap();
let ctrl = RpcController {};
let ctrl = RpcController::default();
let input = SayHelloRequest {
name: "world".to_string(),
};
@ -211,7 +211,7 @@ async fn standalone_rpc_test() {
.await
.unwrap();
let ctrl = RpcController {};
let ctrl = RpcController::default();
let input = SayGoodbyeRequest {
name: "world".to_string(),
};

View File

@ -94,7 +94,7 @@ pub trait Tunnel: Send {
#[auto_impl::auto_impl(Arc)]
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
fn get(&self) -> u32;
fn get(&self) -> Option<u32>;
}
#[derive(Debug, Clone, Copy, PartialEq)]
@ -114,8 +114,8 @@ pub trait TunnelListener: Send {
#[derive(Debug)]
struct FakeTunnelConnCounter {}
impl TunnelConnCounter for FakeTunnelConnCounter {
fn get(&self) -> u32 {
0
fn get(&self) -> Option<u32> {
None
}
}
Arc::new(Box::new(FakeTunnelConnCounter {}))

View File

@ -43,6 +43,10 @@ impl TunnelListener for TcpTunnelListener {
setup_sokcet2(&socket2_socket, &addr)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
if let Err(e) = socket.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in listen");
}
self.addr
.set_port(Some(socket.local_addr()?.port()))
.unwrap();
@ -54,7 +58,11 @@ impl TunnelListener for TcpTunnelListener {
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let listener = self.listener.as_ref().unwrap();
let (stream, _) = listener.accept().await?;
stream.set_nodelay(true).unwrap();
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in accept");
}
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),
local_addr: Some(self.local_url().into()),
@ -80,7 +88,9 @@ fn get_tunnel_with_tcp_stream(
stream: TcpStream,
remote_url: url::Url,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
stream.set_nodelay(true).unwrap();
if let Err(e) = stream.set_nodelay(true) {
tracing::warn!(?e, "set_nodelay fail in get_tunnel_with_tcp_stream");
}
let info = TunnelInfo {
tunnel_type: "tcp".to_owned(),

View File

@ -1,4 +1,7 @@
use std::{fmt::Debug, sync::Arc};
use std::{
fmt::Debug,
sync::{Arc, Weak},
};
use async_trait::async_trait;
use bytes::BytesMut;
@ -445,25 +448,25 @@ impl TunnelListener for UdpTunnelListener {
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
struct UdpTunnelConnCounter {
sock_map: Arc<DashMap<SocketAddr, UdpConnection>>,
sock_map: Weak<DashMap<SocketAddr, UdpConnection>>,
}
impl TunnelConnCounter for UdpTunnelConnCounter {
fn get(&self) -> Option<u32> {
self.sock_map.upgrade().map(|x| x.len() as u32)
}
}
impl Debug for UdpTunnelConnCounter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UdpTunnelConnCounter")
.field("sock_map_len", &self.sock_map.len())
.field("sock_map_len", &self.get())
.finish()
}
}
impl TunnelConnCounter for UdpTunnelConnCounter {
fn get(&self) -> u32 {
self.sock_map.len() as u32
}
}
Arc::new(Box::new(UdpTunnelConnCounter {
sock_map: self.data.sock_map.clone(),
sock_map: Arc::downgrade(&self.data.sock_map.clone()),
}))
}
}
@ -942,14 +945,22 @@ mod tests {
listener.listen().await.unwrap();
let c1 = listener.accept().await.unwrap();
assert_eq!(conn_counter.get(), 1);
assert_eq!(conn_counter.get(), Some(1));
let c2 = listener.accept().await.unwrap();
assert_eq!(conn_counter.get(), 2);
assert_eq!(conn_counter.get(), Some(2));
drop(c2);
wait_for_condition(|| async { conn_counter.get() == 1 }, Duration::from_secs(1)).await;
wait_for_condition(
|| async { conn_counter.get() == Some(1) },
Duration::from_secs(1),
)
.await;
drop(c1);
wait_for_condition(|| async { conn_counter.get() == 0 }, Duration::from_secs(1)).await;
wait_for_condition(
|| async { conn_counter.get().unwrap_or(0) == 0 },
Duration::from_secs(1),
)
.await;
}
}