Feature: support domain in fallback filter (#964)

This commit is contained in:
Melvin 2020-09-28 22:17:10 +08:00 committed by GitHub
parent e09931dcf7
commit a6444bb449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 89 additions and 19 deletions

View File

@ -69,6 +69,7 @@ type DNS struct {
type FallbackFilter struct { type FallbackFilter struct {
GeoIP bool `yaml:"geoip"` GeoIP bool `yaml:"geoip"`
IPCIDR []*net.IPNet `yaml:"ipcidr"` IPCIDR []*net.IPNet `yaml:"ipcidr"`
Domain []string `yaml:"domain"`
} }
// Experimental config // Experimental config
@ -103,6 +104,7 @@ type RawDNS struct {
type RawFallbackFilter struct { type RawFallbackFilter struct {
GeoIP bool `yaml:"geoip"` GeoIP bool `yaml:"geoip"`
IPCIDR []string `yaml:"ipcidr"` IPCIDR []string `yaml:"ipcidr"`
Domain []string `yaml:"domain"`
} }
type RawConfig struct { type RawConfig struct {
@ -561,6 +563,7 @@ func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) {
if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil { if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil {
dnsCfg.FallbackFilter.IPCIDR = fallbackip dnsCfg.FallbackFilter.IPCIDR = fallbackip
} }
dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain
if cfg.UseHosts { if cfg.UseHosts {
dnsCfg.Hosts = hosts dnsCfg.Hosts = hosts

View File

@ -4,9 +4,10 @@ import (
"net" "net"
"github.com/Dreamacro/clash/component/mmdb" "github.com/Dreamacro/clash/component/mmdb"
"github.com/Dreamacro/clash/component/trie"
) )
type fallbackFilter interface { type fallbackIPFilter interface {
Match(net.IP) bool Match(net.IP) bool
} }
@ -24,3 +25,22 @@ type ipnetFilter struct {
func (inf *ipnetFilter) Match(ip net.IP) bool { func (inf *ipnetFilter) Match(ip net.IP) bool {
return inf.ipnet.Contains(ip) return inf.ipnet.Contains(ip)
} }
type fallbackDomainFilter interface {
Match(domain string) bool
}
type domainFilter struct {
tree *trie.DomainTrie
}
func NewDomainFilter(domains []string) *domainFilter {
df := domainFilter{tree: trie.New()}
for _, domain := range domains {
df.tree.Insert(domain, "")
}
return &df
}
func (df *domainFilter) Match(domain string) bool {
return df.tree.Search(domain) != nil
}

View File

@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"math/rand" "math/rand"
"net" "net"
"strings"
"time" "time"
"github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/cache"
@ -34,13 +35,14 @@ type result struct {
} }
type Resolver struct { type Resolver struct {
ipv6 bool ipv6 bool
hosts *trie.DomainTrie hosts *trie.DomainTrie
main []dnsClient main []dnsClient
fallback []dnsClient fallback []dnsClient
fallbackFilters []fallbackFilter fallbackDomainFilters []fallbackDomainFilter
group singleflight.Group fallbackIPFilters []fallbackIPFilter
lruCache *cache.LruCache group singleflight.Group
lruCache *cache.LruCache
} }
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
@ -78,8 +80,8 @@ func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) {
return r.resolveIP(host, D.TypeAAAA) return r.resolveIP(host, D.TypeAAAA)
} }
func (r *Resolver) shouldFallback(ip net.IP) bool { func (r *Resolver) shouldIPFallback(ip net.IP) bool {
for _, filter := range r.fallbackFilters { for _, filter := range r.fallbackIPFilters {
if filter.Match(ip) { if filter.Match(ip) {
return true return true
} }
@ -126,7 +128,7 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) {
isIPReq := isIPRequest(q) isIPReq := isIPRequest(q)
if isIPReq { if isIPReq {
return r.fallbackExchange(m) return r.ipExchange(m)
} }
return r.batchExchange(r.main, m) return r.batchExchange(r.main, m)
@ -170,19 +172,49 @@ func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err
return return
} }
func (r *Resolver) fallbackExchange(m *D.Msg) (msg *D.Msg, err error) { func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool {
if r.fallback == nil || len(r.fallbackDomainFilters) == 0 {
return false
}
domain := r.msgToDomain(m)
if domain == "" {
return false
}
for _, df := range r.fallbackDomainFilters {
if df.Match(domain) {
return true
}
}
return false
}
func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) {
onlyFallback := r.shouldOnlyQueryFallback(m)
if onlyFallback {
res := <-r.asyncExchange(r.fallback, m)
return res.Msg, res.Error
}
msgCh := r.asyncExchange(r.main, m) msgCh := r.asyncExchange(r.main, m)
if r.fallback == nil {
if r.fallback == nil { // directly return if no fallback servers are available
res := <-msgCh res := <-msgCh
msg, err = res.Msg, res.Error msg, err = res.Msg, res.Error
return return
} }
fallbackMsg := r.asyncExchange(r.fallback, m) fallbackMsg := r.asyncExchange(r.fallback, m)
res := <-msgCh res := <-msgCh
if res.Error == nil { if res.Error == nil {
if ips := r.msgToIP(res.Msg); len(ips) != 0 { if ips := r.msgToIP(res.Msg); len(ips) != 0 {
if !r.shouldFallback(ips[0]) { if !r.shouldIPFallback(ips[0]) {
msg = res.Msg msg = res.Msg // no need to wait for fallback result
err = res.Error err = res.Error
return msg, err return msg, err
} }
@ -240,6 +272,14 @@ func (r *Resolver) msgToIP(msg *D.Msg) []net.IP {
return ips return ips
} }
func (r *Resolver) msgToDomain(msg *D.Msg) string {
if len(msg.Question) > 0 {
return strings.TrimRight(msg.Question[0].Name, ".")
}
return ""
}
func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result { func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result {
ch := make(chan *result, 1) ch := make(chan *result, 1)
go func() { go func() {
@ -257,6 +297,7 @@ type NameServer struct {
type FallbackFilter struct { type FallbackFilter struct {
GeoIP bool GeoIP bool
IPCIDR []*net.IPNet IPCIDR []*net.IPNet
Domain []string
} }
type Config struct { type Config struct {
@ -286,14 +327,19 @@ func NewResolver(config Config) *Resolver {
r.fallback = transform(config.Fallback, defaultResolver) r.fallback = transform(config.Fallback, defaultResolver)
} }
fallbackFilters := []fallbackFilter{} fallbackIPFilters := []fallbackIPFilter{}
if config.FallbackFilter.GeoIP { if config.FallbackFilter.GeoIP {
fallbackFilters = append(fallbackFilters, &geoipFilter{}) fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{})
} }
for _, ipnet := range config.FallbackFilter.IPCIDR { for _, ipnet := range config.FallbackFilter.IPCIDR {
fallbackFilters = append(fallbackFilters, &ipnetFilter{ipnet: ipnet}) fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet})
}
r.fallbackIPFilters = fallbackIPFilters
if len(config.FallbackFilter.Domain) != 0 {
fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)}
r.fallbackDomainFilters = fallbackDomainFilters
} }
r.fallbackFilters = fallbackFilters
return r return r
} }

View File

@ -118,6 +118,7 @@ func updateDNS(c *config.DNS) {
FallbackFilter: dns.FallbackFilter{ FallbackFilter: dns.FallbackFilter{
GeoIP: c.FallbackFilter.GeoIP, GeoIP: c.FallbackFilter.GeoIP,
IPCIDR: c.FallbackFilter.IPCIDR, IPCIDR: c.FallbackFilter.IPCIDR,
Domain: c.FallbackFilter.Domain,
}, },
Default: c.DefaultNameserver, Default: c.DefaultNameserver,
} }