Refactor: MainResolver

This commit is contained in:
gVisor bot 2022-03-28 00:44:13 +08:00
parent 17da2d36a5
commit 67d04485ca
8 changed files with 128 additions and 75 deletions

View File

@ -14,6 +14,7 @@ type Direct struct {
// DialContext implements C.ProxyAdapter // DialContext implements C.ProxyAdapter
func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) { func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
opts = append(opts, dialer.WithDirect())
c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...) c, err := dialer.DialContext(ctx, "tcp", metadata.RemoteAddress(), d.Base.DialOptions(opts...)...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -24,6 +25,7 @@ func (d *Direct) DialContext(ctx context.Context, metadata *C.Metadata, opts ...
// ListenPacketContext implements C.ProxyAdapter // ListenPacketContext implements C.ProxyAdapter
func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) { func (d *Direct) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
opts = append(opts, dialer.WithDirect())
pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...) pc, err := dialer.ListenPacket(ctx, "udp", "", d.Base.DialOptions(opts...)...)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -44,7 +44,7 @@ func resolveUDPAddr(network, address string) (*net.UDPAddr, error) {
return nil, err return nil, err
} }
ip, err := resolver.ResolveIP(host) ip, err := resolver.ResolveProxyServerHost(host)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -32,14 +32,14 @@ func DialContext(ctx context.Context, network, address string, options ...Option
var ip net.IP var ip net.IP
switch network { switch network {
case "tcp4", "udp4": case "tcp4", "udp4":
if opt.interfaceName != "" { if !opt.direct {
ip, err = resolver.ResolveIPv4WithMain(host) ip, err = resolver.ResolveIPv4ProxyServerHost(host)
} else { } else {
ip, err = resolver.ResolveIPv4(host) ip, err = resolver.ResolveIPv4(host)
} }
default: default:
if opt.interfaceName != "" { if !opt.direct {
ip, err = resolver.ResolveIPv6WithMain(host) ip, err = resolver.ResolveIPv6ProxyServerHost(host)
} else { } else {
ip, err = resolver.ResolveIPv6(host) ip, err = resolver.ResolveIPv6(host)
} }
@ -121,7 +121,7 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
results := make(chan dialResult) results := make(chan dialResult)
var primary, fallback dialResult var primary, fallback dialResult
startRacer := func(ctx context.Context, network, host string, ipv6 bool) { startRacer := func(ctx context.Context, network, host string, direct bool, ipv6 bool) {
result := dialResult{ipv6: ipv6, done: true} result := dialResult{ipv6: ipv6, done: true}
defer func() { defer func() {
select { select {
@ -135,14 +135,14 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
var ip net.IP var ip net.IP
if ipv6 { if ipv6 {
if opt.interfaceName != "" { if !direct {
ip, result.error = resolver.ResolveIPv6WithMain(host) ip, result.error = resolver.ResolveIPv6ProxyServerHost(host)
} else { } else {
ip, result.error = resolver.ResolveIPv6(host) ip, result.error = resolver.ResolveIPv6(host)
} }
} else { } else {
if opt.interfaceName != "" { if !direct {
ip, result.error = resolver.ResolveIPv4WithMain(host) ip, result.error = resolver.ResolveIPv4ProxyServerHost(host)
} else { } else {
ip, result.error = resolver.ResolveIPv4(host) ip, result.error = resolver.ResolveIPv4(host)
} }
@ -155,8 +155,8 @@ func dualStackDialContext(ctx context.Context, network, address string, opt *opt
result.Conn, result.error = dialContext(ctx, network, ip, port, opt) result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
} }
go startRacer(ctx, network+"4", host, false) go startRacer(ctx, network+"4", host, opt.direct, false)
go startRacer(ctx, network+"6", host, true) go startRacer(ctx, network+"6", host, opt.direct, true)
for res := range results { for res := range results {
if res.error == nil { if res.error == nil {

View File

@ -12,6 +12,7 @@ type option struct {
interfaceName string interfaceName string
addrReuse bool addrReuse bool
routingMark int routingMark int
direct bool
} }
type Option func(opt *option) type Option func(opt *option)
@ -33,3 +34,9 @@ func WithRoutingMark(mark int) Option {
opt.routingMark = mark opt.routingMark = mark
} }
} }
func WithDirect() Option {
return func(opt *option) {
opt.direct = true
}
}

View File

@ -15,8 +15,8 @@ var (
// DefaultResolver aim to resolve ip // DefaultResolver aim to resolve ip
DefaultResolver Resolver DefaultResolver Resolver
// MainResolver resolve ip with main domain server // ProxyServerHostResolver resolve ip to proxies server host
MainResolver Resolver ProxyServerHostResolver Resolver
// DisableIPv6 means don't resolve ipv6 host // DisableIPv6 means don't resolve ipv6 host
// default value is true // default value is true
@ -46,10 +46,6 @@ func ResolveIPv4(host string) (net.IP, error) {
return ResolveIPv4WithResolver(host, DefaultResolver) return ResolveIPv4WithResolver(host, DefaultResolver)
} }
func ResolveIPv4WithMain(host string) (net.IP, error) {
return ResolveIPv4WithResolver(host, MainResolver)
}
func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) { func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) {
if node := DefaultHosts.Search(host); node != nil { if node := DefaultHosts.Search(host); node != nil {
if ip := node.Data.(net.IP).To4(); ip != nil { if ip := node.Data.(net.IP).To4(); ip != nil {
@ -69,6 +65,7 @@ func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) {
return r.ResolveIPv4(host) return r.ResolveIPv4(host)
} }
if DefaultResolver == nil {
ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout)
defer cancel() defer cancel()
ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host) ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip4", host)
@ -79,6 +76,9 @@ func ResolveIPv4WithResolver(host string, r Resolver) (net.IP, error) {
} }
return ipAddrs[rand.Intn(len(ipAddrs))], nil return ipAddrs[rand.Intn(len(ipAddrs))], nil
}
return nil, ErrIPNotFound
} }
// ResolveIPv6 with a host, return ipv6 // ResolveIPv6 with a host, return ipv6
@ -86,10 +86,6 @@ func ResolveIPv6(host string) (net.IP, error) {
return ResolveIPv6WithResolver(host, DefaultResolver) return ResolveIPv6WithResolver(host, DefaultResolver)
} }
func ResolveIPv6WithMain(host string) (net.IP, error) {
return ResolveIPv6WithResolver(host, MainResolver)
}
func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) { func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) {
if DisableIPv6 { if DisableIPv6 {
return nil, ErrIPv6Disabled return nil, ErrIPv6Disabled
@ -113,6 +109,7 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) {
return r.ResolveIPv6(host) return r.ResolveIPv6(host)
} }
if DefaultResolver == nil {
ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout) ctx, cancel := context.WithTimeout(context.Background(), DefaultDNSTimeout)
defer cancel() defer cancel()
ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host) ipAddrs, err := net.DefaultResolver.LookupIP(ctx, "ip6", host)
@ -123,6 +120,9 @@ func ResolveIPv6WithResolver(host string, r Resolver) (net.IP, error) {
} }
return ipAddrs[rand.Intn(len(ipAddrs))], nil return ipAddrs[rand.Intn(len(ipAddrs))], nil
}
return nil, ErrIPNotFound
} }
// ResolveIPWithResolver same as ResolveIP, but with a resolver // ResolveIPWithResolver same as ResolveIP, but with a resolver
@ -145,12 +145,16 @@ func ResolveIPWithResolver(host string, r Resolver) (net.IP, error) {
return ip, nil return ip, nil
} }
if DefaultResolver == nil {
ipAddr, err := net.ResolveIPAddr("ip", host) ipAddr, err := net.ResolveIPAddr("ip", host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return ipAddr.IP, nil return ipAddr.IP, nil
}
return nil, ErrIPNotFound
} }
// ResolveIP with a host, return ip // ResolveIP with a host, return ip
@ -158,7 +162,26 @@ func ResolveIP(host string) (net.IP, error) {
return ResolveIPWithResolver(host, DefaultResolver) return ResolveIPWithResolver(host, DefaultResolver)
} }
// ResolveIPWithMainResolver with a host, use main resolver, return ip // ResolveIPv4ProxyServerHost proxies server host only
func ResolveIPWithMainResolver(host string) (net.IP, error) { func ResolveIPv4ProxyServerHost(host string) (net.IP, error) {
return ResolveIPWithResolver(host, MainResolver) if ProxyServerHostResolver != nil {
return ResolveIPv4WithResolver(host, ProxyServerHostResolver)
}
return ResolveIPv4(host)
}
// ResolveIPv6ProxyServerHost proxies server host only
func ResolveIPv6ProxyServerHost(host string) (net.IP, error) {
if ProxyServerHostResolver != nil {
return ResolveIPv6WithResolver(host, ProxyServerHostResolver)
}
return ResolveIPv6(host)
}
// ResolveProxyServerHost proxies server host only
func ResolveProxyServerHost(host string) (net.IP, error) {
if ProxyServerHostResolver != nil {
return ResolveIPWithResolver(host, ProxyServerHostResolver)
}
return ResolveIP(host)
} }

