diff --git a/config/config.go b/config/config.go index b9e42dd4..9c39b7ef 100644 --- a/config/config.go +++ b/config/config.go @@ -69,6 +69,7 @@ type DNS struct { type FallbackFilter struct { GeoIP bool `yaml:"geoip"` IPCIDR []*net.IPNet `yaml:"ipcidr"` + Domain []string `yaml:"domain"` } // Experimental config @@ -103,6 +104,7 @@ type RawDNS struct { type RawFallbackFilter struct { GeoIP bool `yaml:"geoip"` IPCIDR []string `yaml:"ipcidr"` + Domain []string `yaml:"domain"` } type RawConfig struct { @@ -561,6 +563,7 @@ func parseDNS(cfg RawDNS, hosts *trie.DomainTrie) (*DNS, error) { if fallbackip, err := parseFallbackIPCIDR(cfg.FallbackFilter.IPCIDR); err == nil { dnsCfg.FallbackFilter.IPCIDR = fallbackip } + dnsCfg.FallbackFilter.Domain = cfg.FallbackFilter.Domain if cfg.UseHosts { dnsCfg.Hosts = hosts diff --git a/dns/filters.go b/dns/filters.go index 03089d4c..583883fa 100644 --- a/dns/filters.go +++ b/dns/filters.go @@ -4,9 +4,10 @@ import ( "net" "github.com/Dreamacro/clash/component/mmdb" + "github.com/Dreamacro/clash/component/trie" ) -type fallbackFilter interface { +type fallbackIPFilter interface { Match(net.IP) bool } @@ -24,3 +25,22 @@ type ipnetFilter struct { func (inf *ipnetFilter) Match(ip net.IP) bool { return inf.ipnet.Contains(ip) } + +type fallbackDomainFilter interface { + Match(domain string) bool +} +type domainFilter struct { + tree *trie.DomainTrie +} + +func NewDomainFilter(domains []string) *domainFilter { + df := domainFilter{tree: trie.New()} + for _, domain := range domains { + df.tree.Insert(domain, "") + } + return &df +} + +func (df *domainFilter) Match(domain string) bool { + return df.tree.Search(domain) != nil +} diff --git a/dns/resolver.go b/dns/resolver.go index b7d20b5c..93e3ca6e 100644 --- a/dns/resolver.go +++ b/dns/resolver.go @@ -7,6 +7,7 @@ import ( "fmt" "math/rand" "net" + "strings" "time" "github.com/Dreamacro/clash/common/cache" @@ -34,13 +35,14 @@ type result struct { } type Resolver struct { - ipv6 bool - hosts *trie.DomainTrie - main []dnsClient - fallback []dnsClient - fallbackFilters []fallbackFilter - group singleflight.Group - lruCache *cache.LruCache + ipv6 bool + hosts *trie.DomainTrie + main []dnsClient + fallback []dnsClient + fallbackDomainFilters []fallbackDomainFilter + fallbackIPFilters []fallbackIPFilter + group singleflight.Group + lruCache *cache.LruCache } // ResolveIP request with TypeA and TypeAAAA, priority return TypeA @@ -78,8 +80,8 @@ func (r *Resolver) ResolveIPv6(host string) (ip net.IP, err error) { return r.resolveIP(host, D.TypeAAAA) } -func (r *Resolver) shouldFallback(ip net.IP) bool { - for _, filter := range r.fallbackFilters { +func (r *Resolver) shouldIPFallback(ip net.IP) bool { + for _, filter := range r.fallbackIPFilters { if filter.Match(ip) { return true } @@ -126,7 +128,7 @@ func (r *Resolver) exchangeWithoutCache(m *D.Msg) (msg *D.Msg, err error) { isIPReq := isIPRequest(q) if isIPReq { - return r.fallbackExchange(m) + return r.ipExchange(m) } return r.batchExchange(r.main, m) @@ -170,19 +172,49 @@ func (r *Resolver) batchExchange(clients []dnsClient, m *D.Msg) (msg *D.Msg, err return } -func (r *Resolver) fallbackExchange(m *D.Msg) (msg *D.Msg, err error) { +func (r *Resolver) shouldOnlyQueryFallback(m *D.Msg) bool { + if r.fallback == nil || len(r.fallbackDomainFilters) == 0 { + return false + } + + domain := r.msgToDomain(m) + + if domain == "" { + return false + } + + for _, df := range r.fallbackDomainFilters { + if df.Match(domain) { + return true + } + } + + return false +} + +func (r *Resolver) ipExchange(m *D.Msg) (msg *D.Msg, err error) { + + onlyFallback := r.shouldOnlyQueryFallback(m) + + if onlyFallback { + res := <-r.asyncExchange(r.fallback, m) + return res.Msg, res.Error + } + msgCh := r.asyncExchange(r.main, m) - if r.fallback == nil { + + if r.fallback == nil { // directly return if no fallback servers are available res := <-msgCh msg, err = res.Msg, res.Error return } + fallbackMsg := r.asyncExchange(r.fallback, m) res := <-msgCh if res.Error == nil { if ips := r.msgToIP(res.Msg); len(ips) != 0 { - if !r.shouldFallback(ips[0]) { - msg = res.Msg + if !r.shouldIPFallback(ips[0]) { + msg = res.Msg // no need to wait for fallback result err = res.Error return msg, err } @@ -240,6 +272,14 @@ func (r *Resolver) msgToIP(msg *D.Msg) []net.IP { return ips } +func (r *Resolver) msgToDomain(msg *D.Msg) string { + if len(msg.Question) > 0 { + return strings.TrimRight(msg.Question[0].Name, ".") + } + + return "" +} + func (r *Resolver) asyncExchange(client []dnsClient, msg *D.Msg) <-chan *result { ch := make(chan *result, 1) go func() { @@ -257,6 +297,7 @@ type NameServer struct { type FallbackFilter struct { GeoIP bool IPCIDR []*net.IPNet + Domain []string } type Config struct { @@ -286,14 +327,19 @@ func NewResolver(config Config) *Resolver { r.fallback = transform(config.Fallback, defaultResolver) } - fallbackFilters := []fallbackFilter{} + fallbackIPFilters := []fallbackIPFilter{} if config.FallbackFilter.GeoIP { - fallbackFilters = append(fallbackFilters, &geoipFilter{}) + fallbackIPFilters = append(fallbackIPFilters, &geoipFilter{}) } for _, ipnet := range config.FallbackFilter.IPCIDR { - fallbackFilters = append(fallbackFilters, &ipnetFilter{ipnet: ipnet}) + fallbackIPFilters = append(fallbackIPFilters, &ipnetFilter{ipnet: ipnet}) + } + r.fallbackIPFilters = fallbackIPFilters + + if len(config.FallbackFilter.Domain) != 0 { + fallbackDomainFilters := []fallbackDomainFilter{NewDomainFilter(config.FallbackFilter.Domain)} + r.fallbackDomainFilters = fallbackDomainFilters } - r.fallbackFilters = fallbackFilters return r } diff --git a/hub/executor/executor.go b/hub/executor/executor.go index c24d6123..822b43e5 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -118,6 +118,7 @@ func updateDNS(c *config.DNS) { FallbackFilter: dns.FallbackFilter{ GeoIP: c.FallbackFilter.GeoIP, IPCIDR: c.FallbackFilter.IPCIDR, + Domain: c.FallbackFilter.Domain, }, Default: c.DefaultNameserver, }