refactor: Modular inbound/outbound manager

This commit is contained in:
世界 2024-11-09 21:16:11 +08:00
parent 1ee7a4a272
commit beaab2e4db
No known key found for this signature in database
GPG Key ID: CD109927C34A63C4
52 changed files with 982 additions and 680 deletions

View File

@ -15,7 +15,7 @@ import (
type ClashServer interface { type ClashServer interface {
Service Service
PreStarter LegacyPreStarter
Mode() string Mode() string
ModeList() []string ModeList() []string
HistoryStorage() *urltest.HistoryStorage HistoryStorage() *urltest.HistoryStorage
@ -25,7 +25,7 @@ type ClashServer interface {
type CacheFile interface { type CacheFile interface {
Service Service
PreStarter LegacyPreStarter
StoreFakeIP() bool StoreFakeIP() bool
FakeIPStorage FakeIPStorage

View File

@ -28,7 +28,15 @@ type UDPInjectableInbound interface {
type InboundRegistry interface { type InboundRegistry interface {
option.InboundOptionsRegistry option.InboundOptionsRegistry
CreateInbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Inbound, error) Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, inboundType string, options any) (Inbound, error)
}
type InboundManager interface {
NewService
Inbounds() []Inbound
Get(tag string) (Inbound, bool)
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, inboundType string, options any) error
} }
type InboundContext struct { type InboundContext struct {

143
adapter/inbound/manager.go Normal file
View File

@ -0,0 +1,143 @@
package inbound
import (
"context"
"os"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
)
var _ adapter.InboundManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.InboundRegistry
access sync.Mutex
started bool
stage adapter.StartStage
inbounds []adapter.Inbound
inboundByTag map[string]adapter.Inbound
}
func NewManager(logger log.ContextLogger, registry adapter.InboundRegistry) *Manager {
return &Manager{
logger: logger,
registry: registry,
inboundByTag: make(map[string]adapter.Inbound),
}
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
for _, inbound := range m.inbounds {
err := adapter.LegacyStart(inbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]")
}
}
return nil
}
func (m *Manager) Close() error {
m.access.Lock()
if !m.started {
panic("not started")
}
m.started = false
inbounds := m.inbounds
m.inbounds = nil
m.access.Unlock()
monitor := taskmonitor.New(m.logger, C.StopTimeout)
var err error
for _, inbound := range inbounds {
monitor.Start("close inbound/", inbound.Type(), "[", inbound.Tag(), "]")
err = E.Append(err, inbound.Close(), func(err error) error {
return E.Cause(err, "close inbound/", inbound.Type(), "[", inbound.Tag(), "]")
})
monitor.Finish()
}
return nil
}
func (m *Manager) Inbounds() []adapter.Inbound {
m.access.Lock()
defer m.access.Unlock()
return m.inbounds
}
func (m *Manager) Get(tag string) (adapter.Inbound, bool) {
m.access.Lock()
defer m.access.Unlock()
inbound, found := m.inboundByTag[tag]
return inbound, found
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
inbound, found := m.inboundByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.inboundByTag, tag)
index := common.Index(m.inbounds, func(it adapter.Inbound) bool {
return it == inbound
})
if index == -1 {
panic("invalid inbound index")
}
m.inbounds = append(m.inbounds[:index], m.inbounds[index+1:]...)
started := m.started
m.access.Unlock()
if started {
return inbound.Close()
}
return nil
}
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) error {
inbound, err := m.registry.Create(ctx, router, logger, tag, outboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(inbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " inbound/", inbound.Type(), "[", inbound.Tag(), "]")
}
}
}
if existsInbound, loaded := m.inboundByTag[tag]; loaded {
if m.started {
err = existsInbound.Close()
if err != nil {
return E.Cause(err, "close inbound/", existsInbound.Type(), "[", existsInbound.Tag(), "]")
}
}
existsIndex := common.Index(m.inbounds, func(it adapter.Inbound) bool {
return it == existsInbound
})
if existsIndex == -1 {
panic("invalid inbound index")
}
m.inbounds = append(m.inbounds[:existsIndex], m.inbounds[existsIndex+1:]...)
}
m.inbounds = append(m.inbounds, inbound)
m.inboundByTag[tag] = inbound
return nil
}

View File

@ -28,41 +28,41 @@ type (
) )
type Registry struct { type Registry struct {
access sync.Mutex access sync.Mutex
optionsType map[string]optionsConstructorFunc optionsType map[string]optionsConstructorFunc
constructors map[string]constructorFunc constructor map[string]constructorFunc
} }
func NewRegistry() *Registry { func NewRegistry() *Registry {
return &Registry{ return &Registry{
optionsType: make(map[string]optionsConstructorFunc), optionsType: make(map[string]optionsConstructorFunc),
constructors: make(map[string]constructorFunc), constructor: make(map[string]constructorFunc),
} }
} }
func (r *Registry) CreateOptions(outboundType string) (any, bool) { func (m *Registry) CreateOptions(outboundType string) (any, bool) {
r.access.Lock() m.access.Lock()
defer r.access.Unlock() defer m.access.Unlock()
optionsConstructor, loaded := r.optionsType[outboundType] optionsConstructor, loaded := m.optionsType[outboundType]
if !loaded { if !loaded {
return nil, false return nil, false
} }
return optionsConstructor(), true return optionsConstructor(), true
} }
func (r *Registry) CreateInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Inbound, error) { func (m *Registry) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, outboundType string, options any) (adapter.Inbound, error) {
r.access.Lock() m.access.Lock()
defer r.access.Unlock() defer m.access.Unlock()
constructor, loaded := r.constructors[outboundType] constructor, loaded := m.constructor[outboundType]
if !loaded { if !loaded {
return nil, E.New("outbound type not found: " + outboundType) return nil, E.New("outbound type not found: " + outboundType)
} }
return constructor(ctx, router, logger, tag, options) return constructor(ctx, router, logger, tag, options)
} }
func (r *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) { func (m *Registry) register(outboundType string, optionsConstructor optionsConstructorFunc, constructor constructorFunc) {
r.access.Lock() m.access.Lock()
defer r.access.Unlock() defer m.access.Unlock()
r.optionsType[outboundType] = optionsConstructor m.optionsType[outboundType] = optionsConstructor
r.constructors[outboundType] = constructor m.constructor[outboundType] = constructor
} }

41
adapter/lifecycle.go Normal file
View File

@ -0,0 +1,41 @@
package adapter
type StartStage uint8
const (
StartStateInitialize StartStage = iota
StartStateStart
StartStatePostStart
StartStateStarted
)
var ListStartStages = []StartStage{
StartStateInitialize,
StartStateStart,
StartStatePostStart,
StartStateStarted,
}
func (s StartStage) Action() string {
switch s {
case StartStateInitialize:
return "initialize"
case StartStateStart:
return "start"
case StartStatePostStart:
return "post-start"
case StartStateStarted:
return "start-after-started"
default:
panic("unknown stage")
}
}
type NewService interface {
NewStarter
Close() error
}
type NewStarter interface {
Start(stage StartStage) error
}

View File

@ -0,0 +1,33 @@
package adapter
type LegacyPreStarter interface {
PreStart() error
}
type LegacyPostStarter interface {
PostStart() error
}
func LegacyStart(starter any, stage StartStage) error {
switch stage {
case StartStateInitialize:
if preStarter, isPreStarter := starter.(interface {
PreStart() error
}); isPreStarter {
return preStarter.PreStart()
}
case StartStateStart:
if starter, isStarter := starter.(interface {
Start() error
}); isStarter {
return starter.Start()
}
case StartStatePostStart:
if postStarter, isPostStarter := starter.(interface {
PostStart() error
}); isPostStarter {
return postStarter.PostStart()
}
}
return nil
}

View File

@ -22,3 +22,12 @@ type OutboundRegistry interface {
option.OutboundOptionsRegistry option.OutboundOptionsRegistry
CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error) CreateOutbound(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) (Outbound, error)
} }
type OutboundManager interface {
NewService
Outbounds() []Outbound
Outbound(tag string) (Outbound, bool)
Default() Outbound
Remove(tag string) error
Create(ctx context.Context, router Router, logger log.ContextLogger, tag string, outboundType string, options any) error
}

265
adapter/outbound/manager.go Normal file
View File

