sing-box/transport/wireguard/client_bind.go

260 lines
5.6 KiB
Go
Raw Normal View History

2022-09-06 00:15:09 +08:00
package wireguard
import (
"context"
"net"
"net/netip"
2022-09-06 00:15:09 +08:00
"sync"
2023-08-07 17:46:51 +08:00
"time"
2022-09-06 00:15:09 +08:00
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
2023-04-13 16:02:28 +08:00
E "github.com/sagernet/sing/common/exceptions"
2024-11-02 00:39:02 +08:00
"github.com/sagernet/sing/common/logger"
2022-09-06 00:15:09 +08:00
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
2024-06-03 16:59:13 +08:00
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
2022-09-06 00:15:09 +08:00
)
var _ conn.Bind = (*ClientBind)(nil)
type ClientBind struct {
ctx context.Context
2024-11-02 00:39:02 +08:00
logger logger.Logger
2024-06-03 16:59:13 +08:00
pauseManager pause.Manager
bindCtx context.Context
bindDone context.CancelFunc
dialer N.Dialer
reservedForEndpoint map[netip.AddrPort][3]uint8
connAccess sync.Mutex
conn *wireConn
done chan struct{}
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
2022-09-06 00:15:09 +08:00
}
2024-11-02 00:39:02 +08:00
func NewClientBind(ctx context.Context, logger logger.Logger, dialer N.Dialer, isConnect bool, connectAddr netip.AddrPort, reserved [3]uint8) *ClientBind {
2022-09-06 00:15:09 +08:00
return &ClientBind{
ctx: ctx,
2024-11-02 00:39:02 +08:00
logger: logger,
2024-06-03 16:59:13 +08:00
pauseManager: service.FromContext[pause.Manager](ctx),
dialer: dialer,
reservedForEndpoint: make(map[netip.AddrPort][3]uint8),
2024-03-20 10:46:54 +08:00
done: make(chan struct{}),
isConnect: isConnect,
connectAddr: connectAddr,
reserved: reserved,
2022-09-06 00:15:09 +08:00
}
}
func (c *ClientBind) connect() (*wireConn, error) {
serverConn := c.conn
if serverConn != nil {
select {
case <-serverConn.done:
serverConn = nil
default:
return serverConn, nil
}
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
2024-06-03 16:59:13 +08:00
select {
case <-c.done:
return nil, net.ErrClosed
default:
}
2022-09-06 00:15:09 +08:00
serverConn = c.conn
if serverConn != nil {
select {
case <-serverConn.done:
serverConn = nil
default:
return serverConn, nil
}
}
if c.isConnect {
2024-06-03 16:59:13 +08:00
udpConn, err := c.dialer.DialContext(c.bindCtx, N.NetworkUDP, M.SocksaddrFromNetIP(c.connectAddr))
if err != nil {
2023-04-13 16:02:28 +08:00
return nil, err
}
c.conn = &wireConn{
PacketConn: bufio.NewUnbindPacketConn(udpConn),
done: make(chan struct{}),
}
} else {
2024-06-03 16:59:13 +08:00
udpConn, err := c.dialer.ListenPacket(c.bindCtx, M.Socksaddr{Addr: netip.IPv4Unspecified()})
if err != nil {
2023-04-13 16:02:28 +08:00
return nil, err
}
c.conn = &wireConn{
2023-04-13 16:02:28 +08:00
PacketConn: bufio.NewPacketConn(udpConn),
done: make(chan struct{}),
}
2022-09-06 00:15:09 +08:00
}
return c.conn, nil
}
func (c *ClientBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
2022-09-23 13:14:31 +08:00
select {
case <-c.done:
2024-03-20 10:46:54 +08:00
c.done = make(chan struct{})
2022-09-23 13:14:31 +08:00
default:
}
2024-06-03 16:59:13 +08:00
c.bindCtx, c.bindDone = context.WithCancel(c.ctx)
2022-09-06 00:15:09 +08:00
return []conn.ReceiveFunc{c.receive}, 0, nil
}
2023-04-20 13:16:31 +08:00
func (c *ClientBind) receive(packets [][]byte, sizes []int, eps []conn.Endpoint) (count int, err error) {
2022-09-06 00:15:09 +08:00
udpConn, err := c.connect()
if err != nil {
2023-04-13 16:02:28 +08:00
select {
case <-c.done:
return
default:
}
2024-11-02 00:39:02 +08:00
c.logger.Error(E.Cause(err, "connect to server"))
2023-04-13 16:02:28 +08:00
err = nil
2024-06-03 16:59:13 +08:00
c.pauseManager.WaitActive()
2023-08-07 17:46:51 +08:00
time.Sleep(time.Second)
2022-09-06 00:15:09 +08:00
return
}
2023-04-20 13:16:31 +08:00
n, addr, err := udpConn.ReadFrom(packets[0])
2022-09-06 00:15:09 +08:00
if err != nil {
udpConn.Close()
2022-09-23 13:14:31 +08:00
select {
case <-c.done:
default:
2024-11-21 18:10:41 +08:00
c.logger.Error(E.Cause(err, "read packet"))
2023-04-20 13:16:31 +08:00
err = nil
2022-09-23 13:14:31 +08:00
}
return
2022-09-06 00:15:09 +08:00
}
2023-04-20 13:16:31 +08:00
sizes[0] = n
if n > 3 {
2023-04-20 13:16:31 +08:00
b := packets[0]
common.ClearArray(b[1:4])
}
2024-11-21 18:10:41 +08:00
eps[0] = remoteEndpoint(M.AddrPortFromNet(addr))
2023-04-20 13:16:31 +08:00
count = 1
2022-09-06 00:15:09 +08:00
return
}
func (c *ClientBind) Close() error {
2022-09-23 13:14:31 +08:00
select {
case <-c.done:
default:
close(c.done)
}
2024-06-03 16:59:13 +08:00
if c.bindDone != nil {
c.bindDone()
}
c.connAccess.Lock()
defer c.connAccess.Unlock()
common.Close(common.PtrOrNil(c.conn))
2022-09-06 00:15:09 +08:00
return nil
}
func (c *ClientBind) SetMark(mark uint32) error {
return nil
}
2023-04-20 13:16:31 +08:00
func (c *ClientBind) Send(bufs [][]byte, ep conn.Endpoint) error {
2022-09-06 00:15:09 +08:00
udpConn, err := c.connect()
if err != nil {
2024-06-03 16:59:13 +08:00
c.pauseManager.WaitActive()
time.Sleep(time.Second)
2022-09-06 00:15:09 +08:00
return err
}
2024-11-21 18:10:41 +08:00
destination := netip.AddrPort(ep.(remoteEndpoint))
2023-04-20 13:16:31 +08:00
for _, b := range bufs {
if len(b) > 3 {
reserved, loaded := c.reservedForEndpoint[destination]
if !loaded {
reserved = c.reserved
}
copy(b[1:4], reserved[:])
2023-04-20 13:16:31 +08:00
}
2024-03-20 10:46:54 +08:00
_, err = udpConn.WriteToUDPAddrPort(b, destination)
2023-04-20 13:16:31 +08:00
if err != nil {
udpConn.Close()
return err
}
2022-09-06 00:15:09 +08:00
}
2023-04-20 13:16:31 +08:00
return nil
2022-09-06 00:15:09 +08:00
}
func (c *ClientBind) ParseEndpoint(s string) (conn.Endpoint, error) {
ap, err := netip.ParseAddrPort(s)
if err != nil {
return nil, err
}
2024-11-21 18:10:41 +08:00
return remoteEndpoint(ap), nil
2022-09-06 00:15:09 +08:00
}
2023-04-20 13:16:31 +08:00
func (c *ClientBind) BatchSize() int {
return 1
}
func (c *ClientBind) SetReservedForEndpoint(destination netip.AddrPort, reserved [3]byte) {
c.reservedForEndpoint[destination] = reserved
}
2022-09-06 00:15:09 +08:00
type wireConn struct {
2023-04-13 16:02:28 +08:00
net.PacketConn
2024-03-20 10:46:54 +08:00
conn net.Conn
2022-09-06 00:15:09 +08:00
access sync.Mutex
done chan struct{}
}
2024-03-20 10:46:54 +08:00
func (w *wireConn) WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (int, error) {
if w.conn != nil {
return w.conn.Write(b)
}
return w.PacketConn.WriteTo(b, M.SocksaddrFromNetIP(addr).UDPAddr())
}
2022-09-06 00:15:09 +08:00
func (w *wireConn) Close() error {
w.access.Lock()
defer w.access.Unlock()
select {
case <-w.done:
return net.ErrClosed
default:
}
2023-04-13 16:02:28 +08:00
w.PacketConn.Close()
2022-09-06 00:15:09 +08:00
close(w.done)
return nil
}
2024-11-21 18:10:41 +08:00
var _ conn.Endpoint = (*remoteEndpoint)(nil)
type remoteEndpoint netip.AddrPort
func (e remoteEndpoint) ClearSrc() {
}
func (e remoteEndpoint) SrcToString() string {
return ""
}
func (e remoteEndpoint) DstToString() string {
return (netip.AddrPort)(e).String()
}
func (e remoteEndpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
return b
}
func (e remoteEndpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
}
func (e remoteEndpoint) SrcIP() netip.Addr {
return netip.Addr{}
}