support encryption (#60)

This commit is contained in:
Sijie.Sun 2024-04-27 13:44:59 +08:00 committed by GitHub
parent 69651ae3fd
commit fcc73159b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
23 changed files with 489 additions and 81 deletions

14
Cargo.lock generated
View File

@ -344,9 +344,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
version = "2.4.1"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1"
[[package]]
name = "bitvec"
@ -1288,6 +1288,7 @@ dependencies = [
"atomicbox",
"auto_impl",
"base64 0.21.7",
"bitflags 2.5.0",
"boringtun",
"bytecodec",
"byteorder",
@ -1319,6 +1320,7 @@ dependencies = [
"rand 0.8.5",
"rcgen",
"reqwest",
"ring 0.16.20",
"rkyv",
"rstest",
"rustls",
@ -2561,7 +2563,7 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.5.0",
"libc",
]
@ -2897,7 +2899,7 @@ version = "0.27.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.5.0",
"cfg-if",
"libc",
"memoffset",
@ -3039,7 +3041,7 @@ version = "0.10.64"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.5.0",
"cfg-if",
"foreign-types",
"libc",
@ -4157,7 +4159,7 @@ version = "0.38.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949"
dependencies = [
"bitflags 2.4.1",
"bitflags 2.5.0",
"errno",
"libc",
"linux-raw-sys",

View File

@ -71,10 +71,10 @@ impl NetworkConfig {
.with_context(|| format!("failed to parse instance id: {}", self.instance_id))?,
);
cfg.set_inst_name(self.network_name.clone());
cfg.set_network_identity(NetworkIdentity {
network_name: self.network_name.clone(),
network_secret: self.network_secret.clone(),
});
cfg.set_network_identity(NetworkIdentity::new(
self.network_name.clone(),
self.network_secret.clone(),
));
if self.virtual_ipv4.len() > 0 {
cfg.set_ipv4(

View File

@ -131,6 +131,8 @@ pathfinding = "4.9.1"
# for encryption
boringtun = { version = "0.6.0" }
ring = { version = "0.16" }
bitflags = "2.5"
# for cli
tabled = "0.15.*"

View File

@ -164,7 +164,7 @@ message HandshakeRequest {
uint32 version = 3;
repeated string features = 4;
string network_name = 5;
string network_secret = 6;
bytes network_secret_digrest = 6;
}
message TaRpcPacket {

View File

@ -6,6 +6,8 @@ use std::{
use anyhow::Context;
use serde::{Deserialize, Serialize};
use crate::tunnel::generate_digest_from_str;
#[auto_impl::auto_impl(Box, &)]
pub trait ConfigLoader: Send + Sync {
fn get_id(&self) -> uuid::Uuid;
@ -52,17 +54,49 @@ pub trait ConfigLoader: Send + Sync {
fn dump(&self) -> String;
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]
pub type NetworkSecretDigest = [u8; 32];
#[derive(Debug, Clone, Deserialize, Serialize, Default)]
pub struct NetworkIdentity {
pub network_name: String,
pub network_secret: String,
pub network_secret: Option<String>,
#[serde(skip)]
pub network_secret_digest: Option<NetworkSecretDigest>,
}
impl PartialEq for NetworkIdentity {
fn eq(&self, other: &Self) -> bool {
if self.network_name != other.network_name {
return false;
}
if self.network_secret.is_some()
&& other.network_secret.is_some()
&& self.network_secret != other.network_secret
{
return false;
}
if self.network_secret_digest.is_some()
&& other.network_secret_digest.is_some()
&& self.network_secret_digest != other.network_secret_digest
{
return false;
}
return true;
}
}
impl NetworkIdentity {
pub fn new(network_name: String, network_secret: String) -> Self {
let mut network_secret_digest = [0u8; 32];
generate_digest_from_str(&network_name, &network_secret, &mut network_secret_digest);
NetworkIdentity {
network_name,
network_secret,
network_secret: Some(network_secret),
network_secret_digest: Some(network_secret_digest),
}
}
@ -106,6 +140,8 @@ pub struct VpnPortalConfig {
pub struct Flags {
#[derivative(Default(value = "\"tcp\".to_string()"))]
pub default_protocol: String,
#[derivative(Default(value = "true"))]
pub enable_encryption: bool,
}
#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)]

View File

@ -47,6 +47,9 @@ pub enum Error {
#[error("message decode error: {0}")]
MessageDecodeError(String),
#[error("secret key error: {0}")]
SecretKeyError(String),
}
pub type Result<T> = result::Result<T, Error>;

View File

@ -1,4 +1,8 @@
use std::sync::{Arc, Mutex};
use std::collections::hash_map::DefaultHasher;
use std::{
hash::Hasher,
sync::{Arc, Mutex},
};
use crate::rpc::PeerConnInfo;
use crossbeam::atomic::AtomicCell;
@ -203,6 +207,23 @@ impl GlobalCtx {
pub fn get_flags(&self) -> Flags {
self.config.get_flags()
}
pub fn get_128_key(&self) -> [u8; 16] {
let mut key = [0u8; 16];
let secret = self
.config
.get_network_identity()
.network_secret
.unwrap_or_default();
// fill key according to network secret
let mut hasher = DefaultHasher::new();
hasher.write(secret.as_bytes());
key[0..8].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&key[0..8]);
key[8..16].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&key[0..16]);
key
}
}
#[cfg(test)]

View File

@ -94,8 +94,10 @@ pub async fn create_connector_by_url(
let dst_addr =
crate::tunnels::check_scheme_and_get_socket_addr::<SocketAddr>(&url, "wg")?;
let nid = global_ctx.get_network_identity();
let wg_config =
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
let mut connector = WgTunnelConnector::new(url, wg_config);
set_bind_addr_for_peer_connector(
&mut connector,

View File

@ -124,6 +124,14 @@ and the vpn client is in network of 10.14.14.0/24"
#[arg(long, help = "default protocol to use when connecting to peers")]
default_protocol: Option<String>,
#[arg(
short = 'u',
long,
help = "disable encryption for peers communication, default is false, must be same with peers",
default_value = "false"
)]
disable_encryption: bool,
#[arg(
long,
help = "use multi-thread runtime, default is single-thread",
@ -136,10 +144,10 @@ impl From<Cli> for TomlConfigLoader {
fn from(cli: Cli) -> Self {
let cfg = TomlConfigLoader::default();
cfg.set_inst_name(cli.instance_name.clone());
cfg.set_network_identity(NetworkIdentity {
network_name: cli.network_name.clone(),
network_secret: cli.network_secret.clone(),
});
cfg.set_network_identity(NetworkIdentity::new(
cli.network_name.clone(),
cli.network_secret.clone(),
));
cfg.set_netns(cli.net_ns.clone());
if let Some(ipv4) = &cli.ipv4 {
@ -254,11 +262,12 @@ impl From<Cli> for TomlConfigLoader {
});
}
let mut f = cfg.get_flags();
if cli.default_protocol.is_some() {
let mut f = cfg.get_flags();
f.default_protocol = cli.default_protocol.as_ref().unwrap().clone();
cfg.set_flags(f);
}
f.enable_encryption = !cli.disable_encryption;
cfg.set_flags(f);
cfg
}

View File

@ -30,8 +30,10 @@ pub fn get_listener_by_url(
"udp" => Box::new(UdpTunnelListener::new(l.clone())),
"wg" => {
let nid = ctx.get_network_identity();
let wg_config =
WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret);
let wg_config = WgConfig::new_from_network_identity(
&nid.network_name,
&nid.network_secret.unwrap_or_default(),
);
Box::new(WgTunnelListener::new(l.clone(), wg_config))
}
"quic" => Box::new(QUICTunnelListener::new(l.clone())),

View File

@ -13,7 +13,7 @@ use crate::{
},
tunnel::{
common::{reserve_buf, FramedWriter, TunnelWrapper, ZCPacketToBytes},
packet_def::{ZCPacket, ZCPacketType},
packet_def::{ZCPacket, ZCPacketType, TAIL_RESERVED_SIZE},
StreamItem, Tunnel, TunnelError,
},
};
@ -70,18 +70,17 @@ impl Stream for TunStream {
let ret = ready!(g.as_pin_mut().poll_read(cx, &mut buf));
let len = buf.filled().len();
unsafe { self_mut.cur_buf.advance_mut(len) };
if len == 0 {
return Poll::Ready(None);
}
unsafe { self_mut.cur_buf.advance_mut(len + TAIL_RESERVED_SIZE) };
let mut ret_buf = self_mut.cur_buf.split();
let cur_len = ret_buf.len();
ret_buf.truncate(cur_len - TAIL_RESERVED_SIZE);
match ret {
Ok(_) => {
if len == 0 {
return Poll::Ready(None);
}
Poll::Ready(Some(Ok(ZCPacket::new_from_buf(
self_mut.cur_buf.split(),
ZCPacketType::NIC,
))))
}
Ok(_) => Poll::Ready(Some(Ok(ZCPacket::new_from_buf(ret_buf, ZCPacketType::NIC)))),
Err(err) => {
println!("tun stream error: {:?}", err);
Poll::Ready(None)

View File

@ -0,0 +1,34 @@
use crate::tunnel::packet_def::ZCPacket;
pub mod ring_aes_gcm;
#[derive(thiserror::Error, Debug)]
pub enum Error {
#[error("packet is not encrypted")]
NotEcrypted,
#[error("packet is too short. len: {0}")]
PacketTooShort(usize),
#[error("decryption failed")]
DecryptionFailed,
#[error("encryption failed")]
EncryptionFailed,
#[error("invalid tag. tag: {0:?}")]
InvalidTag(Vec<u8>),
}
pub trait Encryptor: Send + Sync + 'static {
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>;
}
pub struct NullCipher;
impl Encryptor for NullCipher {
fn encrypt(&self, _zc_packet: &mut ZCPacket) -> Result<(), Error> {
Ok(())
}
fn decrypt(&self, _zc_packet: &mut ZCPacket) -> Result<(), Error> {
Ok(())
}
}

View File

@ -0,0 +1,161 @@
use rand::RngCore;
use ring::aead::{self};
use ring::aead::{LessSafeKey, UnboundKey};
use zerocopy::{AsBytes, FromBytes};
use crate::tunnel::packet_def::{AesGcmTail, ZCPacket, AES_GCM_ENCRYPTION_RESERVED};
use super::{Encryptor, Error};
#[derive(Clone)]
pub struct AesGcmCipher {
pub(crate) cipher: AesGcmEnum,
}
pub enum AesGcmEnum {
AesGCM128(LessSafeKey, [u8; 16]),
AesGCM256(LessSafeKey, [u8; 32]),
}
impl Clone for AesGcmEnum {
fn clone(&self) -> Self {
match &self {
AesGcmEnum::AesGCM128(_, key) => {
let c =
LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, key.as_slice()).unwrap());
AesGcmEnum::AesGCM128(c, *key)
}
AesGcmEnum::AesGCM256(_, key) => {
let c =
LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, key.as_slice()).unwrap());
AesGcmEnum::AesGCM256(c, *key)
}
}
}
}
impl AesGcmCipher {
pub fn new_128(key: [u8; 16]) -> Self {
let cipher = LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, &key).unwrap());
Self {
cipher: AesGcmEnum::AesGCM128(cipher, key),
}
}
pub fn new_256(key: [u8; 32]) -> Self {
let cipher = LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, &key).unwrap());
Self {
cipher: AesGcmEnum::AesGCM256(cipher, key),
}
}
}
impl Encryptor for AesGcmCipher {
fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if !pm_header.is_encrypted() {
return Err(Error::NotEcrypted);
}
let payload_len = zc_packet.payload().len();
if payload_len < AES_GCM_ENCRYPTION_RESERVED {
return Err(Error::PacketTooShort(zc_packet.payload().len()));
}
let text_and_tag_len = payload_len - AES_GCM_ENCRYPTION_RESERVED + 16;
let aes_tail = AesGcmTail::ref_from_suffix(zc_packet.payload()).unwrap();
let nonce = aead::Nonce::assume_unique_for_key(aes_tail.nonce.clone());
let rs = match &self.cipher {
AesGcmEnum::AesGCM128(cipher, _) => cipher.open_in_place(
nonce,
aead::Aad::empty(),
&mut zc_packet.mut_payload()[..text_and_tag_len],
),
AesGcmEnum::AesGCM256(cipher, _) => cipher.open_in_place(
nonce,
aead::Aad::empty(),
&mut zc_packet.mut_payload()[..text_and_tag_len],
),
};
if let Err(_) = rs {
return Err(Error::DecryptionFailed);
}
let pm_header = zc_packet.mut_peer_manager_header().unwrap();
pm_header.set_encrypted(false);
let old_len = zc_packet.buf_len();
zc_packet
.mut_inner()
.truncate(old_len - AES_GCM_ENCRYPTION_RESERVED);
return Ok(());
}
fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> {
let pm_header = zc_packet.peer_manager_header().unwrap();
if pm_header.is_encrypted() {
tracing::warn!(?zc_packet, "packet is already encrypted");
return Ok(());
}
let mut tail = AesGcmTail::default();
rand::thread_rng().fill_bytes(&mut tail.nonce);
let nonce = aead::Nonce::assume_unique_for_key(tail.nonce.clone());
let rs = match &self.cipher {
AesGcmEnum::AesGCM128(cipher, _) => cipher.seal_in_place_separate_tag(
nonce,
aead::Aad::empty(),
zc_packet.mut_payload(),
),
AesGcmEnum::AesGCM256(cipher, _) => cipher.seal_in_place_separate_tag(
nonce,
aead::Aad::empty(),
zc_packet.mut_payload(),
),
};
return match rs {
Ok(tag) => {
let tag = tag.as_ref();
if tag.len() != 16 {
return Err(Error::InvalidTag(tag.to_vec()));
}
tail.tag.copy_from_slice(tag);
let pm_header = zc_packet.mut_peer_manager_header().unwrap();
pm_header.set_encrypted(true);
zc_packet.mut_inner().extend_from_slice(tail.as_bytes());
Ok(())
}
Err(_) => Err(Error::EncryptionFailed),
};
}
}
#[cfg(test)]
mod tests {
use crate::{
peers::encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor},
tunnel::packet_def::{ZCPacket, ZCPacketType, AES_GCM_ENCRYPTION_RESERVED},
};
#[test]
fn test_aes_gcm_cipher() {
let key = [0u8; 16];
let cipher = AesGcmCipher::new_128(key);
let text = b"1234567";
let mut packet = ZCPacket::new_with_payload(text);
packet.fill_peer_manager_hdr(0, 0, 0);
cipher.encrypt(&mut packet).unwrap();
assert_eq!(
packet.payload().len(),
text.len() + AES_GCM_ENCRYPTION_RESERVED
);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), true);
cipher.decrypt(&mut packet).unwrap();
assert_eq!(packet.payload(), text);
assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), false);
}
}

