sing-box/common/sniff/quic.go

378 lines
9.7 KiB
Go
Raw Normal View History

2022-07-06 12:39:44 +08:00
package sniff
import (
"bytes"
"context"
"crypto"
"crypto/aes"
2024-07-07 15:45:50 +08:00
"crypto/tls"
2022-07-06 12:39:44 +08:00
"encoding/binary"
"io"
"os"
"github.com/sagernet/sing-box/adapter"
2024-07-07 15:45:50 +08:00
"github.com/sagernet/sing-box/common/ja3"
2022-07-06 12:39:44 +08:00
"github.com/sagernet/sing-box/common/sniff/internal/qtls"
C "github.com/sagernet/sing-box/constant"
2024-07-07 15:45:50 +08:00
"github.com/sagernet/sing/common/buf"
2022-07-10 07:52:33 +08:00
E "github.com/sagernet/sing/common/exceptions"
2022-07-06 15:01:09 +08:00
2022-07-06 12:39:44 +08:00
"golang.org/x/crypto/hkdf"
)
2024-07-07 15:45:50 +08:00
var ErrClientHelloFragmented = E.New("need more packet for chromium QUIC connection")
2022-07-06 12:39:44 +08:00
2024-07-07 15:45:50 +08:00
func QUICClientHello(ctx context.Context, metadata *adapter.InboundContext, packet []byte) error {
reader := bytes.NewReader(packet)
2022-07-06 12:39:44 +08:00
typeByte, err := reader.ReadByte()
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
2022-10-12 13:37:06 +08:00
if typeByte&0x40 == 0 {
2024-07-07 15:45:50 +08:00
return E.New("bad type byte")
2022-07-06 12:39:44 +08:00
}
var versionNumber uint32
err = binary.Read(reader, binary.BigEndian, &versionNumber)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
if versionNumber != qtls.VersionDraft29 && versionNumber != qtls.Version1 && versionNumber != qtls.Version2 {
2024-07-07 15:45:50 +08:00
return E.New("bad version")
2022-07-10 07:52:33 +08:00
}
2022-07-10 14:22:28 +08:00
packetType := (typeByte & 0x30) >> 4
if packetType == 0 && versionNumber == qtls.Version2 || packetType == 2 && versionNumber != qtls.Version2 || packetType > 2 {
2024-07-07 15:45:50 +08:00
return E.New("bad packet type")
2022-07-06 12:39:44 +08:00
}
destConnIDLen, err := reader.ReadByte()
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
2022-07-10 07:52:33 +08:00
if destConnIDLen == 0 || destConnIDLen > 20 {
2024-07-07 15:45:50 +08:00
return E.New("bad destination connection id length")
2022-07-10 07:52:33 +08:00
}
2022-07-06 12:39:44 +08:00
destConnID := make([]byte, destConnIDLen)
_, err = io.ReadFull(reader, destConnID)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
srcConnIDLen, err := reader.ReadByte()
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
_, err = io.CopyN(io.Discard, reader, int64(srcConnIDLen))
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
tokenLen, err := qtls.ReadUvarint(reader)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
_, err = io.CopyN(io.Discard, reader, int64(tokenLen))
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
packetLen, err := qtls.ReadUvarint(reader)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
hdrLen := int(reader.Size()) - reader.Len()
2022-07-10 07:52:33 +08:00
if hdrLen+int(packetLen) > len(packet) {
2024-07-07 15:45:50 +08:00
return os.ErrInvalid
2022-07-06 12:39:44 +08:00
}
_, err = io.CopyN(io.Discard, reader, 4)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
pnBytes := make([]byte, aes.BlockSize)
_, err = io.ReadFull(reader, pnBytes)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
var salt []byte
switch versionNumber {
case qtls.Version1:
salt = qtls.SaltV1
case qtls.Version2:
salt = qtls.SaltV2
default:
salt = qtls.SaltOld
}
var hkdfHeaderProtectionLabel string
switch versionNumber {
case qtls.Version2:
hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV2
default:
hkdfHeaderProtectionLabel = qtls.HKDFLabelHeaderProtectionV1
}
initialSecret := hkdf.Extract(crypto.SHA256.New, destConnID, salt)
secret := qtls.HKDFExpandLabel(crypto.SHA256, initialSecret, []byte{}, "client in", crypto.SHA256.Size())
hpKey := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, hkdfHeaderProtectionLabel, 16)
block, err := aes.NewCipher(hpKey)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
mask := make([]byte, aes.BlockSize)
block.Encrypt(mask, pnBytes)
newPacket := make([]byte, len(packet))
copy(newPacket, packet)
newPacket[0] ^= mask[0] & 0xf
for i := range newPacket[hdrLen : hdrLen+4] {
newPacket[hdrLen+i] ^= mask[i+1]
}
packetNumberLength := newPacket[0]&0x3 + 1
2022-07-10 07:52:33 +08:00
if hdrLen+int(packetNumberLength) > int(packetLen)+hdrLen {
2024-07-07 15:45:50 +08:00
return os.ErrInvalid
2022-07-06 12:39:44 +08:00
}
2022-07-10 07:52:33 +08:00
var packetNumber uint32
switch packetNumberLength {
case 1:
packetNumber = uint32(newPacket[hdrLen])
case 2:
packetNumber = uint32(binary.BigEndian.Uint16(newPacket[hdrLen:]))
case 3:
packetNumber = uint32(newPacket[hdrLen+2]) | uint32(newPacket[hdrLen+1])<<8 | uint32(newPacket[hdrLen])<<16
case 4:
packetNumber = binary.BigEndian.Uint32(newPacket[hdrLen:])
default:
2024-07-07 15:45:50 +08:00
return E.New("bad packet number length")
2022-07-06 12:39:44 +08:00
}
extHdrLen := hdrLen + int(packetNumberLength)
copy(newPacket[extHdrLen:hdrLen+4], packet[extHdrLen:])
data := newPacket[extHdrLen : int(packetLen)+hdrLen]
var keyLabel string
var ivLabel string
switch versionNumber {
case qtls.Version2:
keyLabel = qtls.HKDFLabelKeyV2
ivLabel = qtls.HKDFLabelIVV2
default:
keyLabel = qtls.HKDFLabelKeyV1
ivLabel = qtls.HKDFLabelIVV1
}
key := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, keyLabel, 16)
iv := qtls.HKDFExpandLabel(crypto.SHA256, secret, []byte{}, ivLabel, 12)
cipher := qtls.AEADAESGCMTLS13(key, iv)
nonce := make([]byte, int32(cipher.NonceSize()))
binary.BigEndian.PutUint64(nonce[len(nonce)-8:], uint64(packetNumber))
decrypted, err := cipher.Open(newPacket[extHdrLen:extHdrLen], nonce, data, newPacket[:extHdrLen])
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-07-06 12:39:44 +08:00
}
2022-10-12 13:37:06 +08:00
var frameType byte
2024-07-07 15:45:50 +08:00
var fragments []qCryptoFragment
2022-10-12 13:37:06 +08:00
decryptedReader := bytes.NewReader(decrypted)
2024-07-07 15:45:50 +08:00
const (
frameTypePadding = 0x00
frameTypePing = 0x01
frameTypeAck = 0x02
frameTypeAck2 = 0x03
frameTypeCrypto = 0x06
frameTypeConnectionClose = 0x1c
)
var frameTypeList []uint8
2022-10-12 13:37:06 +08:00
for {
2022-07-10 07:52:33 +08:00
frameType, err = decryptedReader.ReadByte()
2022-10-12 13:37:06 +08:00
if err == io.EOF {
break
}
2024-07-07 15:45:50 +08:00
frameTypeList = append(frameTypeList, frameType)
2022-10-12 13:37:06 +08:00
switch frameType {
2024-07-07 15:45:50 +08:00
case frameTypePadding:
2022-10-12 13:37:06 +08:00
continue
2024-07-07 15:45:50 +08:00
case frameTypePing:
2022-10-12 13:37:06 +08:00
continue
2024-07-07 15:45:50 +08:00
case frameTypeAck, frameTypeAck2:
2023-11-16 22:47:20 +08:00
_, err = qtls.ReadUvarint(decryptedReader) // Largest Acknowledged
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
_, err = qtls.ReadUvarint(decryptedReader) // ACK Delay
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
ackRangeCount, err := qtls.ReadUvarint(decryptedReader) // ACK Range Count
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
_, err = qtls.ReadUvarint(decryptedReader) // First ACK Range
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
for i := 0; i < int(ackRangeCount); i++ {
_, err = qtls.ReadUvarint(decryptedReader) // Gap
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
_, err = qtls.ReadUvarint(decryptedReader) // ACK Range Length
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
}
if frameType == 0x03 {
_, err = qtls.ReadUvarint(decryptedReader) // ECT0 Count
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
_, err = qtls.ReadUvarint(decryptedReader) // ECT1 Count
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
_, err = qtls.ReadUvarint(decryptedReader) // ECN-CE Count
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
}
2024-07-07 15:45:50 +08:00
case frameTypeCrypto:
2022-10-12 13:37:06 +08:00
var offset uint64
offset, err = qtls.ReadUvarint(decryptedReader)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-10-12 13:37:06 +08:00
}
var length uint64
length, err = qtls.ReadUvarint(decryptedReader)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-10-12 13:37:06 +08:00
}
index := len(decrypted) - decryptedReader.Len()
2024-07-07 15:45:50 +08:00
fragments = append(fragments, qCryptoFragment{offset, length, decrypted[index : index+int(length)]})
2022-10-12 13:37:06 +08:00
_, err = decryptedReader.Seek(int64(length), io.SeekCurrent)
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2022-10-12 13:37:06 +08:00
}
2024-07-07 15:45:50 +08:00
case frameTypeConnectionClose:
2023-11-16 22:47:20 +08:00
_, err = qtls.ReadUvarint(decryptedReader) // Error Code
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
_, err = qtls.ReadUvarint(decryptedReader) // Frame Type
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
var length uint64
length, err = qtls.ReadUvarint(decryptedReader) // Reason Phrase Length
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
_, err = decryptedReader.Seek(int64(length), io.SeekCurrent) // Reason Phrase
if err != nil {
2024-07-07 15:45:50 +08:00
return err
2023-11-16 22:47:20 +08:00
}
2022-10-12 13:37:06 +08:00
default:
2024-07-07 15:45:50 +08:00
return os.ErrInvalid
2022-07-10 07:52:33 +08:00
}
2022-07-06 12:39:44 +08:00
}
2024-07-07 15:45:50 +08:00
if metadata.SniffContext != nil {
fragments = append(fragments, metadata.SniffContext.([]qCryptoFragment)...)
metadata.SniffContext = nil
}
var frameLen uint64
for _, fragment := range fragments {
frameLen += fragment.length
}
buffer := buf.NewSize(5 + int(frameLen))
defer buffer.Release()
buffer.WriteByte(0x16)
binary.Write(buffer, binary.BigEndian, uint16(0x0303))
binary.Write(buffer, binary.BigEndian, uint16(frameLen))
2022-10-12 13:37:06 +08:00
var index uint64
var length int
find:
for {
for _, fragment := range fragments {
if fragment.offset == index {
2024-07-07 15:45:50 +08:00
buffer.Write(fragment.payload)
2022-10-12 13:37:06 +08:00
index = fragment.offset + fragment.length
length++
continue find
}
}
2024-07-07 15:45:50 +08:00
break
}
metadata.Protocol = C.ProtocolQUIC
fingerprint, err := ja3.Compute(buffer.Bytes())
if err != nil {
metadata.Protocol = C.ProtocolQUIC
metadata.Client = C.ClientChromium
metadata.SniffContext = fragments
return ErrClientHelloFragmented
}
metadata.Domain = fingerprint.ServerName
for metadata.Client == "" {
if len(frameTypeList) == 1 {
metadata.Client = C.ClientFirefox
break
}
if frameTypeList[0] == frameTypeCrypto && isZero(frameTypeList[1:]) {
if len(fingerprint.Versions) == 2 && fingerprint.Versions[0]&ja3.GreaseBitmask == 0x0A0A &&
len(fingerprint.EllipticCurves) == 5 && fingerprint.EllipticCurves[0]&ja3.GreaseBitmask == 0x0A0A {
metadata.Client = C.ClientSafari
break
}
if len(fingerprint.CipherSuites) == 1 && fingerprint.CipherSuites[0] == tls.TLS_AES_256_GCM_SHA384 &&
len(fingerprint.EllipticCurves) == 1 && fingerprint.EllipticCurves[0] == uint16(tls.X25519) &&
len(fingerprint.SignatureAlgorithms) == 1 && fingerprint.SignatureAlgorithms[0] == uint16(tls.ECDSAWithP256AndSHA256) {
metadata.Client = C.ClientSafari
break
}
}
if frameTypeList[len(frameTypeList)-1] == frameTypeCrypto && isZero(frameTypeList[:len(frameTypeList)-1]) {
metadata.Client = C.ClientQUICGo
2022-10-12 13:37:06 +08:00
break
}
2024-07-07 15:45:50 +08:00
if count(frameTypeList, frameTypeCrypto) > 1 || count(frameTypeList, frameTypePing) > 0 {
if maybeUQUIC(fingerprint) {
metadata.Client = C.ClientQUICGo
} else {
metadata.Client = C.ClientChromium
}
break
}
metadata.Client = C.ClientUnknown
//nolint:staticcheck
break
2022-10-12 13:37:06 +08:00
}
2024-07-07 15:45:50 +08:00
return nil
}
func isZero(slices []uint8) bool {
for _, slice := range slices {
if slice != 0 {
return false
}
2022-07-06 12:39:44 +08:00
}
2024-07-07 15:45:50 +08:00
return true
}
func count(slices []uint8, value uint8) int {
var times int
for _, slice := range slices {
if slice == value {
times++
}
}
return times
}
type qCryptoFragment struct {
offset uint64
length uint64
payload []byte
2022-07-06 12:39:44 +08:00
}