sing-box/transport/vless/vision.go

366 lines
9.9 KiB
Go
Raw Permalink Normal View History

package vless
import (
"bytes"
"crypto/rand"
"crypto/tls"
"io"
"math/big"
"net"
"reflect"
"time"
"unsafe"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
2023-02-27 15:07:15 +08:00
"github.com/sagernet/sing/common/logger"
N "github.com/sagernet/sing/common/network"
)
var tlsRegistry []func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr)
func init() {
tlsRegistry = append(tlsRegistry, func(conn net.Conn) (loaded bool, netConn net.Conn, reflectType reflect.Type, reflectPointer uintptr) {
tlsConn, loaded := conn.(*tls.Conn)
if !loaded {
return
}
return true, tlsConn.NetConn(), reflect.TypeOf(tlsConn).Elem(), uintptr(unsafe.Pointer(tlsConn))
})
}
2023-03-05 14:56:09 +08:00
const xrayChunkSize = 8192
type VisionConn struct {
net.Conn
2023-03-05 14:56:09 +08:00
reader *bufio.ChunkReader
writer N.VectorisedWriter
input *bytes.Reader
rawInput *bytes.Buffer
netConn net.Conn
2023-02-27 15:07:15 +08:00
logger logger.Logger
2023-02-27 15:07:15 +08:00
userUUID [16]byte
isTLS bool
numberOfPacketToFilter int
isTLS12orAbove bool
remainingServerHello int32
cipher uint16
enableXTLS bool
isPadding bool
directWrite bool
writeUUID bool
withinPaddingBuffers bool
remainingContent int
remainingPadding int
currentCommand int
directRead bool
remainingReader io.Reader
}
2023-02-27 15:07:15 +08:00
func NewVisionConn(conn net.Conn, userUUID [16]byte, logger logger.Logger) (*VisionConn, error) {
var (
loaded bool
reflectType reflect.Type
reflectPointer uintptr
netConn net.Conn
)
for _, tlsCreator := range tlsRegistry {
loaded, netConn, reflectType, reflectPointer = tlsCreator(conn)
if loaded {
break
}
}
if !loaded {
return nil, C.ErrTLSRequired
}
input, _ := reflectType.FieldByName("input")
rawInput, _ := reflectType.FieldByName("rawInput")
return &VisionConn{
2023-02-27 15:07:15 +08:00
Conn: conn,
2023-03-05 14:56:09 +08:00
reader: bufio.NewChunkReader(conn, xrayChunkSize),
2023-02-27 15:07:15 +08:00
writer: bufio.NewVectorisedWriter(conn),
input: (*bytes.Reader)(unsafe.Pointer(reflectPointer + input.Offset)),
rawInput: (*bytes.Buffer)(unsafe.Pointer(reflectPointer + rawInput.Offset)),
netConn: netConn,
logger: logger,
userUUID: userUUID,
numberOfPacketToFilter: 8,
remainingServerHello: -1,
isPadding: true,
writeUUID: true,
withinPaddingBuffers: true,
remainingContent: -1,
remainingPadding: -1,
}, nil
}
func (c *VisionConn) Read(p []byte) (n int, err error) {
if c.remainingReader != nil {
n, err = c.remainingReader.Read(p)
if err == io.EOF {
c.remainingReader = nil
2023-03-05 14:56:09 +08:00
}
if n > 0 {
return
}
}
if c.directRead {
return c.netConn.Read(p)
}
2023-03-05 14:56:09 +08:00
var bufferBytes []byte
if len(p) > xrayChunkSize {
n, err = c.Conn.Read(p)
if err != nil {
return
}
bufferBytes = p[:n]
} else {
buffer, err := c.reader.ReadChunk()
if err != nil {
return 0, err
}
defer buffer.FullReset()
bufferBytes = buffer.Bytes()
}
2023-02-27 15:07:15 +08:00
if c.withinPaddingBuffers || c.numberOfPacketToFilter > 0 {
2023-03-05 14:56:09 +08:00
buffers := c.unPadding(bufferBytes)
if c.remainingContent == 0 && c.remainingPadding == 0 {
if c.currentCommand == 1 {
2023-02-27 15:07:15 +08:00
c.withinPaddingBuffers = false
c.remainingContent = -1
c.remainingPadding = -1
} else if c.currentCommand == 2 {
2023-02-27 15:07:15 +08:00
c.withinPaddingBuffers = false
c.directRead = true
inputBuffer, err := io.ReadAll(c.input)
if err != nil {
return 0, err
}
buffers = append(buffers, inputBuffer)
rawInputBuffer, err := io.ReadAll(c.rawInput)
if err != nil {
return 0, err
}
buffers = append(buffers, rawInputBuffer)
2023-02-27 15:07:15 +08:00
c.logger.Trace("XtlsRead readV")
} else if c.currentCommand == 0 {
c.withinPaddingBuffers = true
} else {
return 0, E.New("unknown command ", c.currentCommand)
}
2023-02-27 15:07:15 +08:00
} else if c.remainingContent > 0 || c.remainingPadding > 0 {
c.withinPaddingBuffers = true
} else {
c.withinPaddingBuffers = false
}
if c.numberOfPacketToFilter > 0 {
c.filterTLS(buffers)
}
c.remainingReader = io.MultiReader(common.Map(buffers, func(it []byte) io.Reader { return bytes.NewReader(it) })...)
2023-03-06 11:19:23 +08:00
return c.Read(p)
} else {
if c.numberOfPacketToFilter > 0 {
2023-03-05 14:56:09 +08:00
c.filterTLS([][]byte{bufferBytes})
}
return
}
}
func (c *VisionConn) Write(p []byte) (n int, err error) {
if c.numberOfPacketToFilter > 0 {
c.filterTLS([][]byte{p})
}
2023-02-27 15:07:15 +08:00
if c.isPadding {
inputLen := len(p)
buffers := reshapeBuffer(p)
var specIndex int
for i, buffer := range buffers {
2023-02-27 15:07:15 +08:00
if c.isTLS && buffer.Len() > 6 && bytes.Equal(tlsApplicationDataStart, buffer.To(3)) {
var command byte = commandPaddingEnd
if c.enableXTLS {
c.directWrite = true
specIndex = i
2023-02-27 15:07:15 +08:00
command = commandPaddingDirect
}
2023-02-27 15:07:15 +08:00
c.isPadding = false
buffers[i] = c.padding(buffer, command)
break
2023-02-27 15:07:15 +08:00
} else if !c.isTLS12orAbove && c.numberOfPacketToFilter <= 1 {
c.isPadding = false
buffers[i] = c.padding(buffer, commandPaddingEnd)
break
}
2023-02-27 15:07:15 +08:00
buffers[i] = c.padding(buffer, commandPaddingContinue)
}
if c.directWrite {
encryptedBuffer := buffers[:specIndex+1]
err = c.writer.WriteVectorised(encryptedBuffer)
if err != nil {
return
}
buffers = buffers[specIndex+1:]
c.writer = bufio.NewVectorisedWriter(c.netConn)
2023-02-27 15:07:15 +08:00
c.logger.Trace("XtlsWrite writeV ", specIndex, " ", buf.LenMulti(encryptedBuffer), " ", len(buffers))
time.Sleep(5 * time.Millisecond) // wtf
}
err = c.writer.WriteVectorised(buffers)
if err == nil {
n = inputLen
}
return
}
if c.directWrite {
return c.netConn.Write(p)
} else {
return c.Conn.Write(p)
}
}
func (c *VisionConn) filterTLS(buffers [][]byte) {
for _, buffer := range buffers {
c.numberOfPacketToFilter--
if len(buffer) > 6 {
if buffer[0] == 22 && buffer[1] == 3 && buffer[2] == 3 {
c.isTLS = true
if buffer[5] == 2 {
c.isTLS12orAbove = true
c.remainingServerHello = (int32(buffer[3])<<8 | int32(buffer[4])) + 5
if len(buffer) >= 79 && c.remainingServerHello >= 79 {
sessionIdLen := int32(buffer[43])
cipherSuite := buffer[43+sessionIdLen+1 : 43+sessionIdLen+3]
c.cipher = uint16(cipherSuite[0])<<8 | uint16(cipherSuite[1])
2023-02-27 15:07:15 +08:00
} else {
c.logger.Trace("XtlsFilterTls short server hello, tls 1.2 or older? ", len(buffer), " ", c.remainingServerHello)
}
}
} else if bytes.Equal(tlsClientHandShakeStart, buffer[:2]) && buffer[5] == 1 {
c.isTLS = true
2023-02-27 15:07:15 +08:00
c.logger.Trace("XtlsFilterTls found tls client hello! ", len(buffer))
}
}
if c.remainingServerHello > 0 {
end := int(c.remainingServerHello)
if end > len(buffer) {
end = len(buffer)
}
c.remainingServerHello -= int32(end)
if bytes.Contains(buffer[:end], tls13SupportedVersions) {
cipher, ok := tls13CipherSuiteDic[c.cipher]
if ok && cipher != "TLS_AES_128_CCM_8_SHA256" {
c.enableXTLS = true
}
2023-02-27 15:07:15 +08:00
c.logger.Trace("XtlsFilterTls found tls 1.3! ", len(buffer), " ", c.cipher, " ", c.enableXTLS)
c.numberOfPacketToFilter = 0
return
} else if c.remainingServerHello == 0 {
2023-02-27 15:07:15 +08:00
c.logger.Trace("XtlsFilterTls found tls 1.2! ", len(buffer))
c.numberOfPacketToFilter = 0
return
}
}
2023-02-27 15:07:15 +08:00
if c.numberOfPacketToFilter == 0 {
c.logger.Trace("XtlsFilterTls stop filtering ", len(buffer))
}
}
}
func (c *VisionConn) padding(buffer *buf.Buffer, command byte) *buf.Buffer {
contentLen := 0
paddingLen := 0
if buffer != nil {
contentLen = buffer.Len()
}
2023-02-27 15:07:15 +08:00
if contentLen < 900 && c.isTLS {
l, _ := rand.Int(rand.Reader, big.NewInt(500))
paddingLen = int(l.Int64()) + 900 - contentLen
2023-02-27 15:07:15 +08:00
} else {
l, _ := rand.Int(rand.Reader, big.NewInt(256))
paddingLen = int(l.Int64())
}
2023-03-19 10:25:26 +08:00
var bufferLen int
if c.writeUUID {
2023-03-19 10:25:26 +08:00
bufferLen += 16
}
bufferLen += 5
if buffer != nil {
bufferLen += buffer.Len()
}
bufferLen += paddingLen
newBuffer := buf.NewSize(bufferLen)
if c.writeUUID {
common.Must1(newBuffer.Write(c.userUUID[:]))
c.writeUUID = false
}
2023-03-19 10:25:26 +08:00
common.Must1(newBuffer.Write([]byte{command, byte(contentLen >> 8), byte(contentLen), byte(paddingLen >> 8), byte(paddingLen)}))
if buffer != nil {
2023-03-19 10:25:26 +08:00
common.Must1(newBuffer.Write(buffer.Bytes()))
buffer.Release()
}
newBuffer.Extend(paddingLen)
2023-02-27 15:07:15 +08:00
c.logger.Trace("XtlsPadding ", contentLen, " ", paddingLen, " ", command)
return newBuffer
}
func (c *VisionConn) unPadding(buffer []byte) [][]byte {
var bufferIndex int
if c.remainingContent == -1 && c.remainingPadding == -1 {
if len(buffer) >= 21 && bytes.Equal(c.userUUID[:], buffer[:16]) {
bufferIndex = 16
c.remainingContent = 0
c.remainingPadding = 0
2023-02-27 15:07:15 +08:00
c.currentCommand = 0
}
}
if c.remainingContent == -1 && c.remainingPadding == -1 {
return [][]byte{buffer}
}
var buffers [][]byte
for bufferIndex < len(buffer) {
if c.remainingContent <= 0 && c.remainingPadding <= 0 {
if c.currentCommand == 1 {
buffers = append(buffers, buffer[bufferIndex:])
break
} else {
paddingInfo := buffer[bufferIndex : bufferIndex+5]
c.currentCommand = int(paddingInfo[0])
c.remainingContent = int(paddingInfo[1])<<8 | int(paddingInfo[2])
c.remainingPadding = int(paddingInfo[3])<<8 | int(paddingInfo[4])
bufferIndex += 5
2023-02-27 15:07:15 +08:00
c.logger.Trace("Xtls Unpadding new block ", bufferIndex, " ", c.remainingContent, " padding ", c.remainingPadding, " ", c.currentCommand)
}
} else if c.remainingContent > 0 {
end := c.remainingContent
if end > len(buffer)-bufferIndex {
end = len(buffer) - bufferIndex
}
buffers = append(buffers, buffer[bufferIndex:bufferIndex+end])
c.remainingContent -= end
bufferIndex += end
} else {
end := c.remainingPadding
if end > len(buffer)-bufferIndex {
end = len(buffer) - bufferIndex
}
c.remainingPadding -= end
bufferIndex += end
}
if bufferIndex == len(buffer) {
break
}
}
return buffers
}
2023-03-03 16:31:07 +08:00
func (c *VisionConn) Upstream() any {
return c.Conn
}