websocket support bind addr (#129)

This commit is contained in:
Sijie.Sun 2024-06-02 21:48:16 +08:00 committed by GitHub
parent 360691276c
commit c1b725e64e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 113 additions and 36 deletions

View File

@ -11,7 +11,7 @@ use crate::{
common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector}, common::{error::Error, global_ctx::ArcGlobalCtx, network::IPCollector},
tunnel::{ tunnel::{
check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector, check_scheme_and_get_socket_addr, ring::RingTunnelConnector, tcp::TcpTunnelConnector,
udp::UdpTunnelConnector, TunnelConnector, udp::UdpTunnelConnector, FromUrl, IpVersion, TunnelConnector,
}, },
}; };
@ -107,7 +107,14 @@ pub async fn create_connector_by_url(
} }
#[cfg(feature = "websocket")] #[cfg(feature = "websocket")]
"ws" | "wss" => { "ws" | "wss" => {
let connector = crate::tunnel::websocket::WSTunnelConnector::new(url); let dst_addr = SocketAddr::from_url(url.clone(), IpVersion::Both)?;
let mut connector = crate::tunnel::websocket::WSTunnelConnector::new(url);
set_bind_addr_for_peer_connector(
&mut connector,
dst_addr.is_ipv4(),
&global_ctx.get_ip_collector(),
)
.await;
return Ok(Box::new(connector)); return Ok(Box::new(connector));
} }
_ => { _ => {

View File

@ -2,16 +2,16 @@ use std::{net::SocketAddr, sync::Arc};
use anyhow::Context; use anyhow::Context;
use bytes::BytesMut; use bytes::BytesMut;
use futures::{SinkExt, StreamExt}; use futures::{stream::FuturesUnordered, SinkExt, StreamExt};
use tokio::net::{TcpListener, TcpSocket, TcpStream}; use tokio::net::{TcpListener, TcpSocket, TcpStream};
use tokio_rustls::TlsAcceptor; use tokio_rustls::TlsAcceptor;
use tokio_websockets::{ClientBuilder, Limits, Message}; use tokio_websockets::{ClientBuilder, Limits, MaybeTlsStream, Message};
use zerocopy::AsBytes; use zerocopy::AsBytes;
use crate::{rpc::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config}; use crate::{rpc::TunnelInfo, tunnel::insecure_tls::get_insecure_tls_client_config};
use super::{ use super::{
common::{setup_sokcet2, TunnelWrapper}, common::{setup_sokcet2, wait_for_connect_futures, TunnelWrapper},
insecure_tls::{get_insecure_tls_cert, init_crypto_provider}, insecure_tls::{get_insecure_tls_cert, init_crypto_provider},
packet_def::{ZCPacket, ZCPacketType}, packet_def::{ZCPacket, ZCPacketType},
FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener, FromUrl, IpVersion, Tunnel, TunnelConnector, TunnelError, TunnelListener,
@ -156,6 +156,8 @@ impl TunnelListener for WSTunnelListener {
pub struct WSTunnelConnector { pub struct WSTunnelConnector {
addr: url::Url, addr: url::Url,
ip_version: IpVersion, ip_version: IpVersion,
bind_addrs: Vec<SocketAddr>,
} }
impl WSTunnelConnector { impl WSTunnelConnector {
@ -163,47 +165,100 @@ impl WSTunnelConnector {
WSTunnelConnector { WSTunnelConnector {
addr, addr,
ip_version: IpVersion::Both, ip_version: IpVersion::Both,
bind_addrs: vec![],
} }
} }
async fn connect_with(
addr: url::Url,
ip_version: IpVersion,
tcp_socket: TcpSocket,
) -> Result<Box<dyn Tunnel>, TunnelError> {
let is_wss = is_wss(&addr)?;
let socket_addr = SocketAddr::from_url(addr.clone(), ip_version)?;
let host = socket_addr.ip();
let stream = tcp_socket.connect(socket_addr).await?;
let info = TunnelInfo {
tunnel_type: addr.scheme().to_owned(),
local_addr: super::build_url_from_socket_addr(
&stream.local_addr()?.to_string(),
addr.scheme().to_string().as_str(),
)
.into(),
remote_addr: addr.to_string(),
};
let c = ClientBuilder::from_uri(http::Uri::try_from(addr.to_string()).unwrap());
let stream: MaybeTlsStream<TcpStream> = if is_wss {
init_crypto_provider();
let tls_conn =
tokio_rustls::TlsConnector::from(Arc::new(get_insecure_tls_client_config()));
let stream = tls_conn
.connect(host.to_string().try_into().unwrap(), stream)
.await?;
MaybeTlsStream::Rustls(stream)
} else {
MaybeTlsStream::Plain(stream)
};
let (client, _) = c.connect_on(stream).await?;
let (write, read) = client.split();
let read = read.filter_map(move |msg| map_from_ws_message(msg));
let write = write.with(move |msg| sink_from_zc_packet(msg));
Ok(Box::new(TunnelWrapper::new(read, write, Some(info))))
}
async fn connect_with_default_bind(
&mut self,
addr: SocketAddr,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let socket = if addr.is_ipv4() {
TcpSocket::new_v4()?
} else {
TcpSocket::new_v6()?
};
Self::connect_with(self.addr.clone(), self.ip_version, socket).await
}
async fn connect_with_custom_bind(
&mut self,
addr: SocketAddr,
) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let futures = FuturesUnordered::new();
for bind_addr in self.bind_addrs.iter() {
tracing::info!(bind_addr = ?bind_addr, ?addr, "bind addr");
let socket2_socket = socket2::Socket::new(
socket2::Domain::for_address(addr),
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
setup_sokcet2(&socket2_socket, bind_addr)?;
let socket = TcpSocket::from_std_stream(socket2_socket.into());
futures.push(Self::connect_with(
self.addr.clone(),
self.ip_version,
socket,
))
}
wait_for_connect_futures(futures).await
}
} }
#[async_trait::async_trait] #[async_trait::async_trait]
impl TunnelConnector for WSTunnelConnector { impl TunnelConnector for WSTunnelConnector {
async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> { async fn connect(&mut self) -> Result<Box<dyn Tunnel>, super::TunnelError> {
let is_wss = is_wss(&self.addr)?;
let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version)?; let addr = SocketAddr::from_url(self.addr.clone(), self.ip_version)?;
let local_addr = if addr.is_ipv4() { if self.bind_addrs.is_empty() || addr.is_ipv6() {
"0.0.0.0:0" self.connect_with_default_bind(addr).await
} else { } else {
"[::]:0" self.connect_with_custom_bind(addr).await
};
let info = TunnelInfo {
tunnel_type: self.addr.scheme().to_owned(),
local_addr: super::build_url_from_socket_addr(
&local_addr.to_string(),
self.addr.scheme().to_string().as_str(),
)
.into(),
remote_addr: self.addr.to_string(),
};
let connector =
tokio_websockets::Connector::Rustls(Arc::new(get_insecure_tls_client_config()).into());
let mut client_builder =
ClientBuilder::from_uri(http::Uri::try_from(self.addr.to_string()).unwrap());
if is_wss {
init_crypto_provider();
client_builder = client_builder.connector(&connector);
} }
let (client, _) = client_builder.connect().await?;
let (write, read) = client.split();
let read = read.filter_map(move |msg| map_from_ws_message(msg));
let write = write.with(move |msg| sink_from_zc_packet(msg));
Ok(Box::new(TunnelWrapper::new(read, write, Some(info))))
} }
fn remote_url(&self) -> url::Url { fn remote_url(&self) -> url::Url {
@ -213,6 +268,10 @@ impl TunnelConnector for WSTunnelConnector {
fn set_ip_version(&mut self, ip_version: IpVersion) { fn set_ip_version(&mut self, ip_version: IpVersion) {
self.ip_version = ip_version; self.ip_version = ip_version;
} }
fn set_bind_addrs(&mut self, addrs: Vec<SocketAddr>) {
self.bind_addrs = addrs;
}
} }
#[cfg(test)] #[cfg(test)]
@ -231,6 +290,17 @@ pub mod tests {
_tunnel_pingpong(listener, connector).await _tunnel_pingpong(listener, connector).await
} }
#[rstest::rstest]
#[tokio::test]
#[serial_test::serial]
async fn ws_pingpong_bind(#[values("ws", "wss")] proto: &str) {
let listener = WSTunnelListener::new(format!("{}://0.0.0.0:25557", proto).parse().unwrap());
let mut connector =
WSTunnelConnector::new(format!("{}://127.0.0.1:25557", proto).parse().unwrap());
connector.set_bind_addrs(vec!["127.0.0.1:0".parse().unwrap()]);
_tunnel_pingpong(listener, connector).await
}
// TODO: tokio-websockets cannot correctly handle close, benchmark case is disabled // TODO: tokio-websockets cannot correctly handle close, benchmark case is disabled
// #[rstest::rstest] // #[rstest::rstest]
// #[tokio::test] // #[tokio::test]