From b491c350ae2f6b7ca402bab135fde9ef64310ca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 13 Apr 2023 16:11:46 +0800 Subject: [PATCH] URLTest improvements --- outbound/urltest.go | 52 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 41 insertions(+), 11 deletions(-) diff --git a/outbound/urltest.go b/outbound/urltest.go index d15636d8..9717a86c 100644 --- a/outbound/urltest.go +++ b/outbound/urltest.go @@ -13,6 +13,7 @@ import ( "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" "github.com/sagernet/sing/common" + "github.com/sagernet/sing/common/atomic" "github.com/sagernet/sing/common/batch" E "github.com/sagernet/sing/common/exceptions" M "github.com/sagernet/sing/common/metadata" @@ -20,8 +21,9 @@ import ( ) var ( - _ adapter.Outbound = (*URLTest)(nil) - _ adapter.OutboundGroup = (*URLTest)(nil) + _ adapter.Outbound = (*URLTest)(nil) + _ adapter.OutboundGroup = (*URLTest)(nil) + _ adapter.InterfaceUpdateListener = (*URLTest)(nil) ) type URLTest struct { @@ -71,7 +73,8 @@ func (s *URLTest) Start() error { outbounds = append(outbounds, detour) } s.group = NewURLTestGroup(s.ctx, s.router, s.logger, outbounds, s.link, s.interval, s.tolerance) - return s.group.Start() + go s.group.CheckOutbounds(false) + return nil } func (s *URLTest) Close() error { @@ -93,6 +96,7 @@ func (s *URLTest) URLTest(ctx context.Context, link string) (map[string]uint16, } func (s *URLTest) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) { + s.group.Start() outbound := s.group.Select(network) conn, err := outbound.DialContext(ctx, network, destination) if err == nil { @@ -104,6 +108,7 @@ func (s *URLTest) DialContext(ctx context.Context, network string, destination M } func (s *URLTest) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) { + s.group.Start() outbound := s.group.Select(N.NetworkUDP) conn, err := outbound.ListenPacket(ctx, destination) if err == nil { @@ -122,6 +127,11 @@ func (s *URLTest) NewPacketConnection(ctx context.Context, conn N.PacketConn, me return NewPacketConnection(ctx, s, conn, metadata) } +func (s *URLTest) InterfaceUpdated() error { + go s.group.CheckOutbounds(true) + return nil +} + type URLTestGroup struct { ctx context.Context router adapter.Router @@ -131,7 +141,9 @@ type URLTestGroup struct { interval time.Duration tolerance uint16 history *urltest.HistoryStorage + checking atomic.Bool + access sync.Mutex ticker *time.Ticker close chan struct{} } @@ -162,13 +174,23 @@ func NewURLTestGroup(ctx context.Context, router adapter.Router, logger log.Logg } } -func (g *URLTestGroup) Start() error { +func (g *URLTestGroup) Start() { + if g.ticker != nil { + return + } + g.access.Lock() + defer g.access.Unlock() + if g.ticker != nil { + return + } g.ticker = time.NewTicker(g.interval) go g.loopCheck() - return nil } func (g *URLTestGroup) Close() error { + if g.ticker == nil { + return nil + } g.ticker.Stop() close(g.close) return nil @@ -228,25 +250,33 @@ func (g *URLTestGroup) Fallback(used adapter.Outbound) []adapter.Outbound { } func (g *URLTestGroup) loopCheck() { - go g.checkOutbounds() + go g.CheckOutbounds(true) for { select { case <-g.close: return case <-g.ticker.C: - g.checkOutbounds() + g.CheckOutbounds(false) } } } -func (g *URLTestGroup) checkOutbounds() { - _, _ = g.URLTest(g.ctx, g.link) +func (g *URLTestGroup) CheckOutbounds(force bool) { + _, _ = g.urlTest(g.ctx, g.link, force) } func (g *URLTestGroup) URLTest(ctx context.Context, link string) (map[string]uint16, error) { + return g.urlTest(ctx, link, false) +} + +func (g *URLTestGroup) urlTest(ctx context.Context, link string, force bool) (map[string]uint16, error) { + result := make(map[string]uint16) + if g.checking.Swap(true) { + return result, nil + } + defer g.checking.Store(false) b, _ := batch.New(ctx, batch.WithConcurrencyNum[any](10)) checked := make(map[string]bool) - result := make(map[string]uint16) var resultAccess sync.Mutex for _, detour := range g.outbounds { tag := detour.Tag() @@ -255,7 +285,7 @@ func (g *URLTestGroup) URLTest(ctx context.Context, link string) (map[string]uin continue } history := g.history.LoadURLTestHistory(realTag) - if history != nil && time.Now().Sub(history.Time) < g.interval { + if !force && history != nil && time.Now().Sub(history.Time) < g.interval { continue } checked[realTag] = true