mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-15 19:22:30 +08:00
nat4-nat4 punch (#388)
Some checks are pending
EasyTier Core / pre_job (push) Waiting to run
EasyTier Core / build (freebsd-13.2-x86_64, 13.2, ubuntu-latest, x86_64-unknown-freebsd) (push) Blocked by required conditions
EasyTier Core / build (linux-aarch64, ubuntu-latest, aarch64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-arm, ubuntu-latest, arm-unknown-linux-musleabi) (push) Blocked by required conditions
EasyTier Core / build (linux-armhf, ubuntu-latest, arm-unknown-linux-musleabihf) (push) Blocked by required conditions
EasyTier Core / build (linux-armv7, ubuntu-latest, armv7-unknown-linux-musleabi) (push) Blocked by required conditions
EasyTier Core / build (linux-armv7hf, ubuntu-latest, armv7-unknown-linux-musleabihf) (push) Blocked by required conditions
EasyTier Core / build (linux-mips, ubuntu-latest, mips-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-mipsel, ubuntu-latest, mipsel-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-x86_64, ubuntu-latest, x86_64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (macos-aarch64, macos-latest, aarch64-apple-darwin) (push) Blocked by required conditions
EasyTier Core / build (macos-x86_64, macos-latest, x86_64-apple-darwin) (push) Blocked by required conditions
EasyTier Core / build (windows-x86_64, windows-latest, x86_64-pc-windows-msvc) (push) Blocked by required conditions
EasyTier Core / core-result (push) Blocked by required conditions
EasyTier GUI / pre_job (push) Waiting to run
EasyTier GUI / build-gui (linux-aarch64, aarch64-unknown-linux-gnu, ubuntu-latest, aarch64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier GUI / build-gui (linux-x86_64, x86_64-unknown-linux-gnu, ubuntu-latest, x86_64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier GUI / build-gui (macos-aarch64, aarch64-apple-darwin, macos-latest, aarch64-apple-darwin) (push) Blocked by required conditions
EasyTier GUI / build-gui (macos-x86_64, x86_64-apple-darwin, macos-latest, x86_64-apple-darwin) (push) Blocked by required conditions
EasyTier GUI / build-gui (windows-x86_64, x86_64-pc-windows-msvc, windows-latest, x86_64-pc-windows-msvc) (push) Blocked by required conditions
EasyTier GUI / gui-result (push) Blocked by required conditions
EasyTier Mobile / pre_job (push) Waiting to run
EasyTier Mobile / build-mobile (android, ubuntu-latest, android) (push) Blocked by required conditions
EasyTier Mobile / mobile-result (push) Blocked by required conditions
EasyTier Test / pre_job (push) Waiting to run
EasyTier Test / test (push) Blocked by required conditions
Some checks are pending
EasyTier Core / pre_job (push) Waiting to run
EasyTier Core / build (freebsd-13.2-x86_64, 13.2, ubuntu-latest, x86_64-unknown-freebsd) (push) Blocked by required conditions
EasyTier Core / build (linux-aarch64, ubuntu-latest, aarch64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-arm, ubuntu-latest, arm-unknown-linux-musleabi) (push) Blocked by required conditions
EasyTier Core / build (linux-armhf, ubuntu-latest, arm-unknown-linux-musleabihf) (push) Blocked by required conditions
EasyTier Core / build (linux-armv7, ubuntu-latest, armv7-unknown-linux-musleabi) (push) Blocked by required conditions
EasyTier Core / build (linux-armv7hf, ubuntu-latest, armv7-unknown-linux-musleabihf) (push) Blocked by required conditions
EasyTier Core / build (linux-mips, ubuntu-latest, mips-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-mipsel, ubuntu-latest, mipsel-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (linux-x86_64, ubuntu-latest, x86_64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier Core / build (macos-aarch64, macos-latest, aarch64-apple-darwin) (push) Blocked by required conditions
EasyTier Core / build (macos-x86_64, macos-latest, x86_64-apple-darwin) (push) Blocked by required conditions
EasyTier Core / build (windows-x86_64, windows-latest, x86_64-pc-windows-msvc) (push) Blocked by required conditions
EasyTier Core / core-result (push) Blocked by required conditions
EasyTier GUI / pre_job (push) Waiting to run
EasyTier GUI / build-gui (linux-aarch64, aarch64-unknown-linux-gnu, ubuntu-latest, aarch64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier GUI / build-gui (linux-x86_64, x86_64-unknown-linux-gnu, ubuntu-latest, x86_64-unknown-linux-musl) (push) Blocked by required conditions
EasyTier GUI / build-gui (macos-aarch64, aarch64-apple-darwin, macos-latest, aarch64-apple-darwin) (push) Blocked by required conditions
EasyTier GUI / build-gui (macos-x86_64, x86_64-apple-darwin, macos-latest, x86_64-apple-darwin) (push) Blocked by required conditions
EasyTier GUI / build-gui (windows-x86_64, x86_64-pc-windows-msvc, windows-latest, x86_64-pc-windows-msvc) (push) Blocked by required conditions
EasyTier GUI / gui-result (push) Blocked by required conditions
EasyTier Mobile / pre_job (push) Waiting to run
EasyTier Mobile / build-mobile (android, ubuntu-latest, android) (push) Blocked by required conditions
EasyTier Mobile / mobile-result (push) Blocked by required conditions
EasyTier Test / pre_job (push) Waiting to run
EasyTier Test / test (push) Blocked by required conditions
this patch optimize the udp hole punch logic: 1. allow start punch hole before stun test complete. 2. add lock to symmetric punch, avoid conflict between concurrent hole punching task. 3. support punching hole for predictable nat4-nat4. 4. make backoff of retry reasonable
This commit is contained in:
parent
ba3da97ad4
commit
37ceb77bf6
|
@ -343,6 +343,8 @@ impl StunClientBuilder {
|
|||
pub struct UdpNatTypeDetectResult {
|
||||
source_addr: SocketAddr,
|
||||
stun_resps: Vec<BindRequestResponse>,
|
||||
// if we are easy symmetric nat, we need to test with another port to check inc or dec
|
||||
extra_bind_test: Option<BindRequestResponse>,
|
||||
}
|
||||
|
||||
impl UdpNatTypeDetectResult {
|
||||
|
@ -350,6 +352,7 @@ impl UdpNatTypeDetectResult {
|
|||
Self {
|
||||
source_addr,
|
||||
stun_resps,
|
||||
extra_bind_test: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -406,7 +409,7 @@ impl UdpNatTypeDetectResult {
|
|||
.filter_map(|x| x.mapped_socket_addr)
|
||||
.collect::<BTreeSet<_>>()
|
||||
.len();
|
||||
mapped_addr_count < self.stun_server_count()
|
||||
mapped_addr_count == 1
|
||||
}
|
||||
|
||||
pub fn nat_type(&self) -> NatType {
|
||||
|
@ -429,7 +432,32 @@ impl UdpNatTypeDetectResult {
|
|||
return NatType::PortRestricted;
|
||||
}
|
||||
} else if !self.stun_resps.is_empty() {
|
||||
return NatType::Symmetric;
|
||||
if self.public_ips().len() != 1
|
||||
|| self.usable_stun_resp_count() <= 1
|
||||
|| self.max_port() - self.min_port() > 15
|
||||
|| self.extra_bind_test.is_none()
|
||||
|| self
|
||||
.extra_bind_test
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.mapped_socket_addr
|
||||
.is_none()
|
||||
{
|
||||
return NatType::Symmetric;
|
||||
} else {
|
||||
let extra_bind_test = self.extra_bind_test.as_ref().unwrap();
|
||||
let extra_port = extra_bind_test.mapped_socket_addr.unwrap().port();
|
||||
|
||||
let max_port_diff = extra_port.saturating_sub(self.max_port());
|
||||
let min_port_diff = self.min_port().saturating_sub(extra_port);
|
||||
if max_port_diff != 0 && max_port_diff < 100 {
|
||||
return NatType::SymmetricEasyInc;
|
||||
} else if min_port_diff != 0 && min_port_diff < 100 {
|
||||
return NatType::SymmetricEasyDec;
|
||||
} else {
|
||||
return NatType::Symmetric;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return NatType::Unknown;
|
||||
}
|
||||
|
@ -477,6 +505,13 @@ impl UdpNatTypeDetectResult {
|
|||
.max()
|
||||
.unwrap_or(u16::MAX)
|
||||
}
|
||||
|
||||
pub fn usable_stun_resp_count(&self) -> usize {
|
||||
self.stun_resps
|
||||
.iter()
|
||||
.filter(|x| x.mapped_socket_addr.is_some())
|
||||
.count()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpNatTypeDetector {
|
||||
|
@ -492,6 +527,19 @@ impl UdpNatTypeDetector {
|
|||
}
|
||||
}
|
||||
|
||||
async fn get_extra_bind_result(
|
||||
&self,
|
||||
source_port: u16,
|
||||
stun_server: SocketAddr,
|
||||
) -> Result<BindRequestResponse, Error> {
|
||||
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?);
|
||||
let client_builder = StunClientBuilder::new(udp.clone());
|
||||
client_builder
|
||||
.new_stun_client(stun_server)
|
||||
.bind_request(false, false)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn detect_nat_type(&self, source_port: u16) -> Result<UdpNatTypeDetectResult, Error> {
|
||||
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", source_port)).await?);
|
||||
self.detect_nat_type_with_socket(udp).await
|
||||
|
@ -578,13 +626,28 @@ impl StunInfoCollectorTrait for StunInfoCollector {
|
|||
async fn get_udp_port_mapping(&self, local_port: u16) -> Result<SocketAddr, Error> {
|
||||
self.start_stun_routine();
|
||||
|
||||
let stun_servers = self
|
||||
let mut stun_servers = self
|
||||
.udp_nat_test_result
|
||||
.read()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.map(|x| x.collect_available_stun_server())
|
||||
.ok_or(Error::NotFound)?;
|
||||
.unwrap_or(vec![]);
|
||||
|
||||
if stun_servers.is_empty() {
|
||||
let mut host_resolver =
|
||||
HostResolverIter::new(self.stun_servers.read().unwrap().clone(), 2);
|
||||
while let Some(addr) = host_resolver.next().await {
|
||||
stun_servers.push(addr);
|
||||
if stun_servers.len() >= 2 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if stun_servers.is_empty() {
|
||||
return Err(Error::NotFound);
|
||||
}
|
||||
|
||||
let udp = Arc::new(UdpSocket::bind(format!("0.0.0.0:{}", local_port)).await?);
|
||||
let mut client_builder = StunClientBuilder::new(udp.clone());
|
||||
|
@ -630,9 +693,9 @@ impl StunInfoCollector {
|
|||
// stun server cross nation may return a external ip address with high latency and loss rate
|
||||
vec![
|
||||
"stun.miwifi.com",
|
||||
"stun.cdnbye.com",
|
||||
"stun.hitv.com",
|
||||
"stun.chat.bilibili.com",
|
||||
"stun.hitv.com",
|
||||
"stun.cdnbye.com",
|
||||
"stun.douyucdn.cn:18000",
|
||||
"fwa.lifesizecloud.com",
|
||||
"global.turn.twilio.com",
|
||||
|
@ -673,38 +736,41 @@ impl StunInfoCollector {
|
|||
.map(|x| x.to_string())
|
||||
.collect();
|
||||
let detector = UdpNatTypeDetector::new(servers, 1);
|
||||
let ret = detector.detect_nat_type(0).await;
|
||||
let mut ret = detector.detect_nat_type(0).await;
|
||||
tracing::debug!(?ret, "finish udp nat type detect");
|
||||
|
||||
let mut nat_type = NatType::Unknown;
|
||||
let sleep_sec = match &ret {
|
||||
Ok(resp) => {
|
||||
*udp_nat_test_result.write().unwrap() = Some(resp.clone());
|
||||
udp_test_time.store(Local::now());
|
||||
nat_type = resp.nat_type();
|
||||
if nat_type == NatType::Unknown {
|
||||
15
|
||||
} else {
|
||||
600
|
||||
}
|
||||
}
|
||||
_ => 15,
|
||||
};
|
||||
if let Ok(resp) = &ret {
|
||||
tracing::debug!(?resp, "got udp nat type detect result");
|
||||
nat_type = resp.nat_type();
|
||||
}
|
||||
|
||||
// if nat type is symmtric, detect with another port to gather more info
|
||||
if nat_type == NatType::Symmetric {
|
||||
let old_resp = ret.unwrap();
|
||||
let old_local_port = old_resp.local_addr().port();
|
||||
let new_port = if old_local_port >= 65535 {
|
||||
old_local_port - 1
|
||||
} else {
|
||||
old_local_port + 1
|
||||
};
|
||||
let ret = detector.detect_nat_type(new_port).await;
|
||||
tracing::debug!(?ret, "finish udp nat type detect with another port");
|
||||
if let Ok(resp) = ret {
|
||||
udp_nat_test_result.write().unwrap().as_mut().map(|x| {
|
||||
x.extend_result(resp);
|
||||
});
|
||||
let old_resp = ret.as_mut().unwrap();
|
||||
tracing::debug!(?old_resp, "start get extra bind result");
|
||||
let available_stun_servers = old_resp.collect_available_stun_server();
|
||||
for server in available_stun_servers.iter() {
|
||||
let ret = detector
|
||||
.get_extra_bind_result(0, *server)
|
||||
.await
|
||||
.with_context(|| "get extra bind result failed");
|
||||
tracing::debug!(?ret, "finish udp nat type detect with another port");
|
||||
if let Ok(resp) = ret {
|
||||
old_resp.extra_bind_test = Some(resp);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut sleep_sec = 10;
|
||||
if let Ok(resp) = &ret {
|
||||
udp_test_time.store(Local::now());
|
||||
*udp_nat_test_result.write().unwrap() = Some(resp.clone());
|
||||
if nat_type != NatType::Unknown
|
||||
&& (nat_type != NatType::Symmetric || resp.extra_bind_test.is_some())
|
||||
{
|
||||
sleep_sec = 600
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -734,7 +800,7 @@ impl StunInfoCollectorTrait for MockStunInfoCollector {
|
|||
last_update_time: std::time::Instant::now().elapsed().as_secs() as i64,
|
||||
min_port: 100,
|
||||
max_port: 200,
|
||||
..Default::default()
|
||||
public_ip: vec!["127.0.0.1".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -425,7 +425,7 @@ impl DirectConnectorManager {
|
|||
);
|
||||
|
||||
let ip_list = rpc_stub
|
||||
.get_ip_list(BaseController {}, GetIpListRequest {})
|
||||
.get_ip_list(BaseController::default(), GetIpListRequest {})
|
||||
.await
|
||||
.with_context(|| format!("get ip list from peer {}", dst_peer_id))?;
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
396
easytier/src/connector/udp_hole_punch/both_easy_sym.rs
Normal file
396
easytier/src/connector/udp_hole_punch/both_easy_sym.rs
Normal file
|
@ -0,0 +1,396 @@
|
|||
use std::{
|
||||
net::{IpAddr, SocketAddr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::{
|
||||
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId},
|
||||
connector::udp_hole_punch::common::{
|
||||
try_connect_with_socket, UdpHolePunchListener, HOLE_PUNCH_PACKET_BODY_LEN,
|
||||
},
|
||||
peers::peer_manager::PeerManager,
|
||||
proto::{
|
||||
peer_rpc::{
|
||||
SendPunchPacketBothEasySymRequest, SendPunchPacketBothEasySymResponse,
|
||||
UdpHolePunchRpcClientFactory,
|
||||
},
|
||||
rpc_types::{self, controller::BaseController},
|
||||
},
|
||||
tunnel::{udp::new_hole_punch_packet, Tunnel},
|
||||
};
|
||||
|
||||
use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray};
|
||||
|
||||
const UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM: usize = 25;
|
||||
const DST_PORT_OFFSET: u16 = 20;
|
||||
const REMOTE_WAIT_TIME_MS: u64 = 5000;
|
||||
|
||||
pub(crate) struct PunchBothEasySymHoleServer {
|
||||
common: Arc<PunchHoleServerCommon>,
|
||||
task: Mutex<Option<ScopedTask<()>>>,
|
||||
}
|
||||
|
||||
impl PunchBothEasySymHoleServer {
|
||||
pub(crate) fn new(common: Arc<PunchHoleServerCommon>) -> Self {
|
||||
Self {
|
||||
common,
|
||||
task: Mutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
// hard sym means public port is random and cannot be predicted
|
||||
#[tracing::instrument(skip(self), ret, err)]
|
||||
pub(crate) async fn send_punch_packet_both_easy_sym(
|
||||
&self,
|
||||
request: SendPunchPacketBothEasySymRequest,
|
||||
) -> Result<SendPunchPacketBothEasySymResponse, rpc_types::error::Error> {
|
||||
tracing::info!("send_punch_packet_both_easy_sym start");
|
||||
let busy_resp = Ok(SendPunchPacketBothEasySymResponse {
|
||||
is_busy: true,
|
||||
..Default::default()
|
||||
});
|
||||
let Ok(mut locked_task) = self.task.try_lock() else {
|
||||
return busy_resp;
|
||||
};
|
||||
if locked_task.is_some() && !locked_task.as_ref().unwrap().is_finished() {
|
||||
return busy_resp;
|
||||
}
|
||||
|
||||
let global_ctx = self.common.get_global_ctx();
|
||||
let cur_mapped_addr = global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_udp_port_mapping(0)
|
||||
.await
|
||||
.with_context(|| "failed to get udp port mapping")?;
|
||||
|
||||
tracing::info!("send_punch_packet_hard_sym start");
|
||||
let socket_count = request.udp_socket_count as usize;
|
||||
let public_ips = request
|
||||
.public_ip
|
||||
.ok_or(anyhow::anyhow!("public_ip is required"))?;
|
||||
let transaction_id = request.transaction_id;
|
||||
|
||||
let udp_array =
|
||||
UdpSocketArray::new(socket_count, self.common.get_global_ctx().net_ns.clone());
|
||||
udp_array.start().await?;
|
||||
udp_array.add_intreast_tid(transaction_id);
|
||||
let peer_mgr = self.common.get_peer_mgr();
|
||||
|
||||
let punch_packet =
|
||||
new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes();
|
||||
let mut punched = vec![];
|
||||
let common = self.common.clone();
|
||||
|
||||
let task = tokio::spawn(async move {
|
||||
let mut listeners = Vec::new();
|
||||
let start_time = Instant::now();
|
||||
let wait_time_ms = request.wait_time_ms.min(8000);
|
||||
while start_time.elapsed() < Duration::from_millis(wait_time_ms as u64) {
|
||||
if let Err(e) = udp_array
|
||||
.send_with_all(
|
||||
&punch_packet,
|
||||
SocketAddr::V4(SocketAddrV4::new(
|
||||
public_ips.into(),
|
||||
request.dst_port_num as u16,
|
||||
)),
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!(?e, "failed to send hole punch packet");
|
||||
break;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
if let Some(s) = udp_array.try_fetch_punched_socket(transaction_id) {
|
||||
tracing::info!(?s, ?transaction_id, "got punched socket in both easy sym");
|
||||
assert!(Arc::strong_count(&s.socket) == 1);
|
||||
let Some(port) = s.socket.local_addr().ok().map(|addr| addr.port()) else {
|
||||
tracing::warn!("failed to get local addr from punched socket");
|
||||
continue;
|
||||
};
|
||||
let remote_addr = s.remote_addr;
|
||||
drop(s);
|
||||
|
||||
let listener =
|
||||
match UdpHolePunchListener::new_ext(peer_mgr.clone(), false, Some(port))
|
||||
.await
|
||||
{
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
tracing::warn!(?e, "failed to create listener");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
punched.push((listener.get_socket().await, remote_addr));
|
||||
listeners.push(listener);
|
||||
}
|
||||
|
||||
// if any listener is punched, we can break the loop
|
||||
for l in &listeners {
|
||||
if l.get_conn_count().await > 0 {
|
||||
tracing::info!(?l, "got punched listener");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !punched.is_empty() {
|
||||
tracing::debug!(?punched, "got punched socket and keep sending punch packet");
|
||||
}
|
||||
|
||||
for p in &punched {
|
||||
let (socket, remote_addr) = p;
|
||||
let send_remote_ret = socket.send_to(&punch_packet, remote_addr).await;
|
||||
tracing::debug!(
|
||||
?send_remote_ret,
|
||||
?socket,
|
||||
"send hole punch packet to punched remote"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for l in listeners {
|
||||
if l.get_conn_count().await > 0 {
|
||||
common.add_listener(l).await;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
*locked_task = Some(task.into());
|
||||
return Ok(SendPunchPacketBothEasySymResponse {
|
||||
is_busy: false,
|
||||
base_mapped_addr: Some(cur_mapped_addr.into()),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PunchBothEasySymHoleClient {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
impl PunchBothEasySymHoleClient {
|
||||
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self { peer_mgr }
|
||||
}
|
||||
|
||||
#[tracing::instrument(ret)]
|
||||
pub(crate) async fn do_hole_punching(
|
||||
&self,
|
||||
dst_peer_id: PeerId,
|
||||
my_nat_info: UdpNatType,
|
||||
peer_nat_info: UdpNatType,
|
||||
is_busy: &mut bool,
|
||||
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
|
||||
*is_busy = false;
|
||||
|
||||
let udp_array = UdpSocketArray::new(
|
||||
UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM,
|
||||
self.peer_mgr.get_global_ctx().net_ns.clone(),
|
||||
);
|
||||
udp_array.start().await?;
|
||||
|
||||
let global_ctx = self.peer_mgr.get_global_ctx();
|
||||
let cur_mapped_addr = global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_udp_port_mapping(0)
|
||||
.await
|
||||
.with_context(|| "failed to get udp port mapping")?;
|
||||
let my_public_ip = match cur_mapped_addr.ip() {
|
||||
IpAddr::V4(v4) => v4,
|
||||
_ => {
|
||||
anyhow::bail!("ipv6 is not supported");
|
||||
}
|
||||
};
|
||||
let me_is_incremental = my_nat_info
|
||||
.get_inc_of_easy_sym()
|
||||
.ok_or(anyhow::anyhow!("me_is_incremental is required"))?;
|
||||
let peer_is_incremental = peer_nat_info
|
||||
.get_inc_of_easy_sym()
|
||||
.ok_or(anyhow::anyhow!("peer_is_incremental is required"))?;
|
||||
|
||||
let rpc_stub = self
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_client()
|
||||
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
|
||||
self.peer_mgr.my_peer_id(),
|
||||
dst_peer_id,
|
||||
global_ctx.get_network_name(),
|
||||
);
|
||||
|
||||
let tid = rand::random();
|
||||
udp_array.add_intreast_tid(tid);
|
||||
|
||||
let remote_ret = rpc_stub
|
||||
.send_punch_packet_both_easy_sym(
|
||||
BaseController {
|
||||
timeout_ms: 2000,
|
||||
..Default::default()
|
||||
},
|
||||
SendPunchPacketBothEasySymRequest {
|
||||
transaction_id: tid,
|
||||
public_ip: Some(my_public_ip.into()),
|
||||
dst_port_num: if me_is_incremental {
|
||||
cur_mapped_addr.port().saturating_add(DST_PORT_OFFSET)
|
||||
} else {
|
||||
cur_mapped_addr.port().saturating_sub(DST_PORT_OFFSET)
|
||||
} as u32,
|
||||
udp_socket_count: UDP_ARRAY_SIZE_FOR_BOTH_EASY_SYM as u32,
|
||||
wait_time_ms: REMOTE_WAIT_TIME_MS as u32,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
if remote_ret.is_busy {
|
||||
*is_busy = true;
|
||||
anyhow::bail!("remote is busy");
|
||||
}
|
||||
|
||||
let mut remote_mapped_addr = remote_ret
|
||||
.base_mapped_addr
|
||||
.ok_or(anyhow::anyhow!("remote_mapped_addr is required"))?;
|
||||
|
||||
let now = Instant::now();
|
||||
remote_mapped_addr.port = if peer_is_incremental {
|
||||
remote_mapped_addr
|
||||
.port
|
||||
.saturating_add(DST_PORT_OFFSET as u32)
|
||||
} else {
|
||||
remote_mapped_addr
|
||||
.port
|
||||
.saturating_sub(DST_PORT_OFFSET as u32)
|
||||
};
|
||||
tracing::debug!(
|
||||
?remote_mapped_addr,
|
||||
?remote_ret,
|
||||
"start send hole punch packet for both easy sym"
|
||||
);
|
||||
|
||||
while now.elapsed().as_millis() < (REMOTE_WAIT_TIME_MS + 1000).into() {
|
||||
udp_array
|
||||
.send_with_all(
|
||||
&new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(),
|
||||
remote_mapped_addr.into(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
let Some(socket) = udp_array.try_fetch_punched_socket(tid) else {
|
||||
tracing::trace!(
|
||||
?remote_mapped_addr,
|
||||
?tid,
|
||||
"no punched socket found, send some more hole punch packets"
|
||||
);
|
||||
continue;
|
||||
};
|
||||
|
||||
tracing::info!(
|
||||
?socket,
|
||||
?remote_mapped_addr,
|
||||
?tid,
|
||||
"got punched socket in both easy sym"
|
||||
);
|
||||
|
||||
for _ in 0..2 {
|
||||
match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into())
|
||||
.await
|
||||
{
|
||||
Ok(tunnel) => {
|
||||
return Ok(tunnel);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "failed to connect with socket");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
udp_array.add_new_socket(socket.socket).await?;
|
||||
}
|
||||
|
||||
anyhow::bail!("failed to punch hole for both easy sym");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::{
|
||||
sync::{atomic::AtomicU32, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use crate::{
|
||||
connector::udp_hole_punch::{
|
||||
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector,
|
||||
},
|
||||
peers::tests::{connect_peer_manager, wait_route_appear},
|
||||
proto::common::NatType,
|
||||
tunnel::common::tests::wait_for_condition,
|
||||
};
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial(hole_punch)]
|
||||
async fn hole_punching_easy_sym(#[values("true", "false")] is_inc: bool) {
|
||||
let p_a = create_mock_peer_manager_with_mock_stun(if is_inc {
|
||||
NatType::SymmetricEasyInc
|
||||
} else {
|
||||
NatType::SymmetricEasyDec
|
||||
})
|
||||
.await;
|
||||
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
let p_c = create_mock_peer_manager_with_mock_stun(if !is_inc {
|
||||
NatType::SymmetricEasyInc
|
||||
} else {
|
||||
NatType::SymmetricEasyDec
|
||||
})
|
||||
.await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
|
||||
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
// 144 + DST_PORT_OFFSET = 164
|
||||
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40164").await.unwrap());
|
||||
// 144 - DST_PORT_OFFSET = 124
|
||||
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40124").await.unwrap());
|
||||
let udps = vec![udp1, udp2];
|
||||
|
||||
let counter = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// all these sockets should receive hole punching packet
|
||||
for udp in udps.iter().map(Arc::clone) {
|
||||
let counter = counter.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = [0u8; 1024];
|
||||
let (len, addr) = udp.recv_from(&mut buf).await.unwrap();
|
||||
println!(
|
||||
"got predictable punch packet, {:?} {:?} {:?}",
|
||||
len,
|
||||
addr,
|
||||
udp.local_addr()
|
||||
);
|
||||
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
});
|
||||
}
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
let udp_len = udps.len();
|
||||
wait_for_condition(
|
||||
|| async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 },
|
||||
Duration::from_secs(30),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
573
easytier/src/connector/udp_hole_punch/common.rs
Normal file
573
easytier/src/connector/udp_hole_punch/common.rs
Normal file
|
@ -0,0 +1,573 @@
|
|||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use crossbeam::atomic::AtomicCell;
|
||||
use dashmap::{DashMap, DashSet};
|
||||
use rand::seq::SliceRandom as _;
|
||||
use tokio::{net::UdpSocket, sync::Mutex, task::JoinSet};
|
||||
use tracing::{instrument, Instrument, Level};
|
||||
use zerocopy::FromBytes as _;
|
||||
|
||||
use crate::{
|
||||
common::{
|
||||
error::Error, global_ctx::ArcGlobalCtx, join_joinset_background, netns::NetNS,
|
||||
stun::StunInfoCollectorTrait as _,
|
||||
},
|
||||
defer,
|
||||
peers::peer_manager::PeerManager,
|
||||
proto::common::NatType,
|
||||
tunnel::{
|
||||
packet_def::{UDPTunnelHeader, UdpPacketType, UDP_TUNNEL_HEADER_SIZE},
|
||||
udp::{new_hole_punch_packet, UdpTunnelConnector, UdpTunnelListener},
|
||||
Tunnel, TunnelConnCounter, TunnelListener as _,
|
||||
},
|
||||
};
|
||||
|
||||
pub(crate) const HOLE_PUNCH_PACKET_BODY_LEN: u16 = 16;
|
||||
|
||||
fn generate_shuffled_port_vec() -> Vec<u16> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut port_vec: Vec<u16> = (1..=65535).collect();
|
||||
port_vec.shuffle(&mut rng);
|
||||
port_vec
|
||||
}
|
||||
|
||||
pub(crate) enum UdpPunchClientMethod {
|
||||
None,
|
||||
ConeToCone,
|
||||
SymToCone,
|
||||
EasySymToEasySym,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub(crate) enum UdpNatType {
|
||||
Unknown,
|
||||
Open(NatType),
|
||||
Cone(NatType),
|
||||
// bool means if it is incremental
|
||||
EasySymmetric(NatType, bool),
|
||||
HardSymmetric(NatType),
|
||||
}
|
||||
|
||||
impl From<NatType> for UdpNatType {
|
||||
fn from(nat_type: NatType) -> Self {
|
||||
match nat_type {
|
||||
NatType::Unknown => UdpNatType::Unknown,
|
||||
NatType::NoPat | NatType::OpenInternet => UdpNatType::Open(nat_type),
|
||||
NatType::FullCone | NatType::Restricted | NatType::PortRestricted => {
|
||||
UdpNatType::Cone(nat_type)
|
||||
}
|
||||
NatType::Symmetric | NatType::SymUdpFirewall => UdpNatType::HardSymmetric(nat_type),
|
||||
NatType::SymmetricEasyInc => UdpNatType::EasySymmetric(nat_type, true),
|
||||
NatType::SymmetricEasyDec => UdpNatType::EasySymmetric(nat_type, false),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<NatType> for UdpNatType {
|
||||
fn into(self) -> NatType {
|
||||
match self {
|
||||
UdpNatType::Unknown => NatType::Unknown,
|
||||
UdpNatType::Open(nat_type) => nat_type,
|
||||
UdpNatType::Cone(nat_type) => nat_type,
|
||||
UdpNatType::EasySymmetric(nat_type, _) => nat_type,
|
||||
UdpNatType::HardSymmetric(nat_type) => nat_type,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UdpNatType {
|
||||
pub(crate) fn is_open(&self) -> bool {
|
||||
matches!(self, UdpNatType::Open(_))
|
||||
}
|
||||
|
||||
pub(crate) fn is_unknown(&self) -> bool {
|
||||
matches!(self, UdpNatType::Unknown)
|
||||
}
|
||||
|
||||
pub(crate) fn is_sym(&self) -> bool {
|
||||
self.is_hard_sym() || self.is_easy_sym()
|
||||
}
|
||||
|
||||
pub(crate) fn is_hard_sym(&self) -> bool {
|
||||
matches!(self, UdpNatType::HardSymmetric(_))
|
||||
}
|
||||
|
||||
pub(crate) fn is_easy_sym(&self) -> bool {
|
||||
matches!(self, UdpNatType::EasySymmetric(_, _))
|
||||
}
|
||||
|
||||
pub(crate) fn is_cone(&self) -> bool {
|
||||
matches!(self, UdpNatType::Cone(_))
|
||||
}
|
||||
|
||||
pub(crate) fn get_inc_of_easy_sym(&self) -> Option<bool> {
|
||||
match self {
|
||||
UdpNatType::EasySymmetric(_, inc) => Some(*inc),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn get_punch_hole_method(&self, other: Self) -> UdpPunchClientMethod {
|
||||
if other.is_unknown() {
|
||||
if self.is_sym() {
|
||||
return UdpPunchClientMethod::SymToCone;
|
||||
} else {
|
||||
return UdpPunchClientMethod::ConeToCone;
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_unknown() {
|
||||
if other.is_sym() {
|
||||
return UdpPunchClientMethod::None;
|
||||
} else {
|
||||
return UdpPunchClientMethod::ConeToCone;
|
||||
}
|
||||
}
|
||||
|
||||
if self.is_open() || other.is_open() {
|
||||
// open nat does not need to punch hole
|
||||
return UdpPunchClientMethod::None;
|
||||
}
|
||||
|
||||
if self.is_cone() {
|
||||
if other.is_sym() {
|
||||
return UdpPunchClientMethod::None;
|
||||
} else {
|
||||
return UdpPunchClientMethod::ConeToCone;
|
||||
}
|
||||
} else if self.is_easy_sym() {
|
||||
if other.is_hard_sym() {
|
||||
return UdpPunchClientMethod::None;
|
||||
} else if other.is_easy_sym() {
|
||||
return UdpPunchClientMethod::EasySymToEasySym;
|
||||
} else {
|
||||
return UdpPunchClientMethod::SymToCone;
|
||||
}
|
||||
} else if self.is_hard_sym() {
|
||||
if other.is_sym() {
|
||||
return UdpPunchClientMethod::None;
|
||||
} else {
|
||||
return UdpPunchClientMethod::SymToCone;
|
||||
}
|
||||
}
|
||||
|
||||
unreachable!("invalid nat type");
|
||||
}
|
||||
|
||||
pub(crate) fn can_punch_hole_as_client(&self, other: Self) -> bool {
|
||||
!matches!(
|
||||
self.get_punch_hole_method(other),
|
||||
UdpPunchClientMethod::None
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct PunchedUdpSocket {
|
||||
pub(crate) socket: Arc<UdpSocket>,
|
||||
pub(crate) tid: u32,
|
||||
pub(crate) remote_addr: SocketAddr,
|
||||
}
|
||||
|
||||
// used for symmetric hole punching, binding to multiple ports to increase the chance of success
|
||||
pub(crate) struct UdpSocketArray {
|
||||
sockets: Arc<DashMap<SocketAddr, Arc<UdpSocket>>>,
|
||||
max_socket_count: usize,
|
||||
net_ns: NetNS,
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
|
||||
intreast_tids: Arc<DashSet<u32>>,
|
||||
tid_to_socket: Arc<DashMap<u32, Vec<PunchedUdpSocket>>>,
|
||||
}
|
||||
|
||||
impl UdpSocketArray {
|
||||
pub fn new(max_socket_count: usize, net_ns: NetNS) -> Self {
|
||||
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||
join_joinset_background(tasks.clone(), "UdpSocketArray".to_owned());
|
||||
|
||||
Self {
|
||||
sockets: Arc::new(DashMap::new()),
|
||||
max_socket_count,
|
||||
net_ns,
|
||||
tasks,
|
||||
|
||||
intreast_tids: Arc::new(DashSet::new()),
|
||||
tid_to_socket: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn started(&self) -> bool {
|
||||
!self.sockets.is_empty()
|
||||
}
|
||||
|
||||
pub async fn add_new_socket(&self, socket: Arc<UdpSocket>) -> Result<(), anyhow::Error> {
|
||||
let socket_map = self.sockets.clone();
|
||||
let local_addr = socket.local_addr()?;
|
||||
let intreast_tids = self.intreast_tids.clone();
|
||||
let tid_to_socket = self.tid_to_socket.clone();
|
||||
socket_map.insert(local_addr, socket.clone());
|
||||
self.tasks.lock().unwrap().spawn(
|
||||
async move {
|
||||
defer!(socket_map.remove(&local_addr););
|
||||
let mut buf = [0u8; UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize];
|
||||
tracing::trace!(?local_addr, "udp socket added");
|
||||
loop {
|
||||
let Ok((len, addr)) = socket.recv_from(&mut buf).await else {
|
||||
break;
|
||||
};
|
||||
|
||||
tracing::debug!(?len, ?addr, "got raw packet");
|
||||
|
||||
if len != UDP_TUNNEL_HEADER_SIZE + HOLE_PUNCH_PACKET_BODY_LEN as usize {
|
||||
continue;
|
||||
}
|
||||
|
||||
let Some(p) = UDPTunnelHeader::ref_from_prefix(&buf) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let tid = p.conn_id.get();
|
||||
let valid = p.msg_type == UdpPacketType::HolePunch as u8
|
||||
&& p.len.get() == HOLE_PUNCH_PACKET_BODY_LEN;
|
||||
tracing::debug!(?p, ?addr, ?tid, ?valid, ?p, "got udp hole punch packet");
|
||||
|
||||
if !valid {
|
||||
continue;
|
||||
}
|
||||
|
||||
if intreast_tids.contains(&tid) {
|
||||
tracing::info!(?addr, ?tid, "got hole punching packet with intreast tid");
|
||||
tid_to_socket
|
||||
.entry(tid)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(PunchedUdpSocket {
|
||||
socket: socket.clone(),
|
||||
tid,
|
||||
remote_addr: addr,
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
tracing::debug!(?local_addr, "udp socket recv loop end");
|
||||
}
|
||||
.instrument(tracing::info_span!("udp array socket recv loop")),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(err)]
|
||||
pub async fn start(&self) -> Result<(), anyhow::Error> {
|
||||
tracing::info!("starting udp socket array");
|
||||
|
||||
while self.sockets.len() < self.max_socket_count {
|
||||
let socket = {
|
||||
let _g = self.net_ns.guard();
|
||||
Arc::new(UdpSocket::bind("0.0.0.0:0").await?)
|
||||
};
|
||||
|
||||
self.add_new_socket(socket).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(err)]
|
||||
pub async fn send_with_all(&self, data: &[u8], addr: SocketAddr) -> Result<(), anyhow::Error> {
|
||||
tracing::info!(?addr, "sending hole punching packet");
|
||||
|
||||
for socket in self.sockets.iter() {
|
||||
let socket = socket.value();
|
||||
socket.send_to(data, addr).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[instrument(ret(level = Level::DEBUG))]
|
||||
pub fn try_fetch_punched_socket(&self, tid: u32) -> Option<PunchedUdpSocket> {
|
||||
tracing::debug!(?tid, "try fetch punched socket");
|
||||
self.tid_to_socket.get_mut(&tid)?.value_mut().pop()
|
||||
}
|
||||
|
||||
pub fn add_intreast_tid(&self, tid: u32) {
|
||||
self.intreast_tids.insert(tid);
|
||||
}
|
||||
|
||||
pub fn remove_intreast_tid(&self, tid: u32) {
|
||||
self.intreast_tids.remove(&tid);
|
||||
self.tid_to_socket.remove(&tid);
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for UdpSocketArray {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("UdpSocketArray")
|
||||
.field("sockets", &self.sockets.len())
|
||||
.field("max_socket_count", &self.max_socket_count)
|
||||
.field("started", &self.started())
|
||||
.field("intreast_tids", &self.intreast_tids.len())
|
||||
.field("tid_to_socket", &self.tid_to_socket.len())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct UdpHolePunchListener {
|
||||
socket: Arc<UdpSocket>,
|
||||
tasks: JoinSet<()>,
|
||||
running: Arc<AtomicCell<bool>>,
|
||||
mapped_addr: SocketAddr,
|
||||
conn_counter: Arc<Box<dyn TunnelConnCounter>>,
|
||||
|
||||
listen_time: std::time::Instant,
|
||||
last_select_time: AtomicCell<std::time::Instant>,
|
||||
last_active_time: Arc<AtomicCell<std::time::Instant>>,
|
||||
}
|
||||
|
||||
impl UdpHolePunchListener {
|
||||
async fn get_avail_port() -> Result<u16, Error> {
|
||||
let socket = UdpSocket::bind("0.0.0.0:0").await?;
|
||||
Ok(socket.local_addr()?.port())
|
||||
}
|
||||
|
||||
#[instrument(err)]
|
||||
pub async fn new(peer_mgr: Arc<PeerManager>) -> Result<Self, Error> {
|
||||
Self::new_ext(peer_mgr, true, None).await
|
||||
}
|
||||
|
||||
#[instrument(err)]
|
||||
pub async fn new_ext(
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
with_mapped_addr: bool,
|
||||
port: Option<u16>,
|
||||
) -> Result<Self, Error> {
|
||||
let port = port.unwrap_or(Self::get_avail_port().await?);
|
||||
let listen_url = format!("udp://0.0.0.0:{}", port);
|
||||
|
||||
let mapped_addr = if with_mapped_addr {
|
||||
let gctx = peer_mgr.get_global_ctx();
|
||||
let stun_info_collect = gctx.get_stun_info_collector();
|
||||
stun_info_collect.get_udp_port_mapping(port).await?
|
||||
} else {
|
||||
SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), port))
|
||||
};
|
||||
|
||||
let mut listener = UdpTunnelListener::new(listen_url.parse().unwrap());
|
||||
|
||||
{
|
||||
let _g = peer_mgr.get_global_ctx().net_ns.guard();
|
||||
listener.listen().await?;
|
||||
}
|
||||
let socket = listener.get_socket().unwrap();
|
||||
|
||||
let running = Arc::new(AtomicCell::new(true));
|
||||
let running_clone = running.clone();
|
||||
|
||||
let conn_counter = listener.get_conn_counter();
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
tasks.spawn(async move {
|
||||
while let Ok(conn) = listener.accept().await {
|
||||
tracing::warn!(?conn, "udp hole punching listener got peer connection");
|
||||
let peer_mgr = peer_mgr.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = peer_mgr.add_tunnel_as_server(conn).await {
|
||||
tracing::error!(
|
||||
?e,
|
||||
"failed to add tunnel as server in hole punch listener"
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
running_clone.store(false);
|
||||
});
|
||||
|
||||
let last_active_time = Arc::new(AtomicCell::new(std::time::Instant::now()));
|
||||
let conn_counter_clone = conn_counter.clone();
|
||||
let last_active_time_clone = last_active_time.clone();
|
||||
tasks.spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
|
||||
if conn_counter_clone.get().unwrap_or(0) != 0 {
|
||||
last_active_time_clone.store(std::time::Instant::now());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
tracing::warn!(?mapped_addr, ?socket, "udp hole punching listener started");
|
||||
|
||||
Ok(Self {
|
||||
tasks,
|
||||
socket,
|
||||
running,
|
||||
mapped_addr,
|
||||
conn_counter,
|
||||
|
||||
listen_time: std::time::Instant::now(),
|
||||
last_select_time: AtomicCell::new(std::time::Instant::now()),
|
||||
last_active_time,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn get_socket(&self) -> Arc<UdpSocket> {
|
||||
self.last_select_time.store(std::time::Instant::now());
|
||||
self.socket.clone()
|
||||
}
|
||||
|
||||
pub async fn get_conn_count(&self) -> usize {
|
||||
self.conn_counter.get().unwrap_or(0) as usize
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PunchHoleServerCommon {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
|
||||
listeners: Arc<Mutex<Vec<UdpHolePunchListener>>>,
|
||||
tasks: Arc<std::sync::Mutex<JoinSet<()>>>,
|
||||
}
|
||||
|
||||
impl PunchHoleServerCommon {
|
||||
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
let tasks = Arc::new(std::sync::Mutex::new(JoinSet::new()));
|
||||
join_joinset_background(tasks.clone(), "PunchHoleServerCommon".to_owned());
|
||||
|
||||
let listeners = Arc::new(Mutex::new(Vec::<UdpHolePunchListener>::new()));
|
||||
|
||||
let l = listeners.clone();
|
||||
tasks.lock().unwrap().spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
{
|
||||
// remove listener that is not active for 40 seconds but keep listeners that are selected less than 30 seconds
|
||||
l.lock().await.retain(|listener| {
|
||||
listener.last_active_time.load().elapsed().as_secs() < 40
|
||||
|| listener.last_select_time.load().elapsed().as_secs() < 30
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
peer_mgr,
|
||||
|
||||
listeners,
|
||||
tasks,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn add_listener(&self, listener: UdpHolePunchListener) {
|
||||
self.listeners.lock().await.push(listener);
|
||||
}
|
||||
|
||||
pub(crate) async fn find_listener(&self, addr: &SocketAddr) -> Option<Arc<UdpSocket>> {
|
||||
let all_listener_sockets = self.listeners.lock().await;
|
||||
|
||||
let listener = all_listener_sockets
|
||||
.iter()
|
||||
.find(|listener| listener.mapped_addr == *addr && listener.running.load())?;
|
||||
|
||||
Some(listener.get_socket().await)
|
||||
}
|
||||
|
||||
pub(crate) async fn my_udp_nat_type(&self) -> i32 {
|
||||
self.peer_mgr
|
||||
.get_global_ctx()
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type
|
||||
}
|
||||
|
||||
pub(crate) async fn select_listener(
|
||||
&self,
|
||||
use_new_listener: bool,
|
||||
) -> Option<(Arc<UdpSocket>, SocketAddr)> {
|
||||
let all_listener_sockets = &self.listeners;
|
||||
|
||||
let mut use_last = false;
|
||||
if all_listener_sockets.lock().await.len() < 16 || use_new_listener {
|
||||
tracing::warn!("creating new udp hole punching listener");
|
||||
all_listener_sockets.lock().await.push(
|
||||
UdpHolePunchListener::new(self.peer_mgr.clone())
|
||||
.await
|
||||
.ok()?,
|
||||
);
|
||||
use_last = true;
|
||||
}
|
||||
|
||||
let locked = all_listener_sockets.lock().await;
|
||||
|
||||
let listener = if use_last {
|
||||
locked.last()?
|
||||
} else {
|
||||
// use the listener that is active most recently
|
||||
locked
|
||||
.iter()
|
||||
.max_by_key(|listener| listener.last_active_time.load())?
|
||||
};
|
||||
|
||||
Some((listener.get_socket().await, listener.mapped_addr))
|
||||
}
|
||||
|
||||
pub(crate) fn get_joinset(&self) -> Arc<std::sync::Mutex<JoinSet<()>>> {
|
||||
self.tasks.clone()
|
||||
}
|
||||
|
||||
pub(crate) fn get_global_ctx(&self) -> ArcGlobalCtx {
|
||||
self.peer_mgr.get_global_ctx()
|
||||
}
|
||||
|
||||
pub(crate) fn get_peer_mgr(&self) -> Arc<PeerManager> {
|
||||
self.peer_mgr.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(err, ret(level=Level::DEBUG), skip(ports))]
|
||||
pub(crate) async fn send_symmetric_hole_punch_packet(
|
||||
ports: &Vec<u16>,
|
||||
udp: Arc<UdpSocket>,
|
||||
transaction_id: u32,
|
||||
public_ips: &Vec<Ipv4Addr>,
|
||||
port_start_idx: usize,
|
||||
max_packets: usize,
|
||||
) -> Result<usize, Error> {
|
||||
tracing::debug!("sending hard symmetric hole punching packet");
|
||||
let mut sent_packets = 0;
|
||||
let mut cur_port_idx = port_start_idx;
|
||||
while sent_packets < max_packets {
|
||||
let port = ports[cur_port_idx % ports.len()];
|
||||
for pub_ip in public_ips {
|
||||
let addr = SocketAddr::V4(SocketAddrV4::new(*pub_ip, port));
|
||||
let packet = new_hole_punch_packet(transaction_id, HOLE_PUNCH_PACKET_BODY_LEN);
|
||||
udp.send_to(&packet.into_bytes(), addr).await?;
|
||||
sent_packets += 1;
|
||||
}
|
||||
cur_port_idx = cur_port_idx.wrapping_add(1);
|
||||
tokio::time::sleep(Duration::from_millis(3)).await;
|
||||
}
|
||||
Ok(cur_port_idx % ports.len())
|
||||
}
|
||||
|
||||
pub(crate) async fn try_connect_with_socket(
|
||||
socket: Arc<UdpSocket>,
|
||||
remote_mapped_addr: SocketAddr,
|
||||
) -> Result<Box<dyn Tunnel>, Error> {
|
||||
let connector = UdpTunnelConnector::new(
|
||||
format!(
|
||||
"udp://{}:{}",
|
||||
remote_mapped_addr.ip(),
|
||||
remote_mapped_addr.port()
|
||||
)
|
||||
.to_string()
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
connector
|
||||
.try_connect_with_socket(socket, remote_mapped_addr)
|
||||
.await
|
||||
.map_err(|e| Error::from(e))
|
||||
}
|
258
easytier/src/connector/udp_hole_punch/cone.rs
Normal file
258
easytier/src/connector/udp_hole_punch/cone.rs
Normal file
|
@ -0,0 +1,258 @@
|
|||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use crate::{
|
||||
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId},
|
||||
connector::udp_hole_punch::common::{
|
||||
try_connect_with_socket, UdpSocketArray, HOLE_PUNCH_PACKET_BODY_LEN,
|
||||
},
|
||||
peers::peer_manager::PeerManager,
|
||||
proto::{
|
||||
common::Void,
|
||||
peer_rpc::{
|
||||
SelectPunchListenerRequest, SendPunchPacketConeRequest, UdpHolePunchRpcClientFactory,
|
||||
},
|
||||
rpc_types::{self, controller::BaseController},
|
||||
},
|
||||
tunnel::{udp::new_hole_punch_packet, Tunnel},
|
||||
};
|
||||
|
||||
use super::common::PunchHoleServerCommon;
|
||||
|
||||
pub(crate) struct PunchConeHoleServer {
|
||||
common: Arc<PunchHoleServerCommon>,
|
||||
}
|
||||
|
||||
impl PunchConeHoleServer {
|
||||
pub(crate) fn new(common: Arc<PunchHoleServerCommon>) -> Self {
|
||||
Self { common }
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self), ret, err)]
|
||||
pub(crate) async fn send_punch_packet_cone(
|
||||
&self,
|
||||
_: BaseController,
|
||||
request: SendPunchPacketConeRequest,
|
||||
) -> Result<Void, rpc_types::error::Error> {
|
||||
let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!(
|
||||
"send_punch_packet_for_cone request missing listener_mapped_addr"
|
||||
))?;
|
||||
let listener_addr = std::net::SocketAddr::from(listener_addr);
|
||||
let listener = self
|
||||
.common
|
||||
.find_listener(&listener_addr)
|
||||
.await
|
||||
.ok_or(anyhow::anyhow!(
|
||||
"send_punch_packet_for_cone failed to find listener"
|
||||
))?;
|
||||
|
||||
let dest_addr = request.dest_addr.ok_or(anyhow::anyhow!(
|
||||
"send_punch_packet_for_cone request missing dest_addr"
|
||||
))?;
|
||||
let dest_addr = std::net::SocketAddr::from(dest_addr);
|
||||
let dest_ip = dest_addr.ip();
|
||||
if dest_ip.is_unspecified() || dest_ip.is_multicast() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"send_punch_packet_for_cone dest_ip is malformed, {:?}",
|
||||
request
|
||||
)
|
||||
.into());
|
||||
}
|
||||
|
||||
for _ in 0..request.packet_batch_count {
|
||||
tracing::info!(?request, "sending hole punching packet");
|
||||
|
||||
for _ in 0..request.packet_count_per_batch {
|
||||
let udp_packet =
|
||||
new_hole_punch_packet(request.transaction_id, HOLE_PUNCH_PACKET_BODY_LEN);
|
||||
if let Err(e) = listener.send_to(&udp_packet.into_bytes(), &dest_addr).await {
|
||||
tracing::error!(?e, "failed to send hole punch packet to dest addr");
|
||||
}
|
||||
}
|
||||
tokio::time::sleep(Duration::from_millis(request.packet_interval_ms as u64)).await;
|
||||
}
|
||||
|
||||
Ok(Void::default())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PunchConeHoleClient {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
impl PunchConeHoleClient {
|
||||
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self { peer_mgr }
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub(crate) async fn do_hole_punching(
|
||||
&self,
|
||||
dst_peer_id: PeerId,
|
||||
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
|
||||
tracing::info!(?dst_peer_id, "start hole punching");
|
||||
let tid = rand::random();
|
||||
|
||||
let global_ctx = self.peer_mgr.get_global_ctx();
|
||||
let udp_array = UdpSocketArray::new(1, global_ctx.net_ns.clone());
|
||||
let local_socket = {
|
||||
let _g = self.peer_mgr.get_global_ctx().net_ns.guard();
|
||||
Arc::new(UdpSocket::bind("0.0.0.0:0").await?)
|
||||
};
|
||||
|
||||
let local_addr = local_socket
|
||||
.local_addr()
|
||||
.with_context(|| anyhow::anyhow!("failed to get local port from udp array"))?;
|
||||
let local_port = local_addr.port();
|
||||
|
||||
let local_mapped_addr = global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_udp_port_mapping(local_port)
|
||||
.await
|
||||
.with_context(|| "failed to get udp port mapping")?;
|
||||
|
||||
// client -> server: tell server the mapped port, server will return the mapped address of listening port.
|
||||
let rpc_stub = self
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_client()
|
||||
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
|
||||
self.peer_mgr.my_peer_id(),
|
||||
dst_peer_id,
|
||||
global_ctx.get_network_name(),
|
||||
);
|
||||
|
||||
let resp = rpc_stub
|
||||
.select_punch_listener(
|
||||
BaseController::default(),
|
||||
SelectPunchListenerRequest { force_new: false },
|
||||
)
|
||||
.await
|
||||
.with_context(|| "failed to select punch listener")?;
|
||||
let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!(
|
||||
"select_punch_listener response missing listener_mapped_addr"
|
||||
))?;
|
||||
|
||||
tracing::debug!(
|
||||
?local_mapped_addr,
|
||||
?remote_mapped_addr,
|
||||
"hole punch got remote listener"
|
||||
);
|
||||
|
||||
udp_array.add_new_socket(local_socket).await?;
|
||||
udp_array.add_intreast_tid(tid);
|
||||
let send_from_local = || async {
|
||||
udp_array
|
||||
.send_with_all(
|
||||
&new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes(),
|
||||
remote_mapped_addr.clone().into(),
|
||||
)
|
||||
.await
|
||||
.with_context(|| "failed to send hole punch packet from local")
|
||||
};
|
||||
|
||||
send_from_local().await?;
|
||||
|
||||
let scoped_punch_task: ScopedTask<()> = tokio::spawn(async move {
|
||||
if let Err(e) = rpc_stub
|
||||
.send_punch_packet_cone(
|
||||
BaseController {
|
||||
timeout_ms: 4000,
|
||||
..Default::default()
|
||||
},
|
||||
SendPunchPacketConeRequest {
|
||||
listener_mapped_addr: Some(remote_mapped_addr.into()),
|
||||
dest_addr: Some(local_mapped_addr.into()),
|
||||
transaction_id: tid,
|
||||
packet_count_per_batch: 2,
|
||||
packet_batch_count: 5,
|
||||
packet_interval_ms: 400,
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
tracing::error!(?e, "failed to call remote send punch packet");
|
||||
}
|
||||
})
|
||||
.into();
|
||||
|
||||
// server: will send some punching resps, total 10 packets.
|
||||
// client: use the socket to create UdpTunnel with UdpTunnelConnector
|
||||
// NOTICE: UdpTunnelConnector will ignore the punching resp packet sent by remote.
|
||||
let mut finish_time: Option<Instant> = None;
|
||||
while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 {
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
if finish_time.is_none() && (*scoped_punch_task).is_finished() {
|
||||
finish_time = Some(Instant::now());
|
||||
}
|
||||
|
||||
let Some(socket) = udp_array.try_fetch_punched_socket(tid) else {
|
||||
tracing::debug!("no punched socket found, send some more hole punch packets");
|
||||
send_from_local().await?;
|
||||
continue;
|
||||
};
|
||||
|
||||
tracing::debug!(?socket, ?tid, "punched socket found, try connect with it");
|
||||
|
||||
for _ in 0..2 {
|
||||
match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into())
|
||||
.await
|
||||
{
|
||||
Ok(tunnel) => {
|
||||
tracing::info!(?tunnel, "hole punched");
|
||||
return Ok(tunnel);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "failed to connect with socket");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return Err(anyhow::anyhow!("punch task finished but no hole punched"));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
|
||||
use crate::{
|
||||
connector::udp_hole_punch::{
|
||||
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector,
|
||||
},
|
||||
peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost},
|
||||
proto::common::NatType,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn hole_punching_cone() {
|
||||
let p_a = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await;
|
||||
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
let p_c = create_mock_peer_manager_with_mock_stun(NatType::Restricted).await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
|
||||
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
|
||||
|
||||
hole_punching_a.run_as_client().await.unwrap();
|
||||
hole_punching_c.run_as_server().await.unwrap();
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
|
||||
.await
|
||||
.unwrap();
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
}
|
||||
}
|
482
easytier/src/connector/udp_hole_punch/mod.rs
Normal file
482
easytier/src/connector/udp_hole_punch/mod.rs
Normal file
|
@ -0,0 +1,482 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Error;
|
||||
use both_easy_sym::{PunchBothEasySymHoleClient, PunchBothEasySymHoleServer};
|
||||
use common::{PunchHoleServerCommon, UdpNatType, UdpPunchClientMethod};
|
||||
use cone::{PunchConeHoleClient, PunchConeHoleServer};
|
||||
use sym_to_cone::{PunchSymToConeHoleClient, PunchSymToConeHoleServer};
|
||||
use tokio::{sync::Mutex, task::JoinHandle};
|
||||
|
||||
use crate::{
|
||||
common::{stun::StunInfoCollectorTrait, PeerId},
|
||||
connector::direct::PeerManagerForDirectConnector,
|
||||
peers::{
|
||||
peer_manager::PeerManager,
|
||||
peer_task::{PeerTaskLauncher, PeerTaskManager},
|
||||
},
|
||||
proto::{
|
||||
common::{NatType, Void},
|
||||
peer_rpc::{
|
||||
SelectPunchListenerRequest, SelectPunchListenerResponse,
|
||||
SendPunchPacketBothEasySymRequest, SendPunchPacketBothEasySymResponse,
|
||||
SendPunchPacketConeRequest, SendPunchPacketEasySymRequest,
|
||||
SendPunchPacketHardSymRequest, SendPunchPacketHardSymResponse, UdpHolePunchRpc,
|
||||
UdpHolePunchRpcServer,
|
||||
},
|
||||
rpc_types::{self, controller::BaseController},
|
||||
},
|
||||
};
|
||||
|
||||
pub(crate) mod both_easy_sym;
|
||||
pub(crate) mod common;
|
||||
pub(crate) mod cone;
|
||||
pub(crate) mod sym_to_cone;
|
||||
|
||||
struct UdpHolePunchServer {
|
||||
common: Arc<PunchHoleServerCommon>,
|
||||
cone_server: PunchConeHoleServer,
|
||||
sym_to_cone_server: PunchSymToConeHoleServer,
|
||||
both_easy_sym_server: PunchBothEasySymHoleServer,
|
||||
}
|
||||
|
||||
impl UdpHolePunchServer {
|
||||
pub fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
|
||||
let common = Arc::new(PunchHoleServerCommon::new(peer_mgr.clone()));
|
||||
let cone_server = PunchConeHoleServer::new(common.clone());
|
||||
let sym_to_cone_server = PunchSymToConeHoleServer::new(common.clone());
|
||||
let both_easy_sym_server = PunchBothEasySymHoleServer::new(common.clone());
|
||||
|
||||
Arc::new(Self {
|
||||
common,
|
||||
cone_server,
|
||||
sym_to_cone_server,
|
||||
both_easy_sym_server,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl UdpHolePunchRpc for UdpHolePunchServer {
|
||||
type Controller = BaseController;
|
||||
|
||||
async fn select_punch_listener(
|
||||
&self,
|
||||
_ctrl: Self::Controller,
|
||||
input: SelectPunchListenerRequest,
|
||||
) -> rpc_types::error::Result<SelectPunchListenerResponse> {
|
||||
let (_, addr) = self
|
||||
.common
|
||||
.select_listener(input.force_new)
|
||||
.await
|
||||
.ok_or(anyhow::anyhow!("no listener available"))?;
|
||||
|
||||
Ok(SelectPunchListenerResponse {
|
||||
listener_mapped_addr: Some(addr.into()),
|
||||
})
|
||||
}
|
||||
|
||||
/// send packet to one remote_addr, used by nat1-3 to nat1-3
|
||||
async fn send_punch_packet_cone(
|
||||
&self,
|
||||
ctrl: Self::Controller,
|
||||
input: SendPunchPacketConeRequest,
|
||||
) -> rpc_types::error::Result<Void> {
|
||||
self.cone_server.send_punch_packet_cone(ctrl, input).await
|
||||
}
|
||||
|
||||
/// send packet to multiple remote_addr (birthday attack), used by nat4 to nat1-3
|
||||
async fn send_punch_packet_hard_sym(
|
||||
&self,
|
||||
_ctrl: Self::Controller,
|
||||
input: SendPunchPacketHardSymRequest,
|
||||
) -> rpc_types::error::Result<SendPunchPacketHardSymResponse> {
|
||||
self.sym_to_cone_server
|
||||
.send_punch_packet_hard_sym(input)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn send_punch_packet_easy_sym(
|
||||
&self,
|
||||
_ctrl: Self::Controller,
|
||||
input: SendPunchPacketEasySymRequest,
|
||||
) -> rpc_types::error::Result<Void> {
|
||||
self.sym_to_cone_server
|
||||
.send_punch_packet_easy_sym(input)
|
||||
.await
|
||||
.map(|_| Void {})
|
||||
}
|
||||
|
||||
/// nat4 to nat4 (both predictably)
|
||||
async fn send_punch_packet_both_easy_sym(
|
||||
&self,
|
||||
_ctrl: Self::Controller,
|
||||
input: SendPunchPacketBothEasySymRequest,
|
||||
) -> rpc_types::error::Result<SendPunchPacketBothEasySymResponse> {
|
||||
self.both_easy_sym_server
|
||||
.send_punch_packet_both_easy_sym(input)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
struct BackOff {
|
||||
backoffs_ms: Vec<u64>,
|
||||
current_idx: usize,
|
||||
}
|
||||
|
||||
impl BackOff {
|
||||
pub fn new(backoffs_ms: Vec<u64>) -> Self {
|
||||
Self {
|
||||
backoffs_ms,
|
||||
current_idx: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn next_backoff(&mut self) -> u64 {
|
||||
let backoff = self.backoffs_ms[self.current_idx];
|
||||
self.current_idx = (self.current_idx + 1).min(self.backoffs_ms.len() - 1);
|
||||
backoff
|
||||
}
|
||||
|
||||
pub fn rollback(&mut self) {
|
||||
self.current_idx = self.current_idx.saturating_sub(1);
|
||||
}
|
||||
|
||||
pub async fn sleep_for_next_backoff(&mut self) {
|
||||
let backoff = self.next_backoff();
|
||||
if backoff > 0 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(backoff)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct UdpHoePunchConnectorData {
|
||||
cone_client: PunchConeHoleClient,
|
||||
sym_to_cone_client: PunchSymToConeHoleClient,
|
||||
both_easy_sym_client: PunchBothEasySymHoleClient,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
|
||||
// sym punch should be serialized
|
||||
sym_punch_lock: Mutex<()>,
|
||||
}
|
||||
|
||||
impl UdpHoePunchConnectorData {
|
||||
pub fn new(peer_mgr: Arc<PeerManager>) -> Arc<Self> {
|
||||
let cone_client = PunchConeHoleClient::new(peer_mgr.clone());
|
||||
let sym_to_cone_client = PunchSymToConeHoleClient::new(peer_mgr.clone());
|
||||
let both_easy_sym_client = PunchBothEasySymHoleClient::new(peer_mgr.clone());
|
||||
|
||||
Arc::new(Self {
|
||||
cone_client,
|
||||
sym_to_cone_client,
|
||||
both_easy_sym_client,
|
||||
peer_mgr,
|
||||
sym_punch_lock: Mutex::new(()),
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
async fn cone_to_cone(self: Arc<Self>, task_info: PunchTaskInfo) -> Result<(), Error> {
|
||||
let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000]);
|
||||
|
||||
loop {
|
||||
backoff.sleep_for_next_backoff().await;
|
||||
|
||||
let ret = self
|
||||
.cone_client
|
||||
.do_hole_punching(task_info.dst_peer_id)
|
||||
.await;
|
||||
if let Err(e) = ret {
|
||||
tracing::info!(?e, "cone_to_cone hole punching failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await {
|
||||
tracing::warn!(?e, "cone_to_cone add client tunnel failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
tracing::info!("cone_to_cone hole punching success");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
async fn sym_to_cone(self: Arc<Self>, task_info: PunchTaskInfo) -> Result<(), Error> {
|
||||
let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000, 64000]);
|
||||
let mut round = 0;
|
||||
let mut port_idx = rand::random();
|
||||
|
||||
loop {
|
||||
backoff.sleep_for_next_backoff().await;
|
||||
|
||||
let ret = {
|
||||
let _lock = self.sym_punch_lock.lock().await;
|
||||
self.sym_to_cone_client
|
||||
.do_hole_punching(
|
||||
task_info.dst_peer_id,
|
||||
round,
|
||||
&mut port_idx,
|
||||
task_info.my_nat_type,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
round += 1;
|
||||
|
||||
if let Err(e) = ret {
|
||||
tracing::info!(?e, "sym_to_cone hole punching failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await {
|
||||
tracing::warn!(?e, "sym_to_cone add client tunnel failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
async fn both_easy_sym(self: Arc<Self>, task_info: PunchTaskInfo) -> Result<(), Error> {
|
||||
let mut backoff = BackOff::new(vec![0, 1000, 2000, 4000, 4000, 8000, 8000, 16000, 64000]);
|
||||
|
||||
loop {
|
||||
backoff.sleep_for_next_backoff().await;
|
||||
|
||||
let mut is_busy = false;
|
||||
|
||||
let ret = {
|
||||
let _lock = self.sym_punch_lock.lock().await;
|
||||
self.both_easy_sym_client
|
||||
.do_hole_punching(
|
||||
task_info.dst_peer_id,
|
||||
task_info.my_nat_type,
|
||||
task_info.dst_nat_type,
|
||||
&mut is_busy,
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
if is_busy {
|
||||
backoff.rollback();
|
||||
}
|
||||
|
||||
if let Err(e) = ret {
|
||||
tracing::info!(?e, "both_easy_sym hole punching failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Err(e) = self.peer_mgr.add_client_tunnel(ret.unwrap()).await {
|
||||
tracing::warn!(?e, "both_easy_sym add client tunnel failed");
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct UdpHolePunchPeerTaskLauncher {}
|
||||
|
||||
#[derive(Clone, Debug, Hash, Eq, PartialEq)]
|
||||
struct PunchTaskInfo {
|
||||
dst_peer_id: PeerId,
|
||||
dst_nat_type: UdpNatType,
|
||||
my_nat_type: UdpNatType,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl PeerTaskLauncher for UdpHolePunchPeerTaskLauncher {
|
||||
type Data = Arc<UdpHoePunchConnectorData>;
|
||||
type CollectPeerItem = PunchTaskInfo;
|
||||
type TaskRet = ();
|
||||
|
||||
fn new_data(&self, peer_mgr: Arc<PeerManager>) -> Self::Data {
|
||||
UdpHoePunchConnectorData::new(peer_mgr)
|
||||
}
|
||||
|
||||
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem> {
|
||||
let my_nat_type = data
|
||||
.peer_mgr
|
||||
.get_global_ctx()
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type;
|
||||
let my_nat_type: UdpNatType = NatType::try_from(my_nat_type)
|
||||
.unwrap_or(NatType::Unknown)
|
||||
.into();
|
||||
if !my_nat_type.is_sym() {
|
||||
data.sym_to_cone_client.clear_udp_array().await;
|
||||
}
|
||||
|
||||
let mut peers_to_connect: Vec<Self::CollectPeerItem> = Vec::new();
|
||||
// do not do anything if:
|
||||
// 1. our nat type is OpenInternet or NoPat, which means we can wait other peers to connect us
|
||||
// notice that if we are unknown, we treat ourselves as cone
|
||||
if my_nat_type.is_open() {
|
||||
return peers_to_connect;
|
||||
}
|
||||
|
||||
// collect peer list from peer manager and do some filter:
|
||||
// 1. peers without direct conns;
|
||||
// 2. peers is full cone (any restricted type);
|
||||
for route in data.peer_mgr.list_routes().await.iter() {
|
||||
if route
|
||||
.feature_flag
|
||||
.map(|x| x.is_public_server)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let peer_nat_type = route
|
||||
.stun_info
|
||||
.as_ref()
|
||||
.map(|x| x.udp_nat_type)
|
||||
.unwrap_or(0);
|
||||
let Ok(peer_nat_type) = NatType::try_from(peer_nat_type) else {
|
||||
continue;
|
||||
};
|
||||
let peer_nat_type = peer_nat_type.into();
|
||||
|
||||
let peer_id: PeerId = route.peer_id;
|
||||
let conns = data.peer_mgr.list_peer_conns(peer_id).await;
|
||||
if conns.is_some() && conns.unwrap().len() > 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
if !my_nat_type.can_punch_hole_as_client(peer_nat_type) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::info!(
|
||||
?peer_id,
|
||||
?peer_nat_type,
|
||||
?my_nat_type,
|
||||
"found peer to do hole punching"
|
||||
);
|
||||
|
||||
peers_to_connect.push(PunchTaskInfo {
|
||||
dst_peer_id: peer_id,
|
||||
dst_nat_type: peer_nat_type,
|
||||
my_nat_type,
|
||||
});
|
||||
}
|
||||
|
||||
peers_to_connect
|
||||
}
|
||||
|
||||
async fn launch_task(
|
||||
&self,
|
||||
data: &Self::Data,
|
||||
item: Self::CollectPeerItem,
|
||||
) -> JoinHandle<Result<Self::TaskRet, Error>> {
|
||||
let data = data.clone();
|
||||
let punch_method = item.my_nat_type.get_punch_hole_method(item.dst_nat_type);
|
||||
match punch_method {
|
||||
UdpPunchClientMethod::ConeToCone => tokio::spawn(data.cone_to_cone(item)),
|
||||
UdpPunchClientMethod::SymToCone => tokio::spawn(data.sym_to_cone(item)),
|
||||
UdpPunchClientMethod::EasySymToEasySym => tokio::spawn(data.both_easy_sym(item)),
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn all_task_done(&self, data: &Self::Data) {
|
||||
data.sym_to_cone_client.clear_udp_array().await;
|
||||
}
|
||||
|
||||
fn loop_interval_ms(&self) -> u64 {
|
||||
5000
|
||||
}
|
||||
}
|
||||
|
||||
pub struct UdpHolePunchConnector {
|
||||
server: Arc<UdpHolePunchServer>,
|
||||
client: PeerTaskManager<UdpHolePunchPeerTaskLauncher>,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
}
|
||||
|
||||
// Currently support:
|
||||
// Symmetric -> Full Cone
|
||||
// Any Type of Full Cone -> Any Type of Full Cone
|
||||
|
||||
// if same level of full cone, node with smaller peer_id will be the initiator
|
||||
// if different level of full cone, node with more strict level will be the initiator
|
||||
|
||||
impl UdpHolePunchConnector {
|
||||
pub fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
server: UdpHolePunchServer::new(peer_mgr.clone()),
|
||||
client: PeerTaskManager::new(UdpHolePunchPeerTaskLauncher {}, peer_mgr.clone()),
|
||||
peer_mgr,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_as_client(&mut self) -> Result<(), Error> {
|
||||
self.client.start();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_as_server(&mut self) -> Result<(), Error> {
|
||||
self.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_server()
|
||||
.registry()
|
||||
.register(
|
||||
UdpHolePunchRpcServer::new(self.server.clone()),
|
||||
&self.peer_mgr.get_global_ctx().get_network_name(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run(&mut self) -> Result<(), Error> {
|
||||
let global_ctx = self.peer_mgr.get_global_ctx();
|
||||
|
||||
if global_ctx.get_flags().disable_p2p {
|
||||
return Ok(());
|
||||
}
|
||||
if global_ctx.get_flags().disable_udp_hole_punching {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.run_as_client().await?;
|
||||
self.run_as_server().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::common::stun::MockStunInfoCollector;
|
||||
use crate::proto::common::NatType;
|
||||
|
||||
use crate::peers::{peer_manager::PeerManager, tests::create_mock_peer_manager};
|
||||
|
||||
pub fn replace_stun_info_collector(peer_mgr: Arc<PeerManager>, udp_nat_type: NatType) {
|
||||
let collector = Box::new(MockStunInfoCollector { udp_nat_type });
|
||||
peer_mgr
|
||||
.get_global_ctx()
|
||||
.replace_stun_info_collector(collector);
|
||||
}
|
||||
|
||||
pub async fn create_mock_peer_manager_with_mock_stun(
|
||||
udp_nat_type: NatType,
|
||||
) -> Arc<PeerManager> {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
replace_stun_info_collector(p_a.clone(), udp_nat_type);
|
||||
p_a
|
||||
}
|
||||
}
|
591
easytier/src/connector/udp_hole_punch/sym_to_cone.rs
Normal file
591
easytier/src/connector/udp_hole_punch/sym_to_cone.rs
Normal file
|
@ -0,0 +1,591 @@
|
|||
use std::{
|
||||
net::Ipv4Addr,
|
||||
ops::{Div, Mul},
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use rand::{seq::SliceRandom, Rng};
|
||||
use tokio::{net::UdpSocket, sync::RwLock};
|
||||
use tracing::Level;
|
||||
|
||||
use crate::{
|
||||
common::{scoped_task::ScopedTask, stun::StunInfoCollectorTrait, PeerId},
|
||||
connector::udp_hole_punch::common::{
|
||||
send_symmetric_hole_punch_packet, try_connect_with_socket, HOLE_PUNCH_PACKET_BODY_LEN,
|
||||
},
|
||||
defer,
|
||||
peers::peer_manager::PeerManager,
|
||||
proto::{
|
||||
peer_rpc::{
|
||||
SelectPunchListenerRequest, SendPunchPacketEasySymRequest,
|
||||
SendPunchPacketHardSymRequest, SendPunchPacketHardSymResponse,
|
||||
UdpHolePunchRpcClientFactory,
|
||||
},
|
||||
rpc_types::{self, controller::BaseController},
|
||||
},
|
||||
tunnel::{udp::new_hole_punch_packet, Tunnel},
|
||||
};
|
||||
|
||||
use super::common::{PunchHoleServerCommon, UdpNatType, UdpSocketArray};
|
||||
|
||||
const UDP_ARRAY_SIZE_FOR_HARD_SYM: usize = 84;
|
||||
|
||||
pub(crate) struct PunchSymToConeHoleServer {
|
||||
common: Arc<PunchHoleServerCommon>,
|
||||
|
||||
shuffled_port_vec: Arc<Vec<u16>>,
|
||||
}
|
||||
|
||||
impl PunchSymToConeHoleServer {
|
||||
pub(crate) fn new(common: Arc<PunchHoleServerCommon>) -> Self {
|
||||
let mut shuffled_port_vec: Vec<u16> = (1..=65535).collect();
|
||||
shuffled_port_vec.shuffle(&mut rand::thread_rng());
|
||||
|
||||
Self {
|
||||
common,
|
||||
shuffled_port_vec: Arc::new(shuffled_port_vec),
|
||||
}
|
||||
}
|
||||
|
||||
// hard sym means public port is random and cannot be predicted
|
||||
#[tracing::instrument(skip(self), ret)]
|
||||
pub(crate) async fn send_punch_packet_easy_sym(
|
||||
&self,
|
||||
request: SendPunchPacketEasySymRequest,
|
||||
) -> Result<(), rpc_types::error::Error> {
|
||||
tracing::info!("send_punch_packet_easy_sym start");
|
||||
|
||||
let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!(
|
||||
"send_punch_packet_easy_sym request missing listener_addr"
|
||||
))?;
|
||||
let listener_addr = std::net::SocketAddr::from(listener_addr);
|
||||
let listener = self
|
||||
.common
|
||||
.find_listener(&listener_addr)
|
||||
.await
|
||||
.ok_or(anyhow::anyhow!(
|
||||
"send_punch_packet_easy_sym failed to find listener"
|
||||
))?;
|
||||
|
||||
let public_ips = request
|
||||
.public_ips
|
||||
.into_iter()
|
||||
.map(|ip| std::net::Ipv4Addr::from(ip))
|
||||
.collect::<Vec<_>>();
|
||||
if public_ips.len() == 0 {
|
||||
tracing::warn!("send_punch_packet_easy_sym got zero len public ip");
|
||||
return Err(
|
||||
anyhow::anyhow!("send_punch_packet_easy_sym got zero len public ip").into(),
|
||||
);
|
||||
}
|
||||
|
||||
let transaction_id = request.transaction_id;
|
||||
let base_port_num = request.base_port_num;
|
||||
let max_port_num = request.max_port_num.max(1);
|
||||
let is_incremental = request.is_incremental;
|
||||
|
||||
let port_start = if is_incremental {
|
||||
base_port_num.saturating_add(1)
|
||||
} else {
|
||||
base_port_num.saturating_sub(max_port_num)
|
||||
};
|
||||
|
||||
let port_end = if is_incremental {
|
||||
base_port_num.saturating_add(max_port_num)
|
||||
} else {
|
||||
base_port_num.saturating_sub(1)
|
||||
};
|
||||
|
||||
if port_end <= port_start {
|
||||
return Err(anyhow::anyhow!("send_punch_packet_easy_sym invalid port range").into());
|
||||
}
|
||||
|
||||
let ports = (port_start..=port_end)
|
||||
.map(|x| x as u16)
|
||||
.collect::<Vec<_>>();
|
||||
tracing::debug!(
|
||||
?ports,
|
||||
?public_ips,
|
||||
"send_punch_packet_easy_sym send to ports"
|
||||
);
|
||||
send_symmetric_hole_punch_packet(
|
||||
&ports,
|
||||
listener,
|
||||
transaction_id,
|
||||
&public_ips,
|
||||
0,
|
||||
ports.len(),
|
||||
)
|
||||
.await
|
||||
.with_context(|| "failed to send symmetric hole punch packet")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// hard sym means public port is random and cannot be predicted
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub(crate) async fn send_punch_packet_hard_sym(
|
||||
&self,
|
||||
request: SendPunchPacketHardSymRequest,
|
||||
) -> Result<SendPunchPacketHardSymResponse, rpc_types::error::Error> {
|
||||
tracing::info!("try_punch_symmetric start");
|
||||
|
||||
let listener_addr = request.listener_mapped_addr.ok_or(anyhow::anyhow!(
|
||||
"try_punch_symmetric request missing listener_addr"
|
||||
))?;
|
||||
let listener_addr = std::net::SocketAddr::from(listener_addr);
|
||||
let listener = self
|
||||
.common
|
||||
.find_listener(&listener_addr)
|
||||
.await
|
||||
.ok_or(anyhow::anyhow!(
|
||||
"send_punch_packet_for_cone failed to find listener"
|
||||
))?;
|
||||
|
||||
let public_ips = request
|
||||
.public_ips
|
||||
.into_iter()
|
||||
.map(|ip| std::net::Ipv4Addr::from(ip))
|
||||
.collect::<Vec<_>>();
|
||||
if public_ips.len() == 0 {
|
||||
tracing::warn!("try_punch_symmetric got zero len public ip");
|
||||
return Err(anyhow::anyhow!("try_punch_symmetric got zero len public ip").into());
|
||||
}
|
||||
|
||||
let transaction_id = request.transaction_id;
|
||||
let last_port_index = request.port_index as usize;
|
||||
|
||||
let round = std::cmp::max(request.round, 1);
|
||||
|
||||
// send max k1 packets if we are predicting the dst port
|
||||
let max_k1: u32 = 180;
|
||||
// send max k2 packets if we are sending to random port
|
||||
let mut max_k2: u32 = rand::thread_rng().gen_range(600..800);
|
||||
if round > 2 {
|
||||
max_k2 = max_k2.mul(2).div(round).max(max_k1);
|
||||
}
|
||||
|
||||
let next_port_index = send_symmetric_hole_punch_packet(
|
||||
&self.shuffled_port_vec,
|
||||
listener.clone(),
|
||||
transaction_id,
|
||||
&public_ips,
|
||||
last_port_index,
|
||||
max_k2 as usize,
|
||||
)
|
||||
.await
|
||||
.with_context(|| "failed to send symmetric hole punch packet randomly")?;
|
||||
|
||||
return Ok(SendPunchPacketHardSymResponse {
|
||||
next_port_index: next_port_index as u32,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PunchSymToConeHoleClient {
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
udp_array: RwLock<Option<Arc<UdpSocketArray>>>,
|
||||
try_direct_connect: AtomicBool,
|
||||
punch_predicablely: AtomicBool,
|
||||
punch_randomly: AtomicBool,
|
||||
}
|
||||
|
||||
impl PunchSymToConeHoleClient {
|
||||
pub(crate) fn new(peer_mgr: Arc<PeerManager>) -> Self {
|
||||
Self {
|
||||
peer_mgr,
|
||||
udp_array: RwLock::new(None),
|
||||
try_direct_connect: AtomicBool::new(true),
|
||||
punch_predicablely: AtomicBool::new(true),
|
||||
punch_randomly: AtomicBool::new(true),
|
||||
}
|
||||
}
|
||||
|
||||
async fn prepare_udp_array(&self) -> Result<Arc<UdpSocketArray>, anyhow::Error> {
|
||||
let rlocked = self.udp_array.read().await;
|
||||
if let Some(udp_array) = rlocked.clone() {
|
||||
return Ok(udp_array);
|
||||
}
|
||||
|
||||
drop(rlocked);
|
||||
let mut wlocked = self.udp_array.write().await;
|
||||
if let Some(udp_array) = wlocked.clone() {
|
||||
return Ok(udp_array);
|
||||
}
|
||||
|
||||
let udp_array = Arc::new(UdpSocketArray::new(
|
||||
UDP_ARRAY_SIZE_FOR_HARD_SYM,
|
||||
self.peer_mgr.get_global_ctx().net_ns.clone(),
|
||||
));
|
||||
udp_array.start().await?;
|
||||
wlocked.replace(udp_array.clone());
|
||||
Ok(udp_array)
|
||||
}
|
||||
|
||||
pub(crate) async fn clear_udp_array(&self) {
|
||||
let mut wlocked = self.udp_array.write().await;
|
||||
wlocked.take();
|
||||
}
|
||||
|
||||
async fn get_base_port_for_easy_sym(&self, my_nat_info: UdpNatType) -> Option<u16> {
|
||||
let global_ctx = self.peer_mgr.get_global_ctx();
|
||||
if my_nat_info.is_easy_sym() {
|
||||
match global_ctx
|
||||
.get_stun_info_collector()
|
||||
.get_udp_port_mapping(0)
|
||||
.await
|
||||
{
|
||||
Ok(addr) => Some(addr.port()),
|
||||
ret => {
|
||||
tracing::warn!(?ret, "failed to get udp port mapping for easy sym");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(err(level = Level::ERROR), skip(self))]
|
||||
pub(crate) async fn do_hole_punching(
|
||||
&self,
|
||||
dst_peer_id: PeerId,
|
||||
round: u32,
|
||||
last_port_idx: &mut usize,
|
||||
my_nat_info: UdpNatType,
|
||||
) -> Result<Box<dyn Tunnel>, anyhow::Error> {
|
||||
let udp_array = self.prepare_udp_array().await?;
|
||||
let global_ctx = self.peer_mgr.get_global_ctx();
|
||||
|
||||
let rpc_stub = self
|
||||
.peer_mgr
|
||||
.get_peer_rpc_mgr()
|
||||
.rpc_client()
|
||||
.scoped_client::<UdpHolePunchRpcClientFactory<BaseController>>(
|
||||
self.peer_mgr.my_peer_id(),
|
||||
dst_peer_id,
|
||||
global_ctx.get_network_name(),
|
||||
);
|
||||
|
||||
let resp = rpc_stub
|
||||
.select_punch_listener(
|
||||
BaseController::default(),
|
||||
SelectPunchListenerRequest { force_new: false },
|
||||
)
|
||||
.await
|
||||
.with_context(|| "failed to select punch listener")?;
|
||||
let remote_mapped_addr = resp.listener_mapped_addr.ok_or(anyhow::anyhow!(
|
||||
"select_punch_listener response missing listener_mapped_addr"
|
||||
))?;
|
||||
|
||||
// try direct connect first
|
||||
if self.try_direct_connect.load(Ordering::Relaxed) {
|
||||
if let Ok(tunnel) = try_connect_with_socket(
|
||||
Arc::new(UdpSocket::bind("0.0.0.0:0").await?),
|
||||
remote_mapped_addr.into(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
return Ok(tunnel);
|
||||
}
|
||||
}
|
||||
|
||||
let stun_info = global_ctx.get_stun_info_collector().get_stun_info();
|
||||
let public_ips: Vec<Ipv4Addr> = stun_info
|
||||
.public_ip
|
||||
.iter()
|
||||
.map(|x| x.parse().unwrap())
|
||||
.collect();
|
||||
if public_ips.is_empty() {
|
||||
return Err(anyhow::anyhow!("failed to get public ips"));
|
||||
}
|
||||
|
||||
let tid = rand::thread_rng().gen();
|
||||
let packet = new_hole_punch_packet(tid, HOLE_PUNCH_PACKET_BODY_LEN).into_bytes();
|
||||
udp_array.add_intreast_tid(tid);
|
||||
defer! { udp_array.remove_intreast_tid(tid);}
|
||||
udp_array
|
||||
.send_with_all(&packet, remote_mapped_addr.into())
|
||||
.await?;
|
||||
|
||||
let port_index = *last_port_idx as u32;
|
||||
let base_port_for_easy_sym = self.get_base_port_for_easy_sym(my_nat_info).await;
|
||||
let punch_random = self.punch_randomly.load(Ordering::Relaxed);
|
||||
let punch_predicable = self.punch_predicablely.load(Ordering::Relaxed);
|
||||
let scoped_punch_task: ScopedTask<Option<u32>> = tokio::spawn(async move {
|
||||
if punch_predicable {
|
||||
if let Some(inc) = my_nat_info.get_inc_of_easy_sym() {
|
||||
let req = SendPunchPacketEasySymRequest {
|
||||
listener_mapped_addr: remote_mapped_addr.clone().into(),
|
||||
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
|
||||
transaction_id: tid,
|
||||
base_port_num: base_port_for_easy_sym.unwrap() as u32,
|
||||
max_port_num: 50,
|
||||
is_incremental: inc,
|
||||
};
|
||||
tracing::debug!(?req, "send punch packet for easy sym start");
|
||||
let ret = rpc_stub
|
||||
.send_punch_packet_easy_sym(
|
||||
BaseController {
|
||||
timeout_ms: 4000,
|
||||
trace_id: 0,
|
||||
},
|
||||
req,
|
||||
)
|
||||
.await;
|
||||
tracing::debug!(?ret, "send punch packet for easy sym return");
|
||||
}
|
||||
}
|
||||
|
||||
if punch_random {
|
||||
let req = SendPunchPacketHardSymRequest {
|
||||
listener_mapped_addr: remote_mapped_addr.clone().into(),
|
||||
public_ips: public_ips.clone().into_iter().map(|x| x.into()).collect(),
|
||||
transaction_id: tid,
|
||||
round,
|
||||
port_index,
|
||||
};
|
||||
tracing::debug!(?req, "send punch packet for hard sym start");
|
||||
match rpc_stub
|
||||
.send_punch_packet_hard_sym(
|
||||
BaseController {
|
||||
timeout_ms: 4000,
|
||||
trace_id: 0,
|
||||
},
|
||||
req,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "failed to send punch packet for hard sym");
|
||||
return None;
|
||||
}
|
||||
Ok(resp) => return Some(resp.next_port_index),
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
})
|
||||
.into();
|
||||
|
||||
// no matter what the result is, we should check if we received any hole punching packet
|
||||
let mut ret_tunnel: Option<Box<dyn Tunnel>> = None;
|
||||
let mut finish_time: Option<Instant> = None;
|
||||
while finish_time.is_none() || finish_time.as_ref().unwrap().elapsed().as_millis() < 1000 {
|
||||
tokio::time::sleep(Duration::from_millis(200)).await;
|
||||
|
||||
if finish_time.is_none() && (*scoped_punch_task).is_finished() {
|
||||
finish_time = Some(Instant::now());
|
||||
}
|
||||
|
||||
let Some(socket) = udp_array.try_fetch_punched_socket(tid) else {
|
||||
tracing::debug!("no punched socket found, wait for more time");
|
||||
continue;
|
||||
};
|
||||
|
||||
// if hole punched but tunnel creation failed, need to retry entire process.
|
||||
match try_connect_with_socket(socket.socket.clone(), remote_mapped_addr.into()).await {
|
||||
Ok(tunnel) => {
|
||||
ret_tunnel.replace(tunnel);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "failed to connect with socket");
|
||||
udp_array.add_new_socket(socket.socket).await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let punch_task_result = scoped_punch_task.await;
|
||||
tracing::debug!(?punch_task_result, ?ret_tunnel, "punch task got result");
|
||||
|
||||
if let Ok(Some(next_port_idx)) = punch_task_result {
|
||||
*last_port_idx = next_port_idx as usize;
|
||||
} else {
|
||||
*last_port_idx = rand::random();
|
||||
}
|
||||
|
||||
if let Some(tunnel) = ret_tunnel {
|
||||
Ok(tunnel)
|
||||
} else {
|
||||
anyhow::bail!(
|
||||
"failed to hole punch, punch task result: {:?}",
|
||||
punch_task_result
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
use std::{
|
||||
sync::{atomic::AtomicU32, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use tokio::net::UdpSocket;
|
||||
|
||||
use crate::{
|
||||
connector::udp_hole_punch::{
|
||||
tests::create_mock_peer_manager_with_mock_stun, UdpHolePunchConnector,
|
||||
},
|
||||
peers::tests::{connect_peer_manager, wait_route_appear, wait_route_appear_with_cost},
|
||||
proto::common::NatType,
|
||||
tunnel::common::tests::wait_for_condition,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn hole_punching_symmetric_only_random() {
|
||||
let p_a = create_mock_peer_manager_with_mock_stun(NatType::Symmetric).await;
|
||||
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
|
||||
|
||||
hole_punching_a
|
||||
.client
|
||||
.data()
|
||||
.sym_to_cone_client
|
||||
.try_direct_connect
|
||||
.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
hole_punching_a
|
||||
.client
|
||||
.data()
|
||||
.sym_to_cone_client
|
||||
.punch_predicablely
|
||||
.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
|
||||
wait_for_condition(
|
||||
|| async {
|
||||
hole_punching_a
|
||||
.client
|
||||
.data()
|
||||
.sym_to_cone_client
|
||||
.udp_array
|
||||
.read()
|
||||
.await
|
||||
.is_some()
|
||||
},
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
|
||||
wait_for_condition(
|
||||
|| async {
|
||||
wait_route_appear_with_cost(p_a.clone(), p_c.my_peer_id(), Some(1))
|
||||
.await
|
||||
.is_ok()
|
||||
},
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
println!("{:?}", p_a.list_routes().await);
|
||||
|
||||
wait_for_condition(
|
||||
|| async {
|
||||
hole_punching_a
|
||||
.client
|
||||
.data()
|
||||
.sym_to_cone_client
|
||||
.udp_array
|
||||
.read()
|
||||
.await
|
||||
.is_none()
|
||||
},
|
||||
Duration::from_secs(10),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[rstest::rstest]
|
||||
#[tokio::test]
|
||||
#[serial_test::serial(hole_punch)]
|
||||
async fn hole_punching_symmetric_only_predict(#[values("true", "false")] is_inc: bool) {
|
||||
let p_a = create_mock_peer_manager_with_mock_stun(if is_inc {
|
||||
NatType::SymmetricEasyInc
|
||||
} else {
|
||||
NatType::SymmetricEasyDec
|
||||
})
|
||||
.await;
|
||||
let p_b = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
let p_c = create_mock_peer_manager_with_mock_stun(NatType::PortRestricted).await;
|
||||
connect_peer_manager(p_a.clone(), p_b.clone()).await;
|
||||
connect_peer_manager(p_b.clone(), p_c.clone()).await;
|
||||
wait_route_appear(p_a.clone(), p_c.clone()).await.unwrap();
|
||||
|
||||
let mut hole_punching_a = UdpHolePunchConnector::new(p_a.clone());
|
||||
let mut hole_punching_c = UdpHolePunchConnector::new(p_c.clone());
|
||||
|
||||
hole_punching_a
|
||||
.client
|
||||
.data()
|
||||
.sym_to_cone_client
|
||||
.try_direct_connect
|
||||
.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
hole_punching_a
|
||||
.client
|
||||
.data()
|
||||
.sym_to_cone_client
|
||||
.punch_randomly
|
||||
.store(false, std::sync::atomic::Ordering::Relaxed);
|
||||
|
||||
hole_punching_a.run().await.unwrap();
|
||||
hole_punching_c.run().await.unwrap();
|
||||
|
||||
let udps = if is_inc {
|
||||
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40147").await.unwrap());
|
||||
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40194").await.unwrap());
|
||||
vec![udp1, udp2]
|
||||
} else {
|
||||
let udp1 = Arc::new(UdpSocket::bind("0.0.0.0:40141").await.unwrap());
|
||||
let udp2 = Arc::new(UdpSocket::bind("0.0.0.0:40100").await.unwrap());
|
||||
vec![udp1, udp2]
|
||||
};
|
||||
// let udp_dec = Arc::new(UdpSocket::bind("0.0.0.0:40140").await.unwrap());
|
||||
// let udp_dec2 = Arc::new(UdpSocket::bind("0.0.0.0:40050").await.unwrap());
|
||||
|
||||
let counter = Arc::new(AtomicU32::new(0));
|
||||
|
||||
// all these sockets should receive hole punching packet
|
||||
for udp in udps.iter().map(Arc::clone) {
|
||||
let counter = counter.clone();
|
||||
tokio::spawn(async move {
|
||||
let mut buf = [0u8; 1024];
|
||||
let (len, addr) = udp.recv_from(&mut buf).await.unwrap();
|
||||
println!(
|
||||
"got predictable punch packet, {:?} {:?} {:?}",
|
||||
len,
|
||||
addr,
|
||||
udp.local_addr()
|
||||
);
|
||||
counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
});
|
||||
}
|
||||
|
||||
hole_punching_a.client.run_immediately().await;
|
||||
|
||||
let udp_len = udps.len();
|
||||
wait_for_condition(
|
||||
|| async { counter.load(std::sync::atomic::Ordering::Relaxed) == udp_len as u32 },
|
||||
Duration::from_secs(30),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
|
@ -179,14 +179,16 @@ impl CommandHandler {
|
|||
async fn list_peers(&self) -> Result<ListPeerResponse, Error> {
|
||||
let client = self.get_peer_manager_client().await?;
|
||||
let request = ListPeerRequest::default();
|
||||
let response = client.list_peer(BaseController {}, request).await?;
|
||||
let response = client.list_peer(BaseController::default(), request).await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
async fn list_routes(&self) -> Result<ListRouteResponse, Error> {
|
||||
let client = self.get_peer_manager_client().await?;
|
||||
let request = ListRouteRequest::default();
|
||||
let response = client.list_route(BaseController {}, request).await?;
|
||||
let response = client
|
||||
.list_route(BaseController::default(), request)
|
||||
.await?;
|
||||
Ok(response)
|
||||
}
|
||||
|
||||
|
@ -275,7 +277,7 @@ impl CommandHandler {
|
|||
|
||||
let client = self.get_peer_manager_client().await?;
|
||||
let node_info = client
|
||||
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
|
||||
.show_node_info(BaseController::default(), ShowNodeInfoRequest::default())
|
||||
.await?
|
||||
.node_info
|
||||
.ok_or(anyhow::anyhow!("node info not found"))?;
|
||||
|
@ -296,7 +298,9 @@ impl CommandHandler {
|
|||
async fn handle_route_dump(&self) -> Result<(), Error> {
|
||||
let client = self.get_peer_manager_client().await?;
|
||||
let request = DumpRouteRequest::default();
|
||||
let response = client.dump_route(BaseController {}, request).await?;
|
||||
let response = client
|
||||
.dump_route(BaseController::default(), request)
|
||||
.await?;
|
||||
println!("response: {}", response.result);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -305,7 +309,7 @@ impl CommandHandler {
|
|||
let client = self.get_peer_manager_client().await?;
|
||||
let request = ListForeignNetworkRequest::default();
|
||||
let response = client
|
||||
.list_foreign_network(BaseController {}, request)
|
||||
.list_foreign_network(BaseController::default(), request)
|
||||
.await?;
|
||||
let network_map = response;
|
||||
if self.verbose {
|
||||
|
@ -347,7 +351,7 @@ impl CommandHandler {
|
|||
let client = self.get_peer_manager_client().await?;
|
||||
let request = ListGlobalForeignNetworkRequest::default();
|
||||
let response = client
|
||||
.list_global_foreign_network(BaseController {}, request)
|
||||
.list_global_foreign_network(BaseController::default(), request)
|
||||
.await?;
|
||||
if self.verbose {
|
||||
println!("{:#?}", response);
|
||||
|
@ -383,7 +387,7 @@ impl CommandHandler {
|
|||
let mut items: Vec<RouteTableItem> = vec![];
|
||||
let client = self.get_peer_manager_client().await?;
|
||||
let node_info = client
|
||||
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
|
||||
.show_node_info(BaseController::default(), ShowNodeInfoRequest::default())
|
||||
.await?
|
||||
.node_info
|
||||
.ok_or(anyhow::anyhow!("node info not found"))?;
|
||||
|
@ -451,7 +455,9 @@ impl CommandHandler {
|
|||
async fn handle_connector_list(&self) -> Result<(), Error> {
|
||||
let client = self.get_connector_manager_client().await?;
|
||||
let request = ListConnectorRequest::default();
|
||||
let response = client.list_connector(BaseController {}, request).await?;
|
||||
let response = client
|
||||
.list_connector(BaseController::default(), request)
|
||||
.await?;
|
||||
println!("response: {:#?}", response);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -515,7 +521,7 @@ async fn main() -> Result<(), Error> {
|
|||
Some(RouteSubCommand::Dump) => handler.handle_route_dump().await?,
|
||||
},
|
||||
SubCommand::Stun => {
|
||||
timeout(Duration::from_secs(5), async move {
|
||||
timeout(Duration::from_secs(25), async move {
|
||||
let collector = StunInfoCollector::new_with_default_servers();
|
||||
loop {
|
||||
let ret = collector.get_stun_info();
|
||||
|
@ -532,7 +538,10 @@ async fn main() -> Result<(), Error> {
|
|||
SubCommand::PeerCenter => {
|
||||
let peer_center_client = handler.get_peer_center_client().await?;
|
||||
let resp = peer_center_client
|
||||
.get_global_peer_map(BaseController {}, GetGlobalPeerMapRequest::default())
|
||||
.get_global_peer_map(
|
||||
BaseController::default(),
|
||||
GetGlobalPeerMapRequest::default(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
#[derive(tabled::Tabled)]
|
||||
|
@ -565,7 +574,10 @@ async fn main() -> Result<(), Error> {
|
|||
SubCommand::VpnPortal => {
|
||||
let vpn_portal_client = handler.get_vpn_portal_client().await?;
|
||||
let resp = vpn_portal_client
|
||||
.get_vpn_portal_info(BaseController {}, GetVpnPortalInfoRequest::default())
|
||||
.get_vpn_portal_info(
|
||||
BaseController::default(),
|
||||
GetVpnPortalInfoRequest::default(),
|
||||
)
|
||||
.await?
|
||||
.vpn_portal_info
|
||||
.unwrap_or_default();
|
||||
|
@ -583,7 +595,7 @@ async fn main() -> Result<(), Error> {
|
|||
SubCommand::Node(sub_cmd) => {
|
||||
let client = handler.get_peer_manager_client().await?;
|
||||
let node_info = client
|
||||
.show_node_info(BaseController {}, ShowNodeInfoRequest::default())
|
||||
.show_node_info(BaseController::default(), ShowNodeInfoRequest::default())
|
||||
.await?
|
||||
.node_info
|
||||
.ok_or(anyhow::anyhow!("node info not found"))?;
|
||||
|
|
|
@ -161,7 +161,7 @@ impl Instance {
|
|||
DirectConnectorManager::new(global_ctx.clone(), peer_manager.clone());
|
||||
direct_conn_manager.run();
|
||||
|
||||
let udp_hole_puncher = UdpHolePunchConnector::new(global_ctx.clone(), peer_manager.clone());
|
||||
let udp_hole_puncher = UdpHolePunchConnector::new(peer_manager.clone());
|
||||
|
||||
let peer_center = Arc::new(PeerCenterInstance::new(peer_manager.clone()));
|
||||
|
||||
|
|
|
@ -230,7 +230,7 @@ impl PeerCenterInstance {
|
|||
|
||||
let ret = client
|
||||
.get_global_peer_map(
|
||||
BaseController {},
|
||||
BaseController::default(),
|
||||
GetGlobalPeerMapRequest {
|
||||
digest: ctx.job_ctx.global_peer_map_digest.load(),
|
||||
},
|
||||
|
@ -307,7 +307,7 @@ impl PeerCenterInstance {
|
|||
|
||||
let ret = client
|
||||
.report_peers(
|
||||
BaseController {},
|
||||
BaseController::default(),
|
||||
ReportPeersRequest {
|
||||
my_peer_id: my_node_id,
|
||||
peer_infos: Some(peers),
|
||||
|
|
|
@ -15,6 +15,8 @@ pub mod foreign_network_manager;
|
|||
|
||||
pub mod encrypt;
|
||||
|
||||
pub mod peer_task;
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests;
|
||||
|
||||
|
|
|
@ -1058,7 +1058,7 @@ mod tests {
|
|||
|
||||
let ret = stub
|
||||
.say_hello(
|
||||
RpcController {},
|
||||
RpcController::default(),
|
||||
SayHelloRequest {
|
||||
name: "abc".to_string(),
|
||||
},
|
||||
|
|
|
@ -539,7 +539,7 @@ impl RouteTable {
|
|||
fn get_nat_type(&self, peer_id: PeerId) -> Option<NatType> {
|
||||
self.peer_infos
|
||||
.get(&peer_id)
|
||||
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap())
|
||||
.map(|x| NatType::try_from(x.udp_stun_info as i32).unwrap_or_default())
|
||||
}
|
||||
|
||||
fn build_peer_graph_from_synced_info<T: RouteCostCalculatorInterface>(
|
||||
|
@ -1322,7 +1322,7 @@ impl PeerRouteServiceImpl {
|
|||
self.global_ctx.get_network_name(),
|
||||
);
|
||||
|
||||
let mut ctrl = BaseController {};
|
||||
let mut ctrl = BaseController::default();
|
||||
ctrl.set_timeout_ms(3000);
|
||||
let ret = rpc_stub
|
||||
.sync_route_info(
|
||||
|
|
|
@ -224,7 +224,10 @@ pub mod tests {
|
|||
|
||||
let msg = random_string(8192);
|
||||
let ret = stub
|
||||
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
|
||||
.say_hello(
|
||||
RpcController::default(),
|
||||
SayHelloRequest { name: msg.clone() },
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -233,7 +236,10 @@ pub mod tests {
|
|||
|
||||
let msg = random_string(10);
|
||||
let ret = stub
|
||||
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
|
||||
.say_hello(
|
||||
RpcController::default(),
|
||||
SayHelloRequest { name: msg.clone() },
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
@ -281,7 +287,10 @@ pub mod tests {
|
|||
);
|
||||
|
||||
let ret = stub
|
||||
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
|
||||
.say_hello(
|
||||
RpcController::default(),
|
||||
SayHelloRequest { name: msg.clone() },
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(ret.greeting, format!("Hello {}!", msg));
|
||||
|
@ -289,14 +298,20 @@ pub mod tests {
|
|||
// call again
|
||||
let msg = random_string(16 * 1024);
|
||||
let ret = stub
|
||||
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
|
||||
.say_hello(
|
||||
RpcController::default(),
|
||||
SayHelloRequest { name: msg.clone() },
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(ret.greeting, format!("Hello {}!", msg));
|
||||
|
||||
let msg = random_string(16 * 1024);
|
||||
let ret = stub
|
||||
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
|
||||
.say_hello(
|
||||
RpcController::default(),
|
||||
SayHelloRequest { name: msg.clone() },
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(ret.greeting, format!("Hello {}!", msg));
|
||||
|
@ -340,13 +355,19 @@ pub mod tests {
|
|||
|
||||
let msg = random_string(16 * 1024);
|
||||
let ret = stub1
|
||||
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
|
||||
.say_hello(
|
||||
RpcController::default(),
|
||||
SayHelloRequest { name: msg.clone() },
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(ret.greeting, format!("Hello {}!", msg));
|
||||
|
||||
let ret = stub2
|
||||
.say_hello(RpcController {}, SayHelloRequest { name: msg.clone() })
|
||||
.say_hello(
|
||||
RpcController::default(),
|
||||
SayHelloRequest { name: msg.clone() },
|
||||
)
|
||||
.await;
|
||||
assert!(ret.is_err() && ret.unwrap_err().to_string().contains("Timeout"));
|
||||
}
|
||||
|
|
138
easytier/src/peers/peer_task.rs
Normal file
138
easytier/src/peers/peer_task.rs
Normal file
|
@ -0,0 +1,138 @@
|
|||
use std::result::Result;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use dashmap::DashMap;
|
||||
use tokio::select;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use crate::common::scoped_task::ScopedTask;
|
||||
use anyhow::Error;
|
||||
|
||||
use super::peer_manager::PeerManager;
|
||||
|
||||
#[async_trait]
|
||||
pub trait PeerTaskLauncher: Send + Sync + Clone + 'static {
|
||||
type Data;
|
||||
type CollectPeerItem;
|
||||
type TaskRet;
|
||||
|
||||
fn new_data(&self, peer_mgr: Arc<PeerManager>) -> Self::Data;
|
||||
async fn collect_peers_need_task(&self, data: &Self::Data) -> Vec<Self::CollectPeerItem>;
|
||||
async fn launch_task(
|
||||
&self,
|
||||
data: &Self::Data,
|
||||
item: Self::CollectPeerItem,
|
||||
) -> JoinHandle<Result<Self::TaskRet, Error>>;
|
||||
|
||||
async fn all_task_done(&self, _data: &Self::Data) {}
|
||||
|
||||
fn loop_interval_ms(&self) -> u64 {
|
||||
5000
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PeerTaskManager<Launcher: PeerTaskLauncher> {
|
||||
launcher: Launcher,
|
||||
peer_mgr: Arc<PeerManager>,
|
||||
main_loop_task: Mutex<Option<ScopedTask<()>>>,
|
||||
run_signal: Arc<Notify>,
|
||||
data: Launcher::Data,
|
||||
}
|
||||
|
||||
impl<D, C, T, L> PeerTaskManager<L>
|
||||
where
|
||||
D: Send + Sync + Clone + 'static,
|
||||
C: std::fmt::Debug + Send + Sync + Clone + core::hash::Hash + Eq + 'static,
|
||||
T: Send + 'static,
|
||||
L: PeerTaskLauncher<Data = D, CollectPeerItem = C, TaskRet = T> + 'static,
|
||||
{
|
||||
pub fn new(launcher: L, peer_mgr: Arc<PeerManager>) -> Self {
|
||||
let data = launcher.new_data(peer_mgr.clone());
|
||||
Self {
|
||||
launcher,
|
||||
peer_mgr,
|
||||
main_loop_task: Mutex::new(None),
|
||||
run_signal: Arc::new(Notify::new()),
|
||||
data,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(&self) {
|
||||
let task = tokio::spawn(Self::main_loop(
|
||||
self.launcher.clone(),
|
||||
self.data.clone(),
|
||||
self.run_signal.clone(),
|
||||
))
|
||||
.into();
|
||||
self.main_loop_task.lock().unwrap().replace(task);
|
||||
}
|
||||
|
||||
async fn main_loop(launcher: L, data: D, signal: Arc<Notify>) {
|
||||
let peer_task_map = Arc::new(DashMap::<C, ScopedTask<Result<T, Error>>>::new());
|
||||
|
||||
loop {
|
||||
let peers_to_connect = launcher.collect_peers_need_task(&data).await;
|
||||
|
||||
// remove task not in peers_to_connect
|
||||
let mut to_remove = vec![];
|
||||
for item in peer_task_map.iter() {
|
||||
if !peers_to_connect.contains(item.key()) || item.value().is_finished() {
|
||||
to_remove.push(item.key().clone());
|
||||
}
|
||||
}
|
||||
|
||||
tracing::debug!(
|
||||
?peers_to_connect,
|
||||
?to_remove,
|
||||
"got peers to connect and remove"
|
||||
);
|
||||
|
||||
for key in to_remove {
|
||||
if let Some((_, task)) = peer_task_map.remove(&key) {
|
||||
task.abort();
|
||||
match task.await {
|
||||
Ok(Ok(_)) => {}
|
||||
Ok(Err(task_ret)) => {
|
||||
tracing::error!(?task_ret, "hole punching task failed");
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!(?e, "hole punching task aborted");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !peers_to_connect.is_empty() {
|
||||
for item in peers_to_connect {
|
||||
if peer_task_map.contains_key(&item) {
|
||||
continue;
|
||||
}
|
||||
|
||||
tracing::debug!(?item, "launch hole punching task");
|
||||
peer_task_map
|
||||
.insert(item.clone(), launcher.launch_task(&data, item).await.into());
|
||||
}
|
||||
} else if peer_task_map.is_empty() {
|
||||
tracing::debug!("all task done");
|
||||
launcher.all_task_done(&data).await;
|
||||
}
|
||||
|
||||
select! {
|
||||
_ = tokio::time::sleep(std::time::Duration::from_millis(
|
||||
launcher.loop_interval_ms(),
|
||||
)) => {},
|
||||
_ = signal.notified() => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_immediately(&self) {
|
||||
self.run_signal.notify_one();
|
||||
}
|
||||
|
||||
pub fn data(&self) -> D {
|
||||
self.data.clone()
|
||||
}
|
||||
}
|
|
@ -42,6 +42,8 @@ message RpcPacket {
|
|||
int32 trace_id = 9;
|
||||
}
|
||||
|
||||
message Void {}
|
||||
|
||||
message UUID {
|
||||
uint64 high = 1;
|
||||
uint64 low = 2;
|
||||
|
@ -57,6 +59,8 @@ enum NatType {
|
|||
PortRestricted = 5;
|
||||
Symmetric = 6;
|
||||
SymUdpFirewall = 7;
|
||||
SymmetricEasyInc = 8;
|
||||
SymmetricEasyDec = 9;
|
||||
}
|
||||
|
||||
message Ipv4Addr { uint32 addr = 1; }
|
||||
|
|
|
@ -93,27 +93,78 @@ service DirectConnectorRpc {
|
|||
rpc GetIpList(GetIpListRequest) returns (GetIpListResponse);
|
||||
}
|
||||
|
||||
message TryPunchHoleRequest { common.SocketAddr local_mapped_addr = 1; }
|
||||
|
||||
message TryPunchHoleResponse { common.SocketAddr remote_mapped_addr = 1; }
|
||||
|
||||
message TryPunchSymmetricRequest {
|
||||
common.SocketAddr listener_addr = 1;
|
||||
uint32 port = 2;
|
||||
repeated common.Ipv4Addr public_ips = 3;
|
||||
uint32 min_port = 4;
|
||||
uint32 max_port = 5;
|
||||
uint32 transaction_id = 6;
|
||||
uint32 round = 7;
|
||||
uint32 last_port_index = 8;
|
||||
message SelectPunchListenerRequest {
|
||||
bool force_new = 1;
|
||||
}
|
||||
|
||||
message TryPunchSymmetricResponse { uint32 last_port_index = 1; }
|
||||
message SelectPunchListenerResponse {
|
||||
common.SocketAddr listener_mapped_addr = 1;
|
||||
}
|
||||
|
||||
message SendPunchPacketConeRequest {
|
||||
common.SocketAddr listener_mapped_addr = 1;
|
||||
common.SocketAddr dest_addr = 2;
|
||||
uint32 transaction_id = 3;
|
||||
// send this many packets in a batch
|
||||
uint32 packet_count_per_batch = 4;
|
||||
// send total this batch count, total packet count = packet_batch_size * packet_batch_count
|
||||
uint32 packet_batch_count = 5;
|
||||
// interval between each batch
|
||||
uint32 packet_interval_ms = 6;
|
||||
}
|
||||
|
||||
message SendPunchPacketHardSymRequest {
|
||||
common.SocketAddr listener_mapped_addr = 1;
|
||||
|
||||
repeated common.Ipv4Addr public_ips = 2;
|
||||
uint32 transaction_id = 3;
|
||||
uint32 port_index = 4;
|
||||
uint32 round = 5;
|
||||
}
|
||||
|
||||
message SendPunchPacketHardSymResponse { uint32 next_port_index = 1; }
|
||||
|
||||
message SendPunchPacketEasySymRequest {
|
||||
common.SocketAddr listener_mapped_addr = 1;
|
||||
repeated common.Ipv4Addr public_ips = 2;
|
||||
uint32 transaction_id = 3;
|
||||
|
||||
uint32 base_port_num = 4;
|
||||
uint32 max_port_num = 5;
|
||||
bool is_incremental = 6;
|
||||
}
|
||||
|
||||
message SendPunchPacketBothEasySymRequest {
|
||||
uint32 udp_socket_count = 1;
|
||||
common.Ipv4Addr public_ip = 2;
|
||||
uint32 transaction_id = 3;
|
||||
|
||||
uint32 dst_port_num = 4;
|
||||
uint32 wait_time_ms = 5;
|
||||
}
|
||||
|
||||
message SendPunchPacketBothEasySymResponse {
|
||||
// is doing punch with other peer
|
||||
bool is_busy = 1;
|
||||
common.SocketAddr base_mapped_addr = 2;
|
||||
}
|
||||
|
||||
service UdpHolePunchRpc {
|
||||
rpc TryPunchHole(TryPunchHoleRequest) returns (TryPunchHoleResponse);
|
||||
rpc TryPunchSymmetric(TryPunchSymmetricRequest)
|
||||
returns (TryPunchSymmetricResponse);
|
||||
rpc SelectPunchListener(SelectPunchListenerRequest)
|
||||
returns (SelectPunchListenerResponse);
|
||||
|
||||
// send packet to one remote_addr, used by nat1-3 to nat1-3
|
||||
rpc SendPunchPacketCone(SendPunchPacketConeRequest) returns (common.Void);
|
||||
|
||||
// send packet to multiple remote_addr (birthday attack), used by nat4 to nat1-3
|
||||
rpc SendPunchPacketHardSym(SendPunchPacketHardSymRequest)
|
||||
returns (SendPunchPacketHardSymResponse);
|
||||
rpc SendPunchPacketEasySym(SendPunchPacketEasySymRequest)
|
||||
returns (common.Void);
|
||||
|
||||
// nat4 to nat4 (both predictably)
|
||||
rpc SendPunchPacketBothEasySym(SendPunchPacketBothEasySymRequest)
|
||||
returns (SendPunchPacketBothEasySymResponse);
|
||||
}
|
||||
|
||||
message DirectConnectedPeerInfo { int32 latency_ms = 1; }
|
||||
|
|
|
@ -146,7 +146,7 @@ impl Server {
|
|||
async fn handle_rpc_request(packet: RpcPacket, reg: Arc<ServiceRegistry>) -> Result<Bytes> {
|
||||
let rpc_request = RpcRequest::decode(Bytes::from(packet.body))?;
|
||||
let timeout_duration = std::time::Duration::from_millis(rpc_request.timeout_ms as u64);
|
||||
let ctrl = RpcController {};
|
||||
let ctrl = RpcController::default();
|
||||
Ok(timeout(
|
||||
timeout_duration,
|
||||
reg.call_method(
|
||||
|
|
|
@ -13,6 +13,34 @@ pub trait Controller: Send + Sync + 'static {
|
|||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BaseController {}
|
||||
pub struct BaseController {
|
||||
pub timeout_ms: i32,
|
||||
pub trace_id: i32,
|
||||
}
|
||||
|
||||
impl Controller for BaseController {}
|
||||
impl Controller for BaseController {
|
||||
fn timeout_ms(&self) -> i32 {
|
||||
self.timeout_ms
|
||||
}
|
||||
|
||||
fn set_timeout_ms(&mut self, timeout_ms: i32) {
|
||||
self.timeout_ms = timeout_ms;
|
||||
}
|
||||
|
||||
fn set_trace_id(&mut self, trace_id: i32) {
|
||||
self.trace_id = trace_id;
|
||||
}
|
||||
|
||||
fn trace_id(&self) -> i32 {
|
||||
self.trace_id
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BaseController {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
timeout_ms: 5000,
|
||||
trace_id: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -121,14 +121,14 @@ async fn rpc_basic_test() {
|
|||
|
||||
// small size req and resp
|
||||
|
||||
let ctrl = RpcController {};
|
||||
let ctrl = RpcController::default();
|
||||
let input = SayHelloRequest {
|
||||
name: "world".to_string(),
|
||||
};
|
||||
let ret = out.say_hello(ctrl, input).await;
|
||||
assert_eq!(ret.unwrap().greeting, "Hello world!");
|
||||
|
||||
let ctrl = RpcController {};
|
||||
let ctrl = RpcController::default();
|
||||
let input = SayGoodbyeRequest {
|
||||
name: "world".to_string(),
|
||||
};
|
||||
|
@ -136,7 +136,7 @@ async fn rpc_basic_test() {
|
|||
assert_eq!(ret.unwrap().greeting, "Goodbye, world!");
|
||||
|
||||
// large size req and resp
|
||||
let ctrl = RpcController {};
|
||||
let ctrl = RpcController::default();
|
||||
let name = random_string(20 * 1024 * 1024);
|
||||
let input = SayGoodbyeRequest { name: name.clone() };
|
||||
let ret = out.say_goodbye(ctrl, input).await;
|
||||
|
@ -160,7 +160,7 @@ async fn rpc_timeout_test() {
|
|||
.client
|
||||
.scoped_client::<GreetingClientFactory<RpcController>>(1, 1, "test".to_string());
|
||||
|
||||
let ctrl = RpcController {};
|
||||
let ctrl = RpcController::default();
|
||||
let input = SayHelloRequest {
|
||||
name: "world".to_string(),
|
||||
};
|
||||
|
@ -199,7 +199,7 @@ async fn standalone_rpc_test() {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let ctrl = RpcController {};
|
||||
let ctrl = RpcController::default();
|
||||
let input = SayHelloRequest {
|
||||
name: "world".to_string(),
|
||||
};
|
||||
|
@ -211,7 +211,7 @@ async fn standalone_rpc_test() {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let ctrl = RpcController {};
|
||||
let ctrl = RpcController::default();
|
||||
let input = SayGoodbyeRequest {
|
||||
name: "world".to_string(),
|
||||
};
|
||||
|
|
|
@ -94,7 +94,7 @@ pub trait Tunnel: Send {
|
|||
|
||||
#[auto_impl::auto_impl(Arc)]
|
||||
pub trait TunnelConnCounter: 'static + Send + Sync + Debug {
|
||||
fn get(&self) -> u32;
|
||||
fn get(&self) -> Option<u32>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
|
@ -114,8 +114,8 @@ pub trait TunnelListener: Send {
|
|||
#[derive(Debug)]
|
||||
struct FakeTunnelConnCounter {}
|
||||
impl TunnelConnCounter for FakeTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
0
|
||||
fn get(&self) -> Option<u32> {
|
||||
None
|
||||
}
|
||||
}
|
||||
Arc::new(Box::new(FakeTunnelConnCounter {}))
|
||||
|
|
|
@ -43,6 +43,10 @@ impl TunnelListener for TcpTunnelListener {
|
|||
setup_sokcet2(&socket2_socket, &addr)?;
|
||||
let socket = TcpSocket::from_std_stream(socket2_socket.into());
|
||||
|
||||
if let Err(e) = socket.set_nodelay(true) {
|
||||
tracing::warn!(?e, "set_nodelay fail in listen");
|
||||
}
|
||||
|
||||
self.addr
|
||||
.set_port(Some(socket.local_addr()?.port()))
|
||||
.unwrap();
|
||||
|
@ -54,7 +58,11 @@ impl TunnelListener for TcpTunnelListener {
|
|||
async fn accept(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
let listener = self.listener.as_ref().unwrap();
|
||||
let (stream, _) = listener.accept().await?;
|
||||
stream.set_nodelay(true).unwrap();
|
||||
|
||||
if let Err(e) = stream.set_nodelay(true) {
|
||||
tracing::warn!(?e, "set_nodelay fail in accept");
|
||||
}
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
local_addr: Some(self.local_url().into()),
|
||||
|
@ -80,7 +88,9 @@ fn get_tunnel_with_tcp_stream(
|
|||
stream: TcpStream,
|
||||
remote_url: url::Url,
|
||||
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
|
||||
stream.set_nodelay(true).unwrap();
|
||||
if let Err(e) = stream.set_nodelay(true) {
|
||||
tracing::warn!(?e, "set_nodelay fail in get_tunnel_with_tcp_stream");
|
||||
}
|
||||
|
||||
let info = TunnelInfo {
|
||||
tunnel_type: "tcp".to_owned(),
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use std::{fmt::Debug, sync::Arc};
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
sync::{Arc, Weak},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use bytes::BytesMut;
|
||||
|
@ -445,25 +448,25 @@ impl TunnelListener for UdpTunnelListener {
|
|||
|
||||
fn get_conn_counter(&self) -> Arc<Box<dyn TunnelConnCounter>> {
|
||||
struct UdpTunnelConnCounter {
|
||||
sock_map: Arc<DashMap<SocketAddr, UdpConnection>>,
|
||||
sock_map: Weak<DashMap<SocketAddr, UdpConnection>>,
|
||||
}
|
||||
|
||||
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||
fn get(&self) -> Option<u32> {
|
||||
self.sock_map.upgrade().map(|x| x.len() as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for UdpTunnelConnCounter {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("UdpTunnelConnCounter")
|
||||
.field("sock_map_len", &self.sock_map.len())
|
||||
.field("sock_map_len", &self.get())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl TunnelConnCounter for UdpTunnelConnCounter {
|
||||
fn get(&self) -> u32 {
|
||||
self.sock_map.len() as u32
|
||||
}
|
||||
}
|
||||
|
||||
Arc::new(Box::new(UdpTunnelConnCounter {
|
||||
sock_map: self.data.sock_map.clone(),
|
||||
sock_map: Arc::downgrade(&self.data.sock_map.clone()),
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
@ -942,14 +945,22 @@ mod tests {
|
|||
|
||||
listener.listen().await.unwrap();
|
||||
let c1 = listener.accept().await.unwrap();
|
||||
assert_eq!(conn_counter.get(), 1);
|
||||
assert_eq!(conn_counter.get(), Some(1));
|
||||
let c2 = listener.accept().await.unwrap();
|
||||
assert_eq!(conn_counter.get(), 2);
|
||||
assert_eq!(conn_counter.get(), Some(2));
|
||||
|
||||
drop(c2);
|
||||
wait_for_condition(|| async { conn_counter.get() == 1 }, Duration::from_secs(1)).await;
|
||||
wait_for_condition(
|
||||
|| async { conn_counter.get() == Some(1) },
|
||||
Duration::from_secs(1),
|
||||
)
|
||||
.await;
|
||||
|
||||
drop(c1);
|
||||
wait_for_condition(|| async { conn_counter.get() == 0 }, Duration::from_secs(1)).await;
|
||||
wait_for_condition(
|
||||
|| async { conn_counter.get().unwrap_or(0) == 0 },
|
||||
Duration::from_secs(1),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user