Add disable_cache option to dns rule

This commit is contained in:
世界 2022-07-24 14:05:06 +08:00
parent 8666631732
commit af19ba6119
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
6 changed files with 127 additions and 82 deletions

View File

@ -50,3 +50,8 @@ type Rule interface {
Outbound() string Outbound() string
String() string String() string
} }
type DNSRule interface {
Rule
DisableCache() bool
}

View File

@ -55,6 +55,7 @@ func (r DNSRule) MarshalJSON() ([]byte, error) {
var v any var v any
switch r.Type { switch r.Type {
case C.RuleTypeDefault: case C.RuleTypeDefault:
r.Type = ""
v = r.DefaultOptions v = r.DefaultOptions
case C.RuleTypeLogical: case C.RuleTypeLogical:
v = r.LogicalOptions v = r.LogicalOptions
@ -109,6 +110,7 @@ type DefaultDNSRule struct {
Outbound Listable[string] `json:"outbound,omitempty"` Outbound Listable[string] `json:"outbound,omitempty"`
Invert bool `json:"invert,omitempty"` Invert bool `json:"invert,omitempty"`
Server string `json:"server,omitempty"` Server string `json:"server,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"`
} }
func (r DefaultDNSRule) IsValid() bool { func (r DefaultDNSRule) IsValid() bool {
@ -135,13 +137,17 @@ func (r DefaultDNSRule) Equals(other DefaultDNSRule) bool {
common.ComparableSliceEquals(r.UserID, other.UserID) && common.ComparableSliceEquals(r.UserID, other.UserID) &&
common.ComparableSliceEquals(r.PackageName, other.PackageName) && common.ComparableSliceEquals(r.PackageName, other.PackageName) &&
common.ComparableSliceEquals(r.Outbound, other.Outbound) && common.ComparableSliceEquals(r.Outbound, other.Outbound) &&
r.Server == other.Server r.Invert == other.Invert &&
r.Server == other.Server &&
r.DisableCache == other.DisableCache
} }
type LogicalDNSRule struct { type LogicalDNSRule struct {
Mode string `json:"mode"` Mode string `json:"mode"`
Rules []DefaultDNSRule `json:"rules,omitempty"` Rules []DefaultDNSRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Server string `json:"server,omitempty"` Server string `json:"server,omitempty"`
DisableCache bool `json:"disable_cache,omitempty"`
} }
func (r LogicalDNSRule) IsValid() bool { func (r LogicalDNSRule) IsValid() bool {
@ -151,5 +157,7 @@ func (r LogicalDNSRule) IsValid() bool {
func (r LogicalDNSRule) Equals(other LogicalDNSRule) bool { func (r LogicalDNSRule) Equals(other LogicalDNSRule) bool {
return r.Mode == other.Mode && return r.Mode == other.Mode &&
common.SliceEquals(r.Rules, other.Rules) && common.SliceEquals(r.Rules, other.Rules) &&
r.Server == other.Server r.Invert == other.Invert &&
r.Server == other.Server &&
r.DisableCache == other.DisableCache
} }

View File

@ -145,6 +145,7 @@ func (r DefaultRule) Equals(other DefaultRule) bool {
type LogicalRule struct { type LogicalRule struct {
Mode string `json:"mode"` Mode string `json:"mode"`
Rules []DefaultRule `json:"rules,omitempty"` Rules []DefaultRule `json:"rules,omitempty"`
Invert bool `json:"invert,omitempty"`
Outbound string `json:"outbound,omitempty"` Outbound string `json:"outbound,omitempty"`
} }
@ -155,5 +156,6 @@ func (r LogicalRule) IsValid() bool {
func (r LogicalRule) Equals(other LogicalRule) bool { func (r LogicalRule) Equals(other LogicalRule) bool {
return r.Mode == other.Mode && return r.Mode == other.Mode &&
common.SliceEquals(r.Rules, other.Rules) && common.SliceEquals(r.Rules, other.Rules) &&
r.Invert == other.Invert &&
r.Outbound == other.Outbound r.Outbound == other.Outbound
} }

View File

@ -59,7 +59,7 @@ type Router struct {
geositeCache map[string]adapter.Rule geositeCache map[string]adapter.Rule
dnsClient *dns.Client dnsClient *dns.Client
defaultDomainStrategy dns.DomainStrategy defaultDomainStrategy dns.DomainStrategy
dnsRules []adapter.Rule dnsRules []adapter.DNSRule
defaultTransport dns.Transport defaultTransport dns.Transport
transports []dns.Transport transports []dns.Transport
transportMap map[string]dns.Transport transportMap map[string]dns.Transport
@ -80,7 +80,7 @@ func NewRouter(ctx context.Context, logger log.ContextLogger, dnsLogger log.Cont
dnsLogger: dnsLogger, dnsLogger: dnsLogger,
outboundByTag: make(map[string]adapter.Outbound), outboundByTag: make(map[string]adapter.Outbound),
rules: make([]adapter.Rule, 0, len(options.Rules)), rules: make([]adapter.Rule, 0, len(options.Rules)),
dnsRules: make([]adapter.Rule, 0, len(dnsOptions.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)),
needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule), needGeoIPDatabase: hasRule(options.Rules, isGeoIPRule) || hasDNSRule(dnsOptions.Rules, isGeoIPDNSRule),
needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule), needGeositeDatabase: hasRule(options.Rules, isGeositeRule) || hasDNSRule(dnsOptions.Rules, isGeositeDNSRule),
geoIPOptions: common.PtrValueOrDefault(options.GeoIP), geoIPOptions: common.PtrValueOrDefault(options.GeoIP),
@ -536,15 +536,18 @@ func (r *Router) RoutePacketConnection(ctx context.Context, conn N.PacketConn, m
} }
func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) { func (r *Router) Exchange(ctx context.Context, message *dnsmessage.Message) (*dnsmessage.Message, error) {
return r.dnsClient.Exchange(ctx, r.matchDNS(ctx), message) ctx, transport := r.matchDNS(ctx)
return r.dnsClient.Exchange(ctx, transport, message)
} }
func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) { func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) {
return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, strategy) ctx, transport := r.matchDNS(ctx)
return r.dnsClient.Lookup(ctx, transport, domain, strategy)
} }
func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) { func (r *Router) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) {
return r.dnsClient.Lookup(ctx, r.matchDNS(ctx), domain, r.defaultDomainStrategy) ctx, transport := r.matchDNS(ctx)
return r.dnsClient.Lookup(ctx, transport, domain, r.defaultDomainStrategy)
} }
func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) { func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, defaultOutbound adapter.Outbound) (adapter.Rule, adapter.Outbound) {
@ -586,23 +589,26 @@ func (r *Router) match(ctx context.Context, metadata *adapter.InboundContext, de
return nil, defaultOutbound return nil, defaultOutbound
} }
func (r *Router) matchDNS(ctx context.Context) dns.Transport { func (r *Router) matchDNS(ctx context.Context) (context.Context, dns.Transport) {
metadata := adapter.ContextFrom(ctx) metadata := adapter.ContextFrom(ctx)
if metadata == nil { if metadata == nil {
r.dnsLogger.WarnContext(ctx, "no context: ", reflect.TypeOf(ctx)) r.dnsLogger.WarnContext(ctx, "no context: ", reflect.TypeOf(ctx))
return r.defaultTransport return ctx, r.defaultTransport
} }
for i, rule := range r.dnsRules { for i, rule := range r.dnsRules {
if rule.Match(metadata) { if rule.Match(metadata) {
if rule.DisableCache() {
ctx = dns.ContextWithDisableCache(ctx, true)
}
detour := rule.Outbound() detour := rule.Outbound()
r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour) r.dnsLogger.DebugContext(ctx, "match[", i, "] ", rule.String(), " => ", detour)
if transport, loaded := r.transportMap[detour]; loaded { if transport, loaded := r.transportMap[detour]; loaded {
return transport return ctx, transport
} }
r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour) r.dnsLogger.ErrorContext(ctx, "transport not found: ", detour)
} }
} }
return r.defaultTransport return ctx, r.defaultTransport
} }
func (r *Router) InterfaceBindManager() control.BindManager { func (r *Router) InterfaceBindManager() control.BindManager {

View File

@ -49,10 +49,6 @@ type DefaultRule struct {
outbound string outbound string
} }
func (r *DefaultRule) Type() string {
return C.RuleTypeDefault
}
type RuleItem interface { type RuleItem interface {
Match(metadata *adapter.InboundContext) bool Match(metadata *adapter.InboundContext) bool
String() string String() string
@ -180,6 +176,10 @@ func NewDefaultRule(router adapter.Router, logger log.ContextLogger, options opt
return rule, nil return rule, nil
} }
func (r *DefaultRule) Type() string {
return C.RuleTypeDefault
}
func (r *DefaultRule) Start() error { func (r *DefaultRule) Start() error {
for _, item := range r.allItems { for _, item := range r.allItems {
err := common.Start(item) err := common.Start(item)
@ -261,9 +261,34 @@ var _ adapter.Rule = (*LogicalRule)(nil)
type LogicalRule struct { type LogicalRule struct {
mode string mode string
rules []*DefaultRule rules []*DefaultRule
invert bool
outbound string outbound string
} }
func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
r := &LogicalRule{
rules: make([]*DefaultRule, len(options.Rules)),
invert: options.Invert,
outbound: options.Outbound,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
}
func (r *LogicalRule) Type() string { func (r *LogicalRule) Type() string {
return C.RuleTypeLogical return C.RuleTypeLogical
} }
@ -298,38 +323,15 @@ func (r *LogicalRule) Close() error {
return nil return nil
} }
func NewLogicalRule(router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
r := &LogicalRule{
rules: make([]*DefaultRule, len(options.Rules)),
outbound: options.Outbound,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
}
func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool { func (r *LogicalRule) Match(metadata *adapter.InboundContext) bool {
if r.mode == C.LogicalTypeAnd { if r.mode == C.LogicalTypeAnd {
return common.All(r.rules, func(it *DefaultRule) bool { return common.All(r.rules, func(it *DefaultRule) bool {
return it.Match(metadata) return it.Match(metadata)
}) }) != r.invert
} else { } else {
return common.Any(r.rules, func(it *DefaultRule) bool { return common.Any(r.rules, func(it *DefaultRule) bool {
return it.Match(metadata) return it.Match(metadata)
}) }) != r.invert
} }
} }
@ -345,5 +347,9 @@ func (r *LogicalRule) String() string {
case C.LogicalTypeOr: case C.LogicalTypeOr:
op = "||" op = "||"
} }
return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" if !r.invert {
return strings.Join(F.MapToString(r.rules), " "+op+" ")
} else {
return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
}
} }

View File

@ -12,7 +12,7 @@ import (
F "github.com/sagernet/sing/common/format" F "github.com/sagernet/sing/common/format"
) )
func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.Rule, error) { func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.DNSRule) (adapter.DNSRule, error) {
if common.IsEmptyByEquals(options) { if common.IsEmptyByEquals(options) {
return nil, E.New("empty rule config") return nil, E.New("empty rule config")
} }
@ -38,7 +38,7 @@ func NewDNSRule(router adapter.Router, logger log.ContextLogger, options option.
} }
} }
var _ adapter.Rule = (*DefaultDNSRule)(nil) var _ adapter.DNSRule = (*DefaultDNSRule)(nil)
type DefaultDNSRule struct { type DefaultDNSRule struct {
items []RuleItem items []RuleItem
@ -46,16 +46,14 @@ type DefaultDNSRule struct {
allItems []RuleItem allItems []RuleItem
invert bool invert bool
outbound string outbound string
} disableCache bool
func (r *DefaultDNSRule) Type() string {
return C.RuleTypeDefault
} }
func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) {
rule := &DefaultDNSRule{ rule := &DefaultDNSRule{
invert: true, invert: options.Invert,
outbound: options.Server, outbound: options.Server,
disableCache: options.DisableCache,
} }
if len(options.Inbound) > 0 { if len(options.Inbound) > 0 {
item := NewInboundRule(options.Inbound) item := NewInboundRule(options.Inbound)
@ -156,6 +154,10 @@ func NewDefaultDNSRule(router adapter.Router, logger log.ContextLogger, options
return rule, nil return rule, nil
} }
func (r *DefaultDNSRule) Type() string {
return C.RuleTypeDefault
}
func (r *DefaultDNSRule) Start() error { func (r *DefaultDNSRule) Start() error {
for _, item := range r.allItems { for _, item := range r.allItems {
err := common.Start(item) err := common.Start(item)
@ -213,16 +215,47 @@ func (r *DefaultDNSRule) Outbound() string {
return r.outbound return r.outbound
} }
func (r *DefaultDNSRule) DisableCache() bool {
return r.disableCache
}
func (r *DefaultDNSRule) String() string { func (r *DefaultDNSRule) String() string {
return strings.Join(F.MapToString(r.allItems), " ") return strings.Join(F.MapToString(r.allItems), " ")
} }
var _ adapter.Rule = (*LogicalRule)(nil) var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
type LogicalDNSRule struct { type LogicalDNSRule struct {
mode string mode string
rules []*DefaultDNSRule rules []*DefaultDNSRule
invert bool
outbound string outbound string
disableCache bool
}
func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
r := &LogicalDNSRule{
rules: make([]*DefaultDNSRule, len(options.Rules)),
invert: options.Invert,
outbound: options.Server,
disableCache: options.DisableCache,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultDNSRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
} }
func (r *LogicalDNSRule) Type() string { func (r *LogicalDNSRule) Type() string {
@ -259,38 +292,15 @@ func (r *LogicalDNSRule) Close() error {
return nil return nil
} }
func NewLogicalDNSRule(router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) {
r := &LogicalDNSRule{
rules: make([]*DefaultDNSRule, len(options.Rules)),
outbound: options.Server,
}
switch options.Mode {
case C.LogicalTypeAnd:
r.mode = C.LogicalTypeAnd
case C.LogicalTypeOr:
r.mode = C.LogicalTypeOr
default:
return nil, E.New("unknown logical mode: ", options.Mode)
}
for i, subRule := range options.Rules {
rule, err := NewDefaultDNSRule(router, logger, subRule)
if err != nil {
return nil, E.Cause(err, "sub rule[", i, "]")
}
r.rules[i] = rule
}
return r, nil
}
func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
if r.mode == C.LogicalTypeAnd { if r.mode == C.LogicalTypeAnd {
return common.All(r.rules, func(it *DefaultDNSRule) bool { return common.All(r.rules, func(it *DefaultDNSRule) bool {
return it.Match(metadata) return it.Match(metadata)
}) }) != r.invert
} else { } else {
return common.Any(r.rules, func(it *DefaultDNSRule) bool { return common.Any(r.rules, func(it *DefaultDNSRule) bool {
return it.Match(metadata) return it.Match(metadata)
}) }) != r.invert
} }
} }
@ -298,6 +308,10 @@ func (r *LogicalDNSRule) Outbound() string {
return r.outbound return r.outbound
} }
func (r *LogicalDNSRule) DisableCache() bool {
return r.disableCache
}
func (r *LogicalDNSRule) String() string { func (r *LogicalDNSRule) String() string {
var op string var op string
switch r.mode { switch r.mode {
@ -306,5 +320,9 @@ func (r *LogicalDNSRule) String() string {
case C.LogicalTypeOr: case C.LogicalTypeOr:
op = "||" op = "||"
} }
return "logical(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")" if !r.invert {
return strings.Join(F.MapToString(r.rules), " "+op+" ")
} else {
return "!(" + strings.Join(F.MapToString(r.rules), " "+op+" ") + ")"
}
} }