sing-box/route/route_dns.go

282 lines
8.6 KiB
Go
Raw Normal View History

2022-08-03 18:55:39 +08:00
package route
import (
"context"
"errors"
2022-08-03 18:55:39 +08:00
"net/netip"
"strings"
2023-03-23 19:08:48 +08:00
"time"
2022-08-03 18:55:39 +08:00
2022-08-16 23:46:05 +08:00
"github.com/sagernet/sing-box/adapter"
2024-10-21 23:38:34 +08:00
R "github.com/sagernet/sing-box/route/rule"
2022-08-03 18:55:39 +08:00
"github.com/sagernet/sing-dns"
2023-03-23 19:08:48 +08:00
"github.com/sagernet/sing/common/cache"
2022-08-03 18:55:39 +08:00
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
2023-03-23 19:08:48 +08:00
M "github.com/sagernet/sing/common/metadata"
2022-08-03 18:55:39 +08:00
2022-09-13 16:18:39 +08:00
mDNS "github.com/miekg/dns"
2022-08-03 18:55:39 +08:00
)
2023-03-23 19:08:48 +08:00
type DNSReverseMapping struct {
cache *cache.LruCache[netip.Addr, string]
}
func NewDNSReverseMapping() *DNSReverseMapping {
return &DNSReverseMapping{
cache: cache.New[netip.Addr, string](),
}
}
func (m *DNSReverseMapping) Save(address netip.Addr, domain string, ttl int) {
m.cache.StoreWithExpire(address, domain, time.Now().Add(time.Duration(ttl)*time.Second))
}
func (m *DNSReverseMapping) Query(address netip.Addr) (string, bool) {
domain, loaded := m.cache.Load(address)
return domain, loaded
}
2024-10-21 23:38:34 +08:00
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool) (dns.Transport, dns.QueryOptions, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
}
2024-10-21 23:38:34 +08:00
var options dns.QueryOptions
if ruleIndex < len(r.dnsRules) {
dnsRules := r.dnsRules
2024-10-21 23:38:34 +08:00
if ruleIndex != -1 {
dnsRules = dnsRules[ruleIndex+1:]
}
for currentRuleIndex, rule := range dnsRules {
if rule.WithAddressLimit() && !isAddressQuery {
continue
}
metadata.ResetRuleCache()
if rule.Match(metadata) {
2024-10-21 23:38:34 +08:00
displayRuleIndex := currentRuleIndex
if displayRuleIndex != -1 {
displayRuleIndex += displayRuleIndex + 1
}
2024-10-21 23:38:34 +08:00
if routeAction, isRoute := rule.Action().(*R.RuleActionDNSRoute); isRoute {
transport, loaded := r.transportMap[routeAction.Server]
if !loaded {
r.dnsLogger.ErrorContext(ctx, "transport not found: ", routeAction.Server)
continue
}
_, isFakeIP := transport.(adapter.FakeIPTransport)
if isFakeIP && !allowFakeIP {
continue
}
options.DisableCache = isFakeIP || routeAction.DisableCache
options.RewriteTTL = routeAction.RewriteTTL
options.ClientSubnet = routeAction.ClientSubnet
if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded {
options.Strategy = domainStrategy
} else {
options.Strategy = r.defaultDomainStrategy
}
r.dnsLogger.DebugContext(ctx, "match[", displayRuleIndex, "] ", rule.String(), " => ", rule.Action())
return transport, options, rule, currentRuleIndex
} else {
2024-10-21 23:38:34 +08:00
return nil, options, rule, currentRuleIndex
}
}
}
}
2022-11-19 22:39:30 +08:00
if domainStrategy, dsLoaded := r.transportDomainStrategy[r.defaultTransport]; dsLoaded {
2024-10-21 23:38:34 +08:00
options.Strategy = domainStrategy
2022-11-19 22:39:30 +08:00
} else {
2024-10-21 23:38:34 +08:00
options.Strategy = r.defaultDomainStrategy
2022-11-19 22:39:30 +08:00
}
2024-10-21 23:38:34 +08:00
return r.defaultTransport, options, nil, -1
}
2022-09-13 16:18:39 +08:00
func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
if len(message.Question) > 0 {
r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String()))
2022-08-03 18:55:39 +08:00
}
2023-04-26 04:53:25 +08:00
var (
response *mDNS.Msg
cached bool
transport dns.Transport
err error
2023-04-26 04:53:25 +08:00
)
response, cached = r.dnsClient.ExchangeCache(ctx, message)
if !cached {
var metadata *adapter.InboundContext
2024-06-24 09:41:00 +08:00
ctx, metadata = adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
2023-04-26 04:53:25 +08:00
if len(message.Question) > 0 {
metadata.QueryType = message.Question[0].Qtype
switch metadata.QueryType {
case mDNS.TypeA:
metadata.IPVersion = 4
case mDNS.TypeAAAA:
metadata.IPVersion = 6
}
metadata.Domain = fqdnToDomain(message.Question[0].Name)
}
var (
2024-10-21 23:38:34 +08:00
options dns.QueryOptions
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
2024-10-21 23:38:34 +08:00
dnsCtx := adapter.OverrideContext(ctx)
var addressLimit bool
transport, options, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message))
if rule != nil && rule.WithAddressLimit() {
addressLimit = true
2024-10-21 23:38:34 +08:00
response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, options, func(response *mDNS.Msg) bool {
2024-06-24 09:41:00 +08:00
addresses, addrErr := dns.MessageToAddresses(response)
if addrErr != nil {
return false
}
metadata.DestinationAddresses = addresses
return rule.MatchAddressLimit(metadata)
})
} else {
addressLimit = false
2024-10-21 23:38:34 +08:00
response, err = r.dnsClient.Exchange(dnsCtx, transport, message, options)
}
2024-03-15 17:21:52 +08:00
var rejected bool
if err != nil {
if errors.Is(err, dns.ErrResponseRejectedCached) {
2024-03-15 17:21:52 +08:00
rejected = true
r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())), " (cached)")
} else if errors.Is(err, dns.ErrResponseRejected) {
2024-03-15 17:21:52 +08:00
rejected = true
r.dnsLogger.DebugContext(ctx, E.Cause(err, "response rejected for ", formatQuestion(message.Question[0].String())))
} else if len(message.Question) > 0 {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for ", formatQuestion(message.Question[0].String())))
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "exchange failed for <empty query>"))
}
}
2024-03-15 17:21:52 +08:00
if addressLimit && rejected {
continue
}
2024-03-15 17:21:52 +08:00
break
2022-08-16 23:46:05 +08:00
}
2022-08-03 18:55:39 +08:00
}
if err != nil {
return nil, err
}
2023-03-23 19:08:48 +08:00
if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 {
if _, isFakeIP := transport.(adapter.FakeIPTransport); !isFakeIP {
for _, answer := range response.Answer {
switch record := answer.(type) {
case *mDNS.A:
r.dnsReverseMapping.Save(M.AddrFromIP(record.A), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl))
case *mDNS.AAAA:
r.dnsReverseMapping.Save(M.AddrFromIP(record.AAAA), fqdnToDomain(record.Hdr.Name), int(record.Hdr.Ttl))
}
2023-03-23 19:08:48 +08:00
}
}
}
return response, nil
2022-08-03 18:55:39 +08:00
}
func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
var (
responseAddrs []netip.Addr
cached bool
err error
)
responseAddrs, cached = r.dnsClient.LookupCache(ctx, domain, strategy)
if cached {
if len(responseAddrs) == 0 {
return nil, dns.RCodeNameError
}
return responseAddrs, nil
}
2022-08-03 21:51:34 +08:00
r.dnsLogger.DebugContext(ctx, "lookup domain ", domain)
2024-06-24 09:41:00 +08:00
ctx, metadata := adapter.ExtendContext(ctx)
metadata.Destination = M.Socksaddr{}
2022-09-04 12:39:43 +08:00
metadata.Domain = domain
var (
2024-10-21 23:38:34 +08:00
transport dns.Transport
options dns.QueryOptions
rule adapter.DNSRule
ruleIndex int
)
ruleIndex = -1
for {
2024-10-21 23:38:34 +08:00
dnsCtx := adapter.OverrideContext(ctx)
var addressLimit bool
transport, options, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true)
if strategy != dns.DomainStrategyAsIS {
options.Strategy = strategy
}
if rule != nil && rule.WithAddressLimit() {
addressLimit = true
2024-10-21 23:38:34 +08:00
responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool {
metadata.DestinationAddresses = responseAddrs
return rule.MatchAddressLimit(metadata)
})
} else {
addressLimit = false
2024-10-21 23:38:34 +08:00
responseAddrs, err = r.dnsClient.Lookup(dnsCtx, transport, domain, options)
}
if err != nil {
if errors.Is(err, dns.ErrResponseRejectedCached) {
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain, " (cached)")
} else if errors.Is(err, dns.ErrResponseRejected) {
r.dnsLogger.DebugContext(ctx, "response rejected for ", domain)
} else {
r.dnsLogger.ErrorContext(ctx, E.Cause(err, "lookup failed for ", domain))
}
} else if len(responseAddrs) == 0 {
r.dnsLogger.ErrorContext(ctx, "lookup failed for ", domain, ": empty result")
err = dns.RCodeNameError
}
if !addressLimit || err == nil {
break
}
}
if len(responseAddrs) > 0 {
r.dnsLogger.InfoContext(ctx, "lookup succeed for ", domain, ": ", strings.Join(F.MapToString(responseAddrs), " "))
2022-08-03 18:55:39 +08:00
}
return responseAddrs, err
2022-08-03 18:55:39 +08:00
}
func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
return r.Lookup(ctx, domain, dns.DomainStrategyAsIS)
2022-08-03 18:55:39 +08:00
}
func (r *Router) ClearDNSCache() {
r.dnsClient.ClearCache()
if r.platformInterface != nil {
r.platformInterface.ClearDNSCache()
}
}
func isAddressQuery(message *mDNS.Msg) bool {
for _, question := range message.Question {
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA || question.Qtype == mDNS.TypeHTTPS {
return true
}
2022-08-03 18:55:39 +08:00
}
return false
2022-08-03 18:55:39 +08:00
}
2022-09-13 16:18:39 +08:00
func fqdnToDomain(fqdn string) string {
if mDNS.IsFqdn(fqdn) {
return fqdn[:len(fqdn)-1]
2022-08-03 18:55:39 +08:00
}
2022-09-13 16:18:39 +08:00
return fqdn
2022-08-03 18:55:39 +08:00
}
2022-09-13 16:18:39 +08:00
func formatQuestion(string string) string {
if strings.HasPrefix(string, ";") {
string = string[1:]
}
string = strings.ReplaceAll(string, "\t", " ")
for strings.Contains(string, " ") {
string = strings.ReplaceAll(string, " ", " ")
2022-08-03 18:55:39 +08:00
}
2022-09-13 16:18:39 +08:00
return string
2022-08-03 18:55:39 +08:00
}