diff --git a/adapter/experimental.go b/adapter/experimental.go index 0cab5ed5..bee24c4f 100644 --- a/adapter/experimental.go +++ b/adapter/experimental.go @@ -15,7 +15,7 @@ import ( type ClashServer interface { Service - PreStarter + LegacyPreStarter Mode() string ModeList() []string HistoryStorage() *urltest.HistoryStorage @@ -25,7 +25,7 @@ type ClashServer interface { type CacheFile interface { Service - PreStarter + LegacyPreStarter StoreFakeIP() bool FakeIPStorage diff --git a/adapter/inbound.go b/adapter/inbound.go index f9ed1708..d80e59f7 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -28,7 +28,15 @@ type UDPInjectableInbound interface { type InboundRegistry interface { option.InboundOptionsRegistry - CreateInbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Inbound, error) + Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, inboundType string, options any) (Inbound, error) +} + +type InboundManager interface { + NewService + Inbounds() []Inbound + Get(tag string) (Inbound, bool) + Remove(tag string) error + Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, inboundType string, options any) error } type InboundContext struct { diff --git a/adapter/inbound/manager.go b/adapter/inbound/manager.go new file mode 100644 index 00000000..d2be4b57 --- /dev/null +++ b/adapter/inbound/manager.go @@ -0,0 +1,143 @@ +package inbound + +import ( + "context" + "os" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" +) + +var _ adapter.InboundManager = (*Manager)(nil) + +type Manager struct { + logger log.ContextLogger + registry adapter.InboundRegistry + access sync.Mutex + started bool + stage adapter.StartStage + inbounds []adapter.Inbound + inboundByTag map[string]adapter.Inbound +} + +func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager { + return &Manager{ + logger: logger, + registry: registry, + inboundByTag: make(map[string]adapter.Inbound), + } +} + +func (m *Manager) Start(stage adapter.StartStage) error { + m.access.Lock() + defer m.access.Unlock() + if m.started && m.stage >= stage { + panic("already started") + } + m.started = true + m.stage = stage + for _, inbound := range m.inbounds { + err := adapter.LegacyStart(inbound, stage) + if err != nil { + return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]") + } + } + return nil +} + +func (m *Manager) Close() error { + m.access.Lock() + if !m.started { + panic("not started") + } + m.started = false + inbounds := m.inbounds + m.inbounds = nil + m.access.Unlock() + monitor := taskmonitor.New(m.logger, C.StopTimeout) + var err error + for _, inbound := range inbounds { + monitor.Start("close inbound/", inbound.Type(), "[", inbound.Tag(), "]") + err = E.Append(err, inbound.Close(), func(err error) error { + return E.Cause(err, "close inbound/", inbound.Type(), "[", inbound.Tag(), "]") + }) + monitor.Finish() + } + return nil +} + +func (m *Manager) Inbounds() []adapter.Inbound { + m.access.Lock() + defer m.access.Unlock() + return m.inbounds +} + +func (m *Manager) Get(tag string) (adapter.Inbound, bool) { + m.access.Lock() + defer m.access.Unlock() + inbound, found := m.inboundByTag[tag] + return inbound, found +} + +func (m *Manager) Remove(tag string) error { + m.access.Lock() + inbound, found := m.inboundByTag[tag] + if !found { + m.access.Unlock() + return os.ErrInvalid + } + delete(m.inboundByTag, tag) + index := common.Index(m.inbounds, func(it adapter.Inbound) bool { + return it == inbound + }) + if index == -1 { + panic("invalid inbound index") + } + m.inbounds = append(m.inbounds[:index], m.inbounds[index+1:]...) + started := m.started + m.access.Unlock() + if started { + return inbound.Close() + } + return nil +} + +func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error { + inbound, err := m.registry.Create(ctx, router, logger, tag, outboundType, options) + if err != nil { + return err + } + m.access.Lock() + defer m.access.Unlock() + if m.started { + for _, stage := range adapter.ListStartStages { + err = adapter.LegacyStart(inbound, stage) + if err != nil { + return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]") + } + } + } + if existsInbound, loaded := m.inboundByTag[tag]; loaded { + if m.started { + err = existsInbound.Close() + if err != nil { + return E.Cause(err, "close inbound/", existsInbound.Type(), "[", existsInbound.Tag(), "]") + } + } + existsIndex := common.Index(m.inbounds, func(it adapter.Inbound) bool { + return it == existsInbound + }) + if existsIndex == -1 { + panic("invalid inbound index") + } + m.inbounds = append(m.inbounds[:existsIndex], m.inbounds[existsIndex+1:]...) + } + m.inbounds = append(m.inbounds, inbound) + m.inboundByTag[tag] = inbound + return nil +} diff --git a/adapter/inbound/registry.go b/adapter/inbound/registry.go index 9f678c90..622e01c7 100644 --- a/adapter/inbound/registry.go +++ b/adapter/inbound/registry.go @@ -28,41 +28,41 @@ type ( ) type Registry struct { - access sync.Mutex - optionsType map[string]optionsConstructorFunc - constructors map[string]constructorFunc + access sync.Mutex + optionsType map[string]optionsConstructorFunc + constructor map[string]constructorFunc } func NewRegistry() *Registry { return &Registry{ - optionsType: make(map[string]optionsConstructorFunc), - constructors: make(map[string]constructorFunc), + optionsType: make(map[string]optionsConstructorFunc), + constructor: make(map[string]constructorFunc), } } -func (r *Registry) CreateOptions(outboundType string) (any, bool) { - r.access.Lock() - defer r.access.Unlock() - optionsConstructor, loaded := r.optionsType[outboundType] +func (m *Registry) CreateOptions(outboundType string) (any, bool) { + m.access.Lock() + defer m.access.Unlock() + optionsConstructor, loaded := m.optionsType[outboundType] if !loaded { return nil, false } return optionsConstructor(), true } -func (r *Registry) CreateInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Inbound, error) { - r.access.Lock() - defer r.access.Unlock() - constructor, loaded := r.constructors[outboundType] +func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Inbound, error) { + m.access.Lock() + defer m.access.Unlock() + constructor, loaded := m.constructor[outboundType] if !loaded { return nil, E.New("outbound type not found: " + outboundType) } return constructor(ctx, router, logger, tag, options) } -func (r *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { - r.access.Lock() - defer r.access.Unlock() - r.optionsType[outboundType] = optionsConstructor - r.constructors[outboundType] = constructor +func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { + m.access.Lock() + defer m.access.Unlock() + m.optionsType[outboundType] = optionsConstructor + m.constructor[outboundType] = constructor } diff --git a/adapter/lifecycle.go b/adapter/lifecycle.go new file mode 100644 index 00000000..85de425d --- /dev/null +++ b/adapter/lifecycle.go @@ -0,0 +1,41 @@ +package adapter + +type StartStage uint8 + +const ( + StartStateInitialize StartStage = iota + StartStateStart + StartStatePostStart + StartStateStarted +) + +var ListStartStages = []StartStage{ + StartStateInitialize, + StartStateStart, + StartStatePostStart, + StartStateStarted, +} + +func (s StartStage) Action() string { + switch s { + case StartStateInitialize: + return "initialize" + case StartStateStart: + return "start" + case StartStatePostStart: + return "post-start" + case StartStateStarted: + return "start-after-started" + default: + panic("unknown stage") + } +} + +type NewService interface { + NewStarter + Close() error +} + +type NewStarter interface { + Start(stage StartStage) error +} diff --git a/adapter/lifecycle_legacy.go b/adapter/lifecycle_legacy.go new file mode 100644 index 00000000..5968131b --- /dev/null +++ b/adapter/lifecycle_legacy.go @@ -0,0 +1,33 @@ +package adapter + +type LegacyPreStarter interface { + PreStart() error +} + +type LegacyPostStarter interface { + PostStart() error +} + +func LegacyStart(starter any, stage StartStage) error { + switch stage { + case StartStateInitialize: + if preStarter, isPreStarter := starter.(interface { + PreStart() error + }); isPreStarter { + return preStarter.PreStart() + } + case StartStateStart: + if starter, isStarter := starter.(interface { + Start() error + }); isStarter { + return starter.Start() + } + case StartStatePostStart: + if postStarter, isPostStarter := starter.(interface { + PostStart() error + }); isPostStarter { + return postStarter.PostStart() + } + } + return nil +} diff --git a/adapter/outbound.go b/adapter/outbound.go index df11ed61..b170398a 100644 --- a/adapter/outbound.go +++ b/adapter/outbound.go @@ -22,3 +22,12 @@ type OutboundRegistry interface { option.OutboundOptionsRegistry CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error) } + +type OutboundManager interface { + NewService + Outbounds() []Outbound + Outbound(tag string) (Outbound, bool) + Default() Outbound + Remove(tag string) error + Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) error +} diff --git a/adapter/outbound/manager.go b/adapter/outbound/manager.go new file mode 100644 index 00000000..b3e1a170 --- /dev/null +++ b/adapter/outbound/manager.go @@ -0,0 +1,265 @@ +package outbound + +import ( + "context" + "io" + "os" + "strings" + "sync" + + "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/taskmonitor" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/log" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +var _ adapter.OutboundManager = (*Manager)(nil) + +type Manager struct { + logger log.ContextLogger + registry adapter.OutboundRegistry + defaultTag string + access sync.Mutex + started bool + stage adapter.StartStage + outbounds []adapter.Outbound + outboundByTag map[string]adapter.Outbound + dependByTag map[string][]string + defaultOutbound adapter.Outbound + defaultOutboundFallback adapter.Outbound +} + +func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager { + return &Manager{ + logger: logger, + registry: registry, + defaultTag: defaultTag, + outboundByTag: make(map[string]adapter.Outbound), + dependByTag: make(map[string][]string), + } +} + +func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) { + m.defaultOutboundFallback = defaultOutboundFallback +} + +func (m *Manager) Start(stage adapter.StartStage) error { + m.access.Lock() + defer m.access.Unlock() + if m.started && m.stage >= stage { + panic("already started") + } + m.started = true + m.stage = stage + if stage == adapter.StartStateStart { + m.startOutbounds() + } else { + for _, outbound := range m.outbounds { + err := adapter.LegacyStart(outbound, stage) + if err != nil { + return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]") + } + } + } + return nil +} + +func (m *Manager) startOutbounds() error { + monitor := taskmonitor.New(m.logger, C.StartTimeout) + started := make(map[string]bool) + for { + canContinue := false + startOne: + for _, outboundToStart := range m.outbounds { + outboundTag := outboundToStart.Tag() + if started[outboundTag] { + continue + } + dependencies := outboundToStart.Dependencies() + for _, dependency := range dependencies { + if !started[dependency] { + continue startOne + } + } + started[outboundTag] = true + canContinue = true + if starter, isStarter := outboundToStart.(interface { + Start() error + }); isStarter { + monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]") + err := starter.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]") + } + } + } + if len(started) == len(m.outbounds) { + break + } + if canContinue { + continue + } + currentOutbound := common.Find(m.outbounds, func(it adapter.Outbound) bool { + return !started[it.Tag()] + }) + var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error + lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error { + problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool { + return !started[it] + }) + if common.Contains(oTree, problemOutboundTag) { + return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag) + } + problemOutbound := m.outboundByTag[problemOutboundTag] + if problemOutbound == nil { + return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]") + } + return lintOutbound(append(oTree, problemOutboundTag), problemOutbound) + } + return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound) + } + return nil +} + +func (m *Manager) Close() error { + monitor := taskmonitor.New(m.logger, C.StopTimeout) + m.access.Lock() + if !m.started { + panic("not started") + } + m.started = false + outbounds := m.outbounds + m.outbounds = nil + m.access.Unlock() + var err error + for _, outbound := range outbounds { + if closer, isCloser := outbound.(io.Closer); isCloser { + monitor.Start("close outbound/", outbound.Type(), "[", outbound.Tag(), "]") + err = E.Append(err, closer.Close(), func(err error) error { + return E.Cause(err, "close outbound/", outbound.Type(), "[", outbound.Tag(), "]") + }) + monitor.Finish() + } + } + return nil +} + +func (m *Manager) Outbounds() []adapter.Outbound { + m.access.Lock() + defer m.access.Unlock() + return m.outbounds +} + +func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) { + m.access.Lock() + defer m.access.Unlock() + outbound, found := m.outboundByTag[tag] + return outbound, found +} + +func (m *Manager) Default() adapter.Outbound { + m.access.Lock() + defer m.access.Unlock() + if m.defaultOutbound != nil { + return m.defaultOutbound + } else { + return m.defaultOutboundFallback + } +} + +func (m *Manager) Remove(tag string) error { + m.access.Lock() + outbound, found := m.outboundByTag[tag] + if !found { + m.access.Unlock() + return os.ErrInvalid + } + delete(m.outboundByTag, tag) + index := common.Index(m.outbounds, func(it adapter.Outbound) bool { + return it == outbound + }) + if index == -1 { + panic("invalid inbound index") + } + m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...) + started := m.started + if m.defaultOutbound == outbound { + if len(m.outbounds) > 0 { + m.defaultOutbound = m.outbounds[0] + m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag()) + } else { + m.defaultOutbound = nil + } + } + dependBy := m.dependByTag[tag] + if len(dependBy) > 0 { + return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", ")) + } + dependencies := outbound.Dependencies() + for _, dependency := range dependencies { + if len(m.dependByTag[dependency]) == 1 { + delete(m.dependByTag, dependency) + } else { + m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool { + return it != tag + }) + } + } + m.access.Unlock() + if started { + return common.Close(outbound) + } + return nil +} + +func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error { + if tag == "" { + return os.ErrInvalid + } + outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options) + if err != nil { + return err + } + m.access.Lock() + defer m.access.Unlock() + if m.started { + for _, stage := range adapter.ListStartStages { + err = adapter.LegacyStart(outbound, stage) + if err != nil { + return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]") + } + } + } + if existsOutbound, loaded := m.outboundByTag[tag]; loaded { + if m.started { + err = common.Close(existsOutbound) + if err != nil { + return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]") + } + } + existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool { + return it == existsOutbound + }) + if existsIndex == -1 { + panic("invalid inbound index") + } + m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...) + } + m.outbounds = append(m.outbounds, outbound) + m.outboundByTag[tag] = outbound + dependencies := outbound.Dependencies() + for _, dependency := range dependencies { + m.dependByTag[dependency] = append(m.dependByTag[dependency], tag) + } + if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) { + m.defaultOutbound = outbound + if m.started { + m.logger.Info("updated default outbound to ", outbound.Tag()) + } + } + return nil +} diff --git a/adapter/prestart.go b/adapter/prestart.go index 6a39aec3..b8e8da30 100644 --- a/adapter/prestart.go +++ b/adapter/prestart.go @@ -1,9 +1 @@ package adapter - -type PreStarter interface { - PreStart() error -} - -type PostStarter interface { - PostStart() error -} diff --git a/adapter/router.go b/adapter/router.go index c9cd46e9..b8ac51f5 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -15,21 +15,13 @@ import ( M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" - "github.com/sagernet/sing/service" mdns "github.com/miekg/dns" "go4.org/netipx" ) type Router interface { - Service - PreStarter - PostStarter - Cleanup() error - - Outbounds() []Outbound - Outbound(tag string) (Outbound, bool) - DefaultOutbound(network string) (Outbound, error) + NewService FakeIPStore() FakeIPStore @@ -84,14 +76,6 @@ type ConnectionRouterEx interface { RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc) } -func ContextWithRouter(ctx context.Context, router Router) context.Context { - return service.ContextWith(ctx, router) -} - -func RouterFromContext(ctx context.Context) Router { - return service.FromContext[Router](ctx) -} - type RuleSet interface { Name() string StartContext(ctx context.Context, startContext *HTTPStartContext) error diff --git a/box.go b/box.go index 84da77c0..a4ca530f 100644 --- a/box.go +++ b/box.go @@ -9,6 +9,8 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/adapter/inbound" + "github.com/sagernet/sing-box/adapter/outbound" "github.com/sagernet/sing-box/common/taskmonitor" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/experimental" @@ -30,8 +32,8 @@ var _ adapter.Service = (*Box)(nil) type Box struct { createdAt time.Time router adapter.Router - inbounds []adapter.Inbound - outbounds []adapter.Outbound + inbound *inbound.Manager + outbound *outbound.Manager logFactory log.Factory logger log.ContextLogger preServices1 map[string]adapter.Service @@ -66,6 +68,7 @@ func New(options Options) (*Box, error) { if ctx == nil { ctx = context.Background() } + ctx = service.ContextWithDefaultRegistry(ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) if inboundRegistry == nil { return nil, E.New("missing inbound registry in context") @@ -74,7 +77,6 @@ func New(options Options) (*Box, error) { if outboundRegistry == nil { return nil, E.New("missing outbound registry in context") } - ctx = service.ContextWithDefaultRegistry(ctx) ctx = pause.WithDefaultManager(ctx) experimentalOptions := common.PtrValueOrDefault(options.Experimental) applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug)) @@ -106,10 +108,15 @@ func New(options Options) (*Box, error) { if err != nil { return nil, E.Cause(err, "create log factory") } + routeOptions := common.PtrValueOrDefault(options.Route) + inboundManager := inbound.NewManager(logFactory.NewLogger("inbound-manager"), inboundRegistry) + outboundManager := outbound.NewManager(logFactory.NewLogger("outbound-manager"), outboundRegistry, routeOptions.Final) + ctx = service.ContextWith[adapter.InboundManager](ctx, inboundManager) + ctx = service.ContextWith[adapter.OutboundManager](ctx, outboundManager) router, err := route.NewRouter( ctx, logFactory, - common.PtrValueOrDefault(options.Route), + routeOptions, common.PtrValueOrDefault(options.DNS), common.PtrValueOrDefault(options.NTP), options.Inbounds, @@ -127,7 +134,6 @@ func New(options Options) (*Box, error) { }) } } - inbounds := make([]adapter.Inbound, 0, len(options.Inbounds)) //nolint:staticcheck if len(options.LegacyOutbounds) > 0 { for _, legacyOutbound := range options.LegacyOutbounds { @@ -138,17 +144,14 @@ func New(options Options) (*Box, error) { }) } } - outbounds := make([]adapter.Outbound, 0, len(options.Outbounds)) for i, inboundOptions := range options.Inbounds { - var currentInbound adapter.Inbound var tag string if inboundOptions.Tag != "" { tag = inboundOptions.Tag } else { tag = F.ToString(i) } - currentInbound, err = inboundRegistry.CreateInbound( - ctx, + err = inboundManager.Create(ctx, router, logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")), tag, @@ -156,12 +159,10 @@ func New(options Options) (*Box, error) { inboundOptions.Options, ) if err != nil { - return nil, E.Cause(err, "parse inbound[", i, "]") + return nil, E.Cause(err, "initialize inbound[", i, "]") } - inbounds = append(inbounds, currentInbound) } for i, outboundOptions := range options.Outbounds { - var currentOutbound adapter.Outbound var tag string if outboundOptions.Tag != "" { tag = outboundOptions.Tag @@ -175,7 +176,7 @@ func New(options Options) (*Box, error) { Outbound: tag, }) } - currentOutbound, err = outboundRegistry.CreateOutbound( + err = outboundManager.Create( outboundCtx, router, logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")), @@ -184,16 +185,18 @@ func New(options Options) (*Box, error) { outboundOptions.Options, ) if err != nil { - return nil, E.Cause(err, "parse outbound[", i, "]") + return nil, E.Cause(err, "initialize outbound[", i, "]") } - outbounds = append(outbounds, currentOutbound) } - err = router.Initialize(inbounds, outbounds, func() adapter.Outbound { - defaultOutbound, cErr := direct.NewOutbound(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.DirectOutboundOptions{}) - common.Must(cErr) - outbounds = append(outbounds, defaultOutbound) - return defaultOutbound - }) + outboundManager.Initialize(common.Must1( + direct.NewOutbound( + ctx, + router, + logFactory.NewLogger("outbound/direct"), + "direct", + option.DirectOutboundOptions{}, + ), + )) if err != nil { return nil, err } @@ -217,7 +220,7 @@ func New(options Options) (*Box, error) { if needClashAPI { clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI) clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options) - clashServer, err := experimental.NewClashServer(ctx, router, logFactory.(log.ObservableFactory), clashAPIOptions) + clashServer, err := experimental.NewClashServer(ctx, logFactory.(log.ObservableFactory), clashAPIOptions) if err != nil { return nil, E.Cause(err, "create clash api server") } @@ -234,8 +237,8 @@ func New(options Options) (*Box, error) { } return &Box{ router: router, - inbounds: inbounds, - outbounds: outbounds, + inbound: inboundManager, + outbound: outboundManager, createdAt: createdAt, logFactory: logFactory, logger: logFactory.Logger(), @@ -293,7 +296,7 @@ func (s *Box) preStart() error { return E.Cause(err, "start logger") } for serviceName, service := range s.preServices1 { - if preService, isPreService := service.(adapter.PreStarter); isPreService { + if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService { monitor.Start("pre-start ", serviceName) err := preService.PreStart() monitor.Finish() @@ -303,7 +306,7 @@ func (s *Box) preStart() error { } } for serviceName, service := range s.preServices2 { - if preService, isPreService := service.(adapter.PreStarter); isPreService { + if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService { monitor.Start("pre-start ", serviceName) err := preService.PreStart() monitor.Finish() @@ -312,15 +315,15 @@ func (s *Box) preStart() error { } } } - err = s.router.PreStart() + err = s.router.Start(adapter.StartStateInitialize) if err != nil { - return E.Cause(err, "pre-start router") + return E.Cause(err, "initialize router") } - err = s.startOutbounds() + err = s.outbound.Start(adapter.StartStateStart) if err != nil { return err } - return s.router.Start() + return s.router.Start(adapter.StartStateStart) } func (s *Box) start() error { @@ -340,52 +343,39 @@ func (s *Box) start() error { return E.Cause(err, "start ", serviceName) } } - for i, in := range s.inbounds { - var tag string - if in.Tag() == "" { - tag = F.ToString(i) - } else { - tag = in.Tag() - } - err = in.Start() - if err != nil { - return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]") - } - } - err = s.postStart() + err = s.inbound.Start(adapter.StartStateStart) if err != nil { return err } - return s.router.Cleanup() -} - -func (s *Box) postStart() error { for serviceName, service := range s.postServices { err := service.Start() if err != nil { return E.Cause(err, "start ", serviceName) } } - // TODO: reorganize ALL start order - for _, out := range s.outbounds { - if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound { - err := lateOutbound.PostStart() - if err != nil { - return E.Cause(err, "post-start outbound/", out.Tag()) - } - } - } - err := s.router.PostStart() + err = s.outbound.Start(adapter.StartStatePostStart) if err != nil { return err } - for _, in := range s.inbounds { - if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound { - err = lateInbound.PostStart() - if err != nil { - return E.Cause(err, "post-start inbound/", in.Tag()) - } - } + err = s.router.Start(adapter.StartStatePostStart) + if err != nil { + return err + } + err = s.inbound.Start(adapter.StartStatePostStart) + if err != nil { + return err + } + err = s.router.Start(adapter.StartStateStarted) + if err != nil { + return err + } + err = s.outbound.Start(adapter.StartStateStarted) + if err != nil { + return err + } + err = s.inbound.Start(adapter.StartStateStarted) + if err != nil { + return err } return nil } @@ -406,20 +396,8 @@ func (s *Box) Close() error { }) monitor.Finish() } - for i, in := range s.inbounds { - monitor.Start("close inbound/", in.Type(), "[", i, "]") - errors = E.Append(errors, in.Close(), func(err error) error { - return E.Cause(err, "close inbound/", in.Type(), "[", i, "]") - }) - monitor.Finish() - } - for i, out := range s.outbounds { - monitor.Start("close outbound/", out.Type(), "[", i, "]") - errors = E.Append(errors, common.Close(out), func(err error) error { - return E.Cause(err, "close outbound/", out.Type(), "[", i, "]") - }) - monitor.Finish() - } + errors = E.Errors(errors, s.inbound.Close()) + errors = E.Errors(errors, s.outbound.Close()) monitor.Start("close router") if err := common.Close(s.router); err != nil { errors = E.Append(errors, err, func(err error) error { @@ -449,6 +427,14 @@ func (s *Box) Close() error { return errors } +func (s *Box) Inbound() adapter.InboundManager { + return s.inbound +} + +func (s *Box) Outbound() adapter.OutboundManager { + return s.outbound +} + func (s *Box) Router() adapter.Router { return s.router } diff --git a/box_outbound.go b/box_outbound.go deleted file mode 100644 index f03f3b7d..00000000 --- a/box_outbound.go +++ /dev/null @@ -1,85 +0,0 @@ -package box - -import ( - "strings" - - "github.com/sagernet/sing-box/adapter" - "github.com/sagernet/sing-box/common/taskmonitor" - C "github.com/sagernet/sing-box/constant" - "github.com/sagernet/sing/common" - E "github.com/sagernet/sing/common/exceptions" - F "github.com/sagernet/sing/common/format" -) - -func (s *Box) startOutbounds() error { - monitor := taskmonitor.New(s.logger, C.StartTimeout) - outboundTags := make(map[adapter.Outbound]string) - outbounds := make(map[string]adapter.Outbound) - for i, outboundToStart := range s.outbounds { - var outboundTag string - if outboundToStart.Tag() == "" { - outboundTag = F.ToString(i) - } else { - outboundTag = outboundToStart.Tag() - } - if _, exists := outbounds[outboundTag]; exists { - return E.New("outbound tag ", outboundTag, " duplicated") - } - outboundTags[outboundToStart] = outboundTag - outbounds[outboundTag] = outboundToStart - } - started := make(map[string]bool) - for { - canContinue := false - startOne: - for _, outboundToStart := range s.outbounds { - outboundTag := outboundTags[outboundToStart] - if started[outboundTag] { - continue - } - dependencies := outboundToStart.Dependencies() - for _, dependency := range dependencies { - if !started[dependency] { - continue startOne - } - } - started[outboundTag] = true - canContinue = true - if starter, isStarter := outboundToStart.(interface { - Start() error - }); isStarter { - monitor.Start("initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]") - err := starter.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]") - } - } - } - if len(started) == len(s.outbounds) { - break - } - if canContinue { - continue - } - currentOutbound := common.Find(s.outbounds, func(it adapter.Outbound) bool { - return !started[outboundTags[it]] - }) - var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error - lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error { - problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool { - return !started[it] - }) - if common.Contains(oTree, problemOutboundTag) { - return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag) - } - problemOutbound := outbounds[problemOutboundTag] - if problemOutbound == nil { - return E.New("dependency[", problemOutboundTag, "] not found for outbound[", outboundTags[oCurrent], "]") - } - return lintOutbound(append(oTree, problemOutboundTag), problemOutbound) - } - return lintOutbound([]string{outboundTags[currentOutbound]}, currentOutbound) - } - return nil -} diff --git a/cmd/sing-box/cmd_tools.go b/cmd/sing-box/cmd_tools.go index 86b9302e..8f30e054 100644 --- a/cmd/sing-box/cmd_tools.go +++ b/cmd/sing-box/cmd_tools.go @@ -41,11 +41,11 @@ func createPreStartedClient() (*box.Box, error) { return instance, nil } -func createDialer(instance *box.Box, network string, outboundTag string) (N.Dialer, error) { +func createDialer(instance *box.Box, outboundTag string) (N.Dialer, error) { if outboundTag == "" { - return instance.Router().DefaultOutbound(N.NetworkName(network)) + return instance.Outbound().Default(), nil } else { - outbound, loaded := instance.Router().Outbound(outboundTag) + outbound, loaded := instance.Outbound().Outbound(outboundTag) if !loaded { return nil, E.New("outbound not found: ", outboundTag) } diff --git a/cmd/sing-box/cmd_tools_connect.go b/cmd/sing-box/cmd_tools_connect.go index 3ea04bcd..d352d533 100644 --- a/cmd/sing-box/cmd_tools_connect.go +++ b/cmd/sing-box/cmd_tools_connect.go @@ -45,7 +45,7 @@ func connect(address string) error { return err } defer instance.Close() - dialer, err := createDialer(instance, commandConnectFlagNetwork, commandToolsFlagOutbound) + dialer, err := createDialer(instance, commandToolsFlagOutbound) if err != nil { return err } diff --git a/cmd/sing-box/cmd_tools_fetch.go b/cmd/sing-box/cmd_tools_fetch.go index 3f62424a..5ee3b875 100644 --- a/cmd/sing-box/cmd_tools_fetch.go +++ b/cmd/sing-box/cmd_tools_fetch.go @@ -48,7 +48,7 @@ func fetch(args []string) error { httpClient = &http.Client{ Transport: &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - dialer, err := createDialer(instance, network, commandToolsFlagOutbound) + dialer, err := createDialer(instance, commandToolsFlagOutbound) if err != nil { return nil, err } diff --git a/cmd/sing-box/cmd_tools_fetch_http3.go b/cmd/sing-box/cmd_tools_fetch_http3.go index 5dc3d915..d72ed99d 100644 --- a/cmd/sing-box/cmd_tools_fetch_http3.go +++ b/cmd/sing-box/cmd_tools_fetch_http3.go @@ -16,7 +16,7 @@ import ( ) func initializeHTTP3Client(instance *box.Box) error { - dialer, err := createDialer(instance, N.NetworkUDP, commandToolsFlagOutbound) + dialer, err := createDialer(instance, commandToolsFlagOutbound) if err != nil { return err } diff --git a/cmd/sing-box/cmd_tools_synctime.go b/cmd/sing-box/cmd_tools_synctime.go index 20d73a6d..38510eb8 100644 --- a/cmd/sing-box/cmd_tools_synctime.go +++ b/cmd/sing-box/cmd_tools_synctime.go @@ -9,7 +9,6 @@ import ( "github.com/sagernet/sing-box/log" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/ntp" "github.com/spf13/cobra" @@ -45,7 +44,7 @@ func syncTime() error { if err != nil { return err } - dialer, err := createDialer(instance, N.NetworkUDP, commandToolsFlagOutbound) + dialer, err := createDialer(instance, commandToolsFlagOutbound) if err != nil { return err } diff --git a/common/dialer/detour.go b/common/dialer/detour.go index 81600913..c1d40faa 100644 --- a/common/dialer/detour.go +++ b/common/dialer/detour.go @@ -12,15 +12,15 @@ import ( ) type DetourDialer struct { - router adapter.Router - detour string - dialer N.Dialer - initOnce sync.Once - initErr error + outboundManager adapter.OutboundManager + detour string + dialer N.Dialer + initOnce sync.Once + initErr error } -func NewDetour(router adapter.Router, detour string) N.Dialer { - return &DetourDialer{router: router, detour: detour} +func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer { + return &DetourDialer{outboundManager: outboundManager, detour: detour} } func (d *DetourDialer) Start() error { @@ -31,7 +31,7 @@ func (d *DetourDialer) Start() error { func (d *DetourDialer) Dialer() (N.Dialer, error) { d.initOnce.Do(func() { var loaded bool - d.dialer, loaded = d.router.Outbound(d.detour) + d.dialer, loaded = d.outboundManager.Outbound(d.detour) if !loaded { d.initErr = E.New("outbound detour not found: ", d.detour) } diff --git a/common/dialer/dialer.go b/common/dialer/dialer.go index 56c5f2ad..fe4c7c12 100644 --- a/common/dialer/dialer.go +++ b/common/dialer/dialer.go @@ -1,21 +1,22 @@ package dialer import ( + "context" "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing-dns" + E "github.com/sagernet/sing/common/exceptions" N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/service" ) -func New(router adapter.Router, options option.DialerOptions) (N.Dialer, error) { +func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) { + router := service.FromContext[adapter.Router](ctx) if options.IsWireGuardListener { return NewDefault(router, options) } - if router == nil { - return NewDefault(nil, options) - } var ( dialer N.Dialer err error @@ -26,7 +27,14 @@ func New(router adapter.Router, options option.DialerOptions) (N.Dialer, error) return nil, err } } else { - dialer = NewDetour(router, options.Detour) + outboundManager := service.FromContext[adapter.OutboundManager](ctx) + if outboundManager == nil { + return nil, E.New("missing outbound manager") + } + dialer = NewDetour(outboundManager, options.Detour) + } + if router == nil { + return NewDefault(router, options) } if options.Detour == "" { dialer = NewResolveDialer( diff --git a/common/dialer/router.go b/common/dialer/router.go index 25316077..3edce65b 100644 --- a/common/dialer/router.go +++ b/common/dialer/router.go @@ -9,30 +9,22 @@ import ( N "github.com/sagernet/sing/common/network" ) -type RouterDialer struct { - router adapter.Router +type DefaultOutboundDialer struct { + outboundManager adapter.OutboundManager } -func NewRouter(router adapter.Router) N.Dialer { - return &RouterDialer{router: router} +func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer { + return &DefaultOutboundDialer{outboundManager: outboundManager} } -func (d *RouterDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - dialer, err := d.router.DefaultOutbound(network) - if err != nil { - return nil, err - } - return dialer.DialContext(ctx, network, destination) +func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + return d.outboundManager.Default().DialContext(ctx, network, destination) } -func (d *RouterDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - dialer, err := d.router.DefaultOutbound(N.NetworkUDP) - if err != nil { - return nil, err - } - return dialer.ListenPacket(ctx, destination) +func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + return d.outboundManager.Default().ListenPacket(ctx, destination) } -func (d *RouterDialer) Upstream() any { - return d.router +func (d *DefaultOutboundDialer) Upstream() any { + return d.outboundManager.Default() } diff --git a/common/settings/proxy_darwin.go b/common/settings/proxy_darwin.go index f03658a8..2d86fa2d 100644 --- a/common/settings/proxy_darwin.go +++ b/common/settings/proxy_darwin.go @@ -12,6 +12,7 @@ import ( M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/x/list" + "github.com/sagernet/sing/service" ) type DarwinSystemProxy struct { @@ -24,7 +25,7 @@ type DarwinSystemProxy struct { } func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bool) (*DarwinSystemProxy, error) { - interfaceMonitor := adapter.RouterFromContext(ctx).InterfaceMonitor() + interfaceMonitor := service.FromContext[adapter.Router](ctx).InterfaceMonitor() if interfaceMonitor == nil { return nil, E.New("missing interface monitor") } diff --git a/common/tls/ech_client.go b/common/tls/ech_client.go index 7f72b4d8..0433cebd 100644 --- a/common/tls/ech_client.go +++ b/common/tls/ech_client.go @@ -19,6 +19,7 @@ import ( "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/ntp" + "github.com/sagernet/sing/service" mDNS "github.com/miekg/dns" ) @@ -213,7 +214,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam }, }, } - response, err := adapter.RouterFromContext(ctx).Exchange(ctx, message) + response, err := service.FromContext[adapter.Router](ctx).Exchange(ctx, message) if err != nil { return nil, err } diff --git a/common/tls/reality_server.go b/common/tls/reality_server.go index a9318798..06501d9e 100644 --- a/common/tls/reality_server.go +++ b/common/tls/reality_server.go @@ -11,7 +11,6 @@ import ( "time" "github.com/sagernet/reality" - "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -102,7 +101,7 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb tlsConfig.ShortIds[shortID] = true } - handshakeDialer, err := dialer.New(adapter.RouterFromContext(ctx), options.Reality.Handshake.DialerOptions) + handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions) if err != nil { return nil, err } diff --git a/experimental/clashapi.go b/experimental/clashapi.go index 872d9b99..4ad07c8b 100644 --- a/experimental/clashapi.go +++ b/experimental/clashapi.go @@ -12,7 +12,7 @@ import ( "github.com/sagernet/sing/common" ) -type ClashServerConstructor = func(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) +type ClashServerConstructor = func(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) var clashServerConstructor ClashServerConstructor @@ -20,11 +20,11 @@ func RegisterClashServerConstructor(constructor ClashServerConstructor) { clashServerConstructor = constructor } -func NewClashServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { +func NewClashServer(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { if clashServerConstructor == nil { return nil, os.ErrInvalid } - return clashServerConstructor(ctx, router, logFactory, options) + return clashServerConstructor(ctx, logFactory, options) } func CalculateClashModeList(options option.Options) []string { diff --git a/experimental/clashapi/api_meta_group.go b/experimental/clashapi/api_meta_group.go index 531311f4..c5c07ba6 100644 --- a/experimental/clashapi/api_meta_group.go +++ b/experimental/clashapi/api_meta_group.go @@ -23,7 +23,7 @@ func groupRouter(server *Server) http.Handler { r := chi.NewRouter() r.Get("/", getGroups(server)) r.Route("/{name}", func(r chi.Router) { - r.Use(parseProxyName, findProxyByName(server.router)) + r.Use(parseProxyName, findProxyByName(server)) r.Get("/", getGroup(server)) r.Get("/delay", getGroupDelay(server)) }) @@ -32,7 +32,7 @@ func groupRouter(server *Server) http.Handler { func getGroups(server *Server) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - groups := common.Map(common.Filter(server.router.Outbounds(), func(it adapter.Outbound) bool { + groups := common.Map(common.Filter(server.outboundManager.Outbounds(), func(it adapter.Outbound) bool { _, isGroup := it.(adapter.OutboundGroup) return isGroup }), func(it adapter.Outbound) *badjson.JSONObject { @@ -86,7 +86,7 @@ func getGroupDelay(server *Server) func(w http.ResponseWriter, r *http.Request) result, err = urlTestGroup.URLTest(ctx) } else { outbounds := common.FilterNotNil(common.Map(outboundGroup.All(), func(it string) adapter.Outbound { - itOutbound, _ := server.router.Outbound(it) + itOutbound, _ := server.outboundManager.Outbound(it) return itOutbound })) b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10)) @@ -100,7 +100,7 @@ func getGroupDelay(server *Server) func(w http.ResponseWriter, r *http.Request) continue } checked[realTag] = true - p, loaded := server.router.Outbound(realTag) + p, loaded := server.outboundManager.Outbound(realTag) if !loaded { continue } diff --git a/experimental/clashapi/proxies.go b/experimental/clashapi/proxies.go index 4a9564ee..8d8ecb38 100644 --- a/experimental/clashapi/proxies.go +++ b/experimental/clashapi/proxies.go @@ -23,10 +23,10 @@ import ( func proxyRouter(server *Server, router adapter.Router) http.Handler { r := chi.NewRouter() - r.Get("/", getProxies(server, router)) + r.Get("/", getProxies(server)) r.Route("/{name}", func(r chi.Router) { - r.Use(parseProxyName, findProxyByName(router)) + r.Use(parseProxyName, findProxyByName(server)) r.Get("/", getProxy(server)) r.Get("/delay", getProxyDelay(server)) r.Put("/", updateProxy) @@ -42,11 +42,11 @@ func parseProxyName(next http.Handler) http.Handler { }) } -func findProxyByName(router adapter.Router) func(next http.Handler) http.Handler { +func findProxyByName(server *Server) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { name := r.Context().Value(CtxKeyProxyName).(string) - proxy, exist := router.Outbound(name) + proxy, exist := server.outboundManager.Outbound(name) if !exist { render.Status(r, http.StatusNotFound) render.JSON(w, r, ErrNotFound) @@ -83,10 +83,10 @@ func proxyInfo(server *Server, detour adapter.Outbound) *badjson.JSONObject { return &info } -func getProxies(server *Server, router adapter.Router) func(w http.ResponseWriter, r *http.Request) { +func getProxies(server *Server) func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) { var proxyMap badjson.JSONObject - outbounds := common.Filter(router.Outbounds(), func(detour adapter.Outbound) bool { + outbounds := common.Filter(server.outboundManager.Outbounds(), func(detour adapter.Outbound) bool { return detour.Tag() != "" }) @@ -100,12 +100,7 @@ func getProxies(server *Server, router adapter.Router) func(w http.ResponseWrite allProxies = append(allProxies, detour.Tag()) } - var defaultTag string - if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil { - defaultTag = defaultOutbound.Tag() - } else { - defaultTag = allProxies[0] - } + defaultTag := server.outboundManager.Default().Tag() sort.SliceStable(allProxies, func(i, j int) bool { return allProxies[i] == defaultTag diff --git a/experimental/clashapi/server.go b/experimental/clashapi/server.go index 889d191e..ef08a4be 100644 --- a/experimental/clashapi/server.go +++ b/experimental/clashapi/server.go @@ -40,15 +40,16 @@ func init() { var _ adapter.ClashServer = (*Server)(nil) type Server struct { - ctx context.Context - router adapter.Router - logger log.Logger - httpServer *http.Server - trafficManager *trafficontrol.Manager - urlTestHistory *urltest.HistoryStorage - mode string - modeList []string - modeUpdateHook chan<- struct{} + ctx context.Context + router adapter.Router + outboundManager adapter.OutboundManager + logger log.Logger + httpServer *http.Server + trafficManager *trafficontrol.Manager + urlTestHistory *urltest.HistoryStorage + mode string + modeList []string + modeUpdateHook chan<- struct{} externalController bool externalUI string @@ -56,13 +57,14 @@ type Server struct { externalUIDownloadDetour string } -func NewServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { +func NewServer(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { trafficManager := trafficontrol.NewManager() chiRouter := chi.NewRouter() - server := &Server{ - ctx: ctx, - router: router, - logger: logFactory.NewLogger("clash-api"), + s := &Server{ + ctx: ctx, + router: service.FromContext[adapter.Router](ctx), + outboundManager: service.FromContext[adapter.OutboundManager](ctx), + logger: logFactory.NewLogger("clash-api"), httpServer: &http.Server{ Addr: options.ExternalController, Handler: chiRouter, @@ -73,18 +75,18 @@ func NewServer(ctx context.Context, router adapter.Router, logFactory log.Observ externalUIDownloadURL: options.ExternalUIDownloadURL, externalUIDownloadDetour: options.ExternalUIDownloadDetour, } - server.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx) - if server.urlTestHistory == nil { - server.urlTestHistory = urltest.NewHistoryStorage() + s.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx) + if s.urlTestHistory == nil { + s.urlTestHistory = urltest.NewHistoryStorage() } defaultMode := "Rule" if options.DefaultMode != "" { defaultMode = options.DefaultMode } - if !common.Contains(server.modeList, defaultMode) { - server.modeList = append([]string{defaultMode}, server.modeList...) + if !common.Contains(s.modeList, defaultMode) { + s.modeList = append([]string{defaultMode}, s.modeList...) } - server.mode = defaultMode + s.mode = defaultMode //goland:noinspection GoDeprecation //nolint:staticcheck if options.StoreMode || options.StoreSelected || options.StoreFakeIP || options.CacheFile != "" || options.CacheID != "" { @@ -108,30 +110,30 @@ func NewServer(ctx context.Context, router adapter.Router, logFactory log.Observ r.Get("/logs", getLogs(logFactory)) r.Get("/traffic", traffic(trafficManager)) r.Get("/version", version) - r.Mount("/configs", configRouter(server, logFactory)) - r.Mount("/proxies", proxyRouter(server, router)) - r.Mount("/rules", ruleRouter(router)) - r.Mount("/connections", connectionRouter(router, trafficManager)) + r.Mount("/configs", configRouter(s, logFactory)) + r.Mount("/proxies", proxyRouter(s, s.router)) + r.Mount("/rules", ruleRouter(s.router)) + r.Mount("/connections", connectionRouter(s.router, trafficManager)) r.Mount("/providers/proxies", proxyProviderRouter()) r.Mount("/providers/rules", ruleProviderRouter()) r.Mount("/script", scriptRouter()) r.Mount("/profile", profileRouter()) r.Mount("/cache", cacheRouter(ctx)) - r.Mount("/dns", dnsRouter(router)) + r.Mount("/dns", dnsRouter(s.router)) - server.setupMetaAPI(r) + s.setupMetaAPI(r) }) if options.ExternalUI != "" { - server.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI)) + s.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI)) chiRouter.Group(func(r chi.Router) { - fs := http.StripPrefix("/ui", http.FileServer(http.Dir(server.externalUI))) + fs := http.StripPrefix("/ui", http.FileServer(http.Dir(s.externalUI))) r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP) r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) { fs.ServeHTTP(w, r) }) }) } - return server, nil + return s, nil } func (s *Server) PreStart() error { @@ -235,12 +237,12 @@ func (s *Server) TrafficManager() *trafficontrol.Manager { } func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) { - tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.router, matchedRule) + tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule) return tracker, tracker } func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) { - tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.router, matchedRule) + tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule) return tracker, tracker } diff --git a/experimental/clashapi/server_resources.go b/experimental/clashapi/server_resources.go index a5c79e0c..e5b28e30 100644 --- a/experimental/clashapi/server_resources.go +++ b/experimental/clashapi/server_resources.go @@ -15,7 +15,6 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" - N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/service/filemanager" ) @@ -45,16 +44,13 @@ func (s *Server) downloadExternalUI() error { s.logger.Info("downloading external ui") var detour adapter.Outbound if s.externalUIDownloadDetour != "" { - outbound, loaded := s.router.Outbound(s.externalUIDownloadDetour) + outbound, loaded := s.outboundManager.Outbound(s.externalUIDownloadDetour) if !loaded { return E.New("detour outbound not found: ", s.externalUIDownloadDetour) } detour = outbound } else { - outbound, err := s.router.DefaultOutbound(N.NetworkTCP) - if err != nil { - return err - } + outbound := s.outboundManager.Default() detour = outbound } httpClient := &http.Client{ diff --git a/experimental/clashapi/trafficontrol/tracker.go b/experimental/clashapi/trafficontrol/tracker.go index 9c18abeb..df5437fa 100644 --- a/experimental/clashapi/trafficontrol/tracker.go +++ b/experimental/clashapi/trafficontrol/tracker.go @@ -124,7 +124,7 @@ func (tt *TCPConn) WriterReplaceable() bool { return true } -func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, router adapter.Router, rule adapter.Rule) *TCPConn { +func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *TCPConn { id, _ := uuid.NewV4() var ( chain []string @@ -138,11 +138,11 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont } if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { next = routeAction.Outbound - } else if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil { - next = defaultOutbound.Tag() + } else { + next = outboundManager.Default().Tag() } for { - detour, loaded := router.Outbound(next) + detour, loaded := outboundManager.Outbound(next) if !loaded { break } @@ -213,7 +213,7 @@ func (ut *UDPConn) WriterReplaceable() bool { return true } -func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, router adapter.Router, rule adapter.Rule) *UDPConn { +func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *UDPConn { id, _ := uuid.NewV4() var ( chain []string @@ -227,11 +227,11 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound } if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { next = routeAction.Outbound - } else if defaultOutbound, err := router.DefaultOutbound(N.NetworkUDP); err == nil { - next = defaultOutbound.Tag() + } else { + next = outboundManager.Default().Tag() } for { - detour, loaded := router.Outbound(next) + detour, loaded := outboundManager.Outbound(next) if !loaded { break } diff --git a/protocol/direct/outbound.go b/protocol/direct/outbound.go index 32c1ed8f..27b334c9 100644 --- a/protocol/direct/outbound.go +++ b/protocol/direct/outbound.go @@ -12,7 +12,7 @@ import ( C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" - "github.com/sagernet/sing-dns" + dns "github.com/sagernet/sing-dns" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" @@ -39,7 +39,7 @@ type Outbound struct { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) { options.UDPFragmentDefault = true - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/group/selector.go b/protocol/group/selector.go index 32ab8b2a..dc0b7cce 100644 --- a/protocol/group/selector.go +++ b/protocol/group/selector.go @@ -26,7 +26,7 @@ var _ adapter.OutboundGroup = (*Selector)(nil) type Selector struct { outbound.Adapter ctx context.Context - router adapter.Router + outboundManager adapter.OutboundManager logger logger.ContextLogger tags []string defaultTag string @@ -40,7 +40,7 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL outbound := &Selector{ Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds), ctx: ctx, - router: router, + outboundManager: service.FromContext[adapter.OutboundManager](ctx), logger: logger, tags: options.Outbounds, defaultTag: options.Default, @@ -63,7 +63,7 @@ func (s *Selector) Network() []string { func (s *Selector) Start() error { for i, tag := range s.tags { - detour, loaded := s.router.Outbound(tag) + detour, loaded := s.outboundManager.Outbound(tag) if !loaded { return E.New("outbound ", i, " not found: ", tag) } diff --git a/protocol/group/urltest.go b/protocol/group/urltest.go index ccdf809d..4d76a31c 100644 --- a/protocol/group/urltest.go +++ b/protocol/group/urltest.go @@ -36,6 +36,7 @@ type URLTest struct { outbound.Adapter ctx context.Context router adapter.Router + outboundManager adapter.OutboundManager logger log.ContextLogger tags []string link string @@ -51,6 +52,7 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds), ctx: ctx, router: router, + outboundManager: service.FromContext[adapter.OutboundManager](ctx), logger: logger, tags: options.Outbounds, link: options.URL, @@ -68,7 +70,7 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo func (s *URLTest) Start() error { outbounds := make([]adapter.Outbound, 0, len(s.tags)) for i, tag := range s.tags { - detour, loaded := s.router.Outbound(tag) + detour, loaded := s.outboundManager.Outbound(tag) if !loaded { return E.New("outbound ", i, " not found: ", tag) } @@ -77,6 +79,7 @@ func (s *URLTest) Start() error { group, err := NewURLTestGroup( s.ctx, s.router, + s.outboundManager, s.logger, outbounds, s.link, @@ -190,6 +193,7 @@ func (s *URLTest) InterfaceUpdated() { type URLTestGroup struct { ctx context.Context router adapter.Router + outboundManager adapter.OutboundManager logger log.Logger outbounds []adapter.Outbound link string @@ -214,6 +218,7 @@ type URLTestGroup struct { func NewURLTestGroup( ctx context.Context, router adapter.Router, + outboundManager adapter.OutboundManager, logger log.Logger, outbounds []adapter.Outbound, link string, @@ -244,6 +249,7 @@ func NewURLTestGroup( return &URLTestGroup{ ctx: ctx, router: router, + outboundManager: outboundManager, logger: logger, outbounds: outbounds, link: link, @@ -385,7 +391,7 @@ func (g *URLTestGroup) urlTest(ctx context.Context, force bool) (map[string]uint continue } checked[realTag] = true - p, loaded := g.router.Outbound(realTag) + p, loaded := g.outboundManager.Outbound(realTag) if !loaded { continue } diff --git a/protocol/http/outbound.go b/protocol/http/outbound.go index 4c930591..81fd0246 100644 --- a/protocol/http/outbound.go +++ b/protocol/http/outbound.go @@ -30,7 +30,7 @@ type Outbound struct { } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (adapter.Outbound, error) { - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/hysteria/outbound.go b/protocol/hysteria/outbound.go index 4722f4f0..e4c8775f 100644 --- a/protocol/hysteria/outbound.go +++ b/protocol/hysteria/outbound.go @@ -47,7 +47,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/hysteria2/outbound.go b/protocol/hysteria2/outbound.go index 5ebc6c91..4cabb475 100644 --- a/protocol/hysteria2/outbound.go +++ b/protocol/hysteria2/outbound.go @@ -59,7 +59,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL return nil, E.New("unknown obfs type: ", options.Obfs.Type) } } - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/shadowsocks/outbound.go b/protocol/shadowsocks/outbound.go index 73b38385..8771fa8e 100644 --- a/protocol/shadowsocks/outbound.go +++ b/protocol/shadowsocks/outbound.go @@ -44,7 +44,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/shadowtls/inbound.go b/protocol/shadowtls/inbound.go index 6887e838..ce0431a6 100644 --- a/protocol/shadowtls/inbound.go +++ b/protocol/shadowtls/inbound.go @@ -46,7 +46,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo if options.Version > 1 { handshakeForServerName = make(map[string]shadowtls.HandshakeConfig) for serverName, serverOptions := range options.HandshakeForServerName { - handshakeDialer, err := dialer.New(router, serverOptions.DialerOptions) + handshakeDialer, err := dialer.New(ctx, serverOptions.DialerOptions) if err != nil { return nil, err } @@ -56,7 +56,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo } } } - handshakeDialer, err := dialer.New(router, options.Handshake.DialerOptions) + handshakeDialer, err := dialer.New(ctx, options.Handshake.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/shadowtls/outbound.go b/protocol/shadowtls/outbound.go index 7d46a8f6..e979dba2 100644 --- a/protocol/shadowtls/outbound.go +++ b/protocol/shadowtls/outbound.go @@ -68,7 +68,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL tlsHandshakeFunc = shadowtls.DefaultTLSHandshakeFunc(options.Password, stdTLSConfig) } } - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/socks/outbound.go b/protocol/socks/outbound.go index 0194800a..dbb5ab61 100644 --- a/protocol/socks/outbound.go +++ b/protocol/socks/outbound.go @@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL if err != nil { return nil, err } - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/ssh/outbound.go b/protocol/ssh/outbound.go index 62a2a8d9..1dfc1f6d 100644 --- a/protocol/ssh/outbound.go +++ b/protocol/ssh/outbound.go @@ -49,7 +49,7 @@ type Outbound struct { } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (adapter.Outbound, error) { - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/tor/outbound.go b/protocol/tor/outbound.go index 89a295b8..3d217011 100644 --- a/protocol/tor/outbound.go +++ b/protocol/tor/outbound.go @@ -75,7 +75,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL } startConf.TorrcFile = torrcFile } - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/trojan/outbound.go b/protocol/trojan/outbound.go index f64c48c3..68b00690 100644 --- a/protocol/trojan/outbound.go +++ b/protocol/trojan/outbound.go @@ -38,7 +38,7 @@ type Outbound struct { } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanOutboundOptions) (adapter.Outbound, error) { - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/tuic/outbound.go b/protocol/tuic/outbound.go index 691d1658..177f21fc 100644 --- a/protocol/tuic/outbound.go +++ b/protocol/tuic/outbound.go @@ -60,7 +60,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL case "quic": tuicUDPStream = true } - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/vless/outbound.go b/protocol/vless/outbound.go index 1074549e..de655230 100644 --- a/protocol/vless/outbound.go +++ b/protocol/vless/outbound.go @@ -41,7 +41,7 @@ type Outbound struct { } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VLESSOutboundOptions) (adapter.Outbound, error) { - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/vmess/outbound.go b/protocol/vmess/outbound.go index 759ea8ba..1e84639f 100644 --- a/protocol/vmess/outbound.go +++ b/protocol/vmess/outbound.go @@ -41,7 +41,7 @@ type Outbound struct { } func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessOutboundOptions) (adapter.Outbound, error) { - outboundDialer, err := dialer.New(router, options.DialerOptions) + outboundDialer, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/protocol/wireguard/outbound.go b/protocol/wireguard/outbound.go index 7251de9e..90f76c1a 100644 --- a/protocol/wireguard/outbound.go +++ b/protocol/wireguard/outbound.go @@ -78,7 +78,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL options.IsWireGuardListener = true outbound.useStdNetBind = true } - listener, err := dialer.New(router, options.DialerOptions) + listener, err := dialer.New(ctx, options.DialerOptions) if err != nil { return nil, err } diff --git a/route/router.go b/route/router.go index 9a57328a..557b043a 100644 --- a/route/router.go +++ b/route/router.go @@ -87,7 +87,7 @@ type Router struct { v2rayServer adapter.V2RayServer platformInterface platform.Interface needWIFIState bool - needPackageManager bool + enforcePackageManager bool wifiState adapter.WIFIState started bool } @@ -123,7 +123,7 @@ func NewRouter( pauseManager: service.FromContext[pause.Manager](ctx), platformInterface: service.FromContext[platform.Interface](ctx), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), - needPackageManager: common.Any(inbounds, func(inbound option.Inbound) bool { + enforcePackageManager: common.Any(inbounds, func(inbound option.Inbound) bool { if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute { return true } @@ -191,7 +191,8 @@ func NewRouter( transportTags[i] = tag transportTagMap[tag] = true } - ctx = adapter.ContextWithRouter(ctx, router) + ctx = service.ContextWith[adapter.Router](ctx, router) + outboundManager := service.FromContext[adapter.OutboundManager](ctx) for { lastLen := len(dummyTransportMap) for i, server := range dnsOptions.Servers { @@ -201,9 +202,9 @@ func NewRouter( } var detour N.Dialer if server.Detour == "" { - detour = dialer.NewRouter(router) + detour = dialer.NewDefaultOutbound(outboundManager) } else { - detour = dialer.NewDetour(router, server.Detour) + detour = dialer.NewDetour(outboundManager, server.Detour) } var serverProtocol string switch server.Address { @@ -327,7 +328,7 @@ func NewRouter( } usePlatformDefaultInterfaceMonitor := router.platformInterface != nil && router.platformInterface.UsePlatformDefaultInterfaceMonitor() - needInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool { + enforceInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool { if httpMixedOptions, isHTTPMixed := inbound.Options.(*option.HTTPMixedInboundOptions); isHTTPMixed && httpMixedOptions.SetSystemProxy { return true } @@ -339,7 +340,7 @@ func NewRouter( if !usePlatformDefaultInterfaceMonitor { networkMonitor, err := tun.NewNetworkUpdateMonitor(router.logger) - if !((err != nil && !needInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) { + if !((err != nil && !enforceInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) { if err != nil { return nil, err } @@ -365,7 +366,7 @@ func NewRouter( } if ntpOptions.Enabled { - ntpDialer, err := dialer.New(router, ntpOptions.DialerOptions) + ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions) if err != nil { return nil, E.Cause(err, "create NTP service") } @@ -383,73 +384,6 @@ func NewRouter( return router, nil } -func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error { - inboundByTag := make(map[string]adapter.Inbound) - for _, inbound := range inbounds { - inboundByTag[inbound.Tag()] = inbound - } - outboundByTag := make(map[string]adapter.Outbound) - for _, detour := range outbounds { - outboundByTag[detour.Tag()] = detour - } - var defaultOutboundForConnection adapter.Outbound - var defaultOutboundForPacketConnection adapter.Outbound - if r.defaultDetour != "" { - detour, loaded := outboundByTag[r.defaultDetour] - if !loaded { - return E.New("default detour not found: ", r.defaultDetour) - } - if common.Contains(detour.Network(), N.NetworkTCP) { - defaultOutboundForConnection = detour - } - if common.Contains(detour.Network(), N.NetworkUDP) { - defaultOutboundForPacketConnection = detour - } - } - if defaultOutboundForConnection == nil { - for _, detour := range outbounds { - if common.Contains(detour.Network(), N.NetworkTCP) { - defaultOutboundForConnection = detour - break - } - } - } - if defaultOutboundForPacketConnection == nil { - for _, detour := range outbounds { - if common.Contains(detour.Network(), N.NetworkUDP) { - defaultOutboundForPacketConnection = detour - break - } - } - } - if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil { - detour := defaultOutbound() - if defaultOutboundForConnection == nil { - defaultOutboundForConnection = detour - } - if defaultOutboundForPacketConnection == nil { - defaultOutboundForPacketConnection = detour - } - outbounds = append(outbounds, detour) - outboundByTag[detour.Tag()] = detour - } - r.inboundByTag = inboundByTag - r.outbounds = outbounds - r.defaultOutboundForConnection = defaultOutboundForConnection - r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection - r.outboundByTag = outboundByTag - for i, rule := range r.rules { - routeAction, isRoute := rule.Action().(*R.RuleActionRoute) - if !isRoute { - continue - } - if _, loaded := outboundByTag[routeAction.Outbound]; !loaded { - return E.New("outbound not found for rule[", i, "]: ", routeAction.Outbound) - } - } - return nil -} - func (r *Router) Outbounds() []adapter.Outbound { if !r.started { return nil @@ -457,140 +391,240 @@ func (r *Router) Outbounds() []adapter.Outbound { return r.outbounds } -func (r *Router) PreStart() error { +func (r *Router) Start(stage adapter.StartStage) error { monitor := taskmonitor.New(r.logger, C.StartTimeout) - if r.interfaceMonitor != nil { - monitor.Start("initialize interface monitor") - err := r.interfaceMonitor.Start() - monitor.Finish() - if err != nil { - return err - } - } - if r.networkMonitor != nil { - monitor.Start("initialize network monitor") - err := r.networkMonitor.Start() - monitor.Finish() - if err != nil { - return err - } - } - if r.fakeIPStore != nil { - monitor.Start("initialize fakeip store") - err := r.fakeIPStore.Start() - monitor.Finish() - if err != nil { - return err - } - } - return nil -} - -func (r *Router) Start() error { - monitor := taskmonitor.New(r.logger, C.StartTimeout) - if r.needGeoIPDatabase { - monitor.Start("initialize geoip database") - err := r.prepareGeoIPDatabase() - monitor.Finish() - if err != nil { - return err - } - } - if r.needGeositeDatabase { - monitor.Start("initialize geosite database") - err := r.prepareGeositeDatabase() - monitor.Finish() - if err != nil { - return err - } - } - if r.needGeositeDatabase { - for _, rule := range r.rules { - err := rule.UpdateGeosite() - if err != nil { - r.logger.Error("failed to initialize geosite: ", err) - } - } - for _, rule := range r.dnsRules { - err := rule.UpdateGeosite() - if err != nil { - r.logger.Error("failed to initialize geosite: ", err) - } - } - err := common.Close(r.geositeReader) - if err != nil { - return err - } - r.geositeCache = nil - r.geositeReader = nil - } - - if runtime.GOOS == "windows" { - powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent) - if err == nil { - r.powerListener = powerListener - } else { - r.logger.Warn("initialize power listener: ", err) - } - } - - if r.powerListener != nil { - monitor.Start("start power listener") - err := r.powerListener.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "start power listener") - } - } - - monitor.Start("initialize DNS client") - r.dnsClient.Start() - monitor.Finish() - - if C.IsAndroid && r.platformInterface == nil { - monitor.Start("initialize package manager") - packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{ - Callback: r, - Logger: r.logger, - }) - monitor.Finish() - if err != nil { - return E.Cause(err, "create package manager") - } - if r.needPackageManager { - monitor.Start("start package manager") - err = packageManager.Start() + switch stage { + case adapter.StartStateInitialize: + if r.interfaceMonitor != nil { + monitor.Start("initialize interface monitor") + err := r.interfaceMonitor.Start() monitor.Finish() if err != nil { - return E.Cause(err, "start package manager") + return err } } - r.packageManager = packageManager - } + if r.networkMonitor != nil { + monitor.Start("initialize network monitor") + err := r.networkMonitor.Start() + monitor.Finish() + if err != nil { + return err + } + } + if r.fakeIPStore != nil { + monitor.Start("initialize fakeip store") + err := r.fakeIPStore.Start() + monitor.Finish() + if err != nil { + return err + } + } + case adapter.StartStateStart: + if r.needGeoIPDatabase { + monitor.Start("initialize geoip database") + err := r.prepareGeoIPDatabase() + monitor.Finish() + if err != nil { + return err + } + } + if r.needGeositeDatabase { + monitor.Start("initialize geosite database") + err := r.prepareGeositeDatabase() + monitor.Finish() + if err != nil { + return err + } + } + if r.needGeositeDatabase { + for _, rule := range r.rules { + err := rule.UpdateGeosite() + if err != nil { + r.logger.Error("failed to initialize geosite: ", err) + } + } + for _, rule := range r.dnsRules { + err := rule.UpdateGeosite() + if err != nil { + r.logger.Error("failed to initialize geosite: ", err) + } + } + err := common.Close(r.geositeReader) + if err != nil { + return err + } + r.geositeCache = nil + r.geositeReader = nil + } - for i, rule := range r.dnsRules { - monitor.Start("initialize DNS rule[", i, "]") - err := rule.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize DNS rule[", i, "]") + if runtime.GOOS == "windows" { + powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent) + if err == nil { + r.powerListener = powerListener + } else { + r.logger.Warn("initialize power listener: ", err) + } } - } - for i, transport := range r.transports { - monitor.Start("initialize DNS transport[", i, "]") - err := transport.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize DNS server[", i, "]") + + if r.powerListener != nil { + monitor.Start("start power listener") + err := r.powerListener.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "start power listener") + } } - } - if r.timeService != nil { - monitor.Start("initialize time service") - err := r.timeService.Start() + + monitor.Start("initialize DNS client") + r.dnsClient.Start() monitor.Finish() - if err != nil { - return E.Cause(err, "initialize time service") + + if C.IsAndroid && r.platformInterface == nil { + monitor.Start("initialize package manager") + packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{ + Callback: r, + Logger: r.logger, + }) + monitor.Finish() + if err != nil { + return E.Cause(err, "create package manager") + } + if r.enforcePackageManager { + monitor.Start("start package manager") + err = packageManager.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "start package manager") + } + } + r.packageManager = packageManager } + + for i, rule := range r.dnsRules { + monitor.Start("initialize DNS rule[", i, "]") + err := rule.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "initialize DNS rule[", i, "]") + } + } + for i, transport := range r.transports { + monitor.Start("initialize DNS transport[", i, "]") + err := transport.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "initialize DNS server[", i, "]") + } + } + if r.timeService != nil { + monitor.Start("initialize time service") + err := r.timeService.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "initialize time service") + } + } + case adapter.StartStatePostStart: + var cacheContext *adapter.HTTPStartContext + if len(r.ruleSets) > 0 { + monitor.Start("initialize rule-set") + cacheContext = adapter.NewHTTPStartContext() + var ruleSetStartGroup task.Group + for i, ruleSet := range r.ruleSets { + ruleSetInPlace := ruleSet + ruleSetStartGroup.Append0(func(ctx context.Context) error { + err := ruleSetInPlace.StartContext(ctx, cacheContext) + if err != nil { + return E.Cause(err, "initialize rule-set[", i, "]") + } + return nil + }) + } + ruleSetStartGroup.Concurrency(5) + ruleSetStartGroup.FastFail() + err := ruleSetStartGroup.Run(r.ctx) + monitor.Finish() + if err != nil { + return err + } + } + if cacheContext != nil { + cacheContext.Close() + } + needFindProcess := r.needFindProcess + needWIFIState := r.needWIFIState + for _, ruleSet := range r.ruleSets { + metadata := ruleSet.Metadata() + if metadata.ContainsProcessRule { + needFindProcess = true + } + if metadata.ContainsWIFIRule { + needWIFIState = true + } + } + if C.IsAndroid && r.platformInterface == nil && !r.enforcePackageManager { + if needFindProcess { + monitor.Start("start package manager") + err := r.packageManager.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "start package manager") + } + } else { + r.packageManager = nil + } + } + if needFindProcess { + if r.platformInterface != nil { + r.processSearcher = r.platformInterface + } else { + monitor.Start("initialize process searcher") + searcher, err := process.NewSearcher(process.Config{ + Logger: r.logger, + PackageManager: r.packageManager, + }) + monitor.Finish() + if err != nil { + if err != os.ErrInvalid { + r.logger.Warn(E.Cause(err, "create process searcher")) + } + } else { + r.processSearcher = searcher + } + } + } + if needWIFIState && r.platformInterface != nil { + monitor.Start("initialize WIFI state") + r.needWIFIState = true + r.interfaceMonitor.RegisterCallback(func(_ int) { + r.updateWIFIState() + }) + r.updateWIFIState() + monitor.Finish() + } + for i, rule := range r.rules { + monitor.Start("initialize rule[", i, "]") + err := rule.Start() + monitor.Finish() + if err != nil { + return E.Cause(err, "initialize rule[", i, "]") + } + } + for _, ruleSet := range r.ruleSets { + monitor.Start("post start rule_set[", ruleSet.Name(), "]") + err := ruleSet.PostStart() + monitor.Finish() + if err != nil { + return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]") + } + } + r.started = true + return nil + case adapter.StartStateStarted: + for _, ruleSet := range r.ruleSetMap { + ruleSet.Cleanup() + } + runtime.GC() } return nil } @@ -671,113 +705,6 @@ func (r *Router) Close() error { return err } -func (r *Router) PostStart() error { - monitor := taskmonitor.New(r.logger, C.StopTimeout) - var cacheContext *adapter.HTTPStartContext - if len(r.ruleSets) > 0 { - monitor.Start("initialize rule-set") - cacheContext = adapter.NewHTTPStartContext() - var ruleSetStartGroup task.Group - for i, ruleSet := range r.ruleSets { - ruleSetInPlace := ruleSet - ruleSetStartGroup.Append0(func(ctx context.Context) error { - err := ruleSetInPlace.StartContext(ctx, cacheContext) - if err != nil { - return E.Cause(err, "initialize rule-set[", i, "]") - } - return nil - }) - } - ruleSetStartGroup.Concurrency(5) - ruleSetStartGroup.FastFail() - err := ruleSetStartGroup.Run(r.ctx) - monitor.Finish() - if err != nil { - return err - } - } - if cacheContext != nil { - cacheContext.Close() - } - needFindProcess := r.needFindProcess - needWIFIState := r.needWIFIState - for _, ruleSet := range r.ruleSets { - metadata := ruleSet.Metadata() - if metadata.ContainsProcessRule { - needFindProcess = true - } - if metadata.ContainsWIFIRule { - needWIFIState = true - } - } - if C.IsAndroid && r.platformInterface == nil && !r.needPackageManager { - if needFindProcess { - monitor.Start("start package manager") - err := r.packageManager.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "start package manager") - } - } else { - r.packageManager = nil - } - } - if needFindProcess { - if r.platformInterface != nil { - r.processSearcher = r.platformInterface - } else { - monitor.Start("initialize process searcher") - searcher, err := process.NewSearcher(process.Config{ - Logger: r.logger, - PackageManager: r.packageManager, - }) - monitor.Finish() - if err != nil { - if err != os.ErrInvalid { - r.logger.Warn(E.Cause(err, "create process searcher")) - } - } else { - r.processSearcher = searcher - } - } - } - if needWIFIState && r.platformInterface != nil { - monitor.Start("initialize WIFI state") - r.needWIFIState = true - r.interfaceMonitor.RegisterCallback(func(_ int) { - r.updateWIFIState() - }) - r.updateWIFIState() - monitor.Finish() - } - for i, rule := range r.rules { - monitor.Start("initialize rule[", i, "]") - err := rule.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize rule[", i, "]") - } - } - for _, ruleSet := range r.ruleSets { - monitor.Start("post start rule_set[", ruleSet.Name(), "]") - err := ruleSet.PostStart() - monitor.Finish() - if err != nil { - return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]") - } - } - r.started = true - return nil -} - -func (r *Router) Cleanup() error { - for _, ruleSet := range r.ruleSetMap { - ruleSet.Cleanup() - } - runtime.GC() - return nil -} - func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { outbound, loaded := r.outboundByTag[tag] return outbound, loaded diff --git a/route/rule/rule_action.go b/route/rule/rule_action.go index 0bf45ba2..00579c18 100644 --- a/route/rule/rule_action.go +++ b/route/rule/rule_action.go @@ -22,7 +22,7 @@ import ( N "github.com/sagernet/sing/common/network" ) -func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) { +func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) { switch action.Action { case "": return nil, nil @@ -36,7 +36,7 @@ func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action op UDPConnect: action.RouteOptionsOptions.UDPConnect, }, nil case C.RuleActionTypeDirect: - directDialer, err := dialer.New(router, option.DialerOptions(action.DirectOptions)) + directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions)) if err != nil { return nil, err } diff --git a/route/rule/rule_default.go b/route/rule/rule_default.go index 566c816e..a12c63ef 100644 --- a/route/rule/rule_default.go +++ b/route/rule/rule_default.go @@ -52,7 +52,7 @@ type RuleItem interface { } func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { - action, err := NewRuleAction(router, logger, options.RuleAction) + action, err := NewRuleAction(ctx, logger, options.RuleAction) if err != nil { return nil, E.Cause(err, "action") } @@ -254,7 +254,7 @@ type LogicalRule struct { } func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { - action, err := NewRuleAction(router, logger, options.RuleAction) + action, err := NewRuleAction(ctx, logger, options.RuleAction) if err != nil { return nil, E.Cause(err, "action") } diff --git a/route/rule/rule_set_remote.go b/route/rule/rule_set_remote.go index bdafa656..55c863e9 100644 --- a/route/rule/rule_set_remote.go +++ b/route/rule/rule_set_remote.go @@ -33,23 +33,24 @@ import ( var _ adapter.RuleSet = (*RemoteRuleSet)(nil) type RemoteRuleSet struct { - ctx context.Context - cancel context.CancelFunc - router adapter.Router - logger logger.ContextLogger - options option.RuleSet - metadata adapter.RuleSetMetadata - updateInterval time.Duration - dialer N.Dialer - rules []adapter.HeadlessRule - lastUpdated time.Time - lastEtag string - updateTicker *time.Ticker - cacheFile adapter.CacheFile - pauseManager pause.Manager - callbackAccess sync.Mutex - callbacks list.List[adapter.RuleSetUpdateCallback] - refs atomic.Int32 + ctx context.Context + cancel context.CancelFunc + router adapter.Router + outboundManager adapter.OutboundManager + logger logger.ContextLogger + options option.RuleSet + metadata adapter.RuleSetMetadata + updateInterval time.Duration + dialer N.Dialer + rules []adapter.HeadlessRule + lastUpdated time.Time + lastEtag string + updateTicker *time.Ticker + cacheFile adapter.CacheFile + pauseManager pause.Manager + callbackAccess sync.Mutex + callbacks list.List[adapter.RuleSetUpdateCallback] + refs atomic.Int32 } func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { @@ -61,13 +62,14 @@ func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger. updateInterval = 24 * time.Hour } return &RemoteRuleSet{ - ctx: ctx, - cancel: cancel, - router: router, - logger: logger, - options: options, - updateInterval: updateInterval, - pauseManager: service.FromContext[pause.Manager](ctx), + ctx: ctx, + cancel: cancel, + router: router, + outboundManager: service.FromContext[adapter.OutboundManager](ctx), + logger: logger, + options: options, + updateInterval: updateInterval, + pauseManager: service.FromContext[pause.Manager](ctx), } } @@ -83,17 +85,13 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter. s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) var dialer N.Dialer if s.options.RemoteOptions.DownloadDetour != "" { - outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) + outbound, loaded := s.outboundManager.Outbound(s.options.RemoteOptions.DownloadDetour) if !loaded { return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour) } dialer = outbound } else { - outbound, err := s.router.DefaultOutbound(N.NetworkTCP) - if err != nil { - return err - } - dialer = outbound + dialer = s.outboundManager.Default() } s.dialer = dialer if s.cacheFile != nil { diff --git a/transport/dhcp/server.go b/transport/dhcp/server.go index 8325c37b..d5603e9e 100644 --- a/transport/dhcp/server.go +++ b/transport/dhcp/server.go @@ -23,6 +23,7 @@ import ( E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/x/list" + "github.com/sagernet/sing/service" "github.com/insomniacslk/dhcp/dhcpv4" mDNS "github.com/miekg/dns" @@ -53,7 +54,7 @@ func NewTransport(options dns.TransportOptions) (*Transport, error) { if linkURL.Host == "" { return nil, E.New("missing interface name for DHCP") } - router := adapter.RouterFromContext(options.Context) + router := service.FromContext[adapter.Router](options.Context) if router == nil { return nil, E.New("missing router in context") } diff --git a/transport/fakeip/server.go b/transport/fakeip/server.go index 5e0c7eef..d1bbb2aa 100644 --- a/transport/fakeip/server.go +++ b/transport/fakeip/server.go @@ -9,6 +9,7 @@ import ( "github.com/sagernet/sing-dns" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" + "github.com/sagernet/sing/service" mDNS "github.com/miekg/dns" ) @@ -32,7 +33,7 @@ type Transport struct { } func NewTransport(options dns.TransportOptions) (*Transport, error) { - router := adapter.RouterFromContext(options.Context) + router := service.FromContext[adapter.Router](options.Context) if router == nil { return nil, E.New("missing router in context") }