mirror of
https://github.com/EasyTier/EasyTier.git
synced 2024-11-16 03:32:43 +08:00
fix session_task and session mismatch
This commit is contained in:
parent
4fea3a60d6
commit
d5bc15cf7a
|
@ -1192,22 +1192,4 @@ pub mod tests {
|
|||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn udp_listener() {
|
||||
let p_a = create_mock_peer_manager().await;
|
||||
wait_for_condition(
|
||||
|| async {
|
||||
p_a.get_global_ctx()
|
||||
.get_stun_info_collector()
|
||||
.get_stun_info()
|
||||
.udp_nat_type
|
||||
!= NatType::Unknown as i32
|
||||
},
|
||||
Duration::from_secs(20),
|
||||
)
|
||||
.await;
|
||||
let l = UdpHolePunchListener::new(p_a.clone()).await.unwrap();
|
||||
println!("{:#?}", l.mapped_addr);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -16,7 +16,11 @@ use petgraph::{
|
|||
Directed, Graph,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tokio::{select, sync::Mutex, task::JoinSet};
|
||||
use tokio::{
|
||||
select,
|
||||
sync::Mutex,
|
||||
task::{JoinHandle, JoinSet},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
common::{global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
|
||||
|
@ -602,6 +606,48 @@ type SessionId = u64;
|
|||
|
||||
type AtomicSessionId = atomic_shim::AtomicU64;
|
||||
|
||||
struct SessionTask {
|
||||
task: Arc<std::sync::Mutex<Option<JoinHandle<()>>>>,
|
||||
}
|
||||
|
||||
impl SessionTask {
|
||||
fn new() -> Self {
|
||||
SessionTask {
|
||||
task: Arc::new(std::sync::Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_task(&self, task: JoinHandle<()>) {
|
||||
if let Some(old) = self.task.lock().unwrap().replace(task) {
|
||||
old.abort();
|
||||
}
|
||||
}
|
||||
|
||||
fn is_running(&self) -> bool {
|
||||
if let Some(task) = self.task.lock().unwrap().as_ref() {
|
||||
!task.is_finished()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for SessionTask {
|
||||
fn drop(&mut self) {
|
||||
if let Some(task) = self.task.lock().unwrap().take() {
|
||||
task.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Debug for SessionTask {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("SessionTask")
|
||||
.field("is_running", &self.is_running())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
// if we need to sync route info with one peer, we create a SyncRouteSession with that peer.
|
||||
#[derive(Debug)]
|
||||
struct SyncRouteSession {
|
||||
|
@ -620,6 +666,8 @@ struct SyncRouteSession {
|
|||
|
||||
rpc_tx_count: AtomicU32,
|
||||
rpc_rx_count: AtomicU32,
|
||||
|
||||
task: SessionTask,
|
||||
}
|
||||
|
||||
impl SyncRouteSession {
|
||||
|
@ -639,6 +687,8 @@ impl SyncRouteSession {
|
|||
|
||||
rpc_tx_count: AtomicU32::new(0),
|
||||
rpc_rx_count: AtomicU32::new(0),
|
||||
|
||||
task: SessionTask::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -684,6 +734,20 @@ impl SyncRouteSession {
|
|||
self.dst_saved_peer_info_versions.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn short_debug_string(&self) -> String {
|
||||
format!(
|
||||
"session_dst_peer: {:?}, my_session_id: {:?}, dst_session_id: {:?}, we_are_initiator: {:?}, dst_is_initiator: {:?}, rpc_tx_count: {:?}, rpc_rx_count: {:?}, task: {:?}",
|
||||
self.dst_peer_id,
|
||||
self.my_session_id,
|
||||
self.dst_session_id,
|
||||
self.we_are_initiator,
|
||||
self.dst_is_initiator,
|
||||
self.rpc_tx_count,
|
||||
self.rpc_rx_count,
|
||||
self.task
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
struct PeerRouteServiceImpl {
|
||||
|
@ -756,6 +820,10 @@ impl PeerRouteServiceImpl {
|
|||
self.sessions.remove(&dst_peer_id);
|
||||
}
|
||||
|
||||
fn list_session_peers(&self) -> Vec<PeerId> {
|
||||
self.sessions.iter().map(|x| *x.key()).collect()
|
||||
}
|
||||
|
||||
async fn list_peers_from_interface<T: FromIterator<PeerId>>(&self) -> T {
|
||||
self.interface
|
||||
.lock()
|
||||
|
@ -944,7 +1012,11 @@ impl PeerRouteServiceImpl {
|
|||
dst_peer_id: PeerId,
|
||||
peer_rpc: Arc<PeerRpcManager>,
|
||||
) -> bool {
|
||||
let session = self.get_or_create_session(dst_peer_id);
|
||||
let Some(session) = self.get_session(dst_peer_id) else {
|
||||
// if session not exist, exit the sync loop.
|
||||
return true;
|
||||
};
|
||||
|
||||
let my_peer_id = self.my_peer_id;
|
||||
|
||||
let (peer_infos, conn_bitmap) = self.build_sync_request(&session);
|
||||
|
@ -1018,7 +1090,6 @@ impl PeerRouteServiceImpl {
|
|||
struct RouteSessionManager {
|
||||
service_impl: Weak<PeerRouteServiceImpl>,
|
||||
peer_rpc: Weak<PeerRpcManager>,
|
||||
session_tasks: Arc<DashMap<PeerId, JoinSet<()>>>,
|
||||
|
||||
sync_now_broadcast: tokio::sync::broadcast::Sender<()>,
|
||||
}
|
||||
|
@ -1026,14 +1097,6 @@ struct RouteSessionManager {
|
|||
impl Debug for RouteSessionManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RouteSessionManager")
|
||||
.field(
|
||||
"session_tasks",
|
||||
&self
|
||||
.session_tasks
|
||||
.iter()
|
||||
.map(|x| *x.key())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
.field("dump_sessions", &self.dump_sessions())
|
||||
.finish()
|
||||
}
|
||||
|
@ -1101,7 +1164,6 @@ impl RouteSessionManager {
|
|||
RouteSessionManager {
|
||||
service_impl: Arc::downgrade(&service_impl),
|
||||
peer_rpc: Arc::downgrade(&peer_rpc),
|
||||
session_tasks: Arc::new(DashMap::new()),
|
||||
|
||||
sync_now_broadcast: tokio::sync::broadcast::channel(100).0,
|
||||
}
|
||||
|
@ -1143,7 +1205,6 @@ impl RouteSessionManager {
|
|||
|
||||
fn stop_session(&self, peer_id: PeerId) -> Result<(), Error> {
|
||||
tracing::warn!(?peer_id, "stop ospf sync session");
|
||||
self.session_tasks.remove(&peer_id);
|
||||
let Some(service_impl) = self.service_impl.upgrade() else {
|
||||
return Err(Error::Stopped);
|
||||
};
|
||||
|
@ -1151,24 +1212,15 @@ impl RouteSessionManager {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn start_session(&self, peer_id: PeerId) -> Result<Arc<SyncRouteSession>, Error> {
|
||||
let Some(service_impl) = self.service_impl.upgrade() else {
|
||||
return Err(Error::Stopped);
|
||||
};
|
||||
|
||||
tracing::warn!(?service_impl.my_peer_id, ?peer_id, "start ospf sync session");
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
tasks.spawn(Self::session_task(
|
||||
self.peer_rpc.clone(),
|
||||
self.service_impl.clone(),
|
||||
peer_id,
|
||||
self.sync_now_broadcast.subscribe(),
|
||||
));
|
||||
|
||||
let session = service_impl.get_or_create_session(peer_id);
|
||||
self.session_tasks.insert(peer_id, tasks);
|
||||
Ok(session)
|
||||
fn start_session_task(&self, session: &Arc<SyncRouteSession>) {
|
||||
if !session.task.is_running() {
|
||||
session.task.set_task(tokio::spawn(Self::session_task(
|
||||
self.peer_rpc.clone(),
|
||||
self.service_impl.clone(),
|
||||
session.dst_peer_id,
|
||||
self.sync_now_broadcast.subscribe(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
fn get_or_start_session(&self, peer_id: PeerId) -> Result<Arc<SyncRouteSession>, Error> {
|
||||
|
@ -1176,11 +1228,11 @@ impl RouteSessionManager {
|
|||
return Err(Error::Stopped);
|
||||
};
|
||||
|
||||
if let Some(session) = service_impl.get_session(peer_id) {
|
||||
return Ok(session);
|
||||
}
|
||||
tracing::info!(?service_impl.my_peer_id, ?peer_id, "start ospf sync session");
|
||||
|
||||
self.start_session(peer_id)
|
||||
let session = service_impl.get_or_create_session(peer_id);
|
||||
self.start_session_task(&session);
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
|
@ -1267,9 +1319,10 @@ impl RouteSessionManager {
|
|||
// clear sessions that are neither dst_initiator or we_are_initiator.
|
||||
for peer_id in session_peers.iter() {
|
||||
if let Some(session) = service_impl.get_session(*peer_id) {
|
||||
if session.dst_is_initiator.load(Ordering::Relaxed)
|
||||
if (session.dst_is_initiator.load(Ordering::Relaxed)
|
||||
|| session.we_are_initiator.load(Ordering::Relaxed)
|
||||
|| session.need_sync_initiator_info.load(Ordering::Relaxed)
|
||||
|| session.need_sync_initiator_info.load(Ordering::Relaxed))
|
||||
&& session.task.is_running()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
@ -1283,10 +1336,11 @@ impl RouteSessionManager {
|
|||
}
|
||||
|
||||
fn list_session_peers(&self) -> Vec<PeerId> {
|
||||
self.session_tasks
|
||||
.iter()
|
||||
.map(|x| *x.key())
|
||||
.collect::<Vec<_>>()
|
||||
let Some(service_impl) = self.service_impl.upgrade() else {
|
||||
return vec![];
|
||||
};
|
||||
|
||||
service_impl.list_session_peers()
|
||||
}
|
||||
|
||||
fn dump_sessions(&self) -> Result<String, Error> {
|
||||
|
@ -1296,10 +1350,12 @@ impl RouteSessionManager {
|
|||
|
||||
let mut ret = format!("my_peer_id: {:?}\n", service_impl.my_peer_id);
|
||||
for item in service_impl.sessions.iter() {
|
||||
ret += format!(" session: {:?}, we_are_initiator: {:?}, dst_is_initiator: {:?}, need_sync_initiator_info: {:?}\n",
|
||||
item.key(), item.value().we_are_initiator.load(Ordering::Relaxed),
|
||||
item.value().dst_is_initiator.load(Ordering::Relaxed),
|
||||
item.value().need_sync_initiator_info.load(Ordering::Relaxed)).as_str();
|
||||
ret += format!(
|
||||
" session: {}, {}\n",
|
||||
item.key(),
|
||||
item.value().short_debug_string()
|
||||
)
|
||||
.as_str();
|
||||
}
|
||||
|
||||
Ok(ret.to_string())
|
||||
|
@ -1582,8 +1638,9 @@ mod tests {
|
|||
assert_eq!(2, r_a.service_impl.synced_route_info.peer_infos.len());
|
||||
assert_eq!(2, r_b.service_impl.synced_route_info.peer_infos.len());
|
||||
|
||||
assert_eq!(1, r_a.session_mgr.session_tasks.len());
|
||||
assert_eq!(1, r_b.session_mgr.session_tasks.len());
|
||||
for s in r_a.service_impl.sessions.iter() {
|
||||
assert!(s.value().task.is_running());
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
r_a.service_impl
|
||||
|
@ -1619,7 +1676,12 @@ mod tests {
|
|||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
assert_eq!(0, r_a.session_mgr.session_tasks.len());
|
||||
|
||||
wait_for_condition(
|
||||
|| async { r_a.service_impl.sessions.is_empty() },
|
||||
Duration::from_secs(5),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -1687,11 +1749,15 @@ mod tests {
|
|||
connect_peer_manager(p_e.clone(), last_p.clone()).await;
|
||||
|
||||
wait_for_condition(
|
||||
|| async { r_e.session_mgr.session_tasks.len() == 1 },
|
||||
|| async { r_e.session_mgr.list_session_peers().len() == 1 },
|
||||
Duration::from_secs(3),
|
||||
)
|
||||
.await;
|
||||
|
||||
for s in r_e.service_impl.sessions.iter() {
|
||||
assert!(s.value().task.is_running());
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_secs(2)).await;
|
||||
|
||||
check_rpc_counter(&r_e, last_p.my_peer_id(), 2, 2);
|
||||
|
|
Loading…
Reference in New Issue
Block a user