From 806047488ad3f6f9c8453a3efd08c9065f549538 Mon Sep 17 00:00:00 2001 From: gVisor bot Date: Wed, 24 Jun 2020 18:41:23 +0800 Subject: [PATCH] Fix: domain trie should backtrack to parent if match fail (#758) --- component/trie/domain.go | 58 +++++++++++++++-------------------- component/trie/domain_test.go | 5 +++ 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/component/trie/domain.go b/component/trie/domain.go index e8dcf213..8df22792 100644 --- a/component/trie/domain.go +++ b/component/trie/domain.go @@ -96,39 +96,7 @@ func (t *DomainTrie) Search(domain string) *Node { return nil } - n := t.root - var dotWildcardNode *Node - var wildcardNode *Node - for i := len(parts) - 1; i >= 0; i-- { - part := parts[i] - - if node := n.getChild(dotWildcard); node != nil { - dotWildcardNode = node - } - - child := n.getChild(part) - if child == nil && wildcardNode != nil { - child = wildcardNode.getChild(part) - } - wildcardNode = n.getChild(wildcard) - - n = child - if n == nil { - n = wildcardNode - wildcardNode = nil - } - - if n == nil { - break - } - } - - if n == nil { - if dotWildcardNode != nil { - return dotWildcardNode - } - return nil - } + n := t.search(t.root, parts) if n.Data == nil { return nil @@ -137,6 +105,30 @@ func (t *DomainTrie) Search(domain string) *Node { return n } +func (t *DomainTrie) search(node *Node, parts []string) *Node { + if len(parts) == 0 { + return node + } + + if c := node.getChild(parts[len(parts)-1]); c != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil { + return n + } + } + + if c := node.getChild(wildcard); c != nil { + if n := t.search(c, parts[:len(parts)-1]); n != nil { + return n + } + } + + if c := node.getChild(dotWildcard); c != nil { + return c + } + + return nil +} + // New returns a new, empty Trie. func New() *DomainTrie { return &DomainTrie{root: newNode(nil)} diff --git a/component/trie/domain_test.go b/component/trie/domain_test.go index 38b347e1..3e150717 100644 --- a/component/trie/domain_test.go +++ b/component/trie/domain_test.go @@ -39,6 +39,10 @@ func TestTrie_Wildcard(t *testing.T) { ".example.net", ".apple.*", "+.foo.com", + "+.stun.*.*", + "+.stun.*.*.*", + "+.stun.*.*.*.*", + "stun.l.google.com", } for _, domain := range domains { @@ -52,6 +56,7 @@ func TestTrie_Wildcard(t *testing.T) { assert.NotNil(t, tree.Search("test.apple.com")) assert.NotNil(t, tree.Search("test.foo.com")) assert.NotNil(t, tree.Search("foo.com")) + assert.NotNil(t, tree.Search("global.stun.website.com")) assert.Nil(t, tree.Search("foo.sub.example.com")) assert.Nil(t, tree.Search("foo.example.dev")) assert.Nil(t, tree.Search("example.com"))