@ -0,0 +1,265 @@
package outbound
import (
"context"
"io"
"os"
"strings"
"sync"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
)
var _ adapter.OutboundManager = (*Manager)(nil)
type Manager struct {
logger log.ContextLogger
registry adapter.OutboundRegistry
defaultTag string
access sync.Mutex
started bool
stage adapter.StartStage
outbounds []adapter.Outbound
outboundByTag map[string]adapter.Outbound
dependByTag map[string][]string
defaultOutbound adapter.Outbound
defaultOutboundFallback adapter.Outbound
}
func NewManager(logger logger.ContextLogger, registry adapter.OutboundRegistry, defaultTag string) *Manager {
return &Manager{
logger: logger,
registry: registry,
defaultTag: defaultTag,
outboundByTag: make(map[string]adapter.Outbound),
dependByTag: make(map[string][]string),
}
}
func (m *Manager) Initialize(defaultOutboundFallback adapter.Outbound) {
m.defaultOutboundFallback = defaultOutboundFallback
}
func (m *Manager) Start(stage adapter.StartStage) error {
m.access.Lock()
defer m.access.Unlock()
if m.started && m.stage >= stage {
panic("already started")
}
m.started = true
m.stage = stage
if stage == adapter.StartStateStart {
m.startOutbounds()
} else {
for _, outbound := range m.outbounds {
err := adapter.LegacyStart(outbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
}
}
}
return nil
}
func (m *Manager) startOutbounds() error {
monitor := taskmonitor.New(m.logger, C.StartTimeout)
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, outboundToStart := range m.outbounds {
outboundTag := outboundToStart.Tag()
if started[outboundTag] {
continue
}
dependencies := outboundToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[outboundTag] = true
canContinue = true
if starter, isStarter := outboundToStart.(interface {
Start() error
}); isStarter {
monitor.Start("start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
}
}
if len(started) == len(m.outbounds) {
break
}
if canContinue {
continue
}
currentOutbound := common.Find(m.outbounds, func(it adapter.Outbound) bool {
return !started[it.Tag()]
})
var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemOutboundTag) {
return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
}
problemOutbound := m.outboundByTag[problemOutboundTag]
if problemOutbound == nil {
return E.New("dependency[", problemOutboundTag, "] not found for outbound[", oCurrent.Tag(), "]")
}
return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
}
return lintOutbound([]string{currentOutbound.Tag()}, currentOutbound)
}
return nil
}
func (m *Manager) Close() error {
monitor := taskmonitor.New(m.logger, C.StopTimeout)
m.access.Lock()
if !m.started {
panic("not started")
}
m.started = false
outbounds := m.outbounds
m.outbounds = nil
m.access.Unlock()
var err error
for _, outbound := range outbounds {
if closer, isCloser := outbound.(io.Closer); isCloser {
monitor.Start("close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
err = E.Append(err, closer.Close(), func(err error) error {
return E.Cause(err, "close outbound/", outbound.Type(), "[", outbound.Tag(), "]")
})
monitor.Finish()
}
}
return nil
}
func (m *Manager) Outbounds() []adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
return m.outbounds
}
func (m *Manager) Outbound(tag string) (adapter.Outbound, bool) {
m.access.Lock()
defer m.access.Unlock()
outbound, found := m.outboundByTag[tag]
return outbound, found
}
func (m *Manager) Default() adapter.Outbound {
m.access.Lock()
defer m.access.Unlock()
if m.defaultOutbound != nil {
return m.defaultOutbound
} else {
return m.defaultOutboundFallback
}
}
func (m *Manager) Remove(tag string) error {
m.access.Lock()
outbound, found := m.outboundByTag[tag]
if !found {
m.access.Unlock()
return os.ErrInvalid
}
delete(m.outboundByTag, tag)
index := common.Index(m.outbounds, func(it adapter.Outbound) bool {
return it == outbound
})
if index == -1 {
panic("invalid inbound index")
}
m.outbounds = append(m.outbounds[:index], m.outbounds[index+1:]...)
started := m.started
if m.defaultOutbound == outbound {
if len(m.outbounds) > 0 {
m.defaultOutbound = m.outbounds[0]
m.logger.Info("updated default outbound to ", m.defaultOutbound.Tag())
} else {
m.defaultOutbound = nil
}
}
dependBy := m.dependByTag[tag]
if len(dependBy) > 0 {
return E.New("outbound[", tag, "] is depended by ", strings.Join(dependBy, ", "))
}
dependencies := outbound.Dependencies()
for _, dependency := range dependencies {
if len(m.dependByTag[dependency]) == 1 {
delete(m.dependByTag, dependency)
} else {
m.dependByTag[dependency] = common.Filter(m.dependByTag[dependency], func(it string) bool {
return it != tag
})
}
}
m.access.Unlock()
if started {
return common.Close(outbound)
}
return nil
}
func (m *Manager) Create(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, inboundType string, options any) error {
if tag == "" {
return os.ErrInvalid
}
outbound, err := m.registry.CreateOutbound(ctx, router, logger, tag, inboundType, options)
if err != nil {
return err
}
m.access.Lock()
defer m.access.Unlock()
if m.started {
for _, stage := range adapter.ListStartStages {
err = adapter.LegacyStart(outbound, stage)
if err != nil {
return E.Cause(err, stage.Action(), " outbound/", outbound.Type(), "[", outbound.Tag(), "]")
}
}
}
if existsOutbound, loaded := m.outboundByTag[tag]; loaded {
if m.started {
err = common.Close(existsOutbound)
if err != nil {
return E.Cause(err, "close outbound/", existsOutbound.Type(), "[", existsOutbound.Tag(), "]")
}
}
existsIndex := common.Index(m.outbounds, func(it adapter.Outbound) bool {
return it == existsOutbound
})
if existsIndex == -1 {
panic("invalid inbound index")
}
m.outbounds = append(m.outbounds[:existsIndex], m.outbounds[existsIndex+1:]...)
}
m.outbounds = append(m.outbounds, outbound)
m.outboundByTag[tag] = outbound
dependencies := outbound.Dependencies()
for _, dependency := range dependencies {
m.dependByTag[dependency] = append(m.dependByTag[dependency], tag)
}
if tag == m.defaultTag || (m.defaultTag == "" && m.defaultOutbound == nil) {
m.defaultOutbound = outbound
if m.started {
m.logger.Info("updated default outbound to ", outbound.Tag())
}
}
return nil
}

View File

@ -1,9 +1 @@
package adapter package adapter
type PreStarter interface {
PreStart() error
}
type PostStarter interface {
PostStart() error
}

View File

