sing-box/route/conn.go
2024-12-21 19:40:16 +08:00

282 lines
9.4 KiB
Go

package route
import (
"context"
"io"
"net"
"net/netip"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
"github.com/sagernet/sing/common/canceler"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
)
var _ adapter.ConnectionManager = (*ConnectionManager)(nil)
type ConnectionManager struct {
logger logger.ContextLogger
access sync.Mutex
connections list.List[io.Closer]
}
func NewConnectionManager(logger logger.ContextLogger) *ConnectionManager {
return &ConnectionManager{
logger: logger,
}
}
func (m *ConnectionManager) Start(stage adapter.StartStage) error {
return nil
}
func (m *ConnectionManager) Close() error {
m.access.Lock()
defer m.access.Unlock()
for element := m.connections.Front(); element != nil; element = element.Next() {
common.Close(element.Value)
}
m.connections.Init()
return nil
}
func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = adapter.WithContext(ctx, &metadata)
var (
remoteConn net.Conn
err error
)
if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() {
remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination)
}
if err != nil {
err = E.Cause(err, "open outbound connection")
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, err)
return
}
err = N.ReportConnHandshakeSuccess(conn, remoteConn)
if err != nil {
err = E.Cause(err, "report handshake success")
remoteConn.Close()
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, err)
return
}
m.access.Lock()
element := m.connections.PushBack(conn)
m.access.Unlock()
onClose = N.AppendClose(onClose, func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
})
var done atomic.Bool
go m.connectionCopy(ctx, conn, remoteConn, false, &done, onClose)
go m.connectionCopy(ctx, remoteConn, conn, true, &done, onClose)
}
func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
ctx = adapter.WithContext(ctx, &metadata)
var (
remotePacketConn net.PacketConn
remoteConn net.Conn
destinationAddress netip.Addr
err error
)
if metadata.UDPConnect {
parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer)
if len(metadata.DestinationAddresses) > 0 {
if isParallelDialer {
remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses)
}
} else if metadata.Destination.IsIP() {
if isParallelDialer {
remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
}
} else {
remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination)
}
if err != nil {
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "open outbound packet connection: ", err)
return
}
remotePacketConn = bufio.NewUnbindPacketConn(remoteConn)
connRemoteAddr := M.AddrFromNet(remoteConn.RemoteAddr())
if connRemoteAddr != metadata.Destination.Addr {
destinationAddress = connRemoteAddr
}
} else {
if len(metadata.DestinationAddresses) > 0 {
remotePacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, this, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay)
} else {
remotePacketConn, err = this.ListenPacket(ctx, metadata.Destination)
}
if err != nil {
N.CloseOnHandshakeFailure(conn, onClose, err)
m.logger.ErrorContext(ctx, "listen outbound packet connection: ", err)
return
}
}
err = N.ReportPacketConnHandshakeSuccess(conn, remotePacketConn)
if err != nil {
conn.Close()
remotePacketConn.Close()
m.logger.ErrorContext(ctx, "report handshake success: ", err)
return
}
if destinationAddress.IsValid() {
var originDestination M.Socksaddr
if metadata.RouteOriginalDestination.IsValid() {
originDestination = metadata.RouteOriginalDestination
} else {
originDestination = metadata.Destination
}
if metadata.Destination != M.SocksaddrFrom(destinationAddress, metadata.Destination.Port) {
if metadata.UDPDisableDomainUnmapping {
remotePacketConn = bufio.NewUnidirectionalNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
} else {
remotePacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(remotePacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), originDestination)
}
}
if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded {
natConn.UpdateDestination(destinationAddress)
}
}
var udpTimeout time.Duration
if metadata.UDPTimeout > 0 {
udpTimeout = metadata.UDPTimeout
} else {
protocol := metadata.Protocol
if protocol == "" {
protocol = C.PortProtocols[metadata.Destination.Port]
}
if protocol != "" {
udpTimeout = C.ProtocolTimeouts[protocol]
}
}
if udpTimeout > 0 {
ctx, conn = canceler.NewPacketConn(ctx, conn, udpTimeout)
}
destination := bufio.NewPacketConn(remotePacketConn)
m.access.Lock()
element := m.connections.PushBack(conn)
m.access.Unlock()
onClose = N.AppendClose(onClose, func(it error) {
m.access.Lock()
defer m.access.Unlock()
m.connections.Remove(element)
})
var done atomic.Bool
go m.packetConnectionCopy(ctx, conn, destination, false, &done, onClose)
go m.packetConnectionCopy(ctx, destination, conn, true, &done, onClose)
}
func (m *ConnectionManager) connectionCopy(ctx context.Context, source io.Reader, destination io.Writer, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
originSource := source
originDestination := destination
var readCounters, writeCounters []N.CountFunc
for {
source, readCounters = N.UnwrapCountReader(source, readCounters)
destination, writeCounters = N.UnwrapCountWriter(destination, writeCounters)
if cachedSrc, isCached := source.(N.CachedReader); isCached {
cachedBuffer := cachedSrc.ReadCached()
if cachedBuffer != nil {
dataLen := cachedBuffer.Len()
_, err := destination.Write(cachedBuffer.Bytes())
cachedBuffer.Release()
if err != nil {
if done.Swap(true) {
onClose(err)
}
common.Close(originSource, originDestination)
if !direction {
m.logger.ErrorContext(ctx, "connection upload payload: ", err)
} else {
m.logger.ErrorContext(ctx, "connection download payload: ", err)
}
return
}
for _, counter := range readCounters {
counter(int64(dataLen))
}
for _, counter := range writeCounters {
counter(int64(dataLen))
}
}
continue
}
break
}
_, err := bufio.CopyWithCounters(destination, source, originSource, readCounters, writeCounters)
if err != nil {
common.Close(originDestination)
} else if duplexDst, isDuplex := destination.(N.WriteCloser); isDuplex {
err = duplexDst.CloseWrite()
if err != nil {
common.Close(originSource, originDestination)
}
} else {
common.Close(originDestination)
}
if done.Swap(true) {
onClose(err)
common.Close(originSource, originDestination)
}
if !direction {
if err == nil {
m.logger.DebugContext(ctx, "connection upload finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection upload closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection upload closed")
}
} else {
if err == nil {
m.logger.DebugContext(ctx, "connection download finished")
} else if !E.IsClosedOrCanceled(err) {
m.logger.ErrorContext(ctx, "connection download closed: ", err)
} else {
m.logger.TraceContext(ctx, "connection download closed")
}
}
}
func (m *ConnectionManager) packetConnectionCopy(ctx context.Context, source N.PacketReader, destination N.PacketWriter, direction bool, done *atomic.Bool, onClose N.CloseHandlerFunc) {
_, err := bufio.CopyPacket(destination, source)
if !direction {
if E.IsClosedOrCanceled(err) {
m.logger.TraceContext(ctx, "packet upload closed")
} else {
m.logger.DebugContext(ctx, "packet upload closed: ", err)
}
} else {
if E.IsClosedOrCanceled(err) {
m.logger.TraceContext(ctx, "packet download closed")
} else {
m.logger.DebugContext(ctx, "packet download closed: ", err)
}
}
if !done.Swap(true) {
onClose(err)
}
common.Close(source, destination)
}