mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-16 03:32:43 +08:00
fix ring tunnel cannot close (#51)
This commit is contained in:
parent
727ef37ae4
commit
50e14798d6
|
@ -42,13 +42,22 @@ pub type PeerConnId = uuid::Uuid;
|
||||||
|
|
||||||
macro_rules! wait_response {
|
macro_rules! wait_response {
|
||||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||||
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
|
let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
|
||||||
if rsp_vec.is_err() {
|
|
||||||
return Err(TunnelError::WaitRespError(
|
return Err(TunnelError::WaitRespError(
|
||||||
"wait handshake response timeout".to_owned(),
|
"wait handshake response timeout".to_owned(),
|
||||||
));
|
));
|
||||||
}
|
};
|
||||||
let rsp_vec = rsp_vec.unwrap().unwrap()?;
|
let Some(rsp_vec) = rsp_vec else {
|
||||||
|
return Err(TunnelError::WaitRespError(
|
||||||
|
"wait handshake response get none".to_owned(),
|
||||||
|
));
|
||||||
|
};
|
||||||
|
let Ok(rsp_vec) = rsp_vec else {
|
||||||
|
return Err(TunnelError::WaitRespError(format!(
|
||||||
|
"wait handshake response get error {}",
|
||||||
|
rsp_vec.err().unwrap()
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
|
||||||
let $out_var;
|
let $out_var;
|
||||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
use std::{
|
use std::{
|
||||||
collections::HashMap,
|
collections::HashMap,
|
||||||
sync::{atomic::AtomicBool, Arc},
|
sync::{
|
||||||
|
atomic::{AtomicBool, AtomicU32},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
task::Poll,
|
task::Poll,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -22,28 +25,55 @@ use uuid::Uuid;
|
||||||
use crate::tunnels::{SinkError, SinkItem};
|
use crate::tunnels::{SinkError, SinkItem};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
|
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::FramedTunnel,
|
||||||
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
DatagramSink, DatagramStream, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||||
};
|
};
|
||||||
|
|
||||||
static RING_TUNNEL_CAP: usize = 1000;
|
static RING_TUNNEL_CAP: usize = 1000;
|
||||||
|
|
||||||
|
struct Ring {
|
||||||
|
id: Uuid,
|
||||||
|
ring: ArrayQueue<SinkItem>,
|
||||||
|
consume_notify: Notify,
|
||||||
|
produce_notify: Notify,
|
||||||
|
closed: AtomicBool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Ring {
|
||||||
|
fn new(cap: usize, id: uuid::Uuid) -> Self {
|
||||||
|
Self {
|
||||||
|
id,
|
||||||
|
ring: ArrayQueue::new(cap),
|
||||||
|
consume_notify: Notify::new(),
|
||||||
|
produce_notify: Notify::new(),
|
||||||
|
closed: AtomicBool::new(false),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn close(&self) {
|
||||||
|
self.closed
|
||||||
|
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
self.produce_notify.notify_one();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn closed(&self) -> bool {
|
||||||
|
self.closed.load(std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub struct RingTunnel {
|
pub struct RingTunnel {
|
||||||
id: Uuid,
|
id: Uuid,
|
||||||
ring: Arc<ArrayQueue<SinkItem>>,
|
ring: Arc<Ring>,
|
||||||
consume_notify: Arc<Notify>,
|
sender_counter: Arc<AtomicU32>,
|
||||||
produce_notify: Arc<Notify>,
|
|
||||||
closed: Arc<AtomicBool>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RingTunnel {
|
impl RingTunnel {
|
||||||
pub fn new(cap: usize) -> Self {
|
pub fn new(cap: usize) -> Self {
|
||||||
|
let id = Uuid::new_v4();
|
||||||
RingTunnel {
|
RingTunnel {
|
||||||
id: Uuid::new_v4(),
|
id: id.clone(),
|
||||||
ring: Arc::new(ArrayQueue::new(cap)),
|
ring: Arc::new(Ring::new(cap, id)),
|
||||||
consume_notify: Arc::new(Notify::new()),
|
sender_counter: Arc::new(AtomicU32::new(1)),
|
||||||
produce_notify: Arc::new(Notify::new()),
|
|
||||||
closed: Arc::new(AtomicBool::new(false)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -55,27 +85,24 @@ impl RingTunnel {
|
||||||
|
|
||||||
fn recv_stream(&self) -> impl DatagramStream {
|
fn recv_stream(&self) -> impl DatagramStream {
|
||||||
let ring = self.ring.clone();
|
let ring = self.ring.clone();
|
||||||
let produce_notify = self.produce_notify.clone();
|
|
||||||
let consume_notify = self.consume_notify.clone();
|
|
||||||
let closed = self.closed.clone();
|
|
||||||
let id = self.id;
|
let id = self.id;
|
||||||
stream! {
|
stream! {
|
||||||
loop {
|
loop {
|
||||||
if closed.load(std::sync::atomic::Ordering::Relaxed) {
|
match ring.ring.pop() {
|
||||||
log::warn!("ring recv tunnel {:?} closed", id);
|
|
||||||
yield Err(TunnelError::CommonError("Closed".to_owned()));
|
|
||||||
}
|
|
||||||
match ring.pop() {
|
|
||||||
Some(v) => {
|
Some(v) => {
|
||||||
let mut out = BytesMut::new();
|
let mut out = BytesMut::new();
|
||||||
out.extend_from_slice(&v);
|
out.extend_from_slice(&v);
|
||||||
consume_notify.notify_one();
|
ring.consume_notify.notify_one();
|
||||||
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
|
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
|
||||||
yield Ok(out);
|
yield Ok(out);
|
||||||
},
|
},
|
||||||
None => {
|
None => {
|
||||||
|
if ring.closed() {
|
||||||
|
log::warn!("ring recv tunnel {:?} closed", id);
|
||||||
|
yield Err(TunnelError::CommonError("ring closed".to_owned()));
|
||||||
|
}
|
||||||
log::trace!("waiting recv buffer, id: {}", id);
|
log::trace!("waiting recv buffer, id: {}", id);
|
||||||
produce_notify.notified().await;
|
ring.produce_notify.notified().await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -84,18 +111,13 @@ impl RingTunnel {
|
||||||
|
|
||||||
fn send_sink(&self) -> impl DatagramSink {
|
fn send_sink(&self) -> impl DatagramSink {
|
||||||
let ring = self.ring.clone();
|
let ring = self.ring.clone();
|
||||||
let produce_notify = self.produce_notify.clone();
|
let sender_counter = self.sender_counter.clone();
|
||||||
let consume_notify = self.consume_notify.clone();
|
|
||||||
let closed = self.closed.clone();
|
|
||||||
let id = self.id;
|
|
||||||
|
|
||||||
// type T = RingTunnel;
|
|
||||||
|
|
||||||
use tokio::task::JoinHandle;
|
use tokio::task::JoinHandle;
|
||||||
|
|
||||||
struct T {
|
struct T {
|
||||||
ring: RingTunnel,
|
ring: Arc<Ring>,
|
||||||
wait_consume_task: Option<JoinHandle<()>>,
|
wait_consume_task: Option<JoinHandle<()>>,
|
||||||
|
sender_counter: Arc<AtomicU32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl T {
|
impl T {
|
||||||
|
@ -110,16 +132,15 @@ impl RingTunnel {
|
||||||
}
|
}
|
||||||
if self_mut.wait_consume_task.is_none() {
|
if self_mut.wait_consume_task.is_none() {
|
||||||
let id = self_mut.ring.id;
|
let id = self_mut.ring.id;
|
||||||
let consume_notify = self_mut.ring.consume_notify.clone();
|
let ring = self_mut.ring.clone();
|
||||||
let ring = self_mut.ring.ring.clone();
|
|
||||||
let task = async move {
|
let task = async move {
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"waiting ring consume done, expected_size: {}, id: {}",
|
"waiting ring consume done, expected_size: {}, id: {}",
|
||||||
expected_size,
|
expected_size,
|
||||||
id
|
id
|
||||||
);
|
);
|
||||||
while ring.len() > expected_size {
|
while ring.ring.len() > expected_size {
|
||||||
consume_notify.notified().await;
|
ring.consume_notify.notified().await;
|
||||||
}
|
}
|
||||||
log::trace!(
|
log::trace!(
|
||||||
"ring consume done, expected_size: {}, id: {}",
|
"ring consume done, expected_size: {}, id: {}",
|
||||||
|
@ -147,6 +168,12 @@ impl RingTunnel {
|
||||||
self: std::pin::Pin<&mut Self>,
|
self: std::pin::Pin<&mut Self>,
|
||||||
cx: &mut std::task::Context<'_>,
|
cx: &mut std::task::Context<'_>,
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
if self.ring.closed() {
|
||||||
|
return Poll::Ready(Err(TunnelError::CommonError(
|
||||||
|
"ring closed during ready".to_owned(),
|
||||||
|
)
|
||||||
|
.into()));
|
||||||
|
}
|
||||||
let expected_size = self.ring.ring.capacity() - 1;
|
let expected_size = self.ring.ring.capacity() - 1;
|
||||||
match self.wait_ring_consume(cx, expected_size) {
|
match self.wait_ring_consume(cx, expected_size) {
|
||||||
Poll::Ready(_) => Poll::Ready(Ok(())),
|
Poll::Ready(_) => Poll::Ready(Ok(())),
|
||||||
|
@ -158,6 +185,11 @@ impl RingTunnel {
|
||||||
self: std::pin::Pin<&mut Self>,
|
self: std::pin::Pin<&mut Self>,
|
||||||
item: SinkItem,
|
item: SinkItem,
|
||||||
) -> Result<(), Self::Error> {
|
) -> Result<(), Self::Error> {
|
||||||
|
if self.ring.closed() {
|
||||||
|
return Err(
|
||||||
|
TunnelError::CommonError("ring closed during send".to_owned()).into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
|
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
|
||||||
self.ring.ring.push(item).unwrap();
|
self.ring.ring.push(item).unwrap();
|
||||||
self.ring.produce_notify.notify_one();
|
self.ring.produce_notify.notify_one();
|
||||||
|
@ -168,6 +200,12 @@ impl RingTunnel {
|
||||||
self: std::pin::Pin<&mut Self>,
|
self: std::pin::Pin<&mut Self>,
|
||||||
_cx: &mut std::task::Context<'_>,
|
_cx: &mut std::task::Context<'_>,
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
|
if self.ring.closed() {
|
||||||
|
return Poll::Ready(Err(TunnelError::CommonError(
|
||||||
|
"ring closed during flush".to_owned(),
|
||||||
|
)
|
||||||
|
.into()));
|
||||||
|
}
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -175,24 +213,38 @@ impl RingTunnel {
|
||||||
self: std::pin::Pin<&mut Self>,
|
self: std::pin::Pin<&mut Self>,
|
||||||
_cx: &mut std::task::Context<'_>,
|
_cx: &mut std::task::Context<'_>,
|
||||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||||
self.ring
|
self.ring.close();
|
||||||
.closed
|
|
||||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
|
||||||
log::warn!("ring tunnel send {:?} closed", self.ring.id);
|
|
||||||
self.ring.produce_notify.notify_one();
|
|
||||||
Poll::Ready(Ok(()))
|
Poll::Ready(Ok(()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Drop for T {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let rem = self
|
||||||
|
.sender_counter
|
||||||
|
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if rem == 1 {
|
||||||
|
self.ring.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sender_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
T {
|
T {
|
||||||
ring: RingTunnel {
|
ring,
|
||||||
id,
|
|
||||||
ring,
|
|
||||||
consume_notify,
|
|
||||||
produce_notify,
|
|
||||||
closed,
|
|
||||||
},
|
|
||||||
wait_consume_task: None,
|
wait_consume_task: None,
|
||||||
|
sender_counter,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for RingTunnel {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
let rem = self
|
||||||
|
.sender_counter
|
||||||
|
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||||
|
if rem == 1 {
|
||||||
|
self.ring.close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -213,11 +265,6 @@ impl Tunnel for RingTunnel {
|
||||||
|
|
||||||
fn info(&self) -> Option<TunnelInfo> {
|
fn info(&self) -> Option<TunnelInfo> {
|
||||||
None
|
None
|
||||||
// Some(TunnelInfo {
|
|
||||||
// tunnel_type: "ring".to_owned(),
|
|
||||||
// local_addr: format!("ring://{}", self.id),
|
|
||||||
// remote_addr: format!("ring://{}", self.id),
|
|
||||||
// })
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -241,48 +288,29 @@ impl RingTunnelListener {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
struct ConnectionForServer {
|
|
||||||
conn: Arc<Connection>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Tunnel for ConnectionForServer {
|
fn get_tunnel_for_client(conn: Arc<Connection>) -> Box<dyn Tunnel> {
|
||||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
FramedTunnel::new_tunnel_with_info(
|
||||||
Box::new(self.conn.server.recv_stream())
|
Box::pin(conn.client.recv_stream()),
|
||||||
}
|
conn.server.send_sink(),
|
||||||
|
TunnelInfo {
|
||||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
|
||||||
Box::new(self.conn.client.send_sink())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn info(&self) -> Option<TunnelInfo> {
|
|
||||||
Some(TunnelInfo {
|
|
||||||
tunnel_type: "ring".to_owned(),
|
tunnel_type: "ring".to_owned(),
|
||||||
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||||
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||||
})
|
},
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ConnectionForClient {
|
fn get_tunnel_for_server(conn: Arc<Connection>) -> Box<dyn Tunnel> {
|
||||||
conn: Arc<Connection>,
|
FramedTunnel::new_tunnel_with_info(
|
||||||
}
|
Box::pin(conn.server.recv_stream()),
|
||||||
|
conn.client.send_sink(),
|
||||||
impl Tunnel for ConnectionForClient {
|
TunnelInfo {
|
||||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
|
||||||
Box::new(self.conn.client.recv_stream())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
|
||||||
Box::new(self.conn.server.send_sink())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn info(&self) -> Option<TunnelInfo> {
|
|
||||||
Some(TunnelInfo {
|
|
||||||
tunnel_type: "ring".to_owned(),
|
tunnel_type: "ring".to_owned(),
|
||||||
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||||
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||||
})
|
},
|
||||||
}
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RingTunnelListener {
|
impl RingTunnelListener {
|
||||||
|
@ -308,7 +336,7 @@ impl TunnelListener for RingTunnelListener {
|
||||||
if let Some(conn) = self.conn_receiver.recv().await {
|
if let Some(conn) = self.conn_receiver.recv().await {
|
||||||
if conn.server.id == my_addr {
|
if conn.server.id == my_addr {
|
||||||
log::info!("accept new conn of key: {}", self.listerner_addr);
|
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||||
return Ok(Box::new(ConnectionForServer { conn }));
|
return Ok(get_tunnel_for_server(conn));
|
||||||
} else {
|
} else {
|
||||||
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
|
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
|
||||||
return Err(TunnelError::CommonError(
|
return Err(TunnelError::CommonError(
|
||||||
|
@ -353,7 +381,7 @@ impl TunnelConnector for RingTunnelConnector {
|
||||||
entry
|
entry
|
||||||
.send(conn.clone())
|
.send(conn.clone())
|
||||||
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
|
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
|
||||||
Ok(Box::new(ConnectionForClient { conn }))
|
Ok(get_tunnel_for_client(conn))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remote_url(&self) -> url::Url {
|
fn remote_url(&self) -> url::Url {
|
||||||
|
@ -367,13 +395,15 @@ pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
||||||
server: RingTunnel::new(RING_TUNNEL_CAP),
|
server: RingTunnel::new(RING_TUNNEL_CAP),
|
||||||
});
|
});
|
||||||
(
|
(
|
||||||
Box::new(ConnectionForServer { conn: conn.clone() }),
|
Box::new(get_tunnel_for_server(conn.clone())),
|
||||||
Box::new(ConnectionForClient { conn }),
|
Box::new(get_tunnel_for_client(conn)),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -393,4 +423,14 @@ mod tests {
|
||||||
let connector = RingTunnelConnector::new(id);
|
let connector = RingTunnelConnector::new(id);
|
||||||
_tunnel_bench(listener, connector).await
|
_tunnel_bench(listener, connector).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn ring_close() {
|
||||||
|
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||||
|
drop(stunnel);
|
||||||
|
|
||||||
|
let mut stream = ctunnel.pin_stream();
|
||||||
|
let ret = stream.next().await;
|
||||||
|
assert!(ret.as_ref().unwrap().is_err(), "expect Err, got {:?}", ret);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@ use std::{
|
||||||
hash::Hasher,
|
hash::Hasher,
|
||||||
net::SocketAddr,
|
net::SocketAddr,
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Arc,
|
sync::{atomic::AtomicBool, Arc},
|
||||||
time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -120,6 +120,7 @@ struct WgPeerData {
|
||||||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||||
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
|
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
|
||||||
wg_type: WgType,
|
wg_type: WgType,
|
||||||
|
stopped: Arc<AtomicBool>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for WgPeerData {
|
impl Debug for WgPeerData {
|
||||||
|
@ -366,6 +367,8 @@ impl WgPeer {
|
||||||
tracing::error!("Failed to handle packet from me: {}", e);
|
tracing::error!("Failed to handle packet from me: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
data.stopped
|
||||||
|
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
|
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
|
||||||
|
@ -395,6 +398,7 @@ impl WgPeer {
|
||||||
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
|
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
|
||||||
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
|
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
|
||||||
wg_type: self.config.wg_type.clone(),
|
wg_type: self.config.wg_type.clone(),
|
||||||
|
stopped: Arc::new(AtomicBool::new(false)),
|
||||||
};
|
};
|
||||||
|
|
||||||
self.data = Some(data.clone());
|
self.data = Some(data.clone());
|
||||||
|
@ -403,6 +407,14 @@ impl WgPeer {
|
||||||
|
|
||||||
ctunnel
|
ctunnel
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn stopped(&self) -> bool {
|
||||||
|
self.data
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.stopped
|
||||||
|
.load(std::sync::atomic::Ordering::Relaxed)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for WgPeer {
|
impl Drop for WgPeer {
|
||||||
|
@ -427,6 +439,8 @@ pub struct WgTunnelListener {
|
||||||
conn_recv: ConnReceiver,
|
conn_recv: ConnReceiver,
|
||||||
conn_send: Option<ConnSender>,
|
conn_send: Option<ConnSender>,
|
||||||
|
|
||||||
|
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||||
|
|
||||||
tasks: JoinSet<()>,
|
tasks: JoinSet<()>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -441,6 +455,8 @@ impl WgTunnelListener {
|
||||||
conn_recv,
|
conn_recv,
|
||||||
conn_send: Some(conn_send),
|
conn_send: Some(conn_send),
|
||||||
|
|
||||||
|
wg_peer_map: Arc::new(DashMap::new()),
|
||||||
|
|
||||||
tasks: JoinSet::new(),
|
tasks: JoinSet::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -453,15 +469,16 @@ impl WgTunnelListener {
|
||||||
socket: Arc<UdpSocket>,
|
socket: Arc<UdpSocket>,
|
||||||
config: WgConfig,
|
config: WgConfig,
|
||||||
conn_sender: ConnSender,
|
conn_sender: ConnSender,
|
||||||
|
peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||||
) {
|
) {
|
||||||
let mut tasks = JoinSet::new();
|
let mut tasks = JoinSet::new();
|
||||||
let peer_map: Arc<DashMap<SocketAddr, WgPeer>> = Arc::new(DashMap::new());
|
|
||||||
|
|
||||||
let peer_map_clone = peer_map.clone();
|
let peer_map_clone = peer_map.clone();
|
||||||
tasks.spawn(async move {
|
tasks.spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600);
|
peer_map_clone
|
||||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
.retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped());
|
||||||
|
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
@ -524,6 +541,7 @@ impl TunnelListener for WgTunnelListener {
|
||||||
self.get_udp_socket(),
|
self.get_udp_socket(),
|
||||||
self.config.clone(),
|
self.config.clone(),
|
||||||
self.conn_send.take().unwrap(),
|
self.conn_send.take().unwrap(),
|
||||||
|
self.wg_peer_map.clone(),
|
||||||
));
|
));
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -788,4 +806,36 @@ pub mod tests {
|
||||||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||||
_tunnel_pingpong(listener, connector).await
|
_tunnel_pingpong(listener, connector).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn wg_server_erase_from_map_after_close() {
|
||||||
|
let (server_cfg, client_cfg) = create_wg_config();
|
||||||
|
let mut listener =
|
||||||
|
WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg);
|
||||||
|
listener.listen().await.unwrap();
|
||||||
|
|
||||||
|
const CONN_COUNT: usize = 10;
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
for _ in 0..CONN_COUNT {
|
||||||
|
let mut connector = WgTunnelConnector::new(
|
||||||
|
"wg://127.0.0.1:5595".parse().unwrap(),
|
||||||
|
client_cfg.clone(),
|
||||||
|
);
|
||||||
|
let ret = connector.connect().await;
|
||||||
|
assert!(ret.is_ok());
|
||||||
|
drop(ret);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
for _ in 0..CONN_COUNT {
|
||||||
|
let conn = listener.accept().await;
|
||||||
|
assert!(conn.is_ok());
|
||||||
|
drop(conn);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||||
|
|
||||||
|
assert_eq!(0, listener.wg_peer_map.len());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user