diff --git a/component/resolver/host.go b/component/resolver/host.go index 0b98bd68..ca90cd27 100644 --- a/component/resolver/host.go +++ b/component/resolver/host.go @@ -7,8 +7,42 @@ import ( "strings" "github.com/Dreamacro/clash/common/utils" + "github.com/Dreamacro/clash/component/trie" ) +type Hosts struct { + *trie.DomainTrie[HostValue] +} + +func NewHosts(hosts *trie.DomainTrie[HostValue]) Hosts { + return Hosts{ + hosts, + } +} + +func (h *Hosts) Search(domain string, isDomain bool) (*HostValue, bool) { + value := h.DomainTrie.Search(domain) + if value == nil { + return nil, false + } + hostValue := value.Data() + for { + if isDomain && hostValue.IsDomain { + return &hostValue, true + } else { + if node := h.DomainTrie.Search(hostValue.Domain); node != nil { + hostValue = node.Data() + } else { + break + } + } + } + if isDomain == hostValue.IsDomain { + return &hostValue, true + } + return &hostValue, false +} + type HostValue struct { IsDomain bool IPs []netip.Addr @@ -23,7 +57,7 @@ func NewHostValue(value any) (HostValue, error) { return HostValue{}, err } else { if len(valueArr) > 1 { - isDomain=false + isDomain = false for _, str := range valueArr { if ip, err := netip.ParseAddr(str); err == nil { ips = append(ips, ip) diff --git a/component/resolver/resolver.go b/component/resolver/resolver.go index 322a9224..f5872ad7 100644 --- a/component/resolver/resolver.go +++ b/component/resolver/resolver.go @@ -28,7 +28,7 @@ var ( DisableIPv6 = true // DefaultHosts aim to resolve hosts - DefaultHosts = trie.New[HostValue]() + DefaultHosts = NewHosts(trie.New[HostValue]()) // DefaultDNSTimeout defined the default dns request timeout DefaultDNSTimeout = time.Second * 5 @@ -52,15 +52,11 @@ type Resolver interface { // LookupIPv4WithResolver same as LookupIPv4, but with a resolver func LookupIPv4WithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { - if node := DefaultHosts.Search(host); node != nil { - if value := node.Data(); !value.IsDomain { - if addrs := utils.Filter(value.IPs, func(ip netip.Addr) bool { - return ip.Is4() - }); len(addrs) > 0 { - return addrs, nil - } - }else{ - return LookupIPv4WithResolver(ctx,value.Domain,r) + if node, ok := DefaultHosts.Search(host, false); ok { + if addrs := utils.Filter(node.IPs, func(ip netip.Addr) bool { + return ip.Is4() + }); len(addrs) > 0 { + return addrs, nil } } @@ -113,15 +109,11 @@ func LookupIPv6WithResolver(ctx context.Context, host string, r Resolver) ([]net return nil, ErrIPv6Disabled } - if node := DefaultHosts.Search(host); node != nil { - if value := node.Data(); !value.IsDomain { - if addrs := utils.Filter(value.IPs, func(ip netip.Addr) bool { - return ip.Is6() - }); len(addrs) > 0 { - return addrs, nil - } - }else{ - return LookupIPv6WithResolver(ctx,value.Domain,r) + if node, ok := DefaultHosts.Search(host, false); ok { + if addrs := utils.Filter(node.IPs, func(ip netip.Addr) bool { + return ip.Is6() + }); len(addrs) > 0 { + return addrs, nil } } @@ -168,12 +160,8 @@ func ResolveIPv6(ctx context.Context, host string) (netip.Addr, error) { // LookupIPWithResolver same as LookupIP, but with a resolver func LookupIPWithResolver(ctx context.Context, host string, r Resolver) ([]netip.Addr, error) { - if node := DefaultHosts.Search(host); node != nil { - if !node.Data().IsDomain{ - return node.Data().IPs, nil - }else{ - return LookupIPWithResolver(ctx,node.Data().Domain,r) - } + if node, ok := DefaultHosts.Search(host, false); ok { + return node.IPs, nil } if r != nil { diff --git a/dns/middleware.go b/dns/middleware.go index 66db1405..f45c73b5 100644 --- a/dns/middleware.go +++ b/dns/middleware.go @@ -8,8 +8,7 @@ import ( "github.com/Dreamacro/clash/common/cache" "github.com/Dreamacro/clash/common/nnip" "github.com/Dreamacro/clash/component/fakeip" - "github.com/Dreamacro/clash/component/resolver" - "github.com/Dreamacro/clash/component/trie" + R "github.com/Dreamacro/clash/component/resolver" C "github.com/Dreamacro/clash/constant" "github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/log" @@ -22,7 +21,7 @@ type ( middleware func(next handler) handler ) -func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCache[netip.Addr, string]) middleware { +func withHosts(hosts R.Hosts, mapping *cache.LruCache[netip.Addr, string]) middleware { return func(next handler) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { q := r.Question[0] @@ -33,15 +32,25 @@ func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCac host := strings.TrimRight(q.Name, ".") - record := hosts.Search(host) - if record == nil { - return next(ctx, r) + record, ok := hosts.Search(host, q.Qtype != D.TypeA && q.Qtype != D.TypeAAAA) + if !ok { + if record != nil && record.IsDomain { + // replace request domain + newR := r.Copy() + newR.Question[0].Name = record.Domain+"." + resp, err := next(ctx, newR) + if err==nil{ + resp.Id=r.Id + resp.Question=r.Question + } + return resp,err + } + return next(ctx,r) } - hostValue := record.Data() msg := r.Copy() handleIPs := func() { - for _, ipAddr := range hostValue.IPs { + for _, ipAddr := range record.IPs { if ipAddr.Is4() && q.Qtype == D.TypeA { rr := &D.A{} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeA, Class: D.ClassINET, Ttl: 10} @@ -62,35 +71,16 @@ func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCac } } } - fillMsg := func() { - if !hostValue.IsDomain { - handleIPs() - } else { - for { - if hostValue.IsDomain { - if node := hosts.Search(hostValue.Domain); node != nil { - hostValue = node.Data() - } else { - break - } - }else{ - break - } - } - if !hostValue.IsDomain { - handleIPs() - } - } - } + switch q.Qtype { case D.TypeA: - fillMsg() + handleIPs() case D.TypeAAAA: - fillMsg() + handleIPs() case D.TypeCNAME: rr := &D.CNAME{} rr.Hdr = D.RR_Header{Name: q.Name, Rrtype: D.TypeCNAME, Class: D.ClassINET, Ttl: 10} - rr.Target = hostValue.Domain + "." + rr.Target = record.Domain + "." msg.Answer = append(msg.Answer, rr) default: return next(ctx, r) @@ -100,7 +90,6 @@ func withHosts(hosts *trie.DomainTrie[resolver.HostValue], mapping *cache.LruCac msg.SetRcode(r, D.RcodeSuccess) msg.Authoritative = true msg.RecursionAvailable = true - return msg, nil } } @@ -185,6 +174,7 @@ func withFakeIP(fakePool *fakeip.Pool) middleware { func withResolver(resolver *Resolver) handler { return func(ctx *context.DNSContext, r *D.Msg) (*D.Msg, error) { ctx.SetType(context.DNSTypeRaw) + q := r.Question[0] // return a empty AAAA msg when ipv6 disabled @@ -219,7 +209,7 @@ func NewHandler(resolver *Resolver, mapper *ResolverEnhancer) handler { middlewares := []middleware{} if resolver.hosts != nil { - middlewares = append(middlewares, withHosts(resolver.hosts, mapper.mapping)) + middlewares = append(middlewares, withHosts(R.NewHosts(resolver.hosts), mapper.mapping)) } if mapper.mode == C.DNSFakeIP { diff --git a/hub/executor/executor.go b/hub/executor/executor.go index 0790f9a5..21a25ecd 100644 --- a/hub/executor/executor.go +++ b/hub/executor/executor.go @@ -227,7 +227,7 @@ func updateDNS(c *config.DNS, generalIPv6 bool) { } func updateHosts(tree *trie.DomainTrie[resolver.HostValue]) { - resolver.DefaultHosts = tree + resolver.DefaultHosts = resolver.NewHosts(tree) } func updateProxies(proxies map[string]C.Proxy, providers map[string]provider.ProxyProvider) { diff --git a/tunnel/tunnel.go b/tunnel/tunnel.go index a4b92ea1..48e8a2c0 100644 --- a/tunnel/tunnel.go +++ b/tunnel/tunnel.go @@ -3,7 +3,6 @@ package tunnel import ( "context" "fmt" - "math/rand" "net" "net/netip" "path/filepath" @@ -202,13 +201,9 @@ func preHandleMetadata(metadata *C.Metadata) error { if resolver.FakeIPEnabled() { metadata.DstIP = netip.Addr{} metadata.DNSMode = C.DNSFakeIP - } else if node := resolver.DefaultHosts.Search(host); node != nil { + } else if node, ok := resolver.DefaultHosts.Search(host, false); ok { // redir-host should lookup the hosts - if !node.Data().IsDomain { - metadata.DstIP,_ = node.Data().RandIP() - } else { - metadata.Host = node.Data().Domain - } + metadata.DstIP, _ = node.RandIP() } } else if resolver.IsFakeIP(metadata.DstIP) { return fmt.Errorf("fake DNS record %s missing", metadata.DstIP) @@ -397,14 +392,11 @@ func handleTCPConn(connCtx C.ConnContext) { dialMetadata := metadata if len(metadata.Host) > 0 { - if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - hostValue := node.Data() - if !hostValue.IsDomain { - if dstIp, _ := hostValue.RandIP(); !FakeIPRange().Contains(dstIp) { - dialMetadata.DstIP = dstIp - dialMetadata.DNSMode = C.DNSHosts - dialMetadata = dialMetadata.Pure() - } + if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok { + if dstIp, _ := node.RandIP(); !FakeIPRange().Contains(dstIp) { + dialMetadata.DstIP = dstIp + dialMetadata.DNSMode = C.DNSHosts + dialMetadata = dialMetadata.Pure() } } } @@ -506,12 +498,9 @@ func match(metadata *C.Metadata) (C.Proxy, C.Rule, error) { processFound bool ) - if node := resolver.DefaultHosts.Search(metadata.Host); node != nil { - if !node.Data().IsDomain { - metadata.DstIP = node.Data().IPs[rand.Intn(len(node.Data().IPs)-1)] - resolved = true - } - + if node, ok := resolver.DefaultHosts.Search(metadata.Host, false); ok { + metadata.DstIP, _ = node.RandIP() + resolved = true } for _, rule := range getRules(metadata) {