From dda692e95550dc044eb3cdd087ef59414ea60dda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 15 Dec 2024 00:45:41 +0800 Subject: [PATCH] Fix domain strategy --- adapter/inbound.go | 2 +- adapter/network.go | 2 +- common/dialer/default.go | 98 +++++++++++++-------- common/dialer/default_parallel_interface.go | 55 +++++++----- common/dialer/default_parallel_network.go | 25 +++++- common/dialer/dialer.go | 21 ++--- common/dialer/resolve.go | 4 +- option/outbound.go | 37 ++++---- option/route.go | 2 +- option/rule_action.go | 4 +- protocol/direct/outbound.go | 36 +++----- protocol/wireguard/endpoint.go | 3 +- protocol/wireguard/outbound.go | 3 +- route/conn.go | 11 ++- route/network.go | 4 +- route/route.go | 14 ++- route/router.go | 2 +- route/rule/rule_action.go | 6 +- transport/dhcp/server.go | 2 +- 19 files changed, 195 insertions(+), 136 deletions(-) diff --git a/adapter/inbound.go b/adapter/inbound.go index f5d5c95b..93d2ec60 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -72,7 +72,7 @@ type InboundContext struct { UDPConnect bool UDPTimeout time.Duration - NetworkStrategy C.NetworkStrategy + NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType FallbackDelay time.Duration diff --git a/adapter/network.go b/adapter/network.go index 08fc00fa..00ef54b8 100644 --- a/adapter/network.go +++ b/adapter/network.go @@ -28,7 +28,7 @@ type NetworkManager interface { } type NetworkOptions struct { - NetworkStrategy C.NetworkStrategy + NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType FallbackDelay time.Duration diff --git a/common/dialer/default.go b/common/dialer/default.go index bf553618..49bd145c 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -9,6 +9,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/conntrack" C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/libbox/platform" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/atomic" @@ -16,6 +17,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" ) var ( @@ -33,19 +35,22 @@ type DefaultDialer struct { udpAddr6 string isWireGuardListener bool networkManager adapter.NetworkManager - networkStrategy C.NetworkStrategy + networkStrategy *C.NetworkStrategy networkType []C.InterfaceType fallbackNetworkType []C.InterfaceType networkFallbackDelay time.Duration networkLastFallback atomic.TypedValue[time.Time] } -func NewDefault(networkManager adapter.NetworkManager, options option.DialerOptions) (*DefaultDialer, error) { +func NewDefault(ctx context.Context, options option.DialerOptions) (*DefaultDialer, error) { + networkManager := service.FromContext[adapter.NetworkManager](ctx) + platformInterface := service.FromContext[platform.Interface](ctx) + var ( dialer net.Dialer listener net.ListenConfig interfaceFinder control.InterfaceFinder - networkStrategy C.NetworkStrategy + networkStrategy *C.NetworkStrategy networkType []C.InterfaceType fallbackNetworkType []C.InterfaceType networkFallbackDelay time.Duration @@ -74,31 +79,37 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark)) } } - if C.NetworkStrategy(options.NetworkStrategy) != C.NetworkStrategyDefault { - if options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil { - return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address` and `inet6_bind_address`") - } - networkStrategy = C.NetworkStrategy(options.NetworkStrategy) - networkType = common.Map(options.NetworkType, option.InterfaceType.Build) - fallbackNetworkType = common.Map(options.FallbackNetworkType, option.InterfaceType.Build) - networkFallbackDelay = time.Duration(options.NetworkFallbackDelay) - if networkManager == nil || !networkManager.AutoDetectInterface() { - return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`") + disableDefaultBind := options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil + if disableDefaultBind || options.TCPFastOpen { + if options.NetworkStrategy != nil || len(options.NetworkType) > 0 && options.FallbackNetworkType == nil && options.FallbackDelay == 0 { + return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address`, `inet6_bind_address` and `tcp_fast_open`") } } - if networkManager != nil && options.BindInterface == "" && options.Inet4BindAddress == nil && options.Inet6BindAddress == nil { + + if networkManager != nil { defaultOptions := networkManager.DefaultOptions() - if options.BindInterface == "" { + if !disableDefaultBind { if defaultOptions.BindInterface != "" { bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.BindInterface, -1) dialer.Control = control.Append(dialer.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc) } else if networkManager.AutoDetectInterface() { - if defaultOptions.NetworkStrategy != C.NetworkStrategyDefault && C.NetworkStrategy(options.NetworkStrategy) == C.NetworkStrategyDefault { - networkStrategy = defaultOptions.NetworkStrategy - networkType = defaultOptions.NetworkType - fallbackNetworkType = defaultOptions.FallbackNetworkType - networkFallbackDelay = defaultOptions.FallbackDelay + if platformInterface != nil { + networkStrategy = (*C.NetworkStrategy)(options.NetworkStrategy) + if networkStrategy == nil { + networkStrategy = common.Ptr(C.NetworkStrategyDefault) + } + networkType = common.Map(options.NetworkType, option.InterfaceType.Build) + fallbackNetworkType = common.Map(options.FallbackNetworkType, option.InterfaceType.Build) + if networkStrategy == nil && len(networkType) == 0 && len(fallbackNetworkType) == 0 { + networkStrategy = defaultOptions.NetworkStrategy + networkType = defaultOptions.NetworkType + fallbackNetworkType = defaultOptions.FallbackNetworkType + } + networkFallbackDelay = time.Duration(options.FallbackDelay) + if networkFallbackDelay == 0 && defaultOptions.FallbackDelay != 0 { + networkFallbackDelay = defaultOptions.FallbackDelay + } bindFunc := networkManager.ProtectFunc() dialer.Control = control.Append(dialer.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc) @@ -172,9 +183,6 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti listener.Control = control.Append(listener.Control, controlFn) } } - if networkStrategy != C.NetworkStrategyDefault && options.TCPFastOpen { - return nil, E.New("`tcp_fast_open` is conflict with `network_strategy` or `route.default_network_strategy`") - } tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) if err != nil { return nil, err @@ -204,7 +212,7 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address if !address.IsValid() { return nil, E.New("invalid address") } - if d.networkStrategy == C.NetworkStrategyDefault { + if d.networkStrategy == nil { switch N.NetworkName(network) { case N.NetworkUDP: if !address.IsIPv6() { @@ -223,12 +231,21 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address } } -func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { - if strategy == C.NetworkStrategyDefault { +func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { + if strategy == nil { + strategy = d.networkStrategy + } + if strategy == nil { return d.DialContext(ctx, network, address) } - if !d.networkManager.AutoDetectInterface() { - return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`") + if len(interfaceType) == 0 { + interfaceType = d.networkType + } + if len(fallbackInterfaceType) == 0 { + fallbackInterfaceType = d.fallbackNetworkType + } + if fallbackDelay == 0 { + fallbackDelay = d.networkFallbackDelay } var dialer net.Dialer if N.NetworkName(network) == N.NetworkTCP { @@ -243,9 +260,9 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin err error ) if !fastFallback { - conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), strategy, interfaceType, fallbackInterfaceType, fallbackDelay) + conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), *strategy, interfaceType, fallbackInterfaceType, fallbackDelay) } else { - conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), strategy, interfaceType, fallbackInterfaceType, fallbackDelay, d.networkLastFallback.Store) + conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), *strategy, interfaceType, fallbackInterfaceType, fallbackDelay, d.networkLastFallback.Store) } if err != nil { return nil, err @@ -257,7 +274,7 @@ func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network strin } func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - if d.networkStrategy == C.NetworkStrategyDefault { + if d.networkStrategy == nil { if destination.IsIPv6() { return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)) } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() { @@ -270,18 +287,27 @@ func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksadd } } -func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { - if strategy == C.NetworkStrategyDefault { +func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { + if strategy == nil { + strategy = d.networkStrategy + } + if strategy == nil { return d.ListenPacket(ctx, destination) } - if !d.networkManager.AutoDetectInterface() { - return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`") + if len(interfaceType) == 0 { + interfaceType = d.networkType + } + if len(fallbackInterfaceType) == 0 { + fallbackInterfaceType = d.fallbackNetworkType + } + if fallbackDelay == 0 { + fallbackDelay = d.networkFallbackDelay } network := N.NetworkUDP if destination.IsIPv4() && !destination.Addr.IsUnspecified() { network += "4" } - return trackPacketConn(d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", strategy, interfaceType, fallbackInterfaceType, fallbackDelay)) + return trackPacketConn(d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", *strategy, interfaceType, fallbackInterfaceType, fallbackDelay)) } func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { diff --git a/common/dialer/default_parallel_interface.go b/common/dialer/default_parallel_interface.go index 37a1f79c..269546a4 100644 --- a/common/dialer/default_parallel_interface.go +++ b/common/dialer/default_parallel_interface.go @@ -40,7 +40,7 @@ func (d *DefaultDialer) dialParallelInterface(ctx context.Context, dialer net.Di } } else { select { - case results <- dialResult{Conn: conn}: + case results <- dialResult{Conn: conn, primary: primary}: case <-returned: conn.Close() } @@ -112,7 +112,7 @@ func (d *DefaultDialer) dialParallelInterfaceFastFallback(ctx context.Context, d } } else { select { - case results <- dialResult{Conn: conn}: + case results <- dialResult{Conn: conn, primary: primary}: case <-returned: if primary && time.Since(startAt) <= fallbackDelay { resetFastFallback(time.Time{}) @@ -177,44 +177,57 @@ func selectInterfaces(networkManager adapter.NetworkManager, strategy C.NetworkS case C.NetworkStrategyDefault: if len(interfaceType) == 0 { defaultIf := networkManager.InterfaceMonitor().DefaultInterface() - for _, iif := range interfaces { - if iif.Index == defaultIf.Index { - primaryInterfaces = append(primaryInterfaces, iif) - } else { - fallbackInterfaces = append(fallbackInterfaces, iif) + if defaultIf != nil { + for _, iif := range interfaces { + if iif.Index == defaultIf.Index { + primaryInterfaces = append(primaryInterfaces, iif) + } } + } else { + primaryInterfaces = interfaces } } else { - primaryInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool { - return common.Contains(interfaceType, iif.Type) + primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool { + return common.Contains(interfaceType, it.Type) }) } case C.NetworkStrategyHybrid: if len(interfaceType) == 0 { primaryInterfaces = interfaces } else { - primaryInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool { - return common.Contains(interfaceType, iif.Type) + primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool { + return common.Contains(interfaceType, it.Type) }) } case C.NetworkStrategyFallback: if len(interfaceType) == 0 { defaultIf := networkManager.InterfaceMonitor().DefaultInterface() - for _, iif := range interfaces { - if iif.Index == defaultIf.Index { - primaryInterfaces = append(primaryInterfaces, iif) - } else { - fallbackInterfaces = append(fallbackInterfaces, iif) + if defaultIf != nil { + for _, iif := range interfaces { + if iif.Index == defaultIf.Index { + primaryInterfaces = append(primaryInterfaces, iif) + break + } } + } else { + primaryInterfaces = interfaces } } else { - primaryInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool { - return common.Contains(interfaceType, iif.Type) + primaryInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool { + return common.Contains(interfaceType, it.Type) + }) + } + if len(fallbackInterfaceType) == 0 { + fallbackInterfaces = common.Filter(interfaces, func(it adapter.NetworkInterface) bool { + return !common.Any(primaryInterfaces, func(iif adapter.NetworkInterface) bool { + return it.Index == iif.Index + }) + }) + } else { + fallbackInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool { + return common.Contains(fallbackInterfaceType, iif.Type) }) } - fallbackInterfaces = common.Filter(interfaces, func(iif adapter.NetworkInterface) bool { - return common.Contains(fallbackInterfaceType, iif.Type) - }) } return primaryInterfaces, fallbackInterfaces } diff --git a/common/dialer/default_parallel_network.go b/common/dialer/default_parallel_network.go index 5145656b..006e5747 100644 --- a/common/dialer/default_parallel_network.go +++ b/common/dialer/default_parallel_network.go @@ -13,7 +13,13 @@ import ( N "github.com/sagernet/sing/common/network" ) -func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { +func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { + if len(destinationAddresses) == 0 { + if !destination.IsIP() { + panic("invalid usage") + } + destinationAddresses = []netip.Addr{destination.Addr} + } if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) } @@ -38,7 +44,14 @@ func DialSerialNetwork(ctx context.Context, dialer N.Dialer, network string, des return nil, E.Errors(errors...) } -func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { +func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { + if len(destinationAddresses) == 0 { + if !destination.IsIP() { + panic("invalid usage") + } + destinationAddresses = []netip.Addr{destination.Addr} + } + if fallbackDelay == 0 { fallbackDelay = N.DefaultFallbackDelay } @@ -116,7 +129,13 @@ func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, ne } } -func ListenSerialNetworkPacket(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { +func ListenSerialNetworkPacket(ctx context.Context, dialer N.Dialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { + if len(destinationAddresses) == 0 { + if !destination.IsIP() { + panic("invalid usage") + } + destinationAddresses = []netip.Addr{destination.Addr} + } if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, interfaceType, fallbackInterfaceType, fallbackDelay) } diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index b307a330..89d1eeab 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -17,16 +17,15 @@ import ( ) func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) { - networkManager := service.FromContext[adapter.NetworkManager](ctx) if options.IsWireGuardListener { - return NewDefault(networkManager, options) + return NewDefault(ctx, options) } var ( dialer N.Dialer err error ) if options.Detour == "" { - dialer, err = NewDefault(networkManager, options) + dialer, err = NewDefault(ctx, options) if err != nil { return nil, err } @@ -37,9 +36,6 @@ func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) { } dialer = NewDetour(outboundManager, options.Detour) } - if networkManager == nil { - return NewDefault(networkManager, options) - } if options.Detour == "" { router := service.FromContext[adapter.Router](ctx) if router != nil { @@ -58,11 +54,10 @@ func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInter if options.Detour != "" { return nil, E.New("`detour` is not supported in direct context") } - networkManager := service.FromContext[adapter.NetworkManager](ctx) if options.IsWireGuardListener { - return NewDefault(networkManager, options) + return NewDefault(ctx, options) } - dialer, err := NewDefault(networkManager, options) + dialer, err := NewDefault(ctx, options) if err != nil { return nil, err } @@ -77,11 +72,11 @@ func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInter type ParallelInterfaceDialer interface { N.Dialer - DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) - ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) + DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) + ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) } type ParallelNetworkDialer interface { - DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) - ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) + DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) + ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) } diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go index b5d922b3..ede1afd6 100644 --- a/common/dialer/resolve.go +++ b/common/dialer/resolve.go @@ -106,7 +106,7 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil } -func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { +func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { if !destination.IsFqdn() { return d.dialer.DialContext(ctx, network, destination) } @@ -134,7 +134,7 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context } } -func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { +func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy *C.NetworkStrategy, interfaceType []C.InterfaceType, fallbackInterfaceType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, error) { if !destination.IsFqdn() { return d.dialer.ListenPacket(ctx, destination) } diff --git a/option/outbound.go b/option/outbound.go index 833a2d20..5cadd3e2 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -65,25 +65,24 @@ type DialerOptionsWrapper interface { } type DialerOptions struct { - Detour string `json:"detour,omitempty"` - BindInterface string `json:"bind_interface,omitempty"` - Inet4BindAddress *badoption.Addr `json:"inet4_bind_address,omitempty"` - Inet6BindAddress *badoption.Addr `json:"inet6_bind_address,omitempty"` - ProtectPath string `json:"protect_path,omitempty"` - RoutingMark FwMark `json:"routing_mark,omitempty"` - ReuseAddr bool `json:"reuse_addr,omitempty"` - ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` - TCPFastOpen bool `json:"tcp_fast_open,omitempty"` - TCPMultiPath bool `json:"tcp_multi_path,omitempty"` - UDPFragment *bool `json:"udp_fragment,omitempty"` - UDPFragmentDefault bool `json:"-"` - DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` - NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` - NetworkType badoption.Listable[InterfaceType] `json:"network_type,omitempty"` - FallbackNetworkType badoption.Listable[InterfaceType] `json:"fallback_network_type,omitempty"` - FallbackDelay badoption.Duration `json:"fallback_delay,omitempty"` - NetworkFallbackDelay badoption.Duration `json:"network_fallback_delay,omitempty"` - IsWireGuardListener bool `json:"-"` + Detour string `json:"detour,omitempty"` + BindInterface string `json:"bind_interface,omitempty"` + Inet4BindAddress *badoption.Addr `json:"inet4_bind_address,omitempty"` + Inet6BindAddress *badoption.Addr `json:"inet6_bind_address,omitempty"` + ProtectPath string `json:"protect_path,omitempty"` + RoutingMark FwMark `json:"routing_mark,omitempty"` + ReuseAddr bool `json:"reuse_addr,omitempty"` + ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` + TCPFastOpen bool `json:"tcp_fast_open,omitempty"` + TCPMultiPath bool `json:"tcp_multi_path,omitempty"` + UDPFragment *bool `json:"udp_fragment,omitempty"` + UDPFragmentDefault bool `json:"-"` + DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` + NetworkStrategy *NetworkStrategy `json:"network_strategy,omitempty"` + NetworkType badoption.Listable[InterfaceType] `json:"network_type,omitempty"` + FallbackNetworkType badoption.Listable[InterfaceType] `json:"fallback_network_type,omitempty"` + FallbackDelay badoption.Duration `json:"fallback_delay,omitempty"` + IsWireGuardListener bool `json:"-"` } func (o *DialerOptions) TakeDialerOptions() DialerOptions { diff --git a/option/route.go b/option/route.go index 0eb1cbf1..1eb2294b 100644 --- a/option/route.go +++ b/option/route.go @@ -13,7 +13,7 @@ type RouteOptions struct { OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"` DefaultInterface string `json:"default_interface,omitempty"` DefaultMark FwMark `json:"default_mark,omitempty"` - DefaultNetworkStrategy NetworkStrategy `json:"default_network_strategy,omitempty"` + DefaultNetworkStrategy *NetworkStrategy `json:"default_network_strategy,omitempty"` DefaultNetworkType badoption.Listable[InterfaceType] `json:"default_network_type,omitempty"` DefaultFallbackNetworkType badoption.Listable[InterfaceType] `json:"default_fallback_network_type,omitempty"` DefaultFallbackDelay badoption.Duration `json:"default_fallback_delay,omitempty"` diff --git a/option/rule_action.go b/option/rule_action.go index 29c5a0c3..b7003628 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -145,8 +145,8 @@ type RawRouteOptionsActionOptions struct { OverrideAddress string `json:"override_address,omitempty"` OverridePort uint16 `json:"override_port,omitempty"` - NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` - FallbackDelay uint32 `json:"fallback_delay,omitempty"` + NetworkStrategy *NetworkStrategy `json:"network_strategy,omitempty"` + FallbackDelay uint32 `json:"fallback_delay,omitempty"` UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` UDPConnect bool `json:"udp_connect,omitempty"` diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 42fd284d..aba56336 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -32,16 +32,12 @@ var ( type Outbound struct { outbound.Adapter - logger logger.ContextLogger - dialer dialer.ParallelInterfaceDialer - domainStrategy dns.DomainStrategy - fallbackDelay time.Duration - networkStrategy C.NetworkStrategy - networkType []C.InterfaceType - fallbackNetworkType []C.InterfaceType - networkFallbackDelay time.Duration - overrideOption int - overrideDestination M.Socksaddr + logger logger.ContextLogger + dialer dialer.ParallelInterfaceDialer + domainStrategy dns.DomainStrategy + fallbackDelay time.Duration + overrideOption int + overrideDestination M.Socksaddr // loopBack *loopBackDetector } @@ -52,15 +48,11 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), - logger: logger, - domainStrategy: dns.DomainStrategy(options.DomainStrategy), - fallbackDelay: time.Duration(options.FallbackDelay), - networkStrategy: C.NetworkStrategy(options.NetworkStrategy), - networkType: common.Map(options.NetworkType, option.InterfaceType.Build), - fallbackNetworkType: common.Map(options.FallbackNetworkType, option.InterfaceType.Build), - networkFallbackDelay: time.Duration(options.NetworkFallbackDelay), - dialer: outboundDialer, + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions), + logger: logger, + domainStrategy: dns.DomainStrategy(options.DomainStrategy), + fallbackDelay: time.Duration(options.FallbackDelay), + dialer: outboundDialer, // loopBack: newLoopBackDetector(router), } //nolint:staticcheck @@ -178,10 +170,10 @@ func (h *Outbound) DialParallel(ctx context.Context, network string, destination return nil, E.New("no IPv6 address available for ", destination) } } - return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, h.networkStrategy, h.networkType, h.fallbackNetworkType, h.fallbackDelay) + return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, nil, nil, nil, h.fallbackDelay) } -func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { +func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) { ctx, metadata := adapter.ExtendContext(ctx) metadata.Outbound = h.Tag() metadata.Destination = destination @@ -221,7 +213,7 @@ func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, dest return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, networkStrategy, networkType, fallbackNetworkType, fallbackDelay) } -func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { +func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { ctx, metadata := adapter.ExtendContext(ctx) metadata.Outbound = h.Tag() metadata.Destination = destination diff --git a/protocol/wireguard/endpoint.go b/protocol/wireguard/endpoint.go index 937f84dd..21d72bd9 100644 --- a/protocol/wireguard/endpoint.go +++ b/protocol/wireguard/endpoint.go @@ -20,7 +20,6 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/service" ) func RegisterEndpoint(registry *endpoint.Registry) { @@ -70,7 +69,7 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL UDPTimeout: udpTimeout, Dialer: outboundDialer, CreateDialer: func(interfaceName string) N.Dialer { - return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{ + return common.Must1(dialer.NewDefault(ctx, option.DialerOptions{ BindInterface: interfaceName, })) }, diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index a1fce796..3e299705 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -19,7 +19,6 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" - "github.com/sagernet/sing/service" ) func RegisterOutbound(registry *outbound.Registry) { @@ -86,7 +85,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL System: options.SystemInterface, Dialer: outboundDialer, CreateDialer: func(interfaceName string) N.Dialer { - return common.Must1(dialer.NewDefault(service.FromContext[adapter.NetworkManager](ctx), option.DialerOptions{ + return common.Must1(dialer.NewDefault(ctx, option.DialerOptions{ BindInterface: interfaceName, })) }, diff --git a/route/conn.go b/route/conn.go index 93ac33e3..e010c2cd 100644 --- a/route/conn.go +++ b/route/conn.go @@ -56,7 +56,7 @@ func (m *ConnectionManager) NewConnection(ctx context.Context, this N.Dialer, co remoteConn net.Conn err error ) - if len(metadata.DestinationAddresses) > 0 { + if len(metadata.DestinationAddresses) > 0 || metadata.Destination.IsIP() { remoteConn, err = dialer.DialSerialNetwork(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remoteConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) @@ -97,12 +97,19 @@ func (m *ConnectionManager) NewPacketConnection(ctx context.Context, this N.Dial err error ) if metadata.UDPConnect { + parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer) if len(metadata.DestinationAddresses) > 0 { - if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { + if isParallelDialer { remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) } else { remoteConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) } + } else if metadata.Destination.IsIP() { + if isParallelDialer { + remoteConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.NetworkType, metadata.FallbackNetworkType, metadata.FallbackDelay) + } else { + remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) + } } else { remoteConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) } diff --git a/route/network.go b/route/network.go index 2065c389..d82701fe 100644 --- a/route/network.go +++ b/route/network.go @@ -62,7 +62,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp defaultOptions: adapter.NetworkOptions{ BindInterface: routeOptions.DefaultInterface, RoutingMark: uint32(routeOptions.DefaultMark), - NetworkStrategy: C.NetworkStrategy(routeOptions.DefaultNetworkStrategy), + NetworkStrategy: (*C.NetworkStrategy)(routeOptions.DefaultNetworkStrategy), NetworkType: common.Map(routeOptions.DefaultNetworkType, option.InterfaceType.Build), FallbackNetworkType: common.Map(routeOptions.DefaultFallbackNetworkType, option.InterfaceType.Build), FallbackDelay: time.Duration(routeOptions.DefaultFallbackDelay), @@ -73,7 +73,7 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp inbound: service.FromContext[adapter.InboundManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx), } - if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault { + if routeOptions.DefaultNetworkStrategy != nil { if routeOptions.DefaultInterface != "" { return nil, E.New("`default_network_strategy` is conflict with `default_interface`") } diff --git a/route/route.go b/route/route.go index fb2de85d..ca747462 100644 --- a/route/route.go +++ b/route/route.go @@ -415,8 +415,18 @@ match: Fqdn: metadata.Destination.Fqdn, } } - metadata.NetworkStrategy = routeOptions.NetworkStrategy - metadata.FallbackDelay = routeOptions.FallbackDelay + if routeOptions.NetworkStrategy != nil { + metadata.NetworkStrategy = routeOptions.NetworkStrategy + } + if len(routeOptions.NetworkType) > 0 { + metadata.NetworkType = routeOptions.NetworkType + } + if len(routeOptions.FallbackNetworkType) > 0 { + metadata.FallbackNetworkType = routeOptions.FallbackNetworkType + } + if routeOptions.FallbackDelay != 0 { + metadata.FallbackDelay = routeOptions.FallbackDelay + } if routeOptions.UDPDisableDomainUnmapping { metadata.UDPDisableDomainUnmapping = true } diff --git a/route/router.go b/route/router.go index 792391e2..6526778b 100644 --- a/route/router.go +++ b/route/router.go @@ -262,7 +262,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.Route Context: ctx, Name: "local", Address: "local", - Dialer: common.Must1(dialer.NewDefault(router.network, option.DialerOptions{})), + Dialer: common.Must1(dialer.NewDefault(ctx, option.DialerOptions{})), }))) } defaultTransport = transports[0] diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 34354cc0..f4f2299a 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -33,7 +33,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti RuleActionRouteOptions: RuleActionRouteOptions{ OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptions.OverrideAddress, 0), OverridePort: action.RouteOptions.OverridePort, - NetworkStrategy: C.NetworkStrategy(action.RouteOptions.NetworkStrategy), + NetworkStrategy: (*C.NetworkStrategy)(action.RouteOptions.NetworkStrategy), FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, UDPConnect: action.RouteOptions.UDPConnect, @@ -43,7 +43,7 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti return &RuleActionRouteOptions{ OverrideAddress: M.ParseSocksaddrHostPort(action.RouteOptionsOptions.OverrideAddress, 0), OverridePort: action.RouteOptionsOptions.OverridePort, - NetworkStrategy: C.NetworkStrategy(action.RouteOptionsOptions.NetworkStrategy), + NetworkStrategy: (*C.NetworkStrategy)(action.RouteOptionsOptions.NetworkStrategy), FallbackDelay: time.Duration(action.RouteOptionsOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, UDPConnect: action.RouteOptionsOptions.UDPConnect, @@ -147,7 +147,7 @@ func (r *RuleActionRoute) String() string { type RuleActionRouteOptions struct { OverrideAddress M.Socksaddr OverridePort uint16 - NetworkStrategy C.NetworkStrategy + NetworkStrategy *C.NetworkStrategy NetworkType []C.InterfaceType FallbackNetworkType []C.InterfaceType FallbackDelay time.Duration diff --git a/transport/dhcp/server.go b/transport/dhcp/server.go index 29c6bbe0..8b9187f0 100644 --- a/transport/dhcp/server.go +++ b/transport/dhcp/server.go @@ -253,7 +253,7 @@ func (t *Transport) recreateServers(iface *control.Interface, serverAddrs []neti return it.String() }), ","), "]") } - serverDialer := common.Must1(dialer.NewDefault(t.networkManager, option.DialerOptions{ + serverDialer := common.Must1(dialer.NewDefault(t.options.Context, option.DialerOptions{ BindInterface: iface.Name, UDPFragmentDefault: true, }))