mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-16 03:32:43 +08:00
fix ring tunnel cannot close (#51)
This commit is contained in:
parent
727ef37ae4
commit
50e14798d6
|
@ -42,13 +42,22 @@ pub type PeerConnId = uuid::Uuid;
|
|||
|
||||
macro_rules! wait_response {
|
||||
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
|
||||
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
|
||||
if rsp_vec.is_err() {
|
||||
let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response timeout".to_owned(),
|
||||
));
|
||||
}
|
||||
let rsp_vec = rsp_vec.unwrap().unwrap()?;
|
||||
};
|
||||
let Some(rsp_vec) = rsp_vec else {
|
||||
return Err(TunnelError::WaitRespError(
|
||||
"wait handshake response get none".to_owned(),
|
||||
));
|
||||
};
|
||||
let Ok(rsp_vec) = rsp_vec else {
|
||||
return Err(TunnelError::WaitRespError(format!(
|
||||
"wait handshake response get error {}",
|
||||
rsp_vec.err().unwrap()
|
||||
)));
|
||||
};
|
||||
|
||||
let $out_var;
|
||||
let rsp_bytes = Packet::decode(&rsp_vec);
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU32},
|
||||
Arc,
|
||||
},
|
||||
task::Poll,
|
||||
};
|
||||
|
||||
|
@ -22,28 +25,55 @@ use uuid::Uuid;
|
|||
use crate::tunnels::{SinkError, SinkItem};
|
||||
|
||||
use super::{
|
||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
|
||||
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::FramedTunnel,
|
||||
DatagramSink, DatagramStream, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
|
||||
};
|
||||
|
||||
static RING_TUNNEL_CAP: usize = 1000;
|
||||
|
||||
struct Ring {
|
||||
id: Uuid,
|
||||
ring: ArrayQueue<SinkItem>,
|
||||
consume_notify: Notify,
|
||||
produce_notify: Notify,
|
||||
closed: AtomicBool,
|
||||
}
|
||||
|
||||
impl Ring {
|
||||
fn new(cap: usize, id: uuid::Uuid) -> Self {
|
||||
Self {
|
||||
id,
|
||||
ring: ArrayQueue::new(cap),
|
||||
consume_notify: Notify::new(),
|
||||
produce_notify: Notify::new(),
|
||||
closed: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
fn close(&self) {
|
||||
self.closed
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
self.produce_notify.notify_one();
|
||||
}
|
||||
|
||||
fn closed(&self) -> bool {
|
||||
self.closed.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RingTunnel {
|
||||
id: Uuid,
|
||||
ring: Arc<ArrayQueue<SinkItem>>,
|
||||
consume_notify: Arc<Notify>,
|
||||
produce_notify: Arc<Notify>,
|
||||
closed: Arc<AtomicBool>,
|
||||
ring: Arc<Ring>,
|
||||
sender_counter: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl RingTunnel {
|
||||
pub fn new(cap: usize) -> Self {
|
||||
let id = Uuid::new_v4();
|
||||
RingTunnel {
|
||||
id: Uuid::new_v4(),
|
||||
ring: Arc::new(ArrayQueue::new(cap)),
|
||||
consume_notify: Arc::new(Notify::new()),
|
||||
produce_notify: Arc::new(Notify::new()),
|
||||
closed: Arc::new(AtomicBool::new(false)),
|
||||
id: id.clone(),
|
||||
ring: Arc::new(Ring::new(cap, id)),
|
||||
sender_counter: Arc::new(AtomicU32::new(1)),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -55,27 +85,24 @@ impl RingTunnel {
|
|||
|
||||
fn recv_stream(&self) -> impl DatagramStream {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
stream! {
|
||||
loop {
|
||||
if closed.load(std::sync::atomic::Ordering::Relaxed) {
|
||||
log::warn!("ring recv tunnel {:?} closed", id);
|
||||
yield Err(TunnelError::CommonError("Closed".to_owned()));
|
||||
}
|
||||
match ring.pop() {
|
||||
match ring.ring.pop() {
|
||||
Some(v) => {
|
||||
let mut out = BytesMut::new();
|
||||
out.extend_from_slice(&v);
|
||||
consume_notify.notify_one();
|
||||
ring.consume_notify.notify_one();
|
||||
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
|
||||
yield Ok(out);
|
||||
},
|
||||
None => {
|
||||
if ring.closed() {
|
||||
log::warn!("ring recv tunnel {:?} closed", id);
|
||||
yield Err(TunnelError::CommonError("ring closed".to_owned()));
|
||||
}
|
||||
log::trace!("waiting recv buffer, id: {}", id);
|
||||
produce_notify.notified().await;
|
||||
ring.produce_notify.notified().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -84,18 +111,13 @@ impl RingTunnel {
|
|||
|
||||
fn send_sink(&self) -> impl DatagramSink {
|
||||
let ring = self.ring.clone();
|
||||
let produce_notify = self.produce_notify.clone();
|
||||
let consume_notify = self.consume_notify.clone();
|
||||
let closed = self.closed.clone();
|
||||
let id = self.id;
|
||||
|
||||
// type T = RingTunnel;
|
||||
|
||||
let sender_counter = self.sender_counter.clone();
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
struct T {
|
||||
ring: RingTunnel,
|
||||
ring: Arc<Ring>,
|
||||
wait_consume_task: Option<JoinHandle<()>>,
|
||||
sender_counter: Arc<AtomicU32>,
|
||||
}
|
||||
|
||||
impl T {
|
||||
|
@ -110,16 +132,15 @@ impl RingTunnel {
|
|||
}
|
||||
if self_mut.wait_consume_task.is_none() {
|
||||
let id = self_mut.ring.id;
|
||||
let consume_notify = self_mut.ring.consume_notify.clone();
|
||||
let ring = self_mut.ring.ring.clone();
|
||||
let ring = self_mut.ring.clone();
|
||||
let task = async move {
|
||||
log::trace!(
|
||||
"waiting ring consume done, expected_size: {}, id: {}",
|
||||
expected_size,
|
||||
id
|
||||
);
|
||||
while ring.len() > expected_size {
|
||||
consume_notify.notified().await;
|
||||
while ring.ring.len() > expected_size {
|
||||
ring.consume_notify.notified().await;
|
||||
}
|
||||
log::trace!(
|
||||
"ring consume done, expected_size: {}, id: {}",
|
||||
|
@ -147,6 +168,12 @@ impl RingTunnel {
|
|||
self: std::pin::Pin<&mut Self>,
|
||||
cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
if self.ring.closed() {
|
||||
return Poll::Ready(Err(TunnelError::CommonError(
|
||||
"ring closed during ready".to_owned(),
|
||||
)
|
||||
.into()));
|
||||
}
|
||||
let expected_size = self.ring.ring.capacity() - 1;
|
||||
match self.wait_ring_consume(cx, expected_size) {
|
||||
Poll::Ready(_) => Poll::Ready(Ok(())),
|
||||
|
@ -158,6 +185,11 @@ impl RingTunnel {
|
|||
self: std::pin::Pin<&mut Self>,
|
||||
item: SinkItem,
|
||||
) -> Result<(), Self::Error> {
|
||||
if self.ring.closed() {
|
||||
return Err(
|
||||
TunnelError::CommonError("ring closed during send".to_owned()).into(),
|
||||
);
|
||||
}
|
||||
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
|
||||
self.ring.ring.push(item).unwrap();
|
||||
self.ring.produce_notify.notify_one();
|
||||
|
@ -168,6 +200,12 @@ impl RingTunnel {
|
|||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
if self.ring.closed() {
|
||||
return Poll::Ready(Err(TunnelError::CommonError(
|
||||
"ring closed during flush".to_owned(),
|
||||
)
|
||||
.into()));
|
||||
}
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
|
@ -175,24 +213,38 @@ impl RingTunnel {
|
|||
self: std::pin::Pin<&mut Self>,
|
||||
_cx: &mut std::task::Context<'_>,
|
||||
) -> std::task::Poll<Result<(), Self::Error>> {
|
||||
self.ring
|
||||
.closed
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
log::warn!("ring tunnel send {:?} closed", self.ring.id);
|
||||
self.ring.produce_notify.notify_one();
|
||||
self.ring.close();
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for T {
|
||||
fn drop(&mut self) {
|
||||
let rem = self
|
||||
.sender_counter
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if rem == 1 {
|
||||
self.ring.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sender_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
|
||||
T {
|
||||
ring: RingTunnel {
|
||||
id,
|
||||
ring,
|
||||
consume_notify,
|
||||
produce_notify,
|
||||
closed,
|
||||
},
|
||||
ring,
|
||||
wait_consume_task: None,
|
||||
sender_counter,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for RingTunnel {
|
||||
fn drop(&mut self) {
|
||||
let rem = self
|
||||
.sender_counter
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
|
||||
if rem == 1 {
|
||||
self.ring.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -213,11 +265,6 @@ impl Tunnel for RingTunnel {
|
|||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
None
|
||||
// Some(TunnelInfo {
|
||||
// tunnel_type: "ring".to_owned(),
|
||||
// local_addr: format!("ring://{}", self.id),
|
||||
// remote_addr: format!("ring://{}", self.id),
|
||||
// })
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -241,48 +288,29 @@ impl RingTunnelListener {
|
|||
}
|
||||
}
|
||||
}
|
||||
struct ConnectionForServer {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForServer {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.server.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.client.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
fn get_tunnel_for_client(conn: Arc<Connection>) -> Box<dyn Tunnel> {
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
Box::pin(conn.client.recv_stream()),
|
||||
conn.server.send_sink(),
|
||||
TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
struct ConnectionForClient {
|
||||
conn: Arc<Connection>,
|
||||
}
|
||||
|
||||
impl Tunnel for ConnectionForClient {
|
||||
fn stream(&self) -> Box<dyn DatagramStream> {
|
||||
Box::new(self.conn.client.recv_stream())
|
||||
}
|
||||
|
||||
fn sink(&self) -> Box<dyn DatagramSink> {
|
||||
Box::new(self.conn.server.send_sink())
|
||||
}
|
||||
|
||||
fn info(&self) -> Option<TunnelInfo> {
|
||||
Some(TunnelInfo {
|
||||
fn get_tunnel_for_server(conn: Arc<Connection>) -> Box<dyn Tunnel> {
|
||||
FramedTunnel::new_tunnel_with_info(
|
||||
Box::pin(conn.server.recv_stream()),
|
||||
conn.client.send_sink(),
|
||||
TunnelInfo {
|
||||
tunnel_type: "ring".to_owned(),
|
||||
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
|
||||
})
|
||||
}
|
||||
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
|
||||
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
impl RingTunnelListener {
|
||||
|
@ -308,7 +336,7 @@ impl TunnelListener for RingTunnelListener {
|
|||
if let Some(conn) = self.conn_receiver.recv().await {
|
||||
if conn.server.id == my_addr {
|
||||
log::info!("accept new conn of key: {}", self.listerner_addr);
|
||||
return Ok(Box::new(ConnectionForServer { conn }));
|
||||
return Ok(get_tunnel_for_server(conn));
|
||||
} else {
|
||||
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
|
||||
return Err(TunnelError::CommonError(
|
||||
|
@ -353,7 +381,7 @@ impl TunnelConnector for RingTunnelConnector {
|
|||
entry
|
||||
.send(conn.clone())
|
||||
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
|
||||
Ok(Box::new(ConnectionForClient { conn }))
|
||||
Ok(get_tunnel_for_client(conn))
|
||||
}
|
||||
|
||||
fn remote_url(&self) -> url::Url {
|
||||
|
@ -367,13 +395,15 @@ pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
|
|||
server: RingTunnel::new(RING_TUNNEL_CAP),
|
||||
});
|
||||
(
|
||||
Box::new(ConnectionForServer { conn: conn.clone() }),
|
||||
Box::new(ConnectionForClient { conn }),
|
||||
Box::new(get_tunnel_for_server(conn.clone())),
|
||||
Box::new(get_tunnel_for_client(conn)),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use futures::StreamExt;
|
||||
|
||||
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
|
||||
|
||||
use super::*;
|
||||
|
@ -393,4 +423,14 @@ mod tests {
|
|||
let connector = RingTunnelConnector::new(id);
|
||||
_tunnel_bench(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn ring_close() {
|
||||
let (stunnel, ctunnel) = create_ring_tunnel_pair();
|
||||
drop(stunnel);
|
||||
|
||||
let mut stream = ctunnel.pin_stream();
|
||||
let ret = stream.next().await;
|
||||
assert!(ret.as_ref().unwrap().is_err(), "expect Err, got {:?}", ret);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use std::{
|
|||
hash::Hasher,
|
||||
net::SocketAddr,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
sync::{atomic::AtomicBool, Arc},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
|
@ -120,6 +120,7 @@ struct WgPeerData {
|
|||
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
|
||||
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
|
||||
wg_type: WgType,
|
||||
stopped: Arc<AtomicBool>,
|
||||
}
|
||||
|
||||
impl Debug for WgPeerData {
|
||||
|
@ -366,6 +367,8 @@ impl WgPeer {
|
|||
tracing::error!("Failed to handle packet from me: {}", e);
|
||||
}
|
||||
}
|
||||
data.stopped
|
||||
.store(true, std::sync::atomic::Ordering::Relaxed);
|
||||
}
|
||||
|
||||
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
|
||||
|
@ -395,6 +398,7 @@ impl WgPeer {
|
|||
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
|
||||
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
|
||||
wg_type: self.config.wg_type.clone(),
|
||||
stopped: Arc::new(AtomicBool::new(false)),
|
||||
};
|
||||
|
||||
self.data = Some(data.clone());
|
||||
|
@ -403,6 +407,14 @@ impl WgPeer {
|
|||
|
||||
ctunnel
|
||||
}
|
||||
|
||||
fn stopped(&self) -> bool {
|
||||
self.data
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.stopped
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WgPeer {
|
||||
|
@ -427,6 +439,8 @@ pub struct WgTunnelListener {
|
|||
conn_recv: ConnReceiver,
|
||||
conn_send: Option<ConnSender>,
|
||||
|
||||
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||
|
||||
tasks: JoinSet<()>,
|
||||
}
|
||||
|
||||
|
@ -441,6 +455,8 @@ impl WgTunnelListener {
|
|||
conn_recv,
|
||||
conn_send: Some(conn_send),
|
||||
|
||||
wg_peer_map: Arc::new(DashMap::new()),
|
||||
|
||||
tasks: JoinSet::new(),
|
||||
}
|
||||
}
|
||||
|
@ -453,15 +469,16 @@ impl WgTunnelListener {
|
|||
socket: Arc<UdpSocket>,
|
||||
config: WgConfig,
|
||||
conn_sender: ConnSender,
|
||||
peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
|
||||
) {
|
||||
let mut tasks = JoinSet::new();
|
||||
let peer_map: Arc<DashMap<SocketAddr, WgPeer>> = Arc::new(DashMap::new());
|
||||
|
||||
let peer_map_clone = peer_map.clone();
|
||||
tasks.spawn(async move {
|
||||
loop {
|
||||
peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600);
|
||||
tokio::time::sleep(Duration::from_secs(60)).await;
|
||||
peer_map_clone
|
||||
.retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped());
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
}
|
||||
});
|
||||
|
||||
|
@ -524,6 +541,7 @@ impl TunnelListener for WgTunnelListener {
|
|||
self.get_udp_socket(),
|
||||
self.config.clone(),
|
||||
self.conn_send.take().unwrap(),
|
||||
self.wg_peer_map.clone(),
|
||||
));
|
||||
|
||||
Ok(())
|
||||
|
@ -788,4 +806,36 @@ pub mod tests {
|
|||
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
|
||||
_tunnel_pingpong(listener, connector).await
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn wg_server_erase_from_map_after_close() {
|
||||
let (server_cfg, client_cfg) = create_wg_config();
|
||||
let mut listener =
|
||||
WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg);
|
||||
listener.listen().await.unwrap();
|
||||
|
||||
const CONN_COUNT: usize = 10;
|
||||
|
||||
tokio::spawn(async move {
|
||||
for _ in 0..CONN_COUNT {
|
||||
let mut connector = WgTunnelConnector::new(
|
||||
"wg://127.0.0.1:5595".parse().unwrap(),
|
||||
client_cfg.clone(),
|
||||
);
|
||||
let ret = connector.connect().await;
|
||||
assert!(ret.is_ok());
|
||||
drop(ret);
|
||||
}
|
||||
});
|
||||
|
||||
for _ in 0..CONN_COUNT {
|
||||
let conn = listener.accept().await;
|
||||
assert!(conn.is_ok());
|
||||
drop(conn);
|
||||
}
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
|
||||
|
||||
assert_eq!(0, listener.wg_peer_map.len());
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user