From 81e03ec9042110f359389c8bc66352ef6aab2197 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Mon, 28 Nov 2022 18:18:51 +0800 Subject: [PATCH] chore: tuic-server support disassociate command --- listener/tuic/server.go | 8 ++++---- transport/tuic/conn.go | 22 ++++++++++++++-------- transport/tuic/protocol.go | 15 +++++++++++++++ transport/tuic/server.go | 25 ++++++++++++++++++++----- 4 files changed, 53 insertions(+), 17 deletions(-) diff --git a/listener/tuic/server.go b/listener/tuic/server.go index 3ea2dd80..870c9f80 100644 --- a/listener/tuic/server.go +++ b/listener/tuic/server.go @@ -60,13 +60,13 @@ func New(config config.TuicServer, tcpIn chan<- C.ConnContext, udpIn chan<- *inb } option := &tuic.ServerOption{ - HandleTcpFn: func(conn net.Conn, addr string) error { - tcpIn <- inbound.NewSocket(socks5.ParseAddr(addr), conn, C.TUIC) + HandleTcpFn: func(conn net.Conn, addr socks5.Addr) error { + tcpIn <- inbound.NewSocket(addr, conn, C.TUIC) return nil }, - HandleUdpFn: func(addr *net.UDPAddr, packet C.UDPPacket) error { + HandleUdpFn: func(addr socks5.Addr, packet C.UDPPacket) error { select { - case udpIn <- inbound.NewPacket(socks5.ParseAddrToSocksAddr(addr), packet, C.TUIC): + case udpIn <- inbound.NewPacket(addr, packet, C.TUIC): default: } return nil diff --git a/transport/tuic/conn.go b/transport/tuic/conn.go index 4f914e26..b030c3d8 100644 --- a/transport/tuic/conn.go +++ b/transport/tuic/conn.go @@ -5,6 +5,7 @@ import ( "net" "net/netip" "sync" + "sync/atomic" "time" "github.com/metacubex/quic-go" @@ -115,6 +116,7 @@ type quicStreamPacketConn struct { deferQuicConnFn func(quicConn quic.Connection, err error) closeDeferFn func() + writeClosed *atomic.Bool closeOnce sync.Once closeErr error @@ -133,11 +135,11 @@ func (q *quicStreamPacketConn) close() (err error) { if q.closeDeferFn != nil { defer q.closeDeferFn() } - defer func() { - if q.deferQuicConnFn != nil { + if q.deferQuicConnFn != nil { + defer func() { q.deferQuicConnFn(q.quicConn, err) - } - }() + }() + } if q.inputConn != nil { _ = q.inputConn.Close() q.inputConn = nil @@ -204,11 +206,15 @@ func (q *quicStreamPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err erro if q.closed { return 0, net.ErrClosed } - defer func() { - if q.deferQuicConnFn != nil { + if q.writeClosed != nil && q.writeClosed.Load() { + _ = q.Close() + return 0, net.ErrClosed + } + if q.deferQuicConnFn != nil { + defer func() { q.deferQuicConnFn(q.quicConn, err) - } - }() + }() + } addr.String() buf := pool.GetBuffer() defer pool.PutBuffer(buf) diff --git a/transport/tuic/protocol.go b/transport/tuic/protocol.go index 66a01e54..f7456e95 100644 --- a/transport/tuic/protocol.go +++ b/transport/tuic/protocol.go @@ -554,6 +554,21 @@ func (c Address) String() string { } } +func (c Address) SocksAddr() socks5.Addr { + addr := make([]byte, 1+len(c.ADDR)+2) + switch c.TYPE { + case AtypIPv4: + addr[0] = socks5.AtypIPv4 + case AtypIPv6: + addr[0] = socks5.AtypIPv6 + case AtypDomainName: + addr[0] = socks5.AtypDomainName + } + copy(addr[1:], c.ADDR) + binary.BigEndian.PutUint16(addr[len(addr)-2:], c.PORT) + return addr +} + func (c Address) UDPAddr() *net.UDPAddr { return &net.UDPAddr{ IP: c.ADDR, diff --git a/transport/tuic/server.go b/transport/tuic/server.go index 4da86f9b..07b4c6c1 100644 --- a/transport/tuic/server.go +++ b/transport/tuic/server.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" "time" "github.com/gofrs/uuid" @@ -16,11 +17,12 @@ import ( N "github.com/Dreamacro/clash/common/net" "github.com/Dreamacro/clash/common/pool" C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/transport/socks5" ) type ServerOption struct { - HandleTcpFn func(conn net.Conn, addr string) error - HandleUdpFn func(addr *net.UDPAddr, packet C.UDPPacket) error + HandleTcpFn func(conn net.Conn, addr socks5.Addr) error + HandleUdpFn func(addr socks5.Addr, packet C.UDPPacket) error TlsConfig *tls.Config QuicConfig *quic.Config @@ -78,6 +80,8 @@ type serverHandler struct { authCh chan struct{} authOk bool authOnce sync.Once + + udpInputMap sync.Map } func (s *serverHandler) handle() { @@ -125,6 +129,13 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err err var assocId uint32 assocId = packet.ASSOC_ID + + v, _ := s.udpInputMap.LoadOrStore(assocId, &atomic.Bool{}) + writeClosed := v.(*atomic.Bool) + if writeClosed.Load() { + return nil + } + pc := &quicStreamPacketConn{ connId: assocId, quicConn: s.quicConn, @@ -135,9 +146,10 @@ func (s *serverHandler) parsePacket(packet Packet, udpRelayMode string) (err err ref: s, deferQuicConnFn: nil, closeDeferFn: nil, + writeClosed: writeClosed, } - return s.HandleUdpFn(packet.ADDR.UDPAddr(), &serverUDPPacket{ + return s.HandleUdpFn(packet.ADDR.SocksAddr(), &serverUDPPacket{ pc: pc, packet: &packet, rAddr: s.genServerAssocIdAddr(assocId), @@ -175,7 +187,7 @@ func (s *serverHandler) handleStream() (err error) { buf := pool.GetBuffer() defer pool.PutBuffer(buf) - err = s.HandleTcpFn(conn, connect.ADDR.String()) + err = s.HandleTcpFn(conn, connect.ADDR.SocksAddr()) if err != nil { err = NewResponseFailed().WriteTo(buf) } @@ -245,7 +257,10 @@ func (s *serverHandler) handleUniStream() (err error) { if err != nil { return } - disassociate.BytesLen() + if v, loaded := s.udpInputMap.LoadAndDelete(disassociate.ASSOC_ID); loaded { + writeClosed := v.(*atomic.Bool) + writeClosed.Store(true) + } } return }()