diff --git a/inbound/tproxy.go b/inbound/tproxy.go index 55dcac56..ff12b93c 100644 --- a/inbound/tproxy.go +++ b/inbound/tproxy.go @@ -12,6 +12,7 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" + "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -54,21 +55,17 @@ func (t *TProxy) Start() error { return err } if t.tcpListener != nil { - tcpFd, err := common.GetFileDescriptor(t.tcpListener) - if err != nil { - return err - } - err = redir.TProxy(tcpFd, M.SocksaddrFromNet(t.tcpListener.Addr()).Addr.Is6()) + err = control.Conn(t.tcpListener, func(fd uintptr) error { + return redir.TProxy(fd, M.SocksaddrFromNet(t.tcpListener.Addr()).Addr.Is6()) + }) if err != nil { return E.Cause(err, "configure tproxy TCP listener") } } if t.udpConn != nil { - udpFd, err := common.GetFileDescriptor(t.udpConn) - if err != nil { - return err - } - err = redir.TProxy(udpFd, M.SocksaddrFromNet(t.udpConn.LocalAddr()).Addr.Is6()) + err = control.Conn(t.udpConn, func(fd uintptr) error { + return redir.TProxy(fd, M.SocksaddrFromNet(t.udpConn.LocalAddr()).Addr.Is6()) + }) if err != nil { return E.Cause(err, "configure tproxy UDP listener") } @@ -88,21 +85,40 @@ func (t *TProxy) NewPacket(ctx context.Context, conn N.PacketConn, buffer *buf.B } metadata.Destination = M.SocksaddrFromNetIP(destination) t.udpNat.NewContextPacket(ctx, metadata.Source.AddrPort(), buffer, adapter.UpstreamMetadata(metadata), func(natConn N.PacketConn) (context.Context, N.PacketWriter) { - return adapter.WithContext(log.ContextWithNewID(ctx), &metadata), &tproxyPacketWriter{natConn} + return adapter.WithContext(log.ContextWithNewID(ctx), &metadata), &tproxyPacketWriter{source: natConn} }) return nil } type tproxyPacketWriter struct { - source N.PacketConn + source N.PacketConn + destination M.Socksaddr + conn *net.UDPConn } func (w *tproxyPacketWriter) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { defer buffer.Release() - udpConn, err := redir.DialUDP(destination.UDPAddr(), M.SocksaddrFromNet(w.source.LocalAddr()).UDPAddr()) - if err != nil { - return E.Cause(err, "tproxy udp write back") + var udpConn *net.UDPConn + if w.destination == destination { + if w.conn != nil { + udpConn = w.conn + } + } + if udpConn == nil { + var err error + udpConn, err = redir.DialUDP(destination.UDPAddr(), M.SocksaddrFromNet(w.source.LocalAddr()).UDPAddr()) + if err != nil { + return E.Cause(err, "tproxy udp write back") + } + if w.destination == destination { + w.conn = udpConn + } else { + defer udpConn.Close() + } } - defer udpConn.Close() return common.Error(udpConn.Write(buffer.Bytes())) } + +func (w *tproxyPacketWriter) Close() error { + return common.Close(common.PtrOrNil(w.conn)) +}