diff --git a/component/domain-trie/tire.go b/component/domain-trie/tire.go index 13567a98..26062fc2 100644 --- a/component/domain-trie/tire.go +++ b/component/domain-trie/tire.go @@ -21,17 +21,21 @@ type Trie struct { root *Node } +func isValidDomain(domain string) bool { + return domain[0] != '.' && domain[len(domain)-1] != '.' +} + // Insert adds a node to the trie. // Support // 1. www.example.com // 2. *.example.com // 3. subdomain.*.example.com func (t *Trie) Insert(domain string, data interface{}) error { - parts := strings.Split(domain, domainStep) - if len(parts) < 2 { + if !isValidDomain(domain) { return ErrInvalidDomain } + parts := strings.Split(domain, domainStep) node := t.root // reverse storage domain part to save space for i := len(parts) - 1; i >= 0; i-- { @@ -52,10 +56,10 @@ func (t *Trie) Insert(domain string, data interface{}) error { // 1. static part // 2. wildcard domain func (t *Trie) Search(domain string) *Node { - parts := strings.Split(domain, domainStep) - if len(parts) < 2 { + if !isValidDomain(domain) { return nil } + parts := strings.Split(domain, domainStep) n := t.root for i := len(parts) - 1; i >= 0; i-- { diff --git a/component/domain-trie/trie_test.go b/component/domain-trie/trie_test.go index bd594e97..5303596b 100644 --- a/component/domain-trie/trie_test.go +++ b/component/domain-trie/trie_test.go @@ -65,11 +65,19 @@ func TestTrie_Boundary(t *testing.T) { tree := New() tree.Insert("*.dev", localIP) - if err := tree.Insert("com", localIP); err == nil { + if err := tree.Insert(".", localIP); err == nil { + t.Error("should recv err") + } + + if err := tree.Insert(".com", localIP); err == nil { t.Error("should recv err") } if tree.Search("dev") != nil { t.Error("should recv nil") } + + if tree.Search(".dev") != nil { + t.Error("should recv nil") + } }