Improve: UDP relay refactor (#441)

Co-authored-by: Dreamacro <Dreamacro@vip.qq.com>
This commit is contained in:
gVisor bot 2019-12-28 18:44:01 +08:00
parent 7fd4645237
commit 95f06ab9b9
11 changed files with 202 additions and 52 deletions

View File

@ -0,0 +1,33 @@
package inbound
import (
"github.com/Dreamacro/clash/component/socks5"
C "github.com/Dreamacro/clash/constant"
)
// PacketAdapter is a UDP Packet adapter for socks/redir/tun
type PacketAdapter struct {
C.UDPPacket
metadata *C.Metadata
}
// Metadata returns destination metadata
func (s *PacketAdapter) Metadata() *C.Metadata {
return s.metadata
}
// NewPacket is PacketAdapter generator
func NewPacket(target socks5.Addr, packet C.UDPPacket, source C.Type, netType C.NetWork) *PacketAdapter {
metadata := parseSocksAddr(target)
metadata.NetWork = netType
metadata.Type = source
if ip, port, err := parseAddr(packet.LocalAddr().String()); err == nil {
metadata.SrcIP = ip
metadata.SrcPort = port
}
return &PacketAdapter{
UDPPacket: packet,
metadata: metadata,
}
}

View File

@ -201,8 +201,13 @@ func (uc *ssUDPConn) WriteTo(b []byte, addr net.Addr) (n int, err error) {
} }
func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) { func (uc *ssUDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, a, e := uc.PacketConn.ReadFrom(b) n, _, e := uc.PacketConn.ReadFrom(b)
addr := socks5.SplitAddr(b[:n]) addr := socks5.SplitAddr(b[:n])
var from net.Addr
if e == nil {
// Get the source IP/Port of packet.
from = addr.UDPAddr()
}
copy(b, b[len(addr):]) copy(b, b[len(addr):])
return n - len(addr), a, e return n - len(addr), from, e
} }

View File

@ -62,12 +62,26 @@ func (p *Pool) LookBack(ip net.IP) (string, bool) {
return "", false return "", false
} }
// LookupHost return if host in host // LookupHost return if domain in host
func (p *Pool) LookupHost(host string) bool { func (p *Pool) LookupHost(domain string) bool {
if p.host == nil { if p.host == nil {
return false return false
} }
return p.host.Search(host) != nil return p.host.Search(domain) != nil
}
// Exist returns if given ip exists in fake-ip pool
func (p *Pool) Exist(ip net.IP) bool {
p.mux.Lock()
defer p.mux.Unlock()
if ip = ip.To4(); ip == nil {
return false
}
n := ipToUint(ip.To4())
offset := n - p.min + 1
return p.cache.Exist(offset)
} }
// Gateway return gateway ip // Gateway return gateway ip

View File

