From 1df8dfcadea24834b3f5d904dafbd3b414fae3b1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Sun, 10 Nov 2024 12:11:21 +0800 Subject: [PATCH] refactor: Modular network manager --- adapter/inbound/manager.go | 4 +- adapter/network.go | 23 ++ adapter/outbound/manager.go | 20 +- adapter/router.go | 19 +- box.go | 56 +-- cmd/sing-box/cmd_rule_set_match.go | 3 +- common/dialer/default.go | 24 +- common/dialer/dialer.go | 25 +- common/settings/proxy_darwin.go | 2 +- experimental/libbox/command_group.go | 4 +- experimental/libbox/command_select.go | 2 +- experimental/libbox/command_urltest.go | 4 +- experimental/libbox/config.go | 2 +- experimental/libbox/monitor.go | 2 +- experimental/libbox/platform/interface.go | 4 +- experimental/libbox/service.go | 10 +- experimental/libbox/service_pause.go | 2 +- include/clashapi_stub.go | 2 +- protocol/direct/loopback_detect.go | 16 +- protocol/tun/inbound.go | 24 +- protocol/wireguard/outbound.go | 2 +- route/geo_resources.go | 10 +- route/network.go | 337 ++++++++++++++++ route/route.go | 34 +- route/router.go | 455 +++------------------- route/rule/rule_default.go | 19 +- route/rule/rule_dns.go | 19 +- route/rule/rule_headless.go | 24 +- route/rule/rule_item_wifi_bssid.go | 12 +- route/rule/rule_item_wifi_ssid.go | 12 +- route/rule/rule_set.go | 6 +- route/rule/rule_set_local.go | 8 +- route/rule/rule_set_remote.go | 6 +- transport/dhcp/server.go | 25 +- transport/wireguard/device_system.go | 4 +- 35 files changed, 615 insertions(+), 606 deletions(-) create mode 100644 adapter/network.go create mode 100644 route/network.go diff --git a/adapter/inbound/manager.go b/adapter/inbound/manager.go index d2be4b57..69a3ad46 100644 --- a/adapter/inbound/manager.go +++ b/adapter/inbound/manager.go @@ -52,13 +52,13 @@ func (m *Manager) Start(stage adapter.StartStage) error { func (m *Manager) Close() error { m.access.Lock() + defer m.access.Unlock() if !m.started { - panic("not started") + return nil } m.started = false inbounds := m.inbounds m.inbounds = nil - m.access.Unlock() monitor := taskmonitor.New(m.logger, C.StopTimeout) var err error for _, inbound := range inbounds { diff --git a/adapter/network.go b/adapter/network.go new file mode 100644 index 00000000..0ce27411 --- /dev/null +++ b/adapter/network.go @@ -0,0 +1,23 @@ +package adapter + +import ( + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/control" +) + +type NetworkManager interface { + NewService + InterfaceFinder() control.InterfaceFinder + UpdateInterfaces() error + DefaultInterface() string + AutoDetectInterface() bool + AutoDetectInterfaceFunc() control.Func + DefaultMark() uint32 + RegisterAutoRedirectOutputMark(mark uint32) error + AutoRedirectOutputMark() uint32 + NetworkMonitor() tun.NetworkUpdateMonitor + InterfaceMonitor() tun.DefaultInterfaceMonitor + PackageManager() tun.PackageManager + WIFIState() WIFIState + ResetNetwork() +} diff --git a/adapter/outbound/manager.go b/adapter/outbound/manager.go index b3e1a170..10a89a1c 100644 --- a/adapter/outbound/manager.go +++ b/adapter/outbound/manager.go @@ -48,16 +48,17 @@ func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) { func (m *Manager) Start(stage adapter.StartStage) error { m.access.Lock() - defer m.access.Unlock() if m.started && m.stage >= stage { panic("already started") } m.started = true m.stage = stage + outbounds := m.outbounds + m.access.Unlock() if stage == adapter.StartStateStart { - m.startOutbounds() + return m.startOutbounds(outbounds) } else { - for _, outbound := range m.outbounds { + for _, outbound := range outbounds { err := adapter.LegacyStart(outbound, stage) if err != nil { return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]") @@ -67,13 +68,13 @@ func (m *Manager) Start(stage adapter.StartStage) error { return nil } -func (m *Manager) startOutbounds() error { +func (m *Manager) startOutbounds(outbounds []adapter.Outbound) error { monitor := taskmonitor.New(m.logger, C.StartTimeout) started := make(map[string]bool) for { canContinue := false startOne: - for _, outboundToStart := range m.outbounds { + for _, outboundToStart := range outbounds { outboundTag := outboundToStart.Tag() if started[outboundTag] { continue @@ -97,13 +98,13 @@ func (m *Manager) startOutbounds() error { } } } - if len(started) == len(m.outbounds) { + if len(started) == len(outbounds) { break } if canContinue { continue } - currentOutbound := common.Find(m.outbounds, func(it adapter.Outbound) bool { + currentOutbound := common.Find(outbounds, func(it adapter.Outbound) bool { return !started[it.Tag()] }) var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error @@ -114,7 +115,9 @@ func (m *Manager) startOutbounds() error { if common.Contains(oTree, problemOutboundTag) { return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag) } + m.access.Lock() problemOutbound := m.outboundByTag[problemOutboundTag] + m.access.Unlock() if problemOutbound == nil { return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]") } @@ -129,7 +132,8 @@ func (m *Manager) Close() error { monitor := taskmonitor.New(m.logger, C.StopTimeout) m.access.Lock() if !m.started { - panic("not started") + m.access.Unlock() + return nil } m.started = false outbounds := m.outbounds diff --git a/adapter/router.go b/adapter/router.go index b8ac51f5..40a461a7 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -10,8 +10,6 @@ import ( "github.com/sagernet/sing-box/common/geoip" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-dns" - "github.com/sagernet/sing-tun" - "github.com/sagernet/sing/common/control" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" @@ -31,28 +29,13 @@ type Router interface { GeoIPReader() *geoip.Reader LoadGeosite(code string) (Rule, error) - RuleSet(tag string) (RuleSet, bool) - NeedWIFIState() bool Exchange(ctx context.Context, message *mdns.Msg) (*mdns.Msg, error) Lookup(ctx context.Context, domain string, strategy dns.DomainStrategy) ([]netip.Addr, error) LookupDefault(ctx context.Context, domain string) ([]netip.Addr, error) ClearDNSCache() - - InterfaceFinder() control.InterfaceFinder - UpdateInterfaces() error - DefaultInterface() string - AutoDetectInterface() bool - AutoDetectInterfaceFunc() control.Func - DefaultMark() uint32 - RegisterAutoRedirectOutputMark(mark uint32) error - AutoRedirectOutputMark() uint32 - NetworkMonitor() tun.NetworkUpdateMonitor - InterfaceMonitor() tun.DefaultInterfaceMonitor - PackageManager() tun.PackageManager - WIFIState() WIFIState Rules() []Rule ClashServer() ClashServer @@ -61,7 +44,7 @@ type Router interface { V2RayServer() V2RayServer SetV2RayServer(server V2RayServer) - ResetNetwork() error + ResetNetwork() } // Deprecated: Use ConnectionRouterEx instead. diff --git a/box.go b/box.go index a4ca530f..d3fe5e46 100644 --- a/box.go +++ b/box.go @@ -34,6 +34,7 @@ type Box struct { router adapter.Router inbound *inbound.Manager outbound *outbound.Manager + network *route.NetworkManager logFactory log.Factory logger log.ContextLogger preServices1 map[string]adapter.Service @@ -109,20 +110,18 @@ func New(options Options) (*Box, error) { return nil, E.Cause(err, "create log factory") } routeOptions := common.PtrValueOrDefault(options.Route) - inboundManager := inbound.NewManager(logFactory.NewLogger("inbound-manager"), inboundRegistry) - outboundManager := outbound.NewManager(logFactory.NewLogger("outbound-manager"), outboundRegistry, routeOptions.Final) + inboundManager := inbound.NewManager(logFactory.NewLogger("inbound"), inboundRegistry) + outboundManager := outbound.NewManager(logFactory.NewLogger("outbound"), outboundRegistry, routeOptions.Final) ctx = service.ContextWith[adapter.InboundManager](ctx, inboundManager) ctx = service.ContextWith[adapter.OutboundManager](ctx, outboundManager) - router, err := route.NewRouter( - ctx, - logFactory, - routeOptions, - common.PtrValueOrDefault(options.DNS), - common.PtrValueOrDefault(options.NTP), - options.Inbounds, - ) + networkManager, err := route.NewNetworkManager(ctx, logFactory.NewLogger("network"), routeOptions) if err != nil { - return nil, E.Cause(err, "parse route options") + return nil, E.Cause(err, "initialize network manager") + } + ctx = service.ContextWith[adapter.NetworkManager](ctx, networkManager) + router, err := route.NewRouter(ctx, logFactory, routeOptions, common.PtrValueOrDefault(options.DNS), common.PtrValueOrDefault(options.NTP)) + if err != nil { + return nil, E.Cause(err, "initialize router") } //nolint:staticcheck if len(options.LegacyInbounds) > 0 { @@ -197,11 +196,8 @@ func New(options Options) (*Box, error) { option.DirectOutboundOptions{}, ), )) - if err != nil { - return nil, err - } if platformInterface != nil { - err = platformInterface.Initialize(ctx, router) + err = platformInterface.Initialize(networkManager) if err != nil { return nil, E.Cause(err, "initialize platform interface") } @@ -239,6 +235,7 @@ func New(options Options) (*Box, error) { router: router, inbound: inboundManager, outbound: outboundManager, + network: networkManager, createdAt: createdAt, logFactory: logFactory, logger: logFactory.Logger(), @@ -315,6 +312,10 @@ func (s *Box) preStart() error { } } } + err = s.network.Start(adapter.StartStateInitialize) + if err != nil { + return E.Cause(err, "initialize network manager") + } err = s.router.Start(adapter.StartStateInitialize) if err != nil { return E.Cause(err, "initialize router") @@ -323,6 +324,10 @@ func (s *Box) preStart() error { if err != nil { return err } + err = s.network.Start(adapter.StartStateStart) + if err != nil { + return err + } return s.router.Start(adapter.StartStateStart) } @@ -357,6 +362,10 @@ func (s *Box) start() error { if err != nil { return err } + err = s.network.Start(adapter.StartStatePostStart) + if err != nil { + return err + } err = s.router.Start(adapter.StartStatePostStart) if err != nil { return err @@ -365,6 +374,10 @@ func (s *Box) start() error { if err != nil { return err } + err = s.network.Start(adapter.StartStateStarted) + if err != nil { + return err + } err = s.router.Start(adapter.StartStateStarted) if err != nil { return err @@ -398,13 +411,8 @@ func (s *Box) Close() error { } errors = E.Errors(errors, s.inbound.Close()) errors = E.Errors(errors, s.outbound.Close()) - monitor.Start("close router") - if err := common.Close(s.router); err != nil { - errors = E.Append(errors, err, func(err error) error { - return E.Cause(err, "close router") - }) - } - monitor.Finish() + errors = E.Errors(errors, s.network.Close()) + errors = E.Errors(errors, s.router.Close()) for serviceName, service := range s.preServices1 { monitor.Start("close ", serviceName) errors = E.Append(errors, service.Close(), func(err error) error { @@ -435,6 +443,10 @@ func (s *Box) Outbound() adapter.OutboundManager { return s.outbound } +func (s *Box) Network() adapter.NetworkManager { + return s.network +} + func (s *Box) Router() adapter.Router { return s.router } diff --git a/cmd/sing-box/cmd_rule_set_match.go b/cmd/sing-box/cmd_rule_set_match.go index bfa340cf..7a3a0423 100644 --- a/cmd/sing-box/cmd_rule_set_match.go +++ b/cmd/sing-box/cmd_rule_set_match.go @@ -2,6 +2,7 @@ package main import ( "bytes" + "context" "io" "os" @@ -84,7 +85,7 @@ func ruleSetMatch(sourcePath string, domain string) error { } for i, ruleOptions := range plainRuleSet.Rules { var currentRule adapter.HeadlessRule - currentRule, err = rule.NewHeadlessRule(nil, ruleOptions) + currentRule, err = rule.NewHeadlessRule(context.Background(), ruleOptions) if err != nil { return E.Cause(err, "parse rule_set.rules.[", i, "]") } diff --git a/common/dialer/default.go b/common/dialer/default.go index fcc5e60e..a4e4290e 100644 --- a/common/dialer/default.go +++ b/common/dialer/default.go @@ -29,31 +29,31 @@ type DefaultDialer struct { isWireGuardListener bool } -func NewDefault(router adapter.Router, options option.DialerOptions) (*DefaultDialer, error) { +func NewDefault(networkManager adapter.NetworkManager, options option.DialerOptions) (*DefaultDialer, error) { var dialer net.Dialer var listener net.ListenConfig if options.BindInterface != "" { var interfaceFinder control.InterfaceFinder - if router != nil { - interfaceFinder = router.InterfaceFinder() + if networkManager != nil { + interfaceFinder = networkManager.InterfaceFinder() } else { interfaceFinder = control.NewDefaultInterfaceFinder() } bindFunc := control.BindToInterface(interfaceFinder, options.BindInterface, -1) dialer.Control = control.Append(dialer.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc) - } else if router != nil && router.AutoDetectInterface() { - bindFunc := router.AutoDetectInterfaceFunc() + } else if networkManager != nil && networkManager.AutoDetectInterface() { + bindFunc := networkManager.AutoDetectInterfaceFunc() dialer.Control = control.Append(dialer.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc) - } else if router != nil && router.DefaultInterface() != "" { - bindFunc := control.BindToInterface(router.InterfaceFinder(), router.DefaultInterface(), -1) + } else if networkManager != nil && networkManager.DefaultInterface() != "" { + bindFunc := control.BindToInterface(networkManager.InterfaceFinder(), networkManager.DefaultInterface(), -1) dialer.Control = control.Append(dialer.Control, bindFunc) listener.Control = control.Append(listener.Control, bindFunc) } var autoRedirectOutputMark uint32 - if router != nil { - autoRedirectOutputMark = router.AutoRedirectOutputMark() + if networkManager != nil { + autoRedirectOutputMark = networkManager.AutoRedirectOutputMark() } if autoRedirectOutputMark > 0 { dialer.Control = control.Append(dialer.Control, control.RoutingMark(autoRedirectOutputMark)) @@ -65,9 +65,9 @@ func NewDefault(router adapter.Router, options option.DialerOptions) (*DefaultDi if autoRedirectOutputMark > 0 { return nil, E.New("`auto_redirect` with `route_[_exclude]_address_set is conflict with `routing_mark`") } - } else if router != nil && router.DefaultMark() > 0 { - dialer.Control = control.Append(dialer.Control, control.RoutingMark(router.DefaultMark())) - listener.Control = control.Append(listener.Control, control.RoutingMark(router.DefaultMark())) + } else if networkManager != nil && networkManager.DefaultMark() > 0 { + dialer.Control = control.Append(dialer.Control, control.RoutingMark(networkManager.DefaultMark())) + listener.Control = control.Append(listener.Control, control.RoutingMark(networkManager.DefaultMark())) if autoRedirectOutputMark > 0 { return nil, E.New("`auto_redirect` with `route_[_exclude]_address_set is conflict with `default_mark`") } diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index fe4c7c12..047a2514 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -13,16 +13,16 @@ import ( ) func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) { - router := service.FromContext[adapter.Router](ctx) + networkManager := service.FromContext[adapter.NetworkManager](ctx) if options.IsWireGuardListener { - return NewDefault(router, options) + return NewDefault(networkManager, options) } var ( dialer N.Dialer err error ) if options.Detour == "" { - dialer, err = NewDefault(router, options) + dialer, err = NewDefault(networkManager, options) if err != nil { return nil, err } @@ -33,16 +33,19 @@ func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) { } dialer = NewDetour(outboundManager, options.Detour) } - if router == nil { - return NewDefault(router, options) + if networkManager == nil { + return NewDefault(networkManager, options) } if options.Detour == "" { - dialer = NewResolveDialer( - router, - dialer, - options.Detour == "" && !options.TCPFastOpen, - dns.DomainStrategy(options.DomainStrategy), - time.Duration(options.FallbackDelay)) + router := service.FromContext[adapter.Router](ctx) + if router != nil { + dialer = NewResolveDialer( + router, + dialer, + options.Detour == "" && !options.TCPFastOpen, + dns.DomainStrategy(options.DomainStrategy), + time.Duration(options.FallbackDelay)) + } } return dialer, nil } diff --git a/common/settings/proxy_darwin.go b/common/settings/proxy_darwin.go index 2d86fa2d..17e82cf5 100644 --- a/common/settings/proxy_darwin.go +++ b/common/settings/proxy_darwin.go @@ -25,7 +25,7 @@ type DarwinSystemProxy struct { } func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bool) (*DarwinSystemProxy, error) { - interfaceMonitor := service.FromContext[adapter.Router](ctx).InterfaceMonitor() + interfaceMonitor := service.FromContext[adapter.NetworkManager](ctx).InterfaceMonitor() if interfaceMonitor == nil { return nil, E.New("missing interface monitor") } diff --git a/experimental/libbox/command_group.go b/experimental/libbox/command_group.go index 0915f56b..684cac62 100644 --- a/experimental/libbox/command_group.go +++ b/experimental/libbox/command_group.go @@ -109,7 +109,7 @@ func readGroups(reader io.Reader) (OutboundGroupIterator, error) { func writeGroups(writer io.Writer, boxService *BoxService) error { historyStorage := service.PtrFromContext[urltest.HistoryStorage](boxService.ctx) cacheFile := service.FromContext[adapter.CacheFile](boxService.ctx) - outbounds := boxService.instance.Router().Outbounds() + outbounds := boxService.instance.Outbound().Outbounds() var iGroups []adapter.OutboundGroup for _, it := range outbounds { if group, isGroup := it.(adapter.OutboundGroup); isGroup { @@ -130,7 +130,7 @@ func writeGroups(writer io.Writer, boxService *BoxService) error { } for _, itemTag := range iGroup.All() { - itemOutbound, isLoaded := boxService.instance.Router().Outbound(itemTag) + itemOutbound, isLoaded := boxService.instance.Outbound().Outbound(itemTag) if !isLoaded { continue } diff --git a/experimental/libbox/command_select.go b/experimental/libbox/command_select.go index f352005d..6dd74a2d 100644 --- a/experimental/libbox/command_select.go +++ b/experimental/libbox/command_select.go @@ -43,7 +43,7 @@ func (s *CommandServer) handleSelectOutbound(conn net.Conn) error { if service == nil { return writeError(conn, E.New("service not ready")) } - outboundGroup, isLoaded := service.instance.Router().Outbound(groupTag) + outboundGroup, isLoaded := service.instance.Outbound().Outbound(groupTag) if !isLoaded { return writeError(conn, E.New("selector not found: ", groupTag)) } diff --git a/experimental/libbox/command_urltest.go b/experimental/libbox/command_urltest.go index c72ded8a..c30d996e 100644 --- a/experimental/libbox/command_urltest.go +++ b/experimental/libbox/command_urltest.go @@ -41,7 +41,7 @@ func (s *CommandServer) handleURLTest(conn net.Conn) error { if serviceNow == nil { return nil } - abstractOutboundGroup, isLoaded := serviceNow.instance.Router().Outbound(groupTag) + abstractOutboundGroup, isLoaded := serviceNow.instance.Outbound().Outbound(groupTag) if !isLoaded { return writeError(conn, E.New("outbound group not found: ", groupTag)) } @@ -55,7 +55,7 @@ func (s *CommandServer) handleURLTest(conn net.Conn) error { } else { historyStorage := service.PtrFromContext[urltest.HistoryStorage](serviceNow.ctx) outbounds := common.Filter(common.Map(outboundGroup.All(), func(it string) adapter.Outbound { - itOutbound, _ := serviceNow.instance.Router().Outbound(it) + itOutbound, _ := serviceNow.instance.Outbound().Outbound(it) return itOutbound }), func(it adapter.Outbound) bool { if it == nil { diff --git a/experimental/libbox/config.go b/experimental/libbox/config.go index 53889d30..a102a3d8 100644 --- a/experimental/libbox/config.go +++ b/experimental/libbox/config.go @@ -50,7 +50,7 @@ func CheckConfig(configContent string) error { type platformInterfaceStub struct{} -func (s *platformInterfaceStub) Initialize(ctx context.Context, router adapter.Router) error { +func (s *platformInterfaceStub) Initialize(networkManager adapter.NetworkManager) error { return nil } diff --git a/experimental/libbox/monitor.go b/experimental/libbox/monitor.go index 60fc2b8e..18c8b009 100644 --- a/experimental/libbox/monitor.go +++ b/experimental/libbox/monitor.go @@ -115,7 +115,7 @@ func (m *platformDefaultInterfaceMonitor) UpdateDefaultInterface(interfaceName s err = m.updateInterfaces() } if err == nil { - err = m.router.UpdateInterfaces() + err = m.networkManager.UpdateInterfaces() } if err != nil { m.logger.Error(E.Cause(err, "update interfaces")) diff --git a/experimental/libbox/platform/interface.go b/experimental/libbox/platform/interface.go index 3dc00081..16eab5ab 100644 --- a/experimental/libbox/platform/interface.go +++ b/experimental/libbox/platform/interface.go @@ -1,8 +1,6 @@ package platform import ( - "context" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/process" "github.com/sagernet/sing-box/option" @@ -12,7 +10,7 @@ import ( ) type Interface interface { - Initialize(ctx context.Context, router adapter.Router) error + Initialize(networkManager adapter.NetworkManager) error UsePlatformAutoDetectInterfaceControl() bool AutoDetectInterfaceControl(fd int) error OpenTun(options *tun.Options, platformOptions option.TunPlatformOptions) (tun.Tun, error) diff --git a/experimental/libbox/service.go b/experimental/libbox/service.go index dcb0370f..d1c2b4d2 100644 --- a/experimental/libbox/service.go +++ b/experimental/libbox/service.go @@ -104,13 +104,13 @@ var ( ) type platformInterfaceWrapper struct { - iif PlatformInterface - useProcFS bool - router adapter.Router + iif PlatformInterface + useProcFS bool + networkManager adapter.NetworkManager } -func (w *platformInterfaceWrapper) Initialize(ctx context.Context, router adapter.Router) error { - w.router = router +func (w *platformInterfaceWrapper) Initialize(networkManager adapter.NetworkManager) error { + w.networkManager = networkManager return nil } diff --git a/experimental/libbox/service_pause.go b/experimental/libbox/service_pause.go index c4aa8daa..a7599a76 100644 --- a/experimental/libbox/service_pause.go +++ b/experimental/libbox/service_pause.go @@ -29,5 +29,5 @@ func (s *BoxService) Wake() { } func (s *BoxService) ResetNetwork() { - _ = s.instance.Router().ResetNetwork() + s.instance.Router().ResetNetwork() } diff --git a/include/clashapi_stub.go b/include/clashapi_stub.go index 26e199b3..e7d5304f 100644 --- a/include/clashapi_stub.go +++ b/include/clashapi_stub.go @@ -13,7 +13,7 @@ import ( ) func init() { - experimental.RegisterClashServerConstructor(func(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { + experimental.RegisterClashServerConstructor(func(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { return nil, E.New(`clash api is not included in this build, rebuild with -tags with_clash_api`) }) } diff --git a/protocol/direct/loopback_detect.go b/protocol/direct/loopback_detect.go index 5a184e69..4bc1be3b 100644 --- a/protocol/direct/loopback_detect.go +++ b/protocol/direct/loopback_detect.go @@ -11,18 +11,18 @@ import ( ) type loopBackDetector struct { - router adapter.Router + networkManager adapter.NetworkManager connAccess sync.RWMutex packetConnAccess sync.RWMutex connMap map[netip.AddrPort]netip.AddrPort packetConnMap map[uint16]uint16 } -func newLoopBackDetector(router adapter.Router) *loopBackDetector { +func newLoopBackDetector(networkManager adapter.NetworkManager) *loopBackDetector { return &loopBackDetector{ - router: router, - connMap: make(map[netip.AddrPort]netip.AddrPort), - packetConnMap: make(map[uint16]uint16), + networkManager: networkManager, + connMap: make(map[netip.AddrPort]netip.AddrPort), + packetConnMap: make(map[uint16]uint16), } } @@ -33,7 +33,7 @@ func (l *loopBackDetector) NewConn(conn net.Conn) net.Conn { } if udpConn, isUDPConn := conn.(abstractUDPConn); isUDPConn { if !source.Addr().IsLoopback() { - _, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr()) + _, err := l.networkManager.InterfaceFinder().InterfaceByAddr(source.Addr()) if err != nil { return conn } @@ -59,7 +59,7 @@ func (l *loopBackDetector) NewPacketConn(conn N.NetPacketConn, destination M.Soc return conn } if !source.Addr().IsLoopback() { - _, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr()) + _, err := l.networkManager.InterfaceFinder().InterfaceByAddr(source.Addr()) if err != nil { return conn } @@ -82,7 +82,7 @@ func (l *loopBackDetector) CheckPacketConn(source netip.AddrPort, local netip.Ad return false } if !source.Addr().IsLoopback() { - _, err := l.router.InterfaceFinder().InterfaceByAddr(source.Addr()) + _, err := l.networkManager.InterfaceFinder().InterfaceByAddr(source.Addr()) if err != nil { return false } diff --git a/protocol/tun/inbound.go b/protocol/tun/inbound.go index f2476223..302afb57 100644 --- a/protocol/tun/inbound.go +++ b/protocol/tun/inbound.go @@ -36,10 +36,11 @@ func RegisterInbound(registry *inbound.Registry) { } type Inbound struct { - tag string - ctx context.Context - router adapter.Router - logger log.ContextLogger + tag string + ctx context.Context + router adapter.Router + networkManager adapter.NetworkManager + logger log.ContextLogger // Deprecated inboundOptions option.InboundOptions tunOptions tun.Options @@ -168,11 +169,12 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if outputMark == 0 { outputMark = tun.DefaultAutoRedirectOutputMark } - + networkManager := service.FromContext[adapter.NetworkManager](ctx) inbound := &Inbound{ tag: tag, ctx: ctx, router: router, + networkManager: networkManager, logger: logger, inboundOptions: options.InboundOptions, tunOptions: tun.Options{ @@ -198,7 +200,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo IncludeAndroidUser: options.IncludeAndroidUser, IncludePackage: options.IncludePackage, ExcludePackage: options.ExcludePackage, - InterfaceMonitor: router.InterfaceMonitor(), + InterfaceMonitor: networkManager.InterfaceMonitor(), }, endpointIndependentNat: options.EndpointIndependentNat, udpTimeout: udpTimeout, @@ -216,8 +218,8 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo Context: ctx, Handler: (*autoRedirectHandler)(inbound), Logger: logger, - NetworkMonitor: router.NetworkMonitor(), - InterfaceFinder: router.InterfaceFinder(), + NetworkMonitor: networkManager.NetworkMonitor(), + InterfaceFinder: networkManager.InterfaceFinder(), TableName: "sing-box", DisableNFTables: dErr == nil && disableNFTables, RouteAddressSet: &inbound.routeAddressSet, @@ -248,7 +250,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo } if markMode { inbound.tunOptions.AutoRedirectMarkMode = true - err = router.RegisterAutoRedirectOutputMark(inbound.tunOptions.AutoRedirectOutputMark) + err = networkManager.RegisterAutoRedirectOutputMark(inbound.tunOptions.AutoRedirectOutputMark) if err != nil { return nil, err } @@ -300,7 +302,7 @@ func (t *Inbound) Tag() string { func (t *Inbound) Start() error { if C.IsAndroid && t.platformInterface == nil { - t.tunOptions.BuildAndroidRules(t.router.PackageManager()) + t.tunOptions.BuildAndroidRules(t.networkManager.PackageManager()) } if t.tunOptions.Name == "" { t.tunOptions.Name = tun.CalculateInterfaceName("") @@ -338,7 +340,7 @@ func (t *Inbound) Start() error { Handler: t, Logger: t.logger, ForwarderBindInterface: forwarderBindInterface, - InterfaceFinder: t.router.InterfaceFinder(), + InterfaceFinder: t.networkManager.InterfaceFinder(), IncludeAllNetworks: includeAllNetworks, }) if err != nil { diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index 90f76c1a..7f33bf60 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -100,7 +100,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if !options.SystemInterface && tun.WithGVisor { wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu) } else { - wireTunDevice, err = wireguard.NewSystemDevice(router, options.InterfaceName, options.LocalAddress, mtu, options.GSO) + wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO) } if err != nil { return nil, E.Cause(err, "create WireGuard device") diff --git a/route/geo_resources.go b/route/geo_resources.go index 91c06796..c5a45ffd 100644 --- a/route/geo_resources.go +++ b/route/geo_resources.go @@ -33,7 +33,7 @@ func (r *Router) LoadGeosite(code string) (adapter.Rule, error) { if err != nil { return nil, err } - rule, err = R.NewDefaultRule(r.ctx, r, nil, geosite.Compile(items)) + rule, err = R.NewDefaultRule(r.ctx, nil, geosite.Compile(items)) if err != nil { return nil, err } @@ -145,13 +145,13 @@ func (r *Router) downloadGeoIPDatabase(savePath string) error { r.logger.Info("downloading geoip database") var detour adapter.Outbound if r.geoIPOptions.DownloadDetour != "" { - outbound, loaded := r.Outbound(r.geoIPOptions.DownloadDetour) + outbound, loaded := r.outboundManager.Outbound(r.geoIPOptions.DownloadDetour) if !loaded { return E.New("detour outbound not found: ", r.geoIPOptions.DownloadDetour) } detour = outbound } else { - detour = r.defaultOutboundForConnection + detour = r.outboundManager.Default() } if parentDir := filepath.Dir(savePath); parentDir != "" { @@ -200,13 +200,13 @@ func (r *Router) downloadGeositeDatabase(savePath string) error { r.logger.Info("downloading geosite database") var detour adapter.Outbound if r.geositeOptions.DownloadDetour != "" { - outbound, loaded := r.Outbound(r.geositeOptions.DownloadDetour) + outbound, loaded := r.outboundManager.Outbound(r.geositeOptions.DownloadDetour) if !loaded { return E.New("detour outbound not found: ", r.geositeOptions.DownloadDetour) } detour = outbound } else { - detour = r.defaultOutboundForConnection + detour = r.outboundManager.Default() } if parentDir := filepath.Dir(savePath); parentDir != "" { diff --git a/route/network.go b/route/network.go new file mode 100644 index 00000000..b0caa24c --- /dev/null +++ b/route/network.go @@ -0,0 +1,337 @@ +package route + +import ( + "context" + "errors" + "net/netip" + "os" + "runtime" + "syscall" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/conntrack" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/experimental/libbox/platform" + "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common/control" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" + "github.com/sagernet/sing/common/winpowrprof" + "github.com/sagernet/sing/service" + "github.com/sagernet/sing/service/pause" +) + +var _ adapter.NetworkManager = (*NetworkManager)(nil) + +type NetworkManager struct { + logger logger.ContextLogger + interfaceFinder *control.DefaultInterfaceFinder + autoDetectInterface bool + defaultInterface string + defaultMark uint32 + autoRedirectOutputMark uint32 + networkMonitor tun.NetworkUpdateMonitor + interfaceMonitor tun.DefaultInterfaceMonitor + packageManager tun.PackageManager + powerListener winpowrprof.EventListener + pauseManager pause.Manager + platformInterface platform.Interface + outboundManager adapter.OutboundManager + wifiState adapter.WIFIState + started bool +} + +func NewNetworkManager(ctx context.Context, logger logger.ContextLogger, routeOptions option.RouteOptions) (*NetworkManager, error) { + nm := &NetworkManager{ + logger: logger, + interfaceFinder: control.NewDefaultInterfaceFinder(), + autoDetectInterface: routeOptions.AutoDetectInterface, + defaultInterface: routeOptions.DefaultInterface, + defaultMark: routeOptions.DefaultMark, + pauseManager: service.FromContext[pause.Manager](ctx), + platformInterface: service.FromContext[platform.Interface](ctx), + outboundManager: service.FromContext[adapter.OutboundManager](ctx), + } + usePlatformDefaultInterfaceMonitor := nm.platformInterface != nil && nm.platformInterface.UsePlatformDefaultInterfaceMonitor() + enforceInterfaceMonitor := routeOptions.AutoDetectInterface + if !usePlatformDefaultInterfaceMonitor { + networkMonitor, err := tun.NewNetworkUpdateMonitor(logger) + if !((err != nil && !enforceInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) { + if err != nil { + return nil, E.Cause(err, "create network monitor") + } + nm.networkMonitor = networkMonitor + networkMonitor.RegisterCallback(func() { + _ = nm.interfaceFinder.Update() + }) + interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(nm.networkMonitor, logger, tun.DefaultInterfaceMonitorOptions{ + InterfaceFinder: nm.interfaceFinder, + OverrideAndroidVPN: routeOptions.OverrideAndroidVPN, + UnderNetworkExtension: nm.platformInterface != nil && nm.platformInterface.UnderNetworkExtension(), + }) + if err != nil { + return nil, E.New("auto_detect_interface unsupported on current platform") + } + interfaceMonitor.RegisterCallback(nm.notifyNetworkUpdate) + nm.interfaceMonitor = interfaceMonitor + } + } else { + interfaceMonitor := nm.platformInterface.CreateDefaultInterfaceMonitor(logger) + interfaceMonitor.RegisterCallback(nm.notifyNetworkUpdate) + nm.interfaceMonitor = interfaceMonitor + } + return nm, nil +} + +func (r *NetworkManager) Start(stage adapter.StartStage) error { + monitor := taskmonitor.New(r.logger, C.StartTimeout) + switch stage { + case adapter.StartStateInitialize: + if r.interfaceMonitor != nil { + monitor.Start("initialize interface monitor") + err := r.interfaceMonitor.Start() + monitor.Finish() + if err != nil { + return err + } + } + if r.networkMonitor != nil { + monitor.Start("initialize network monitor") + err := r.networkMonitor.Start() + monitor.Finish() + if err != nil { + return err + } + } + case adapter.StartStateStart: + if runtime.GOOS == "windows" { + powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent) + if err == nil { + r.powerListener = powerListener + } else { + r.logger.Warn("initialize power listener: ", err) + } + } + if r.powerListener != nil { + monitor.Start("start power listener") + err := r.powerListener.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "start power listener") + } + } + if C.IsAndroid && r.platformInterface == nil { + monitor.Start("initialize package manager") + packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{ + Callback: r, + Logger: r.logger, + }) + monitor.Finish() + if err != nil { + return E.Cause(err, "create package manager") + } + monitor.Start("start package manager") + err = packageManager.Start() + monitor.Finish() + if err != nil { + r.logger.Warn("initialize package manager: ", err) + } else { + r.packageManager = packageManager + } + } + case adapter.StartStatePostStart: + r.started = true + } + return nil +} + +func (r *NetworkManager) Close() error { + monitor := taskmonitor.New(r.logger, C.StopTimeout) + var err error + if r.interfaceMonitor != nil { + monitor.Start("close interface monitor") + err = E.Append(err, r.interfaceMonitor.Close(), func(err error) error { + return E.Cause(err, "close interface monitor") + }) + monitor.Finish() + } + if r.networkMonitor != nil { + monitor.Start("close network monitor") + err = E.Append(err, r.networkMonitor.Close(), func(err error) error { + return E.Cause(err, "close network monitor") + }) + monitor.Finish() + } + if r.packageManager != nil { + monitor.Start("close package manager") + err = E.Append(err, r.packageManager.Close(), func(err error) error { + return E.Cause(err, "close package manager") + }) + monitor.Finish() + } + if r.powerListener != nil { + monitor.Start("close power listener") + err = E.Append(err, r.powerListener.Close(), func(err error) error { + return E.Cause(err, "close power listener") + }) + monitor.Finish() + } + return nil +} + +func (r *NetworkManager) InterfaceFinder() control.InterfaceFinder { + return r.interfaceFinder +} + +func (r *NetworkManager) UpdateInterfaces() error { + if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() { + return r.interfaceFinder.Update() + } else { + interfaces, err := r.platformInterface.Interfaces() + if err != nil { + return err + } + r.interfaceFinder.UpdateInterfaces(interfaces) + return nil + } +} + +func (r *NetworkManager) DefaultInterface() string { + return r.defaultInterface +} + +func (r *NetworkManager) AutoDetectInterface() bool { + return r.autoDetectInterface +} + +func (r *NetworkManager) AutoDetectInterfaceFunc() control.Func { + if r.platformInterface != nil && r.platformInterface.UsePlatformAutoDetectInterfaceControl() { + return func(network, address string, conn syscall.RawConn) error { + return control.Raw(conn, func(fd uintptr) error { + return r.platformInterface.AutoDetectInterfaceControl(int(fd)) + }) + } + } else { + if r.interfaceMonitor == nil { + return nil + } + return control.BindToInterfaceFunc(r.interfaceFinder, func(network string, address string) (interfaceName string, interfaceIndex int, err error) { + remoteAddr := M.ParseSocksaddr(address).Addr + if C.IsLinux { + interfaceName, interfaceIndex = r.interfaceMonitor.DefaultInterface(remoteAddr) + if interfaceIndex == -1 { + err = tun.ErrNoRoute + } + } else { + interfaceIndex = r.interfaceMonitor.DefaultInterfaceIndex(remoteAddr) + if interfaceIndex == -1 { + err = tun.ErrNoRoute + } + } + return + }) + } +} + +func (r *NetworkManager) DefaultMark() uint32 { + return r.defaultMark +} + +func (r *NetworkManager) RegisterAutoRedirectOutputMark(mark uint32) error { + if r.autoRedirectOutputMark > 0 { + return E.New("only one auto-redirect can be configured") + } + r.autoRedirectOutputMark = mark + return nil +} + +func (r *NetworkManager) AutoRedirectOutputMark() uint32 { + return r.autoRedirectOutputMark +} + +func (r *NetworkManager) NetworkMonitor() tun.NetworkUpdateMonitor { + return r.networkMonitor +} + +func (r *NetworkManager) InterfaceMonitor() tun.DefaultInterfaceMonitor { + return r.interfaceMonitor +} + +func (r *NetworkManager) PackageManager() tun.PackageManager { + return r.packageManager +} + +func (r *NetworkManager) WIFIState() adapter.WIFIState { + return r.wifiState +} + +func (r *NetworkManager) ResetNetwork() { + conntrack.Close() + + for _, outbound := range r.outboundManager.Outbounds() { + listener, isListener := outbound.(adapter.InterfaceUpdateListener) + if isListener { + listener.InterfaceUpdated() + } + } +} + +func (r *NetworkManager) notifyNetworkUpdate(event int) { + if event == tun.EventNoRoute { + r.pauseManager.NetworkPause() + r.logger.Error("missing default interface") + } else { + r.pauseManager.NetworkWake() + if C.IsAndroid && r.platformInterface == nil { + var vpnStatus string + if r.interfaceMonitor.AndroidVPNEnabled() { + vpnStatus = "enabled" + } else { + vpnStatus = "disabled" + } + r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()), ", vpn ", vpnStatus) + } else { + r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified())) + } + if r.platformInterface != nil { + state := r.platformInterface.ReadWIFIState() + if state != r.wifiState { + r.wifiState = state + if state.SSID == "" && state.BSSID == "" { + r.logger.Info("updated WIFI state: disconnected") + } else { + r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID) + } + } + } + } + + if !r.started { + return + } + + r.ResetNetwork() +} + +func (r *NetworkManager) notifyWindowsPowerEvent(event int) { + switch event { + case winpowrprof.EVENT_SUSPEND: + r.pauseManager.DevicePause() + r.ResetNetwork() + case winpowrprof.EVENT_RESUME: + if !r.pauseManager.IsDevicePaused() { + return + } + fallthrough + case winpowrprof.EVENT_RESUME_AUTOMATIC: + r.pauseManager.DeviceWake() + r.ResetNetwork() + } +} + +func (r *NetworkManager) OnPackagesUpdated(packages int, sharedUsers int) { + r.logger.Info("updated packages list: ", packages, " packages, ", sharedUsers, " shared users") +} diff --git a/route/route.go b/route/route.go index 8fcc3f92..3897e4c6 100644 --- a/route/route.go +++ b/route/route.go @@ -58,8 +58,8 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if metadata.LastInbound == metadata.InboundDetour { return E.New("routing loop on detour: ", metadata.InboundDetour) } - detour := r.inboundByTag[metadata.InboundDetour] - if detour == nil { + detour, loaded := r.inboundManager.Get(metadata.InboundDetour) + if !loaded { return E.New("inbound detour not found: ", metadata.InboundDetour) } injectable, isInjectable := detour.(adapter.TCPInjectableInbound) @@ -100,7 +100,7 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad if selectedRule != nil { switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: - selectedOutbound, loaded := r.Outbound(action.Outbound) + selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound) if !loaded { buf.ReleaseMulti(buffers) return E.New("outbound not found: ", action.Outbound) @@ -128,13 +128,14 @@ func (r *Router) routeConnection(ctx context.Context, conn net.Conn, metadata ad } } if selectedRule == nil { - if r.defaultOutboundForConnection == nil { + defaultOutbound := r.outboundManager.Default() + if !common.Contains(defaultOutbound.Network(), N.NetworkTCP) { buf.ReleaseMulti(buffers) - return E.New("missing default outbound with TCP support") + return E.New("TCP is not supported by default outbound: ", defaultOutbound.Tag()) } - selectedDialer = r.defaultOutboundForConnection - selectedTag = r.defaultOutboundForConnection.Tag() - selectedDescription = F.ToString("outbound/", r.defaultOutboundForConnection.Type(), "[", r.defaultOutboundForConnection.Tag(), "]") + selectedDialer = defaultOutbound + selectedTag = defaultOutbound.Tag() + selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]") } for _, buffer := range buffers { @@ -217,8 +218,8 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m if metadata.LastInbound == metadata.InboundDetour { return E.New("routing loop on detour: ", metadata.InboundDetour) } - detour := r.inboundByTag[metadata.InboundDetour] - if detour == nil { + detour, loaded := r.inboundManager.Get(metadata.InboundDetour) + if !loaded { return E.New("inbound detour not found: ", metadata.InboundDetour) } injectable, isInjectable := detour.(adapter.UDPInjectableInbound) @@ -254,7 +255,7 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m if selectedRule != nil { switch action := selectedRule.Action().(type) { case *rule.RuleActionRoute: - selectedOutbound, loaded := r.Outbound(action.Outbound) + selectedOutbound, loaded := r.outboundManager.Outbound(action.Outbound) if !loaded { N.ReleaseMultiPacketBuffer(packetBuffers) return E.New("outbound not found: ", action.Outbound) @@ -279,13 +280,14 @@ func (r *Router) routePacketConnection(ctx context.Context, conn N.PacketConn, m } } if selectedRule == nil || selectReturn { - if r.defaultOutboundForPacketConnection == nil { + defaultOutbound := r.outboundManager.Default() + if !common.Contains(defaultOutbound.Network(), N.NetworkUDP) { N.ReleaseMultiPacketBuffer(packetBuffers) - return E.New("missing default outbound with UDP support") + return E.New("UDP is not supported by outbound: ", defaultOutbound.Tag()) } - selectedDialer = r.defaultOutboundForPacketConnection - selectedTag = r.defaultOutboundForPacketConnection.Tag() - selectedDescription = F.ToString("outbound/", r.defaultOutboundForPacketConnection.Type(), "[", r.defaultOutboundForPacketConnection.Tag(), "]") + selectedDialer = defaultOutbound + selectedTag = defaultOutbound.Tag() + selectedDescription = F.ToString("outbound/", defaultOutbound.Type(), "[", defaultOutbound.Tag(), "]") } for _, buffer := range packetBuffers { conn = bufio.NewCachedPacketConn(conn, buffer.Buffer, buffer.Destination) diff --git a/route/router.go b/route/router.go index 557b043a..62b37447 100644 --- a/route/router.go +++ b/route/router.go @@ -2,17 +2,14 @@ package route import ( "context" - "errors" "net/netip" "net/url" "os" "runtime" "strings" - "syscall" "time" "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/conntrack" "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/geoip" "github.com/sagernet/sing-box/common/geosite" @@ -25,16 +22,13 @@ import ( R "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-box/transport/fakeip" "github.com/sagernet/sing-dns" - "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common" - "github.com/sagernet/sing/common/control" E "github.com/sagernet/sing/common/exceptions" F "github.com/sagernet/sing/common/format" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/task" - "github.com/sagernet/sing/common/winpowrprof" "github.com/sagernet/sing/service" "github.com/sagernet/sing/service/pause" ) @@ -42,69 +36,50 @@ import ( var _ adapter.Router = (*Router)(nil) type Router struct { - ctx context.Context - logger log.ContextLogger - dnsLogger log.ContextLogger - inboundByTag map[string]adapter.Inbound - outbounds []adapter.Outbound - outboundByTag map[string]adapter.Outbound - rules []adapter.Rule - defaultDetour string - defaultOutboundForConnection adapter.Outbound - defaultOutboundForPacketConnection adapter.Outbound - needGeoIPDatabase bool - needGeositeDatabase bool - geoIPOptions option.GeoIPOptions - geositeOptions option.GeositeOptions - geoIPReader *geoip.Reader - geositeReader *geosite.Reader - geositeCache map[string]adapter.Rule - needFindProcess bool - dnsClient *dns.Client - defaultDomainStrategy dns.DomainStrategy - dnsRules []adapter.DNSRule - ruleSets []adapter.RuleSet - ruleSetMap map[string]adapter.RuleSet - defaultTransport dns.Transport - transports []dns.Transport - transportMap map[string]dns.Transport - transportDomainStrategy map[dns.Transport]dns.DomainStrategy - dnsReverseMapping *DNSReverseMapping - fakeIPStore adapter.FakeIPStore - interfaceFinder *control.DefaultInterfaceFinder - autoDetectInterface bool - defaultInterface string - defaultMark uint32 - autoRedirectOutputMark uint32 - networkMonitor tun.NetworkUpdateMonitor - interfaceMonitor tun.DefaultInterfaceMonitor - packageManager tun.PackageManager - powerListener winpowrprof.EventListener - processSearcher process.Searcher - timeService *ntp.Service - pauseManager pause.Manager - clashServer adapter.ClashServer - v2rayServer adapter.V2RayServer - platformInterface platform.Interface - needWIFIState bool - enforcePackageManager bool - wifiState adapter.WIFIState - started bool + ctx context.Context + logger log.ContextLogger + dnsLogger log.ContextLogger + inboundManager adapter.InboundManager + outboundManager adapter.OutboundManager + networkManager adapter.NetworkManager + rules []adapter.Rule + needGeoIPDatabase bool + needGeositeDatabase bool + geoIPOptions option.GeoIPOptions + geositeOptions option.GeositeOptions + geoIPReader *geoip.Reader + geositeReader *geosite.Reader + geositeCache map[string]adapter.Rule + needFindProcess bool + dnsClient *dns.Client + defaultDomainStrategy dns.DomainStrategy + dnsRules []adapter.DNSRule + ruleSets []adapter.RuleSet + ruleSetMap map[string]adapter.RuleSet + defaultTransport dns.Transport + transports []dns.Transport + transportMap map[string]dns.Transport + transportDomainStrategy map[dns.Transport]dns.DomainStrategy + dnsReverseMapping *DNSReverseMapping + fakeIPStore adapter.FakeIPStore + processSearcher process.Searcher + timeService *ntp.Service + pauseManager pause.Manager + clashServer adapter.ClashServer + v2rayServer adapter.V2RayServer + platformInterface platform.Interface + needWIFIState bool + started bool } -func NewRouter( - ctx context.Context, - logFactory log.Factory, - options option.RouteOptions, - dnsOptions option.DNSOptions, - ntpOptions option.NTPOptions, - inbounds []option.Inbound, -) (*Router, error) { +func NewRouter(ctx context.Context, logFactory log.Factory, options option.RouteOptions, dnsOptions option.DNSOptions, ntpOptions option.NTPOptions) (*Router, error) { router := &Router{ ctx: ctx, logger: logFactory.NewLogger("router"), dnsLogger: logFactory.NewLogger("dns"), - outboundByTag: make(map[string]adapter.Outbound), + inboundManager: service.FromContext[adapter.InboundManager](ctx), + outboundManager: service.FromContext[adapter.OutboundManager](ctx), + networkManager: service.FromContext[adapter.NetworkManager](ctx), rules: make([]adapter.Rule, 0, len(options.Rules)), dnsRules: make([]adapter.DNSRule, 0, len(dnsOptions.Rules)), ruleSetMap: make(map[string]adapter.RuleSet), @@ -114,22 +89,12 @@ func NewRouter( geositeOptions: common.PtrValueOrDefault(options.Geosite), geositeCache: make(map[string]adapter.Rule), needFindProcess: hasRule(options.Rules, isProcessRule) || hasDNSRule(dnsOptions.Rules, isProcessDNSRule) || options.FindProcess, - defaultDetour: options.Final, defaultDomainStrategy: dns.DomainStrategy(dnsOptions.Strategy), - interfaceFinder: control.NewDefaultInterfaceFinder(), - autoDetectInterface: options.AutoDetectInterface, - defaultInterface: options.DefaultInterface, - defaultMark: options.DefaultMark, pauseManager: service.FromContext[pause.Manager](ctx), platformInterface: service.FromContext[platform.Interface](ctx), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), - enforcePackageManager: common.Any(inbounds, func(inbound option.Inbound) bool { - if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute { - return true - } - return false - }), } + ctx = service.ContextWith[adapter.Router](ctx, router) router.dnsClient = dns.NewClient(dns.ClientOptions{ DisableCache: dnsOptions.DNSClientOptions.DisableCache, DisableExpire: dnsOptions.DNSClientOptions.DisableExpire, @@ -147,14 +112,14 @@ func NewRouter( Logger: router.dnsLogger, }) for i, ruleOptions := range options.Rules { - routeRule, err := R.NewRule(ctx, router, router.logger, ruleOptions, true) + routeRule, err := R.NewRule(ctx, router.logger, ruleOptions, true) if err != nil { return nil, E.Cause(err, "parse rule[", i, "]") } router.rules = append(router.rules, routeRule) } for i, dnsRuleOptions := range dnsOptions.Rules { - dnsRule, err := R.NewDNSRule(ctx, router, router.logger, dnsRuleOptions, true) + dnsRule, err := R.NewDNSRule(ctx, router.logger, dnsRuleOptions, true) if err != nil { return nil, E.Cause(err, "parse dns rule[", i, "]") } @@ -164,7 +129,7 @@ func NewRouter( if _, exists := router.ruleSetMap[ruleSetOptions.Tag]; exists { return nil, E.New("duplicate rule-set tag: ", ruleSetOptions.Tag) } - ruleSet, err := R.NewRuleSet(ctx, router, router.logger, ruleSetOptions) + ruleSet, err := R.NewRuleSet(ctx, router.logger, ruleSetOptions) if err != nil { return nil, E.Cause(err, "parse rule-set[", i, "]") } @@ -191,7 +156,6 @@ func NewRouter( transportTags[i] = tag transportTagMap[tag] = true } - ctx = service.ContextWith[adapter.Router](ctx, router) outboundManager := service.FromContext[adapter.OutboundManager](ctx) for { lastLen := len(dummyTransportMap) @@ -298,7 +262,7 @@ func NewRouter( Context: ctx, Name: "local", Address: "local", - Dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{})), + Dialer: common.Must1(dialer.NewDefault(router.networkManager, option.DialerOptions{})), }))) } defaultTransport = transports[0] @@ -327,44 +291,6 @@ func NewRouter( router.fakeIPStore = fakeip.NewStore(ctx, router.logger, inet4Range, inet6Range) } - usePlatformDefaultInterfaceMonitor := router.platformInterface != nil && router.platformInterface.UsePlatformDefaultInterfaceMonitor() - enforceInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool { - if httpMixedOptions, isHTTPMixed := inbound.Options.(*option.HTTPMixedInboundOptions); isHTTPMixed && httpMixedOptions.SetSystemProxy { - return true - } - if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute { - return true - } - return false - }) - - if !usePlatformDefaultInterfaceMonitor { - networkMonitor, err := tun.NewNetworkUpdateMonitor(router.logger) - if !((err != nil && !enforceInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) { - if err != nil { - return nil, err - } - router.networkMonitor = networkMonitor - networkMonitor.RegisterCallback(func() { - _ = router.interfaceFinder.Update() - }) - interfaceMonitor, err := tun.NewDefaultInterfaceMonitor(router.networkMonitor, router.logger, tun.DefaultInterfaceMonitorOptions{ - InterfaceFinder: router.interfaceFinder, - OverrideAndroidVPN: options.OverrideAndroidVPN, - UnderNetworkExtension: router.platformInterface != nil && router.platformInterface.UnderNetworkExtension(), - }) - if err != nil { - return nil, E.New("auto_detect_interface unsupported on current platform") - } - interfaceMonitor.RegisterCallback(router.notifyNetworkUpdate) - router.interfaceMonitor = interfaceMonitor - } - } else { - interfaceMonitor := router.platformInterface.CreateDefaultInterfaceMonitor(router.logger) - interfaceMonitor.RegisterCallback(router.notifyNetworkUpdate) - router.interfaceMonitor = interfaceMonitor - } - if ntpOptions.Enabled { ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions) if err != nil { @@ -384,33 +310,10 @@ func NewRouter( return router, nil } -func (r *Router) Outbounds() []adapter.Outbound { - if !r.started { - return nil - } - return r.outbounds -} - func (r *Router) Start(stage adapter.StartStage) error { monitor := taskmonitor.New(r.logger, C.StartTimeout) switch stage { case adapter.StartStateInitialize: - if r.interfaceMonitor != nil { - monitor.Start("initialize interface monitor") - err := r.interfaceMonitor.Start() - monitor.Finish() - if err != nil { - return err - } - } - if r.networkMonitor != nil { - monitor.Start("initialize network monitor") - err := r.networkMonitor.Start() - monitor.Finish() - if err != nil { - return err - } - } if r.fakeIPStore != nil { monitor.Start("initialize fakeip store") err := r.fakeIPStore.Start() @@ -457,49 +360,10 @@ func (r *Router) Start(stage adapter.StartStage) error { r.geositeReader = nil } - if runtime.GOOS == "windows" { - powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent) - if err == nil { - r.powerListener = powerListener - } else { - r.logger.Warn("initialize power listener: ", err) - } - } - - if r.powerListener != nil { - monitor.Start("start power listener") - err := r.powerListener.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "start power listener") - } - } - monitor.Start("initialize DNS client") r.dnsClient.Start() monitor.Finish() - if C.IsAndroid && r.platformInterface == nil { - monitor.Start("initialize package manager") - packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{ - Callback: r, - Logger: r.logger, - }) - monitor.Finish() - if err != nil { - return E.Cause(err, "create package manager") - } - if r.enforcePackageManager { - monitor.Start("start package manager") - err = packageManager.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "start package manager") - } - } - r.packageManager = packageManager - } - for i, rule := range r.dnsRules { monitor.Start("initialize DNS rule[", i, "]") err := rule.Start() @@ -552,26 +416,13 @@ func (r *Router) Start(stage adapter.StartStage) error { cacheContext.Close() } needFindProcess := r.needFindProcess - needWIFIState := r.needWIFIState for _, ruleSet := range r.ruleSets { metadata := ruleSet.Metadata() if metadata.ContainsProcessRule { needFindProcess = true } if metadata.ContainsWIFIRule { - needWIFIState = true - } - } - if C.IsAndroid && r.platformInterface == nil && !r.enforcePackageManager { - if needFindProcess { - monitor.Start("start package manager") - err := r.packageManager.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "start package manager") - } - } else { - r.packageManager = nil + r.needWIFIState = true } } if needFindProcess { @@ -581,7 +432,7 @@ func (r *Router) Start(stage adapter.StartStage) error { monitor.Start("initialize process searcher") searcher, err := process.NewSearcher(process.Config{ Logger: r.logger, - PackageManager: r.packageManager, + PackageManager: r.networkManager.PackageManager(), }) monitor.Finish() if err != nil { @@ -593,15 +444,6 @@ func (r *Router) Start(stage adapter.StartStage) error { } } } - if needWIFIState && r.platformInterface != nil { - monitor.Start("initialize WIFI state") - r.needWIFIState = true - r.interfaceMonitor.RegisterCallback(func(_ int) { - r.updateWIFIState() - }) - r.updateWIFIState() - monitor.Finish() - } for i, rule := range r.rules { monitor.Start("initialize rule[", i, "]") err := rule.Start() @@ -660,34 +502,6 @@ func (r *Router) Close() error { }) monitor.Finish() } - if r.interfaceMonitor != nil { - monitor.Start("close interface monitor") - err = E.Append(err, r.interfaceMonitor.Close(), func(err error) error { - return E.Cause(err, "close interface monitor") - }) - monitor.Finish() - } - if r.networkMonitor != nil { - monitor.Start("close network monitor") - err = E.Append(err, r.networkMonitor.Close(), func(err error) error { - return E.Cause(err, "close network monitor") - }) - monitor.Finish() - } - if r.packageManager != nil { - monitor.Start("close package manager") - err = E.Append(err, r.packageManager.Close(), func(err error) error { - return E.Cause(err, "close package manager") - }) - monitor.Finish() - } - if r.powerListener != nil { - monitor.Start("close power listener") - err = E.Append(err, r.powerListener.Close(), func(err error) error { - return E.Cause(err, "close power listener") - }) - monitor.Finish() - } if r.timeService != nil { monitor.Start("close time service") err = E.Append(err, r.timeService.Close(), func(err error) error { @@ -705,25 +519,6 @@ func (r *Router) Close() error { return err } -func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { - outbound, loaded := r.outboundByTag[tag] - return outbound, loaded -} - -func (r *Router) DefaultOutbound(network string) (adapter.Outbound, error) { - if network == N.NetworkTCP { - if r.defaultOutboundForConnection == nil { - return nil, E.New("missing default outbound for TCP connections") - } - return r.defaultOutboundForConnection, nil - } else { - if r.defaultOutboundForPacketConnection == nil { - return nil, E.New("missing default outbound for UDP connections") - } - return r.defaultOutboundForPacketConnection, nil - } -} - func (r *Router) FakeIPStore() adapter.FakeIPStore { return r.fakeIPStore } @@ -737,96 +532,10 @@ func (r *Router) NeedWIFIState() bool { return r.needWIFIState } -func (r *Router) InterfaceFinder() control.InterfaceFinder { - return r.interfaceFinder -} - -func (r *Router) UpdateInterfaces() error { - if r.platformInterface == nil || !r.platformInterface.UsePlatformInterfaceGetter() { - return r.interfaceFinder.Update() - } else { - interfaces, err := r.platformInterface.Interfaces() - if err != nil { - return err - } - r.interfaceFinder.UpdateInterfaces(interfaces) - return nil - } -} - -func (r *Router) AutoDetectInterface() bool { - return r.autoDetectInterface -} - -func (r *Router) AutoDetectInterfaceFunc() control.Func { - if r.platformInterface != nil && r.platformInterface.UsePlatformAutoDetectInterfaceControl() { - return func(network, address string, conn syscall.RawConn) error { - return control.Raw(conn, func(fd uintptr) error { - return r.platformInterface.AutoDetectInterfaceControl(int(fd)) - }) - } - } else { - if r.interfaceMonitor == nil { - return nil - } - return control.BindToInterfaceFunc(r.InterfaceFinder(), func(network string, address string) (interfaceName string, interfaceIndex int, err error) { - remoteAddr := M.ParseSocksaddr(address).Addr - if C.IsLinux { - interfaceName, interfaceIndex = r.InterfaceMonitor().DefaultInterface(remoteAddr) - if interfaceIndex == -1 { - err = tun.ErrNoRoute - } - } else { - interfaceIndex = r.InterfaceMonitor().DefaultInterfaceIndex(remoteAddr) - if interfaceIndex == -1 { - err = tun.ErrNoRoute - } - } - return - }) - } -} - -func (r *Router) RegisterAutoRedirectOutputMark(mark uint32) error { - if r.autoRedirectOutputMark > 0 { - return E.New("only one auto-redirect can be configured") - } - r.autoRedirectOutputMark = mark - return nil -} - -func (r *Router) AutoRedirectOutputMark() uint32 { - return r.autoRedirectOutputMark -} - -func (r *Router) DefaultInterface() string { - return r.defaultInterface -} - -func (r *Router) DefaultMark() uint32 { - return r.defaultMark -} - func (r *Router) Rules() []adapter.Rule { return r.rules } -func (r *Router) WIFIState() adapter.WIFIState { - return r.wifiState -} - -func (r *Router) NetworkMonitor() tun.NetworkUpdateMonitor { - return r.networkMonitor -} - -func (r *Router) InterfaceMonitor() tun.DefaultInterfaceMonitor { - return r.interfaceMonitor -} - -func (r *Router) PackageManager() tun.PackageManager { - return r.packageManager -} - func (r *Router) ClashServer() adapter.ClashServer { return r.clashServer } @@ -843,10 +552,6 @@ func (r *Router) SetV2RayServer(server adapter.V2RayServer) { r.v2rayServer = server } -func (r *Router) OnPackagesUpdated(packages int, sharedUsers int) { - r.logger.Info("updated packages list: ", packages, " packages, ", sharedUsers, " shared users") -} - func (r *Router) NewError(ctx context.Context, err error) { common.Close(err) if E.IsClosedOrCanceled(err) { @@ -856,75 +561,9 @@ func (r *Router) NewError(ctx context.Context, err error) { r.logger.ErrorContext(ctx, err) } -func (r *Router) notifyNetworkUpdate(event int) { - if event == tun.EventNoRoute { - r.pauseManager.NetworkPause() - r.logger.Error("missing default interface") - } else { - r.pauseManager.NetworkWake() - if C.IsAndroid && r.platformInterface == nil { - var vpnStatus string - if r.interfaceMonitor.AndroidVPNEnabled() { - vpnStatus = "enabled" - } else { - vpnStatus = "disabled" - } - r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified()), ", vpn ", vpnStatus) - } else { - r.logger.Info("updated default interface ", r.interfaceMonitor.DefaultInterfaceName(netip.IPv4Unspecified()), ", index ", r.interfaceMonitor.DefaultInterfaceIndex(netip.IPv4Unspecified())) - } - } - - if !r.started { - return - } - - _ = r.ResetNetwork() -} - -func (r *Router) ResetNetwork() error { - conntrack.Close() - - for _, outbound := range r.outbounds { - listener, isListener := outbound.(adapter.InterfaceUpdateListener) - if isListener { - listener.InterfaceUpdated() - } - } - +func (r *Router) ResetNetwork() { + r.networkManager.ResetNetwork() for _, transport := range r.transports { transport.Reset() } - return nil -} - -func (r *Router) updateWIFIState() { - if r.platformInterface == nil { - return - } - state := r.platformInterface.ReadWIFIState() - if state != r.wifiState { - r.wifiState = state - if state.SSID == "" && state.BSSID == "" { - r.logger.Info("updated WIFI state: disconnected") - } else { - r.logger.Info("updated WIFI state: SSID=", state.SSID, ", BSSID=", state.BSSID) - } - } -} - -func (r *Router) notifyWindowsPowerEvent(event int) { - switch event { - case winpowrprof.EVENT_SUSPEND: - r.pauseManager.DevicePause() - _ = r.ResetNetwork() - case winpowrprof.EVENT_RESUME: - if !r.pauseManager.IsDevicePaused() { - return - } - fallthrough - case winpowrprof.EVENT_RESUME_AUTOMATIC: - r.pauseManager.DeviceWake() - _ = r.ResetNetwork() - } } diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index a12c63ef..33a8e16c 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -9,9 +9,10 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/service" ) -func NewRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.Rule, checkOutbound bool) (adapter.Rule, error) { +func NewRule(ctx context.Context, logger log.ContextLogger, options option.Rule, checkOutbound bool) (adapter.Rule, error) { switch options.Type { case "", C.RuleTypeDefault: if !options.DefaultOptions.IsValid() { @@ -23,7 +24,7 @@ func NewRule(ctx context.Context, router adapter.Router, logger log.ContextLogge return nil, E.New("missing outbound field") } } - return NewDefaultRule(ctx, router, logger, options.DefaultOptions) + return NewDefaultRule(ctx, logger, options.DefaultOptions) case C.RuleTypeLogical: if !options.LogicalOptions.IsValid() { return nil, E.New("missing conditions") @@ -34,7 +35,7 @@ func NewRule(ctx context.Context, router adapter.Router, logger log.ContextLogge return nil, E.New("missing outbound field") } } - return NewLogicalRule(ctx, router, logger, options.LogicalOptions) + return NewLogicalRule(ctx, logger, options.LogicalOptions) default: return nil, E.New("unknown rule type: ", options.Type) } @@ -51,7 +52,7 @@ type RuleItem interface { String() string } -func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { +func NewDefaultRule(ctx context.Context, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { action, err := NewRuleAction(ctx, logger, options.RuleAction) if err != nil { return nil, E.Cause(err, "action") @@ -62,6 +63,8 @@ func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.Conte action: action, }, } + router := service.FromContext[adapter.Router](ctx) + networkManager := service.FromContext[adapter.NetworkManager](ctx) if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) rule.items = append(rule.items, item) @@ -221,12 +224,12 @@ func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.Conte rule.allItems = append(rule.allItems, item) } if len(options.WIFISSID) > 0 { - item := NewWIFISSIDItem(router, options.WIFISSID) + item := NewWIFISSIDItem(networkManager, options.WIFISSID) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } if len(options.WIFIBSSID) > 0 { - item := NewWIFIBSSIDItem(router, options.WIFIBSSID) + item := NewWIFIBSSIDItem(networkManager, options.WIFIBSSID) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } @@ -253,7 +256,7 @@ type LogicalRule struct { abstractLogicalRule } -func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { +func NewLogicalRule(ctx context.Context, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { action, err := NewRuleAction(ctx, logger, options.RuleAction) if err != nil { return nil, E.Cause(err, "action") @@ -274,7 +277,7 @@ func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.Conte return nil, E.New("unknown logical mode: ", options.Mode) } for i, subOptions := range options.Rules { - subRule, err := NewRule(ctx, router, logger, subOptions, false) + subRule, err := NewRule(ctx, logger, subOptions, false) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") } diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index 2218f6a3..df5f3f33 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -10,9 +10,10 @@ import ( "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/service" ) -func NewDNSRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DNSRule, checkServer bool) (adapter.DNSRule, error) { +func NewDNSRule(ctx context.Context, logger log.ContextLogger, options option.DNSRule, checkServer bool) (adapter.DNSRule, error) { switch options.Type { case "", C.RuleTypeDefault: if !options.DefaultOptions.IsValid() { @@ -24,7 +25,7 @@ func NewDNSRule(ctx context.Context, router adapter.Router, logger log.ContextLo return nil, E.New("missing server field") } } - return NewDefaultDNSRule(ctx, router, logger, options.DefaultOptions) + return NewDefaultDNSRule(ctx, logger, options.DefaultOptions) case C.RuleTypeLogical: if !options.LogicalOptions.IsValid() { return nil, E.New("missing conditions") @@ -35,7 +36,7 @@ func NewDNSRule(ctx context.Context, router adapter.Router, logger log.ContextLo return nil, E.New("missing server field") } } - return NewLogicalDNSRule(ctx, router, logger, options.LogicalOptions) + return NewLogicalDNSRule(ctx, logger, options.LogicalOptions) default: return nil, E.New("unknown rule type: ", options.Type) } @@ -47,7 +48,7 @@ type DefaultDNSRule struct { abstractDefaultRule } -func NewDefaultDNSRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { +func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options option.DefaultDNSRule) (*DefaultDNSRule, error) { rule := &DefaultDNSRule{ abstractDefaultRule: abstractDefaultRule{ invert: options.Invert, @@ -59,6 +60,8 @@ func NewDefaultDNSRule(ctx context.Context, router adapter.Router, logger log.Co rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } + router := service.FromContext[adapter.Router](ctx) + networkManager := service.FromContext[adapter.NetworkManager](ctx) if options.IPVersion > 0 { switch options.IPVersion { case 4, 6: @@ -218,12 +221,12 @@ func NewDefaultDNSRule(ctx context.Context, router adapter.Router, logger log.Co rule.allItems = append(rule.allItems, item) } if len(options.WIFISSID) > 0 { - item := NewWIFISSIDItem(router, options.WIFISSID) + item := NewWIFISSIDItem(networkManager, options.WIFISSID) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } if len(options.WIFIBSSID) > 0 { - item := NewWIFIBSSIDItem(router, options.WIFIBSSID) + item := NewWIFIBSSIDItem(networkManager, options.WIFIBSSID) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } @@ -282,7 +285,7 @@ type LogicalDNSRule struct { abstractLogicalRule } -func NewLogicalDNSRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { +func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options option.LogicalDNSRule) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ abstractLogicalRule: abstractLogicalRule{ rules: make([]adapter.HeadlessRule, len(options.Rules)), @@ -299,7 +302,7 @@ func NewLogicalDNSRule(ctx context.Context, router adapter.Router, logger log.Co return nil, E.New("unknown logical mode: ", options.Mode) } for i, subRule := range options.Rules { - rule, err := NewDNSRule(ctx, router, logger, subRule, false) + rule, err := NewDNSRule(ctx, logger, subRule, false) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") } diff --git a/route/rule/rule_headless.go b/route/rule/rule_headless.go index 9ea357af..99488b20 100644 --- a/route/rule/rule_headless.go +++ b/route/rule/rule_headless.go @@ -1,24 +1,27 @@ package rule import ( + "context" + "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/option" E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/service" ) -func NewHeadlessRule(router adapter.Router, options option.HeadlessRule) (adapter.HeadlessRule, error) { +func NewHeadlessRule(ctx context.Context, options option.HeadlessRule) (adapter.HeadlessRule, error) { switch options.Type { case "", C.RuleTypeDefault: if !options.DefaultOptions.IsValid() { return nil, E.New("missing conditions") } - return NewDefaultHeadlessRule(router, options.DefaultOptions) + return NewDefaultHeadlessRule(ctx, options.DefaultOptions) case C.RuleTypeLogical: if !options.LogicalOptions.IsValid() { return nil, E.New("missing conditions") } - return NewLogicalHeadlessRule(router, options.LogicalOptions) + return NewLogicalHeadlessRule(ctx, options.LogicalOptions) default: return nil, E.New("unknown rule type: ", options.Type) } @@ -30,7 +33,8 @@ type DefaultHeadlessRule struct { abstractDefaultRule } -func NewDefaultHeadlessRule(router adapter.Router, options option.DefaultHeadlessRule) (*DefaultHeadlessRule, error) { +func NewDefaultHeadlessRule(ctx context.Context, options option.DefaultHeadlessRule) (*DefaultHeadlessRule, error) { + networkManager := service.FromContext[adapter.NetworkManager](ctx) rule := &DefaultHeadlessRule{ abstractDefaultRule{ invert: options.Invert, @@ -137,15 +141,15 @@ func NewDefaultHeadlessRule(router adapter.Router, options option.DefaultHeadles rule.allItems = append(rule.allItems, item) } if len(options.WIFISSID) > 0 { - if router != nil { - item := NewWIFISSIDItem(router, options.WIFISSID) + if networkManager != nil { + item := NewWIFISSIDItem(networkManager, options.WIFISSID) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } } if len(options.WIFIBSSID) > 0 { - if router != nil { - item := NewWIFIBSSIDItem(router, options.WIFIBSSID) + if networkManager != nil { + item := NewWIFIBSSIDItem(networkManager, options.WIFIBSSID) rule.items = append(rule.items, item) rule.allItems = append(rule.allItems, item) } @@ -168,7 +172,7 @@ type LogicalHeadlessRule struct { abstractLogicalRule } -func NewLogicalHeadlessRule(router adapter.Router, options option.LogicalHeadlessRule) (*LogicalHeadlessRule, error) { +func NewLogicalHeadlessRule(ctx context.Context, options option.LogicalHeadlessRule) (*LogicalHeadlessRule, error) { r := &LogicalHeadlessRule{ abstractLogicalRule{ rules: make([]adapter.HeadlessRule, len(options.Rules)), @@ -184,7 +188,7 @@ func NewLogicalHeadlessRule(router adapter.Router, options option.LogicalHeadles return nil, E.New("unknown logical mode: ", options.Mode) } for i, subRule := range options.Rules { - rule, err := NewHeadlessRule(router, subRule) + rule, err := NewHeadlessRule(ctx, subRule) if err != nil { return nil, E.Cause(err, "sub rule[", i, "]") } diff --git a/route/rule/rule_item_wifi_bssid.go b/route/rule/rule_item_wifi_bssid.go index ae94bd6d..8f887322 100644 --- a/route/rule/rule_item_wifi_bssid.go +++ b/route/rule/rule_item_wifi_bssid.go @@ -10,12 +10,12 @@ import ( var _ RuleItem = (*WIFIBSSIDItem)(nil) type WIFIBSSIDItem struct { - bssidList []string - bssidMap map[string]bool - router adapter.Router + bssidList []string + bssidMap map[string]bool + networkManager adapter.NetworkManager } -func NewWIFIBSSIDItem(router adapter.Router, bssidList []string) *WIFIBSSIDItem { +func NewWIFIBSSIDItem(networkManager adapter.NetworkManager, bssidList []string) *WIFIBSSIDItem { bssidMap := make(map[string]bool) for _, bssid := range bssidList { bssidMap[bssid] = true @@ -23,12 +23,12 @@ func NewWIFIBSSIDItem(router adapter.Router, bssidList []string) *WIFIBSSIDItem return &WIFIBSSIDItem{ bssidList, bssidMap, - router, + networkManager, } } func (r *WIFIBSSIDItem) Match(metadata *adapter.InboundContext) bool { - return r.bssidMap[r.router.WIFIState().BSSID] + return r.bssidMap[r.networkManager.WIFIState().BSSID] } func (r *WIFIBSSIDItem) String() string { diff --git a/route/rule/rule_item_wifi_ssid.go b/route/rule/rule_item_wifi_ssid.go index 3a928f77..ab9fdd88 100644 --- a/route/rule/rule_item_wifi_ssid.go +++ b/route/rule/rule_item_wifi_ssid.go @@ -10,12 +10,12 @@ import ( var _ RuleItem = (*WIFISSIDItem)(nil) type WIFISSIDItem struct { - ssidList []string - ssidMap map[string]bool - router adapter.Router + ssidList []string + ssidMap map[string]bool + networkManager adapter.NetworkManager } -func NewWIFISSIDItem(router adapter.Router, ssidList []string) *WIFISSIDItem { +func NewWIFISSIDItem(networkManager adapter.NetworkManager, ssidList []string) *WIFISSIDItem { ssidMap := make(map[string]bool) for _, ssid := range ssidList { ssidMap[ssid] = true @@ -23,12 +23,12 @@ func NewWIFISSIDItem(router adapter.Router, ssidList []string) *WIFISSIDItem { return &WIFISSIDItem{ ssidList, ssidMap, - router, + networkManager, } } func (r *WIFISSIDItem) Match(metadata *adapter.InboundContext) bool { - return r.ssidMap[r.router.WIFIState().SSID] + return r.ssidMap[r.networkManager.WIFIState().SSID] } func (r *WIFISSIDItem) String() string { diff --git a/route/rule/rule_set.go b/route/rule/rule_set.go index cdd0fc0a..c0f307b7 100644 --- a/route/rule/rule_set.go +++ b/route/rule/rule_set.go @@ -13,12 +13,12 @@ import ( "go4.org/netipx" ) -func NewRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) { +func NewRuleSet(ctx context.Context, logger logger.ContextLogger, options option.RuleSet) (adapter.RuleSet, error) { switch options.Type { case C.RuleSetTypeInline, C.RuleSetTypeLocal, "": - return NewLocalRuleSet(ctx, router, logger, options) + return NewLocalRuleSet(ctx, logger, options) case C.RuleSetTypeRemote: - return NewRemoteRuleSet(ctx, router, logger, options), nil + return NewRemoteRuleSet(ctx, logger, options), nil default: return nil, E.New("unknown rule-set type: ", options.Type) } diff --git a/route/rule/rule_set_local.go b/route/rule/rule_set_local.go index 900c8b24..63678de1 100644 --- a/route/rule/rule_set_local.go +++ b/route/rule/rule_set_local.go @@ -26,7 +26,7 @@ import ( var _ adapter.RuleSet = (*LocalRuleSet)(nil) type LocalRuleSet struct { - router adapter.Router + ctx context.Context logger logger.Logger tag string rules []adapter.HeadlessRule @@ -36,9 +36,9 @@ type LocalRuleSet struct { refs atomic.Int32 } -func NewLocalRuleSet(ctx context.Context, router adapter.Router, logger logger.Logger, options option.RuleSet) (*LocalRuleSet, error) { +func NewLocalRuleSet(ctx context.Context, logger logger.Logger, options option.RuleSet) (*LocalRuleSet, error) { ruleSet := &LocalRuleSet{ - router: router, + ctx: ctx, logger: logger, tag: options.Tag, fileFormat: options.Format, @@ -129,7 +129,7 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error { rules := make([]adapter.HeadlessRule, len(headlessRules)) var err error for i, ruleOptions := range headlessRules { - rules[i], err = NewHeadlessRule(s.router, ruleOptions) + rules[i], err = NewHeadlessRule(s.ctx, ruleOptions) if err != nil { return E.Cause(err, "parse rule_set.rules.[", i, "]") } diff --git a/route/rule/rule_set_remote.go b/route/rule/rule_set_remote.go index 55c863e9..5ecbcbef 100644 --- a/route/rule/rule_set_remote.go +++ b/route/rule/rule_set_remote.go @@ -35,7 +35,6 @@ var _ adapter.RuleSet = (*RemoteRuleSet)(nil) type RemoteRuleSet struct { ctx context.Context cancel context.CancelFunc - router adapter.Router outboundManager adapter.OutboundManager logger logger.ContextLogger options option.RuleSet @@ -53,7 +52,7 @@ type RemoteRuleSet struct { refs atomic.Int32 } -func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { +func NewRemoteRuleSet(ctx context.Context, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { ctx, cancel := context.WithCancel(ctx) var updateInterval time.Duration if options.RemoteOptions.UpdateInterval > 0 { @@ -64,7 +63,6 @@ func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger. return &RemoteRuleSet{ ctx: ctx, cancel: cancel, - router: router, outboundManager: service.FromContext[adapter.OutboundManager](ctx), logger: logger, options: options, @@ -181,7 +179,7 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error { } rules := make([]adapter.HeadlessRule, len(plainRuleSet.Rules)) for i, ruleOptions := range plainRuleSet.Rules { - rules[i], err = NewHeadlessRule(s.router, ruleOptions) + rules[i], err = NewHeadlessRule(s.ctx, ruleOptions) if err != nil { return E.Cause(err, "parse rule_set.rules.[", i, "]") } diff --git a/transport/dhcp/server.go b/transport/dhcp/server.go index d5603e9e..dfe33d86 100644 --- a/transport/dhcp/server.go +++ b/transport/dhcp/server.go @@ -38,6 +38,7 @@ func init() { type Transport struct { options dns.TransportOptions router adapter.Router + networkManager adapter.NetworkManager interfaceName string autoInterface bool interfaceCallback *list.Element[tun.DefaultInterfaceUpdateCallback] @@ -54,15 +55,11 @@ func NewTransport(options dns.TransportOptions) (*Transport, error) { if linkURL.Host == "" { return nil, E.New("missing interface name for DHCP") } - router := service.FromContext[adapter.Router](options.Context) - if router == nil { - return nil, E.New("missing router in context") - } transport := &Transport{ - options: options, - router: router, - interfaceName: linkURL.Host, - autoInterface: linkURL.Host == "auto", + options: options, + networkManager: service.FromContext[adapter.NetworkManager](options.Context), + interfaceName: linkURL.Host, + autoInterface: linkURL.Host == "auto", } return transport, nil } @@ -77,7 +74,7 @@ func (t *Transport) Start() error { return err } if t.autoInterface { - t.interfaceCallback = t.router.InterfaceMonitor().RegisterCallback(t.interfaceUpdated) + t.interfaceCallback = t.networkManager.InterfaceMonitor().RegisterCallback(t.interfaceUpdated) } return nil } @@ -93,7 +90,7 @@ func (t *Transport) Close() error { transport.Close() } if t.interfaceCallback != nil { - t.router.InterfaceMonitor().UnregisterCallback(t.interfaceCallback) + t.networkManager.InterfaceMonitor().UnregisterCallback(t.interfaceCallback) } return nil } @@ -125,10 +122,10 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, func (t *Transport) fetchInterface() (*net.Interface, error) { interfaceName := t.interfaceName if t.autoInterface { - if t.router.InterfaceMonitor() == nil { + if t.networkManager.InterfaceMonitor() == nil { return nil, E.New("missing monitor for auto DHCP, set route.auto_detect_interface") } - interfaceName = t.router.InterfaceMonitor().DefaultInterfaceName(netip.Addr{}) + interfaceName = t.networkManager.InterfaceMonitor().DefaultInterfaceName(netip.Addr{}) } if interfaceName == "" { return nil, E.New("missing default interface") @@ -177,7 +174,7 @@ func (t *Transport) interfaceUpdated(int) { func (t *Transport) fetchServers0(ctx context.Context, iface *net.Interface) error { var listener net.ListenConfig - listener.Control = control.Append(listener.Control, control.BindToInterface(t.router.InterfaceFinder(), iface.Name, iface.Index)) + listener.Control = control.Append(listener.Control, control.BindToInterface(t.networkManager.InterfaceFinder(), iface.Name, iface.Index)) listener.Control = control.Append(listener.Control, control.ReuseAddr()) listenAddr := "0.0.0.0:68" if runtime.GOOS == "linux" || runtime.GOOS == "android" { @@ -255,7 +252,7 @@ func (t *Transport) recreateServers(iface *net.Interface, serverAddrs []netip.Ad return it.String() }), ","), "]") } - serverDialer := common.Must1(dialer.NewDefault(t.router, option.DialerOptions{ + serverDialer := common.Must1(dialer.NewDefault(t.networkManager, option.DialerOptions{ BindInterface: iface.Name, UDPFragmentDefault: true, })) diff --git a/transport/wireguard/device_system.go b/transport/wireguard/device_system.go index 36c270f9..8a54a75e 100644 --- a/transport/wireguard/device_system.go +++ b/transport/wireguard/device_system.go @@ -34,7 +34,7 @@ type SystemDevice struct { closeOnce sync.Once } -func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) { +func NewSystemDevice(networkManager adapter.NetworkManager, interfaceName string, localPrefixes []netip.Prefix, mtu uint32, gso bool) (*SystemDevice, error) { var inet4Addresses []netip.Prefix var inet6Addresses []netip.Prefix for _, prefixes := range localPrefixes { @@ -49,7 +49,7 @@ func NewSystemDevice(router adapter.Router, interfaceName string, localPrefixes } return &SystemDevice{ - dialer: common.Must1(dialer.NewDefault(router, option.DialerOptions{ + dialer: common.Must1(dialer.NewDefault(networkManager, option.DialerOptions{ BindInterface: interfaceName, })), name: interfaceName,