mihomo/adapter/outbound/wireguard.go

587 lines
16 KiB
Go
Raw Normal View History

2022-11-09 18:44:06 +08:00
package outbound
import (
"context"
"encoding/base64"
"encoding/hex"
"errors"
"fmt"
"net"
"net/netip"
"runtime"
"strconv"
"strings"
"sync"
2024-01-30 15:29:48 +08:00
"time"
2022-11-09 18:44:06 +08:00
2024-01-30 15:29:48 +08:00
"github.com/metacubex/mihomo/common/atomic"
2023-11-03 21:01:45 +08:00
CN "github.com/metacubex/mihomo/common/net"
"github.com/metacubex/mihomo/component/dialer"
"github.com/metacubex/mihomo/component/proxydialer"
"github.com/metacubex/mihomo/component/resolver"
C "github.com/metacubex/mihomo/constant"
"github.com/metacubex/mihomo/dns"
"github.com/metacubex/mihomo/log"
2022-11-09 18:44:06 +08:00
2022-11-09 19:42:56 +08:00
wireguard "github.com/metacubex/sing-wireguard"
2022-11-09 18:44:06 +08:00
2024-01-30 15:29:48 +08:00
"github.com/jpillora/backoff"
2022-11-09 18:44:06 +08:00
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/debug"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"github.com/sagernet/wireguard-go/device"
)
type WireGuard struct {
*Base
bind *wireguard.ClientBind
device *device.Device
tunDevice wireguard.Device
2023-09-25 09:10:43 +08:00
dialer proxydialer.SingDialer
2022-11-09 18:44:06 +08:00
startOnce sync.Once
2022-11-18 19:40:39 +08:00
startErr error
resolver *dns.Resolver
2023-04-26 09:33:58 +08:00
refP *refProxyAdapter
2022-11-09 18:44:06 +08:00
}
type WireGuardOption struct {
BasicOption
WireGuardPeerOption
Name string `proxy:"name"`
PrivateKey string `proxy:"private-key"`
Workers int `proxy:"workers,omitempty"`
MTU int `proxy:"mtu,omitempty"`
UDP bool `proxy:"udp,omitempty"`
PersistentKeepalive int `proxy:"persistent-keepalive,omitempty"`
Peers []WireGuardPeerOption `proxy:"peers,omitempty"`
RemoteDnsResolve bool `proxy:"remote-dns-resolve,omitempty"`
Dns []string `proxy:"dns,omitempty"`
}
type WireGuardPeerOption struct {
Server string `proxy:"server"`
Port int `proxy:"port"`
Ip string `proxy:"ip,omitempty"`
Ipv6 string `proxy:"ipv6,omitempty"`
PublicKey string `proxy:"public-key,omitempty"`
PreSharedKey string `proxy:"pre-shared-key,omitempty"`
Reserved []uint8 `proxy:"reserved,omitempty"`
AllowedIPs []string `proxy:"allowed-ips,omitempty"`
2022-11-09 18:44:06 +08:00
}
2023-04-26 09:33:58 +08:00
type wgSingErrorHandler struct {
name string
}
var _ E.Handler = (*wgSingErrorHandler)(nil)
func (w wgSingErrorHandler) NewError(ctx context.Context, err error) {
if E.IsClosedOrCanceled(err) {
log.SingLogger.Debug(fmt.Sprintf("[WG](%s) connection closed: %s", w.name, err))
return
}
log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", w.name, err))
}
2023-03-07 09:30:51 +08:00
type wgNetDialer struct {
tunDevice wireguard.Device
}
var _ dialer.NetDialer = (*wgNetDialer)(nil)
2023-03-07 09:30:51 +08:00
func (d wgNetDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.tunDevice.DialContext(ctx, network, M.ParseSocksaddr(address).Unwrap())
2022-11-09 18:44:06 +08:00
}
func (option WireGuardPeerOption) Addr() M.Socksaddr {
return M.ParseSocksaddrHostPort(option.Server, uint16(option.Port))
}
2022-11-09 18:44:06 +08:00
func (option WireGuardPeerOption) Prefixes() ([]netip.Prefix, error) {
2022-11-09 18:44:06 +08:00
localPrefixes := make([]netip.Prefix, 0, 2)
2022-11-09 19:35:03 +08:00
if len(option.Ip) > 0 {
if !strings.Contains(option.Ip, "/") {
option.Ip = option.Ip + "/32"
}
if prefix, err := netip.ParsePrefix(option.Ip); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, E.Cause(err, "ip address parse error")
}
2022-11-09 18:44:06 +08:00
}
if len(option.Ipv6) > 0 {
if !strings.Contains(option.Ipv6, "/") {
option.Ipv6 = option.Ipv6 + "/128"
}
if prefix, err := netip.ParsePrefix(option.Ipv6); err == nil {
localPrefixes = append(localPrefixes, prefix)
} else {
return nil, E.Cause(err, "ipv6 address parse error")
}
}
2022-11-09 19:35:03 +08:00
if len(localPrefixes) == 0 {
return nil, E.New("missing local address")
}
return localPrefixes, nil
}
2024-01-30 15:29:48 +08:00
type wgSingDialer struct {
proxydialer.SingDialer
errTimes atomic.Int64
backoff *backoff.Backoff
}
func (d *wgSingDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
if d.errTimes.Load() > 10 {
select {
case <-time.After(d.backoff.Duration()):
case <-ctx.Done():
return nil, ctx.Err()
}
}
c, err := d.SingDialer.DialContext(ctx, network, destination)
if err != nil {
d.errTimes.Add(1)
return nil, err
}
d.errTimes.Store(0)
d.backoff.Reset()
return c, nil
}
func (d *wgSingDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
if d.errTimes.Load() > 10 {
select {
case <-time.After(d.backoff.Duration()):
case <-ctx.Done():
return nil, ctx.Err()
}
}
c, err := d.SingDialer.ListenPacket(ctx, destination)
if err != nil {
d.errTimes.Add(1)
return nil, err
}
d.errTimes.Store(0)
d.backoff.Reset()
return c, nil
}
func NewWireGuard(option WireGuardOption) (*WireGuard, error) {
outbound := &WireGuard{
Base: &Base{
name: option.Name,
addr: net.JoinHostPort(option.Server, strconv.Itoa(option.Port)),
tp: C.WireGuard,
udp: option.UDP,
iface: option.Interface,
rmark: option.RoutingMark,
prefer: C.NewDNSPrefer(option.IPVersion),
},
2024-01-30 15:29:48 +08:00
dialer: &wgSingDialer{
SingDialer: proxydialer.NewByNameSingDialer(option.DialerProxy, dialer.NewDialer()),
errTimes: atomic.NewInt64(0),
backoff: &backoff.Backoff{
Min: 10 * time.Millisecond,
Max: 1 * time.Second,
Factor: 2,
Jitter: true,
},
},
}
runtime.SetFinalizer(outbound, closeWireGuard)
var reserved [3]uint8
if len(option.Reserved) > 0 {
if len(option.Reserved) != 3 {
return nil, E.New("invalid reserved value, required 3 bytes, got ", len(option.Reserved))
}
copy(reserved[:], option.Reserved)
}
var isConnect bool
var connectAddr M.Socksaddr
if len(option.Peers) < 2 {
isConnect = true
if len(option.Peers) == 1 {
connectAddr = option.Peers[0].Addr()
} else {
connectAddr = option.Addr()
}
}
2023-04-26 09:33:58 +08:00
outbound.bind = wireguard.NewClientBind(context.Background(), wgSingErrorHandler{outbound.Name()}, outbound.dialer, isConnect, connectAddr, reserved)
var localPrefixes []netip.Prefix
var privateKey string
2022-11-09 18:44:06 +08:00
{
bytes, err := base64.StdEncoding.DecodeString(option.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey = hex.EncodeToString(bytes)
}
ipcConf := "private_key=" + privateKey
if peersLen := len(option.Peers); peersLen > 0 {
localPrefixes = make([]netip.Prefix, 0, peersLen*2)
for i, peer := range option.Peers {
var peerPublicKey, preSharedKey string
{
bytes, err := base64.StdEncoding.DecodeString(peer.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode public key for peer ", i)
}
peerPublicKey = hex.EncodeToString(bytes)
}
if peer.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(peer.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key for peer ", i)
}
preSharedKey = hex.EncodeToString(bytes)
}
destination := peer.Addr()
ipcConf += "\npublic_key=" + peerPublicKey
ipcConf += "\nendpoint=" + destination.String()
if preSharedKey != "" {
ipcConf += "\npreshared_key=" + preSharedKey
}
if len(peer.AllowedIPs) == 0 {
return nil, E.New("missing allowed_ips for peer ", i)
}
for _, allowedIP := range peer.AllowedIPs {
ipcConf += "\nallowed_ip=" + allowedIP
}
if len(peer.Reserved) > 0 {
if len(peer.Reserved) != 3 {
return nil, E.New("invalid reserved value for peer ", i, ", required 3 bytes, got ", len(peer.Reserved))
}
copy(reserved[:], option.Reserved)
outbound.bind.SetReservedForEndpoint(destination, reserved)
}
prefixes, err := peer.Prefixes()
if err != nil {
return nil, err
}
localPrefixes = append(localPrefixes, prefixes...)
2022-11-09 18:44:06 +08:00
}
} else {
var peerPublicKey, preSharedKey string
{
bytes, err := base64.StdEncoding.DecodeString(option.PublicKey)
if err != nil {
return nil, E.Cause(err, "decode peer public key")
}
peerPublicKey = hex.EncodeToString(bytes)
}
if option.PreSharedKey != "" {
bytes, err := base64.StdEncoding.DecodeString(option.PreSharedKey)
if err != nil {
return nil, E.Cause(err, "decode pre shared key")
}
preSharedKey = hex.EncodeToString(bytes)
}
ipcConf += "\npublic_key=" + peerPublicKey
ipcConf += "\nendpoint=" + connectAddr.String()
if preSharedKey != "" {
ipcConf += "\npreshared_key=" + preSharedKey
}
var err error
localPrefixes, err = option.Prefixes()
2022-11-09 18:44:06 +08:00
if err != nil {
return nil, err
2022-11-09 18:44:06 +08:00
}
var has4, has6 bool
for _, address := range localPrefixes {
if address.Addr().Is4() {
has4 = true
} else {
has6 = true
}
}
if has4 {
ipcConf += "\nallowed_ip=0.0.0.0/0"
}
if has6 {
ipcConf += "\nallowed_ip=::/0"
2022-11-09 18:44:06 +08:00
}
}
if option.PersistentKeepalive != 0 {
ipcConf += fmt.Sprintf("\npersistent_keepalive_interval=%d", option.PersistentKeepalive)
}
2022-11-09 18:44:06 +08:00
mtu := option.MTU
if mtu == 0 {
mtu = 1408
}
if len(localPrefixes) == 0 {
return nil, E.New("missing local address")
}
2022-11-09 18:44:06 +08:00
var err error
outbound.tunDevice, err = wireguard.NewStackDevice(localPrefixes, uint32(mtu))
if err != nil {
return nil, E.Cause(err, "create WireGuard device")
}
2023-08-16 21:30:12 +08:00
outbound.device = device.NewDevice(context.Background(), outbound.tunDevice, outbound.bind, &device.Logger{
2022-11-09 18:44:06 +08:00
Verbosef: func(format string, args ...interface{}) {
2023-04-10 10:13:36 +08:00
log.SingLogger.Debug(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
2022-11-09 18:44:06 +08:00
},
Errorf: func(format string, args ...interface{}) {
2023-04-10 10:13:36 +08:00
log.SingLogger.Error(fmt.Sprintf("[WG](%s) %s", option.Name, fmt.Sprintf(format, args...)))
2022-11-09 18:44:06 +08:00
},
}, option.Workers)
if debug.Enabled {
2023-04-10 10:13:36 +08:00
log.SingLogger.Trace(fmt.Sprintf("[WG](%s) created wireguard ipc conf: \n %s", option.Name, ipcConf))
2022-11-09 18:44:06 +08:00
}
err = outbound.device.IpcSet(ipcConf)
if err != nil {
return nil, E.Cause(err, "setup wireguard")
}
//err = outbound.tunDevice.Start()
var has6 bool
for _, address := range localPrefixes {
if !address.Addr().Unmap().Is4() {
has6 = true
break
}
}
2023-04-26 09:33:58 +08:00
refP := &refProxyAdapter{}
outbound.refP = refP
if option.RemoteDnsResolve && len(option.Dns) > 0 {
nss, err := dns.ParseNameServer(option.Dns)
if err != nil {
return nil, err
}
for i := range nss {
2023-04-26 09:33:58 +08:00
nss[i].ProxyAdapter = refP
}
outbound.resolver = dns.NewResolver(dns.Config{
Main: nss,
IPv6: has6,
})
}
2022-11-09 18:44:06 +08:00
return outbound, nil
}
func closeWireGuard(w *WireGuard) {
if w.device != nil {
w.device.Close()
}
_ = common.Close(w.tunDevice)
}
func (w *WireGuard) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.Conn, err error) {
2023-03-07 09:30:51 +08:00
options := w.Base.DialOptions(opts...)
2023-09-25 09:10:43 +08:00
w.dialer.SetDialer(dialer.NewDialer(options...))
2022-11-09 18:44:06 +08:00
var conn net.Conn
w.startOnce.Do(func() {
2022-11-18 19:40:39 +08:00
w.startErr = w.tunDevice.Start()
2022-11-09 18:44:06 +08:00
})
2022-11-18 19:40:39 +08:00
if w.startErr != nil {
return nil, w.startErr
2022-11-09 18:44:06 +08:00
}
if !metadata.Resolved() || w.resolver != nil {
r := resolver.DefaultResolver
if w.resolver != nil {
2023-04-26 09:33:58 +08:00
w.refP.SetProxyAdapter(w)
defer w.refP.ClearProxyAdapter()
r = w.resolver
}
options = append(options, dialer.WithResolver(r))
2023-03-07 09:30:51 +08:00
options = append(options, dialer.WithNetDialer(wgNetDialer{tunDevice: w.tunDevice}))
conn, err = dialer.NewDialer(options...).DialContext(ctx, "tcp", metadata.RemoteAddress())
2022-11-09 18:44:06 +08:00
} else {
conn, err = w.tunDevice.DialContext(ctx, "tcp", M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
2022-11-09 18:44:06 +08:00
}
if err != nil {
return nil, err
}
2022-11-18 19:32:12 +08:00
if conn == nil {
return nil, E.New("conn is nil")
}
return NewConn(CN.NewRefConn(conn, w), w), nil
2022-11-09 18:44:06 +08:00
}
func (w *WireGuard) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (_ C.PacketConn, err error) {
2023-03-07 09:30:51 +08:00
options := w.Base.DialOptions(opts...)
2023-09-25 09:10:43 +08:00
w.dialer.SetDialer(dialer.NewDialer(options...))
2022-11-09 18:44:06 +08:00
var pc net.PacketConn
w.startOnce.Do(func() {
2022-11-18 19:40:39 +08:00
w.startErr = w.tunDevice.Start()
2022-11-09 18:44:06 +08:00
})
2022-11-18 19:40:39 +08:00
if w.startErr != nil {
return nil, w.startErr
}
2022-11-09 18:44:06 +08:00
if err != nil {
return nil, err
}
if (!metadata.Resolved() || w.resolver != nil) && metadata.Host != "" {
r := resolver.DefaultResolver
if w.resolver != nil {
2023-04-26 09:33:58 +08:00
w.refP.SetProxyAdapter(w)
defer w.refP.ClearProxyAdapter()
r = w.resolver
}
ip, err := resolver.ResolveIPWithResolver(ctx, metadata.Host, r)
2022-11-09 18:44:06 +08:00
if err != nil {
return nil, errors.New("can't resolve ip")
}
metadata.DstIP = ip
}
pc, err = w.tunDevice.ListenPacket(ctx, M.SocksaddrFrom(metadata.DstIP, metadata.DstPort).Unwrap())
2022-11-09 18:44:06 +08:00
if err != nil {
return nil, err
}
2022-11-18 19:40:39 +08:00
if pc == nil {
return nil, E.New("packetConn is nil")
}
return newPacketConn(CN.NewRefPacketConn(pc, w), w), nil
2022-11-09 18:44:06 +08:00
}
// IsL3Protocol implements C.ProxyAdapter
func (w *WireGuard) IsL3Protocol(metadata *C.Metadata) bool {
return true
}
2023-04-26 09:33:58 +08:00
type refProxyAdapter struct {
proxyAdapter C.ProxyAdapter
count int
mutex sync.Mutex
}
func (r *refProxyAdapter) SetProxyAdapter(proxyAdapter C.ProxyAdapter) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.proxyAdapter = proxyAdapter
r.count++
}
func (r *refProxyAdapter) ClearProxyAdapter() {
r.mutex.Lock()
defer r.mutex.Unlock()
r.count--
if r.count == 0 {
r.proxyAdapter = nil
}
}
func (r *refProxyAdapter) Name() string {
if r.proxyAdapter != nil {
return r.proxyAdapter.Name()
}
return ""
}
func (r *refProxyAdapter) Type() C.AdapterType {
if r.proxyAdapter != nil {
return r.proxyAdapter.Type()
}
return C.AdapterType(0)
}
func (r *refProxyAdapter) Addr() string {
if r.proxyAdapter != nil {
return r.proxyAdapter.Addr()
}
return ""
}
func (r *refProxyAdapter) SupportUDP() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportUDP()
}
return false
}
func (r *refProxyAdapter) SupportXUDP() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportXUDP()
}
return false
}
func (r *refProxyAdapter) SupportTFO() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportTFO()
}
return false
}
func (r *refProxyAdapter) MarshalJSON() ([]byte, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.MarshalJSON()
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) StreamConnContext(ctx context.Context, c net.Conn, metadata *C.Metadata) (net.Conn, error) {
2023-04-26 09:33:58 +08:00
if r.proxyAdapter != nil {
return r.proxyAdapter.StreamConnContext(ctx, c, metadata)
2023-04-26 09:33:58 +08:00
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) DialContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.DialContext(ctx, metadata, opts...)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) ListenPacketContext(ctx context.Context, metadata *C.Metadata, opts ...dialer.Option) (C.PacketConn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.ListenPacketContext(ctx, metadata, opts...)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) SupportUOT() bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportUOT()
}
return false
}
func (r *refProxyAdapter) SupportWithDialer() C.NetWork {
if r.proxyAdapter != nil {
return r.proxyAdapter.SupportWithDialer()
}
return C.InvalidNet
}
func (r *refProxyAdapter) DialContextWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.Conn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.DialContextWithDialer(ctx, dialer, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) ListenPacketWithDialer(ctx context.Context, dialer C.Dialer, metadata *C.Metadata) (C.PacketConn, error) {
if r.proxyAdapter != nil {
return r.proxyAdapter.ListenPacketWithDialer(ctx, dialer, metadata)
}
return nil, C.ErrNotSupport
}
func (r *refProxyAdapter) IsL3Protocol(metadata *C.Metadata) bool {
if r.proxyAdapter != nil {
return r.proxyAdapter.IsL3Protocol(metadata)
}
return false
}
func (r *refProxyAdapter) Unwrap(metadata *C.Metadata, touch bool) C.Proxy {
if r.proxyAdapter != nil {
return r.proxyAdapter.Unwrap(metadata, touch)
}
return nil
}
var _ C.ProxyAdapter = (*refProxyAdapter)(nil)