From 466171b3cf91282980fe9fb5ccb69afe6f9b724f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 6 Nov 2024 17:30:40 +0800 Subject: [PATCH] Improve rule actions --- adapter/inbound.go | 4 +- constant/rule.go | 13 +-- option/rule.go | 14 +-- option/rule_action.go | 182 ++++++++++++++++++++++++++++++++----- option/rule_dns.go | 16 ++-- protocol/dns/handle.go | 1 + route/route.go | 130 ++++++++++++++------------ route/route_dns.go | 107 +++++++++++++++++----- route/rule/rule_action.go | 113 +++++++++++++++++++---- route/rule/rule_default.go | 4 +- 10 files changed, 434 insertions(+), 150 deletions(-) diff --git a/adapter/inbound.go b/adapter/inbound.go index ca3e9e59..f9ed1708 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -57,7 +57,9 @@ type InboundContext struct { // Deprecated InboundOptions option.InboundOptions UDPDisableDomainUnmapping bool - DNSServer string + UDPConnect bool + + DNSServer string DestinationAddresses []netip.Addr SourceGeoIPCode string diff --git a/constant/rule.go b/constant/rule.go index 73227175..b1f91c60 100644 --- a/constant/rule.go +++ b/constant/rule.go @@ -25,12 +25,13 @@ const ( ) const ( - RuleActionTypeRoute = "route" - RuleActionTypeReturn = "return" - RuleActionTypeReject = "reject" - RuleActionTypeHijackDNS = "hijack-dns" - RuleActionTypeSniff = "sniff" - RuleActionTypeResolve = "resolve" + RuleActionTypeRoute = "route" + RuleActionTypeRouteOptions = "route-options" + RuleActionTypeDirect = "direct" + RuleActionTypeReject = "reject" + RuleActionTypeHijackDNS = "hijack-dns" + RuleActionTypeSniff = "sniff" + RuleActionTypeResolve = "resolve" ) const ( diff --git a/option/rule.go b/option/rule.go index 07e6ddbe..952afa61 100644 --- a/option/rule.go +++ b/option/rule.go @@ -109,7 +109,7 @@ type DefaultRule struct { RuleAction } -func (r *DefaultRule) MarshalJSON() ([]byte, error) { +func (r DefaultRule) MarshalJSON() ([]byte, error) { return badjson.MarshallObjects(r.RawDefaultRule, r.RuleAction) } @@ -128,27 +128,27 @@ func (r *DefaultRule) IsValid() bool { return !reflect.DeepEqual(r, defaultValue) } -type _LogicalRule struct { +type RawLogicalRule struct { Mode string `json:"mode"` Rules []Rule `json:"rules,omitempty"` Invert bool `json:"invert,omitempty"` } type LogicalRule struct { - _LogicalRule + RawLogicalRule RuleAction } -func (r *LogicalRule) MarshalJSON() ([]byte, error) { - return badjson.MarshallObjects(r._LogicalRule, r.RuleAction) +func (r LogicalRule) MarshalJSON() ([]byte, error) { + return badjson.MarshallObjects(r.RawLogicalRule, r.RuleAction) } func (r *LogicalRule) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, &r._LogicalRule) + err := json.Unmarshal(data, &r.RawLogicalRule) if err != nil { return err } - return badjson.UnmarshallExcluded(data, &r._LogicalRule, &r.RuleAction) + return badjson.UnmarshallExcluded(data, &r.RawLogicalRule, &r.RuleAction) } func (r *LogicalRule) IsValid() bool { diff --git a/option/rule_action.go b/option/rule_action.go index 3a40e1c0..edc197de 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -1,30 +1,41 @@ package option import ( + "fmt" + "time" + C "github.com/sagernet/sing-box/constant" + dns "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json" "github.com/sagernet/sing/common/json/badjson" ) type _RuleAction struct { - Action string `json:"action,omitempty"` - RouteOptions RouteActionOptions `json:"-"` - RejectOptions RejectActionOptions `json:"-"` - SniffOptions RouteActionSniff `json:"-"` - ResolveOptions RouteActionResolve `json:"-"` + Action string `json:"action,omitempty"` + RouteOptions RouteActionOptions `json:"-"` + RouteOptionsOptions RouteOptionsActionOptions `json:"-"` + DirectOptions DirectActionOptions `json:"-"` + RejectOptions RejectActionOptions `json:"-"` + SniffOptions RouteActionSniff `json:"-"` + ResolveOptions RouteActionResolve `json:"-"` } type RuleAction _RuleAction func (r RuleAction) MarshalJSON() ([]byte, error) { + if r.Action == "" { + return json.Marshal(struct{}{}) + } var v any switch r.Action { case C.RuleActionTypeRoute: r.Action = "" v = r.RouteOptions - case C.RuleActionTypeReturn: - v = nil + case C.RuleActionTypeRouteOptions: + v = r.RouteOptionsOptions + case C.RuleActionTypeDirect: + v = r.DirectOptions case C.RuleActionTypeReject: v = r.RejectOptions case C.RuleActionTypeHijackDNS: @@ -52,8 +63,10 @@ func (r *RuleAction) UnmarshalJSON(data []byte) error { case "", C.RuleActionTypeRoute: r.Action = C.RuleActionTypeRoute v = &r.RouteOptions - case C.RuleActionTypeReturn: - v = nil + case C.RuleActionTypeRouteOptions: + v = &r.RouteOptionsOptions + case C.RuleActionTypeDirect: + v = &r.DirectOptions case C.RuleActionTypeReject: v = &r.RejectOptions case C.RuleActionTypeHijackDNS: @@ -73,29 +86,30 @@ func (r *RuleAction) UnmarshalJSON(data []byte) error { } type _DNSRuleAction struct { - Action string `json:"action,omitempty"` - RouteOptions DNSRouteActionOptions `json:"-"` - RejectOptions RejectActionOptions `json:"-"` + Action string `json:"action,omitempty"` + RouteOptions DNSRouteActionOptions `json:"-"` + RouteOptionsOptions DNSRouteOptionsActionOptions `json:"-"` + RejectOptions RejectActionOptions `json:"-"` } type DNSRuleAction _DNSRuleAction func (r DNSRuleAction) MarshalJSON() ([]byte, error) { + if r.Action == "" { + return json.Marshal(struct{}{}) + } var v any switch r.Action { case C.RuleActionTypeRoute: r.Action = "" v = r.RouteOptions - case C.RuleActionTypeReturn: - v = nil + case C.RuleActionTypeRouteOptions: + v = r.RouteOptionsOptions case C.RuleActionTypeReject: v = r.RejectOptions default: return nil, E.New("unknown DNS rule action: " + r.Action) } - if v == nil { - return badjson.MarshallObjects((_DNSRuleAction)(r)) - } return badjson.MarshallObjects((_DNSRuleAction)(r), v) } @@ -109,8 +123,8 @@ func (r *DNSRuleAction) UnmarshalJSON(data []byte) error { case "", C.RuleActionTypeRoute: r.Action = C.RuleActionTypeRoute v = &r.RouteOptions - case C.RuleActionTypeReturn: - v = nil + case C.RuleActionTypeRouteOptions: + v = &r.RouteOptionsOptions case C.RuleActionTypeReject: v = &r.RejectOptions default: @@ -123,18 +137,136 @@ func (r *DNSRuleAction) UnmarshalJSON(data []byte) error { return badjson.UnmarshallExcluded(data, (*_DNSRuleAction)(r), v) } -type RouteActionOptions struct { - Outbound string `json:"outbound"` - UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` +type _RouteActionOptions struct { + Outbound string `json:"outbound,omitempty"` } -type DNSRouteActionOptions struct { - Server string `json:"server"` +type RouteActionOptions _RouteActionOptions + +func (r *RouteActionOptions) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, (*_RouteActionOptions)(r)) + if err != nil { + return err + } + if r.Outbound == "" { + return E.New("missing outbound") + } + return nil +} + +type _RouteOptionsActionOptions struct { + UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` + UDPConnect bool `json:"udp_connect,omitempty"` +} + +type RouteOptionsActionOptions _RouteOptionsActionOptions + +func (r *RouteOptionsActionOptions) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, (*_RouteOptionsActionOptions)(r)) + if err != nil { + return err + } + if *r == (RouteOptionsActionOptions{}) { + return E.New("empty route option action") + } + return nil +} + +type _DNSRouteActionOptions struct { + Server string `json:"server,omitempty"` + // Deprecated: Use DNSRouteOptionsActionOptions instead. + DisableCache bool `json:"disable_cache,omitempty"` + // Deprecated: Use DNSRouteOptionsActionOptions instead. + RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` + // Deprecated: Use DNSRouteOptionsActionOptions instead. + ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"` +} + +type DNSRouteActionOptions _DNSRouteActionOptions + +func (r *DNSRouteActionOptions) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, (*_DNSRouteActionOptions)(r)) + if err != nil { + return err + } + if r.Server == "" { + return E.New("missing server") + } + return nil +} + +type _DNSRouteOptionsActionOptions struct { DisableCache bool `json:"disable_cache,omitempty"` RewriteTTL *uint32 `json:"rewrite_ttl,omitempty"` ClientSubnet *AddrPrefix `json:"client_subnet,omitempty"` } +type DNSRouteOptionsActionOptions _DNSRouteOptionsActionOptions + +func (r *DNSRouteOptionsActionOptions) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, (*_DNSRouteOptionsActionOptions)(r)) + if err != nil { + return err + } + if *r == (DNSRouteOptionsActionOptions{}) { + return E.New("empty DNS route option action") + } + return nil +} + +type _DirectActionOptions DialerOptions + +type DirectActionOptions _DirectActionOptions + +func (d DirectActionOptions) Descriptions() []string { + var descriptions []string + if d.BindInterface != "" { + descriptions = append(descriptions, "bind_interface="+d.BindInterface) + } + if d.Inet4BindAddress != nil { + descriptions = append(descriptions, "inet4_bind_address="+d.Inet4BindAddress.Build().String()) + } + if d.Inet6BindAddress != nil { + descriptions = append(descriptions, "inet6_bind_address="+d.Inet6BindAddress.Build().String()) + } + if d.RoutingMark != 0 { + descriptions = append(descriptions, "routing_mark="+fmt.Sprintf("0x%x", d.RoutingMark)) + } + if d.ReuseAddr { + descriptions = append(descriptions, "reuse_addr") + } + if d.ConnectTimeout != 0 { + descriptions = append(descriptions, "connect_timeout="+time.Duration(d.ConnectTimeout).String()) + } + if d.TCPFastOpen { + descriptions = append(descriptions, "tcp_fast_open") + } + if d.TCPMultiPath { + descriptions = append(descriptions, "tcp_multi_path") + } + if d.UDPFragment != nil { + descriptions = append(descriptions, "udp_fragment="+fmt.Sprint(*d.UDPFragment)) + } + if d.DomainStrategy != DomainStrategy(dns.DomainStrategyAsIS) { + descriptions = append(descriptions, "domain_strategy="+d.DomainStrategy.String()) + } + if d.FallbackDelay != 0 { + descriptions = append(descriptions, "fallback_delay="+time.Duration(d.FallbackDelay).String()) + } + return descriptions +} + +func (d *DirectActionOptions) UnmarshalJSON(data []byte) error { + err := json.Unmarshal(data, (*_DirectActionOptions)(d)) + if err != nil { + return err + } + if d.Detour != "" { + return E.New("detour is not available in the current context") + } + return nil +} + type _RejectActionOptions struct { Method string `json:"method,omitempty"` NoDrop bool `json:"no_drop,omitempty"` @@ -155,7 +287,7 @@ func (r *RejectActionOptions) UnmarshalJSON(bytes []byte) error { return E.New("unknown reject method: " + r.Method) } if r.Method == C.RuleActionRejectMethodDrop && r.NoDrop { - return E.New("no_drop is not allowed when method is drop") + return E.New("no_drop is not available in current context") } return nil } diff --git a/option/rule_dns.go b/option/rule_dns.go index 8c4b6ab8..0683e16a 100644 --- a/option/rule_dns.go +++ b/option/rule_dns.go @@ -111,7 +111,7 @@ type DefaultDNSRule struct { DNSRuleAction } -func (r *DefaultDNSRule) MarshalJSON() ([]byte, error) { +func (r DefaultDNSRule) MarshalJSON() ([]byte, error) { return badjson.MarshallObjects(r.RawDefaultDNSRule, r.DNSRuleAction) } @@ -123,34 +123,34 @@ func (r *DefaultDNSRule) UnmarshalJSON(data []byte) error { return badjson.UnmarshallExcluded(data, &r.RawDefaultDNSRule, &r.DNSRuleAction) } -func (r *DefaultDNSRule) IsValid() bool { +func (r DefaultDNSRule) IsValid() bool { var defaultValue DefaultDNSRule defaultValue.Invert = r.Invert defaultValue.DNSRuleAction = r.DNSRuleAction return !reflect.DeepEqual(r, defaultValue) } -type _LogicalDNSRule struct { +type RawLogicalDNSRule struct { Mode string `json:"mode"` Rules []DNSRule `json:"rules,omitempty"` Invert bool `json:"invert,omitempty"` } type LogicalDNSRule struct { - _LogicalDNSRule + RawLogicalDNSRule DNSRuleAction } -func (r *LogicalDNSRule) MarshalJSON() ([]byte, error) { - return badjson.MarshallObjects(r._LogicalDNSRule, r.DNSRuleAction) +func (r LogicalDNSRule) MarshalJSON() ([]byte, error) { + return badjson.MarshallObjects(r.RawLogicalDNSRule, r.DNSRuleAction) } func (r *LogicalDNSRule) UnmarshalJSON(data []byte) error { - err := json.Unmarshal(data, &r._LogicalDNSRule) + err := json.Unmarshal(data, &r.RawLogicalDNSRule) if err != nil { return err } - return badjson.UnmarshallExcluded(data, &r._LogicalDNSRule, &r.DNSRuleAction) + return badjson.UnmarshallExcluded(data, &r.RawLogicalDNSRule, &r.DNSRuleAction) } func (r *LogicalDNSRule) IsValid() bool { diff --git a/protocol/dns/handle.go b/protocol/dns/handle.go index 23ed1c0c..bc58d9e2 100644 --- a/protocol/dns/handle.go +++ b/protocol/dns/handle.go @@ -43,6 +43,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net go func() error { response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message) if err != nil { + conn.Close() return err } responseBuffer := buf.NewPacket() diff --git a/route/route.go b/route/route.go index 2c199a2d..854fa4f1 100644 --- a/route/route.go +++ b/route/route.go @@ -87,23 +87,34 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if deadline.NeedAdditionalReadDeadline(conn) { conn = deadline.NewConn(conn) } - selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil, -1) + selectedRule, _, buffers, _, err := r.matchRule(ctx, &metadata, false, conn, nil) if err != nil { return err } - var selectedOutbound adapter.Outbound - var selectReturn bool + var ( + // selectedOutbound adapter.Outbound + selectedDialer N.Dialer + selectedTag string + selectedDescription string + ) if selectedRule != nil { switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: - var loaded bool - selectedOutbound, loaded = r.Outbound(action.Outbound) + selectedOutbound, loaded := r.Outbound(action.Outbound) if !loaded { buf.ReleaseMulti(buffers) return E.New("outbound not found: ", action.Outbound) } - case *rule.RuleActionReturn: - selectReturn = true + if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) { + buf.ReleaseMulti(buffers) + return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag()) + } + selectedDialer = selectedOutbound + selectedTag = selectedOutbound.Tag() + selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") + case *rule.RuleActionDirect: + selectedDialer = action.Dialer + selectedDescription = action.String() case *rule.RuleActionReject: buf.ReleaseMulti(buffers) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) @@ -116,17 +127,16 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad return nil } } - if selectedRule == nil || selectReturn { + if selectedRule == nil { if r.defaultOutboundForConnection == nil { buf.ReleaseMulti(buffers) return E.New("missing default outbound with TCP support") } - selectedOutbound = r.defaultOutboundForConnection - } - if !common.Contains(selectedOutbound.Network(), N.NetworkTCP) { - buf.ReleaseMulti(buffers) - return E.New("TCP is not supported by outbound: ", selectedOutbound.Tag()) + selectedDialer = r.defaultOutboundForConnection + selectedTag = r.defaultOutboundForConnection.Tag() + selectedDescription = F.ToString("outbound/", r.defaultOutboundForConnection.Type(), "[", r.defaultOutboundForConnection.Tag(), "]") } + for _, buffer := range buffers { conn = bufio.NewCachedConn(conn, buffer) } @@ -137,10 +147,10 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad } if r.v2rayServer != nil { if statsService := r.v2rayServer.StatsService(); statsService != nil { - conn = statsService.RoutedConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn) + conn = statsService.RoutedConnection(metadata.Inbound, selectedTag, metadata.User, conn) } } - legacyOutbound, isLegacy := selectedOutbound.(adapter.ConnectionHandler) + legacyOutbound, isLegacy := selectedDialer.(adapter.ConnectionHandler) if isLegacy { err = legacyOutbound.NewConnection(ctx, conn, metadata) if err != nil { @@ -148,7 +158,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if onClose != nil { onClose(err) } - return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") + return E.Cause(err, selectedDescription) } else { if onClose != nil { onClose(nil) @@ -157,13 +167,13 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad return nil } // TODO - err = outbound.NewConnection(ctx, selectedOutbound, conn, metadata) + err = outbound.NewConnection(ctx, selectedDialer, conn, metadata) if err != nil { conn.Close() if onClose != nil { onClose(err) } - return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") + return E.Cause(err, selectedDescription) } else { if onClose != nil { onClose(nil) @@ -231,24 +241,34 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m conn = deadline.NewPacketConn(bufio.NewNetPacketConn(conn)) }*/ - selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn, -1) + selectedRule, _, _, packetBuffers, err := r.matchRule(ctx, &metadata, false, nil, conn) if err != nil { return err } - var selectedOutbound adapter.Outbound + var ( + selectedDialer N.Dialer + selectedTag string + selectedDescription string + ) var selectReturn bool if selectedRule != nil { switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: - var loaded bool - selectedOutbound, loaded = r.Outbound(action.Outbound) + selectedOutbound, loaded := r.Outbound(action.Outbound) if !loaded { N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("outbound not found: ", action.Outbound) } - metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping - case *rule.RuleActionReturn: - selectReturn = true + if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) { + N.ReleaseMultiPacketBuffer(packetBuffers) + return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) + } + selectedDialer = selectedOutbound + selectedTag = selectedOutbound.Tag() + selectedDescription = F.ToString("outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") + case *rule.RuleActionDirect: + selectedDialer = action.Dialer + selectedDescription = action.String() case *rule.RuleActionReject: N.ReleaseMultiPacketBuffer(packetBuffers) N.CloseOnHandshakeFailure(conn, onClose, action.Error(ctx)) @@ -263,11 +283,9 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("missing default outbound with UDP support") } - selectedOutbound = r.defaultOutboundForPacketConnection - } - if !common.Contains(selectedOutbound.Network(), N.NetworkUDP) { - N.ReleaseMultiPacketBuffer(packetBuffers) - return E.New("UDP is not supported by outbound: ", selectedOutbound.Tag()) + selectedDialer = r.defaultOutboundForPacketConnection + selectedTag = r.defaultOutboundForPacketConnection.Tag() + selectedDescription = F.ToString("outbound/", r.defaultOutboundForPacketConnection.Type(), "[", r.defaultOutboundForPacketConnection.Tag(), "]") } for _, buffer := range packetBuffers { conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination) @@ -280,32 +298,32 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m } if r.v2rayServer != nil { if statsService := r.v2rayServer.StatsService(); statsService != nil { - conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedOutbound.Tag(), metadata.User, conn) + conn = statsService.RoutedPacketConnection(metadata.Inbound, selectedTag, metadata.User, conn) } } if metadata.FakeIP { conn = bufio.NewNATPacketConn(bufio.NewNetPacketConn(conn), metadata.OriginDestination, metadata.Destination) } - legacyOutbound, isLegacy := selectedOutbound.(adapter.PacketConnectionHandler) + legacyOutbound, isLegacy := selectedDialer.(adapter.PacketConnectionHandler) if isLegacy { err = legacyOutbound.NewPacketConnection(ctx, conn, metadata) N.CloseOnHandshakeFailure(conn, onClose, err) if err != nil { - return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") + return E.Cause(err, selectedDescription) } return nil } // TODO - err = outbound.NewPacketConnection(ctx, selectedOutbound, conn, metadata) + err = outbound.NewPacketConnection(ctx, selectedDialer, conn, metadata) N.CloseOnHandshakeFailure(conn, onClose, err) if err != nil { - return E.Cause(err, "outbound/", selectedOutbound.Type(), "[", selectedOutbound.Tag(), "]") + return E.Cause(err, selectedDescription) } return nil } func (r *Router) PreMatch(metadata adapter.InboundContext) error { - selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil, -1) + selectedRule, _, _, _, err := r.matchRule(r.ctx, &metadata, true, nil, nil) if err != nil { return err } @@ -321,7 +339,7 @@ func (r *Router) PreMatch(metadata adapter.InboundContext) error { func (r *Router) matchRule( ctx context.Context, metadata *adapter.InboundContext, preMatch bool, - inputConn net.Conn, inputPacketConn N.PacketConn, ruleIndex int, + inputConn net.Conn, inputPacketConn N.PacketConn, ) ( selectedRule adapter.Rule, selectedRuleIndex int, buffers []*buf.Buffer, packetBuffers []*N.PacketBuffer, fatalErr error, @@ -416,24 +434,10 @@ func (r *Router) matchRule( } match: - for ruleIndex < len(r.rules) { - rules := r.rules - if ruleIndex != -1 { - rules = rules[ruleIndex+1:] - } - var ( - currentRule adapter.Rule - currentRuleIndex int - matched bool - ) - for currentRuleIndex, currentRule = range rules { - if currentRule.Match(metadata) { - matched = true - break - } - } - if !matched { - break + for currentRuleIndex, currentRule := range r.rules { + metadata.ResetRuleCache() + if !currentRule.Match(metadata) { + continue } if !preMatch { ruleDescription := currentRule.String() @@ -444,7 +448,7 @@ match: } } else { switch currentRule.Action().Type() { - case C.RuleActionTypeReject, C.RuleActionTypeResolve: + case C.RuleActionTypeReject: ruleDescription := currentRule.String() if ruleDescription != "" { r.logger.DebugContext(ctx, "pre-match[", currentRuleIndex, "] ", currentRule, " => ", currentRule.Action()) @@ -454,6 +458,12 @@ match: } } switch action := currentRule.Action().(type) { + case *rule.RuleActionRoute: + metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping + metadata.UDPConnect = action.UDPConnect + case *rule.RuleActionRouteOptions: + metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping + metadata.UDPConnect = action.UDPConnect case *rule.RuleActionSniff: if !preMatch { newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, action, inputConn, inputPacketConn) @@ -476,12 +486,16 @@ match: if fatalErr != nil { return } - default: + } + actionType := currentRule.Action().Type() + if actionType == C.RuleActionTypeRoute || + actionType == C.RuleActionTypeReject || + actionType == C.RuleActionTypeHijackDNS || + (actionType == C.RuleActionTypeSniff && preMatch) { selectedRule = currentRule selectedRuleIndex = currentRuleIndex break match } - ruleIndex = currentRuleIndex } if !preMatch && metadata.Destination.Addr.IsUnspecified() { newBuffer, newPacketBuffers, newErr := r.actionSniff(ctx, metadata, &rule.RuleActionSniff{}, inputConn, inputPacketConn) diff --git a/route/route_dns.go b/route/route_dns.go index 60aff6a9..c11c07fe 100644 --- a/route/route_dns.go +++ b/route/route_dns.go @@ -8,8 +8,10 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" R "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-dns" + tun "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/cache" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" @@ -48,38 +50,63 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, if ruleIndex != -1 { dnsRules = dnsRules[ruleIndex+1:] } - for currentRuleIndex, rule := range dnsRules { - if rule.WithAddressLimit() && !isAddressQuery { + for currentRuleIndex, currentRule := range dnsRules { + if currentRule.WithAddressLimit() && !isAddressQuery { continue } metadata.ResetRuleCache() - if rule.Match(metadata) { + if currentRule.Match(metadata) { displayRuleIndex := currentRuleIndex if displayRuleIndex != -1 { displayRuleIndex += displayRuleIndex + 1 } - if routeAction, isRoute := rule.Action().(*R.RuleActionDNSRoute); isRoute { - transport, loaded := r.transportMap[routeAction.Server] + ruleDescription := currentRule.String() + if ruleDescription != "" { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] ", currentRule, " => ", currentRule.Action()) + } else { + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + } + switch action := currentRule.Action().(type) { + case *R.RuleActionDNSRoute: + transport, loaded := r.transportMap[action.Server] if !loaded { - r.dnsLogger.ErrorContext(ctx, "transport not found: ", routeAction.Server) + r.dnsLogger.ErrorContext(ctx, "transport not found: ", action.Server) continue } _, isFakeIP := transport.(adapter.FakeIPTransport) if isFakeIP && !allowFakeIP { continue } - options.DisableCache = isFakeIP || routeAction.DisableCache - options.RewriteTTL = routeAction.RewriteTTL - options.ClientSubnet = routeAction.ClientSubnet + if isFakeIP || action.DisableCache { + options.DisableCache = true + } + if action.RewriteTTL != nil { + options.RewriteTTL = action.RewriteTTL + } + if action.ClientSubnet.IsValid() { + options.ClientSubnet = action.ClientSubnet + } if domainStrategy, dsLoaded := r.transportDomainStrategy[transport]; dsLoaded { options.Strategy = domainStrategy } else { options.Strategy = r.defaultDomainStrategy } - r.dnsLogger.DebugContext(ctx, "match[", displayRuleIndex, "] ", rule.String(), " => ", rule.Action()) - return transport, options, rule, currentRuleIndex - } else { - return nil, options, rule, currentRuleIndex + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + return transport, options, currentRule, currentRuleIndex + case *R.RuleActionDNSRouteOptions: + if action.DisableCache { + options.DisableCache = true + } + if action.RewriteTTL != nil { + options.RewriteTTL = action.RewriteTTL + } + if action.ClientSubnet.IsValid() { + options.ClientSubnet = action.ClientSubnet + } + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + case *R.RuleActionReject: + r.logger.DebugContext(ctx, "match[", displayRuleIndex, "] => ", currentRule.Action()) + return nil, options, currentRule, currentRuleIndex } } } @@ -93,9 +120,19 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, } func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - if len(message.Question) > 0 { - r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String())) + if len(message.Question) != 1 { + r.dnsLogger.WarnContext(ctx, "bad question size: ", len(message.Question)) + responseMessage := mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Id: message.Id, + Response: true, + Rcode: mDNS.RcodeFormatError, + }, + Question: message.Question, + } + return &responseMessage, nil } + r.dnsLogger.DebugContext(ctx, "exchange ", formatQuestion(message.Question[0].String())) var ( response *mDNS.Msg cached bool @@ -107,16 +144,14 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er var metadata *adapter.InboundContext ctx, metadata = adapter.ExtendContext(ctx) metadata.Destination = M.Socksaddr{} - if len(message.Question) > 0 { - metadata.QueryType = message.Question[0].Qtype - switch metadata.QueryType { - case mDNS.TypeA: - metadata.IPVersion = 4 - case mDNS.TypeAAAA: - metadata.IPVersion = 6 - } - metadata.Domain = fqdnToDomain(message.Question[0].Name) + metadata.QueryType = message.Question[0].Qtype + switch metadata.QueryType { + case mDNS.TypeA: + metadata.IPVersion = 4 + case mDNS.TypeAAAA: + metadata.IPVersion = 6 } + metadata.Domain = fqdnToDomain(message.Question[0].Name) var ( options dns.QueryOptions rule adapter.DNSRule @@ -127,6 +162,17 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er dnsCtx := adapter.OverrideContext(ctx) var addressLimit bool transport, options, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message)) + if rule != nil { + switch action := rule.Action().(type) { + case *R.RuleActionReject: + switch action.Method { + case C.RuleActionRejectMethodDefault: + return dns.FixedResponse(message.Id, message.Question[0], nil, 0), nil + case C.RuleActionRejectMethodDrop: + return nil, tun.ErrDrop + } + } + } if rule != nil && rule.WithAddressLimit() { addressLimit = true response, err = r.dnsClient.ExchangeWithResponseCheck(dnsCtx, transport, message, options, func(response *mDNS.Msg) bool { @@ -164,7 +210,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, er if err != nil { return nil, err } - if r.dnsReverseMapping != nil && len(message.Question) > 0 && response != nil && len(response.Answer) > 0 { + if r.dnsReverseMapping != nil && response != nil && len(response.Answer) > 0 { if _, isFakeIP := transport.(adapter.FakeIPTransport); !isFakeIP { for _, answer := range response.Answer { switch record := answer.(type) { @@ -238,6 +284,17 @@ func (r *Router) Lookup(ctx context.Context, domain string, strategy dns.DomainS if strategy != dns.DomainStrategyAsIS { options.Strategy = strategy } + if rule != nil { + switch action := rule.Action().(type) { + case *R.RuleActionReject: + switch action.Method { + case C.RuleActionRejectMethodDefault: + return nil, nil + case C.RuleActionRejectMethodDrop: + return nil, tun.ErrDrop + } + } + } if rule != nil && rule.WithAddressLimit() { addressLimit = true responseAddrs, err = r.dnsClient.LookupWithResponseCheck(dnsCtx, transport, domain, options, func(responseAddrs []netip.Addr) bool { diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 031f181c..620260d0 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -5,9 +5,11 @@ import ( "net/netip" "strings" "sync" + "syscall" "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/sniff" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" @@ -17,19 +19,42 @@ import ( E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" "github.com/sagernet/sing/common/logger" - - "golang.org/x/sys/unix" + N "github.com/sagernet/sing/common/network" ) -func NewRuleAction(logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) { +func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) { switch action.Action { + case "": + return nil, nil case C.RuleActionTypeRoute: return &RuleActionRoute{ - Outbound: action.RouteOptions.Outbound, - UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, + Outbound: action.RouteOptions.Outbound, + }, nil + case C.RuleActionTypeRouteOptions: + return &RuleActionRouteOptions{ + UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, + UDPConnect: action.RouteOptionsOptions.UDPConnect, + }, nil + case C.RuleActionTypeDirect: + directDialer, err := dialer.New(router, option.DialerOptions(action.DirectOptions)) + if err != nil { + return nil, err + } + var description string + descriptions := action.DirectOptions.Descriptions() + switch len(descriptions) { + case 0: + case 1: + description = F.ToString("(", descriptions[0], ")") + case 2: + description = F.ToString("(", descriptions[0], ",", descriptions[1], ")") + default: + description = F.ToString("(", descriptions[0], ",", descriptions[1], ",...)") + } + return &RuleActionDirect{ + Dialer: directDialer, + description: description, }, nil - case C.RuleActionTypeReturn: - return &RuleActionReturn{}, nil case C.RuleActionTypeReject: return &RuleActionReject{ Method: action.RejectOptions.Method, @@ -56,6 +81,8 @@ func NewRuleAction(logger logger.ContextLogger, action option.RuleAction) (adapt func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) adapter.RuleAction { switch action.Action { + case "": + return nil case C.RuleActionTypeRoute: return &RuleActionDNSRoute{ Server: action.RouteOptions.Server, @@ -63,8 +90,12 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) RewriteTTL: action.RouteOptions.RewriteTTL, ClientSubnet: action.RouteOptions.ClientSubnet.Build(), } - case C.RuleActionTypeReturn: - return &RuleActionReturn{} + case C.RuleActionTypeRouteOptions: + return &RuleActionDNSRouteOptions{ + DisableCache: action.RouteOptionsOptions.DisableCache, + RewriteTTL: action.RouteOptionsOptions.RewriteTTL, + ClientSubnet: action.RouteOptionsOptions.ClientSubnet.Build(), + } case C.RuleActionTypeReject: return &RuleActionReject{ Method: action.RejectOptions.Method, @@ -77,8 +108,7 @@ func NewDNSRuleAction(logger logger.ContextLogger, action option.DNSRuleAction) } type RuleActionRoute struct { - Outbound string - UDPDisableDomainUnmapping bool + Outbound string } func (r *RuleActionRoute) Type() string { @@ -89,6 +119,26 @@ func (r *RuleActionRoute) String() string { return F.ToString("route(", r.Outbound, ")") } +type RuleActionRouteOptions struct { + UDPDisableDomainUnmapping bool + UDPConnect bool +} + +func (r *RuleActionRouteOptions) Type() string { + return C.RuleActionTypeRouteOptions +} + +func (r *RuleActionRouteOptions) String() string { + var descriptions []string + if r.UDPDisableDomainUnmapping { + descriptions = append(descriptions, "udp-disable-domain-unmapping") + } + if r.UDPConnect { + descriptions = append(descriptions, "udp-connect") + } + return F.ToString("route-options(", strings.Join(descriptions, ","), ")") +} + type RuleActionDNSRoute struct { Server string DisableCache bool @@ -104,14 +154,41 @@ func (r *RuleActionDNSRoute) String() string { return F.ToString("route(", r.Server, ")") } -type RuleActionReturn struct{} - -func (r *RuleActionReturn) Type() string { - return C.RuleActionTypeReturn +type RuleActionDNSRouteOptions struct { + DisableCache bool + RewriteTTL *uint32 + ClientSubnet netip.Prefix } -func (r *RuleActionReturn) String() string { - return "return" +func (r *RuleActionDNSRouteOptions) Type() string { + return C.RuleActionTypeRouteOptions +} + +func (r *RuleActionDNSRouteOptions) String() string { + var descriptions []string + if r.DisableCache { + descriptions = append(descriptions, "disable-cache") + } + if r.RewriteTTL != nil { + descriptions = append(descriptions, F.ToString("rewrite-ttl(", *r.RewriteTTL, ")")) + } + if r.ClientSubnet.IsValid() { + descriptions = append(descriptions, F.ToString("client-subnet(", r.ClientSubnet, ")")) + } + return F.ToString("route-options(", strings.Join(descriptions, ","), ")") +} + +type RuleActionDirect struct { + Dialer N.Dialer + description string +} + +func (r *RuleActionDirect) Type() string { + return C.RuleActionTypeDirect +} + +func (r *RuleActionDirect) String() string { + return "direct" + r.description } type RuleActionReject struct { @@ -137,7 +214,7 @@ func (r *RuleActionReject) Error(ctx context.Context) error { var returnErr error switch r.Method { case C.RuleActionRejectMethodDefault: - returnErr = unix.ECONNREFUSED + returnErr = syscall.ECONNREFUSED case C.RuleActionRejectMethodDrop: return tun.ErrDrop default: diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index a337c19f..566c816e 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -52,7 +52,7 @@ type RuleItem interface { } func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { - action, err := NewRuleAction(logger, options.RuleAction) + action, err := NewRuleAction(router, logger, options.RuleAction) if err != nil { return nil, E.Cause(err, "action") } @@ -254,7 +254,7 @@ type LogicalRule struct { } func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { - action, err := NewRuleAction(logger, options.RuleAction) + action, err := NewRuleAction(router, logger, options.RuleAction) if err != nil { return nil, E.Cause(err, "action") }