From 15a9876a10f55b9c6dddabda196bd1d7b97e3493 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 12 Nov 2024 19:37:10 +0800 Subject: [PATCH] Add multi network dialing --- adapter/inbound.go | 4 + adapter/network.go | 14 +- adapter/outbound/default.go | 116 ++-------- common/dialer/default.go | 214 ++++++++++++----- common/dialer/default_go1.20.go | 4 + common/dialer/default_nongo1.20.go | 4 + common/dialer/default_parallel_interface.go | 241 ++++++++++++++++++++ common/dialer/default_parallel_network.go | 122 ++++++++++ common/dialer/dialer.go | 36 +++ common/dialer/resolve.go | 89 +++++++- common/settings/proxy_darwin.go | 10 +- constant/network.go | 42 ++++ experimental/libbox/monitor.go | 4 +- go.mod | 4 +- go.sum | 8 +- option/outbound.go | 32 +-- option/route.go | 24 +- option/rule_action.go | 14 +- option/types.go | 21 ++ protocol/direct/outbound.go | 165 ++++++++++---- protocol/socks/outbound.go | 21 -- protocol/wireguard/outbound.go | 13 -- route/network.go | 121 ++++++---- route/route.go | 4 + route/rule/rule_action.go | 6 + transport/dhcp/server.go | 2 +- 26 files changed, 994 insertions(+), 341 deletions(-) create mode 100644 common/dialer/default_parallel_interface.go create mode 100644 common/dialer/default_parallel_network.go diff --git a/adapter/inbound.go b/adapter/inbound.go index 7932237d..33b1b4d1 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -3,8 +3,10 @@ package adapter import ( "context" "net/netip" + "time" "github.com/sagernet/sing-box/common/process" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" M "github.com/sagernet/sing/common/metadata" @@ -66,6 +68,8 @@ type InboundContext struct { InboundOptions option.InboundOptions UDPDisableDomainUnmapping bool UDPConnect bool + NetworkStrategy C.NetworkStrategy + FallbackDelay time.Duration DNSServer string diff --git a/adapter/network.go b/adapter/network.go index 6c09a0a3..dd924c33 100644 --- a/adapter/network.go +++ b/adapter/network.go @@ -1,6 +1,9 @@ package adapter import ( + "time" + + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/control" ) @@ -11,10 +14,10 @@ type NetworkManager interface { UpdateInterfaces() error DefaultNetworkInterface() *NetworkInterface NetworkInterfaces() []NetworkInterface - DefaultInterface() string AutoDetectInterface() bool AutoDetectInterfaceFunc() control.Func - DefaultMark() uint32 + ProtectFunc() control.Func + DefaultOptions() NetworkOptions RegisterAutoRedirectOutputMark(mark uint32) error AutoRedirectOutputMark() uint32 NetworkMonitor() tun.NetworkUpdateMonitor @@ -24,6 +27,13 @@ type NetworkManager interface { ResetNetwork() } +type NetworkOptions struct { + DefaultNetworkStrategy C.NetworkStrategy + DefaultFallbackDelay time.Duration + DefaultInterface string + DefaultMark uint32 +} + type InterfaceUpdateListener interface { InterfaceUpdated() } diff --git a/adapter/outbound/default.go b/adapter/outbound/default.go index bb58ff54..84be8aba 100644 --- a/adapter/outbound/default.go +++ b/adapter/outbound/default.go @@ -8,8 +8,8 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/buf" "github.com/sagernet/sing/common/bufio" @@ -25,35 +25,11 @@ func NewConnection(ctx context.Context, this N.Dialer, conn net.Conn, metadata a var outConn net.Conn var err error if len(metadata.DestinationAddresses) > 0 { - outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) - } else { - outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) - } - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - err = N.ReportConnHandshakeSuccess(conn, outConn) - if err != nil { - outConn.Close() - return err - } - return CopyEarlyConn(ctx, conn, outConn) -} - -func NewDirectConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn net.Conn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error { - defer conn.Close() - ctx = adapter.WithContext(ctx, &metadata) - var outConn net.Conn - var err error - if len(metadata.DestinationAddresses) > 0 { - outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) - } else if metadata.Destination.IsFqdn() { - var destinationAddresses []netip.Addr - destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) - if err != nil { - return N.ReportHandshakeFailure(conn, err) + if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { + outConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.FallbackDelay) + } else { + outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, metadata.DestinationAddresses) } - outConn, err = N.DialSerial(ctx, this, N.NetworkTCP, metadata.Destination, destinationAddresses) } else { outConn, err = this.DialContext(ctx, N.NetworkTCP, metadata.Destination) } @@ -79,7 +55,11 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, ) if metadata.UDPConnect { if len(metadata.DestinationAddresses) > 0 { - outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) + if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { + outConn, err = dialer.DialSerialNetwork(ctx, parallelDialer, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.FallbackDelay) + } else { + outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) + } } else { outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) } @@ -93,7 +73,11 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, } } else { if len(metadata.DestinationAddresses) > 0 { - outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) + if parallelDialer, isParallelDialer := this.(dialer.ParallelInterfaceDialer); isParallelDialer { + outPacketConn, destinationAddress, err = dialer.ListenSerialNetworkPacket(ctx, parallelDialer, metadata.Destination, metadata.DestinationAddresses, metadata.NetworkStrategy, metadata.FallbackDelay) + } else { + outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) + } } else { outPacketConn, err = this.ListenPacket(ctx, metadata.Destination) } @@ -129,76 +113,6 @@ func NewPacketConnection(ctx context.Context, this N.Dialer, conn N.PacketConn, return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn)) } -func NewDirectPacketConnection(ctx context.Context, router adapter.Router, this N.Dialer, conn N.PacketConn, metadata adapter.InboundContext, domainStrategy dns.DomainStrategy) error { - defer conn.Close() - ctx = adapter.WithContext(ctx, &metadata) - var ( - outPacketConn net.PacketConn - outConn net.Conn - destinationAddress netip.Addr - err error - ) - if metadata.UDPConnect { - if len(metadata.DestinationAddresses) > 0 { - outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, metadata.DestinationAddresses) - } else if metadata.Destination.IsFqdn() { - var destinationAddresses []netip.Addr - destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - outConn, err = N.DialSerial(ctx, this, N.NetworkUDP, metadata.Destination, destinationAddresses) - } else { - outConn, err = this.DialContext(ctx, N.NetworkUDP, metadata.Destination) - } - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - connRemoteAddr := M.AddrFromNet(outConn.RemoteAddr()) - if connRemoteAddr != metadata.Destination.Addr { - destinationAddress = connRemoteAddr - } - } else { - if len(metadata.DestinationAddresses) > 0 { - outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, metadata.DestinationAddresses) - } else if metadata.Destination.IsFqdn() { - var destinationAddresses []netip.Addr - destinationAddresses, err = router.Lookup(ctx, metadata.Destination.Fqdn, domainStrategy) - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - outPacketConn, destinationAddress, err = N.ListenSerial(ctx, this, metadata.Destination, destinationAddresses) - } else { - outPacketConn, err = this.ListenPacket(ctx, metadata.Destination) - } - if err != nil { - return N.ReportHandshakeFailure(conn, err) - } - } - err = N.ReportPacketConnHandshakeSuccess(conn, outPacketConn) - if err != nil { - outPacketConn.Close() - return err - } - if destinationAddress.IsValid() { - if metadata.Destination.IsFqdn() { - outPacketConn = bufio.NewNATPacketConn(bufio.NewPacketConn(outPacketConn), M.SocksaddrFrom(destinationAddress, metadata.Destination.Port), metadata.Destination) - } - if natConn, loaded := common.Cast[bufio.NATPacketConn](conn); loaded { - natConn.UpdateDestination(destinationAddress) - } - } - switch metadata.Protocol { - case C.ProtocolSTUN: - ctx, conn = canceler.NewPacketConn(ctx, conn, C.STUNTimeout) - case C.ProtocolQUIC: - ctx, conn = canceler.NewPacketConn(ctx, conn, C.QUICTimeout) - case C.ProtocolDNS: - ctx, conn = canceler.NewPacketConn(ctx, conn, C.DNSTimeout) - } - return bufio.CopyPacketConn(ctx, conn, bufio.NewPacketConn(outPacketConn)) -} - func CopyEarlyConn(ctx context.Context, conn net.Conn, serverConn net.Conn) error { if cachedReader, isCached := conn.(N.CachedReader); isCached { payload := cachedReader.ReadCached() diff --git a/common/dialer/default.go b/common/dialer/default.go index a4e4290e..3b2ffd76 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -10,66 +10,93 @@ import ( "github.com/sagernet/sing-box/common/conntrack" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" ) -var _ WireGuardListener = (*DefaultDialer)(nil) +var ( + _ ParallelInterfaceDialer = (*DefaultDialer)(nil) + _ WireGuardListener = (*DefaultDialer)(nil) +) type DefaultDialer struct { - dialer4 tcpDialer - dialer6 tcpDialer - udpDialer4 net.Dialer - udpDialer6 net.Dialer - udpListener net.ListenConfig - udpAddr4 string - udpAddr6 string - isWireGuardListener bool + dialer4 tcpDialer + dialer6 tcpDialer + udpDialer4 net.Dialer + udpDialer6 net.Dialer + udpListener net.ListenConfig + udpAddr4 string + udpAddr6 string + isWireGuardListener bool + networkManager adapter.NetworkManager + networkStrategy C.NetworkStrategy + networkFallbackDelay time.Duration + networkLastFallback atomic.TypedValue[time.Time] } func NewDefault(networkManager adapter.NetworkManager, options option.DialerOptions) (*DefaultDialer, error) { - var dialer net.Dialer - var listener net.ListenConfig + var ( + dialer net.Dialer + listener net.ListenConfig + interfaceFinder control.InterfaceFinder + networkStrategy C.NetworkStrategy + networkFallbackDelay time.Duration + ) + if networkManager != nil { + interfaceFinder = networkManager.InterfaceFinder() + } else { + interfaceFinder = control.NewDefaultInterfaceFinder() + } if options.BindInterface != "" { - var interfaceFinder control.InterfaceFinder - if networkManager != nil { - interfaceFinder = networkManager.InterfaceFinder() - } else { - interfaceFinder = control.NewDefaultInterfaceFinder() - } bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) dialer.Control = control.Append(dialer.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc) - } else if networkManager != nil && networkManager.AutoDetectInterface() { - bindFunc := networkManager.AutoDetectInterfaceFunc() - dialer.Control = control.Append(dialer.Control, bindFunc) - listener.Control = control.Append(listener.Control, bindFunc) - } else if networkManager != nil && networkManager.DefaultInterface() != "" { - bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), networkManager.DefaultInterface(), -1) - dialer.Control = control.Append(dialer.Control, bindFunc) - listener.Control = control.Append(listener.Control, bindFunc) - } - var autoRedirectOutputMark uint32 - if networkManager != nil { - autoRedirectOutputMark = networkManager.AutoRedirectOutputMark() - } - if autoRedirectOutputMark > 0 { - dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark)) - listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark)) } if options.RoutingMark > 0 { dialer.Control = control.Append(dialer.Control, control.RoutingMark(options.RoutingMark)) listener.Control = control.Append(listener.Control, control.RoutingMark(options.RoutingMark)) + } + if networkManager != nil { + autoRedirectOutputMark := networkManager.AutoRedirectOutputMark() if autoRedirectOutputMark > 0 { - return nil, E.New("`auto_redirect` with `route_[_exclude]_address_set is conflict with `routing_mark`") + if options.RoutingMark > 0 { + return nil, E.New("`routing_mark` is conflict with `tun.auto_redirect` with `tun.route_[_exclude]_address_set") + } + dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark)) + listener.Control = control.Append(listener.Control, control.RoutingMark(autoRedirectOutputMark)) } - } else if networkManager != nil && networkManager.DefaultMark() > 0 { - dialer.Control = control.Append(dialer.Control, control.RoutingMark(networkManager.DefaultMark())) - listener.Control = control.Append(listener.Control, control.RoutingMark(networkManager.DefaultMark())) - if autoRedirectOutputMark > 0 { - return nil, E.New("`auto_redirect` with `route_[_exclude]_address_set is conflict with `default_mark`") + } + if C.NetworkStrategy(options.NetworkStrategy) != C.NetworkStrategyDefault { + if options.BindInterface != "" || options.Inet4BindAddress != nil || options.Inet6BindAddress != nil { + return nil, E.New("`network_strategy` is conflict with `bind_interface`, `inet4_bind_address` and `inet6_bind_address`") + } + networkStrategy = C.NetworkStrategy(options.NetworkStrategy) + networkFallbackDelay = time.Duration(options.NetworkFallbackDelay) + if networkManager == nil || !networkManager.AutoDetectInterface() { + return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`") + } + } + if networkManager != nil && options.BindInterface == "" && options.Inet4BindAddress == nil && options.Inet6BindAddress == nil { + defaultOptions := networkManager.DefaultOptions() + if defaultOptions.DefaultInterface != "" { + bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), defaultOptions.DefaultInterface, -1) + dialer.Control = control.Append(dialer.Control, bindFunc) + listener.Control = control.Append(listener.Control, bindFunc) + } else if networkManager.AutoDetectInterface() { + if defaultOptions.DefaultNetworkStrategy != C.NetworkStrategyDefault && C.NetworkStrategy(options.NetworkStrategy) == C.NetworkStrategyDefault { + networkStrategy = defaultOptions.DefaultNetworkStrategy + networkFallbackDelay = defaultOptions.DefaultFallbackDelay + bindFunc := networkManager.ProtectFunc() + dialer.Control = control.Append(dialer.Control, bindFunc) + listener.Control = control.Append(listener.Control, bindFunc) + } else { + bindFunc := networkManager.AutoDetectInterfaceFunc() + dialer.Control = control.Append(dialer.Control, bindFunc) + listener.Control = control.Append(listener.Control, bindFunc) + } } } if options.ReuseAddr { @@ -130,6 +157,9 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti listener.Control = control.Append(listener.Control, controlFn) } } + if networkStrategy != C.NetworkStrategyDefault && options.TCPFastOpen { + return nil, E.New("`tcp_fast_open` is conflict with `network_strategy` or `route.default_network_strategy`") + } tcpDialer4, err := newTCPDialer(dialer4, options.TCPFastOpen) if err != nil { return nil, err @@ -139,14 +169,17 @@ func NewDefault(networkManager adapter.NetworkManager, options option.DialerOpti return nil, err } return &DefaultDialer{ - tcpDialer4, - tcpDialer6, - udpDialer4, - udpDialer6, - listener, - udpAddr4, - udpAddr6, - options.IsWireGuardListener, + dialer4: tcpDialer4, + dialer6: tcpDialer6, + udpDialer4: udpDialer4, + udpDialer6: udpDialer6, + udpListener: listener, + udpAddr4: udpAddr4, + udpAddr6: udpAddr6, + isWireGuardListener: options.IsWireGuardListener, + networkManager: networkManager, + networkStrategy: networkStrategy, + networkFallbackDelay: networkFallbackDelay, }, nil } @@ -154,33 +187,88 @@ func (d *DefaultDialer) DialContext(ctx context.Context, network string, address if !address.IsValid() { return nil, E.New("invalid address") } - switch N.NetworkName(network) { - case N.NetworkUDP: - if !address.IsIPv6() { - return trackConn(d.udpDialer4.DialContext(ctx, network, address.String())) - } else { - return trackConn(d.udpDialer6.DialContext(ctx, network, address.String())) + if d.networkStrategy == C.NetworkStrategyDefault { + switch N.NetworkName(network) { + case N.NetworkUDP: + if !address.IsIPv6() { + return trackConn(d.udpDialer4.DialContext(ctx, network, address.String())) + } else { + return trackConn(d.udpDialer6.DialContext(ctx, network, address.String())) + } + } + if !address.IsIPv6() { + return trackConn(DialSlowContext(&d.dialer4, ctx, network, address)) + } else { + return trackConn(DialSlowContext(&d.dialer6, ctx, network, address)) } - } - if !address.IsIPv6() { - return trackConn(DialSlowContext(&d.dialer4, ctx, network, address)) } else { - return trackConn(DialSlowContext(&d.dialer6, ctx, network, address)) + return d.DialParallelInterface(ctx, network, address, d.networkStrategy, d.networkFallbackDelay) } } +func (d *DefaultDialer) DialParallelInterface(ctx context.Context, network string, address M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) { + if strategy == C.NetworkStrategyDefault { + return d.DialContext(ctx, network, address) + } + if !d.networkManager.AutoDetectInterface() { + return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`") + } + var dialer net.Dialer + if N.NetworkName(network) == N.NetworkTCP { + dialer = dialerFromTCPDialer(d.dialer4) + } else { + dialer = d.udpDialer4 + } + fastFallback := time.Now().Sub(d.networkLastFallback.Load()) < C.TCPTimeout + var ( + conn net.Conn + isPrimary bool + err error + ) + if !fastFallback { + conn, isPrimary, err = d.dialParallelInterface(ctx, dialer, network, address.String(), strategy, fallbackDelay) + } else { + conn, isPrimary, err = d.dialParallelInterfaceFastFallback(ctx, dialer, network, address.String(), strategy, fallbackDelay, d.networkLastFallback.Store) + } + if err != nil { + return nil, err + } + if !fastFallback && !isPrimary { + d.networkLastFallback.Store(time.Now()) + } + return trackConn(conn, nil) +} + func (d *DefaultDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - if destination.IsIPv6() { - return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)) - } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() { - return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4)) + if d.networkStrategy == C.NetworkStrategyDefault { + if destination.IsIPv6() { + return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr6)) + } else if destination.IsIPv4() && !destination.Addr.IsUnspecified() { + return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP+"4", d.udpAddr4)) + } else { + return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4)) + } } else { - return trackPacketConn(d.udpListener.ListenPacket(ctx, N.NetworkUDP, d.udpAddr4)) + return d.ListenSerialInterfacePacket(ctx, destination, d.networkStrategy, d.networkFallbackDelay) } } +func (d *DefaultDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error) { + if strategy == C.NetworkStrategyDefault { + return d.ListenPacket(ctx, destination) + } + if !d.networkManager.AutoDetectInterface() { + return nil, E.New("`route.auto_detect_interface` is require by `network_strategy`") + } + network := N.NetworkUDP + if destination.IsIPv4() && !destination.Addr.IsUnspecified() { + network += "4" + } + return trackPacketConn(d.listenSerialInterfacePacket(ctx, d.udpListener, network, "", strategy, fallbackDelay)) +} + func (d *DefaultDialer) ListenPacketCompat(network, address string) (net.PacketConn, error) { - return trackPacketConn(d.udpListener.ListenPacket(context.Background(), network, address)) + return trackPacketConn(d.listenSerialInterfacePacket(context.Background(), d.udpListener, network, address, d.networkStrategy, d.networkFallbackDelay)) } func trackConn(conn net.Conn, err error) (net.Conn, error) { diff --git a/common/dialer/default_go1.20.go b/common/dialer/default_go1.20.go index a9f7b612..9dde955f 100644 --- a/common/dialer/default_go1.20.go +++ b/common/dialer/default_go1.20.go @@ -13,3 +13,7 @@ type tcpDialer = tfo.Dialer func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) { return tfo.Dialer{Dialer: dialer, DisableTFO: !tfoEnabled}, nil } + +func dialerFromTCPDialer(dialer tcpDialer) net.Dialer { + return dialer.Dialer +} diff --git a/common/dialer/default_nongo1.20.go b/common/dialer/default_nongo1.20.go index 21502424..b2e4638d 100644 --- a/common/dialer/default_nongo1.20.go +++ b/common/dialer/default_nongo1.20.go @@ -16,3 +16,7 @@ func newTCPDialer(dialer net.Dialer, tfoEnabled bool) (tcpDialer, error) { } return dialer, nil } + +func dialerFromTCPDialer(dialer tcpDialer) net.Dialer { + return dialer +} diff --git a/common/dialer/default_parallel_interface.go b/common/dialer/default_parallel_interface.go new file mode 100644 index 00000000..baf0349e --- /dev/null +++ b/common/dialer/default_parallel_interface.go @@ -0,0 +1,241 @@ +package dialer + +import ( + "context" + "net" + "time" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + F "github.com/sagernet/sing/common/format" + N "github.com/sagernet/sing/common/network" +) + +func (d *DefaultDialer) dialParallelInterface(ctx context.Context, dialer net.Dialer, network string, addr string, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, bool, error) { + primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy) + if len(primaryInterfaces)+len(fallbackInterfaces) == 0 { + return nil, false, E.New("no available network interface") + } + if fallbackDelay == 0 { + fallbackDelay = N.DefaultFallbackDelay + } + returned := make(chan struct{}) + defer close(returned) + type dialResult struct { + net.Conn + error + primary bool + } + results := make(chan dialResult) // unbuffered + startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) { + perNetDialer := dialer + perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index)) + conn, err := perNetDialer.DialContext(ctx, network, addr) + if err != nil { + select { + case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Name, ")"), primary: primary}: + case <-returned: + } + } else { + select { + case results <- dialResult{Conn: conn}: + case <-returned: + conn.Close() + } + } + } + primaryCtx, primaryCancel := context.WithCancel(ctx) + defer primaryCancel() + for _, iif := range primaryInterfaces { + go startRacer(primaryCtx, true, iif) + } + var ( + fallbackTimer *time.Timer + fallbackChan <-chan time.Time + ) + if len(fallbackInterfaces) > 0 { + fallbackTimer = time.NewTimer(fallbackDelay) + defer fallbackTimer.Stop() + fallbackChan = fallbackTimer.C + } + var errors []error + for { + select { + case <-fallbackChan: + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + for _, iif := range fallbackInterfaces { + go startRacer(fallbackCtx, false, iif) + } + case res := <-results: + if res.error == nil { + return res.Conn, res.primary, nil + } + errors = append(errors, res.error) + if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) { + return nil, false, E.Errors(errors...) + } + if res.primary && fallbackTimer != nil && fallbackTimer.Stop() { + fallbackTimer.Reset(0) + } + } + } +} + +func (d *DefaultDialer) dialParallelInterfaceFastFallback(ctx context.Context, dialer net.Dialer, network string, addr string, strategy C.NetworkStrategy, fallbackDelay time.Duration, resetFastFallback func(time.Time)) (net.Conn, bool, error) { + primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy) + if len(primaryInterfaces)+len(fallbackInterfaces) == 0 { + return nil, false, E.New("no available network interface") + } + if fallbackDelay == 0 { + fallbackDelay = N.DefaultFallbackDelay + } + returned := make(chan struct{}) + defer close(returned) + type dialResult struct { + net.Conn + error + primary bool + } + startAt := time.Now() + results := make(chan dialResult) // unbuffered + startRacer := func(ctx context.Context, primary bool, iif adapter.NetworkInterface) { + perNetDialer := dialer + perNetDialer.Control = control.Append(perNetDialer.Control, control.BindToInterface(nil, iif.Name, iif.Index)) + conn, err := perNetDialer.DialContext(ctx, network, addr) + if err != nil { + select { + case results <- dialResult{error: E.Cause(err, "dial ", iif.Name, " (", iif.Name, ")"), primary: primary}: + case <-returned: + } + } else { + select { + case results <- dialResult{Conn: conn}: + case <-returned: + if primary && time.Since(startAt) <= fallbackDelay { + resetFastFallback(time.Time{}) + } + conn.Close() + } + } + } + for _, iif := range primaryInterfaces { + go startRacer(ctx, true, iif) + } + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + for _, iif := range fallbackInterfaces { + go startRacer(fallbackCtx, false, iif) + } + var errors []error + for { + select { + case res := <-results: + if res.error == nil { + return res.Conn, res.primary, nil + } + errors = append(errors, res.error) + if len(errors) == len(primaryInterfaces)+len(fallbackInterfaces) { + return nil, false, E.Errors(errors...) + } + } + } +} + +func (d *DefaultDialer) listenSerialInterfacePacket(ctx context.Context, listener net.ListenConfig, network string, addr string, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error) { + primaryInterfaces, fallbackInterfaces := selectInterfaces(d.networkManager, strategy) + if len(primaryInterfaces)+len(fallbackInterfaces) == 0 { + return nil, E.New("no available network interface") + } + if fallbackDelay == 0 { + fallbackDelay = N.DefaultFallbackDelay + } + var errors []error + for _, primaryInterface := range primaryInterfaces { + perNetListener := listener + perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, primaryInterface.Name, primaryInterface.Index)) + conn, err := perNetListener.ListenPacket(ctx, network, addr) + if err == nil { + return conn, nil + } + errors = append(errors, E.Cause(err, "listen ", primaryInterface.Name, " (", primaryInterface.Name, ")")) + } + for _, fallbackInterface := range fallbackInterfaces { + perNetListener := listener + perNetListener.Control = control.Append(perNetListener.Control, control.BindToInterface(nil, fallbackInterface.Name, fallbackInterface.Index)) + conn, err := perNetListener.ListenPacket(ctx, network, addr) + if err == nil { + return conn, nil + } + errors = append(errors, E.Cause(err, "listen ", fallbackInterface.Name, " (", fallbackInterface.Name, ")")) + } + return nil, E.Errors(errors...) +} + +func selectInterfaces(networkManager adapter.NetworkManager, strategy C.NetworkStrategy) (primaryInterfaces []adapter.NetworkInterface, fallbackInterfaces []adapter.NetworkInterface) { + interfaces := networkManager.NetworkInterfaces() + switch strategy { + case C.NetworkStrategyFallback: + defaultIf := networkManager.InterfaceMonitor().DefaultInterface() + if defaultIf != nil { + for _, iif := range interfaces { + if iif.Index == defaultIf.Index { + primaryInterfaces = append(primaryInterfaces, iif) + } else { + fallbackInterfaces = append(fallbackInterfaces, iif) + } + } + } else { + primaryInterfaces = interfaces + } + case C.NetworkStrategyHybrid: + primaryInterfaces = interfaces + case C.NetworkStrategyWIFI: + for _, iif := range interfaces { + if iif.Type == C.InterfaceTypeWIFI { + primaryInterfaces = append(primaryInterfaces, iif) + } else { + fallbackInterfaces = append(fallbackInterfaces, iif) + } + } + case C.NetworkStrategyCellular: + for _, iif := range interfaces { + if iif.Type == C.InterfaceTypeCellular { + primaryInterfaces = append(primaryInterfaces, iif) + } else { + fallbackInterfaces = append(fallbackInterfaces, iif) + } + } + case C.NetworkStrategyEthernet: + for _, iif := range interfaces { + if iif.Type == C.InterfaceTypeEthernet { + primaryInterfaces = append(primaryInterfaces, iif) + } else { + fallbackInterfaces = append(fallbackInterfaces, iif) + } + } + case C.NetworkStrategyWIFIOnly: + for _, iif := range interfaces { + if iif.Type == C.InterfaceTypeWIFI { + primaryInterfaces = append(primaryInterfaces, iif) + } + } + case C.NetworkStrategyCellularOnly: + for _, iif := range interfaces { + if iif.Type == C.InterfaceTypeCellular { + primaryInterfaces = append(primaryInterfaces, iif) + } + } + case C.NetworkStrategyEthernetOnly: + for _, iif := range interfaces { + if iif.Type == C.InterfaceTypeEthernet { + primaryInterfaces = append(primaryInterfaces, iif) + } + } + default: + panic(F.ToString("unknown network strategy: ", strategy)) + } + return primaryInterfaces, fallbackInterfaces +} diff --git a/common/dialer/default_parallel_network.go b/common/dialer/default_parallel_network.go new file mode 100644 index 00000000..f42d9330 --- /dev/null +++ b/common/dialer/default_parallel_network.go @@ -0,0 +1,122 @@ +package dialer + +import ( + "context" + "net" + "net/netip" + "time" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" + N "github.com/sagernet/sing/common/network" +) + +func DialSerialNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) { + if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { + return parallelDialer.DialParallelNetwork(ctx, network, destination, destinationAddresses, strategy, fallbackDelay) + } + var errors []error + for _, address := range destinationAddresses { + conn, err := dialer.DialParallelInterface(ctx, network, M.SocksaddrFrom(address, destination.Port), strategy, fallbackDelay) + if err == nil { + return conn, nil + } + errors = append(errors, err) + } + return nil, E.Errors(errors...) +} + +func DialParallelNetwork(ctx context.Context, dialer ParallelInterfaceDialer, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, preferIPv6 bool, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) { + if fallbackDelay == 0 { + fallbackDelay = N.DefaultFallbackDelay + } + + returned := make(chan struct{}) + defer close(returned) + + addresses4 := common.Filter(destinationAddresses, func(address netip.Addr) bool { + return address.Is4() || address.Is4In6() + }) + addresses6 := common.Filter(destinationAddresses, func(address netip.Addr) bool { + return address.Is6() && !address.Is4In6() + }) + if len(addresses4) == 0 || len(addresses6) == 0 { + return DialSerialNetwork(ctx, dialer, network, destination, destinationAddresses, strategy, fallbackDelay) + } + var primaries, fallbacks []netip.Addr + if preferIPv6 { + primaries = addresses6 + fallbacks = addresses4 + } else { + primaries = addresses4 + fallbacks = addresses6 + } + type dialResult struct { + net.Conn + error + primary bool + done bool + } + results := make(chan dialResult) // unbuffered + startRacer := func(ctx context.Context, primary bool) { + ras := primaries + if !primary { + ras = fallbacks + } + c, err := DialSerialNetwork(ctx, dialer, network, destination, ras, strategy, fallbackDelay) + select { + case results <- dialResult{Conn: c, error: err, primary: primary, done: true}: + case <-returned: + if c != nil { + c.Close() + } + } + } + var primary, fallback dialResult + primaryCtx, primaryCancel := context.WithCancel(ctx) + defer primaryCancel() + go startRacer(primaryCtx, true) + fallbackTimer := time.NewTimer(fallbackDelay) + defer fallbackTimer.Stop() + for { + select { + case <-fallbackTimer.C: + fallbackCtx, fallbackCancel := context.WithCancel(ctx) + defer fallbackCancel() + go startRacer(fallbackCtx, false) + + case res := <-results: + if res.error == nil { + return res.Conn, nil + } + if res.primary { + primary = res + } else { + fallback = res + } + if primary.done && fallback.done { + return nil, primary.error + } + if res.primary && fallbackTimer.Stop() { + fallbackTimer.Reset(0) + } + } + } +} + +func ListenSerialNetworkPacket(ctx context.Context, dialer ParallelInterfaceDialer, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { + if parallelDialer, isParallel := dialer.(ParallelNetworkDialer); isParallel { + return parallelDialer.ListenSerialNetworkPacket(ctx, destination, destinationAddresses, strategy, fallbackDelay) + } + var errors []error + for _, address := range destinationAddresses { + conn, err := dialer.ListenSerialInterfacePacket(ctx, M.SocksaddrFrom(address, destination.Port), strategy, fallbackDelay) + if err == nil { + return conn, address, nil + } + errors = append(errors, err) + } + return nil, netip.Addr{}, E.Errors(errors...) +} diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index 047a2514..b3305d73 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -2,12 +2,16 @@ package dialer import ( "context" + "net" + "net/netip" "time" "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/service" ) @@ -49,3 +53,35 @@ func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) { } return dialer, nil } + +func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInterfaceDialer, error) { + if options.Detour != "" { + return nil, E.New("`detour` is not supported in direct context") + } + networkManager := service.FromContext[adapter.NetworkManager](ctx) + if options.IsWireGuardListener { + return NewDefault(networkManager, options) + } + dialer, err := NewDefault(networkManager, options) + if err != nil { + return nil, err + } + return NewResolveParallelInterfaceDialer( + service.FromContext[adapter.Router](ctx), + dialer, + true, + dns.DomainStrategy(options.DomainStrategy), + time.Duration(options.FallbackDelay), + ), nil +} + +type ParallelInterfaceDialer interface { + N.Dialer + DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) + ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error) +} + +type ParallelNetworkDialer interface { + DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) + ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) +} diff --git a/common/dialer/resolve.go b/common/dialer/resolve.go index f2ee50db..ce17923c 100644 --- a/common/dialer/resolve.go +++ b/common/dialer/resolve.go @@ -7,6 +7,7 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common/bufio" @@ -14,7 +15,12 @@ import ( N "github.com/sagernet/sing/common/network" ) -type ResolveDialer struct { +var ( + _ N.Dialer = (*resolveDialer)(nil) + _ ParallelInterfaceDialer = (*resolveParallelNetworkDialer)(nil) +) + +type resolveDialer struct { dialer N.Dialer parallel bool router adapter.Router @@ -22,8 +28,8 @@ type ResolveDialer struct { fallbackDelay time.Duration } -func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) *ResolveDialer { - return &ResolveDialer{ +func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) N.Dialer { + return &resolveDialer{ dialer, parallel, router, @@ -32,7 +38,25 @@ func NewResolveDialer(router adapter.Router, dialer N.Dialer, parallel bool, str } } -func (d *ResolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { +type resolveParallelNetworkDialer struct { + resolveDialer + dialer ParallelInterfaceDialer +} + +func NewResolveParallelInterfaceDialer(router adapter.Router, dialer ParallelInterfaceDialer, parallel bool, strategy dns.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer { + return &resolveParallelNetworkDialer{ + resolveDialer{ + dialer, + parallel, + router, + strategy, + fallbackDelay, + }, + dialer, + } +} + +func (d *resolveDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { if !destination.IsFqdn() { return d.dialer.DialContext(ctx, network, destination) } @@ -57,7 +81,7 @@ func (d *ResolveDialer) DialContext(ctx context.Context, network string, destina } } -func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { +func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { if !destination.IsFqdn() { return d.dialer.ListenPacket(ctx, destination) } @@ -82,6 +106,59 @@ func (d *ResolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil } -func (d *ResolveDialer) Upstream() any { +func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context, network string, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) { + if !destination.IsFqdn() { + return d.dialer.DialContext(ctx, network, destination) + } + ctx, metadata := adapter.ExtendContext(ctx) + ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) + metadata.Destination = destination + metadata.Domain = "" + var addresses []netip.Addr + var err error + if d.strategy == dns.DomainStrategyAsIS { + addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) + } else { + addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) + } + if err != nil { + return nil, err + } + if fallbackDelay == 0 { + fallbackDelay = d.fallbackDelay + } + if d.parallel { + return DialParallelNetwork(ctx, d.dialer, network, destination, addresses, d.strategy == dns.DomainStrategyPreferIPv6, strategy, fallbackDelay) + } else { + return DialSerialNetwork(ctx, d.dialer, network, destination, addresses, strategy, fallbackDelay) + } +} + +func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.Context, destination M.Socksaddr, strategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, error) { + if !destination.IsFqdn() { + return d.dialer.ListenPacket(ctx, destination) + } + ctx, metadata := adapter.ExtendContext(ctx) + ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug) + metadata.Destination = destination + metadata.Domain = "" + var addresses []netip.Addr + var err error + if d.strategy == dns.DomainStrategyAsIS { + addresses, err = d.router.LookupDefault(ctx, destination.Fqdn) + } else { + addresses, err = d.router.Lookup(ctx, destination.Fqdn, d.strategy) + } + if err != nil { + return nil, err + } + conn, destinationAddress, err := ListenSerialNetworkPacket(ctx, d.dialer, destination, addresses, strategy, fallbackDelay) + if err != nil { + return nil, err + } + return bufio.NewNATPacketConn(bufio.NewPacketConn(conn), M.SocksaddrFrom(destinationAddress, destination.Port), destination), nil +} + +func (d *resolveDialer) Upstream() any { return d.dialer } diff --git a/common/settings/proxy_darwin.go b/common/settings/proxy_darwin.go index 3c06a853..53ed0fe0 100644 --- a/common/settings/proxy_darwin.go +++ b/common/settings/proxy_darwin.go @@ -7,6 +7,7 @@ import ( "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/shell" @@ -33,7 +34,7 @@ func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bo serverAddr: serverAddr, supportSOCKS: supportSOCKS, } - proxy.element = interfaceMonitor.RegisterCallback(proxy.update) + proxy.element = interfaceMonitor.RegisterCallback(proxy.routeUpdate) return proxy, nil } @@ -65,11 +66,8 @@ func (p *DarwinSystemProxy) Disable() error { return err } -func (p *DarwinSystemProxy) update(event int) { - if event&tun.EventInterfaceUpdate == 0 { - return - } - if !p.isEnabled { +func (p *DarwinSystemProxy) routeUpdate(defaultInterface *control.Interface, flags int) { + if !p.isEnabled || defaultInterface == nil { return } _ = p.update0() diff --git a/constant/network.go b/constant/network.go index f5ac2a4e..c026b7b1 100644 --- a/constant/network.go +++ b/constant/network.go @@ -1,8 +1,50 @@ package constant +import ( + "github.com/sagernet/sing/common" + F "github.com/sagernet/sing/common/format" +) + const ( InterfaceTypeWIFI = "wifi" InterfaceTypeCellular = "cellular" InterfaceTypeEthernet = "ethernet" InterfaceTypeOther = "other" ) + +type NetworkStrategy int + +const ( + NetworkStrategyDefault NetworkStrategy = iota + NetworkStrategyFallback + NetworkStrategyHybrid + NetworkStrategyWIFI + NetworkStrategyCellular + NetworkStrategyEthernet + NetworkStrategyWIFIOnly + NetworkStrategyCellularOnly + NetworkStrategyEthernetOnly +) + +var ( + NetworkStrategyToString = map[NetworkStrategy]string{ + NetworkStrategyDefault: "default", + NetworkStrategyFallback: "fallback", + NetworkStrategyHybrid: "hybrid", + NetworkStrategyWIFI: "wifi", + NetworkStrategyCellular: "cellular", + NetworkStrategyEthernet: "ethernet", + NetworkStrategyWIFIOnly: "wifi_only", + NetworkStrategyCellularOnly: "cellular_only", + NetworkStrategyEthernetOnly: "ethernet_only", + } + StringToNetworkStrategy = common.ReverseMap(NetworkStrategyToString) +) + +func (s NetworkStrategy) String() string { + name, loaded := NetworkStrategyToString[s] + if !loaded { + return F.ToString(int(s)) + } + return name +} diff --git a/experimental/libbox/monitor.go b/experimental/libbox/monitor.go index 8f401971..d237d534 100644 --- a/experimental/libbox/monitor.go +++ b/experimental/libbox/monitor.go @@ -67,7 +67,7 @@ func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName s callbacks := m.callbacks.Array() m.defaultInterfaceAccess.Unlock() for _, callback := range callbacks { - callback(tun.EventInterfaceUpdate) + callback(nil, 0) } return } @@ -86,6 +86,6 @@ func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName s callbacks := m.callbacks.Array() m.defaultInterfaceAccess.Unlock() for _, callback := range callbacks { - callback(tun.EventInterfaceUpdate) + callback(newInterface, 0) } } diff --git a/go.mod b/go.mod index f2683f84..95c865f1 100644 --- a/go.mod +++ b/go.mod @@ -25,14 +25,14 @@ require ( github.com/sagernet/gvisor v0.0.0-20241021032506-a4324256e4a3 github.com/sagernet/quic-go v0.48.1-beta.1 github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 - github.com/sagernet/sing v0.6.0-alpha.6 + github.com/sagernet/sing v0.6.0-alpha.7 github.com/sagernet/sing-dns v0.4.0-alpha.1 github.com/sagernet/sing-mux v0.3.0-alpha.1 github.com/sagernet/sing-quic v0.3.0-rc.2 github.com/sagernet/sing-shadowsocks v0.2.7 github.com/sagernet/sing-shadowsocks2 v0.2.0 github.com/sagernet/sing-shadowtls v0.1.4 - github.com/sagernet/sing-tun v0.6.0-alpha.7 + github.com/sagernet/sing-tun v0.6.0-alpha.8 github.com/sagernet/sing-vmess v0.1.12 github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 github.com/sagernet/utls v1.6.7 diff --git a/go.sum b/go.sum index f1f16aea..1e423087 100644 --- a/go.sum +++ b/go.sum @@ -110,8 +110,8 @@ github.com/sagernet/quic-go v0.48.1-beta.1/go.mod h1:1WgdDIVD1Gybp40JTWketeSfKA/ github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691 h1:5Th31OC6yj8byLGkEnIYp6grlXfo1QYUfiYFGjewIdc= github.com/sagernet/reality v0.0.0-20230406110435-ee17307e7691/go.mod h1:B8lp4WkQ1PwNnrVMM6KyuFR20pU8jYBD+A4EhJovEXU= github.com/sagernet/sing v0.2.18/go.mod h1:OL6k2F0vHmEzXz2KW19qQzu172FDgSbUSODylighuVo= -github.com/sagernet/sing v0.6.0-alpha.6 h1:R0abM8ZeazyAKo9d3DNxtrgW17g3tZAD8al7O5+ADOw= -github.com/sagernet/sing v0.6.0-alpha.6/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= +github.com/sagernet/sing v0.6.0-alpha.7 h1:C77ZlUxdSJiHLCLXbmWAVvyllFaNdRl0nUkdbWZyFcU= +github.com/sagernet/sing v0.6.0-alpha.7/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak= github.com/sagernet/sing-dns v0.4.0-alpha.1 h1:2KlP8DeqtGkULFiZtvG2r7SuoJP6orANFzJwC5vDKvg= github.com/sagernet/sing-dns v0.4.0-alpha.1/go.mod h1:vgHATsm4wdymwpvBZPei8RY+546iGXS6hlWv2x6YKcM= github.com/sagernet/sing-mux v0.3.0-alpha.1 h1:IgNX5bJBpL41gGbp05pdDOvh/b5eUQ6cv9240+Ngipg= @@ -124,8 +124,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.0 h1:wpZNs6wKnR7mh1wV9OHwOyUr21VkS3wK github.com/sagernet/sing-shadowsocks2 v0.2.0/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ= github.com/sagernet/sing-shadowtls v0.1.4 h1:aTgBSJEgnumzFenPvc+kbD9/W0PywzWevnVpEx6Tw3k= github.com/sagernet/sing-shadowtls v0.1.4/go.mod h1:F8NBgsY5YN2beQavdgdm1DPlhaKQlaL6lpDdcBglGK4= -github.com/sagernet/sing-tun v0.6.0-alpha.7 h1:h5Fqg+H5VggJq/LGdc/hOctNEcYAdkmfKY83lYIDHUg= -github.com/sagernet/sing-tun v0.6.0-alpha.7/go.mod h1:JkgiLLnQUXln1zLGVoJqUwAulJGT0xoiPU4/pYF1fhU= +github.com/sagernet/sing-tun v0.6.0-alpha.8 h1:HhXyUvXxtaXgT+IILZMq6kbrAyDbUwbN+Df/XxpL7Vo= +github.com/sagernet/sing-tun v0.6.0-alpha.8/go.mod h1:JkgiLLnQUXln1zLGVoJqUwAulJGT0xoiPU4/pYF1fhU= github.com/sagernet/sing-vmess v0.1.12 h1:2gFD8JJb+eTFMoa8FIVMnknEi+vCSfaiTXTfEYAYAPg= github.com/sagernet/sing-vmess v0.1.12/go.mod h1:luTSsfyBGAc9VhtCqwjR+dt1QgqBhuYBCONB/POhF8I= github.com/sagernet/smux v0.0.0-20231208180855-7041f6ea79e7 h1:DImB4lELfQhplLTxeq2z31Fpv8CQqqrUwTbrIRumZqQ= diff --git a/option/outbound.go b/option/outbound.go index 0e2d5874..5791802c 100644 --- a/option/outbound.go +++ b/option/outbound.go @@ -65,21 +65,23 @@ type DialerOptionsWrapper interface { } type DialerOptions struct { - Detour string `json:"detour,omitempty"` - BindInterface string `json:"bind_interface,omitempty"` - Inet4BindAddress *badoption.Addr `json:"inet4_bind_address,omitempty"` - Inet6BindAddress *badoption.Addr `json:"inet6_bind_address,omitempty"` - ProtectPath string `json:"protect_path,omitempty"` - RoutingMark uint32 `json:"routing_mark,omitempty"` - ReuseAddr bool `json:"reuse_addr,omitempty"` - ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` - TCPFastOpen bool `json:"tcp_fast_open,omitempty"` - TCPMultiPath bool `json:"tcp_multi_path,omitempty"` - UDPFragment *bool `json:"udp_fragment,omitempty"` - UDPFragmentDefault bool `json:"-"` - DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` - FallbackDelay badoption.Duration `json:"fallback_delay,omitempty"` - IsWireGuardListener bool `json:"-"` + Detour string `json:"detour,omitempty"` + BindInterface string `json:"bind_interface,omitempty"` + Inet4BindAddress *badoption.Addr `json:"inet4_bind_address,omitempty"` + Inet6BindAddress *badoption.Addr `json:"inet6_bind_address,omitempty"` + ProtectPath string `json:"protect_path,omitempty"` + RoutingMark uint32 `json:"routing_mark,omitempty"` + ReuseAddr bool `json:"reuse_addr,omitempty"` + ConnectTimeout badoption.Duration `json:"connect_timeout,omitempty"` + TCPFastOpen bool `json:"tcp_fast_open,omitempty"` + TCPMultiPath bool `json:"tcp_multi_path,omitempty"` + UDPFragment *bool `json:"udp_fragment,omitempty"` + UDPFragmentDefault bool `json:"-"` + DomainStrategy DomainStrategy `json:"domain_strategy,omitempty"` + NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` + FallbackDelay badoption.Duration `json:"fallback_delay,omitempty"` + NetworkFallbackDelay badoption.Duration `json:"network_fallback_delay,omitempty"` + IsWireGuardListener bool `json:"-"` } func (o *DialerOptions) TakeDialerOptions() DialerOptions { diff --git a/option/route.go b/option/route.go index dfd72986..236e56f7 100644 --- a/option/route.go +++ b/option/route.go @@ -1,16 +1,20 @@ package option +import "github.com/sagernet/sing/common/json/badoption" + type RouteOptions struct { - GeoIP *GeoIPOptions `json:"geoip,omitempty"` - Geosite *GeositeOptions `json:"geosite,omitempty"` - Rules []Rule `json:"rules,omitempty"` - RuleSet []RuleSet `json:"rule_set,omitempty"` - Final string `json:"final,omitempty"` - FindProcess bool `json:"find_process,omitempty"` - AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` - OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"` - DefaultInterface string `json:"default_interface,omitempty"` - DefaultMark uint32 `json:"default_mark,omitempty"` + GeoIP *GeoIPOptions `json:"geoip,omitempty"` + Geosite *GeositeOptions `json:"geosite,omitempty"` + Rules []Rule `json:"rules,omitempty"` + RuleSet []RuleSet `json:"rule_set,omitempty"` + Final string `json:"final,omitempty"` + FindProcess bool `json:"find_process,omitempty"` + AutoDetectInterface bool `json:"auto_detect_interface,omitempty"` + OverrideAndroidVPN bool `json:"override_android_vpn,omitempty"` + DefaultInterface string `json:"default_interface,omitempty"` + DefaultMark uint32 `json:"default_mark,omitempty"` + DefaultNetworkStrategy NetworkStrategy `json:"default_network_strategy,omitempty"` + DefaultFallbackDelay badoption.Duration `json:"default_fallback_delay,omitempty"` } type GeoIPOptions struct { diff --git a/option/rule_action.go b/option/rule_action.go index 9bc13039..7c31ea7a 100644 --- a/option/rule_action.go +++ b/option/rule_action.go @@ -137,14 +137,18 @@ func (r *DNSRuleAction) UnmarshalJSONContext(ctx context.Context, data []byte) e } type RouteActionOptions struct { - Outbound string `json:"outbound,omitempty"` - UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` - UDPConnect bool `json:"udp_connect,omitempty"` + Outbound string `json:"outbound,omitempty"` + NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` + FallbackDelay uint32 `json:"fallback_delay,omitempty"` + UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` + UDPConnect bool `json:"udp_connect,omitempty"` } type _RouteOptionsActionOptions struct { - UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` - UDPConnect bool `json:"udp_connect,omitempty"` + NetworkStrategy NetworkStrategy `json:"network_strategy,omitempty"` + FallbackDelay uint32 `json:"fallback_delay,omitempty"` + UDPDisableDomainUnmapping bool `json:"udp_disable_domain_unmapping,omitempty"` + UDPConnect bool `json:"udp_connect,omitempty"` } type RouteOptionsActionOptions _RouteOptionsActionOptions diff --git a/option/types.go b/option/types.go index 04e3f10e..8ed06250 100644 --- a/option/types.go +++ b/option/types.go @@ -3,6 +3,7 @@ package option import ( "strings" + C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" @@ -150,3 +151,23 @@ func DNSQueryTypeToString(queryType uint16) string { } return F.ToString(queryType) } + +type NetworkStrategy C.NetworkStrategy + +func (n NetworkStrategy) MarshalJSON() ([]byte, error) { + return json.Marshal(C.NetworkStrategy(n).String()) +} + +func (n *NetworkStrategy) UnmarshalJSON(content []byte) error { + var value string + err := json.Unmarshal(content, &value) + if err != nil { + return err + } + strategy, loaded := C.StringToNetworkStrategy[value] + if !loaded { + return E.New("unknown network strategy: ", value) + } + *n = NetworkStrategy(strategy) + return nil +} diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 27b334c9..4251c366 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" dns "github.com/sagernet/sing-dns" + "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" @@ -24,31 +25,38 @@ func RegisterOutbound(registry *outbound.Registry) { outbound.Register[option.DirectOutboundOptions](registry, C.TypeDirect, NewOutbound) } -var _ N.ParallelDialer = (*Outbound)(nil) +var ( + _ N.ParallelDialer = (*Outbound)(nil) + _ dialer.ParallelNetworkDialer = (*Outbound)(nil) +) type Outbound struct { outbound.Adapter - logger logger.ContextLogger - dialer N.Dialer - domainStrategy dns.DomainStrategy - fallbackDelay time.Duration - overrideOption int - overrideDestination M.Socksaddr + logger logger.ContextLogger + dialer dialer.ParallelInterfaceDialer + domainStrategy dns.DomainStrategy + fallbackDelay time.Duration + networkStrategy C.NetworkStrategy + networkFallbackDelay time.Duration + overrideOption int + overrideDestination M.Socksaddr // loopBack *loopBackDetector } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) { options.UDPFragmentDefault = true - outboundDialer, err := dialer.New(ctx, options.DialerOptions) + outboundDialer, err := dialer.NewDirect(ctx, options.DialerOptions) if err != nil { return nil, err } outbound := &Outbound{ - Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions), - logger: logger, - domainStrategy: dns.DomainStrategy(options.DomainStrategy), - fallbackDelay: time.Duration(options.FallbackDelay), - dialer: outboundDialer, + Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.DialerOptions), + logger: logger, + domainStrategy: dns.DomainStrategy(options.DomainStrategy), + fallbackDelay: time.Duration(options.FallbackDelay), + networkStrategy: C.NetworkStrategy(options.NetworkStrategy), + networkFallbackDelay: time.Duration(options.NetworkFallbackDelay), + dialer: outboundDialer, // loopBack: newLoopBackDetector(router), } if options.ProxyProtocol != 0 { @@ -96,33 +104,6 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination return h.dialer.DialContext(ctx, network, destination) } -func (h *Outbound) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) { - ctx, metadata := adapter.ExtendContext(ctx) - metadata.Outbound = h.Tag() - metadata.Destination = destination - switch h.overrideOption { - case 1, 2: - // override address - return h.DialContext(ctx, network, destination) - case 3: - destination.Port = h.overrideDestination.Port - } - network = N.NetworkName(network) - switch network { - case N.NetworkTCP: - h.logger.InfoContext(ctx, "outbound connection to ", destination) - case N.NetworkUDP: - h.logger.InfoContext(ctx, "outbound packet connection to ", destination) - } - var domainStrategy dns.DomainStrategy - if h.domainStrategy != dns.DomainStrategyAsIS { - domainStrategy = h.domainStrategy - } else { - domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) - } - return N.DialParallel(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, h.fallbackDelay) -} - func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { ctx, metadata := adapter.ExtendContext(ctx) metadata.Outbound = h.Tag() @@ -154,6 +135,110 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n return conn, nil } +func (h *Outbound) DialParallel(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr) (net.Conn, error) { + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Outbound = h.Tag() + metadata.Destination = destination + switch h.overrideOption { + case 1, 2: + // override address + return h.DialContext(ctx, network, destination) + case 3: + destination.Port = h.overrideDestination.Port + } + network = N.NetworkName(network) + switch network { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + case N.NetworkUDP: + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + } + var domainStrategy dns.DomainStrategy + if h.domainStrategy != dns.DomainStrategyAsIS { + domainStrategy = h.domainStrategy + } else { + domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) + } + switch domainStrategy { + case dns.DomainStrategyUseIPv4: + destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4) + if len(destinationAddresses) == 0 { + return nil, E.New("no IPv4 address available for ", destination) + } + case dns.DomainStrategyUseIPv6: + destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6) + if len(destinationAddresses) == 0 { + return nil, E.New("no IPv6 address available for ", destination) + } + } + return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, h.networkStrategy, h.fallbackDelay) +} + +func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy C.NetworkStrategy, fallbackDelay time.Duration) (net.Conn, error) { + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Outbound = h.Tag() + metadata.Destination = destination + switch h.overrideOption { + case 1, 2: + // override address + return h.DialContext(ctx, network, destination) + case 3: + destination.Port = h.overrideDestination.Port + } + network = N.NetworkName(network) + switch network { + case N.NetworkTCP: + h.logger.InfoContext(ctx, "outbound connection to ", destination) + case N.NetworkUDP: + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + } + var domainStrategy dns.DomainStrategy + if h.domainStrategy != dns.DomainStrategyAsIS { + domainStrategy = h.domainStrategy + } else { + domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy) + } + switch domainStrategy { + case dns.DomainStrategyUseIPv4: + destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4) + if len(destinationAddresses) == 0 { + return nil, E.New("no IPv4 address available for ", destination) + } + case dns.DomainStrategyUseIPv6: + destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6) + if len(destinationAddresses) == 0 { + return nil, E.New("no IPv6 address available for ", destination) + } + } + return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, networkStrategy, fallbackDelay) +} + +func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy C.NetworkStrategy, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) { + ctx, metadata := adapter.ExtendContext(ctx) + metadata.Outbound = h.Tag() + metadata.Destination = destination + switch h.overrideOption { + case 1: + destination = h.overrideDestination + case 2: + newDestination := h.overrideDestination + newDestination.Port = destination.Port + destination = newDestination + case 3: + destination.Port = h.overrideDestination.Port + } + if h.overrideOption == 0 { + h.logger.InfoContext(ctx, "outbound packet connection") + } else { + h.logger.InfoContext(ctx, "outbound packet connection to ", destination) + } + conn, newDestination, err := dialer.ListenSerialNetworkPacket(ctx, h.dialer, destination, destinationAddresses, networkStrategy, fallbackDelay) + if err != nil { + return nil, netip.Addr{}, err + } + return conn, newDestination, nil +} + /*func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { if h.loopBack.CheckConn(metadata.Source.AddrPort(), M.AddrPortFromNet(conn.LocalAddr())) { return E.New("reject loopback connection to ", metadata.Destination) diff --git a/protocol/socks/outbound.go b/protocol/socks/outbound.go index dbb5ab61..70a5a5ed 100644 --- a/protocol/socks/outbound.go +++ b/protocol/socks/outbound.go @@ -10,7 +10,6 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" @@ -115,23 +114,3 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n h.logger.InfoContext(ctx, "outbound packet connection to ", destination) return h.client.ListenPacket(ctx, destination) } - -// TODO -// Deprecated -func (h *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - if h.resolve { - return outbound.NewDirectConnection(ctx, h.router, h, conn, metadata, dns.DomainStrategyUseIPv4) - } else { - return outbound.NewConnection(ctx, h, conn, metadata) - } -} - -// TODO -// Deprecated -func (h *Outbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - if h.resolve { - return outbound.NewDirectPacketConnection(ctx, h.router, h, conn, metadata, dns.DomainStrategyUseIPv4) - } else { - return outbound.NewPacketConnection(ctx, h, conn, metadata) - } -} diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index 7f33bf60..7b2f8a6c 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -16,7 +16,6 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/transport/wireguard" - "github.com/sagernet/sing-dns" "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" @@ -231,15 +230,3 @@ func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n } return w.tunDevice.ListenPacket(ctx, destination) } - -// TODO -// Deprecated -func (w *Outbound) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { - return outbound.NewDirectConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS) -} - -// TODO -// Deprecated -func (w *Outbound) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { - return outbound.NewDirectPacketConnection(ctx, w.router, w, conn, metadata, dns.DomainStrategyAsIS) -} diff --git a/route/network.go b/route/network.go index c6253b91..51fbdf05 100644 --- a/route/network.go +++ b/route/network.go @@ -8,6 +8,7 @@ import ( "runtime" "strings" "syscall" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/conntrack" @@ -38,8 +39,7 @@ type NetworkManager struct { networkInterfaces atomic.TypedValue[[]adapter.NetworkInterface] autoDetectInterface bool - defaultInterface string - defaultMark uint32 + defaultOptions adapter.NetworkOptions autoRedirectOutputMark uint32 networkMonitor tun.NetworkUpdateMonitor @@ -58,11 +58,23 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp logger: logger, interfaceFinder: control.NewDefaultInterfaceFinder(), autoDetectInterface: routeOptions.AutoDetectInterface, - defaultInterface: routeOptions.DefaultInterface, - defaultMark: routeOptions.DefaultMark, - pauseManager: service.FromContext[pause.Manager](ctx), - platformInterface: service.FromContext[platform.Interface](ctx), - outboundManager: service.FromContext[adapter.OutboundManager](ctx), + defaultOptions: adapter.NetworkOptions{ + DefaultInterface: routeOptions.DefaultInterface, + DefaultMark: routeOptions.DefaultMark, + DefaultNetworkStrategy: C.NetworkStrategy(routeOptions.DefaultNetworkStrategy), + DefaultFallbackDelay: time.Duration(routeOptions.DefaultFallbackDelay), + }, + pauseManager: service.FromContext[pause.Manager](ctx), + platformInterface: service.FromContext[platform.Interface](ctx), + outboundManager: service.FromContext[adapter.OutboundManager](ctx), + } + if C.NetworkStrategy(routeOptions.DefaultNetworkStrategy) != C.NetworkStrategyDefault { + if routeOptions.DefaultInterface != "" { + return nil, E.New("`default_network_strategy` is conflict with `default_interface`") + } + if !routeOptions.AutoDetectInterface { + return nil, E.New("`auto_detect_interface` is required by `default_network_strategy`") + } } usePlatformDefaultInterfaceMonitor := nm.platformInterface != nil enforceInterfaceMonitor := routeOptions.AutoDetectInterface @@ -84,12 +96,12 @@ func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOp if err != nil { return nil, E.New("auto_detect_interface unsupported on current platform") } - interfaceMonitor.RegisterCallback(nm.notifyNetworkUpdate) + interfaceMonitor.RegisterCallback(nm.notifyInterfaceUpdate) nm.interfaceMonitor = interfaceMonitor } } else { interfaceMonitor := nm.platformInterface.CreateDefaultInterfaceMonitor(logger) - interfaceMonitor.RegisterCallback(nm.notifyNetworkUpdate) + interfaceMonitor.RegisterCallback(nm.notifyInterfaceUpdate) nm.interfaceMonitor = interfaceMonitor } return nm, nil @@ -265,10 +277,6 @@ func (r *NetworkManager) NetworkInterfaces() []adapter.NetworkInterface { return r.networkInterfaces.Load() } -func (r *NetworkManager) DefaultInterface() string { - return r.defaultInterface -} - func (r *NetworkManager) AutoDetectInterface() bool { return r.autoDetectInterface } @@ -301,8 +309,19 @@ func (r *NetworkManager) AutoDetectInterfaceFunc() control.Func { } } -func (r *NetworkManager) DefaultMark() uint32 { - return r.defaultMark +func (r *NetworkManager) ProtectFunc() control.Func { + if r.platformInterface != nil && r.platformInterface.UsePlatformAutoDetectInterfaceControl() { + return func(network, address string, conn syscall.RawConn) error { + return control.Raw(conn, func(fd uintptr) error { + return r.platformInterface.AutoDetectInterfaceControl(int(fd)) + }) + } + } + return nil +} + +func (r *NetworkManager) DefaultOptions() adapter.NetworkOptions { + return r.defaultOptions } func (r *NetworkManager) RegisterAutoRedirectOutputMark(mark uint32) error { @@ -344,45 +363,47 @@ func (r *NetworkManager) ResetNetwork() { } } -func (r *NetworkManager) notifyNetworkUpdate(event int) { - if event == tun.EventNoRoute { +func (r *NetworkManager) notifyInterfaceUpdate(defaultInterface *control.Interface, flags int) { + if defaultInterface == nil { r.pauseManager.NetworkPause() r.logger.Error("missing default interface") - } else { - r.pauseManager.NetworkWake() - defaultInterface := r.DefaultNetworkInterface() - if defaultInterface == nil { - panic("invalid interface context") - } - var options []string - options = append(options, F.ToString("index ", defaultInterface.Index)) - if C.IsAndroid && r.platformInterface == nil { - var vpnStatus string - if r.interfaceMonitor.AndroidVPNEnabled() { - vpnStatus = "enabled" - } else { - vpnStatus = "disabled" - } - options = append(options, "vpn "+vpnStatus) + return + } + + r.pauseManager.NetworkWake() + var options []string + options = append(options, F.ToString("index ", defaultInterface.Index)) + if C.IsAndroid && r.platformInterface == nil { + var vpnStatus string + if r.interfaceMonitor.AndroidVPNEnabled() { + vpnStatus = "enabled" } else { - if defaultInterface.Type != "" { - options = append(options, F.ToString("type ", defaultInterface.Type)) - } - if defaultInterface.Expensive { - options = append(options, "expensive") - } - if defaultInterface.Constrained { - options = append(options, "constrained") - } + vpnStatus = "disabled" } - r.logger.Info("updated default interface ", defaultInterface.Name, ", ", strings.Join(options, ", ")) - if r.platformInterface != nil { - state := r.platformInterface.ReadWIFIState() - if state != r.wifiState { - r.wifiState = state - if state.SSID != "" { - r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID) - } + options = append(options, "vpn "+vpnStatus) + } else if r.platformInterface != nil { + networkInterface := common.Find(r.networkInterfaces.Load(), func(it adapter.NetworkInterface) bool { + return it.Interface.Index == defaultInterface.Index + }) + if networkInterface.Type == "" { + // race + return + } + options = append(options, F.ToString("type ", networkInterface.Type)) + if networkInterface.Expensive { + options = append(options, "expensive") + } + if networkInterface.Constrained { + options = append(options, "constrained") + } + } + r.logger.Info("updated default interface ", defaultInterface.Name, ", ", strings.Join(options, ", ")) + if r.platformInterface != nil { + state := r.platformInterface.ReadWIFIState() + if state != r.wifiState { + r.wifiState = state + if state.SSID != "" { + r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID) } } } diff --git a/route/route.go b/route/route.go index 1c4da4b7..051ee403 100644 --- a/route/route.go +++ b/route/route.go @@ -424,9 +424,13 @@ match: } switch action := currentRule.Action().(type) { case *rule.RuleActionRoute: + metadata.NetworkStrategy = action.NetworkStrategy + metadata.FallbackDelay = action.FallbackDelay metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping metadata.UDPConnect = action.UDPConnect case *rule.RuleActionRouteOptions: + metadata.NetworkStrategy = action.NetworkStrategy + metadata.FallbackDelay = action.FallbackDelay metadata.UDPDisableDomainUnmapping = action.UDPDisableDomainUnmapping metadata.UDPConnect = action.UDPConnect case *rule.RuleActionSniff: diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index d44e36ee..f9b2e641 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -30,12 +30,16 @@ func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action opti return &RuleActionRoute{ Outbound: action.RouteOptions.Outbound, RuleActionRouteOptions: RuleActionRouteOptions{ + NetworkStrategy: C.NetworkStrategy(action.RouteOptions.NetworkStrategy), + FallbackDelay: time.Duration(action.RouteOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptions.UDPDisableDomainUnmapping, UDPConnect: action.RouteOptions.UDPConnect, }, }, nil case C.RuleActionTypeRouteOptions: return &RuleActionRouteOptions{ + NetworkStrategy: C.NetworkStrategy(action.RouteOptionsOptions.NetworkStrategy), + FallbackDelay: time.Duration(action.RouteOptionsOptions.FallbackDelay), UDPDisableDomainUnmapping: action.RouteOptionsOptions.UDPDisableDomainUnmapping, UDPConnect: action.RouteOptionsOptions.UDPConnect, }, nil @@ -135,6 +139,8 @@ func (r *RuleActionRoute) String() string { } type RuleActionRouteOptions struct { + NetworkStrategy C.NetworkStrategy + FallbackDelay time.Duration UDPDisableDomainUnmapping bool UDPConnect bool } diff --git a/transport/dhcp/server.go b/transport/dhcp/server.go index 9a06ac17..29c6bbe0 100644 --- a/transport/dhcp/server.go +++ b/transport/dhcp/server.go @@ -166,7 +166,7 @@ func (t *Transport) updateServers() error { } } -func (t *Transport) interfaceUpdated(int) { +func (t *Transport) interfaceUpdated(defaultInterface *control.Interface, flags int) { err := t.updateServers() if err != nil { t.options.Logger.Error("update servers: ", err)