From 2adb586a78f7b9ce3570423dacff3f32e0d8ad8f Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Fri, 11 Oct 2019 20:11:18 +0800 Subject: [PATCH] Fix: some UDP issues (#265) --- adapters/outbound/direct.go | 2 +- adapters/outbound/http.go | 2 +- adapters/outbound/shadowsocks.go | 22 ++--- adapters/outbound/socks5.go | 17 ++-- adapters/outbound/util.go | 13 --- adapters/outbound/vmess.go | 15 ++- component/nat-table/nat.go | 98 ------------------- component/nat/table.go | 46 +++++++++ component/socks5/socks5.go | 12 +-- constant/metadata.go | 11 ++- proxy/socks/udp.go | 27 +++--- proxy/socks/utils.go | 22 ++--- tunnel/connection.go | 5 +- tunnel/session.go | 22 ----- tunnel/tunnel.go | 157 ++++++++++++++++++++----------- 15 files changed, 228 insertions(+), 243 deletions(-) delete mode 100644 component/nat-table/nat.go create mode 100644 component/nat/table.go delete mode 100644 tunnel/session.go diff --git a/adapters/outbound/direct.go b/adapters/outbound/direct.go index 0f3e6dc4..2b0a2a47 100644 --- a/adapters/outbound/direct.go +++ b/adapters/outbound/direct.go @@ -30,7 +30,7 @@ func (d *Direct) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { return nil, nil, err } - addr, err := resolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) + addr, err := resolveUDPAddr("udp", metadata.RemoteAddress()) if err != nil { return nil, nil, err } diff --git a/adapters/outbound/http.go b/adapters/outbound/http.go index 12b80a5b..357ed5df 100644 --- a/adapters/outbound/http.go +++ b/adapters/outbound/http.go @@ -58,7 +58,7 @@ func (h *Http) shakeHand(metadata *C.Metadata, rw io.ReadWriter) error { var buf bytes.Buffer var err error - addr := net.JoinHostPort(metadata.String(), metadata.DstPort) + addr := metadata.RemoteAddress() buf.WriteString("CONNECT " + addr + " HTTP/1.1\r\n") buf.WriteString("Host: " + metadata.String() + "\r\n") buf.WriteString("Proxy-Connection: Keep-Alive\r\n") diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index 58c76bc9..c46f8fc9 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -7,7 +7,6 @@ import ( "net" "strconv" - "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/structure" obfs "github.com/Dreamacro/clash/component/simple-obfs" "github.com/Dreamacro/clash/component/socks5" @@ -93,9 +92,9 @@ func (ss *ShadowSocks) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, er return nil, nil, err } - targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) - if err != nil { - return nil, nil, err + targetAddr := socks5.ParseAddr(metadata.RemoteAddress()) + if targetAddr == nil { + return nil, nil, fmt.Errorf("parse address error: %v:%v", metadata.String(), metadata.DstPort) } pc = ss.cipher.PacketConn(pc) @@ -189,16 +188,15 @@ func NewShadowSocks(option ShadowSocksOption) (*ShadowSocks, error) { type ssUDPConn struct { net.PacketConn - rAddr net.Addr + rAddr socks5.Addr } -func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { - buf := pool.BufPool.Get().([]byte) - defer pool.BufPool.Put(buf[:cap(buf)]) - rAddr := socks5.ParseAddr(uc.rAddr.String()) - copy(buf[len(rAddr):], b) - copy(buf, rAddr) - return uc.PacketConn.WriteTo(buf[:len(rAddr)+len(b)], addr) +func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { + packet, err := socks5.EncodeUDPPacket(uc.rAddr, b) + if err != nil { + return + } + return uc.PacketConn.WriteTo(packet[3:], addr) } func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { diff --git a/adapters/outbound/socks5.go b/adapters/outbound/socks5.go index a72c9386..d99daa93 100644 --- a/adapters/outbound/socks5.go +++ b/adapters/outbound/socks5.go @@ -98,9 +98,9 @@ func (ss *Socks5) DialUDP(metadata *C.Metadata) (_ C.PacketConn, _ net.Addr, err return } - targetAddr, err := net.ResolveUDPAddr("udp", net.JoinHostPort(metadata.String(), metadata.DstPort)) - if err != nil { - return + targetAddr := socks5.ParseAddr(metadata.RemoteAddress()) + if targetAddr == nil { + return nil, nil, fmt.Errorf("parse address error: %v:%v", metadata.String(), metadata.DstPort) } pc, err := net.ListenPacket("udp", "") @@ -146,12 +146,12 @@ func NewSocks5(option Socks5Option) *Socks5 { type socksUDPConn struct { net.PacketConn - rAddr net.Addr + rAddr socks5.Addr tcpConn net.Conn } func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { - packet, err := socks5.EncodeUDPPacket(uc.rAddr.String(), b) + packet, err := socks5.EncodeUDPPacket(uc.rAddr, b) if err != nil { return } @@ -160,12 +160,17 @@ func (uc *socksUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { func (uc *socksUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { n, a, e := uc.PacketConn.ReadFrom(b) + if e != nil { + return 0, nil, e + } addr, payload, err := socks5.DecodeUDPPacket(b) if err != nil { return 0, nil, err } + // due to DecodeUDPPacket is mutable, record addr length + addrLength := len(addr) copy(b, payload) - return n - len(addr) - 3, a, e + return n - addrLength - 3, a, nil } func (uc *socksUDPConn) Close() error { diff --git a/adapters/outbound/util.go b/adapters/outbound/util.go index 8d617968..46c4581a 100644 --- a/adapters/outbound/util.go +++ b/adapters/outbound/util.go @@ -86,19 +86,6 @@ func serializesSocksAddr(metadata *C.Metadata) []byte { return bytes.Join(buf, nil) } -type fakeUDPConn struct { - net.Conn -} - -func (fuc *fakeUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { - return fuc.Conn.Write(b) -} - -func (fuc *fakeUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, err := fuc.Conn.Read(b) - return n, fuc.RemoteAddr(), err -} - func dialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { host, port, err := net.SplitHostPort(address) if err != nil { diff --git a/adapters/outbound/vmess.go b/adapters/outbound/vmess.go index 0121feb4..5b5337a0 100644 --- a/adapters/outbound/vmess.go +++ b/adapters/outbound/vmess.go @@ -51,7 +51,7 @@ func (v *Vmess) DialUDP(metadata *C.Metadata) (C.PacketConn, net.Addr, error) { if err != nil { return nil, nil, fmt.Errorf("new vmess client error: %v", err) } - return newPacketConn(&fakeUDPConn{Conn: c}, v), c.RemoteAddr(), nil + return newPacketConn(&vmessUDPConn{Conn: c}, v), c.RemoteAddr(), nil } func NewVmess(option VmessOption) (*Vmess, error) { @@ -111,3 +111,16 @@ func parseVmessAddr(metadata *C.Metadata) *vmess.DstAddr { Port: uint(port), } } + +type vmessUDPConn struct { + net.Conn +} + +func (uc *vmessUDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { + return uc.Conn.Write(b) +} + +func (uc *vmessUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { + n, err := uc.Conn.Read(b) + return n, uc.RemoteAddr(), err +} diff --git a/component/nat-table/nat.go b/component/nat-table/nat.go deleted file mode 100644 index 06b38671..00000000 --- a/component/nat-table/nat.go +++ /dev/null @@ -1,98 +0,0 @@ -package nat - -import ( - "net" - "runtime" - "sync" - "time" -) - -type Table struct { - *table -} - -type table struct { - mapping sync.Map - janitor *janitor - timeout time.Duration -} - -type element struct { - Expired time.Time - RemoteAddr net.Addr - RemoteConn net.PacketConn -} - -func (t *table) Set(key net.Addr, rConn net.PacketConn, rAddr net.Addr) { - // set conn read timeout - rConn.SetReadDeadline(time.Now().Add(t.timeout)) - t.mapping.Store(key, &element{ - RemoteAddr: rAddr, - RemoteConn: rConn, - Expired: time.Now().Add(t.timeout), - }) -} - -func (t *table) Get(key net.Addr) (rConn net.PacketConn, rAddr net.Addr) { - item, exist := t.mapping.Load(key) - if !exist { - return - } - elm := item.(*element) - // expired - if time.Since(elm.Expired) > 0 { - t.mapping.Delete(key) - elm.RemoteConn.Close() - return - } - // reset expired time - elm.Expired = time.Now().Add(t.timeout) - return elm.RemoteConn, elm.RemoteAddr -} - -func (t *table) cleanup() { - t.mapping.Range(func(k, v interface{}) bool { - key := k.(net.Addr) - elm := v.(*element) - if time.Since(elm.Expired) > 0 { - t.mapping.Delete(key) - elm.RemoteConn.Close() - } - return true - }) -} - -type janitor struct { - interval time.Duration - stop chan struct{} -} - -func (j *janitor) process(t *table) { - ticker := time.NewTicker(j.interval) - for { - select { - case <-ticker.C: - t.cleanup() - case <-j.stop: - ticker.Stop() - return - } - } -} - -func stopJanitor(t *Table) { - t.janitor.stop <- struct{}{} -} - -// New return *Cache -func New(interval time.Duration) *Table { - j := &janitor{ - interval: interval, - stop: make(chan struct{}), - } - t := &table{janitor: j, timeout: interval} - go j.process(t) - T := &Table{t} - runtime.SetFinalizer(T, stopJanitor) - return T -} diff --git a/component/nat/table.go b/component/nat/table.go new file mode 100644 index 00000000..eac98467 --- /dev/null +++ b/component/nat/table.go @@ -0,0 +1,46 @@ +package nat + +import ( + "net" + "sync" +) + +type Table struct { + mapping sync.Map +} + +type element struct { + RemoteAddr net.Addr + RemoteConn net.PacketConn +} + +func (t *Table) Set(key string, pc net.PacketConn, addr net.Addr) { + // set conn read timeout + t.mapping.Store(key, &element{ + RemoteConn: pc, + RemoteAddr: addr, + }) +} + +func (t *Table) Get(key string) (net.PacketConn, net.Addr) { + item, exist := t.mapping.Load(key) + if !exist { + return nil, nil + } + elm := item.(*element) + return elm.RemoteConn, elm.RemoteAddr +} + +func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) { + item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{}) + return item.(*sync.WaitGroup), loaded +} + +func (t *Table) Delete(key string) { + t.mapping.Delete(key) +} + +// New return *Cache +func New() *Table { + return &Table{} +} diff --git a/component/socks5/socks5.go b/component/socks5/socks5.go index 1dd60391..243f34cb 100644 --- a/component/socks5/socks5.go +++ b/component/socks5/socks5.go @@ -338,6 +338,7 @@ func ParseAddr(s string) Addr { return addr } +// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { if len(packet) < 5 { err = errors.New("insufficient length of packet") @@ -360,16 +361,15 @@ func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { err = errors.New("failed to read UDP header") } - payload = bytes.Join([][]byte{packet[3+len(addr):]}, []byte{}) + payload = packet[3+len(addr):] return } -func EncodeUDPPacket(addr string, payload []byte) (packet []byte, err error) { - rAddr := ParseAddr(addr) - if rAddr == nil { - err = errors.New("cannot parse addr") +func EncodeUDPPacket(addr Addr, payload []byte) (packet []byte, err error) { + if addr == nil { + err = errors.New("address is invalid") return } - packet = bytes.Join([][]byte{{0, 0, 0}, rAddr, payload}, []byte{}) + packet = bytes.Join([][]byte{{0, 0, 0}, addr, payload}, []byte{}) return } diff --git a/constant/metadata.go b/constant/metadata.go index f95bea58..cc870361 100644 --- a/constant/metadata.go +++ b/constant/metadata.go @@ -41,11 +41,18 @@ type Metadata struct { Host string } +func (m *Metadata) RemoteAddress() string { + return net.JoinHostPort(m.String(), m.DstPort) +} + func (m *Metadata) String() string { - if m.Host == "" { + if m.Host != "" { + return m.Host + } else if m.DstIP != nil { return m.DstIP.String() + } else { + return "" } - return m.Host } func (m *Metadata) Valid() bool { diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index 63b63ba5..36228ad5 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -1,17 +1,13 @@ package socks import ( + "bytes" "net" adapters "github.com/Dreamacro/clash/adapters/inbound" "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/component/socks5" C "github.com/Dreamacro/clash/constant" - "github.com/Dreamacro/clash/tunnel" -) - -var ( - _ = tunnel.NATInstance() ) type SockUDPListener struct { @@ -28,17 +24,17 @@ func NewSocksUDPProxy(addr string) (*SockUDPListener, error) { sl := &SockUDPListener{l, addr, false} go func() { - buf := pool.BufPool.Get().([]byte) - defer pool.BufPool.Put(buf[:cap(buf)]) for { + buf := pool.BufPool.Get().([]byte) n, remoteAddr, err := l.ReadFrom(buf) if err != nil { + pool.BufPool.Put(buf[:cap(buf)]) if sl.closed { break } continue } - go handleSocksUDP(l, buf[:n], remoteAddr) + handleSocksUDP(l, buf[:n], remoteAddr) } }() @@ -54,12 +50,19 @@ func (l *SockUDPListener) Address() string { return l.address } -func handleSocksUDP(c net.PacketConn, packet []byte, remoteAddr net.Addr) { - target, payload, err := socks5.DecodeUDPPacket(packet) +func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) { + target, payload, err := socks5.DecodeUDPPacket(buf) if err != nil { - // Unresolved UDP packet, do nothing + // Unresolved UDP packet, return buffer to the pool + pool.BufPool.Put(buf[:cap(buf)]) return } - conn := newfakeConn(c, target.String(), remoteAddr, payload) + conn := &fakeConn{ + PacketConn: pc, + remoteAddr: addr, + targetAddr: target, + buffer: bytes.NewBuffer(payload), + bufRef: buf, + } tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP)) } diff --git a/proxy/socks/utils.go b/proxy/socks/utils.go index 2ecee3c3..ea05961e 100644 --- a/proxy/socks/utils.go +++ b/proxy/socks/utils.go @@ -4,24 +4,16 @@ import ( "bytes" "net" + "github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/component/socks5" ) type fakeConn struct { net.PacketConn - target string remoteAddr net.Addr + targetAddr socks5.Addr buffer *bytes.Buffer -} - -func newfakeConn(conn net.PacketConn, target string, remoteAddr net.Addr, buf []byte) *fakeConn { - buffer := bytes.NewBuffer(buf) - return &fakeConn{ - PacketConn: conn, - target: target, - buffer: buffer, - remoteAddr: remoteAddr, - } + bufRef []byte } func (c *fakeConn) Read(b []byte) (n int, err error) { @@ -29,7 +21,7 @@ func (c *fakeConn) Read(b []byte) (n int, err error) { } func (c *fakeConn) Write(b []byte) (n int, err error) { - packet, err := socks5.EncodeUDPPacket(c.target, b) + packet, err := socks5.EncodeUDPPacket(c.targetAddr, b) if err != nil { return } @@ -39,3 +31,9 @@ func (c *fakeConn) Write(b []byte) (n int, err error) { func (c *fakeConn) RemoteAddr() net.Addr { return c.remoteAddr } + +func (c *fakeConn) Close() error { + err := c.PacketConn.Close() + pool.BufPool.Put(c.bufRef[:cap(c.bufRef)]) + return err +} diff --git a/tunnel/connection.go b/tunnel/connection.go index b5beb043..e48167f2 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -86,11 +86,14 @@ func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Ad t.traffic.Up() <- int64(n) } -func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn) { +func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) { buf := pool.BufPool.Get().([]byte) defer pool.BufPool.Put(buf[:cap(buf)]) + defer t.natTable.Delete(key) + defer pc.Close() for { + pc.SetReadDeadline(time.Now().Add(timeout)) n, _, err := pc.ReadFrom(buf) if err != nil { return diff --git a/tunnel/session.go b/tunnel/session.go deleted file mode 100644 index 4433deae..00000000 --- a/tunnel/session.go +++ /dev/null @@ -1,22 +0,0 @@ -package tunnel - -import ( - "sync" - "time" - - nat "github.com/Dreamacro/clash/component/nat-table" -) - -var ( - natTable *nat.Table - natOnce sync.Once - - natTimeout = 120 * time.Second -) - -func NATInstance() *nat.Table { - natOnce.Do(func() { - natTable = nat.New(natTimeout) - }) - return natTable -} diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 87aafec3..32647f1a 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -7,6 +7,7 @@ import ( "time" InboundAdapter "github.com/Dreamacro/clash/adapters/inbound" + "github.com/Dreamacro/clash/component/nat" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/dns" "github.com/Dreamacro/clash/log" @@ -17,11 +18,16 @@ import ( var ( tunnel *Tunnel once sync.Once + + // default timeout for UDP session + udpTimeout = 60 * time.Second ) // Tunnel handle relay inbound proxy and outbound proxy type Tunnel struct { - queue *channels.InfiniteChannel + tcpQueue *channels.InfiniteChannel + udpQueue *channels.InfiniteChannel + natTable *nat.Table rules []C.Rule proxies map[string]C.Proxy configMux *sync.RWMutex @@ -36,7 +42,12 @@ type Tunnel struct { // Add request to queue func (t *Tunnel) Add(req C.ServerAdapter) { - t.queue.In() <- req + switch req.Metadata().NetWork { + case C.TCP: + t.tcpQueue.In() <- req + case C.UDP: + t.udpQueue.In() <- req + } } // Traffic return traffic of all connections @@ -86,11 +97,18 @@ func (t *Tunnel) SetMode(mode Mode) { } func (t *Tunnel) process() { - queue := t.queue.Out() - for { - elm := <-queue + go func() { + queue := t.udpQueue.Out() + for elm := range queue { + conn := elm.(C.ServerAdapter) + t.handleUDPConn(conn) + } + }() + + queue := t.tcpQueue.Out() + for elm := range queue { conn := elm.(C.ServerAdapter) - go t.handleConn(conn) + go t.handleTCPConn(conn) } } @@ -102,26 +120,7 @@ func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil } -func (t *Tunnel) handleConn(localConn C.ServerAdapter) { - defer func() { - var conn net.Conn - switch adapter := localConn.(type) { - case *InboundAdapter.HTTPAdapter: - conn = adapter.Conn - case *InboundAdapter.SocketAdapter: - conn = adapter.Conn - } - if _, ok := conn.(*net.TCPConn); ok { - localConn.Close() - } - }() - - metadata := localConn.Metadata() - if !metadata.Valid() { - log.Warnln("[Metadata] not valid: %#v", metadata) - return - } - +func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { // preprocess enhanced-mode metadata if t.needLookupIP(metadata) { host, exist := dns.DefaultResolver.IPToHost(*metadata.DstIP) @@ -146,43 +145,87 @@ func (t *Tunnel) handleConn(localConn C.ServerAdapter) { var err error proxy, rule, err = t.match(metadata) if err != nil { - return + return nil, nil, err } } - - switch metadata.NetWork { - case C.TCP: - t.handleTCPConn(localConn, metadata, proxy, rule) - case C.UDP: - t.handleUDPConn(localConn, metadata, proxy, rule) - } + return proxy, rule, nil } -func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy, rule C.Rule) { - pc, addr := natTable.Get(localConn.RemoteAddr()) - if pc == nil { - rawPc, nAddr, err := proxy.DialUDP(metadata) - addr = nAddr - pc = rawPc - if err != nil { - log.Warnln("dial %s error: %s", proxy.Name(), err.Error()) - return - } - - if rule != nil { - log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String()) - } else { - log.Infoln("%s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String()) - } - - natTable.Set(localConn.RemoteAddr(), pc, addr) - go t.handleUDPToLocal(localConn, pc) +func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) { + metadata := localConn.Metadata() + if !metadata.Valid() { + log.Warnln("[Metadata] not valid: %#v", metadata) + return } - t.handleUDPToRemote(localConn, pc, addr) + src := localConn.RemoteAddr().String() + dst := metadata.RemoteAddress() + key := src + "-" + dst + + pc, addr := t.natTable.Get(key) + if pc != nil { + t.handleUDPToRemote(localConn, pc, addr) + return + } + + lockKey := key + "-lock" + wg, loaded := t.natTable.GetOrCreateLock(lockKey) + go func() { + if !loaded { + wg.Add(1) + proxy, rule, err := t.resolveMetadata(metadata) + if err != nil { + log.Warnln("Parse metadata failed: %s", err.Error()) + t.natTable.Delete(lockKey) + wg.Done() + return + } + + rawPc, nAddr, err := proxy.DialUDP(metadata) + if err != nil { + log.Warnln("dial %s error: %s", proxy.Name(), err.Error()) + t.natTable.Delete(lockKey) + wg.Done() + return + } + pc = rawPc + addr = nAddr + + if rule != nil { + log.Infoln("%s --> %v match %s using %s", metadata.SrcIP.String(), metadata.String(), rule.RuleType().String(), rawPc.Chains().String()) + } else { + log.Infoln("%s --> %v doesn't match any rule using DIRECT", metadata.SrcIP.String(), metadata.String()) + } + + t.natTable.Set(key, pc, addr) + t.natTable.Delete(lockKey) + wg.Done() + go t.handleUDPToLocal(localConn, pc, key, udpTimeout) + } + + wg.Wait() + pc, addr := t.natTable.Get(key) + if pc != nil { + t.handleUDPToRemote(localConn, pc, addr) + } + }() } -func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter, metadata *C.Metadata, proxy C.Proxy, rule C.Rule) { +func (t *Tunnel) handleTCPConn(localConn C.ServerAdapter) { + defer localConn.Close() + + metadata := localConn.Metadata() + if !metadata.Valid() { + log.Warnln("[Metadata] not valid: %#v", metadata) + return + } + + proxy, rule, err := t.resolveMetadata(metadata) + if err != nil { + log.Warnln("Parse metadata failed: %v", err) + return + } + remoteConn, err := proxy.Dial(metadata) if err != nil { log.Warnln("dial %s error: %s", proxy.Name(), err.Error()) @@ -253,7 +296,9 @@ func (t *Tunnel) match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { func newTunnel() *Tunnel { return &Tunnel{ - queue: channels.NewInfiniteChannel(), + tcpQueue: channels.NewInfiniteChannel(), + udpQueue: channels.NewInfiniteChannel(), + natTable: nat.New(), proxies: make(map[string]C.Proxy), configMux: &sync.RWMutex{}, traffic: C.NewTraffic(time.Second),