mihomo/transport/hysteria/core/client.go

423 lines
9.8 KiB
Go

package core
import (
"bytes"
"context"
"crypto/tls"
"errors"
"fmt"
"github.com/Dreamacro/clash/transport/hysteria/obfs"
"github.com/Dreamacro/clash/transport/hysteria/pmtud_fix"
"github.com/Dreamacro/clash/transport/hysteria/transport"
"github.com/Dreamacro/clash/transport/hysteria/utils"
"github.com/lunixbochs/struc"
"github.com/metacubex/quic-go"
"github.com/metacubex/quic-go/congestion"
"math/rand"
"net"
"strconv"
"sync"
"time"
)
var (
ErrClosed = errors.New("closed")
)
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
type Client struct {
transport *transport.ClientTransport
serverAddr string
protocol string
sendBPS, recvBPS uint64
auth []byte
congestionFactory CongestionFactory
obfuscator obfs.Obfuscator
tlsConfig *tls.Config
quicConfig *quic.Config
quicSession quic.Connection
reconnectMutex sync.Mutex
closed bool
udpSessionMutex sync.RWMutex
udpSessionMap map[uint32]chan *udpMessage
udpDefragger defragger
}
func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory,
obfuscator obfs.Obfuscator) (*Client, error) {
quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery
c := &Client{
transport: transport,
serverAddr: serverAddr,
protocol: protocol,
sendBPS: sendBPS,
recvBPS: recvBPS,
auth: auth,
congestionFactory: congestionFactory,
obfuscator: obfuscator,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
}
return c, nil
}
func (c *Client) connectToServer(dialer transport.PacketDialer) error {
qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator, dialer)
if err != nil {
return err
}
// Control stream
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
stream, err := qs.OpenStreamSync(ctx)
ctxCancel()
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return err
}
ok, msg, err := c.handleControlStream(qs, stream)
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return err
}
if !ok {
_ = qs.CloseWithError(closeErrorCodeAuth, "auth error")
return fmt.Errorf("auth error: %s", msg)
}
// All good
c.udpSessionMap = make(map[uint32]chan *udpMessage)
go c.handleMessage(qs)
c.quicSession = qs
return nil
}
func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bool, string, error) {
// Send protocol version
_, err := stream.Write([]byte{protocolVersion})
if err != nil {
return false, "", err
}
// Send client hello
err = struc.Pack(stream, &clientHello{
Rate: transmissionRate{
SendBPS: c.sendBPS,
RecvBPS: c.recvBPS,
},
Auth: c.auth,
})
if err != nil {
return false, "", err
}
// Receive server hello
var sh serverHello
err = struc.Unpack(stream, &sh)
if err != nil {
return false, "", err
}
// Set the congestion accordingly
if sh.OK && c.congestionFactory != nil {
qs.SetCongestionControl(c.congestionFactory(sh.Rate.RecvBPS))
}
return sh.OK, sh.Message, nil
}
func (c *Client) handleMessage(qs quic.Connection) {
for {
msg, err := qs.ReceiveMessage()
if err != nil {
break
}
var udpMsg udpMessage
err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
if err != nil {
continue
}
dfMsg := c.udpDefragger.Feed(udpMsg)
if dfMsg == nil {
continue
}
c.udpSessionMutex.RLock()
ch, ok := c.udpSessionMap[dfMsg.SessionID]
if ok {
select {
case ch <- dfMsg:
// OK
default:
// Silently drop the message when the channel is full
}
}
c.udpSessionMutex.RUnlock()
}
}
func (c *Client) openStreamWithReconnect(dialer transport.PacketDialer) (quic.Connection, quic.Stream, error) {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
if c.closed {
return nil, nil, ErrClosed
}
if c.quicSession == nil {
if err := c.connectToServer(dialer); err != nil {
// Still error, oops
return nil, nil, err
}
}
stream, err := c.quicSession.OpenStream()
if err == nil {
// All good
return c.quicSession, &wrappedQUICStream{stream}, nil
}
// Something is wrong
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just return
return nil, nil, err
}
// Permanent error, need to reconnect
if err := c.connectToServer(dialer); err != nil {
// Still error, oops
return nil, nil, err
}
// We are not going to try again even if it still fails the second time
stream, err = c.quicSession.OpenStream()
return c.quicSession, &wrappedQUICStream{stream}, err
}
func (c *Client) DialTCP(addr string, dialer transport.PacketDialer) (net.Conn, error) {
host, port, err := utils.SplitHostPort(addr)
if err != nil {
return nil, err
}
session, stream, err := c.openStreamWithReconnect(dialer)
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: false,
Host: host,
Port: port,
})
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
var sr serverResponse
err = struc.Unpack(stream, &sr)
if err != nil {
_ = stream.Close()
return nil, err
}
if !sr.OK {
_ = stream.Close()
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
}
return &quicConn{
Orig: stream,
PseudoLocalAddr: session.LocalAddr(),
PseudoRemoteAddr: session.RemoteAddr(),
}, nil
}
func (c *Client) DialUDP(dialer transport.PacketDialer) (UDPConn, error) {
session, stream, err := c.openStreamWithReconnect(dialer)
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: true,
})
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
var sr serverResponse
err = struc.Unpack(stream, &sr)
if err != nil {
_ = stream.Close()
return nil, err
}
if !sr.OK {
_ = stream.Close()
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
}
// Create a session in the map
c.udpSessionMutex.Lock()
nCh := make(chan *udpMessage, 1024)
// Store the current session map for CloseFunc below
// to ensures that we are adding and removing sessions on the same map,
// as reconnecting will reassign the map
sessionMap := c.udpSessionMap
sessionMap[sr.UDPSessionID] = nCh
c.udpSessionMutex.Unlock()
pktConn := &quicPktConn{
Session: session,
Stream: stream,
CloseFunc: func() {
c.udpSessionMutex.Lock()
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
close(ch)
delete(sessionMap, sr.UDPSessionID)
}
c.udpSessionMutex.Unlock()
},
UDPSessionID: sr.UDPSessionID,
MsgCh: nCh,
}
go pktConn.Hold()
return pktConn, nil
}
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
c.closed = true
return err
}
type quicConn struct {
Orig quic.Stream
PseudoLocalAddr net.Addr
PseudoRemoteAddr net.Addr
}
func (w *quicConn) Read(b []byte) (n int, err error) {
return w.Orig.Read(b)
}
func (w *quicConn) Write(b []byte) (n int, err error) {
return w.Orig.Write(b)
}
func (w *quicConn) Close() error {
return w.Orig.Close()
}
func (w *quicConn) LocalAddr() net.Addr {
return w.PseudoLocalAddr
}
func (w *quicConn) RemoteAddr() net.Addr {
return w.PseudoRemoteAddr
}
func (w *quicConn) SetDeadline(t time.Time) error {
return w.Orig.SetDeadline(t)
}
func (w *quicConn) SetReadDeadline(t time.Time) error {
return w.Orig.SetReadDeadline(t)
}
func (w *quicConn) SetWriteDeadline(t time.Time) error {
return w.Orig.SetWriteDeadline(t)
}
type UDPConn interface {
ReadFrom() ([]byte, string, error)
WriteTo([]byte, string) error
Close() error
LocalAddr() net.Addr
SetDeadline(t time.Time) error
SetReadDeadline(t time.Time) error
SetWriteDeadline(t time.Time) error
}
type quicPktConn struct {
Session quic.Connection
Stream quic.Stream
CloseFunc func()
UDPSessionID uint32
MsgCh <-chan *udpMessage
}
func (c *quicPktConn) Hold() {
// Hold the stream until it's closed
buf := make([]byte, 1024)
for {
_, err := c.Stream.Read(buf)
if err != nil {
break
}
}
_ = c.Close()
}
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
msg := <-c.MsgCh
if msg == nil {
// Closed
return nil, "", ErrClosed
}
return msg.Data, net.JoinHostPort(msg.Host, strconv.Itoa(int(msg.Port))), nil
}
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
host, port, err := utils.SplitHostPort(addr)
if err != nil {
return err
}
msg := udpMessage{
SessionID: c.UDPSessionID,
Host: host,
Port: port,
FragCount: 1,
Data: p,
}
// try no frag first
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &msg)
err = c.Session.SendMessage(msgBuf.Bytes())
if err != nil {
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
// need to frag
msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
fragMsgs := fragUDPMessage(msg, int(errSize))
for _, fragMsg := range fragMsgs {
msgBuf.Reset()
_ = struc.Pack(&msgBuf, &fragMsg)
err = c.Session.SendMessage(msgBuf.Bytes())
if err != nil {
return err
}
}
return nil
} else {
// some other error
return err
}
} else {
return nil
}
}
func (c *quicPktConn) Close() error {
c.CloseFunc()
return c.Stream.Close()
}
func (c *quicPktConn) LocalAddr() net.Addr {
return c.Session.LocalAddr()
}
func (c *quicPktConn) SetDeadline(t time.Time) error {
return c.Stream.SetDeadline(t)
}
func (c *quicPktConn) SetReadDeadline(t time.Time) error {
return c.Stream.SetReadDeadline(t)
}
func (c *quicPktConn) SetWriteDeadline(t time.Time) error {
return c.Stream.SetWriteDeadline(t)
}