fix ring tunnel cannot close (#51)

This commit is contained in:
Sijie.Sun 2024-04-07 11:35:22 +08:00 committed by GitHub
parent 727ef37ae4
commit 50e14798d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 199 additions and 100 deletions

View File

@ -42,13 +42,22 @@ pub type PeerConnId = uuid::Uuid;
macro_rules! wait_response { macro_rules! wait_response {
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => { ($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await; let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
if rsp_vec.is_err() {
return Err(TunnelError::WaitRespError( return Err(TunnelError::WaitRespError(
"wait handshake response timeout".to_owned(), "wait handshake response timeout".to_owned(),
)); ));
} };
let rsp_vec = rsp_vec.unwrap().unwrap()?; let Some(rsp_vec) = rsp_vec else {
return Err(TunnelError::WaitRespError(
"wait handshake response get none".to_owned(),
));
};
let Ok(rsp_vec) = rsp_vec else {
return Err(TunnelError::WaitRespError(format!(
"wait handshake response get error {}",
rsp_vec.err().unwrap()
)));
};
let $out_var; let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec); let rsp_bytes = Packet::decode(&rsp_vec);

View File

@ -1,6 +1,9 @@
use std::{ use std::{
collections::HashMap, collections::HashMap,
sync::{atomic::AtomicBool, Arc}, sync::{
atomic::{AtomicBool, AtomicU32},
Arc,
},
task::Poll, task::Poll,
}; };
@ -22,28 +25,55 @@ use uuid::Uuid;
use crate::tunnels::{SinkError, SinkItem}; use crate::tunnels::{SinkError, SinkItem};
use super::{ use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream, build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::FramedTunnel,
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener, DatagramSink, DatagramStream, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
}; };
static RING_TUNNEL_CAP: usize = 1000; static RING_TUNNEL_CAP: usize = 1000;
struct Ring {
id: Uuid,
ring: ArrayQueue<SinkItem>,
consume_notify: Notify,
produce_notify: Notify,
closed: AtomicBool,
}
impl Ring {
fn new(cap: usize, id: uuid::Uuid) -> Self {
Self {
id,
ring: ArrayQueue::new(cap),
consume_notify: Notify::new(),
produce_notify: Notify::new(),
closed: AtomicBool::new(false),
}
}
fn close(&self) {
self.closed
.store(true, std::sync::atomic::Ordering::Relaxed);
self.produce_notify.notify_one();
}
fn closed(&self) -> bool {
self.closed.load(std::sync::atomic::Ordering::Relaxed)
}
}
pub struct RingTunnel { pub struct RingTunnel {
id: Uuid, id: Uuid,
ring: Arc<ArrayQueue<SinkItem>>, ring: Arc<Ring>,
consume_notify: Arc<Notify>, sender_counter: Arc<AtomicU32>,
produce_notify: Arc<Notify>,
closed: Arc<AtomicBool>,
} }
impl RingTunnel { impl RingTunnel {
pub fn new(cap: usize) -> Self { pub fn new(cap: usize) -> Self {
let id = Uuid::new_v4();
RingTunnel { RingTunnel {
id: Uuid::new_v4(), id: id.clone(),
ring: Arc::new(ArrayQueue::new(cap)), ring: Arc::new(Ring::new(cap, id)),
consume_notify: Arc::new(Notify::new()), sender_counter: Arc::new(AtomicU32::new(1)),
produce_notify: Arc::new(Notify::new()),
closed: Arc::new(AtomicBool::new(false)),
} }
} }
@ -55,27 +85,24 @@ impl RingTunnel {
fn recv_stream(&self) -> impl DatagramStream { fn recv_stream(&self) -> impl DatagramStream {
let ring = self.ring.clone(); let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id; let id = self.id;
stream! { stream! {
loop { loop {
if closed.load(std::sync::atomic::Ordering::Relaxed) { match ring.ring.pop() {
log::warn!("ring recv tunnel {:?} closed", id);
yield Err(TunnelError::CommonError("Closed".to_owned()));
}
match ring.pop() {
Some(v) => { Some(v) => {
let mut out = BytesMut::new(); let mut out = BytesMut::new();
out.extend_from_slice(&v); out.extend_from_slice(&v);
consume_notify.notify_one(); ring.consume_notify.notify_one();
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v); log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
yield Ok(out); yield Ok(out);
}, },
None => { None => {
if ring.closed() {
log::warn!("ring recv tunnel {:?} closed", id);
yield Err(TunnelError::CommonError("ring closed".to_owned()));
}
log::trace!("waiting recv buffer, id: {}", id); log::trace!("waiting recv buffer, id: {}", id);
produce_notify.notified().await; ring.produce_notify.notified().await;
} }
} }
} }
@ -84,18 +111,13 @@ impl RingTunnel {
fn send_sink(&self) -> impl DatagramSink { fn send_sink(&self) -> impl DatagramSink {
let ring = self.ring.clone(); let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone(); let sender_counter = self.sender_counter.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
// type T = RingTunnel;
use tokio::task::JoinHandle; use tokio::task::JoinHandle;
struct T { struct T {
ring: RingTunnel, ring: Arc<Ring>,
wait_consume_task: Option<JoinHandle<()>>, wait_consume_task: Option<JoinHandle<()>>,
sender_counter: Arc<AtomicU32>,
} }
impl T { impl T {
@ -110,16 +132,15 @@ impl RingTunnel {
} }
if self_mut.wait_consume_task.is_none() { if self_mut.wait_consume_task.is_none() {
let id = self_mut.ring.id; let id = self_mut.ring.id;
let consume_notify = self_mut.ring.consume_notify.clone(); let ring = self_mut.ring.clone();
let ring = self_mut.ring.ring.clone();
let task = async move { let task = async move {
log::trace!( log::trace!(
"waiting ring consume done, expected_size: {}, id: {}", "waiting ring consume done, expected_size: {}, id: {}",
expected_size, expected_size,
id id
); );
while ring.len() > expected_size { while ring.ring.len() > expected_size {
consume_notify.notified().await; ring.consume_notify.notified().await;
} }
log::trace!( log::trace!(
"ring consume done, expected_size: {}, id: {}", "ring consume done, expected_size: {}, id: {}",
@ -147,6 +168,12 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>, cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> { ) -> std::task::Poll<Result<(), Self::Error>> {
if self.ring.closed() {
return Poll::Ready(Err(TunnelError::CommonError(
"ring closed during ready".to_owned(),
)
.into()));
}
let expected_size = self.ring.ring.capacity() - 1; let expected_size = self.ring.ring.capacity() - 1;
match self.wait_ring_consume(cx, expected_size) { match self.wait_ring_consume(cx, expected_size) {
Poll::Ready(_) => Poll::Ready(Ok(())), Poll::Ready(_) => Poll::Ready(Ok(())),
@ -158,6 +185,11 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
item: SinkItem, item: SinkItem,
) -> Result<(), Self::Error> { ) -> Result<(), Self::Error> {
if self.ring.closed() {
return Err(
TunnelError::CommonError("ring closed during send".to_owned()).into(),
);
}
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item); log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
self.ring.ring.push(item).unwrap(); self.ring.ring.push(item).unwrap();
self.ring.produce_notify.notify_one(); self.ring.produce_notify.notify_one();
@ -168,6 +200,12 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>, _cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> { ) -> std::task::Poll<Result<(), Self::Error>> {
if self.ring.closed() {
return Poll::Ready(Err(TunnelError::CommonError(
"ring closed during flush".to_owned(),
)
.into()));
}
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
@ -175,24 +213,38 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>, self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>, _cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> { ) -> std::task::Poll<Result<(), Self::Error>> {
self.ring self.ring.close();
.closed
.store(true, std::sync::atomic::Ordering::Relaxed);
log::warn!("ring tunnel send {:?} closed", self.ring.id);
self.ring.produce_notify.notify_one();
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
} }
impl Drop for T {
fn drop(&mut self) {
let rem = self
.sender_counter
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if rem == 1 {
self.ring.close()
}
}
}
sender_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
T { T {
ring: RingTunnel { ring,
id,
ring,
consume_notify,
produce_notify,
closed,
},
wait_consume_task: None, wait_consume_task: None,
sender_counter,
}
}
}
impl Drop for RingTunnel {
fn drop(&mut self) {
let rem = self
.sender_counter
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if rem == 1 {
self.ring.close()
} }
} }
} }
@ -213,11 +265,6 @@ impl Tunnel for RingTunnel {
fn info(&self) -> Option<TunnelInfo> { fn info(&self) -> Option<TunnelInfo> {
None None
// Some(TunnelInfo {
// tunnel_type: "ring".to_owned(),
// local_addr: format!("ring://{}", self.id),
// remote_addr: format!("ring://{}", self.id),
// })
} }
} }
@ -241,48 +288,29 @@ impl RingTunnelListener {
} }
} }
} }
struct ConnectionForServer {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForServer { fn get_tunnel_for_client(conn: Arc<Connection>) -> Box<dyn Tunnel> {
fn stream(&self) -> Box<dyn DatagramStream> { FramedTunnel::new_tunnel_with_info(
Box::new(self.conn.server.recv_stream()) Box::pin(conn.client.recv_stream()),
} conn.server.send_sink(),
TunnelInfo {
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.client.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(), tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(), local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(), remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
}) },
} )
} }
struct ConnectionForClient { fn get_tunnel_for_server(conn: Arc<Connection>) -> Box<dyn Tunnel> {
conn: Arc<Connection>, FramedTunnel::new_tunnel_with_info(
} Box::pin(conn.server.recv_stream()),
conn.client.send_sink(),
impl Tunnel for ConnectionForClient { TunnelInfo {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.client.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.server.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
tunnel_type: "ring".to_owned(), tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(), local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(), remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
}) },
} )
} }
impl RingTunnelListener { impl RingTunnelListener {
@ -308,7 +336,7 @@ impl TunnelListener for RingTunnelListener {
if let Some(conn) = self.conn_receiver.recv().await { if let Some(conn) = self.conn_receiver.recv().await {
if conn.server.id == my_addr { if conn.server.id == my_addr {
log::info!("accept new conn of key: {}", self.listerner_addr); log::info!("accept new conn of key: {}", self.listerner_addr);
return Ok(Box::new(ConnectionForServer { conn })); return Ok(get_tunnel_for_server(conn));
} else { } else {
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id"); tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
return Err(TunnelError::CommonError( return Err(TunnelError::CommonError(
@ -353,7 +381,7 @@ impl TunnelConnector for RingTunnelConnector {
entry entry
.send(conn.clone()) .send(conn.clone())
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?; .map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
Ok(Box::new(ConnectionForClient { conn })) Ok(get_tunnel_for_client(conn))
} }
fn remote_url(&self) -> url::Url { fn remote_url(&self) -> url::Url {
@ -367,13 +395,15 @@ pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
server: RingTunnel::new(RING_TUNNEL_CAP), server: RingTunnel::new(RING_TUNNEL_CAP),
}); });
( (
Box::new(ConnectionForServer { conn: conn.clone() }), Box::new(get_tunnel_for_server(conn.clone())),
Box::new(ConnectionForClient { conn }), Box::new(get_tunnel_for_client(conn)),
) )
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures::StreamExt;
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong}; use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use super::*; use super::*;
@ -393,4 +423,14 @@ mod tests {
let connector = RingTunnelConnector::new(id); let connector = RingTunnelConnector::new(id);
_tunnel_bench(listener, connector).await _tunnel_bench(listener, connector).await
} }
#[tokio::test]
async fn ring_close() {
let (stunnel, ctunnel) = create_ring_tunnel_pair();
drop(stunnel);
let mut stream = ctunnel.pin_stream();
let ret = stream.next().await;
assert!(ret.as_ref().unwrap().is_err(), "expect Err, got {:?}", ret);
}
} }

View File

@ -4,7 +4,7 @@ use std::{
hash::Hasher, hash::Hasher,
net::SocketAddr, net::SocketAddr,
pin::Pin, pin::Pin,
sync::Arc, sync::{atomic::AtomicBool, Arc},
time::Duration, time::Duration,
}; };
@ -120,6 +120,7 @@ struct WgPeerData {
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>, sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>, stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
wg_type: WgType, wg_type: WgType,
stopped: Arc<AtomicBool>,
} }
impl Debug for WgPeerData { impl Debug for WgPeerData {
@ -366,6 +367,8 @@ impl WgPeer {
tracing::error!("Failed to handle packet from me: {}", e); tracing::error!("Failed to handle packet from me: {}", e);
} }
} }
data.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
} }
async fn handle_packet_from_peer(&mut self, packet: &[u8]) { async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
@ -395,6 +398,7 @@ impl WgPeer {
sink: Arc::new(Mutex::new(stunnel.pin_sink())), sink: Arc::new(Mutex::new(stunnel.pin_sink())),
stream: Arc::new(Mutex::new(stunnel.pin_stream())), stream: Arc::new(Mutex::new(stunnel.pin_stream())),
wg_type: self.config.wg_type.clone(), wg_type: self.config.wg_type.clone(),
stopped: Arc::new(AtomicBool::new(false)),
}; };
self.data = Some(data.clone()); self.data = Some(data.clone());
@ -403,6 +407,14 @@ impl WgPeer {
ctunnel ctunnel
} }
fn stopped(&self) -> bool {
self.data
.as_ref()
.unwrap()
.stopped
.load(std::sync::atomic::Ordering::Relaxed)
}
} }
impl Drop for WgPeer { impl Drop for WgPeer {
@ -427,6 +439,8 @@ pub struct WgTunnelListener {
conn_recv: ConnReceiver, conn_recv: ConnReceiver,
conn_send: Option<ConnSender>, conn_send: Option<ConnSender>,
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
tasks: JoinSet<()>, tasks: JoinSet<()>,
} }
@ -441,6 +455,8 @@ impl WgTunnelListener {
conn_recv, conn_recv,
conn_send: Some(conn_send), conn_send: Some(conn_send),
wg_peer_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(), tasks: JoinSet::new(),
} }
} }
@ -453,15 +469,16 @@ impl WgTunnelListener {
socket: Arc<UdpSocket>, socket: Arc<UdpSocket>,
config: WgConfig, config: WgConfig,
conn_sender: ConnSender, conn_sender: ConnSender,
peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
) { ) {
let mut tasks = JoinSet::new(); let mut tasks = JoinSet::new();
let peer_map: Arc<DashMap<SocketAddr, WgPeer>> = Arc::new(DashMap::new());
let peer_map_clone = peer_map.clone(); let peer_map_clone = peer_map.clone();
tasks.spawn(async move { tasks.spawn(async move {
loop { loop {
peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600); peer_map_clone
tokio::time::sleep(Duration::from_secs(60)).await; .retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped());
tokio::time::sleep(Duration::from_secs(1)).await;
} }
}); });
@ -524,6 +541,7 @@ impl TunnelListener for WgTunnelListener {
self.get_udp_socket(), self.get_udp_socket(),
self.config.clone(), self.config.clone(),
self.conn_send.take().unwrap(), self.conn_send.take().unwrap(),
self.wg_peer_map.clone(),
)); ));
Ok(()) Ok(())
@ -788,4 +806,36 @@ pub mod tests {
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]); connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
#[tokio::test]
async fn wg_server_erase_from_map_after_close() {
let (server_cfg, client_cfg) = create_wg_config();
let mut listener =
WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg);
listener.listen().await.unwrap();
const CONN_COUNT: usize = 10;
tokio::spawn(async move {
for _ in 0..CONN_COUNT {
let mut connector = WgTunnelConnector::new(
"wg://127.0.0.1:5595".parse().unwrap(),
client_cfg.clone(),
);
let ret = connector.connect().await;
assert!(ret.is_ok());
drop(ret);
}
});
for _ in 0..CONN_COUNT {
let conn = listener.accept().await;
assert!(conn.is_ok());
drop(conn);
}
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
assert_eq!(0, listener.wg_peer_map.len());
}
} }