sing-box/protocol/wireguard/outbound.go

233 lines
6.4 KiB
Go
Raw Normal View History

2024-11-02 00:39:02 +08:00
package wireguard
2022-08-16 23:37:51 +08:00
import (
"context"
"encoding/base64"
"encoding/hex"
"fmt"
"net"
"net/netip"
2022-08-16 23:37:51 +08:00
"strings"
"github.com/sagernet/sing-box/adapter"
2024-11-02 00:39:02 +08:00
"github.com/sagernet/sing-box/adapter/outbound"
2022-08-16 23:37:51 +08:00
"github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
2022-09-06 00:15:09 +08:00
"github.com/sagernet/sing-box/transport/wireguard"
2022-09-15 12:20:38 +08:00
"github.com/sagernet/sing-tun"
2024-06-06 20:51:21 +08:00
"github.com/sagernet/sing/common"
2022-08-16 23:37:51 +08:00
E "github.com/sagernet/sing/common/exceptions"
2024-11-02 00:39:02 +08:00
"github.com/sagernet/sing/common/logger"
2022-08-16 23:37:51 +08:00
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
2023-12-16 15:40:14 +08:00
"github.com/sagernet/sing/service"
"github.com/sagernet/sing/service/pause"
"github.com/sagernet/wireguard-go/conn"
"github.com/sagernet/wireguard-go/device"
2022-08-16 23:37:51 +08:00
)
2024-11-02 00:39:02 +08:00
func RegisterOutbound(registry *outbound.Registry) {
outbound.Register[option.WireGuardOutboundOptions](registry, C.TypeWireGuard, NewOutbound)
}
2022-08-16 23:37:51 +08:00
2024-11-02 00:39:02 +08:00
var _ adapter.InterfaceUpdateListener = (*Outbound)(nil)
type Outbound struct {
outbound.Adapter
ctx context.Context
2024-11-02 00:39:02 +08:00
router adapter.Router
logger logger.ContextLogger
workers int
peers []wireguard.PeerConfig
useStdNetBind bool
listener N.Dialer
ipcConf string
pauseManager pause.Manager
pauseCallback *list.Element[pause.Callback]
bind conn.Bind
device *device.Device
tunDevice wireguard.Device
2022-08-16 23:37:51 +08:00
}
2024-11-02 00:39:02 +08:00
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.WireGuardOutboundOptions) (adapter.Outbound, error) {
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, options.Network.Build(), tag, options.DialerOptions),
ctx: ctx,
2024-11-02 00:39:02 +08:00
router: router,
logger: logger,
workers: options.Workers,
2023-12-16 15:40:14 +08:00
pauseManager: service.FromContext[pause.Manager](ctx),
2022-08-16 23:37:51 +08:00
}
peers, err := wireguard.ParsePeers(options)
2023-08-08 16:14:03 +08:00
if err != nil {
return nil, err
}
outbound.peers = peers
2023-10-21 12:00:00 +08:00
if len(options.LocalAddress) == 0 {
2022-08-16 23:37:51 +08:00
return nil, E.New("missing local address")
}
if options.GSO {
if options.GSO && options.Detour != "" {
return nil, E.New("gso is conflict with detour")
}
options.IsWireGuardListener = true
outbound.useStdNetBind = true
}
listener, err := dialer.New(ctx, options.DialerOptions)
if err != nil {
return nil, err
}
outbound.listener = listener
var privateKey string
2022-08-16 23:37:51 +08:00
{
bytes, err := base64.StdEncoding.DecodeString(options.PrivateKey)
if err != nil {
return nil, E.Cause(err, "decode private key")
}
privateKey = hex.EncodeToString(bytes)
}
outbound.ipcConf = "private_key=" + privateKey
2022-08-16 23:37:51 +08:00
mtu := options.MTU
if mtu == 0 {
2022-08-17 15:19:10 +08:00
mtu = 1408
2022-08-16 23:37:51 +08:00
}
2022-09-06 00:15:09 +08:00
var wireTunDevice wireguard.Device
2022-09-15 12:20:38 +08:00
if !options.SystemInterface && tun.WithGVisor {
2023-10-21 12:00:00 +08:00
wireTunDevice, err = wireguard.NewStackDevice(options.LocalAddress, mtu)
2022-09-06 00:15:09 +08:00
} else {
2024-11-10 12:11:21 +08:00
wireTunDevice, err = wireguard.NewSystemDevice(service.FromContext[adapter.NetworkManager](ctx), options.InterfaceName, options.LocalAddress, mtu, options.GSO)
2022-09-06 00:15:09 +08:00
}
2022-08-16 23:37:51 +08:00
if err != nil {
2022-09-06 00:15:09 +08:00
return nil, E.Cause(err, "create WireGuard device")
2022-08-16 23:37:51 +08:00
}
outbound.tunDevice = wireTunDevice
return outbound, nil
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) Start() error {
2024-06-06 20:51:21 +08:00
if common.Any(w.peers, func(peer wireguard.PeerConfig) bool {
return !peer.Endpoint.IsValid()
}) {
// wait for all outbounds to be started and continue in PortStart
return nil
}
return w.start()
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) PostStart() error {
2024-06-06 20:51:21 +08:00
if common.All(w.peers, func(peer wireguard.PeerConfig) bool {
return peer.Endpoint.IsValid()
}) {
return nil
}
return w.start()
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) start() error {
err := wireguard.ResolvePeers(w.ctx, w.router, w.peers)
if err != nil {
return err
}
var bind conn.Bind
if w.useStdNetBind {
bind = conn.NewStdNetBind(w.listener.(dialer.WireGuardListener))
} else {
var (
isConnect bool
connectAddr netip.AddrPort
reserved [3]uint8
)
peerLen := len(w.peers)
if peerLen == 1 {
isConnect = true
connectAddr = w.peers[0].Endpoint
reserved = w.peers[0].Reserved
}
2024-11-02 00:39:02 +08:00
bind = wireguard.NewClientBind(w.ctx, w.logger, w.listener, isConnect, connectAddr, reserved)
}
2024-08-26 14:01:32 +08:00
err = w.tunDevice.Start()
if err != nil {
return err
}
wgDevice := device.NewDevice(w.tunDevice, bind, &device.Logger{
2022-08-16 23:37:51 +08:00
Verbosef: func(format string, args ...interface{}) {
w.logger.Debug(fmt.Sprintf(strings.ToLower(format), args...))
2022-08-16 23:37:51 +08:00
},
Errorf: func(format string, args ...interface{}) {
w.logger.Error(fmt.Sprintf(strings.ToLower(format), args...))
2022-08-16 23:37:51 +08:00
},
}, w.workers)
ipcConf := w.ipcConf
for _, peer := range w.peers {
ipcConf += peer.GenerateIpcLines()
2022-08-16 23:37:51 +08:00
}
err = wgDevice.IpcSet(ipcConf)
if err != nil {
return E.Cause(err, "setup wireguard: \n", ipcConf)
2022-08-16 23:37:51 +08:00
}
w.device = wgDevice
w.pauseCallback = w.pauseManager.RegisterCallback(w.onPauseUpdated)
2024-08-26 14:01:32 +08:00
return nil
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) Close() error {
if w.device != nil {
w.device.Close()
}
if w.pauseCallback != nil {
w.pauseManager.UnregisterCallback(w.pauseCallback)
}
return nil
2022-08-16 23:37:51 +08:00
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) InterfaceUpdated() {
w.device.BindUpdate()
2023-07-23 14:42:19 +08:00
return
2022-11-06 10:36:19 +08:00
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) onPauseUpdated(event int) {
switch event {
case pause.EventDevicePaused:
w.device.Down()
case pause.EventDeviceWake:
w.device.Up()
}
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
2022-08-16 23:37:51 +08:00
switch network {
case N.NetworkTCP:
w.logger.InfoContext(ctx, "outbound connection to ", destination)
case N.NetworkUDP:
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
2023-09-06 19:13:39 +08:00
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
2022-08-16 23:37:51 +08:00
if err != nil {
return nil, err
}
2023-09-06 19:13:39 +08:00
return N.DialSerial(ctx, w.tunDevice, network, destination, destinationAddresses)
2022-08-16 23:37:51 +08:00
}
2022-09-06 00:15:09 +08:00
return w.tunDevice.DialContext(ctx, network, destination)
2022-08-16 23:37:51 +08:00
}
2024-11-02 00:39:02 +08:00
func (w *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
2022-08-16 23:37:51 +08:00
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
2023-09-06 19:13:39 +08:00
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
if err != nil {
return nil, err
}
packetConn, _, err := N.ListenSerial(ctx, w.tunDevice, destination, destinationAddresses)
if err != nil {
return nil, err
}
return packetConn, err
}
2022-09-06 00:15:09 +08:00
return w.tunDevice.ListenPacket(ctx, destination)
2022-08-16 23:37:51 +08:00
}