From 95f06ab9b9d64c486368537d37cd76f79cbf8d4f Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Sat, 28 Dec 2019 18:44:01 +0800 Subject: [PATCH] Improve: UDP relay refactor (#441) Co-authored-by: Dreamacro --- adapters/inbound/packet.go | 33 +++++++++++++++++++ adapters/outbound/shadowsocks.go | 9 +++-- component/fakeip/pool.go | 20 ++++++++++-- component/socks5/socks5.go | 54 ++++++++++++++++++++++++++++++ constant/adapters.go | 18 ++++++++++ dns/middleware.go | 2 +- dns/resolver.go | 11 ++++++- proxy/socks/udp.go | 7 ++-- proxy/socks/utils.go | 20 ++++++++---- tunnel/connection.go | 24 +++++++------- tunnel/tunnel.go | 56 ++++++++++++++++++++------------ 11 files changed, 202 insertions(+), 52 deletions(-) create mode 100644 adapters/inbound/packet.go diff --git a/adapters/inbound/packet.go b/adapters/inbound/packet.go new file mode 100644 index 00000000..59ccce85 --- /dev/null +++ b/adapters/inbound/packet.go @@ -0,0 +1,33 @@ +package inbound + +import ( + "github.com/Dreamacro/clash/component/socks5" + C "github.com/Dreamacro/clash/constant" +) + +// PacketAdapter is a UDP Packet adapter for socks/redir/tun +type PacketAdapter struct { + C.UDPPacket + metadata *C.Metadata +} + +// Metadata returns destination metadata +func (s *PacketAdapter) Metadata() *C.Metadata { + return s.metadata +} + +// NewPacket is PacketAdapter generator +func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type, netType C.NetWork) *PacketAdapter { + metadata := parseSocksAddr(target) + metadata.NetWork = netType + metadata.Type = source + if ip, port, err := parseAddr(packet.LocalAddr().String()); err == nil { + metadata.SrcIP = ip + metadata.SrcPort = port + } + + return &PacketAdapter{ + UDPPacket: packet, + metadata: metadata, + } +} diff --git a/adapters/outbound/shadowsocks.go b/adapters/outbound/shadowsocks.go index 80aa6659..83e1a8b8 100644 --- a/adapters/outbound/shadowsocks.go +++ b/adapters/outbound/shadowsocks.go @@ -201,8 +201,13 @@ func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) { } func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { - n, a, e := uc.PacketConn.ReadFrom(b) + n, _, e := uc.PacketConn.ReadFrom(b) addr := socks5.SplitAddr(b[:n]) + var from net.Addr + if e == nil { + // Get the source IP/Port of packet. + from = addr.UDPAddr() + } copy(b, b[len(addr):]) - return n - len(addr), a, e + return n - len(addr), from, e } diff --git a/component/fakeip/pool.go b/component/fakeip/pool.go index ec84d8f8..b92b55b3 100644 --- a/component/fakeip/pool.go +++ b/component/fakeip/pool.go @@ -62,12 +62,26 @@ func (p *Pool) LookBack(ip net.IP) (string, bool) { return "", false } -// LookupHost return if host in host -func (p *Pool) LookupHost(host string) bool { +// LookupHost return if domain in host +func (p *Pool) LookupHost(domain string) bool { if p.host == nil { return false } - return p.host.Search(host) != nil + return p.host.Search(domain) != nil +} + +// Exist returns if given ip exists in fake-ip pool +func (p *Pool) Exist(ip net.IP) bool { + p.mux.Lock() + defer p.mux.Unlock() + + if ip = ip.To4(); ip == nil { + return false + } + + n := ipToUint(ip.To4()) + offset := n - p.min + 1 + return p.cache.Exist(offset) } // Gateway return gateway ip diff --git a/component/socks5/socks5.go b/component/socks5/socks5.go index 243f34cb..b150d1f7 100644 --- a/component/socks5/socks5.go +++ b/component/socks5/socks5.go @@ -2,6 +2,7 @@ package socks5 import ( "bytes" + "encoding/binary" "errors" "io" "net" @@ -62,6 +63,25 @@ func (a Addr) String() string { return net.JoinHostPort(host, port) } +// UDPAddr converts a socks5.Addr to *net.UDPAddr +func (a Addr) UDPAddr() *net.UDPAddr { + if len(a) == 0 { + return nil + } + switch a[0] { + case AtypIPv4: + var ip [net.IPv4len]byte + copy(ip[0:], a[1:1+net.IPv4len]) + return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))} + case AtypIPv6: + var ip [net.IPv6len]byte + copy(ip[0:], a[1:1+net.IPv6len]) + return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))} + } + // Other Atyp + return nil +} + // SOCKS errors as defined in RFC 1928 section 6. const ( ErrGeneralFailure = Error(1) @@ -338,6 +358,40 @@ func ParseAddr(s string) Addr { return addr } +// ParseAddrToSocksAddr parse a socks addr from net.addr +// This is a fast path of ParseAddr(addr.String()) +func ParseAddrToSocksAddr(addr net.Addr) Addr { + var hostip net.IP + var port int + if udpaddr, ok := addr.(*net.UDPAddr); ok { + hostip = udpaddr.IP + port = udpaddr.Port + } else if tcpaddr, ok := addr.(*net.TCPAddr); ok { + hostip = tcpaddr.IP + port = tcpaddr.Port + } + + // fallback parse + if hostip == nil { + return ParseAddr(addr.String()) + } + + var parsed Addr + if ip4 := hostip.To4(); ip4.DefaultMask() != nil { + parsed = make([]byte, 1+net.IPv4len+2) + parsed[0] = AtypIPv4 + copy(parsed[1:], ip4) + binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port)) + + } else { + parsed = make([]byte, 1+net.IPv6len+2) + parsed[0] = AtypIPv6 + copy(parsed[1:], hostip) + binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port)) + } + return parsed +} + // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { if len(packet) < 5 { diff --git a/constant/adapters.go b/constant/adapters.go index b51d083b..f05a23fa 100644 --- a/constant/adapters.go +++ b/constant/adapters.go @@ -109,3 +109,21 @@ func (at AdapterType) String() string { return "Unknown" } } + +// UDPPacket contains the data of UDP packet, and offers control/info of UDP packet's source +type UDPPacket interface { + // Data get the payload of UDP Packet + Data() []byte + + // WriteBack writes the payload with source IP/Port equals addr + // - variable source IP/Port is important to STUN + // - if addr is not provided, WriteBack will wirte out UDP packet with SourceIP/Prot equals to origional Target, + // this is important when using Fake-IP. + WriteBack(b []byte, addr net.Addr) (n int, err error) + + // Close closes the underlaying connection. + Close() error + + // LocalAddr returns the source IP/Port of packet + LocalAddr() net.Addr +} diff --git a/dns/middleware.go b/dns/middleware.go index 2a83c648..5aa691ad 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -75,7 +75,7 @@ func compose(middlewares []middleware, endpoint handler) handler { func newHandler(resolver *Resolver) handler { middlewares := []middleware{} - if resolver.IsFakeIP() { + if resolver.FakeIPEnabled() { middlewares = append(middlewares, withFakeIP(resolver.pool)) } diff --git a/dns/resolver.go b/dns/resolver.go index 1b68d7d6..a3540817 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -166,10 +166,19 @@ func (r *Resolver) IsMapping() bool { return r.mapping } -func (r *Resolver) IsFakeIP() bool { +// FakeIPEnabled returns if fake-ip is enabled +func (r *Resolver) FakeIPEnabled() bool { return r.fakeip } +// IsFakeIP determine if given ip is a fake-ip +func (r *Resolver) IsFakeIP(ip net.IP) bool { + if r.FakeIPEnabled() { + return r.pool.Exist(ip) + } + return false +} + func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) { fast, ctx := picker.WithTimeout(context.Background(), time.Second) for _, client := range clients { diff --git a/proxy/socks/udp.go b/proxy/socks/udp.go index 36228ad5..0e26cc14 100644 --- a/proxy/socks/udp.go +++ b/proxy/socks/udp.go @@ -1,7 +1,6 @@ package socks import ( - "bytes" "net" adapters "github.com/Dreamacro/clash/adapters/inbound" @@ -57,12 +56,12 @@ func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) { pool.BufPool.Put(buf[:cap(buf)]) return } - conn := &fakeConn{ + packet := &fakeConn{ PacketConn: pc, remoteAddr: addr, targetAddr: target, - buffer: bytes.NewBuffer(payload), + payload: payload, bufRef: buf, } - tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP)) + tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS, C.UDP)) } diff --git a/proxy/socks/utils.go b/proxy/socks/utils.go index ea05961e..8cbe57e9 100644 --- a/proxy/socks/utils.go +++ b/proxy/socks/utils.go @@ -1,7 +1,6 @@ package socks import ( - "bytes" "net" "github.com/Dreamacro/clash/common/pool" @@ -12,23 +11,30 @@ type fakeConn struct { net.PacketConn remoteAddr net.Addr targetAddr socks5.Addr - buffer *bytes.Buffer + payload []byte bufRef []byte } -func (c *fakeConn) Read(b []byte) (n int, err error) { - return c.buffer.Read(b) +func (c *fakeConn) Data() []byte { + return c.payload } -func (c *fakeConn) Write(b []byte) (n int, err error) { - packet, err := socks5.EncodeUDPPacket(c.targetAddr, b) +// WriteBack wirtes UDP packet with source(ip, port) = `addr` +func (c *fakeConn) WriteBack(b []byte, addr net.Addr) (n int, err error) { + from := c.targetAddr + if addr != nil { + // if addr is provided, use the parsed addr + from = socks5.ParseAddrToSocksAddr(addr) + } + packet, err := socks5.EncodeUDPPacket(from, b) if err != nil { return } return c.PacketConn.WriteTo(packet, c.remoteAddr) } -func (c *fakeConn) RemoteAddr() net.Addr { +// LocalAddr returns the source IP/Port of UDP Packet +func (c *fakeConn) LocalAddr() net.Addr { return c.remoteAddr } diff --git a/tunnel/connection.go b/tunnel/connection.go index 2849bd9f..825d0b34 100644 --- a/tunnel/connection.go +++ b/tunnel/connection.go @@ -9,6 +9,8 @@ import ( "time" adapters "github.com/Dreamacro/clash/adapters/inbound" + C "github.com/Dreamacro/clash/constant" + "github.com/Dreamacro/clash/common/pool" ) @@ -79,21 +81,14 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) { } } -func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Addr) { - buf := pool.BufPool.Get().([]byte) - defer pool.BufPool.Put(buf[:cap(buf)]) - - n, err := conn.Read(buf) - if err != nil { +func (t *Tunnel) handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr net.Addr) { + if _, err := pc.WriteTo(packet.Data(), addr); err != nil { return } - if _, err = pc.WriteTo(buf[:n], addr); err != nil { - return - } - DefaultManager.Upload() <- int64(n) + DefaultManager.Upload() <- int64(len(packet.Data())) } -func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) { +func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, omitSrcAddr bool, timeout time.Duration) { buf := pool.BufPool.Get().([]byte) defer pool.BufPool.Put(buf[:cap(buf)]) defer t.natTable.Delete(key) @@ -101,12 +96,15 @@ func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, for { pc.SetReadDeadline(time.Now().Add(timeout)) - n, _, err := pc.ReadFrom(buf) + n, from, err := pc.ReadFrom(buf) if err != nil { return } + if from != nil && omitSrcAddr { + from = nil + } - n, err = conn.Write(buf[:n]) + n, err = packet.WriteBack(buf[:n], from) if err != nil { return } diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 69765e70..3a8285bc 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -3,6 +3,7 @@ package tunnel import ( "fmt" "net" + "runtime" "sync" "time" @@ -43,12 +44,12 @@ type Tunnel struct { // Add request to queue func (t *Tunnel) Add(req C.ServerAdapter) { - switch req.Metadata().NetWork { - case C.TCP: - t.tcpQueue.In() <- req - case C.UDP: - t.udpQueue.In() <- req - } + t.tcpQueue.In() <- req +} + +// AddPacket add udp Packet to queue +func (t *Tunnel) AddPacket(packet *inbound.PacketAdapter) { + t.udpQueue.In() <- packet } // Rules return all rules @@ -98,14 +99,23 @@ func (t *Tunnel) SetMode(mode Mode) { t.mode = mode } +// processUDP starts a loop to handle udp packet +func (t *Tunnel) processUDP() { + queue := t.udpQueue.Out() + for elm := range queue { + conn := elm.(*inbound.PacketAdapter) + t.handleUDPConn(conn) + } +} + func (t *Tunnel) process() { - go func() { - queue := t.udpQueue.Out() - for elm := range queue { - conn := elm.(C.ServerAdapter) - t.handleUDPConn(conn) - } - }() + numUDPWorkers := 4 + if runtime.NumCPU() > numUDPWorkers { + numUDPWorkers = runtime.NumCPU() + } + for i := 0; i < numUDPWorkers; i++ { + go t.processUDP() + } queue := t.tcpQueue.Out() for elm := range queue { @@ -119,7 +129,7 @@ func (t *Tunnel) resolveIP(host string) (net.IP, error) { } func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { - return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil + return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil } func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { @@ -134,7 +144,7 @@ func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) if exist { metadata.Host = host metadata.AddrType = C.AtypDomainName - if dns.DefaultResolver.IsFakeIP() { + if dns.DefaultResolver.FakeIPEnabled() { metadata.DstIP = nil } } @@ -158,25 +168,28 @@ func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) return proxy, rule, nil } -func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) { - metadata := localConn.Metadata() +func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) { + metadata := packet.Metadata() if !metadata.Valid() { log.Warnln("[Metadata] not valid: %#v", metadata) return } - src := localConn.RemoteAddr().String() + src := packet.LocalAddr().String() dst := metadata.RemoteAddress() key := src + "-" + dst pc, addr := t.natTable.Get(key) if pc != nil { - t.handleUDPToRemote(localConn, pc, addr) + t.handleUDPToRemote(packet, pc, addr) return } lockKey := key + "-lock" wg, loaded := t.natTable.GetOrCreateLock(lockKey) + + isFakeIP := dns.DefaultResolver.IsFakeIP(metadata.DstIP) + go func() { if !loaded { wg.Add(1) @@ -207,13 +220,14 @@ func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) { t.natTable.Set(key, pc, addr) t.natTable.Delete(lockKey) wg.Done() - go t.handleUDPToLocal(localConn, pc, key, udpTimeout) + // in fake-ip mode, Full-Cone NAT can never achieve, fallback to omitting src Addr + go t.handleUDPToLocal(packet.UDPPacket, pc, key, isFakeIP, udpTimeout) } wg.Wait() pc, addr := t.natTable.Get(key) if pc != nil { - t.handleUDPToRemote(localConn, pc, addr) + t.handleUDPToRemote(packet, pc, addr) } }() }