sing-box/transport/wireguard/device_stack.go

285 lines
6.8 KiB
Go
Raw Normal View History

2022-09-15 12:20:38 +08:00
//go:build with_gvisor
2022-09-06 00:15:09 +08:00
package wireguard
import (
"context"
"net"
"os"
2023-06-11 22:07:22 +08:00
"github.com/sagernet/gvisor/pkg/buffer"
"github.com/sagernet/gvisor/pkg/tcpip"
"github.com/sagernet/gvisor/pkg/tcpip/adapters/gonet"
"github.com/sagernet/gvisor/pkg/tcpip/header"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv4"
"github.com/sagernet/gvisor/pkg/tcpip/network/ipv6"
"github.com/sagernet/gvisor/pkg/tcpip/stack"
"github.com/sagernet/gvisor/pkg/tcpip/transport/tcp"
"github.com/sagernet/gvisor/pkg/tcpip/transport/udp"
2023-06-13 22:38:05 +08:00
"github.com/sagernet/sing-tun"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
2024-11-21 18:10:41 +08:00
"github.com/sagernet/wireguard-go/device"
2023-06-13 22:38:05 +08:00
wgTun "github.com/sagernet/wireguard-go/tun"
2022-09-06 00:15:09 +08:00
)
2024-11-21 18:10:41 +08:00
var _ Device = (*stackDevice)(nil)
type stackDevice struct {
stack *stack.Stack
mtu uint32
events chan wgTun.Event
outbound chan *stack.PacketBuffer
done chan struct{}
dispatcher stack.NetworkDispatcher
addr4 tcpip.Address
addr6 tcpip.Address
}
func newStackDevice(options DeviceOptions) (*stackDevice, error) {
tunDevice := &stackDevice{
mtu: options.MTU,
events: make(chan wgTun.Event, 1),
outbound: make(chan *stack.PacketBuffer, 256),
done: make(chan struct{}),
2022-09-06 00:15:09 +08:00
}
2024-11-21 18:10:41 +08:00
ipStack, err := tun.NewGVisorStack((*wireEndpoint)(tunDevice))
2022-09-06 00:15:09 +08:00
if err != nil {
2024-11-21 18:10:41 +08:00
return nil, err
2022-09-06 00:15:09 +08:00
}
2024-11-21 18:10:41 +08:00
for _, prefix := range options.Address {
2023-06-11 15:07:32 +08:00
addr := tun.AddressFromAddr(prefix.Addr())
2022-09-06 00:15:09 +08:00
protoAddr := tcpip.ProtocolAddress{
AddressWithPrefix: tcpip.AddressWithPrefix{
Address: addr,
PrefixLen: prefix.Bits(),
},
}
if prefix.Addr().Is4() {
tunDevice.addr4 = addr
protoAddr.Protocol = ipv4.ProtocolNumber
} else {
tunDevice.addr6 = addr
protoAddr.Protocol = ipv6.ProtocolNumber
}
2024-11-21 18:10:41 +08:00
gErr := ipStack.AddProtocolAddress(tun.DefaultNIC, protoAddr, stack.AddressProperties{})
if gErr != nil {
return nil, E.New("parse local address ", protoAddr.AddressWithPrefix, ": ", gErr.String())
2022-09-06 00:15:09 +08:00
}
}
2024-11-21 18:10:41 +08:00
tunDevice.stack = ipStack
if options.Handler != nil {
ipStack.SetTransportProtocolHandler(tcp.ProtocolNumber, tun.NewTCPForwarder(options.Context, ipStack, options.Handler).HandlePacket)
ipStack.SetTransportProtocolHandler(udp.ProtocolNumber, tun.NewUDPForwarder(options.Context, ipStack, options.Handler, options.UDPTimeout).HandlePacket)
}
2022-09-06 00:15:09 +08:00
return tunDevice, nil
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
2022-09-06 00:15:09 +08:00
addr := tcpip.FullAddress{
2024-11-21 18:10:41 +08:00
NIC: tun.DefaultNIC,
2022-09-06 00:15:09 +08:00
Port: destination.Port,
2023-06-11 15:07:32 +08:00
Addr: tun.AddressFromAddr(destination.Addr),
2022-09-06 00:15:09 +08:00
}
bind := tcpip.FullAddress{
2024-11-21 18:10:41 +08:00
NIC: tun.DefaultNIC,
2022-09-06 00:15:09 +08:00
}
var networkProtocol tcpip.NetworkProtocolNumber
if destination.IsIPv4() {
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4
} else {
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6
}
switch N.NetworkName(network) {
case N.NetworkTCP:
tcpConn, err := DialTCPWithBind(ctx, w.stack, bind, addr, networkProtocol)
if err != nil {
return nil, err
}
return tcpConn, nil
2022-09-06 00:15:09 +08:00
case N.NetworkUDP:
udpConn, err := gonet.DialUDP(w.stack, &bind, &addr, networkProtocol)
if err != nil {
return nil, err
}
return udpConn, nil
2022-09-06 00:15:09 +08:00
default:
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
2022-09-06 00:15:09 +08:00
bind := tcpip.FullAddress{
2024-11-21 18:10:41 +08:00
NIC: tun.DefaultNIC,
2022-09-06 00:15:09 +08:00
}
var networkProtocol tcpip.NetworkProtocolNumber
2023-06-11 15:07:32 +08:00
if destination.IsIPv4() {
2022-09-06 00:15:09 +08:00
networkProtocol = header.IPv4ProtocolNumber
bind.Addr = w.addr4
} else {
networkProtocol = header.IPv6ProtocolNumber
bind.Addr = w.addr6
}
udpConn, err := gonet.DialUDP(w.stack, &bind, nil, networkProtocol)
if err != nil {
return nil, err
}
return udpConn, nil
2022-09-06 00:15:09 +08:00
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) SetDevice(device *device.Device) {
2023-04-20 13:16:31 +08:00
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) Start() error {
2023-04-20 13:16:31 +08:00
w.events <- wgTun.EventUp
2022-09-06 00:15:09 +08:00
return nil
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) File() *os.File {
2022-09-06 00:15:09 +08:00
return nil
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) Read(bufs [][]byte, sizes []int, offset int) (count int, err error) {
2022-11-06 10:20:23 +08:00
select {
case packetBuffer, ok := <-w.outbound:
if !ok {
return 0, os.ErrClosed
}
defer packetBuffer.DecRef()
2023-04-20 13:16:31 +08:00
p := bufs[0]
2022-11-06 10:20:23 +08:00
p = p[offset:]
2023-04-20 13:16:31 +08:00
n := 0
2022-11-06 10:20:23 +08:00
for _, slice := range packetBuffer.AsSlices() {
n += copy(p[n:], slice)
}
2023-04-20 13:16:31 +08:00
sizes[0] = n
count = 1
return
2022-11-06 10:20:23 +08:00
case <-w.done:
2022-09-06 00:15:09 +08:00
return 0, os.ErrClosed
}
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) Write(bufs [][]byte, offset int) (count int, err error) {
2023-04-20 13:16:31 +08:00
for _, b := range bufs {
b = b[offset:]
if len(b) == 0 {
continue
}
var networkProtocol tcpip.NetworkProtocolNumber
switch header.IPVersion(b) {
case header.IPv4Version:
networkProtocol = header.IPv4ProtocolNumber
case header.IPv6Version:
networkProtocol = header.IPv6ProtocolNumber
}
packetBuffer := stack.NewPacketBuffer(stack.PacketBufferOptions{
2023-06-11 15:07:32 +08:00
Payload: buffer.MakeWithData(b),
2023-04-20 13:16:31 +08:00
})
w.dispatcher.DeliverNetworkPacket(networkProtocol, packetBuffer)
packetBuffer.DecRef()
count++
2022-09-06 00:15:09 +08:00
}
return
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) Flush() error {
2022-09-06 00:15:09 +08:00
return nil
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) MTU() (int, error) {
2022-09-06 00:15:09 +08:00
return int(w.mtu), nil
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) Name() (string, error) {
2022-09-06 00:15:09 +08:00
return "sing-box", nil
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) Events() <-chan wgTun.Event {
2022-09-06 00:15:09 +08:00
return w.events
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) Close() error {
2024-08-26 14:01:32 +08:00
close(w.done)
close(w.events)
2022-09-06 00:15:09 +08:00
w.stack.Close()
for _, endpoint := range w.stack.CleanupEndpoints() {
endpoint.Abort()
}
w.stack.Wait()
return nil
}
2024-11-21 18:10:41 +08:00
func (w *stackDevice) BatchSize() int {
2023-04-20 13:16:31 +08:00
return 1
}
2022-09-06 00:15:09 +08:00
var _ stack.LinkEndpoint = (*wireEndpoint)(nil)
2024-11-21 18:10:41 +08:00
type wireEndpoint stackDevice
2022-09-06 00:15:09 +08:00
func (ep *wireEndpoint) MTU() uint32 {
return ep.mtu
}
2024-11-18 12:51:58 +08:00
func (ep *wireEndpoint) SetMTU(mtu uint32) {
}
2022-09-06 00:15:09 +08:00
func (ep *wireEndpoint) MaxHeaderLength() uint16 {
return 0
}
func (ep *wireEndpoint) LinkAddress() tcpip.LinkAddress {
return ""
}
2024-11-18 12:51:58 +08:00
func (ep *wireEndpoint) SetLinkAddress(addr tcpip.LinkAddress) {
}
2022-09-06 00:15:09 +08:00
func (ep *wireEndpoint) Capabilities() stack.LinkEndpointCapabilities {
return stack.CapabilityRXChecksumOffload
2022-09-06 00:15:09 +08:00
}
func (ep *wireEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
ep.dispatcher = dispatcher
}
func (ep *wireEndpoint) IsAttached() bool {
return ep.dispatcher != nil
}
func (ep *wireEndpoint) Wait() {
}
func (ep *wireEndpoint) ARPHardwareType() header.ARPHardwareType {
return header.ARPHardwareNone
}
2024-02-14 13:08:08 +08:00
func (ep *wireEndpoint) AddHeader(buffer *stack.PacketBuffer) {
2022-09-06 00:15:09 +08:00
}
2024-02-14 13:08:08 +08:00
func (ep *wireEndpoint) ParseHeader(ptr *stack.PacketBuffer) bool {
2023-10-21 12:00:00 +08:00
return true
}
2022-09-06 00:15:09 +08:00
func (ep *wireEndpoint) WritePackets(list stack.PacketBufferList) (int, tcpip.Error) {
for _, packetBuffer := range list.AsSlice() {
packetBuffer.IncRef()
2022-11-06 10:20:23 +08:00
select {
case <-ep.done:
return 0, &tcpip.ErrClosedForSend{}
case ep.outbound <- packetBuffer:
}
2022-09-06 00:15:09 +08:00
}
return list.Len(), nil
}
2024-11-18 12:51:58 +08:00
func (ep *wireEndpoint) Close() {
}
func (ep *wireEndpoint) SetOnCloseAction(f func()) {
}