View File

@ -74,6 +74,7 @@ type DNS struct {
FakeIPRange *fakeip.Pool FakeIPRange *fakeip.Pool
Hosts *trie.DomainTrie Hosts *trie.DomainTrie
NameServerPolicy map[string]dns.NameServer NameServerPolicy map[string]dns.NameServer
ProxyServerNameserver []dns.NameServer
} }
// FallbackFilter config // FallbackFilter config
@ -137,6 +138,7 @@ type RawDNS struct {
FakeIPFilter []string `yaml:"fake-ip-filter"` FakeIPFilter []string `yaml:"fake-ip-filter"`
DefaultNameserver []string `yaml:"default-nameserver"` DefaultNameserver []string `yaml:"default-nameserver"`
NameServerPolicy map[string]string `yaml:"nameserver-policy"` NameServerPolicy map[string]string `yaml:"nameserver-policy"`
ProxyServerNameserver []string `yaml:"proxy-server-nameserver"`
} }
type RawFallbackFilter struct { type RawFallbackFilter struct {
@ -679,6 +681,10 @@ func parseDNS(rawCfg *RawConfig, hosts *trie.DomainTrie, rules []C.Rule) (*DNS,
return nil, err return nil, err
} }
if dnsCfg.ProxyServerNameserver, err = parseNameServer(cfg.ProxyServerNameserver); err != nil {
return nil, err
}
if len(cfg.DefaultNameserver) == 0 { if len(cfg.DefaultNameserver) == 0 {
return nil, errors.New("default nameserver should have at least one nameserver") return nil, errors.New("default nameserver should have at least one nameserver")
} }

View File

@ -41,6 +41,7 @@ type Resolver struct {
group singleflight.Group group singleflight.Group
lruCache *cache.LruCache lruCache *cache.LruCache
policy *trie.DomainTrie policy *trie.DomainTrie
proxyServer []dnsClient
} }
// ResolveIP request with TypeA and TypeAAAA, priority return TypeA // ResolveIP request with TypeA and TypeAAAA, priority return TypeA
@ -301,6 +302,11 @@ func (r *Resolver) asyncExchange(ctx context.Context, client []dnsClient, msg *D
return ch return ch
} }
// HasProxyServer has proxy server dns client
func (r *Resolver) HasProxyServer() bool {
return len(r.main) > 0
}
type NameServer struct { type NameServer struct {
Net string Net string
Addr string Addr string
@ -319,6 +325,7 @@ type FallbackFilter struct {
type Config struct { type Config struct {
Main, Fallback []NameServer Main, Fallback []NameServer
Default []NameServer Default []NameServer
ProxyServer []NameServer
IPv6 bool IPv6 bool
EnhancedMode C.DNSMode EnhancedMode C.DNSMode
FallbackFilter FallbackFilter FallbackFilter FallbackFilter
@ -344,6 +351,10 @@ func NewResolver(config Config) *Resolver {
r.fallback = transform(config.Fallback, defaultResolver) r.fallback = transform(config.Fallback, defaultResolver)
} }
if len(config.ProxyServer) != 0 {
r.proxyServer = transform(config.ProxyServer, defaultResolver)
}
if len(config.Policy) != 0 { if len(config.Policy) != 0 {
r.policy = trie.New() r.policy = trie.New()
for domain, nameserver := range config.Policy { for domain, nameserver := range config.Policy {
@ -377,10 +388,10 @@ func NewResolver(config Config) *Resolver {
return r return r
} }
func NewMainResolver(old *Resolver) *Resolver { func NewProxyServerHostResolver(old *Resolver) *Resolver {
r := &Resolver{ r := &Resolver{
ipv6: old.ipv6, ipv6: old.ipv6,
main: old.main, main: old.proxyServer,
lruCache: old.lruCache, lruCache: old.lruCache,
hosts: old.hosts, hosts: old.hosts,
policy: old.policy, policy: old.policy,

View File

@ -132,10 +132,11 @@ func updateDNS(c *config.DNS, t *config.Tun) {
}, },
Default: c.DefaultNameserver, Default: c.DefaultNameserver,
Policy: c.NameServerPolicy, Policy: c.NameServerPolicy,
ProxyServer: c.ProxyServerNameserver,
} }
r := dns.NewResolver(cfg) r := dns.NewResolver(cfg)
mr := dns.NewMainResolver(r) pr := dns.NewProxyServerHostResolver(r)
m := dns.NewEnhancer(cfg) m := dns.NewEnhancer(cfg)
// reuse cache of old host mapper // reuse cache of old host mapper
@ -144,9 +145,12 @@ func updateDNS(c *config.DNS, t *config.Tun) {
} }
resolver.DefaultResolver = r resolver.DefaultResolver = r
resolver.MainResolver = mr
resolver.DefaultHostMapper = m resolver.DefaultHostMapper = m
if pr.HasProxyServer() {
resolver.ProxyServerHostResolver = pr
}
if t.Enable { if t.Enable {
resolver.DefaultLocalServer = dns.NewLocalServer(r, m) resolver.DefaultLocalServer = dns.NewLocalServer(r, m)
} }
@ -156,9 +160,9 @@ func updateDNS(c *config.DNS, t *config.Tun) {
} else { } else {
if !t.Enable { if !t.Enable {
resolver.DefaultResolver = nil resolver.DefaultResolver = nil
resolver.MainResolver = nil
resolver.DefaultHostMapper = nil resolver.DefaultHostMapper = nil
resolver.DefaultLocalServer = nil resolver.DefaultLocalServer = nil
resolver.ProxyServerHostResolver = nil
} }
dns.ReCreateServer("", nil, nil) dns.ReCreateServer("", nil, nil)
} }