diff --git a/easytier/src/connector/mod.rs b/easytier/src/connector/mod.rs index 24396bf..085bd8d 100644 --- a/easytier/src/connector/mod.rs +++ b/easytier/src/connector/mod.rs @@ -11,7 +11,7 @@ use crate::{ common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector}, tunnel::{ check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector, - udp::UdpTunnelConnector, TunnelConnector, + udp::UdpTunnelConnector, FromUrl, IpVersion, TunnelConnector, }, }; @@ -107,7 +107,14 @@ pub async fn create_connector_by_url( } #[cfg(feature = "websocket")] "ws" | "wss" => { - let connector = crate::tunnel::websocket::WSTunnelConnector::new(url); + let dst_addr = SocketAddr::from_url(url.clone(), IpVersion::Both)?; + let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url); + set_bind_addr_for_peer_connector( + &mut connector, + dst_addr.is_ipv4(), + &global_ctx.get_ip_collector(), + ) + .await; return Ok(Box::new(connector)); } _ => { diff --git a/easytier/src/tunnel/websocket.rs b/easytier/src/tunnel/websocket.rs index 4500540..ada2f9c 100644 --- a/easytier/src/tunnel/websocket.rs +++ b/easytier/src/tunnel/websocket.rs @@ -2,16 +2,16 @@ use std::{net::SocketAddr, sync::Arc}; use anyhow::Context; use bytes::BytesMut; -use futures::{SinkExt, StreamExt}; +use futures::{stream::FuturesUnordered, SinkExt, StreamExt}; use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio_rustls::TlsAcceptor; -use tokio_websockets::{ClientBuilder, Limits, Message}; +use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message}; use zerocopy::AsBytes; use crate::{rpc::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config}; use super::{ - common::{setup_sokcet2, TunnelWrapper}, + common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper}, insecure_tls::{get_insecure_tls_cert, init_crypto_provider}, packet_def::{ZCPacket, ZCPacketType}, FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener, @@ -156,6 +156,8 @@ impl TunnelListener for WSTunnelListener { pub struct WSTunnelConnector { addr: url::Url, ip_version: IpVersion, + + bind_addrs: Vec, } impl WSTunnelConnector { @@ -163,47 +165,100 @@ impl WSTunnelConnector { WSTunnelConnector { addr, ip_version: IpVersion::Both, + + bind_addrs: vec![], } } + + async fn connect_with( + addr: url::Url, + ip_version: IpVersion, + tcp_socket: TcpSocket, + ) -> Result, TunnelError> { + let is_wss = is_wss(&addr)?; + let socket_addr = SocketAddr::from_url(addr.clone(), ip_version)?; + let host = socket_addr.ip(); + let stream = tcp_socket.connect(socket_addr).await?; + + let info = TunnelInfo { + tunnel_type: addr.scheme().to_owned(), + local_addr: super::build_url_from_socket_addr( + &stream.local_addr()?.to_string(), + addr.scheme().to_string().as_str(), + ) + .into(), + remote_addr: addr.to_string(), + }; + + let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap()); + let stream: MaybeTlsStream = if is_wss { + init_crypto_provider(); + let tls_conn = + tokio_rustls::TlsConnector::from(Arc::new(get_insecure_tls_client_config())); + let stream = tls_conn + .connect(host.to_string().try_into().unwrap(), stream) + .await?; + MaybeTlsStream::Rustls(stream) + } else { + MaybeTlsStream::Plain(stream) + }; + + let (client, _) = c.connect_on(stream).await?; + let (write, read) = client.split(); + let read = read.filter_map(move |msg| map_from_ws_message(msg)); + let write = write.with(move |msg| sink_from_zc_packet(msg)); + Ok(Box::new(TunnelWrapper::new(read, write, Some(info)))) + } + + async fn connect_with_default_bind( + &mut self, + addr: SocketAddr, + ) -> Result, super::TunnelError> { + let socket = if addr.is_ipv4() { + TcpSocket::new_v4()? + } else { + TcpSocket::new_v6()? + }; + Self::connect_with(self.addr.clone(), self.ip_version, socket).await + } + + async fn connect_with_custom_bind( + &mut self, + addr: SocketAddr, + ) -> Result, super::TunnelError> { + let futures = FuturesUnordered::new(); + + for bind_addr in self.bind_addrs.iter() { + tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr"); + + let socket2_socket = socket2::Socket::new( + socket2::Domain::for_address(addr), + socket2::Type::STREAM, + Some(socket2::Protocol::TCP), + )?; + setup_sokcet2(&socket2_socket, bind_addr)?; + + let socket = TcpSocket::from_std_stream(socket2_socket.into()); + futures.push(Self::connect_with( + self.addr.clone(), + self.ip_version, + socket, + )) + } + + wait_for_connect_futures(futures).await + } } #[async_trait::async_trait] impl TunnelConnector for WSTunnelConnector { async fn connect(&mut self) -> Result, super::TunnelError> { - let is_wss = is_wss(&self.addr)?; let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version)?; - let local_addr = if addr.is_ipv4() { - "0.0.0.0:0" + if self.bind_addrs.is_empty() || addr.is_ipv6() { + self.connect_with_default_bind(addr).await } else { - "[::]:0" - }; - - let info = TunnelInfo { - tunnel_type: self.addr.scheme().to_owned(), - local_addr: super::build_url_from_socket_addr( - &local_addr.to_string(), - self.addr.scheme().to_string().as_str(), - ) - .into(), - remote_addr: self.addr.to_string(), - }; - - let connector = - tokio_websockets::Connector::Rustls(Arc::new(get_insecure_tls_client_config()).into()); - let mut client_builder = - ClientBuilder::from_uri(http::Uri::try_from(self.addr.to_string()).unwrap()); - if is_wss { - init_crypto_provider(); - client_builder = client_builder.connector(&connector); + self.connect_with_custom_bind(addr).await } - - let (client, _) = client_builder.connect().await?; - - let (write, read) = client.split(); - let read = read.filter_map(move |msg| map_from_ws_message(msg)); - let write = write.with(move |msg| sink_from_zc_packet(msg)); - - Ok(Box::new(TunnelWrapper::new(read, write, Some(info)))) } fn remote_url(&self) -> url::Url { @@ -213,6 +268,10 @@ impl TunnelConnector for WSTunnelConnector { fn set_ip_version(&mut self, ip_version: IpVersion) { self.ip_version = ip_version; } + + fn set_bind_addrs(&mut self, addrs: Vec) { + self.bind_addrs = addrs; + } } #[cfg(test)] @@ -231,6 +290,17 @@ pub mod tests { _tunnel_pingpong(listener, connector).await } + #[rstest::rstest] + #[tokio::test] + #[serial_test::serial] + async fn ws_pingpong_bind(#[values("ws", "wss")] proto: &str) { + let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap()); + let mut connector = + WSTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap()); + connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]); + _tunnel_pingpong(listener, connector).await + } + // TODO: tokio-websockets cannot correctly handle close, benchmark case is disabled // #[rstest::rstest] // #[tokio::test]