Optimization: refactor picker

This commit is contained in:
Dreamacro 2019-07-02 19:18:03 +08:00
parent 0eff8516c0
commit 7c6c147a18
9 changed files with 123 additions and 104 deletions

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"encoding/json"
"errors"
"net"
@ -99,7 +100,7 @@ func (p *Proxy) MarshalJSON() ([]byte, error) {
}
// URLTest get the delay for the specified URL
func (p *Proxy) URLTest(url string) (t uint16, err error) {
func (p *Proxy) URLTest(ctx context.Context, url string) (t uint16, err error) {
defer func() {
p.alive = err == nil
record := C.DelayHistory{Time: time.Now()}
@ -123,6 +124,13 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
return
}
defer instance.Close()
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return
}
req = req.WithContext(ctx)
transport := &http.Transport{
Dial: func(string, string) (net.Conn, error) {
return instance, nil
@ -133,8 +141,9 @@ func (p *Proxy) URLTest(url string) (t uint16, err error) {
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
client := http.Client{Transport: transport}
resp, err := client.Get(url)
resp, err := client.Do(req)
if err != nil {
return
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"encoding/json"
"errors"
"net"
@ -90,7 +91,7 @@ func (f *Fallback) validTest() {
for _, p := range f.proxies {
go func(p C.Proxy) {
p.URLTest(f.rawURL)
p.URLTest(context.Background(), f.rawURL)
wg.Done()
}(p)
}

View File

@ -1,6 +1,7 @@
package adapters
import (
"context"
"encoding/json"
"errors"
"net"
@ -95,7 +96,7 @@ func (lb *LoadBalance) validTest() {
for _, p := range lb.proxies {
go func(p C.Proxy) {
p.URLTest(lb.rawURL)
p.URLTest(context.Background(), lb.rawURL)
wg.Done()
}(p)
}

View File

@ -5,7 +5,6 @@ import (
"encoding/json"
"errors"
"net"
"sync"
"sync/atomic"
"time"
@ -103,35 +102,22 @@ func (u *URLTest) speedTest() {
}
defer atomic.StoreInt32(&u.once, 0)
wg := sync.WaitGroup{}
wg.Add(len(u.proxies))
c := make(chan interface{})
fast := picker.SelectFast(context.Background(), c)
timer := time.NewTimer(u.interval)
ctx, cancel := context.WithTimeout(context.Background(), u.interval)
defer cancel()
picker, ctx := picker.WithContext(ctx)
for _, p := range u.proxies {
go func(p C.Proxy) {
_, err := p.URLTest(u.rawURL)
if err == nil {
c <- p
picker.Go(func() (interface{}, error) {
_, err := p.URLTest(ctx, u.rawURL)
if err != nil {
return nil, err
}
wg.Done()
}(p)
return p, nil
})
}
go func() {
wg.Wait()
close(c)
}()
select {
case <-timer.C:
// Wait for fast to return or close.
<-fast
case p, open := <-fast:
if open {
u.fast = p.(C.Proxy)
}
fast := picker.Wait()
if fast != nil {
u.fast = fast.(C.Proxy)
}
}

View File

@ -1,22 +1,53 @@
package picker
import "context"
import (
"context"
"sync"
)
// Picker provides synchronization, and Context cancelation
// for groups of goroutines working on subtasks of a common task.
// Inspired by errGroup
type Picker struct {
cancel func()
wg sync.WaitGroup
once sync.Once
result interface{}
}
// WithContext returns a new Picker and an associated Context derived from ctx.
func WithContext(ctx context.Context) (*Picker, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Picker{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned,
// then returns the first nil error result (if any) from them.
func (p *Picker) Wait() interface{} {
p.wg.Wait()
if p.cancel != nil {
p.cancel()
}
return p.result
}
// Go calls the given function in a new goroutine.
// The first call to return a nil error cancels the group; its result will be returned by Wait.
func (p *Picker) Go(f func() (interface{}, error)) {
p.wg.Add(1)
func SelectFast(ctx context.Context, in <-chan interface{}) <-chan interface{} {
out := make(chan interface{})
go func() {
select {
case p, open := <-in:
if open {
out <- p
}
case <-ctx.Done():
}
defer p.wg.Done()
close(out)
for range in {
if ret, err := f(); err == nil {
p.once.Do(func() {
p.result = ret
if p.cancel != nil {
p.cancel()
}
})
}
}()
return out
}

View File

@ -6,39 +6,37 @@ import (
"time"
)
func sleepAndSend(delay int, in chan<- interface{}, input interface{}) {
time.Sleep(time.Millisecond * time.Duration(delay))
in <- input
func sleepAndSend(ctx context.Context, delay int, input interface{}) func() (interface{}, error) {
return func() (interface{}, error) {
timer := time.NewTimer(time.Millisecond * time.Duration(delay))
select {
case <-timer.C:
return input, nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func sleepAndClose(delay int, in chan interface{}) {
time.Sleep(time.Millisecond * time.Duration(delay))
close(in)
}
func TestPicker_Basic(t *testing.T) {
in := make(chan interface{})
fast := SelectFast(context.Background(), in)
go sleepAndSend(20, in, 1)
go sleepAndSend(30, in, 2)
go sleepAndClose(40, in)
picker, ctx := WithContext(context.Background())
picker.Go(sleepAndSend(ctx, 30, 2))
picker.Go(sleepAndSend(ctx, 20, 1))
number, exist := <-fast
if !exist || number != 1 {
t.Error("should recv 1", exist, number)
number := picker.Wait()
if number != nil && number.(int) != 1 {
t.Error("should recv 1", number)
}
}
func TestPicker_Timeout(t *testing.T) {
in := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*5)
defer cancel()
fast := SelectFast(ctx, in)
go sleepAndSend(20, in, 1)
go sleepAndClose(30, in)
picker, ctx := WithContext(ctx)
picker.Go(sleepAndSend(ctx, 20, 1))
_, exist := <-fast
if exist {
t.Error("should recv false")
number := picker.Wait()
if number != nil {
t.Error("should recv nil")
}
}

View File

@ -1,6 +1,7 @@
package constant
import (
"context"
"net"
"time"
)
@ -44,7 +45,7 @@ type Proxy interface {
Alive() bool
DelayHistory() []DelayHistory
LastDelay() uint16
URLTest(url string) (uint16, error)
URLTest(ctx context.Context, url string) (uint16, error)
}
// AdapterType is enum of adapter type

View File

@ -163,32 +163,22 @@ func (r *Resolver) IsFakeIP() bool {
}
func (r *Resolver) batchExchange(clients []resolver, m *D.Msg) (msg *D.Msg, err error) {
in := make(chan interface{})
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
fast := picker.SelectFast(ctx, in)
fast, ctx := picker.WithContext(ctx)
wg := sync.WaitGroup{}
wg.Add(len(clients))
for _, r := range clients {
go func(r resolver) {
defer wg.Done()
fast.Go(func() (interface{}, error) {
msg, err := r.ExchangeContext(ctx, m)
if err != nil || msg.Rcode != D.RcodeSuccess {
return
return nil, errors.New("resolve error")
}
in <- msg
}(r)
return msg, nil
})
}
// release in channel
go func() {
wg.Wait()
close(in)
}()
elm, exist := <-fast
if !exist {
elm := fast.Wait()
if elm == nil {
return nil, errors.New("All DNS requests failed")
}

View File

@ -9,6 +9,7 @@ import (
"time"
A "github.com/Dreamacro/clash/adapters/outbound"
"github.com/Dreamacro/clash/common/picker"
C "github.com/Dreamacro/clash/constant"
T "github.com/Dreamacro/clash/tunnel"
@ -110,27 +111,28 @@ func getProxyDelay(w http.ResponseWriter, r *http.Request) {
proxy := r.Context().Value(CtxKeyProxy).(C.Proxy)
sigCh := make(chan uint16)
go func() {
t, err := proxy.URLTest(url)
if err != nil {
sigCh <- 0
}
sigCh <- t
}()
ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond*time.Duration(timeout))
defer cancel()
picker, ctx := picker.WithContext(ctx)
picker.Go(func() (interface{}, error) {
return proxy.URLTest(ctx, url)
})
select {
case <-time.After(time.Millisecond * time.Duration(timeout)):
elm := picker.Wait()
if elm == nil {
render.Status(r, http.StatusRequestTimeout)
render.JSON(w, r, ErrRequestTimeout)
case t := <-sigCh:
if t == 0 {
return
}
delay := elm.(uint16)
if delay == 0 {
render.Status(r, http.StatusServiceUnavailable)
render.JSON(w, r, newError("An error occurred in the delay test"))
} else {
return
}
render.JSON(w, r, render.M{
"delay": t,
"delay": delay,
})
}
}
}