View File

@ -141,6 +141,10 @@ impl ForeignNetworkClient {
self.get_next_hop(peer_id).is_some()
}
pub fn is_peer_public_node(&self, peer_id: &PeerId) -> bool {
self.peer_map.has_peer(*peer_id)
}
pub fn get_next_hop(&self, peer_id: PeerId) -> Option<PeerId> {
if self.peer_map.has_peer(peer_id) {
return Some(peer_id.clone());

View File

@ -212,8 +212,13 @@ impl ForeignNetworkManager {
peer_conn.get_network_identity().network_name.clone(),
);
if entry.network.network_secret != peer_conn.get_network_identity().network_secret {
return Err(anyhow::anyhow!("network secret not match").into());
if entry.network != peer_conn.get_network_identity() {
return Err(anyhow::anyhow!(
"network secret not match. exp: {:?} real: {:?}",
entry.network,
peer_conn.get_network_identity()
)
.into());
}
Ok(entry.peer_map.add_new_peer_conn(peer_conn).await)
@ -337,10 +342,10 @@ mod tests {
let (s, _r) = tokio::sync::mpsc::channel(1000);
let peer_mgr = Arc::new(PeerManager::new(
RouteAlgoType::Ospf,
get_mock_global_ctx_with_network(Some(NetworkIdentity {
network_name: network.to_string(),
network_secret: network.to_string(),
})),
get_mock_global_ctx_with_network(Some(NetworkIdentity::new(
network.to_string(),
network.to_string(),
))),
s,
));
replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown);

View File

@ -14,6 +14,8 @@ pub mod zc_peer_conn;
pub mod foreign_network_client;
pub mod foreign_network_manager;
pub mod encrypt;
#[cfg(test)]
pub mod tests;

View File

@ -4,6 +4,7 @@ use std::{
sync::{Arc, Weak},
};
use anyhow::Context;
use async_trait::async_trait;
use futures::StreamExt;
@ -31,6 +32,7 @@ use crate::{
};
use super::{
encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor, NullCipher},
foreign_network_client::ForeignNetworkClient,
foreign_network_manager::ForeignNetworkManager,
peer_map::PeerMap,
@ -49,6 +51,8 @@ struct RpcTransport {
packet_recv: Mutex<UnboundedReceiver<ZCPacket>>,
peer_rpc_tspt_sender: UnboundedSender<ZCPacket>,
encryptor: Arc<Box<dyn Encryptor>>,
}
#[async_trait::async_trait]
@ -57,7 +61,7 @@ impl PeerRpcManagerTransport for RpcTransport {
self.my_peer_id
}
async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> {
let foreign_peers = self
.foreign_peers
.lock()
@ -75,8 +79,17 @@ impl PeerRpcManagerTransport for RpcTransport {
?self.my_peer_id,
"send msg to peer via gateway",
);
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
peers.send_msg_directly(msg, gateway_id).await
} else if foreign_peers.has_next_hop(dst_peer_id) {
if !foreign_peers.is_peer_public_node(&dst_peer_id) {
// do not encrypt for msg sending to public node
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
}
tracing::debug!(
?dst_peer_id,
?self.my_peer_id,
@ -134,6 +147,8 @@ pub struct PeerManager {
foreign_network_manager: Arc<ForeignNetworkManager>,
foreign_network_client: Arc<ForeignNetworkClient>,
encryptor: Arc<Box<dyn Encryptor>>,
}
impl Debug for PeerManager {
@ -161,6 +176,13 @@ impl PeerManager {
my_peer_id,
));
let encryptor: Arc<Box<dyn Encryptor>> =
Arc::new(if global_ctx.get_flags().enable_encryption {
Box::new(AesGcmCipher::new_128(global_ctx.get_128_key()))
} else {
Box::new(NullCipher)
});
// TODO: remove these because we have impl pipeline processor.
let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel();
let rpc_tspt = Arc::new(RpcTransport {
@ -169,6 +191,7 @@ impl PeerManager {
foreign_peers: Mutex::new(None),
packet_recv: Mutex::new(peer_rpc_tspt_recv),
peer_rpc_tspt_sender,
encryptor: encryptor.clone(),
});
let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone()));
@ -218,9 +241,20 @@ impl PeerManager {
foreign_network_manager,
foreign_network_client,
encryptor,
}
}
async fn add_new_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> {
if self.global_ctx.get_network_identity() != peer_conn.get_network_identity() {
return Err(Error::SecretKeyError(
"network identity not match".to_string(),
));
}
Ok(self.peers.add_new_peer_conn(peer_conn).await)
}
pub async fn add_client_tunnel(
&self,
tunnel: Box<dyn Tunnel>,
@ -229,8 +263,10 @@ impl PeerManager {
peer.do_handshake_as_client().await?;
let conn_id = peer.get_conn_id();
let peer_id = peer.get_peer_id();
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
if peer.get_network_identity().network_name
== self.global_ctx.get_network_identity().network_name
{
self.add_new_peer_conn(peer).await?;
} else {
self.foreign_network_client.add_new_peer_conn(peer).await;
}
@ -254,8 +290,10 @@ impl PeerManager {
tracing::info!("add tunnel as server start");
let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel);
peer.do_handshake_as_server().await?;
if peer.get_network_identity() == self.global_ctx.get_network_identity() {
self.peers.add_new_peer_conn(peer).await;
if peer.get_network_identity().network_name
== self.global_ctx.get_network_identity().network_name
{
self.add_new_peer_conn(peer).await?;
} else {
self.foreign_network_manager.add_peer_conn(peer).await?;
}
@ -268,9 +306,10 @@ impl PeerManager {
let my_peer_id = self.my_peer_id;
let peers = self.peers.clone();
let pipe_line = self.peer_packet_process_pipeline.clone();
let encryptor = self.encryptor.clone();
self.tasks.lock().await.spawn(async move {
log::trace!("start_peer_recv");
while let Some(ret) = recv.next().await {
while let Some(mut ret) = recv.next().await {
let Some(hdr) = ret.peer_manager_header() else {
tracing::warn!(?ret, "invalid packet, skip");
continue;
@ -285,6 +324,13 @@ impl PeerManager {
tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error");
}
} else {
if let Err(e) = encryptor
.decrypt(&mut ret)
.with_context(|| "decrypt failed")
{
tracing::error!(?e, "decrypt failed");
}
let mut processed = false;
let mut zc_packet = Some(ret);
let mut idx = 0;
@ -490,7 +536,12 @@ impl PeerManager {
return Ok(());
}
msg.fill_peer_manager_hdr(self.my_peer_id, 0, packet::PacketType::Data as u8);
self.run_nic_packet_process_pipeline(&mut msg).await;
self.encryptor
.encrypt(&mut msg)
.with_context(|| "encrypt failed")?;
let mut errs: Vec<Error> = vec![];
let mut msg = Some(msg);
@ -503,8 +554,10 @@ impl PeerManager {
};
let peer_id = &dst_peers[i];
msg.fill_peer_manager_hdr(self.my_peer_id, *peer_id, packet::PacketType::Data as u8);
msg.mut_peer_manager_header()
.unwrap()
.to_peer_id
.set(*peer_id);
if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await {
if let Err(e) = self.peers.send_msg_directly(msg, gateway).await {

View File

@ -24,8 +24,9 @@ use zerocopy::AsBytes;
use crate::{
common::{
config::{NetworkIdentity, NetworkSecretDigest},
error::Error,
global_ctx::{ArcGlobalCtx, NetworkIdentity},
global_ctx::ArcGlobalCtx,
PeerId,
},
peers::packet::PacketType,
@ -129,10 +130,17 @@ impl PeerConn {
));
};
let rsp = rsp?;
let rsp = HandshakeRequest::decode(rsp.payload())
.map_err(|e| Error::WaitRespError(format!("decode handshake response error: {:?}", e)));
let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| {
Error::WaitRespError(format!("decode handshake response error: {:?}", e))
})?;
return Ok(rsp.unwrap());
if rsp.network_secret_digrest.len() != std::mem::size_of::<NetworkSecretDigest>() {
return Err(Error::WaitRespError(
"invalid network secret digest".to_owned(),
));
}
return Ok(rsp);
}
async fn wait_handshake_loop(&mut self) -> Result<HandshakeRequest, Error> {
@ -152,14 +160,16 @@ impl PeerConn {
async fn send_handshake(&mut self) -> Result<(), Error> {
let network = self.global_ctx.get_network_identity();
let req = HandshakeRequest {
let mut req = HandshakeRequest {
magic: MAGIC,
my_peer_id: self.my_peer_id,
version: VERSION,
features: Vec::new(),
network_name: network.network_name.clone(),
network_secret: network.network_secret.clone(),
..Default::default()
};
req.network_secret_digrest
.extend_from_slice(&network.network_secret_digest.unwrap_or_default());
let hs_req = req.encode_to_vec();
let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes());
@ -297,10 +307,16 @@ impl PeerConn {
pub fn get_network_identity(&self) -> NetworkIdentity {
let info = self.info.as_ref().unwrap();
NetworkIdentity {
let mut ret = NetworkIdentity {
network_name: info.network_name.clone(),
network_secret: info.network_secret.clone(),
}
..Default::default()
};
ret.network_secret_digest = Some([0u8; 32]);
ret.network_secret_digest
.as_mut()
.unwrap()
.copy_from_slice(&info.network_secret_digrest);
ret
}
pub fn set_close_event_sender(&mut self, sender: mpsc::Sender<PeerConnId>) {

View File

@ -87,7 +87,11 @@ pub async fn init_three_node(proto: &str) -> Vec<Instance> {
"wg://10.1.1.1:11011".parse().unwrap(),
WgConfig::new_from_network_identity(
&inst1.get_global_ctx().get_network_identity().network_name,
&inst1.get_global_ctx().get_network_identity().network_secret,
&inst1
.get_global_ctx()
.get_network_identity()
.network_secret
.unwrap_or_default(),
),
));
}
@ -243,7 +247,11 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str
"wg://10.1.2.3:11011".parse().unwrap(),
WgConfig::new_from_network_identity(
&inst4.get_global_ctx().get_network_identity().network_name,
&inst4.get_global_ctx().get_network_identity().network_secret,
&inst4
.get_global_ctx()
.get_network_identity()
.network_secret
.unwrap_or_default(),
),
));
} else {
@ -376,10 +384,8 @@ pub async fn foreign_network_forward_nic_data() {
prepare_linux_namespaces();
let center_node_config = get_inst_config("inst1", Some("net_a"), "10.144.144.1");
center_node_config.set_network_identity(NetworkIdentity {
network_name: "center".to_string(),
network_secret: "".to_string(),
});
center_node_config
.set_network_identity(NetworkIdentity::new("center".to_string(), "".to_string()));
let mut center_inst = Instance::new(center_node_config);
let mut inst1 = Instance::new(get_inst_config("inst1", Some("net_b"), "10.144.145.1"));
@ -450,7 +456,7 @@ fn run_wireguard_client(
log::info!("endpoint");
// Peer endpoint and interval
peer.endpoint = Some(endpoint);
peer.persistent_keepalive_interval = Some(25);
peer.persistent_keepalive_interval = Some(1);
for ip in allowed_ips {
peer.allowed_ips.push(IpAddrMask::from_str(ip.as_str())?);
}
@ -502,7 +508,7 @@ pub async fn wireguard_vpn_portal() {
// ping other node in network
wait_for_condition(
|| async { ping_test("net_d", "10.144.144.1").await },
Duration::from_secs(5),
Duration::from_secs(5000),
)
.await;
wait_for_condition(
@ -514,7 +520,7 @@ pub async fn wireguard_vpn_portal() {
// ping portal node
wait_for_condition(
|| async { ping_test("net_d", "10.144.144.3").await },
Duration::from_secs(500),
Duration::from_secs(5),
)
.await;
}

View File

@ -1,3 +1,5 @@
use std::collections::hash_map::DefaultHasher;
use std::hash::Hasher;
use std::{net::SocketAddr, pin::Pin, sync::Arc};
use async_trait::async_trait;
@ -197,3 +199,17 @@ impl TunnelUrl {
})
}
}
pub fn generate_digest_from_str(str1: &str, str2: &str, digest: &mut [u8]) {
let mut hasher = DefaultHasher::new();
hasher.write(str1.as_bytes());
hasher.write(str2.as_bytes());
assert_eq!(digest.len() % 8, 0, "digest length must be multiple of 8");
let shard_count = digest.len() / 8;
for i in 0..shard_count {
digest[i * 8..(i + 1) * 8].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&digest[..(i + 1) * 8]);
}
}

View File

@ -56,16 +56,53 @@ pub enum PacketType {
Route = 7,
}
bitflags::bitflags! {
struct PeerManagerHeaderFlags: u8 {
const ENCRYPTED = 0b0000_0001;
}
}
#[repr(C, packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
pub struct PeerManagerHeader {
pub from_peer_id: U32<DefaultEndian>,
pub to_peer_id: U32<DefaultEndian>,
pub packet_type: u8,
pub flags: u8,
reserved: U16<DefaultEndian>,
pub len: U32<DefaultEndian>,
}
pub const PEER_MANAGER_HEADER_SIZE: usize = std::mem::size_of::<PeerManagerHeader>();
impl PeerManagerHeader {
pub fn is_encrypted(&self) -> bool {
PeerManagerHeaderFlags::from_bits(self.flags)
.unwrap()
.contains(PeerManagerHeaderFlags::ENCRYPTED)
}
pub fn set_encrypted(&mut self, encrypted: bool) {
let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap();
if encrypted {
flags.insert(PeerManagerHeaderFlags::ENCRYPTED);
} else {
flags.remove(PeerManagerHeaderFlags::ENCRYPTED);
}
self.flags = flags.bits();
}
}
// reserve the space for aes tag and nonce
#[repr(C, packed)]
#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)]
pub struct AesGcmTail {
pub tag: [u8; 16],
pub nonce: [u8; 12],
}
pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::<AesGcmTail>();
pub const TAIL_RESERVED_SIZE: usize = AES_GCM_ENCRYPTION_RESERVED;
const fn max(a: usize, b: usize) -> usize {
[a, b][(a < b) as usize]
}
@ -308,6 +345,7 @@ impl ZCPacket {
hdr.from_peer_id.set(from_peer_id);
hdr.to_peer_id.set(to_peer_id);
hdr.packet_type = packet_type;
hdr.flags = 0;
hdr.len.set(payload_len as u32);
}

View File

@ -1,7 +1,5 @@
use std::{
collections::hash_map::DefaultHasher,
fmt::{Debug, Formatter},
hash::Hasher,
net::SocketAddr,
pin::Pin,
sync::{atomic::AtomicBool, Arc},
@ -33,6 +31,7 @@ use crate::{
use super::{
check_scheme_and_get_socket_addr,
common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures},
generate_digest_from_str,
packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE},
ring::create_ring_tunnel_pair,
Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream,
@ -62,16 +61,7 @@ pub struct WgConfig {
impl WgConfig {
pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self {
let mut my_sec = [0u8; 32];
let mut hasher = DefaultHasher::new();
hasher.write(network_name.as_bytes());
hasher.write(network_secret.as_bytes());
my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&my_sec[0..8]);
my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&my_sec[0..16]);
my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes());
hasher.write(&my_sec[0..24]);
my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes());
generate_digest_from_str(network_name, network_secret, &mut my_sec);
let my_secret_key = StaticSecret::from(my_sec);
let my_public_key = PublicKey::from(&my_secret_key);
@ -491,6 +481,7 @@ impl WgTunnelListener {
let mut buf = vec![0u8; MAX_PACKET];
loop {
tracing::info!("Waiting for incoming UDP packet");
let Ok((n, addr)) = socket.recv_from(&mut buf).await else {
tracing::error!("Failed to receive from UDP socket");
break;

View File

@ -10,6 +10,7 @@ use dashmap::DashMap;
use futures::StreamExt;
use pnet::packet::ipv4::Ipv4Packet;
use tokio::task::JoinSet;
use tracing::Level;
use crate::{
common::{
@ -31,7 +32,11 @@ use super::VpnPortal;
type WgPeerIpTable = Arc<DashMap<Ipv4Addr, Arc<ClientEntry>>>;
pub(crate) fn get_wg_config_for_portal(nid: &NetworkIdentity) -> WgConfig {
let key_seed = format!("{}{}", nid.network_name, nid.network_secret);
let key_seed = format!(
"{}{}",
nid.network_name,
nid.network_secret.as_ref().unwrap_or(&"".to_string())
);
WgConfig::new_for_portal(&key_seed, &key_seed)
}
@ -166,6 +171,7 @@ impl WireGuardImpl {
.await;
}
#[tracing::instrument(skip(self), err(level = Level::WARN))]
async fn start(&self) -> anyhow::Result<()> {
let mut l = WgTunnelListener::new(
format!("wg://{}", self.listenr_addr).parse().unwrap(),