@ -15,21 +15,13 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
mdns "github.com/miekg/dns" mdns "github.com/miekg/dns"
"go4.org/netipx" "go4.org/netipx"
) )
type Router interface { type Router interface {
Service NewService
PreStarter
PostStarter
Cleanup() error
Outbounds() []Outbound
Outbound(tag string) (Outbound, bool)
DefaultOutbound(network string) (Outbound, error)
FakeIPStore() FakeIPStore FakeIPStore() FakeIPStore
@ -84,14 +76,6 @@ type ConnectionRouterEx interface {
RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata InboundContext, onClose N.CloseHandlerFunc)
} }
func ContextWithRouter(ctx context.Context, router Router) context.Context {
return service.ContextWith(ctx, router)
}
func RouterFromContext(ctx context.Context) Router {
return service.FromContext[Router](ctx)
}
type RuleSet interface { type RuleSet interface {
Name() string Name() string
StartContext(ctx context.Context, startContext *HTTPStartContext) error StartContext(ctx context.Context, startContext *HTTPStartContext) error

142
box.go
View File

@ -9,6 +9,8 @@ import (
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
"github.com/sagernet/sing-box/adapter/outbound"
"github.com/sagernet/sing-box/common/taskmonitor" "github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental" "github.com/sagernet/sing-box/experimental"
@ -30,8 +32,8 @@ var _ adapter.Service = (*Box)(nil)
type Box struct { type Box struct {
createdAt time.Time createdAt time.Time
router adapter.Router router adapter.Router
inbounds []adapter.Inbound inbound *inbound.Manager
outbounds []adapter.Outbound outbound *outbound.Manager
logFactory log.Factory logFactory log.Factory
logger log.ContextLogger logger log.ContextLogger
preServices1 map[string]adapter.Service preServices1 map[string]adapter.Service
@ -66,6 +68,7 @@ func New(options Options) (*Box, error) {
if ctx == nil { if ctx == nil {
ctx = context.Background() ctx = context.Background()
} }
ctx = service.ContextWithDefaultRegistry(ctx)
inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx) inboundRegistry := service.FromContext[adapter.InboundRegistry](ctx)
if inboundRegistry == nil { if inboundRegistry == nil {
return nil, E.New("missing inbound registry in context") return nil, E.New("missing inbound registry in context")
@ -74,7 +77,6 @@ func New(options Options) (*Box, error) {
if outboundRegistry == nil { if outboundRegistry == nil {
return nil, E.New("missing outbound registry in context") return nil, E.New("missing outbound registry in context")
} }
ctx = service.ContextWithDefaultRegistry(ctx)
ctx = pause.WithDefaultManager(ctx) ctx = pause.WithDefaultManager(ctx)
experimentalOptions := common.PtrValueOrDefault(options.Experimental) experimentalOptions := common.PtrValueOrDefault(options.Experimental)
applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug)) applyDebugOptions(common.PtrValueOrDefault(experimentalOptions.Debug))
@ -106,10 +108,15 @@ func New(options Options) (*Box, error) {
if err != nil { if err != nil {
return nil, E.Cause(err, "create log factory") return nil, E.Cause(err, "create log factory")
} }
routeOptions := common.PtrValueOrDefault(options.Route)
inboundManager := inbound.NewManager(logFactory.NewLogger("inbound-manager"), inboundRegistry)
outboundManager := outbound.NewManager(logFactory.NewLogger("outbound-manager"), outboundRegistry, routeOptions.Final)
ctx = service.ContextWith[adapter.InboundManager](ctx, inboundManager)
ctx = service.ContextWith[adapter.OutboundManager](ctx, outboundManager)
router, err := route.NewRouter( router, err := route.NewRouter(
ctx, ctx,
logFactory, logFactory,
common.PtrValueOrDefault(options.Route), routeOptions,
common.PtrValueOrDefault(options.DNS), common.PtrValueOrDefault(options.DNS),
common.PtrValueOrDefault(options.NTP), common.PtrValueOrDefault(options.NTP),
options.Inbounds, options.Inbounds,
@ -127,7 +134,6 @@ func New(options Options) (*Box, error) {
}) })
} }
} }
inbounds := make([]adapter.Inbound, 0, len(options.Inbounds))
//nolint:staticcheck //nolint:staticcheck
if len(options.LegacyOutbounds) > 0 { if len(options.LegacyOutbounds) > 0 {
for _, legacyOutbound := range options.LegacyOutbounds { for _, legacyOutbound := range options.LegacyOutbounds {
@ -138,17 +144,14 @@ func New(options Options) (*Box, error) {
}) })
} }
} }
outbounds := make([]adapter.Outbound, 0, len(options.Outbounds))
for i, inboundOptions := range options.Inbounds { for i, inboundOptions := range options.Inbounds {
var currentInbound adapter.Inbound
var tag string var tag string
if inboundOptions.Tag != "" { if inboundOptions.Tag != "" {
tag = inboundOptions.Tag tag = inboundOptions.Tag
} else { } else {
tag = F.ToString(i) tag = F.ToString(i)
} }
currentInbound, err = inboundRegistry.CreateInbound( err = inboundManager.Create(ctx,
ctx,
router, router,
logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")), logFactory.NewLogger(F.ToString("inbound/", inboundOptions.Type, "[", tag, "]")),
tag, tag,
@ -156,12 +159,10 @@ func New(options Options) (*Box, error) {
inboundOptions.Options, inboundOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "parse inbound[", i, "]") return nil, E.Cause(err, "initialize inbound[", i, "]")
} }
inbounds = append(inbounds, currentInbound)
} }
for i, outboundOptions := range options.Outbounds { for i, outboundOptions := range options.Outbounds {
var currentOutbound adapter.Outbound
var tag string var tag string
if outboundOptions.Tag != "" { if outboundOptions.Tag != "" {
tag = outboundOptions.Tag tag = outboundOptions.Tag
@ -175,7 +176,7 @@ func New(options Options) (*Box, error) {
Outbound: tag, Outbound: tag,
}) })
} }
currentOutbound, err = outboundRegistry.CreateOutbound( err = outboundManager.Create(
outboundCtx, outboundCtx,
router, router,
logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")), logFactory.NewLogger(F.ToString("outbound/", outboundOptions.Type, "[", tag, "]")),
@ -184,16 +185,18 @@ func New(options Options) (*Box, error) {
outboundOptions.Options, outboundOptions.Options,
) )
if err != nil { if err != nil {
return nil, E.Cause(err, "parse outbound[", i, "]") return nil, E.Cause(err, "initialize outbound[", i, "]")
} }
outbounds = append(outbounds, currentOutbound)
} }
err = router.Initialize(inbounds, outbounds, func() adapter.Outbound { outboundManager.Initialize(common.Must1(
defaultOutbound, cErr := direct.NewOutbound(ctx, router, logFactory.NewLogger("outbound/direct"), "direct", option.DirectOutboundOptions{}) direct.NewOutbound(
common.Must(cErr) ctx,
outbounds = append(outbounds, defaultOutbound) router,
return defaultOutbound logFactory.NewLogger("outbound/direct"),
}) "direct",
option.DirectOutboundOptions{},
),
))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -217,7 +220,7 @@ func New(options Options) (*Box, error) {
if needClashAPI { if needClashAPI {
clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI) clashAPIOptions := common.PtrValueOrDefault(experimentalOptions.ClashAPI)
clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options) clashAPIOptions.ModeList = experimental.CalculateClashModeList(options.Options)
clashServer, err := experimental.NewClashServer(ctx, router, logFactory.(log.ObservableFactory), clashAPIOptions) clashServer, err := experimental.NewClashServer(ctx, logFactory.(log.ObservableFactory), clashAPIOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "create clash api server") return nil, E.Cause(err, "create clash api server")
} }
@ -234,8 +237,8 @@ func New(options Options) (*Box, error) {
} }
return &Box{ return &Box{
router: router, router: router,
inbounds: inbounds, inbound: inboundManager,
outbounds: outbounds, outbound: outboundManager,
createdAt: createdAt, createdAt: createdAt,
logFactory: logFactory, logFactory: logFactory,
logger: logFactory.Logger(), logger: logFactory.Logger(),
@ -293,7 +296,7 @@ func (s *Box) preStart() error {
return E.Cause(err, "start logger") return E.Cause(err, "start logger")
} }
for serviceName, service := range s.preServices1 { for serviceName, service := range s.preServices1 {
if preService, isPreService := service.(adapter.PreStarter); isPreService { if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService {
monitor.Start("pre-start ", serviceName) monitor.Start("pre-start ", serviceName)
err := preService.PreStart() err := preService.PreStart()
monitor.Finish() monitor.Finish()
@ -303,7 +306,7 @@ func (s *Box) preStart() error {
} }
} }
for serviceName, service := range s.preServices2 { for serviceName, service := range s.preServices2 {
if preService, isPreService := service.(adapter.PreStarter); isPreService { if preService, isPreService := service.(adapter.LegacyPreStarter); isPreService {
monitor.Start("pre-start ", serviceName) monitor.Start("pre-start ", serviceName)
err := preService.PreStart() err := preService.PreStart()
monitor.Finish() monitor.Finish()
@ -312,15 +315,15 @@ func (s *Box) preStart() error {
} }
} }
} }
err = s.router.PreStart() err = s.router.Start(adapter.StartStateInitialize)
if err != nil { if err != nil {
return E.Cause(err, "pre-start router") return E.Cause(err, "initialize router")
} }
err = s.startOutbounds() err = s.outbound.Start(adapter.StartStateStart)
if err != nil { if err != nil {
return err return err
} }
return s.router.Start() return s.router.Start(adapter.StartStateStart)
} }
func (s *Box) start() error { func (s *Box) start() error {
@ -340,52 +343,39 @@ func (s *Box) start() error {
return E.Cause(err, "start ", serviceName) return E.Cause(err, "start ", serviceName)
} }
} }
for i, in := range s.inbounds { err = s.inbound.Start(adapter.StartStateStart)
var tag string
if in.Tag() == "" {
tag = F.ToString(i)
} else {
tag = in.Tag()
}
err = in.Start()
if err != nil {
return E.Cause(err, "initialize inbound/", in.Type(), "[", tag, "]")
}
}
err = s.postStart()
if err != nil { if err != nil {
return err return err
} }
return s.router.Cleanup()
}
func (s *Box) postStart() error {
for serviceName, service := range s.postServices { for serviceName, service := range s.postServices {
err := service.Start() err := service.Start()
if err != nil { if err != nil {
return E.Cause(err, "start ", serviceName) return E.Cause(err, "start ", serviceName)
} }
} }
// TODO: reorganize ALL start order err = s.outbound.Start(adapter.StartStatePostStart)
for _, out := range s.outbounds {
if lateOutbound, isLateOutbound := out.(adapter.PostStarter); isLateOutbound {
err := lateOutbound.PostStart()
if err != nil {
return E.Cause(err, "post-start outbound/", out.Tag())
}
}
}
err := s.router.PostStart()
if err != nil { if err != nil {
return err return err
} }
for _, in := range s.inbounds { err = s.router.Start(adapter.StartStatePostStart)
if lateInbound, isLateInbound := in.(adapter.PostStarter); isLateInbound { if err != nil {
err = lateInbound.PostStart() return err
if err != nil { }
return E.Cause(err, "post-start inbound/", in.Tag()) err = s.inbound.Start(adapter.StartStatePostStart)
} if err != nil {
} return err
}
err = s.router.Start(adapter.StartStateStarted)
if err != nil {
return err
}
err = s.outbound.Start(adapter.StartStateStarted)
if err != nil {
return err
}
err = s.inbound.Start(adapter.StartStateStarted)
if err != nil {
return err
} }
return nil return nil
} }
@ -406,20 +396,8 @@ func (s *Box) Close() error {
}) })
monitor.Finish() monitor.Finish()
} }
for i, in := range s.inbounds { errors = E.Errors(errors, s.inbound.Close())
monitor.Start("close inbound/", in.Type(), "[", i, "]") errors = E.Errors(errors, s.outbound.Close())
errors = E.Append(errors, in.Close(), func(err error) error {
return E.Cause(err, "close inbound/", in.Type(), "[", i, "]")
})
monitor.Finish()
}
for i, out := range s.outbounds {
monitor.Start("close outbound/", out.Type(), "[", i, "]")
errors = E.Append(errors, common.Close(out), func(err error) error {
return E.Cause(err, "close outbound/", out.Type(), "[", i, "]")
})
monitor.Finish()
}
monitor.Start("close router") monitor.Start("close router")
if err := common.Close(s.router); err != nil { if err := common.Close(s.router); err != nil {
errors = E.Append(errors, err, func(err error) error { errors = E.Append(errors, err, func(err error) error {
@ -449,6 +427,14 @@ func (s *Box) Close() error {
return errors return errors
} }
func (s *Box) Inbound() adapter.InboundManager {
return s.inbound
}
func (s *Box) Outbound() adapter.OutboundManager {
return s.outbound
}
func (s *Box) Router() adapter.Router { func (s *Box) Router() adapter.Router {
return s.router return s.router
} }

View File

@ -1,85 +0,0 @@
package box
import (
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/taskmonitor"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
F "github.com/sagernet/sing/common/format"
)
func (s *Box) startOutbounds() error {
monitor := taskmonitor.New(s.logger, C.StartTimeout)
outboundTags := make(map[adapter.Outbound]string)
outbounds := make(map[string]adapter.Outbound)
for i, outboundToStart := range s.outbounds {
var outboundTag string
if outboundToStart.Tag() == "" {
outboundTag = F.ToString(i)
} else {
outboundTag = outboundToStart.Tag()
}
if _, exists := outbounds[outboundTag]; exists {
return E.New("outbound tag ", outboundTag, " duplicated")
}
outboundTags[outboundToStart] = outboundTag
outbounds[outboundTag] = outboundToStart
}
started := make(map[string]bool)
for {
canContinue := false
startOne:
for _, outboundToStart := range s.outbounds {
outboundTag := outboundTags[outboundToStart]
if started[outboundTag] {
continue
}
dependencies := outboundToStart.Dependencies()
for _, dependency := range dependencies {
if !started[dependency] {
continue startOne
}
}
started[outboundTag] = true
canContinue = true
if starter, isStarter := outboundToStart.(interface {
Start() error
}); isStarter {
monitor.Start("initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
err := starter.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize outbound/", outboundToStart.Type(), "[", outboundTag, "]")
}
}
}
if len(started) == len(s.outbounds) {
break
}
if canContinue {
continue
}
currentOutbound := common.Find(s.outbounds, func(it adapter.Outbound) bool {
return !started[outboundTags[it]]
})
var lintOutbound func(oTree []string, oCurrent adapter.Outbound) error
lintOutbound = func(oTree []string, oCurrent adapter.Outbound) error {
problemOutboundTag := common.Find(oCurrent.Dependencies(), func(it string) bool {
return !started[it]
})
if common.Contains(oTree, problemOutboundTag) {
return E.New("circular outbound dependency: ", strings.Join(oTree, " -> "), " -> ", problemOutboundTag)
}
problemOutbound := outbounds[problemOutboundTag]
if problemOutbound == nil {
return E.New("dependency[", problemOutboundTag, "] not found for outbound[", outboundTags[oCurrent], "]")
}
return lintOutbound(append(oTree, problemOutboundTag), problemOutbound)
}
return lintOutbound([]string{outboundTags[currentOutbound]}, currentOutbound)
}
return nil
}

View File

@ -41,11 +41,11 @@ func createPreStartedClient() (*box.Box, error) {
return instance, nil return instance, nil
} }
func createDialer(instance *box.Box, network string, outboundTag string) (N.Dialer, error) { func createDialer(instance *box.Box, outboundTag string) (N.Dialer, error) {
if outboundTag == "" { if outboundTag == "" {
return instance.Router().DefaultOutbound(N.NetworkName(network)) return instance.Outbound().Default(), nil
} else { } else {
outbound, loaded := instance.Router().Outbound(outboundTag) outbound, loaded := instance.Outbound().Outbound(outboundTag)
if !loaded { if !loaded {
return nil, E.New("outbound not found: ", outboundTag) return nil, E.New("outbound not found: ", outboundTag)
} }

View File

@ -45,7 +45,7 @@ func connect(address string) error {
return err return err
} }
defer instance.Close() defer instance.Close()
dialer, err := createDialer(instance, commandConnectFlagNetwork, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return err return err
} }

View File

@ -48,7 +48,7 @@ func fetch(args []string) error {
httpClient = &http.Client{ httpClient = &http.Client{
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
dialer, err := createDialer(instance, network, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -16,7 +16,7 @@ import (
) )
func initializeHTTP3Client(instance *box.Box) error { func initializeHTTP3Client(instance *box.Box) error {
dialer, err := createDialer(instance, N.NetworkUDP, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return err return err
} }

View File

@ -9,7 +9,6 @@ import (
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -45,7 +44,7 @@ func syncTime() error {
if err != nil { if err != nil {
return err return err
} }
dialer, err := createDialer(instance, N.NetworkUDP, commandToolsFlagOutbound) dialer, err := createDialer(instance, commandToolsFlagOutbound)
if err != nil { if err != nil {
return err return err
} }

View File

@ -12,15 +12,15 @@ import (
) )
type DetourDialer struct { type DetourDialer struct {
router adapter.Router outboundManager adapter.OutboundManager
detour string detour string
dialer N.Dialer dialer N.Dialer
initOnce sync.Once initOnce sync.Once
initErr error initErr error
} }
func NewDetour(router adapter.Router, detour string) N.Dialer { func NewDetour(outboundManager adapter.OutboundManager, detour string) N.Dialer {
return &DetourDialer{router: router, detour: detour} return &DetourDialer{outboundManager: outboundManager, detour: detour}
} }
func (d *DetourDialer) Start() error { func (d *DetourDialer) Start() error {
@ -31,7 +31,7 @@ func (d *DetourDialer) Start() error {
func (d *DetourDialer) Dialer() (N.Dialer, error) { func (d *DetourDialer) Dialer() (N.Dialer, error) {
d.initOnce.Do(func() { d.initOnce.Do(func() {
var loaded bool var loaded bool
d.dialer, loaded = d.router.Outbound(d.detour) d.dialer, loaded = d.outboundManager.Outbound(d.detour)
if !loaded { if !loaded {
d.initErr = E.New("outbound detour not found: ", d.detour) d.initErr = E.New("outbound detour not found: ", d.detour)
} }

View File

@ -1,21 +1,22 @@
package dialer package dialer
import ( import (
"context"
"time" "time"
"github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions"
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
) )
func New(router adapter.Router, options option.DialerOptions) (N.Dialer, error) { func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
router := service.FromContext[adapter.Router](ctx)
if options.IsWireGuardListener { if options.IsWireGuardListener {
return NewDefault(router, options) return NewDefault(router, options)
} }
if router == nil {
return NewDefault(nil, options)
}
var ( var (
dialer N.Dialer dialer N.Dialer
err error err error
@ -26,7 +27,14 @@ func New(router adapter.Router, options option.DialerOptions) (N.Dialer, error)
return nil, err return nil, err
} }
} else { } else {
dialer = NewDetour(router, options.Detour) outboundManager := service.FromContext[adapter.OutboundManager](ctx)
if outboundManager == nil {
return nil, E.New("missing outbound manager")
}
dialer = NewDetour(outboundManager, options.Detour)
}
if router == nil {
return NewDefault(router, options)
} }
if options.Detour == "" { if options.Detour == "" {
dialer = NewResolveDialer( dialer = NewResolveDialer(

View File

@ -9,30 +9,22 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
type RouterDialer struct { type DefaultOutboundDialer struct {
router adapter.Router outboundManager adapter.OutboundManager
} }
func NewRouter(router adapter.Router) N.Dialer { func NewDefaultOutbound(outboundManager adapter.OutboundManager) N.Dialer {
return &RouterDialer{router: router} return &DefaultOutboundDialer{outboundManager: outboundManager}
} }
func (d *RouterDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { func (d *DefaultOutboundDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
dialer, err := d.router.DefaultOutbound(network) return d.outboundManager.Default().DialContext(ctx, network, destination)
if err != nil {
return nil, err
}
return dialer.DialContext(ctx, network, destination)
} }
func (d *RouterDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { func (d *DefaultOutboundDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
dialer, err := d.router.DefaultOutbound(N.NetworkUDP) return d.outboundManager.Default().ListenPacket(ctx, destination)
if err != nil {
return nil, err
}
return dialer.ListenPacket(ctx, destination)
} }
func (d *RouterDialer) Upstream() any { func (d *DefaultOutboundDialer) Upstream() any {
return d.router return d.outboundManager.Default()
} }

View File

@ -12,6 +12,7 @@ import (
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/sing/common/shell" "github.com/sagernet/sing/common/shell"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
) )
type DarwinSystemProxy struct { type DarwinSystemProxy struct {
@ -24,7 +25,7 @@ type DarwinSystemProxy struct {
} }
func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bool) (*DarwinSystemProxy, error) { func NewSystemProxy(ctx context.Context, serverAddr M.Socksaddr, supportSOCKS bool) (*DarwinSystemProxy, error) {
interfaceMonitor := adapter.RouterFromContext(ctx).InterfaceMonitor() interfaceMonitor := service.FromContext[adapter.Router](ctx).InterfaceMonitor()
if interfaceMonitor == nil { if interfaceMonitor == nil {
return nil, E.New("missing interface monitor") return nil, E.New("missing interface monitor")
} }

View File

@ -19,6 +19,7 @@ import (
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/ntp" "github.com/sagernet/sing/common/ntp"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
) )
@ -213,7 +214,7 @@ func fetchECHClientConfig(ctx context.Context) func(_ context.Context, serverNam
}, },
}, },
} }
response, err := adapter.RouterFromContext(ctx).Exchange(ctx, message) response, err := service.FromContext[adapter.Router](ctx).Exchange(ctx, message)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -11,7 +11,6 @@ import (
"time" "time"
"github.com/sagernet/reality" "github.com/sagernet/reality"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/dialer"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
@ -102,7 +101,7 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
tlsConfig.ShortIds[shortID] = true tlsConfig.ShortIds[shortID] = true
} }
handshakeDialer, err := dialer.New(adapter.RouterFromContext(ctx), options.Reality.Handshake.DialerOptions) handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -12,7 +12,7 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
) )
type ClashServerConstructor = func(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) type ClashServerConstructor = func(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error)
var clashServerConstructor ClashServerConstructor var clashServerConstructor ClashServerConstructor
@ -20,11 +20,11 @@ func RegisterClashServerConstructor(constructor ClashServerConstructor) {
clashServerConstructor = constructor clashServerConstructor = constructor
} }
func NewClashServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { func NewClashServer(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
if clashServerConstructor == nil { if clashServerConstructor == nil {
return nil, os.ErrInvalid return nil, os.ErrInvalid
} }
return clashServerConstructor(ctx, router, logFactory, options) return clashServerConstructor(ctx, logFactory, options)
} }
func CalculateClashModeList(options option.Options) []string { func CalculateClashModeList(options option.Options) []string {

View File

@ -23,7 +23,7 @@ func groupRouter(server *Server) http.Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Get("/", getGroups(server)) r.Get("/", getGroups(server))
r.Route("/{name}", func(r chi.Router) { r.Route("/{name}", func(r chi.Router) {
r.Use(parseProxyName, findProxyByName(server.router)) r.Use(parseProxyName, findProxyByName(server))
r.Get("/", getGroup(server)) r.Get("/", getGroup(server))
r.Get("/delay", getGroupDelay(server)) r.Get("/delay", getGroupDelay(server))
}) })
@ -32,7 +32,7 @@ func groupRouter(server *Server) http.Handler {
func getGroups(server *Server) func(w http.ResponseWriter, r *http.Request) { func getGroups(server *Server) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
groups := common.Map(common.Filter(server.router.Outbounds(), func(it adapter.Outbound) bool { groups := common.Map(common.Filter(server.outboundManager.Outbounds(), func(it adapter.Outbound) bool {
_, isGroup := it.(adapter.OutboundGroup) _, isGroup := it.(adapter.OutboundGroup)
return isGroup return isGroup
}), func(it adapter.Outbound) *badjson.JSONObject { }), func(it adapter.Outbound) *badjson.JSONObject {
@ -86,7 +86,7 @@ func getGroupDelay(server *Server) func(w http.ResponseWriter, r *http.Request)
result, err = urlTestGroup.URLTest(ctx) result, err = urlTestGroup.URLTest(ctx)
} else { } else {
outbounds := common.FilterNotNil(common.Map(outboundGroup.All(), func(it string) adapter.Outbound { outbounds := common.FilterNotNil(common.Map(outboundGroup.All(), func(it string) adapter.Outbound {
itOutbound, _ := server.router.Outbound(it) itOutbound, _ := server.outboundManager.Outbound(it)
return itOutbound return itOutbound
})) }))
b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10)) b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10))
@ -100,7 +100,7 @@ func getGroupDelay(server *Server) func(w http.ResponseWriter, r *http.Request)
continue continue
} }
checked[realTag] = true checked[realTag] = true
p, loaded := server.router.Outbound(realTag) p, loaded := server.outboundManager.Outbound(realTag)
if !loaded { if !loaded {
continue continue
} }

View File

@ -23,10 +23,10 @@ import (
func proxyRouter(server *Server, router adapter.Router) http.Handler { func proxyRouter(server *Server, router adapter.Router) http.Handler {
r := chi.NewRouter() r := chi.NewRouter()
r.Get("/", getProxies(server, router)) r.Get("/", getProxies(server))
r.Route("/{name}", func(r chi.Router) { r.Route("/{name}", func(r chi.Router) {
r.Use(parseProxyName, findProxyByName(router)) r.Use(parseProxyName, findProxyByName(server))
r.Get("/", getProxy(server)) r.Get("/", getProxy(server))
r.Get("/delay", getProxyDelay(server)) r.Get("/delay", getProxyDelay(server))
r.Put("/", updateProxy) r.Put("/", updateProxy)
@ -42,11 +42,11 @@ func parseProxyName(next http.Handler) http.Handler {
}) })
} }
func findProxyByName(router adapter.Router) func(next http.Handler) http.Handler { func findProxyByName(server *Server) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
name := r.Context().Value(CtxKeyProxyName).(string) name := r.Context().Value(CtxKeyProxyName).(string)
proxy, exist := router.Outbound(name) proxy, exist := server.outboundManager.Outbound(name)
if !exist { if !exist {
render.Status(r, http.StatusNotFound) render.Status(r, http.StatusNotFound)
render.JSON(w, r, ErrNotFound) render.JSON(w, r, ErrNotFound)
@ -83,10 +83,10 @@ func proxyInfo(server *Server, detour adapter.Outbound) *badjson.JSONObject {
return &info return &info
} }
func getProxies(server *Server, router adapter.Router) func(w http.ResponseWriter, r *http.Request) { func getProxies(server *Server) func(w http.ResponseWriter, r *http.Request) {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var proxyMap badjson.JSONObject var proxyMap badjson.JSONObject
outbounds := common.Filter(router.Outbounds(), func(detour adapter.Outbound) bool { outbounds := common.Filter(server.outboundManager.Outbounds(), func(detour adapter.Outbound) bool {
return detour.Tag() != "" return detour.Tag() != ""
}) })
@ -100,12 +100,7 @@ func getProxies(server *Server, router adapter.Router) func(w http.ResponseWrite
allProxies = append(allProxies, detour.Tag()) allProxies = append(allProxies, detour.Tag())
} }
var defaultTag string defaultTag := server.outboundManager.Default().Tag()
if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil {
defaultTag = defaultOutbound.Tag()
} else {
defaultTag = allProxies[0]
}
sort.SliceStable(allProxies, func(i, j int) bool { sort.SliceStable(allProxies, func(i, j int) bool {
return allProxies[i] == defaultTag return allProxies[i] == defaultTag

View File

@ -40,15 +40,16 @@ func init() {
var _ adapter.ClashServer = (*Server)(nil) var _ adapter.ClashServer = (*Server)(nil)
type Server struct { type Server struct {
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
logger log.Logger outboundManager adapter.OutboundManager
httpServer *http.Server logger log.Logger
trafficManager *trafficontrol.Manager httpServer *http.Server
urlTestHistory *urltest.HistoryStorage trafficManager *trafficontrol.Manager
mode string urlTestHistory *urltest.HistoryStorage
modeList []string mode string
modeUpdateHook chan<- struct{} modeList []string
modeUpdateHook chan<- struct{}
externalController bool externalController bool
externalUI string externalUI string
@ -56,13 +57,14 @@ type Server struct {
externalUIDownloadDetour string externalUIDownloadDetour string
} }
func NewServer(ctx context.Context, router adapter.Router, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) { func NewServer(ctx context.Context, logFactory log.ObservableFactory, options option.ClashAPIOptions) (adapter.ClashServer, error) {
trafficManager := trafficontrol.NewManager() trafficManager := trafficontrol.NewManager()
chiRouter := chi.NewRouter() chiRouter := chi.NewRouter()
server := &Server{ s := &Server{
ctx: ctx, ctx: ctx,
router: router, router: service.FromContext[adapter.Router](ctx),
logger: logFactory.NewLogger("clash-api"), outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logFactory.NewLogger("clash-api"),
httpServer: &http.Server{ httpServer: &http.Server{
Addr: options.ExternalController, Addr: options.ExternalController,
Handler: chiRouter, Handler: chiRouter,
@ -73,18 +75,18 @@ func NewServer(ctx context.Context, router adapter.Router, logFactory log.Observ
externalUIDownloadURL: options.ExternalUIDownloadURL, externalUIDownloadURL: options.ExternalUIDownloadURL,
externalUIDownloadDetour: options.ExternalUIDownloadDetour, externalUIDownloadDetour: options.ExternalUIDownloadDetour,
} }
server.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx) s.urlTestHistory = service.PtrFromContext[urltest.HistoryStorage](ctx)
if server.urlTestHistory == nil { if s.urlTestHistory == nil {
server.urlTestHistory = urltest.NewHistoryStorage() s.urlTestHistory = urltest.NewHistoryStorage()
} }
defaultMode := "Rule" defaultMode := "Rule"
if options.DefaultMode != "" { if options.DefaultMode != "" {
defaultMode = options.DefaultMode defaultMode = options.DefaultMode
} }
if !common.Contains(server.modeList, defaultMode) { if !common.Contains(s.modeList, defaultMode) {
server.modeList = append([]string{defaultMode}, server.modeList...) s.modeList = append([]string{defaultMode}, s.modeList...)
} }
server.mode = defaultMode s.mode = defaultMode
//goland:noinspection GoDeprecation //goland:noinspection GoDeprecation
//nolint:staticcheck //nolint:staticcheck
if options.StoreMode || options.StoreSelected || options.StoreFakeIP || options.CacheFile != "" || options.CacheID != "" { if options.StoreMode || options.StoreSelected || options.StoreFakeIP || options.CacheFile != "" || options.CacheID != "" {
@ -108,30 +110,30 @@ func NewServer(ctx context.Context, router adapter.Router, logFactory log.Observ
r.Get("/logs", getLogs(logFactory)) r.Get("/logs", getLogs(logFactory))
r.Get("/traffic", traffic(trafficManager)) r.Get("/traffic", traffic(trafficManager))
r.Get("/version", version) r.Get("/version", version)
r.Mount("/configs", configRouter(server, logFactory)) r.Mount("/configs", configRouter(s, logFactory))
r.Mount("/proxies", proxyRouter(server, router)) r.Mount("/proxies", proxyRouter(s, s.router))
r.Mount("/rules", ruleRouter(router)) r.Mount("/rules", ruleRouter(s.router))
r.Mount("/connections", connectionRouter(router, trafficManager)) r.Mount("/connections", connectionRouter(s.router, trafficManager))
r.Mount("/providers/proxies", proxyProviderRouter()) r.Mount("/providers/proxies", proxyProviderRouter())
r.Mount("/providers/rules", ruleProviderRouter()) r.Mount("/providers/rules", ruleProviderRouter())
r.Mount("/script", scriptRouter()) r.Mount("/script", scriptRouter())
r.Mount("/profile", profileRouter()) r.Mount("/profile", profileRouter())
r.Mount("/cache", cacheRouter(ctx)) r.Mount("/cache", cacheRouter(ctx))
r.Mount("/dns", dnsRouter(router)) r.Mount("/dns", dnsRouter(s.router))
server.setupMetaAPI(r) s.setupMetaAPI(r)
}) })
if options.ExternalUI != "" { if options.ExternalUI != "" {
server.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI)) s.externalUI = filemanager.BasePath(ctx, os.ExpandEnv(options.ExternalUI))
chiRouter.Group(func(r chi.Router) { chiRouter.Group(func(r chi.Router) {
fs := http.StripPrefix("/ui", http.FileServer(http.Dir(server.externalUI))) fs := http.StripPrefix("/ui", http.FileServer(http.Dir(s.externalUI)))
r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP) r.Get("/ui", http.RedirectHandler("/ui/", http.StatusTemporaryRedirect).ServeHTTP)
r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) { r.Get("/ui/*", func(w http.ResponseWriter, r *http.Request) {
fs.ServeHTTP(w, r) fs.ServeHTTP(w, r)
}) })
}) })
} }
return server, nil return s, nil
} }
func (s *Server) PreStart() error { func (s *Server) PreStart() error {
@ -235,12 +237,12 @@ func (s *Server) TrafficManager() *trafficontrol.Manager {
} }
func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) { func (s *Server) RoutedConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, matchedRule adapter.Rule) (net.Conn, adapter.Tracker) {
tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.router, matchedRule) tracker := trafficontrol.NewTCPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule)
return tracker, tracker return tracker, tracker
} }
func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) { func (s *Server) RoutedPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, matchedRule adapter.Rule) (N.PacketConn, adapter.Tracker) {
tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.router, matchedRule) tracker := trafficontrol.NewUDPTracker(conn, s.trafficManager, metadata, s.outboundManager, matchedRule)
return tracker, tracker return tracker, tracker
} }

View File

@ -15,7 +15,6 @@ import (
"github.com/sagernet/sing/common" "github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata" M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service/filemanager" "github.com/sagernet/sing/service/filemanager"
) )
@ -45,16 +44,13 @@ func (s *Server) downloadExternalUI() error {
s.logger.Info("downloading external ui") s.logger.Info("downloading external ui")
var detour adapter.Outbound var detour adapter.Outbound
if s.externalUIDownloadDetour != "" { if s.externalUIDownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.externalUIDownloadDetour) outbound, loaded := s.outboundManager.Outbound(s.externalUIDownloadDetour)
if !loaded { if !loaded {
return E.New("detour outbound not found: ", s.externalUIDownloadDetour) return E.New("detour outbound not found: ", s.externalUIDownloadDetour)
} }
detour = outbound detour = outbound
} else { } else {
outbound, err := s.router.DefaultOutbound(N.NetworkTCP) outbound := s.outboundManager.Default()
if err != nil {
return err
}
detour = outbound detour = outbound
} }
httpClient := &http.Client{ httpClient := &http.Client{

View File

@ -124,7 +124,7 @@ func (tt *TCPConn) WriterReplaceable() bool {
return true return true
} }
func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, router adapter.Router, rule adapter.Rule) *TCPConn { func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *TCPConn {
id, _ := uuid.NewV4() id, _ := uuid.NewV4()
var ( var (
chain []string chain []string
@ -138,11 +138,11 @@ func NewTCPTracker(conn net.Conn, manager *Manager, metadata adapter.InboundCont
} }
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound next = routeAction.Outbound
} else if defaultOutbound, err := router.DefaultOutbound(N.NetworkTCP); err == nil { } else {
next = defaultOutbound.Tag() next = outboundManager.Default().Tag()
} }
for { for {
detour, loaded := router.Outbound(next) detour, loaded := outboundManager.Outbound(next)
if !loaded { if !loaded {
break break
} }
@ -213,7 +213,7 @@ func (ut *UDPConn) WriterReplaceable() bool {
return true return true
} }
func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, router adapter.Router, rule adapter.Rule) *UDPConn { func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.InboundContext, outboundManager adapter.OutboundManager, rule adapter.Rule) *UDPConn {
id, _ := uuid.NewV4() id, _ := uuid.NewV4()
var ( var (
chain []string chain []string
@ -227,11 +227,11 @@ func NewUDPTracker(conn N.PacketConn, manager *Manager, metadata adapter.Inbound
} }
if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction { if routeAction, isRouteAction := action.(*R.RuleActionRoute); isRouteAction {
next = routeAction.Outbound next = routeAction.Outbound
} else if defaultOutbound, err := router.DefaultOutbound(N.NetworkUDP); err == nil { } else {
next = defaultOutbound.Tag() next = outboundManager.Default().Tag()
} }
for { for {
detour, loaded := router.Outbound(next) detour, loaded := outboundManager.Outbound(next)
if !loaded { if !loaded {
break break
} }

View File

@ -12,7 +12,7 @@ import (
C "github.com/sagernet/sing-box/constant" C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option" "github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns" dns "github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common/bufio" "github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
@ -39,7 +39,7 @@ type Outbound struct {
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.DirectOutboundOptions) (adapter.Outbound, error) {
options.UDPFragmentDefault = true options.UDPFragmentDefault = true
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -26,7 +26,7 @@ var _ adapter.OutboundGroup = (*Selector)(nil)
type Selector struct { type Selector struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
router adapter.Router outboundManager adapter.OutboundManager
logger logger.ContextLogger logger logger.ContextLogger
tags []string tags []string
defaultTag string defaultTag string
@ -40,7 +40,7 @@ func NewSelector(ctx context.Context, router adapter.Router, logger log.ContextL
outbound := &Selector{ outbound := &Selector{
Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeSelector, nil, tag, options.Outbounds),
ctx: ctx, ctx: ctx,
router: router, outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,
tags: options.Outbounds, tags: options.Outbounds,
defaultTag: options.Default, defaultTag: options.Default,
@ -63,7 +63,7 @@ func (s *Selector) Network() []string {
func (s *Selector) Start() error { func (s *Selector) Start() error {
for i, tag := range s.tags { for i, tag := range s.tags {
detour, loaded := s.router.Outbound(tag) detour, loaded := s.outboundManager.Outbound(tag)
if !loaded { if !loaded {
return E.New("outbound ", i, " not found: ", tag) return E.New("outbound ", i, " not found: ", tag)
} }

View File

@ -36,6 +36,7 @@ type URLTest struct {
outbound.Adapter outbound.Adapter
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
outboundManager adapter.OutboundManager
logger log.ContextLogger logger log.ContextLogger
tags []string tags []string
link string link string
@ -51,6 +52,7 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds), Adapter: outbound.NewAdapter(C.TypeURLTest, []string{N.NetworkTCP, N.NetworkUDP}, tag, options.Outbounds),
ctx: ctx, ctx: ctx,
router: router, router: router,
outboundManager: service.FromContext[adapter.OutboundManager](ctx),
logger: logger, logger: logger,
tags: options.Outbounds, tags: options.Outbounds,
link: options.URL, link: options.URL,
@ -68,7 +70,7 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo
func (s *URLTest) Start() error { func (s *URLTest) Start() error {
outbounds := make([]adapter.Outbound, 0, len(s.tags)) outbounds := make([]adapter.Outbound, 0, len(s.tags))
for i, tag := range s.tags { for i, tag := range s.tags {
detour, loaded := s.router.Outbound(tag) detour, loaded := s.outboundManager.Outbound(tag)
if !loaded { if !loaded {
return E.New("outbound ", i, " not found: ", tag) return E.New("outbound ", i, " not found: ", tag)
} }
@ -77,6 +79,7 @@ func (s *URLTest) Start() error {
group, err := NewURLTestGroup( group, err := NewURLTestGroup(
s.ctx, s.ctx,
s.router, s.router,
s.outboundManager,
s.logger, s.logger,
outbounds, outbounds,
s.link, s.link,
@ -190,6 +193,7 @@ func (s *URLTest) InterfaceUpdated() {
type URLTestGroup struct { type URLTestGroup struct {
ctx context.Context ctx context.Context
router adapter.Router router adapter.Router
outboundManager adapter.OutboundManager
logger log.Logger logger log.Logger
outbounds []adapter.Outbound outbounds []adapter.Outbound
link string link string
@ -214,6 +218,7 @@ type URLTestGroup struct {
func NewURLTestGroup( func NewURLTestGroup(
ctx context.Context, ctx context.Context,
router adapter.Router, router adapter.Router,
outboundManager adapter.OutboundManager,
logger log.Logger, logger log.Logger,
outbounds []adapter.Outbound, outbounds []adapter.Outbound,
link string, link string,
@ -244,6 +249,7 @@ func NewURLTestGroup(
return &URLTestGroup{ return &URLTestGroup{
ctx: ctx, ctx: ctx,
router: router, router: router,
outboundManager: outboundManager,
logger: logger, logger: logger,
outbounds: outbounds, outbounds: outbounds,
link: link, link: link,
@ -385,7 +391,7 @@ func (g *URLTestGroup) urlTest(ctx context.Context, force bool) (map[string]uint
continue continue
} }
checked[realTag] = true checked[realTag] = true
p, loaded := g.router.Outbound(realTag) p, loaded := g.outboundManager.Outbound(realTag)
if !loaded { if !loaded {
continue continue
} }

View File

@ -30,7 +30,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.HTTPOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -47,7 +47,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil { if err != nil {
return nil, err return nil, err
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -59,7 +59,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, E.New("unknown obfs type: ", options.Obfs.Type) return nil, E.New("unknown obfs type: ", options.Obfs.Type)
} }
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -44,7 +44,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil { if err != nil {
return nil, err return nil, err
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -46,7 +46,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
if options.Version > 1 { if options.Version > 1 {
handshakeForServerName = make(map[string]shadowtls.HandshakeConfig) handshakeForServerName = make(map[string]shadowtls.HandshakeConfig)
for serverName, serverOptions := range options.HandshakeForServerName { for serverName, serverOptions := range options.HandshakeForServerName {
handshakeDialer, err := dialer.New(router, serverOptions.DialerOptions) handshakeDialer, err := dialer.New(ctx, serverOptions.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -56,7 +56,7 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
} }
} }
} }
handshakeDialer, err := dialer.New(router, options.Handshake.DialerOptions) handshakeDialer, err := dialer.New(ctx, options.Handshake.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -68,7 +68,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
tlsHandshakeFunc = shadowtls.DefaultTLSHandshakeFunc(options.Password, stdTLSConfig) tlsHandshakeFunc = shadowtls.DefaultTLSHandshakeFunc(options.Password, stdTLSConfig)
} }
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -46,7 +46,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
if err != nil { if err != nil {
return nil, err return nil, err
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -49,7 +49,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.SSHOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -75,7 +75,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
} }
startConf.TorrcFile = torrcFile startConf.TorrcFile = torrcFile
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -38,7 +38,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.TrojanOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -60,7 +60,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
case "quic": case "quic":
tuicUDPStream = true tuicUDPStream = true
} }
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -41,7 +41,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VLESSOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VLESSOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -41,7 +41,7 @@ type Outbound struct {
} }
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessOutboundOptions) (adapter.Outbound, error) { func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.VMessOutboundOptions) (adapter.Outbound, error) {
outboundDialer, err := dialer.New(router, options.DialerOptions) outboundDialer, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -78,7 +78,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
options.IsWireGuardListener = true options.IsWireGuardListener = true
outbound.useStdNetBind = true outbound.useStdNetBind = true
} }
listener, err := dialer.New(router, options.DialerOptions) listener, err := dialer.New(ctx, options.DialerOptions)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -87,7 +87,7 @@ type Router struct {
v2rayServer adapter.V2RayServer v2rayServer adapter.V2RayServer
platformInterface platform.Interface platformInterface platform.Interface
needWIFIState bool needWIFIState bool
needPackageManager bool enforcePackageManager bool
wifiState adapter.WIFIState wifiState adapter.WIFIState
started bool started bool
} }
@ -123,7 +123,7 @@ func NewRouter(
pauseManager: service.FromContext[pause.Manager](ctx), pauseManager: service.FromContext[pause.Manager](ctx),
platformInterface: service.FromContext[platform.Interface](ctx), platformInterface: service.FromContext[platform.Interface](ctx),
needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule), needWIFIState: hasRule(options.Rules, isWIFIRule) || hasDNSRule(dnsOptions.Rules, isWIFIDNSRule),
needPackageManager: common.Any(inbounds, func(inbound option.Inbound) bool { enforcePackageManager: common.Any(inbounds, func(inbound option.Inbound) bool {
if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute { if tunOptions, isTUN := inbound.Options.(*option.TunInboundOptions); isTUN && tunOptions.AutoRoute {
return true return true
} }
@ -191,7 +191,8 @@ func NewRouter(
transportTags[i] = tag transportTags[i] = tag
transportTagMap[tag] = true transportTagMap[tag] = true
} }
ctx = adapter.ContextWithRouter(ctx, router) ctx = service.ContextWith[adapter.Router](ctx, router)
outboundManager := service.FromContext[adapter.OutboundManager](ctx)
for { for {
lastLen := len(dummyTransportMap) lastLen := len(dummyTransportMap)
for i, server := range dnsOptions.Servers { for i, server := range dnsOptions.Servers {
@ -201,9 +202,9 @@ func NewRouter(
} }
var detour N.Dialer var detour N.Dialer
if server.Detour == "" { if server.Detour == "" {
detour = dialer.NewRouter(router) detour = dialer.NewDefaultOutbound(outboundManager)
} else { } else {
detour = dialer.NewDetour(router, server.Detour) detour = dialer.NewDetour(outboundManager, server.Detour)
} }
var serverProtocol string var serverProtocol string
switch server.Address { switch server.Address {
@ -327,7 +328,7 @@ func NewRouter(
} }
usePlatformDefaultInterfaceMonitor := router.platformInterface != nil && router.platformInterface.UsePlatformDefaultInterfaceMonitor() usePlatformDefaultInterfaceMonitor := router.platformInterface != nil && router.platformInterface.UsePlatformDefaultInterfaceMonitor()
needInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool { enforceInterfaceMonitor := options.AutoDetectInterface || common.Any(inbounds, func(inbound option.Inbound) bool {
if httpMixedOptions, isHTTPMixed := inbound.Options.(*option.HTTPMixedInboundOptions); isHTTPMixed && httpMixedOptions.SetSystemProxy { if httpMixedOptions, isHTTPMixed := inbound.Options.(*option.HTTPMixedInboundOptions); isHTTPMixed && httpMixedOptions.SetSystemProxy {
return true return true
} }
@ -339,7 +340,7 @@ func NewRouter(
if !usePlatformDefaultInterfaceMonitor { if !usePlatformDefaultInterfaceMonitor {
networkMonitor, err := tun.NewNetworkUpdateMonitor(router.logger) networkMonitor, err := tun.NewNetworkUpdateMonitor(router.logger)
if !((err != nil && !needInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) { if !((err != nil && !enforceInterfaceMonitor) || errors.Is(err, os.ErrInvalid)) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -365,7 +366,7 @@ func NewRouter(
} }
if ntpOptions.Enabled { if ntpOptions.Enabled {
ntpDialer, err := dialer.New(router, ntpOptions.DialerOptions) ntpDialer, err := dialer.New(ctx, ntpOptions.DialerOptions)
if err != nil { if err != nil {
return nil, E.Cause(err, "create NTP service") return nil, E.Cause(err, "create NTP service")
} }
@ -383,73 +384,6 @@ func NewRouter(
return router, nil return router, nil
} }
func (r *Router) Initialize(inbounds []adapter.Inbound, outbounds []adapter.Outbound, defaultOutbound func() adapter.Outbound) error {
inboundByTag := make(map[string]adapter.Inbound)
for _, inbound := range inbounds {
inboundByTag[inbound.Tag()] = inbound
}
outboundByTag := make(map[string]adapter.Outbound)
for _, detour := range outbounds {
outboundByTag[detour.Tag()] = detour
}
var defaultOutboundForConnection adapter.Outbound
var defaultOutboundForPacketConnection adapter.Outbound
if r.defaultDetour != "" {
detour, loaded := outboundByTag[r.defaultDetour]
if !loaded {
return E.New("default detour not found: ", r.defaultDetour)
}
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
}
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
}
}
if defaultOutboundForConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkTCP) {
defaultOutboundForConnection = detour
break
}
}
}
if defaultOutboundForPacketConnection == nil {
for _, detour := range outbounds {
if common.Contains(detour.Network(), N.NetworkUDP) {
defaultOutboundForPacketConnection = detour
break
}
}
}
if defaultOutboundForConnection == nil || defaultOutboundForPacketConnection == nil {
detour := defaultOutbound()
if defaultOutboundForConnection == nil {
defaultOutboundForConnection = detour
}
if defaultOutboundForPacketConnection == nil {
defaultOutboundForPacketConnection = detour
}
outbounds = append(outbounds, detour)
outboundByTag[detour.Tag()] = detour
}
r.inboundByTag = inboundByTag
r.outbounds = outbounds
r.defaultOutboundForConnection = defaultOutboundForConnection
r.defaultOutboundForPacketConnection = defaultOutboundForPacketConnection
r.outboundByTag = outboundByTag
for i, rule := range r.rules {
routeAction, isRoute := rule.Action().(*R.RuleActionRoute)
if !isRoute {
continue
}
if _, loaded := outboundByTag[routeAction.Outbound]; !loaded {
return E.New("outbound not found for rule[", i, "]: ", routeAction.Outbound)
}
}
return nil
}
func (r *Router) Outbounds() []adapter.Outbound { func (r *Router) Outbounds() []adapter.Outbound {
if !r.started { if !r.started {
return nil return nil
@ -457,140 +391,240 @@ func (r *Router) Outbounds() []adapter.Outbound {
return r.outbounds return r.outbounds
} }
func (r *Router) PreStart() error { func (r *Router) Start(stage adapter.StartStage) error {
monitor := taskmonitor.New(r.logger, C.StartTimeout) monitor := taskmonitor.New(r.logger, C.StartTimeout)
if r.interfaceMonitor != nil { switch stage {
monitor.Start("initialize interface monitor") case adapter.StartStateInitialize:
err := r.interfaceMonitor.Start() if r.interfaceMonitor != nil {
monitor.Finish() monitor.Start("initialize interface monitor")
if err != nil { err := r.interfaceMonitor.Start()
return err
}
}
if r.networkMonitor != nil {
monitor.Start("initialize network monitor")
err := r.networkMonitor.Start()
monitor.Finish()
if err != nil {
return err
}
}
if r.fakeIPStore != nil {
monitor.Start("initialize fakeip store")
err := r.fakeIPStore.Start()
monitor.Finish()
if err != nil {
return err
}
}
return nil
}
func (r *Router) Start() error {
monitor := taskmonitor.New(r.logger, C.StartTimeout)
if r.needGeoIPDatabase {
monitor.Start("initialize geoip database")
err := r.prepareGeoIPDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
monitor.Start("initialize geosite database")
err := r.prepareGeositeDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
for _, rule := range r.rules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
for _, rule := range r.dnsRules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
err := common.Close(r.geositeReader)
if err != nil {
return err
}
r.geositeCache = nil
r.geositeReader = nil
}
if runtime.GOOS == "windows" {
powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent)
if err == nil {
r.powerListener = powerListener
} else {
r.logger.Warn("initialize power listener: ", err)
}
}
if r.powerListener != nil {
monitor.Start("start power listener")
err := r.powerListener.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start power listener")
}
}
monitor.Start("initialize DNS client")
r.dnsClient.Start()
monitor.Finish()
if C.IsAndroid && r.platformInterface == nil {
monitor.Start("initialize package manager")
packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{
Callback: r,
Logger: r.logger,
})
monitor.Finish()
if err != nil {
return E.Cause(err, "create package manager")
}
if r.needPackageManager {
monitor.Start("start package manager")
err = packageManager.Start()
monitor.Finish() monitor.Finish()
if err != nil { if err != nil {
return E.Cause(err, "start package manager") return err
} }
} }
r.packageManager = packageManager if r.networkMonitor != nil {
} monitor.Start("initialize network monitor")
err := r.networkMonitor.Start()
monitor.Finish()
if err != nil {
return err
}
}
if r.fakeIPStore != nil {
monitor.Start("initialize fakeip store")
err := r.fakeIPStore.Start()
monitor.Finish()
if err != nil {
return err
}
}
case adapter.StartStateStart:
if r.needGeoIPDatabase {
monitor.Start("initialize geoip database")
err := r.prepareGeoIPDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
monitor.Start("initialize geosite database")
err := r.prepareGeositeDatabase()
monitor.Finish()
if err != nil {
return err
}
}
if r.needGeositeDatabase {
for _, rule := range r.rules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
for _, rule := range r.dnsRules {
err := rule.UpdateGeosite()
if err != nil {
r.logger.Error("failed to initialize geosite: ", err)
}
}
err := common.Close(r.geositeReader)
if err != nil {
return err
}
r.geositeCache = nil
r.geositeReader = nil
}
for i, rule := range r.dnsRules { if runtime.GOOS == "windows" {
monitor.Start("initialize DNS rule[", i, "]") powerListener, err := winpowrprof.NewEventListener(r.notifyWindowsPowerEvent)
err := rule.Start() if err == nil {
monitor.Finish() r.powerListener = powerListener
if err != nil { } else {
return E.Cause(err, "initialize DNS rule[", i, "]") r.logger.Warn("initialize power listener: ", err)
}
} }
}
for i, transport := range r.transports { if r.powerListener != nil {
monitor.Start("initialize DNS transport[", i, "]") monitor.Start("start power listener")
err := transport.Start() err := r.powerListener.Start()
monitor.Finish() monitor.Finish()
if err != nil { if err != nil {
return E.Cause(err, "initialize DNS server[", i, "]") return E.Cause(err, "start power listener")
}
} }
}
if r.timeService != nil { monitor.Start("initialize DNS client")
monitor.Start("initialize time service") r.dnsClient.Start()
err := r.timeService.Start()
monitor.Finish() monitor.Finish()
if err != nil {
return E.Cause(err, "initialize time service") if C.IsAndroid && r.platformInterface == nil {
monitor.Start("initialize package manager")
packageManager, err := tun.NewPackageManager(tun.PackageManagerOptions{
Callback: r,
Logger: r.logger,
})
monitor.Finish()
if err != nil {
return E.Cause(err, "create package manager")
}
if r.enforcePackageManager {
monitor.Start("start package manager")
err = packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
}
r.packageManager = packageManager
} }
for i, rule := range r.dnsRules {
monitor.Start("initialize DNS rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS rule[", i, "]")
}
}
for i, transport := range r.transports {
monitor.Start("initialize DNS transport[", i, "]")
err := transport.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize DNS server[", i, "]")
}
}
if r.timeService != nil {
monitor.Start("initialize time service")
err := r.timeService.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize time service")
}
}
case adapter.StartStatePostStart:
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext()
var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets {
ruleSetInPlace := ruleSet
ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSetInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
return nil
})
}
ruleSetStartGroup.Concurrency(5)
ruleSetStartGroup.FastFail()
err := ruleSetStartGroup.Run(r.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
needFindProcess := r.needFindProcess
needWIFIState := r.needWIFIState
for _, ruleSet := range r.ruleSets {
metadata := ruleSet.Metadata()
if metadata.ContainsProcessRule {
needFindProcess = true
}
if metadata.ContainsWIFIRule {
needWIFIState = true
}
}
if C.IsAndroid && r.platformInterface == nil && !r.enforcePackageManager {
if needFindProcess {
monitor.Start("start package manager")
err := r.packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
} else {
r.packageManager = nil
}
}
if needFindProcess {
if r.platformInterface != nil {
r.processSearcher = r.platformInterface
} else {
monitor.Start("initialize process searcher")
searcher, err := process.NewSearcher(process.Config{
Logger: r.logger,
PackageManager: r.packageManager,
})
monitor.Finish()
if err != nil {
if err != os.ErrInvalid {
r.logger.Warn(E.Cause(err, "create process searcher"))
}
} else {
r.processSearcher = searcher
}
}
}
if needWIFIState && r.platformInterface != nil {
monitor.Start("initialize WIFI state")
r.needWIFIState = true
r.interfaceMonitor.RegisterCallback(func(_ int) {
r.updateWIFIState()
})
r.updateWIFIState()
monitor.Finish()
}
for i, rule := range r.rules {
monitor.Start("initialize rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize rule[", i, "]")
}
}
for _, ruleSet := range r.ruleSets {
monitor.Start("post start rule_set[", ruleSet.Name(), "]")
err := ruleSet.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
}
}
r.started = true
return nil
case adapter.StartStateStarted:
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
} }
return nil return nil
} }
@ -671,113 +705,6 @@ func (r *Router) Close() error {
return err return err
} }
func (r *Router) PostStart() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
var cacheContext *adapter.HTTPStartContext
if len(r.ruleSets) > 0 {
monitor.Start("initialize rule-set")
cacheContext = adapter.NewHTTPStartContext()
var ruleSetStartGroup task.Group
for i, ruleSet := range r.ruleSets {
ruleSetInPlace := ruleSet
ruleSetStartGroup.Append0(func(ctx context.Context) error {
err := ruleSetInPlace.StartContext(ctx, cacheContext)
if err != nil {
return E.Cause(err, "initialize rule-set[", i, "]")
}
return nil
})
}
ruleSetStartGroup.Concurrency(5)
ruleSetStartGroup.FastFail()
err := ruleSetStartGroup.Run(r.ctx)
monitor.Finish()
if err != nil {
return err
}
}
if cacheContext != nil {
cacheContext.Close()
}
needFindProcess := r.needFindProcess
needWIFIState := r.needWIFIState
for _, ruleSet := range r.ruleSets {
metadata := ruleSet.Metadata()
if metadata.ContainsProcessRule {
needFindProcess = true
}
if metadata.ContainsWIFIRule {
needWIFIState = true
}
}
if C.IsAndroid && r.platformInterface == nil && !r.needPackageManager {
if needFindProcess {
monitor.Start("start package manager")
err := r.packageManager.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "start package manager")
}
} else {
r.packageManager = nil
}
}
if needFindProcess {
if r.platformInterface != nil {
r.processSearcher = r.platformInterface
} else {
monitor.Start("initialize process searcher")
searcher, err := process.NewSearcher(process.Config{
Logger: r.logger,
PackageManager: r.packageManager,
})
monitor.Finish()
if err != nil {
if err != os.ErrInvalid {
r.logger.Warn(E.Cause(err, "create process searcher"))
}
} else {
r.processSearcher = searcher
}
}
}
if needWIFIState && r.platformInterface != nil {
monitor.Start("initialize WIFI state")
r.needWIFIState = true
r.interfaceMonitor.RegisterCallback(func(_ int) {
r.updateWIFIState()
})
r.updateWIFIState()
monitor.Finish()
}
for i, rule := range r.rules {
monitor.Start("initialize rule[", i, "]")
err := rule.Start()
monitor.Finish()
if err != nil {
return E.Cause(err, "initialize rule[", i, "]")
}
}
for _, ruleSet := range r.ruleSets {
monitor.Start("post start rule_set[", ruleSet.Name(), "]")
err := ruleSet.PostStart()
monitor.Finish()
if err != nil {
return E.Cause(err, "post start rule_set[", ruleSet.Name(), "]")
}
}
r.started = true
return nil
}
func (r *Router) Cleanup() error {
for _, ruleSet := range r.ruleSetMap {
ruleSet.Cleanup()
}
runtime.GC()
return nil
}
func (r *Router) Outbound(tag string) (adapter.Outbound, bool) { func (r *Router) Outbound(tag string) (adapter.Outbound, bool) {
outbound, loaded := r.outboundByTag[tag] outbound, loaded := r.outboundByTag[tag]
return outbound, loaded return outbound, loaded

View File

@ -22,7 +22,7 @@ import (
N "github.com/sagernet/sing/common/network" N "github.com/sagernet/sing/common/network"
) )
func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) { func NewRuleAction(ctx context.Context, logger logger.ContextLogger, action option.RuleAction) (adapter.RuleAction, error) {
switch action.Action { switch action.Action {
case "": case "":
return nil, nil return nil, nil
@ -36,7 +36,7 @@ func NewRuleAction(router adapter.Router, logger logger.ContextLogger, action op
UDPConnect: action.RouteOptionsOptions.UDPConnect, UDPConnect: action.RouteOptionsOptions.UDPConnect,
}, nil }, nil
case C.RuleActionTypeDirect: case C.RuleActionTypeDirect:
directDialer, err := dialer.New(router, option.DialerOptions(action.DirectOptions)) directDialer, err := dialer.New(ctx, option.DialerOptions(action.DirectOptions))
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -52,7 +52,7 @@ type RuleItem interface {
} }
func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) { func NewDefaultRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.DefaultRule) (*DefaultRule, error) {
action, err := NewRuleAction(router, logger, options.RuleAction) action, err := NewRuleAction(ctx, logger, options.RuleAction)
if err != nil { if err != nil {
return nil, E.Cause(err, "action") return nil, E.Cause(err, "action")
} }
@ -254,7 +254,7 @@ type LogicalRule struct {
} }
func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) { func NewLogicalRule(ctx context.Context, router adapter.Router, logger log.ContextLogger, options option.LogicalRule) (*LogicalRule, error) {
action, err := NewRuleAction(router, logger, options.RuleAction) action, err := NewRuleAction(ctx, logger, options.RuleAction)
if err != nil { if err != nil {
return nil, E.Cause(err, "action") return nil, E.Cause(err, "action")
} }

View File

@ -33,23 +33,24 @@ import (
var _ adapter.RuleSet = (*RemoteRuleSet)(nil) var _ adapter.RuleSet = (*RemoteRuleSet)(nil)
type RemoteRuleSet struct { type RemoteRuleSet struct {
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
router adapter.Router router adapter.Router
logger logger.ContextLogger outboundManager adapter.OutboundManager
options option.RuleSet logger logger.ContextLogger
metadata adapter.RuleSetMetadata options option.RuleSet
updateInterval time.Duration metadata adapter.RuleSetMetadata
dialer N.Dialer updateInterval time.Duration
rules []adapter.HeadlessRule dialer N.Dialer
lastUpdated time.Time rules []adapter.HeadlessRule
lastEtag string lastUpdated time.Time
updateTicker *time.Ticker lastEtag string
cacheFile adapter.CacheFile updateTicker *time.Ticker
pauseManager pause.Manager cacheFile adapter.CacheFile
callbackAccess sync.Mutex pauseManager pause.Manager
callbacks list.List[adapter.RuleSetUpdateCallback] callbackAccess sync.Mutex
refs atomic.Int32 callbacks list.List[adapter.RuleSetUpdateCallback]
refs atomic.Int32
} }
func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet { func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.ContextLogger, options option.RuleSet) *RemoteRuleSet {
@ -61,13 +62,14 @@ func NewRemoteRuleSet(ctx context.Context, router adapter.Router, logger logger.
updateInterval = 24 * time.Hour updateInterval = 24 * time.Hour
} }
return &RemoteRuleSet{ return &RemoteRuleSet{
ctx: ctx, ctx: ctx,
cancel: cancel, cancel: cancel,
router: router, router: router,
logger: logger, outboundManager: service.FromContext[adapter.OutboundManager](ctx),
options: options, logger: logger,
updateInterval: updateInterval, options: options,
pauseManager: service.FromContext[pause.Manager](ctx), updateInterval: updateInterval,
pauseManager: service.FromContext[pause.Manager](ctx),
} }
} }
@ -83,17 +85,13 @@ func (s *RemoteRuleSet) StartContext(ctx context.Context, startContext *adapter.
s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx) s.cacheFile = service.FromContext[adapter.CacheFile](s.ctx)
var dialer N.Dialer var dialer N.Dialer
if s.options.RemoteOptions.DownloadDetour != "" { if s.options.RemoteOptions.DownloadDetour != "" {
outbound, loaded := s.router.Outbound(s.options.RemoteOptions.DownloadDetour) outbound, loaded := s.outboundManager.Outbound(s.options.RemoteOptions.DownloadDetour)
if !loaded { if !loaded {
return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour) return E.New("download_detour not found: ", s.options.RemoteOptions.DownloadDetour)
} }
dialer = outbound dialer = outbound
} else { } else {
outbound, err := s.router.DefaultOutbound(N.NetworkTCP) dialer = s.outboundManager.Default()
if err != nil {
return err
}
dialer = outbound
} }
s.dialer = dialer s.dialer = dialer
if s.cacheFile != nil { if s.cacheFile != nil {

View File

@ -23,6 +23,7 @@ import (
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/task" "github.com/sagernet/sing/common/task"
"github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/common/x/list"
"github.com/sagernet/sing/service"
"github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv4"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
@ -53,7 +54,7 @@ func NewTransport(options dns.TransportOptions) (*Transport, error) {
if linkURL.Host == "" { if linkURL.Host == "" {
return nil, E.New("missing interface name for DHCP") return nil, E.New("missing interface name for DHCP")
} }
router := adapter.RouterFromContext(options.Context) router := service.FromContext[adapter.Router](options.Context)
if router == nil { if router == nil {
return nil, E.New("missing router in context") return nil, E.New("missing router in context")
} }

View File

@ -9,6 +9,7 @@ import (
"github.com/sagernet/sing-dns" "github.com/sagernet/sing-dns"
E "github.com/sagernet/sing/common/exceptions" E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger" "github.com/sagernet/sing/common/logger"
"github.com/sagernet/sing/service"
mDNS "github.com/miekg/dns" mDNS "github.com/miekg/dns"
) )
@ -32,7 +33,7 @@ type Transport struct {
} }
func NewTransport(options dns.TransportOptions) (*Transport, error) { func NewTransport(options dns.TransportOptions) (*Transport, error) {
router := adapter.RouterFromContext(options.Context) router := service.FromContext[adapter.Router](options.Context)
if router == nil { if router == nil {
return nil, E.New("missing router in context") return nil, E.New("missing router in context")
} }