From fcc73159b3c0d77cb97ac853593ca37e08ac44af Mon Sep 17 00:00:00 2001 From: "Sijie.Sun" Date: Sat, 27 Apr 2024 13:44:59 +0800 Subject: [PATCH] support encryption (#60) --- Cargo.lock | 14 +- easytier-gui/src-tauri/src/main.rs | 8 +- easytier/Cargo.toml | 2 + easytier/proto/cli.proto | 2 +- easytier/src/common/config.rs | 42 ++++- easytier/src/common/error.rs | 3 + easytier/src/common/global_ctx.rs | 23 ++- easytier/src/connector/mod.rs | 6 +- easytier/src/easytier-core.rs | 21 ++- easytier/src/instance/listeners.rs | 6 +- easytier/src/instance/virtual_nic.rs | 21 ++- easytier/src/peers/encrypt/mod.rs | 34 ++++ easytier/src/peers/encrypt/ring_aes_gcm.rs | 161 ++++++++++++++++++ easytier/src/peers/foreign_network_client.rs | 4 + easytier/src/peers/foreign_network_manager.rs | 17 +- easytier/src/peers/mod.rs | 2 + easytier/src/peers/peer_manager.rs | 69 +++++++- easytier/src/peers/zc_peer_conn.rs | 34 +++- easytier/src/tests/three_node.rs | 24 ++- easytier/src/tunnel/mod.rs | 16 ++ easytier/src/tunnel/packet_def.rs | 38 +++++ easytier/src/tunnel/wireguard.rs | 15 +- easytier/src/vpn_portal/wireguard.rs | 8 +- 23 files changed, 489 insertions(+), 81 deletions(-) create mode 100644 easytier/src/peers/encrypt/mod.rs create mode 100644 easytier/src/peers/encrypt/ring_aes_gcm.rs diff --git a/Cargo.lock b/Cargo.lock index 205f60c..4b742ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -344,9 +344,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "bitvec" @@ -1288,6 +1288,7 @@ dependencies = [ "atomicbox", "auto_impl", "base64 0.21.7", + "bitflags 2.5.0", "boringtun", "bytecodec", "byteorder", @@ -1319,6 +1320,7 @@ dependencies = [ "rand 0.8.5", "rcgen", "reqwest", + "ring 0.16.20", "rkyv", "rstest", "rustls", @@ -2561,7 +2563,7 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.5.0", "libc", ] @@ -2897,7 +2899,7 @@ version = "0.27.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2eb04e9c688eff1c89d72b407f168cf79bb9e867a9d3323ed6c01519eb9cc053" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.5.0", "cfg-if", "libc", "memoffset", @@ -3039,7 +3041,7 @@ version = "0.10.64" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95a0481286a310808298130d22dd1fef0fa571e05a8f44ec801801e84b216b1f" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.5.0", "cfg-if", "foreign-types", "libc", @@ -4157,7 +4159,7 @@ version = "0.38.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949" dependencies = [ - "bitflags 2.4.1", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", diff --git a/easytier-gui/src-tauri/src/main.rs b/easytier-gui/src-tauri/src/main.rs index 88c0a18..ad4a472 100644 --- a/easytier-gui/src-tauri/src/main.rs +++ b/easytier-gui/src-tauri/src/main.rs @@ -71,10 +71,10 @@ impl NetworkConfig { .with_context(|| format!("failed to parse instance id: {}", self.instance_id))?, ); cfg.set_inst_name(self.network_name.clone()); - cfg.set_network_identity(NetworkIdentity { - network_name: self.network_name.clone(), - network_secret: self.network_secret.clone(), - }); + cfg.set_network_identity(NetworkIdentity::new( + self.network_name.clone(), + self.network_secret.clone(), + )); if self.virtual_ipv4.len() > 0 { cfg.set_ipv4( diff --git a/easytier/Cargo.toml b/easytier/Cargo.toml index acad937..bfd3476 100644 --- a/easytier/Cargo.toml +++ b/easytier/Cargo.toml @@ -131,6 +131,8 @@ pathfinding = "4.9.1" # for encryption boringtun = { version = "0.6.0" } +ring = { version = "0.16" } +bitflags = "2.5" # for cli tabled = "0.15.*" diff --git a/easytier/proto/cli.proto b/easytier/proto/cli.proto index 6d3267e..bbc08e0 100644 --- a/easytier/proto/cli.proto +++ b/easytier/proto/cli.proto @@ -164,7 +164,7 @@ message HandshakeRequest { uint32 version = 3; repeated string features = 4; string network_name = 5; - string network_secret = 6; + bytes network_secret_digrest = 6; } message TaRpcPacket { diff --git a/easytier/src/common/config.rs b/easytier/src/common/config.rs index b4cae4d..d170638 100644 --- a/easytier/src/common/config.rs +++ b/easytier/src/common/config.rs @@ -6,6 +6,8 @@ use std::{ use anyhow::Context; use serde::{Deserialize, Serialize}; +use crate::tunnel::generate_digest_from_str; + #[auto_impl::auto_impl(Box, &)] pub trait ConfigLoader: Send + Sync { fn get_id(&self) -> uuid::Uuid; @@ -52,17 +54,49 @@ pub trait ConfigLoader: Send + Sync { fn dump(&self) -> String; } -#[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] +pub type NetworkSecretDigest = [u8; 32]; + +#[derive(Debug, Clone, Deserialize, Serialize, Default)] pub struct NetworkIdentity { pub network_name: String, - pub network_secret: String, + pub network_secret: Option, + #[serde(skip)] + pub network_secret_digest: Option, +} + +impl PartialEq for NetworkIdentity { + fn eq(&self, other: &Self) -> bool { + if self.network_name != other.network_name { + return false; + } + + if self.network_secret.is_some() + && other.network_secret.is_some() + && self.network_secret != other.network_secret + { + return false; + } + + if self.network_secret_digest.is_some() + && other.network_secret_digest.is_some() + && self.network_secret_digest != other.network_secret_digest + { + return false; + } + + return true; + } } impl NetworkIdentity { pub fn new(network_name: String, network_secret: String) -> Self { + let mut network_secret_digest = [0u8; 32]; + generate_digest_from_str(&network_name, &network_secret, &mut network_secret_digest); + NetworkIdentity { network_name, - network_secret, + network_secret: Some(network_secret), + network_secret_digest: Some(network_secret_digest), } } @@ -106,6 +140,8 @@ pub struct VpnPortalConfig { pub struct Flags { #[derivative(Default(value = "\"tcp\".to_string()"))] pub default_protocol: String, + #[derivative(Default(value = "true"))] + pub enable_encryption: bool, } #[derive(Debug, Clone, Deserialize, Serialize, PartialEq)] diff --git a/easytier/src/common/error.rs b/easytier/src/common/error.rs index efa7b70..17c7574 100644 --- a/easytier/src/common/error.rs +++ b/easytier/src/common/error.rs @@ -47,6 +47,9 @@ pub enum Error { #[error("message decode error: {0}")] MessageDecodeError(String), + + #[error("secret key error: {0}")] + SecretKeyError(String), } pub type Result = result::Result; diff --git a/easytier/src/common/global_ctx.rs b/easytier/src/common/global_ctx.rs index f4a62d5..dcd71a6 100644 --- a/easytier/src/common/global_ctx.rs +++ b/easytier/src/common/global_ctx.rs @@ -1,4 +1,8 @@ -use std::sync::{Arc, Mutex}; +use std::collections::hash_map::DefaultHasher; +use std::{ + hash::Hasher, + sync::{Arc, Mutex}, +}; use crate::rpc::PeerConnInfo; use crossbeam::atomic::AtomicCell; @@ -203,6 +207,23 @@ impl GlobalCtx { pub fn get_flags(&self) -> Flags { self.config.get_flags() } + + pub fn get_128_key(&self) -> [u8; 16] { + let mut key = [0u8; 16]; + let secret = self + .config + .get_network_identity() + .network_secret + .unwrap_or_default(); + // fill key according to network secret + let mut hasher = DefaultHasher::new(); + hasher.write(secret.as_bytes()); + key[0..8].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&key[0..8]); + key[8..16].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&key[0..16]); + key + } } #[cfg(test)] diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 3b61f26..81863bb 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -94,8 +94,10 @@ pub async fn create_connector_by_url( let dst_addr = crate::tunnels::check_scheme_and_get_socket_addr::(&url, "wg")?; let nid = global_ctx.get_network_identity(); - let wg_config = - WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret); + let wg_config = WgConfig::new_from_network_identity( + &nid.network_name, + &nid.network_secret.unwrap_or_default(), + ); let mut connector = WgTunnelConnector::new(url, wg_config); set_bind_addr_for_peer_connector( &mut connector, diff --git a/easytier/src/easytier-core.rs b/easytier/src/easytier-core.rs index 4c56410..b1b0c52 100644 --- a/easytier/src/easytier-core.rs +++ b/easytier/src/easytier-core.rs @@ -124,6 +124,14 @@ and the vpn client is in network of 10.14.14.0/24" #[arg(long, help = "default protocol to use when connecting to peers")] default_protocol: Option, + #[arg( + short = 'u', + long, + help = "disable encryption for peers communication, default is false, must be same with peers", + default_value = "false" + )] + disable_encryption: bool, + #[arg( long, help = "use multi-thread runtime, default is single-thread", @@ -136,10 +144,10 @@ impl From for TomlConfigLoader { fn from(cli: Cli) -> Self { let cfg = TomlConfigLoader::default(); cfg.set_inst_name(cli.instance_name.clone()); - cfg.set_network_identity(NetworkIdentity { - network_name: cli.network_name.clone(), - network_secret: cli.network_secret.clone(), - }); + cfg.set_network_identity(NetworkIdentity::new( + cli.network_name.clone(), + cli.network_secret.clone(), + )); cfg.set_netns(cli.net_ns.clone()); if let Some(ipv4) = &cli.ipv4 { @@ -254,11 +262,12 @@ impl From for TomlConfigLoader { }); } + let mut f = cfg.get_flags(); if cli.default_protocol.is_some() { - let mut f = cfg.get_flags(); f.default_protocol = cli.default_protocol.as_ref().unwrap().clone(); - cfg.set_flags(f); } + f.enable_encryption = !cli.disable_encryption; + cfg.set_flags(f); cfg } diff --git a/easytier/src/instance/listeners.rs b/easytier/src/instance/listeners.rs index 8148ae9..afb6fdd 100644 --- a/easytier/src/instance/listeners.rs +++ b/easytier/src/instance/listeners.rs @@ -30,8 +30,10 @@ pub fn get_listener_by_url( "udp" => Box::new(UdpTunnelListener::new(l.clone())), "wg" => { let nid = ctx.get_network_identity(); - let wg_config = - WgConfig::new_from_network_identity(&nid.network_name, &nid.network_secret); + let wg_config = WgConfig::new_from_network_identity( + &nid.network_name, + &nid.network_secret.unwrap_or_default(), + ); Box::new(WgTunnelListener::new(l.clone(), wg_config)) } "quic" => Box::new(QUICTunnelListener::new(l.clone())), diff --git a/easytier/src/instance/virtual_nic.rs b/easytier/src/instance/virtual_nic.rs index ad6b27d..75e7d3f 100644 --- a/easytier/src/instance/virtual_nic.rs +++ b/easytier/src/instance/virtual_nic.rs @@ -13,7 +13,7 @@ use crate::{ }, tunnel::{ common::{reserve_buf, FramedWriter, TunnelWrapper, ZCPacketToBytes}, - packet_def::{ZCPacket, ZCPacketType}, + packet_def::{ZCPacket, ZCPacketType, TAIL_RESERVED_SIZE}, StreamItem, Tunnel, TunnelError, }, }; @@ -70,18 +70,17 @@ impl Stream for TunStream { let ret = ready!(g.as_pin_mut().poll_read(cx, &mut buf)); let len = buf.filled().len(); - unsafe { self_mut.cur_buf.advance_mut(len) }; + if len == 0 { + return Poll::Ready(None); + } + unsafe { self_mut.cur_buf.advance_mut(len + TAIL_RESERVED_SIZE) }; + + let mut ret_buf = self_mut.cur_buf.split(); + let cur_len = ret_buf.len(); + ret_buf.truncate(cur_len - TAIL_RESERVED_SIZE); match ret { - Ok(_) => { - if len == 0 { - return Poll::Ready(None); - } - Poll::Ready(Some(Ok(ZCPacket::new_from_buf( - self_mut.cur_buf.split(), - ZCPacketType::NIC, - )))) - } + Ok(_) => Poll::Ready(Some(Ok(ZCPacket::new_from_buf(ret_buf, ZCPacketType::NIC)))), Err(err) => { println!("tun stream error: {:?}", err); Poll::Ready(None) diff --git a/easytier/src/peers/encrypt/mod.rs b/easytier/src/peers/encrypt/mod.rs new file mode 100644 index 0000000..1ff923f --- /dev/null +++ b/easytier/src/peers/encrypt/mod.rs @@ -0,0 +1,34 @@ +use crate::tunnel::packet_def::ZCPacket; + +pub mod ring_aes_gcm; + +#[derive(thiserror::Error, Debug)] +pub enum Error { + #[error("packet is not encrypted")] + NotEcrypted, + #[error("packet is too short. len: {0}")] + PacketTooShort(usize), + #[error("decryption failed")] + DecryptionFailed, + #[error("encryption failed")] + EncryptionFailed, + #[error("invalid tag. tag: {0:?}")] + InvalidTag(Vec), +} + +pub trait Encryptor: Send + Sync + 'static { + fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>; + fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error>; +} + +pub struct NullCipher; + +impl Encryptor for NullCipher { + fn encrypt(&self, _zc_packet: &mut ZCPacket) -> Result<(), Error> { + Ok(()) + } + + fn decrypt(&self, _zc_packet: &mut ZCPacket) -> Result<(), Error> { + Ok(()) + } +} diff --git a/easytier/src/peers/encrypt/ring_aes_gcm.rs b/easytier/src/peers/encrypt/ring_aes_gcm.rs new file mode 100644 index 0000000..059a716 --- /dev/null +++ b/easytier/src/peers/encrypt/ring_aes_gcm.rs @@ -0,0 +1,161 @@ +use rand::RngCore; +use ring::aead::{self}; +use ring::aead::{LessSafeKey, UnboundKey}; +use zerocopy::{AsBytes, FromBytes}; + +use crate::tunnel::packet_def::{AesGcmTail, ZCPacket, AES_GCM_ENCRYPTION_RESERVED}; + +use super::{Encryptor, Error}; + +#[derive(Clone)] +pub struct AesGcmCipher { + pub(crate) cipher: AesGcmEnum, +} + +pub enum AesGcmEnum { + AesGCM128(LessSafeKey, [u8; 16]), + AesGCM256(LessSafeKey, [u8; 32]), +} + +impl Clone for AesGcmEnum { + fn clone(&self) -> Self { + match &self { + AesGcmEnum::AesGCM128(_, key) => { + let c = + LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, key.as_slice()).unwrap()); + AesGcmEnum::AesGCM128(c, *key) + } + AesGcmEnum::AesGCM256(_, key) => { + let c = + LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, key.as_slice()).unwrap()); + AesGcmEnum::AesGCM256(c, *key) + } + } + } +} + +impl AesGcmCipher { + pub fn new_128(key: [u8; 16]) -> Self { + let cipher = LessSafeKey::new(UnboundKey::new(&aead::AES_128_GCM, &key).unwrap()); + Self { + cipher: AesGcmEnum::AesGCM128(cipher, key), + } + } + + pub fn new_256(key: [u8; 32]) -> Self { + let cipher = LessSafeKey::new(UnboundKey::new(&aead::AES_256_GCM, &key).unwrap()); + Self { + cipher: AesGcmEnum::AesGCM256(cipher, key), + } + } +} + +impl Encryptor for AesGcmCipher { + fn decrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { + let pm_header = zc_packet.peer_manager_header().unwrap(); + if !pm_header.is_encrypted() { + return Err(Error::NotEcrypted); + } + + let payload_len = zc_packet.payload().len(); + if payload_len < AES_GCM_ENCRYPTION_RESERVED { + return Err(Error::PacketTooShort(zc_packet.payload().len())); + } + + let text_and_tag_len = payload_len - AES_GCM_ENCRYPTION_RESERVED + 16; + + let aes_tail = AesGcmTail::ref_from_suffix(zc_packet.payload()).unwrap(); + let nonce = aead::Nonce::assume_unique_for_key(aes_tail.nonce.clone()); + + let rs = match &self.cipher { + AesGcmEnum::AesGCM128(cipher, _) => cipher.open_in_place( + nonce, + aead::Aad::empty(), + &mut zc_packet.mut_payload()[..text_and_tag_len], + ), + AesGcmEnum::AesGCM256(cipher, _) => cipher.open_in_place( + nonce, + aead::Aad::empty(), + &mut zc_packet.mut_payload()[..text_and_tag_len], + ), + }; + if let Err(_) = rs { + return Err(Error::DecryptionFailed); + } + + let pm_header = zc_packet.mut_peer_manager_header().unwrap(); + pm_header.set_encrypted(false); + let old_len = zc_packet.buf_len(); + zc_packet + .mut_inner() + .truncate(old_len - AES_GCM_ENCRYPTION_RESERVED); + return Ok(()); + } + + fn encrypt(&self, zc_packet: &mut ZCPacket) -> Result<(), Error> { + let pm_header = zc_packet.peer_manager_header().unwrap(); + if pm_header.is_encrypted() { + tracing::warn!(?zc_packet, "packet is already encrypted"); + return Ok(()); + } + + let mut tail = AesGcmTail::default(); + rand::thread_rng().fill_bytes(&mut tail.nonce); + let nonce = aead::Nonce::assume_unique_for_key(tail.nonce.clone()); + + let rs = match &self.cipher { + AesGcmEnum::AesGCM128(cipher, _) => cipher.seal_in_place_separate_tag( + nonce, + aead::Aad::empty(), + zc_packet.mut_payload(), + ), + AesGcmEnum::AesGCM256(cipher, _) => cipher.seal_in_place_separate_tag( + nonce, + aead::Aad::empty(), + zc_packet.mut_payload(), + ), + }; + return match rs { + Ok(tag) => { + let tag = tag.as_ref(); + if tag.len() != 16 { + return Err(Error::InvalidTag(tag.to_vec())); + } + tail.tag.copy_from_slice(tag); + + let pm_header = zc_packet.mut_peer_manager_header().unwrap(); + pm_header.set_encrypted(true); + zc_packet.mut_inner().extend_from_slice(tail.as_bytes()); + Ok(()) + } + Err(_) => Err(Error::EncryptionFailed), + }; + } +} + +#[cfg(test)] +mod tests { + use crate::{ + peers::encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor}, + tunnel::packet_def::{ZCPacket, ZCPacketType, AES_GCM_ENCRYPTION_RESERVED}, + }; + + #[test] + fn test_aes_gcm_cipher() { + let key = [0u8; 16]; + let cipher = AesGcmCipher::new_128(key); + let text = b"1234567"; + let mut packet = ZCPacket::new_with_payload(text); + packet.fill_peer_manager_hdr(0, 0, 0); + cipher.encrypt(&mut packet).unwrap(); + assert_eq!( + packet.payload().len(), + text.len() + AES_GCM_ENCRYPTION_RESERVED + ); + assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), true); + + cipher.decrypt(&mut packet).unwrap(); + assert_eq!(packet.payload(), text); + assert_eq!(packet.peer_manager_header().unwrap().is_encrypted(), false); + } +} diff --git a/easytier/src/peers/foreign_network_client.rs b/easytier/src/peers/foreign_network_client.rs index b276363..cff5008 100644 --- a/easytier/src/peers/foreign_network_client.rs +++ b/easytier/src/peers/foreign_network_client.rs @@ -141,6 +141,10 @@ impl ForeignNetworkClient { self.get_next_hop(peer_id).is_some() } + pub fn is_peer_public_node(&self, peer_id: &PeerId) -> bool { + self.peer_map.has_peer(*peer_id) + } + pub fn get_next_hop(&self, peer_id: PeerId) -> Option { if self.peer_map.has_peer(peer_id) { return Some(peer_id.clone()); diff --git a/easytier/src/peers/foreign_network_manager.rs b/easytier/src/peers/foreign_network_manager.rs index 6ceecf9..fbcc579 100644 --- a/easytier/src/peers/foreign_network_manager.rs +++ b/easytier/src/peers/foreign_network_manager.rs @@ -212,8 +212,13 @@ impl ForeignNetworkManager { peer_conn.get_network_identity().network_name.clone(), ); - if entry.network.network_secret != peer_conn.get_network_identity().network_secret { - return Err(anyhow::anyhow!("network secret not match").into()); + if entry.network != peer_conn.get_network_identity() { + return Err(anyhow::anyhow!( + "network secret not match. exp: {:?} real: {:?}", + entry.network, + peer_conn.get_network_identity() + ) + .into()); } Ok(entry.peer_map.add_new_peer_conn(peer_conn).await) @@ -337,10 +342,10 @@ mod tests { let (s, _r) = tokio::sync::mpsc::channel(1000); let peer_mgr = Arc::new(PeerManager::new( RouteAlgoType::Ospf, - get_mock_global_ctx_with_network(Some(NetworkIdentity { - network_name: network.to_string(), - network_secret: network.to_string(), - })), + get_mock_global_ctx_with_network(Some(NetworkIdentity::new( + network.to_string(), + network.to_string(), + ))), s, )); replace_stun_info_collector(peer_mgr.clone(), NatType::Unknown); diff --git a/easytier/src/peers/mod.rs b/easytier/src/peers/mod.rs index 0cf06a8..ff6d497 100644 --- a/easytier/src/peers/mod.rs +++ b/easytier/src/peers/mod.rs @@ -14,6 +14,8 @@ pub mod zc_peer_conn; pub mod foreign_network_client; pub mod foreign_network_manager; +pub mod encrypt; + #[cfg(test)] pub mod tests; diff --git a/easytier/src/peers/peer_manager.rs b/easytier/src/peers/peer_manager.rs index d659181..267a74d 100644 --- a/easytier/src/peers/peer_manager.rs +++ b/easytier/src/peers/peer_manager.rs @@ -4,6 +4,7 @@ use std::{ sync::{Arc, Weak}, }; +use anyhow::Context; use async_trait::async_trait; use futures::StreamExt; @@ -31,6 +32,7 @@ use crate::{ }; use super::{ + encrypt::{ring_aes_gcm::AesGcmCipher, Encryptor, NullCipher}, foreign_network_client::ForeignNetworkClient, foreign_network_manager::ForeignNetworkManager, peer_map::PeerMap, @@ -49,6 +51,8 @@ struct RpcTransport { packet_recv: Mutex>, peer_rpc_tspt_sender: UnboundedSender, + + encryptor: Arc>, } #[async_trait::async_trait] @@ -57,7 +61,7 @@ impl PeerRpcManagerTransport for RpcTransport { self.my_peer_id } - async fn send(&self, msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { + async fn send(&self, mut msg: ZCPacket, dst_peer_id: PeerId) -> Result<(), Error> { let foreign_peers = self .foreign_peers .lock() @@ -75,8 +79,17 @@ impl PeerRpcManagerTransport for RpcTransport { ?self.my_peer_id, "send msg to peer via gateway", ); + self.encryptor + .encrypt(&mut msg) + .with_context(|| "encrypt failed")?; peers.send_msg_directly(msg, gateway_id).await } else if foreign_peers.has_next_hop(dst_peer_id) { + if !foreign_peers.is_peer_public_node(&dst_peer_id) { + // do not encrypt for msg sending to public node + self.encryptor + .encrypt(&mut msg) + .with_context(|| "encrypt failed")?; + } tracing::debug!( ?dst_peer_id, ?self.my_peer_id, @@ -134,6 +147,8 @@ pub struct PeerManager { foreign_network_manager: Arc, foreign_network_client: Arc, + + encryptor: Arc>, } impl Debug for PeerManager { @@ -161,6 +176,13 @@ impl PeerManager { my_peer_id, )); + let encryptor: Arc> = + Arc::new(if global_ctx.get_flags().enable_encryption { + Box::new(AesGcmCipher::new_128(global_ctx.get_128_key())) + } else { + Box::new(NullCipher) + }); + // TODO: remove these because we have impl pipeline processor. let (peer_rpc_tspt_sender, peer_rpc_tspt_recv) = mpsc::unbounded_channel(); let rpc_tspt = Arc::new(RpcTransport { @@ -169,6 +191,7 @@ impl PeerManager { foreign_peers: Mutex::new(None), packet_recv: Mutex::new(peer_rpc_tspt_recv), peer_rpc_tspt_sender, + encryptor: encryptor.clone(), }); let peer_rpc_mgr = Arc::new(PeerRpcManager::new(rpc_tspt.clone())); @@ -218,9 +241,20 @@ impl PeerManager { foreign_network_manager, foreign_network_client, + + encryptor, } } + async fn add_new_peer_conn(&self, peer_conn: PeerConn) -> Result<(), Error> { + if self.global_ctx.get_network_identity() != peer_conn.get_network_identity() { + return Err(Error::SecretKeyError( + "network identity not match".to_string(), + )); + } + Ok(self.peers.add_new_peer_conn(peer_conn).await) + } + pub async fn add_client_tunnel( &self, tunnel: Box, @@ -229,8 +263,10 @@ impl PeerManager { peer.do_handshake_as_client().await?; let conn_id = peer.get_conn_id(); let peer_id = peer.get_peer_id(); - if peer.get_network_identity() == self.global_ctx.get_network_identity() { - self.peers.add_new_peer_conn(peer).await; + if peer.get_network_identity().network_name + == self.global_ctx.get_network_identity().network_name + { + self.add_new_peer_conn(peer).await?; } else { self.foreign_network_client.add_new_peer_conn(peer).await; } @@ -254,8 +290,10 @@ impl PeerManager { tracing::info!("add tunnel as server start"); let mut peer = PeerConn::new(self.my_peer_id, self.global_ctx.clone(), tunnel); peer.do_handshake_as_server().await?; - if peer.get_network_identity() == self.global_ctx.get_network_identity() { - self.peers.add_new_peer_conn(peer).await; + if peer.get_network_identity().network_name + == self.global_ctx.get_network_identity().network_name + { + self.add_new_peer_conn(peer).await?; } else { self.foreign_network_manager.add_peer_conn(peer).await?; } @@ -268,9 +306,10 @@ impl PeerManager { let my_peer_id = self.my_peer_id; let peers = self.peers.clone(); let pipe_line = self.peer_packet_process_pipeline.clone(); + let encryptor = self.encryptor.clone(); self.tasks.lock().await.spawn(async move { log::trace!("start_peer_recv"); - while let Some(ret) = recv.next().await { + while let Some(mut ret) = recv.next().await { let Some(hdr) = ret.peer_manager_header() else { tracing::warn!(?ret, "invalid packet, skip"); continue; @@ -285,6 +324,13 @@ impl PeerManager { tracing::error!(?ret, ?to_peer_id, ?from_peer_id, "forward packet error"); } } else { + if let Err(e) = encryptor + .decrypt(&mut ret) + .with_context(|| "decrypt failed") + { + tracing::error!(?e, "decrypt failed"); + } + let mut processed = false; let mut zc_packet = Some(ret); let mut idx = 0; @@ -490,7 +536,12 @@ impl PeerManager { return Ok(()); } + msg.fill_peer_manager_hdr(self.my_peer_id, 0, packet::PacketType::Data as u8); self.run_nic_packet_process_pipeline(&mut msg).await; + self.encryptor + .encrypt(&mut msg) + .with_context(|| "encrypt failed")?; + let mut errs: Vec = vec![]; let mut msg = Some(msg); @@ -503,8 +554,10 @@ impl PeerManager { }; let peer_id = &dst_peers[i]; - - msg.fill_peer_manager_hdr(self.my_peer_id, *peer_id, packet::PacketType::Data as u8); + msg.mut_peer_manager_header() + .unwrap() + .to_peer_id + .set(*peer_id); if let Some(gateway) = self.peers.get_gateway_peer_id(*peer_id).await { if let Err(e) = self.peers.send_msg_directly(msg, gateway).await { diff --git a/easytier/src/peers/zc_peer_conn.rs b/easytier/src/peers/zc_peer_conn.rs index 5d29b26..c79c1e3 100644 --- a/easytier/src/peers/zc_peer_conn.rs +++ b/easytier/src/peers/zc_peer_conn.rs @@ -24,8 +24,9 @@ use zerocopy::AsBytes; use crate::{ common::{ + config::{NetworkIdentity, NetworkSecretDigest}, error::Error, - global_ctx::{ArcGlobalCtx, NetworkIdentity}, + global_ctx::ArcGlobalCtx, PeerId, }, peers::packet::PacketType, @@ -129,10 +130,17 @@ impl PeerConn { )); }; let rsp = rsp?; - let rsp = HandshakeRequest::decode(rsp.payload()) - .map_err(|e| Error::WaitRespError(format!("decode handshake response error: {:?}", e))); + let rsp = HandshakeRequest::decode(rsp.payload()).map_err(|e| { + Error::WaitRespError(format!("decode handshake response error: {:?}", e)) + })?; - return Ok(rsp.unwrap()); + if rsp.network_secret_digrest.len() != std::mem::size_of::() { + return Err(Error::WaitRespError( + "invalid network secret digest".to_owned(), + )); + } + + return Ok(rsp); } async fn wait_handshake_loop(&mut self) -> Result { @@ -152,14 +160,16 @@ impl PeerConn { async fn send_handshake(&mut self) -> Result<(), Error> { let network = self.global_ctx.get_network_identity(); - let req = HandshakeRequest { + let mut req = HandshakeRequest { magic: MAGIC, my_peer_id: self.my_peer_id, version: VERSION, features: Vec::new(), network_name: network.network_name.clone(), - network_secret: network.network_secret.clone(), + ..Default::default() }; + req.network_secret_digrest + .extend_from_slice(&network.network_secret_digest.unwrap_or_default()); let hs_req = req.encode_to_vec(); let mut zc_packet = ZCPacket::new_with_payload(hs_req.as_bytes()); @@ -297,10 +307,16 @@ impl PeerConn { pub fn get_network_identity(&self) -> NetworkIdentity { let info = self.info.as_ref().unwrap(); - NetworkIdentity { + let mut ret = NetworkIdentity { network_name: info.network_name.clone(), - network_secret: info.network_secret.clone(), - } + ..Default::default() + }; + ret.network_secret_digest = Some([0u8; 32]); + ret.network_secret_digest + .as_mut() + .unwrap() + .copy_from_slice(&info.network_secret_digrest); + ret } pub fn set_close_event_sender(&mut self, sender: mpsc::Sender) { diff --git a/easytier/src/tests/three_node.rs b/easytier/src/tests/three_node.rs index bede763..3fc979c 100644 --- a/easytier/src/tests/three_node.rs +++ b/easytier/src/tests/three_node.rs @@ -87,7 +87,11 @@ pub async fn init_three_node(proto: &str) -> Vec { "wg://10.1.1.1:11011".parse().unwrap(), WgConfig::new_from_network_identity( &inst1.get_global_ctx().get_network_identity().network_name, - &inst1.get_global_ctx().get_network_identity().network_secret, + &inst1 + .get_global_ctx() + .get_network_identity() + .network_secret + .unwrap_or_default(), ), )); } @@ -243,7 +247,11 @@ pub async fn proxy_three_node_disconnect_test(#[values("tcp", "wg")] proto: &str "wg://10.1.2.3:11011".parse().unwrap(), WgConfig::new_from_network_identity( &inst4.get_global_ctx().get_network_identity().network_name, - &inst4.get_global_ctx().get_network_identity().network_secret, + &inst4 + .get_global_ctx() + .get_network_identity() + .network_secret + .unwrap_or_default(), ), )); } else { @@ -376,10 +384,8 @@ pub async fn foreign_network_forward_nic_data() { prepare_linux_namespaces(); let center_node_config = get_inst_config("inst1", Some("net_a"), "10.144.144.1"); - center_node_config.set_network_identity(NetworkIdentity { - network_name: "center".to_string(), - network_secret: "".to_string(), - }); + center_node_config + .set_network_identity(NetworkIdentity::new("center".to_string(), "".to_string())); let mut center_inst = Instance::new(center_node_config); let mut inst1 = Instance::new(get_inst_config("inst1", Some("net_b"), "10.144.145.1")); @@ -450,7 +456,7 @@ fn run_wireguard_client( log::info!("endpoint"); // Peer endpoint and interval peer.endpoint = Some(endpoint); - peer.persistent_keepalive_interval = Some(25); + peer.persistent_keepalive_interval = Some(1); for ip in allowed_ips { peer.allowed_ips.push(IpAddrMask::from_str(ip.as_str())?); } @@ -502,7 +508,7 @@ pub async fn wireguard_vpn_portal() { // ping other node in network wait_for_condition( || async { ping_test("net_d", "10.144.144.1").await }, - Duration::from_secs(5), + Duration::from_secs(5000), ) .await; wait_for_condition( @@ -514,7 +520,7 @@ pub async fn wireguard_vpn_portal() { // ping portal node wait_for_condition( || async { ping_test("net_d", "10.144.144.3").await }, - Duration::from_secs(500), + Duration::from_secs(5), ) .await; } diff --git a/easytier/src/tunnel/mod.rs b/easytier/src/tunnel/mod.rs index 671f925..1fa782f 100644 --- a/easytier/src/tunnel/mod.rs +++ b/easytier/src/tunnel/mod.rs @@ -1,3 +1,5 @@ +use std::collections::hash_map::DefaultHasher; +use std::hash::Hasher; use std::{net::SocketAddr, pin::Pin, sync::Arc}; use async_trait::async_trait; @@ -197,3 +199,17 @@ impl TunnelUrl { }) } } + +pub fn generate_digest_from_str(str1: &str, str2: &str, digest: &mut [u8]) { + let mut hasher = DefaultHasher::new(); + hasher.write(str1.as_bytes()); + hasher.write(str2.as_bytes()); + + assert_eq!(digest.len() % 8, 0, "digest length must be multiple of 8"); + + let shard_count = digest.len() / 8; + for i in 0..shard_count { + digest[i * 8..(i + 1) * 8].copy_from_slice(&hasher.finish().to_be_bytes()); + hasher.write(&digest[..(i + 1) * 8]); + } +} diff --git a/easytier/src/tunnel/packet_def.rs b/easytier/src/tunnel/packet_def.rs index 1979572..dec7787 100644 --- a/easytier/src/tunnel/packet_def.rs +++ b/easytier/src/tunnel/packet_def.rs @@ -56,16 +56,53 @@ pub enum PacketType { Route = 7, } +bitflags::bitflags! { + struct PeerManagerHeaderFlags: u8 { + const ENCRYPTED = 0b0000_0001; + } +} + #[repr(C, packed)] #[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] pub struct PeerManagerHeader { pub from_peer_id: U32, pub to_peer_id: U32, pub packet_type: u8, + pub flags: u8, + reserved: U16, pub len: U32, } pub const PEER_MANAGER_HEADER_SIZE: usize = std::mem::size_of::(); +impl PeerManagerHeader { + pub fn is_encrypted(&self) -> bool { + PeerManagerHeaderFlags::from_bits(self.flags) + .unwrap() + .contains(PeerManagerHeaderFlags::ENCRYPTED) + } + + pub fn set_encrypted(&mut self, encrypted: bool) { + let mut flags = PeerManagerHeaderFlags::from_bits(self.flags).unwrap(); + if encrypted { + flags.insert(PeerManagerHeaderFlags::ENCRYPTED); + } else { + flags.remove(PeerManagerHeaderFlags::ENCRYPTED); + } + self.flags = flags.bits(); + } +} + +// reserve the space for aes tag and nonce +#[repr(C, packed)] +#[derive(AsBytes, FromBytes, FromZeroes, Clone, Debug, Default)] +pub struct AesGcmTail { + pub tag: [u8; 16], + pub nonce: [u8; 12], +} +pub const AES_GCM_ENCRYPTION_RESERVED: usize = std::mem::size_of::(); + +pub const TAIL_RESERVED_SIZE: usize = AES_GCM_ENCRYPTION_RESERVED; + const fn max(a: usize, b: usize) -> usize { [a, b][(a < b) as usize] } @@ -308,6 +345,7 @@ impl ZCPacket { hdr.from_peer_id.set(from_peer_id); hdr.to_peer_id.set(to_peer_id); hdr.packet_type = packet_type; + hdr.flags = 0; hdr.len.set(payload_len as u32); } diff --git a/easytier/src/tunnel/wireguard.rs b/easytier/src/tunnel/wireguard.rs index da5d85b..f753bde 100644 --- a/easytier/src/tunnel/wireguard.rs +++ b/easytier/src/tunnel/wireguard.rs @@ -1,7 +1,5 @@ use std::{ - collections::hash_map::DefaultHasher, fmt::{Debug, Formatter}, - hash::Hasher, net::SocketAddr, pin::Pin, sync::{atomic::AtomicBool, Arc}, @@ -33,6 +31,7 @@ use crate::{ use super::{ check_scheme_and_get_socket_addr, common::{setup_sokcet2, setup_sokcet2_ext, wait_for_connect_futures}, + generate_digest_from_str, packet_def::{ZCPacketType, PEER_MANAGER_HEADER_SIZE}, ring::create_ring_tunnel_pair, Tunnel, TunnelError, TunnelListener, TunnelUrl, ZCPacketSink, ZCPacketStream, @@ -62,16 +61,7 @@ pub struct WgConfig { impl WgConfig { pub fn new_from_network_identity(network_name: &str, network_secret: &str) -> Self { let mut my_sec = [0u8; 32]; - let mut hasher = DefaultHasher::new(); - hasher.write(network_name.as_bytes()); - hasher.write(network_secret.as_bytes()); - my_sec[0..8].copy_from_slice(&hasher.finish().to_be_bytes()); - hasher.write(&my_sec[0..8]); - my_sec[8..16].copy_from_slice(&hasher.finish().to_be_bytes()); - hasher.write(&my_sec[0..16]); - my_sec[16..24].copy_from_slice(&hasher.finish().to_be_bytes()); - hasher.write(&my_sec[0..24]); - my_sec[24..32].copy_from_slice(&hasher.finish().to_be_bytes()); + generate_digest_from_str(network_name, network_secret, &mut my_sec); let my_secret_key = StaticSecret::from(my_sec); let my_public_key = PublicKey::from(&my_secret_key); @@ -491,6 +481,7 @@ impl WgTunnelListener { let mut buf = vec![0u8; MAX_PACKET]; loop { + tracing::info!("Waiting for incoming UDP packet"); let Ok((n, addr)) = socket.recv_from(&mut buf).await else { tracing::error!("Failed to receive from UDP socket"); break; diff --git a/easytier/src/vpn_portal/wireguard.rs b/easytier/src/vpn_portal/wireguard.rs index 118fcb5..7cbcd9b 100644 --- a/easytier/src/vpn_portal/wireguard.rs +++ b/easytier/src/vpn_portal/wireguard.rs @@ -10,6 +10,7 @@ use dashmap::DashMap; use futures::StreamExt; use pnet::packet::ipv4::Ipv4Packet; use tokio::task::JoinSet; +use tracing::Level; use crate::{ common::{ @@ -31,7 +32,11 @@ use super::VpnPortal; type WgPeerIpTable = Arc>>; pub(crate) fn get_wg_config_for_portal(nid: &NetworkIdentity) -> WgConfig { - let key_seed = format!("{}{}", nid.network_name, nid.network_secret); + let key_seed = format!( + "{}{}", + nid.network_name, + nid.network_secret.as_ref().unwrap_or(&"".to_string()) + ); WgConfig::new_for_portal(&key_seed, &key_seed) } @@ -166,6 +171,7 @@ impl WireGuardImpl { .await; } + #[tracing::instrument(skip(self), err(level = Level::WARN))] async fn start(&self) -> anyhow::Result<()> { let mut l = WgTunnelListener::new( format!("wg://{}", self.listenr_addr).parse().unwrap(),