mihomo/component/dialer/dialer.go

452 lines
10 KiB
Go
Raw Normal View History

2020-02-09 17:02:48 +08:00
package dialer
import (
"context"
"errors"
"fmt"
2020-02-09 17:02:48 +08:00
"net"
2022-04-20 01:52:51 +08:00
"net/netip"
2022-08-28 13:41:19 +08:00
"strings"
2022-04-27 21:37:20 +08:00
"sync"
"github.com/Dreamacro/clash/component/resolver"
"go.uber.org/atomic"
2020-02-09 17:02:48 +08:00
)
2022-04-27 21:37:20 +08:00
var (
dialMux sync.Mutex
actualSingleDialContext = singleDialContext
actualDualStackDialContext = dualStackDialContext
tcpConcurrent = false
2022-04-27 21:37:20 +08:00
DisableIPv6 = false
2022-08-28 13:41:19 +08:00
ErrorInvalidedNetworkStack = errors.New("invalided network stack")
ErrorDisableIPv6 = errors.New("IPv6 is disabled, dialer cancel")
2022-04-27 21:37:20 +08:00
)
2022-12-22 09:53:11 +08:00
func applyOptions(options ...Option) *option {
opt := &option{
interfaceName: DefaultInterface.Load(),
routingMark: int(DefaultRoutingMark.Load()),
}
for _, o := range DefaultOptions {
o(opt)
}
for _, o := range options {
o(opt)
}
2022-11-25 08:08:14 +08:00
return opt
}
func DialContext(ctx context.Context, network, address string, options ...Option) (net.Conn, error) {
2022-12-22 09:53:11 +08:00
opt := applyOptions(options...)
2022-11-25 08:08:14 +08:00
2022-08-28 13:41:19 +08:00
if opt.network == 4 || opt.network == 6 {
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
network = fmt.Sprintf("%s%d", network, opt.network)
}
switch network {
case "tcp4", "tcp6", "udp4", "udp6":
2022-04-27 21:37:20 +08:00
return actualSingleDialContext(ctx, network, address, opt)
case "tcp", "udp":
2022-04-27 21:37:20 +08:00
return actualDualStackDialContext(ctx, network, address, opt)
default:
2022-08-28 13:41:19 +08:00
return nil, ErrorInvalidedNetworkStack
}
2020-02-09 17:02:48 +08:00
}
func ListenPacket(ctx context.Context, network, address string, options ...Option) (net.PacketConn, error) {
cfg := &option{
interfaceName: DefaultInterface.Load(),
routingMark: int(DefaultRoutingMark.Load()),
}
for _, o := range DefaultOptions {
o(cfg)
}
for _, o := range options {
o(cfg)
}
lc := &net.ListenConfig{}
if cfg.interfaceName != "" {
addr, err := bindIfaceToListenConfig(cfg.interfaceName, lc, network, address)
if err != nil {
return nil, err
}
address = addr
}
if cfg.addrReuse {
addrReuseToListenConfig(lc)
}
2021-11-08 16:59:48 +08:00
if cfg.routingMark != 0 {
bindMarkToListenConfig(cfg.routingMark, lc, network, address)
}
return lc.ListenPacket(ctx, network, address)
2020-02-09 17:02:48 +08:00
}
2022-04-27 21:37:20 +08:00
func SetDial(concurrent bool) {
dialMux.Lock()
tcpConcurrent = concurrent
2022-04-27 21:37:20 +08:00
if concurrent {
actualSingleDialContext = concurrentSingleDialContext
actualDualStackDialContext = concurrentDualStackDialContext
} else {
actualSingleDialContext = singleDialContext
2022-05-27 20:43:39 +08:00
actualDualStackDialContext = dualStackDialContext
2022-04-27 21:37:20 +08:00
}
dialMux.Unlock()
}
func GetDial() bool {
return tcpConcurrent
}
2022-04-20 01:52:51 +08:00
func dialContext(ctx context.Context, network string, destination netip.Addr, port string, opt *option) (net.Conn, error) {
dialer := &net.Dialer{}
if opt.interfaceName != "" {
if err := bindIfaceToDialer(opt.interfaceName, dialer, network, destination); err != nil {
return nil, err
}
}
2021-11-08 16:59:48 +08:00
if opt.routingMark != 0 {
bindMarkToDialer(opt.routingMark, dialer, network, destination)
}
if DisableIPv6 && destination.Is6() {
2022-08-28 13:41:19 +08:00
return nil, ErrorDisableIPv6
}
return dialer.DialContext(ctx, network, net.JoinHostPort(destination.String(), port))
}
func dualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
net.Conn
error
resolved bool
ipv6 bool
done bool
}
results := make(chan dialResult)
var primary, fallback dialResult
startRacer := func(ctx context.Context, network, host string, r resolver.Resolver, ipv6 bool) {
result := dialResult{ipv6: ipv6, done: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
2022-04-20 01:52:51 +08:00
_ = result.Conn.Close()
}
}
}()
2022-04-20 01:52:51 +08:00
var ip netip.Addr
if ipv6 {
if r == nil {
ip, result.error = resolver.ResolveIPv6ProxyServerHost(ctx, host)
} else {
ip, result.error = resolver.ResolveIPv6WithResolver(ctx, host, r)
}
} else {
if r == nil {
ip, result.error = resolver.ResolveIPv4ProxyServerHost(ctx, host)
} else {
ip, result.error = resolver.ResolveIPv4WithResolver(ctx, host, r)
}
}
if result.error != nil {
return
}
result.resolved = true
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
}
go startRacer(ctx, network+"4", host, opt.resolver, false)
go startRacer(ctx, network+"6", host, opt.resolver, true)
count := 2
for i := 0; i < count; i++ {
select {
case res := <-results:
if res.error == nil {
return res.Conn, nil
}
if !res.ipv6 {
primary = res
2020-08-25 22:19:59 +08:00
} else {
fallback = res
}
if primary.done && fallback.done {
if primary.resolved {
return nil, primary.error
} else if fallback.resolved {
return nil, fallback.error
} else {
return nil, primary.error
}
}
case <-ctx.Done():
2022-11-25 08:08:14 +08:00
err = ctx.Err()
break
}
}
2020-08-25 22:19:59 +08:00
2022-11-25 08:08:14 +08:00
if err == nil {
err = fmt.Errorf("dual stack dial failed")
} else {
err = fmt.Errorf("dual stack dial failed:%w", err)
2022-11-19 10:57:33 +08:00
}
return nil, err
}
2022-04-27 21:37:20 +08:00
func concurrentDualStackDialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
2022-04-27 21:37:20 +08:00
var ips []netip.Addr
if opt.resolver != nil {
ips, err = resolver.LookupIPWithResolver(ctx, host, opt.resolver)
2022-04-27 21:37:20 +08:00
} else {
ips, err = resolver.LookupIPProxyServerHost(ctx, host)
2022-04-27 21:37:20 +08:00
}
if err != nil {
return nil, err
}
2022-04-27 21:37:20 +08:00
return concurrentDialContext(ctx, network, ips, port, opt)
}
func concurrentDialContext(ctx context.Context, network string, ips []netip.Addr, port string, opt *option) (net.Conn, error) {
returned := make(chan struct{})
defer close(returned)
type dialResult struct {
ip netip.Addr
net.Conn
error
2022-08-28 13:41:19 +08:00
isPrimary bool
done bool
}
2022-08-28 13:41:19 +08:00
preferCount := atomic.NewInt32(0)
results := make(chan dialResult)
tcpRacer := func(ctx context.Context, ip netip.Addr) {
2022-08-28 13:41:19 +08:00
result := dialResult{ip: ip, done: true}
defer func() {
select {
case results <- result:
case <-returned:
if result.Conn != nil {
2022-08-28 13:41:19 +08:00
_ = result.Conn.Close()
}
}
}()
2022-08-28 13:41:19 +08:00
if strings.Contains(network, "tcp") {
network = "tcp"
} else {
network = "udp"
}
if ip.Is6() {
2022-08-28 13:41:19 +08:00
network += "6"
if opt.prefer != 4 {
result.isPrimary = true
}
}
if ip.Is4() {
network += "4"
if opt.prefer != 6 {
result.isPrimary = true
}
}
if result.isPrimary {
preferCount.Add(1)
}
2022-04-27 21:37:20 +08:00
2022-08-28 13:41:19 +08:00
result.Conn, result.error = dialContext(ctx, network, ip, port, opt)
}
for _, ip := range ips {
go tcpRacer(ctx, ip)
}
connCount := len(ips)
2022-08-28 13:41:19 +08:00
var fallback dialResult
var primaryError error
2022-11-19 10:57:33 +08:00
var finalError error
for i := 0; i < connCount; i++ {
select {
case res := <-results:
if res.error == nil {
2022-08-28 13:41:19 +08:00
if res.isPrimary {
return res.Conn, nil
} else {
2022-08-28 20:26:13 +08:00
if !fallback.done || fallback.error != nil {
fallback = res
}
2022-08-28 13:41:19 +08:00
}
} else {
if res.isPrimary {
primaryError = res.error
2022-08-28 13:41:19 +08:00
preferCount.Add(-1)
2022-08-28 20:26:13 +08:00
if preferCount.Load() == 0 && fallback.done && fallback.error == nil {
2022-08-28 13:41:19 +08:00
return fallback.Conn, nil
}
}
}
case <-ctx.Done():
2022-08-28 20:26:13 +08:00
if fallback.done && fallback.error == nil {
2022-08-28 13:41:19 +08:00
return fallback.Conn, nil
}
2022-11-25 08:08:14 +08:00
finalError = ctx.Err()
break
}
}
2022-04-27 21:37:20 +08:00
2022-08-28 20:26:13 +08:00
if fallback.done && fallback.error == nil {
return fallback.Conn, nil
}
if primaryError != nil {
return nil, primaryError
}
if fallback.error != nil {
return nil, fallback.error
}
2022-11-25 08:08:14 +08:00
if finalError == nil {
finalError = fmt.Errorf("all ips %v tcp shake hands failed", ips)
} else {
finalError = fmt.Errorf("concurrent dial failed:%w", finalError)
2022-11-19 10:50:13 +08:00
}
2022-11-19 10:57:33 +08:00
return nil, finalError
2022-04-27 21:37:20 +08:00
}
func singleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ip netip.Addr
switch network {
case "tcp4", "udp4":
if opt.resolver == nil {
ip, err = resolver.ResolveIPv4ProxyServerHost(ctx, host)
2022-04-27 21:37:20 +08:00
} else {
ip, err = resolver.ResolveIPv4WithResolver(ctx, host, opt.resolver)
2022-04-27 21:37:20 +08:00
}
default:
if opt.resolver == nil {
ip, err = resolver.ResolveIPv6ProxyServerHost(ctx, host)
2022-04-27 21:37:20 +08:00
} else {
ip, err = resolver.ResolveIPv6WithResolver(ctx, host, opt.resolver)
2022-04-27 21:37:20 +08:00
}
}
if err != nil {
return nil, err
}
return dialContext(ctx, network, ip, port, opt)
}
func concurrentSingleDialContext(ctx context.Context, network string, address string, opt *option) (net.Conn, error) {
2022-08-28 13:41:19 +08:00
switch network {
case "tcp4", "udp4":
return concurrentIPv4DialContext(ctx, network, address, opt)
default:
return concurrentIPv6DialContext(ctx, network, address, opt)
}
}
func concurrentIPv4DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
2022-04-27 21:37:20 +08:00
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
if opt.resolver == nil {
ips, err = resolver.LookupIPv4ProxyServerHost(ctx, host)
2022-08-28 13:41:19 +08:00
} else {
ips, err = resolver.LookupIPv4WithResolver(ctx, host, opt.resolver)
2022-08-28 13:41:19 +08:00
}
if err != nil {
return nil, err
}
return concurrentDialContext(ctx, network, ips, port, opt)
}
func concurrentIPv6DialContext(ctx context.Context, network, address string, opt *option) (net.Conn, error) {
host, port, err := net.SplitHostPort(address)
if err != nil {
return nil, err
}
var ips []netip.Addr
if opt.resolver == nil {
ips, err = resolver.LookupIPv6ProxyServerHost(ctx, host)
2022-08-28 13:41:19 +08:00
} else {
ips, err = resolver.LookupIPv6WithResolver(ctx, host, opt.resolver)
2022-04-27 21:37:20 +08:00
}
if err != nil {
return nil, err
}
return concurrentDialContext(ctx, network, ips, port, opt)
}
2022-12-19 21:34:07 +08:00
2022-12-22 09:53:11 +08:00
type Dialer struct {
Opt option
2022-12-19 21:34:07 +08:00
}
2022-12-22 09:53:11 +08:00
func (d Dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return DialContext(ctx, network, address, WithOption(d.Opt))
2022-12-19 21:34:07 +08:00
}
2022-12-22 09:53:11 +08:00
func (d Dialer) ListenPacket(ctx context.Context, network, address string, rAddrPort netip.AddrPort) (net.PacketConn, error) {
return ListenPacket(ctx, ParseNetwork(network, rAddrPort.Addr()), address, WithOption(d.Opt))
2022-12-20 00:11:02 +08:00
}
2022-12-22 09:53:11 +08:00
func NewDialer(options ...Option) Dialer {
opt := applyOptions(options...)
return Dialer{Opt: *opt}
2022-12-19 21:34:07 +08:00
}