fix session_task and session mismatch

This commit is contained in:
sijie.sun 2024-08-03 11:01:17 +08:00 committed by Sijie.Sun
parent 4fea3a60d6
commit d5bc15cf7a
2 changed files with 115 additions and 67 deletions

View File

@ -1192,22 +1192,4 @@ pub mod tests {
)
.await;
}
#[tokio::test]
async fn udp_listener() {
let p_a = create_mock_peer_manager().await;
wait_for_condition(
|| async {
p_a.get_global_ctx()
.get_stun_info_collector()
.get_stun_info()
.udp_nat_type
!= NatType::Unknown as i32
},
Duration::from_secs(20),
)
.await;
let l = UdpHolePunchListener::new(p_a.clone()).await.unwrap();
println!("{:#?}", l.mapped_addr);
}
}

View File

@ -16,7 +16,11 @@ use petgraph::{
Directed, Graph,
};
use serde::{Deserialize, Serialize};
use tokio::{select, sync::Mutex, task::JoinSet};
use tokio::{
select,
sync::Mutex,
task::{JoinHandle, JoinSet},
};
use crate::{
common::{global_ctx::ArcGlobalCtx, stun::StunInfoCollectorTrait, PeerId},
@ -602,6 +606,48 @@ type SessionId = u64;
type AtomicSessionId = atomic_shim::AtomicU64;
struct SessionTask {
task: Arc<std::sync::Mutex<Option<JoinHandle<()>>>>,
}
impl SessionTask {
fn new() -> Self {
SessionTask {
task: Arc::new(std::sync::Mutex::new(None)),
}
}
fn set_task(&self, task: JoinHandle<()>) {
if let Some(old) = self.task.lock().unwrap().replace(task) {
old.abort();
}
}
fn is_running(&self) -> bool {
if let Some(task) = self.task.lock().unwrap().as_ref() {
!task.is_finished()
} else {
false
}
}
}
impl Drop for SessionTask {
fn drop(&mut self) {
if let Some(task) = self.task.lock().unwrap().take() {
task.abort();
}
}
}
impl Debug for SessionTask {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionTask")
.field("is_running", &self.is_running())
.finish()
}
}
// if we need to sync route info with one peer, we create a SyncRouteSession with that peer.
#[derive(Debug)]
struct SyncRouteSession {
@ -620,6 +666,8 @@ struct SyncRouteSession {
rpc_tx_count: AtomicU32,
rpc_rx_count: AtomicU32,
task: SessionTask,
}
impl SyncRouteSession {
@ -639,6 +687,8 @@ impl SyncRouteSession {
rpc_tx_count: AtomicU32::new(0),
rpc_rx_count: AtomicU32::new(0),
task: SessionTask::new(),
}
}
@ -684,6 +734,20 @@ impl SyncRouteSession {
self.dst_saved_peer_info_versions.clear();
}
}
fn short_debug_string(&self) -> String {
format!(
"session_dst_peer: {:?}, my_session_id: {:?}, dst_session_id: {:?}, we_are_initiator: {:?}, dst_is_initiator: {:?}, rpc_tx_count: {:?}, rpc_rx_count: {:?}, task: {:?}",
self.dst_peer_id,
self.my_session_id,
self.dst_session_id,
self.we_are_initiator,
self.dst_is_initiator,
self.rpc_tx_count,
self.rpc_rx_count,
self.task
)
}
}
struct PeerRouteServiceImpl {
@ -756,6 +820,10 @@ impl PeerRouteServiceImpl {
self.sessions.remove(&dst_peer_id);
}
fn list_session_peers(&self) -> Vec<PeerId> {
self.sessions.iter().map(|x| *x.key()).collect()
}
async fn list_peers_from_interface<T: FromIterator<PeerId>>(&self) -> T {
self.interface
.lock()
@ -944,7 +1012,11 @@ impl PeerRouteServiceImpl {
dst_peer_id: PeerId,
peer_rpc: Arc<PeerRpcManager>,
) -> bool {
let session = self.get_or_create_session(dst_peer_id);
let Some(session) = self.get_session(dst_peer_id) else {
// if session not exist, exit the sync loop.
return true;
};
let my_peer_id = self.my_peer_id;
let (peer_infos, conn_bitmap) = self.build_sync_request(&session);
@ -1018,7 +1090,6 @@ impl PeerRouteServiceImpl {
struct RouteSessionManager {
service_impl: Weak<PeerRouteServiceImpl>,
peer_rpc: Weak<PeerRpcManager>,
session_tasks: Arc<DashMap<PeerId, JoinSet<()>>>,
sync_now_broadcast: tokio::sync::broadcast::Sender<()>,
}
@ -1026,14 +1097,6 @@ struct RouteSessionManager {
impl Debug for RouteSessionManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouteSessionManager")
.field(
"session_tasks",
&self
.session_tasks
.iter()
.map(|x| *x.key())
.collect::<Vec<_>>(),
)
.field("dump_sessions", &self.dump_sessions())
.finish()
}
@ -1101,7 +1164,6 @@ impl RouteSessionManager {
RouteSessionManager {
service_impl: Arc::downgrade(&service_impl),
peer_rpc: Arc::downgrade(&peer_rpc),
session_tasks: Arc::new(DashMap::new()),
sync_now_broadcast: tokio::sync::broadcast::channel(100).0,
}
@ -1143,7 +1205,6 @@ impl RouteSessionManager {
fn stop_session(&self, peer_id: PeerId) -> Result<(), Error> {
tracing::warn!(?peer_id, "stop ospf sync session");
self.session_tasks.remove(&peer_id);
let Some(service_impl) = self.service_impl.upgrade() else {
return Err(Error::Stopped);
};
@ -1151,24 +1212,15 @@ impl RouteSessionManager {
Ok(())
}
fn start_session(&self, peer_id: PeerId) -> Result<Arc<SyncRouteSession>, Error> {
let Some(service_impl) = self.service_impl.upgrade() else {
return Err(Error::Stopped);
};
tracing::warn!(?service_impl.my_peer_id, ?peer_id, "start ospf sync session");
let mut tasks = JoinSet::new();
tasks.spawn(Self::session_task(
self.peer_rpc.clone(),
self.service_impl.clone(),
peer_id,
self.sync_now_broadcast.subscribe(),
));
let session = service_impl.get_or_create_session(peer_id);
self.session_tasks.insert(peer_id, tasks);
Ok(session)
fn start_session_task(&self, session: &Arc<SyncRouteSession>) {
if !session.task.is_running() {
session.task.set_task(tokio::spawn(Self::session_task(
self.peer_rpc.clone(),
self.service_impl.clone(),
session.dst_peer_id,
self.sync_now_broadcast.subscribe(),
)));
}
}
fn get_or_start_session(&self, peer_id: PeerId) -> Result<Arc<SyncRouteSession>, Error> {
@ -1176,11 +1228,11 @@ impl RouteSessionManager {
return Err(Error::Stopped);
};
if let Some(session) = service_impl.get_session(peer_id) {
return Ok(session);
}
tracing::info!(?service_impl.my_peer_id, ?peer_id, "start ospf sync session");
self.start_session(peer_id)
let session = service_impl.get_or_create_session(peer_id);
self.start_session_task(&session);
Ok(session)
}
#[tracing::instrument(skip(self))]
@ -1267,9 +1319,10 @@ impl RouteSessionManager {
// clear sessions that are neither dst_initiator or we_are_initiator.
for peer_id in session_peers.iter() {
if let Some(session) = service_impl.get_session(*peer_id) {
if session.dst_is_initiator.load(Ordering::Relaxed)
if (session.dst_is_initiator.load(Ordering::Relaxed)
|| session.we_are_initiator.load(Ordering::Relaxed)
|| session.need_sync_initiator_info.load(Ordering::Relaxed)
|| session.need_sync_initiator_info.load(Ordering::Relaxed))
&& session.task.is_running()
{
continue;
}
@ -1283,10 +1336,11 @@ impl RouteSessionManager {
}
fn list_session_peers(&self) -> Vec<PeerId> {
self.session_tasks
.iter()
.map(|x| *x.key())
.collect::<Vec<_>>()
let Some(service_impl) = self.service_impl.upgrade() else {
return vec![];
};
service_impl.list_session_peers()
}
fn dump_sessions(&self) -> Result<String, Error> {
@ -1296,10 +1350,12 @@ impl RouteSessionManager {
let mut ret = format!("my_peer_id: {:?}\n", service_impl.my_peer_id);
for item in service_impl.sessions.iter() {
ret += format!(" session: {:?}, we_are_initiator: {:?}, dst_is_initiator: {:?}, need_sync_initiator_info: {:?}\n",
item.key(), item.value().we_are_initiator.load(Ordering::Relaxed),
item.value().dst_is_initiator.load(Ordering::Relaxed),
item.value().need_sync_initiator_info.load(Ordering::Relaxed)).as_str();
ret += format!(
" session: {}, {}\n",
item.key(),
item.value().short_debug_string()
)
.as_str();
}
Ok(ret.to_string())
@ -1582,8 +1638,9 @@ mod tests {
assert_eq!(2, r_a.service_impl.synced_route_info.peer_infos.len());
assert_eq!(2, r_b.service_impl.synced_route_info.peer_infos.len());
assert_eq!(1, r_a.session_mgr.session_tasks.len());
assert_eq!(1, r_b.session_mgr.session_tasks.len());
for s in r_a.service_impl.sessions.iter() {
assert!(s.value().task.is_running());
}
assert_eq!(
r_a.service_impl
@ -1619,7 +1676,12 @@ mod tests {
Duration::from_secs(5),
)
.await;
assert_eq!(0, r_a.session_mgr.session_tasks.len());
wait_for_condition(
|| async { r_a.service_impl.sessions.is_empty() },
Duration::from_secs(5),
)
.await;
}
#[tokio::test]
@ -1687,11 +1749,15 @@ mod tests {
connect_peer_manager(p_e.clone(), last_p.clone()).await;
wait_for_condition(
|| async { r_e.session_mgr.session_tasks.len() == 1 },
|| async { r_e.session_mgr.list_session_peers().len() == 1 },
Duration::from_secs(3),
)
.await;
for s in r_e.service_impl.sessions.iter() {
assert!(s.value().task.is_running());
}
tokio::time::sleep(Duration::from_secs(2)).await;
check_rpc_counter(&r_e, last_p.my_peer_id(), 2, 2);