fix ring tunnel cannot close (#51)

This commit is contained in:
Sijie.Sun 2024-04-07 11:35:22 +08:00 committed by GitHub
parent 727ef37ae4
commit 50e14798d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 199 additions and 100 deletions

View File

@ -42,13 +42,22 @@ pub type PeerConnId = uuid::Uuid;
macro_rules! wait_response {
($stream: ident, $out_var:ident, $pattern:pat_param => $value:expr) => {
let rsp_vec = timeout(Duration::from_secs(1), $stream.next()).await;
if rsp_vec.is_err() {
let Ok(rsp_vec) = timeout(Duration::from_secs(1), $stream.next()).await else {
return Err(TunnelError::WaitRespError(
"wait handshake response timeout".to_owned(),
));
}
let rsp_vec = rsp_vec.unwrap().unwrap()?;
};
let Some(rsp_vec) = rsp_vec else {
return Err(TunnelError::WaitRespError(
"wait handshake response get none".to_owned(),
));
};
let Ok(rsp_vec) = rsp_vec else {
return Err(TunnelError::WaitRespError(format!(
"wait handshake response get error {}",
rsp_vec.err().unwrap()
)));
};
let $out_var;
let rsp_bytes = Packet::decode(&rsp_vec);

View File

@ -1,6 +1,9 @@
use std::{
collections::HashMap,
sync::{atomic::AtomicBool, Arc},
sync::{
atomic::{AtomicBool, AtomicU32},
Arc,
},
task::Poll,
};
@ -22,28 +25,55 @@ use uuid::Uuid;
use crate::tunnels::{SinkError, SinkItem};
use super::{
build_url_from_socket_addr, check_scheme_and_get_socket_addr, DatagramSink, DatagramStream,
Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
build_url_from_socket_addr, check_scheme_and_get_socket_addr, common::FramedTunnel,
DatagramSink, DatagramStream, Tunnel, TunnelConnector, TunnelError, TunnelInfo, TunnelListener,
};
static RING_TUNNEL_CAP: usize = 1000;
struct Ring {
id: Uuid,
ring: ArrayQueue<SinkItem>,
consume_notify: Notify,
produce_notify: Notify,
closed: AtomicBool,
}
impl Ring {
fn new(cap: usize, id: uuid::Uuid) -> Self {
Self {
id,
ring: ArrayQueue::new(cap),
consume_notify: Notify::new(),
produce_notify: Notify::new(),
closed: AtomicBool::new(false),
}
}
fn close(&self) {
self.closed
.store(true, std::sync::atomic::Ordering::Relaxed);
self.produce_notify.notify_one();
}
fn closed(&self) -> bool {
self.closed.load(std::sync::atomic::Ordering::Relaxed)
}
}
pub struct RingTunnel {
id: Uuid,
ring: Arc<ArrayQueue<SinkItem>>,
consume_notify: Arc<Notify>,
produce_notify: Arc<Notify>,
closed: Arc<AtomicBool>,
ring: Arc<Ring>,
sender_counter: Arc<AtomicU32>,
}
impl RingTunnel {
pub fn new(cap: usize) -> Self {
let id = Uuid::new_v4();
RingTunnel {
id: Uuid::new_v4(),
ring: Arc::new(ArrayQueue::new(cap)),
consume_notify: Arc::new(Notify::new()),
produce_notify: Arc::new(Notify::new()),
closed: Arc::new(AtomicBool::new(false)),
id: id.clone(),
ring: Arc::new(Ring::new(cap, id)),
sender_counter: Arc::new(AtomicU32::new(1)),
}
}
@ -55,27 +85,24 @@ impl RingTunnel {
fn recv_stream(&self) -> impl DatagramStream {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
stream! {
loop {
if closed.load(std::sync::atomic::Ordering::Relaxed) {
log::warn!("ring recv tunnel {:?} closed", id);
yield Err(TunnelError::CommonError("Closed".to_owned()));
}
match ring.pop() {
match ring.ring.pop() {
Some(v) => {
let mut out = BytesMut::new();
out.extend_from_slice(&v);
consume_notify.notify_one();
ring.consume_notify.notify_one();
log::trace!("id: {}, recv buffer, len: {:?}, buf: {:?}", id, v.len(), &v);
yield Ok(out);
},
None => {
if ring.closed() {
log::warn!("ring recv tunnel {:?} closed", id);
yield Err(TunnelError::CommonError("ring closed".to_owned()));
}
log::trace!("waiting recv buffer, id: {}", id);
produce_notify.notified().await;
ring.produce_notify.notified().await;
}
}
}
@ -84,18 +111,13 @@ impl RingTunnel {
fn send_sink(&self) -> impl DatagramSink {
let ring = self.ring.clone();
let produce_notify = self.produce_notify.clone();
let consume_notify = self.consume_notify.clone();
let closed = self.closed.clone();
let id = self.id;
// type T = RingTunnel;
let sender_counter = self.sender_counter.clone();
use tokio::task::JoinHandle;
struct T {
ring: RingTunnel,
ring: Arc<Ring>,
wait_consume_task: Option<JoinHandle<()>>,
sender_counter: Arc<AtomicU32>,
}
impl T {
@ -110,16 +132,15 @@ impl RingTunnel {
}
if self_mut.wait_consume_task.is_none() {
let id = self_mut.ring.id;
let consume_notify = self_mut.ring.consume_notify.clone();
let ring = self_mut.ring.ring.clone();
let ring = self_mut.ring.clone();
let task = async move {
log::trace!(
"waiting ring consume done, expected_size: {}, id: {}",
expected_size,
id
);
while ring.len() > expected_size {
consume_notify.notified().await;
while ring.ring.len() > expected_size {
ring.consume_notify.notified().await;
}
log::trace!(
"ring consume done, expected_size: {}, id: {}",
@ -147,6 +168,12 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
if self.ring.closed() {
return Poll::Ready(Err(TunnelError::CommonError(
"ring closed during ready".to_owned(),
)
.into()));
}
let expected_size = self.ring.ring.capacity() - 1;
match self.wait_ring_consume(cx, expected_size) {
Poll::Ready(_) => Poll::Ready(Ok(())),
@ -158,6 +185,11 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>,
item: SinkItem,
) -> Result<(), Self::Error> {
if self.ring.closed() {
return Err(
TunnelError::CommonError("ring closed during send".to_owned()).into(),
);
}
log::trace!("id: {}, send buffer, buf: {:?}", self.ring.id, &item);
self.ring.ring.push(item).unwrap();
self.ring.produce_notify.notify_one();
@ -168,6 +200,12 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
if self.ring.closed() {
return Poll::Ready(Err(TunnelError::CommonError(
"ring closed during flush".to_owned(),
)
.into()));
}
Poll::Ready(Ok(()))
}
@ -175,24 +213,38 @@ impl RingTunnel {
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.ring
.closed
.store(true, std::sync::atomic::Ordering::Relaxed);
log::warn!("ring tunnel send {:?} closed", self.ring.id);
self.ring.produce_notify.notify_one();
self.ring.close();
Poll::Ready(Ok(()))
}
}
impl Drop for T {
fn drop(&mut self) {
let rem = self
.sender_counter
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if rem == 1 {
self.ring.close()
}
}
}
sender_counter.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
T {
ring: RingTunnel {
id,
ring,
consume_notify,
produce_notify,
closed,
},
ring,
wait_consume_task: None,
sender_counter,
}
}
}
impl Drop for RingTunnel {
fn drop(&mut self) {
let rem = self
.sender_counter
.fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
if rem == 1 {
self.ring.close()
}
}
}
@ -213,11 +265,6 @@ impl Tunnel for RingTunnel {
fn info(&self) -> Option<TunnelInfo> {
None
// Some(TunnelInfo {
// tunnel_type: "ring".to_owned(),
// local_addr: format!("ring://{}", self.id),
// remote_addr: format!("ring://{}", self.id),
// })
}
}
@ -241,48 +288,29 @@ impl RingTunnelListener {
}
}
}
struct ConnectionForServer {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForServer {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.server.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.client.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
fn get_tunnel_for_client(conn: Arc<Connection>) -> Box<dyn Tunnel> {
FramedTunnel::new_tunnel_with_info(
Box::pin(conn.client.recv_stream()),
conn.server.send_sink(),
TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
})
}
local_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
},
)
}
struct ConnectionForClient {
conn: Arc<Connection>,
}
impl Tunnel for ConnectionForClient {
fn stream(&self) -> Box<dyn DatagramStream> {
Box::new(self.conn.client.recv_stream())
}
fn sink(&self) -> Box<dyn DatagramSink> {
Box::new(self.conn.server.send_sink())
}
fn info(&self) -> Option<TunnelInfo> {
Some(TunnelInfo {
fn get_tunnel_for_server(conn: Arc<Connection>) -> Box<dyn Tunnel> {
FramedTunnel::new_tunnel_with_info(
Box::pin(conn.server.recv_stream()),
conn.client.send_sink(),
TunnelInfo {
tunnel_type: "ring".to_owned(),
local_addr: build_url_from_socket_addr(&self.conn.client.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&self.conn.server.id.into(), "ring").into(),
})
}
local_addr: build_url_from_socket_addr(&conn.server.id.into(), "ring").into(),
remote_addr: build_url_from_socket_addr(&conn.client.id.into(), "ring").into(),
},
)
}
impl RingTunnelListener {
@ -308,7 +336,7 @@ impl TunnelListener for RingTunnelListener {
if let Some(conn) = self.conn_receiver.recv().await {
if conn.server.id == my_addr {
log::info!("accept new conn of key: {}", self.listerner_addr);
return Ok(Box::new(ConnectionForServer { conn }));
return Ok(get_tunnel_for_server(conn));
} else {
tracing::error!(?conn.server.id, ?my_addr, "got new conn with wrong id");
return Err(TunnelError::CommonError(
@ -353,7 +381,7 @@ impl TunnelConnector for RingTunnelConnector {
entry
.send(conn.clone())
.map_err(|_| TunnelError::CommonError("send conn to listner failed".to_owned()))?;
Ok(Box::new(ConnectionForClient { conn }))
Ok(get_tunnel_for_client(conn))
}
fn remote_url(&self) -> url::Url {
@ -367,13 +395,15 @@ pub fn create_ring_tunnel_pair() -> (Box<dyn Tunnel>, Box<dyn Tunnel>) {
server: RingTunnel::new(RING_TUNNEL_CAP),
});
(
Box::new(ConnectionForServer { conn: conn.clone() }),
Box::new(ConnectionForClient { conn }),
Box::new(get_tunnel_for_server(conn.clone())),
Box::new(get_tunnel_for_client(conn)),
)
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use crate::tunnels::common::tests::{_tunnel_bench, _tunnel_pingpong};
use super::*;
@ -393,4 +423,14 @@ mod tests {
let connector = RingTunnelConnector::new(id);
_tunnel_bench(listener, connector).await
}
#[tokio::test]
async fn ring_close() {
let (stunnel, ctunnel) = create_ring_tunnel_pair();
drop(stunnel);
let mut stream = ctunnel.pin_stream();
let ret = stream.next().await;
assert!(ret.as_ref().unwrap().is_err(), "expect Err, got {:?}", ret);
}
}

View File

@ -4,7 +4,7 @@ use std::{
hash::Hasher,
net::SocketAddr,
pin::Pin,
sync::Arc,
sync::{atomic::AtomicBool, Arc},
time::Duration,
};
@ -120,6 +120,7 @@ struct WgPeerData {
sink: Arc<Mutex<Pin<Box<dyn DatagramSink>>>>,
stream: Arc<Mutex<Pin<Box<dyn DatagramStream>>>>,
wg_type: WgType,
stopped: Arc<AtomicBool>,
}
impl Debug for WgPeerData {
@ -366,6 +367,8 @@ impl WgPeer {
tracing::error!("Failed to handle packet from me: {}", e);
}
}
data.stopped
.store(true, std::sync::atomic::Ordering::Relaxed);
}
async fn handle_packet_from_peer(&mut self, packet: &[u8]) {
@ -395,6 +398,7 @@ impl WgPeer {
sink: Arc::new(Mutex::new(stunnel.pin_sink())),
stream: Arc::new(Mutex::new(stunnel.pin_stream())),
wg_type: self.config.wg_type.clone(),
stopped: Arc::new(AtomicBool::new(false)),
};
self.data = Some(data.clone());
@ -403,6 +407,14 @@ impl WgPeer {
ctunnel
}
fn stopped(&self) -> bool {
self.data
.as_ref()
.unwrap()
.stopped
.load(std::sync::atomic::Ordering::Relaxed)
}
}
impl Drop for WgPeer {
@ -427,6 +439,8 @@ pub struct WgTunnelListener {
conn_recv: ConnReceiver,
conn_send: Option<ConnSender>,
wg_peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
tasks: JoinSet<()>,
}
@ -441,6 +455,8 @@ impl WgTunnelListener {
conn_recv,
conn_send: Some(conn_send),
wg_peer_map: Arc::new(DashMap::new()),
tasks: JoinSet::new(),
}
}
@ -453,15 +469,16 @@ impl WgTunnelListener {
socket: Arc<UdpSocket>,
config: WgConfig,
conn_sender: ConnSender,
peer_map: Arc<DashMap<SocketAddr, WgPeer>>,
) {
let mut tasks = JoinSet::new();
let peer_map: Arc<DashMap<SocketAddr, WgPeer>> = Arc::new(DashMap::new());
let peer_map_clone = peer_map.clone();
tasks.spawn(async move {
loop {
peer_map_clone.retain(|_, peer| peer.access_time.elapsed().as_secs() < 600);
tokio::time::sleep(Duration::from_secs(60)).await;
peer_map_clone
.retain(|_, peer| peer.access_time.elapsed().as_secs() < 61 && !peer.stopped());
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
@ -524,6 +541,7 @@ impl TunnelListener for WgTunnelListener {
self.get_udp_socket(),
self.config.clone(),
self.conn_send.take().unwrap(),
self.wg_peer_map.clone(),
));
Ok(())
@ -788,4 +806,36 @@ pub mod tests {
connector.set_bind_addrs(vec!["10.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
#[tokio::test]
async fn wg_server_erase_from_map_after_close() {
let (server_cfg, client_cfg) = create_wg_config();
let mut listener =
WgTunnelListener::new("wg://127.0.0.1:5595".parse().unwrap(), server_cfg);
listener.listen().await.unwrap();
const CONN_COUNT: usize = 10;
tokio::spawn(async move {
for _ in 0..CONN_COUNT {
let mut connector = WgTunnelConnector::new(
"wg://127.0.0.1:5595".parse().unwrap(),
client_cfg.clone(),
);
let ret = connector.connect().await;
assert!(ret.is_ok());
drop(ret);
}
});
for _ in 0..CONN_COUNT {
let conn = listener.accept().await;
assert!(conn.is_ok());
drop(conn);
}
tokio::time::sleep(tokio::time::Duration::from_secs(2)).await;
assert_eq!(0, listener.wg_peer_map.len());
}
}