From c320be75a70a3fcacdb664ee4fb4914b0d056c66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Fri, 15 Sep 2023 00:07:07 +0800 Subject: [PATCH] Add interrupt support for outbound groups --- common/interrupt/conn.go | 75 ++++++++++++++ common/interrupt/context.go | 13 +++ common/interrupt/group.go | 52 ++++++++++ docs/configuration/outbound/selector.md | 11 ++- docs/configuration/outbound/selector.zh.md | 9 +- docs/configuration/outbound/urltest.md | 9 +- docs/configuration/outbound/urltest.zh.md | 9 +- option/clash.go | 14 +-- outbound/selector.go | 37 +++++-- outbound/urltest.go | 108 ++++++++++++++------- 10 files changed, 282 insertions(+), 55 deletions(-) create mode 100644 common/interrupt/conn.go create mode 100644 common/interrupt/context.go create mode 100644 common/interrupt/group.go diff --git a/common/interrupt/conn.go b/common/interrupt/conn.go new file mode 100644 index 00000000..6a6d31c6 --- /dev/null +++ b/common/interrupt/conn.go @@ -0,0 +1,75 @@ +package interrupt + +import ( + "net" + + "github.com/sagernet/sing/common/x/list" +) + +/*type GroupedConn interface { + MarkAsInternal() +} + +func MarkAsInternal(conn any) { + if groupedConn, isGroupConn := common.Cast[GroupedConn](conn); isGroupConn { + groupedConn.MarkAsInternal() + } +}*/ + +type Conn struct { + net.Conn + group *Group + element *list.Element[*groupConnItem] +} + +/*func (c *Conn) MarkAsInternal() { + c.element.Value.internal = true +}*/ + +func (c *Conn) Close() error { + c.group.access.Lock() + defer c.group.access.Unlock() + c.group.connections.Remove(c.element) + return c.Conn.Close() +} + +func (c *Conn) ReaderReplaceable() bool { + return true +} + +func (c *Conn) WriterReplaceable() bool { + return true +} + +func (c *Conn) Upstream() any { + return c.Conn +} + +type PacketConn struct { + net.PacketConn + group *Group + element *list.Element[*groupConnItem] +} + +/*func (c *PacketConn) MarkAsInternal() { + c.element.Value.internal = true +}*/ + +func (c *PacketConn) Close() error { + c.group.access.Lock() + defer c.group.access.Unlock() + c.group.connections.Remove(c.element) + return c.PacketConn.Close() +} + +func (c *PacketConn) ReaderReplaceable() bool { + return true +} + +func (c *PacketConn) WriterReplaceable() bool { + return true +} + +func (c *PacketConn) Upstream() any { + return c.PacketConn +} diff --git a/common/interrupt/context.go b/common/interrupt/context.go new file mode 100644 index 00000000..44726b2d --- /dev/null +++ b/common/interrupt/context.go @@ -0,0 +1,13 @@ +package interrupt + +import "context" + +type contextKeyIsExternalConnection struct{} + +func ContextWithIsExternalConnection(ctx context.Context) context.Context { + return context.WithValue(ctx, contextKeyIsExternalConnection{}, true) +} + +func IsExternalConnectionFromContext(ctx context.Context) bool { + return ctx.Value(contextKeyIsExternalConnection{}) != nil +} diff --git a/common/interrupt/group.go b/common/interrupt/group.go new file mode 100644 index 00000000..ba2e7f73 --- /dev/null +++ b/common/interrupt/group.go @@ -0,0 +1,52 @@ +package interrupt + +import ( + "io" + "net" + "sync" + + "github.com/sagernet/sing/common/x/list" +) + +type Group struct { + access sync.Mutex + connections list.List[*groupConnItem] +} + +type groupConnItem struct { + conn io.Closer + isExternal bool +} + +func NewGroup() *Group { + return &Group{} +} + +func (g *Group) NewConn(conn net.Conn, isExternal bool) net.Conn { + g.access.Lock() + defer g.access.Unlock() + item := g.connections.PushBack(&groupConnItem{conn, isExternal}) + return &Conn{Conn: conn, group: g, element: item} +} + +func (g *Group) NewPacketConn(conn net.PacketConn, isExternal bool) net.PacketConn { + g.access.Lock() + defer g.access.Unlock() + item := g.connections.PushBack(&groupConnItem{conn, isExternal}) + return &PacketConn{PacketConn: conn, group: g, element: item} +} + +func (g *Group) Interrupt(interruptExternalConnections bool) { + g.access.Lock() + defer g.access.Unlock() + var toDelete []*list.Element[*groupConnItem] + for element := g.connections.Front(); element != nil; element = element.Next() { + if !element.Value.isExternal || interruptExternalConnections { + element.Value.conn.Close() + toDelete = append(toDelete, element) + } + } + for _, element := range toDelete { + g.connections.Remove(element) + } +} diff --git a/docs/configuration/outbound/selector.md b/docs/configuration/outbound/selector.md index 35be3041..1d2c74a9 100644 --- a/docs/configuration/outbound/selector.md +++ b/docs/configuration/outbound/selector.md @@ -10,7 +10,8 @@ "proxy-b", "proxy-c" ], - "default": "proxy-c" + "default": "proxy-c", + "interrupt_exist_connections": false } ``` @@ -28,4 +29,10 @@ List of outbound tags to select. #### default -The default outbound tag. The first outbound will be used if empty. \ No newline at end of file +The default outbound tag. The first outbound will be used if empty. + +#### interrupt_exist_connections + +Interrupt existing connections when the selected outbound has changed. + +Only inbound connections are affected by this setting, internal connections will always be interrupted. diff --git a/docs/configuration/outbound/selector.zh.md b/docs/configuration/outbound/selector.zh.md index adfbf3bf..9e985ab1 100644 --- a/docs/configuration/outbound/selector.zh.md +++ b/docs/configuration/outbound/selector.zh.md @@ -10,7 +10,8 @@ "proxy-b", "proxy-c" ], - "default": "proxy-c" + "default": "proxy-c", + "interrupt_exist_connections": false } ``` @@ -29,3 +30,9 @@ #### default 默认的出站标签。默认使用第一个出站。 + +#### interrupt_exist_connections + +当选定的出站发生更改时,中断现有连接。 + +仅入站连接受此设置影响,内部连接将始终被中断。 \ No newline at end of file diff --git a/docs/configuration/outbound/urltest.md b/docs/configuration/outbound/urltest.md index cdfed7b9..d905068d 100644 --- a/docs/configuration/outbound/urltest.md +++ b/docs/configuration/outbound/urltest.md @@ -12,7 +12,8 @@ ], "url": "https://www.gstatic.com/generate_204", "interval": "1m", - "tolerance": 50 + "tolerance": 50, + "interrupt_exist_connections": false } ``` @@ -35,3 +36,9 @@ The test interval. `1m` will be used if empty. #### tolerance The test tolerance in milliseconds. `50` will be used if empty. + +#### interrupt_exist_connections + +Interrupt existing connections when the selected outbound has changed. + +Only inbound connections are affected by this setting, internal connections will always be interrupted. diff --git a/docs/configuration/outbound/urltest.zh.md b/docs/configuration/outbound/urltest.zh.md index 210eadb5..0ad891f6 100644 --- a/docs/configuration/outbound/urltest.zh.md +++ b/docs/configuration/outbound/urltest.zh.md @@ -12,7 +12,8 @@ ], "url": "https://www.gstatic.com/generate_204", "interval": "1m", - "tolerance": 50 + "tolerance": 50, + "interrupt_exist_connections": false } ``` @@ -35,3 +36,9 @@ #### tolerance 以毫秒为单位的测试容差。 默认使用 `50`。 + +#### interrupt_exist_connections + +当选定的出站发生更改时,中断现有连接。 + +仅入站连接受此设置影响,内部连接将始终被中断。 \ No newline at end of file diff --git a/option/clash.go b/option/clash.go index 175a2c50..63ee2aeb 100644 --- a/option/clash.go +++ b/option/clash.go @@ -17,13 +17,15 @@ type ClashAPIOptions struct { } type SelectorOutboundOptions struct { - Outbounds []string `json:"outbounds"` - Default string `json:"default,omitempty"` + Outbounds []string `json:"outbounds"` + Default string `json:"default,omitempty"` + InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"` } type URLTestOutboundOptions struct { - Outbounds []string `json:"outbounds"` - URL string `json:"url,omitempty"` - Interval Duration `json:"interval,omitempty"` - Tolerance uint16 `json:"tolerance,omitempty"` + Outbounds []string `json:"outbounds"` + URL string `json:"url,omitempty"` + Interval Duration `json:"interval,omitempty"` + Tolerance uint16 `json:"tolerance,omitempty"` + InterruptExistConnections bool `json:"interrupt_exist_connections,omitempty"` } diff --git a/outbound/selector.go b/outbound/selector.go index c99d7af9..c66591cd 100644 --- a/outbound/selector.go +++ b/outbound/selector.go @@ -5,6 +5,7 @@ import ( "net" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/interrupt" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" @@ -20,10 +21,12 @@ var ( type Selector struct { myOutboundAdapter - tags []string - defaultTag string - outbounds map[string]adapter.Outbound - selected adapter.Outbound + tags []string + defaultTag string + outbounds map[string]adapter.Outbound + selected adapter.Outbound + interruptGroup *interrupt.Group + interruptExternalConnections bool } func NewSelector(router adapter.Router, logger log.ContextLogger, tag string, options option.SelectorOutboundOptions) (*Selector, error) { @@ -35,9 +38,11 @@ func NewSelector(router adapter.Router, logger log.ContextLogger, tag string, op tag: tag, dependencies: options.Outbounds, }, - tags: options.Outbounds, - defaultTag: options.Default, - outbounds: make(map[string]adapter.Outbound), + tags: options.Outbounds, + defaultTag: options.Default, + outbounds: make(map[string]adapter.Outbound), + interruptGroup: interrupt.NewGroup(), + interruptExternalConnections: options.InterruptExistConnections, } if len(outbound.tags) == 0 { return nil, E.New("missing tags") @@ -100,6 +105,9 @@ func (s *Selector) SelectOutbound(tag string) bool { if !loaded { return false } + if s.selected == detour { + return true + } s.selected = detour if s.tag != "" { if clashServer := s.router.ClashServer(); clashServer != nil && clashServer.StoreSelected() { @@ -109,22 +117,33 @@ func (s *Selector) SelectOutbound(tag string) bool { } } } + s.interruptGroup.Interrupt(s.interruptExternalConnections) return true } func (s *Selector) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { - return s.selected.DialContext(ctx, network, destination) + conn, err := s.selected.DialContext(ctx, network, destination) + if err != nil { + return nil, err + } + return s.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil } func (s *Selector) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { - return s.selected.ListenPacket(ctx, destination) + conn, err := s.selected.ListenPacket(ctx, destination) + if err != nil { + return nil, err + } + return s.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil } func (s *Selector) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + ctx = interrupt.ContextWithIsExternalConnection(ctx) return s.selected.NewConnection(ctx, conn, metadata) } func (s *Selector) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = interrupt.ContextWithIsExternalConnection(ctx) return s.selected.NewPacketConnection(ctx, conn, metadata) } diff --git a/outbound/urltest.go b/outbound/urltest.go index 79ab0bf8..45ce90b6 100644 --- a/outbound/urltest.go +++ b/outbound/urltest.go @@ -8,6 +8,7 @@ import ( "time" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/interrupt" "github.com/sagernet/sing-box/common/urltest" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" @@ -30,12 +31,13 @@ var ( type URLTest struct { myOutboundAdapter - ctx context.Context - tags []string - link string - interval time.Duration - tolerance uint16 - group *URLTestGroup + ctx context.Context + tags []string + link string + interval time.Duration + tolerance uint16 + group *URLTestGroup + interruptExternalConnections bool } func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.URLTestOutboundOptions) (*URLTest, error) { @@ -47,11 +49,12 @@ func NewURLTest(ctx context.Context, router adapter.Router, logger log.ContextLo tag: tag, dependencies: options.Outbounds, }, - ctx: ctx, - tags: options.Outbounds, - link: options.URL, - interval: time.Duration(options.Interval), - tolerance: options.Tolerance, + ctx: ctx, + tags: options.Outbounds, + link: options.URL, + interval: time.Duration(options.Interval), + tolerance: options.Tolerance, + interruptExternalConnections: options.InterruptExistConnections, } if len(outbound.tags) == 0 { return nil, E.New("missing tags") @@ -75,7 +78,7 @@ func (s *URLTest) Start() error { } outbounds = append(outbounds, detour) } - s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance) + s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance, s.interruptExternalConnections) return nil } @@ -111,7 +114,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M outbound := s.group.Select(network) conn, err := outbound.DialContext(ctx, network, destination) if err == nil { - return conn, nil + return s.group.interruptGroup.NewConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil } s.logger.ErrorContext(ctx, err) s.group.history.DeleteURLTestHistory(outbound.Tag()) @@ -123,7 +126,7 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne outbound := s.group.Select(N.NetworkUDP) conn, err := outbound.ListenPacket(ctx, destination) if err == nil { - return conn, nil + return s.group.interruptGroup.NewPacketConn(conn, interrupt.IsExternalConnectionFromContext(ctx)), nil } s.logger.ErrorContext(ctx, err) s.group.history.DeleteURLTestHistory(outbound.Tag()) @@ -131,10 +134,12 @@ func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (ne } func (s *URLTest) NewConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error { + ctx = interrupt.ContextWithIsExternalConnection(ctx) return NewConnection(ctx, s, conn, metadata) } func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error { + ctx = interrupt.ContextWithIsExternalConnection(ctx) return NewPacketConnection(ctx, s, conn, metadata) } @@ -144,23 +149,36 @@ func (s *URLTest) InterfaceUpdated() { } type URLTestGroup struct { - ctx context.Context - router adapter.Router - logger log.Logger - outbounds []adapter.Outbound - link string - interval time.Duration - tolerance uint16 - history *urltest.HistoryStorage - checking atomic.Bool - pauseManager pause.Manager + ctx context.Context + router adapter.Router + logger log.Logger + outbounds []adapter.Outbound + link string + interval time.Duration + tolerance uint16 + history *urltest.HistoryStorage + checking atomic.Bool + pauseManager pause.Manager + selectedOutboundTCP adapter.Outbound + selectedOutboundUDP adapter.Outbound + interruptGroup *interrupt.Group + interruptExternalConnections bool access sync.Mutex ticker *time.Ticker close chan struct{} } -func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logger, outbounds []adapter.Outbound, link string, interval time.Duration, tolerance uint16) *URLTestGroup { +func NewURLTestGroup( + ctx context.Context, + router adapter.Router, + logger log.Logger, + outbounds []adapter.Outbound, + link string, + interval time.Duration, + tolerance uint16, + interruptExternalConnections bool, +) *URLTestGroup { if interval == 0 { interval = C.DefaultURLTestInterval } @@ -175,16 +193,18 @@ func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logg history = urltest.NewHistoryStorage() } return &URLTestGroup{ - ctx: ctx, - router: router, - logger: logger, - outbounds: outbounds, - link: link, - interval: interval, - tolerance: tolerance, - history: history, - close: make(chan struct{}), - pauseManager: pause.ManagerFromContext(ctx), + ctx: ctx, + router: router, + logger: logger, + outbounds: outbounds, + link: link, + interval: interval, + tolerance: tolerance, + history: history, + close: make(chan struct{}), + pauseManager: pause.ManagerFromContext(ctx), + interruptGroup: interrupt.NewGroup(), + interruptExternalConnections: interruptExternalConnections, } } @@ -329,5 +349,23 @@ func (g *URLTestGroup) urlTest(ctx context.Context, link string, force bool) (ma }) } b.Wait() + g.performUpdateCheck() return result, nil } + +func (g *URLTestGroup) performUpdateCheck() { + outbound := g.Select(N.NetworkTCP) + var updated bool + if outbound != g.selectedOutboundTCP { + g.selectedOutboundTCP = outbound + updated = true + } + outbound = g.Select(N.NetworkUDP) + if outbound != g.selectedOutboundUDP { + g.selectedOutboundUDP = outbound + updated = true + } + if updated { + g.interruptGroup.Interrupt(g.interruptExternalConnections) + } +}