@ -2,6 +2,7 @@ package socks5
import ( import (
"bytes" "bytes"
"encoding/binary"
"errors" "errors"
"io" "io"
"net" "net"
@ -62,6 +63,25 @@ func (a Addr) String() string {
return net.JoinHostPort(host, port) return net.JoinHostPort(host, port)
} }
// UDPAddr converts a socks5.Addr to *net.UDPAddr
func (a Addr) UDPAddr() *net.UDPAddr {
if len(a) == 0 {
return nil
}
switch a[0] {
case AtypIPv4:
var ip [net.IPv4len]byte
copy(ip[0:], a[1:1+net.IPv4len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv4len : 1+net.IPv4len+2]))}
case AtypIPv6:
var ip [net.IPv6len]byte
copy(ip[0:], a[1:1+net.IPv6len])
return &net.UDPAddr{IP: net.IP(ip[:]), Port: int(binary.BigEndian.Uint16(a[1+net.IPv6len : 1+net.IPv6len+2]))}
}
// Other Atyp
return nil
}
// SOCKS errors as defined in RFC 1928 section 6. // SOCKS errors as defined in RFC 1928 section 6.
const ( const (
ErrGeneralFailure = Error(1) ErrGeneralFailure = Error(1)
@ -338,6 +358,40 @@ func ParseAddr(s string) Addr {
return addr return addr
} }
// ParseAddrToSocksAddr parse a socks addr from net.addr
// This is a fast path of ParseAddr(addr.String())
func ParseAddrToSocksAddr(addr net.Addr) Addr {
var hostip net.IP
var port int
if udpaddr, ok := addr.(*net.UDPAddr); ok {
hostip = udpaddr.IP
port = udpaddr.Port
} else if tcpaddr, ok := addr.(*net.TCPAddr); ok {
hostip = tcpaddr.IP
port = tcpaddr.Port
}
// fallback parse
if hostip == nil {
return ParseAddr(addr.String())
}
var parsed Addr
if ip4 := hostip.To4(); ip4.DefaultMask() != nil {
parsed = make([]byte, 1+net.IPv4len+2)
parsed[0] = AtypIPv4
copy(parsed[1:], ip4)
binary.BigEndian.PutUint16(parsed[1+net.IPv4len:], uint16(port))
} else {
parsed = make([]byte, 1+net.IPv6len+2)
parsed[0] = AtypIPv6
copy(parsed[1:], hostip)
binary.BigEndian.PutUint16(parsed[1+net.IPv6len:], uint16(port))
}
return parsed
}
// DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet` // DecodeUDPPacket split `packet` to addr payload, and this function is mutable with `packet`
func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) { func DecodeUDPPacket(packet []byte) (addr Addr, payload []byte, err error) {
if len(packet) < 5 { if len(packet) < 5 {

View File

@ -109,3 +109,21 @@ func (at AdapterType) String() string {
return "Unknown" return "Unknown"
} }
} }
// UDPPacket contains the data of UDP packet, and offers control/info of UDP packet's source
type UDPPacket interface {
// Data get the payload of UDP Packet
Data() []byte
// WriteBack writes the payload with source IP/Port equals addr
// - variable source IP/Port is important to STUN
// - if addr is not provided, WriteBack will wirte out UDP packet with SourceIP/Prot equals to origional Target,
// this is important when using Fake-IP.
WriteBack(b []byte, addr net.Addr) (n int, err error)
// Close closes the underlaying connection.
Close() error
// LocalAddr returns the source IP/Port of packet
LocalAddr() net.Addr
}

View File

@ -75,7 +75,7 @@ func compose(middlewares []middleware, endpoint handler) handler {
func newHandler(resolver *Resolver) handler { func newHandler(resolver *Resolver) handler {
middlewares := []middleware{} middlewares := []middleware{}
if resolver.IsFakeIP() { if resolver.FakeIPEnabled() {
middlewares = append(middlewares, withFakeIP(resolver.pool)) middlewares = append(middlewares, withFakeIP(resolver.pool))
} }

View File

@ -166,10 +166,19 @@ func (r *Resolver) IsMapping() bool {
return r.mapping return r.mapping
} }
func (r *Resolver) IsFakeIP() bool { // FakeIPEnabled returns if fake-ip is enabled
func (r *Resolver) FakeIPEnabled() bool {
return r.fakeip return r.fakeip
} }
// IsFakeIP determine if given ip is a fake-ip
func (r *Resolver) IsFakeIP(ip net.IP) bool {
if r.FakeIPEnabled() {
return r.pool.Exist(ip)
}
return false
}
func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) {
fast, ctx := picker.WithTimeout(context.Background(), time.Second) fast, ctx := picker.WithTimeout(context.Background(), time.Second)
for _, client := range clients { for _, client := range clients {

View File

@ -1,7 +1,6 @@
package socks package socks
import ( import (
"bytes"
"net" "net"
adapters "github.com/Dreamacro/clash/adapters/inbound" adapters "github.com/Dreamacro/clash/adapters/inbound"
@ -57,12 +56,12 @@ func handleSocksUDP(pc net.PacketConn, buf []byte, addr net.Addr) {
pool.BufPool.Put(buf[:cap(buf)]) pool.BufPool.Put(buf[:cap(buf)])
return return
} }
conn := &fakeConn{ packet := &fakeConn{
PacketConn: pc, PacketConn: pc,
remoteAddr: addr, remoteAddr: addr,
targetAddr: target, targetAddr: target,
buffer: bytes.NewBuffer(payload), payload: payload,
bufRef: buf, bufRef: buf,
} }
tun.Add(adapters.NewSocket(target, conn, C.SOCKS, C.UDP)) tun.AddPacket(adapters.NewPacket(target, packet, C.SOCKS, C.UDP))
} }

View File

@ -1,7 +1,6 @@
package socks package socks
import ( import (
"bytes"
"net" "net"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
@ -12,23 +11,30 @@ type fakeConn struct {
net.PacketConn net.PacketConn
remoteAddr net.Addr remoteAddr net.Addr
targetAddr socks5.Addr targetAddr socks5.Addr
buffer *bytes.Buffer payload []byte
bufRef []byte bufRef []byte
} }
func (c *fakeConn) Read(b []byte) (n int, err error) { func (c *fakeConn) Data() []byte {
return c.buffer.Read(b) return c.payload
} }
func (c *fakeConn) Write(b []byte) (n int, err error) { // WriteBack wirtes UDP packet with source(ip, port) = `addr`
packet, err := socks5.EncodeUDPPacket(c.targetAddr, b) func (c *fakeConn) WriteBack(b []byte, addr net.Addr) (n int, err error) {
from := c.targetAddr
if addr != nil {
// if addr is provided, use the parsed addr
from = socks5.ParseAddrToSocksAddr(addr)
}
packet, err := socks5.EncodeUDPPacket(from, b)
if err != nil { if err != nil {
return return
} }
return c.PacketConn.WriteTo(packet, c.remoteAddr) return c.PacketConn.WriteTo(packet, c.remoteAddr)
} }
func (c *fakeConn) RemoteAddr() net.Addr { // LocalAddr returns the source IP/Port of UDP Packet
func (c *fakeConn) LocalAddr() net.Addr {
return c.remoteAddr return c.remoteAddr
} }

View File

@ -9,6 +9,8 @@ import (
"time" "time"
adapters "github.com/Dreamacro/clash/adapters/inbound" adapters "github.com/Dreamacro/clash/adapters/inbound"
C "github.com/Dreamacro/clash/constant"
"github.com/Dreamacro/clash/common/pool" "github.com/Dreamacro/clash/common/pool"
) )
@ -79,21 +81,14 @@ func (t *Tunnel) handleHTTP(request *adapters.HTTPAdapter, outbound net.Conn) {
} }
} }
func (t *Tunnel) handleUDPToRemote(conn net.Conn, pc net.PacketConn, addr net.Addr) { func (t *Tunnel) handleUDPToRemote(packet C.UDPPacket, pc net.PacketConn, addr net.Addr) {
buf := pool.BufPool.Get().([]byte) if _, err := pc.WriteTo(packet.Data(), addr); err != nil {
defer pool.BufPool.Put(buf[:cap(buf)])
n, err := conn.Read(buf)
if err != nil {
return return
} }
if _, err = pc.WriteTo(buf[:n], addr); err != nil { DefaultManager.Upload() <- int64(len(packet.Data()))
return
}
DefaultManager.Upload() <- int64(n)
} }
func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string, timeout time.Duration) { func (t *Tunnel) handleUDPToLocal(packet C.UDPPacket, pc net.PacketConn, key string, omitSrcAddr bool, timeout time.Duration) {
buf := pool.BufPool.Get().([]byte) buf := pool.BufPool.Get().([]byte)
defer pool.BufPool.Put(buf[:cap(buf)]) defer pool.BufPool.Put(buf[:cap(buf)])
defer t.natTable.Delete(key) defer t.natTable.Delete(key)
@ -101,12 +96,15 @@ func (t *Tunnel) handleUDPToLocal(conn net.Conn, pc net.PacketConn, key string,
for { for {
pc.SetReadDeadline(time.Now().Add(timeout)) pc.SetReadDeadline(time.Now().Add(timeout))
n, _, err := pc.ReadFrom(buf) n, from, err := pc.ReadFrom(buf)
if err != nil { if err != nil {
return return
} }
if from != nil && omitSrcAddr {
from = nil
}
n, err = conn.Write(buf[:n]) n, err = packet.WriteBack(buf[:n], from)
if err != nil { if err != nil {
return return
} }

View File

@ -3,6 +3,7 @@ package tunnel
import ( import (
"fmt" "fmt"
"net" "net"
"runtime"
"sync" "sync"
"time" "time"
@ -43,12 +44,12 @@ type Tunnel struct {
// Add request to queue // Add request to queue
func (t *Tunnel) Add(req C.ServerAdapter) { func (t *Tunnel) Add(req C.ServerAdapter) {
switch req.Metadata().NetWork { t.tcpQueue.In() <- req
case C.TCP: }
t.tcpQueue.In() <- req
case C.UDP: // AddPacket add udp Packet to queue
t.udpQueue.In() <- req func (t *Tunnel) AddPacket(packet *inbound.PacketAdapter) {
} t.udpQueue.In() <- packet
} }
// Rules return all rules // Rules return all rules
@ -98,14 +99,23 @@ func (t *Tunnel) SetMode(mode Mode) {
t.mode = mode t.mode = mode
} }
// processUDP starts a loop to handle udp packet
func (t *Tunnel) processUDP() {
queue := t.udpQueue.Out()
for elm := range queue {
conn := elm.(*inbound.PacketAdapter)
t.handleUDPConn(conn)
}
}
func (t *Tunnel) process() { func (t *Tunnel) process() {
go func() { numUDPWorkers := 4
queue := t.udpQueue.Out() if runtime.NumCPU() > numUDPWorkers {
for elm := range queue { numUDPWorkers = runtime.NumCPU()
conn := elm.(C.ServerAdapter) }
t.handleUDPConn(conn) for i := 0; i < numUDPWorkers; i++ {
} go t.processUDP()
}() }
queue := t.tcpQueue.Out() queue := t.tcpQueue.Out()
for elm := range queue { for elm := range queue {
@ -119,7 +129,7 @@ func (t *Tunnel) resolveIP(host string) (net.IP, error) {
} }
func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool { func (t *Tunnel) needLookupIP(metadata *C.Metadata) bool {
return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.IsFakeIP()) && metadata.Host == "" && metadata.DstIP != nil return dns.DefaultResolver != nil && (dns.DefaultResolver.IsMapping() || dns.DefaultResolver.FakeIPEnabled()) && metadata.Host == "" && metadata.DstIP != nil
} }
func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) { func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error) {
@ -134,7 +144,7 @@ func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error)
if exist { if exist {
metadata.Host = host metadata.Host = host
metadata.AddrType = C.AtypDomainName metadata.AddrType = C.AtypDomainName
if dns.DefaultResolver.IsFakeIP() { if dns.DefaultResolver.FakeIPEnabled() {
metadata.DstIP = nil metadata.DstIP = nil
} }
} }
@ -158,25 +168,28 @@ func (t *Tunnel) resolveMetadata(metadata *C.Metadata) (C.Proxy, C.Rule, error)
return proxy, rule, nil return proxy, rule, nil
} }
func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) { func (t *Tunnel) handleUDPConn(packet *inbound.PacketAdapter) {
metadata := localConn.Metadata() metadata := packet.Metadata()
if !metadata.Valid() { if !metadata.Valid() {
log.Warnln("[Metadata] not valid: %#v", metadata) log.Warnln("[Metadata] not valid: %#v", metadata)
return return
} }
src := localConn.RemoteAddr().String() src := packet.LocalAddr().String()
dst := metadata.RemoteAddress() dst := metadata.RemoteAddress()
key := src + "-" + dst key := src + "-" + dst
pc, addr := t.natTable.Get(key) pc, addr := t.natTable.Get(key)
if pc != nil { if pc != nil {
t.handleUDPToRemote(localConn, pc, addr) t.handleUDPToRemote(packet, pc, addr)
return return
} }
lockKey := key + "-lock" lockKey := key + "-lock"
wg, loaded := t.natTable.GetOrCreateLock(lockKey) wg, loaded := t.natTable.GetOrCreateLock(lockKey)
isFakeIP := dns.DefaultResolver.IsFakeIP(metadata.DstIP)
go func() { go func() {
if !loaded { if !loaded {
wg.Add(1) wg.Add(1)
@ -207,13 +220,14 @@ func (t *Tunnel) handleUDPConn(localConn C.ServerAdapter) {
t.natTable.Set(key, pc, addr) t.natTable.Set(key, pc, addr)
t.natTable.Delete(lockKey) t.natTable.Delete(lockKey)
wg.Done() wg.Done()
go t.handleUDPToLocal(localConn, pc, key, udpTimeout) // in fake-ip mode, Full-Cone NAT can never achieve, fallback to omitting src Addr
go t.handleUDPToLocal(packet.UDPPacket, pc, key, isFakeIP, udpTimeout)
} }
wg.Wait() wg.Wait()
pc, addr := t.natTable.Get(key) pc, addr := t.natTable.Get(key)
if pc != nil { if pc != nil {
t.handleUDPToRemote(localConn, pc, addr) t.handleUDPToRemote(packet, pc, addr)
} }
}() }()
} }