From f066687f937aba6916cde9fd59e1a26d8a00ba66 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 28 Oct 2020 21:26:50 +0800 Subject: [PATCH] Fix: tunnel UDP race condition (#1043) --- component/nat/table.go | 6 +-- tunnel/tunnel.go | 90 +++++++++++++++++++++++------------------- 2 files changed, 52 insertions(+), 44 deletions(-) diff --git a/component/nat/table.go b/component/nat/table.go index 9a8696b7..fbb16dec 100644 --- a/component/nat/table.go +++ b/component/nat/table.go @@ -22,9 +22,9 @@ func (t *Table) Get(key string) C.PacketConn { return item.(C.PacketConn) } -func (t *Table) GetOrCreateLock(key string) (*sync.WaitGroup, bool) { - item, loaded := t.mapping.LoadOrStore(key, &sync.WaitGroup{}) - return item.(*sync.WaitGroup), loaded +func (t *Table) GetOrCreateLock(key string) (*sync.Cond, bool) { + item, loaded := t.mapping.LoadOrStore(key, sync.NewCond(&sync.Mutex{})) + return item.(*sync.Cond), loaded } func (t *Table) Delete(key string) { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index 19d9f8b2..2e7b9696 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -164,7 +164,7 @@ func handleUDPConn(packet *inbound.PacketAdapter) { return } - // make a fAddr if requset ip is fakeip + // make a fAddr if request ip is fakeip var fAddr net.Addr if resolver.IsExistFakeIP(metadata.DstIP) { fAddr = metadata.UDPAddr() @@ -176,57 +176,65 @@ func handleUDPConn(packet *inbound.PacketAdapter) { } key := packet.LocalAddr().String() - pc := natTable.Get(key) - if pc != nil { - handleUDPToRemote(packet, pc, metadata) + + handle := func() bool { + pc := natTable.Get(key) + if pc != nil { + handleUDPToRemote(packet, pc, metadata) + return true + } + return false + } + + if handle() { return } lockKey := key + "-lock" - wg, loaded := natTable.GetOrCreateLock(lockKey) + cond, loaded := natTable.GetOrCreateLock(lockKey) go func() { - if !loaded { - wg.Add(1) - proxy, rule, err := resolveMetadata(metadata) - if err != nil { - log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) - natTable.Delete(lockKey) - wg.Done() - return - } + if loaded { + cond.L.Lock() + cond.Wait() + handle() + cond.L.Unlock() + return + } - rawPc, err := proxy.DialUDP(metadata) - if err != nil { - log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error()) - natTable.Delete(lockKey) - wg.Done() - return - } - pc = newUDPTracker(rawPc, DefaultManager, metadata, rule) - - switch true { - case rule != nil: - log.Infoln("[UDP] %s --> %v match %s(%s) using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rule.Payload(), rawPc.Chains().String()) - case mode == Global: - log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String()) - case mode == Direct: - log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String()) - default: - log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) - } - - natTable.Set(key, pc) + defer func() { natTable.Delete(lockKey) - wg.Done() - go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr) + cond.Broadcast() + }() + + proxy, rule, err := resolveMetadata(metadata) + if err != nil { + log.Warnln("[UDP] Parse metadata failed: %s", err.Error()) + return } - wg.Wait() - pc := natTable.Get(key) - if pc != nil { - handleUDPToRemote(packet, pc, metadata) + rawPc, err := proxy.DialUDP(metadata) + if err != nil { + log.Warnln("[UDP] dial %s error: %s", proxy.Name(), err.Error()) + return } + pc := newUDPTracker(rawPc, DefaultManager, metadata, rule) + + switch true { + case rule != nil: + log.Infoln("[UDP] %s --> %v match %s(%s) using %s", metadata.SourceAddress(), metadata.String(), rule.RuleType().String(), rule.Payload(), rawPc.Chains().String()) + case mode == Global: + log.Infoln("[UDP] %s --> %v using GLOBAL", metadata.SourceAddress(), metadata.String()) + case mode == Direct: + log.Infoln("[UDP] %s --> %v using DIRECT", metadata.SourceAddress(), metadata.String()) + default: + log.Infoln("[UDP] %s --> %v doesn't match any rule using DIRECT", metadata.SourceAddress(), metadata.String()) + } + + go handleUDPToLocal(packet.UDPPacket, pc, key, fAddr) + + natTable.Set(key, pc) + handle() }() }